首页 > 编程语言 >diffusers-源码解析-十四-

diffusers-源码解析-十四-

时间:2024-10-22 12:33:12浏览次数:1  
标签:states int self attention diffusers channels 源码 hidden 解析

diffusers 源码解析(十四)

.\diffusers\models\unets\unet_2d_blocks_flax.py

# 版权声明,说明该文件的版权信息及相关许可协议
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 许可信息,使用 Apache License 2.0 许可
# Licensed under the Apache License, Version 2.0 (the "License");
# 该文件只能在符合许可证的情况下使用
# you may not use this file except in compliance with the License.
# 提供许可证获取链接
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 免责声明,表明不提供任何形式的保证或条件
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 许可证的相关权限及限制
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入 flax.linen 模块,用于构建神经网络
import flax.linen as nn
# 导入 jax.numpy,用于数值计算
import jax.numpy as jnp

# 从其他模块导入特定的类,用于构建模型的各个组件
from ..attention_flax import FlaxTransformer2DModel
from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D


# 定义 FlaxCrossAttnDownBlock2D 类,表示一个 2D 跨注意力下采样模块
class FlaxCrossAttnDownBlock2D(nn.Module):
    r"""
    跨注意力 2D 下采样块 - 原始架构来自 Unet transformers:
    https://arxiv.org/abs/2103.06104

    参数说明:
        in_channels (:obj:`int`):
            输入通道数
        out_channels (:obj:`int`):
            输出通道数
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout 率
        num_layers (:obj:`int`, *optional*, defaults to 1):
            注意力块层数
        num_attention_heads (:obj:`int`, *optional*, defaults to 1):
            每个空间变换块的注意力头数
        add_downsample (:obj:`bool`, *optional*, defaults to `True`):
            是否在每个最终输出之前添加下采样层
        use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
            启用内存高效的注意力 https://arxiv.org/abs/2112.05682
        split_head_dim (`bool`, *optional*, defaults to `False`):
            是否将头维度拆分为一个新的轴进行自注意力计算。在大多数情况下,
            启用此标志应加快 Stable Diffusion 2.x 和 Stable Diffusion XL 的计算速度。
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            参数的数据类型
    """

    # 定义输入通道数
    in_channels: int
    # 定义输出通道数
    out_channels: int
    # 定义 Dropout 率,默认为 0.0
    dropout: float = 0.0
    # 定义注意力块的层数,默认为 1
    num_layers: int = 1
    # 定义注意力头数,默认为 1
    num_attention_heads: int = 1
    # 定义是否添加下采样层,默认为 True
    add_downsample: bool = True
    # 定义是否使用线性投影,默认为 False
    use_linear_projection: bool = False
    # 定义是否仅使用跨注意力,默认为 False
    only_cross_attention: bool = False
    # 定义是否启用内存高效注意力,默认为 False
    use_memory_efficient_attention: bool = False
    # 定义是否拆分头维度,默认为 False
    split_head_dim: bool = False
    # 定义参数的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32
    # 定义每个块的变换器层数,默认为 1
    transformer_layers_per_block: int = 1
    # 设置模型的各个组成部分,包括残差块和注意力块
        def setup(self):
            # 初始化残差块列表
            resnets = []
            # 初始化注意力块列表
            attentions = []
    
            # 遍历每一层,构建残差块和注意力块
            for i in range(self.num_layers):
                # 第一层的输入通道为 in_channels,其他层为 out_channels
                in_channels = self.in_channels if i == 0 else self.out_channels
    
                # 创建一个 FlaxResnetBlock2D 实例
                res_block = FlaxResnetBlock2D(
                    in_channels=in_channels,  # 输入通道
                    out_channels=self.out_channels,  # 输出通道
                    dropout_prob=self.dropout,  # 丢弃率
                    dtype=self.dtype,  # 数据类型
                )
                # 将残差块添加到列表中
                resnets.append(res_block)
    
                # 创建一个 FlaxTransformer2DModel 实例
                attn_block = FlaxTransformer2DModel(
                    in_channels=self.out_channels,  # 输入通道
                    n_heads=self.num_attention_heads,  # 注意力头数
                    d_head=self.out_channels // self.num_attention_heads,  # 每个头的维度
                    depth=self.transformer_layers_per_block,  # 每个块的层数
                    use_linear_projection=self.use_linear_projection,  # 是否使用线性投影
                    only_cross_attention=self.only_cross_attention,  # 是否只使用交叉注意力
                    use_memory_efficient_attention=self.use_memory_efficient_attention,  # 是否使用内存高效的注意力
                    split_head_dim=self.split_head_dim,  # 是否拆分头的维度
                    dtype=self.dtype,  # 数据类型
                )
                # 将注意力块添加到列表中
                attentions.append(attn_block)
    
            # 将残差块列表赋值给实例变量
            self.resnets = resnets
            # 将注意力块列表赋值给实例变量
            self.attentions = attentions
    
            # 如果需要下采样,则创建下采样层
            if self.add_downsample:
                self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
    
        # 定义前向调用方法,处理隐藏状态和编码器隐藏状态
        def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
            # 初始化输出状态元组
            output_states = ()
    
            # 遍历残差块和注意力块并进行处理
            for resnet, attn in zip(self.resnets, self.attentions):
                # 通过残差块处理隐藏状态
                hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
                # 通过注意力块处理隐藏状态
                hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
                # 将当前隐藏状态添加到输出状态元组中
                output_states += (hidden_states,)
    
            # 如果需要下采样,则进行下采样
            if self.add_downsample:
                hidden_states = self.downsamplers_0(hidden_states)
                # 将下采样后的隐藏状态添加到输出状态元组中
                output_states += (hidden_states,)
    
            # 返回最终的隐藏状态和输出状态元组
            return hidden_states, output_states
# 定义 Flax 2D 降维块类,继承自 nn.Module
class FlaxDownBlock2D(nn.Module):
    r"""
    Flax 2D downsizing block

    Parameters:
        in_channels (:obj:`int`):
            Input channels
        out_channels (:obj:`int`):
            Output channels
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout rate
        num_layers (:obj:`int`, *optional*, defaults to 1):
            Number of attention blocks layers
        add_downsample (:obj:`bool`, *optional*, defaults to `True`):
            Whether to add downsampling layer before each final output
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            Parameters `dtype`
    """
    
    # 声明输入输出通道和其他参数
    in_channels: int
    out_channels: int
    dropout: float = 0.0
    num_layers: int = 1
    add_downsample: bool = True
    dtype: jnp.dtype = jnp.float32

    # 设置方法,用于初始化模型的层
    def setup(self):
        # 创建空列表以存储残差块
        resnets = []

        # 根据层数创建残差块
        for i in range(self.num_layers):
            # 第一个块的输入通道为 in_channels,其余为 out_channels
            in_channels = self.in_channels if i == 0 else self.out_channels

            # 创建残差块实例
            res_block = FlaxResnetBlock2D(
                in_channels=in_channels,
                out_channels=self.out_channels,
                dropout_prob=self.dropout,
                dtype=self.dtype,
            )
            # 将残差块添加到列表中
            resnets.append(res_block)
        # 将列表赋值给实例属性
        self.resnets = resnets

        # 如果需要,添加降采样层
        if self.add_downsample:
            self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)

    # 调用方法,执行前向传播
    def __call__(self, hidden_states, temb, deterministic=True):
        # 创建空元组以存储输出状态
        output_states = ()

        # 遍历所有残差块进行前向传播
        for resnet in self.resnets:
            hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
            # 将当前隐藏状态添加到输出状态中
            output_states += (hidden_states,)

        # 如果需要,应用降采样层
        if self.add_downsample:
            hidden_states = self.downsamplers_0(hidden_states)
            # 将降采样后的隐藏状态添加到输出状态中
            output_states += (hidden_states,)

        # 返回最终的隐藏状态和输出状态
        return hidden_states, output_states


# 定义 Flax 交叉注意力 2D 上采样块类,继承自 nn.Module
class FlaxCrossAttnUpBlock2D(nn.Module):
    r"""
    Cross Attention 2D Upsampling block - original architecture from Unet transformers:
    https://arxiv.org/abs/2103.06104
    # 定义参数的文档字符串,描述各个参数的用途和类型
        Parameters:
            in_channels (:obj:`int`):  # 输入通道数
                Input channels
            out_channels (:obj:`int`):  # 输出通道数
                Output channels
            dropout (:obj:`float`, *optional*, defaults to 0.0):  # Dropout 率,默认值为 0.0
                Dropout rate
            num_layers (:obj:`int`, *optional*, defaults to 1):  # 注意力块的层数,默认值为 1
                Number of attention blocks layers
            num_attention_heads (:obj:`int`, *optional*, defaults to 1):  # 每个空间变换块的注意力头数量,默认值为 1
                Number of attention heads of each spatial transformer block
            add_upsample (:obj:`bool`, *optional*, defaults to `True`):  # 是否在每个最终输出前添加上采样层,默认值为 True
                Whether to add upsampling layer before each final output
            use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):  # 启用内存高效注意力,默认值为 False
                enable memory efficient attention https://arxiv.org/abs/2112.05682
            split_head_dim (`bool`, *optional*, defaults to `False`):  # 是否将头维度拆分为新轴以进行自注意力计算,默认值为 False
                Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
                enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
            dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):  # 数据类型参数,默认值为 jnp.float32
                Parameters `dtype`
        """
    
        in_channels: int  # 输入通道数的声明
        out_channels: int  # 输出通道数的声明
        prev_output_channel: int  # 前一个输出通道数的声明
        dropout: float = 0.0  # Dropout 率的声明,默认值为 0.0
        num_layers: int = 1  # 注意力层数的声明,默认值为 1
        num_attention_heads: int = 1  # 注意力头数量的声明,默认值为 1
        add_upsample: bool = True  # 是否添加上采样的声明,默认值为 True
        use_linear_projection: bool = False  # 是否使用线性投影的声明,默认值为 False
        only_cross_attention: bool = False  # 是否仅使用交叉注意力的声明,默认值为 False
        use_memory_efficient_attention: bool = False  # 是否启用内存高效注意力的声明,默认值为 False
        split_head_dim: bool = False  # 是否拆分头维度的声明,默认值为 False
        dtype: jnp.dtype = jnp.float32  # 数据类型的声明,默认值为 jnp.float32
        transformer_layers_per_block: int = 1  # 每个块的变换层数的声明,默认值为 1
    # 设置方法,初始化网络结构
    def setup(self):
        # 初始化空列表以存储 ResNet 块
        resnets = []
        # 初始化空列表以存储注意力块
        attentions = []
    
        # 遍历每一层以创建相应的 ResNet 和注意力块
        for i in range(self.num_layers):
            # 设置跳跃连接的通道数,最后一层使用输入通道,否则使用输出通道
            res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
            # 设置当前 ResNet 块的输入通道,第一层使用前一层的输出通道
            resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
    
            # 创建 FlaxResnetBlock2D 实例
            res_block = FlaxResnetBlock2D(
                # 设置输入通道为当前 ResNet 块输入通道加跳跃连接通道
                in_channels=resnet_in_channels + res_skip_channels,
                # 设置输出通道为指定的输出通道
                out_channels=self.out_channels,
                # 设置 dropout 概率
                dropout_prob=self.dropout,
                # 设置数据类型
                dtype=self.dtype,
            )
            # 将创建的 ResNet 块添加到列表中
            resnets.append(res_block)
    
            # 创建 FlaxTransformer2DModel 实例
            attn_block = FlaxTransformer2DModel(
                # 设置输入通道为输出通道
                in_channels=self.out_channels,
                # 设置注意力头数
                n_heads=self.num_attention_heads,
                # 设置每个注意力头的维度
                d_head=self.out_channels // self.num_attention_heads,
                # 设置 transformer 块的深度
                depth=self.transformer_layers_per_block,
                # 设置是否使用线性投影
                use_linear_projection=self.use_linear_projection,
                # 设置是否仅使用交叉注意力
                only_cross_attention=self.only_cross_attention,
                # 设置是否使用内存高效的注意力机制
                use_memory_efficient_attention=self.use_memory_efficient_attention,
                # 设置是否分割头部维度
                split_head_dim=self.split_head_dim,
                # 设置数据类型
                dtype=self.dtype,
            )
            # 将创建的注意力块添加到列表中
            attentions.append(attn_block)
    
        # 将 ResNet 列表保存到实例属性
        self.resnets = resnets
        # 将注意力列表保存到实例属性
        self.attentions = attentions
    
        # 如果需要添加上采样层,则创建相应的 FlaxUpsample2D 实例
        if self.add_upsample:
            self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
    
    # 定义调用方法,接受隐藏状态和其他参数
    def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True):
        # 遍历 ResNet 和注意力块
        for resnet, attn in zip(self.resnets, self.attentions):
            # 从跳跃连接的隐藏状态元组中取出最后一个状态
            res_hidden_states = res_hidden_states_tuple[-1]
            # 更新跳跃连接的隐藏状态元组,去掉最后一个状态
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
            # 将隐藏状态与跳跃连接的隐藏状态在最后一个轴上拼接
            hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
    
            # 使用当前的 ResNet 块处理隐藏状态
            hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
            # 使用当前的注意力块处理隐藏状态
            hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
    
        # 如果需要添加上采样,则使用上采样层处理隐藏状态
        if self.add_upsample:
            hidden_states = self.upsamplers_0(hidden_states)
    
        # 返回处理后的隐藏状态
        return hidden_states
# 定义一个 2D 上采样块类,继承自 nn.Module
class FlaxUpBlock2D(nn.Module):
    r"""
    Flax 2D upsampling block

    Parameters:
        in_channels (:obj:`int`):
            Input channels
        out_channels (:obj:`int`):
            Output channels
        prev_output_channel (:obj:`int`):
            Output channels from the previous block
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout rate
        num_layers (:obj:`int`, *optional*, defaults to 1):
            Number of attention blocks layers
        add_downsample (:obj:`bool`, *optional*, defaults to `True`):
            Whether to add downsampling layer before each final output
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            Parameters `dtype`
    """

    # 定义输入输出通道和其他参数
    in_channels: int
    out_channels: int
    prev_output_channel: int
    dropout: float = 0.0
    num_layers: int = 1
    add_upsample: bool = True
    dtype: jnp.dtype = jnp.float32

    # 设置方法用于初始化块的结构
    def setup(self):
        resnets = []  # 创建一个空列表用于存储 ResNet 块

        # 遍历每一层,创建 ResNet 块
        for i in range(self.num_layers):
            # 计算跳跃连接通道数
            res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
            # 设置输入通道数
            resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels

            # 创建一个新的 FlaxResnetBlock2D 实例
            res_block = FlaxResnetBlock2D(
                in_channels=resnet_in_channels + res_skip_channels,
                out_channels=self.out_channels,
                dropout_prob=self.dropout,
                dtype=self.dtype,
            )
            resnets.append(res_block)  # 将块添加到列表中

        self.resnets = resnets  # 将列表赋值给实例变量

        # 如果需要上采样,初始化上采样层
        if self.add_upsample:
            self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)

    # 定义前向传播方法
    def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True):
        # 遍历每个 ResNet 块进行前向传播
        for resnet in self.resnets:
            # 从元组中弹出最后的残差隐藏状态
            res_hidden_states = res_hidden_states_tuple[-1]
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]  # 更新元组,去掉最后一项
            # 连接当前隐藏状态与残差隐藏状态
            hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)

            # 通过 ResNet 块处理隐藏状态
            hidden_states = resnet(hidden_states, temb, deterministic=deterministic)

        # 如果需要上采样,调用上采样层
        if self.add_upsample:
            hidden_states = self.upsamplers_0(hidden_states)

        return hidden_states  # 返回处理后的隐藏状态


# 定义一个 2D 中级交叉注意力块类,继承自 nn.Module
class FlaxUNetMidBlock2DCrossAttn(nn.Module):
    r"""
    Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104
    # 定义参数的文档字符串
    Parameters:
        in_channels (:obj:`int`):  # 输入通道数
            Input channels
        dropout (:obj:`float`, *optional*, defaults to 0.0):  # Dropout比率,默认为0.0
            Dropout rate
        num_layers (:obj:`int`, *optional*, defaults to 1):  # 注意力层的数量,默认为1
            Number of attention blocks layers
        num_attention_heads (:obj:`int`, *optional*, defaults to 1):  # 每个空间变换块的注意力头数量,默认为1
            Number of attention heads of each spatial transformer block
        use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):  # 是否启用内存高效的注意力机制,默认为False
            enable memory efficient attention https://arxiv.org/abs/2112.05682
        split_head_dim (`bool`, *optional*, defaults to `False`):  # 是否将头维度分割为新的轴以加速计算,默认为False
            Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
            enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):  # 数据类型参数,默认为jnp.float32
            Parameters `dtype`
    """

    in_channels: int  # 输入通道数的类型
    dropout: float = 0.0  # Dropout比率的默认值
    num_layers: int = 1  # 注意力层数量的默认值
    num_attention_heads: int = 1  # 注意力头数量的默认值
    use_linear_projection: bool = False  # 是否使用线性投影的默认值
    use_memory_efficient_attention: bool = False  # 是否使用内存高效注意力的默认值
    split_head_dim: bool = False  # 是否分割头维度的默认值
    dtype: jnp.dtype = jnp.float32  # 数据类型的默认值
    transformer_layers_per_block: int = 1  # 每个块中的变换层数量的默认值

    def setup(self):  # 设置方法,用于初始化
        # 至少会有一个ResNet块
        resnets = [  # 创建ResNet块列表
            FlaxResnetBlock2D(  # 创建一个ResNet块
                in_channels=self.in_channels,  # 输入通道数
                out_channels=self.in_channels,  # 输出通道数
                dropout_prob=self.dropout,  # Dropout概率
                dtype=self.dtype,  # 数据类型
            )
        ]

        attentions = []  # 初始化注意力块列表

        for _ in range(self.num_layers):  # 遍历指定的注意力层数
            attn_block = FlaxTransformer2DModel(  # 创建一个Transformer块
                in_channels=self.in_channels,  # 输入通道数
                n_heads=self.num_attention_heads,  # 注意力头数量
                d_head=self.in_channels // self.num_attention_heads,  # 每个头的维度
                depth=self.transformer_layers_per_block,  # 变换层深度
                use_linear_projection=self.use_linear_projection,  # 是否使用线性投影
                use_memory_efficient_attention=self.use_memory_efficient_attention,  # 是否使用内存高效注意力
                split_head_dim=self.split_head_dim,  # 是否分割头维度
                dtype=self.dtype,  # 数据类型
            )
            attentions.append(attn_block)  # 将注意力块添加到列表中

            res_block = FlaxResnetBlock2D(  # 创建一个ResNet块
                in_channels=self.in_channels,  # 输入通道数
                out_channels=self.in_channels,  # 输出通道数
                dropout_prob=self.dropout,  # Dropout概率
                dtype=self.dtype,  # 数据类型
            )
            resnets.append(res_block)  # 将ResNet块添加到列表中

        self.resnets = resnets  # 将ResNet块列表赋值给实例属性
        self.attentions = attentions  # 将注意力块列表赋值给实例属性

    def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):  # 调用方法
        hidden_states = self.resnets[0](hidden_states, temb)  # 通过第一个ResNet块处理隐藏状态
        for attn, resnet in zip(self.attentions, self.resnets[1:]):  # 遍历每个注意力块和后续ResNet块
            hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)  # 处理隐藏状态
            hidden_states = resnet(hidden_states, temb, deterministic=deterministic)  # 再次处理隐藏状态

        return hidden_states  # 返回处理后的隐藏状态

.\diffusers\models\unets\unet_2d_condition.py

# 版权声明,标明版权信息和使用许可
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 按照 Apache License 2.0 版本进行许可
# Licensed under the Apache License, Version 2.0 (the "License");
# 你不得在未遵守许可的情况下使用此文件
# you may not use this file except in compliance with the License.
# 你可以在以下网址获得许可证的副本
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非法律要求或书面同意,软件以“按原样”方式分发,不提供任何形式的担保或条件
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 查看许可证以了解特定语言所适用的权限和限制
# See the License for the specific language governing permissions and
# limitations under the License.

# 从 dataclasses 模块导入 dataclass 装饰器,用于简化类的定义
from dataclasses import dataclass
# 导入所需的类型注释
from typing import Any, Dict, List, Optional, Tuple, Union

# 导入 PyTorch 库和相关模块
import torch
import torch.nn as nn
import torch.utils.checkpoint

# 从配置和加载器模块中导入所需的类和函数
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ..activations import get_activation
from ..attention_processor import (
    ADDED_KV_ATTENTION_PROCESSORS,  # 导入与注意力机制相关的处理器
    CROSS_ATTENTION_PROCESSORS,
    Attention,
    AttentionProcessor,
    AttnAddedKVProcessor,
    AttnProcessor,
    FusedAttnProcessor2_0,
)
from ..embeddings import (
    GaussianFourierProjection,  # 导入多种嵌入方法
    GLIGENTextBoundingboxProjection,
    ImageHintTimeEmbedding,
    ImageProjection,
    ImageTimeEmbedding,
    TextImageProjection,
    TextImageTimeEmbedding,
    TextTimeEmbedding,
    TimestepEmbedding,
    Timesteps,
)
from ..modeling_utils import ModelMixin  # 导入模型混合类
from .unet_2d_blocks import (
    get_down_block,  # 导入下采样块的构造函数
    get_mid_block,   # 导入中间块的构造函数
    get_up_block,    # 导入上采样块的构造函数
)

# 创建一个日志记录器,用于记录模型相关信息
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# 定义 UNet2DConditionOutput 数据类,用于存储 UNet2DConditionModel 的输出
@dataclass
class UNet2DConditionOutput(BaseOutput):
    """
    UNet2DConditionModel 的输出。

    参数:
        sample (`torch.Tensor`,形状为 `(batch_size, num_channels, height, width)`):
            基于 `encoder_hidden_states` 输入的隐藏状态输出,模型最后一层的输出。
    """

    sample: torch.Tensor = None  # 定义一个样本属性,默认为 None

# 定义 UNet2DConditionModel 类,表示一个条件 2D UNet 模型
class UNet2DConditionModel(
    ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin
):
    r"""
    一个条件 2D UNet 模型,接受一个噪声样本、条件状态和时间步,并返回样本形状的输出。

    该模型继承自 [`ModelMixin`]。查看超类文档以获取其为所有模型实现的通用方法
    (例如下载或保存)。
    """

    _supports_gradient_checkpointing = True  # 表示该模型支持梯度检查点
    _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]  # 不进行拆分的模块列表

    @register_to_config  # 将该方法注册到配置中
    # 初始化方法,设置类的基本属性
        def __init__(
            # 样本大小,默认为 None
            self,
            sample_size: Optional[int] = None,
            # 输入通道数,默认为 4
            in_channels: int = 4,
            # 输出通道数,默认为 4
            out_channels: int = 4,
            # 是否将输入样本中心化,默认为 False
            center_input_sample: bool = False,
            # 是否将正弦函数翻转为余弦函数,默认为 True
            flip_sin_to_cos: bool = True,
            # 频率偏移量,默认为 0
            freq_shift: int = 0,
            # 向下采样的块类型,包含多种块类型
            down_block_types: Tuple[str] = (
                "CrossAttnDownBlock2D",
                "CrossAttnDownBlock2D",
                "CrossAttnDownBlock2D",
                "DownBlock2D",
            ),
            # 中间块的类型,默认为 UNet 的中间块类型
            mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
            # 向上采样的块类型,包含多种块类型
            up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
            # 是否仅使用交叉注意力,默认为 False
            only_cross_attention: Union[bool, Tuple[bool]] = False,
            # 每个块的输出通道数
            block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
            # 每个块的层数,默认为 2
            layers_per_block: Union[int, Tuple[int]] = 2,
            # 下采样时的填充大小,默认为 1
            downsample_padding: int = 1,
            # 中间块的缩放因子,默认为 1
            mid_block_scale_factor: float = 1,
            # dropout 概率,默认为 0.0
            dropout: float = 0.0,
            # 激活函数类型,默认为 "silu"
            act_fn: str = "silu",
            # 归一化的组数,默认为 32
            norm_num_groups: Optional[int] = 32,
            # 归一化的 epsilon 值,默认为 1e-5
            norm_eps: float = 1e-5,
            # 交叉注意力的维度,默认为 1280
            cross_attention_dim: Union[int, Tuple[int]] = 1280,
            # 每个块的变换层数,默认为 1
            transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
            # 反向变换层的块数,默认为 None
            reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
            # 编码器隐藏层的维度,默认为 None
            encoder_hid_dim: Optional[int] = None,
            # 编码器隐藏层类型,默认为 None
            encoder_hid_dim_type: Optional[str] = None,
            # 注意力头的维度,默认为 8
            attention_head_dim: Union[int, Tuple[int]] = 8,
            # 注意力头的数量,默认为 None
            num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
            # 是否使用双交叉注意力,默认为 False
            dual_cross_attention: bool = False,
            # 是否使用线性投影,默认为 False
            use_linear_projection: bool = False,
            # 类嵌入类型,默认为 None
            class_embed_type: Optional[str] = None,
            # 附加嵌入类型,默认为 None
            addition_embed_type: Optional[str] = None,
            # 附加时间嵌入维度,默认为 None
            addition_time_embed_dim: Optional[int] = None,
            # 类嵌入数量,默认为 None
            num_class_embeds: Optional[int] = None,
            # 是否上溯注意力,默认为 False
            upcast_attention: bool = False,
            # ResNet 时间缩放偏移类型,默认为 "default"
            resnet_time_scale_shift: str = "default",
            # ResNet 是否跳过时间激活,默认为 False
            resnet_skip_time_act: bool = False,
            # ResNet 输出缩放因子,默认为 1.0
            resnet_out_scale_factor: float = 1.0,
            # 时间嵌入类型,默认为 "positional"
            time_embedding_type: str = "positional",
            # 时间嵌入维度,默认为 None
            time_embedding_dim: Optional[int] = None,
            # 时间嵌入激活函数,默认为 None
            time_embedding_act_fn: Optional[str] = None,
            # 时间步后激活函数,默认为 None
            timestep_post_act: Optional[str] = None,
            # 时间条件投影维度,默认为 None
            time_cond_proj_dim: Optional[int] = None,
            # 输入卷积核大小,默认为 3
            conv_in_kernel: int = 3,
            # 输出卷积核大小,默认为 3
            conv_out_kernel: int = 3,
            # 投影类嵌入输入维度,默认为 None
            projection_class_embeddings_input_dim: Optional[int] = None,
            # 注意力类型,默认为 "default"
            attention_type: str = "default",
            # 类嵌入是否拼接,默认为 False
            class_embeddings_concat: bool = False,
            # 中间块是否仅使用交叉注意力,默认为 None
            mid_block_only_cross_attention: Optional[bool] = None,
            # 交叉注意力归一化类型,默认为 None
            cross_attention_norm: Optional[str] = None,
            # 附加嵌入类型的头数量,默认为 64
            addition_embed_type_num_heads: int = 64,
    # 定义一个私有方法,用于检查配置参数
        def _check_config(
            self,
            # 定义下行块类型的元组,表示模型的结构
            down_block_types: Tuple[str],
            # 定义上行块类型的元组,表示模型的结构
            up_block_types: Tuple[str],
            # 定义仅使用交叉注意力的标志,可以是布尔值或布尔值的元组
            only_cross_attention: Union[bool, Tuple[bool]],
            # 定义每个块的输出通道数的元组,表示层的宽度
            block_out_channels: Tuple[int],
            # 定义每个块的层数,可以是整数或整数的元组
            layers_per_block: Union[int, Tuple[int]],
            # 定义交叉注意力维度,可以是整数或整数的元组
            cross_attention_dim: Union[int, Tuple[int]],
            # 定义每个块的变换器层数,可以是整数、整数的元组或元组的元组
            transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
            # 定义是否反转变换器层的布尔值
            reverse_transformer_layers_per_block: bool,
            # 定义注意力头的维度,表示注意力的分辨率
            attention_head_dim: int,
            # 定义注意力头的数量,可以是可选的整数或整数的元组
            num_attention_heads: Optional[Union[int, Tuple[int]],
    ):
        # 检查 down_block_types 和 up_block_types 的长度是否相同
        if len(down_block_types) != len(up_block_types):
            # 如果不同,抛出值错误并提供详细信息
            raise ValueError(
                f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
            )

        # 检查 block_out_channels 和 down_block_types 的长度是否相同
        if len(block_out_channels) != len(down_block_types):
            # 如果不同,抛出值错误并提供详细信息
            raise ValueError(
                f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
            )

        # 检查 only_cross_attention 是否为布尔值且长度与 down_block_types 相同
        if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
            # 如果不满足条件,抛出值错误并提供详细信息
            raise ValueError(
                f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
            )

        # 检查 num_attention_heads 是否为整数且长度与 down_block_types 相同
        if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
            # 如果不满足条件,抛出值错误并提供详细信息
            raise ValueError(
                f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
            )

        # 检查 attention_head_dim 是否为整数且长度与 down_block_types 相同
        if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
            # 如果不满足条件,抛出值错误并提供详细信息
            raise ValueError(
                f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
            )

        # 检查 cross_attention_dim 是否为列表且长度与 down_block_types 相同
        if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
            # 如果不满足条件,抛出值错误并提供详细信息
            raise ValueError(
                f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
            )

        # 检查 layers_per_block 是否为整数且长度与 down_block_types 相同
        if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
            # 如果不满足条件,抛出值错误并提供详细信息
            raise ValueError(
                f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
            )
        # 检查 transformer_layers_per_block 是否为列表且 reverse_transformer_layers_per_block 为 None
        if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
            # 遍历 transformer_layers_per_block 中的每个层
            for layer_number_per_block in transformer_layers_per_block:
                # 检查每个层是否为列表
                if isinstance(layer_number_per_block, list):
                    # 如果是,则抛出值错误,提示需要提供 reverse_transformer_layers_per_block
                    raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")

    # 定义设置时间投影的私有方法
    def _set_time_proj(
        self,
        # 时间嵌入类型
        time_embedding_type: str,
        # 块输出通道数
        block_out_channels: int,
        # 是否翻转正弦和余弦
        flip_sin_to_cos: bool,
        # 频率偏移
        freq_shift: float,
        # 时间嵌入维度
        time_embedding_dim: int,
    # 返回时间嵌入维度和时间步输入维度的元组
    ) -> Tuple[int, int]:
        # 判断时间嵌入类型是否为傅里叶
        if time_embedding_type == "fourier":
            # 计算时间嵌入维度,默认为 block_out_channels[0] * 2
            time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
            # 确保时间嵌入维度为偶数
            if time_embed_dim % 2 != 0:
                raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
            # 初始化高斯傅里叶投影,设定相关参数
            self.time_proj = GaussianFourierProjection(
                time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
            )
            # 设置时间步输入维度为时间嵌入维度
            timestep_input_dim = time_embed_dim
        # 判断时间嵌入类型是否为位置编码
        elif time_embedding_type == "positional":
            # 计算时间嵌入维度,默认为 block_out_channels[0] * 4
            time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
            # 初始化时间步对象,设定相关参数
            self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
            # 设置时间步输入维度为 block_out_channels[0]
            timestep_input_dim = block_out_channels[0]
        # 如果时间嵌入类型不合法,抛出错误
        else:
            raise ValueError(
                f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
            )
    
        # 返回时间嵌入维度和时间步输入维度
        return time_embed_dim, timestep_input_dim
    
    # 定义设置编码器隐藏投影的方法
    def _set_encoder_hid_proj(
        self,
        encoder_hid_dim_type: Optional[str],
        cross_attention_dim: Union[int, Tuple[int]],
        encoder_hid_dim: Optional[int],
    ):
        # 如果编码器隐藏维度类型为空且隐藏维度已定义
        if encoder_hid_dim_type is None and encoder_hid_dim is not None:
            # 默认将编码器隐藏维度类型设为'text_proj'
            encoder_hid_dim_type = "text_proj"
            # 注册编码器隐藏维度类型到配置中
            self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
            # 记录信息日志
            logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
    
        # 如果编码器隐藏维度为空且隐藏维度类型已定义,抛出错误
        if encoder_hid_dim is None and encoder_hid_dim_type is not None:
            raise ValueError(
                f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
            )
    
        # 判断编码器隐藏维度类型是否为'text_proj'
        if encoder_hid_dim_type == "text_proj":
            # 初始化线性投影层,输入维度为encoder_hid_dim,输出维度为cross_attention_dim
            self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
        # 判断编码器隐藏维度类型是否为'text_image_proj'
        elif encoder_hid_dim_type == "text_image_proj":
            # 初始化文本-图像投影对象,设定相关参数
            self.encoder_hid_proj = TextImageProjection(
                text_embed_dim=encoder_hid_dim,
                image_embed_dim=cross_attention_dim,
                cross_attention_dim=cross_attention_dim,
            )
        # 判断编码器隐藏维度类型是否为'image_proj'
        elif encoder_hid_dim_type == "image_proj":
            # 初始化图像投影对象,设定相关参数
            self.encoder_hid_proj = ImageProjection(
                image_embed_dim=encoder_hid_dim,
                cross_attention_dim=cross_attention_dim,
            )
        # 如果编码器隐藏维度类型不合法,抛出错误
        elif encoder_hid_dim_type is not None:
            raise ValueError(
                f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
            )
        # 如果都不符合,将编码器隐藏投影设为None
        else:
            self.encoder_hid_proj = None
    # 设置类嵌入的私有方法
        def _set_class_embedding(
            self,
            class_embed_type: Optional[str],  # 嵌入类型,可能为 None 或特定字符串
            act_fn: str,  # 激活函数的名称
            num_class_embeds: Optional[int],  # 类嵌入数量,可能为 None
            projection_class_embeddings_input_dim: Optional[int],  # 投影类嵌入输入维度,可能为 None
            time_embed_dim: int,  # 时间嵌入的维度
            timestep_input_dim: int,  # 时间步输入的维度
        ):
            # 如果嵌入类型为 None 且类嵌入数量不为 None
            if class_embed_type is None and num_class_embeds is not None:
                # 创建嵌入层,大小为类嵌入数量和时间嵌入维度
                self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
            # 如果嵌入类型为 "timestep"
            elif class_embed_type == "timestep":
                # 创建时间步嵌入对象
                self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
            # 如果嵌入类型为 "identity"
            elif class_embed_type == "identity":
                # 创建恒等层,输入和输出维度相同
                self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
            # 如果嵌入类型为 "projection"
            elif class_embed_type == "projection":
                # 如果投影类嵌入输入维度为 None,抛出错误
                if projection_class_embeddings_input_dim is None:
                    raise ValueError(
                        "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
                    )
                # 创建投影时间步嵌入对象
                self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
            # 如果嵌入类型为 "simple_projection"
            elif class_embed_type == "simple_projection":
                # 如果投影类嵌入输入维度为 None,抛出错误
                if projection_class_embeddings_input_dim is None:
                    raise ValueError(
                        "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
                    )
                # 创建线性层作为简单投影
                self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
            # 如果没有匹配的嵌入类型
            else:
                # 将类嵌入设置为 None
                self.class_embedding = None
    
        # 设置附加嵌入的私有方法
        def _set_add_embedding(
            self,
            addition_embed_type: str,  # 附加嵌入类型
            addition_embed_type_num_heads: int,  # 附加嵌入类型的头数
            addition_time_embed_dim: Optional[int],  # 附加时间嵌入维度,可能为 None
            flip_sin_to_cos: bool,  # 是否翻转正弦到余弦
            freq_shift: float,  # 频率偏移量
            cross_attention_dim: Optional[int],  # 交叉注意力维度,可能为 None
            encoder_hid_dim: Optional[int],  # 编码器隐藏维度,可能为 None
            projection_class_embeddings_input_dim: Optional[int],  # 投影类嵌入输入维度,可能为 None
            time_embed_dim: int,  # 时间嵌入维度
    ):
        # 检查附加嵌入类型是否为 "text"
        if addition_embed_type == "text":
            # 如果编码器隐藏维度不为 None,则使用该维度
            if encoder_hid_dim is not None:
                text_time_embedding_from_dim = encoder_hid_dim
            # 否则使用交叉注意力维度
            else:
                text_time_embedding_from_dim = cross_attention_dim

            # 创建文本时间嵌入对象
            self.add_embedding = TextTimeEmbedding(
                text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
            )
        # 检查附加嵌入类型是否为 "text_image"
        elif addition_embed_type == "text_image":
            # text_embed_dim 和 image_embed_dim 不必是 `cross_attention_dim`,为了避免 __init__ 过于繁杂
            # 在这里设置为 `cross_attention_dim`,因为这是当前唯一使用情况的所需维度 (Kandinsky 2.1)
            self.add_embedding = TextImageTimeEmbedding(
                text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
            )
        # 检查附加嵌入类型是否为 "text_time"
        elif addition_embed_type == "text_time":
            # 创建时间投影对象
            self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
            # 创建时间嵌入对象
            self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
        # 检查附加嵌入类型是否为 "image"
        elif addition_embed_type == "image":
            # Kandinsky 2.2
            # 创建图像时间嵌入对象
            self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
        # 检查附加嵌入类型是否为 "image_hint"
        elif addition_embed_type == "image_hint":
            # Kandinsky 2.2 ControlNet
            # 创建图像提示时间嵌入对象
            self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
        # 检查附加嵌入类型是否为 None 以外的值
        elif addition_embed_type is not None:
            # 抛出值错误,提示无效的附加嵌入类型
            raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")

    # 定义一个属性方法,用于设置位置网络
    def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
        # 检查注意力类型是否为 "gated" 或 "gated-text-image"
        if attention_type in ["gated", "gated-text-image"]:
            positive_len = 768  # 默认的正向长度
            # 如果交叉注意力维度是整数,则使用该值
            if isinstance(cross_attention_dim, int):
                positive_len = cross_attention_dim
            # 如果交叉注意力维度是列表或元组,则使用第一个值
            elif isinstance(cross_attention_dim, (list, tuple)):
                positive_len = cross_attention_dim[0]

            # 根据注意力类型确定特征类型
            feature_type = "text-only" if attention_type == "gated" else "text-image"
            # 创建 GLIGEN 文本边界框投影对象
            self.position_net = GLIGENTextBoundingboxProjection(
                positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
            )

    # 定义一个属性
    @property
    # 定义一个方法,返回一个字典,包含模型中所有的注意力处理器
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
        # 初始化一个空字典,用于存储注意力处理器
        processors = {}

        # 定义一个递归函数,用于添加处理器到字典
        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
            # 检查模块是否有获取处理器的方法
            if hasattr(module, "get_processor"):
                # 将处理器添加到字典中,键为名称,值为处理器
                processors[f"{name}.processor"] = module.get_processor()

            # 遍历模块的所有子模块
            for sub_name, child in module.named_children():
                # 递归调用,处理子模块
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            # 返回更新后的处理器字典
            return processors

        # 遍历当前模块的所有子模块
        for name, module in self.named_children():
            # 调用递归函数,添加处理器
            fn_recursive_add_processors(name, module, processors)

        # 返回包含所有处理器的字典
        return processors

    # 定义一个方法,设置用于计算注意力的处理器
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        Sets the attention processor to use to compute attention.

        Parameters:
            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
                The instantiated processor class or a dictionary of processor classes that will be set as the processor
                for **all** `Attention` layers.

                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
                processor. This is strongly recommended when setting trainable attention processors.

        """
        # 获取当前处理器的数量
        count = len(self.attn_processors.keys())

        # 如果传入的是字典,且字典长度与注意力层数量不匹配,抛出错误
        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
            )

        # 定义一个递归函数,用于设置处理器
        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            # 检查模块是否有设置处理器的方法
            if hasattr(module, "set_processor"):
                # 如果处理器不是字典,直接设置
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    # 从字典中弹出对应的处理器并设置
                    module.set_processor(processor.pop(f"{name}.processor"))

            # 遍历模块的所有子模块
            for sub_name, child in module.named_children():
                # 递归调用,设置子模块的处理器
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

        # 遍历当前模块的所有子模块
        for name, module in self.named_children():
            # 调用递归函数,设置处理器
            fn_recursive_attn_processor(name, module, processor)
    # 定义设置默认注意力处理器的方法
    def set_default_attn_processor(self):
        """
        禁用自定义注意力处理器并设置默认的注意力实现。
        """
        # 检查所有注意力处理器是否属于添加的键值注意力处理器类
        if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 创建添加键值注意力处理器的实例
            processor = AttnAddedKVProcessor()
        # 检查所有注意力处理器是否属于交叉注意力处理器类
        elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 创建标准注意力处理器的实例
            processor = AttnProcessor()
        else:
            # 如果注意力处理器类型不匹配,则抛出错误
            raise ValueError(
                f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
            )

        # 设置选定的注意力处理器
        self.set_attn_processor(processor)

    # 定义设置梯度检查点的方法
    def _set_gradient_checkpointing(self, module, value=False):
        # 如果模块具有梯度检查点属性,则设置其值
        if hasattr(module, "gradient_checkpointing"):
            module.gradient_checkpointing = value

    # 定义启用 FreeU 机制的方法
    def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
        r"""启用 FreeU 机制,详细信息请见 https://arxiv.org/abs/2309.11497。

        在缩放因子后面的后缀表示它们被应用的阶段块。

        请参考 [官方仓库](https://github.com/ChenyangSi/FreeU) 以获取已知在不同管道(如 Stable Diffusion v1、v2 和 Stable Diffusion XL)中效果良好的值组合。

        参数:
            s1 (`float`):
                阶段 1 的缩放因子,用于减弱跳跃特征的贡献,以减轻增强去噪过程中的“过平滑效应”。
            s2 (`float`):
                阶段 2 的缩放因子,用于减弱跳跃特征的贡献,以减轻增强去噪过程中的“过平滑效应”。
            b1 (`float`): 阶段 1 的缩放因子,用于增强骨干特征的贡献。
            b2 (`float`): 阶段 2 的缩放因子,用于增强骨干特征的贡献。
        """
        # 遍历上采样块并设置相应的缩放因子
        for i, upsample_block in enumerate(self.up_blocks):
            setattr(upsample_block, "s1", s1)  # 设置阶段 1 的缩放因子
            setattr(upsample_block, "s2", s2)  # 设置阶段 2 的缩放因子
            setattr(upsample_block, "b1", b1)  # 设置阶段 1 的骨干缩放因子
            setattr(upsample_block, "b2", b2)  # 设置阶段 2 的骨干缩放因子

    # 定义禁用 FreeU 机制的方法
    def disable_freeu(self):
        """禁用 FreeU 机制。"""
        freeu_keys = {"s1", "s2", "b1", "b2"}  # 定义 FreeU 相关的键
        # 遍历上采样块
        for i, upsample_block in enumerate(self.up_blocks):
            # 遍历每个 FreeU 键
            for k in freeu_keys:
                # 如果上采样块具有该键的属性或其值不为 None,则将其值设置为 None
                if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
                    setattr(upsample_block, k, None)
    # 定义一个方法,用于启用融合的 QKV 投影
    def fuse_qkv_projections(self):
        """
        启用融合的 QKV 投影。对于自注意力模块,所有投影矩阵(即查询、键、值)都被融合。
        对于交叉注意力模块,键和值的投影矩阵被融合。

        <Tip warning={true}>

        此 API 是 

标签:states,int,self,attention,diffusers,channels,源码,hidden,解析
From: https://www.cnblogs.com/apachecn/p/18492383

相关文章

  • diffusers-源码解析-十三-
    diffusers源码解析(十三).\diffusers\models\unets\unet_2d.py#版权声明,表示该代码由HuggingFace团队所有##根据Apache2.0许可证进行许可;#除非遵循许可证,否则不得使用此文件。#可以在以下地址获取许可证的副本:##http://www.apache.org/licenses/LICENSE-2.......
  • diffusers-源码解析-五-
    diffusers源码解析(五).\diffusers\models\autoencoders\autoencoder_asym_kl.py#版权声明,标识该文件的所有权和使用条款#Copyright2024TheHuggingFaceTeam.Allrightsreserved.##根据Apache许可证第2.0版(“许可证”)进行授权;#除非遵循许可证,否则您不得使用此文......
  • diffusers-源码解析-四-
    diffusers源码解析(四).\diffusers\models\attention_flax.py#版权声明,表明该代码的版权归HuggingFace团队所有#根据Apache2.0许可证授权使用该文件,未遵守许可证不得使用#许可证获取链接#指出该软件是以“现状”分发,不附带任何明示或暗示的保证#具体的权限和限制请......
  • C语言使用指针作为函数参数,并利用函数嵌套求输入三个整数,将它们按大到小的顺序输出。(
    输入三个整数,要求从大到小的顺序向他们输出,用函数实现。   本代码使用到了指针和函数嵌套。   调用指针做函数ex,并嵌套调用指针函数exx在函数ex中。(代码在下面哦!)一、关于函数 ex  1. 这个函数接受三个指针参数 int*p1 、 int*p2 和 int*p3 ,分别指......
  • 【QT速成】半小时入门QT6之QT前置知识扫盲(超详细QT工程解析)
    目录一.QT工程介绍1.创建工程ModelDefineBuildSystemClassInformation BaseclassKit二.工程构成三.类汇总一.QT工程介绍1.创建工程Model    QT创建工程时首先会让我们选择项目模板,对应的英文解释很详尽,这里我们也可做一下简单介绍。应用程序(A......
  • 【关注可白嫖源码】计算机等级考试在线刷题小程序,不会的看过来
    设计一个计算机等级考试在线刷题小程序,需要确保系统能够提供高效的刷题功能,帮助用户随时随地练习。以下是系统的设计思路:一、系统设计总体思路该小程序需要包含用户端、题库管理系统、后台管理系统三大部分。用户可以通过小程序在线刷题、查看答案解析、查看个人练习情况,而......
  • 【可白嫖源码】基于SSM的在线点餐系统(案例分析)
    摘  要   当前高速发展的经济模式下,人们工作和生活都处于高压下,没时间做饭,在哪做饭成了人们的难题,传统下班回家做饭的生活习俗渐渐地变得难以实现。在社会驱动下,我国在餐饮方面的收入额,逐年成上升趋势。餐饮方面带来的收入拉高了社会消费品的零售总额。不得不说,餐饮......
  • 深入理解华为鸿蒙的 Context —— 应用上下文解析
    本文旨在深入探讨华为鸿蒙HarmonyOSNext系统(截止目前API12)的技术细节,基于实际开发实践进行总结。主要作为技术分享与交流载体,难免错漏,欢迎各位同仁提出宝贵意见和问题,以便共同进步。本文为原创内容,任何形式的转载必须注明出处及原作者。在华为鸿蒙(HarmonyOS)开发中,Context是......
  • 华为鸿蒙Next:应用启动框架AppStartup的解析与实战应用
    本文旨在深入探讨华为鸿蒙HarmonyOSNext系统(截止目前API12)的技术细节,基于实际开发实践进行总结。主要作为技术分享与交流载体,难免错漏,欢迎各位同仁提出宝贵意见和问题,以便共同进步。本文为原创内容,任何形式的转载必须注明出处及原作者。在华为鸿蒙(HarmonyOS)开发领域,应用的启......
  • 深入解析Apache DolphinScheduler容错机制
    简述ApacheDolphinschedulerMaster和Worker都是支持多节点部署,无中心化的设计。Master主要负责是流程DAG的切分,最终通过RPC将任务分发到Worker节点上以及Worker上任务状态的处理Worker主要负责是真正任务的执行,最后将任务状态汇报给Master,Master进行状态处理那问题来了:M......