diffusers 源码解析(十三)
.\diffusers\models\unets\unet_2d.py
# 版权声明,表示该代码由 HuggingFace 团队所有
#
# 根据 Apache 2.0 许可证进行许可;
# 除非遵循许可证,否则不得使用此文件。
# 可以在以下地址获取许可证的副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,软件按 "原样" 分发,
# 不提供任何形式的保证或条件,无论是明示或暗示。
# 请参阅许可证以了解有关权限和
# 限制的具体信息。
from dataclasses import dataclass # 从 dataclasses 模块导入 dataclass 装饰器
from typing import Optional, Tuple, Union # 从 typing 模块导入类型提示相关的类
import torch # 导入 PyTorch 库
import torch.nn as nn # 导入 PyTorch 的神经网络模块
from ...configuration_utils import ConfigMixin, register_to_config # 从配置工具导入配置混合类和注册函数
from ...utils import BaseOutput # 从工具模块导入基础输出类
from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps # 从嵌入模块导入相关类
from ..modeling_utils import ModelMixin # 从建模工具导入模型混合类
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block # 从 UNet 2D 块导入相关组件
@dataclass
class UNet2DOutput(BaseOutput): # 定义 UNet2DOutput 类,继承自 BaseOutput
"""
[`UNet2DModel`] 的输出类。
参数:
sample (`torch.Tensor` 形状为 `(batch_size, num_channels, height, width)`):
从模型最后一层输出的隐藏状态。
"""
sample: torch.Tensor # 定义输出样本的属性,类型为 torch.Tensor
class UNet2DModel(ModelMixin, ConfigMixin): # 定义 UNet2DModel 类,继承自 ModelMixin 和 ConfigMixin
r"""
一个 2D UNet 模型,接收一个有噪声的样本和一个时间步,返回一个样本形状的输出。
此模型继承自 [`ModelMixin`]。请查看超类文档以了解其为所有模型实现的通用方法
(例如下载或保存)。
"""
@register_to_config # 使用装饰器将该方法注册到配置中
def __init__( # 定义初始化方法
self,
sample_size: Optional[Union[int, Tuple[int, int]]] = None, # 可选的样本大小,可以是整数或整数元组
in_channels: int = 3, # 输入通道数,默认为 3
out_channels: int = 3, # 输出通道数,默认为 3
center_input_sample: bool = False, # 是否将输入样本居中,默认为 False
time_embedding_type: str = "positional", # 时间嵌入类型,默认为 "positional"
freq_shift: int = 0, # 频率偏移量,默认为 0
flip_sin_to_cos: bool = True, # 是否将正弦函数翻转为余弦函数,默认为 True
down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), # 下采样块类型
up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), # 上采样块类型
block_out_channels: Tuple[int, ...] = (224, 448, 672, 896), # 各块输出通道数
layers_per_block: int = 2, # 每个块的层数,默认为 2
mid_block_scale_factor: float = 1, # 中间块缩放因子,默认为 1
downsample_padding: int = 1, # 下采样的填充大小,默认为 1
downsample_type: str = "conv", # 下采样类型,默认为卷积
upsample_type: str = "conv", # 上采样类型,默认为卷积
dropout: float = 0.0, # dropout 概率,默认为 0.0
act_fn: str = "silu", # 激活函数类型,默认为 "silu"
attention_head_dim: Optional[int] = 8, # 注意力头维度,默认为 8
norm_num_groups: int = 32, # 规范化组数量,默认为 32
attn_norm_num_groups: Optional[int] = None, # 注意力规范化组数量,可选
norm_eps: float = 1e-5, # 规范化的 epsilon 值,默认为 1e-5
resnet_time_scale_shift: str = "default", # ResNet 时间缩放偏移类型,默认为 "default"
add_attention: bool = True, # 是否添加注意力机制,默认为 True
class_embed_type: Optional[str] = None, # 类别嵌入类型,可选
num_class_embeds: Optional[int] = None, # 类别嵌入数量,可选
num_train_timesteps: Optional[int] = None, # 训练时间步数量,可选
# 定义一个名为 forward 的方法
def forward(
# 输入参数 sample,类型为 torch.Tensor,表示样本数据
self,
sample: torch.Tensor,
# 输入参数 timestep,可以是 torch.Tensor、float 或 int,表示时间步
timestep: Union[torch.Tensor, float, int],
# 可选参数 class_labels,类型为 torch.Tensor,表示分类标签,默认为 None
class_labels: Optional[torch.Tensor] = None,
# 可选参数 return_dict,类型为 bool,表示是否以字典形式返回结果,默认为 True
return_dict: bool = True,
.\diffusers\models\unets\unet_2d_blocks.py
# 版权所有 2024 HuggingFace 团队,保留所有权利。
#
# 根据 Apache 许可证第 2.0 版(“许可证”)进行许可;
# 除非遵循许可证,否则不得使用此文件。
# 可以在以下网址获取许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,否则根据许可证分发的软件按“现状”基础分发,
# 不提供任何形式的保证或条件,无论是明示还是暗示。
# 请参阅许可证以了解特定语言的权限和
# 限制。
from typing import Any, Dict, Optional, Tuple, Union # 导入类型注解相关的模块
import numpy as np # 导入 NumPy 库用于数值计算
import torch # 导入 PyTorch 库用于深度学习
import torch.nn.functional as F # 导入 PyTorch 的功能性 API
from torch import nn # 导入 PyTorch 的神经网络模块
from ...utils import deprecate, is_torch_version, logging # 从工具模块导入日志和版本检测功能
from ...utils.torch_utils import apply_freeu # 从 PyTorch 工具模块导入 apply_freeu 函数
from ..activations import get_activation # 从激活函数模块导入 get_activation 函数
from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 # 导入注意力处理器相关的类
from ..normalization import AdaGroupNorm # 从归一化模块导入 AdaGroupNorm 类
from ..resnet import ( # 从 ResNet 模块导入多个下采样和上采样类
Downsample2D,
FirDownsample2D,
FirUpsample2D,
KDownsample2D,
KUpsample2D,
ResnetBlock2D,
ResnetBlockCondNorm2D,
Upsample2D,
)
from ..transformers.dual_transformer_2d import DualTransformer2DModel # 导入双重变换器模型
from ..transformers.transformer_2d import Transformer2DModel # 导入二维变换器模型
logger = logging.get_logger(__name__) # 获取当前模块的日志记录器实例
def get_down_block( # 定义获取下采样块的函数
down_block_type: str, # 下采样块的类型
num_layers: int, # 下采样层的数量
in_channels: int, # 输入通道数
out_channels: int, # 输出通道数
temb_channels: int, # 时间嵌入通道数
add_downsample: bool, # 是否添加下采样标志
resnet_eps: float, # ResNet 中的 epsilon 参数
resnet_act_fn: str, # ResNet 使用的激活函数
transformer_layers_per_block: int = 1, # 每个块中的变换器层数,默认为 1
num_attention_heads: Optional[int] = None, # 注意力头的数量,默认为 None
resnet_groups: Optional[int] = None, # ResNet 中的组数,默认为 None
cross_attention_dim: Optional[int] = None, # 交叉注意力维度,默认为 None
downsample_padding: Optional[int] = None, # 下采样填充参数,默认为 None
dual_cross_attention: bool = False, # 是否使用双重交叉注意力标志
use_linear_projection: bool = False, # 是否使用线性投影标志
only_cross_attention: bool = False, # 是否仅使用交叉注意力标志
upcast_attention: bool = False, # 是否上升注意力标志
resnet_time_scale_shift: str = "default", # ResNet 时间缩放移位,默认为“default”
attention_type: str = "default", # 注意力类型,默认为“default”
resnet_skip_time_act: bool = False, # ResNet 跳过时间激活标志
resnet_out_scale_factor: float = 1.0, # ResNet 输出缩放因子,默认为 1.0
cross_attention_norm: Optional[str] = None, # 交叉注意力归一化,默认为 None
attention_head_dim: Optional[int] = None, # 注意力头维度,默认为 None
downsample_type: Optional[str] = None, # 下采样类型,默认为 None
dropout: float = 0.0, # dropout 比例,默认为 0.0
):
# 如果没有定义注意力头维度,默认设置为头的数量
if attention_head_dim is None:
logger.warning( # 记录警告信息
f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." # 提醒用户使用默认的注意力头维度
)
attention_head_dim = num_attention_heads # 将注意力头维度设置为头的数量
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type # 处理下采样块类型的字符串
# 检查下行块的类型是否为 "DownBlock2D"
if down_block_type == "DownBlock2D":
# 返回 DownBlock2D 实例,传入相关参数
return DownBlock2D(
# 传入层数
num_layers=num_layers,
# 输入通道数
in_channels=in_channels,
# 输出通道数
out_channels=out_channels,
# 时间嵌入通道数
temb_channels=temb_channels,
# dropout 比率
dropout=dropout,
# 是否添加下采样
add_downsample=add_downsample,
# ResNet 的 epsilon 值
resnet_eps=resnet_eps,
# ResNet 的激活函数
resnet_act_fn=resnet_act_fn,
# ResNet 的分组数
resnet_groups=resnet_groups,
# 下采样的填充
downsample_padding=downsample_padding,
# ResNet 的时间尺度偏移
resnet_time_scale_shift=resnet_time_scale_shift,
)
# 检查下行块的类型是否为 "ResnetDownsampleBlock2D"
elif down_block_type == "ResnetDownsampleBlock2D":
# 返回 ResnetDownsampleBlock2D 实例,传入相关参数
return ResnetDownsampleBlock2D(
# 传入层数
num_layers=num_layers,
# 输入通道数
in_channels=in_channels,
# 输出通道数
out_channels=out_channels,
# 时间嵌入通道数
temb_channels=temb_channels,
# dropout 比率
dropout=dropout,
# 是否添加下采样
add_downsample=add_downsample,
# ResNet 的 epsilon 值
resnet_eps=resnet_eps,
# ResNet 的激活函数
resnet_act_fn=resnet_act_fn,
# ResNet 的分组数
resnet_groups=resnet_groups,
# ResNet 的时间尺度偏移
resnet_time_scale_shift=resnet_time_scale_shift,
# ResNet 的时间激活跳过标志
skip_time_act=resnet_skip_time_act,
# ResNet 的输出缩放因子
output_scale_factor=resnet_out_scale_factor,
)
# 检查下行块的类型是否为 "AttnDownBlock2D"
elif down_block_type == "AttnDownBlock2D":
# 如果不添加下采样,则将下采样类型设为 None
if add_downsample is False:
downsample_type = None
else:
# 如果添加下采样,则默认下采样类型为 'conv'
downsample_type = downsample_type or "conv" # default to 'conv'
# 返回 AttnDownBlock2D 实例,传入相关参数
return AttnDownBlock2D(
# 传入层数
num_layers=num_layers,
# 输入通道数
in_channels=in_channels,
# 输出通道数
out_channels=out_channels,
# 时间嵌入通道数
temb_channels=temb_channels,
# dropout 比率
dropout=dropout,
# ResNet 的 epsilon 值
resnet_eps=resnet_eps,
# ResNet 的激活函数
resnet_act_fn=resnet_act_fn,
# ResNet 的分组数
resnet_groups=resnet_groups,
# 下采样的填充
downsample_padding=downsample_padding,
# 注意力头的维度
attention_head_dim=attention_head_dim,
# ResNet 的时间尺度偏移
resnet_time_scale_shift=resnet_time_scale_shift,
# 下采样类型
downsample_type=downsample_type,
)
# 检查下行块类型是否为 CrossAttnDownBlock2D
elif down_block_type == "CrossAttnDownBlock2D":
# 如果 cross_attention_dim 未指定,则抛出错误
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
# 返回 CrossAttnDownBlock2D 实例,使用提供的参数进行初始化
return CrossAttnDownBlock2D(
# 设置层的数量
num_layers=num_layers,
# 每个块的变换层数量
transformer_layers_per_block=transformer_layers_per_block,
# 输入通道数
in_channels=in_channels,
# 输出通道数
out_channels=out_channels,
# 时间嵌入通道数
temb_channels=temb_channels,
# dropout 比例
dropout=dropout,
# 是否添加下采样
add_downsample=add_downsample,
# ResNet 中的 epsilon 值
resnet_eps=resnet_eps,
# ResNet 激活函数
resnet_act_fn=resnet_act_fn,
# ResNet 分组数
resnet_groups=resnet_groups,
# 下采样填充
downsample_padding=downsample_padding,
# 跨注意力维度
cross_attention_dim=cross_attention_dim,
# 注意力头数
num_attention_heads=num_attention_heads,
# 是否使用双向跨注意力
dual_cross_attention=dual_cross_attention,
# 是否使用线性投影
use_linear_projection=use_linear_projection,
# 是否仅使用跨注意力
only_cross_attention=only_cross_attention,
# 是否上溯注意力
upcast_attention=upcast_attention,
# ResNet 时间尺度偏移
resnet_time_scale_shift=resnet_time_scale_shift,
# 注意力类型
attention_type=attention_type,
)
# 检查下行块类型是否为 SimpleCrossAttnDownBlock2D
elif down_block_type == "SimpleCrossAttnDownBlock2D":
# 如果 cross_attention_dim 未指定,则抛出错误
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
# 返回 SimpleCrossAttnDownBlock2D 实例,使用提供的参数进行初始化
return SimpleCrossAttnDownBlock2D(
# 设置层的数量
num_layers=num_layers,
# 输入通道数
in_channels=in_channels,
# 输出通道数
out_channels=out_channels,
# 时间嵌入通道数
temb_channels=temb_channels,
# dropout 比例
dropout=dropout,
# 是否添加下采样
add_downsample=add_downsample,
# ResNet 中的 epsilon 值
resnet_eps=resnet_eps,
# ResNet 激活函数
resnet_act_fn=resnet_act_fn,
# ResNet 分组数
resnet_groups=resnet_groups,
# 跨注意力维度
cross_attention_dim=cross_attention_dim,
# 注意力头的维度
attention_head_dim=attention_head_dim,
# ResNet 时间尺度偏移
resnet_time_scale_shift=resnet_time_scale_shift,
# 是否跳过时间激活
skip_time_act=resnet_skip_time_act,
# 输出缩放因子
output_scale_factor=resnet_out_scale_factor,
# 是否仅使用跨注意力
only_cross_attention=only_cross_attention,
# 跨注意力规范化
cross_attention_norm=cross_attention_norm,
)
# 检查下行块类型是否为 SkipDownBlock2D
elif down_block_type == "SkipDownBlock2D":
# 返回 SkipDownBlock2D 实例,使用提供的参数进行初始化
return SkipDownBlock2D(
# 设置层的数量
num_layers=num_layers,
# 输入通道数
in_channels=in_channels,
# 输出通道数
out_channels=out_channels,
# 时间嵌入通道数
temb_channels=temb_channels,
# dropout 比例
dropout=dropout,
# 是否添加下采样
add_downsample=add_downsample,
# ResNet 中的 epsilon 值
resnet_eps=resnet_eps,
# ResNet 激活函数
resnet_act_fn=resnet_act_fn,
# 下采样填充
downsample_padding=downsample_padding,
# ResNet 时间尺度偏移
resnet_time_scale_shift=resnet_time_scale_shift,
)
# 检查下采样块的类型是否为 "AttnSkipDownBlock2D"
elif down_block_type == "AttnSkipDownBlock2D":
# 返回一个 AttnSkipDownBlock2D 对象,初始化参数传入
return AttnSkipDownBlock2D(
# 设置层数
num_layers=num_layers,
# 输入通道数
in_channels=in_channels,
# 输出通道数
out_channels=out_channels,
# 时间嵌入通道数
temb_channels=temb_channels,
# dropout 率
dropout=dropout,
# 是否添加下采样
add_downsample=add_downsample,
# ResNet 的 epsilon 值
resnet_eps=resnet_eps,
# ResNet 的激活函数
resnet_act_fn=resnet_act_fn,
# 注意力头的维度
attention_head_dim=attention_head_dim,
# ResNet 时间缩放偏移
resnet_time_scale_shift=resnet_time_scale_shift,
)
# 检查下采样块的类型是否为 "DownEncoderBlock2D"
elif down_block_type == "DownEncoderBlock2D":
# 返回一个 DownEncoderBlock2D 对象,初始化参数传入
return DownEncoderBlock2D(
# 设置层数
num_layers=num_layers,
# 输入通道数
in_channels=in_channels,
# 输出通道数
out_channels=out_channels,
# dropout 率
dropout=dropout,
# 是否添加下采样
add_downsample=add_downsample,
# ResNet 的 epsilon 值
resnet_eps=resnet_eps,
# ResNet 的激活函数
resnet_act_fn=resnet_act_fn,
# ResNet 的组数
resnet_groups=resnet_groups,
# 下采样填充
downsample_padding=downsample_padding,
# ResNet 时间缩放偏移
resnet_time_scale_shift=resnet_time_scale_shift,
)
# 检查下采样块的类型是否为 "AttnDownEncoderBlock2D"
elif down_block_type == "AttnDownEncoderBlock2D":
# 返回一个 AttnDownEncoderBlock2D 对象,初始化参数传入
return AttnDownEncoderBlock2D(
# 设置层数
num_layers=num_layers,
# 输入通道数
in_channels=in_channels,
# 输出通道数
out_channels=out_channels,
# dropout 率
dropout=dropout,
# 是否添加下采样
add_downsample=add_downsample,
# ResNet 的 epsilon 值
resnet_eps=resnet_eps,
# ResNet 的激活函数
resnet_act_fn=resnet_act_fn,
# ResNet 的组数
resnet_groups=resnet_groups,
# 下采样填充
downsample_padding=downsample_padding,
# 注意力头的维度
attention_head_dim=attention_head_dim,
# ResNet 时间缩放偏移
resnet_time_scale_shift=resnet_time_scale_shift,
)
# 检查下采样块的类型是否为 "KDownBlock2D"
elif down_block_type == "KDownBlock2D":
# 返回一个 KDownBlock2D 对象,初始化参数传入
return KDownBlock2D(
# 设置层数
num_layers=num_layers,
# 输入通道数
in_channels=in_channels,
# 输出通道数
out_channels=out_channels,
# 时间嵌入通道数
temb_channels=temb_channels,
# dropout 率
dropout=dropout,
# 是否添加下采样
add_downsample=add_downsample,
# ResNet 的 epsilon 值
resnet_eps=resnet_eps,
# ResNet 的激活函数
resnet_act_fn=resnet_act_fn,
)
# 检查下采样块的类型是否为 "KCrossAttnDownBlock2D"
elif down_block_type == "KCrossAttnDownBlock2D":
# 返回一个 KCrossAttnDownBlock2D 对象,初始化参数传入
return KCrossAttnDownBlock2D(
# 设置层数
num_layers=num_layers,
# 输入通道数
in_channels=in_channels,
# 输出通道数
out_channels=out_channels,
# 时间嵌入通道数
temb_channels=temb_channels,
# dropout 率
dropout=dropout,
# 是否添加下采样
add_downsample=add_downsample,
# ResNet 的 epsilon 值
resnet_eps=resnet_eps,
# ResNet 的激活函数
resnet_act_fn=resnet_act_fn,
# 跨注意力的维度
cross_attention_dim=cross_attention_dim,
# 注意力头的维度
attention_head_dim=attention_head_dim,
# 是否添加自注意力
add_self_attention=True if not add_downsample else False,
)
# 如果下采样块类型不匹配,则抛出异常
raise ValueError(f"{down_block_type} does not exist.")
# 根据给定参数生成中间块(mid block)
def get_mid_block(
# 中间块的类型
mid_block_type: str,
# 嵌入通道数
temb_channels: int,
# 输入通道数
in_channels: int,
# ResNet 的 epsilon 值
resnet_eps: float,
# ResNet 的激活函数类型
resnet_act_fn: str,
# ResNet 的组数
resnet_groups: int,
# 输出缩放因子,默认为 1.0
output_scale_factor: float = 1.0,
# 每个块的变换层数,默认为 1
transformer_layers_per_block: int = 1,
# 注意力头的数量,默认为 None
num_attention_heads: Optional[int] = None,
# 跨注意力的维度,默认为 None
cross_attention_dim: Optional[int] = None,
# 是否使用双重跨注意力,默认为 False
dual_cross_attention: bool = False,
# 是否使用线性投影,默认为 False
use_linear_projection: bool = False,
# 是否仅使用跨注意力作为中间块,默认为 False
mid_block_only_cross_attention: bool = False,
# 是否提升注意力精度,默认为 False
upcast_attention: bool = False,
# ResNet 的时间缩放偏移,默认为 "default"
resnet_time_scale_shift: str = "default",
# 注意力类型,默认为 "default"
attention_type: str = "default",
# ResNet 是否跳过时间激活,默认为 False
resnet_skip_time_act: bool = False,
# 跨注意力的归一化类型,默认为 None
cross_attention_norm: Optional[str] = None,
# 注意力头的维度,默认为 1
attention_head_dim: Optional[int] = 1,
# dropout 概率,默认为 0.0
dropout: float = 0.0,
):
# 根据中间块的类型生成对应的对象
if mid_block_type == "UNetMidBlock2DCrossAttn":
# 创建 UNet 的 2D 跨注意力中间块
return UNetMidBlock2DCrossAttn(
# 设置变换层数
transformer_layers_per_block=transformer_layers_per_block,
# 设置输入通道数
in_channels=in_channels,
# 设置嵌入通道数
temb_channels=temb_channels,
# 设置 dropout 概率
dropout=dropout,
# 设置 ResNet epsilon 值
resnet_eps=resnet_eps,
# 设置 ResNet 激活函数
resnet_act_fn=resnet_act_fn,
# 设置输出缩放因子
output_scale_factor=output_scale_factor,
# 设置时间缩放偏移
resnet_time_scale_shift=resnet_time_scale_shift,
# 设置跨注意力维度
cross_attention_dim=cross_attention_dim,
# 设置注意力头数量
num_attention_heads=num_attention_heads,
# 设置 ResNet 组数
resnet_groups=resnet_groups,
# 设置是否使用双重跨注意力
dual_cross_attention=dual_cross_attention,
# 设置是否使用线性投影
use_linear_projection=use_linear_projection,
# 设置是否提升注意力精度
upcast_attention=upcast_attention,
# 设置注意力类型
attention_type=attention_type,
)
# 检查是否为简单跨注意力中间块
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
# 创建 UNet 的 2D 简单跨注意力中间块
return UNetMidBlock2DSimpleCrossAttn(
# 设置输入通道数
in_channels=in_channels,
# 设置嵌入通道数
temb_channels=temb_channels,
# 设置 dropout 概率
dropout=dropout,
# 设置 ResNet epsilon 值
resnet_eps=resnet_eps,
# 设置 ResNet 激活函数
resnet_act_fn=resnet_act_fn,
# 设置输出缩放因子
output_scale_factor=output_scale_factor,
# 设置跨注意力维度
cross_attention_dim=cross_attention_dim,
# 设置注意力头的维度
attention_head_dim=attention_head_dim,
# 设置 ResNet 组数
resnet_groups=resnet_groups,
# 设置时间缩放偏移
resnet_time_scale_shift=resnet_time_scale_shift,
# 设置是否跳过时间激活
skip_time_act=resnet_skip_time_act,
# 设置是否仅使用跨注意力
only_cross_attention=mid_block_only_cross_attention,
# 设置跨注意力的归一化类型
cross_attention_norm=cross_attention_norm,
)
# 检查是否为标准的 2D 中间块
elif mid_block_type == "UNetMidBlock2D":
# 创建 UNet 的 2D 中间块
return UNetMidBlock2D(
# 设置输入通道数
in_channels=in_channels,
# 设置嵌入通道数
temb_channels=temb_channels,
# 设置 dropout 概率
dropout=dropout,
# 设置层数为 0
num_layers=0,
# 设置 ResNet epsilon 值
resnet_eps=resnet_eps,
# 设置 ResNet 激活函数
resnet_act_fn=resnet_act_fn,
# 设置输出缩放因子
output_scale_factor=output_scale_factor,
# 设置 ResNet 组数
resnet_groups=resnet_groups,
# 设置时间缩放偏移
resnet_time_scale_shift=resnet_time_scale_shift,
# 不添加注意力
add_attention=False,
)
# 检查中间块类型是否为 None
elif mid_block_type is None:
# 返回 None
return None
# 抛出未知类型的异常
else:
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
# 输出通道的数量
out_channels: int,
# 前一层输出通道的数量
prev_output_channel: int,
# 嵌入层通道的数量
temb_channels: int,
# 是否添加上采样层
add_upsample: bool,
# ResNet 中的 epsilon 值,用于数值稳定性
resnet_eps: float,
# ResNet 中的激活函数类型
resnet_act_fn: str,
# 分辨率索引,默认为 None
resolution_idx: Optional[int] = None,
# 每个块中的变换层数量
transformer_layers_per_block: int = 1,
# 注意力头的数量,默认为 None
num_attention_heads: Optional[int] = None,
# ResNet 中的组数量,默认为 None
resnet_groups: Optional[int] = None,
# 交叉注意力的维度,默认为 None
cross_attention_dim: Optional[int] = None,
# 是否使用双重交叉注意力
dual_cross_attention: bool = False,
# 是否使用线性投影
use_linear_projection: bool = False,
# 是否仅使用交叉注意力
only_cross_attention: bool = False,
# 是否上采样时提高注意力精度
upcast_attention: bool = False,
# ResNet 时间缩放偏移的类型,默认为 "default"
resnet_time_scale_shift: str = "default",
# 注意力类型,默认为 "default"
attention_type: str = "default",
# ResNet 中跳过时间激活的标志
resnet_skip_time_act: bool = False,
# ResNet 输出缩放因子,默认为 1.0
resnet_out_scale_factor: float = 1.0,
# 交叉注意力的归一化方式,默认为 None
cross_attention_norm: Optional[str] = None,
# 注意力头的维度,默认为 None
attention_head_dim: Optional[int] = None,
# 上采样类型,默认为 None
upsample_type: Optional[str] = None,
# 丢弃率,默认为 0.0
dropout: float = 0.0,
) -> nn.Module: # 指定该函数的返回类型为 nn.Module
# 如果未定义注意力头的维度,默认设置为注意力头的数量
if attention_head_dim is None:
logger.warning( # 记录警告信息
f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." # 提示用户提供 attention_head_dim
)
attention_head_dim = num_attention_heads # 将 attention_head_dim 设置为 num_attention_heads
# 如果 up_block_type 以 "UNetRes" 开头,则去掉前缀
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
# 检查 up_block_type 是否为 "UpBlock2D"
if up_block_type == "UpBlock2D":
return UpBlock2D( # 返回 UpBlock2D 对象
num_layers=num_layers, # 设置网络层数
in_channels=in_channels, # 设置输入通道数
out_channels=out_channels, # 设置输出通道数
prev_output_channel=prev_output_channel, # 设置前一个输出通道
temb_channels=temb_channels, # 设置时间嵌入通道数
resolution_idx=resolution_idx, # 设置分辨率索引
dropout=dropout, # 设置 dropout 参数
add_upsample=add_upsample, # 设置是否添加上采样
resnet_eps=resnet_eps, # 设置 ResNet 的 epsilon
resnet_act_fn=resnet_act_fn, # 设置 ResNet 的激活函数
resnet_groups=resnet_groups, # 设置 ResNet 的组数
resnet_time_scale_shift=resnet_time_scale_shift, # 设置 ResNet 的时间缩放偏移
)
# 检查 up_block_type 是否为 "ResnetUpsampleBlock2D"
elif up_block_type == "ResnetUpsampleBlock2D":
return ResnetUpsampleBlock2D( # 返回 ResnetUpsampleBlock2D 对象
num_layers=num_layers, # 设置网络层数
in_channels=in_channels, # 设置输入通道数
out_channels=out_channels, # 设置输出通道数
prev_output_channel=prev_output_channel, # 设置前一个输出通道
temb_channels=temb_channels, # 设置时间嵌入通道数
resolution_idx=resolution_idx, # 设置分辨率索引
dropout=dropout, # 设置 dropout 参数
add_upsample=add_upsample, # 设置是否添加上采样
resnet_eps=resnet_eps, # 设置 ResNet 的 epsilon
resnet_act_fn=resnet_act_fn, # 设置 ResNet 的激活函数
resnet_groups=resnet_groups, # 设置 ResNet 的组数
resnet_time_scale_shift=resnet_time_scale_shift, # 设置 ResNet 的时间缩放偏移
skip_time_act=resnet_skip_time_act, # 设置是否跳过时间激活
output_scale_factor=resnet_out_scale_factor, # 设置输出缩放因子
)
# 检查 up_block_type 是否为 "CrossAttnUpBlock2D"
elif up_block_type == "CrossAttnUpBlock2D":
# 如果未定义交叉注意力维度,则抛出异常
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") # 抛出值错误
return CrossAttnUpBlock2D( # 返回 CrossAttnUpBlock2D 对象
num_layers=num_layers, # 设置网络层数
transformer_layers_per_block=transformer_layers_per_block, # 设置每个块的变换层数
in_channels=in_channels, # 设置输入通道数
out_channels=out_channels, # 设置输出通道数
prev_output_channel=prev_output_channel, # 设置前一个输出通道
temb_channels=temb_channels, # 设置时间嵌入通道数
resolution_idx=resolution_idx, # 设置分辨率索引
dropout=dropout, # 设置 dropout 参数
add_upsample=add_upsample, # 设置是否添加上采样
resnet_eps=resnet_eps, # 设置 ResNet 的 epsilon
resnet_act_fn=resnet_act_fn, # 设置 ResNet 的激活函数
resnet_groups=resnet_groups, # 设置 ResNet 的组数
cross_attention_dim=cross_attention_dim, # 设置交叉注意力维度
num_attention_heads=num_attention_heads, # 设置注意力头的数量
dual_cross_attention=dual_cross_attention, # 设置双重交叉注意力
use_linear_projection=use_linear_projection, # 设置是否使用线性投影
only_cross_attention=only_cross_attention, # 设置是否仅使用交叉注意力
upcast_attention=upcast_attention, # 设置是否提升注意力
resnet_time_scale_shift=resnet_time_scale_shift, # 设置 ResNet 的时间缩放偏移
attention_type=attention_type, # 设置注意力类型
)
# 检查上采样块的类型是否为 SimpleCrossAttnUpBlock2D
elif up_block_type == "SimpleCrossAttnUpBlock2D":
# 如果未指定 cross_attention_dim,则抛出值错误
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
# 返回 SimpleCrossAttnUpBlock2D 实例,使用相关参数初始化
return SimpleCrossAttnUpBlock2D(
num_layers=num_layers, # 设置层数
in_channels=in_channels, # 设置输入通道数
out_channels=out_channels, # 设置输出通道数
prev_output_channel=prev_output_channel, # 设置前一层的输出通道数
temb_channels=temb_channels, # 设置时间嵌入通道数
resolution_idx=resolution_idx, # 设置分辨率索引
dropout=dropout, # 设置 dropout 概率
add_upsample=add_upsample, # 设置是否添加上采样
resnet_eps=resnet_eps, # 设置 ResNet 的 epsilon
resnet_act_fn=resnet_act_fn, # 设置 ResNet 的激活函数
resnet_groups=resnet_groups, # 设置 ResNet 的组数
cross_attention_dim=cross_attention_dim, # 设置交叉注意力维度
attention_head_dim=attention_head_dim, # 设置注意力头维度
resnet_time_scale_shift=resnet_time_scale_shift, # 设置 ResNet 的时间缩放偏移
skip_time_act=resnet_skip_time_act, # 设置是否跳过时间激活
output_scale_factor=resnet_out_scale_factor, # 设置输出缩放因子
only_cross_attention=only_cross_attention, # 设置是否仅使用交叉注意力
cross_attention_norm=cross_attention_norm, # 设置交叉注意力的归一化方式
)
# 检查上采样块的类型是否为 AttnUpBlock2D
elif up_block_type == "AttnUpBlock2D":
# 如果未添加上采样,则将上采样类型设为 None
if add_upsample is False:
upsample_type = None
else:
# 默认将上采样类型设为 'conv'
upsample_type = upsample_type or "conv"
# 返回 AttnUpBlock2D 实例,使用相关参数初始化
return AttnUpBlock2D(
num_layers=num_layers, # 设置层数
in_channels=in_channels, # 设置输入通道数
out_channels=out_channels, # 设置输出通道数
prev_output_channel=prev_output_channel, # 设置前一层的输出通道数
temb_channels=temb_channels, # 设置时间嵌入通道数
resolution_idx=resolution_idx, # 设置分辨率索引
dropout=dropout, # 设置 dropout 概率
resnet_eps=resnet_eps, # 设置 ResNet 的 epsilon
resnet_act_fn=resnet_act_fn, # 设置 ResNet 的激活函数
resnet_groups=resnet_groups, # 设置 ResNet 的组数
attention_head_dim=attention_head_dim, # 设置注意力头维度
resnet_time_scale_shift=resnet_time_scale_shift, # 设置 ResNet 的时间缩放偏移
upsample_type=upsample_type, # 设置上采样类型
)
# 检查上采样块的类型是否为 SkipUpBlock2D
elif up_block_type == "SkipUpBlock2D":
# 返回 SkipUpBlock2D 实例,使用相关参数初始化
return SkipUpBlock2D(
num_layers=num_layers, # 设置层数
in_channels=in_channels, # 设置输入通道数
out_channels=out_channels, # 设置输出通道数
prev_output_channel=prev_output_channel, # 设置前一层的输出通道数
temb_channels=temb_channels, # 设置时间嵌入通道数
resolution_idx=resolution_idx, # 设置分辨率索引
dropout=dropout, # 设置 dropout 概率
add_upsample=add_upsample, # 设置是否添加上采样
resnet_eps=resnet_eps, # 设置 ResNet 的 epsilon
resnet_act_fn=resnet_act_fn, # 设置 ResNet 的激活函数
resnet_time_scale_shift=resnet_time_scale_shift, # 设置 ResNet 的时间缩放偏移
)
# 检查上采样块的类型是否为 AttnSkipUpBlock2D
elif up_block_type == "AttnSkipUpBlock2D":
# 返回 AttnSkipUpBlock2D 实例,使用相关参数初始化
return AttnSkipUpBlock2D(
num_layers=num_layers, # 设置层数
in_channels=in_channels, # 设置输入通道数
out_channels=out_channels, # 设置输出通道数
prev_output_channel=prev_output_channel, # 设置前一层的输出通道数
temb_channels=temb_channels, # 设置时间嵌入通道数
resolution_idx=resolution_idx, # 设置分辨率索引
dropout=dropout, # 设置 dropout 概率
add_upsample=add_upsample, # 设置是否添加上采样
resnet_eps=resnet_eps, # 设置 ResNet 的 epsilon
resnet_act_fn=resnet_act_fn, # 设置 ResNet 的激活函数
attention_head_dim=attention_head_dim, # 设置注意力头维度
resnet_time_scale_shift=resnet_time_scale_shift, # 设置 ResNet 的时间缩放偏移
)
# 检查上采样块类型是否为 UpDecoderBlock2D
elif up_block_type == "UpDecoderBlock2D":
# 返回 UpDecoderBlock2D 的实例,传入相应参数
return UpDecoderBlock2D(
num_layers=num_layers, # 设置层数
in_channels=in_channels, # 设置输入通道数
out_channels=out_channels, # 设置输出通道数
resolution_idx=resolution_idx, # 设置分辨率索引
dropout=dropout, # 设置 dropout 比例
add_upsample=add_upsample, # 设置是否添加上采样
resnet_eps=resnet_eps, # 设置 ResNet 的 epsilon 值
resnet_act_fn=resnet_act_fn, # 设置 ResNet 的激活函数
resnet_groups=resnet_groups, # 设置 ResNet 的组数
resnet_time_scale_shift=resnet_time_scale_shift, # 设置时间尺度偏移
temb_channels=temb_channels, # 设置时间嵌入通道数
)
# 检查上采样块类型是否为 AttnUpDecoderBlock2D
elif up_block_type == "AttnUpDecoderBlock2D":
# 返回 AttnUpDecoderBlock2D 的实例,传入相应参数
return AttnUpDecoderBlock2D(
num_layers=num_layers, # 设置层数
in_channels=in_channels, # 设置输入通道数
out_channels=out_channels, # 设置输出通道数
resolution_idx=resolution_idx, # 设置分辨率索引
dropout=dropout, # 设置 dropout 比例
add_upsample=add_upsample, # 设置是否添加上采样
resnet_eps=resnet_eps, # 设置 ResNet 的 epsilon 值
resnet_act_fn=resnet_act_fn, # 设置 ResNet 的激活函数
resnet_groups=resnet_groups, # 设置 ResNet 的组数
attention_head_dim=attention_head_dim, # 设置注意力头维度
resnet_time_scale_shift=resnet_time_scale_shift, # 设置时间尺度偏移
temb_channels=temb_channels, # 设置时间嵌入通道数
)
# 检查上采样块类型是否为 KUpBlock2D
elif up_block_type == "KUpBlock2D":
# 返回 KUpBlock2D 的实例,传入相应参数
return KUpBlock2D(
num_layers=num_layers, # 设置层数
in_channels=in_channels, # 设置输入通道数
out_channels=out_channels, # 设置输出通道数
temb_channels=temb_channels, # 设置时间嵌入通道数
resolution_idx=resolution_idx, # 设置分辨率索引
dropout=dropout, # 设置 dropout 比例
add_upsample=add_upsample, # 设置是否添加上采样
resnet_eps=resnet_eps, # 设置 ResNet 的 epsilon 值
resnet_act_fn=resnet_act_fn, # 设置 ResNet 的激活函数
)
# 检查上采样块类型是否为 KCrossAttnUpBlock2D
elif up_block_type == "KCrossAttnUpBlock2D":
# 返回 KCrossAttnUpBlock2D 的实例,传入相应参数
return KCrossAttnUpBlock2D(
num_layers=num_layers, # 设置层数
in_channels=in_channels, # 设置输入通道数
out_channels=out_channels, # 设置输出通道数
temb_channels=temb_channels, # 设置时间嵌入通道数
resolution_idx=resolution_idx, # 设置分辨率索引
dropout=dropout, # 设置 dropout 比例
add_upsample=add_upsample, # 设置是否添加上采样
resnet_eps=resnet_eps, # 设置 ResNet 的 epsilon 值
resnet_act_fn=resnet_act_fn, # 设置 ResNet 的激活函数
cross_attention_dim=cross_attention_dim, # 设置交叉注意力维度
attention_head_dim=attention_head_dim, # 设置注意力头维度
)
# 如果未匹配到任何上采样块类型,抛出值错误
raise ValueError(f"{up_block_type} does not exist.")
# 定义一个小型自编码器块,继承自 nn.Module
class AutoencoderTinyBlock(nn.Module):
"""
Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU
blocks.
Args:
in_channels (`int`): The number of input channels.
out_channels (`int`): The number of output channels.
act_fn (`str`):
` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
Returns:
`torch.Tensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
`out_channels`.
"""
# 初始化函数,接受输入通道数、输出通道数和激活函数类型
def __init__(self, in_channels: int, out_channels: int, act_fn: str):
# 调用父类初始化
super().__init__()
# 获取指定的激活函数
act_fn = get_activation(act_fn)
# 定义一个序列,包括多个卷积层和激活函数
self.conv = nn.Sequential(
# 第一层卷积,输入通道数、输出通道数、卷积核大小和填充方式
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
# 添加激活函数
act_fn,
# 第二层卷积,保持输出通道数
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
# 添加激活函数
act_fn,
# 第三层卷积,保持输出通道数
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
)
# 判断输入和输出通道是否相同,决定使用卷积或身份映射
self.skip = (
# 如果通道数不一致,使用 1x1 卷积进行跳跃连接
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
if in_channels != out_channels
else nn.Identity()
)
# 使用 ReLU 进行特征融合
self.fuse = nn.ReLU()
# 定义前向传播函数,接受输入张量 x
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 返回卷积输出和跳跃连接的和,经过融合激活函数处理
return self.fuse(self.conv(x) + self.skip(x))
# 定义一个 2D UNet 中间块,继承自 nn.Module
class UNetMidBlock2D(nn.Module):
"""
A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
# 参数说明
Args:
in_channels (`int`): 输入通道的数量。
temb_channels (`int`): 时间嵌入通道的数量。
dropout (`float`, *optional*, defaults to 0.0): dropout 比率,用于防止过拟合。
num_layers (`int`, *optional*, defaults to 1): 残差块的数量。
resnet_eps (`float`, *optional*, 1e-6 ): 残差块的 epsilon 值,用于数值稳定性。
resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
应用于时间嵌入的归一化类型。可以改善模型在长时间依赖任务上的表现。
resnet_act_fn (`str`, *optional*, defaults to `swish`): 残差块的激活函数类型。
resnet_groups (`int`, *optional*, defaults to 32):
残差块中组归一化层使用的组数量。
attn_groups (`Optional[int]`, *optional*, defaults to None): 注意力块的组数量。
resnet_pre_norm (`bool`, *optional*, defaults to `True`):
是否在残差块中使用预归一化。
add_attention (`bool`, *optional*, defaults to `True`): 是否添加注意力块。
attention_head_dim (`int`, *optional*, defaults to 1):
单个注意力头的维度。注意力头的数量由该值和输入通道的数量决定。
output_scale_factor (`float`, *optional*, defaults to 1.0): 输出的缩放因子。
# 返回值说明
Returns:
`torch.Tensor`: 最后一个残差块的输出,形状为 `(batch_size, in_channels,
height, width)`。
"""
# 初始化函数,设置模型的各种参数
def __init__(
self,
in_channels: int, # 输入通道数
temb_channels: int, # 时间嵌入通道数
dropout: float = 0.0, # dropout比率,默认为0.0
num_layers: int = 1, # 残差块数量,默认为1
resnet_eps: float = 1e-6, # 残差块的epsilon值
resnet_time_scale_shift: str = "default", # 时间尺度归一化的类型
resnet_act_fn: str = "swish", # 残差块的激活函数
resnet_groups: int = 32, # 残差块的组数量
attn_groups: Optional[int] = None, # 注意力块的组数量
resnet_pre_norm: bool = True, # 是否使用预归一化
add_attention: bool = True, # 是否添加注意力块
attention_head_dim: int = 1, # 注意力头的维度
output_scale_factor: float = 1.0, # 输出缩放因子
# 前向传播函数,接受隐藏状态和时间嵌入作为输入
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
# 将输入的隐藏状态通过第一个残差块进行处理
hidden_states = self.resnets[0](hidden_states, temb)
# 遍历剩余的注意力块和残差块进行处理
for attn, resnet in zip(self.attentions, self.resnets[1:]):
# 如果当前存在注意力块,则进行处理
if attn is not None:
hidden_states = attn(hidden_states, temb=temb)
# 将处理后的隐藏状态通过当前的残差块进行处理
hidden_states = resnet(hidden_states, temb)
# 返回最终的隐藏状态
return hidden_states
# 定义一个继承自 nn.Module 的类 UNetMidBlock2DCrossAttn
class UNetMidBlock2DCrossAttn(nn.Module):
# 初始化方法,定义各个参数
def __init__(
# 输入通道数
in_channels: int,
# 时间嵌入通道数
temb_channels: int,
# 输出通道数(可选,默认为 None)
out_channels: Optional[int] = None,
# dropout 概率(默认为 0.0)
dropout: float = 0.0,
# 层数(默认为 1)
num_layers: int = 1,
# 每个块的变换器层数(默认为 1)
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
# ResNet 的 epsilon 值(默认为 1e-6)
resnet_eps: float = 1e-6,
# ResNet 的时间缩放偏移方式(默认为 "default")
resnet_time_scale_shift: str = "default",
# ResNet 的激活函数类型(默认为 "swish")
resnet_act_fn: str = "swish",
# ResNet 的组数(默认为 32)
resnet_groups: int = 32,
# 输出的 ResNet 组数(可选,默认为 None)
resnet_groups_out: Optional[int] = None,
# 是否进行 ResNet 预归一化(默认为 True)
resnet_pre_norm: bool = True,
# 注意力头的数量(默认为 1)
num_attention_heads: int = 1,
# 输出缩放因子(默认为 1.0)
output_scale_factor: float = 1.0,
# 交叉注意力维度(默认为 1280)
cross_attention_dim: int = 1280,
# 是否使用双重交叉注意力(默认为 False)
dual_cross_attention: bool = False,
# 是否使用线性投影(默认为 False)
use_linear_projection: bool = False,
# 是否提升注意力计算精度(默认为 False)
upcast_attention: bool = False,
# 注意力类型(默认为 "default")
attention_type: str = "default",
# 前向传播方法
def forward(
# 隐藏状态的张量
hidden_states: torch.Tensor,
# 时间嵌入的张量(可选,默认为 None)
temb: Optional[torch.Tensor] = None,
# 编码器隐藏状态的张量(可选,默认为 None)
encoder_hidden_states: Optional[torch.Tensor] = None,
# 注意力掩码(可选,默认为 None)
attention_mask: Optional[torch.Tensor] = None,
# 交叉注意力参数(可选,默认为 None)
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# 编码器注意力掩码(可选,默认为 None)
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: # 定义返回类型为 torch.Tensor
if cross_attention_kwargs is not None: # 检查交叉注意力参数是否存在
if cross_attention_kwargs.get("scale", None) is not None: # 检查是否有 scale 参数
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") # 记录警告,提示 scale 参数已弃用
hidden_states = self.resnets[0](hidden_states, temb) # 使用第一个残差网络处理隐藏状态和时间嵌入
for attn, resnet in zip(self.attentions, self.resnets[1:]): # 遍历注意力层和残差网络(跳过第一个)
if self.training and self.gradient_checkpointing: # 如果处于训练模式并且启用了梯度检查点
def create_custom_forward(module, return_dict=None): # 定义自定义前向传播函数
def custom_forward(*inputs): # 定义实际的前向传播实现
if return_dict is not None: # 如果返回字典不为 None
return module(*inputs, return_dict=return_dict) # 调用模块并返回字典
else: # 如果返回字典为 None
return module(*inputs) # 直接调用模块并返回结果
return custom_forward # 返回自定义前向传播函数
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} # 设置检查点的关键字参数
hidden_states = attn( # 调用注意力层处理隐藏状态
hidden_states, # 输入隐藏状态
encoder_hidden_states=encoder_hidden_states, # 编码器隐藏状态
cross_attention_kwargs=cross_attention_kwargs, # 交叉注意力参数
attention_mask=attention_mask, # 注意力掩码
encoder_attention_mask=encoder_attention_mask, # 编码器注意力掩码
return_dict=False, # 不返回字典格式
)[0] # 取处理结果的第一个元素
hidden_states = torch.utils.checkpoint.checkpoint( # 使用梯度检查点
create_custom_forward(resnet), # 创建残差网络的自定义前向函数
hidden_states, # 输入隐藏状态
temb, # 输入时间嵌入
**ckpt_kwargs, # 解包关键字参数
)
else: # 如果不处于训练模式或不启用梯度检查点
hidden_states = attn( # 调用注意力层处理隐藏状态
hidden_states, # 输入隐藏状态
encoder_hidden_states=encoder_hidden_states, # 编码器隐藏状态
cross_attention_kwargs=cross_attention_kwargs, # 交叉注意力参数
attention_mask=attention_mask, # 注意力掩码
encoder_attention_mask=encoder_attention_mask, # 编码器注意力掩码
return_dict=False, # 不返回字典格式
)[0] # 取处理结果的第一个元素
hidden_states = resnet(hidden_states, temb) # 使用残差网络处理隐藏状态和时间嵌入
return hidden_states # 返回最终的隐藏状态
# 定义一个 UNet 中间块类,继承自 nn.Module
class UNetMidBlock2DSimpleCrossAttn(nn.Module):
# 初始化方法,设置类的参数
def __init__(
# 输入通道数
in_channels: int,
# 时间嵌入通道数
temb_channels: int,
# Dropout 比例,默认为 0.0
dropout: float = 0.0,
# 层数,默认为 1
num_layers: int = 1,
# ResNet 的小 epsilon 值,默认为 1e-6
resnet_eps: float = 1e-6,
# ResNet 时间尺度偏移,默认为 "default"
resnet_time_scale_shift: str = "default",
# ResNet 激活函数类型,默认为 "swish"
resnet_act_fn: str = "swish",
# ResNet 组数,默认为 32
resnet_groups: int = 32,
# 是否使用预归一化,默认为 True
resnet_pre_norm: bool = True,
# 注意力头维度,默认为 1
attention_head_dim: int = 1,
# 输出缩放因子,默认为 1.0
output_scale_factor: float = 1.0,
# 交叉注意力维度,默认为 1280
cross_attention_dim: int = 1280,
# 是否跳过时间激活,默认为 False
skip_time_act: bool = False,
# 是否仅使用交叉注意力,默认为 False
only_cross_attention: bool = False,
# 交叉注意力的归一化方法,可选参数
cross_attention_norm: Optional[str] = None,
):
# 调用父类的初始化方法
super().__init__()
# 设置是否使用交叉注意力
self.has_cross_attention = True
# 设置注意力头的维度
self.attention_head_dim = attention_head_dim
# 计算 ResNet 的组数,如果未提供,则取输入通道数的四分之一和 32 的最小值
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# 计算注意力头的数量
self.num_heads = in_channels // self.attention_head_dim
# 至少存在一个 ResNet 块
resnets = [
# 创建一个 ResNet 块
ResnetBlock2D(
# 输入通道数
in_channels=in_channels,
# 输出通道数
out_channels=in_channels,
# 时间嵌入通道数
temb_channels=temb_channels,
# 正则化参数
eps=resnet_eps,
# ResNet 组数
groups=resnet_groups,
# dropout 概率
dropout=dropout,
# 时间嵌入归一化方法
time_embedding_norm=resnet_time_scale_shift,
# 非线性激活函数
non_linearity=resnet_act_fn,
# 输出缩放因子
output_scale_factor=output_scale_factor,
# 是否进行预归一化
pre_norm=resnet_pre_norm,
# 是否跳过时间激活
skip_time_act=skip_time_act,
)
]
# 初始化注意力层列表
attentions = []
# 循环创建指定数量的层
for _ in range(num_layers):
# 根据是否具有缩放点积注意力,选择处理器
processor = (
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
)
# 添加注意力层到列表中
attentions.append(
Attention(
# 查询的维度
query_dim=in_channels,
# 交叉注意力的维度
cross_attention_dim=in_channels,
# 注意力头的数量
heads=self.num_heads,
# 每个头的维度
dim_head=self.attention_head_dim,
# 额外的 KV 投影维度
added_kv_proj_dim=cross_attention_dim,
# 归一化的组数
norm_num_groups=resnet_groups,
# 是否使用偏置
bias=True,
# 是否上cast softmax
upcast_softmax=True,
# 是否仅使用交叉注意力
only_cross_attention=only_cross_attention,
# 交叉注意力的归一化方法
cross_attention_norm=cross_attention_norm,
# 设置处理器
processor=processor,
)
)
# 添加另一个 ResNet 块到列表中
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
)
# 将注意力层转为可训练模块列表
self.attentions = nn.ModuleList(attentions)
# 将 ResNet 块转为可训练模块列表
self.resnets = nn.ModuleList(resnets)
def forward(
# 前向传播函数的定义
self,
# 输入的隐状态
hidden_states: torch.Tensor,
# 可选的时间嵌入
temb: Optional[torch.Tensor] = None,
# 可选的编码器隐状态
encoder_hidden_states: Optional[torch.Tensor] = None,
# 可选的注意力掩码
attention_mask: Optional[torch.Tensor] = None,
# 可选的交叉注意力参数
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# 可选的编码器注意力掩码
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# 如果 cross_attention_kwargs 为 None,则初始化为空字典
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
# 检查 cross_attention_kwargs 中是否有 "scale" 参数,如果有则记录警告信息
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# 如果 attention_mask 为 None
if attention_mask is None:
# 如果 encoder_hidden_states 已定义,表示正在进行交叉注意力,因此使用交叉注意力掩码
mask = None if encoder_hidden_states is None else encoder_attention_mask
else:
# 当 attention_mask 已定义时,不检查 encoder_attention_mask
# 这是为了与 UnCLIP 兼容,UnCLIP 使用 'attention_mask' 参数作为交叉注意力掩码
# TODO: UnCLIP 应该通过 encoder_attention_mask 参数而不是 attention_mask 参数来表达交叉注意力掩码
# 然后可以简化整个 if/else 块为:
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
mask = attention_mask
# 通过第一个残差网络处理隐藏状态
hidden_states = self.resnets[0](hidden_states, temb)
# 遍历注意力层和后续的残差网络
for attn, resnet in zip(self.attentions, self.resnets[1:]):
# 使用当前的注意力层处理隐藏状态
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states, # 传递编码器的隐藏状态
attention_mask=mask, # 传递注意力掩码
**cross_attention_kwargs, # 解包交叉注意力参数
)
# 通过当前的残差网络处理隐藏状态
hidden_states = resnet(hidden_states, temb)
# 返回处理后的隐藏状态
return hidden_states
# 定义一个名为 AttnDownBlock2D 的类,继承自 nn.Module
class AttnDownBlock2D(nn.Module):
# 初始化方法,接受多个参数以设置层的属性
def __init__(
# 输入通道数
in_channels: int,
# 输出通道数
out_channels: int,
# 时间嵌入通道数
temb_channels: int,
# dropout 概率,默认为 0.0
dropout: float = 0.0,
# 层数,默认为 1
num_layers: int = 1,
# ResNet 的 epsilon 值,防止除零错误,默认为 1e-6
resnet_eps: float = 1e-6,
# ResNet 时间尺度偏移,默认为 "default"
resnet_time_scale_shift: str = "default",
# ResNet 激活函数,默认为 "swish"
resnet_act_fn: str = "swish",
# ResNet 组数,默认为 32
resnet_groups: int = 32,
# 是否在 ResNet 中使用预归一化,默认为 True
resnet_pre_norm: bool = True,
# 注意力头的维度,默认为 1
attention_head_dim: int = 1,
# 输出缩放因子,默认为 1.0
output_scale_factor: float = 1.0,
# 下采样的填充大小,默认为 1
downsample_padding: int = 1,
# 下采样类型,默认为 "conv"
downsample_type: str = "conv",
):
# 调用父类构造函数初始化
super().__init__()
# 初始化空列表,用于存储 ResNet 块
resnets = []
# 初始化空列表,用于存储注意力机制模块
attentions = []
# 保存下采样类型
self.downsample_type = downsample_type
# 如果未指定注意力头的维度,则发出警告,并默认使用输出通道数
if attention_head_dim is None:
logger.warning(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
)
# 将注意力头维度设置为输出通道数
attention_head_dim = out_channels
# 遍历层数以构建 ResNet 块和注意力模块
for i in range(num_layers):
# 确定输入通道数,第一层使用初始通道数,其余层使用输出通道数
in_channels = in_channels if i == 0 else out_channels
# 创建并添加 ResNet 块到 resnets 列表
resnets.append(
ResnetBlock2D(
in_channels=in_channels, # 输入通道数
out_channels=out_channels, # 输出通道数
temb_channels=temb_channels, # 时间嵌入通道数
eps=resnet_eps, # 防止除零的epsilon值
groups=resnet_groups, # 分组数
dropout=dropout, # dropout 比率
time_embedding_norm=resnet_time_scale_shift, # 时间嵌入的归一化方式
non_linearity=resnet_act_fn, # 非线性激活函数
output_scale_factor=output_scale_factor, # 输出缩放因子
pre_norm=resnet_pre_norm, # 是否在前面进行归一化
)
)
# 创建并添加注意力模块到 attentions 列表
attentions.append(
Attention(
out_channels, # 输出通道数
heads=out_channels // attention_head_dim, # 注意力头的数量
dim_head=attention_head_dim, # 每个注意力头的维度
rescale_output_factor=output_scale_factor, # 输出缩放因子
eps=resnet_eps, # 防止除零的epsilon值
norm_num_groups=resnet_groups, # 归一化的组数
residual_connection=True, # 是否使用残差连接
bias=True, # 是否使用偏置
upcast_softmax=True, # 是否上溯 softmax
_from_deprecated_attn_block=True, # 是否来自于已弃用的注意力块
)
)
# 将注意力模块列表转换为 nn.ModuleList,便于管理
self.attentions = nn.ModuleList(attentions)
# 将 ResNet 块列表转换为 nn.ModuleList,便于管理
self.resnets = nn.ModuleList(resnets)
# 根据下采样类型选择相应的下采样方法
if downsample_type == "conv":
# 创建卷积下采样模块并存储
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, # 输出通道数
use_conv=True, # 是否使用卷积
out_channels=out_channels, # 输出通道数
padding=downsample_padding, # 填充
name="op" # 模块名称
)
]
)
elif downsample_type == "resnet":
# 创建 ResNet 下采样块并存储
self.downsamplers = nn.ModuleList(
[
ResnetBlock2D(
in_channels=out_channels, # 输入通道数
out_channels=out_channels, # 输出通道数
temb_channels=temb_channels, # 时间嵌入通道数
eps=resnet_eps, # 防止除零的epsilon值
groups=resnet_groups, # 分组数
dropout=dropout, # dropout 比率
time_embedding_norm=resnet_time_scale_shift, # 时间嵌入的归一化方式
non_linearity=resnet_act_fn, # 非线性激活函数
output_scale_factor=output_scale_factor, # 输出缩放因子
pre_norm=resnet_pre_norm, # 是否在前面进行归一化
down=True, # 指示为下采样
)
]
)
else:
# 如果没有匹配的下采样类型,则将下采样模块设置为 None
self.downsamplers = None
# 前向传播方法,处理隐状态和可选的其他参数,返回处理后的隐状态和输出状态
def forward(
self,
hidden_states: torch.Tensor, # 输入的隐状态张量
temb: Optional[torch.Tensor] = None, # 可选的时间嵌入张量
upsample_size: Optional[int] = None, # 可选的上采样尺寸
cross_attention_kwargs: Optional[Dict[str, Any]] = None, # 可选的交叉注意力参数字典
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: # 返回隐状态和输出状态的元组
# 如果没有提供交叉注意力参数,则初始化为空字典
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
# 检查是否传入了 scale 参数,如果有,则发出警告,因为这个参数已被弃用
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
output_states = () # 初始化输出状态为一个空元组
# 遍历自定义的残差网络和注意力层
for resnet, attn in zip(self.resnets, self.attentions):
# 将隐状态传递给残差网络并更新隐状态
hidden_states = resnet(hidden_states, temb)
# 将隐状态传递给注意力层,并更新隐状态
hidden_states = attn(hidden_states, **cross_attention_kwargs)
# 将当前隐状态添加到输出状态元组中
output_states = output_states + (hidden_states,)
# 检查是否存在下采样层
if self.downsamplers is not None:
# 遍历每个下采样层
for downsampler in self.downsamplers:
# 根据下采样类型选择不同的处理方式
if self.downsample_type == "resnet":
hidden_states = downsampler(hidden_states, temb=temb) # 使用时间嵌入处理
else:
hidden_states = downsampler(hidden_states) # 不使用时间嵌入处理
# 将最后的隐状态添加到输出状态中
output_states += (hidden_states,)
# 返回处理后的隐状态和输出状态
return hidden_states, output_states
# 定义一个名为 CrossAttnDownBlock2D 的类,继承自 nn.Module
class CrossAttnDownBlock2D(nn.Module):
# 初始化方法,接收多个参数以配置模块
def __init__(
# 输入通道数
self,
in_channels: int,
# 输出通道数
out_channels: int,
# 时间嵌入通道数
temb_channels: int,
# dropout 比例,默认为 0.0
dropout: float = 0.0,
# 层数,默认为 1
num_layers: int = 1,
# 每个块的变换器层数,可以是单个整数或整数元组
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
# ResNet 的 epsilon 值,默认为 1e-6
resnet_eps: float = 1e-6,
# ResNet 时间缩放偏移的类型,默认为 "default"
resnet_time_scale_shift: str = "default",
# ResNet 的激活函数,默认为 "swish"
resnet_act_fn: str = "swish",
# ResNet 的分组数,默认为 32
resnet_groups: int = 32,
# 是否使用预归一化,默认为 True
resnet_pre_norm: bool = True,
# 注意力头的数量,默认为 1
num_attention_heads: int = 1,
# 交叉注意力维度,默认为 1280
cross_attention_dim: int = 1280,
# 输出缩放因子,默认为 1.0
output_scale_factor: float = 1.0,
# 下采样填充的大小,默认为 1
downsample_padding: int = 1,
# 是否添加下采样,默认为 True
add_downsample: bool = True,
# 是否使用双重交叉注意力,默认为 False
dual_cross_attention: bool = False,
# 是否使用线性投影,默认为 False
use_linear_projection: bool = False,
# 是否仅使用交叉注意力,默认为 False
only_cross_attention: bool = False,
# 是否提升注意力精度,默认为 False
upcast_attention: bool = False,
# 注意力类型,默认为 "default"
attention_type: str = "default",
# 初始化父类
):
super().__init__()
# 初始化残差块列表
resnets = []
# 初始化注意力机制列表
attentions = []
# 设置是否使用交叉注意力
self.has_cross_attention = True
# 设置注意力头的数量
self.num_attention_heads = num_attention_heads
# 如果每个块的变换层是整数,则扩展为列表
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
# 遍历每一层
for i in range(num_layers):
# 确定输入通道数
in_channels = in_channels if i == 0 else out_channels
# 添加残差块到列表
resnets.append(
ResnetBlock2D(
in_channels=in_channels, # 输入通道数
out_channels=out_channels, # 输出通道数
temb_channels=temb_channels, # 时间嵌入通道数
eps=resnet_eps, # epsilon值
groups=resnet_groups, # 组数
dropout=dropout, # dropout比率
time_embedding_norm=resnet_time_scale_shift, # 时间嵌入归一化
non_linearity=resnet_act_fn, # 非线性激活函数
output_scale_factor=output_scale_factor, # 输出缩放因子
pre_norm=resnet_pre_norm, # 是否使用预归一化
)
)
# 检查是否使用双重交叉注意力
if not dual_cross_attention:
# 添加普通的变换模型到列表
attentions.append(
Transformer2DModel(
num_attention_heads, # 注意力头数量
out_channels // num_attention_heads, # 每个头的输出通道数
in_channels=out_channels, # 输入通道数
num_layers=transformer_layers_per_block[i], # 层数
cross_attention_dim=cross_attention_dim, # 交叉注意力维度
norm_num_groups=resnet_groups, # 归一化组数
use_linear_projection=use_linear_projection, # 是否使用线性投影
only_cross_attention=only_cross_attention, # 是否仅使用交叉注意力
upcast_attention=upcast_attention, # 是否向上投射注意力
attention_type=attention_type, # 注意力类型
)
)
else:
# 添加双重变换模型到列表
attentions.append(
DualTransformer2DModel(
num_attention_heads, # 注意力头数量
out_channels // num_attention_heads, # 每个头的输出通道数
in_channels=out_channels, # 输入通道数
num_layers=1, # 层数固定为1
cross_attention_dim=cross_attention_dim, # 交叉注意力维度
norm_num_groups=resnet_groups, # 归一化组数
)
)
# 将注意力模型列表转换为nn.ModuleList
self.attentions = nn.ModuleList(attentions)
# 将残差块列表转换为nn.ModuleList
self.resnets = nn.ModuleList(resnets)
# 检查是否添加下采样层
if add_downsample:
# 创建下采样层列表
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, # 输出通道数
use_conv=True, # 是否使用卷积
out_channels=out_channels, # 输出通道数
padding=downsample_padding, # 填充
name="op" # 操作名称
)
]
)
else:
# 如果不添加下采样层,则设置为None
self.downsamplers = None
# 设置梯度检查点开关为关闭
self.gradient_checkpointing = False
# 定义前向传播函数,接收多个参数
def forward(
self,
# 隐藏状态张量,表示模型的内部状态
hidden_states: torch.Tensor,
# 可选的时间嵌入张量,用于控制生成的时间步
temb: Optional[torch.Tensor] = None,
# 可选的编码器隐藏状态张量,表示编码器的输出
encoder_hidden_states: Optional[torch.Tensor] = None,
# 可选的注意力掩码,用于屏蔽输入中不需要关注的部分
attention_mask: Optional[torch.Tensor] = None,
# 可选的交叉注意力参数字典,用于传递其他配置
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# 可选的编码器注意力掩码,控制编码器的注意力机制
encoder_attention_mask: Optional[torch.Tensor] = None,
# 可选的附加残差张量,作为额外的信息传递
additional_residuals: Optional[torch.Tensor] = None,
# 返回类型为元组,包含张量和一个张量元组
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
# 检查交叉注意力参数是否不为空
if cross_attention_kwargs is not None:
# 如果"scale"参数存在,则发出警告,说明其已被弃用
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# 初始化输出状态为空元组
output_states = ()
# 将残差网络和注意力层配对
blocks = list(zip(self.resnets, self.attentions))
# 遍历每一对残差网络和注意力层
for i, (resnet, attn) in enumerate(blocks):
# 如果在训练中并且启用了梯度检查点
if self.training and self.gradient_checkpointing:
# 定义创建自定义前向传播的函数
def create_custom_forward(module, return_dict=None):
# 定义自定义前向传播逻辑
def custom_forward(*inputs):
# 根据是否返回字典调用模块
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
# 根据 PyTorch 版本设置检查点参数
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
# 使用检查点机制计算隐藏状态
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
**ckpt_kwargs,
)
# 通过注意力层处理隐藏状态
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
# 直接通过残差网络处理隐藏状态
hidden_states = resnet(hidden_states, temb)
# 通过注意力层处理隐藏状态
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
# 如果是最后一对块并且有额外的残差
if i == len(blocks) - 1 and additional_residuals is not None:
# 将额外的残差添加到隐藏状态
hidden_states = hidden_states + additional_residuals
# 更新输出状态,添加当前隐藏状态
output_states = output_states + (hidden_states,)
# 如果下采样器不为空
if self.downsamplers is not None:
# 遍历每个下采样器
for downsampler in self.downsamplers:
# 使用下采样器处理隐藏状态
hidden_states = downsampler(hidden_states)
# 更新输出状态,添加当前隐藏状态
output_states = output_states + (hidden_states,)
# 返回最终的隐藏状态和输出状态
return hidden_states, output_states
# 定义一个二维向下块,继承自 nn.Module
class DownBlock2D(nn.Module):
# 初始化方法,定义各个参数及其默认值
def __init__(
self,
in_channels: int, # 输入通道数
out_channels: int, # 输出通道数
temb_channels: int, # 时间嵌入通道数
dropout: float = 0.0, # dropout 概率
num_layers: int = 1, # ResNet 层数
resnet_eps: float = 1e-6, # ResNet 中的 epsilon
resnet_time_scale_shift: str = "default", # 时间缩放偏移设置
resnet_act_fn: str = "swish", # ResNet 的激活函数
resnet_groups: int = 32, # ResNet 的分组数
resnet_pre_norm: bool = True, # 是否使用预归一化
output_scale_factor: float = 1.0, # 输出缩放因子
add_downsample: bool = True, # 是否添加下采样
downsample_padding: int = 1, # 下采样填充
):
# 调用父类构造函数
super().__init__()
# 初始化空的 ResNet 块列表
resnets = []
# 根据层数循环创建 ResNet 块
for i in range(num_layers):
# 确定当前层的输入通道数
in_channels = in_channels if i == 0 else out_channels
# 添加 ResNet 块到列表
resnets.append(
ResnetBlock2D(
in_channels=in_channels, # 当前层的输入通道数
out_channels=out_channels, # 当前层的输出通道数
temb_channels=temb_channels, # 时间嵌入通道数
eps=resnet_eps, # epsilon 参数
groups=resnet_groups, # 分组数
dropout=dropout, # dropout 概率
time_embedding_norm=resnet_time_scale_shift, # 时间嵌入归一化
non_linearity=resnet_act_fn, # 激活函数
output_scale_factor=output_scale_factor, # 输出缩放因子
pre_norm=resnet_pre_norm, # 预归一化标志
)
)
# 将 ResNet 块列表转换为 ModuleList
self.resnets = nn.ModuleList(resnets)
# 根据标志决定是否添加下采样层
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample2D( # 创建下采样层
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
# 如果不添加下采样,则将其设置为 None
self.downsamplers = None
# 初始化梯度检查点标志为 False
self.gradient_checkpointing = False
# 定义前向传播方法
def forward(
self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
# 检查位置参数是否大于0,或关键字参数中的 "scale" 是否不为 None
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 定义弃用消息,提示用户 "scale" 参数已弃用,未来将引发错误
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 调用弃用函数,记录 "scale" 参数的弃用
deprecate("scale", "1.0.0", deprecation_message)
# 初始化输出状态元组
output_states = ()
# 遍历自定义的 ResNet 模块列表
for resnet in self.resnets:
# 如果在训练模式下并且启用了梯度检查点
if self.training and self.gradient_checkpointing:
# 定义用于创建自定义前向传播的函数
def create_custom_forward(module):
# 定义自定义前向传播函数
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# 检查 PyTorch 版本是否大于等于 1.11.0
if is_torch_version(">=", "1.11.0"):
# 使用检查点技术来计算隐藏状态,防止内存泄漏
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
# 在较早的版本中调用检查点计算隐藏状态
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
# 在非训练模式下直接调用 ResNet 模块以计算隐藏状态
hidden_states = resnet(hidden_states, temb)
# 将计算出的隐藏状态添加到输出状态元组中
output_states = output_states + (hidden_states,)
# 如果存在下采样器
if self.downsamplers is not None:
# 遍历下采样器并计算隐藏状态
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
# 将下采样后的隐藏状态添加到输出状态元组中
output_states = output_states + (hidden_states,)
# 返回当前的隐藏状态和输出状态元组
return hidden_states, output_states
# 定义一个 2D 下采样编码块的类,继承自 nn.Module
class DownEncoderBlock2D(nn.Module):
# 初始化方法,设置各类参数
def __init__(
self,
in_channels: int, # 输入通道数
out_channels: int, # 输出通道数
dropout: float = 0.0, # dropout 概率,默认 0
num_layers: int = 1, # 层数,默认 1
resnet_eps: float = 1e-6, # ResNet 的 epsilon 值,默认 1e-6
resnet_time_scale_shift: str = "default", # 时间尺度偏移方式,默认值为 "default"
resnet_act_fn: str = "swish", # ResNet 使用的激活函数,默认是 "swish"
resnet_groups: int = 32, # ResNet 的组数,默认 32
resnet_pre_norm: bool = True, # 是否在前面进行归一化,默认 True
output_scale_factor: float = 1.0, # 输出缩放因子,默认 1.0
add_downsample: bool = True, # 是否添加下采样层,默认 True
downsample_padding: int = 1, # 下采样的填充大小,默认 1
):
# 调用父类的初始化方法
super().__init__()
# 初始化一个空的列表,用于存放 ResNet 块
resnets = []
# 根据层数创建 ResNet 块
for i in range(num_layers):
# 如果是第一层,使用输入通道数;否则使用输出通道数
in_channels = in_channels if i == 0 else out_channels
# 根据时间尺度偏移方式选择相应的 ResNet 块
if resnet_time_scale_shift == "spatial":
# 创建一个带条件归一化的 ResNet 块,并添加到 resnets 列表
resnets.append(
ResnetBlockCondNorm2D(
in_channels=in_channels, # 输入通道数
out_channels=out_channels, # 输出通道数
temb_channels=None, # 时间嵌入通道数,默认 None
eps=resnet_eps, # epsilon 值
groups=resnet_groups, # 组数
dropout=dropout, # dropout 概率
time_embedding_norm="spatial", # 时间嵌入的归一化方式
non_linearity=resnet_act_fn, # 激活函数
output_scale_factor=output_scale_factor, # 输出缩放因子
)
)
else:
# 创建一个标准的 ResNet 块,并添加到 resnets 列表
resnets.append(
ResnetBlock2D(
in_channels=in_channels, # 输入通道数
out_channels=out_channels, # 输出通道数
temb_channels=None, # 时间嵌入通道数,默认 None
eps=resnet_eps, # epsilon 值
groups=resnet_groups, # 组数
dropout=dropout, # dropout 概率
time_embedding_norm=resnet_time_scale_shift, # 时间嵌入的归一化方式
non_linearity=resnet_act_fn, # 激活函数
output_scale_factor=output_scale_factor, # 输出缩放因子
pre_norm=resnet_pre_norm, # 是否前归一化
)
)
# 将 ResNet 块列表转为 nn.ModuleList,便于管理
self.resnets = nn.ModuleList(resnets)
# 根据 add_downsample 参数决定是否添加下采样层
if add_downsample:
# 创建下采样层,并添加到 nn.ModuleList
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, # 输入通道数
use_conv=True, # 是否使用卷积进行下采样
out_channels=out_channels, # 输出通道数
padding=downsample_padding, # 填充大小
name="op" # 下采样层的名称
)
]
)
else:
# 如果不添加下采样层,设置为 None
self.downsamplers = None
# 定义前向传播函数,接受隐藏状态和可选参数
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
# 检查是否传入多余的参数或已弃用的 `scale` 参数
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 设置弃用信息,提醒用户移除 `scale` 参数
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 调用弃用函数,记录警告信息
deprecate("scale", "1.0.0", deprecation_message)
# 遍历每个 ResNet 层,更新隐藏状态
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None)
# 如果存在下采样层,则逐个应用下采样
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
# 返回最终的隐藏状态
return hidden_states
# 定义一个二维注意力下采样编码器块的类,继承自 nn.Module
class AttnDownEncoderBlock2D(nn.Module):
# 初始化方法,接收多个参数用于配置编码器块
def __init__(
# 输入通道数
in_channels: int,
# 输出通道数
out_channels: int,
# 丢弃率,控制神经元随机失活的比例
dropout: float = 0.0,
# 编码器块的层数
num_layers: int = 1,
# ResNet 的小常量,用于防止除零错误
resnet_eps: float = 1e-6,
# ResNet 的时间尺度偏移,默认配置
resnet_time_scale_shift: str = "default",
# ResNet 使用的激活函数类型,默认为 swish
resnet_act_fn: str = "swish",
# ResNet 中的分组数
resnet_groups: int = 32,
# 是否在前面进行归一化,默认为 True
resnet_pre_norm: bool = True,
# 注意力头的维度
attention_head_dim: int = 1,
# 输出缩放因子,默认为 1.0
output_scale_factor: float = 1.0,
# 是否添加下采样层,默认为 True
add_downsample: bool = True,
# 下采样的填充大小,默认为 1
downsample_padding: int = 1,
):
# 调用父类构造函数
super().__init__()
# 初始化空列表以存储残差块
resnets = []
# 初始化空列表以存储注意力模块
attentions = []
# 检查是否传入注意力头维度
if attention_head_dim is None:
# 记录警告信息,默认设置注意力头维度为输出通道数
logger.warning(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
)
# 将注意力头维度设置为输出通道数
attention_head_dim = out_channels
# 遍历层数以构建残差块和注意力模块
for i in range(num_layers):
# 第一层输入通道为 in_channels,其余层为 out_channels
in_channels = in_channels if i == 0 else out_channels
# 根据时间缩放偏移类型构建不同类型的残差块
if resnet_time_scale_shift == "spatial":
# 添加条件归一化的残差块到列表
resnets.append(
ResnetBlockCondNorm2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=None,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm="spatial",
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
)
)
else:
# 添加普通残差块到列表
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=None,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
# 添加注意力模块到列表
attentions.append(
Attention(
out_channels,
heads=out_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=resnet_groups,
residual_connection=True,
bias=True,
upcast_softmax=True,
_from_deprecated_attn_block=True,
)
)
# 将注意力模块列表转换为 nn.ModuleList
self.attentions = nn.ModuleList(attentions)
# 将残差块列表转换为 nn.ModuleList
self.resnets = nn.ModuleList(resnets)
# 根据标志决定是否添加下采样层
if add_downsample:
# 创建下采样模块并添加到列表
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
# 如果不添加下采样层,将其设置为 None
self.downsamplers = None
# 定义前向传播方法,接收隐藏状态和其他参数
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
# 检查是否有额外的参数或已弃用的 scale 参数
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 构建弃用警告信息
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 调用 deprecate 函数显示弃用警告
deprecate("scale", "1.0.0", deprecation_message)
# 遍历自定义的 ResNet 和注意力层进行处理
for resnet, attn in zip(self.resnets, self.attentions):
# 通过 ResNet 层处理隐藏状态
hidden_states = resnet(hidden_states, temb=None)
# 通过注意力层处理更新后的隐藏状态
hidden_states = attn(hidden_states)
# 如果有下采样层,则依次处理隐藏状态
if self.downsamplers is not None:
for downsampler in self.downsamplers:
# 通过下采样层处理隐藏状态
hidden_states = downsampler(hidden_states)
# 返回处理后的隐藏状态
return hidden_states
# 定义一个名为 AttnSkipDownBlock2D 的类,继承自 nn.Module
class AttnSkipDownBlock2D(nn.Module):
# 初始化方法,定义类的构造函数
def __init__(
# 输入通道数,整型
in_channels: int,
# 输出通道数,整型
out_channels: int,
# 嵌入通道数,整型
temb_channels: int,
# dropout 率,浮点型,默认为 0.0
dropout: float = 0.0,
# 网络层数,整型,默认为 1
num_layers: int = 1,
# ResNet 的 epsilon 值,浮点型,默认为 1e-6
resnet_eps: float = 1e-6,
# ResNet 的时间尺度偏移方式,字符串,默认为 "default"
resnet_time_scale_shift: str = "default",
# ResNet 的激活函数类型,字符串,默认为 "swish"
resnet_act_fn: str = "swish",
# 是否在前面进行规范化,布尔值,默认为 True
resnet_pre_norm: bool = True,
# 注意力头的维度,整型,默认为 1
attention_head_dim: int = 1,
# 输出缩放因子,浮点型,默认为平方根2
output_scale_factor: float = np.sqrt(2.0),
# 是否添加下采样层,布尔值,默认为 True
add_downsample: bool = True,
):
# 调用父类的初始化方法
super().__init__()
# 初始化一个空的模块列表用于存储注意力层
self.attentions = nn.ModuleList([])
# 初始化一个空的模块列表用于存储残差块
self.resnets = nn.ModuleList([])
# 检查 attention_head_dim 是否为 None
if attention_head_dim is None:
# 如果为 None,记录警告信息,并将其设置为输出通道数
logger.warning(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
)
attention_head_dim = out_channels
# 根据层数创建残差块和注意力层
for i in range(num_layers):
# 设置当前层的输入通道数,如果是第一层则使用 in_channels,否则使用 out_channels
in_channels = in_channels if i == 0 else out_channels
# 添加一个 ResnetBlock2D 到 resnets 列表中
self.resnets.append(
ResnetBlock2D(
in_channels=in_channels, # 输入通道数
out_channels=out_channels, # 输出通道数
temb_channels=temb_channels, # 时间嵌入通道数
eps=resnet_eps, # 小常数以防止除零
groups=min(in_channels // 4, 32), # 分组数
groups_out=min(out_channels // 4, 32), # 输出分组数
dropout=dropout, # dropout 概率
time_embedding_norm=resnet_time_scale_shift, # 时间嵌入归一化方式
non_linearity=resnet_act_fn, # 非线性激活函数
output_scale_factor=output_scale_factor, # 输出缩放因子
pre_norm=resnet_pre_norm, # 是否在残差块前进行归一化
)
)
# 添加一个 Attention 层到 attentions 列表中
self.attentions.append(
Attention(
out_channels, # 输出通道数
heads=out_channels // attention_head_dim, # 注意力头的数量
dim_head=attention_head_dim, # 每个注意力头的维度
rescale_output_factor=output_scale_factor, # 输出缩放因子
eps=resnet_eps, # 小常数以防止除零
norm_num_groups=32, # 归一化分组数
residual_connection=True, # 是否使用残差连接
bias=True, # 是否使用偏置
upcast_softmax=True, # 是否使用上溢出 softmax
_from_deprecated_attn_block=True, # 是否来自过时的注意力块
)
)
# 检查是否需要添加下采样层
if add_downsample:
# 创建一个 ResnetBlock2D 作为下采样层
self.resnet_down = ResnetBlock2D(
in_channels=out_channels, # 输入通道数
out_channels=out_channels, # 输出通道数
temb_channels=temb_channels, # 时间嵌入通道数
eps=resnet_eps, # 小常数以防止除零
groups=min(out_channels // 4, 32), # 分组数
dropout=dropout, # dropout 概率
time_embedding_norm=resnet_time_scale_shift, # 时间嵌入归一化方式
non_linearity=resnet_act_fn, # 非线性激活函数
output_scale_factor=output_scale_factor, # 输出缩放因子
pre_norm=resnet_pre_norm, # 是否在残差块前进行归一化
use_in_shortcut=True, # 是否在快捷连接中使用
down=True, # 是否进行下采样
kernel="fir", # 卷积核类型
)
# 创建下采样模块列表
self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
# 创建跳跃连接卷积层
self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
else:
# 如果不添加下采样层,则将相关属性设置为 None
self.resnet_down = None
self.downsamplers = None
self.skip_conv = None
# 前向传播方法
def forward(
self,
hidden_states: torch.Tensor, # 输入的隐藏状态
temb: Optional[torch.Tensor] = None, # 可选的时间嵌入
skip_sample: Optional[torch.Tensor] = None, # 可选的跳跃样本
*args, # 额外的位置参数
**kwargs, # 额外的关键字参数
# 定义返回类型为元组,包含张量和多个张量的元组
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...], torch.Tensor]:
# 检查传入的参数是否存在,或 kwargs 中的 scale 是否不为 None
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 定义弃用信息,说明 scale 参数将被忽略并将来会引发错误
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 调用弃用函数,记录 scale 参数的弃用信息
deprecate("scale", "1.0.0", deprecation_message)
# 初始化输出状态为一个空元组
output_states = ()
# 遍历 resnet 和 attention 的组合
for resnet, attn in zip(self.resnets, self.attentions):
# 使用 resnet 处理隐藏状态和时间嵌入
hidden_states = resnet(hidden_states, temb)
# 使用 attention 处理更新后的隐藏状态
hidden_states = attn(hidden_states)
# 将当前隐藏状态添加到输出状态元组中
output_states += (hidden_states,)
# 检查是否存在下采样器
if self.downsamplers is not None:
# 使用下采样网络处理隐藏状态
hidden_states = self.resnet_down(hidden_states, temb)
# 遍历每个下采样器并处理跳跃样本
for downsampler in self.downsamplers:
skip_sample = downsampler(skip_sample)
# 结合跳跃样本和隐藏状态,更新隐藏状态
hidden_states = self.skip_conv(skip_sample) + hidden_states
# 将当前隐藏状态添加到输出状态元组中
output_states += (hidden_states,)
# 返回更新后的隐藏状态、输出状态元组和跳跃样本
return hidden_states, output_states, skip_sample
# 定义一个二维跳过块的类,继承自 nn.Module
class SkipDownBlock2D(nn.Module):
# 初始化方法,设置输入和输出通道等参数
def __init__(
self,
in_channels: int, # 输入通道数
out_channels: int, # 输出通道数
temb_channels: int, # 时间嵌入通道数
dropout: float = 0.0, # dropout 概率
num_layers: int = 1, # 层数
resnet_eps: float = 1e-6, # ResNet 的 epsilon 参数
resnet_time_scale_shift: str = "default", # 时间缩放偏移方式
resnet_act_fn: str = "swish", # ResNet 激活函数
resnet_pre_norm: bool = True, # 是否在前面进行归一化
output_scale_factor: float = np.sqrt(2.0), # 输出缩放因子
add_downsample: bool = True, # 是否添加下采样层
downsample_padding: int = 1, # 下采样时的填充
):
super().__init__() # 调用父类构造函数
self.resnets = nn.ModuleList([]) # 初始化 ResNet 块列表
# 循环创建每一层的 ResNet 块
for i in range(num_layers):
# 第一层使用输入通道,后续层使用输出通道
in_channels = in_channels if i == 0 else out_channels
self.resnets.append(
ResnetBlock2D( # 添加 ResNet 块
in_channels=in_channels, # 输入通道数
out_channels=out_channels, # 输出通道数
temb_channels=temb_channels, # 时间嵌入通道数
eps=resnet_eps, # epsilon 参数
groups=min(in_channels // 4, 32), # 输入通道组数
groups_out=min(out_channels // 4, 32), # 输出通道组数
dropout=dropout, # dropout 概率
time_embedding_norm=resnet_time_scale_shift, # 时间嵌入归一化
non_linearity=resnet_act_fn, # 激活函数
output_scale_factor=output_scale_factor, # 输出缩放因子
pre_norm=resnet_pre_norm, # 前归一化
)
)
# 如果需要添加下采样层
if add_downsample:
self.resnet_down = ResnetBlock2D( # 创建下采样 ResNet 块
in_channels=out_channels, # 输入通道数
out_channels=out_channels, # 输出通道数
temb_channels=temb_channels, # 时间嵌入通道数
eps=resnet_eps, # epsilon 参数
groups=min(out_channels // 4, 32), # 输出通道组数
dropout=dropout, # dropout 概率
time_embedding_norm=resnet_time_scale_shift, # 时间嵌入归一化
non_linearity=resnet_act_fn, # 激活函数
output_scale_factor=output_scale_factor, # 输出缩放因子
pre_norm=resnet_pre_norm, # 前归一化
use_in_shortcut=True, # 使用短接
down=True, # 启用下采样
kernel="fir", # 指定卷积核类型
)
# 创建下采样模块列表
self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
# 创建跳过连接卷积层
self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
else: # 如果不添加下采样层
self.resnet_down = None # 不使用下采样 ResNet 块
self.downsamplers = None # 不使用下采样模块列表
self.skip_conv = None # 不使用跳过连接卷积层
# 前向传播方法
def forward(
self,
hidden_states: torch.Tensor, # 输入的隐藏状态
temb: Optional[torch.Tensor] = None, # 可选的时间嵌入
skip_sample: Optional[torch.Tensor] = None, # 可选的跳过样本
*args, # 可变位置参数
**kwargs, # 可变关键字参数
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...], torch.Tensor]: # 定义返回类型为元组,包含一个张量、一个张量元组和另一个张量
if len(args) > 0 or kwargs.get("scale", None) is not None: # 检查是否传入了位置参数或名为“scale”的关键字参数
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." # 定义废弃警告信息
deprecate("scale", "1.0.0", deprecation_message) # 调用 deprecate 函数,记录“scale”参数的废弃信息
output_states = () # 初始化输出状态为一个空元组
for resnet in self.resnets: # 遍历自定义的 ResNet 模型列表
hidden_states = resnet(hidden_states, temb) # 将当前的隐藏状态和时间嵌入传递给 ResNet,获取更新后的隐藏状态
output_states += (hidden_states,) # 将当前的隐藏状态添加到输出状态元组中
if self.downsamplers is not None: # 检查是否存在下采样模块
hidden_states = self.resnet_down(hidden_states, temb) # 使用 ResNet 下采样隐藏状态和时间嵌入
for downsampler in self.downsamplers: # 遍历下采样模块
skip_sample = downsampler(skip_sample) # 对跳过连接样本进行下采样
hidden_states = self.skip_conv(skip_sample) + hidden_states # 通过跳过卷积处理下采样样本,并与当前的隐藏状态相加
output_states += (hidden_states,) # 将更新后的隐藏状态添加到输出状态元组中
return hidden_states, output_states, skip_sample # 返回更新后的隐藏状态、输出状态元组和跳过样本
# 定义一个 2D ResNet 下采样块,继承自 nn.Module
class ResnetDownsampleBlock2D(nn.Module):
# 初始化函数,定义各层参数
def __init__(
self,
in_channels: int, # 输入通道数
out_channels: int, # 输出通道数
temb_channels: int, # 时间嵌入通道数
dropout: float = 0.0, # dropout 概率
num_layers: int = 1, # ResNet 层数
resnet_eps: float = 1e-6, # ResNet 中的 epsilon 值
resnet_time_scale_shift: str = "default", # 时间缩放偏移方式
resnet_act_fn: str = "swish", # 激活函数
resnet_groups: int = 32, # 组数
resnet_pre_norm: bool = True, # 是否使用预归一化
output_scale_factor: float = 1.0, # 输出缩放因子
add_downsample: bool = True, # 是否添加下采样层
skip_time_act: bool = False, # 是否跳过时间激活
):
# 调用父类构造函数
super().__init__()
resnets = [] # 初始化 ResNet 层列表
# 根据指定的层数构建 ResNet 块
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels # 确定输入通道数
resnets.append(
ResnetBlock2D(
in_channels=in_channels, # 输入通道数
out_channels=out_channels, # 输出通道数
temb_channels=temb_channels, # 时间嵌入通道数
eps=resnet_eps, # epsilon 值
groups=resnet_groups, # 组数
dropout=dropout, # dropout 概率
time_embedding_norm=resnet_time_scale_shift, # 时间嵌入归一化方式
non_linearity=resnet_act_fn, # 激活函数
output_scale_factor=output_scale_factor, # 输出缩放因子
pre_norm=resnet_pre_norm, # 是否使用预归一化
skip_time_act=skip_time_act, # 是否跳过时间激活
)
)
self.resnets = nn.ModuleList(resnets) # 将 ResNet 层转换为模块列表
# 如果需要,添加下采样层
if add_downsample:
self.downsamplers = nn.ModuleList(
[
ResnetBlock2D(
in_channels=out_channels, # 输入通道数
out_channels=out_channels, # 输出通道数
temb_channels=temb_channels, # 时间嵌入通道数
eps=resnet_eps, # epsilon 值
groups=resnet_groups, # 组数
dropout=dropout, # dropout 概率
time_embedding_norm=resnet_time_scale_shift, # 时间嵌入归一化方式
non_linearity=resnet_act_fn, # 激活函数
output_scale_factor=output_scale_factor, # 输出缩放因子
pre_norm=resnet_pre_norm, # 是否使用预归一化
skip_time_act=skip_time_act, # 是否跳过时间激活
down=True, # 指定为下采样层
)
]
)
else:
self.downsamplers = None # 如果不需要下采样层,则为 None
self.gradient_checkpointing = False # 初始化梯度检查点为 False
# 定义前向传播函数
def forward(
self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs # 前向传播输入
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
# 检查参数是否存在,或者是否提供了已弃用的 `scale` 参数
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 定义弃用消息,告知用户 `scale` 参数已弃用并将在未来的版本中引发错误
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 调用 deprecate 函数记录弃用信息
deprecate("scale", "1.0.0", deprecation_message)
# 初始化输出状态元组
output_states = ()
# 遍历所有的 ResNet 模型
for resnet in self.resnets:
# 如果处于训练模式并启用梯度检查点
if self.training and self.gradient_checkpointing:
# 定义一个函数用于创建自定义前向传播
def create_custom_forward(module):
# 定义自定义前向传播函数,调用传入的模块
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# 检查 PyTorch 版本是否大于等于 1.11.0
if is_torch_version(">=", "1.11.0"):
# 使用梯度检查点机制计算隐藏状态,禁用重入
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
# 在较旧版本的 PyTorch 中使用梯度检查点机制计算隐藏状态
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
# 在非训练模式下直接调用 ResNet 计算隐藏状态
hidden_states = resnet(hidden_states, temb)
# 将当前的隐藏状态添加到输出状态元组中
output_states = output_states + (hidden_states,)
# 检查是否存在下采样器
if self.downsamplers is not None:
# 遍历所有的下采样器
for downsampler in self.downsamplers:
# 使用下采样器计算隐藏状态
hidden_states = downsampler(hidden_states, temb)
# 将当前的隐藏状态添加到输出状态元组中
output_states = output_states + (hidden_states,)
# 返回最终的隐藏状态和输出状态元组
return hidden_states, output_states
# 定义一个简单的二维交叉注意力下采样块类,继承自 nn.Module
class SimpleCrossAttnDownBlock2D(nn.Module):
# 初始化方法,设置输入、输出通道等参数
def __init__(
# 输入通道数
in_channels: int,
# 输出通道数
out_channels: int,
# 时间嵌入通道数
temb_channels: int,
# dropout 概率,默认值为 0.0
dropout: float = 0.0,
# 层数,默认值为 1
num_layers: int = 1,
# ResNet 中的 epsilon 值,默认值为 1e-6
resnet_eps: float = 1e-6,
# ResNet 的时间尺度偏移设置,默认为 "default"
resnet_time_scale_shift: str = "default",
# ResNet 的激活函数类型,默认为 "swish"
resnet_act_fn: str = "swish",
# ResNet 的分组数,默认为 32
resnet_groups: int = 32,
# 是否在 ResNet 中使用预归一化,默认为 True
resnet_pre_norm: bool = True,
# 注意力头的维度,默认为 1
attention_head_dim: int = 1,
# 交叉注意力的维度,默认为 1280
cross_attention_dim: int = 1280,
# 输出缩放因子,默认为 1.0
output_scale_factor: float = 1.0,
# 是否添加下采样层,默认为 True
add_downsample: bool = True,
# 是否跳过时间激活,默认为 False
skip_time_act: bool = False,
# 是否仅使用交叉注意力,默认为 False
only_cross_attention: bool = False,
# 交叉注意力的归一化方式,默认为 None
cross_attention_norm: Optional[str] = None,
):
# 调用父类的初始化方法
super().__init__()
# 初始化是否有交叉注意力标志
self.has_cross_attention = True
# 初始化残差网络和注意力模块的列表
resnets = []
attentions = []
# 设置注意力头的维度
self.attention_head_dim = attention_head_dim
# 计算注意力头的数量
self.num_heads = out_channels // self.attention_head_dim
# 根据层数创建残差块
for i in range(num_layers):
# 设置输入通道,第一层使用给定的输入通道,其余层使用输出通道
in_channels = in_channels if i == 0 else out_channels
# 将残差块添加到列表中
resnets.append(
ResnetBlock2D(
in_channels=in_channels, # 输入通道数
out_channels=out_channels, # 输出通道数
temb_channels=temb_channels, # 时间嵌入通道数
eps=resnet_eps, # 残差网络中的 epsilon 值
groups=resnet_groups, # 分组数量
dropout=dropout, # dropout 概率
time_embedding_norm=resnet_time_scale_shift, # 时间嵌入规范化
non_linearity=resnet_act_fn, # 非线性激活函数
output_scale_factor=output_scale_factor, # 输出缩放因子
pre_norm=resnet_pre_norm, # 是否预归一化
skip_time_act=skip_time_act, # 是否跳过时间激活
)
)
# 根据是否有缩放点积注意力创建处理器
processor = (
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
)
# 将注意力模块添加到列表中
attentions.append(
Attention(
query_dim=out_channels, # 查询维度
cross_attention_dim=out_channels, # 交叉注意力维度
heads=self.num_heads, # 注意力头数量
dim_head=attention_head_dim, # 每个头的维度
added_kv_proj_dim=cross_attention_dim, # 额外的键值投影维度
norm_num_groups=resnet_groups, # 规范化的组数量
bias=True, # 是否使用偏置
upcast_softmax=True, # 是否上调 softmax
only_cross_attention=only_cross_attention, # 是否仅使用交叉注意力
cross_attention_norm=cross_attention_norm, # 交叉注意力的规范化
processor=processor, # 使用的处理器
)
)
# 将注意力模块列表转换为可训练模块
self.attentions = nn.ModuleList(attentions)
# 将残差块列表转换为可训练模块
self.resnets = nn.ModuleList(resnets)
# 如果需要添加下采样
if add_downsample:
# 创建下采样的残差块
self.downsamplers = nn.ModuleList(
[
ResnetBlock2D(
in_channels=out_channels, # 输入通道数
out_channels=out_channels, # 输出通道数
temb_channels=temb_channels, # 时间嵌入通道数
eps=resnet_eps, # 残差网络中的 epsilon 值
groups=resnet_groups, # 分组数量
dropout=dropout, # dropout 概率
time_embedding_norm=resnet_time_scale_shift, # 时间嵌入规范化
non_linearity=resnet_act_fn, # 非线性激活函数
output_scale_factor=output_scale_factor, # 输出缩放因子
pre_norm=resnet_pre_norm, # 是否预归一化
skip_time_act=skip_time_act, # 是否跳过时间激活
down=True, # 表示这是下采样
)
]
)
else:
# 如果不需要下采样,将下采样设置为 None
self.downsamplers = None
# 初始化梯度检查点标志为 False
self.gradient_checkpointing = False
# 定义一个前向传播方法,接受多个输入参数
def forward(
self,
# 输入的隐藏状态张量
hidden_states: torch.Tensor,
# 可选的时间嵌入张量
temb: Optional[torch.Tensor] = None,
# 可选的编码器隐藏状态张量
encoder_hidden_states: Optional[torch.Tensor] = None,
# 可选的注意力掩码张量
attention_mask: Optional[torch.Tensor] = None,
# 可选的交叉注意力参数字典
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# 可选的编码器注意力掩码张量
encoder_attention_mask: Optional[torch.Tensor] = None,
# 返回值类型为元组,包含一个张量和一个张量元组
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
# 如果未提供 cross_attention_kwargs,则使用空字典
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
# 检查是否传入了 scale 参数,若有则发出弃用警告
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# 初始化输出状态为一个空元组
output_states = ()
# 检查 attention_mask 是否为 None
if attention_mask is None:
# 如果 encoder_hidden_states 已定义,则进行交叉注意力,使用交叉注意力掩码
mask = None if encoder_hidden_states is None else encoder_attention_mask
else:
# 如果已定义 attention_mask,则直接使用,不检查 encoder_attention_mask
# 这为 UnCLIP 兼容性提供支持
# TODO: UnCLIP 应该通过 encoder_attention_mask 参数表达交叉注意力掩码
mask = attention_mask
# 遍历 ResNet 和注意力层
for resnet, attn in zip(self.resnets, self.attentions):
# 在训练中且开启了梯度检查点
if self.training and self.gradient_checkpointing:
# 定义一个自定义前向传播函数
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
# 根据 return_dict 决定返回方式
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
# 使用检查点进行前向传播,节省内存
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
# 执行注意力层
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=mask,
**cross_attention_kwargs,
)
else:
# 否则直接使用 ResNet 进行前向传播
hidden_states = resnet(hidden_states, temb)
# 执行注意力层
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=mask,
**cross_attention_kwargs,
)
# 将当前隐藏状态添加到输出状态元组中
output_states = output_states + (hidden_states,)
# 如果存在下采样层
if self.downsamplers is not None:
# 遍历所有下采样层
for downsampler in self.downsamplers:
# 执行下采样
hidden_states = downsampler(hidden_states, temb)
# 将下采样后的隐藏状态添加到输出状态元组中
output_states = output_states + (hidden_states,)
# 返回最终的隐藏状态和输出状态元组
return hidden_states, output_states
# 定义一个二维下采样的神经网络模块,继承自 nn.Module
class KDownBlock2D(nn.Module):
# 初始化方法,定义输入输出通道数、时间嵌入通道、dropout 概率等参数
def __init__(
self,
in_channels: int, # 输入通道数
out_channels: int, # 输出通道数
temb_channels: int, # 时间嵌入通道数
dropout: float = 0.0, # dropout 概率,默认为 0
num_layers: int = 4, # 残差层的数量,默认为 4
resnet_eps: float = 1e-5, # 残差层中的 epsilon 值,防止除零错误
resnet_act_fn: str = "gelu", # 残差层使用的激活函数,默认为 GELU
resnet_group_size: int = 32, # 残差层中组的大小
add_downsample: bool = False, # 是否添加下采样层的标志
):
# 调用父类的初始化方法
super().__init__()
resnets = [] # 初始化一个空列表用于存储残差块
# 根据层数构建残差块
for i in range(num_layers):
# 第一层使用输入通道,其他层使用输出通道
in_channels = in_channels if i == 0 else out_channels
# 计算组的数量
groups = in_channels // resnet_group_size
# 计算输出组的数量
groups_out = out_channels // resnet_group_size
# 创建残差块并添加到列表中
resnets.append(
ResnetBlockCondNorm2D(
in_channels=in_channels, # 当前层的输入通道数
out_channels=out_channels, # 当前层的输出通道数
dropout=dropout, # 当前层的 dropout 概率
temb_channels=temb_channels, # 时间嵌入通道数
groups=groups, # 当前层的组数量
groups_out=groups_out, # 输出层的组数量
eps=resnet_eps, # 残差层中的 epsilon 值
non_linearity=resnet_act_fn, # 残差层的激活函数
time_embedding_norm="ada_group", # 时间嵌入的归一化方式
conv_shortcut_bias=False, # 卷积快捷连接是否使用偏置
)
)
# 将残差块列表转换为 nn.ModuleList,以便于参数管理
self.resnets = nn.ModuleList(resnets)
# 根据标志决定是否添加下采样层
if add_downsample:
# 如果需要,创建一个下采样层并添加到列表中
self.downsamplers = nn.ModuleList([KDownsample2D()])
else:
# 如果不需要下采样,设置为 None
self.downsamplers = None
# 初始化梯度检查点标志为 False
self.gradient_checkpointing = False
# 前向传播方法,接收隐藏状态和时间嵌入(可选)
def forward(
self, hidden_states: torch.Tensor, # 输入的隐藏状态张量
temb: Optional[torch.Tensor] = None, # 可选的时间嵌入张量
*args, **kwargs # 其他可选参数
# 函数返回一个包含张量和元组的元组
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
# 检查是否有额外参数或“scale”关键字参数不为 None
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 定义关于“scale”参数的弃用消息
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 调用弃用函数记录“scale”的弃用
deprecate("scale", "1.0.0", deprecation_message)
# 初始化输出状态为一个空元组
output_states = ()
# 遍历所有的 ResNet 模块
for resnet in self.resnets:
# 如果处于训练模式且启用了梯度检查点
if self.training and self.gradient_checkpointing:
# 定义一个创建自定义前向传播的函数
def create_custom_forward(module):
# 定义自定义前向传播函数,接受任意输入并返回模块的输出
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# 如果 PyTorch 版本大于等于 1.11.0
if is_torch_version(">=", "1.11.0"):
# 使用检查点功能计算隐藏状态,传递自定义前向函数
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
# 否则使用检查点功能计算隐藏状态
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
# 如果不满足条件,则直接通过 ResNet 模块计算隐藏状态
hidden_states = resnet(hidden_states, temb)
# 将当前隐藏状态添加到输出状态元组中
output_states += (hidden_states,)
# 如果存在下采样器
if self.downsamplers is not None:
# 遍历每个下采样器
for downsampler in self.downsamplers:
# 通过下采样器计算隐藏状态
hidden_states = downsampler(hidden_states)
# 返回最终的隐藏状态和输出状态
return hidden_states, output_states
# 定义一个名为 KCrossAttnDownBlock2D 的类,继承自 nn.Module
class KCrossAttnDownBlock2D(nn.Module):
# 初始化方法,接受多个参数以设置模型的结构
def __init__(
self,
in_channels: int, # 输入通道数
out_channels: int, # 输出通道数
temb_channels: int, # 时间嵌入通道数
cross_attention_dim: int, # 跨注意力维度
dropout: float = 0.0, # dropout 概率,默认为 0
num_layers: int = 4, # 层数,默认为 4
resnet_group_size: int = 32, # ResNet 组的大小,默认为 32
add_downsample: bool = True, # 是否添加下采样,默认为 True
attention_head_dim: int = 64, # 注意力头维度,默认为 64
add_self_attention: bool = False, # 是否添加自注意力,默认为 False
resnet_eps: float = 1e-5, # ResNet 的 epsilon 值,默认为 1e-5
resnet_act_fn: str = "gelu", # ResNet 的激活函数类型,默认为 "gelu"
):
# 调用父类的初始化方法
super().__init__()
# 初始化空列表以存放 ResNet 块
resnets = []
# 初始化空列表以存放注意力块
attentions = []
# 设置是否包含跨注意力标志
self.has_cross_attention = True
# 创建指定数量的层
for i in range(num_layers):
# 第一层的输入通道数为 in_channels,之后的层使用 out_channels
in_channels = in_channels if i == 0 else out_channels
# 计算组数
groups = in_channels // resnet_group_size
groups_out = out_channels // resnet_group_size
# 将 ResnetBlockCondNorm2D 添加到 resnets 列表
resnets.append(
ResnetBlockCondNorm2D(
in_channels=in_channels, # 输入通道数
out_channels=out_channels, # 输出通道数
dropout=dropout, # dropout 概率
temb_channels=temb_channels, # 时间嵌入通道数
groups=groups, # 组数
groups_out=groups_out, # 输出组数
eps=resnet_eps, # epsilon 值
non_linearity=resnet_act_fn, # 激活函数
time_embedding_norm="ada_group", # 时间嵌入归一化类型
conv_shortcut_bias=False, # 是否使用卷积快捷连接偏置
)
)
# 将 KAttentionBlock 添加到 attentions 列表
attentions.append(
KAttentionBlock(
out_channels, # 输出通道数
out_channels // attention_head_dim, # 注意力头数量
attention_head_dim, # 注意力头维度
cross_attention_dim=cross_attention_dim, # 跨注意力维度
temb_channels=temb_channels, # 时间嵌入通道数
attention_bias=True, # 是否使用注意力偏置
add_self_attention=add_self_attention, # 是否添加自注意力
cross_attention_norm="layer_norm", # 跨注意力归一化类型
group_size=resnet_group_size, # 组大小
)
)
# 将 resnets 列表转换为 nn.ModuleList,以便可以在模型中使用
self.resnets = nn.ModuleList(resnets)
# 将 attentions 列表转换为 nn.ModuleList
self.attentions = nn.ModuleList(attentions)
# 根据参数决定是否添加下采样层
if add_downsample:
# 添加下采样模块
self.downsamplers = nn.ModuleList([KDownsample2D()])
else:
# 如果不添加下采样,设置为 None
self.downsamplers = None
# 初始化梯度检查点标志为 False
self.gradient_checkpointing = False
# 前向传播方法,定义输入和输出
def forward(
self,
hidden_states: torch.Tensor, # 隐藏状态输入
temb: Optional[torch.Tensor] = None, # 可选的时间嵌入
encoder_hidden_states: Optional[torch.Tensor] = None, # 可选的编码器隐藏状态
attention_mask: Optional[torch.Tensor] = None, # 可选的注意力掩码
cross_attention_kwargs: Optional[Dict[str, Any]] = None, # 可选的跨注意力参数
encoder_attention_mask: Optional[torch.Tensor] = None, # 可选的编码器注意力掩码
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
# 如果没有传入 cross_attention_kwargs,则初始化为空字典
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
# 检查 cross_attention_kwargs 中是否存在 "scale" 参数,并发出警告
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# 初始化输出状态为一个空元组
output_states = ()
# 遍历 resnets 和 attentions 的对应元素
for resnet, attn in zip(self.resnets, self.attentions):
# 如果处于训练模式且开启了梯度检查点
if self.training and self.gradient_checkpointing:
# 创建自定义前向传播函数
def create_custom_forward(module, return_dict=None):
# 定义自定义前向传播逻辑
def custom_forward(*inputs):
# 如果指定了返回字典,则返回包含字典的结果
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
# 否则返回普通结果
return module(*inputs)
return custom_forward
# 设置检查点参数,针对 PyTorch 版本进行不同处理
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
# 使用检查点机制计算隐藏状态
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), # 传入自定义前向函数
hidden_states, # 输入隐藏状态
temb, # 传入时间嵌入
**ckpt_kwargs, # 传入检查点参数
)
# 使用注意力机制更新隐藏状态
hidden_states = attn(
hidden_states, # 输入隐藏状态
encoder_hidden_states=encoder_hidden_states, # 编码器隐藏状态
emb=temb, # 时间嵌入
attention_mask=attention_mask, # 注意力掩码
cross_attention_kwargs=cross_attention_kwargs, # 交叉注意力参数
encoder_attention_mask=encoder_attention_mask, # 编码器注意力掩码
)
else:
# 如果不是训练模式或没有使用梯度检查点,直接通过 ResNet 更新隐藏状态
hidden_states = resnet(hidden_states, temb)
# 使用注意力机制更新隐藏状态
hidden_states = attn(
hidden_states, # 输入隐藏状态
encoder_hidden_states=encoder_hidden_states, # 编码器隐藏状态
emb=temb, # 时间嵌入
attention_mask=attention_mask, # 注意力掩码
cross_attention_kwargs=cross_attention_kwargs, # 交叉注意力参数
encoder_attention_mask=encoder_attention_mask, # 编码器注意力掩码
)
# 如果没有下采样层,输出状态添加 None
if self.downsamplers is None:
output_states += (None,)
else:
# 否则将当前隐藏状态添加到输出状态
output_states += (hidden_states,)
# 如果存在下采样层,则依次对隐藏状态进行下采样
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
# 返回最终的隐藏状态和输出状态
return hidden_states, output_states
# 定义一个名为 AttnUpBlock2D 的类,继承自 nn.Module
class AttnUpBlock2D(nn.Module):
# 初始化方法,接受多个参数以配置该模块
def __init__(
# 输入通道数
self,
in_channels: int,
# 前一层输出通道数
prev_output_channel: int,
# 输出通道数
out_channels: int,
# 嵌入通道数
temb_channels: int,
# 分辨率索引,默认为 None
resolution_idx: int = None,
# dropout 率,默认为 0.0
dropout: float = 0.0,
# 层数,默认为 1
num_layers: int = 1,
# ResNet 中的小常数,避免除零
resnet_eps: float = 1e-6,
# ResNet 时间缩放偏移,默认为 "default"
resnet_time_scale_shift: str = "default",
# ResNet 激活函数,默认为 "swish"
resnet_act_fn: str = "swish",
# ResNet 组数,默认为 32
resnet_groups: int = 32,
# 是否在 ResNet 中使用预归一化,默认为 True
resnet_pre_norm: bool = True,
# 注意力头维度,默认为 1
attention_head_dim: int = 1,
# 输出缩放因子,默认为 1.0
output_scale_factor: float = 1.0,
# 上采样类型,默认为 "conv"
upsample_type: str = "conv",
):
# 调用父类的初始化方法
super().__init__()
# 初始化空列表用于存储 ResNet 块
resnets = []
# 初始化空列表用于存储注意力层
attentions = []
# 设置上采样类型
self.upsample_type = upsample_type
# 如果没有传入注意力头维度
if attention_head_dim is None:
# 记录警告,建议使用默认的头维度
logger.warning(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
)
# 将注意力头维度设置为输出通道数
attention_head_dim = out_channels
# 遍历每一层
for i in range(num_layers):
# 设置残差跳过通道数,最后一层使用输入通道
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
# 设置当前 ResNet 块的输入通道数
resnet_in_channels = prev_output_channel if i == 0 else out_channels
# 创建 ResNet 块并添加到列表
resnets.append(
ResnetBlock2D(
# 输入通道数为当前 ResNet 的输入通道加上跳过的通道
in_channels=resnet_in_channels + res_skip_channels,
# 输出通道数
out_channels=out_channels,
# 时间嵌入通道数
temb_channels=temb_channels,
# 余弦相似性小常数
eps=resnet_eps,
# 分组数
groups=resnet_groups,
# dropout 概率
dropout=dropout,
# 时间嵌入的归一化方式
time_embedding_norm=resnet_time_scale_shift,
# 非线性激活函数
non_linearity=resnet_act_fn,
# 输出缩放因子
output_scale_factor=output_scale_factor,
# 是否进行预归一化
pre_norm=resnet_pre_norm,
)
)
# 创建注意力层并添加到列表
attentions.append(
Attention(
# 输出通道数
out_channels,
# 注意力头的数量
heads=out_channels // attention_head_dim,
# 每个头的维度
dim_head=attention_head_dim,
# 输出重缩放因子
rescale_output_factor=output_scale_factor,
# 余弦相似性小常数
eps=resnet_eps,
# 归一化的组数
norm_num_groups=resnet_groups,
# 是否使用残差连接
residual_connection=True,
# 是否使用偏置
bias=True,
# 是否上升 softmax 的精度
upcast_softmax=True,
# 是否从已弃用的注意力块中获取
_from_deprecated_attn_block=True,
)
)
# 将注意力层列表转换为 nn.ModuleList,以便于管理
self.attentions = nn.ModuleList(attentions)
# 将 ResNet 块列表转换为 nn.ModuleList
self.resnets = nn.ModuleList(resnets)
# 根据上采样类型选择上采样方法
if upsample_type == "conv":
# 使用卷积上采样,并创建 ModuleList
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
elif upsample_type == "resnet":
# 使用 ResNet 块进行上采样,并创建 ModuleList
self.upsamplers = nn.ModuleList(
[
ResnetBlock2D(
# 输入通道数
in_channels=out_channels,
# 输出通道数
out_channels=out_channels,
# 时间嵌入通道数
temb_channels=temb_channels,
# 余弦相似性小常数
eps=resnet_eps,
# 分组数
groups=resnet_groups,
# dropout 概率
dropout=dropout,
# 时间嵌入的归一化方式
time_embedding_norm=resnet_time_scale_shift,
# 非线性激活函数
non_linearity=resnet_act_fn,
# 输出缩放因子
output_scale_factor=output_scale_factor,
# 是否进行预归一化
pre_norm=resnet_pre_norm,
# 表示这是一个上采样的块
up=True,
)
]
)
else:
# 如果上采样类型无效,则设为 None
self.upsamplers = None
# 存储当前分辨率索引
self.resolution_idx = resolution_idx
# 定义前向传播函数,接收隐藏状态及其他参数
def forward(
self,
hidden_states: torch.Tensor, # 当前的隐藏状态张量
res_hidden_states_tuple: Tuple[torch.Tensor, ...], # 之前的隐藏状态元组
temb: Optional[torch.Tensor] = None, # 可选的时间嵌入张量
upsample_size: Optional[int] = None, # 可选的上采样大小
*args, # 额外的位置参数
**kwargs, # 额外的关键字参数
) -> torch.Tensor: # 返回类型为张量
# 检查是否传入了多余的参数或已弃用的 scale 参数
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 设置弃用警告信息
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 调用弃用函数发出警告
deprecate("scale", "1.0.0", deprecation_message)
# 遍历每一对残差网络和注意力层
for resnet, attn in zip(self.resnets, self.attentions):
# 从元组中弹出最后一个残差隐藏状态
res_hidden_states = res_hidden_states_tuple[-1]
# 更新残差隐藏状态元组,去掉最后一个元素
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# 将当前隐藏状态和残差隐藏状态在维度1上拼接
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
# 将拼接后的隐藏状态传入残差网络
hidden_states = resnet(hidden_states, temb)
# 将输出的隐藏状态传入注意力层
hidden_states = attn(hidden_states)
# 检查是否存在上采样器
if self.upsamplers is not None:
# 遍历每个上采样器
for upsampler in self.upsamplers:
# 根据上采样类型选择处理方式
if self.upsample_type == "resnet":
# 将隐藏状态传入上采样器并提供时间嵌入
hidden_states = upsampler(hidden_states, temb=temb)
else:
# 将隐藏状态传入上采样器
hidden_states = upsampler(hidden_states)
# 返回处理后的隐藏状态
return hidden_states
# 定义一个名为 CrossAttnUpBlock2D 的类,继承自 nn.Module
class CrossAttnUpBlock2D(nn.Module):
# 初始化方法,定义该类的构造函数
def __init__(
# 输入通道数
self,
in_channels: int,
# 输出通道数
out_channels: int,
# 前一层输出通道数
prev_output_channel: int,
# 时间嵌入通道数
temb_channels: int,
# 可选的分辨率索引
resolution_idx: Optional[int] = None,
# Dropout 概率
dropout: float = 0.0,
# 层数
num_layers: int = 1,
# 每个块的 Transformer 层数,可以是单个整数或元组
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
# ResNet 的 epsilon 值,避免除零错误
resnet_eps: float = 1e-6,
# ResNet 的时间尺度偏移参数
resnet_time_scale_shift: str = "default",
# ResNet 的激活函数类型
resnet_act_fn: str = "swish",
# ResNet 的组数
resnet_groups: int = 32,
# 是否使用预归一化
resnet_pre_norm: bool = True,
# 注意力头的数量
num_attention_heads: int = 1,
# 跨注意力的维度
cross_attention_dim: int = 1280,
# 输出缩放因子
output_scale_factor: float = 1.0,
# 是否添加上采样层
add_upsample: bool = True,
# 是否使用双重跨注意力
dual_cross_attention: bool = False,
# 是否使用线性投影
use_linear_projection: bool = False,
# 是否仅使用跨注意力
only_cross_attention: bool = False,
# 是否提升注意力计算精度
upcast_attention: bool = False,
# 注意力类型
attention_type: str = "default",
# 继承父类的初始化方法
):
super().__init__()
# 初始化空列表用于存储 ResNet 和注意力模块
resnets = []
attentions = []
# 设置是否有交叉注意力的标志
self.has_cross_attention = True
# 设置注意力头的数量
self.num_attention_heads = num_attention_heads
# 如果输入的是整数,则将其转换为包含多个相同值的列表
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
# 遍历每一层,构建 ResNet 和注意力模块
for i in range(num_layers):
# 设置残差跳跃通道数量
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
# 设置 ResNet 输入通道
resnet_in_channels = prev_output_channel if i == 0 else out_channels
# 将 ResNet 模块添加到列表中
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels, # 输入通道数
out_channels=out_channels, # 输出通道数
temb_channels=temb_channels, # 时间嵌入通道数
eps=resnet_eps, # 小常数用于数值稳定性
groups=resnet_groups, # 分组数
dropout=dropout, # Dropout 率
time_embedding_norm=resnet_time_scale_shift, # 时间嵌入的归一化方式
non_linearity=resnet_act_fn, # 非线性激活函数
output_scale_factor=output_scale_factor, # 输出缩放因子
pre_norm=resnet_pre_norm, # 是否进行预归一化
)
)
# 根据是否启用双重交叉注意力,选择不同的注意力模块
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads, # 注意力头数量
out_channels // num_attention_heads, # 每个头的输出通道数
in_channels=out_channels, # 输入通道数
num_layers=transformer_layers_per_block[i], # 当前层的变换器层数
cross_attention_dim=cross_attention_dim, # 交叉注意力维度
norm_num_groups=resnet_groups, # 归一化分组数
use_linear_projection=use_linear_projection, # 是否使用线性投影
only_cross_attention=only_cross_attention, # 是否仅使用交叉注意力
upcast_attention=upcast_attention, # 是否上调注意力精度
attention_type=attention_type, # 注意力类型
)
)
else:
attentions.append(
DualTransformer2DModel(
num_attention_heads, # 注意力头数量
out_channels // num_attention_heads, # 每个头的输出通道数
in_channels=out_channels, # 输入通道数
num_layers=1, # 仅使用一层
cross_attention_dim=cross_attention_dim, # 交叉注意力维度
norm_num_groups=resnet_groups, # 归一化分组数
)
)
# 将注意力模块和 ResNet 模块转换为 nn.ModuleList 以便于管理
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
# 根据是否添加上采样层初始化上采样模块
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
# 初始化梯度检查点标志
self.gradient_checkpointing = False
# 设置分辨率索引
self.resolution_idx = resolution_idx
# 定义前向传播函数,接收多个参数
def forward(
self,
# 当前隐藏状态的张量
hidden_states: torch.Tensor,
# 元组,包含残差隐藏状态的张量
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
# 可选的时间嵌入张量
temb: Optional[torch.Tensor] = None,
# 可选的编码器隐藏状态张量
encoder_hidden_states: Optional[torch.Tensor] = None,
# 可选的跨注意力参数字典
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# 可选的上采样大小
upsample_size: Optional[int] = None,
# 可选的注意力掩码张量
attention_mask: Optional[torch.Tensor] = None,
# 可选的编码器注意力掩码张量
encoder_attention_mask: Optional[torch.Tensor] = None,
# 定义一个名为 UpBlock2D 的类,继承自 nn.Module
class UpBlock2D(nn.Module):
# 初始化方法,接受多个参数来构造 UpBlock2D 对象
def __init__(
self,
in_channels: int, # 输入通道数
prev_output_channel: int, # 前一层输出的通道数
out_channels: int, # 输出通道数
temb_channels: int, # 时间嵌入通道数
resolution_idx: Optional[int] = None, # 分辨率索引,默认为 None
dropout: float = 0.0, # dropout 概率,默认为 0.0
num_layers: int = 1, # 层数,默认为 1
resnet_eps: float = 1e-6, # ResNet 的 epsilon 值,默认为 1e-6
resnet_time_scale_shift: str = "default", # ResNet 的时间缩放偏移,默认为 "default"
resnet_act_fn: str = "swish", # ResNet 的激活函数,默认为 "swish"
resnet_groups: int = 32, # ResNet 的组数,默认为 32
resnet_pre_norm: bool = True, # 是否进行预归一化,默认为 True
output_scale_factor: float = 1.0, # 输出缩放因子,默认为 1.0
add_upsample: bool = True, # 是否添加上采样层,默认为 True
):
# 调用父类的初始化方法
super().__init__()
resnets = [] # 初始化一个空列表用于存储 ResNet 块
# 根据 num_layers 创建 ResNet 块
for i in range(num_layers):
# 设置跳过通道数,如果是最后一层,则使用 in_channels,否则使用 out_channels
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
# 设置 ResNet 输入通道数,第一层使用 prev_output_channel,其余层使用 out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
# 将 ResNet 块添加到列表中
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels, # 输入通道数加上跳过通道数
out_channels=out_channels, # 输出通道数
temb_channels=temb_channels, # 时间嵌入通道数
eps=resnet_eps, # epsilon 值
groups=resnet_groups, # 组数
dropout=dropout, # dropout 概率
time_embedding_norm=resnet_time_scale_shift, # 时间嵌入归一化
non_linearity=resnet_act_fn, # 非线性激活函数
output_scale_factor=output_scale_factor, # 输出缩放因子
pre_norm=resnet_pre_norm, # 预归一化
)
)
# 将 ResNet 块列表转换为 nn.ModuleList 以便于管理
self.resnets = nn.ModuleList(resnets)
# 如果需要添加上采样层,则初始化上采样模块
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None # 不添加上采样层,设置为 None
self.gradient_checkpointing = False # 初始化梯度检查点标志为 False
self.resolution_idx = resolution_idx # 保存分辨率索引
# 定义前向传播方法,接受输入的隐藏状态及其他参数
def forward(
self,
hidden_states: torch.Tensor, # 输入的隐藏状态张量
res_hidden_states_tuple: Tuple[torch.Tensor, ...], # 之前层的隐藏状态元组
temb: Optional[torch.Tensor] = None, # 可选的时间嵌入张量
upsample_size: Optional[int] = None, # 可选的上采样大小
*args, # 其他位置参数
**kwargs, # 其他关键字参数
) -> torch.Tensor:
# 检查是否传入参数或 'scale' 参数
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 设置弃用消息,提示用户 'scale' 参数已被弃用
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 调用 deprecate 函数记录弃用
deprecate("scale", "1.0.0", deprecation_message)
# 检查 FreeU 是否启用
is_freeu_enabled = (
getattr(self, "s1", None) # 获取属性 s1
and getattr(self, "s2", None) # 获取属性 s2
and getattr(self, "b1", None) # 获取属性 b1
and getattr(self, "b2", None) # 获取属性 b2
)
# 遍历所有 ResNet 模型
for resnet in self.resnets:
# 弹出最后的 ResNet 隐藏状态
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1] # 更新元组,去掉最后一个元素
# 如果 FreeU 被启用,则仅对前两个阶段进行操作
if is_freeu_enabled:
# 调用 apply_freeu 函数处理隐藏状态
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx, # 当前分辨率索引
hidden_states, # 当前隐藏状态
res_hidden_states, # ResNet 隐藏状态
s1=self.s1, # s1 属性
s2=self.s2, # s2 属性
b1=self.b1, # b1 属性
b2=self.b2, # b2 属性
)
# 连接当前隐藏状态和 ResNet 隐藏状态
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
# 如果处于训练模式且启用梯度检查点
if self.training and self.gradient_checkpointing:
# 创建自定义前向传播函数
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs) # 调用模块的前向传播
return custom_forward
# 根据 PyTorch 版本选择检查点方式
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), # 使用自定义前向函数
hidden_states, # 当前隐藏状态
temb, # 时间嵌入
use_reentrant=False # 不使用可重入检查点
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), # 使用自定义前向函数
hidden_states, # 当前隐藏状态
temb # 时间嵌入
)
else:
# 如果不启用检查点,直接调用 ResNet
hidden_states = resnet(hidden_states, temb)
# 如果存在上采样器,则遍历进行上采样
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size) # 调用上采样器
# 返回最终的隐藏状态
return hidden_states
# 定义一个 2D 上采样解码块类,继承自 nn.Module
class UpDecoderBlock2D(nn.Module):
# 初始化方法,接受多个参数用于构造解码块
def __init__(
self,
in_channels: int, # 输入通道数
out_channels: int, # 输出通道数
resolution_idx: Optional[int] = None, # 分辨率索引,默认为 None
dropout: float = 0.0, # dropout 概率,默认为 0
num_layers: int = 1, # 层数,默认为 1
resnet_eps: float = 1e-6, # ResNet 的 epsilon 值,默认为 1e-6
resnet_time_scale_shift: str = "default", # 时间尺度偏移类型,默认为 "default"
resnet_act_fn: str = "swish", # ResNet 的激活函数,默认为 "swish"
resnet_groups: int = 32, # ResNet 的分组数,默认为 32
resnet_pre_norm: bool = True, # 是否在 ResNet 中预先归一化,默认为 True
output_scale_factor: float = 1.0, # 输出缩放因子,默认为 1.0
add_upsample: bool = True, # 是否添加上采样层,默认为 True
temb_channels: Optional[int] = None, # 时间嵌入通道数,默认为 None
):
# 调用父类构造方法
super().__init__()
# 初始化一个空的 ResNet 列表
resnets = []
# 根据层数创建相应数量的 ResNet 层
for i in range(num_layers):
# 第一个层使用输入通道数,其余层使用输出通道数
input_channels = in_channels if i == 0 else out_channels
# 根据时间尺度偏移类型创建不同的 ResNet 块
if resnet_time_scale_shift == "spatial":
resnets.append(
ResnetBlockCondNorm2D( # 添加条件归一化的 ResNet 块
in_channels=input_channels, # 输入通道数
out_channels=out_channels, # 输出通道数
temb_channels=temb_channels, # 时间嵌入通道数
eps=resnet_eps, # epsilon 值
groups=resnet_groups, # 分组数
dropout=dropout, # dropout 概率
time_embedding_norm="spatial", # 时间嵌入归一化类型
non_linearity=resnet_act_fn, # 激活函数
output_scale_factor=output_scale_factor, # 输出缩放因子
)
)
else:
resnets.append(
ResnetBlock2D( # 添加普通的 ResNet 块
in_channels=input_channels, # 输入通道数
out_channels=out_channels, # 输出通道数
temb_channels=temb_channels, # 时间嵌入通道数
eps=resnet_eps, # epsilon 值
groups=resnet_groups, # 分组数
dropout=dropout, # dropout 概率
time_embedding_norm=resnet_time_scale_shift, # 时间嵌入归一化类型
non_linearity=resnet_act_fn, # 激活函数
output_scale_factor=output_scale_factor, # 输出缩放因子
pre_norm=resnet_pre_norm, # 是否预先归一化
)
)
# 将创建的 ResNet 块存储在 ModuleList 中,以便于管理
self.resnets = nn.ModuleList(resnets)
# 根据是否添加上采样层初始化上采样层列表
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None # 如果不添加,则设为 None
# 存储分辨率索引
self.resolution_idx = resolution_idx
# 定义前向传播方法
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
# 遍历所有 ResNet 层进行前向传播
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb) # 更新隐藏状态
# 如果存在上采样层,则遍历进行上采样
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states) # 更新隐藏状态
# 返回最终的隐藏状态
return hidden_states
# 定义一个注意力上采样解码块类,继承自 nn.Module
class AttnUpDecoderBlock2D(nn.Module):
# 初始化类的构造函数,定义各参数
def __init__(
# 输入通道数,决定输入数据的特征维度
self,
in_channels: int,
# 输出通道数,决定输出数据的特征维度
out_channels: int,
# 分辨率索引,选择特定分辨率(可选)
resolution_idx: Optional[int] = None,
# dropout比率,用于防止过拟合,默认为0
dropout: float = 0.0,
# 网络层数,决定模型的深度,默认为1
num_layers: int = 1,
# ResNet的epsilon值,防止分母为0的情况,默认为1e-6
resnet_eps: float = 1e-6,
# ResNet的时间尺度偏移设置,默认为"default"
resnet_time_scale_shift: str = "default",
# ResNet的激活函数类型,默认为"swish"
resnet_act_fn: str = "swish",
# ResNet中的组数,影响计算和模型复杂度,默认为32
resnet_groups: int = 32,
# 是否在ResNet中使用预归一化,默认为True
resnet_pre_norm: bool = True,
# 注意力头的维度,影响注意力机制的计算,默认为1
attention_head_dim: int = 1,
# 输出缩放因子,调整输出的大小,默认为1.0
output_scale_factor: float = 1.0,
# 是否添加上采样层,影响模型的结构,默认为True
add_upsample: bool = True,
# 时间嵌入通道数(可选),用于特定的时间信息表示
temb_channels: Optional[int] = None,
):
# 调用父类的初始化方法
super().__init__()
# 初始化存储残差块的列表
resnets = []
# 初始化存储注意力层的列表
attentions = []
# 如果未指定注意力头维度,则发出警告并使用输出通道数作为默认值
if attention_head_dim is None:
logger.warning(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
)
# 将注意力头维度设置为输出通道数
attention_head_dim = out_channels
# 遍历每一层以构建残差块和注意力层
for i in range(num_layers):
# 如果是第一层,则输入通道为输入通道数,否则为输出通道数
input_channels = in_channels if i == 0 else out_channels
# 如果时间尺度偏移为"spatial",则使用条件归一化的残差块
if resnet_time_scale_shift == "spatial":
resnets.append(
ResnetBlockCondNorm2D(
# 输入通道数
in_channels=input_channels,
# 输出通道数
out_channels=out_channels,
# 时间嵌入通道数
temb_channels=temb_channels,
# 残差块的epsilon参数
eps=resnet_eps,
# 组归一化的组数
groups=resnet_groups,
# dropout比例
dropout=dropout,
# 时间嵌入的归一化方式
time_embedding_norm="spatial",
# 非线性激活函数
non_linearity=resnet_act_fn,
# 输出缩放因子
output_scale_factor=output_scale_factor,
)
)
else:
# 否则使用普通的2D残差块
resnets.append(
ResnetBlock2D(
# 输入通道数
in_channels=input_channels,
# 输出通道数
out_channels=out_channels,
# 时间嵌入通道数
temb_channels=temb_channels,
# 残差块的epsilon参数
eps=resnet_eps,
# 组归一化的组数
groups=resnet_groups,
# dropout比例
dropout=dropout,
# 时间嵌入的归一化方式
time_embedding_norm=resnet_time_scale_shift,
# 非线性激活函数
non_linearity=resnet_act_fn,
# 输出缩放因子
output_scale_factor=output_scale_factor,
# 是否使用预归一化
pre_norm=resnet_pre_norm,
)
)
# 添加注意力层
attentions.append(
Attention(
# 输出通道数
out_channels,
# 计算注意力头数
heads=out_channels // attention_head_dim,
# 注意力头维度
dim_head=attention_head_dim,
# 输出缩放因子
rescale_output_factor=output_scale_factor,
# epsilon参数
eps=resnet_eps,
# 组归一化的组数(如果不是空间归一化)
norm_num_groups=resnet_groups if resnet_time_scale_shift != "spatial" else None,
# 空间归一化维度(如果是空间归一化)
spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
# 是否使用残差连接
residual_connection=True,
# 是否使用偏置
bias=True,
# 是否使用上采样的softmax
upcast_softmax=True,
# 从过时的注意力块创建
_from_deprecated_attn_block=True,
)
)
# 将注意力层转换为模块列表
self.attentions = nn.ModuleList(attentions)
# 将残差块转换为模块列表
self.resnets = nn.ModuleList(resnets)
# 如果需要添加上采样层
if add_upsample:
# 创建上采样层的模块列表
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
# 如果不需要上采样,则将其设置为None
self.upsamplers = None
# 设置分辨率索引
self.resolution_idx = resolution_idx
# 定义前向传播函数,接收隐藏状态和可选的时间嵌入,返回处理后的隐藏状态
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
# 遍历残差网络和注意力模块的组合
for resnet, attn in zip(self.resnets, self.attentions):
# 将当前隐藏状态输入到残差网络中,可能包含时间嵌入
hidden_states = resnet(hidden_states, temb=temb)
# 将残差网络的输出输入到注意力模块中,可能包含时间嵌入
hidden_states = attn(hidden_states, temb=temb)
# 如果上采样模块不为空,则执行上采样操作
if self.upsamplers is not None:
# 遍历所有上采样模块
for upsampler in self.upsamplers:
# 将当前隐藏状态输入到上采样模块中
hidden_states = upsampler(hidden_states)
# 返回最终处理后的隐藏状态
return hidden_states
# 定义一个名为 AttnSkipUpBlock2D 的类,继承自 nn.Module
class AttnSkipUpBlock2D(nn.Module):
# 初始化方法,定义该类的属性
def __init__(
# 输入通道数
in_channels: int,
# 前一个输出通道数
prev_output_channel: int,
# 输出通道数
out_channels: int,
# 时间嵌入通道数
temb_channels: int,
# 可选的分辨率索引
resolution_idx: Optional[int] = None,
# dropout 概率
dropout: float = 0.0,
# 层数
num_layers: int = 1,
# ResNet 的 epsilon 值
resnet_eps: float = 1e-6,
# ResNet 时间缩放偏移设置
resnet_time_scale_shift: str = "default",
# ResNet 激活函数类型
resnet_act_fn: str = "swish",
# 是否使用 ResNet 预归一化
resnet_pre_norm: bool = True,
# 注意力头的维度
attention_head_dim: int = 1,
# 输出缩放因子
output_scale_factor: float = np.sqrt(2.0),
# 是否添加上采样
add_upsample: bool = True,
):
# 初始化父类
super().__init__()
# 此处应有具体的初始化代码(如层的定义),略去
# 定义前向传播方法
def forward(
# 隐藏状态输入
hidden_states: torch.Tensor,
# 之前的隐藏状态元组
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
# 可选的时间嵌入
temb: Optional[torch.Tensor] = None,
# 可选的跳跃样本
skip_sample=None,
# 额外参数
*args,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 检查 args 和 kwargs 是否包含 scale 参数
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 定义弃用信息
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 调用弃用函数
deprecate("scale", "1.0.0", deprecation_message)
# 遍历 ResNet 层
for resnet in self.resnets:
# 从元组中提取最近的隐藏状态
res_hidden_states = res_hidden_states_tuple[-1]
# 更新隐藏状态元组,去掉最近的隐藏状态
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# 将当前隐藏状态与提取的隐藏状态拼接在一起
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
# 通过 ResNet 层处理隐藏状态
hidden_states = resnet(hidden_states, temb)
# 通过注意力层处理隐藏状态
hidden_states = self.attentions[0](hidden_states)
# 检查跳跃样本是否为 None
if skip_sample is not None:
# 如果不是,则通过上采样层处理跳跃样本
skip_sample = self.upsampler(skip_sample)
else:
# 如果是,则将跳跃样本设为 0
skip_sample = 0
# 检查 ResNet 上采样层是否存在
if self.resnet_up is not None:
# 通过跳跃归一化层处理隐藏状态
skip_sample_states = self.skip_norm(hidden_states)
# 应用激活函数
skip_sample_states = self.act(skip_sample_states)
# 通过跳跃卷积层处理状态
skip_sample_states = self.skip_conv(skip_sample_states)
# 更新跳跃样本
skip_sample = skip_sample + skip_sample_states
# 通过 ResNet 上采样层处理隐藏状态
hidden_states = self.resnet_up(hidden_states, temb)
# 返回处理后的隐藏状态和跳跃样本
return hidden_states, skip_sample
# 定义一个名为 SkipUpBlock2D 的类,继承自 nn.Module
class SkipUpBlock2D(nn.Module):
# 初始化方法,定义该类的属性
def __init__(
# 输入通道数
in_channels: int,
# 前一个输出通道数
prev_output_channel: int,
# 输出通道数
out_channels: int,
# 时间嵌入通道数
temb_channels: int,
# 可选的分辨率索引
resolution_idx: Optional[int] = None,
# dropout 概率
dropout: float = 0.0,
# 层数
num_layers: int = 1,
# ResNet 的 epsilon 值
resnet_eps: float = 1e-6,
# ResNet 时间缩放偏移设置
resnet_time_scale_shift: str = "default",
# ResNet 激活函数类型
resnet_act_fn: str = "swish",
# 是否使用 ResNet 预归一化
resnet_pre_norm: bool = True,
# 输出缩放因子
output_scale_factor: float = np.sqrt(2.0),
# 是否添加上采样
add_upsample: bool = True,
# 上采样填充大小
upsample_padding: int = 1,
):
# 初始化父类
super().__init__()
# 此处应有具体的初始化代码(如层的定义),略去
):
# 调用父类的初始化方法
super().__init__()
# 创建一个空的 ModuleList 用于存储 ResnetBlock2D 层
self.resnets = nn.ModuleList([])
# 根据 num_layers 的数量来添加 ResnetBlock2D 层
for i in range(num_layers):
# 计算跳过通道数,如果是最后一层则使用 in_channels,否则使用 out_channels
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
# 确定当前 ResNet 块的输入通道数
resnet_in_channels = prev_output_channel if i == 0 else out_channels
# 向 resnets 列表中添加一个新的 ResnetBlock2D 实例
self.resnets.append(
ResnetBlock2D(
# 输入通道数为当前层输入通道加跳过通道数
in_channels=resnet_in_channels + res_skip_channels,
# 输出通道数
out_channels=out_channels,
# 时间嵌入通道数
temb_channels=temb_channels,
# 归一化的 epsilon 值
eps=resnet_eps,
# 分组数为输入通道数的一部分,最多为 32
groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
# 输出分组数,同样最多为 32
groups_out=min(out_channels // 4, 32),
# dropout 概率
dropout=dropout,
# 时间嵌入的归一化方式
time_embedding_norm=resnet_time_scale_shift,
# 激活函数类型
non_linearity=resnet_act_fn,
# 输出缩放因子
output_scale_factor=output_scale_factor,
# 是否使用预归一化
pre_norm=resnet_pre_norm,
)
)
# 初始化上采样层
self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
# 如果需要添加上采样层
if add_upsample:
# 添加一个上采样的 ResnetBlock2D
self.resnet_up = ResnetBlock2D(
# 输入通道数
in_channels=out_channels,
# 输出通道数
out_channels=out_channels,
# 时间嵌入通道数
temb_channels=temb_channels,
# 归一化的 epsilon 值
eps=resnet_eps,
# 分组数,最多为 32
groups=min(out_channels // 4, 32),
# 输出分组数,最多为 32
groups_out=min(out_channels // 4, 32),
# dropout 概率
dropout=dropout,
# 时间嵌入的归一化方式
time_embedding_norm=resnet_time_scale_shift,
# 激活函数类型
non_linearity=resnet_act_fn,
# 输出缩放因子
output_scale_factor=output_scale_factor,
# 是否使用预归一化
pre_norm=resnet_pre_norm,
# 是否在快捷路径中使用
use_in_shortcut=True,
# 标记为上采样
up=True,
# 使用 FIR 卷积核
kernel="fir",
)
# 定义跳过连接的卷积层,将输出通道数映射到 3 通道
self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# 定义跳过连接的归一化层
self.skip_norm = torch.nn.GroupNorm(
num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
)
# 定义激活函数为 SiLU
self.act = nn.SiLU()
else:
# 如果不添加上采样层,则将相关属性设为 None
self.resnet_up = None
self.skip_conv = None
self.skip_norm = None
self.act = None
# 保存分辨率索引
self.resolution_idx = resolution_idx
# 定义前向传播方法
def forward(
# 定义前向传播输入参数
hidden_states: torch.Tensor,
# 存储残差隐藏状态的元组
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
# 可选的时间嵌入张量
temb: Optional[torch.Tensor] = None,
# 跳过采样的可选参数
skip_sample=None,
# 可变位置参数
*args,
# 可变关键字参数
**kwargs,
# 函数返回两个张量,表示隐藏状态和跳过的样本
) -> Tuple[torch.Tensor, torch.Tensor]:
# 检查参数长度或关键字参数 "scale" 是否存在
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 创建弃用消息,提醒用户 "scale" 参数将被忽略
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 调用弃用函数,传递参数信息
deprecate("scale", "1.0.0", deprecation_message)
# 遍历自定义的 ResNet 模块
for resnet in self.resnets:
# 从隐藏状态元组中弹出最后一个 ResNet 隐藏状态
res_hidden_states = res_hidden_states_tuple[-1]
# 更新隐藏状态元组,去掉最后一个元素
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# 将当前的隐藏状态与 ResNet 隐藏状态连接
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
# 通过 ResNet 模块更新隐藏状态
hidden_states = resnet(hidden_states, temb)
# 检查跳过样本是否存在
if skip_sample is not None:
# 如果存在,使用上采样器处理跳过样本
skip_sample = self.upsampler(skip_sample)
else:
# 否则,将跳过样本初始化为 0
skip_sample = 0
# 检查是否存在 ResNet 上采样模块
if self.resnet_up is not None:
# 对隐藏状态应用归一化
skip_sample_states = self.skip_norm(hidden_states)
# 对归一化结果应用激活函数
skip_sample_states = self.act(skip_sample_states)
# 对激活结果应用卷积操作
skip_sample_states = self.skip_conv(skip_sample_states)
# 将跳过样本与处理后的状态相加
skip_sample = skip_sample + skip_sample_states
# 通过 ResNet 上采样模块更新隐藏状态
hidden_states = self.resnet_up(hidden_states, temb)
# 返回最终的隐藏状态和跳过样本
return hidden_states, skip_sample
# 定义一个 2D 上采样的 ResNet 块,继承自 nn.Module
class ResnetUpsampleBlock2D(nn.Module):
# 初始化方法,设置网络参数
def __init__(
self,
in_channels: int, # 输入通道数
prev_output_channel: int, # 前一层输出的通道数
out_channels: int, # 输出通道数
temb_channels: int, # 时间嵌入的通道数
resolution_idx: Optional[int] = None, # 分辨率索引(可选)
dropout: float = 0.0, # dropout 比例
num_layers: int = 1, # ResNet 层数
resnet_eps: float = 1e-6, # ResNet 的 epsilon 值
resnet_time_scale_shift: str = "default", # 时间缩放偏移方式
resnet_act_fn: str = "swish", # 激活函数类型
resnet_groups: int = 32, # 组数
resnet_pre_norm: bool = True, # 是否使用预归一化
output_scale_factor: float = 1.0, # 输出缩放因子
add_upsample: bool = True, # 是否添加上采样层
skip_time_act: bool = False, # 是否跳过时间激活
):
# 调用父类构造函数
super().__init__()
# 初始化一个空的 ResNet 列表
resnets = []
# 遍历层数,创建每一层的 ResNet 块
for i in range(num_layers):
# 确定跳过通道数,如果是最后一层,则使用输入通道数,否则使用输出通道数
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
# 确定当前层的输入通道数
resnet_in_channels = prev_output_channel if i == 0 else out_channels
# 将 ResNet 块添加到列表中
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels, # 输入通道数
out_channels=out_channels, # 输出通道数
temb_channels=temb_channels, # 时间嵌入通道数
eps=resnet_eps, # epsilon 值
groups=resnet_groups, # 组数
dropout=dropout, # dropout 比例
time_embedding_norm=resnet_time_scale_shift, # 时间嵌入归一化方式
non_linearity=resnet_act_fn, # 非线性激活函数
output_scale_factor=output_scale_factor, # 输出缩放因子
pre_norm=resnet_pre_norm, # 是否使用预归一化
skip_time_act=skip_time_act, # 是否跳过时间激活
)
)
# 将 ResNet 列表转换为 nn.ModuleList,便于 PyTorch 管理
self.resnets = nn.ModuleList(resnets)
# 如果需要添加上采样层,则创建上采样模块
if add_upsample:
self.upsamplers = nn.ModuleList(
[
ResnetBlock2D(
in_channels=out_channels, # 输入通道数
out_channels=out_channels, # 输出通道数
temb_channels=temb_channels, # 时间嵌入通道数
eps=resnet_eps, # epsilon 值
groups=resnet_groups, # 组数
dropout=dropout, # dropout 比例
time_embedding_norm=resnet_time_scale_shift, # 时间嵌入归一化方式
non_linearity=resnet_act_fn, # 非线性激活函数
output_scale_factor=output_scale_factor, # 输出缩放因子
pre_norm=resnet_pre_norm, # 是否使用预归一化
skip_time_act=skip_time_act, # 是否跳过时间激活
up=True, # 标记为上采样块
)
]
)
else:
# 如果不需要上采样,则将其设为 None
self.upsamplers = None
# 初始化梯度检查点为 False
self.gradient_checkpointing = False
# 设置分辨率索引
self.resolution_idx = resolution_idx
# 定义前向传播方法
def forward(
self,
hidden_states: torch.Tensor, # 输入的隐藏状态
res_hidden_states_tuple: Tuple[torch.Tensor, ...], # 额外的隐藏状态元组
temb: Optional[torch.Tensor] = None, # 时间嵌入(可选)
upsample_size: Optional[int] = None, # 上采样大小(可选)
*args, # 额外的位置参数
**kwargs, # 额外的关键字参数
) -> torch.Tensor: # 定义一个返回 torch.Tensor 的函数
# 检查参数列表是否包含参数或 "scale" 是否不为 None
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 定义弃用信息,说明 "scale" 参数将被忽略
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 调用 deprecate 函数以记录 "scale" 参数的弃用
deprecate("scale", "1.0.0", deprecation_message)
# 遍历存储的 ResNet 模型列表
for resnet in self.resnets:
# 从隐藏状态元组中弹出最后一个 ResNet 隐藏状态
res_hidden_states = res_hidden_states_tuple[-1]
# 更新隐藏状态元组,去掉最后一个元素
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# 将当前的隐藏状态与 ResNet 隐藏状态在指定维度上拼接
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
# 如果处于训练模式且开启了梯度检查点
if self.training and self.gradient_checkpointing:
# 定义一个创建自定义前向传播函数的内部函数
def create_custom_forward(module):
# 定义自定义前向传播函数,调用模块
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# 检查 PyTorch 版本是否大于等于 1.11.0
if is_torch_version(">=", "1.11.0"):
# 使用梯度检查点功能进行前向传播
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
# 使用梯度检查点功能进行前向传播,不使用可重入选项
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
# 直接调用 ResNet 模型进行前向传播
hidden_states = resnet(hidden_states, temb)
# 如果存在上采样器
if self.upsamplers is not None:
# 遍历每个上采样器
for upsampler in self.upsamplers:
# 调用上采样器进行上采样处理
hidden_states = upsampler(hidden_states, temb)
# 返回最终的隐藏状态
return hidden_states
# 定义一个简单的二维交叉注意力上采样模块,继承自 nn.Module
class SimpleCrossAttnUpBlock2D(nn.Module):
# 初始化函数,接受多个参数来配置模块
def __init__(
# 输入通道数
in_channels: int,
# 输出通道数
out_channels: int,
# 上一个输出通道数
prev_output_channel: int,
# 时间嵌入通道数
temb_channels: int,
# 可选的分辨率索引
resolution_idx: Optional[int] = None,
# Dropout 概率
dropout: float = 0.0,
# 层数
num_layers: int = 1,
# ResNet 的 epsilon 值
resnet_eps: float = 1e-6,
# ResNet 时间缩放偏移
resnet_time_scale_shift: str = "default",
# ResNet 激活函数
resnet_act_fn: str = "swish",
# ResNet 中的组数
resnet_groups: int = 32,
# 是否在 ResNet 中使用预归一化
resnet_pre_norm: bool = True,
# 注意力头的维度
attention_head_dim: int = 1,
# 交叉注意力的维度
cross_attention_dim: int = 1280,
# 输出缩放因子
output_scale_factor: float = 1.0,
# 是否添加上采样
add_upsample: bool = True,
# 是否跳过时间激活
skip_time_act: bool = False,
# 是否仅使用交叉注意力
only_cross_attention: bool = False,
# 可选的交叉注意力归一化方式
cross_attention_norm: Optional[str] = None,
# 初始化函数
):
# 调用父类的初始化方法
super().__init__()
# 创建一个空列表用于存储 ResNet 模块
resnets = []
# 创建一个空列表用于存储 Attention 模块
attentions = []
# 设置是否使用交叉注意力
self.has_cross_attention = True
# 设置每个注意力头的维度
self.attention_head_dim = attention_head_dim
# 计算注意力头的数量
self.num_heads = out_channels // self.attention_head_dim
# 遍历每一层以构建 ResNet 模块
for i in range(num_layers):
# 设置跳跃连接通道数
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
# 设置当前 ResNet 输入通道数
resnet_in_channels = prev_output_channel if i == 0 else out_channels
# 添加 ResNet 块到列表
resnets.append(
ResnetBlock2D(
# 设置输入通道数为 ResNet 输入通道加上跳跃连接通道
in_channels=resnet_in_channels + res_skip_channels,
# 设置输出通道数
out_channels=out_channels,
# 设置时间嵌入通道数
temb_channels=temb_channels,
# 设置小常数用于数值稳定性
eps=resnet_eps,
# 设置分组数
groups=resnet_groups,
# 设置 dropout 比例
dropout=dropout,
# 设置时间嵌入的归一化方式
time_embedding_norm=resnet_time_scale_shift,
# 设置激活函数
non_linearity=resnet_act_fn,
# 设置输出缩放因子
output_scale_factor=output_scale_factor,
# 设置是否预归一化
pre_norm=resnet_pre_norm,
# 设置是否跳过时间激活
skip_time_act=skip_time_act,
)
)
# 根据是否支持缩放点积注意力选择处理器
processor = (
AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
)
# 添加 Attention 模块到列表
attentions.append(
Attention(
# 设置查询维度
query_dim=out_channels,
# 设置交叉注意力维度
cross_attention_dim=out_channels,
# 设置头的数量
heads=self.num_heads,
# 设置每个头的维度
dim_head=self.attention_head_dim,
# 设置额外的 KV 投影维度
added_kv_proj_dim=cross_attention_dim,
# 设置归一化的组数
norm_num_groups=resnet_groups,
# 设置是否使用偏置
bias=True,
# 设置是否上溯 softmax
upcast_softmax=True,
# 设置是否仅使用交叉注意力
only_cross_attention=only_cross_attention,
# 设置交叉注意力的归一化方式
cross_attention_norm=cross_attention_norm,
# 设置处理器
processor=processor,
)
)
# 将 Attention 模块列表转换为 ModuleList
self.attentions = nn.ModuleList(attentions)
# 将 ResNet 模块列表转换为 ModuleList
self.resnets = nn.ModuleList(resnets)
# 如果需要添加上采样模块
if add_upsample:
# 创建一个上采样的 ResNet 模块列表
self.upsamplers = nn.ModuleList(
[
ResnetBlock2D(
# 设置上采样的输入和输出通道数
in_channels=out_channels,
out_channels=out_channels,
# 设置时间嵌入通道数
temb_channels=temb_channels,
# 设置小常数用于数值稳定性
eps=resnet_eps,
# 设置分组数
groups=resnet_groups,
# 设置 dropout 比例
dropout=dropout,
# 设置时间嵌入的归一化方式
time_embedding_norm=resnet_time_scale_shift,
# 设置激活函数
non_linearity=resnet_act_fn,
# 设置输出缩放因子
output_scale_factor=output_scale_factor,
# 设置是否预归一化
pre_norm=resnet_pre_norm,
# 设置是否跳过时间激活
skip_time_act=skip_time_act,
# 设置为上采样模式
up=True,
)
]
)
else:
# 如果不需要上采样,则将上采样模块设置为 None
self.upsamplers = None
# 初始化梯度检查点设置为 False
self.gradient_checkpointing = False
# 设置分辨率索引
self.resolution_idx = resolution_idx
# 定义前向传播方法,接收多个输入参数
def forward(
self,
# 当前隐藏状态的张量
hidden_states: torch.Tensor,
# 包含残差隐藏状态的元组
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
# 可选的时间嵌入张量
temb: Optional[torch.Tensor] = None,
# 可选的编码器隐藏状态张量
encoder_hidden_states: Optional[torch.Tensor] = None,
# 可选的上采样大小
upsample_size: Optional[int] = None,
# 可选的注意力掩码张量
attention_mask: Optional[torch.Tensor] = None,
# 可选的跨注意力参数字典
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# 可选的编码器注意力掩码张量
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: # 定义函数的返回类型为 torch.Tensor
# 如果 cross_attention_kwargs 为 None,则初始化为空字典
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
# 如果 cross_attention_kwargs 中的 "scale" 存在,发出警告,提示已弃用
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# 如果 attention_mask 为 None
if attention_mask is None:
# 如果 encoder_hidden_states 已定义,则进行交叉注意力,使用 encoder_attention_mask
mask = None if encoder_hidden_states is None else encoder_attention_mask
else:
# 如果 attention_mask 已定义,不检查 encoder_attention_mask
# 这样做是为了兼容 UnCLIP,后者使用 'attention_mask' 参数作为交叉注意力掩码
# TODO: UnCLIP 应该通过 encoder_attention_mask 参数表达交叉注意力掩码,而不是通过 attention_mask
# 那样可以简化整个 if/else 语句块
mask = attention_mask # 使用提供的 attention_mask
# 遍历 self.resnets 和 self.attentions 的元素
for resnet, attn in zip(self.resnets, self.attentions):
# 获取最后一项的残差隐藏状态
res_hidden_states = res_hidden_states_tuple[-1]
# 更新 res_hidden_states_tuple,去掉最后一项
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# 将当前的 hidden_states 和残差隐藏状态在维度 1 上连接
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
# 如果处于训练模式并且开启了梯度检查点
if self.training and self.gradient_checkpointing:
# 定义创建自定义前向传播函数的内部函数
def create_custom_forward(module, return_dict=None):
# 定义自定义前向传播函数
def custom_forward(*inputs):
# 如果 return_dict 不为 None,则返回字典形式的结果
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs) # 否则直接返回结果
return custom_forward # 返回自定义前向传播函数
# 使用检查点机制进行前向传播,节省内存
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
# 进行注意力计算
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=mask,
**cross_attention_kwargs,
)
else:
# 直接进行前向传播计算
hidden_states = resnet(hidden_states, temb)
# 进行注意力计算
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=mask,
**cross_attention_kwargs,
)
# 如果存在上采样器
if self.upsamplers is not None:
# 遍历所有上采样器
for upsampler in self.upsamplers:
# 进行上采样操作
hidden_states = upsampler(hidden_states, temb)
# 返回最终的隐藏状态
return hidden_states
# 定义一个名为 KUpBlock2D 的神经网络模块,继承自 nn.Module
class KUpBlock2D(nn.Module):
# 初始化函数,设置网络层的参数
def __init__(
self,
in_channels: int, # 输入通道数
out_channels: int, # 输出通道数
temb_channels: int, # 时间嵌入通道数
resolution_idx: int, # 分辨率索引
dropout: float = 0.0, # dropout 比例,默认 0
num_layers: int = 5, # 网络层数,默认 5
resnet_eps: float = 1e-5, # ResNet 中的 epsilon 值
resnet_act_fn: str = "gelu", # ResNet 中使用的激活函数,默认 "gelu"
resnet_group_size: Optional[int] = 32, # ResNet 中的组大小,默认 32
add_upsample: bool = True, # 是否添加上采样层,默认 True
):
# 调用父类初始化函数
super().__init__()
# 创建一个空的列表,用于存放 ResNet 模块
resnets = []
# 定义输入通道的数量,设置为输出通道的两倍
k_in_channels = 2 * out_channels
# 定义输出通道的数量
k_out_channels = in_channels
# 减少层数,以适应后续的循环
num_layers = num_layers - 1
# 创建指定层数的 ResNet 模块
for i in range(num_layers):
# 第一层的输入通道为 k_in_channels,其余层为 out_channels
in_channels = k_in_channels if i == 0 else out_channels
# 计算组的数量
groups = in_channels // resnet_group_size
# 计算输出组的数量
groups_out = out_channels // resnet_group_size
# 将 ResNet 模块添加到列表中
resnets.append(
ResnetBlockCondNorm2D(
in_channels=in_channels, # 输入通道数
out_channels=k_out_channels if (i == num_layers - 1) else out_channels, # 输出通道数
temb_channels=temb_channels, # 时间嵌入通道数
eps=resnet_eps, # epsilon 值
groups=groups, # 输入组数量
groups_out=groups_out, # 输出组数量
dropout=dropout, # dropout 比例
non_linearity=resnet_act_fn, # 激活函数
time_embedding_norm="ada_group", # 时间嵌入规范化方式
conv_shortcut_bias=False, # 是否使用卷积快捷连接的偏置
)
)
# 将 ResNet 模块列表转换为 nn.ModuleList
self.resnets = nn.ModuleList(resnets)
# 根据是否添加上采样层来初始化上采样模块
if add_upsample:
# 如果添加上采样,创建上采样层列表
self.upsamplers = nn.ModuleList([KUpsample2D()])
else:
# 如果不添加上采样,设置为 None
self.upsamplers = None
# 初始化梯度检查点为 False
self.gradient_checkpointing = False
# 存储分辨率索引
self.resolution_idx = resolution_idx
# 定义前向传播函数
def forward(
self,
hidden_states: torch.Tensor, # 输入的隐藏状态张量
res_hidden_states_tuple: Tuple[torch.Tensor, ...], # 传入的隐藏状态元组
temb: Optional[torch.Tensor] = None, # 可选的时间嵌入张量
upsample_size: Optional[int] = None, # 可选的上采样大小
*args, # 额外的位置参数
**kwargs, # 额外的关键字参数
# 定义返回值为 torch.Tensor 的函数结束部分
) -> torch.Tensor:
# 检查传入参数是否包含 args 或者 kwargs 中的 "scale" 参数
if len(args) > 0 or kwargs.get("scale", None) is not None:
# 定义弃用的提示信息,告知用户应删除 "scale" 参数
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# 调用 deprecate 函数记录弃用信息
deprecate("scale", "1.0.0", deprecation_message)
# 取 res_hidden_states_tuple 的最后一个元素
res_hidden_states_tuple = res_hidden_states_tuple[-1]
# 如果 res_hidden_states_tuple 不为 None,则将其与 hidden_states 拼接
if res_hidden_states_tuple is not None:
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
# 遍历 self.resnets 列表中的每个 resnet 模块
for resnet in self.resnets:
# 如果处于训练模式且开启梯度检查点功能
if self.training and self.gradient_checkpointing:
# 定义一个创建自定义前向函数的内部函数
def create_custom_forward(module):
# 定义自定义前向函数,调用模块处理输入
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# 检查 PyTorch 版本是否大于等于 1.11.0
if is_torch_version(">=", "1.11.0"):
# 使用检查点功能进行前向传播,避免计算图保存
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
else:
# 使用检查点功能进行前向传播
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)
else:
# 在非训练模式下直接通过 resnet 处理 hidden_states
hidden_states = resnet(hidden_states, temb)
# 如果存在 upsamplers,则遍历每个 upsampler
if self.upsamplers is not None:
for upsampler in self.upsamplers:
# 通过 upsampler 处理 hidden_states
hidden_states = upsampler(hidden_states)
# 返回处理后的 hidden_states
return hidden_states
# 定义一个 KCrossAttnUpBlock2D 类,继承自 nn.Module
class KCrossAttnUpBlock2D(nn.Module):
# 初始化方法,定义该类的属性
def __init__(
# 输入通道数
in_channels: int,
# 输出通道数
out_channels: int,
# 额外的嵌入通道数
temb_channels: int,
# 当前分辨率索引
resolution_idx: int,
# dropout 概率,默认为 0.0
dropout: float = 0.0,
# 残差网络的层数,默认为 4
num_layers: int = 4,
# 残差网络的 epsilon 值,默认为 1e-5
resnet_eps: float = 1e-5,
# 残差网络的激活函数类型,默认为 "gelu"
resnet_act_fn: str = "gelu",
# 残差网络的分组大小,默认为 32
resnet_group_size: int = 32,
# 注意力的维度,默认为 1
attention_head_dim: int = 1, # attention dim_head
# 交叉注意力的维度,默认为 768
cross_attention_dim: int = 768,
# 是否添加上采样,默认为 True
add_upsample: bool = True,
# 是否上溢注意力,默认为 False
upcast_attention: bool = False,
):
# 调用父类的构造函数
super().__init__()
# 初始化一个空列表,用于存储 ResNet 块
resnets = []
# 初始化一个空列表,用于存储注意力块
attentions = []
# 判断是否为第一个块:输入、输出和时间嵌入通道是否相等
is_first_block = in_channels == out_channels == temb_channels
# 判断是否为中间块:输入和输出通道是否不相等
is_middle_block = in_channels != out_channels
# 如果是第一个块,设置为 True 以添加自注意力
add_self_attention = True if is_first_block else False
# 设置跨注意力的标志为 True
self.has_cross_attention = True
# 存储注意力头的维度
self.attention_head_dim = attention_head_dim
# 定义当前块的输入通道,若是第一个块则使用输出通道,否则使用两倍的输出通道
k_in_channels = out_channels if is_first_block else 2 * out_channels
# 当前块的输出通道为输入通道
k_out_channels = in_channels
# 减少层数以计算循环中的层数
num_layers = num_layers - 1
# 根据层数循环创建 ResNet 块和注意力块
for i in range(num_layers):
# 第一个层使用 k_in_channels,后续层使用 out_channels
in_channels = k_in_channels if i == 0 else out_channels
# 计算组数,以便在 ResNet 中分组
groups = in_channels // resnet_group_size
groups_out = out_channels // resnet_group_size
# 判断是否为中间块并且是最后一层,设置卷积的输出通道
if is_middle_block and (i == num_layers - 1):
conv_2d_out_channels = k_out_channels
else:
# 如果不是,设置为 None
conv_2d_out_channels = None
# 创建并添加 ResNet 块到 resnets 列表
resnets.append(
ResnetBlockCondNorm2D(
# 输入通道
in_channels=in_channels,
# 输出通道
out_channels=out_channels,
# 卷积输出通道
conv_2d_out_channels=conv_2d_out_channels,
# 时间嵌入通道
temb_channels=temb_channels,
# 设定 epsilon 值
eps=resnet_eps,
# 输入组数
groups=groups,
# 输出组数
groups_out=groups_out,
# dropout 概率
dropout=dropout,
# 非线性激活函数
non_linearity=resnet_act_fn,
# 时间嵌入的归一化方式
time_embedding_norm="ada_group",
# 是否使用卷积快捷连接的偏置
conv_shortcut_bias=False,
)
)
# 创建并添加注意力块到 attentions 列表
attentions.append(
KAttentionBlock(
# 最后一个层使用 k_out_channels,否则使用 out_channels
k_out_channels if (i == num_layers - 1) else out_channels,
# 最后一个层注意力维度
k_out_channels // attention_head_dim
if (i == num_layers - 1)
else out_channels // attention_head_dim,
# 注意力头的维度
attention_head_dim,
# 跨注意力维度
cross_attention_dim=cross_attention_dim,
# 时间嵌入通道
temb_channels=temb_channels,
# 是否添加注意力偏置
attention_bias=True,
# 是否添加自注意力
add_self_attention=add_self_attention,
# 跨注意力归一化方式
cross_attention_norm="layer_norm",
# 是否上溯注意力
upcast_attention=upcast_attention,
)
)
# 将 ResNet 块列表转为 PyTorch 的 ModuleList
self.resnets = nn.ModuleList(resnets)
# 将注意力块列表转为 PyTorch 的 ModuleList
self.attentions = nn.ModuleList(attentions)
# 如果需要上采样,则创建一个包含上采样块的 ModuleList
if add_upsample:
self.upsamplers = nn.ModuleList([KUpsample2D()])
else:
# 否则将上采样块设置为 None
self.upsamplers = None
# 初始化梯度检查点为 False
self.gradient_checkpointing = False
# 存储当前分辨率的索引
self.resolution_idx = resolution_idx
# 定义前向传播函数,接受多个输入参数,返回处理后的张量
def forward(
self,
hidden_states: torch.Tensor, # 输入的隐藏状态张量
res_hidden_states_tuple: Tuple[torch.Tensor, ...], # 先前隐藏状态的元组
temb: Optional[torch.Tensor] = None, # 可选的时间嵌入张量
encoder_hidden_states: Optional[torch.Tensor] = None, # 可选的编码器隐藏状态
cross_attention_kwargs: Optional[Dict[str, Any]] = None, # 交叉注意力的可选参数
upsample_size: Optional[int] = None, # 可选的上采样尺寸
attention_mask: Optional[torch.Tensor] = None, # 可选的注意力掩码
encoder_attention_mask: Optional[torch.Tensor] = None, # 可选的编码器注意力掩码
) -> torch.Tensor: # 函数返回类型为张量
res_hidden_states_tuple = res_hidden_states_tuple[-1] # 获取最后一个隐藏状态
if res_hidden_states_tuple is not None: # 检查是否存在先前的隐藏状态
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) # 拼接当前和先前的隐藏状态
for resnet, attn in zip(self.resnets, self.attentions): # 遍历每个残差网络和注意力层
if self.training and self.gradient_checkpointing: # 检查是否在训练模式且使用梯度检查点
def create_custom_forward(module, return_dict=None): # 定义创建自定义前向函数的内部函数
def custom_forward(*inputs): # 自定义前向函数
if return_dict is not None: # 检查是否需要返回字典
return module(*inputs, return_dict=return_dict) # 使用返回字典的方式调用模块
else:
return module(*inputs) # 普通调用模块
return custom_forward # 返回自定义前向函数
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} # 根据Torch版本设置检查点参数
hidden_states = torch.utils.checkpoint.checkpoint( # 使用检查点进行前向传播以节省内存
create_custom_forward(resnet), # 创建自定义前向函数
hidden_states, # 输入隐藏状态
temb, # 输入时间嵌入
**ckpt_kwargs, # 传递检查点参数
)
hidden_states = attn( # 通过注意力层处理隐藏状态
hidden_states, # 输入隐藏状态
encoder_hidden_states=encoder_hidden_states, # 输入编码器隐藏状态
emb=temb, # 输入时间嵌入
attention_mask=attention_mask, # 输入注意力掩码
cross_attention_kwargs=cross_attention_kwargs, # 交叉注意力参数
encoder_attention_mask=encoder_attention_mask, # 编码器注意力掩码
)
else: # 如果不使用梯度检查点
hidden_states = resnet(hidden_states, temb) # 直接通过残差网络处理隐藏状态
hidden_states = attn( # 通过注意力层处理隐藏状态
hidden_states, # 输入隐藏状态
encoder_hidden_states=encoder_hidden_states, # 输入编码器隐藏状态
emb=temb, # 输入时间嵌入
attention_mask=attention_mask, # 输入注意力掩码
cross_attention_kwargs=cross_attention_kwargs, # 交叉注意力参数
encoder_attention_mask=encoder_attention_mask, # 编码器注意力掩码
)
if self.upsamplers is not None: # 检查是否有上采样层
for upsampler in self.upsamplers: # 遍历每个上采样层
hidden_states = upsampler(hidden_states) # 通过上采样层处理隐藏状态
return hidden_states # 返回处理后的隐藏状态
# 可以潜在地更名为 `No-feed-forward` 注意力
class KAttentionBlock(nn.Module):
r"""
基本的 Transformer 块。
参数:
dim (`int`): 输入和输出的通道数。
num_attention_heads (`int`): 用于多头注意力的头数。
attention_head_dim (`int`): 每个头的通道数。
dropout (`float`, *可选*, 默认为 0.0): 使用的丢弃概率。
cross_attention_dim (`int`, *可选*): 用于交叉注意力的 encoder_hidden_states 向量的大小。
attention_bias (`bool`, *可选*, 默认为 `False`):
配置注意力层是否应该包含偏置参数。
upcast_attention (`bool`, *可选*, 默认为 `False`):
设置为 `True` 以将注意力计算上调为 `float32`。
temb_channels (`int`, *可选*, 默认为 768):
令牌嵌入中的通道数。
add_self_attention (`bool`, *可选*, 默认为 `False`):
设置为 `True` 以将自注意力添加到块中。
cross_attention_norm (`str`, *可选*, 默认为 `None`):
用于交叉注意力的规范化类型。可以是 `None`、`layer_norm` 或 `group_norm`。
group_size (`int`, *可选*, 默认为 32):
用于组规范化将通道分成的组数。
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout: float = 0.0,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
upcast_attention: bool = False,
temb_channels: int = 768, # 用于 ada_group_norm
add_self_attention: bool = False,
cross_attention_norm: Optional[str] = None,
group_size: int = 32,
):
# 调用父类构造函数,初始化 nn.Module
super().__init__()
# 设置是否添加自注意力的标志
self.add_self_attention = add_self_attention
# 1. 自注意力
if add_self_attention:
# 初始化自注意力的归一化层
self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))
# 初始化自注意力机制
self.attn1 = Attention(
query_dim=dim, # 查询向量的维度
heads=num_attention_heads, # 注意力头数
dim_head=attention_head_dim, # 每个头的维度
dropout=dropout, # 丢弃率
bias=attention_bias, # 是否使用偏置
cross_attention_dim=None, # 交叉注意力维度
cross_attention_norm=None, # 交叉注意力的归一化
)
# 2. 交叉注意力
# 初始化交叉注意力的归一化层
self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))
# 初始化交叉注意力机制
self.attn2 = Attention(
query_dim=dim, # 查询向量的维度
cross_attention_dim=cross_attention_dim, # 交叉注意力维度
heads=num_attention_heads, # 注意力头数
dim_head=attention_head_dim, # 每个头的维度
dropout=dropout, # 丢弃率
bias=attention_bias, # 是否使用偏置
upcast_attention=upcast_attention, # 是否上调注意力计算
cross_attention_norm=cross_attention_norm, # 交叉注意力的归一化
)
# 将隐藏状态转换为 3D 张量,包含 batch size, height*weight 和通道数
def _to_3d(self, hidden_states: torch.Tensor, height: int, weight: int) -> torch.Tensor:
# 重新排列维度,并调整形状为 (batch size, height*weight, -1)
return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1)
# 将隐藏状态转换为 4D 张量,包含 batch size, 通道数, height 和 weight
def _to_4d(self, hidden_states: torch.Tensor, height: int, weight: int) -> torch.Tensor:
# 重新排列维度,并调整形状为 (batch size, -1, height, weight)
return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
# TODO: 将 emb 标记为非可选 (self.norm2 需要它)。
# 需要评估对位置参数接口更改的影响。
emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# 如果 cross_attention_kwargs 为空,则初始化为一个空字典
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
# 检查 "scale" 参数是否存在,如果存在则发出警告
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# 1. 自注意力
if self.add_self_attention:
# 使用 norm1 对隐藏状态进行归一化处理
norm_hidden_states = self.norm1(hidden_states, emb)
# 获取归一化后状态的高度和宽度
height, weight = norm_hidden_states.shape[2:]
# 将归一化后的隐藏状态转换为 3D 张量
norm_hidden_states = self._to_3d(norm_hidden_states, height, weight)
# 执行自注意力操作
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
# 将自注意力输出转换为 4D 张量
attn_output = self._to_4d(attn_output, height, weight)
# 将自注意力输出与原始隐藏状态相加
hidden_states = attn_output + hidden_states
# 2. 交叉注意力或无交叉注意力
# 使用 norm2 对隐藏状态进行归一化处理
norm_hidden_states = self.norm2(hidden_states, emb)
# 获取归一化后状态的高度和宽度
height, weight = norm_hidden_states.shape[2:]
# 将归一化后的隐藏状态转换为 3D 张量
norm_hidden_states = self._to_3d(norm_hidden_states, height, weight)
# 执行交叉注意力操作
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask,
**cross_attention_kwargs,
)
# 将交叉注意力输出转换为 4D 张量
attn_output = self._to_4d(attn_output, height, weight)
# 将交叉注意力输出与隐藏状态相加
hidden_states = attn_output + hidden_states
# 返回最终的隐藏状态
return hidden_states
标签:states,None,attention,resnet,diffusers,channels,源码,hidden,解析
From: https://www.cnblogs.com/apachecn/p/18492382