首页 > 编程语言 >diffusers-源码解析-十五-

diffusers-源码解析-十五-

时间:2024-10-22 12:33:33浏览次数:1  
标签:dim None int self attention diffusers 源码 hidden 解析

diffusers 源码解析(十五)

.\diffusers\models\unets\unet_3d_condition.py

# 版权声明,声明此代码的版权信息和所有权
# Copyright 2024 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
# 版权声明,声明此代码的版权信息和所有权
# Copyright 2024 The ModelScope Team.
#
# 许可声明,声明本代码使用的 Apache 许可证 2.0 版本
# Licensed under the Apache License, Version 2.0 (the "License");
# 使用此文件前需遵守许可证规定
# you may not use this file except in compliance with the License.
# 可在以下网址获取许可证副本
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 免责声明,说明软件在许可下按 "原样" 提供,不附加任何明示或暗示的保证
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 许可证中规定的权限和限制说明
# See the License for the specific language governing permissions and
# limitations under the License.

# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass
# 从 typing 模块导入所需的类型提示
from typing import Any, Dict, List, Optional, Tuple, Union

# 导入 PyTorch 库
import torch
# 导入 PyTorch 神经网络模块
import torch.nn as nn
# 导入 PyTorch 的检查点工具
import torch.utils.checkpoint

# 导入配置相关的工具类和函数
from ...configuration_utils import ConfigMixin, register_to_config
# 导入 UNet2D 条件加载器混合类
from ...loaders import UNet2DConditionLoadersMixin
# 导入基本输出类和日志工具
from ...utils import BaseOutput, logging
# 导入激活函数获取工具
from ..activations import get_activation
# 导入各种注意力处理器相关组件
from ..attention_processor import (
    ADDED_KV_ATTENTION_PROCESSORS,  # 导入添加键值对注意力处理器
    CROSS_ATTENTION_PROCESSORS,      # 导入交叉注意力处理器
    Attention,                       # 导入注意力类
    AttentionProcessor,              # 导入注意力处理器基类
    AttnAddedKVProcessor,            # 导入添加键值对的注意力处理器
    AttnProcessor,                   # 导入普通注意力处理器
    FusedAttnProcessor2_0,           # 导入融合注意力处理器
)
# 导入时间步嵌入和时间步类
from ..embeddings import TimestepEmbedding, Timesteps
# 导入模型混合类
from ..modeling_utils import ModelMixin
# 导入时间变换器模型
from ..transformers.transformer_temporal import TransformerTemporalModel
# 导入 3D UNet 相关的块
from .unet_3d_blocks import (
    CrossAttnDownBlock3D,          # 导入交叉注意力下采样块
    CrossAttnUpBlock3D,            # 导入交叉注意力上采样块
    DownBlock3D,                   # 导入下采样块
    UNetMidBlock3DCrossAttn,      # 导入 UNet 中间交叉注意力块
    UpBlock3D,                     # 导入上采样块
    get_down_block,                # 导入获取下采样块的函数
    get_up_block,                  # 导入获取上采样块的函数
)

# 创建日志记录器,使用当前模块的名称
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# 定义 UNet3DConditionOutput 数据类,继承自 BaseOutput
@dataclass
class UNet3DConditionOutput(BaseOutput):
    """
    [`UNet3DConditionModel`] 的输出类。

    参数:
        sample (`torch.Tensor` 的形状为 `(batch_size, num_channels, num_frames, height, width)`):
            基于 `encoder_hidden_states` 输入的隐藏状态输出。模型最后一层的输出。
    """

    sample: torch.Tensor  # 定义样本输出,类型为 PyTorch 张量

# 定义 UNet3DConditionModel 类,继承自多个混合类
class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
    r"""
    条件 3D UNet 模型,接受噪声样本、条件状态和时间步,并返回形状为样本的输出。

    此模型继承自 [`ModelMixin`]。有关其通用方法的文档,请参阅超类文档(如下载或保存)。
    # 参数说明部分
    Parameters:
        # 输入/输出样本的高度和宽度,类型可以为整数或元组,默认为 None
        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
            Height and width of input/output sample.
        # 输入样本的通道数,默认为 4
        in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
        # 输出的通道数,默认为 4
        out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
        # 使用的下采样块类型的元组,默认为指定的四种块
        down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D")`):
            The tuple of downsample blocks to use.
        # 使用的上采样块类型的元组,默认为指定的四种块
        up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D")`):
            The tuple of upsample blocks to use.
        # 每个块的输出通道数的元组,默认为 (320, 640, 1280, 1280)
        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
            The tuple of output channels for each block.
        # 每个块的层数,默认为 2
        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
        # 下采样卷积使用的填充,默认为 1
        downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
        # 中间块使用的缩放因子,默认为 1.0
        mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
        # 使用的激活函数,默认为 "silu"
        act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
        # 用于归一化的组数,默认为 32;如果为 None,则跳过归一化和激活层
        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
            If `None`, normalization and activation layers is skipped in post-processing.
        # 归一化使用的 epsilon 值,默认为 1e-5
        norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
        # 交叉注意力特征的维度,默认为 1024
        cross_attention_dim (`int`, *optional*, defaults to 1024): The dimension of the cross attention features.
        # 注意力头的维度,默认为 64
        attention_head_dim (`int`, *optional*, defaults to 64): The dimension of the attention heads.
        # 注意力头的数量,类型为整数,默认为 None
        num_attention_heads (`int`, *optional*): The number of attention heads.
        # 时间条件投影层的维度,默认为 None
        time_cond_proj_dim (`int`, *optional*, defaults to `None`):
            The dimension of `cond_proj` layer in the timestep embedding.
    """

    # 是否支持梯度检查点,默认为 False
    _supports_gradient_checkpointing = False

    # 将此类注册到配置中
    @register_to_config
    # 初始化方法,用于创建类的实例
        def __init__(
            # 样本大小,默认为 None
            self,
            sample_size: Optional[int] = None,
            # 输入通道数量,默认为 4
            in_channels: int = 4,
            # 输出通道数量,默认为 4
            out_channels: int = 4,
            # 下采样块类型的元组,定义模型的下采样结构
            down_block_types: Tuple[str, ...] = (
                "CrossAttnDownBlock3D",
                "CrossAttnDownBlock3D",
                "CrossAttnDownBlock3D",
                "DownBlock3D",
            ),
            # 上采样块类型的元组,定义模型的上采样结构
            up_block_types: Tuple[str, ...] = (
                "UpBlock3D",
                "CrossAttnUpBlock3D",
                "CrossAttnUpBlock3D",
                "CrossAttnUpBlock3D",
            ),
            # 每个块的输出通道数量,定义模型每个层的通道设置
            block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
            # 每个块的层数,默认为 2
            layers_per_block: int = 2,
            # 下采样时的填充大小,默认为 1
            downsample_padding: int = 1,
            # 中间块的缩放因子,默认为 1
            mid_block_scale_factor: float = 1,
            # 激活函数类型,默认为 "silu"
            act_fn: str = "silu",
            # 归一化组的数量,默认为 32
            norm_num_groups: Optional[int] = 32,
            # 归一化的 epsilon 值,默认为 1e-5
            norm_eps: float = 1e-5,
            # 跨注意力维度,默认为 1024
            cross_attention_dim: int = 1024,
            # 注意力头的维度,可以是单一整数或整数元组,默认为 64
            attention_head_dim: Union[int, Tuple[int]] = 64,
            # 注意力头的数量,可选参数
            num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
            # 时间条件投影维度,可选参数
            time_cond_proj_dim: Optional[int] = None,
        @property
        # 从 UNet2DConditionModel 复制的属性,获取注意力处理器
        # 返回所有注意力处理器的字典,以权重名称为索引
        def attn_processors(self) -> Dict[str, AttentionProcessor]:
            r"""
            Returns:
                `dict` of attention processors: A dictionary containing all attention processors used in the model with
                indexed by its weight name.
            """
            # 初始化处理器字典
            processors = {}
    
            # 递归添加处理器的函数
            def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
                # 如果模块有获取处理器的方法,添加到处理器字典中
                if hasattr(module, "get_processor"):
                    processors[f"{name}.processor"] = module.get_processor()
    
                # 遍历子模块,递归调用该函数
                for sub_name, child in module.named_children():
                    fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
    
                # 返回处理器字典
                return processors
    
            # 遍历当前类的子模块,调用递归添加处理器的函数
            for name, module in self.named_children():
                fn_recursive_add_processors(name, module, processors)
    
            # 返回所有处理器
            return processors
    
        # 从 UNet2DConditionModel 复制的设置注意力切片的方法
        # 从 UNet2DConditionModel 复制的设置注意力处理器的方法
    # 定义一个方法用于设置注意力处理器
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        设置用于计算注意力的处理器。
    
        参数:
            processor(`dict` of `AttentionProcessor` 或仅 `AttentionProcessor`):
                实例化的处理器类或一个处理器类的字典,将作为所有 `Attention` 层的处理器。
    
                如果 `processor` 是一个字典,键需要定义相应的交叉注意力处理器的路径。
                在设置可训练的注意力处理器时,强烈推荐这样做。
    
        """
        # 获取当前注意力处理器的数量
        count = len(self.attn_processors.keys())
    
        # 如果传入的处理器是字典,且数量不等于注意力层数量,抛出错误
        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"传入了一个处理器字典,但处理器的数量 {len(processor)} 与"
                f" 注意力层的数量 {count} 不匹配。请确保传入 {count} 个处理器类。"
            )
    
        # 定义一个递归函数来设置每个模块的处理器
        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            # 如果模块有设置处理器的方法
            if hasattr(module, "set_processor"):
                # 如果处理器不是字典,直接设置处理器
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    # 从字典中获取相应的处理器并设置
                    module.set_processor(processor.pop(f"{name}.processor"))
    
            # 遍历子模块并递归调用处理器设置
            for sub_name, child in module.named_children():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
    
        # 遍历当前对象的所有子模块,并调用递归设置函数
        for name, module in self.named_children():
            fn_recursive_attn_processor(name, module, processor)
    # 定义一个方法来启用前馈层的分块处理
        def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
            """
            设置注意力处理器以使用 [前馈分块](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers)。
    
            参数:
                chunk_size (`int`, *可选*):
                    前馈层的分块大小。如果未指定,将对维度为`dim`的每个张量单独运行前馈层。
                dim (`int`, *可选*, 默认为`0`):
                    应对哪个维度进行前馈计算的分块。可以选择 dim=0(批次)或 dim=1(序列长度)。
            """
            # 确保 dim 参数为 0 或 1
            if dim not in [0, 1]:
                raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
    
            # 默认的分块大小为 1
            chunk_size = chunk_size or 1
    
            # 定义一个递归函数来设置每个模块的分块前馈处理
            def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
                # 如果模块具有设置分块前馈的属性,则设置它
                if hasattr(module, "set_chunk_feed_forward"):
                    module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
    
                # 遍历子模块,递归调用函数
                for child in module.children():
                    fn_recursive_feed_forward(child, chunk_size, dim)
    
            # 遍历当前实例的子模块,应用递归函数
            for module in self.children():
                fn_recursive_feed_forward(module, chunk_size, dim)
    
        # 定义一个方法来禁用前馈层的分块处理
        def disable_forward_chunking(self):
            # 定义一个递归函数来禁用分块前馈处理
            def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
                # 如果模块具有设置分块前馈的属性,则设置为 None
                if hasattr(module, "set_chunk_feed_forward"):
                    module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
    
                # 遍历子模块,递归调用函数
                for child in module.children():
                    fn_recursive_feed_forward(child, chunk_size, dim)
    
            # 遍历当前实例的子模块,应用递归函数,禁用分块
            for module in self.children():
                fn_recursive_feed_forward(module, None, 0)
    
        # 从 diffusers.models.unets.unet_2d_condition 中复制的方法,设置默认注意力处理器
        def set_default_attn_processor(self):
            """
            禁用自定义注意力处理器并设置默认注意力实现。
            """
            # 检查所有注意力处理器是否为添加的 KV 处理器
            if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
                processor = AttnAddedKVProcessor()  # 设置为添加的 KV 处理器
            # 检查所有注意力处理器是否为交叉注意力处理器
            elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
                processor = AttnProcessor()  # 设置为普通注意力处理器
            else:
                # 抛出异常,若注意力处理器类型不符合预期
                raise ValueError(
                    f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
                )
    
            # 设置选定的注意力处理器
            self.set_attn_processor(processor)
    
        # 定义一个私有方法来设置模块的梯度检查点
        def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
            # 检查模块是否属于特定类型
            if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
                module.gradient_checkpointing = value  # 设置梯度检查点值
    
        # 从 diffusers.models.unets.unet_2d_condition 中复制的方法,启用自由度
    # 启用 FreeU 机制,参数为两个缩放因子和两个增强因子的值
    def enable_freeu(self, s1, s2, b1, b2):
        r"""从 https://arxiv.org/abs/2309.11497 启用 FreeU 机制。

        缩放因子的后缀表示它们应用的阶段块。

        请参考 [官方仓库](https://github.com/ChenyangSi/FreeU) 以获取在不同管道(如 Stable Diffusion v1、v2 和 Stable Diffusion XL)中已知效果良好的值组合。

        Args:
            s1 (`float`):
                第1阶段的缩放因子,用于减弱跳跃特征的贡献,以减轻增强去噪过程中的“过平滑效应”。
            s2 (`float`):
                第2阶段的缩放因子,用于减弱跳跃特征的贡献,以减轻增强去噪过程中的“过平滑效应”。
            b1 (`float`): 第1阶段的缩放因子,用于增强骨干特征的贡献。
            b2 (`float`): 第2阶段的缩放因子,用于增强骨干特征的贡献。
        """
        # 遍历上采样块,给每个块设置缩放因子和增强因子
        for i, upsample_block in enumerate(self.up_blocks):
            # 设置第1阶段的缩放因子
            setattr(upsample_block, "s1", s1)
            # 设置第2阶段的缩放因子
            setattr(upsample_block, "s2", s2)
            # 设置第1阶段的增强因子
            setattr(upsample_block, "b1", b1)
            # 设置第2阶段的增强因子
            setattr(upsample_block, "b2", b2)

    # 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu 复制
    # 禁用 FreeU 机制
    def disable_freeu(self):
        """禁用 FreeU 机制。"""
        # 定义 FreeU 机制的关键属性
        freeu_keys = {"s1", "s2", "b1", "b2"}
        # 遍历上采样块
        for i, upsample_block in enumerate(self.up_blocks):
            # 遍历 FreeU 关键属性
            for k in freeu_keys:
                # 如果上采样块有该属性,或者该属性值不为 None
                if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
                    # 将属性值设置为 None,禁用 FreeU
                    setattr(upsample_block, k, None)

    # 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections 复制
    # 启用融合的 QKV 投影
    def fuse_qkv_projections(self):
        """
        启用融合的 QKV 投影。对于自注意力模块,所有投影矩阵(即查询、键、值)都被融合。对于交叉注意力模块,键和值投影矩阵被融合。

        <Tip warning={true}>

        此 API 是 

标签:dim,None,int,self,attention,diffusers,源码,hidden,解析
From: https://www.cnblogs.com/apachecn/p/18492384

相关文章

  • diffusers-源码解析-十四-
    diffusers源码解析(十四).\diffusers\models\unets\unet_2d_blocks_flax.py#版权声明,说明该文件的版权信息及相关许可协议#Copyright2024TheHuggingFaceTeam.Allrightsreserved.##许可信息,使用ApacheLicense2.0许可#LicensedundertheApacheLicense,Versi......
  • diffusers-源码解析-十三-
    diffusers源码解析(十三).\diffusers\models\unets\unet_2d.py#版权声明,表示该代码由HuggingFace团队所有##根据Apache2.0许可证进行许可;#除非遵循许可证,否则不得使用此文件。#可以在以下地址获取许可证的副本:##http://www.apache.org/licenses/LICENSE-2.......
  • diffusers-源码解析-五-
    diffusers源码解析(五).\diffusers\models\autoencoders\autoencoder_asym_kl.py#版权声明,标识该文件的所有权和使用条款#Copyright2024TheHuggingFaceTeam.Allrightsreserved.##根据Apache许可证第2.0版(“许可证”)进行授权;#除非遵循许可证,否则您不得使用此文......
  • diffusers-源码解析-四-
    diffusers源码解析(四).\diffusers\models\attention_flax.py#版权声明,表明该代码的版权归HuggingFace团队所有#根据Apache2.0许可证授权使用该文件,未遵守许可证不得使用#许可证获取链接#指出该软件是以“现状”分发,不附带任何明示或暗示的保证#具体的权限和限制请......
  • C语言使用指针作为函数参数,并利用函数嵌套求输入三个整数,将它们按大到小的顺序输出。(
    输入三个整数,要求从大到小的顺序向他们输出,用函数实现。   本代码使用到了指针和函数嵌套。   调用指针做函数ex,并嵌套调用指针函数exx在函数ex中。(代码在下面哦!)一、关于函数 ex  1. 这个函数接受三个指针参数 int*p1 、 int*p2 和 int*p3 ,分别指......
  • 【QT速成】半小时入门QT6之QT前置知识扫盲(超详细QT工程解析)
    目录一.QT工程介绍1.创建工程ModelDefineBuildSystemClassInformation BaseclassKit二.工程构成三.类汇总一.QT工程介绍1.创建工程Model    QT创建工程时首先会让我们选择项目模板,对应的英文解释很详尽,这里我们也可做一下简单介绍。应用程序(A......
  • 【关注可白嫖源码】计算机等级考试在线刷题小程序,不会的看过来
    设计一个计算机等级考试在线刷题小程序,需要确保系统能够提供高效的刷题功能,帮助用户随时随地练习。以下是系统的设计思路:一、系统设计总体思路该小程序需要包含用户端、题库管理系统、后台管理系统三大部分。用户可以通过小程序在线刷题、查看答案解析、查看个人练习情况,而......
  • 【可白嫖源码】基于SSM的在线点餐系统(案例分析)
    摘  要   当前高速发展的经济模式下,人们工作和生活都处于高压下,没时间做饭,在哪做饭成了人们的难题,传统下班回家做饭的生活习俗渐渐地变得难以实现。在社会驱动下,我国在餐饮方面的收入额,逐年成上升趋势。餐饮方面带来的收入拉高了社会消费品的零售总额。不得不说,餐饮......
  • 深入理解华为鸿蒙的 Context —— 应用上下文解析
    本文旨在深入探讨华为鸿蒙HarmonyOSNext系统(截止目前API12)的技术细节,基于实际开发实践进行总结。主要作为技术分享与交流载体,难免错漏,欢迎各位同仁提出宝贵意见和问题,以便共同进步。本文为原创内容,任何形式的转载必须注明出处及原作者。在华为鸿蒙(HarmonyOS)开发中,Context是......
  • 华为鸿蒙Next:应用启动框架AppStartup的解析与实战应用
    本文旨在深入探讨华为鸿蒙HarmonyOSNext系统(截止目前API12)的技术细节,基于实际开发实践进行总结。主要作为技术分享与交流载体,难免错漏,欢迎各位同仁提出宝贵意见和问题,以便共同进步。本文为原创内容,任何形式的转载必须注明出处及原作者。在华为鸿蒙(HarmonyOS)开发领域,应用的启......