首页 > 编程语言 >diffusers-源码解析-六-

diffusers-源码解析-六-

时间:2024-10-22 12:36:43浏览次数:1  
标签:sample self torch diffusers channels 源码 解析 block out

diffusers 源码解析(六)

.\diffusers\models\autoencoders\autoencoder_oobleck.py

# 版权声明,表示该代码属于 HuggingFace 团队,所有权利保留
# 根据 Apache 2.0 许可证进行授权
# 用户在合规的情况下可以使用该文件
# 许可证的获取地址
# 如果没有适用的法律或书面协议,软件是按“现状”提供的
# 免责声明,表示不提供任何形式的保证或条件
import math  # 导入数学库,提供数学函数和常数
from dataclasses import dataclass  # 导入数据类装饰器,用于简化类的定义
from typing import Optional, Tuple, Union  # 导入类型提示的必要工具

import numpy as np  # 导入 NumPy 库,提供数组和矩阵操作
import torch  # 导入 PyTorch 库,提供深度学习框架
import torch.nn as nn  # 导入 PyTorch 神经网络模块
from torch.nn.utils import weight_norm  # 导入权重归一化工具

from ...configuration_utils import ConfigMixin, register_to_config  # 导入配置相关工具
from ...utils import BaseOutput  # 导入基础输出工具
from ...utils.accelerate_utils import apply_forward_hook  # 导入加速工具
from ...utils.torch_utils import randn_tensor  # 导入随机张量生成工具
from ..modeling_utils import ModelMixin  # 导入模型混合工具


class Snake1d(nn.Module):
    """
    一个 1 维的 Snake 激活函数模块。
    """

    def __init__(self, hidden_dim, logscale=True):  # 初始化方法,接收隐藏维度和对数缩放标志
        super().__init__()  # 调用父类初始化方法
        self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1))  # 定义 alpha 参数,初始为 0
        self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1))  # 定义 beta 参数,初始为 0

        self.alpha.requires_grad = True  # 允许 alpha 参数更新
        self.beta.requires_grad = True  # 允许 beta 参数更新
        self.logscale = logscale  # 设置对数缩放标志

    def forward(self, hidden_states):  # 前向传播方法,接收隐藏状态
        shape = hidden_states.shape  # 获取隐藏状态的形状

        alpha = self.alpha if not self.logscale else torch.exp(self.alpha)  # 计算 alpha 值
        beta = self.beta if not self.logscale else torch.exp(self.beta)  # 计算 beta 值

        hidden_states = hidden_states.reshape(shape[0], shape[1], -1)  # 重塑隐藏状态
        hidden_states = hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2)  # 更新隐藏状态
        hidden_states = hidden_states.reshape(shape)  # 恢复隐藏状态形状
        return hidden_states  # 返回更新后的隐藏状态


class OobleckResidualUnit(nn.Module):
    """
    一个由 Snake1d 和带扩张的权重归一化 Conv1d 层组成的残差单元。
    """

    def __init__(self, dimension: int = 16, dilation: int = 1):  # 初始化方法,接收维度和扩张因子
        super().__init__()  # 调用父类初始化方法
        pad = ((7 - 1) * dilation) // 2  # 计算填充大小

        self.snake1 = Snake1d(dimension)  # 创建第一个 Snake1d 实例
        self.conv1 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad))  # 创建第一个卷积层并应用权重归一化
        self.snake2 = Snake1d(dimension)  # 创建第二个 Snake1d 实例
        self.conv2 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=1))  # 创建第二个卷积层并应用权重归一化
    # 定义前向传播函数,接收隐藏状态作为输入
    def forward(self, hidden_state):
        """
        前向传播通过残差单元。
    
        参数:
            hidden_state (`torch.Tensor` 形状为 `(batch_size, channels, time_steps)`):
                输入张量。
    
        返回:
            output_tensor (`torch.Tensor` 形状为 `(batch_size, channels, time_steps)`)
                通过残差单元处理后的输入张量。
        """
        # 将输入隐藏状态赋值给输出张量
        output_tensor = hidden_state
        # 通过第一个卷积层和激活函数处理输出张量
        output_tensor = self.conv1(self.snake1(output_tensor))
        # 通过第二个卷积层和激活函数处理输出张量
        output_tensor = self.conv2(self.snake2(output_tensor))
    
        # 计算填充量,以对齐输出张量和输入张量的时间步长
        padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2
        # 如果需要填充,则对隐藏状态进行切片
        if padding > 0:
            hidden_state = hidden_state[..., padding:-padding]
        # 将处理后的输出张量与切片后的隐藏状态相加,实现残差连接
        output_tensor = hidden_state + output_tensor
        # 返回最终的输出张量
        return output_tensor
# Oobleck编码器块的定义,继承自nn.Module
class OobleckEncoderBlock(nn.Module):
    """Encoder block used in Oobleck encoder."""

    # 初始化函数,定义输入维度、输出维度和步幅
    def __init__(self, input_dim, output_dim, stride: int = 1):
        # 调用父类的初始化函数
        super().__init__()

        # 定义第一个残差单元,膨胀率为1
        self.res_unit1 = OobleckResidualUnit(input_dim, dilation=1)
        # 定义第二个残差单元,膨胀率为3
        self.res_unit2 = OobleckResidualUnit(input_dim, dilation=3)
        # 定义第三个残差单元,膨胀率为9
        self.res_unit3 = OobleckResidualUnit(input_dim, dilation=9)
        # 定义一维蛇形结构
        self.snake1 = Snake1d(input_dim)
        # 定义卷积层,使用权重归一化
        self.conv1 = weight_norm(
            nn.Conv1d(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2))
        )

    # 前向传播函数
    def forward(self, hidden_state):
        # 通过第一个残差单元处理隐状态
        hidden_state = self.res_unit1(hidden_state)
        # 通过第二个残差单元处理隐状态
        hidden_state = self.res_unit2(hidden_state)
        # 通过第三个残差单元和蛇形结构处理隐状态
        hidden_state = self.snake1(self.res_unit3(hidden_state))
        # 通过卷积层处理隐状态
        hidden_state = self.conv1(hidden_state)

        # 返回处理后的隐状态
        return hidden_state


# Oobleck解码器块的定义,继承自nn.Module
class OobleckDecoderBlock(nn.Module):
    """Decoder block used in Oobleck decoder."""

    # 初始化函数,定义输入维度、输出维度和步幅
    def __init__(self, input_dim, output_dim, stride: int = 1):
        # 调用父类的初始化函数
        super().__init__()

        # 定义一维蛇形结构
        self.snake1 = Snake1d(input_dim)
        # 定义转置卷积层,使用权重归一化
        self.conv_t1 = weight_norm(
            nn.ConvTranspose1d(
                input_dim,
                output_dim,
                kernel_size=2 * stride,
                stride=stride,
                padding=math.ceil(stride / 2),
            )
        )
        # 定义第一个残差单元,膨胀率为1
        self.res_unit1 = OobleckResidualUnit(output_dim, dilation=1)
        # 定义第二个残差单元,膨胀率为3
        self.res_unit2 = OobleckResidualUnit(output_dim, dilation=3)
        # 定义第三个残差单元,膨胀率为9
        self.res_unit3 = OobleckResidualUnit(output_dim, dilation=9)

    # 前向传播函数
    def forward(self, hidden_state):
        # 通过蛇形结构处理隐状态
        hidden_state = self.snake1(hidden_state)
        # 通过转置卷积层处理隐状态
        hidden_state = self.conv_t1(hidden_state)
        # 通过第一个残差单元处理隐状态
        hidden_state = self.res_unit1(hidden_state)
        # 通过第二个残差单元处理隐状态
        hidden_state = self.res_unit2(hidden_state)
        # 通过第三个残差单元处理隐状态
        hidden_state = self.res_unit3(hidden_state)

        # 返回处理后的隐状态
        return hidden_state


# Oobleck对角高斯分布的定义
class OobleckDiagonalGaussianDistribution(object):
    # 初始化函数,定义参数和确定性标志
    def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
        # 保存输入的参数
        self.parameters = parameters
        # 将参数分解为均值和尺度
        self.mean, self.scale = parameters.chunk(2, dim=1)
        # 计算标准差,确保为正值
        self.std = nn.functional.softplus(self.scale) + 1e-4
        # 计算方差
        self.var = self.std * self.std
        # 计算对数方差
        self.logvar = torch.log(self.var)
        # 保存确定性标志
        self.deterministic = deterministic

    # 采样函数,生成高斯样本
    def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
        # 确保生成的样本在与参数相同的设备和数据类型上
        sample = randn_tensor(
            self.mean.shape,
            generator=generator,
            device=self.parameters.device,
            dtype=self.parameters.dtype,
        )
        # 根据均值和标准差生成样本
        x = self.mean + self.std * sample
        # 返回生成的样本
        return x
    # 计算 Kullback-Leibler 散度,返回一个张量
        def kl(self, other: "OobleckDiagonalGaussianDistribution" = None) -> torch.Tensor:
            # 如果是确定性分布,返回零的张量
            if self.deterministic:
                return torch.Tensor([0.0])
            else:
                # 如果没有提供其他分布,计算本分布的 KL 散度
                if other is None:
                    return (self.mean * self.mean + self.var - self.logvar - 1.0).sum(1).mean()
                else:
                    # 计算均值差的平方归一化
                    normalized_diff = torch.pow(self.mean - other.mean, 2) / other.var
                    # 计算方差比
                    var_ratio = self.var / other.var
                    # 计算对数方差差值
                    logvar_diff = self.logvar - other.logvar
    
                    # 计算 KL 散度的各个部分
                    kl = normalized_diff + var_ratio + logvar_diff - 1
    
                    # 计算 KL 散度的平均值
                    kl = kl.sum(1).mean()
                    return kl
    
        # 返回分布的众数
        def mode(self) -> torch.Tensor:
            return self.mean
# 定义一个数据类,表示自编码器的输出
@dataclass
class AutoencoderOobleckOutput(BaseOutput):
    """
    AutoencoderOobleck 编码方法的输出。

    Args:
        latent_dist (`OobleckDiagonalGaussianDistribution`):
            表示 `Encoder` 编码输出的均值和标准差,
            `OobleckDiagonalGaussianDistribution` 允许从分布中采样潜在变量。
    """

    latent_dist: "OobleckDiagonalGaussianDistribution"  # noqa: F821


# 定义一个数据类,表示解码器的输出
@dataclass
class OobleckDecoderOutput(BaseOutput):
    r"""
    解码方法的输出。

    Args:
        sample (`torch.Tensor` of shape `(batch_size, audio_channels, sequence_length)`):
            从模型最后一层解码的输出样本。
    """

    sample: torch.Tensor


# 定义 Oobleck 编码器类,继承自 nn.Module
class OobleckEncoder(nn.Module):
    """Oobleck 编码器"""

    # 初始化编码器参数
    def __init__(self, encoder_hidden_size, audio_channels, downsampling_ratios, channel_multiples):
        super().__init__()

        # 设置下采样比例
        strides = downsampling_ratios
        # 为通道倍数添加一个起始值
        channel_multiples = [1] + channel_multiples

        # 创建第一个卷积层
        self.conv1 = weight_norm(nn.Conv1d(audio_channels, encoder_hidden_size, kernel_size=7, padding=3))

        self.block = []
        # 创建编码块,随着下采样通过 `stride` 加倍通道
        for stride_index, stride in enumerate(strides):
            self.block += [
                OobleckEncoderBlock(
                    input_dim=encoder_hidden_size * channel_multiples[stride_index],  # 输入维度
                    output_dim=encoder_hidden_size * channel_multiples[stride_index + 1],  # 输出维度
                    stride=stride,  # 下采样步幅
                )
            ]

        # 将编码块转换为模块列表
        self.block = nn.ModuleList(self.block)
        # 计算模型的最终维度
        d_model = encoder_hidden_size * channel_multiples[-1]
        # 创建 Snake1d 模块
        self.snake1 = Snake1d(d_model)
        # 创建第二个卷积层
        self.conv2 = weight_norm(nn.Conv1d(d_model, encoder_hidden_size, kernel_size=3, padding=1))

    # 定义前向传播方法
    def forward(self, hidden_state):
        # 将输入通过第一个卷积层
        hidden_state = self.conv1(hidden_state)

        # 通过每个编码块进行处理
        for module in self.block:
            hidden_state = module(hidden_state)

        # 通过 Snake1d 模块处理
        hidden_state = self.snake1(hidden_state)
        # 将结果通过第二个卷积层
        hidden_state = self.conv2(hidden_state)

        # 返回最终的隐藏状态
        return hidden_state


# 定义 Oobleck 解码器类,继承自 nn.Module
class OobleckDecoder(nn.Module):
    """Oobleck 解码器"""
    # 初始化方法,设置模型参数
        def __init__(self, channels, input_channels, audio_channels, upsampling_ratios, channel_multiples):
            # 调用父类的初始化方法
            super().__init__()
    
            # 将上采样比例赋值给 strides
            strides = upsampling_ratios
            # 在 channel_multiples 列表前添加 1
            channel_multiples = [1] + channel_multiples
    
            # 添加第一个卷积层
            self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * channel_multiples[-1], kernel_size=7, padding=3))
    
            # 添加上采样 + MRF 块
            block = []
            # 遍历 strides 列表,构建 OobleckDecoderBlock
            for stride_index, stride in enumerate(strides):
                block += [
                    OobleckDecoderBlock(
                        # 设置输入和输出维度
                        input_dim=channels * channel_multiples[len(strides) - stride_index],
                        output_dim=channels * channel_multiples[len(strides) - stride_index - 1],
                        stride=stride,
                    )
                ]
    
            # 将构建的块列表转为 nn.ModuleList
            self.block = nn.ModuleList(block)
            # 设置输出维度
            output_dim = channels
            # 创建 Snake1d 实例
            self.snake1 = Snake1d(output_dim)
            # 添加第二个卷积层,不使用偏置
            self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, kernel_size=7, padding=3, bias=False))
    
        # 前向传播方法
        def forward(self, hidden_state):
            # 通过第一个卷积层处理输入
            hidden_state = self.conv1(hidden_state)
    
            # 遍历每个块,依次处理隐藏状态
            for layer in self.block:
                hidden_state = layer(hidden_state)
    
            # 通过 Snake1d 处理隐藏状态
            hidden_state = self.snake1(hidden_state)
            # 通过第二个卷积层处理隐藏状态
            hidden_state = self.conv2(hidden_state)
    
            # 返回最终的隐藏状态
            return hidden_state
# 定义一个自动编码器类,用于将波形编码为潜在表示并解码为波形
class AutoencoderOobleck(ModelMixin, ConfigMixin):
    r"""
    自动编码器,用于将波形编码为潜在向量并将潜在表示解码为波形。首次引入于 Stable Audio。

    此模型继承自 [`ModelMixin`]。请查阅超类文档以获取所有模型的通用方法(例如下载或保存)。

    参数:
        encoder_hidden_size (`int`, *可选*, 默认值为 128):
            编码器的中间表示维度。
        downsampling_ratios (`List[int]`, *可选*, 默认值为 `[2, 4, 4, 8, 8]`):
            编码器中下采样的比率。这些比率在解码器中以反向顺序用于上采样。
        channel_multiples (`List[int]`, *可选*, 默认值为 `[1, 2, 4, 8, 16]`):
            用于确定隐藏层隐藏尺寸的倍数。
        decoder_channels (`int`, *可选*, 默认值为 128):
            解码器的中间表示维度。
        decoder_input_channels (`int`, *可选*, 默认值为 64):
            解码器的输入维度。对应于潜在维度。
        audio_channels (`int`, *可选*, 默认值为 2):
            音频数据中的通道数。1 表示单声道,2 表示立体声。
        sampling_rate (`int`, *可选*, 默认值为 44100):
            音频波形应数字化的采样率,以赫兹(Hz)表示。
    """

    # 指示模型是否支持梯度检查点
    _supports_gradient_checkpointing = False

    # 注册到配置的构造函数
    @register_to_config
    def __init__(
        self,
        encoder_hidden_size=128,
        downsampling_ratios=[2, 4, 4, 8, 8],
        channel_multiples=[1, 2, 4, 8, 16],
        decoder_channels=128,
        decoder_input_channels=64,
        audio_channels=2,
        sampling_rate=44100,
    ):
        # 调用父类的构造函数
        super().__init__()

        # 设置编码器的隐藏层大小
        self.encoder_hidden_size = encoder_hidden_size
        # 设置编码器中的下采样比率
        self.downsampling_ratios = downsampling_ratios
        # 设置解码器的通道数
        self.decoder_channels = decoder_channels
        # 将下采样比率反转以用于解码器的上采样
        self.upsampling_ratios = downsampling_ratios[::-1]
        # 计算跳长,作为下采样比率的乘积
        self.hop_length = int(np.prod(downsampling_ratios))
        # 设置音频采样率
        self.sampling_rate = sampling_rate

        # 创建编码器实例,传入必要参数
        self.encoder = OobleckEncoder(
            encoder_hidden_size=encoder_hidden_size,
            audio_channels=audio_channels,
            downsampling_ratios=downsampling_ratios,
            channel_multiples=channel_multiples,
        )

        # 创建解码器实例,传入必要参数
        self.decoder = OobleckDecoder(
            channels=decoder_channels,
            input_channels=decoder_input_channels,
            audio_channels=audio_channels,
            upsampling_ratios=self.upsampling_ratios,
            channel_multiples=channel_multiples,
        )

        # 设置是否使用切片,初始值为假
        self.use_slicing = False
    # 启用切片 VAE 解码的功能
    def enable_slicing(self):
        r"""
        启用切片 VAE 解码。当启用此选项时,VAE 将会将输入张量分割为多个切片以
        分步计算解码。这对于节省内存和允许更大的批量大小非常有用。
        """
        # 设置标志以启用切片
        self.use_slicing = True
    
    # 禁用切片 VAE 解码的功能
    def disable_slicing(self):
        r"""
        禁用切片 VAE 解码。如果之前启用了 `enable_slicing`,则该方法将恢复为一步
        计算解码。
        """
        # 设置标志以禁用切片
        self.use_slicing = False
    
    @apply_forward_hook
    # 编码函数,将一批图像编码为潜在表示
    def encode(
        self, x: torch.Tensor, return_dict: bool = True
    ) -> Union[AutoencoderOobleckOutput, Tuple[OobleckDiagonalGaussianDistribution]]:
        """
        将一批图像编码为潜在表示。
    
        参数:
            x (`torch.Tensor`): 输入图像批。
            return_dict (`bool`, *可选*, 默认为 `True`):
                是否返回 [`~models.autoencoder_kl.AutoencoderKLOutput`] 而不是普通元组。
    
        返回:
                编码图像的潜在表示。如果 `return_dict` 为 True,则返回
                [`~models.autoencoder_kl.AutoencoderKLOutput`],否则返回普通 `tuple`。
        """
        # 检查是否启用切片且输入批量大于 1
        if self.use_slicing and x.shape[0] > 1:
            # 对输入进行切片编码
            encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
            # 将所有编码结果连接成一个张量
            h = torch.cat(encoded_slices)
        else:
            # 对整个输入进行编码
            h = self.encoder(x)
    
        # 创建潜在分布
        posterior = OobleckDiagonalGaussianDistribution(h)
    
        # 检查是否返回字典格式
        if not return_dict:
            return (posterior,)
    
        # 返回潜在表示的输出
        return AutoencoderOobleckOutput(latent_dist=posterior)
    
    # 解码函数,将潜在向量解码为图像
    def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[OobleckDecoderOutput, torch.Tensor]:
        # 使用解码器解码潜在向量
        dec = self.decoder(z)
    
        # 检查是否返回字典格式
        if not return_dict:
            return (dec,)
    
        # 返回解码结果的输出
        return OobleckDecoderOutput(sample=dec)
    
    @apply_forward_hook
    # 解码函数,解码一批潜在向量为图像
    def decode(
        self, z: torch.FloatTensor, return_dict: bool = True, generator=None
    ) -> Union[OobleckDecoderOutput, torch.FloatTensor]:
        """
        解码一批图像。
    
        参数:
            z (`torch.Tensor`): 输入潜在向量批。
            return_dict (`bool`, *可选*, 默认为 `True`):
                是否返回 [`~models.vae.OobleckDecoderOutput`] 而不是普通元组。
    
        返回:
            [`~models.vae.OobleckDecoderOutput`] 或 `tuple`:
                如果 return_dict 为 True,则返回 [`~models.vae.OobleckDecoderOutput`],否则返回普通 `tuple`。
        """
        # 检查是否启用切片且输入批量大于 1
        if self.use_slicing and z.shape[0] > 1:
            # 对输入进行切片解码
            decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
            # 将所有解码结果连接成一个张量
            decoded = torch.cat(decoded_slices)
        else:
            # 对整个输入进行解码
            decoded = self._decode(z).sample
    
        # 检查是否返回字典格式
        if not return_dict:
            return (decoded,)
    
        # 返回解码结果的输出
        return OobleckDecoderOutput(sample=decoded)
    # 定义一个前向传播的方法,接受样本输入和其他参数
        def forward(
            self,
            sample: torch.Tensor,
            sample_posterior: bool = False,  # 是否从后验分布中采样的标志,默认值为 False
            return_dict: bool = True,  # 是否返回 OobleckDecoderOutput 对象而非普通元组,默认值为 True
            generator: Optional[torch.Generator] = None,  # 可选的随机数生成器
        ) -> Union[OobleckDecoderOutput, torch.Tensor]:  # 返回类型为 OobleckDecoderOutput 或 torch.Tensor
            r"""
            Args:
                sample (`torch.Tensor`): 输入样本。
                sample_posterior (`bool`, *optional*, defaults to `False`):
                    是否从后验分布中采样。
                return_dict (`bool`, *optional*, defaults to `True`):
                    是否返回一个 [`OobleckDecoderOutput`] 而不是普通元组。
            """
            x = sample  # 将输入样本赋值给变量 x
            posterior = self.encode(x).latent_dist  # 对输入样本进行编码,获取潜在分布
            if sample_posterior:  # 如果选择从后验分布中采样
                z = posterior.sample(generator=generator)  # 从潜在分布中采样
            else:
                z = posterior.mode()  # 否则取潜在分布的众数
            dec = self.decode(z).sample  # 解码采样得到的潜在变量 z,并获取样本
    
            if not return_dict:  # 如果不需要返回字典
                return (dec,)  # 返回解码样本作为元组
    
            return OobleckDecoderOutput(sample=dec)  # 返回 OobleckDecoderOutput 对象

.\diffusers\models\autoencoders\autoencoder_tiny.py

# 版权所有 2024 Ollin Boer Bohan 和 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证第 2.0 版(“许可证”)进行许可;
# 除非遵守许可证,否则您不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,按照许可证分发的软件是在“按原样”基础上分发的,
# 不提供任何种类的保证或条件,无论是明示还是暗示。
# 请参阅许可证以获取有关权限和限制的具体信息。


# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass
# 从 typing 模块导入可选类型、元组和联合类型
from typing import Optional, Tuple, Union

# 导入 PyTorch 库
import torch

# 从配置相关的模块导入 ConfigMixin 和 register_to_config
from ...configuration_utils import ConfigMixin, register_to_config
# 从工具模块导入 BaseOutput
from ...utils import BaseOutput
# 从加速工具模块导入 apply_forward_hook 函数
from ...utils.accelerate_utils import apply_forward_hook
# 从模型工具模块导入 ModelMixin
from ..modeling_utils import ModelMixin
# 从 VAE 模块导入 DecoderOutput、DecoderTiny 和 EncoderTiny
from .vae import DecoderOutput, DecoderTiny, EncoderTiny


@dataclass
class AutoencoderTinyOutput(BaseOutput):
    """
    AutoencoderTiny 编码方法的输出。

    参数:
        latents (`torch.Tensor`): `Encoder` 的编码输出。
    """

    # 定义编码输出的张量属性
    latents: torch.Tensor


class AutoencoderTiny(ModelMixin, ConfigMixin):
    r"""
    一个小型的蒸馏变分自编码器(VAE)模型,用于将图像编码为潜在表示并将潜在表示解码为图像。

    [`AutoencoderTiny`] 是对 `TAESD` 原始实现的封装。

    此模型继承自 [`ModelMixin`]。有关其为所有模型实现的通用方法的文档,请查看超类文档
    (例如下载或保存)。

    """

    # 指示该模型支持梯度检查点
    _supports_gradient_checkpointing = True

    @register_to_config
    def __init__(
        # 定义输入通道的默认值,默认为3(RGB图像)
        self,
        in_channels: int = 3,
        # 定义输出通道的默认值,默认为3(RGB图像)
        out_channels: int = 3,
        # 定义编码器块输出通道的元组,默认为四个64
        encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
        # 定义解码器块输出通道的元组,默认为四个64
        decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
        # 定义激活函数的默认值,默认为 ReLU
        act_fn: str = "relu",
        # 定义上采样函数的默认值,默认为最近邻插值
        upsample_fn: str = "nearest",
        # 定义潜在通道的默认值
        latent_channels: int = 4,
        # 定义上采样缩放因子的默认值
        upsampling_scaling_factor: int = 2,
        # 定义编码器块数量的元组,默认为 (1, 3, 3, 3)
        num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3),
        # 定义解码器块数量的元组,默认为 (3, 3, 3, 1)
        num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1),
        # 定义潜在幅度的默认值
        latent_magnitude: int = 3,
        # 定义潜在偏移量的默认值
        latent_shift: float = 0.5,
        # 定义是否强制上溯的布尔值,默认为 False
        force_upcast: bool = False,
        # 定义缩放因子的默认值
        scaling_factor: float = 1.0,
        # 定义偏移因子的默认值
        shift_factor: float = 0.0,
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 检查编码器块输出通道数量与编码器块数量是否一致
        if len(encoder_block_out_channels) != len(num_encoder_blocks):
            # 抛出异常,提示编码器块输出通道数量与编码器块数量不匹配
            raise ValueError("`encoder_block_out_channels` should have the same length as `num_encoder_blocks`.")
        # 检查解码器块输出通道数量与解码器块数量是否一致
        if len(decoder_block_out_channels) != len(num_decoder_blocks):
            # 抛出异常,提示解码器块输出通道数量与解码器块数量不匹配
            raise ValueError("`decoder_block_out_channels` should have the same length as `num_decoder_blocks`.")

        # 创建编码器实例
        self.encoder = EncoderTiny(
            # 输入通道数
            in_channels=in_channels,
            # 潜在通道数
            out_channels=latent_channels,
            # 编码块数量
            num_blocks=num_encoder_blocks,
            # 编码块输出通道
            block_out_channels=encoder_block_out_channels,
            # 激活函数
            act_fn=act_fn,
        )

        # 创建解码器实例
        self.decoder = DecoderTiny(
            # 输入潜在通道数
            in_channels=latent_channels,
            # 输出通道数
            out_channels=out_channels,
            # 解码块数量
            num_blocks=num_decoder_blocks,
            # 解码块输出通道
            block_out_channels=decoder_block_out_channels,
            # 上采样缩放因子
            upsampling_scaling_factor=upsampling_scaling_factor,
            # 激活函数
            act_fn=act_fn,
            # 上采样函数
            upsample_fn=upsample_fn,
        )

        # 潜在幅度
        self.latent_magnitude = latent_magnitude
        # 潜在偏移量
        self.latent_shift = latent_shift
        # 缩放因子
        self.scaling_factor = scaling_factor

        # 切片使用标志
        self.use_slicing = False
        # 瓦片使用标志
        self.use_tiling = False

        # 仅在启用 VAE 瓦片时相关
        # 空间缩放因子
        self.spatial_scale_factor = 2**out_channels
        # 瓦片重叠因子
        self.tile_overlap_factor = 0.125
        # 瓦片样本最小大小
        self.tile_sample_min_size = 512
        # 瓦片潜在最小大小
        self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor

        # 注册解码器块输出通道到配置
        self.register_to_config(block_out_channels=decoder_block_out_channels)
        # 注册强制上升到配置
        self.register_to_config(force_upcast=False)

    # 设置梯度检查点的方法
    def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
        # 如果模块是 EncoderTiny 或 DecoderTiny 类型
        if isinstance(module, (EncoderTiny, DecoderTiny)):
            # 设置梯度检查点标志
            module.gradient_checkpointing = value

    # 缩放潜在变量的方法
    def scale_latents(self, x: torch.Tensor) -> torch.Tensor:
        """raw latents -> [0, 1]"""
        # 将潜在变量缩放到 [0, 1] 范围
        return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)

    # 反缩放潜在变量的方法
    def unscale_latents(self, x: torch.Tensor) -> torch.Tensor:
        """[0, 1] -> raw latents"""
        # 将 [0, 1] 范围的潜在变量反缩放回原始值
        return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)

    # 启用切片解码的方法
    def enable_slicing(self) -> None:
        r"""
        启用切片 VAE 解码。当启用此选项时,VAE 将输入张量切片以
        分几步计算解码。这有助于节省一些内存并允许更大的批量大小。
        """
        # 设置切片使用标志为真
        self.use_slicing = True

    # 禁用切片解码的方法
    def disable_slicing(self) -> None:
        r"""
        禁用切片 VAE 解码。如果之前启用了 `enable_slicing`,此方法将回到
        一步计算解码。
        """
        # 设置切片使用标志为假
        self.use_slicing = False
    # 启用分块 VAE 解码的函数,默认为启用
    def enable_tiling(self, use_tiling: bool = True) -> None:
        r"""
        启用分块 VAE 解码。启用后,VAE 将把输入张量拆分成多个块,以
        分步计算解码和编码。这有助于节省大量内存,并允许处理更大的图像。
        """
        # 将实例变量设置为传入的布尔值以启用或禁用分块
        self.use_tiling = use_tiling

    # 禁用分块 VAE 解码的函数
    def disable_tiling(self) -> None:
        r"""
        禁用分块 VAE 解码。如果之前启用了 `enable_tiling`,则该方法将
        返回到一次计算解码的方式。
        """
        # 调用 enable_tiling 方法以禁用分块
        self.enable_tiling(False)

    # 使用分块编码器对图像批次进行编码的私有方法
    def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
        r"""使用分块编码器编码图像批次。

        启用该选项后,VAE 将把输入张量拆分成多个块,以分步计算编码。
        这有助于保持内存使用量在固定范围内,不受图像大小影响。为了避免
        块之间的伪影,块之间会重叠并进行混合,以形成平滑的输出。

        Args:
            x (`torch.Tensor`): 输入的图像批次。

        Returns:
            `torch.Tensor`: 编码后的图像批次。
        """
        # 编码器输出相对于输入的缩放比例
        sf = self.spatial_scale_factor
        # 分块的最小样本尺寸
        tile_size = self.tile_sample_min_size

        # 计算混合和遍历之间的像素数量
        blend_size = int(tile_size * self.tile_overlap_factor)
        traverse_size = tile_size - blend_size

        # 计算分块的索引(上/左)
        ti = range(0, x.shape[-2], traverse_size)
        tj = range(0, x.shape[-1], traverse_size)

        # 创建用于混合的掩码
        blend_masks = torch.stack(
            torch.meshgrid([torch.arange(tile_size / sf) / (blend_size / sf - 1)] * 2, indexing="ij")
        )
        # 将掩码限制在 0 到 1 之间并转移到相应的设备
        blend_masks = blend_masks.clamp(0, 1).to(x.device)

        # 初始化输出数组
        out = torch.zeros(x.shape[0], 4, x.shape[-2] // sf, x.shape[-1] // sf, device=x.device)
        # 遍历分块索引
        for i in ti:
            for j in tj:
                # 获取当前分块的输入张量
                tile_in = x[..., i : i + tile_size, j : j + tile_size]
                # 获取当前分块的输出位置
                tile_out = out[..., i // sf : (i + tile_size) // sf, j // sf : (j + tile_size) // sf]
                # 对输入分块进行编码
                tile = self.encoder(tile_in)
                h, w = tile.shape[-2], tile.shape[-1]
                # 将当前块的结果与输出进行混合
                blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
                blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
                # 计算总混合掩码
                blend_mask = blend_mask_i * blend_mask_j
                # 将块和混合掩码裁剪到一致的形状
                tile, blend_mask = tile[..., :h, :w], blend_mask[..., :h, :w]
                # 更新输出数组中的当前块
                tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
        # 返回最终的编码输出
        return out
    # 定义一个私有方法,用于对一批图像进行分块解码
    def _tiled_decode(self, x: torch.Tensor) -> torch.Tensor:
        r"""使用分块编码器对图像批进行编码。

        启用此选项时,VAE 将把输入张量分割成多个块,以分步计算编码。这有助于保持内存使用在
        常数范围内,无论图像大小如何。为了避免分块伪影,块之间重叠并融合在一起形成平滑输出。

        参数:
            x (`torch.Tensor`): 输入图像批。

        返回:
            `torch.Tensor`: 编码后的图像批。
        """
        # 解码器输出相对于输入的缩放因子
        sf = self.spatial_scale_factor
        # 定义每个块的最小大小
        tile_size = self.tile_latent_min_size

        # 计算用于混合和在块之间遍历的像素数量
        blend_size = int(tile_size * self.tile_overlap_factor)
        # 计算块之间的遍历大小
        traverse_size = tile_size - blend_size

        # 创建块的索引(上/左)
        ti = range(0, x.shape[-2], traverse_size)
        tj = range(0, x.shape[-1], traverse_size)

        # 创建混合掩码
        blend_masks = torch.stack(
            torch.meshgrid([torch.arange(tile_size * sf) / (blend_size * sf - 1)] * 2, indexing="ij")
        )
        # 将混合掩码限制在0到1之间,并移动到输入张量所在的设备
        blend_masks = blend_masks.clamp(0, 1).to(x.device)

        # 创建输出数组,初始化为零
        out = torch.zeros(x.shape[0], 3, x.shape[-2] * sf, x.shape[-1] * sf, device=x.device)
        # 遍历每个块的索引
        for i in ti:
            for j in tj:
                # 获取当前块的输入数据
                tile_in = x[..., i : i + tile_size, j : j + tile_size]
                # 获取当前块的输出位置
                tile_out = out[..., i * sf : (i + tile_size) * sf, j * sf : (j + tile_size) * sf]
                # 使用解码器对当前块进行解码
                tile = self.decoder(tile_in)
                h, w = tile.shape[-2], tile.shape[-1]
                # 将当前块的结果混合到输出中
                blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
                blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
                # 计算最终的混合掩码
                blend_mask = (blend_mask_i * blend_mask_j)[..., :h, :w]
                # 将解码后的块与输出进行混合
                tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
        # 返回最终的输出结果
        return out

    # 使用装饰器应用前向钩子,定义编码方法
    @apply_forward_hook
    def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderTinyOutput, Tuple[torch.Tensor]]:
        # 如果使用切片且输入批量大于1
        if self.use_slicing and x.shape[0] > 1:
            # 对每个切片进行编码,使用分块编码或普通编码
            output = [
                self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x_slice) for x_slice in x.split(1)
            ]
            # 将输出合并成一个张量
            output = torch.cat(output)
        else:
            # 对整个输入进行编码,使用分块编码或普通编码
            output = self._tiled_encode(x) if self.use_tiling else self.encoder(x)

        # 如果不返回字典格式,返回输出张量的元组
        if not return_dict:
            return (output,)

        # 返回包含编码结果的自定义输出对象
        return AutoencoderTinyOutput(latents=output)

    # 使用装饰器应用前向钩子,定义解码方法
    @apply_forward_hook
    def decode(
        self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
    # 函数返回类型为 DecoderOutput 或者元组,处理解码后的输出
    ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
        # 如果使用切片并且输入的第一维大于1
        if self.use_slicing and x.shape[0] > 1:
            # 根据是否使用平铺解码,将切片解码结果存入列表中
            output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
            # 将列表中的张量沿着第0维连接成一个张量
            output = torch.cat(output)
        else:
            # 直接对输入进行解码,根据是否使用平铺解码
            output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
    
        # 如果不需要返回字典格式
        if not return_dict:
            # 返回解码结果作为元组
            return (output,)
    
        # 返回解码结果作为 DecoderOutput 对象
        return DecoderOutput(sample=output)
    
    # forward 方法定义,处理输入样本并返回解码输出
    def forward(
        self,
        sample: torch.Tensor,
        return_dict: bool = True,
    ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
        r"""
        Args:
            sample (`torch.Tensor`): 输入样本。
            return_dict (`bool`, *optional*, defaults to `True`):
                是否返回一个 [`DecoderOutput`] 对象而不是普通元组。
        """
        # 对输入样本进行编码,提取潜在表示
        enc = self.encode(sample).latents
    
        # 将潜在表示缩放到 [0, 1] 范围,并量化为字节张量,
        # 类似于将潜在表示存储为 RGBA uint8 图像
        scaled_enc = self.scale_latents(enc).mul_(255).round_().byte()
    
        # 将量化后的潜在表示反量化回 [0, 1] 范围,并恢复到原始范围,
        # 类似于从 RGBA uint8 图像中加载潜在表示
        unscaled_enc = self.unscale_latents(scaled_enc / 255.0)
    
        # 对反量化后的潜在表示进行解码
        dec = self.decode(unscaled_enc)
    
        # 如果不需要返回字典格式
        if not return_dict:
            # 返回解码结果作为元组
            return (dec,)
        # 返回解码结果作为 DecoderOutput 对象
        return DecoderOutput(sample=dec)

.\diffusers\models\autoencoders\consistency_decoder_vae.py

# 版权所有 2024 The HuggingFace Team. 保留所有权利。
#
# 根据 Apache License 2.0 版(“许可证”)授权;
# 除非遵循许可证,否则您不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有约定,否则根据许可证分发的软件是按“原样”基础分发的,
# 不提供任何形式的保证或条件,无论是明示或暗示的。
# 有关许可证下权限和限制的具体条款,请参见许可证。
from dataclasses import dataclass  # 从 dataclasses 模块导入 dataclass 装饰器
from typing import Dict, Optional, Tuple, Union  # 从 typing 模块导入类型提示

import torch  # 导入 PyTorch 库
import torch.nn.functional as F  # 导入 PyTorch 的函数式神经网络模块
from torch import nn  # 从 PyTorch 导入 nn 模块

from ...configuration_utils import ConfigMixin, register_to_config  # 从配置工具模块导入混合类和注册函数
from ...schedulers import ConsistencyDecoderScheduler  # 从调度器模块导入一致性解码器调度器
from ...utils import BaseOutput  # 从工具模块导入基础输出类
from ...utils.accelerate_utils import apply_forward_hook  # 从加速工具模块导入前向钩子应用函数
from ...utils.torch_utils import randn_tensor  # 从 PyTorch 工具模块导入随机张量函数
from ..attention_processor import (  # 从注意力处理器模块导入所需类
    ADDED_KV_ATTENTION_PROCESSORS,  # 导入添加键值注意力处理器
    CROSS_ATTENTION_PROCESSORS,  # 导入交叉注意力处理器
    AttentionProcessor,  # 导入注意力处理器基类
    AttnAddedKVProcessor,  # 导入添加键值注意力处理器类
    AttnProcessor,  # 导入注意力处理器类
)
from ..modeling_utils import ModelMixin  # 从建模工具模块导入模型混合类
from ..unets.unet_2d import UNet2DModel  # 从 2D U-Net 模块导入 U-Net 模型类
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder  # 从 VAE 模块导入解码器输出、对角高斯分布和编码器

@dataclass  # 将该类标记为数据类
class ConsistencyDecoderVAEOutput(BaseOutput):  # 定义一致性解码器 VAE 输出类,继承自基础输出类
    """
    编码方法的输出。

    参数:
        latent_dist (`DiagonalGaussianDistribution`):
            表示为均值和对数方差的编码器输出。
            `DiagonalGaussianDistribution` 允许从分布中采样潜变量。
    """

    latent_dist: "DiagonalGaussianDistribution"  # 定义潜在分布属性,类型为对角高斯分布


class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):  # 定义一致性解码器 VAE 类,继承自模型混合类和配置混合类
    r"""
    与 DALL-E 3 一起使用的一致性解码器。

    示例:
        ```py
        >>> import torch
        >>> from diffusers import StableDiffusionPipeline, ConsistencyDecoderVAE

        >>> vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)  # 从预训练模型加载一致性解码器 VAE
        >>> pipe = StableDiffusionPipeline.from_pretrained(  # 从预训练模型加载稳定扩散管道
        ...     "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
        ... ).to("cuda")  # 将管道移动到 CUDA 设备

        >>> image = pipe("horse", generator=torch.manual_seed(0)).images[0]  # 生成图像
        >>> image  # 输出生成的图像
        ```py
    """

    @register_to_config  # 将该方法注册到配置中
    # 初始化方法,用于创建类的实例
    def __init__(
            # 缩放因子,默认为 0.18215
            scaling_factor: float = 0.18215,
            # 潜在通道数,默认为 4
            latent_channels: int = 4,
            # 样本尺寸,默认为 32
            sample_size: int = 32,
            # 编码器激活函数,默认为 "silu"
            encoder_act_fn: str = "silu",
            # 编码器输出通道数的元组,默认为 (128, 256, 512, 512)
            encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
            # 编码器是否使用双重 Z,默认为 True
            encoder_double_z: bool = True,
            # 编码器下采样块类型的元组,默认为多个 "DownEncoderBlock2D"
            encoder_down_block_types: Tuple[str, ...] = (
                "DownEncoderBlock2D",
                "DownEncoderBlock2D",
                "DownEncoderBlock2D",
                "DownEncoderBlock2D",
            ),
            # 编码器输入通道数,默认为 3
            encoder_in_channels: int = 3,
            # 每个编码器块的层数,默认为 2
            encoder_layers_per_block: int = 2,
            # 编码器归一化组数,默认为 32
            encoder_norm_num_groups: int = 32,
            # 编码器输出通道数,默认为 4
            encoder_out_channels: int = 4,
            # 解码器是否添加注意力机制,默认为 False
            decoder_add_attention: bool = False,
            # 解码器输出通道数的元组,默认为 (320, 640, 1024, 1024)
            decoder_block_out_channels: Tuple[int, ...] = (320, 640, 1024, 1024),
            # 解码器下采样块类型的元组,默认为多个 "ResnetDownsampleBlock2D"
            decoder_down_block_types: Tuple[str, ...] = (
                "ResnetDownsampleBlock2D",
                "ResnetDownsampleBlock2D",
                "ResnetDownsampleBlock2D",
                "ResnetDownsampleBlock2D",
            ),
            # 解码器下采样填充,默认为 1
            decoder_downsample_padding: int = 1,
            # 解码器输入通道数,默认为 7
            decoder_in_channels: int = 7,
            # 每个解码器块的层数,默认为 3
            decoder_layers_per_block: int = 3,
            # 解码器归一化的 epsilon 值,默认为 1e-05
            decoder_norm_eps: float = 1e-05,
            # 解码器归一化组数,默认为 32
            decoder_norm_num_groups: int = 32,
            # 解码器训练时长的时间步数,默认为 1024
            decoder_num_train_timesteps: int = 1024,
            # 解码器输出通道数,默认为 6
            decoder_out_channels: int = 6,
            # 解码器时间缩放偏移类型,默认为 "scale_shift"
            decoder_resnet_time_scale_shift: str = "scale_shift",
            # 解码器时间嵌入类型,默认为 "learned"
            decoder_time_embedding_type: str = "learned",
            # 解码器上采样块类型的元组,默认为多个 "ResnetUpsampleBlock2D"
            decoder_up_block_types: Tuple[str, ...] = (
                "ResnetUpsampleBlock2D",
                "ResnetUpsampleBlock2D",
                "ResnetUpsampleBlock2D",
                "ResnetUpsampleBlock2D",
            ),
    ):
        # 调用父类的构造函数
        super().__init__()
        # 初始化编码器,传入各类参数以配置其行为
        self.encoder = Encoder(
            act_fn=encoder_act_fn,  # 激活函数
            block_out_channels=encoder_block_out_channels,  # 编码器每个块的输出通道数
            double_z=encoder_double_z,  # 是否使用双Z向量
            down_block_types=encoder_down_block_types,  # 编码器下采样块的类型
            in_channels=encoder_in_channels,  # 输入通道数
            layers_per_block=encoder_layers_per_block,  # 每个块的层数
            norm_num_groups=encoder_norm_num_groups,  # 归一化的组数
            out_channels=encoder_out_channels,  # 输出通道数
        )

        # 初始化解码器UNet模型,配置其参数
        self.decoder_unet = UNet2DModel(
            add_attention=decoder_add_attention,  # 是否添加注意力机制
            block_out_channels=decoder_block_out_channels,  # 解码器每个块的输出通道数
            down_block_types=decoder_down_block_types,  # 解码器下采样块的类型
            downsample_padding=decoder_downsample_padding,  # 下采样的填充方式
            in_channels=decoder_in_channels,  # 输入通道数
            layers_per_block=decoder_layers_per_block,  # 每个块的层数
            norm_eps=decoder_norm_eps,  # 归一化中的epsilon值
            norm_num_groups=decoder_norm_num_groups,  # 归一化的组数
            num_train_timesteps=decoder_num_train_timesteps,  # 训练时的时间步数
            out_channels=decoder_out_channels,  # 输出通道数
            resnet_time_scale_shift=decoder_resnet_time_scale_shift,  # ResNet时间尺度偏移
            time_embedding_type=decoder_time_embedding_type,  # 时间嵌入类型
            up_block_types=decoder_up_block_types,  # 解码器上采样块的类型
        )
        # 初始化一致性解码器调度器
        self.decoder_scheduler = ConsistencyDecoderScheduler()
        # 注册编码器的输出通道数到配置中
        self.register_to_config(block_out_channels=encoder_block_out_channels)
        # 注册强制上采样的配置
        self.register_to_config(force_upcast=False)
        # 注册均值的缓冲区,形状为(1, C, 1, 1)
        self.register_buffer(
            "means",
            torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None],  # 均值张量
            persistent=False,  # 不持久化保存
        )
        # 注册标准差的缓冲区,形状为(1, C, 1, 1)
        self.register_buffer(
            "stds", torch.tensor([0.9654121, 1.0440036, 0.76147926, 0.77022034])[None, :, None, None], persistent=False  # 标准差张量
        )

        # 初始化量化卷积层,输入和输出通道数相同
        self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)

        # 设置切片和拼接的使用标志为假
        self.use_slicing = False
        self.use_tiling = False

        # 仅在启用 VAE 切片时相关
        self.tile_sample_min_size = self.config.sample_size  # 最小样本大小
        # 判断样本大小类型,并设置样本大小
        sample_size = (
            self.config.sample_size[0]  # 如果样本大小是列表或元组,取第一个元素
            if isinstance(self.config.sample_size, (list, tuple))
            else self.config.sample_size  # 否则直接使用样本大小
        )
        # 计算最小的切片潜在大小
        self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
        # 设置切片重叠因子
        self.tile_overlap_factor = 0.25

    # 从 diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling 复制的方法
    def enable_tiling(self, use_tiling: bool = True):
        r"""
        启用切片 VAE 解码。启用此选项时,VAE 将输入张量拆分为切片,以
        分步骤计算解码和编码。这有助于节省大量内存,并允许处理更大图像。
        """
        # 设置是否使用切片标志
        self.use_tiling = use_tiling

    # 从 diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_tiling 复制的方法
    # 定义一个方法来禁用平铺的 VAE 解码
    def disable_tiling(self):
        r"""
        禁用平铺的 VAE 解码。如果之前启用了 `enable_tiling`,该方法将恢复为一步计算解码。
        """
        # 调用 enable_tiling 方法并传入 False 参数以禁用平铺
        self.enable_tiling(False)

    # 从 diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_slicing 复制而来
    # 定义一个方法来启用切片的 VAE 解码
    def enable_slicing(self):
        r"""
        启用切片的 VAE 解码。当启用此选项时,VAE 将把输入张量分成多个切片进行解码计算。
        这对于节省一些内存和允许更大的批处理大小非常有用。
        """
        # 设置 use_slicing 为 True,以启用切片
        self.use_slicing = True

    # 从 diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_slicing 复制而来
    # 定义一个方法来禁用切片的 VAE 解码
    def disable_slicing(self):
        r"""
        禁用切片的 VAE 解码。如果之前启用了 `enable_slicing`,该方法将恢复为一步计算解码。
        """
        # 设置 use_slicing 为 False,以禁用切片
        self.use_slicing = False

    @property
    # 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 复制而来
    # 定义一个属性方法,用于返回注意力处理器
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        返回:
            `dict` 类型的注意力处理器:一个包含模型中所有注意力处理器的字典,按权重名称索引。
        """
        # 创建一个空字典用于存储处理器
        processors = {}

        # 定义一个递归函数以添加处理器
        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
            # 检查模块是否有 get_processor 方法
            if hasattr(module, "get_processor"):
                # 将处理器添加到字典中,键为模块名称加上 ".processor"
                processors[f"{name}.processor"] = module.get_processor()

            # 遍历模块的子模块
            for sub_name, child in module.named_children():
                # 递归调用该函数以添加子模块的处理器
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            # 返回更新后的处理器字典
            return processors

        # 遍历当前对象的所有子模块
        for name, module in self.named_children():
            # 调用递归函数以添加所有处理器
            fn_recursive_add_processors(name, module, processors)

        # 返回最终的处理器字典
        return processors

    # 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 复制而来
    # 设置用于计算注意力的处理器
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        设置用于计算注意力的处理器。

        参数:
            processor(`dict` of `AttentionProcessor` 或仅 `AttentionProcessor`):
                实例化的处理器类或一个处理器类的字典,将作为所有 `Attention` 层的处理器设置。

                如果 `processor` 是一个字典,键需要定义对应的交叉注意力处理器的路径。强烈建议在设置可训练的注意力处理器时使用此方式。

        """
        # 计算当前注意力处理器的数量
        count = len(self.attn_processors.keys())

        # 检查传入的处理器字典长度是否与当前注意力层数量匹配
        if isinstance(processor, dict) and len(processor) != count:
            # 抛出值错误,提示处理器数量不匹配
            raise ValueError(
                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
            )

        # 定义递归函数,用于设置每个子模块的处理器
        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            # 检查子模块是否有设置处理器的方法
            if hasattr(module, "set_processor"):
                # 如果处理器不是字典,则直接设置
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    # 从字典中弹出对应的处理器并设置
                    module.set_processor(processor.pop(f"{name}.processor"))

            # 遍历子模块,递归调用自身
            for sub_name, child in module.named_children():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

        # 遍历当前模块的所有子模块
        for name, module in self.named_children():
            # 调用递归函数设置处理器
            fn_recursive_attn_processor(name, module, processor)

    # 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor 复制而来
    def set_default_attn_processor(self):
        """
        禁用自定义注意力处理器并设置默认的注意力实现。
        """
        # 检查所有处理器是否属于新增的 KV 注意力处理器
        if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 设置处理器为新增的 KV 注意力处理器
            processor = AttnAddedKVProcessor()
        # 检查所有处理器是否属于交叉注意力处理器
        elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 设置处理器为标准的注意力处理器
            processor = AttnProcessor()
        else:
            # 抛出值错误,提示无法设置默认处理器
            raise ValueError(
                f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
            )

        # 调用设置处理器的方法
        self.set_attn_processor(processor)

    # 应用前向钩子修饰器
    @apply_forward_hook
    def encode(
        # 输入张量 x,返回字典的标志
        self, x: torch.Tensor, return_dict: bool = True
    # 定义一个方法,返回编码后的图像的潜在表示
    ) -> Union[ConsistencyDecoderVAEOutput, Tuple[DiagonalGaussianDistribution]]:
            """
            将一批图像编码为潜在表示。
    
            参数:
                x (`torch.Tensor`): 输入图像批次。
                return_dict (`bool`, *可选*, 默认为 `True`):
                    是否返回 [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] 
                    而不是普通的元组。
    
            返回:
                    编码图像的潜在表示。如果 `return_dict` 为 True,则返回 
                    [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`],否则返回 
                    普通的 `tuple`。
            """
            # 检查是否使用分块编码,并且图像尺寸超过最小样本大小
            if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
                # 调用分块编码方法处理输入
                return self.tiled_encode(x, return_dict=return_dict)
    
            # 检查是否使用切片,并且输入批次大于1
            if self.use_slicing and x.shape[0] > 1:
                # 对输入的每个切片进行编码,并将结果收集到列表中
                encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
                # 将所有编码切片连接成一个张量
                h = torch.cat(encoded_slices)
            else:
                # 对整个输入进行编码
                h = self.encoder(x)
    
            # 通过量化卷积获取潜在表示的统计量
            moments = self.quant_conv(h)
            # 创建一个对角高斯分布作为后验分布
            posterior = DiagonalGaussianDistribution(moments)
    
            # 如果不需要返回字典格式
            if not return_dict:
                # 返回包含后验分布的元组
                return (posterior,)
    
            # 返回包含潜在分布的输出对象
            return ConsistencyDecoderVAEOutput(latent_dist=posterior)
    
        # 应用前向钩子装饰器
        @apply_forward_hook
        def decode(
            # 定义解码方法,输入潜在变量
            z: torch.Tensor,
            # 可选的随机数生成器
            generator: Optional[torch.Generator] = None,
            # 是否返回字典格式,默认为 True
            return_dict: bool = True,
            # 推理步骤的数量,默认为 2
            num_inference_steps: int = 2,
    ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
        """
        解码输入的潜在向量 `z`,使用一致性解码器 VAE 模型。

        Args:
            z (torch.Tensor): 输入的潜在向量。
            generator (Optional[torch.Generator]): 随机数生成器,默认为 None。
            return_dict (bool): 是否以字典形式返回输出,默认为 True。
            num_inference_steps (int): 推理步骤的数量,默认为 2。

        Returns:
            Union[DecoderOutput, Tuple[torch.Tensor]]: 解码后的输出。

        """
        # 对潜在向量 `z` 进行标准化处理
        z = (z * self.config.scaling_factor - self.means) / self.stds

        # 计算缩放因子,基于输出通道的数量
        scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
        # 将潜在向量 `z` 进行最近邻插值缩放
        z = F.interpolate(z, mode="nearest", scale_factor=scale_factor)

        # 获取当前张量的批量大小、高度和宽度
        batch_size, _, height, width = z.shape

        # 设置解码器调度器的时间步长
        self.decoder_scheduler.set_timesteps(num_inference_steps, device=self.device)

        # 初始化噪声张量 `x_t`,用于后续的解码过程
        x_t = self.decoder_scheduler.init_noise_sigma * randn_tensor(
            (batch_size, 3, height, width), generator=generator, dtype=z.dtype, device=z.device
        )

        # 对每个时间步进行解码
        for t in self.decoder_scheduler.timesteps:
            # 将当前噪声与潜在向量 `z` 组合成模型输入
            model_input = torch.concat([self.decoder_scheduler.scale_model_input(x_t, t), z], dim=1)
            # 获取模型输出的样本
            model_output = self.decoder_unet(model_input, t).sample[:, :3, :, :]
            # 更新当前样本 `x_t`
            prev_sample = self.decoder_scheduler.step(model_output, t, x_t, generator).prev_sample
            x_t = prev_sample

        # 将最后的样本赋值给 `x_0`
        x_0 = x_t

        # 如果不返回字典,直接返回样本
        if not return_dict:
            return (x_0,)

        # 返回解码后的输出,封装成 DecoderOutput 对象
        return DecoderOutput(sample=x_0)

    # 从 diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_v 复制的函数
    def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
        # 计算混合范围,确保不超过输入张量的维度
        blend_extent = min(a.shape[2], b.shape[2], blend_extent)
        # 对于每个混合范围内的高度像素进行加权混合
        for y in range(blend_extent):
            b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
        # 返回混合后的张量
        return b

    # 从 diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_h 复制的函数
    def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
        # 计算混合范围,确保不超过输入张量的维度
        blend_extent = min(a.shape[3], b.shape[3], blend_extent)
        # 对于每个混合范围内的宽度像素进行加权混合
        for x in range(blend_extent):
            b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
        # 返回混合后的张量
        return b
    # 定义一个使用平铺编码器对图像批次进行编码的方法
    def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[ConsistencyDecoderVAEOutput, Tuple]:
        r"""使用平铺编码器编码一批图像。

        当启用此选项时,VAE将输入张量分割成平铺,以进行多个步骤的编码。这有助于保持内存使用量在任何图像大小下都是恒定的。平铺编码的最终结果与非平铺编码不同,因为每个平铺使用不同的编码器。为了避免平铺伪影,平铺之间会重叠并进行混合,以形成平滑的输出。您仍然可能会看到平铺大小的变化,但应该不那么明显。

        参数:
            x (`torch.Tensor`): 输入图像批次。
            return_dict (`bool`, *可选*, 默认为 `True`):
                是否返回 [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] 
                而不是一个普通的元组。

        返回:
            [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] 或 `tuple`:
                如果 return_dict 为 True,则返回一个 [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`]
                ,否则返回一个普通的 `tuple`。
        """
        # 计算重叠区域的大小
        overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
        # 计算混合范围的大小
        blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
        # 计算行限制的大小
        row_limit = self.tile_latent_min_size - blend_extent

        # 将图像分割成512x512的平铺并分别编码
        rows = []  # 存储每一行的编码结果
        for i in range(0, x.shape[2], overlap_size):  # 遍历图像的高度
            row = []  # 存储当前行的编码结果
            for j in range(0, x.shape[3], overlap_size):  # 遍历图像的宽度
                # 提取当前平铺
                tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
                # 使用编码器对平铺进行编码
                tile = self.encoder(tile)
                # 进行量化处理
                tile = self.quant_conv(tile)
                # 将编码后的平铺添加到当前行
                row.append(tile)
            # 将当前行添加到所有行的列表中
            rows.append(row)
        
        result_rows = []  # 存储最终结果的行
        for i, row in enumerate(rows):  # 遍历每一行的编码结果
            result_row = []  # 存储当前行的最终结果
            for j, tile in enumerate(row):  # 遍历当前行的每个平铺
                # 将上方的平铺与当前平铺进行混合
                if i > 0:
                    tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
                # 将左侧的平铺与当前平铺进行混合
                if j > 0:
                    tile = self.blend_h(row[j - 1], tile, blend_extent)
                # 将处理后的平铺裁剪并添加到结果行
                result_row.append(tile[:, :, :row_limit, :row_limit])
            # 将当前行的结果合并到结果行列表中
            result_rows.append(torch.cat(result_row, dim=3))

        # 将所有结果行在高度维度上进行合并
        moments = torch.cat(result_rows, dim=2)
        # 创建对角高斯分布对象
        posterior = DiagonalGaussianDistribution(moments)

        # 如果不返回字典,则返回一个包含后验分布的元组
        if not return_dict:
            return (posterior,)

        # 返回包含后验分布的ConsistencyDecoderVAEOutput对象
        return ConsistencyDecoderVAEOutput(latent_dist=posterior)
    # 定义前向传播方法,处理输入样本并返回解码结果
    def forward(
            self,
            sample: torch.Tensor,  # 输入样本,类型为 torch.Tensor
            sample_posterior: bool = False,  # 是否从后验分布采样,默认为 False
            return_dict: bool = True,  # 是否返回 DecoderOutput 对象,默认为 True
            generator: Optional[torch.Generator] = None,  # 可选的随机数生成器,用于采样
        ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:  # 返回类型可以是 DecoderOutput 或元组
            r"""
            Args:
                sample (`torch.Tensor`): Input sample.  # 输入样本
                sample_posterior (`bool`, *optional*, defaults to `False`):  # 是否采样后验的标志
                    Whether to sample from the posterior.
                return_dict (`bool`, *optional*, defaults to `True`):  # 返回类型的标志
                    Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
                generator (`torch.Generator`, *optional*, defaults to `None`):  # 随机数生成器的说明
                    Generator to use for sampling.
    
            Returns:
                [`DecoderOutput`] or `tuple`:  # 返回类型的说明
                    If return_dict is True, a [`DecoderOutput`] is returned, otherwise a plain `tuple` is returned.
            """
            x = sample  # 将输入样本赋值给变量 x
            posterior = self.encode(x).latent_dist  # 使用编码器对样本进行编码,获取后验分布
            if sample_posterior:  # 检查是否需要从后验分布采样
                z = posterior.sample(generator=generator)  # 从后验分布中采样,使用指定的生成器
            else:  # 如果不从后验分布采样
                z = posterior.mode()  # 选择后验分布的众数作为 z
            dec = self.decode(z, generator=generator).sample  # 解码 z,并获取解码后的样本
    
            if not return_dict:  # 如果不需要返回字典
                return (dec,)  # 返回解码样本的元组
    
            return DecoderOutput(sample=dec)  # 返回 DecoderOutput 对象,包含解码样本

.\diffusers\models\autoencoders\vae.py

# 版权声明,2024年HuggingFace团队保留所有权利
# 
# 根据Apache许可证第2.0版(“许可证”)授权;
# 除非遵循许可证,否则您不得使用此文件。
# 您可以在以下网址获取许可证副本:
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# 除非适用法律要求或书面同意,否则根据许可证分发的软件是按“原样”基础提供的,
# 不提供任何形式的保证或条件,无论是明示或暗示的。
# 请参阅许可证以获取有关特定语言的权限和限制的更多信息。
from dataclasses import dataclass  # 导入dataclass装饰器用于简化类的定义
from typing import Optional, Tuple  # 导入可选类型和元组类型,用于类型注释

import numpy as np  # 导入numpy库,用于数值计算
import torch  # 导入PyTorch库,用于构建深度学习模型
import torch.nn as nn  # 导入PyTorch的神经网络模块

from ...utils import BaseOutput, is_torch_version  # 从utils模块导入BaseOutput类和版本检查函数
from ...utils.torch_utils import randn_tensor  # 从torch_utils模块导入随机张量生成函数
from ..activations import get_activation  # 从activations模块导入获取激活函数的函数
from ..attention_processor import SpatialNorm  # 从attention_processor模块导入空间归一化类
from ..unets.unet_2d_blocks import (  # 从unet_2d_blocks模块导入多个网络块
    AutoencoderTinyBlock,  # 导入自动编码器小块
    UNetMidBlock2D,  # 导入UNet中间块
    get_down_block,  # 导入获取下采样块的函数
    get_up_block,  # 导入获取上采样块的函数
)

@dataclass  # 使用dataclass装饰器,简化类的初始化和表示
class DecoderOutput(BaseOutput):  # 定义DecoderOutput类,继承自BaseOutput
    r"""  # 文档字符串,描述解码方法的输出
    Output of decoding method.  # 解码方法的输出

    Args:  # 参数说明
        sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):  # 输入样本的描述
            The decoded output sample from the last layer of the model.  # 模型最后一层的解码输出样本
    """

    sample: torch.Tensor  # 定义样本属性,类型为torch.Tensor
    commit_loss: Optional[torch.FloatTensor] = None  # 定义可选的损失属性,默认为None

class Encoder(nn.Module):  # 定义Encoder类,继承自nn.Module
    r"""  # 文档字符串,描述变分自动编码器的Encoder层
    The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.  # 变分自动编码器的Encoder层

    Args:  # 参数说明
        in_channels (`int`, *optional*, defaults to 3):  # 输入通道的描述
            The number of input channels.  # 输入通道的数量
        out_channels (`int`, *optional*, defaults to 3):  # 输出通道的描述
            The number of output channels.  # 输出通道的数量
        down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):  # 下采样块类型的描述
            The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available  # 使用的下采样块类型,具体可查看相关文档
            options.  # 可选项
        block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):  # 块输出通道的描述
            The number of output channels for each block.  # 每个块的输出通道数量
        layers_per_block (`int`, *optional*, defaults to 2):  # 每个块层数的描述
            The number of layers per block.  # 每个块的层数
        norm_num_groups (`int`, *optional*, defaults to 32):  # 归一化组数的描述
            The number of groups for normalization.  # 归一化的组数
        act_fn (`str`, *optional*, defaults to `"silu"`):  # 激活函数类型的描述
            The activation function to use. See `~diffusers.models.activations.get_activation` for available options.  # 使用的激活函数,具体可查看相关文档
        double_z (`bool`, *optional*, defaults to `True`):  # 最后一个块输出通道双倍化的描述
            Whether to double the number of output channels for the last block.  # 是否将最后一个块的输出通道数量翻倍
    """
    # 初始化方法,设置网络的基本参数
        def __init__(
            self,
            in_channels: int = 3,  # 输入通道数,默认为3(RGB图像)
            out_channels: int = 3,  # 输出通道数,默认为3(RGB图像)
            down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),  # 下采样模块类型
            block_out_channels: Tuple[int, ...] = (64,),  # 每个块的输出通道数
            layers_per_block: int = 2,  # 每个下采样块的层数
            norm_num_groups: int = 32,  # 归一化时的组数
            act_fn: str = "silu",  # 激活函数类型,默认为SiLU
            double_z: bool = True,  # 是否双输出通道
            mid_block_add_attention=True,  # 中间块是否添加注意力机制
        ):
            super().__init__()  # 调用父类构造函数
            self.layers_per_block = layers_per_block  # 保存每块的层数
    
            # 定义输入卷积层
            self.conv_in = nn.Conv2d(
                in_channels,  # 输入通道数
                block_out_channels[0],  # 输出通道数
                kernel_size=3,  # 卷积核大小
                stride=1,  # 步幅
                padding=1,  # 填充
            )
    
            self.down_blocks = nn.ModuleList([])  # 初始化下采样块的模块列表
    
            # down
            output_channel = block_out_channels[0]  # 初始化输出通道数
            for i, down_block_type in enumerate(down_block_types):  # 遍历下采样块类型
                input_channel = output_channel  # 当前块的输入通道数
                output_channel = block_out_channels[i]  # 当前块的输出通道数
                is_final_block = i == len(block_out_channels) - 1  # 判断是否为最后一个块
    
                # 获取下采样块并初始化
                down_block = get_down_block(
                    down_block_type,  # 下采样块类型
                    num_layers=self.layers_per_block,  # 下采样块的层数
                    in_channels=input_channel,  # 输入通道数
                    out_channels=output_channel,  # 输出通道数
                    add_downsample=not is_final_block,  # 是否添加下采样
                    resnet_eps=1e-6,  # ResNet的epsilon值
                    downsample_padding=0,  # 下采样的填充
                    resnet_act_fn=act_fn,  # ResNet的激活函数
                    resnet_groups=norm_num_groups,  # ResNet的组数
                    attention_head_dim=output_channel,  # 注意力头的维度
                    temb_channels=None,  # 时间嵌入通道数
                )
                self.down_blocks.append(down_block)  # 将下采样块添加到模块列表中
    
            # mid
            # 定义中间块
            self.mid_block = UNetMidBlock2D(
                in_channels=block_out_channels[-1],  # 中间块的输入通道数
                resnet_eps=1e-6,  # ResNet的epsilon值
                resnet_act_fn=act_fn,  # ResNet的激活函数
                output_scale_factor=1,  # 输出缩放因子
                resnet_time_scale_shift="default",  # ResNet时间缩放偏移
                attention_head_dim=block_out_channels[-1],  # 注意力头的维度
                resnet_groups=norm_num_groups,  # ResNet的组数
                temb_channels=None,  # 时间嵌入通道数
                add_attention=mid_block_add_attention,  # 是否添加注意力机制
            )
    
            # out
            # 定义输出的归一化层
            self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
            self.conv_act = nn.SiLU()  # 激活函数层
    
            # 根据双输出设置卷积层的输出通道数
            conv_out_channels = 2 * out_channels if double_z else out_channels  
            # 定义输出卷积层
            self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)  
    
            self.gradient_checkpointing = False  # 设置梯度检查点为False
    # 定义 Encoder 类的前向传播方法,接收一个张量作为输入并返回一个张量
    def forward(self, sample: torch.Tensor) -> torch.Tensor:
        r"""Encoder 类的前向方法。"""

        # 使用输入的样本通过初始卷积层进行处理
        sample = self.conv_in(sample)

        # 如果处于训练模式并且开启了梯度检查点
        if self.training and self.gradient_checkpointing:

            # 定义一个创建自定义前向传播的内部函数
            def create_custom_forward(module):
                # 定义自定义前向传播函数,接受任意输入并调用模块
                def custom_forward(*inputs):
                    return module(*inputs)

                return custom_forward

            # 如果 PyTorch 版本大于等于 1.11.0
            if is_torch_version(">=", "1.11.0"):
                # 遍历每个下采样块,应用检查点机制来节省内存
                for down_block in self.down_blocks:
                    sample = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(down_block), sample, use_reentrant=False
                    )
                # 中间块应用检查点机制
                sample = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(self.mid_block), sample, use_reentrant=False
                )
            else:
                # 遍历每个下采样块,应用检查点机制
                for down_block in self.down_blocks:
                    sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
                # 中间块应用检查点机制
                sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)

        else:
            # 否则,直接通过下采样块处理样本
            for down_block in self.down_blocks:
                sample = down_block(sample)

            # 直接通过中间块处理样本
            sample = self.mid_block(sample)

        # 后处理步骤
        sample = self.conv_norm_out(sample)  # 应用归一化卷积层
        sample = self.conv_act(sample)        # 应用激活函数卷积层
        sample = self.conv_out(sample)        # 应用输出卷积层

        # 返回处理后的样本
        return sample
# 定义变分自编码器的解码层,将潜在表示解码为输出样本
class Decoder(nn.Module):
    r"""
    `Decoder`层的文档字符串,描述其功能及参数

    Args:
        in_channels (`int`, *optional*, defaults to 3):
            输入通道的数量
        out_channels (`int`, *optional*, defaults to 3):
            输出通道的数量
        up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
            使用的上采样块类型,参考可用选项
        block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
            每个块的输出通道数量
        layers_per_block (`int`, *optional*, defaults to 2):
            每个块包含的层数
        norm_num_groups (`int`, *optional*, defaults to 32):
            归一化的组数
        act_fn (`str`, *optional*, defaults to `"silu"`):
            使用的激活函数,参考可用选项
        norm_type (`str`, *optional*, defaults to `"group"`):
            使用的归一化类型,可以是"group"或"spatial"
    """

    # 初始化解码器层
    def __init__(
        self,
        in_channels: int = 3,  # 默认输入通道为3
        out_channels: int = 3,  # 默认输出通道为3
        up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),  # 默认上采样块类型
        block_out_channels: Tuple[int, ...] = (64,),  # 默认块输出通道
        layers_per_block: int = 2,  # 每个块的默认层数
        norm_num_groups: int = 32,  # 默认归一化组数
        act_fn: str = "silu",  # 默认激活函数为"silu"
        norm_type: str = "group",  # 默认归一化类型为"group"
        mid_block_add_attention=True,  # 中间块是否添加注意力机制,默认为True
    # 初始化父类
        ):
            super().__init__()
            # 设置每个块的层数
            self.layers_per_block = layers_per_block
    
            # 定义输入卷积层,转换输入通道数到最后块的输出通道数
            self.conv_in = nn.Conv2d(
                in_channels,
                block_out_channels[-1],
                kernel_size=3,
                stride=1,
                padding=1,
            )
    
            # 初始化上采样块的模块列表
            self.up_blocks = nn.ModuleList([])
    
            # 根据归一化类型设置时间嵌入通道数
            temb_channels = in_channels if norm_type == "spatial" else None
    
            # 中间块
            self.mid_block = UNetMidBlock2D(
                # 设置中间块的输入通道、eps、激活函数等参数
                in_channels=block_out_channels[-1],
                resnet_eps=1e-6,
                resnet_act_fn=act_fn,
                output_scale_factor=1,
                resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
                attention_head_dim=block_out_channels[-1],
                resnet_groups=norm_num_groups,
                temb_channels=temb_channels,
                add_attention=mid_block_add_attention,
            )
    
            # 上采样
            # 反转输出通道数列表以用于上采样
            reversed_block_out_channels = list(reversed(block_out_channels))
            # 设定当前输出通道数为反转列表的第一个元素
            output_channel = reversed_block_out_channels[0]
            # 遍历上采样块类型并创建相应的上采样块
            for i, up_block_type in enumerate(up_block_types):
                # 存储前一个输出通道数
                prev_output_channel = output_channel
                # 更新当前输出通道数
                output_channel = reversed_block_out_channels[i]
    
                # 判断当前块是否为最后一个块
                is_final_block = i == len(block_out_channels) - 1
    
                # 创建上采样块并传入相关参数
                up_block = get_up_block(
                    up_block_type,
                    num_layers=self.layers_per_block + 1,
                    in_channels=prev_output_channel,
                    out_channels=output_channel,
                    prev_output_channel=None,
                    add_upsample=not is_final_block,
                    resnet_eps=1e-6,
                    resnet_act_fn=act_fn,
                    resnet_groups=norm_num_groups,
                    attention_head_dim=output_channel,
                    temb_channels=temb_channels,
                    resnet_time_scale_shift=norm_type,
                )
                # 将新创建的上采样块添加到模块列表
                self.up_blocks.append(up_block)
                # 更新前一个输出通道数
                prev_output_channel = output_channel
    
            # 输出层
            # 根据归一化类型选择输出卷积层的归一化方法
            if norm_type == "spatial":
                self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
            else:
                self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
            # 设置输出卷积层的激活函数为 SiLU
            self.conv_act = nn.SiLU()
            # 定义最终输出的卷积层,输出通道数为 out_channels
            self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
    
            # 初始化梯度检查点开关为 False
            self.gradient_checkpointing = False
    
        # 定义前向传播方法
        def forward(
            self,
            sample: torch.Tensor,
            latent_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # `Decoder` 类的前向方法文档字符串

        # 通过输入卷积层处理样本
        sample = self.conv_in(sample)

        # 获取上采样块参数的数据类型
        upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
        # 如果处于训练模式且使用梯度检查点
        if self.training and self.gradient_checkpointing:

            # 创建自定义前向函数
            def create_custom_forward(module):
                def custom_forward(*inputs):
                    return module(*inputs)

                return custom_forward

            # 检查 PyTorch 版本
            if is_torch_version(">=", "1.11.0"):
                # 中间处理
                sample = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(self.mid_block),
                    sample,
                    latent_embeds,
                    use_reentrant=False,
                )
                # 转换样本数据类型
                sample = sample.to(upscale_dtype)

                # 上采样处理
                for up_block in self.up_blocks:
                    sample = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(up_block),
                        sample,
                        latent_embeds,
                        use_reentrant=False,
                    )
            else:
                # 中间处理
                sample = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(self.mid_block), sample, latent_embeds
                )
                # 转换样本数据类型
                sample = sample.to(upscale_dtype)

                # 上采样处理
                for up_block in self.up_blocks:
                    sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
        else:
            # 中间处理
            sample = self.mid_block(sample, latent_embeds)
            # 转换样本数据类型
            sample = sample.to(upscale_dtype)

            # 上采样处理
            for up_block in self.up_blocks:
                sample = up_block(sample, latent_embeds)

        # 后处理
        if latent_embeds is None:
            # 如果没有潜在嵌入,则直接进行卷积归一化输出
            sample = self.conv_norm_out(sample)
        else:
            # 如果有潜在嵌入,则传入进行卷积归一化输出
            sample = self.conv_norm_out(sample, latent_embeds)
        # 应用激活函数
        sample = self.conv_act(sample)
        # 最终输出卷积层处理样本
        sample = self.conv_out(sample)

        # 返回处理后的样本
        return sample
# 定义一个名为 UpSample 的类,继承自 nn.Module
class UpSample(nn.Module):
    r"""
    `UpSample` 层用于变分自编码器,可以对输入进行上采样。

    参数:
        in_channels (`int`, *可选*, 默认为 3):
            输入通道的数量。
        out_channels (`int`, *可选*, 默认为 3):
            输出通道的数量。
    """

    # 初始化方法,接受输入和输出通道数量
    def __init__(
        self,
        in_channels: int,  # 输入通道数量
        out_channels: int,  # 输出通道数量
    ) -> None:
        super().__init__()  # 调用父类的初始化方法
        self.in_channels = in_channels  # 保存输入通道数量
        self.out_channels = out_channels  # 保存输出通道数量
        # 创建转置卷积层,用于上采样
        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)

    # 前向传播方法
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        r"""`UpSample` 类的前向传播方法。"""
        x = torch.relu(x)  # 对输入应用 ReLU 激活函数
        x = self.deconv(x)  # 通过转置卷积层进行上采样
        return x  # 返回上采样后的结果


# 定义一个名为 MaskConditionEncoder 的类,继承自 nn.Module
class MaskConditionEncoder(nn.Module):
    """
    用于 AsymmetricAutoencoderKL
    """

    # 初始化方法,接受多个参数以构建编码器
    def __init__(
        self,
        in_ch: int,  # 输入通道数量
        out_ch: int = 192,  # 输出通道数量,默认值为 192
        res_ch: int = 768,  # 结果通道数量,默认值为 768
        stride: int = 16,  # 步幅,默认值为 16
    ) -> None:
        super().__init__()  # 调用父类的初始化方法

        channels = []  # 初始化通道列表
        # 计算每一层的输入和输出通道数量,直到步幅小于等于 1
        while stride > 1:
            stride = stride // 2  # 将步幅减半
            in_ch_ = out_ch * 2  # 输入通道数量为输出通道的两倍
            if out_ch > res_ch:  # 如果输出通道大于结果通道
                out_ch = res_ch  # 将输出通道设置为结果通道
            if stride == 1:  # 如果步幅为 1
                in_ch_ = res_ch  # 输入通道数量设置为结果通道
            channels.append((in_ch_, out_ch))  # 将输入和输出通道对添加到列表
            out_ch *= 2  # 输出通道数量翻倍

        out_channels = []  # 初始化输出通道列表
        # 从通道列表中提取输出通道数量
        for _in_ch, _out_ch in channels:
            out_channels.append(_out_ch)  # 添加输出通道数量
        out_channels.append(channels[-1][0])  # 添加最后一层的输入通道数量

        layers = []  # 初始化层列表
        in_ch_ = in_ch  # 将输入通道数量赋值给临时变量
        # 根据输出通道数量构建卷积层
        for l in range(len(out_channels)):
            out_ch_ = out_channels[l]  # 当前输出通道数量
            if l == 0 or l == 1:  # 对于前两层
                layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1))  # 添加 3x3 卷积层
            else:  # 对于后续层
                layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1))  # 添加 4x4 卷积层
            in_ch_ = out_ch_  # 更新输入通道数量

        self.layers = nn.Sequential(*layers)  # 将所有层组合成一个顺序容器

    # 前向传播方法
    def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:
        r"""`MaskConditionEncoder` 类的前向传播方法。"""
        out = {}  # 初始化输出字典
        # 遍历所有层
        for l in range(len(self.layers)):
            layer = self.layers[l]  # 获取当前层
            x = layer(x)  # 通过当前层处理输入
            out[str(tuple(x.shape))] = x  # 将当前输出的形状作为键,输出张量作为值存入字典
            x = torch.relu(x)  # 对输出应用 ReLU 激活函数
        return out  # 返回输出字典


# 定义一个名为 MaskConditionDecoder 的类,继承自 nn.Module
class MaskConditionDecoder(nn.Module):
    r"""`MaskConditionDecoder` 应与 [`AsymmetricAutoencoderKL`] 一起使用,以增强模型的
    解码器,结合掩膜和被掩膜的图像。
    # 函数参数定义部分
        Args:
            in_channels (`int`, *optional*, defaults to 3):  # 输入通道的数量,默认为3
                The number of input channels.
            out_channels (`int`, *optional*, defaults to 3):  # 输出通道的数量,默认为3
                The number of output channels.
            up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):  # 使用的上采样模块类型,默认为UpDecoderBlock2D
                The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
            block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):  # 每个模块的输出通道数量,默认为64
                The number of output channels for each block.
            layers_per_block (`int`, *optional*, defaults to 2):  # 每个模块的层数,默认为2
                The number of layers per block.
            norm_num_groups (`int`, *optional*, defaults to 32):  # 归一化的组数,默认为32
                The number of groups for normalization.
            act_fn (`str`, *optional*, defaults to `"silu"`):  # 使用的激活函数,默认为silu
                The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
            norm_type (`str`, *optional*, defaults to `"group"`):  # 归一化类型,可以是"group"或"spatial",默认为"group"
                The normalization type to use. Can be either `"group"` or `"spatial"`.
        """
    
        # 初始化方法定义
        def __init__(
            self,
            in_channels: int = 3,  # 输入通道的数量,默认为3
            out_channels: int = 3,  # 输出通道的数量,默认为3
            up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),  # 上采样模块类型,默认为UpDecoderBlock2D
            block_out_channels: Tuple[int, ...] = (64,),  # 每个模块的输出通道数量,默认为64
            layers_per_block: int = 2,  # 每个模块的层数,默认为2
            norm_num_groups: int = 32,  # 归一化的组数,默认为32
            act_fn: str = "silu",  # 激活函数,默认为silu
            norm_type: str = "group",  # 归一化类型,默认为"group"
    ):
        # 调用父类构造函数初始化
        super().__init__()
        # 设置每个块的层数
        self.layers_per_block = layers_per_block

        # 初始化输入卷积层,接收输入通道并生成块输出通道
        self.conv_in = nn.Conv2d(
            in_channels,
            block_out_channels[-1],
            kernel_size=3,
            stride=1,
            padding=1,
        )

        # 创建一个空的模块列表,用于存储上采样块
        self.up_blocks = nn.ModuleList([])

        # 根据归一化类型设置时间嵌入通道数
        temb_channels = in_channels if norm_type == "spatial" else None

        # 中间块的初始化
        self.mid_block = UNetMidBlock2D(
            in_channels=block_out_channels[-1],  # 使用最后一个块的输出通道作为输入
            resnet_eps=1e-6,                      # ResNet 的 epsilon 参数
            resnet_act_fn=act_fn,                 # ResNet 的激活函数
            output_scale_factor=1,                # 输出缩放因子
            resnet_time_scale_shift="default" if norm_type == "group" else norm_type,  # 时间缩放偏移
            attention_head_dim=block_out_channels[-1],  # 注意力头的维度
            resnet_groups=norm_num_groups,        # ResNet 的组数
            temb_channels=temb_channels,          # 时间嵌入通道数
        )

        # 初始化上采样块
        reversed_block_out_channels = list(reversed(block_out_channels))  # 反转输出通道列表
        output_channel = reversed_block_out_channels[0]  # 获取第一个输出通道
        for i, up_block_type in enumerate(up_block_types):  # 遍历上采样块类型
            prev_output_channel = output_channel  # 保存上一个输出通道数
            output_channel = reversed_block_out_channels[i]  # 获取当前输出通道数

            is_final_block = i == len(block_out_channels) - 1  # 判断是否为最后一个块

            # 获取上采样块
            up_block = get_up_block(
                up_block_type,  # 上采样块类型
                num_layers=self.layers_per_block + 1,  # 上采样层数
                in_channels=prev_output_channel,  # 输入通道数
                out_channels=output_channel,  # 输出通道数
                prev_output_channel=None,  # 前一个输出通道数
                add_upsample=not is_final_block,  # 是否添加上采样操作
                resnet_eps=1e-6,  # ResNet 的 epsilon 参数
                resnet_act_fn=act_fn,  # ResNet 的激活函数
                resnet_groups=norm_num_groups,  # ResNet 的组数
                attention_head_dim=output_channel,  # 注意力头的维度
                temb_channels=temb_channels,  # 时间嵌入通道数
                resnet_time_scale_shift=norm_type,  # 时间缩放偏移
            )
            self.up_blocks.append(up_block)  # 将上采样块添加到模块列表
            prev_output_channel = output_channel  # 更新前一个输出通道数

        # 条件编码器的初始化
        self.condition_encoder = MaskConditionEncoder(
            in_ch=out_channels,  # 输入通道数
            out_ch=block_out_channels[0],  # 输出通道数
            res_ch=block_out_channels[-1],  # ResNet 通道数
        )

        # 输出层的归一化处理
        if norm_type == "spatial":  # 如果归一化类型为空间归一化
            self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)  # 初始化空间归一化
        else:  # 否则使用组归一化
            self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)  # 初始化组归一化
        # 初始化激活函数为 SiLU
        self.conv_act = nn.SiLU()
        # 初始化输出卷积层
        self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)

        # 初始化梯度检查点标志为 False
        self.gradient_checkpointing = False

    def forward(
        self,
        z: torch.Tensor,  # 输入的张量 z
        image: Optional[torch.Tensor] = None,  # 可选的输入图像张量
        mask: Optional[torch.Tensor] = None,  # 可选的输入掩码张量
        latent_embeds: Optional[torch.Tensor] = None,  # 可选的潜在嵌入张量
# 定义一个向量量化器类,继承自 nn.Module
class VectorQuantizer(nn.Module):
    """
    Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
    multiplications and allows for post-hoc remapping of indices.
    """

    # 初始化方法,设置类的基本参数
    def __init__(
        self,
        n_e: int,  # 向量量化的嵌入数量
        vq_embed_dim: int,  # 嵌入的维度
        beta: float,  # beta 参数,用于调节量化误差
        remap=None,  # 用于重映射的可选参数
        unknown_index: str = "random",  # 未知索引的处理方式
        sane_index_shape: bool = False,  # 是否强制索引形状的合理性
        legacy: bool = True,  # 是否使用旧版本的实现
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 保存嵌入数量
        self.n_e = n_e
        # 保存嵌入维度
        self.vq_embed_dim = vq_embed_dim
        # 保存 beta 参数
        self.beta = beta
        # 保存是否使用旧版的标志
        self.legacy = legacy

        # 初始化嵌入层,随机生成嵌入权重
        self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim)
        # 将嵌入权重初始化为[-1/n_e, 1/n_e]的均匀分布
        self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)

        # 处理重映射参数
        self.remap = remap
        if self.remap is not None:  # 如果提供了重映射文件
            # 注册一个缓冲区,加载重映射的数据
            self.register_buffer("used", torch.tensor(np.load(self.remap)))
            self.used: torch.Tensor  # 声明用于重映射的张量
            # 重新嵌入的数量
            self.re_embed = self.used.shape[0]
            # 设置未知索引的方式
            self.unknown_index = unknown_index  # "random"、"extra"或整数
            if self.unknown_index == "extra":  # 如果未知索引为"extra"
                self.unknown_index = self.re_embed  # 设置为重新嵌入数量
                self.re_embed = self.re_embed + 1  # 增加重新嵌入的数量
            # 打印重映射信息
            print(
                f"Remapping {self.n_e} indices to {self.re_embed} indices. "
                f"Using {self.unknown_index} for unknown indices."
            )
        else:
            # 如果没有提供重映射,则重新嵌入数量与嵌入数量相同
            self.re_embed = n_e

        # 保存是否强制索引形状合理的标志
        self.sane_index_shape = sane_index_shape

    # 将索引映射到已使用的索引
    def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor:
        # 保存输入张量的形状
        ishape = inds.shape
        # 确保输入张量至少有两个维度
        assert len(ishape) > 1
        # 将输入张量重塑为二维,保持第一个维度不变
        inds = inds.reshape(ishape[0], -1)
        # 将使用的张量转换到相同的设备
        used = self.used.to(inds)
        # 检查 inds 中的元素是否与 used 中的元素匹配
        match = (inds[:, :, None] == used[None, None, ...]).long()
        # 找到匹配的索引
        new = match.argmax(-1)
        # 检查是否有未知索引
        unknown = match.sum(2) < 1
        if self.unknown_index == "random":  # 如果未知索引为"random"
            # 随机生成未知索引
            new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
        else:
            # 否则设置为指定的未知索引
            new[unknown] = self.unknown_index
        # 将结果重塑为原来的形状并返回
        return new.reshape(ishape)

    # 将映射到的索引还原为所有索引
    def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor:
        # 保存输入张量的形状
        ishape = inds.shape
        # 确保输入张量至少有两个维度
        assert len(ishape) > 1
        # 将输入张量重塑为二维,保持第一个维度不变
        inds = inds.reshape(ishape[0], -1)
        # 将使用的张量转换到相同的设备
        used = self.used.to(inds)
        if self.re_embed > self.used.shape[0]:  # 如果有额外的标记
            # 将超出已使用索引的标记设置为零
            inds[inds >= self.used.shape[0]] = 0  # 简单设置为零
        # 根据 inds 从 used 中选择对应的值
        back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
        # 将结果重塑为原来的形状并返回
        return back.reshape(ishape)
    # 前向传播方法,接收一个张量 z,返回量化张量、损失和附加信息
    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Tuple]:
        # 将 z 重新排列为 (batch, height, width, channel) 的形状,并展平
        z = z.permute(0, 2, 3, 1).contiguous()
        # 将 z 展平为二维张量,维度为 (batch_size * height * width, vq_embed_dim)
        z_flattened = z.view(-1, self.vq_embed_dim)

        # 计算 z 与嵌入 e_j 之间的距离,公式为 (z - e)^2 = z^2 + e^2 - 2 * e * z
        min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1)

        # 根据最小编码索引从嵌入中获取量化的 z,重新调整为原始 z 的形状
        z_q = self.embedding(min_encoding_indices).view(z.shape)
        perplexity = None  # 初始化困惑度为 None
        min_encodings = None  # 初始化最小编码为 None

        # 计算嵌入的损失
        if not self.legacy:
            # 计算损失时考虑 beta 权重
            loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
        else:
            # 计算损失时考虑不同的权重
            loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)

        # 保持梯度
        z_q: torch.Tensor = z + (z_q - z).detach()

        # 将 z_q 重新排列为与原始输入形状相匹配
        z_q = z_q.permute(0, 3, 1, 2).contiguous()

        if self.remap is not None:
            # 如果存在重映射,则调整最小编码索引形状,增加批次维度
            min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1)  # add batch axis
            # 将索引映射到使用的编码
            min_encoding_indices = self.remap_to_used(min_encoding_indices)
            # 将索引展平
            min_encoding_indices = min_encoding_indices.reshape(-1, 1)  # flatten

        if self.sane_index_shape:
            # 如果需要,调整最小编码索引的形状以匹配 z_q 的形状
            min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])

        # 返回量化的 z、损失和其他信息的元组
        return z_q, loss, (perplexity, min_encodings, min_encoding_indices)

    # 获取代码簿条目,根据索引返回量化的潜在向量
    def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.Tensor:
        # shape 指定 (batch, height, width, channel)
        if self.remap is not None:
            # 如果存在重映射,则调整索引形状,增加批次维度
            indices = indices.reshape(shape[0], -1)  # add batch axis
            # 将索引映射回所有编码
            indices = self.unmap_to_all(indices)
            # 将索引展平
            indices = indices.reshape(-1)  # flatten again

        # 获取量化的潜在向量
        z_q: torch.Tensor = self.embedding(indices)

        if shape is not None:
            # 如果形状不为空,将 z_q 重新调整为指定的形状
            z_q = z_q.view(shape)
            # 重新排列以匹配原始输入形状
            z_q = z_q.permute(0, 3, 1, 2).contiguous()

        # 返回量化的潜在向量
        return z_q
# 定义对角高斯分布类
class DiagonalGaussianDistribution(object):
    # 初始化方法,接收参数和是否为确定性分布的标志
    def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
        # 将参数存储在实例中
        self.parameters = parameters
        # 将参数分为均值和对数方差
        self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
        # 将对数方差限制在-30到20之间
        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
        # 记录是否为确定性分布
        self.deterministic = deterministic
        # 计算标准差
        self.std = torch.exp(0.5 * self.logvar)
        # 计算方差
        self.var = torch.exp(self.logvar)
        # 如果是确定性分布,方差和标准差设为零
        if self.deterministic:
            self.var = self.std = torch.zeros_like(
                self.mean, device=self.parameters.device, dtype=self.parameters.dtype
            )

    # 采样方法,生成符合分布的样本
    def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
        # 确保样本与参数在相同的设备上且具有相同的数据类型
        sample = randn_tensor(
            self.mean.shape,
            generator=generator,
            device=self.parameters.device,
            dtype=self.parameters.dtype,
        )
        # 根据均值和标准差生成样本
        x = self.mean + self.std * sample
        # 返回生成的样本
        return x

    # 计算与另一个分布的KL散度
    def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
        # 如果是确定性分布,KL散度为0
        if self.deterministic:
            return torch.Tensor([0.0])
        else:
            # 如果没有提供另一个分布
            if other is None:
                # 计算自身的KL散度
                return 0.5 * torch.sum(
                    torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
                    dim=[1, 2, 3],
                )
            else:
                # 计算与另一个分布的KL散度
                return 0.5 * torch.sum(
                    torch.pow(self.mean - other.mean, 2) / other.var
                    + self.var / other.var
                    - 1.0
                    - self.logvar
                    + other.logvar,
                    dim=[1, 2, 3],
                )

    # 计算负对数似然
    def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
        # 如果是确定性分布,负对数似然为0
        if self.deterministic:
            return torch.Tensor([0.0])
        # 计算常数log(2π)
        logtwopi = np.log(2.0 * np.pi)
        # 返回负对数似然的计算结果
        return 0.5 * torch.sum(
            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
            dim=dims,
        )

    # 返回分布的众数
    def mode(self) -> torch.Tensor:
        return self.mean


# 定义EncoderTiny类,继承自nn.Module
class EncoderTiny(nn.Module):
    r"""
    `EncoderTiny`层是`Encoder`层的简化版本。

    参数:
        in_channels (`int`):
            输入通道的数量。
        out_channels (`int`):
            输出通道的数量。
        num_blocks (`Tuple[int, ...]`):
            元组中的每个值表示一个Conv2d层后跟随`value`数量的`AutoencoderTinyBlock`。
        block_out_channels (`Tuple[int, ...]`):
            每个块的输出通道数量。
        act_fn (`str`):
            使用的激活函数。请参见`~diffusers.models.activations.get_activation`以获取可用选项。
    """
    # 初始化方法,构造 EncoderTiny 类的实例
        def __init__(
            self,
            in_channels: int,  # 输入通道数
            out_channels: int,  # 输出通道数
            num_blocks: Tuple[int, ...],  # 每个层中块的数量
            block_out_channels: Tuple[int, ...],  # 每个层的输出通道数
            act_fn: str,  # 激活函数的类型
        ):
            # 调用父类的初始化方法
            super().__init__()
    
            layers = []  # 初始化空层列表
            # 遍历每个层的块数量
            for i, num_block in enumerate(num_blocks):
                num_channels = block_out_channels[i]  # 当前层的输出通道数
    
                # 如果是第一个层,创建卷积层
                if i == 0:
                    layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1))
                else:
                    # 创建后续卷积层,包含步幅和无偏置选项
                    layers.append(
                        nn.Conv2d(
                            num_channels,
                            num_channels,
                            kernel_size=3,
                            padding=1,
                            stride=2,
                            bias=False,
                        )
                    )
    
                # 添加指定数量的 AutoencoderTinyBlock
                for _ in range(num_block):
                    layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
    
            # 添加最后的卷积层,将最后一层输出通道映射到目标输出通道
            layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1))
    
            # 将所有层组合为一个顺序模块
            self.layers = nn.Sequential(*layers)
            # 初始化梯度检查点标志为 False
            self.gradient_checkpointing = False
    
        # 前向传播方法
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            r"""EncoderTiny 类的前向方法。"""
            # 如果模型处于训练状态并且启用了梯度检查点
            if self.training and self.gradient_checkpointing:
    
                # 创建自定义前向传播方法
                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs)
    
                    return custom_forward
    
                # 根据 PyTorch 版本选择检查点方式
                if is_torch_version(">=", "1.11.0"):
                    x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
                else:
                    x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
    
            else:
                # 将图像从 [-1, 1] 线性缩放到 [0, 1],以匹配 TAESD 规范
                x = self.layers(x.add(1).div(2))
    
            # 返回前向传播的输出
            return x
# 定义一个名为 `DecoderTiny` 的类,继承自 `nn.Module`
class DecoderTiny(nn.Module):
    r"""
    `DecoderTiny` 层是 `Decoder` 层的简化版本。

    参数:
        in_channels (`int`):
            输入通道的数量。
        out_channels (`int`):
            输出通道的数量。
        num_blocks (`Tuple[int, ...]`):
            元组中的每个值表示一个 Conv2d 层后面跟着 `value` 个 `AutoencoderTinyBlock` 的数量。
        block_out_channels (`Tuple[int, ...]`):
            每个块的输出通道数量。
        upsampling_scaling_factor (`int`):
            用于上采样的缩放因子。
        act_fn (`str`):
            使用的激活函数。有关可用选项,请参见 `~diffusers.models.activations.get_activation`。
    """

    # 初始化方法,设置类的基本参数
    def __init__(
        self,
        in_channels: int,  # 输入通道数量
        out_channels: int,  # 输出通道数量
        num_blocks: Tuple[int, ...],  # 每个块的数量
        block_out_channels: Tuple[int, ...],  # 每个块的输出通道数量
        upsampling_scaling_factor: int,  # 上采样缩放因子
        act_fn: str,  # 激活函数名称
        upsample_fn: str,  # 上采样函数名称
    ):
        super().__init__()  # 调用父类的初始化方法

        # 初始化层的列表
        layers = [
            # 添加一个 Conv2d 层,输入通道到第一个块的输出通道
            nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1),
            # 添加指定的激活函数
            get_activation(act_fn),
        ]

        # 遍历每个块的数量
        for i, num_block in enumerate(num_blocks):
            is_final_block = i == (len(num_blocks) - 1)  # 判断是否为最后一个块
            num_channels = block_out_channels[i]  # 获取当前块的输出通道数量

            # 对于当前块的数量,添加相应数量的 `AutoencoderTinyBlock`
            for _ in range(num_block):
                layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))

            # 如果不是最后一个块,则添加上采样层
            if not is_final_block:
                layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor, mode=upsample_fn))

            # 设置当前卷积的输出通道数量
            conv_out_channel = num_channels if not is_final_block else out_channels
            # 添加卷积层
            layers.append(
                nn.Conv2d(
                    num_channels,  # 输入通道数量
                    conv_out_channel,  # 输出通道数量
                    kernel_size=3,  # 卷积核大小
                    padding=1,  # 填充大小
                    bias=is_final_block,  # 如果是最后一个块,使用偏置
                )
            )

        # 将所有层组合成一个顺序模型
        self.layers = nn.Sequential(*layers)
        # 初始化梯度检查点标志为 False
        self.gradient_checkpointing = False

    # 前向传播方法
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        r"""`DecoderTiny` 类的前向方法。"""
        # 将输入张量缩放并限制到 [-3, 3] 范围
        x = torch.tanh(x / 3) * 3

        # 如果处于训练状态并且启用了梯度检查点
        if self.training and self.gradient_checkpointing:

            # 创建自定义前向函数
            def create_custom_forward(module):
                def custom_forward(*inputs):
                    return module(*inputs)  # 调用模块

                return custom_forward

            # 如果 PyTorch 版本大于等于 1.11.0,使用非重入的检查点
            if is_torch_version(">=", "1.11.0"):
                x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
            else:
                x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)  # 使用检查点

        else:
            x = self.layers(x)  # 否则直接通过层处理输入

        # 将图像从 [0, 1] 范围缩放到 [-1, 1],以匹配 diffusers 的约定
        return x.mul(2).sub(1)  # 缩放并返回结果

标签:sample,self,torch,diffusers,channels,源码,解析,block,out
From: https://www.cnblogs.com/apachecn/p/18492371

相关文章

  • diffusers-源码解析-九-
    diffusers源码解析(九).\diffusers\models\embeddings_flax.py#Copyright2024TheHuggingFaceTeam.Allrightsreserved.##LicensedundertheApacheLicense,Version2.0(the"License");#youmaynotusethisfileexceptincompliancewiththe......
  • diffusers-源码解析-二十一-
    diffusers源码解析(二十一).\diffusers\pipelines\controlnet\pipeline_controlnet.py#版权信息,指明该代码由HuggingFace团队版权所有##根据Apache2.0许可证授权,用户需遵循许可证规定使用该文件#许可证可以在以下网址获取##http://www.apache.org/licenses/L......
  • diffusers-源码解析-二十四-
    diffusers源码解析(二十四).\diffusers\pipelines\controlnet_sd3\pipeline_stable_diffusion_3_controlnet.py#版权声明,指出版权所有者及相关信息#Copyright2024StabilityAI,TheHuggingFaceTeamandTheInstantXTeam.Allrightsreserved.##按照Apache2.0许可......
  • diffusers-源码解析-二十三-
    diffusers源码解析(二十三).\diffusers\pipelines\controlnet\pipeline_controlnet_sd_xl_img2img.py#版权所有2024HuggingFace团队。保留所有权利。##根据Apache许可证第2.0版(“许可证”)许可;#除非遵守许可证,否则您不得使用此文件。#您可以在以下网址获得许可证副......
  • diffusers-源码解析-二十六-
    diffusers源码解析(二十六).\diffusers\pipelines\deepfloyd_if\pipeline_if_inpainting_superresolution.py#导入html模块,用于处理HTML文本importhtml#导入inspect模块,用于获取对象的信息importinspect#导入re模块,用于正则表达式匹配importre#导入urllib.......
  • diffusers-源码解析-二十九-
    diffusers源码解析(二十九).\diffusers\pipelines\deprecated\stable_diffusion_variants\pipeline_stable_diffusion_model_editing.py#版权信息,声明版权和许可协议#Copyright2024TIMEAuthorsandTheHuggingFaceTeam.Allrightsreserved."#根据ApacheLicense2.0......
  • diffusers-源码解析-十一-
    diffusers源码解析(十一).\diffusers\models\transformers\hunyuan_transformer_2d.py#版权所有2024HunyuanDiT作者,QixunWang和HuggingFace团队。保留所有权利。##根据Apache许可证第2.0版("许可证")进行许可;#除非符合许可证,否则您不得使用此文件。#您可以在以......
  • diffusers-源码解析-十五-
    diffusers源码解析(十五).\diffusers\models\unets\unet_3d_condition.py#版权声明,声明此代码的版权信息和所有权#Copyright2024AlibabaDAMO-VILABandTheHuggingFaceTeam.Allrightsreserved.#版权声明,声明此代码的版权信息和所有权#Copyright2024TheModelSco......
  • diffusers-源码解析-十四-
    diffusers源码解析(十四).\diffusers\models\unets\unet_2d_blocks_flax.py#版权声明,说明该文件的版权信息及相关许可协议#Copyright2024TheHuggingFaceTeam.Allrightsreserved.##许可信息,使用ApacheLicense2.0许可#LicensedundertheApacheLicense,Versi......
  • diffusers-源码解析-十三-
    diffusers源码解析(十三).\diffusers\models\unets\unet_2d.py#版权声明,表示该代码由HuggingFace团队所有##根据Apache2.0许可证进行许可;#除非遵循许可证,否则不得使用此文件。#可以在以下地址获取许可证的副本:##http://www.apache.org/licenses/LICENSE-2.......