diffusers 源码解析(六)
.\diffusers\models\autoencoders\autoencoder_oobleck.py
# 版权声明,表示该代码属于 HuggingFace 团队,所有权利保留
# 根据 Apache 2.0 许可证进行授权
# 用户在合规的情况下可以使用该文件
# 许可证的获取地址
# 如果没有适用的法律或书面协议,软件是按“现状”提供的
# 免责声明,表示不提供任何形式的保证或条件
import math # 导入数学库,提供数学函数和常数
from dataclasses import dataclass # 导入数据类装饰器,用于简化类的定义
from typing import Optional, Tuple, Union # 导入类型提示的必要工具
import numpy as np # 导入 NumPy 库,提供数组和矩阵操作
import torch # 导入 PyTorch 库,提供深度学习框架
import torch.nn as nn # 导入 PyTorch 神经网络模块
from torch.nn.utils import weight_norm # 导入权重归一化工具
from ...configuration_utils import ConfigMixin, register_to_config # 导入配置相关工具
from ...utils import BaseOutput # 导入基础输出工具
from ...utils.accelerate_utils import apply_forward_hook # 导入加速工具
from ...utils.torch_utils import randn_tensor # 导入随机张量生成工具
from ..modeling_utils import ModelMixin # 导入模型混合工具
class Snake1d(nn.Module):
"""
一个 1 维的 Snake 激活函数模块。
"""
def __init__(self, hidden_dim, logscale=True): # 初始化方法,接收隐藏维度和对数缩放标志
super().__init__() # 调用父类初始化方法
self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1)) # 定义 alpha 参数,初始为 0
self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1)) # 定义 beta 参数,初始为 0
self.alpha.requires_grad = True # 允许 alpha 参数更新
self.beta.requires_grad = True # 允许 beta 参数更新
self.logscale = logscale # 设置对数缩放标志
def forward(self, hidden_states): # 前向传播方法,接收隐藏状态
shape = hidden_states.shape # 获取隐藏状态的形状
alpha = self.alpha if not self.logscale else torch.exp(self.alpha) # 计算 alpha 值
beta = self.beta if not self.logscale else torch.exp(self.beta) # 计算 beta 值
hidden_states = hidden_states.reshape(shape[0], shape[1], -1) # 重塑隐藏状态
hidden_states = hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2) # 更新隐藏状态
hidden_states = hidden_states.reshape(shape) # 恢复隐藏状态形状
return hidden_states # 返回更新后的隐藏状态
class OobleckResidualUnit(nn.Module):
"""
一个由 Snake1d 和带扩张的权重归一化 Conv1d 层组成的残差单元。
"""
def __init__(self, dimension: int = 16, dilation: int = 1): # 初始化方法,接收维度和扩张因子
super().__init__() # 调用父类初始化方法
pad = ((7 - 1) * dilation) // 2 # 计算填充大小
self.snake1 = Snake1d(dimension) # 创建第一个 Snake1d 实例
self.conv1 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad)) # 创建第一个卷积层并应用权重归一化
self.snake2 = Snake1d(dimension) # 创建第二个 Snake1d 实例
self.conv2 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=1)) # 创建第二个卷积层并应用权重归一化
# 定义前向传播函数,接收隐藏状态作为输入
def forward(self, hidden_state):
"""
前向传播通过残差单元。
参数:
hidden_state (`torch.Tensor` 形状为 `(batch_size, channels, time_steps)`):
输入张量。
返回:
output_tensor (`torch.Tensor` 形状为 `(batch_size, channels, time_steps)`)
通过残差单元处理后的输入张量。
"""
# 将输入隐藏状态赋值给输出张量
output_tensor = hidden_state
# 通过第一个卷积层和激活函数处理输出张量
output_tensor = self.conv1(self.snake1(output_tensor))
# 通过第二个卷积层和激活函数处理输出张量
output_tensor = self.conv2(self.snake2(output_tensor))
# 计算填充量,以对齐输出张量和输入张量的时间步长
padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2
# 如果需要填充,则对隐藏状态进行切片
if padding > 0:
hidden_state = hidden_state[..., padding:-padding]
# 将处理后的输出张量与切片后的隐藏状态相加,实现残差连接
output_tensor = hidden_state + output_tensor
# 返回最终的输出张量
return output_tensor
# Oobleck编码器块的定义,继承自nn.Module
class OobleckEncoderBlock(nn.Module):
"""Encoder block used in Oobleck encoder."""
# 初始化函数,定义输入维度、输出维度和步幅
def __init__(self, input_dim, output_dim, stride: int = 1):
# 调用父类的初始化函数
super().__init__()
# 定义第一个残差单元,膨胀率为1
self.res_unit1 = OobleckResidualUnit(input_dim, dilation=1)
# 定义第二个残差单元,膨胀率为3
self.res_unit2 = OobleckResidualUnit(input_dim, dilation=3)
# 定义第三个残差单元,膨胀率为9
self.res_unit3 = OobleckResidualUnit(input_dim, dilation=9)
# 定义一维蛇形结构
self.snake1 = Snake1d(input_dim)
# 定义卷积层,使用权重归一化
self.conv1 = weight_norm(
nn.Conv1d(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2))
)
# 前向传播函数
def forward(self, hidden_state):
# 通过第一个残差单元处理隐状态
hidden_state = self.res_unit1(hidden_state)
# 通过第二个残差单元处理隐状态
hidden_state = self.res_unit2(hidden_state)
# 通过第三个残差单元和蛇形结构处理隐状态
hidden_state = self.snake1(self.res_unit3(hidden_state))
# 通过卷积层处理隐状态
hidden_state = self.conv1(hidden_state)
# 返回处理后的隐状态
return hidden_state
# Oobleck解码器块的定义,继承自nn.Module
class OobleckDecoderBlock(nn.Module):
"""Decoder block used in Oobleck decoder."""
# 初始化函数,定义输入维度、输出维度和步幅
def __init__(self, input_dim, output_dim, stride: int = 1):
# 调用父类的初始化函数
super().__init__()
# 定义一维蛇形结构
self.snake1 = Snake1d(input_dim)
# 定义转置卷积层,使用权重归一化
self.conv_t1 = weight_norm(
nn.ConvTranspose1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
)
)
# 定义第一个残差单元,膨胀率为1
self.res_unit1 = OobleckResidualUnit(output_dim, dilation=1)
# 定义第二个残差单元,膨胀率为3
self.res_unit2 = OobleckResidualUnit(output_dim, dilation=3)
# 定义第三个残差单元,膨胀率为9
self.res_unit3 = OobleckResidualUnit(output_dim, dilation=9)
# 前向传播函数
def forward(self, hidden_state):
# 通过蛇形结构处理隐状态
hidden_state = self.snake1(hidden_state)
# 通过转置卷积层处理隐状态
hidden_state = self.conv_t1(hidden_state)
# 通过第一个残差单元处理隐状态
hidden_state = self.res_unit1(hidden_state)
# 通过第二个残差单元处理隐状态
hidden_state = self.res_unit2(hidden_state)
# 通过第三个残差单元处理隐状态
hidden_state = self.res_unit3(hidden_state)
# 返回处理后的隐状态
return hidden_state
# Oobleck对角高斯分布的定义
class OobleckDiagonalGaussianDistribution(object):
# 初始化函数,定义参数和确定性标志
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
# 保存输入的参数
self.parameters = parameters
# 将参数分解为均值和尺度
self.mean, self.scale = parameters.chunk(2, dim=1)
# 计算标准差,确保为正值
self.std = nn.functional.softplus(self.scale) + 1e-4
# 计算方差
self.var = self.std * self.std
# 计算对数方差
self.logvar = torch.log(self.var)
# 保存确定性标志
self.deterministic = deterministic
# 采样函数,生成高斯样本
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
# 确保生成的样本在与参数相同的设备和数据类型上
sample = randn_tensor(
self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
# 根据均值和标准差生成样本
x = self.mean + self.std * sample
# 返回生成的样本
return x
# 计算 Kullback-Leibler 散度,返回一个张量
def kl(self, other: "OobleckDiagonalGaussianDistribution" = None) -> torch.Tensor:
# 如果是确定性分布,返回零的张量
if self.deterministic:
return torch.Tensor([0.0])
else:
# 如果没有提供其他分布,计算本分布的 KL 散度
if other is None:
return (self.mean * self.mean + self.var - self.logvar - 1.0).sum(1).mean()
else:
# 计算均值差的平方归一化
normalized_diff = torch.pow(self.mean - other.mean, 2) / other.var
# 计算方差比
var_ratio = self.var / other.var
# 计算对数方差差值
logvar_diff = self.logvar - other.logvar
# 计算 KL 散度的各个部分
kl = normalized_diff + var_ratio + logvar_diff - 1
# 计算 KL 散度的平均值
kl = kl.sum(1).mean()
return kl
# 返回分布的众数
def mode(self) -> torch.Tensor:
return self.mean
# 定义一个数据类,表示自编码器的输出
@dataclass
class AutoencoderOobleckOutput(BaseOutput):
"""
AutoencoderOobleck 编码方法的输出。
Args:
latent_dist (`OobleckDiagonalGaussianDistribution`):
表示 `Encoder` 编码输出的均值和标准差,
`OobleckDiagonalGaussianDistribution` 允许从分布中采样潜在变量。
"""
latent_dist: "OobleckDiagonalGaussianDistribution" # noqa: F821
# 定义一个数据类,表示解码器的输出
@dataclass
class OobleckDecoderOutput(BaseOutput):
r"""
解码方法的输出。
Args:
sample (`torch.Tensor` of shape `(batch_size, audio_channels, sequence_length)`):
从模型最后一层解码的输出样本。
"""
sample: torch.Tensor
# 定义 Oobleck 编码器类,继承自 nn.Module
class OobleckEncoder(nn.Module):
"""Oobleck 编码器"""
# 初始化编码器参数
def __init__(self, encoder_hidden_size, audio_channels, downsampling_ratios, channel_multiples):
super().__init__()
# 设置下采样比例
strides = downsampling_ratios
# 为通道倍数添加一个起始值
channel_multiples = [1] + channel_multiples
# 创建第一个卷积层
self.conv1 = weight_norm(nn.Conv1d(audio_channels, encoder_hidden_size, kernel_size=7, padding=3))
self.block = []
# 创建编码块,随着下采样通过 `stride` 加倍通道
for stride_index, stride in enumerate(strides):
self.block += [
OobleckEncoderBlock(
input_dim=encoder_hidden_size * channel_multiples[stride_index], # 输入维度
output_dim=encoder_hidden_size * channel_multiples[stride_index + 1], # 输出维度
stride=stride, # 下采样步幅
)
]
# 将编码块转换为模块列表
self.block = nn.ModuleList(self.block)
# 计算模型的最终维度
d_model = encoder_hidden_size * channel_multiples[-1]
# 创建 Snake1d 模块
self.snake1 = Snake1d(d_model)
# 创建第二个卷积层
self.conv2 = weight_norm(nn.Conv1d(d_model, encoder_hidden_size, kernel_size=3, padding=1))
# 定义前向传播方法
def forward(self, hidden_state):
# 将输入通过第一个卷积层
hidden_state = self.conv1(hidden_state)
# 通过每个编码块进行处理
for module in self.block:
hidden_state = module(hidden_state)
# 通过 Snake1d 模块处理
hidden_state = self.snake1(hidden_state)
# 将结果通过第二个卷积层
hidden_state = self.conv2(hidden_state)
# 返回最终的隐藏状态
return hidden_state
# 定义 Oobleck 解码器类,继承自 nn.Module
class OobleckDecoder(nn.Module):
"""Oobleck 解码器"""
# 初始化方法,设置模型参数
def __init__(self, channels, input_channels, audio_channels, upsampling_ratios, channel_multiples):
# 调用父类的初始化方法
super().__init__()
# 将上采样比例赋值给 strides
strides = upsampling_ratios
# 在 channel_multiples 列表前添加 1
channel_multiples = [1] + channel_multiples
# 添加第一个卷积层
self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * channel_multiples[-1], kernel_size=7, padding=3))
# 添加上采样 + MRF 块
block = []
# 遍历 strides 列表,构建 OobleckDecoderBlock
for stride_index, stride in enumerate(strides):
block += [
OobleckDecoderBlock(
# 设置输入和输出维度
input_dim=channels * channel_multiples[len(strides) - stride_index],
output_dim=channels * channel_multiples[len(strides) - stride_index - 1],
stride=stride,
)
]
# 将构建的块列表转为 nn.ModuleList
self.block = nn.ModuleList(block)
# 设置输出维度
output_dim = channels
# 创建 Snake1d 实例
self.snake1 = Snake1d(output_dim)
# 添加第二个卷积层,不使用偏置
self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, kernel_size=7, padding=3, bias=False))
# 前向传播方法
def forward(self, hidden_state):
# 通过第一个卷积层处理输入
hidden_state = self.conv1(hidden_state)
# 遍历每个块,依次处理隐藏状态
for layer in self.block:
hidden_state = layer(hidden_state)
# 通过 Snake1d 处理隐藏状态
hidden_state = self.snake1(hidden_state)
# 通过第二个卷积层处理隐藏状态
hidden_state = self.conv2(hidden_state)
# 返回最终的隐藏状态
return hidden_state
# 定义一个自动编码器类,用于将波形编码为潜在表示并解码为波形
class AutoencoderOobleck(ModelMixin, ConfigMixin):
r"""
自动编码器,用于将波形编码为潜在向量并将潜在表示解码为波形。首次引入于 Stable Audio。
此模型继承自 [`ModelMixin`]。请查阅超类文档以获取所有模型的通用方法(例如下载或保存)。
参数:
encoder_hidden_size (`int`, *可选*, 默认值为 128):
编码器的中间表示维度。
downsampling_ratios (`List[int]`, *可选*, 默认值为 `[2, 4, 4, 8, 8]`):
编码器中下采样的比率。这些比率在解码器中以反向顺序用于上采样。
channel_multiples (`List[int]`, *可选*, 默认值为 `[1, 2, 4, 8, 16]`):
用于确定隐藏层隐藏尺寸的倍数。
decoder_channels (`int`, *可选*, 默认值为 128):
解码器的中间表示维度。
decoder_input_channels (`int`, *可选*, 默认值为 64):
解码器的输入维度。对应于潜在维度。
audio_channels (`int`, *可选*, 默认值为 2):
音频数据中的通道数。1 表示单声道,2 表示立体声。
sampling_rate (`int`, *可选*, 默认值为 44100):
音频波形应数字化的采样率,以赫兹(Hz)表示。
"""
# 指示模型是否支持梯度检查点
_supports_gradient_checkpointing = False
# 注册到配置的构造函数
@register_to_config
def __init__(
self,
encoder_hidden_size=128,
downsampling_ratios=[2, 4, 4, 8, 8],
channel_multiples=[1, 2, 4, 8, 16],
decoder_channels=128,
decoder_input_channels=64,
audio_channels=2,
sampling_rate=44100,
):
# 调用父类的构造函数
super().__init__()
# 设置编码器的隐藏层大小
self.encoder_hidden_size = encoder_hidden_size
# 设置编码器中的下采样比率
self.downsampling_ratios = downsampling_ratios
# 设置解码器的通道数
self.decoder_channels = decoder_channels
# 将下采样比率反转以用于解码器的上采样
self.upsampling_ratios = downsampling_ratios[::-1]
# 计算跳长,作为下采样比率的乘积
self.hop_length = int(np.prod(downsampling_ratios))
# 设置音频采样率
self.sampling_rate = sampling_rate
# 创建编码器实例,传入必要参数
self.encoder = OobleckEncoder(
encoder_hidden_size=encoder_hidden_size,
audio_channels=audio_channels,
downsampling_ratios=downsampling_ratios,
channel_multiples=channel_multiples,
)
# 创建解码器实例,传入必要参数
self.decoder = OobleckDecoder(
channels=decoder_channels,
input_channels=decoder_input_channels,
audio_channels=audio_channels,
upsampling_ratios=self.upsampling_ratios,
channel_multiples=channel_multiples,
)
# 设置是否使用切片,初始值为假
self.use_slicing = False
# 启用切片 VAE 解码的功能
def enable_slicing(self):
r"""
启用切片 VAE 解码。当启用此选项时,VAE 将会将输入张量分割为多个切片以
分步计算解码。这对于节省内存和允许更大的批量大小非常有用。
"""
# 设置标志以启用切片
self.use_slicing = True
# 禁用切片 VAE 解码的功能
def disable_slicing(self):
r"""
禁用切片 VAE 解码。如果之前启用了 `enable_slicing`,则该方法将恢复为一步
计算解码。
"""
# 设置标志以禁用切片
self.use_slicing = False
@apply_forward_hook
# 编码函数,将一批图像编码为潜在表示
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderOobleckOutput, Tuple[OobleckDiagonalGaussianDistribution]]:
"""
将一批图像编码为潜在表示。
参数:
x (`torch.Tensor`): 输入图像批。
return_dict (`bool`, *可选*, 默认为 `True`):
是否返回 [`~models.autoencoder_kl.AutoencoderKLOutput`] 而不是普通元组。
返回:
编码图像的潜在表示。如果 `return_dict` 为 True,则返回
[`~models.autoencoder_kl.AutoencoderKLOutput`],否则返回普通 `tuple`。
"""
# 检查是否启用切片且输入批量大于 1
if self.use_slicing and x.shape[0] > 1:
# 对输入进行切片编码
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
# 将所有编码结果连接成一个张量
h = torch.cat(encoded_slices)
else:
# 对整个输入进行编码
h = self.encoder(x)
# 创建潜在分布
posterior = OobleckDiagonalGaussianDistribution(h)
# 检查是否返回字典格式
if not return_dict:
return (posterior,)
# 返回潜在表示的输出
return AutoencoderOobleckOutput(latent_dist=posterior)
# 解码函数,将潜在向量解码为图像
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[OobleckDecoderOutput, torch.Tensor]:
# 使用解码器解码潜在向量
dec = self.decoder(z)
# 检查是否返回字典格式
if not return_dict:
return (dec,)
# 返回解码结果的输出
return OobleckDecoderOutput(sample=dec)
@apply_forward_hook
# 解码函数,解码一批潜在向量为图像
def decode(
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
) -> Union[OobleckDecoderOutput, torch.FloatTensor]:
"""
解码一批图像。
参数:
z (`torch.Tensor`): 输入潜在向量批。
return_dict (`bool`, *可选*, 默认为 `True`):
是否返回 [`~models.vae.OobleckDecoderOutput`] 而不是普通元组。
返回:
[`~models.vae.OobleckDecoderOutput`] 或 `tuple`:
如果 return_dict 为 True,则返回 [`~models.vae.OobleckDecoderOutput`],否则返回普通 `tuple`。
"""
# 检查是否启用切片且输入批量大于 1
if self.use_slicing and z.shape[0] > 1:
# 对输入进行切片解码
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
# 将所有解码结果连接成一个张量
decoded = torch.cat(decoded_slices)
else:
# 对整个输入进行解码
decoded = self._decode(z).sample
# 检查是否返回字典格式
if not return_dict:
return (decoded,)
# 返回解码结果的输出
return OobleckDecoderOutput(sample=decoded)
# 定义一个前向传播的方法,接受样本输入和其他参数
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False, # 是否从后验分布中采样的标志,默认值为 False
return_dict: bool = True, # 是否返回 OobleckDecoderOutput 对象而非普通元组,默认值为 True
generator: Optional[torch.Generator] = None, # 可选的随机数生成器
) -> Union[OobleckDecoderOutput, torch.Tensor]: # 返回类型为 OobleckDecoderOutput 或 torch.Tensor
r"""
Args:
sample (`torch.Tensor`): 输入样本。
sample_posterior (`bool`, *optional*, defaults to `False`):
是否从后验分布中采样。
return_dict (`bool`, *optional*, defaults to `True`):
是否返回一个 [`OobleckDecoderOutput`] 而不是普通元组。
"""
x = sample # 将输入样本赋值给变量 x
posterior = self.encode(x).latent_dist # 对输入样本进行编码,获取潜在分布
if sample_posterior: # 如果选择从后验分布中采样
z = posterior.sample(generator=generator) # 从潜在分布中采样
else:
z = posterior.mode() # 否则取潜在分布的众数
dec = self.decode(z).sample # 解码采样得到的潜在变量 z,并获取样本
if not return_dict: # 如果不需要返回字典
return (dec,) # 返回解码样本作为元组
return OobleckDecoderOutput(sample=dec) # 返回 OobleckDecoderOutput 对象
.\diffusers\models\autoencoders\autoencoder_tiny.py
# 版权所有 2024 Ollin Boer Bohan 和 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证第 2.0 版(“许可证”)进行许可;
# 除非遵守许可证,否则您不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,按照许可证分发的软件是在“按原样”基础上分发的,
# 不提供任何种类的保证或条件,无论是明示还是暗示。
# 请参阅许可证以获取有关权限和限制的具体信息。
# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass
# 从 typing 模块导入可选类型、元组和联合类型
from typing import Optional, Tuple, Union
# 导入 PyTorch 库
import torch
# 从配置相关的模块导入 ConfigMixin 和 register_to_config
from ...configuration_utils import ConfigMixin, register_to_config
# 从工具模块导入 BaseOutput
from ...utils import BaseOutput
# 从加速工具模块导入 apply_forward_hook 函数
from ...utils.accelerate_utils import apply_forward_hook
# 从模型工具模块导入 ModelMixin
from ..modeling_utils import ModelMixin
# 从 VAE 模块导入 DecoderOutput、DecoderTiny 和 EncoderTiny
from .vae import DecoderOutput, DecoderTiny, EncoderTiny
@dataclass
class AutoencoderTinyOutput(BaseOutput):
"""
AutoencoderTiny 编码方法的输出。
参数:
latents (`torch.Tensor`): `Encoder` 的编码输出。
"""
# 定义编码输出的张量属性
latents: torch.Tensor
class AutoencoderTiny(ModelMixin, ConfigMixin):
r"""
一个小型的蒸馏变分自编码器(VAE)模型,用于将图像编码为潜在表示并将潜在表示解码为图像。
[`AutoencoderTiny`] 是对 `TAESD` 原始实现的封装。
此模型继承自 [`ModelMixin`]。有关其为所有模型实现的通用方法的文档,请查看超类文档
(例如下载或保存)。
"""
# 指示该模型支持梯度检查点
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
# 定义输入通道的默认值,默认为3(RGB图像)
self,
in_channels: int = 3,
# 定义输出通道的默认值,默认为3(RGB图像)
out_channels: int = 3,
# 定义编码器块输出通道的元组,默认为四个64
encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
# 定义解码器块输出通道的元组,默认为四个64
decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
# 定义激活函数的默认值,默认为 ReLU
act_fn: str = "relu",
# 定义上采样函数的默认值,默认为最近邻插值
upsample_fn: str = "nearest",
# 定义潜在通道的默认值
latent_channels: int = 4,
# 定义上采样缩放因子的默认值
upsampling_scaling_factor: int = 2,
# 定义编码器块数量的元组,默认为 (1, 3, 3, 3)
num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3),
# 定义解码器块数量的元组,默认为 (3, 3, 3, 1)
num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1),
# 定义潜在幅度的默认值
latent_magnitude: int = 3,
# 定义潜在偏移量的默认值
latent_shift: float = 0.5,
# 定义是否强制上溯的布尔值,默认为 False
force_upcast: bool = False,
# 定义缩放因子的默认值
scaling_factor: float = 1.0,
# 定义偏移因子的默认值
shift_factor: float = 0.0,
):
# 调用父类的初始化方法
super().__init__()
# 检查编码器块输出通道数量与编码器块数量是否一致
if len(encoder_block_out_channels) != len(num_encoder_blocks):
# 抛出异常,提示编码器块输出通道数量与编码器块数量不匹配
raise ValueError("`encoder_block_out_channels` should have the same length as `num_encoder_blocks`.")
# 检查解码器块输出通道数量与解码器块数量是否一致
if len(decoder_block_out_channels) != len(num_decoder_blocks):
# 抛出异常,提示解码器块输出通道数量与解码器块数量不匹配
raise ValueError("`decoder_block_out_channels` should have the same length as `num_decoder_blocks`.")
# 创建编码器实例
self.encoder = EncoderTiny(
# 输入通道数
in_channels=in_channels,
# 潜在通道数
out_channels=latent_channels,
# 编码块数量
num_blocks=num_encoder_blocks,
# 编码块输出通道
block_out_channels=encoder_block_out_channels,
# 激活函数
act_fn=act_fn,
)
# 创建解码器实例
self.decoder = DecoderTiny(
# 输入潜在通道数
in_channels=latent_channels,
# 输出通道数
out_channels=out_channels,
# 解码块数量
num_blocks=num_decoder_blocks,
# 解码块输出通道
block_out_channels=decoder_block_out_channels,
# 上采样缩放因子
upsampling_scaling_factor=upsampling_scaling_factor,
# 激活函数
act_fn=act_fn,
# 上采样函数
upsample_fn=upsample_fn,
)
# 潜在幅度
self.latent_magnitude = latent_magnitude
# 潜在偏移量
self.latent_shift = latent_shift
# 缩放因子
self.scaling_factor = scaling_factor
# 切片使用标志
self.use_slicing = False
# 瓦片使用标志
self.use_tiling = False
# 仅在启用 VAE 瓦片时相关
# 空间缩放因子
self.spatial_scale_factor = 2**out_channels
# 瓦片重叠因子
self.tile_overlap_factor = 0.125
# 瓦片样本最小大小
self.tile_sample_min_size = 512
# 瓦片潜在最小大小
self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor
# 注册解码器块输出通道到配置
self.register_to_config(block_out_channels=decoder_block_out_channels)
# 注册强制上升到配置
self.register_to_config(force_upcast=False)
# 设置梯度检查点的方法
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
# 如果模块是 EncoderTiny 或 DecoderTiny 类型
if isinstance(module, (EncoderTiny, DecoderTiny)):
# 设置梯度检查点标志
module.gradient_checkpointing = value
# 缩放潜在变量的方法
def scale_latents(self, x: torch.Tensor) -> torch.Tensor:
"""raw latents -> [0, 1]"""
# 将潜在变量缩放到 [0, 1] 范围
return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
# 反缩放潜在变量的方法
def unscale_latents(self, x: torch.Tensor) -> torch.Tensor:
"""[0, 1] -> raw latents"""
# 将 [0, 1] 范围的潜在变量反缩放回原始值
return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
# 启用切片解码的方法
def enable_slicing(self) -> None:
r"""
启用切片 VAE 解码。当启用此选项时,VAE 将输入张量切片以
分几步计算解码。这有助于节省一些内存并允许更大的批量大小。
"""
# 设置切片使用标志为真
self.use_slicing = True
# 禁用切片解码的方法
def disable_slicing(self) -> None:
r"""
禁用切片 VAE 解码。如果之前启用了 `enable_slicing`,此方法将回到
一步计算解码。
"""
# 设置切片使用标志为假
self.use_slicing = False
# 启用分块 VAE 解码的函数,默认为启用
def enable_tiling(self, use_tiling: bool = True) -> None:
r"""
启用分块 VAE 解码。启用后,VAE 将把输入张量拆分成多个块,以
分步计算解码和编码。这有助于节省大量内存,并允许处理更大的图像。
"""
# 将实例变量设置为传入的布尔值以启用或禁用分块
self.use_tiling = use_tiling
# 禁用分块 VAE 解码的函数
def disable_tiling(self) -> None:
r"""
禁用分块 VAE 解码。如果之前启用了 `enable_tiling`,则该方法将
返回到一次计算解码的方式。
"""
# 调用 enable_tiling 方法以禁用分块
self.enable_tiling(False)
# 使用分块编码器对图像批次进行编码的私有方法
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""使用分块编码器编码图像批次。
启用该选项后,VAE 将把输入张量拆分成多个块,以分步计算编码。
这有助于保持内存使用量在固定范围内,不受图像大小影响。为了避免
块之间的伪影,块之间会重叠并进行混合,以形成平滑的输出。
Args:
x (`torch.Tensor`): 输入的图像批次。
Returns:
`torch.Tensor`: 编码后的图像批次。
"""
# 编码器输出相对于输入的缩放比例
sf = self.spatial_scale_factor
# 分块的最小样本尺寸
tile_size = self.tile_sample_min_size
# 计算混合和遍历之间的像素数量
blend_size = int(tile_size * self.tile_overlap_factor)
traverse_size = tile_size - blend_size
# 计算分块的索引(上/左)
ti = range(0, x.shape[-2], traverse_size)
tj = range(0, x.shape[-1], traverse_size)
# 创建用于混合的掩码
blend_masks = torch.stack(
torch.meshgrid([torch.arange(tile_size / sf) / (blend_size / sf - 1)] * 2, indexing="ij")
)
# 将掩码限制在 0 到 1 之间并转移到相应的设备
blend_masks = blend_masks.clamp(0, 1).to(x.device)
# 初始化输出数组
out = torch.zeros(x.shape[0], 4, x.shape[-2] // sf, x.shape[-1] // sf, device=x.device)
# 遍历分块索引
for i in ti:
for j in tj:
# 获取当前分块的输入张量
tile_in = x[..., i : i + tile_size, j : j + tile_size]
# 获取当前分块的输出位置
tile_out = out[..., i // sf : (i + tile_size) // sf, j // sf : (j + tile_size) // sf]
# 对输入分块进行编码
tile = self.encoder(tile_in)
h, w = tile.shape[-2], tile.shape[-1]
# 将当前块的结果与输出进行混合
blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
# 计算总混合掩码
blend_mask = blend_mask_i * blend_mask_j
# 将块和混合掩码裁剪到一致的形状
tile, blend_mask = tile[..., :h, :w], blend_mask[..., :h, :w]
# 更新输出数组中的当前块
tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
# 返回最终的编码输出
return out
# 定义一个私有方法,用于对一批图像进行分块解码
def _tiled_decode(self, x: torch.Tensor) -> torch.Tensor:
r"""使用分块编码器对图像批进行编码。
启用此选项时,VAE 将把输入张量分割成多个块,以分步计算编码。这有助于保持内存使用在
常数范围内,无论图像大小如何。为了避免分块伪影,块之间重叠并融合在一起形成平滑输出。
参数:
x (`torch.Tensor`): 输入图像批。
返回:
`torch.Tensor`: 编码后的图像批。
"""
# 解码器输出相对于输入的缩放因子
sf = self.spatial_scale_factor
# 定义每个块的最小大小
tile_size = self.tile_latent_min_size
# 计算用于混合和在块之间遍历的像素数量
blend_size = int(tile_size * self.tile_overlap_factor)
# 计算块之间的遍历大小
traverse_size = tile_size - blend_size
# 创建块的索引(上/左)
ti = range(0, x.shape[-2], traverse_size)
tj = range(0, x.shape[-1], traverse_size)
# 创建混合掩码
blend_masks = torch.stack(
torch.meshgrid([torch.arange(tile_size * sf) / (blend_size * sf - 1)] * 2, indexing="ij")
)
# 将混合掩码限制在0到1之间,并移动到输入张量所在的设备
blend_masks = blend_masks.clamp(0, 1).to(x.device)
# 创建输出数组,初始化为零
out = torch.zeros(x.shape[0], 3, x.shape[-2] * sf, x.shape[-1] * sf, device=x.device)
# 遍历每个块的索引
for i in ti:
for j in tj:
# 获取当前块的输入数据
tile_in = x[..., i : i + tile_size, j : j + tile_size]
# 获取当前块的输出位置
tile_out = out[..., i * sf : (i + tile_size) * sf, j * sf : (j + tile_size) * sf]
# 使用解码器对当前块进行解码
tile = self.decoder(tile_in)
h, w = tile.shape[-2], tile.shape[-1]
# 将当前块的结果混合到输出中
blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
# 计算最终的混合掩码
blend_mask = (blend_mask_i * blend_mask_j)[..., :h, :w]
# 将解码后的块与输出进行混合
tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
# 返回最终的输出结果
return out
# 使用装饰器应用前向钩子,定义编码方法
@apply_forward_hook
def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderTinyOutput, Tuple[torch.Tensor]]:
# 如果使用切片且输入批量大于1
if self.use_slicing and x.shape[0] > 1:
# 对每个切片进行编码,使用分块编码或普通编码
output = [
self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x_slice) for x_slice in x.split(1)
]
# 将输出合并成一个张量
output = torch.cat(output)
else:
# 对整个输入进行编码,使用分块编码或普通编码
output = self._tiled_encode(x) if self.use_tiling else self.encoder(x)
# 如果不返回字典格式,返回输出张量的元组
if not return_dict:
return (output,)
# 返回包含编码结果的自定义输出对象
return AutoencoderTinyOutput(latents=output)
# 使用装饰器应用前向钩子,定义解码方法
@apply_forward_hook
def decode(
self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
# 函数返回类型为 DecoderOutput 或者元组,处理解码后的输出
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
# 如果使用切片并且输入的第一维大于1
if self.use_slicing and x.shape[0] > 1:
# 根据是否使用平铺解码,将切片解码结果存入列表中
output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
# 将列表中的张量沿着第0维连接成一个张量
output = torch.cat(output)
else:
# 直接对输入进行解码,根据是否使用平铺解码
output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
# 如果不需要返回字典格式
if not return_dict:
# 返回解码结果作为元组
return (output,)
# 返回解码结果作为 DecoderOutput 对象
return DecoderOutput(sample=output)
# forward 方法定义,处理输入样本并返回解码输出
def forward(
self,
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
r"""
Args:
sample (`torch.Tensor`): 输入样本。
return_dict (`bool`, *optional*, defaults to `True`):
是否返回一个 [`DecoderOutput`] 对象而不是普通元组。
"""
# 对输入样本进行编码,提取潜在表示
enc = self.encode(sample).latents
# 将潜在表示缩放到 [0, 1] 范围,并量化为字节张量,
# 类似于将潜在表示存储为 RGBA uint8 图像
scaled_enc = self.scale_latents(enc).mul_(255).round_().byte()
# 将量化后的潜在表示反量化回 [0, 1] 范围,并恢复到原始范围,
# 类似于从 RGBA uint8 图像中加载潜在表示
unscaled_enc = self.unscale_latents(scaled_enc / 255.0)
# 对反量化后的潜在表示进行解码
dec = self.decode(unscaled_enc)
# 如果不需要返回字典格式
if not return_dict:
# 返回解码结果作为元组
return (dec,)
# 返回解码结果作为 DecoderOutput 对象
return DecoderOutput(sample=dec)
.\diffusers\models\autoencoders\consistency_decoder_vae.py
# 版权所有 2024 The HuggingFace Team. 保留所有权利。
#
# 根据 Apache License 2.0 版(“许可证”)授权;
# 除非遵循许可证,否则您不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有约定,否则根据许可证分发的软件是按“原样”基础分发的,
# 不提供任何形式的保证或条件,无论是明示或暗示的。
# 有关许可证下权限和限制的具体条款,请参见许可证。
from dataclasses import dataclass # 从 dataclasses 模块导入 dataclass 装饰器
from typing import Dict, Optional, Tuple, Union # 从 typing 模块导入类型提示
import torch # 导入 PyTorch 库
import torch.nn.functional as F # 导入 PyTorch 的函数式神经网络模块
from torch import nn # 从 PyTorch 导入 nn 模块
from ...configuration_utils import ConfigMixin, register_to_config # 从配置工具模块导入混合类和注册函数
from ...schedulers import ConsistencyDecoderScheduler # 从调度器模块导入一致性解码器调度器
from ...utils import BaseOutput # 从工具模块导入基础输出类
from ...utils.accelerate_utils import apply_forward_hook # 从加速工具模块导入前向钩子应用函数
from ...utils.torch_utils import randn_tensor # 从 PyTorch 工具模块导入随机张量函数
from ..attention_processor import ( # 从注意力处理器模块导入所需类
ADDED_KV_ATTENTION_PROCESSORS, # 导入添加键值注意力处理器
CROSS_ATTENTION_PROCESSORS, # 导入交叉注意力处理器
AttentionProcessor, # 导入注意力处理器基类
AttnAddedKVProcessor, # 导入添加键值注意力处理器类
AttnProcessor, # 导入注意力处理器类
)
from ..modeling_utils import ModelMixin # 从建模工具模块导入模型混合类
from ..unets.unet_2d import UNet2DModel # 从 2D U-Net 模块导入 U-Net 模型类
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder # 从 VAE 模块导入解码器输出、对角高斯分布和编码器
@dataclass # 将该类标记为数据类
class ConsistencyDecoderVAEOutput(BaseOutput): # 定义一致性解码器 VAE 输出类,继承自基础输出类
"""
编码方法的输出。
参数:
latent_dist (`DiagonalGaussianDistribution`):
表示为均值和对数方差的编码器输出。
`DiagonalGaussianDistribution` 允许从分布中采样潜变量。
"""
latent_dist: "DiagonalGaussianDistribution" # 定义潜在分布属性,类型为对角高斯分布
class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): # 定义一致性解码器 VAE 类,继承自模型混合类和配置混合类
r"""
与 DALL-E 3 一起使用的一致性解码器。
示例:
```py
>>> import torch
>>> from diffusers import StableDiffusionPipeline, ConsistencyDecoderVAE
>>> vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16) # 从预训练模型加载一致性解码器 VAE
>>> pipe = StableDiffusionPipeline.from_pretrained( # 从预训练模型加载稳定扩散管道
... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
... ).to("cuda") # 将管道移动到 CUDA 设备
>>> image = pipe("horse", generator=torch.manual_seed(0)).images[0] # 生成图像
>>> image # 输出生成的图像
```py
"""
@register_to_config # 将该方法注册到配置中
# 初始化方法,用于创建类的实例
def __init__(
# 缩放因子,默认为 0.18215
scaling_factor: float = 0.18215,
# 潜在通道数,默认为 4
latent_channels: int = 4,
# 样本尺寸,默认为 32
sample_size: int = 32,
# 编码器激活函数,默认为 "silu"
encoder_act_fn: str = "silu",
# 编码器输出通道数的元组,默认为 (128, 256, 512, 512)
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
# 编码器是否使用双重 Z,默认为 True
encoder_double_z: bool = True,
# 编码器下采样块类型的元组,默认为多个 "DownEncoderBlock2D"
encoder_down_block_types: Tuple[str, ...] = (
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
),
# 编码器输入通道数,默认为 3
encoder_in_channels: int = 3,
# 每个编码器块的层数,默认为 2
encoder_layers_per_block: int = 2,
# 编码器归一化组数,默认为 32
encoder_norm_num_groups: int = 32,
# 编码器输出通道数,默认为 4
encoder_out_channels: int = 4,
# 解码器是否添加注意力机制,默认为 False
decoder_add_attention: bool = False,
# 解码器输出通道数的元组,默认为 (320, 640, 1024, 1024)
decoder_block_out_channels: Tuple[int, ...] = (320, 640, 1024, 1024),
# 解码器下采样块类型的元组,默认为多个 "ResnetDownsampleBlock2D"
decoder_down_block_types: Tuple[str, ...] = (
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
),
# 解码器下采样填充,默认为 1
decoder_downsample_padding: int = 1,
# 解码器输入通道数,默认为 7
decoder_in_channels: int = 7,
# 每个解码器块的层数,默认为 3
decoder_layers_per_block: int = 3,
# 解码器归一化的 epsilon 值,默认为 1e-05
decoder_norm_eps: float = 1e-05,
# 解码器归一化组数,默认为 32
decoder_norm_num_groups: int = 32,
# 解码器训练时长的时间步数,默认为 1024
decoder_num_train_timesteps: int = 1024,
# 解码器输出通道数,默认为 6
decoder_out_channels: int = 6,
# 解码器时间缩放偏移类型,默认为 "scale_shift"
decoder_resnet_time_scale_shift: str = "scale_shift",
# 解码器时间嵌入类型,默认为 "learned"
decoder_time_embedding_type: str = "learned",
# 解码器上采样块类型的元组,默认为多个 "ResnetUpsampleBlock2D"
decoder_up_block_types: Tuple[str, ...] = (
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
),
):
# 调用父类的构造函数
super().__init__()
# 初始化编码器,传入各类参数以配置其行为
self.encoder = Encoder(
act_fn=encoder_act_fn, # 激活函数
block_out_channels=encoder_block_out_channels, # 编码器每个块的输出通道数
double_z=encoder_double_z, # 是否使用双Z向量
down_block_types=encoder_down_block_types, # 编码器下采样块的类型
in_channels=encoder_in_channels, # 输入通道数
layers_per_block=encoder_layers_per_block, # 每个块的层数
norm_num_groups=encoder_norm_num_groups, # 归一化的组数
out_channels=encoder_out_channels, # 输出通道数
)
# 初始化解码器UNet模型,配置其参数
self.decoder_unet = UNet2DModel(
add_attention=decoder_add_attention, # 是否添加注意力机制
block_out_channels=decoder_block_out_channels, # 解码器每个块的输出通道数
down_block_types=decoder_down_block_types, # 解码器下采样块的类型
downsample_padding=decoder_downsample_padding, # 下采样的填充方式
in_channels=decoder_in_channels, # 输入通道数
layers_per_block=decoder_layers_per_block, # 每个块的层数
norm_eps=decoder_norm_eps, # 归一化中的epsilon值
norm_num_groups=decoder_norm_num_groups, # 归一化的组数
num_train_timesteps=decoder_num_train_timesteps, # 训练时的时间步数
out_channels=decoder_out_channels, # 输出通道数
resnet_time_scale_shift=decoder_resnet_time_scale_shift, # ResNet时间尺度偏移
time_embedding_type=decoder_time_embedding_type, # 时间嵌入类型
up_block_types=decoder_up_block_types, # 解码器上采样块的类型
)
# 初始化一致性解码器调度器
self.decoder_scheduler = ConsistencyDecoderScheduler()
# 注册编码器的输出通道数到配置中
self.register_to_config(block_out_channels=encoder_block_out_channels)
# 注册强制上采样的配置
self.register_to_config(force_upcast=False)
# 注册均值的缓冲区,形状为(1, C, 1, 1)
self.register_buffer(
"means",
torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None], # 均值张量
persistent=False, # 不持久化保存
)
# 注册标准差的缓冲区,形状为(1, C, 1, 1)
self.register_buffer(
"stds", torch.tensor([0.9654121, 1.0440036, 0.76147926, 0.77022034])[None, :, None, None], persistent=False # 标准差张量
)
# 初始化量化卷积层,输入和输出通道数相同
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
# 设置切片和拼接的使用标志为假
self.use_slicing = False
self.use_tiling = False
# 仅在启用 VAE 切片时相关
self.tile_sample_min_size = self.config.sample_size # 最小样本大小
# 判断样本大小类型,并设置样本大小
sample_size = (
self.config.sample_size[0] # 如果样本大小是列表或元组,取第一个元素
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size # 否则直接使用样本大小
)
# 计算最小的切片潜在大小
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
# 设置切片重叠因子
self.tile_overlap_factor = 0.25
# 从 diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling 复制的方法
def enable_tiling(self, use_tiling: bool = True):
r"""
启用切片 VAE 解码。启用此选项时,VAE 将输入张量拆分为切片,以
分步骤计算解码和编码。这有助于节省大量内存,并允许处理更大图像。
"""
# 设置是否使用切片标志
self.use_tiling = use_tiling
# 从 diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_tiling 复制的方法
# 定义一个方法来禁用平铺的 VAE 解码
def disable_tiling(self):
r"""
禁用平铺的 VAE 解码。如果之前启用了 `enable_tiling`,该方法将恢复为一步计算解码。
"""
# 调用 enable_tiling 方法并传入 False 参数以禁用平铺
self.enable_tiling(False)
# 从 diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_slicing 复制而来
# 定义一个方法来启用切片的 VAE 解码
def enable_slicing(self):
r"""
启用切片的 VAE 解码。当启用此选项时,VAE 将把输入张量分成多个切片进行解码计算。
这对于节省一些内存和允许更大的批处理大小非常有用。
"""
# 设置 use_slicing 为 True,以启用切片
self.use_slicing = True
# 从 diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_slicing 复制而来
# 定义一个方法来禁用切片的 VAE 解码
def disable_slicing(self):
r"""
禁用切片的 VAE 解码。如果之前启用了 `enable_slicing`,该方法将恢复为一步计算解码。
"""
# 设置 use_slicing 为 False,以禁用切片
self.use_slicing = False
@property
# 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 复制而来
# 定义一个属性方法,用于返回注意力处理器
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
返回:
`dict` 类型的注意力处理器:一个包含模型中所有注意力处理器的字典,按权重名称索引。
"""
# 创建一个空字典用于存储处理器
processors = {}
# 定义一个递归函数以添加处理器
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
# 检查模块是否有 get_processor 方法
if hasattr(module, "get_processor"):
# 将处理器添加到字典中,键为模块名称加上 ".processor"
processors[f"{name}.processor"] = module.get_processor()
# 遍历模块的子模块
for sub_name, child in module.named_children():
# 递归调用该函数以添加子模块的处理器
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
# 返回更新后的处理器字典
return processors
# 遍历当前对象的所有子模块
for name, module in self.named_children():
# 调用递归函数以添加所有处理器
fn_recursive_add_processors(name, module, processors)
# 返回最终的处理器字典
return processors
# 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 复制而来
# 设置用于计算注意力的处理器
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
设置用于计算注意力的处理器。
参数:
processor(`dict` of `AttentionProcessor` 或仅 `AttentionProcessor`):
实例化的处理器类或一个处理器类的字典,将作为所有 `Attention` 层的处理器设置。
如果 `processor` 是一个字典,键需要定义对应的交叉注意力处理器的路径。强烈建议在设置可训练的注意力处理器时使用此方式。
"""
# 计算当前注意力处理器的数量
count = len(self.attn_processors.keys())
# 检查传入的处理器字典长度是否与当前注意力层数量匹配
if isinstance(processor, dict) and len(processor) != count:
# 抛出值错误,提示处理器数量不匹配
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
# 定义递归函数,用于设置每个子模块的处理器
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
# 检查子模块是否有设置处理器的方法
if hasattr(module, "set_processor"):
# 如果处理器不是字典,则直接设置
if not isinstance(processor, dict):
module.set_processor(processor)
else:
# 从字典中弹出对应的处理器并设置
module.set_processor(processor.pop(f"{name}.processor"))
# 遍历子模块,递归调用自身
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
# 遍历当前模块的所有子模块
for name, module in self.named_children():
# 调用递归函数设置处理器
fn_recursive_attn_processor(name, module, processor)
# 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor 复制而来
def set_default_attn_processor(self):
"""
禁用自定义注意力处理器并设置默认的注意力实现。
"""
# 检查所有处理器是否属于新增的 KV 注意力处理器
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
# 设置处理器为新增的 KV 注意力处理器
processor = AttnAddedKVProcessor()
# 检查所有处理器是否属于交叉注意力处理器
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
# 设置处理器为标准的注意力处理器
processor = AttnProcessor()
else:
# 抛出值错误,提示无法设置默认处理器
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
# 调用设置处理器的方法
self.set_attn_processor(processor)
# 应用前向钩子修饰器
@apply_forward_hook
def encode(
# 输入张量 x,返回字典的标志
self, x: torch.Tensor, return_dict: bool = True
# 定义一个方法,返回编码后的图像的潜在表示
) -> Union[ConsistencyDecoderVAEOutput, Tuple[DiagonalGaussianDistribution]]:
"""
将一批图像编码为潜在表示。
参数:
x (`torch.Tensor`): 输入图像批次。
return_dict (`bool`, *可选*, 默认为 `True`):
是否返回 [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`]
而不是普通的元组。
返回:
编码图像的潜在表示。如果 `return_dict` 为 True,则返回
[`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`],否则返回
普通的 `tuple`。
"""
# 检查是否使用分块编码,并且图像尺寸超过最小样本大小
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
# 调用分块编码方法处理输入
return self.tiled_encode(x, return_dict=return_dict)
# 检查是否使用切片,并且输入批次大于1
if self.use_slicing and x.shape[0] > 1:
# 对输入的每个切片进行编码,并将结果收集到列表中
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
# 将所有编码切片连接成一个张量
h = torch.cat(encoded_slices)
else:
# 对整个输入进行编码
h = self.encoder(x)
# 通过量化卷积获取潜在表示的统计量
moments = self.quant_conv(h)
# 创建一个对角高斯分布作为后验分布
posterior = DiagonalGaussianDistribution(moments)
# 如果不需要返回字典格式
if not return_dict:
# 返回包含后验分布的元组
return (posterior,)
# 返回包含潜在分布的输出对象
return ConsistencyDecoderVAEOutput(latent_dist=posterior)
# 应用前向钩子装饰器
@apply_forward_hook
def decode(
# 定义解码方法,输入潜在变量
z: torch.Tensor,
# 可选的随机数生成器
generator: Optional[torch.Generator] = None,
# 是否返回字典格式,默认为 True
return_dict: bool = True,
# 推理步骤的数量,默认为 2
num_inference_steps: int = 2,
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
"""
解码输入的潜在向量 `z`,使用一致性解码器 VAE 模型。
Args:
z (torch.Tensor): 输入的潜在向量。
generator (Optional[torch.Generator]): 随机数生成器,默认为 None。
return_dict (bool): 是否以字典形式返回输出,默认为 True。
num_inference_steps (int): 推理步骤的数量,默认为 2。
Returns:
Union[DecoderOutput, Tuple[torch.Tensor]]: 解码后的输出。
"""
# 对潜在向量 `z` 进行标准化处理
z = (z * self.config.scaling_factor - self.means) / self.stds
# 计算缩放因子,基于输出通道的数量
scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
# 将潜在向量 `z` 进行最近邻插值缩放
z = F.interpolate(z, mode="nearest", scale_factor=scale_factor)
# 获取当前张量的批量大小、高度和宽度
batch_size, _, height, width = z.shape
# 设置解码器调度器的时间步长
self.decoder_scheduler.set_timesteps(num_inference_steps, device=self.device)
# 初始化噪声张量 `x_t`,用于后续的解码过程
x_t = self.decoder_scheduler.init_noise_sigma * randn_tensor(
(batch_size, 3, height, width), generator=generator, dtype=z.dtype, device=z.device
)
# 对每个时间步进行解码
for t in self.decoder_scheduler.timesteps:
# 将当前噪声与潜在向量 `z` 组合成模型输入
model_input = torch.concat([self.decoder_scheduler.scale_model_input(x_t, t), z], dim=1)
# 获取模型输出的样本
model_output = self.decoder_unet(model_input, t).sample[:, :3, :, :]
# 更新当前样本 `x_t`
prev_sample = self.decoder_scheduler.step(model_output, t, x_t, generator).prev_sample
x_t = prev_sample
# 将最后的样本赋值给 `x_0`
x_0 = x_t
# 如果不返回字典,直接返回样本
if not return_dict:
return (x_0,)
# 返回解码后的输出,封装成 DecoderOutput 对象
return DecoderOutput(sample=x_0)
# 从 diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_v 复制的函数
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
# 计算混合范围,确保不超过输入张量的维度
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
# 对于每个混合范围内的高度像素进行加权混合
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
# 返回混合后的张量
return b
# 从 diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_h 复制的函数
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
# 计算混合范围,确保不超过输入张量的维度
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
# 对于每个混合范围内的宽度像素进行加权混合
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
# 返回混合后的张量
return b
# 定义一个使用平铺编码器对图像批次进行编码的方法
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[ConsistencyDecoderVAEOutput, Tuple]:
r"""使用平铺编码器编码一批图像。
当启用此选项时,VAE将输入张量分割成平铺,以进行多个步骤的编码。这有助于保持内存使用量在任何图像大小下都是恒定的。平铺编码的最终结果与非平铺编码不同,因为每个平铺使用不同的编码器。为了避免平铺伪影,平铺之间会重叠并进行混合,以形成平滑的输出。您仍然可能会看到平铺大小的变化,但应该不那么明显。
参数:
x (`torch.Tensor`): 输入图像批次。
return_dict (`bool`, *可选*, 默认为 `True`):
是否返回 [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`]
而不是一个普通的元组。
返回:
[`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] 或 `tuple`:
如果 return_dict 为 True,则返回一个 [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`]
,否则返回一个普通的 `tuple`。
"""
# 计算重叠区域的大小
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
# 计算混合范围的大小
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
# 计算行限制的大小
row_limit = self.tile_latent_min_size - blend_extent
# 将图像分割成512x512的平铺并分别编码
rows = [] # 存储每一行的编码结果
for i in range(0, x.shape[2], overlap_size): # 遍历图像的高度
row = [] # 存储当前行的编码结果
for j in range(0, x.shape[3], overlap_size): # 遍历图像的宽度
# 提取当前平铺
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
# 使用编码器对平铺进行编码
tile = self.encoder(tile)
# 进行量化处理
tile = self.quant_conv(tile)
# 将编码后的平铺添加到当前行
row.append(tile)
# 将当前行添加到所有行的列表中
rows.append(row)
result_rows = [] # 存储最终结果的行
for i, row in enumerate(rows): # 遍历每一行的编码结果
result_row = [] # 存储当前行的最终结果
for j, tile in enumerate(row): # 遍历当前行的每个平铺
# 将上方的平铺与当前平铺进行混合
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
# 将左侧的平铺与当前平铺进行混合
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
# 将处理后的平铺裁剪并添加到结果行
result_row.append(tile[:, :, :row_limit, :row_limit])
# 将当前行的结果合并到结果行列表中
result_rows.append(torch.cat(result_row, dim=3))
# 将所有结果行在高度维度上进行合并
moments = torch.cat(result_rows, dim=2)
# 创建对角高斯分布对象
posterior = DiagonalGaussianDistribution(moments)
# 如果不返回字典,则返回一个包含后验分布的元组
if not return_dict:
return (posterior,)
# 返回包含后验分布的ConsistencyDecoderVAEOutput对象
return ConsistencyDecoderVAEOutput(latent_dist=posterior)
# 定义前向传播方法,处理输入样本并返回解码结果
def forward(
self,
sample: torch.Tensor, # 输入样本,类型为 torch.Tensor
sample_posterior: bool = False, # 是否从后验分布采样,默认为 False
return_dict: bool = True, # 是否返回 DecoderOutput 对象,默认为 True
generator: Optional[torch.Generator] = None, # 可选的随机数生成器,用于采样
) -> Union[DecoderOutput, Tuple[torch.Tensor]]: # 返回类型可以是 DecoderOutput 或元组
r"""
Args:
sample (`torch.Tensor`): Input sample. # 输入样本
sample_posterior (`bool`, *optional*, defaults to `False`): # 是否采样后验的标志
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`): # 返回类型的标志
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
generator (`torch.Generator`, *optional*, defaults to `None`): # 随机数生成器的说明
Generator to use for sampling.
Returns:
[`DecoderOutput`] or `tuple`: # 返回类型的说明
If return_dict is True, a [`DecoderOutput`] is returned, otherwise a plain `tuple` is returned.
"""
x = sample # 将输入样本赋值给变量 x
posterior = self.encode(x).latent_dist # 使用编码器对样本进行编码,获取后验分布
if sample_posterior: # 检查是否需要从后验分布采样
z = posterior.sample(generator=generator) # 从后验分布中采样,使用指定的生成器
else: # 如果不从后验分布采样
z = posterior.mode() # 选择后验分布的众数作为 z
dec = self.decode(z, generator=generator).sample # 解码 z,并获取解码后的样本
if not return_dict: # 如果不需要返回字典
return (dec,) # 返回解码样本的元组
return DecoderOutput(sample=dec) # 返回 DecoderOutput 对象,包含解码样本
.\diffusers\models\autoencoders\vae.py
# 版权声明,2024年HuggingFace团队保留所有权利
#
# 根据Apache许可证第2.0版(“许可证”)授权;
# 除非遵循许可证,否则您不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则根据许可证分发的软件是按“原样”基础提供的,
# 不提供任何形式的保证或条件,无论是明示或暗示的。
# 请参阅许可证以获取有关特定语言的权限和限制的更多信息。
from dataclasses import dataclass # 导入dataclass装饰器用于简化类的定义
from typing import Optional, Tuple # 导入可选类型和元组类型,用于类型注释
import numpy as np # 导入numpy库,用于数值计算
import torch # 导入PyTorch库,用于构建深度学习模型
import torch.nn as nn # 导入PyTorch的神经网络模块
from ...utils import BaseOutput, is_torch_version # 从utils模块导入BaseOutput类和版本检查函数
from ...utils.torch_utils import randn_tensor # 从torch_utils模块导入随机张量生成函数
from ..activations import get_activation # 从activations模块导入获取激活函数的函数
from ..attention_processor import SpatialNorm # 从attention_processor模块导入空间归一化类
from ..unets.unet_2d_blocks import ( # 从unet_2d_blocks模块导入多个网络块
AutoencoderTinyBlock, # 导入自动编码器小块
UNetMidBlock2D, # 导入UNet中间块
get_down_block, # 导入获取下采样块的函数
get_up_block, # 导入获取上采样块的函数
)
@dataclass # 使用dataclass装饰器,简化类的初始化和表示
class DecoderOutput(BaseOutput): # 定义DecoderOutput类,继承自BaseOutput
r""" # 文档字符串,描述解码方法的输出
Output of decoding method. # 解码方法的输出
Args: # 参数说明
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): # 输入样本的描述
The decoded output sample from the last layer of the model. # 模型最后一层的解码输出样本
"""
sample: torch.Tensor # 定义样本属性,类型为torch.Tensor
commit_loss: Optional[torch.FloatTensor] = None # 定义可选的损失属性,默认为None
class Encoder(nn.Module): # 定义Encoder类,继承自nn.Module
r""" # 文档字符串,描述变分自动编码器的Encoder层
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. # 变分自动编码器的Encoder层
Args: # 参数说明
in_channels (`int`, *optional*, defaults to 3): # 输入通道的描述
The number of input channels. # 输入通道的数量
out_channels (`int`, *optional*, defaults to 3): # 输出通道的描述
The number of output channels. # 输出通道的数量
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): # 下采样块类型的描述
The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available # 使用的下采样块类型,具体可查看相关文档
options. # 可选项
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): # 块输出通道的描述
The number of output channels for each block. # 每个块的输出通道数量
layers_per_block (`int`, *optional*, defaults to 2): # 每个块层数的描述
The number of layers per block. # 每个块的层数
norm_num_groups (`int`, *optional*, defaults to 32): # 归一化组数的描述
The number of groups for normalization. # 归一化的组数
act_fn (`str`, *optional*, defaults to `"silu"`): # 激活函数类型的描述
The activation function to use. See `~diffusers.models.activations.get_activation` for available options. # 使用的激活函数,具体可查看相关文档
double_z (`bool`, *optional*, defaults to `True`): # 最后一个块输出通道双倍化的描述
Whether to double the number of output channels for the last block. # 是否将最后一个块的输出通道数量翻倍
"""
# 初始化方法,设置网络的基本参数
def __init__(
self,
in_channels: int = 3, # 输入通道数,默认为3(RGB图像)
out_channels: int = 3, # 输出通道数,默认为3(RGB图像)
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), # 下采样模块类型
block_out_channels: Tuple[int, ...] = (64,), # 每个块的输出通道数
layers_per_block: int = 2, # 每个下采样块的层数
norm_num_groups: int = 32, # 归一化时的组数
act_fn: str = "silu", # 激活函数类型,默认为SiLU
double_z: bool = True, # 是否双输出通道
mid_block_add_attention=True, # 中间块是否添加注意力机制
):
super().__init__() # 调用父类构造函数
self.layers_per_block = layers_per_block # 保存每块的层数
# 定义输入卷积层
self.conv_in = nn.Conv2d(
in_channels, # 输入通道数
block_out_channels[0], # 输出通道数
kernel_size=3, # 卷积核大小
stride=1, # 步幅
padding=1, # 填充
)
self.down_blocks = nn.ModuleList([]) # 初始化下采样块的模块列表
# down
output_channel = block_out_channels[0] # 初始化输出通道数
for i, down_block_type in enumerate(down_block_types): # 遍历下采样块类型
input_channel = output_channel # 当前块的输入通道数
output_channel = block_out_channels[i] # 当前块的输出通道数
is_final_block = i == len(block_out_channels) - 1 # 判断是否为最后一个块
# 获取下采样块并初始化
down_block = get_down_block(
down_block_type, # 下采样块类型
num_layers=self.layers_per_block, # 下采样块的层数
in_channels=input_channel, # 输入通道数
out_channels=output_channel, # 输出通道数
add_downsample=not is_final_block, # 是否添加下采样
resnet_eps=1e-6, # ResNet的epsilon值
downsample_padding=0, # 下采样的填充
resnet_act_fn=act_fn, # ResNet的激活函数
resnet_groups=norm_num_groups, # ResNet的组数
attention_head_dim=output_channel, # 注意力头的维度
temb_channels=None, # 时间嵌入通道数
)
self.down_blocks.append(down_block) # 将下采样块添加到模块列表中
# mid
# 定义中间块
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1], # 中间块的输入通道数
resnet_eps=1e-6, # ResNet的epsilon值
resnet_act_fn=act_fn, # ResNet的激活函数
output_scale_factor=1, # 输出缩放因子
resnet_time_scale_shift="default", # ResNet时间缩放偏移
attention_head_dim=block_out_channels[-1], # 注意力头的维度
resnet_groups=norm_num_groups, # ResNet的组数
temb_channels=None, # 时间嵌入通道数
add_attention=mid_block_add_attention, # 是否添加注意力机制
)
# out
# 定义输出的归一化层
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU() # 激活函数层
# 根据双输出设置卷积层的输出通道数
conv_out_channels = 2 * out_channels if double_z else out_channels
# 定义输出卷积层
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
self.gradient_checkpointing = False # 设置梯度检查点为False
# 定义 Encoder 类的前向传播方法,接收一个张量作为输入并返回一个张量
def forward(self, sample: torch.Tensor) -> torch.Tensor:
r"""Encoder 类的前向方法。"""
# 使用输入的样本通过初始卷积层进行处理
sample = self.conv_in(sample)
# 如果处于训练模式并且开启了梯度检查点
if self.training and self.gradient_checkpointing:
# 定义一个创建自定义前向传播的内部函数
def create_custom_forward(module):
# 定义自定义前向传播函数,接受任意输入并调用模块
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# 如果 PyTorch 版本大于等于 1.11.0
if is_torch_version(">=", "1.11.0"):
# 遍历每个下采样块,应用检查点机制来节省内存
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block), sample, use_reentrant=False
)
# 中间块应用检查点机制
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, use_reentrant=False
)
else:
# 遍历每个下采样块,应用检查点机制
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
# 中间块应用检查点机制
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
else:
# 否则,直接通过下采样块处理样本
for down_block in self.down_blocks:
sample = down_block(sample)
# 直接通过中间块处理样本
sample = self.mid_block(sample)
# 后处理步骤
sample = self.conv_norm_out(sample) # 应用归一化卷积层
sample = self.conv_act(sample) # 应用激活函数卷积层
sample = self.conv_out(sample) # 应用输出卷积层
# 返回处理后的样本
return sample
# 定义变分自编码器的解码层,将潜在表示解码为输出样本
class Decoder(nn.Module):
r"""
`Decoder`层的文档字符串,描述其功能及参数
Args:
in_channels (`int`, *optional*, defaults to 3):
输入通道的数量
out_channels (`int`, *optional*, defaults to 3):
输出通道的数量
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
使用的上采样块类型,参考可用选项
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
每个块的输出通道数量
layers_per_block (`int`, *optional*, defaults to 2):
每个块包含的层数
norm_num_groups (`int`, *optional*, defaults to 32):
归一化的组数
act_fn (`str`, *optional*, defaults to `"silu"`):
使用的激活函数,参考可用选项
norm_type (`str`, *optional*, defaults to `"group"`):
使用的归一化类型,可以是"group"或"spatial"
"""
# 初始化解码器层
def __init__(
self,
in_channels: int = 3, # 默认输入通道为3
out_channels: int = 3, # 默认输出通道为3
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), # 默认上采样块类型
block_out_channels: Tuple[int, ...] = (64,), # 默认块输出通道
layers_per_block: int = 2, # 每个块的默认层数
norm_num_groups: int = 32, # 默认归一化组数
act_fn: str = "silu", # 默认激活函数为"silu"
norm_type: str = "group", # 默认归一化类型为"group"
mid_block_add_attention=True, # 中间块是否添加注意力机制,默认为True
# 初始化父类
):
super().__init__()
# 设置每个块的层数
self.layers_per_block = layers_per_block
# 定义输入卷积层,转换输入通道数到最后块的输出通道数
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[-1],
kernel_size=3,
stride=1,
padding=1,
)
# 初始化上采样块的模块列表
self.up_blocks = nn.ModuleList([])
# 根据归一化类型设置时间嵌入通道数
temb_channels = in_channels if norm_type == "spatial" else None
# 中间块
self.mid_block = UNetMidBlock2D(
# 设置中间块的输入通道、eps、激活函数等参数
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=temb_channels,
add_attention=mid_block_add_attention,
)
# 上采样
# 反转输出通道数列表以用于上采样
reversed_block_out_channels = list(reversed(block_out_channels))
# 设定当前输出通道数为反转列表的第一个元素
output_channel = reversed_block_out_channels[0]
# 遍历上采样块类型并创建相应的上采样块
for i, up_block_type in enumerate(up_block_types):
# 存储前一个输出通道数
prev_output_channel = output_channel
# 更新当前输出通道数
output_channel = reversed_block_out_channels[i]
# 判断当前块是否为最后一个块
is_final_block = i == len(block_out_channels) - 1
# 创建上采样块并传入相关参数
up_block = get_up_block(
up_block_type,
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
prev_output_channel=None,
add_upsample=not is_final_block,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=output_channel,
temb_channels=temb_channels,
resnet_time_scale_shift=norm_type,
)
# 将新创建的上采样块添加到模块列表
self.up_blocks.append(up_block)
# 更新前一个输出通道数
prev_output_channel = output_channel
# 输出层
# 根据归一化类型选择输出卷积层的归一化方法
if norm_type == "spatial":
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
else:
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
# 设置输出卷积层的激活函数为 SiLU
self.conv_act = nn.SiLU()
# 定义最终输出的卷积层,输出通道数为 out_channels
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
# 初始化梯度检查点开关为 False
self.gradient_checkpointing = False
# 定义前向传播方法
def forward(
self,
sample: torch.Tensor,
latent_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# `Decoder` 类的前向方法文档字符串
# 通过输入卷积层处理样本
sample = self.conv_in(sample)
# 获取上采样块参数的数据类型
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
# 如果处于训练模式且使用梯度检查点
if self.training and self.gradient_checkpointing:
# 创建自定义前向函数
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# 检查 PyTorch 版本
if is_torch_version(">=", "1.11.0"):
# 中间处理
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
latent_embeds,
use_reentrant=False,
)
# 转换样本数据类型
sample = sample.to(upscale_dtype)
# 上采样处理
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
latent_embeds,
use_reentrant=False,
)
else:
# 中间处理
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, latent_embeds
)
# 转换样本数据类型
sample = sample.to(upscale_dtype)
# 上采样处理
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
else:
# 中间处理
sample = self.mid_block(sample, latent_embeds)
# 转换样本数据类型
sample = sample.to(upscale_dtype)
# 上采样处理
for up_block in self.up_blocks:
sample = up_block(sample, latent_embeds)
# 后处理
if latent_embeds is None:
# 如果没有潜在嵌入,则直接进行卷积归一化输出
sample = self.conv_norm_out(sample)
else:
# 如果有潜在嵌入,则传入进行卷积归一化输出
sample = self.conv_norm_out(sample, latent_embeds)
# 应用激活函数
sample = self.conv_act(sample)
# 最终输出卷积层处理样本
sample = self.conv_out(sample)
# 返回处理后的样本
return sample
# 定义一个名为 UpSample 的类,继承自 nn.Module
class UpSample(nn.Module):
r"""
`UpSample` 层用于变分自编码器,可以对输入进行上采样。
参数:
in_channels (`int`, *可选*, 默认为 3):
输入通道的数量。
out_channels (`int`, *可选*, 默认为 3):
输出通道的数量。
"""
# 初始化方法,接受输入和输出通道数量
def __init__(
self,
in_channels: int, # 输入通道数量
out_channels: int, # 输出通道数量
) -> None:
super().__init__() # 调用父类的初始化方法
self.in_channels = in_channels # 保存输入通道数量
self.out_channels = out_channels # 保存输出通道数量
# 创建转置卷积层,用于上采样
self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
# 前向传播方法
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""`UpSample` 类的前向传播方法。"""
x = torch.relu(x) # 对输入应用 ReLU 激活函数
x = self.deconv(x) # 通过转置卷积层进行上采样
return x # 返回上采样后的结果
# 定义一个名为 MaskConditionEncoder 的类,继承自 nn.Module
class MaskConditionEncoder(nn.Module):
"""
用于 AsymmetricAutoencoderKL
"""
# 初始化方法,接受多个参数以构建编码器
def __init__(
self,
in_ch: int, # 输入通道数量
out_ch: int = 192, # 输出通道数量,默认值为 192
res_ch: int = 768, # 结果通道数量,默认值为 768
stride: int = 16, # 步幅,默认值为 16
) -> None:
super().__init__() # 调用父类的初始化方法
channels = [] # 初始化通道列表
# 计算每一层的输入和输出通道数量,直到步幅小于等于 1
while stride > 1:
stride = stride // 2 # 将步幅减半
in_ch_ = out_ch * 2 # 输入通道数量为输出通道的两倍
if out_ch > res_ch: # 如果输出通道大于结果通道
out_ch = res_ch # 将输出通道设置为结果通道
if stride == 1: # 如果步幅为 1
in_ch_ = res_ch # 输入通道数量设置为结果通道
channels.append((in_ch_, out_ch)) # 将输入和输出通道对添加到列表
out_ch *= 2 # 输出通道数量翻倍
out_channels = [] # 初始化输出通道列表
# 从通道列表中提取输出通道数量
for _in_ch, _out_ch in channels:
out_channels.append(_out_ch) # 添加输出通道数量
out_channels.append(channels[-1][0]) # 添加最后一层的输入通道数量
layers = [] # 初始化层列表
in_ch_ = in_ch # 将输入通道数量赋值给临时变量
# 根据输出通道数量构建卷积层
for l in range(len(out_channels)):
out_ch_ = out_channels[l] # 当前输出通道数量
if l == 0 or l == 1: # 对于前两层
layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1)) # 添加 3x3 卷积层
else: # 对于后续层
layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1)) # 添加 4x4 卷积层
in_ch_ = out_ch_ # 更新输入通道数量
self.layers = nn.Sequential(*layers) # 将所有层组合成一个顺序容器
# 前向传播方法
def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:
r"""`MaskConditionEncoder` 类的前向传播方法。"""
out = {} # 初始化输出字典
# 遍历所有层
for l in range(len(self.layers)):
layer = self.layers[l] # 获取当前层
x = layer(x) # 通过当前层处理输入
out[str(tuple(x.shape))] = x # 将当前输出的形状作为键,输出张量作为值存入字典
x = torch.relu(x) # 对输出应用 ReLU 激活函数
return out # 返回输出字典
# 定义一个名为 MaskConditionDecoder 的类,继承自 nn.Module
class MaskConditionDecoder(nn.Module):
r"""`MaskConditionDecoder` 应与 [`AsymmetricAutoencoderKL`] 一起使用,以增强模型的
解码器,结合掩膜和被掩膜的图像。
# 函数参数定义部分
Args:
in_channels (`int`, *optional*, defaults to 3): # 输入通道的数量,默认为3
The number of input channels.
out_channels (`int`, *optional*, defaults to 3): # 输出通道的数量,默认为3
The number of output channels.
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): # 使用的上采样模块类型,默认为UpDecoderBlock2D
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): # 每个模块的输出通道数量,默认为64
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2): # 每个模块的层数,默认为2
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32): # 归一化的组数,默认为32
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`): # 使用的激活函数,默认为silu
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
norm_type (`str`, *optional*, defaults to `"group"`): # 归一化类型,可以是"group"或"spatial",默认为"group"
The normalization type to use. Can be either `"group"` or `"spatial"`.
"""
# 初始化方法定义
def __init__(
self,
in_channels: int = 3, # 输入通道的数量,默认为3
out_channels: int = 3, # 输出通道的数量,默认为3
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), # 上采样模块类型,默认为UpDecoderBlock2D
block_out_channels: Tuple[int, ...] = (64,), # 每个模块的输出通道数量,默认为64
layers_per_block: int = 2, # 每个模块的层数,默认为2
norm_num_groups: int = 32, # 归一化的组数,默认为32
act_fn: str = "silu", # 激活函数,默认为silu
norm_type: str = "group", # 归一化类型,默认为"group"
):
# 调用父类构造函数初始化
super().__init__()
# 设置每个块的层数
self.layers_per_block = layers_per_block
# 初始化输入卷积层,接收输入通道并生成块输出通道
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[-1],
kernel_size=3,
stride=1,
padding=1,
)
# 创建一个空的模块列表,用于存储上采样块
self.up_blocks = nn.ModuleList([])
# 根据归一化类型设置时间嵌入通道数
temb_channels = in_channels if norm_type == "spatial" else None
# 中间块的初始化
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1], # 使用最后一个块的输出通道作为输入
resnet_eps=1e-6, # ResNet 的 epsilon 参数
resnet_act_fn=act_fn, # ResNet 的激活函数
output_scale_factor=1, # 输出缩放因子
resnet_time_scale_shift="default" if norm_type == "group" else norm_type, # 时间缩放偏移
attention_head_dim=block_out_channels[-1], # 注意力头的维度
resnet_groups=norm_num_groups, # ResNet 的组数
temb_channels=temb_channels, # 时间嵌入通道数
)
# 初始化上采样块
reversed_block_out_channels = list(reversed(block_out_channels)) # 反转输出通道列表
output_channel = reversed_block_out_channels[0] # 获取第一个输出通道
for i, up_block_type in enumerate(up_block_types): # 遍历上采样块类型
prev_output_channel = output_channel # 保存上一个输出通道数
output_channel = reversed_block_out_channels[i] # 获取当前输出通道数
is_final_block = i == len(block_out_channels) - 1 # 判断是否为最后一个块
# 获取上采样块
up_block = get_up_block(
up_block_type, # 上采样块类型
num_layers=self.layers_per_block + 1, # 上采样层数
in_channels=prev_output_channel, # 输入通道数
out_channels=output_channel, # 输出通道数
prev_output_channel=None, # 前一个输出通道数
add_upsample=not is_final_block, # 是否添加上采样操作
resnet_eps=1e-6, # ResNet 的 epsilon 参数
resnet_act_fn=act_fn, # ResNet 的激活函数
resnet_groups=norm_num_groups, # ResNet 的组数
attention_head_dim=output_channel, # 注意力头的维度
temb_channels=temb_channels, # 时间嵌入通道数
resnet_time_scale_shift=norm_type, # 时间缩放偏移
)
self.up_blocks.append(up_block) # 将上采样块添加到模块列表
prev_output_channel = output_channel # 更新前一个输出通道数
# 条件编码器的初始化
self.condition_encoder = MaskConditionEncoder(
in_ch=out_channels, # 输入通道数
out_ch=block_out_channels[0], # 输出通道数
res_ch=block_out_channels[-1], # ResNet 通道数
)
# 输出层的归一化处理
if norm_type == "spatial": # 如果归一化类型为空间归一化
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) # 初始化空间归一化
else: # 否则使用组归一化
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) # 初始化组归一化
# 初始化激活函数为 SiLU
self.conv_act = nn.SiLU()
# 初始化输出卷积层
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
# 初始化梯度检查点标志为 False
self.gradient_checkpointing = False
def forward(
self,
z: torch.Tensor, # 输入的张量 z
image: Optional[torch.Tensor] = None, # 可选的输入图像张量
mask: Optional[torch.Tensor] = None, # 可选的输入掩码张量
latent_embeds: Optional[torch.Tensor] = None, # 可选的潜在嵌入张量
# 定义一个向量量化器类,继承自 nn.Module
class VectorQuantizer(nn.Module):
"""
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
multiplications and allows for post-hoc remapping of indices.
"""
# 初始化方法,设置类的基本参数
def __init__(
self,
n_e: int, # 向量量化的嵌入数量
vq_embed_dim: int, # 嵌入的维度
beta: float, # beta 参数,用于调节量化误差
remap=None, # 用于重映射的可选参数
unknown_index: str = "random", # 未知索引的处理方式
sane_index_shape: bool = False, # 是否强制索引形状的合理性
legacy: bool = True, # 是否使用旧版本的实现
):
# 调用父类的初始化方法
super().__init__()
# 保存嵌入数量
self.n_e = n_e
# 保存嵌入维度
self.vq_embed_dim = vq_embed_dim
# 保存 beta 参数
self.beta = beta
# 保存是否使用旧版的标志
self.legacy = legacy
# 初始化嵌入层,随机生成嵌入权重
self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim)
# 将嵌入权重初始化为[-1/n_e, 1/n_e]的均匀分布
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
# 处理重映射参数
self.remap = remap
if self.remap is not None: # 如果提供了重映射文件
# 注册一个缓冲区,加载重映射的数据
self.register_buffer("used", torch.tensor(np.load(self.remap)))
self.used: torch.Tensor # 声明用于重映射的张量
# 重新嵌入的数量
self.re_embed = self.used.shape[0]
# 设置未知索引的方式
self.unknown_index = unknown_index # "random"、"extra"或整数
if self.unknown_index == "extra": # 如果未知索引为"extra"
self.unknown_index = self.re_embed # 设置为重新嵌入数量
self.re_embed = self.re_embed + 1 # 增加重新嵌入的数量
# 打印重映射信息
print(
f"Remapping {self.n_e} indices to {self.re_embed} indices. "
f"Using {self.unknown_index} for unknown indices."
)
else:
# 如果没有提供重映射,则重新嵌入数量与嵌入数量相同
self.re_embed = n_e
# 保存是否强制索引形状合理的标志
self.sane_index_shape = sane_index_shape
# 将索引映射到已使用的索引
def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor:
# 保存输入张量的形状
ishape = inds.shape
# 确保输入张量至少有两个维度
assert len(ishape) > 1
# 将输入张量重塑为二维,保持第一个维度不变
inds = inds.reshape(ishape[0], -1)
# 将使用的张量转换到相同的设备
used = self.used.to(inds)
# 检查 inds 中的元素是否与 used 中的元素匹配
match = (inds[:, :, None] == used[None, None, ...]).long()
# 找到匹配的索引
new = match.argmax(-1)
# 检查是否有未知索引
unknown = match.sum(2) < 1
if self.unknown_index == "random": # 如果未知索引为"random"
# 随机生成未知索引
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
else:
# 否则设置为指定的未知索引
new[unknown] = self.unknown_index
# 将结果重塑为原来的形状并返回
return new.reshape(ishape)
# 将映射到的索引还原为所有索引
def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor:
# 保存输入张量的形状
ishape = inds.shape
# 确保输入张量至少有两个维度
assert len(ishape) > 1
# 将输入张量重塑为二维,保持第一个维度不变
inds = inds.reshape(ishape[0], -1)
# 将使用的张量转换到相同的设备
used = self.used.to(inds)
if self.re_embed > self.used.shape[0]: # 如果有额外的标记
# 将超出已使用索引的标记设置为零
inds[inds >= self.used.shape[0]] = 0 # 简单设置为零
# 根据 inds 从 used 中选择对应的值
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
# 将结果重塑为原来的形状并返回
return back.reshape(ishape)
# 前向传播方法,接收一个张量 z,返回量化张量、损失和附加信息
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Tuple]:
# 将 z 重新排列为 (batch, height, width, channel) 的形状,并展平
z = z.permute(0, 2, 3, 1).contiguous()
# 将 z 展平为二维张量,维度为 (batch_size * height * width, vq_embed_dim)
z_flattened = z.view(-1, self.vq_embed_dim)
# 计算 z 与嵌入 e_j 之间的距离,公式为 (z - e)^2 = z^2 + e^2 - 2 * e * z
min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1)
# 根据最小编码索引从嵌入中获取量化的 z,重新调整为原始 z 的形状
z_q = self.embedding(min_encoding_indices).view(z.shape)
perplexity = None # 初始化困惑度为 None
min_encodings = None # 初始化最小编码为 None
# 计算嵌入的损失
if not self.legacy:
# 计算损失时考虑 beta 权重
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
else:
# 计算损失时考虑不同的权重
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
# 保持梯度
z_q: torch.Tensor = z + (z_q - z).detach()
# 将 z_q 重新排列为与原始输入形状相匹配
z_q = z_q.permute(0, 3, 1, 2).contiguous()
if self.remap is not None:
# 如果存在重映射,则调整最小编码索引形状,增加批次维度
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
# 将索引映射到使用的编码
min_encoding_indices = self.remap_to_used(min_encoding_indices)
# 将索引展平
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
if self.sane_index_shape:
# 如果需要,调整最小编码索引的形状以匹配 z_q 的形状
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
# 返回量化的 z、损失和其他信息的元组
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
# 获取代码簿条目,根据索引返回量化的潜在向量
def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.Tensor:
# shape 指定 (batch, height, width, channel)
if self.remap is not None:
# 如果存在重映射,则调整索引形状,增加批次维度
indices = indices.reshape(shape[0], -1) # add batch axis
# 将索引映射回所有编码
indices = self.unmap_to_all(indices)
# 将索引展平
indices = indices.reshape(-1) # flatten again
# 获取量化的潜在向量
z_q: torch.Tensor = self.embedding(indices)
if shape is not None:
# 如果形状不为空,将 z_q 重新调整为指定的形状
z_q = z_q.view(shape)
# 重新排列以匹配原始输入形状
z_q = z_q.permute(0, 3, 1, 2).contiguous()
# 返回量化的潜在向量
return z_q
# 定义对角高斯分布类
class DiagonalGaussianDistribution(object):
# 初始化方法,接收参数和是否为确定性分布的标志
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
# 将参数存储在实例中
self.parameters = parameters
# 将参数分为均值和对数方差
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
# 将对数方差限制在-30到20之间
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
# 记录是否为确定性分布
self.deterministic = deterministic
# 计算标准差
self.std = torch.exp(0.5 * self.logvar)
# 计算方差
self.var = torch.exp(self.logvar)
# 如果是确定性分布,方差和标准差设为零
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
)
# 采样方法,生成符合分布的样本
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
# 确保样本与参数在相同的设备上且具有相同的数据类型
sample = randn_tensor(
self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
# 根据均值和标准差生成样本
x = self.mean + self.std * sample
# 返回生成的样本
return x
# 计算与另一个分布的KL散度
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
# 如果是确定性分布,KL散度为0
if self.deterministic:
return torch.Tensor([0.0])
else:
# 如果没有提供另一个分布
if other is None:
# 计算自身的KL散度
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
else:
# 计算与另一个分布的KL散度
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
# 计算负对数似然
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
# 如果是确定性分布,负对数似然为0
if self.deterministic:
return torch.Tensor([0.0])
# 计算常数log(2π)
logtwopi = np.log(2.0 * np.pi)
# 返回负对数似然的计算结果
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
# 返回分布的众数
def mode(self) -> torch.Tensor:
return self.mean
# 定义EncoderTiny类,继承自nn.Module
class EncoderTiny(nn.Module):
r"""
`EncoderTiny`层是`Encoder`层的简化版本。
参数:
in_channels (`int`):
输入通道的数量。
out_channels (`int`):
输出通道的数量。
num_blocks (`Tuple[int, ...]`):
元组中的每个值表示一个Conv2d层后跟随`value`数量的`AutoencoderTinyBlock`。
block_out_channels (`Tuple[int, ...]`):
每个块的输出通道数量。
act_fn (`str`):
使用的激活函数。请参见`~diffusers.models.activations.get_activation`以获取可用选项。
"""
# 初始化方法,构造 EncoderTiny 类的实例
def __init__(
self,
in_channels: int, # 输入通道数
out_channels: int, # 输出通道数
num_blocks: Tuple[int, ...], # 每个层中块的数量
block_out_channels: Tuple[int, ...], # 每个层的输出通道数
act_fn: str, # 激活函数的类型
):
# 调用父类的初始化方法
super().__init__()
layers = [] # 初始化空层列表
# 遍历每个层的块数量
for i, num_block in enumerate(num_blocks):
num_channels = block_out_channels[i] # 当前层的输出通道数
# 如果是第一个层,创建卷积层
if i == 0:
layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1))
else:
# 创建后续卷积层,包含步幅和无偏置选项
layers.append(
nn.Conv2d(
num_channels,
num_channels,
kernel_size=3,
padding=1,
stride=2,
bias=False,
)
)
# 添加指定数量的 AutoencoderTinyBlock
for _ in range(num_block):
layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
# 添加最后的卷积层,将最后一层输出通道映射到目标输出通道
layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1))
# 将所有层组合为一个顺序模块
self.layers = nn.Sequential(*layers)
# 初始化梯度检查点标志为 False
self.gradient_checkpointing = False
# 前向传播方法
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""EncoderTiny 类的前向方法。"""
# 如果模型处于训练状态并且启用了梯度检查点
if self.training and self.gradient_checkpointing:
# 创建自定义前向传播方法
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# 根据 PyTorch 版本选择检查点方式
if is_torch_version(">=", "1.11.0"):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
else:
# 将图像从 [-1, 1] 线性缩放到 [0, 1],以匹配 TAESD 规范
x = self.layers(x.add(1).div(2))
# 返回前向传播的输出
return x
# 定义一个名为 `DecoderTiny` 的类,继承自 `nn.Module`
class DecoderTiny(nn.Module):
r"""
`DecoderTiny` 层是 `Decoder` 层的简化版本。
参数:
in_channels (`int`):
输入通道的数量。
out_channels (`int`):
输出通道的数量。
num_blocks (`Tuple[int, ...]`):
元组中的每个值表示一个 Conv2d 层后面跟着 `value` 个 `AutoencoderTinyBlock` 的数量。
block_out_channels (`Tuple[int, ...]`):
每个块的输出通道数量。
upsampling_scaling_factor (`int`):
用于上采样的缩放因子。
act_fn (`str`):
使用的激活函数。有关可用选项,请参见 `~diffusers.models.activations.get_activation`。
"""
# 初始化方法,设置类的基本参数
def __init__(
self,
in_channels: int, # 输入通道数量
out_channels: int, # 输出通道数量
num_blocks: Tuple[int, ...], # 每个块的数量
block_out_channels: Tuple[int, ...], # 每个块的输出通道数量
upsampling_scaling_factor: int, # 上采样缩放因子
act_fn: str, # 激活函数名称
upsample_fn: str, # 上采样函数名称
):
super().__init__() # 调用父类的初始化方法
# 初始化层的列表
layers = [
# 添加一个 Conv2d 层,输入通道到第一个块的输出通道
nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1),
# 添加指定的激活函数
get_activation(act_fn),
]
# 遍历每个块的数量
for i, num_block in enumerate(num_blocks):
is_final_block = i == (len(num_blocks) - 1) # 判断是否为最后一个块
num_channels = block_out_channels[i] # 获取当前块的输出通道数量
# 对于当前块的数量,添加相应数量的 `AutoencoderTinyBlock`
for _ in range(num_block):
layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
# 如果不是最后一个块,则添加上采样层
if not is_final_block:
layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor, mode=upsample_fn))
# 设置当前卷积的输出通道数量
conv_out_channel = num_channels if not is_final_block else out_channels
# 添加卷积层
layers.append(
nn.Conv2d(
num_channels, # 输入通道数量
conv_out_channel, # 输出通道数量
kernel_size=3, # 卷积核大小
padding=1, # 填充大小
bias=is_final_block, # 如果是最后一个块,使用偏置
)
)
# 将所有层组合成一个顺序模型
self.layers = nn.Sequential(*layers)
# 初始化梯度检查点标志为 False
self.gradient_checkpointing = False
# 前向传播方法
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""`DecoderTiny` 类的前向方法。"""
# 将输入张量缩放并限制到 [-3, 3] 范围
x = torch.tanh(x / 3) * 3
# 如果处于训练状态并且启用了梯度检查点
if self.training and self.gradient_checkpointing:
# 创建自定义前向函数
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs) # 调用模块
return custom_forward
# 如果 PyTorch 版本大于等于 1.11.0,使用非重入的检查点
if is_torch_version(">=", "1.11.0"):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x) # 使用检查点
else:
x = self.layers(x) # 否则直接通过层处理输入
# 将图像从 [0, 1] 范围缩放到 [-1, 1],以匹配 diffusers 的约定
return x.mul(2).sub(1) # 缩放并返回结果
标签:sample,self,torch,diffusers,channels,源码,解析,block,out
From: https://www.cnblogs.com/apachecn/p/18492371