首页 > 编程语言 >diffusers-源码解析-四-

diffusers-源码解析-四-

时间:2024-10-22 12:32:04浏览次数:6  
标签:dim self attention diffusers states 源码 attn hidden 解析

diffusers 源码解析(四)

.\diffusers\models\attention_flax.py

# 版权声明,表明该代码的版权归 HuggingFace 团队所有
# 根据 Apache 2.0 许可证授权使用该文件,未遵守许可证不得使用
# 许可证获取链接
# 指出该软件是以“现状”分发,不附带任何明示或暗示的保证
# 具体的权限和限制请参见许可证

# 导入 functools 模块,用于函数式编程工具
import functools
# 导入 math 模块,提供数学相关的功能
import math

# 导入 flax.linen 模块,作为神经网络构建的工具
import flax.linen as nn
# 导入 jax 库,用于加速计算
import jax
# 导入 jax.numpy 模块,提供类似于 NumPy 的数组功能
import jax.numpy as jnp


def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
    """多头点积注意力,查询数目有限的实现。"""
    # 获取 key 的维度信息,包括 key 的数量、头数和特征维度
    num_kv, num_heads, k_features = key.shape[-3:]
    # 获取 value 的特征维度
    v_features = value.shape[-1]
    # 确保 key_chunk_size 不超过 num_kv
    key_chunk_size = min(key_chunk_size, num_kv)
    # 对查询进行缩放,防止数值溢出
    query = query / jnp.sqrt(k_features)

    @functools.partial(jax.checkpoint, prevent_cse=False)
    def summarize_chunk(query, key, value):
        # 计算查询和键之间的注意力权重
        attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)

        # 获取每个查询的最大得分,用于数值稳定性
        max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
        # 计算最大得分的梯度不更新
        max_score = jax.lax.stop_gradient(max_score)
        # 计算经过 softmax 的注意力权重
        exp_weights = jnp.exp(attn_weights - max_score)

        # 计算加权后的值
        exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
        # 获取每个查询的最大得分
        max_score = jnp.einsum("...qhk->...qh", max_score)

        return (exp_values, exp_weights.sum(axis=-1), max_score)

    def chunk_scanner(chunk_idx):
        # 动态切片获取键的部分数据
        key_chunk = jax.lax.dynamic_slice(
            operand=key,
            start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0],  # [...,k,h,d]
            slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features],  # [...,k,h,d]
        )

        # 动态切片获取值的部分数据
        value_chunk = jax.lax.dynamic_slice(
            operand=value,
            start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0],  # [...,v,h,d]
            slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features],  # [...,v,h,d]
        )

        return summarize_chunk(query, key_chunk, value_chunk)

    # 对每个键块进行注意力计算
    chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))

    # 计算全局最大得分
    global_max = jnp.max(chunk_max, axis=0, keepdims=True)
    # 计算每个块与全局最大得分的差异
    max_diffs = jnp.exp(chunk_max - global_max)

    # 更新值和权重以便于归一化
    chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
    chunk_weights *= max_diffs

    # 计算所有块的总值和总权重
    all_values = chunk_values.sum(axis=0)
    all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)

    # 返回归一化后的总值
    return all_values / all_weights


def jax_memory_efficient_attention(
    query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
):
    r"""
    # Flax 实现的内存高效多头点积注意力机制,相关文献链接
    Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
    # 相关 GitHub 项目链接
    https://github.com/AminRezaei0x443/memory-efficient-attention

    # 参数说明:
    # query: 输入的查询张量,形状为 (batch..., query_length, head, query_key_depth_per_head)
    Args:
        query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
        # key: 输入的键张量,形状为 (batch..., key_value_length, head, query_key_depth_per_head)
        key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
        # value: 输入的值张量,形状为 (batch..., key_value_length, head, value_depth_per_head)
        value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
        # precision: 计算时的数值精度,默认值为 jax.lax.Precision.HIGHEST
        precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
            numerical precision for computation
        # query_chunk_size: 将查询数组划分的块大小,必须能整除 query_length
        query_chunk_size (`int`, *optional*, defaults to 1024):
            chunk size to divide query array value must divide query_length equally without remainder
        # key_chunk_size: 将键和值数组划分的块大小,必须能整除 key_value_length
        key_chunk_size (`int`, *optional*, defaults to 4096):
            chunk size to divide key and value array value must divide key_value_length equally without remainder

    # 返回值为形状为 (batch..., query_length, head, value_depth_per_head) 的数组
    Returns:
        (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
    """
    # 获取查询张量的最后三个维度的大小
    num_q, num_heads, q_features = query.shape[-3:]

    # 定义一个函数,用于扫描处理每个查询块
    def chunk_scanner(chunk_idx, _):
        # 从查询数组中切片出当前块
        query_chunk = jax.lax.dynamic_slice(
            # 操作的对象是查询张量
            operand=query,
            # 起始索引,保持前面的维度不变,从 chunk_idx 开始切片
            start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0],  # [...,q,h,d]
            # 切片的大小,前面的维度不变,后面根据块大小取最小值
            slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features],  # [...,q,h,d]
        )

        return (
            # 返回未使用的下一个块索引
            chunk_idx + query_chunk_size,  # unused ignore it
            # 调用注意力函数处理当前查询块
            _query_chunk_attention(
                query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
            ),
        )

    # 使用 jax.lax.scan 进行块的扫描处理
    _, res = jax.lax.scan(
        f=chunk_scanner,  # 处理函数
        init=0,  # 初始化块索引为 0
        xs=None,  # 不需要额外的输入数据
        # 根据查询块大小计算要处理的块数
        length=math.ceil(num_q / query_chunk_size),  # start counter  # stop counter
    )

    # 将所有块的结果在第 -3 维度拼接在一起
    return jnp.concatenate(res, axis=-3)  # fuse the chunked result back
# 定义一个 Flax 的多头注意力模块,遵循文献中的描述
class FlaxAttention(nn.Module):
    r"""
    Flax多头注意力模块,详见: https://arxiv.org/abs/1706.03762

    参数:
        query_dim (:obj:`int`):
            输入隐藏状态的维度
        heads (:obj:`int`, *optional*, defaults to 8):
            注意力头的数量
        dim_head (:obj:`int`, *optional*, defaults to 64):
            每个头内隐藏状态的维度
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            dropout比率
        use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
            启用内存高效注意力 https://arxiv.org/abs/2112.05682
        split_head_dim (`bool`, *optional*, defaults to `False`):
            是否将头维度拆分为自注意力计算的新轴。通常情况下,启用该标志可以加快Stable Diffusion 2.x和Stable Diffusion XL的计算速度。
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            参数的 `dtype`

    """

    # 定义输入参数的类型和默认值
    query_dim: int
    heads: int = 8
    dim_head: int = 64
    dropout: float = 0.0
    use_memory_efficient_attention: bool = False
    split_head_dim: bool = False
    dtype: jnp.dtype = jnp.float32

    # 设置模块的初始化函数
    def setup(self):
        # 计算内部维度为每个头的维度与头的数量的乘积
        inner_dim = self.dim_head * self.heads
        # 计算缩放因子
        self.scale = self.dim_head**-0.5

        # 创建权重矩阵,使用旧的命名 {to_q, to_k, to_v, to_out}
        self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
        # 创建键的权重矩阵
        self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
        # 创建值的权重矩阵
        self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")

        # 创建输出的权重矩阵
        self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
        # 创建dropout层
        self.dropout_layer = nn.Dropout(rate=self.dropout)

    # 将张量的头部维度重塑为批次维度
    def reshape_heads_to_batch_dim(self, tensor):
        # 解构张量的形状
        batch_size, seq_len, dim = tensor.shape
        head_size = self.heads
        # 重塑张量形状以分离头维度
        tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
        # 转置张量的维度
        tensor = jnp.transpose(tensor, (0, 2, 1, 3))
        # 进一步重塑为批次与头维度合并
        tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
        return tensor

    # 将张量的批次维度重塑为头部维度
    def reshape_batch_dim_to_heads(self, tensor):
        # 解构张量的形状
        batch_size, seq_len, dim = tensor.shape
        head_size = self.heads
        # 重塑张量形状以合并批次与头维度
        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
        # 转置张量的维度
        tensor = jnp.transpose(tensor, (0, 2, 1, 3))
        # 进一步重塑为合并批次与头维度
        tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
        return tensor

# 定义一个 Flax 基础变换器块层,使用 GLU 激活函数,详见:
class FlaxBasicTransformerBlock(nn.Module):
    r"""
    Flax 变换器块层,使用 `GLU` (门控线性单元) 激活函数,详见:
    https://arxiv.org/abs/1706.03762
    # 参数说明部分
    Parameters:
        dim (:obj:`int`):  # 内部隐藏状态的维度
            Inner hidden states dimension
        n_heads (:obj:`int`):  # 注意力头的数量
            Number of heads
        d_head (:obj:`int`):  # 每个头内部隐藏状态的维度
            Hidden states dimension inside each head
        dropout (:obj:`float`, *optional*, defaults to 0.0):  # 随机失活率
            Dropout rate
        only_cross_attention (`bool`, defaults to `False`):  # 是否仅应用交叉注意力
            Whether to only apply cross attention.
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):  # 参数数据类型
            Parameters `dtype`
        use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):  # 启用内存高效注意力
            enable memory efficient attention https://arxiv.org/abs/2112.05682
        split_head_dim (`bool`, *optional*, defaults to `False`):  # 是否将头维度拆分为新轴
            Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
            enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
    """

    dim: int  # 内部隐藏状态维度的类型声明
    n_heads: int  # 注意力头数量的类型声明
    d_head: int  # 每个头的隐藏状态维度的类型声明
    dropout: float = 0.0  # 随机失活率的默认值
    only_cross_attention: bool = False  # 默认不只应用交叉注意力
    dtype: jnp.dtype = jnp.float32  # 默认数据类型为 jnp.float32
    use_memory_efficient_attention: bool = False  # 默认不启用内存高效注意力
    split_head_dim: bool = False  # 默认不拆分头维度

    def setup(self):
        # 设置自注意力(如果 only_cross_attention 为 True,则为交叉注意力)
        self.attn1 = FlaxAttention(
            self.dim,  # 传入的内部隐藏状态维度
            self.n_heads,  # 传入的注意力头数量
            self.d_head,  # 传入的每个头的隐藏状态维度
            self.dropout,  # 传入的随机失活率
            self.use_memory_efficient_attention,  # 是否使用内存高效注意力
            self.split_head_dim,  # 是否拆分头维度
            dtype=self.dtype,  # 传入的数据类型
        )
        # 设置交叉注意力
        self.attn2 = FlaxAttention(
            self.dim,  # 传入的内部隐藏状态维度
            self.n_heads,  # 传入的注意力头数量
            self.d_head,  # 传入的每个头的隐藏状态维度
            self.dropout,  # 传入的随机失活率
            self.use_memory_efficient_attention,  # 是否使用内存高效注意力
            self.split_head_dim,  # 是否拆分头维度
            dtype=self.dtype,  # 传入的数据类型
        )
        # 设置前馈网络
        self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)  # 前馈网络初始化
        # 设置第一个归一化层
        self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)  # 归一化层初始化
        # 设置第二个归一化层
        self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)  # 归一化层初始化
        # 设置第三个归一化层
        self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)  # 归一化层初始化
        # 设置丢弃层
        self.dropout_layer = nn.Dropout(rate=self.dropout)  # 丢弃层初始化
    # 定义可调用对象,接收隐藏状态、上下文和确定性标志
    def __call__(self, hidden_states, context, deterministic=True):
        # 保存输入的隐藏状态以供后续残差连接使用
        residual = hidden_states
        # 如果仅执行交叉注意力,进行相关的处理
        if self.only_cross_attention:
            hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
        else:
            # 否则执行自注意力处理
            hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
        # 将自注意力的输出与输入的残差相加
        hidden_states = hidden_states + residual
    
        # 交叉注意力处理
        residual = hidden_states
        # 处理交叉注意力
        hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
        # 将交叉注意力的输出与输入的残差相加
        hidden_states = hidden_states + residual
    
        # 前馈网络处理
        residual = hidden_states
        # 应用前馈网络
        hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
        # 将前馈网络的输出与输入的残差相加
        hidden_states = hidden_states + residual
    
        # 返回经过 dropout 处理的最终隐藏状态
        return self.dropout_layer(hidden_states, deterministic=deterministic)
# 定义一个二维的 Flax Transformer 模型,继承自 nn.Module
class FlaxTransformer2DModel(nn.Module):
    r"""
    A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
    https://arxiv.org/pdf/1506.02025.pdf

    文档字符串,描述该类的功能和参数。

    Parameters:
        in_channels (:obj:`int`):
            Input number of channels
        n_heads (:obj:`int`):
            Number of heads
        d_head (:obj:`int`):
            Hidden states dimension inside each head
        depth (:obj:`int`, *optional*, defaults to 1):
            Number of transformers block
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout rate
        use_linear_projection (`bool`, defaults to `False`): tbd
        only_cross_attention (`bool`, defaults to `False`): tbd
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            Parameters `dtype`
        use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
            enable memory efficient attention https://arxiv.org/abs/2112.05682
        split_head_dim (`bool`, *optional*, defaults to `False`):
            Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
            enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
    """
    
    # 定义输入通道数
    in_channels: int
    # 定义头的数量
    n_heads: int
    # 定义每个头的隐藏状态维度
    d_head: int
    # 定义 Transformer 块的数量,默认为 1
    depth: int = 1
    # 定义 Dropout 率,默认为 0.0
    dropout: float = 0.0
    # 定义是否使用线性投影,默认为 False
    use_linear_projection: bool = False
    # 定义是否仅使用交叉注意力,默认为 False
    only_cross_attention: bool = False
    # 定义参数的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32
    # 定义是否使用内存高效注意力,默认为 False
    use_memory_efficient_attention: bool = False
    # 定义是否将头维度拆分为新的轴,默认为 False
    split_head_dim: bool = False

    # 设置模型的组件
    def setup(self):
        # 使用 Group Normalization 规范化层,分组数为 32,epsilon 为 1e-5
        self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)

        # 计算内部维度为头的数量乘以每个头的维度
        inner_dim = self.n_heads * self.d_head
        # 根据是否使用线性投影选择输入层
        if self.use_linear_projection:
            # 创建一个线性投影层,输出维度为 inner_dim,数据类型为 self.dtype
            self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
        else:
            # 创建一个卷积层,输出维度为 inner_dim,卷积核大小为 (1, 1),步幅为 (1, 1),填充方式为 "VALID",数据类型为 self.dtype
            self.proj_in = nn.Conv(
                inner_dim,
                kernel_size=(1, 1),
                strides=(1, 1),
                padding="VALID",
                dtype=self.dtype,
            )

        # 创建一系列 Transformer 块,数量为 depth
        self.transformer_blocks = [
            FlaxBasicTransformerBlock(
                inner_dim,
                self.n_heads,
                self.d_head,
                dropout=self.dropout,
                only_cross_attention=self.only_cross_attention,
                dtype=self.dtype,
                use_memory_efficient_attention=self.use_memory_efficient_attention,
                split_head_dim=self.split_head_dim,
            )
            for _ in range(self.depth)  # 循环生成每个 Transformer 块
        ]

        # 根据是否使用线性投影选择输出层
        if self.use_linear_projection:
            # 创建一个线性投影层,输出维度为 inner_dim,数据类型为 self.dtype
            self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
        else:
            # 创建一个卷积层,输出维度为 inner_dim,卷积核大小为 (1, 1),步幅为 (1, 1),填充方式为 "VALID",数据类型为 self.dtype
            self.proj_out = nn.Conv(
                inner_dim,
                kernel_size=(1, 1),
                strides=(1, 1),
                padding="VALID",
                dtype=self.dtype,
            )

        # 创建一个 Dropout 层,Dropout 率为 self.dropout
        self.dropout_layer = nn.Dropout(rate=self.dropout)
    # 定义可调用对象的方法,接收隐藏状态、上下文和确定性标志
    def __call__(self, hidden_states, context, deterministic=True):
        # 解构隐藏状态的形状,获取批量大小、高度、宽度和通道数
        batch, height, width, channels = hidden_states.shape
        # 保存原始隐藏状态以用于残差连接
        residual = hidden_states
        # 对隐藏状态进行归一化处理
        hidden_states = self.norm(hidden_states)
        # 如果使用线性投影,则重塑隐藏状态
        if self.use_linear_projection:
            # 将隐藏状态重塑为(batch, height * width, channels)的形状
            hidden_states = hidden_states.reshape(batch, height * width, channels)
            # 应用输入投影
            hidden_states = self.proj_in(hidden_states)
        else:
            # 直接应用输入投影
            hidden_states = self.proj_in(hidden_states)
            # 将隐藏状态重塑为(batch, height * width, channels)的形状
            hidden_states = hidden_states.reshape(batch, height * width, channels)
    
        # 遍历每个变换块,更新隐藏状态
        for transformer_block in self.transformer_blocks:
            # 通过变换块处理隐藏状态和上下文
            hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
    
        # 如果使用线性投影,则先应用输出投影
        if self.use_linear_projection:
            hidden_states = self.proj_out(hidden_states)
            # 将隐藏状态重塑回原来的形状
            hidden_states = hidden_states.reshape(batch, height, width, channels)
        else:
            # 先重塑隐藏状态
            hidden_states = hidden_states.reshape(batch, height, width, channels)
            # 再应用输出投影
            hidden_states = self.proj_out(hidden_states)
    
        # 将隐藏状态与原始状态相加,实现残差连接
        hidden_states = hidden_states + residual
        # 返回经过dropout层处理后的隐藏状态
        return self.dropout_layer(hidden_states, deterministic=deterministic)
# 定义一个 Flax 的前馈神经网络模块,继承自 nn.Module
class FlaxFeedForward(nn.Module):
    r"""
    Flax 模块封装了两个线性层,中间由一个非线性激活函数分隔。它是 PyTorch 的
    [`FeedForward`] 类的对应物,具有以下简化:
    - 激活函数目前硬编码为门控线性单元,来自:
    https://arxiv.org/abs/2002.05202
    - `dim_out` 等于 `dim`。
    - 隐藏维度的数量硬编码为 `dim * 4` 在 [`FlaxGELU`] 中。

    参数:
        dim (:obj:`int`):
            内部隐藏状态的维度
        dropout (:obj:`float`, *可选*, 默认为 0.0):
            丢弃率
        dtype (:obj:`jnp.dtype`, *可选*, 默认为 jnp.float32):
            参数的数据类型
    """

    # 定义类属性 dim、dropout 和 dtype,分别表示维度、丢弃率和数据类型
    dim: int
    dropout: float = 0.0
    dtype: jnp.dtype = jnp.float32

    # 设置方法,初始化网络层
    def setup(self):
        # 第二个线性层暂时称为 net_2,以匹配顺序层的索引
        self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)  # 初始化 FlaxGEGLU 网络
        self.net_2 = nn.Dense(self.dim, dtype=self.dtype)  # 初始化线性层

    # 定义前向传播方法
    def __call__(self, hidden_states, deterministic=True):
        hidden_states = self.net_0(hidden_states, deterministic=deterministic)  # 通过 net_0 处理隐藏状态
        hidden_states = self.net_2(hidden_states)  # 通过 net_2 处理隐藏状态
        return hidden_states  # 返回处理后的隐藏状态


# 定义 Flax 的 GEGLU 激活层,继承自 nn.Module
class FlaxGEGLU(nn.Module):
    r"""
    Flax 实现的线性层后跟门控线性单元激活函数变体,来自
    https://arxiv.org/abs/2002.05202。

    参数:
        dim (:obj:`int`):
            输入隐藏状态的维度
        dropout (:obj:`float`, *可选*, 默认为 0.0):
            丢弃率
        dtype (:obj:`jnp.dtype`, *可选*, 默认为 jnp.float32):
            参数的数据类型
    """

    # 定义类属性 dim、dropout 和 dtype
    dim: int
    dropout: float = 0.0
    dtype: jnp.dtype = jnp.float32

    # 设置方法,初始化网络层
    def setup(self):
        inner_dim = self.dim * 4  # 计算内部维度
        self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)  # 初始化线性层
        self.dropout_layer = nn.Dropout(rate=self.dropout)  # 初始化丢弃层

    # 定义前向传播方法
    def __call__(self, hidden_states, deterministic=True):
        hidden_states = self.proj(hidden_states)  # 通过线性层处理隐藏状态
        hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)  # 将输出分为两个部分
        return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)  # 返回带丢弃的激活输出

.\diffusers\models\attention_processor.py

# 版权声明,标明该文件的版权归 HuggingFace 团队所有
# 该文件根据 Apache 2.0 许可证进行许可
# 在遵守许可证的情况下,您可以使用该文件
# 许可证的副本可以在以下网址获取
# http://www.apache.org/licenses/LICENSE-2.0
# 除非法律要求或书面同意,否则软件按 "现状" 提供,不附带任何明示或暗示的担保
# 请参阅许可证以了解有关权限和限制的具体信息

import inspect  # 导入 inspect 模块,用于获取对象的信息
import math  # 导入 math 模块,提供数学函数
from typing import Callable, List, Optional, Tuple, Union  # 导入类型提示相关的类型

import torch  # 导入 PyTorch 库
import torch.nn.functional as F  # 导入 PyTorch 中的神经网络功能模块,并重命名为 F
from torch import nn  # 从 PyTorch 导入 nn 模块,提供神经网络的构建块

from ..image_processor import IPAdapterMaskProcessor  # 从上层模块导入 IPAdapterMaskProcessor
from ..utils import deprecate, logging  # 从上层模块导入弃用和日志记录功能
from ..utils.import_utils import is_torch_npu_available, is_xformers_available  # 导入检查 PyTorch NPU 和 xformers 可用性的工具
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph  # 导入与 PyTorch 版本和图形相关的工具

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器实例,便于记录日志信息

if is_torch_npu_available():  # 检查是否可以使用 PyTorch NPU
    import torch_npu  # 如果可用,则导入 torch_npu 模块

if is_xformers_available():  # 检查是否可以使用 xformers 库
    import xformers  # 如果可用,导入 xformers 模块
    import xformers.ops  # 导入 xformers 中的操作模块
else:  # 如果 xformers 不可用
    xformers = None  # 将 xformers 设为 None

@maybe_allow_in_graph  # 装饰器,可能允许在图中使用该类
class Attention(nn.Module):  # 定义 Attention 类,继承自 nn.Module
    r"""  # 文档字符串,描述该类是一个交叉注意力层
    A cross attention layer.
    """

    def __init__(  # 初始化方法,定义构造函数
        self,
        query_dim: int,  # 查询维度,类型为整数
        cross_attention_dim: Optional[int] = None,  # 可选的交叉注意力维度,默认为 None
        heads: int = 8,  # 注意力头的数量,默认为 8
        kv_heads: Optional[int] = None,  # 可选的键值头数量,默认为 None
        dim_head: int = 64,  # 每个头的维度,默认为 64
        dropout: float = 0.0,  # dropout 概率,默认为 0.0
        bias: bool = False,  # 是否使用偏置,默认为 False
        upcast_attention: bool = False,  # 是否上升注意力精度,默认为 False
        upcast_softmax: bool = False,  # 是否上升 softmax 精度,默认为 False
        cross_attention_norm: Optional[str] = None,  # 可选的交叉注意力归一化方式,默认为 None
        cross_attention_norm_num_groups: int = 32,  # 交叉注意力归一化的组数量,默认为 32
        qk_norm: Optional[str] = None,  # 可选的查询键归一化方式,默认为 None
        added_kv_proj_dim: Optional[int] = None,  # 可选的添加键值投影维度,默认为 None
        added_proj_bias: Optional[bool] = True,  # 是否为添加的投影使用偏置,默认为 True
        norm_num_groups: Optional[int] = None,  # 可选的归一化组数量,默认为 None
        spatial_norm_dim: Optional[int] = None,  # 可选的空间归一化维度,默认为 None
        out_bias: bool = True,  # 是否使用输出偏置,默认为 True
        scale_qk: bool = True,  # 是否缩放查询和键,默认为 True
        only_cross_attention: bool = False,  # 是否仅使用交叉注意力,默认为 False
        eps: float = 1e-5,  # 为数值稳定性引入的微小常数,默认为 1e-5
        rescale_output_factor: float = 1.0,  # 输出重标定因子,默认为 1.0
        residual_connection: bool = False,  # 是否使用残差连接,默认为 False
        _from_deprecated_attn_block: bool = False,  # 可选参数,指示是否来自弃用的注意力块,默认为 False
        processor: Optional["AttnProcessor"] = None,  # 可选的处理器,默认为 None
        out_dim: int = None,  # 输出维度,默认为 None
        context_pre_only=None,  # 上下文前处理,默认为 None
        pre_only=False,  # 是否仅进行前处理,默认为 False
    # 设置是否使用来自 `torch_npu` 的 npu flash attention
    def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
        r"""
        设置是否使用来自 `torch_npu` 的 npu flash attention。
        """
        # 如果选择使用 npu flash attention
        if use_npu_flash_attention:
            # 创建 NPU 注意力处理器实例
            processor = AttnProcessorNPU()
        else:
            # 设置注意力处理器
            # 默认情况下使用 AttnProcessor2_0,当使用 torch 2.x 时,
            # 它利用 torch.nn.functional.scaled_dot_product_attention 进行本地 Flash/内存高效注意力
            # 仅在其具有默认 `scale` 参数时适用。TODO: 在迁移到 torch 2.1 时移除 scale_qk 检查
            processor = (
                AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
            )
        # 设置当前的处理器
        self.set_processor(processor)
    
    # 设置是否使用内存高效的 xformers 注意力
    def set_use_memory_efficient_attention_xformers(
        self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
    ):
        pass  # 此处可能缺少实现
    
    # 设置注意力计算的切片大小
    def set_attention_slice(self, slice_size: int) -> None:
        r"""
        设置注意力计算的切片大小。
    
        参数:
            slice_size (`int`):
                用于注意力计算的切片大小。
        """
        # 如果切片大小不为 None 且大于可切片头维度
        if slice_size is not None and slice_size > self.sliceable_head_dim:
            # 抛出值错误,切片大小必须小于或等于可切片头维度
            raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
    
        # 如果切片大小不为 None 且添加的 kv 投影维度不为 None
        if slice_size is not None and self.added_kv_proj_dim is not None:
            # 创建带切片大小的 KV 处理器实例
            processor = SlicedAttnAddedKVProcessor(slice_size)
        # 如果切片大小不为 None
        elif slice_size is not None:
            # 创建带切片大小的注意力处理器实例
            processor = SlicedAttnProcessor(slice_size)
        # 如果添加的 kv 投影维度不为 None
        elif self.added_kv_proj_dim is not None:
            # 创建 KV 注意力处理器实例
            processor = AttnAddedKVProcessor()
        else:
            # 设置注意力处理器
            # 默认情况下使用 AttnProcessor2_0,当使用 torch 2.x 时,
            # 它利用 torch.nn.functional.scaled_dot_product_attention 进行本地 Flash/内存高效注意力
            # 仅在其具有默认 `scale` 参数时适用。TODO: 在迁移到 torch 2.1 时移除 scale_qk 检查
            processor = (
                AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
            )
    
        # 设置当前的处理器
        self.set_processor(processor)
    # 设置要使用的注意力处理器
    def set_processor(self, processor: "AttnProcessor") -> None:
        r"""
        设置要使用的注意力处理器。
    
        参数:
            processor (`AttnProcessor`):
                要使用的注意力处理器。
        """
        # 如果当前处理器在 `self._modules` 中,且传入的 `processor` 不在其中,则需要从 `self._modules` 中移除当前处理器
        if (
            hasattr(self, "processor")  # 检查当前对象是否有处理器属性
            and isinstance(self.processor, torch.nn.Module)  # 确保当前处理器是一个 PyTorch 模块
            and not isinstance(processor, torch.nn.Module)  # 检查传入的处理器不是 PyTorch 模块
        ):
            # 记录日志,指出将移除已训练权重的处理器
            logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
            # 从模块中移除当前处理器
            self._modules.pop("processor")
    
        # 设置当前对象的处理器为传入的处理器
        self.processor = processor
    
    # 获取正在使用的注意力处理器
    def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
        r"""
        获取正在使用的注意力处理器。
    
        参数:
            return_deprecated_lora (`bool`, *可选*, 默认为 `False`):
                设置为 `True` 以返回过时的 LoRA 注意力处理器。
    
        返回:
            "AttentionProcessor": 正在使用的注意力处理器。
        """
        # 如果不需要返回过时的 LoRA 处理器,则返回当前处理器
        if not return_deprecated_lora:
            return self.processor
    
    # 前向传播方法,处理输入的隐藏状态
    def forward(
        self,
        hidden_states: torch.Tensor,  # 输入的隐藏状态张量
        encoder_hidden_states: Optional[torch.Tensor] = None,  # 可选的编码器隐藏状态张量
        attention_mask: Optional[torch.Tensor] = None,  # 可选的注意力掩码张量
        **cross_attention_kwargs,  # 可变参数,用于交叉注意力
    ) -> torch.Tensor:
        r"""  # 文档字符串,描述此方法的功能和参数
        The forward method of the `Attention` class.

        Args:  # 参数说明
            hidden_states (`torch.Tensor`):  # 查询的隐藏状态,类型为张量
                The hidden states of the query.
            encoder_hidden_states (`torch.Tensor`, *optional*):  # 编码器的隐藏状态,可选参数
                The hidden states of the encoder.
            attention_mask (`torch.Tensor`, *optional*):  # 注意力掩码,可选参数
                The attention mask to use. If `None`, no mask is applied.
            **cross_attention_kwargs:  # 额外的关键字参数,传递给交叉注意力
                Additional keyword arguments to pass along to the cross attention.

        Returns:  # 返回值说明
            `torch.Tensor`: The output of the attention layer.  # 返回注意力层的输出
        """
        # `Attention` 类可以调用不同的注意力处理器/函数
        # 这里我们简单地将所有张量传递给所选的处理器类
        # 对于此处定义的标准处理器,`**cross_attention_kwargs` 是空的

        attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())  # 获取处理器调用方法的参数名集合
        quiet_attn_parameters = {"ip_adapter_masks"}  # 定义不需要警告的参数集合
        unused_kwargs = [  # 筛选出未被使用的关键字参数
            k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
        ]
        if len(unused_kwargs) > 0:  # 如果存在未使用的关键字参数
            logger.warning(  # 记录警告日志
                f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
            )
        cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}  # 过滤出有效的关键字参数

        return self.processor(  # 调用处理器并返回结果
            self,
            hidden_states,  # 传递隐藏状态
            encoder_hidden_states=encoder_hidden_states,  # 传递编码器的隐藏状态
            attention_mask=attention_mask,  # 传递注意力掩码
            **cross_attention_kwargs,  # 解包有效的额外关键字参数
        )

    def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:  # 定义方法,输入张量并返回处理后的张量
        r"""  # 文档字符串,描述此方法的功能和参数
        Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`  # 将张量从 `[batch_size, seq_len, dim]` 重新形状为 `[batch_size // heads, seq_len, dim * heads]`,`heads` 为初始化时的头数量
        is the number of heads initialized while constructing the `Attention` class.

        Args:  # 参数说明
            tensor (`torch.Tensor`): The tensor to reshape.  # 要重新形状的张量

        Returns:  # 返回值说明
            `torch.Tensor`: The reshaped tensor.  # 返回重新形状后的张量
        """
        head_size = self.heads  # 获取头的数量
        batch_size, seq_len, dim = tensor.shape  # 解包输入张量的形状
        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)  # 重新调整张量的形状
        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)  # 调整维度顺序并重新形状
        return tensor  # 返回处理后的张量
    # 将输入张量从形状 `[batch_size, seq_len, dim]` 转换为 `[batch_size, seq_len, heads, dim // heads]`
    def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
            r"""
            将张量从 `[batch_size, seq_len, dim]` 重塑为 `[batch_size, seq_len, heads, dim // heads]`,其中 `heads` 是
            在构造 `Attention` 类时初始化的头数。
    
            参数:
                tensor (`torch.Tensor`): 要重塑的张量。
                out_dim (`int`, *可选*, 默认值为 `3`): 张量的输出维度。如果为 `3`,则张量被
                    重塑为 `[batch_size * heads, seq_len, dim // heads]`。
    
            返回:
                `torch.Tensor`: 重塑后的张量。
            """
            # 获取头的数量
            head_size = self.heads
            # 检查输入张量的维度,如果是三维则提取形状信息
            if tensor.ndim == 3:
                batch_size, seq_len, dim = tensor.shape
                extra_dim = 1
            else:
                # 如果不是三维,提取四维形状信息
                batch_size, extra_dim, seq_len, dim = tensor.shape
            # 重塑张量为 `[batch_size, seq_len * extra_dim, head_size, dim // head_size]`
            tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
            # 调整张量维度顺序为 `[batch_size, heads, seq_len * extra_dim, dim // heads]`
            tensor = tensor.permute(0, 2, 1, 3)
    
            # 如果输出维度为 3,进一步重塑张量为 `[batch_size * heads, seq_len * extra_dim, dim // heads]`
            if out_dim == 3:
                tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
    
            # 返回重塑后的张量
            return tensor
    
    # 计算注意力得分的函数
    def get_attention_scores(
            self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
        ) -> torch.Tensor:
            r"""
            计算注意力得分。
    
            参数:
                query (`torch.Tensor`): 查询张量。
                key (`torch.Tensor`): 键张量。
                attention_mask (`torch.Tensor`, *可选*): 使用的注意力掩码。如果为 `None`,则不应用掩码。
    
            返回:
                `torch.Tensor`: 注意力概率/得分。
            """
            # 获取查询张量的数据类型
            dtype = query.dtype
            # 如果需要上升类型,将查询和键张量转换为浮点型
            if self.upcast_attention:
                query = query.float()
                key = key.float()
    
            # 如果没有提供注意力掩码,创建空的输入张量
            if attention_mask is None:
                baddbmm_input = torch.empty(
                    query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
                )
                # 设置 beta 为 0
                beta = 0
            else:
                # 如果有注意力掩码,将其用作输入
                baddbmm_input = attention_mask
                # 设置 beta 为 1
                beta = 1
    
            # 计算注意力得分
            attention_scores = torch.baddbmm(
                baddbmm_input,
                query,
                key.transpose(-1, -2),
                beta=beta,
                alpha=self.scale,
            )
            # 删除临时的输入张量
            del baddbmm_input
    
            # 如果需要上升类型,将注意力得分转换为浮点型
            if self.upcast_softmax:
                attention_scores = attention_scores.float()
    
            # 计算注意力概率
            attention_probs = attention_scores.softmax(dim=-1)
            # 删除注意力得分张量
            del attention_scores
    
            # 将注意力概率转换回原始数据类型
            attention_probs = attention_probs.to(dtype)
    
            # 返回注意力概率
            return attention_probs
    
    # 准备注意力掩码的函数
    def prepare_attention_mask(
            self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
    ) -> torch.Tensor:  # 定义一个函数的返回类型为 torch.Tensor
        r"""  # 开始文档字符串,描述函数的作用和参数
        Prepare the attention mask for the attention computation.  # 准备注意力计算的注意力掩码
        Args:  # 参数说明
            attention_mask (`torch.Tensor`):  # 输入参数,注意力掩码,类型为 torch.Tensor
                The attention mask to prepare.  # 待准备的注意力掩码
            target_length (`int`):  # 输入参数,目标长度,类型为 int
                The target length of the attention mask. This is the length of the attention mask after padding.  # 注意力掩码的目标长度,经过填充后的长度
            batch_size (`int`):  # 输入参数,批处理大小,类型为 int
                The batch size, which is used to repeat the attention mask.  # 批处理大小,用于重复注意力掩码
            out_dim (`int`, *optional*, defaults to `3`):  # 可选参数,输出维度,类型为 int,默认为 3
                The output dimension of the attention mask. Can be either `3` or `4`.  # 注意力掩码的输出维度,可以是 3 或 4
        Returns:  # 返回说明
            `torch.Tensor`: The prepared attention mask.  # 返回准备好的注意力掩码,类型为 torch.Tensor
        """  # 结束文档字符串
        head_size = self.heads  # 获取头部大小,来自类的属性 heads
        if attention_mask is None:  # 检查注意力掩码是否为 None
            return attention_mask  # 如果是 None,直接返回

        current_length: int = attention_mask.shape[-1]  # 获取当前注意力掩码的长度
        if current_length != target_length:  # 检查当前长度是否与目标长度不匹配
            if attention_mask.device.type == "mps":  # 如果设备类型是 "mps"
                # HACK: MPS: Does not support padding by greater than dimension of input tensor.  # HACK: MPS 不支持填充超过输入张量的维度
                # Instead, we can manually construct the padding tensor.  # 所以我们手动构建填充张量
                padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)  # 定义填充张量的形状
                padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)  # 创建全零填充张量
                attention_mask = torch.cat([attention_mask, padding], dim=2)  # 在最后一个维度上拼接填充张量
            else:  # 如果不是 "mps" 设备
                # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:  # TODO: 对于如 stable-diffusion 的管道,填充交叉注意力掩码
                #       we want to instead pad by (0, remaining_length), where remaining_length is:  # 我们希望用 (0, remaining_length) 填充,其中 remaining_length 是
                #       remaining_length: int = target_length - current_length  # remaining_length 的计算
                # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding  # TODO: 重新启用相关测试
                attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)  # 用零填充注意力掩码到目标长度

        if out_dim == 3:  # 如果输出维度是 3
            if attention_mask.shape[0] < batch_size * head_size:  # 检查注意力掩码的第一维是否小于批处理大小乘以头部大小
                attention_mask = attention_mask.repeat_interleave(head_size, dim=0)  # 在第一维上重复注意力掩码
        elif out_dim == 4:  # 如果输出维度是 4
            attention_mask = attention_mask.unsqueeze(1)  # 在第一维增加一个维度
            attention_mask = attention_mask.repeat_interleave(head_size, dim=1)  # 在第二维上重复注意力掩码

        return attention_mask  # 返回准备好的注意力掩码
    # 定义一个函数用于规范化编码器的隐藏状态,接受一个张量作为输入并返回一个张量
    def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
        r"""
        规范化编码器隐藏状态。构造 `Attention` 类时需要指定 `self.norm_cross`。

        参数:
            encoder_hidden_states (`torch.Tensor`): 编码器的隐藏状态。

        返回:
            `torch.Tensor`: 规范化后的编码器隐藏状态。
        """
        # 确保在调用此方法之前已定义 `self.norm_cross`
        assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"

        # 检查 `self.norm_cross` 是否为 LayerNorm 类型
        if isinstance(self.norm_cross, nn.LayerNorm):
            # 对编码器隐藏状态进行层归一化
            encoder_hidden_states = self.norm_cross(encoder_hidden_states)
        # 检查 `self.norm_cross` 是否为 GroupNorm 类型
        elif isinstance(self.norm_cross, nn.GroupNorm):
            # GroupNorm 沿通道维度进行归一化,并期望输入形状为 (N, C, *)。
            # 此时我们希望沿隐藏维度进行归一化,因此需要调整形状
            # (batch_size, sequence_length, hidden_size) ->
            # (batch_size, hidden_size, sequence_length)
            encoder_hidden_states = encoder_hidden_states.transpose(1, 2)  # 转置张量以调整维度顺序
            encoder_hidden_states = self.norm_cross(encoder_hidden_states)  # 对转置后的张量进行归一化
            encoder_hidden_states = encoder_hidden_states.transpose(1, 2)  # 再次转置回原始顺序
        else:
            # 如果 `self.norm_cross` 既不是 LayerNorm 也不是 GroupNorm,则触发断言失败
            assert False

        # 返回规范化后的编码器隐藏状态
        return encoder_hidden_states

    # 该装饰器在计算图中禁止梯度计算,以节省内存和加快推理速度
    @torch.no_grad()
    # 定义一个融合投影的方法,默认参数 fuse 为 True
    def fuse_projections(self, fuse=True):
        # 获取 to_q 权重的设备信息
        device = self.to_q.weight.data.device
        # 获取 to_q 权重的数据类型
        dtype = self.to_q.weight.data.dtype

        # 如果不是交叉注意力
        if not self.is_cross_attention:
            # 获取权重矩阵的拼接
            concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
            # 输入特征数为拼接后权重的列数
            in_features = concatenated_weights.shape[1]
            # 输出特征数为拼接后权重的行数
            out_features = concatenated_weights.shape[0]

            # 创建一个新的线性投影层并复制权重
            self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
            # 复制拼接后的权重到新的层
            self.to_qkv.weight.copy_(concatenated_weights)
            # 如果使用偏置
            if self.use_bias:
                # 拼接 q、k、v 的偏置
                concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
                # 复制拼接后的偏置到新的层
                self.to_qkv.bias.copy_(concatenated_bias)

        # 如果是交叉注意力
        else:
            # 获取 k 和 v 权重的拼接
            concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
            # 输入特征数为拼接后权重的列数
            in_features = concatenated_weights.shape[1]
            # 输出特征数为拼接后权重的行数
            out_features = concatenated_weights.shape[0]

            # 创建一个新的线性投影层并复制权重
            self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
            # 复制拼接后的权重到新的层
            self.to_kv.weight.copy_(concatenated_weights)
            # 如果使用偏置
            if self.use_bias:
                # 拼接 k 和 v 的偏置
                concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
                # 复制拼接后的偏置到新的层
                self.to_kv.bias.copy_(concatenated_bias)

        # 处理 SD3 和其他添加的投影
        if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"):
            # 获取额外投影的权重拼接
            concatenated_weights = torch.cat(
                [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
            )
            # 输入特征数为拼接后权重的列数
            in_features = concatenated_weights.shape[1]
            # 输出特征数为拼接后权重的行数
            out_features = concatenated_weights.shape[0]

            # 创建一个新的线性投影层并复制权重
            self.to_added_qkv = nn.Linear(
                in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
            )
            # 复制拼接后的权重到新的层
            self.to_added_qkv.weight.copy_(concatenated_weights)
            # 如果使用偏置
            if self.added_proj_bias:
                # 拼接额外投影的偏置
                concatenated_bias = torch.cat(
                    [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
                )
                # 复制拼接后的偏置到新的层
                self.to_added_qkv.bias.copy_(concatenated_bias)

        # 将融合状态存储到属性中
        self.fused_projections = fuse
# 定义一个处理器类,用于执行与注意力相关的计算
class AttnProcessor:
    r"""
    默认处理器,用于执行与注意力相关的计算。
    """

    # 实现可调用方法,处理注意力计算
    def __call__(
        self,
        attn: Attention,  # 注意力对象
        hidden_states: torch.Tensor,  # 输入的隐藏状态张量
        encoder_hidden_states: Optional[torch.Tensor] = None,  # 编码器隐藏状态(可选)
        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码(可选)
        temb: Optional[torch.Tensor] = None,  # 额外的时间嵌入(可选)
        *args,  # 额外的位置参数
        **kwargs,  # 额外的关键字参数
    ) -> torch.Tensor:  # 返回处理后的张量
        # 检查是否有额外参数或已弃用的 scale 参数
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            # 构建弃用警告消息
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            # 调用弃用处理函数
            deprecate("scale", "1.0.0", deprecation_message)

        # 初始化残差为隐藏状态
        residual = hidden_states

        # 如果空间归一化存在,则应用于隐藏状态
        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        # 获取输入张量的维度
        input_ndim = hidden_states.ndim

        # 如果输入是四维的,则调整形状
        if input_ndim == 4:
            # 解包隐藏状态的形状
            batch_size, channel, height, width = hidden_states.shape
            # 重新调整形状为(batch_size, channel, height*width)并转置
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        # 根据编码器隐藏状态的存在与否,获取批次大小和序列长度
        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        # 准备注意力掩码
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        # 如果组归一化存在,则应用于隐藏状态
        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        # 将隐藏状态转换为查询向量
        query = attn.to_q(hidden_states)

        # 如果没有编码器隐藏状态,使用隐藏状态作为编码器隐藏状态
        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        # 如果需要规范化编码器隐藏状态,则应用规范化
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        # 从编码器隐藏状态中获取键和值
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        # 将查询、键和值转换为批次维度
        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        # 计算注意力分数
        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        # 通过注意力分数加权求值
        hidden_states = torch.bmm(attention_probs, value)
        # 将隐藏状态转换回头维度
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # 线性投影
        hidden_states = attn.to_out[0](hidden_states)
        # 应用 dropout
        hidden_states = attn.to_out[1](hidden_states)

        # 如果输入是四维的,调整回原始形状
        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        # 如果存在残差连接,则将残差加回隐藏状态
        if attn.residual_connection:
            hidden_states = hidden_states + residual

        # 将隐藏状态归一化到输出因子
        hidden_states = hidden_states / attn.rescale_output_factor

        # 返回最终的隐藏状态
        return hidden_states


# 定义一个处理器类,用于实现自定义扩散方法的注意力
class CustomDiffusionAttnProcessor(nn.Module):
    r"""
    实现自定义扩散方法的注意力处理器。
    # 定义参数说明
    Args:
        train_kv (`bool`, defaults to `True`):  # 是否重新训练对应于文本特征的键值矩阵
            Whether to newly train the key and value matrices corresponding to the text features.
        train_q_out (`bool`, defaults to `True`):  # 是否重新训练对应于潜在图像特征的查询矩阵
            Whether to newly train query matrices corresponding to the latent image features.
        hidden_size (`int`, *optional*, defaults to `None`):  # 注意力层的隐藏大小
            The hidden size of the attention layer.
        cross_attention_dim (`int`, *optional*, defaults to `None`):  # 编码器隐藏状态中的通道数量
            The number of channels in the `encoder_hidden_states`.
        out_bias (`bool`, defaults to `True`):  # 是否在 `train_q_out` 中包含偏置参数
            Whether to include the bias parameter in `train_q_out`.
        dropout (`float`, *optional*, defaults to 0.0):  # 使用的 dropout 概率
            The dropout probability to use.
    """

    # 初始化方法
    def __init__(
        self,  # 初始化方法的第一个参数,表示对象本身
        train_kv: bool = True,  # 设置键值矩阵训练的默认值为 True
        train_q_out: bool = True,  # 设置查询矩阵训练的默认值为 True
        hidden_size: Optional[int] = None,  # 隐藏层大小,默认为 None
        cross_attention_dim: Optional[int] = None,  # 跨注意力维度,默认为 None
        out_bias: bool = True,  # 输出偏置参数的默认值为 True
        dropout: float = 0.0,  # 默认的 dropout 概率为 0.0
    ):
        super().__init__()  # 调用父类的初始化方法
        self.train_kv = train_kv  # 保存键值训练标志
        self.train_q_out = train_q_out  # 保存查询输出训练标志

        self.hidden_size = hidden_size  # 保存隐藏层大小
        self.cross_attention_dim = cross_attention_dim  # 保存跨注意力维度

        # `_custom_diffusion` id 方便序列化和加载
        if self.train_kv:  # 如果需要训练键值
            self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)  # 创建键的线性层
            self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)  # 创建值的线性层
        if self.train_q_out:  # 如果需要训练查询输出
            self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)  # 创建查询的线性层
            self.to_out_custom_diffusion = nn.ModuleList([])  # 初始化输出层的模块列表
            self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))  # 添加线性输出层
            self.to_out_custom_diffusion.append(nn.Dropout(dropout))  # 添加 dropout 层

    # 可调用方法
    def __call__(  # 定义对象被调用时的行为
        self,  # 第一个参数,表示对象本身
        attn: Attention,  # 注意力对象
        hidden_states: torch.Tensor,  # 隐藏状态张量
        encoder_hidden_states: Optional[torch.Tensor] = None,  # 编码器隐藏状态,默认为 None
        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码,默认为 None
    # 返回类型为 torch.Tensor
    ) -> torch.Tensor:
        # 获取隐藏状态的批量大小和序列长度
        batch_size, sequence_length, _ = hidden_states.shape
        # 准备注意力掩码以适应当前批量和序列长度
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
        # 如果需要训练查询输出,则使用自定义扩散进行转换
        if self.train_q_out:
            query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
        else:
            # 否则使用标准的查询转换
            query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
    
        # 检查编码器隐藏状态是否为 None
        if encoder_hidden_states is None:
            # 如果是,则不进行交叉注意力
            crossattn = False
            encoder_hidden_states = hidden_states
        else:
            # 否则,启用交叉注意力
            crossattn = True
            # 如果需要归一化编码器隐藏状态,则进行归一化
            if attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
    
        # 如果需要训练键值对
        if self.train_kv:
            # 使用自定义扩散获取键和值
            key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
            value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
            # 将键和值转换为查询的权重数据类型
            key = key.to(attn.to_q.weight.dtype)
            value = value.to(attn.to_q.weight.dtype)
        else:
            # 否则使用标准的键和值转换
            key = attn.to_k(encoder_hidden_states)
            value = attn.to_v(encoder_hidden_states)
    
        # 如果进行交叉注意力
        if crossattn:
            # 创建与键相同形状的张量以进行detach操作
            detach = torch.ones_like(key)
            detach[:, :1, :] = detach[:, :1, :] * 0.0
            # 应用detach逻辑以阻止梯度流动
            key = detach * key + (1 - detach) * key.detach()
            value = detach * value + (1 - detach) * value.detach()
    
        # 将查询、键和值转换为批次维度
        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)
    
        # 计算注意力分数
        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        # 使用注意力分数和值进行批量矩阵乘法
        hidden_states = torch.bmm(attention_probs, value)
        # 将隐藏状态转换回头维度
        hidden_states = attn.batch_to_head_dim(hidden_states)
    
        # 如果需要训练查询输出
        if self.train_q_out:
            # 线性投影
            hidden_states = self.to_out_custom_diffusion[0](hidden_states)
            # 应用dropout
            hidden_states = self.to_out_custom_diffusion[1](hidden_states)
        else:
            # 否则使用标准的线性投影
            hidden_states = attn.to_out[0](hidden_states)
            # 应用dropout
            hidden_states = attn.to_out[1](hidden_states)
    
        # 返回最终的隐藏状态
        return hidden_states
# 定义一个带有额外可学习的键和值矩阵的注意力处理器类
class AttnAddedKVProcessor:
    r"""
    处理器,用于执行与文本编码器相关的注意力计算
    """

    # 定义调用方法,以实现注意力计算
    def __call__(
        self,
        attn: Attention,  # 注意力对象
        hidden_states: torch.Tensor,  # 输入的隐藏状态张量
        encoder_hidden_states: Optional[torch.Tensor] = None,  # 编码器的隐藏状态(可选)
        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码(可选)
        *args,  # 其他位置参数
        **kwargs,  # 其他关键字参数
    ) -> torch.Tensor:  # 返回类型为张量
        # 检查是否传递了多余的参数或已弃用的 scale 参数
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            # 发出弃用警告
            deprecate("scale", "1.0.0", deprecation_message)

        # 将隐藏状态赋值给残差
        residual = hidden_states

        # 重塑隐藏状态的形状,并转置维度
        hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
        # 获取批大小和序列长度
        batch_size, sequence_length, _ = hidden_states.shape

        # 准备注意力掩码
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        # 如果没有编码器隐藏状态,则使用输入的隐藏状态
        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        # 如果需要进行归一化处理
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        # 对隐藏状态进行分组归一化处理
        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        # 将隐藏状态转换为查询
        query = attn.to_q(hidden_states)
        # 将查询从头维度转换为批维度
        query = attn.head_to_batch_dim(query)

        # 将编码器隐藏状态投影为键和值
        encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
        encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
        # 将投影结果转换为批维度
        encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
        encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)

        # 如果不是仅进行交叉注意力
        if not attn.only_cross_attention:
            # 将隐藏状态转换为键和值
            key = attn.to_k(hidden_states)
            value = attn.to_v(hidden_states)
            # 转换为批维度
            key = attn.head_to_batch_dim(key)
            value = attn.head_to_batch_dim(value)
            # 将编码器键和值与当前键和值拼接
            key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
            value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
        else:
            # 仅使用编码器的键和值
            key = encoder_hidden_states_key_proj
            value = encoder_hidden_states_value_proj

        # 获取注意力概率
        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        # 计算隐藏状态的新值
        hidden_states = torch.bmm(attention_probs, value)
        # 将隐藏状态转换回头维度
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # 线性投影
        hidden_states = attn.to_out[0](hidden_states)
        # 应用 dropout
        hidden_states = attn.to_out[1](hidden_states)

        # 重塑隐藏状态,并将残差加回
        hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
        hidden_states = hidden_states + residual

        # 返回处理后的隐藏状态
        return hidden_states


# 定义另一个注意力处理器类
class AttnAddedKVProcessor2_0:
    r"""
    # 处理缩放点积注意力的处理器(如果使用 PyTorch 2.0,默认启用),
    # 其中为文本编码器添加了额外的可学习的键和值矩阵。
        """
    
        # 初始化方法
        def __init__(self):
            # 检查 F 中是否有 "scaled_dot_product_attention" 属性
            if not hasattr(F, "scaled_dot_product_attention"):
                # 如果没有,抛出 ImportError,提示用户需要升级到 PyTorch 2.0
                raise ImportError(
                    "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
                )
    
        # 定义调用方法
        def __call__(
            self,
            attn: Attention,  # 输入的注意力机制对象
            hidden_states: torch.Tensor,  # 隐藏状态张量
            encoder_hidden_states: Optional[torch.Tensor] = None,  # 可选的编码器隐藏状态张量
            attention_mask: Optional[torch.Tensor] = None,  # 可选的注意力掩码张量
            *args,  # 额外的位置参数
            **kwargs,  # 额外的关键字参数
    ) -> torch.Tensor:  # 指定函数返回类型为 torch.Tensor
        # 检查参数是否存在或 scale 参数是否被提供
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            # 设置弃用消息,告知 scale 参数将被忽略
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            # 调用 deprecate 函数发出弃用警告
            deprecate("scale", "1.0.0", deprecation_message)

        # 将输入的 hidden_states 赋值给 residual
        residual = hidden_states

        # 调整 hidden_states 的形状并进行转置
        hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
        # 获取 batch_size 和 sequence_length
        batch_size, sequence_length, _ = hidden_states.shape

        # 准备注意力掩码
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)

        # 如果没有提供 encoder_hidden_states,则使用 hidden_states
        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        # 如果需要归一化交叉隐藏状态
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        # 对 hidden_states 进行分组归一化
        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        # 计算查询向量
        query = attn.to_q(hidden_states)
        # 将查询向量转换为批次维度
        query = attn.head_to_batch_dim(query, out_dim=4)

        # 生成 encoder_hidden_states 的键和值的投影
        encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
        encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
        # 将键和值转换为批次维度
        encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
        encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)

        # 如果不是只进行交叉注意力
        if not attn.only_cross_attention:
            # 计算当前 hidden_states 的键和值
            key = attn.to_k(hidden_states)
            value = attn.to_v(hidden_states)
            # 转换为批次维度
            key = attn.head_to_batch_dim(key, out_dim=4)
            value = attn.head_to_batch_dim(value, out_dim=4)
            # 将键和值与 encoder 的键和值连接
            key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
            value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
        else:
            # 如果只进行交叉注意力,使用 encoder 的键和值
            key = encoder_hidden_states_key_proj
            value = encoder_hidden_states_value_proj

        # 计算缩放点积注意力的输出,形状为 (batch, num_heads, seq_len, head_dim)
        # TODO: 在迁移到 Torch 2.1 时添加对 attn.scale 的支持
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )
        # 转置并重塑 hidden_states
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])

        # 进行线性投影
        hidden_states = attn.to_out[0](hidden_states)
        # 进行 dropout
        hidden_states = attn.to_out[1](hidden_states)

        # 转置并重塑回 residual 的形状
        hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
        # 将 residual 加到 hidden_states 上
        hidden_states = hidden_states + residual

        # 返回最终的 hidden_states
        return hidden_states
# 定义一个名为 JointAttnProcessor2_0 的类,用于处理自注意力投影
class JointAttnProcessor2_0:
    """Attention processor used typically in processing the SD3-like self-attention projections."""

    # 初始化方法
    def __init__(self):
        # 检查 F 是否有 scaled_dot_product_attention 属性
        if not hasattr(F, "scaled_dot_product_attention"):
            # 如果没有,抛出导入错误,提示需要升级 PyTorch 到 2.0
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    # 定义调用方法,接受多个参数
    def __call__(
        self,
        attn: Attention,  # 自注意力对象
        hidden_states: torch.FloatTensor,  # 当前隐藏状态的张量
        encoder_hidden_states: torch.FloatTensor = None,  # 编码器的隐藏状态,默认为 None
        attention_mask: Optional[torch.FloatTensor] = None,  # 可选的注意力掩码,默认为 None
        *args,  # 额外的位置参数
        **kwargs,  # 额外的关键字参数
    # 返回一个浮点张量
    ) -> torch.FloatTensor:
        # 保存输入的隐藏状态,以便后续使用
        residual = hidden_states
    
        # 获取隐藏状态的维度
        input_ndim = hidden_states.ndim
        # 如果隐藏状态是四维的
        if input_ndim == 4:
            # 解包隐藏状态的形状为批大小、通道、高度和宽度
            batch_size, channel, height, width = hidden_states.shape
            # 将隐藏状态重塑为三维,并进行转置
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
        # 获取编码器隐藏状态的维度
        context_input_ndim = encoder_hidden_states.ndim
        # 如果编码器隐藏状态是四维的
        if context_input_ndim == 4:
            # 解包编码器隐藏状态的形状为批大小、通道、高度和宽度
            batch_size, channel, height, width = encoder_hidden_states.shape
            # 将编码器隐藏状态重塑为三维,并进行转置
            encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
    
        # 获取编码器隐藏状态的批大小
        batch_size = encoder_hidden_states.shape[0]
    
        # 计算 `sample` 投影
        query = attn.to_q(hidden_states)
        key = attn.to_k(hidden_states)
        value = attn.to_v(hidden_states)
    
        # 计算 `context` 投影
        encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
        encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
        encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
    
        # 合并注意力查询、键和值
        query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
        key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
        value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
    
        # 获取键的最后一维大小
        inner_dim = key.shape[-1]
        # 计算每个头的维度
        head_dim = inner_dim // attn.heads
        # 重塑查询、键和值以适应多个头
        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
    
        # 计算缩放点积注意力
        hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
        # 转置并重塑隐藏状态
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        # 转换为查询的类型
        hidden_states = hidden_states.to(query.dtype)
    
        # 拆分注意力输出
        hidden_states, encoder_hidden_states = (
            hidden_states[:, : residual.shape[1]],  # 获取原隐藏状态的部分
            hidden_states[:, residual.shape[1] :],  # 获取编码器隐藏状态的部分
        )
    
        # 进行线性投影
        hidden_states = attn.to_out[0](hidden_states)
        # 进行 dropout
        hidden_states = attn.to_out[1](hidden_states)
        # 如果上下文不是仅限于编码器
        if not attn.context_pre_only:
            # 对编码器隐藏状态进行额外处理
            encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
    
        # 如果输入是四维的,进行转置和重塑
        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
        # 如果上下文输入是四维的,进行转置和重塑
        if context_input_ndim == 4:
            encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
    
        # 返回处理后的隐藏状态和编码器隐藏状态
        return hidden_states, encoder_hidden_states
# 定义一个类,PAGJointAttnProcessor2_0,用于处理自注意力投影
class PAGJointAttnProcessor2_0:
    """Attention processor used typically in processing the SD3-like self-attention projections."""

    # 初始化方法
    def __init__(self):
        # 检查是否存在名为"scaled_dot_product_attention"的属性
        if not hasattr(F, "scaled_dot_product_attention"):
            # 如果不存在,则抛出导入错误,提示需要升级PyTorch到2.0
            raise ImportError(
                "PAGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
            )

    # 可调用方法,接受注意力对象和隐藏状态
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: torch.FloatTensor = None,
        # 其他可选参数
        attention_mask: Optional[torch.FloatTensor] = None,
        *args,
        **kwargs,
# 定义另一个类,PAGCFGJointAttnProcessor2_0,类似于PAGJointAttnProcessor2_0
class PAGCFGJointAttnProcessor2_0:
    """Attention processor used typically in processing the SD3-like self-attention projections."""

    # 初始化方法
    def __init__(self):
        # 检查是否存在名为"scaled_dot_product_attention"的属性
        if not hasattr(F, "scaled_dot_product_attention"):
            # 如果不存在,则抛出导入错误,提示需要升级PyTorch到2.0
            raise ImportError(
                "PAGCFGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
            )

    # 可调用方法,接受注意力对象和隐藏状态
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: torch.FloatTensor = None,
        # 其他可选参数
        attention_mask: Optional[torch.FloatTensor] = None,
        *args,
        **kwargs,
# 定义第三个类,FusedJointAttnProcessor2_0,处理自注意力投影
class FusedJointAttnProcessor2_0:
    """Attention processor used typically in processing the SD3-like self-attention projections."""

    # 初始化方法
    def __init__(self):
        # 检查是否存在名为"scaled_dot_product_attention"的属性
        if not hasattr(F, "scaled_dot_product_attention"):
            # 如果不存在,则抛出导入错误,提示需要升级PyTorch到2.0
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    # 可调用方法,接受注意力对象和隐藏状态
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: torch.FloatTensor = None,
        # 其他可选参数
        attention_mask: Optional[torch.FloatTensor] = None,
        *args,
        **kwargs,
    ) -> torch.FloatTensor:
        # 将隐藏状态赋值给残差变量
        residual = hidden_states

        # 获取隐藏状态的维度
        input_ndim = hidden_states.ndim
        # 如果隐藏状态是四维的,进行维度变换
        if input_ndim == 4:
            # 解包隐藏状态的形状
            batch_size, channel, height, width = hidden_states.shape
            # 将隐藏状态变形为(batch_size, channel, height * width)并转置
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
        # 获取编码器隐藏状态的维度
        context_input_ndim = encoder_hidden_states.ndim
        # 如果编码器隐藏状态是四维的,进行维度变换
        if context_input_ndim == 4:
            # 解包编码器隐藏状态的形状
            batch_size, channel, height, width = encoder_hidden_states.shape
            # 将编码器隐藏状态变形为(batch_size, channel, height * width)并转置
            encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        # 获取编码器隐藏状态的批量大小
        batch_size = encoder_hidden_states.shape[0]

        # `sample` 进行投影
        qkv = attn.to_qkv(hidden_states)
        # 计算每个分量的大小
        split_size = qkv.shape[-1] // 3
        # 将qkv拆分为query、key和value
        query, key, value = torch.split(qkv, split_size, dim=-1)

        # `context` 进行投影
        encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
        # 计算编码器qkv的分量大小
        split_size = encoder_qkv.shape[-1] // 3
        # 将编码器qkv拆分为查询、键和值的投影
        (
            encoder_hidden_states_query_proj,
            encoder_hidden_states_key_proj,
            encoder_hidden_states_value_proj,
        ) = torch.split(encoder_qkv, split_size, dim=-1)

        # 进行注意力计算
        # 将query、key、value进行连接
        query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
        key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
        value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)

        # 获取key的最后一维大小
        inner_dim = key.shape[-1]
        # 计算每个头的维度
        head_dim = inner_dim // attn.heads
        # 调整query的形状以适应多头注意力
        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        # 调整key的形状以适应多头注意力
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        # 调整value的形状以适应多头注意力
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # 进行缩放点积注意力计算
        hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
        # 调整hidden_states的形状
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        # 将hidden_states转换为与query相同的数据类型
        hidden_states = hidden_states.to(query.dtype)

        # 拆分注意力输出
        hidden_states, encoder_hidden_states = (
            # 保留残差形状的部分
            hidden_states[:, : residual.shape[1]],
            # 剩余的部分
            hidden_states[:, residual.shape[1] :],
        )

        # 线性投影
        hidden_states = attn.to_out[0](hidden_states)
        # 进行dropout
        hidden_states = attn.to_out[1](hidden_states)
        # 如果不是只使用上下文,进行编码器输出的投影
        if not attn.context_pre_only:
            encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

        # 如果输入是四维的,调整hidden_states的形状
        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
        # 如果上下文输入是四维的,调整encoder_hidden_states的形状
        if context_input_ndim == 4:
            encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        # 返回hidden_states和encoder_hidden_states
        return hidden_states, encoder_hidden_states
# 定义一个用于处理 Aura Flow 的注意力处理器类
class AuraFlowAttnProcessor2_0:
    """Attention processor used typically in processing Aura Flow."""

    # 初始化方法
    def __init__(self):
        # 检查 F 是否具有 scaled_dot_product_attention 属性,并确保 PyTorch 版本符合要求
        if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
            # 如果不满足条件,抛出导入错误,提示用户升级 PyTorch
            raise ImportError(
                "AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
            )

    # 可调用方法,用于处理输入的注意力和隐藏状态
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: torch.FloatTensor = None,
        *args,
        **kwargs,
# 定义一个用于处理 Aura Flow 的融合投影注意力处理器类
class FusedAuraFlowAttnProcessor2_0:
    """Attention processor used typically in processing Aura Flow with fused projections."""

    # 初始化方法
    def __init__(self):
        # 检查 F 是否具有 scaled_dot_product_attention 属性,并确保 PyTorch 版本符合要求
        if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
            # 如果不满足条件,抛出导入错误,提示用户升级 PyTorch
            raise ImportError(
                "FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
            )

    # 可调用方法,用于处理输入的注意力和隐藏状态
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: torch.FloatTensor = None,
        *args,
        **kwargs,
# YiYi 待办事项:重构与 rope 相关的函数/类
def apply_rope(xq, xk, freqs_cis):
    # 将 xq 转换为浮点型,并重新调整形状以便处理
    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
    # 将 xk 转换为浮点型,并重新调整形状以便处理
    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
    # 计算 xq 的输出,结合频率复数
    xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
    # 计算 xk 的输出,结合频率复数
    xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
    # 返回调整形状后的 xq_out 和 xk_out,并确保与原始类型匹配
    return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)


# 定义一个实现缩放点积注意力的处理器类
class FluxSingleAttnProcessor2_0:
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
    """

    # 初始化方法
    def __init__(self):
        # 检查 F 是否具有 scaled_dot_product_attention 属性
        if not hasattr(F, "scaled_dot_product_attention"):
            # 如果不满足条件,抛出导入错误,提示用户升级 PyTorch
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    # 可调用方法,用于处理输入的注意力和隐藏状态
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        image_rotary_emb: Optional[torch.Tensor] = None,
    # 定义函数的返回类型为 torch.Tensor
    ) -> torch.Tensor:
        # 获取 hidden_states 的维度数量
        input_ndim = hidden_states.ndim
    
        # 如果输入的维度为 4
        if input_ndim == 4:
            # 解包 hidden_states 的形状为 batch_size, channel, height, width
            batch_size, channel, height, width = hidden_states.shape
            # 将 hidden_states 视图调整为 (batch_size, channel, height * width) 并转置
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
    
        # 如果 encoder_hidden_states 为 None,则获取 hidden_states 的形状
        # 否则获取 encoder_hidden_states 的形状
        batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
    
        # 将 hidden_states 转换为查询向量
        query = attn.to_q(hidden_states)
        # 如果 encoder_hidden_states 为 None,将其设置为 hidden_states
        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
    
        # 将 encoder_hidden_states 转换为键向量
        key = attn.to_k(encoder_hidden_states)
        # 将 encoder_hidden_states 转换为值向量
        value = attn.to_v(encoder_hidden_states)
    
        # 获取键的最后一个维度的大小
        inner_dim = key.shape[-1]
        # 计算每个头的维度
        head_dim = inner_dim // attn.heads
    
        # 将查询向量调整视图为 (batch_size, -1, attn.heads, head_dim) 并转置
        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
    
        # 将键向量调整视图为 (batch_size, -1, attn.heads, head_dim) 并转置
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        # 将值向量调整视图为 (batch_size, -1, attn.heads, head_dim) 并转置
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
    
        # 如果存在规范化查询的层,则对查询进行规范化
        if attn.norm_q is not None:
            query = attn.norm_q(query)
        # 如果存在规范化键的层,则对键进行规范化
        if attn.norm_k is not None:
            key = attn.norm_k(key)
    
        # 如果需要应用 RoPE
        if image_rotary_emb is not None:
            # 应用旋转嵌入到查询和键上
            query, key = apply_rope(query, key, image_rotary_emb)
    
        # 计算缩放点积注意力,输出形状为 (batch, num_heads, seq_len, head_dim)
        hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
    
        # 转置并调整 hidden_states 的形状为 (batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        # 将 hidden_states 转换为与查询相同的数据类型
        hidden_states = hidden_states.to(query.dtype)
    
        # 如果输入维度为 4,将 hidden_states 转置并调整形状回原始维度
        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
    
        # 返回处理后的 hidden_states
        return hidden_states
# 定义一个名为 FluxAttnProcessor2_0 的类,通常用于处理 SD3 类自注意力投影
class FluxAttnProcessor2_0:
    """Attention processor used typically in processing the SD3-like self-attention projections."""

    # 初始化方法
    def __init__(self):
        # 检查 F 是否有 scaled_dot_product_attention 属性,如果没有则抛出 ImportError
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    # 定义调用方法,使类实例可被调用
    def __call__(
        self,
        attn: Attention,  # 接收 Attention 对象
        hidden_states: torch.FloatTensor,  # 接收隐藏状态张量
        encoder_hidden_states: torch.FloatTensor = None,  # 可选的编码器隐藏状态张量
        attention_mask: Optional[torch.FloatTensor] = None,  # 可选的注意力掩码张量
        image_rotary_emb: Optional[torch.Tensor] = None,  # 可选的图像旋转嵌入张量
    ):
        # 此处将实现自注意力的具体处理逻辑

# 定义一个名为 CogVideoXAttnProcessor2_0 的类,专用于 CogVideoX 模型的缩放点积注意力处理
class CogVideoXAttnProcessor2_0:
    r"""
    Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
    query and key vectors, but does not include spatial normalization.
    """

    # 初始化方法
    def __init__(self):
        # 检查 F 是否有 scaled_dot_product_attention 属性,如果没有则抛出 ImportError
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    # 定义调用方法,使类实例可被调用
    def __call__(
        self,
        attn: Attention,  # 接收 Attention 对象
        hidden_states: torch.Tensor,  # 接收隐藏状态张量
        encoder_hidden_states: torch.Tensor,  # 接收编码器隐藏状态张量
        attention_mask: Optional[torch.Tensor] = None,  # 可选的注意力掩码张量
        image_rotary_emb: Optional[torch.Tensor] = None,  # 可选的图像旋转嵌入张量
    ):
        # 此处将实现自注意力的具体处理逻辑
    ) -> torch.Tensor:  # 函数返回一个张量,表示隐藏状态
        text_seq_length = encoder_hidden_states.size(1)  # 获取编码器隐藏状态的序列长度

        hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)  # 在维度1上连接编码器隐藏状态和当前隐藏状态

        batch_size, sequence_length, _ = (  # 解包 batch_size 和 sequence_length
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape  # 根据编码器隐藏状态的存在性决定形状
        )

        if attention_mask is not None:  # 如果存在注意力掩码
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)  # 准备注意力掩码
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])  # 调整注意力掩码的形状以适应头数

        query = attn.to_q(hidden_states)  # 将隐藏状态转换为查询向量
        key = attn.to_k(hidden_states)  # 将隐藏状态转换为键向量
        value = attn.to_v(hidden_states)  # 将隐藏状态转换为值向量

        inner_dim = key.shape[-1]  # 获取键向量的最后一个维度大小
        head_dim = inner_dim // attn.heads  # 计算每个头的维度

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)  # 调整查询向量形状并转置以适应多头注意力
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)  # 调整键向量形状并转置以适应多头注意力
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)  # 调整值向量形状并转置以适应多头注意力

        if attn.norm_q is not None:  # 如果查询归一化层存在
            query = attn.norm_q(query)  # 对查询向量进行归一化
        if attn.norm_k is not None:  # 如果键归一化层存在
            key = attn.norm_k(key)  # 对键向量进行归一化

        # Apply RoPE if needed  # 如果需要应用旋转位置编码
        if image_rotary_emb is not None:  # 如果图像旋转嵌入存在
            from .embeddings import apply_rotary_emb  # 导入应用旋转嵌入的函数

            query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)  # 应用旋转嵌入到查询向量的后半部分
            if not attn.is_cross_attention:  # 如果不是交叉注意力
                key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)  # 应用旋转嵌入到键向量的后半部分

        hidden_states = F.scaled_dot_product_attention(  # 计算缩放点积注意力
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False  # 输入查询、键和值,以及注意力掩码
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)  # 转置和重塑隐藏状态以合并头维度

        # linear proj  # 线性投影
        hidden_states = attn.to_out[0](hidden_states)  # 对隐藏状态应用输出线性变换
        # dropout  # 进行dropout操作
        hidden_states = attn.to_out[1](hidden_states)  # 对隐藏状态应用dropout

        encoder_hidden_states, hidden_states = hidden_states.split(  # 将隐藏状态分割为编码器和当前隐藏状态
            [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1  # 根据文本序列长度和剩余部分进行分割
        )
        return hidden_states, encoder_hidden_states  # 返回当前隐藏状态和编码器隐藏状态
# 定义一个用于实现 CogVideoX 模型的缩放点积注意力的处理器类
class FusedCogVideoXAttnProcessor2_0:
    r"""
    Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
    query and key vectors, but does not include spatial normalization.
    """

    # 初始化方法
    def __init__(self):
        # 检查 F 是否具有 scaled_dot_product_attention 属性,如果没有则抛出导入错误
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    # 定义可调用方法,处理注意力计算
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        image_rotary_emb: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # 获取编码器隐藏状态的序列长度
        text_seq_length = encoder_hidden_states.size(1)

        # 将编码器和当前隐藏状态按维度 1 连接
        hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

        # 获取批次大小和序列长度,依据编码器隐藏状态是否为 None
        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        # 如果提供了注意力掩码,则准备掩码
        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # 将掩码调整为适当的形状
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        # 将隐藏状态转换为查询、键、值
        qkv = attn.to_qkv(hidden_states)
        # 计算每个部分的大小
        split_size = qkv.shape[-1] // 3
        # 分割成查询、键和值
        query, key, value = torch.split(qkv, split_size, dim=-1)

        # 获取键的内部维度
        inner_dim = key.shape[-1]
        # 计算每个头的维度
        head_dim = inner_dim // attn.heads

        # 调整查询、键和值的形状以适应多头注意力
        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # 如果存在查询的归一化,则应用归一化
        if attn.norm_q is not None:
            query = attn.norm_q(query)
        # 如果存在键的归一化,则应用归一化
        if attn.norm_k is not None:
            key = attn.norm_k(key)

        # 如果需要应用 RoPE
        if image_rotary_emb is not None:
            from .embeddings import apply_rotary_emb

            # 对查询的特定部分应用旋转嵌入
            query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
            # 如果不是交叉注意力,则对键的特定部分应用旋转嵌入
            if not attn.is_cross_attention:
                key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)

        # 计算缩放点积注意力
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        # 调整隐藏状态的形状以便输出
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)

        # 线性投影
        hidden_states = attn.to_out[0](hidden_states)
        # 应用 dropout
        hidden_states = attn.to_out[1](hidden_states)

        # 将隐藏状态拆分为编码器隐藏状态和当前隐藏状态
        encoder_hidden_states, hidden_states = hidden_states.split(
            [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
        )
        # 返回当前隐藏状态和编码器隐藏状态
        return hidden_states, encoder_hidden_states


# 定义用于实现内存高效注意力的处理器类
class XFormersAttnAddedKVProcessor:
    r"""
    Processor for implementing memory efficient attention using xFormers.
    # 文档字符串,说明可选参数 attention_op 的作用
        Args:
            attention_op (`Callable`, *optional*, defaults to `None`):
                使用的基本注意力操作符,推荐设置为 `None` 让 xFormers 选择最佳操作符
        """
    
        # 构造函数,初始化注意力操作符
        def __init__(self, attention_op: Optional[Callable] = None):
            # 将传入的注意力操作符赋值给实例变量
            self.attention_op = attention_op
    
        # 可调用方法,用于执行注意力计算
        def __call__(
            self,
            attn: Attention,  # 注意力对象
            hidden_states: torch.Tensor,  # 隐藏状态张量
            encoder_hidden_states: Optional[torch.Tensor] = None,  # 编码器隐藏状态,默认为 None
            attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码,默认为 None
        ) -> torch.Tensor:
            # 将当前隐藏状态保存为残差以便后续使用
            residual = hidden_states
            # 调整隐藏状态的形状并转置
            hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
            # 获取批次大小和序列长度
            batch_size, sequence_length, _ = hidden_states.shape
    
            # 准备注意力掩码
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
    
            # 如果没有编码器隐藏状态,则将其设置为当前的隐藏状态
            if encoder_hidden_states is None:
                encoder_hidden_states = hidden_states
            # 如果需要,则对编码器隐藏状态进行归一化处理
            elif attn.norm_cross:
                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
    
            # 对隐藏状态进行分组归一化处理
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
    
            # 生成查询向量
            query = attn.to_q(hidden_states)
            # 将查询向量从头部维度转换为批次维度
            query = attn.head_to_batch_dim(query)
    
            # 对编码器隐藏状态进行键和值的投影
            encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
            encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
            # 将编码器隐藏状态的键和值转换为批次维度
            encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
            encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
    
            # 如果不是仅使用交叉注意力
            if not attn.only_cross_attention:
                # 生成当前隐藏状态的键和值
                key = attn.to_k(hidden_states)
                value = attn.to_v(hidden_states)
                # 转换键和值到批次维度
                key = attn.head_to_batch_dim(key)
                value = attn.head_to_batch_dim(value)
                # 将编码器的键和值与当前的键和值连接起来
                key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
                value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
            else:
                # 如果仅使用交叉注意力,则直接使用编码器的键和值
                key = encoder_hidden_states_key_proj
                value = encoder_hidden_states_value_proj
    
            # 计算高效的注意力
            hidden_states = xformers.ops.memory_efficient_attention(
                query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
            )
            # 将结果转换为查询的 dtype
            hidden_states = hidden_states.to(query.dtype)
            # 将隐藏状态从批次维度转换回头部维度
            hidden_states = attn.batch_to_head_dim(hidden_states)
    
            # 线性变换
            hidden_states = attn.to_out[0](hidden_states)
            # 应用 dropout
            hidden_states = attn.to_out[1](hidden_states)
    
            # 调整隐藏状态的形状以匹配残差
            hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
            # 将当前隐藏状态与残差相加
            hidden_states = hidden_states + residual
    
            # 返回最终的隐藏状态
            return hidden_states
# 定义一个用于实现基于 xFormers 的内存高效注意力的处理器类
class XFormersAttnProcessor:
    r"""
    处理器,用于实现基于 xFormers 的内存高效注意力。

    参数:
        attention_op (`Callable`, *可选*, 默认为 `None`):
            基础
            [操作符](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase),
            用作注意力操作符。建议将其设置为 `None`,并让 xFormers 选择最佳操作符。
    """

    # 初始化方法,接受一个可选的注意力操作符
    def __init__(self, attention_op: Optional[Callable] = None):
        # 将传入的注意力操作符赋值给实例变量
        self.attention_op = attention_op

    # 定义可调用方法,用于执行注意力计算
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        temb: Optional[torch.Tensor] = None,
        *args,
        **kwargs,
# 定义一个用于实现 flash attention 的处理器类,使用 torch_npu
class AttnProcessorNPU:
    r"""
    处理器,用于使用 torch_npu 实现 flash attention。torch_npu 仅支持 fp16 和 bf16 数据类型。如果
    使用 fp32,将使用 F.scaled_dot_product_attention 进行计算,但在 NPU 上加速效果不明显。

    """

    # 初始化方法
    def __init__(self):
        # 检查是否可用 torch_npu,如果不可用则抛出异常
        if not is_torch_npu_available():
            raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.")

    # 定义可调用方法,用于执行注意力计算
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        temb: Optional[torch.Tensor] = None,
        *args,
        **kwargs,
# 定义一个用于实现 scaled dot-product attention 的处理器类,默认在 PyTorch 2.0 中启用
class AttnProcessor2_0:
    r"""
    处理器,用于实现 scaled dot-product attention(如果您使用的是 PyTorch 2.0,默认启用)。
    """

    # 初始化方法
    def __init__(self):
        # 检查 F 中是否有 scaled_dot_product_attention 属性,如果没有则抛出异常
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    # 定义可调用方法,用于执行注意力计算
    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        temb: Optional[torch.Tensor] = None,
        *args,
        **kwargs,
# 定义一个用于实现 scaled dot-product attention 的处理器类,适用于稳定音频模型
class StableAudioAttnProcessor2_0:
    r"""
    处理器,用于实现 scaled dot-product attention(如果您使用的是 PyTorch 2.0,默认启用)。此处理器用于
    稳定音频模型。它在查询和键向量上应用旋转嵌入,并允许 MHA、GQA 或 MQA。
    """

    # 初始化方法
    def __init__(self):
        # 检查 F 中是否有 scaled_dot_product_attention 属性,如果没有则抛出异常
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError(
                "StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
            )

    # 定义方法,用于应用部分旋转嵌入
    def apply_partial_rotary_emb(
        self,
        x: torch.Tensor,
        freqs_cis: Tuple[torch.Tensor],
    # 定义返回类型为 torch.Tensor 的函数
    ) -> torch.Tensor:
        # 从当前模块导入 apply_rotary_emb 函数
        from .embeddings import apply_rotary_emb
    
        # 获取频率余弦的最后一个维度大小,用于旋转
        rot_dim = freqs_cis[0].shape[-1]
        # 将输入张量 x 划分为需要旋转和不需要旋转的部分
        x_to_rotate, x_unrotated = x[..., :rot_dim], x[..., rot_dim:]
    
        # 应用旋转嵌入到需要旋转的部分
        x_rotated = apply_rotary_emb(x_to_rotate, freqs_cis, use_real=True, use_real_unbind_dim=-2)
    
        # 将旋转后的部分与未旋转的部分在最后一个维度上连接
        out = torch.cat((x_rotated, x_unrotated), dim=-1)
        # 返回连接后的输出张量
        return out
    
    # 定义可调用方法,接收注意力和隐藏状态
    def __call__(
        self,
        # 输入的注意力对象
        attn: Attention,
        # 隐藏状态的张量
        hidden_states: torch.Tensor,
        # 可选的编码器隐藏状态张量
        encoder_hidden_states: Optional[torch.Tensor] = None,
        # 可选的注意力掩码张量
        attention_mask: Optional[torch.Tensor] = None,
        # 可选的旋转嵌入张量
        rotary_emb: Optional[torch.Tensor] = None,
# 定义 HunyuanAttnProcessor2_0 类,处理缩放的点积注意力
class HunyuanAttnProcessor2_0:
    r"""
    处理器用于实现缩放的点积注意力(如果使用 PyTorch 2.0,默认启用)。这是
    HunyuanDiT 模型中使用的。它在查询和键向量上应用归一化层和旋转嵌入。
    """

    # 初始化方法
    def __init__(self):
        # 检查 F 中是否有 scaled_dot_product_attention 属性
        if not hasattr(F, "scaled_dot_product_attention"):
            # 如果没有,则抛出导入错误,提示需要升级 PyTorch 到 2.0
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    # 定义调用方法
    def __call__(
        self,
        attn: Attention,  # 注意力机制实例
        hidden_states: torch.Tensor,  # 当前隐藏状态的张量
        encoder_hidden_states: Optional[torch.Tensor] = None,  # 编码器隐藏状态的可选张量
        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码的可选张量
        temb: Optional[torch.Tensor] = None,  # 时间嵌入的可选张量
        image_rotary_emb: Optional[torch.Tensor] = None,  # 图像旋转嵌入的可选张量
class FusedHunyuanAttnProcessor2_0:
    r"""
    处理器用于实现缩放的点积注意力(如果使用 PyTorch 2.0,默认启用),带有融合的
    投影层。这是 HunyuanDiT 模型中使用的。它在查询和键向量上应用归一化层和旋转嵌入。
    """

    # 初始化方法
    def __init__(self):
        # 检查 F 中是否有 scaled_dot_product_attention 属性
        if not hasattr(F, "scaled_dot_product_attention"):
            # 如果没有,则抛出导入错误,提示需要升级 PyTorch 到 2.0
            raise ImportError(
                "FusedHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
            )

    # 定义调用方法
    def __call__(
        self,
        attn: Attention,  # 注意力机制实例
        hidden_states: torch.Tensor,  # 当前隐藏状态的张量
        encoder_hidden_states: Optional[torch.Tensor] = None,  # 编码器隐藏状态的可选张量
        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码的可选张量
        temb: Optional[torch.Tensor] = None,  # 时间嵌入的可选张量
        image_rotary_emb: Optional[torch.Tensor] = None,  # 图像旋转嵌入的可选张量
class PAGHunyuanAttnProcessor2_0:
    r"""
    处理器用于实现缩放的点积注意力(如果使用 PyTorch 2.0,默认启用)。这是
    HunyuanDiT 模型中使用的。它在查询和键向量上应用归一化层和旋转嵌入。该处理器
    变体采用了 [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377)。
    """

    # 初始化方法
    def __init__(self):
        # 检查 F 中是否有 scaled_dot_product_attention 属性
        if not hasattr(F, "scaled_dot_product_attention"):
            # 如果没有,则抛出导入错误,提示需要升级 PyTorch 到 2.0
            raise ImportError(
                "PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
            )

    # 定义调用方法
    def __call__(
        self,
        attn: Attention,  # 注意力机制实例
        hidden_states: torch.Tensor,  # 当前隐藏状态的张量
        encoder_hidden_states: Optional[torch.Tensor] = None,  # 编码器隐藏状态的可选张量
        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码的可选张量
        temb: Optional[torch.Tensor] = None,  # 时间嵌入的可选张量
        image_rotary_emb: Optional[torch.Tensor] = None,  # 图像旋转嵌入的可选张量
class PAGCFGHunyuanAttnProcessor2_0:
    r"""
    处理器用于实现缩放的点积注意力(如果使用 PyTorch 2.0,默认启用)。这是
    HunyuanDiT 模型中使用的。它在查询和键向量上应用归一化层和旋转嵌入。该处理器
    变体采用了 [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377)。
    """
    # 初始化方法,用于创建类的实例
        def __init__(self):
            # 检查模块 F 是否具有属性 "scaled_dot_product_attention"
            if not hasattr(F, "scaled_dot_product_attention"):
                # 如果没有该属性,则抛出 ImportError,提示用户升级 PyTorch
                raise ImportError(
                    "PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
                )
    
    # 可调用方法,允许类的实例像函数一样被调用
        def __call__(
            self,
            attn: Attention,  # 注意力机制对象
            hidden_states: torch.Tensor,  # 当前隐藏状态的张量
            encoder_hidden_states: Optional[torch.Tensor] = None,  # 编码器的隐藏状态,可选参数
            attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码,可选参数
            temb: Optional[torch.Tensor] = None,  # 时间嵌入,可选参数
            image_rotary_emb: Optional[torch.Tensor] = None,  # 图像旋转嵌入,可选参数
# 定义一个用于实现缩放点积注意力的处理器类
class LuminaAttnProcessor2_0:
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
    used in the LuminaNextDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
    """

    # 初始化方法
    def __init__(self):
        # 检查 PyTorch 是否具有缩放点积注意力功能
        if not hasattr(F, "scaled_dot_product_attention"):
            # 如果没有,抛出导入错误,提示用户升级 PyTorch 到 2.0
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    # 定义调用方法,使类实例可调用
    def __call__(
        self,
        # 接收注意力对象
        attn: Attention,
        # 接收隐藏状态张量
        hidden_states: torch.Tensor,
        # 接收编码器隐藏状态张量
        encoder_hidden_states: torch.Tensor,
        # 可选的注意力掩码张量
        attention_mask: Optional[torch.Tensor] = None,
        # 可选的查询旋转嵌入张量
        query_rotary_emb: Optional[torch.Tensor] = None,
        # 可选的键旋转嵌入张量
        key_rotary_emb: Optional[torch.Tensor] = None,
        # 可选的基本序列长度
        base_sequence_length: Optional[int] = None,
    ) -> torch.Tensor:  # 函数返回一个张量,表示处理后的隐藏状态
        from .embeddings import apply_rotary_emb  # 从当前包导入应用旋转嵌入的函数

        input_ndim = hidden_states.ndim  # 获取隐藏状态的维度数

        if input_ndim == 4:  # 如果隐藏状态是四维张量
            batch_size, channel, height, width = hidden_states.shape  # 解包出批次大小、通道、高度和宽度
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)  # 重塑并转置隐藏状态

        batch_size, sequence_length, _ = hidden_states.shape  # 解包出批次大小和序列长度

        # Get Query-Key-Value Pair  # 获取查询、键、值对
        query = attn.to_q(hidden_states)  # 将隐藏状态转换为查询张量
        key = attn.to_k(encoder_hidden_states)  # 将编码器的隐藏状态转换为键张量
        value = attn.to_v(encoder_hidden_states)  # 将编码器的隐藏状态转换为值张量

        query_dim = query.shape[-1]  # 获取查询的最后一个维度(特征维度)
        inner_dim = key.shape[-1]  # 获取键的最后一个维度
        head_dim = query_dim // attn.heads  # 计算每个头的维度
        dtype = query.dtype  # 获取查询张量的数据类型

        # Get key-value heads  # 获取键值头的数量
        kv_heads = inner_dim // head_dim  # 计算每个头的键值数量

        # Apply Query-Key Norm if needed  # 如果需要,应用查询-键归一化
        if attn.norm_q is not None:  # 如果定义了查询的归一化
            query = attn.norm_q(query)  # 对查询进行归一化
        if attn.norm_k is not None:  # 如果定义了键的归一化
            key = attn.norm_k(key)  # 对键进行归一化

        query = query.view(batch_size, -1, attn.heads, head_dim)  # 重塑查询张量以适应头的维度

        key = key.view(batch_size, -1, kv_heads, head_dim)  # 重塑键张量以适应头的维度
        value = value.view(batch_size, -1, kv_heads, head_dim)  # 重塑值张量以适应头的维度

        # Apply RoPE if needed  # 如果需要,应用旋转位置嵌入
        if query_rotary_emb is not None:  # 如果定义了查询的旋转嵌入
            query = apply_rotary_emb(query, query_rotary_emb, use_real=False)  # 应用旋转嵌入到查询
        if key_rotary_emb is not None:  # 如果定义了键的旋转嵌入
            key = apply_rotary_emb(key, key_rotary_emb, use_real=False)  # 应用旋转嵌入到键

        query, key = query.to(dtype), key.to(dtype)  # 将查询和键转换为相同的数据类型

        # Apply proportional attention if true  # 如果为真,应用比例注意力
        if key_rotary_emb is None:  # 如果没有键的旋转嵌入
            softmax_scale = None  # 设置缩放因子为 None
        else:  # 如果有键的旋转嵌入
            if base_sequence_length is not None:  # 如果定义了基础序列长度
                softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale  # 计算缩放因子
            else:  # 如果没有定义基础序列长度
                softmax_scale = attn.scale  # 使用注意力的缩放因子

        # perform Grouped-query Attention (GQA)  # 执行分组查询注意力
        n_rep = attn.heads // kv_heads  # 计算每个键值头的重复数量
        if n_rep >= 1:  # 如果重复数量大于等于 1
            key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)  # 扩展并重复键
            value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)  # 扩展并重复值

        # scaled_dot_product_attention expects attention_mask shape to be  # 缩放点积注意力期望的注意力掩码形状
        # (batch, heads, source_length, target_length)
        attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)  # 将注意力掩码转换为布尔值并调整形状
        attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1)  # 扩展注意力掩码以匹配头的数量

        query = query.transpose(1, 2)  # 转置查询张量
        key = key.transpose(1, 2)  # 转置键张量
        value = value.transpose(1, 2)  # 转置值张量

        # the output of sdp = (batch, num_heads, seq_len, head_dim)  # 缩放点积注意力的输出形状
        # TODO: add support for attn.scale when we move to Torch 2.1  # TODO: 在迁移到 Torch 2.1 时支持 attn.scale
        hidden_states = F.scaled_dot_product_attention(  # 计算缩放点积注意力
            query, key, value, attn_mask=attention_mask, scale=softmax_scale  # 输入查询、键、值及注意力掩码和缩放因子
        )
        hidden_states = hidden_states.transpose(1, 2).to(dtype)  # 转置输出并转换为相应的数据类型

        return hidden_states  # 返回处理后的隐藏状态
# 定义一个用于实现缩放点积注意力的处理器类,默认启用(如果使用 PyTorch 2.0)
class FusedAttnProcessor2_0:
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
    fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
    For cross-attention modules, key and value projection matrices are fused.

    <Tip warning={true}>

    This API is currently 

标签:dim,self,attention,diffusers,states,源码,attn,hidden,解析
From: https://www.cnblogs.com/apachecn/p/18492386

相关文章

  • C语言使用指针作为函数参数,并利用函数嵌套求输入三个整数,将它们按大到小的顺序输出。(
    输入三个整数,要求从大到小的顺序向他们输出,用函数实现。   本代码使用到了指针和函数嵌套。   调用指针做函数ex,并嵌套调用指针函数exx在函数ex中。(代码在下面哦!)一、关于函数 ex  1. 这个函数接受三个指针参数 int*p1 、 int*p2 和 int*p3 ,分别指......
  • 【QT速成】半小时入门QT6之QT前置知识扫盲(超详细QT工程解析)
    目录一.QT工程介绍1.创建工程ModelDefineBuildSystemClassInformation BaseclassKit二.工程构成三.类汇总一.QT工程介绍1.创建工程Model    QT创建工程时首先会让我们选择项目模板,对应的英文解释很详尽,这里我们也可做一下简单介绍。应用程序(A......
  • 【关注可白嫖源码】计算机等级考试在线刷题小程序,不会的看过来
    设计一个计算机等级考试在线刷题小程序,需要确保系统能够提供高效的刷题功能,帮助用户随时随地练习。以下是系统的设计思路:一、系统设计总体思路该小程序需要包含用户端、题库管理系统、后台管理系统三大部分。用户可以通过小程序在线刷题、查看答案解析、查看个人练习情况,而......
  • 【可白嫖源码】基于SSM的在线点餐系统(案例分析)
    摘  要   当前高速发展的经济模式下,人们工作和生活都处于高压下,没时间做饭,在哪做饭成了人们的难题,传统下班回家做饭的生活习俗渐渐地变得难以实现。在社会驱动下,我国在餐饮方面的收入额,逐年成上升趋势。餐饮方面带来的收入拉高了社会消费品的零售总额。不得不说,餐饮......
  • 深入理解华为鸿蒙的 Context —— 应用上下文解析
    本文旨在深入探讨华为鸿蒙HarmonyOSNext系统(截止目前API12)的技术细节,基于实际开发实践进行总结。主要作为技术分享与交流载体,难免错漏,欢迎各位同仁提出宝贵意见和问题,以便共同进步。本文为原创内容,任何形式的转载必须注明出处及原作者。在华为鸿蒙(HarmonyOS)开发中,Context是......
  • 华为鸿蒙Next:应用启动框架AppStartup的解析与实战应用
    本文旨在深入探讨华为鸿蒙HarmonyOSNext系统(截止目前API12)的技术细节,基于实际开发实践进行总结。主要作为技术分享与交流载体,难免错漏,欢迎各位同仁提出宝贵意见和问题,以便共同进步。本文为原创内容,任何形式的转载必须注明出处及原作者。在华为鸿蒙(HarmonyOS)开发领域,应用的启......
  • 深入解析Apache DolphinScheduler容错机制
    简述ApacheDolphinschedulerMaster和Worker都是支持多节点部署,无中心化的设计。Master主要负责是流程DAG的切分,最终通过RPC将任务分发到Worker节点上以及Worker上任务状态的处理Worker主要负责是真正任务的执行,最后将任务状态汇报给Master,Master进行状态处理那问题来了:M......
  • GLM-4v-9B-源码解析-五-
    GLM-4v-9B源码解析(五)license:otherlicense_name:glm-4license_link:https://huggingface.co/THUDM/glm-4v-9b/blob/main/LICENSElanguage:zhentags:glmchatglmthudminference:falseGLM-4V-9B源码解析ReadthisinEnglish2024/08/12,本仓库代码已更新并使......
  • GLM-4v-9B-源码解析-四-
    GLM-4v-9B源码解析(四)GLM-4-9BChatdialoguemodelfine-tuningInthisdemo,youwillexperiencehowtofine-tunetheGLM-4-9B-Chatopensourcemodel(visualunderstandingmodelisnotsupported).Pleasestrictlyfollowthestepsinthedocumenttoavoidunne......
  • GLM-4v-9B-源码解析-三-
    GLM-4v-9B源码解析(三).\chatglm4-finetune\composite_demo\src\tools\browser.py#简单的浏览器工具说明"""Simplebrowsertool.#UsagePleasestartthebackendbrowserserveraccordingtotheinstructionsintheREADME."""#导入用于格式......