首页 > 编程语言 >CogVideo---CogVideoX-微调代码源码解析-八-

CogVideo---CogVideoX-微调代码源码解析-八-

时间:2024-10-23 09:23:08浏览次数:1  
标签:channels CogVideoX nn conv self torch --- 源码 sigma

CogVideo & CogVideoX 微调代码源码解析(八)

.\cogvideo-finetune\sat\sgm\modules\autoencoding\vqvae\movq_enc_3d.py

# pytorch_diffusion + derived encoder decoder
import math  # 导入数学库,提供数学函数
import torch  # 导入 PyTorch 库,用于深度学习
import torch.nn as nn  # 导入 PyTorch 的神经网络模块
import torch.nn.functional as F  # 导入 PyTorch 的功能性操作模块
import numpy as np  # 导入 NumPy 库,用于数组和数学操作

from beartype import beartype  # 导入 beartype 库,用于类型检查
from beartype.typing import Union, Tuple, Optional, List  # 导入类型注解
from einops import rearrange  # 导入 einops 库,用于数组重排


def cast_tuple(t, length=1):  # 定义将输入转换为元组的函数
    return t if isinstance(t, tuple) else ((t,) * length)  # 如果是元组,返回原值,否则返回长度为 length 的元组


def divisible_by(num, den):  # 定义判断是否整除的函数
    return (num % den) == 0  # 返回 num 是否能被 den 整除


def is_odd(n):  # 定义判断数字是否为奇数的函数
    return not divisible_by(n, 2)  # 返回 n 是否为奇数


def get_timestep_embedding(timesteps, embedding_dim):  # 定义生成时间步嵌入的函数
    """
    This matches the implementation in Denoising Diffusion Probabilistic Models:
    From Fairseq.
    Build sinusoidal embeddings.
    This matches the implementation in tensor2tensor, but differs slightly
    from the description in Section 3.5 of "Attention Is All You Need".
    """
    assert len(timesteps.shape) == 1  # 确保 timesteps 是一维的

    half_dim = embedding_dim // 2  # 计算嵌入维度的一半
    emb = math.log(10000) / (half_dim - 1)  # 计算正弦嵌入的缩放因子
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)  # 生成指数衰减的嵌入
    emb = emb.to(device=timesteps.device)  # 将嵌入移动到 timesteps 的设备上
    emb = timesteps.float()[:, None] * emb[None, :]  # 扩展 timesteps 和嵌入的维度以便相乘
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)  # 将正弦和余弦嵌入在一起
    if embedding_dim % 2 == 1:  # 如果嵌入维度为奇数,进行零填充
        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))  # 在最后填充一个零
    return emb  # 返回生成的嵌入


def nonlinearity(x):  # 定义非线性激活函数
    # swish
    return x * torch.sigmoid(x)  # 返回 Swish 激活值


class CausalConv3d(nn.Module):  # 定义因果三维卷积层的类
    @beartype  # 应用类型检查装饰器
    def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], pad_mode="constant", **kwargs):  # 初始化函数
        super().__init__()  # 调用父类构造函数
        kernel_size = cast_tuple(kernel_size, 3)  # 将 kernel_size 转换为三元组

        time_kernel_size, height_kernel_size, width_kernel_size = kernel_size  # 解包卷积核的大小

        assert is_odd(height_kernel_size) and is_odd(width_kernel_size)  # 确保高度和宽度的卷积核为奇数

        dilation = kwargs.pop("dilation", 1)  # 从参数中弹出膨胀值,默认为1
        stride = kwargs.pop("stride", 1)  # 从参数中弹出步幅值,默认为1

        self.pad_mode = pad_mode  # 保存填充模式
        time_pad = dilation * (time_kernel_size - 1) + (1 - stride)  # 计算时间维度的填充量
        height_pad = height_kernel_size // 2  # 计算高度维度的填充量
        width_pad = width_kernel_size // 2  # 计算宽度维度的填充量

        self.height_pad = height_pad  # 保存高度填充量
        self.width_pad = width_pad  # 保存宽度填充量
        self.time_pad = time_pad  # 保存时间填充量
        self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)  # 定义因果卷积的填充格式

        stride = (stride, 1, 1)  # 将步幅转换为三维元组
        dilation = (dilation, 1, 1)  # 将膨胀转换为三维元组
        self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)  # 初始化三维卷积层
    # 定义前向传播函数,接收输入张量 x
    def forward(self, x):
        # 检查填充模式是否为常数填充
        if self.pad_mode == "constant":
            # 设置三维的因果填充参数
            causal_padding_3d = (self.time_pad, 0, self.width_pad, self.width_pad, self.height_pad, self.height_pad)
            # 对输入张量 x 进行常数填充
            x = F.pad(x, causal_padding_3d, mode="constant", value=0)
        # 检查填充模式是否为首元素填充
        elif self.pad_mode == "first":
            # 复制输入张量的首元素并在时间维度上进行填充
            pad_x = torch.cat([x[:, :, :1]] * self.time_pad, dim=2)
            # 将填充的张量与原张量拼接
            x = torch.cat([pad_x, x], dim=2)
            # 设置二维的因果填充参数
            causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
            # 对拼接后的张量进行常数填充
            x = F.pad(x, causal_padding_2d, mode="constant", value=0)
        # 检查填充模式是否为反射填充
        elif self.pad_mode == "reflect":
            # 进行反射填充,获取输入张量的反向切片
            reflect_x = x[:, :, 1 : self.time_pad + 1, :, :].flip(dims=[2])
            # 如果反射张量的时间维度小于所需的填充时间,则进行零填充
            if reflect_x.shape[2] < self.time_pad:
                reflect_x = torch.cat(
                    [torch.zeros_like(x[:, :, :1, :, :])] * (self.time_pad - reflect_x.shape[2]) + [reflect_x], dim=2
                )
            # 将反射填充的张量与原张量拼接
            x = torch.cat([reflect_x, x], dim=2)
            # 设置二维的因果填充参数
            causal_padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
            # 对拼接后的张量进行常数填充
            x = F.pad(x, causal_padding_2d, mode="constant", value=0)
        # 如果填充模式无效,则引发错误
        else:
            raise ValueError("Invalid pad mode")
        # 返回经过卷积操作的结果
        return self.conv(x)
# 定义一个用于归一化的函数,适用于3D和2D数据
def Normalize3D(in_channels):  # same for 3D and 2D
    # 返回一个GroupNorm层,用于归一化输入通道
    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)

# 定义一个3D上采样的类,继承自nn.Module
class Upsample3D(nn.Module):
    # 初始化方法,设置输入通道、是否使用卷积和时间压缩参数
    def __init__(self, in_channels, with_conv, compress_time=False):
        super().__init__()  # 调用父类初始化方法
        self.with_conv = with_conv  # 保存是否使用卷积的标志
        if self.with_conv:
            # 如果使用卷积,创建一个卷积层,输入输出通道相同
            self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
        self.compress_time = compress_time  # 保存时间压缩的标志

    # 前向传播方法
    def forward(self, x):
        if self.compress_time:  # 如果启用时间压缩
            if x.shape[2] > 1:  # 如果时间维度大于1
                # 分离第一帧和其余帧
                x_first, x_rest = x[:, :, 0], x[:, :, 1:]

                # 将第一帧上采样到原来的2倍
                x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
                # 将其余帧上采样到原来的2倍
                x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
                # 将第一帧和其余帧在时间维度上连接
                x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
            else:  # 如果时间维度为1
                x = x.squeeze(2)  # 去除时间维度
                # 将数据上采样到原来的2倍
                x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
                x = x[:, :, None, :, :]  # 添加时间维度
        else:  # 如果不进行时间压缩
            # 只对2D进行上采样
            t = x.shape[2]  # 获取时间维度大小
            # 重新排列数据,合并批次和时间维度
            x = rearrange(x, "b c t h w -> (b t) c h w")
            # 将数据上采样到原来的2倍
            x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
            # 恢复数据排列
            x = rearrange(x, "(b t) c h w -> b c t h w", t=t)

        if self.with_conv:  # 如果使用卷积
            t = x.shape[2]  # 获取时间维度大小
            # 重新排列数据,合并批次和时间维度
            x = rearrange(x, "b c t h w -> (b t) c h w")
            # 通过卷积层处理数据
            x = self.conv(x)
            # 恢复数据排列
            x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
        return x  # 返回处理后的数据


# 定义一个3D下采样的类,继承自nn.Module
class DownSample3D(nn.Module):
    # 初始化方法,设置输入通道、是否使用卷积、时间压缩和输出通道
    def __init__(self, in_channels, with_conv, compress_time=False, out_channels=None):
        super().__init__()  # 调用父类初始化方法
        self.with_conv = with_conv  # 保存是否使用卷积的标志
        if out_channels is None:  # 如果没有指定输出通道
            out_channels = in_channels  # 设置输出通道为输入通道
        if self.with_conv:  # 如果使用卷积
            # 在torch的卷积中没有不对称填充,必须手动处理
            self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
        self.compress_time = compress_time  # 保存时间压缩的标志
    # 定义前向传播函数,接受输入 x
        def forward(self, x):
            # 如果设置了 compress_time 为 True
            if self.compress_time:
                # 获取输入张量 x 的高度和宽度
                h, w = x.shape[-2:]
                # 重新排列输入张量 x 的维度,将 (batch, channel, time, height, width) 转换为 ((batch*height*width), channel, time)
                x = rearrange(x, "b c t h w -> (b h w) c t")
    
                # 分离出第一帧和其余帧
                x_first, x_rest = x[..., 0], x[..., 1:]
    
                # 如果剩余帧存在
                if x_rest.shape[-1] > 0:
                    # 对剩余帧进行平均池化,减少时间维度
                    x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
                # 将第一帧和池化后的剩余帧拼接在一起
                x = torch.cat([x_first[..., None], x_rest], dim=-1)
                # 重新排列回原始维度,恢复为 (batch, channel, time, height, width)
                x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w)
    
            # 如果设置了 with_conv 为 True
            if self.with_conv:
                # 定义填充参数,确保维度匹配
                pad = (0, 1, 0, 1)
                # 对输入张量 x 进行零填充
                x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
                # 获取时间维度的大小
                t = x.shape[2]
                # 重新排列输入张量 x 的维度,为卷积操作准备
                x = rearrange(x, "b c t h w -> (b t) c h w")
                # 通过卷积层处理输入张量
                x = self.conv(x)
                # 重新排列输出张量,恢复维度
                x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
            else:
                # 获取时间维度的大小
                t = x.shape[2]
                # 重新排列输入张量 x 的维度,准备进行池化操作
                x = rearrange(x, "b c t h w -> (b t) c h w")
                # 对输入张量进行平均池化,减小空间维度
                x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
                # 重新排列输出张量,恢复维度
                x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
            # 返回处理后的输出张量
            return x
# 定义一个 3D 残差块,继承自 nn.Module
class ResnetBlock3D(nn.Module):
    # 初始化方法,设置输入、输出通道和其他参数
    def __init__(
        self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512, pad_mode="constant"
    ):
        # 调用父类构造函数
        super().__init__()
        # 保存输入通道数
        self.in_channels = in_channels
        # 如果未指定输出通道,则与输入通道数相同
        out_channels = in_channels if out_channels is None else out_channels
        # 保存输出通道数
        self.out_channels = out_channels
        # 保存是否使用卷积捷径的标志
        self.use_conv_shortcut = conv_shortcut

        # 初始化输入通道的归一化层
        self.norm1 = Normalize3D(in_channels)
        # 使用因果卷积初始化第一层卷积
        self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
        # 如果时间嵌入通道数大于0,初始化时间嵌入投影层
        if temb_channels > 0:
            self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
        # 初始化输出通道的归一化层
        self.norm2 = Normalize3D(out_channels)
        # 初始化 dropout 层
        self.dropout = torch.nn.Dropout(dropout)
        # 使用因果卷积初始化第二层卷积
        self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
        # 如果输入和输出通道不相同
        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                # 使用因果卷积初始化捷径卷积
                self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
            else:
                # 初始化 1x1 卷积作为捷径
                self.nin_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

    # 前向传播方法
    def forward(self, x, temb):
        # 将输入赋值给 h
        h = x
        # 对 h 进行归一化
        h = self.norm1(h)
        # 应用非线性激活函数
        h = nonlinearity(h)
        # 通过第一层卷积处理 h
        h = self.conv1(h)

        # 如果时间嵌入存在,将其投影并添加到 h
        if temb is not None:
            h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]

        # 对 h 进行归一化
        h = self.norm2(h)
        # 应用非线性激活函数
        h = nonlinearity(h)
        # 应用 dropout
        h = self.dropout(h)
        # 通过第二层卷积处理 h
        h = self.conv2(h)

        # 如果输入和输出通道不相同
        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                # 使用捷径卷积处理 x
                x = self.conv_shortcut(x)
            else:
                # 使用 1x1 卷积处理 x
                x = self.nin_shortcut(x)

        # 返回输入和处理后的 h 的和
        return x + h
    # 初始化方法,接收输入通道数
        def __init__(self, in_channels):
            # 调用父类的初始化方法
            super().__init__()
            # 保存输入通道数
            self.in_channels = in_channels
    
            # 创建 3D 归一化层
            self.norm = Normalize3D(in_channels)
            # 创建查询卷积层,使用 1x1 卷积
            self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
            # 创建键卷积层,使用 1x1 卷积
            self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
            # 创建值卷积层,使用 1x1 卷积
            self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
            # 创建输出投影卷积层,使用 1x1 卷积
            self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
    
        # 前向传播方法,接收输入 x
        def forward(self, x):
            # 将输入赋值给 h_
            h_ = x
            # 对输入进行归一化
            h_ = self.norm(h_)
    
            # 获取时间维度的大小
            t = h_.shape[2]
            # 调整 h_ 的形状以便于后续计算
            h_ = rearrange(h_, "b c t h w -> (b t) c h w")
    
            # 计算查询向量
            q = self.q(h_)
            # 计算键向量
            k = self.k(h_)
            # 计算值向量
            v = self.v(h_)
    
            # 计算注意力
            b, c, h, w = q.shape
            # 调整查询向量的形状
            q = q.reshape(b, c, h * w)
            # 调整查询向量的维度顺序
            q = q.permute(0, 2, 1)  # b,hw,c
            # 调整键向量的形状
            k = k.reshape(b, c, h * w)  # b,c,hw
    
            # # 原始版本,fp16 中出现 nan
            # w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
            # w_ = w_ * (int(c)**(-0.5))
            # 在查询向量上实现 c**-0.5 的缩放
            q = q * (int(c) ** (-0.5))
            # 计算注意力权重
            w_ = torch.bmm(q, k)  # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
    
            # 对注意力权重进行 softmax 归一化
            w_ = torch.nn.functional.softmax(w_, dim=2)
    
            # 处理值向量
            v = v.reshape(b, c, h * w)
            # 调整注意力权重的维度顺序
            w_ = w_.permute(0, 2, 1)  # b,hw,hw (first hw of k, second of q)
            # 计算最终输出
            h_ = torch.bmm(v, w_)  # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
            # 将输出调整回原始形状
            h_ = h_.reshape(b, c, h, w)
    
            # 通过输出投影层处理结果
            h_ = self.proj_out(h_)
    
            # 调整输出的形状以匹配输入
            h_ = rearrange(h_, "(b t) c h w -> b c t h w", t=t)
    
            # 返回输入与输出的相加结果
            return x + h_
# 定义一个3D编码器类,继承自PyTorch的nn.Module
class Encoder3D(nn.Module):
    # 初始化方法,接收多个参数用于配置编码器
    def __init__(
        self,
        *,
        ch,  # 输入通道数
        out_ch,  # 输出通道数
        ch_mult=(1, 2, 4, 8),  # 通道数倍增因子
        num_res_blocks,  # 每个分辨率的残差块数量
        attn_resolutions,  # 注意力分辨率
        dropout=0.0,  # dropout比率
        resamp_with_conv=True,  # 是否使用卷积进行重采样
        in_channels,  # 输入数据的通道数
        resolution,  # 输入数据的分辨率
        z_channels,  # 潜在变量的通道数
        double_z=True,  # 是否使用双重潜在变量
        pad_mode="first",  # 填充模式
        temporal_compress_times=4,  # 时间压缩次数
        **ignore_kwargs,  # 额外参数,忽略处理
    # 前向传播方法,定义数据如何通过网络流动
    def forward(self, x, use_cp=False):
        # 确保输入x的高度和宽度等于预期分辨率
        # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
        # 初始化时间步嵌入为None
        temb = None

        # 下采样过程,初始化特征列表
        hs = [self.conv_in(x)]
        # 遍历每个分辨率级别
        for i_level in range(self.num_resolutions):
            # 遍历每个残差块
            for i_block in range(self.num_res_blocks):
                # 通过下采样块处理当前特征图
                h = self.down[i_level].block[i_block](hs[-1], temb)
                # 如果当前级别有注意力层,则应用该层
                if len(self.down[i_level].attn) > 0:
                    h = self.down[i_level].attn[i_block](h)
                # 将处理后的特征图添加到特征列表中
                hs.append(h)
            # 如果不是最后一个分辨率级别,则进行下采样
            if i_level != self.num_resolutions - 1:
                hs.append(self.down[i_level].downsample(hs[-1]))

        # 处理最后的特征图
        h = hs[-1]
        # 通过中间块1处理特征图
        h = self.mid.block_1(h, temb)
        # h = self.mid.attn_1(h)  # 注意力层(已注释)
        # 通过中间块2处理特征图
        h = self.mid.block_2(h, temb)

        # 最终处理步骤
        h = self.norm_out(h)  # 归一化输出
        h = nonlinearity(h)  # 应用非线性激活函数
        h = self.conv_out(h)  # 最终卷积操作
        # 返回处理后的特征图
        return h

.\cogvideo-finetune\sat\sgm\modules\autoencoding\vqvae\movq_modules.py

# pytorch_diffusion + derived encoder decoder
# 导入数学库
import math
# 导入 PyTorch 库
import torch
# 导入 PyTorch 神经网络模块
import torch.nn as nn
# 导入 NumPy 库
import numpy as np


def get_timestep_embedding(timesteps, embedding_dim):
    """
    该函数实现了去噪扩散概率模型中的时间步嵌入构建。
    来自 Fairseq。
    构建正弦嵌入。
    该实现与 tensor2tensor 中的实现匹配,但与“Attention Is All You Need”第 3.5 节中的描述略有不同。
    """
    # 确保 timesteps 是一维张量
    assert len(timesteps.shape) == 1

    # 计算嵌入维度的一半
    half_dim = embedding_dim // 2
    # 计算嵌入的指数缩放因子
    emb = math.log(10000) / (half_dim - 1)
    # 创建半维度的指数衰减张量
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
    # 将嵌入移动到与 timesteps 相同的设备上
    emb = emb.to(device=timesteps.device)
    # 计算时间步嵌入,使用广播机制
    emb = timesteps.float()[:, None] * emb[None, :]
    # 将正弦和余弦嵌入连接在一起
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    # 如果嵌入维度是奇数,则进行零填充
    if embedding_dim % 2 == 1:  # zero pad
        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
    # 返回生成的嵌入
    return emb


def nonlinearity(x):
    # 定义非线性激活函数:swish
    return x * torch.sigmoid(x)


class SpatialNorm(nn.Module):
    # 定义空间归一化模块
    def __init__(
        self,
        f_channels,
        zq_channels,
        norm_layer=nn.GroupNorm,
        freeze_norm_layer=False,
        add_conv=False,
        **norm_layer_params,
    ):
        # 初始化父类
        super().__init__()
        # 创建归一化层
        self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
        # 如果需要冻结归一化层的参数
        if freeze_norm_layer:
            # 将归一化层的所有参数设置为不需要梯度
            for p in self.norm_layer.parameters:
                p.requires_grad = False
        # 是否添加卷积层
        self.add_conv = add_conv
        # 如果添加卷积层,定义卷积层
        if self.add_conv:
            self.conv = nn.Conv2d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=1)
        # 定义用于处理 zq 的卷积层
        self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
        # 定义另一个用于处理 zq 的卷积层
        self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, f, zq):
        # 获取 f 的空间尺寸
        f_size = f.shape[-2:]
        # 将 zq 进行上采样以匹配 f 的尺寸
        zq = torch.nn.functional.interpolate(zq, size=f_size, mode="nearest")
        # 如果需要添加卷积层,则对 zq 进行卷积处理
        if self.add_conv:
            zq = self.conv(zq)
        # 对 f 应用归一化层
        norm_f = self.norm_layer(f)
        # 计算新的 f,结合归一化后的 f 和 zq 的卷积结果
        new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
        # 返回新的 f
        return new_f


def Normalize(in_channels, zq_ch, add_conv):
    # 创建并返回一个 SpatialNorm 实例
    return SpatialNorm(
        in_channels,
        zq_ch,
        norm_layer=nn.GroupNorm,
        freeze_norm_layer=False,
        add_conv=add_conv,
        num_groups=32,
        eps=1e-6,
        affine=True,
    )


class Upsample(nn.Module):
    # 定义上采样模块
    def __init__(self, in_channels, with_conv):
        # 初始化父类
        super().__init__()
        # 是否使用卷积层
        self.with_conv = with_conv
        # 如果使用卷积层,定义卷积层
        if self.with_conv:
            self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        # 将输入 x 上采样,放大两倍
        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
        # 如果使用卷积层,对上采样后的 x 进行卷积处理
        if self.with_conv:
            x = self.conv(x)
        # 返回处理后的 x
        return x


class Downsample(nn.Module):
    # 该类的定义未完成
    # 初始化方法,设置输入通道和是否使用卷积
        def __init__(self, in_channels, with_conv):
            # 调用父类初始化方法
            super().__init__()
            # 存储是否使用卷积的标志
            self.with_conv = with_conv
            # 如果选择使用卷积
            if self.with_conv:
                # 创建一个 2D 卷积层,卷积核大小为 3,步幅为 2,填充为 0
                # 注意:在 torch 卷积中没有非对称填充,必须手动处理
                self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
    
        # 前向传播方法,接收输入张量 x
        def forward(self, x):
            # 如果选择使用卷积
            if self.with_conv:
                # 定义填充参数,添加零填充以匹配卷积输入要求
                pad = (0, 1, 0, 1)
                # 对输入 x 进行常量填充
                x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
                # 通过卷积层处理输入 x
                x = self.conv(x)
            else:
                # 如果不使用卷积,执行平均池化操作,降低维度
                x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
            # 返回处理后的输出 x
            return x
# 定义一个残差块类,继承自 nn.Module
class ResnetBlock(nn.Module):
    # 初始化方法,接受多个参数配置残差块
    def __init__(
        self,
        *,
        in_channels,  # 输入通道数
        out_channels=None,  # 输出通道数,默认为 None
        conv_shortcut=False,  # 是否使用卷积短接
        dropout,  # dropout 的比率
        temb_channels=512,  # 时间嵌入的通道数,默认值为 512
        zq_ch=None,  # zq 的通道数,默认为 None
        add_conv=False,  # 是否添加卷积
    ):
        # 调用父类初始化方法
        super().__init__()
        self.in_channels = in_channels  # 设置输入通道数
        # 确定输出通道数,如果未提供,则等于输入通道数
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels  # 设置输出通道数
        self.use_conv_shortcut = conv_shortcut  # 设置是否使用卷积短接

        # 初始化归一化层
        self.norm1 = Normalize(in_channels, zq_ch, add_conv=add_conv)
        # 初始化第一个卷积层
        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        # 如果时间嵌入通道数大于 0,初始化时间嵌入投影层
        if temb_channels > 0:
            self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
        # 初始化第二个归一化层
        self.norm2 = Normalize(out_channels, zq_ch, add_conv=add_conv)
        # 初始化 dropout 层
        self.dropout = torch.nn.Dropout(dropout)
        # 初始化第二个卷积层
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        # 如果输入通道数与输出通道数不同,设置短接卷积
        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                # 使用卷积短接
                self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
            else:
                # 使用 1x1 卷积短接
                self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

    # 前向传播方法,接受输入 x、时间嵌入 temb 和 zq
    def forward(self, x, temb, zq):
        h = x  # 将输入赋值给 h
        h = self.norm1(h, zq)  # 对 h 进行第一次归一化
        h = nonlinearity(h)  # 应用非线性激活函数
        h = self.conv1(h)  # 通过第一个卷积层

        # 如果时间嵌入存在,则进行相应处理
        if temb is not None:
            h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]  # 加入时间嵌入投影

        h = self.norm2(h, zq)  # 对 h 进行第二次归一化
        h = nonlinearity(h)  # 应用非线性激活函数
        h = self.dropout(h)  # 应用 dropout
        h = self.conv2(h)  # 通过第二个卷积层

        # 如果输入和输出通道数不同,进行短接处理
        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                x = self.conv_shortcut(x)  # 使用卷积短接
            else:
                x = self.nin_shortcut(x)  # 使用 1x1 卷积短接

        # 返回输入和 h 的和
        return x + h  # 残差连接


# 定义一个注意力块类,继承自 nn.Module
class AttnBlock(nn.Module):
    # 初始化方法,接受输入通道数和可选参数
    def __init__(self, in_channels, zq_ch=None, add_conv=False):
        # 调用父类初始化方法
        super().__init__()
        self.in_channels = in_channels  # 设置输入通道数

        # 初始化归一化层
        self.norm = Normalize(in_channels, zq_ch, add_conv=add_conv)
        # 初始化查询卷积层
        self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        # 初始化键卷积层
        self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        # 初始化值卷积层
        self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        # 初始化输出卷积层
        self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
    # 定义前向传播函数,接受输入 x 和 zq
        def forward(self, x, zq):
            # 将输入 x 赋值给 h_
            h_ = x
            # 对 h_ 进行归一化处理,依据 zq
            h_ = self.norm(h_, zq)
            # 通过 q 层对 h_ 进行变换,得到查询 q
            q = self.q(h_)
            # 通过 k 层对 h_ 进行变换,得到键 k
            k = self.k(h_)
            # 通过 v 层对 h_ 进行变换,得到值 v
            v = self.v(h_)
    
            # 计算注意力
            # 获取 q 的形状,b 是批大小,c 是通道数,h 和 w 是高和宽
            b, c, h, w = q.shape
            # 将 q 重塑为 (b, c, h*w) 的形状
            q = q.reshape(b, c, h * w)
            # 变换 q 的维度顺序为 (b, hw, c)
            q = q.permute(0, 2, 1)  # b,hw,c
            # 将 k 重塑为 (b, c, h*w) 的形状
            k = k.reshape(b, c, h * w)  # b,c,hw
            # 计算 q 和 k 的乘积,得到注意力权重 w_
            w_ = torch.bmm(q, k)  # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
            # 对权重进行缩放
            w_ = w_ * (int(c) ** (-0.5))
            # 对权重应用 softmax 函数,进行归一化
            w_ = torch.nn.functional.softmax(w_, dim=2)
    
            # 对值进行注意力计算
            # 将 v 重塑为 (b, c, h*w) 的形状
            v = v.reshape(b, c, h * w)
            # 变换 w_ 的维度顺序为 (b, hw, hw)
            w_ = w_.permute(0, 2, 1)  # b,hw,hw (first hw of k, second of q)
            # 计算 v 和 w_ 的乘积,得到最终的 h_
            h_ = torch.bmm(v, w_)  # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
            # 将 h_ 重塑为 (b, c, h, w) 的形状
            h_ = h_.reshape(b, c, h, w)
    
            # 对 h_ 进行投影处理
            h_ = self.proj_out(h_)
    
            # 返回输入 x 和 h_ 的和
            return x + h_
# 定义一个名为 MOVQDecoder 的类,继承自 nn.Module
class MOVQDecoder(nn.Module):
    # 初始化方法,用于创建 MOVQDecoder 的实例
    def __init__(
        # 接受关键字参数 ch,表示输入通道数
        *,
        ch,
        # 接受关键字参数 out_ch,表示输出通道数
        out_ch,
        # 接受关键字参数 ch_mult,表示通道倍增的因子,默认为 (1, 2, 4, 8)
        ch_mult=(1, 2, 4, 8),
        # 接受关键字参数 num_res_blocks,表示残差块的数量
        num_res_blocks,
        # 接受关键字参数 attn_resolutions,表示注意力分辨率
        attn_resolutions,
        # 接受关键字参数 dropout,表示 dropout 的比例,默认为 0.0
        dropout=0.0,
        # 接受关键字参数 resamp_with_conv,表示是否使用卷积进行重采样,默认为 True
        resamp_with_conv=True,
        # 接受关键字参数 in_channels,表示输入数据的通道数
        in_channels,
        # 接受关键字参数 resolution,表示输入数据的分辨率
        resolution,
        # 接受关键字参数 z_channels,表示潜在空间的通道数
        z_channels,
        # 接受关键字参数 give_pre_end,表示是否提供前置结束标志,默认为 False
        give_pre_end=False,
        # 接受关键字参数 zq_ch,表示潜在空间的量化通道数,默认为 None
        zq_ch=None,
        # 接受关键字参数 add_conv,表示是否添加额外的卷积层,默认为 False
        add_conv=False,
        # 接受其他未指定的关键字参数,使用 **ignorekwargs 收集
        **ignorekwargs,
    # 定义构造函数的结束部分
        ):
            # 调用父类构造函数
            super().__init__()
            # 存储输入参数 ch
            self.ch = ch
            # 初始化 temb_ch 为 0
            self.temb_ch = 0
            # 计算分辨率的数量
            self.num_resolutions = len(ch_mult)
            # 存储残差块的数量
            self.num_res_blocks = num_res_blocks
            # 存储分辨率
            self.resolution = resolution
            # 存储输入通道数
            self.in_channels = in_channels
            # 存储是否给出预处理结束标志
            self.give_pre_end = give_pre_end
    
            # 计算输入通道数乘法,块输入和当前分辨率
            in_ch_mult = (1,) + tuple(ch_mult)
            # 计算当前块的输入通道数
            block_in = ch * ch_mult[self.num_resolutions - 1]
            # 计算当前分辨率
            curr_res = resolution // 2 ** (self.num_resolutions - 1)
            # 定义 z_shape,表示 z 的形状
            self.z_shape = (1, z_channels, curr_res, curr_res)
            # 打印 z 的形状信息
            print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
    
            # 定义从 z 到块输入的卷积层
            self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
    
            # 创建中间模块
            self.mid = nn.Module()
            # 定义中间块 1
            self.mid.block_1 = ResnetBlock(
                in_channels=block_in,
                out_channels=block_in,
                temb_channels=self.temb_ch,
                dropout=dropout,
                zq_ch=zq_ch,
                add_conv=add_conv,
            )
            # 定义中间注意力块 1
            self.mid.attn_1 = AttnBlock(block_in, zq_ch, add_conv=add_conv)
            # 定义中间块 2
            self.mid.block_2 = ResnetBlock(
                in_channels=block_in,
                out_channels=block_in,
                temb_channels=self.temb_ch,
                dropout=dropout,
                zq_ch=zq_ch,
                add_conv=add_conv,
            )
    
            # 初始化上采样模块列表
            self.up = nn.ModuleList()
            # 遍历分辨率,从高到低
            for i_level in reversed(range(self.num_resolutions)):
                # 初始化块和注意力模块列表
                block = nn.ModuleList()
                attn = nn.ModuleList()
                # 计算当前输出块的通道数
                block_out = ch * ch_mult[i_level]
                # 遍历每个残差块
                for i_block in range(self.num_res_blocks + 1):
                    # 添加残差块到块列表
                    block.append(
                        ResnetBlock(
                            in_channels=block_in,
                            out_channels=block_out,
                            temb_channels=self.temb_ch,
                            dropout=dropout,
                            zq_ch=zq_ch,
                            add_conv=add_conv,
                        )
                    )
                    # 更新块输入通道数
                    block_in = block_out
                    # 如果当前分辨率需要注意力模块,添加注意力模块
                    if curr_res in attn_resolutions:
                        attn.append(AttnBlock(block_in, zq_ch, add_conv=add_conv))
                # 创建上采样模块
                up = nn.Module()
                # 存储块和注意力模块
                up.block = block
                up.attn = attn
                # 如果不是最低分辨率,添加上采样层
                if i_level != 0:
                    up.upsample = Upsample(block_in, resamp_with_conv)
                    # 更新当前分辨率
                    curr_res = curr_res * 2
                # 将上采样模块插入到列表前面,确保顺序一致
                self.up.insert(0, up)  # prepend to get consistent order
    
            # 创建输出的归一化层
            self.norm_out = Normalize(block_in, zq_ch, add_conv=add_conv)
            # 定义输出的卷积层
            self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
    # 前向传播方法,接受输入 z 和条件 zq
    def forward(self, z, zq):
        # 断言 z 的形状与预期形状相符(已注释掉)
        # assert z.shape[1:] == self.z_shape[1:]
        # 保存输入 z 的形状
        self.last_z_shape = z.shape
    
        # 时间步嵌入初始化
        temb = None
    
        # 将 z 输入到卷积层
        h = self.conv_in(z)
    
        # 中间处理层
        h = self.mid.block_1(h, temb, zq)
        h = self.mid.attn_1(h, zq)
        h = self.mid.block_2(h, temb, zq)
    
        # 上采样过程
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks + 1):
                h = self.up[i_level].block[i_block](h, temb, zq)
                # 如果当前层有注意力模块,则应用注意力模块
                if len(self.up[i_level].attn) > 0:
                    h = self.up[i_level].attn[i_block](h, zq)
            # 如果不是最后一层,则进行上采样
            if i_level != 0:
                h = self.up[i_level].upsample(h)
    
        # 结束处理,条件性返回结果
        if self.give_pre_end:
            return h
    
        # 输出层归一化处理
        h = self.norm_out(h, zq)
        # 应用非线性激活函数
        h = nonlinearity(h)
        # 最后卷积层处理
        h = self.conv_out(h)
        return h
    
    # 带特征输出的前向传播方法,接受输入 z 和条件 zq
    def forward_with_features_output(self, z, zq):
        # 断言 z 的形状与预期形状相符(已注释掉)
        # assert z.shape[1:] == self.z_shape[1:]
        # 保存输入 z 的形状
        self.last_z_shape = z.shape
    
        # 时间步嵌入初始化
        temb = None
        output_features = {}
    
        # 将 z 输入到卷积层
        h = self.conv_in(z)
        # 保存卷积层输出特征
        output_features["conv_in"] = h
    
        # 中间处理层
        h = self.mid.block_1(h, temb, zq)
        # 保存中间块 1 的输出特征
        output_features["mid_block_1"] = h
        h = self.mid.attn_1(h, zq)
        # 保存中间注意力 1 的输出特征
        output_features["mid_attn_1"] = h
        h = self.mid.block_2(h, temb, zq)
        # 保存中间块 2 的输出特征
        output_features["mid_block_2"] = h
    
        # 上采样过程
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks + 1):
                h = self.up[i_level].block[i_block](h, temb, zq)
                # 保存每个上采样块的输出特征
                output_features[f"up_{i_level}_block_{i_block}"] = h
                # 如果当前层有注意力模块,则应用注意力模块并保存特征
                if len(self.up[i_level].attn) > 0:
                    h = self.up[i_level].attn[i_block](h, zq)
                    output_features[f"up_{i_level}_attn_{i_block}"] = h
            # 如果不是最后一层,则进行上采样并保存特征
            if i_level != 0:
                h = self.up[i_level].upsample(h)
                output_features[f"up_{i_level}_upsample"] = h
    
        # 结束处理,条件性返回结果
        if self.give_pre_end:
            return h
    
        # 输出层归一化处理
        h = self.norm_out(h, zq)
        # 保存归一化后的特征
        output_features["norm_out"] = h
        # 应用非线性激活函数并保存特征
        h = nonlinearity(h)
        output_features["nonlinearity"] = h
        # 最后卷积层处理并保存特征
        h = self.conv_out(h)
        output_features["conv_out"] = h
    
        return h, output_features

.\cogvideo-finetune\sat\sgm\modules\autoencoding\vqvae\quantize.py

# 导入 PyTorch 和其他必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch import einsum
from einops import rearrange

# 定义一个改进版的向量量化器类,继承自 nn.Module
class VectorQuantizer2(nn.Module):
    """
    Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
    avoids costly matrix multiplications and allows for post-hoc remapping of indices.
    """

    # 初始化方法,接受多个参数用于配置向量量化器
    def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
        # 调用父类构造函数
        super().__init__()
        # 设置向量数量
        self.n_e = n_e
        # 设置嵌入维度
        self.e_dim = e_dim
        # 设置 beta 值
        self.beta = beta
        # 设置是否使用旧版兼容性
        self.legacy = legacy

        # 创建一个嵌入层,尺寸为 (n_e, e_dim)
        self.embedding = nn.Embedding(self.n_e, self.e_dim)
        # 将嵌入权重初始化为均匀分布
        self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)

        # 如果提供了 remap 参数
        self.remap = remap
        if self.remap is not None:
            # 从文件加载已使用的索引并注册为缓冲区
            self.register_buffer("used", torch.tensor(np.load(self.remap)))
            # 设置重新嵌入的大小
            self.re_embed = self.used.shape[0]
            # 设置未知索引,默认为 "random"
            self.unknown_index = unknown_index  # "random" 或 "extra" 或整数
            if self.unknown_index == "extra":
                # 如果未知索引是 "extra",调整索引值
                self.unknown_index = self.re_embed
                self.re_embed = self.re_embed + 1
            # 打印重映射信息
            print(
                f"Remapping {self.n_e} indices to {self.re_embed} indices. "
                f"Using {self.unknown_index} for unknown indices."
            )
        else:
            # 如果没有重映射,则重新嵌入大小等于 n_e
            self.re_embed = n_e

        # 设置是否使用合理的索引形状
        self.sane_index_shape = sane_index_shape

    # 将索引映射到已使用的索引
    def remap_to_used(self, inds):
        # 记录输入索引的形状
        ishape = inds.shape
        # 确保输入至少有两个维度
        assert len(ishape) > 1
        # 将索引重塑为 (批量大小, -1) 的形状
        inds = inds.reshape(ishape[0], -1)
        # 将已使用的索引移到当前设备
        used = self.used.to(inds)
        # 检查 inds 是否在 used 中匹配
        match = (inds[:, :, None] == used[None, None, ...]).long()
        # 找到匹配的最大值索引
        new = match.argmax(-1)
        # 查找未知索引
        unknown = match.sum(2) < 1
        if self.unknown_index == "random":
            # 随机生成未知索引
            new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
        else:
            # 用指定的未知索引填充未知位置
            new[unknown] = self.unknown_index
        # 返回重塑后的新索引
        return new.reshape(ishape)

    # 将已使用的索引映射回所有索引
    def unmap_to_all(self, inds):
        # 记录输入索引的形状
        ishape = inds.shape
        # 确保输入至少有两个维度
        assert len(ishape) > 1
        # 将索引重塑为 (批量大小, -1) 的形状
        inds = inds.reshape(ishape[0], -1)
        # 将已使用的索引移到当前设备
        used = self.used.to(inds)
        # 如果重新嵌入的大小大于已使用的索引数量,处理额外的标记
        if self.re_embed > self.used.shape[0]:  # extra token
            # 将超出范围的索引设为零
            inds[inds >= self.used.shape[0]] = 0  # simply set to zero
        # 使用索引反向收集所有标记
        back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
        # 返回重塑后的标记
        return back.reshape(ishape)
    # 前向传播函数,处理输入 z,选择温度和日志缩放参数
    def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
        # 确保温度参数为 None 或 1.0,适用于 Gumbel 接口
        assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
        # 确保日志缩放参数为 False,适用于 Gumbel 接口
        assert rescale_logits == False, "Only for interface compatible with Gumbel"
        # 确保返回日志参数为 False,适用于 Gumbel 接口
        assert return_logits == False, "Only for interface compatible with Gumbel"
        # 将 z 变形为 (batch, height, width, channel) 并展平
        z = rearrange(z, "b c h w -> b h w c").contiguous()
        # 将 z 展平为二维张量,形状为 (batch_size, embedding_dim)
        z_flattened = z.view(-1, self.e_dim)
        # 计算 z 到嵌入 e_j 的距离 (z - e)^2 = z^2 + e^2 - 2 e * z
        d = (
            # 计算 z_flattened 的平方和,保留维度
            torch.sum(z_flattened**2, dim=1, keepdim=True)
            # 加上嵌入权重的平方和
            + torch.sum(self.embedding.weight**2, dim=1)
            # 减去 2 * z_flattened 和嵌入权重的内积
            - 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
        )
        # 找到距离最近的编码索引
        min_encoding_indices = torch.argmin(d, dim=1)
        # 根据最小编码索引获取量化的 z,并调整为原始形状
        z_q = self.embedding(min_encoding_indices).view(z.shape)
        perplexity = None  # 初始化困惑度为 None
        min_encodings = None  # 初始化最小编码为 None

        # 计算嵌入的损失
        if not self.legacy:
            # 计算损失,考虑 beta 和两个均方差项
            loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
        else:
            # 计算损失,考虑 beta 和两个均方差项(顺序不同)
            loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)

        # 保持梯度
        z_q = z + (z_q - z).detach()

        # 重新调整 z_q 的形状以匹配原始输入形状
        z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()

        if self.remap is not None:
            # 如果需要重映射,将最小编码索引展平并添加批次维度
            min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1)  # add batch axis
            # 使用重映射函数
            min_encoding_indices = self.remap_to_used(min_encoding_indices)
            # 再次展平
            min_encoding_indices = min_encoding_indices.reshape(-1, 1)  # flatten

        if self.sane_index_shape:
            # 确保最小编码索引形状合理,重新调整形状
            min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])

        # 返回量化的 z、损失及其他信息
        return z_q, loss, (perplexity, min_encodings, min_encoding_indices)

    # 根据索引获取代码簿条目,返回形状为 (batch, height, width, channel)
    def get_codebook_entry(self, indices, shape):
        # shape 指定 (batch, height, width, channel)
        if self.remap is not None:
            # 如果需要重映射,展平索引并添加批次维度
            indices = indices.reshape(shape[0], -1)  # add batch axis
            # 使用反重映射函数
            indices = self.unmap_to_all(indices)
            # 再次展平
            indices = indices.reshape(-1)  # flatten again

        # 获取量化的潜在向量
        z_q = self.embedding(indices)

        if shape is not None:
            # 如果形状不为 None,重新调整 z_q 的形状
            z_q = z_q.view(shape)
            # 重新调整以匹配原始输入形状
            z_q = z_q.permute(0, 3, 1, 2).contiguous()

        # 返回量化后的潜在向量
        return z_q
# 定义 GumbelQuantize 类,继承自 nn.Module
class GumbelQuantize(nn.Module):
    """
    归功于 @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (谢谢!)
    Gumbel Softmax 技巧量化器
    使用 Gumbel-Softmax 的分类重参数化,Jang 等人 2016
    https://arxiv.org/abs/1611.01144
    """

    # 初始化方法,定义类的属性
    def __init__(
        self,
        num_hiddens,  # 隐藏层的神经元数量
        embedding_dim,  # 嵌入维度
        n_embed,  # 嵌入的数量
        straight_through=True,  # 是否使用直通估计
        kl_weight=5e-4,  # KL 散度的权重
        temp_init=1.0,  # 初始温度
        use_vqinterface=True,  # 是否使用 VQ 接口
        remap=None,  # 重新映射参数
        unknown_index="random",  # 未知索引的处理方式
    ):
        super().__init__()  # 调用父类初始化

        self.embedding_dim = embedding_dim  # 设置嵌入维度
        self.n_embed = n_embed  # 设置嵌入数量

        self.straight_through = straight_through  # 保存直通估计的状态
        self.temperature = temp_init  # 设置温度
        self.kl_weight = kl_weight  # 设置 KL 权重

        # 定义卷积层,将隐藏层映射到嵌入空间
        self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
        # 定义嵌入层,将索引映射到嵌入向量
        self.embed = nn.Embedding(n_embed, embedding_dim)

        self.use_vqinterface = use_vqinterface  # 保存 VQ 接口使用状态

        self.remap = remap  # 保存重新映射参数
        if self.remap is not None:  # 如果存在重新映射
            # 注册用于存储重新映射的张量
            self.register_buffer("used", torch.tensor(np.load(self.remap)))
            self.re_embed = self.used.shape[0]  # 重新映射的嵌入数量
            self.unknown_index = unknown_index  # 保存未知索引
            if self.unknown_index == "extra":  # 如果未知索引为“extra”
                self.unknown_index = self.re_embed  # 设置未知索引为重新映射数量
                self.re_embed = self.re_embed + 1  # 增加重新映射数量
            # 打印重新映射的信息
            print(
                f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
                f"Using {self.unknown_index} for unknown indices."
            )
        else:
            self.re_embed = n_embed  # 否则,重新映射数量等于嵌入数量

    # 将索引重新映射到已使用的索引
    def remap_to_used(self, inds):
        ishape = inds.shape  # 获取输入的形状
        assert len(ishape) > 1  # 确保输入维度大于1
        inds = inds.reshape(ishape[0], -1)  # 将索引重塑为二维形状
        used = self.used.to(inds)  # 将已使用的张量移动到索引的设备上
        match = (inds[:, :, None] == used[None, None, ...]).long()  # 计算匹配矩阵
        new = match.argmax(-1)  # 获取匹配的最大值索引
        unknown = match.sum(2) < 1  # 检查未知索引
        if self.unknown_index == "random":  # 如果未知索引为随机
            # 随机生成未知索引
            new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
        else:
            new[unknown] = self.unknown_index  # 否则设置为未知索引
        return new.reshape(ishape)  # 返回重塑后的新索引

    # 将索引映射回所有索引
    def unmap_to_all(self, inds):
        ishape = inds.shape  # 获取输入的形状
        assert len(ishape) > 1  # 确保输入维度大于1
        inds = inds.reshape(ishape[0], -1)  # 将索引重塑为二维形状
        used = self.used.to(inds)  # 将已使用的张量移动到索引的设备上
        if self.re_embed > self.used.shape[0]:  # 如果有额外的标记
            inds[inds >= self.used.shape[0]] = 0  # 将超过范围的索引设置为零
        # 根据输入索引从已使用张量中收集值
        back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
        return back.reshape(ishape)  # 返回重塑后的结果
    # 定义前向传播方法,接受潜在变量 z、温度 temp 和返回 logits 的标志
    def forward(self, z, temp=None, return_logits=False):
        # 在评估模式下强制硬性为 True,因为必须进行量化。实际上,总是设为 True 似乎也可以
        hard = self.straight_through if self.training else True
        # 如果未提供温度,则使用类的温度属性
        temp = self.temperature if temp is None else temp
    
        # 将输入 z 投影到 logits 空间
        logits = self.proj(z)
        if self.remap is not None:
            # 仅继续使用的 logits
            full_zeros = torch.zeros_like(logits)  # 创建与 logits 同形状的零张量
            logits = logits[:, self.used, ...]  # 只保留使用的 logits
    
        # 使用 Gumbel-softmax 生成软 one-hot 编码
        soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
        if self.remap is not None:
            # 返回到所有条目,但未使用的设置为零
            full_zeros[:, self.used, ...] = soft_one_hot  # 将使用的编码放入全零张量中
            soft_one_hot = full_zeros  # 更新为全零张量
    
        # 计算量化后的表示
        z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
    
        # 加上对先验损失的 KL 散度
        qy = F.softmax(logits, dim=1)  # 对 logits 进行 softmax 操作
        # 计算 KL 散度
        diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
    
        # 找到软 one-hot 编码中最大值的索引
        ind = soft_one_hot.argmax(dim=1)
        if self.remap is not None:
            ind = self.remap_to_used(ind)  # 进行重映射
        if self.use_vqinterface:
            if return_logits:
                # 如果需要返回 logits,则返回量化后的表示、KL 散度和索引
                return z_q, diff, (None, None, ind), logits
            # 如果不需要返回 logits,则返回量化后的表示、KL 散度和索引
            return z_q, diff, (None, None, ind)
        # 返回量化后的表示、KL 散度和索引
        return z_q, diff, ind
    
    # 定义获取代码本条目的方法,接受索引和形状
    def get_codebook_entry(self, indices, shape):
        b, h, w, c = shape  # 解构形状信息
        assert b * h * w == indices.shape[0]  # 确保索引数量与形状一致
        indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w)  # 重新排列索引形状
        if self.remap is not None:
            indices = self.unmap_to_all(indices)  # 如果需要,进行反映射
        # 创建 one-hot 编码并调整维度顺序
        one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
        # 计算量化后的表示
        z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight)
        return z_q  # 返回量化后的表示

.\cogvideo-finetune\sat\sgm\modules\autoencoding\vqvae\vqvae_blocks.py

# pytorch_diffusion + derived encoder decoder
import math  # 导入数学库,用于数学计算
import torch  # 导入 PyTorch 库,用于深度学习
import torch.nn as nn  # 导入 PyTorch 的神经网络模块
import numpy as np  # 导入 NumPy 库,用于数组处理

def get_timestep_embedding(timesteps, embedding_dim):
    """
    该函数实现了 Denoising Diffusion Probabilistic Models 中的嵌入构建
    从 Fairseq。
    构建正弦波嵌入。
    该实现与 tensor2tensor 中的实现匹配,但与 "Attention Is All You Need" 的
    第 3.5 节中的描述略有不同。
    """
    # 确保 timesteps 是一维数组
    assert len(timesteps.shape) == 1

    # 计算嵌入维度的一半
    half_dim = embedding_dim // 2
    # 计算用于正弦和余弦的频率
    emb = math.log(10000) / (half_dim - 1)
    # 生成正弦和余弦嵌入所需的频率向量
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
    # 将频率向量移动到与 timesteps 相同的设备
    emb = emb.to(device=timesteps.device)
    # 根据时间步生成最终的嵌入
    emb = timesteps.float()[:, None] * emb[None, :]
    # 将正弦和余弦嵌入合并
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    # 如果嵌入维度是奇数,进行零填充
    if embedding_dim % 2 == 1:  # zero pad
        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
    # 返回生成的嵌入
    return emb

def nonlinearity(x):
    # 定义 Swish 非线性激活函数
    return x * torch.sigmoid(x)  # 返回 Swish 激活值

def Normalize(in_channels):
    # 定义归一化层,使用 GroupNorm
    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)

class Upsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        # 初始化 Upsample 类
        super().__init__()  # 调用父类构造函数
        self.with_conv = with_conv  # 根据参数决定是否使用卷积层
        if self.with_conv:
            # 如果使用卷积,则定义卷积层
            self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        # 定义前向传播方法
        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")  # 进行上采样
        if self.with_conv:
            x = self.conv(x)  # 如果需要,经过卷积层处理
        return x  # 返回处理后的张量

class Downsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        # 初始化 Downsample 类
        super().__init__()  # 调用父类构造函数
        self.with_conv = with_conv  # 根据参数决定是否使用卷积层
        if self.with_conv:
            # 定义卷积层,步幅为 2,填充为 0
            self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)

    def forward(self, x):
        # 定义前向传播方法
        if self.with_conv:
            pad = (0, 1, 0, 1)  # 定义填充参数
            x = torch.nn.functional.pad(x, pad, mode="constant", value=0)  # 对输入张量进行填充
            x = self.conv(x)  # 经过卷积层处理
        else:
            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)  # 进行平均池化
        return x  # 返回处理后的张量

class ResnetBlock(nn.Module):
    # 初始化方法,设置输入输出通道、卷积快捷连接、丢弃率和时间嵌入通道数
    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
        # 调用父类的初始化方法
        super().__init__()
        # 设置输入通道数
        self.in_channels = in_channels
        # 如果未提供输出通道数,则设置为输入通道数
        out_channels = in_channels if out_channels is None else out_channels
        # 设置输出通道数
        self.out_channels = out_channels
        # 设置是否使用卷积快捷连接的标志
        self.use_conv_shortcut = conv_shortcut

        # 初始化归一化层,处理输入通道
        self.norm1 = Normalize(in_channels)
        # 初始化第一层卷积,输入输出通道数相应,使用3x3卷积核
        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        # 如果时间嵌入通道数大于0,则初始化时间嵌入投影层
        if temb_channels > 0:
            self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
        # 初始化第二层归一化,处理输出通道
        self.norm2 = Normalize(out_channels)
        # 初始化丢弃层,设置丢弃率
        self.dropout = torch.nn.Dropout(dropout)
        # 初始化第二层卷积,输入输出通道数相应,使用3x3卷积核
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        # 如果输入通道数与输出通道数不相同
        if self.in_channels != self.out_channels:
            # 根据是否使用卷积快捷连接,初始化相应的卷积层
            if self.use_conv_shortcut:
                self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
            else:
                self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

    # 前向传播方法,处理输入数据和时间嵌入
    def forward(self, x, temb):
        # 将输入赋值给中间变量
        h = x
        # 对输入进行归一化处理
        h = self.norm1(h)
        # 应用非线性激活函数
        h = nonlinearity(h)
        # 经过第一层卷积
        h = self.conv1(h)

        # 如果时间嵌入不为 None,则将其添加到中间变量中
        if temb is not None:
            h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]

        # 对中间变量进行第二次归一化
        h = self.norm2(h)
        # 应用非线性激活函数
        h = nonlinearity(h)
        # 应用丢弃层
        h = self.dropout(h)
        # 经过第二层卷积
        h = self.conv2(h)

        # 如果输入通道数与输出通道数不相同
        if self.in_channels != self.out_channels:
            # 根据是否使用卷积快捷连接处理输入
            if self.use_conv_shortcut:
                x = self.conv_shortcut(x)
            else:
                x = self.nin_shortcut(x)

        # 返回输入和中间变量的和
        return x + h
# 定义注意力模块类,继承自 nn.Module
class AttnBlock(nn.Module):
    # 初始化方法,接受输入通道数
    def __init__(self, in_channels):
        # 调用父类的初始化方法
        super().__init__()
        # 保存输入通道数
        self.in_channels = in_channels

        # 初始化归一化层
        self.norm = Normalize(in_channels)
        # 初始化用于生成查询(q)的卷积层
        self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        # 初始化用于生成键(k)的卷积层
        self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        # 初始化用于生成值(v)的卷积层
        self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        # 初始化输出投影的卷积层
        self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)

    # 前向传播方法
    def forward(self, x):
        # 将输入赋值给 h_
        h_ = x
        # 对输入进行归一化处理
        h_ = self.norm(h_)
        # 通过查询卷积层得到查询向量 q
        q = self.q(h_)
        # 通过键卷积层得到键向量 k
        k = self.k(h_)
        # 通过值卷积层得到值向量 v
        v = self.v(h_)

        # 计算注意力
        # 获取查询向量的形状参数
        b, c, h, w = q.shape
        # 将 q 进行重塑,改变形状为 (b, c, h*w)
        q = q.reshape(b, c, h * w)
        # 变换 q 的维度顺序,变为 (b, hw, c)
        q = q.permute(0, 2, 1)  # b,hw,c
        # 将 k 进行重塑,改变形状为 (b, c, hw)
        k = k.reshape(b, c, h * w)  # b,c,hw

        # # 原始版本,在 fp16 中会出现 nan
        # w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
        # w_ = w_ * (int(c)**(-0.5))
        # # 实现 c**-0.5 在 q 上
        # 将查询向量 q 乘以 c 的倒数平方根
        q = q * (int(c) ** (-0.5))
        # 计算权重 w_,使用批量矩阵相乘
        w_ = torch.bmm(q, k)  # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]

        # 对权重 w_ 进行 softmax 归一化
        w_ = torch.nn.functional.softmax(w_, dim=2)

        # 根据值向量进行加权
        # 将值向量 v 进行重塑,改变形状为 (b, c, h*w)
        v = v.reshape(b, c, h * w)
        # 变换 w_ 的维度顺序,变为 (b, hw, hw) 
        w_ = w_.permute(0, 2, 1)  # b,hw,hw (first hw of k, second of q)
        # 通过批量矩阵相乘计算加权后的结果 h_
        h_ = torch.bmm(v, w_)  # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
        # 将 h_ 进行重塑,改变形状回到 (b, c, h, w)
        h_ = h_.reshape(b, c, h, w)

        # 通过输出投影层生成最终结果
        h_ = self.proj_out(h_)

        # 返回输入和经过注意力模块处理后的结果相加
        return x + h_


# 定义编码器类,继承自 nn.Module
class Encoder(nn.Module):
    # 初始化方法,接收多个参数
    def __init__(
        self,
        *,
        ch,  # 通道数
        out_ch,  # 输出通道数
        ch_mult=(1, 2, 4, 8),  # 通道数的倍率
        num_res_blocks,  # 残差块的数量
        attn_resolutions,  # 注意力计算的分辨率
        dropout=0.0,  # dropout 概率
        resamp_with_conv=True,  # 是否使用卷积进行重采样
        in_channels,  # 输入通道数
        resolution,  # 输入分辨率
        z_channels,  # z 维度通道数
        double_z=True,  # 是否使用双 z
        **ignore_kwargs,  # 其他忽略的关键字参数
    # 定义类的初始化方法
        ):
            # 调用父类的初始化方法
            super().__init__()
            # 初始化通道数
            self.ch = ch
            # 初始化时间嵌入通道数
            self.temb_ch = 0
            # 获取分辨率数量
            self.num_resolutions = len(ch_mult)
            # 获取残差块数量
            self.num_res_blocks = num_res_blocks
            # 设置分辨率
            self.resolution = resolution
            # 输入通道数
            self.in_channels = in_channels
    
            # downsampling
            # 初始化卷积层,输入通道为in_channels,输出通道为self.ch
            self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
    
            # 当前分辨率
            curr_res = resolution
            # 设置输入通道的倍数
            in_ch_mult = (1,) + tuple(ch_mult)
            # 初始化一个模块列表来存储下采样模块
            self.down = nn.ModuleList()
            # 遍历每个分辨率级别
            for i_level in range(self.num_resolutions):
                # 初始化块和注意力模块的模块列表
                block = nn.ModuleList()
                attn = nn.ModuleList()
                # 计算当前块的输入和输出通道数
                block_in = ch * in_ch_mult[i_level]
                block_out = ch * ch_mult[i_level]
                # 遍历每个残差块
                for i_block in range(self.num_res_blocks):
                    # 添加残差块到块列表
                    block.append(
                        ResnetBlock(
                            in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
                        )
                    )
                    # 更新输入通道数为输出通道数
                    block_in = block_out
                    # 如果当前分辨率在注意力分辨率中,添加注意力块
                    if curr_res in attn_resolutions:
                        attn.append(AttnBlock(block_in))
                # 创建下采样模块
                down = nn.Module()
                down.block = block
                down.attn = attn
                # 如果不是最后一个分辨率级别,添加下采样层
                if i_level != self.num_resolutions - 1:
                    down.downsample = Downsample(block_in, resamp_with_conv)
                    # 将当前分辨率减半
                    curr_res = curr_res // 2
                # 将下采样模块添加到列表中
                self.down.append(down)
    
            # middle
            # 创建中间模块
            self.mid = nn.Module()
            # 添加第一个残差块
            self.mid.block_1 = ResnetBlock(
                in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
            )
            # 添加中间注意力块
            self.mid.attn_1 = AttnBlock(block_in)
            # 添加第二个残差块
            self.mid.block_2 = ResnetBlock(
                in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
            )
    
            # end
            # 规范化输出
            self.norm_out = Normalize(block_in)
            # 初始化输出卷积层,输出通道为2 * z_channels或z_channels
            self.conv_out = torch.nn.Conv2d(
                block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1
            )
    
        # 定义前向传播方法
        def forward(self, x):
            # 确保输入张量的宽和高等于设定的分辨率,若不匹配则抛出异常
            # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
    
            # timestep embedding
            # 初始化时间嵌入
            temb = None
    
            # downsampling
            # 通过输入数据进行卷积操作
            hs = [self.conv_in(x)]
            # 遍历每个分辨率级别
            for i_level in range(self.num_resolutions):
                # 遍历每个残差块
                for i_block in range(self.num_res_blocks):
                    # 通过残差块处理前一层的输出
                    h = self.down[i_level].block[i_block](hs[-1], temb)
                    # 如果存在注意力模块,则进行注意力处理
                    if len(self.down[i_level].attn) > 0:
                        h = self.down[i_level].attn[i_block](h)
                    # 将当前层输出添加到列表中
                    hs.append(h)
                # 如果不是最后一个分辨率级别,则进行下采样
                if i_level != self.num_resolutions - 1:
                    hs.append(self.down[i_level].downsample(hs[-1]))
    
            # middle
            # 获取最后一层的输出
            h = hs[-1]
            # 通过中间块进行处理
            h = self.mid.block_1(h, temb)
            # 进行注意力处理
            h = self.mid.attn_1(h)
            # 通过第二个中间块进行处理
            h = self.mid.block_2(h, temb)
    
            # end
            # 规范化处理
            h = self.norm_out(h)
            # 应用非线性激活函数
            h = nonlinearity(h)
            # 通过输出卷积层得到最终输出
            h = self.conv_out(h)
            # 返回处理后的输出
            return h
    # 定义一个前向传播函数,输出特征
    def forward_with_features_output(self, x):
        # 断言输入张量的高和宽等于预设的分辨率
        # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)

        # 时间步嵌入初始化为空
        temb = None
        # 用于存储各层输出特征的字典
        output_features = {}

        # 下采样阶段,首先通过输入卷积层处理输入 x
        hs = [self.conv_in(x)]
        # 将输入卷积的输出保存到输出特征字典中
        output_features["conv_in"] = hs[-1]
        # 遍历每个分辨率层级
        for i_level in range(self.num_resolutions):
            # 遍历每个残差块
            for i_block in range(self.num_res_blocks):
                # 通过当前层级的当前块进行处理
                h = self.down[i_level].block[i_block](hs[-1], temb)
                # 将当前块的输出保存到输出特征字典中
                output_features["down{}_block{}".format(i_level, i_block)] = h
                # 如果当前层级有注意力机制,应用注意力机制
                if len(self.down[i_level].attn) > 0:
                    h = self.down[i_level].attn[i_block](h)
                    # 将注意力机制的输出保存到输出特征字典中
                    output_features["down{}_attn{}".format(i_level, i_block)] = h
                # 将当前块的输出加入历史输出列表
                hs.append(h)
            # 如果不是最后一个分辨率层级,进行下采样
            if i_level != self.num_resolutions - 1:
                hs.append(self.down[i_level].downsample(hs[-1]))
                # 将下采样的输出保存到输出特征字典中
                output_features["down{}_downsample".format(i_level)] = hs[-1]

        # 中间层处理
        h = hs[-1]
        # 通过中间块1进行处理
        h = self.mid.block_1(h, temb)
        # 将中间块1的输出保存到输出特征字典中
        output_features["mid_block_1"] = h
        # 应用中间层的注意力机制
        h = self.mid.attn_1(h)
        # 将中间层注意力机制的输出保存到输出特征字典中
        output_features["mid_attn_1"] = h
        # 通过中间块2进行处理
        h = self.mid.block_2(h, temb)
        # 将中间块2的输出保存到输出特征字典中
        output_features["mid_block_2"] = h

        # 结束处理阶段
        h = self.norm_out(h)  # 进行归一化处理
        output_features["norm_out"] = h  # 保存归一化输出
        h = nonlinearity(h)  # 应用非线性激活函数
        output_features["nonlinearity"] = h  # 保存非线性输出
        h = self.conv_out(h)  # 通过输出卷积层处理
        output_features["conv_out"] = h  # 保存输出卷积的结果

        # 返回最终输出和特征字典
        return h, output_features
# 定义一个名为Decoder的类,继承自nn.Module
class Decoder(nn.Module):
    # 初始化函数,接收一系列参数
    def __init__(
        self,
        *,
        ch,  # 输入通道数
        out_ch,  # 输出通道数
        ch_mult=(1, 2, 4, 8),  # 通道数的倍数
        num_res_blocks,  # 残差块的数量
        attn_resolutions,  # 注意力机制的分辨率
        dropout=0.0,  # dropout的比例
        resamp_with_conv=True,  # 是否使用卷积进行重采样
        in_channels,  # 输入的通道数
        resolution,  # 分辨率
        z_channels,  # z的通道数
        give_pre_end=False,  # 是否给出预处理结果
        **ignorekwargs,  # 忽略的关键字参数
    ):
        super().__init__()
        # 将参数赋值给类的属性
        self.ch = ch
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels
        self.give_pre_end = give_pre_end

        # 计算最低分辨率下的in_ch_mult、block_in和curr_res
        in_ch_mult = (1,) + tuple(ch_mult)
        block_in = ch * ch_mult[self.num_resolutions - 1]
        curr_res = resolution // 2 ** (self.num_resolutions - 1)
        self.z_shape = (1, z_channels, curr_res, curr_res)
        # 打印z的形状
        print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))

        # 将z转换为block_in
        self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)

        # 中间部分
        self.mid = nn.Module()
        # 中间部分的第一个残差块
        self.mid.block_1 = ResnetBlock(
            in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
        )
        # 中间部分的注意力机制
        self.mid.attn_1 = AttnBlock(block_in)
        # 中间部分的第二个残差块
        self.mid.block_2 = ResnetBlock(
            in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
        )

        # 上采样部分
        self.up = nn.ModuleList()
        # 从最高分辨率开始逆序遍历
        for i_level in reversed(range(self.num_resolutions)):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_out = ch * ch_mult[i_level]
            # 根据残差块的数量创建残差块和注意力机制
            for i_block in range(self.num_res_blocks + 1):
                block.append(
                    ResnetBlock(
                        in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
                    )
                )
                block_in = block_out
                # 如果当前分辨率在attn_resolutions中,则添加注意力机制
                if curr_res in attn_resolutions:
                    attn.append(AttnBlock(block_in))
            up = nn.Module()
            up.block = block
            up.attn = attn
            # 如果不是最低分辨率,则添加上采样层
            if i_level != 0:
                up.upsample = Upsample(block_in, resamp_with_conv)
                curr_res = curr_res * 2
            # 将up插入到up列表的开头,以保持一致的顺序
            self.up.insert(0, up)

        # 结束部分
        # 归一化层
        self.norm_out = Normalize(block_in)
        # 输出卷积层
        self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
    # 定义前向传播函数,接收输入 z
    def forward(self, z):
        # 检查 z 的形状是否与期望的 z_shape 相匹配(被注释掉的断言)
        # assert z.shape[1:] == self.z_shape[1:]
        # 保存当前输入 z 的形状
        self.last_z_shape = z.shape

        # 初始化时间步嵌入为 None
        temb = None

        # 将输入 z 通过第一层卷积层进行处理,得到 block_in
        h = self.conv_in(z)

        # 中间处理阶段
        # 将 h 传入第一块中间模块,结合时间步嵌入 temb
        h = self.mid.block_1(h, temb)
        # 经过第一个注意力层
        h = self.mid.attn_1(h)
        # 将 h 传入第二块中间模块,结合时间步嵌入 temb
        h = self.mid.block_2(h, temb)

        # 上采样阶段
        # 从高到低的分辨率进行迭代
        for i_level in reversed(range(self.num_resolutions)):
            # 在每个分辨率下遍历所有的块
            for i_block in range(self.num_res_blocks + 1):
                # 将 h 传入当前上采样块进行处理,结合时间步嵌入 temb
                h = self.up[i_level].block[i_block](h, temb)
                # 如果当前块有注意力层,进行注意力处理
                if len(self.up[i_level].attn) > 0:
                    h = self.up[i_level].attn[i_block](h)
            # 如果不是最后一层,进行上采样
            if i_level != 0:
                h = self.up[i_level].upsample(h)

        # 结束阶段
        # 如果给定了预结束标志,直接返回 h
        if self.give_pre_end:
            return h

        # 对 h 进行归一化处理
        h = self.norm_out(h)
        # 应用非线性激活函数
        h = nonlinearity(h)
        # 通过输出卷积层得到最终结果
        h = self.conv_out(h)
        # 返回处理后的结果 h
        return h

.\cogvideo-finetune\sat\sgm\modules\autoencoding\__init__.py


.\cogvideo-finetune\sat\sgm\modules\cp_enc_dec.py

# 导入数学库
import math
# 导入 PyTorch 库
import torch
# 导入 PyTorch 分布式计算库
import torch.distributed
# 导入 PyTorch 神经网络模块
import torch.nn as nn
# 从上级模块导入实用函数
from ..util import (
    get_context_parallel_group,  # 获取并行计算组
    get_context_parallel_rank,   # 获取当前并行计算的排名
    get_context_parallel_world_size,  # 获取并行计算的世界大小
)

# 设置使用的计算模式为 CP (Context Parallel)
_USE_CP = True

# 将输入转换为元组,确保长度为指定值
def cast_tuple(t, length=1):
    # 如果 t 已经是元组,直接返回;否则返回重复的元组
    return t if isinstance(t, tuple) else ((t,) * length)

# 检查 num 是否可以被 den 整除
def divisible_by(num, den):
    # 返回 num 除以 den 的余数是否为 0
    return (num % den) == 0

# 检查 n 是否为奇数
def is_odd(n):
    # 返回 n 除以 2 的结果是否为偶数的相反值
    return not divisible_by(n, 2)

# 检查值 v 是否存在(不为 None)
def exists(v):
    # 返回 v 是否不为 None
    return v is not None

# 将输入 t 转换为成对的元组
def pair(t):
    # 如果 t 是元组,直接返回;否则返回重复的元组
    return t if isinstance(t, tuple) else (t, t)

# 获取时间步嵌入
def get_timestep_embedding(timesteps, embedding_dim):
    """
    该实现与 Denoising Diffusion Probabilistic Models 中的实现相匹配:
    来自 Fairseq。
    构建正弦嵌入。
    与 tensor2tensor 中的实现相匹配,但与“Attention Is All You Need”第 3.5 节的描述略有不同。
    """
    # 确保 timesteps 是一维数组
    assert len(timesteps.shape) == 1

    # 计算一半维度
    half_dim = embedding_dim // 2
    # 计算正弦嵌入的缩放因子
    emb = math.log(10000) / (half_dim - 1)
    # 生成正弦嵌入
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
    # 将嵌入移动到与 timesteps 相同的设备
    emb = emb.to(device=timesteps.device)
    # 根据时间步生成嵌入
    emb = timesteps.float()[:, None] * emb[None, :]
    # 拼接正弦和余弦嵌入
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    # 如果嵌入维度为奇数,则填充一个零
    if embedding_dim % 2 == 1:  # zero pad
        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
    # 返回嵌入结果
    return emb

# 应用非线性函数(swish)
def nonlinearity(x):
    # swish 激活函数
    return x * torch.sigmoid(x)

# 创建 LeakyReLU 激活函数
def leaky_relu(p=0.1):
    # 返回带有给定负斜率的 LeakyReLU 实例
    return nn.LeakyReLU(p)

# 在并行计算中拆分输入
def _split(input_, dim):
    # 获取并行计算的世界大小
    cp_world_size = get_context_parallel_world_size()

    # 如果并行计算的世界大小为 1,直接返回输入
    if cp_world_size == 1:
        return input_

    # 获取当前并行计算的排名
    cp_rank = get_context_parallel_rank()

    # print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape)

    # 获取输入的第一帧,并保持连续性
    inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
    # 更新输入,去掉第一帧
    input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
    # 计算每个并行计算的维度大小
    dim_size = input_.size()[dim] // cp_world_size

    # 按指定维度拆分输入
    input_list = torch.split(input_, dim_size, dim=dim)
    # 获取当前排名对应的输出
    output = input_list[cp_rank]

    # 如果当前排名为 0,拼接第一帧
    if cp_rank == 0:
        output = torch.cat([inpu_first_frame_, output], dim=dim)
    # 确保输出是连续的
    output = output.contiguous()

    # print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape)

    # 返回拆分后的输出
    return output

# 在并行计算中收集输入
def _gather(input_, dim):
    # 获取并行计算的世界大小
    cp_world_size = get_context_parallel_world_size()

    # 如果并行计算的世界大小为 1,直接返回输入
    if cp_world_size == 1:
        return input_

    # 获取并行计算组
    group = get_context_parallel_group()
    # 获取当前并行计算的排名
    cp_rank = get_context_parallel_rank()

    # print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape)

    # 获取输入的第一帧,并保持连续性
    input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous()
    # 如果当前排名为 0,更新输入,去掉第一帧
    if cp_rank == 0:
        input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()

    # 创建一个包含空张量的列表,用于收集输入
    tensor_list = [torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))] + [
        torch.empty_like(input_) for _ in range(cp_world_size - 1)
    ]

    # 如果当前排名为 0,拼接第一帧到输入中
    if cp_rank == 0:
        input_ = torch.cat([input_first_frame_, input_], dim=dim)
    # 将输入张量存入指定的 tensor_list 中,索引由 cp_rank 确定
    tensor_list[cp_rank] = input_
    # 从所有进程中收集输入张量,并将结果存入 tensor_list 中,使用指定的分组
    torch.distributed.all_gather(tensor_list, input_, group=group)

    # 将 tensor_list 中的所有张量在指定维度上连接成一个新的张量,并确保内存是连续的
    output = torch.cat(tensor_list, dim=dim).contiguous()

    # 调试输出当前进程的 cp_rank 和输出张量的尺寸(此行已被注释)
    # print('out _gather, cp_rank:', cp_rank, 'output_size:', output.shape)

    # 返回连接后的输出张量
    return output
# 定义函数 _conv_split,接收输入张量、维度和卷积核大小
def _conv_split(input_, dim, kernel_size):
    # 获取当前并行上下文的进程数量
    cp_world_size = get_context_parallel_world_size()

    # 如果并行上下文进程数为 1,则直接返回输入
    if cp_world_size == 1:
        return input_

    # 获取当前进程在并行上下文中的排名
    cp_rank = get_context_parallel_rank()

    # 计算每个进程处理的维度大小
    dim_size = (input_.size()[dim] - kernel_size) // cp_world_size

    # 如果当前进程是 0 号进程,处理输入的前一部分
    if cp_rank == 0:
        output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0)
    else:
        # 其他进程处理输入的对应部分
        output = input_.transpose(dim, 0)[cp_rank * dim_size + 1 : (cp_rank + 1) * dim_size + kernel_size].transpose(
            dim, 0
        )
    # 确保输出张量在内存中的连续性
    output = output.contiguous()

    # 返回处理后的输出
    return output


# 定义函数 _conv_gather,接收输入张量、维度和卷积核大小
def _conv_gather(input_, dim, kernel_size):
    # 获取当前并行上下文的进程数量
    cp_world_size = get_context_parallel_world_size()

    # 如果并行上下文进程数为 1,则直接返回输入
    if cp_world_size == 1:
        return input_

    # 获取当前的并行组
    group = get_context_parallel_group()
    # 获取当前进程在并行上下文中的排名
    cp_rank = get_context_parallel_rank()

    # 处理输入的第一部分卷积核
    input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous()
    # 如果当前进程是 0 号进程,处理剩余的输入
    if cp_rank == 0:
        input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous()
    else:
        # 其他进程处理相应的输入部分
        input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim).contiguous()

    # 创建张量列表以存储各进程的输出
    tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [
        torch.empty_like(input_) for _ in range(cp_world_size - 1)
    ]
    # 如果当前进程是 0 号进程,合并输入
    if cp_rank == 0:
        input_ = torch.cat([input_first_kernel_, input_], dim=dim)

    # 将当前进程的输入保存到张量列表中
    tensor_list[cp_rank] = input_
    # 收集所有进程的输入到张量列表
    torch.distributed.all_gather(tensor_list, input_, group=group)

    # 合并张量列表中的所有输出,确保输出在内存中连续
    output = torch.cat(tensor_list, dim=dim).contiguous()

    # 返回处理后的输出
    return output

.\cogvideo-finetune\sat\sgm\modules\diffusionmodules\denoiser.py

# 从 typing 模块导入字典和联合类型的支持
from typing import Dict, Union

# 导入 PyTorch 库
import torch
import torch.nn as nn

# 从上级模块的 util 中导入 append_dims 和 instantiate_from_config 函数
from ...util import append_dims, instantiate_from_config


# 定义去噪器类,继承自 nn.Module
class Denoiser(nn.Module):
    # 初始化方法,接受权重配置和缩放配置
    def __init__(self, weighting_config, scaling_config):
        # 调用父类的初始化方法
        super().__init__()

        # 根据权重配置实例化权重处理对象
        self.weighting = instantiate_from_config(weighting_config)
        # 根据缩放配置实例化缩放处理对象
        self.scaling = instantiate_from_config(scaling_config)

    # 可能对 sigma 进行量化的函数,当前仅返回原值
    def possibly_quantize_sigma(self, sigma):
        return sigma

    # 可能对 c_noise 进行量化的函数,当前仅返回原值
    def possibly_quantize_c_noise(self, c_noise):
        return c_noise

    # 计算权重的函数,返回处理后的 sigma
    def w(self, sigma):
        return self.weighting(sigma)

    # 前向传播方法,定义了模型的计算过程
    def forward(
        self,
        network: nn.Module,  # 网络模块
        input: torch.Tensor,  # 输入张量
        sigma: torch.Tensor,  # sigma 张量
        cond: Dict,  # 条件字典
        **additional_model_inputs,  # 其他模型输入参数
    ) -> torch.Tensor:  # 返回一个张量
        # 量化 sigma
        sigma = self.possibly_quantize_sigma(sigma)
        # 获取 sigma 的形状
        sigma_shape = sigma.shape
        # 将 sigma 的维度扩展到输入的维度
        sigma = append_dims(sigma, input.ndim)
        # 通过缩放处理得到多个输出
        c_skip, c_out, c_in, c_noise = self.scaling(sigma, **additional_model_inputs)
        # 量化 c_noise,并恢复原来的形状
        c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
        # 返回通过网络处理的结果,加上输入与 c_skip 的加权和
        return network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out + input * c_skip


# 定义离散去噪器类,继承自 Denoiser
class DiscreteDenoiser(Denoiser):
    # 初始化方法,接受多个参数进行配置
    def __init__(
        self,
        weighting_config,  # 权重配置
        scaling_config,  # 缩放配置
        num_idx,  # 索引数量
        discretization_config,  # 离散化配置
        do_append_zero=False,  # 是否添加零
        quantize_c_noise=True,  # 是否量化 c_noise
        flip=True,  # 是否翻转
    ):
        # 调用父类的初始化方法
        super().__init__(weighting_config, scaling_config)
        # 根据离散化配置实例化 sigma 对象
        sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip)
        # 保存 sigma 对象
        self.sigmas = sigmas
        # self.register_buffer("sigmas", sigmas)  # 可选,注册为缓冲区
        # 保存是否量化 c_noise 的配置
        self.quantize_c_noise = quantize_c_noise

    # 将 sigma 转换为索引的函数
    def sigma_to_idx(self, sigma):
        # 计算 sigma 与每个 sigma 值的距离
        dists = sigma - self.sigmas.to(sigma.device)[:, None]
        # 返回距离最小的索引
        return dists.abs().argmin(dim=0).view(sigma.shape)

    # 将索引转换为 sigma 的函数
    def idx_to_sigma(self, idx):
        # 根据索引返回对应的 sigma 值
        return self.sigmas.to(idx.device)[idx]

    # 可能对 sigma 进行量化的函数
    def possibly_quantize_sigma(self, sigma):
        # 通过索引转换函数进行量化
        return self.idx_to_sigma(self.sigma_to_idx(sigma))

    # 可能对 c_noise 进行量化的函数
    def possibly_quantize_c_noise(self, c_noise):
        # 如果配置为量化,则返回 c_noise 的索引
        if self.quantize_c_noise:
            return self.sigma_to_idx(c_noise)
        else:
            # 否则返回原值
            return c_noise

.\cogvideo-finetune\sat\sgm\modules\diffusionmodules\denoiser_scaling.py

# 从 abc 模块导入 ABC 类和 abstractmethod 装饰器
from abc import ABC, abstractmethod
# 从 typing 模块导入 Any 和 Tuple 类型
from typing import Any, Tuple

# 导入 torch 库
import torch


# 定义一个抽象基类 DenoiserScaling,继承自 ABC
class DenoiserScaling(ABC):
    # 定义一个抽象方法 __call__,接受一个 torch.Tensor 参数并返回一个四元组
    @abstractmethod
    def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        pass


# 定义 EDMScaling 类
class EDMScaling:
    # 初始化方法,接受一个 sigma_data 参数,默认值为 0.5
    def __init__(self, sigma_data: float = 0.5):
        # 设置实例变量 sigma_data
        self.sigma_data = sigma_data

    # 定义 __call__ 方法,接受一个 torch.Tensor 参数并返回一个四元组
    def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # 计算 c_skip,使用 sigma 和 sigma_data
        c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
        # 计算 c_out,结合 sigma 和 sigma_data
        c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
        # 计算 c_in,涉及 sigma 和 sigma_data
        c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
        # 计算 c_noise,基于 sigma 的对数
        c_noise = 0.25 * sigma.log()
        # 返回计算结果
        return c_skip, c_out, c_in, c_noise


# 定义 EpsScaling 类
class EpsScaling:
    # 定义 __call__ 方法,接受一个 torch.Tensor 参数并返回一个四元组
    def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # 创建与 sigma 形状相同的全 1 张量,设备与 sigma 相同
        c_skip = torch.ones_like(sigma, device=sigma.device)
        # c_out 为 sigma 的负值
        c_out = -sigma
        # 计算 c_in,涉及 sigma 的平方
        c_in = 1 / (sigma**2 + 1.0) ** 0.5
        # 复制 sigma 作为 c_noise
        c_noise = sigma.clone()
        # 返回计算结果
        return c_skip, c_out, c_in, c_noise


# 定义 VScaling 类
class VScaling:
    # 定义 __call__ 方法,接受一个 torch.Tensor 参数并返回一个四元组
    def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # 计算 c_skip,涉及 sigma 的平方
        c_skip = 1.0 / (sigma**2 + 1.0)
        # 计算 c_out,结合 sigma 的负值
        c_out = -sigma / (sigma**2 + 1.0) ** 0.5
        # 计算 c_in,涉及 sigma 的平方
        c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
        # 复制 sigma 作为 c_noise
        c_noise = sigma.clone()
        # 返回计算结果
        return c_skip, c_out, c_in, c_noise


# 定义 VScalingWithEDMcNoise 类,继承自 DenoiserScaling
class VScalingWithEDMcNoise(DenoiserScaling):
    # 定义 __call__ 方法,接受一个 torch.Tensor 参数并返回一个四元组
    def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # 计算 c_skip,涉及 sigma 的平方
        c_skip = 1.0 / (sigma**2 + 1.0)
        # 计算 c_out,结合 sigma 的负值
        c_out = -sigma / (sigma**2 + 1.0) ** 0.5
        # 计算 c_in,涉及 sigma 的平方
        c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
        # 计算 c_noise,基于 sigma 的对数
        c_noise = 0.25 * sigma.log()
        # 返回计算结果
        return c_skip, c_out, c_in, c_noise


# 定义 VideoScaling 类,类似于 VScaling
class VideoScaling:  # similar to VScaling
    # 定义 __call__ 方法,接受一个 torch.Tensor 和可选的其他模型输入
    def __call__(
        self, alphas_cumprod_sqrt: torch.Tensor, **additional_model_inputs
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # 将 alphas_cumprod_sqrt 赋值给 c_skip
        c_skip = alphas_cumprod_sqrt
        # 计算 c_out,涉及 alphas_cumprod_sqrt
        c_out = -((1 - alphas_cumprod_sqrt**2) ** 0.5)
        # 创建与 alphas_cumprod_sqrt 形状相同的全 1 张量,设备相同
        c_in = torch.ones_like(alphas_cumprod_sqrt, device=alphas_cumprod_sqrt.device)
        # 复制 additional_model_inputs 中的 "idx" 作为 c_noise
        c_noise = additional_model_inputs["idx"].clone()
        # 返回计算结果
        return c_skip, c_out, c_in, c_noise

.\cogvideo-finetune\sat\sgm\modules\diffusionmodules\denoiser_weighting.py

# 导入 PyTorch 库
import torch


# 定义一个单位权重类
class UnitWeighting:
    # 定义可调用方法,接受 sigma 作为参数
    def __call__(self, sigma):
        # 返回与 sigma 形状相同的全 1 张量,设备与 sigma 相同
        return torch.ones_like(sigma, device=sigma.device)


# 定义 EDM 权重类
class EDMWeighting:
    # 初始化方法,设置 sigma_data 的默认值为 0.5
    def __init__(self, sigma_data=0.5):
        # 将传入的 sigma_data 保存为实例变量
        self.sigma_data = sigma_data

    # 定义可调用方法,接受 sigma 作为参数
    def __call__(self, sigma):
        # 返回计算的权重值
        return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2


# 定义 V 权重类,继承自 EDMWeighting
class VWeighting(EDMWeighting):
    # 初始化方法
    def __init__(self):
        # 调用父类构造方法,设置 sigma_data 为 1.0
        super().__init__(sigma_data=1.0)


# 定义 Eps 权重类
class EpsWeighting:
    # 定义可调用方法,接受 sigma 作为参数
    def __call__(self, sigma):
        # 返回 sigma 的平方的倒数
        return sigma**-2.0

.\cogvideo-finetune\sat\sgm\modules\diffusionmodules\discretizer.py

# 从 abc 模块导入抽象方法,用于定义抽象基类
from abc import abstractmethod
# 从 functools 模块导入 partial,用于创建部分应用的函数
from functools import partial

# 导入 numpy 库,通常用于数值计算
import numpy as np
# 导入 PyTorch 库,通常用于深度学习
import torch

# 从自定义模块中导入 make_beta_schedule 函数,用于生成 beta 调度
from ...modules.diffusionmodules.util import make_beta_schedule
# 从自定义模块中导入 append_zero 函数,用于处理 sigma 数组
from ...util import append_zero


# 定义一个函数,用于生成大致均匀间隔的步数
def generate_roughly_equally_spaced_steps(num_substeps: int, max_step: int) -> np.ndarray:
    # 使用 linspace 生成从 max_step-1 到 0 的均匀间隔的数组,反转并转换为整数
    return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]


# 定义一个抽象基类 Discretization
class Discretization:
    # 定义一个可调用方法,用于处理输入参数
    def __call__(self, n, do_append_zero=True, device="cpu", flip=False, return_idx=False):
        # 根据 return_idx 的值获取 sigma 和索引
        if return_idx:
            sigmas, idx = self.get_sigmas(n, device=device, return_idx=return_idx)
        else:
            # 获取 sigma,不返回索引
            sigmas = self.get_sigmas(n, device=device, return_idx=return_idx)
        # 如果 do_append_zero 为真,则在 sigmas 末尾添加零
        sigmas = append_zero(sigmas) if do_append_zero else sigmas
        # 根据 flip 的值决定返回的 sigmas 和索引
        if return_idx:
            return sigmas if not flip else torch.flip(sigmas, (0,)), idx
        else:
            return sigmas if not flip else torch.flip(sigmas, (0,))

    # 定义一个抽象方法 get_sigmas,必须在子类中实现
    @abstractmethod
    def get_sigmas(self, n, device):
        pass


# 定义 EDMDiscretization 类,继承自 Discretization
class EDMDiscretization(Discretization):
    # 初始化方法,设置 sigma_min、sigma_max 和 rho 的默认值
    def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):
        self.sigma_min = sigma_min  # 设置最小 sigma 值
        self.sigma_max = sigma_max  # 设置最大 sigma 值
        self.rho = rho  # 设置 rho 值

    # 实现 get_sigmas 方法
    def get_sigmas(self, n, device="cpu"):
        # 在指定设备上生成从 0 到 1 的均匀分布数组
        ramp = torch.linspace(0, 1, n, device=device)
        # 计算 sigma_min 和 sigma_max 的倒数
        min_inv_rho = self.sigma_min ** (1 / self.rho)
        max_inv_rho = self.sigma_max ** (1 / self.rho)
        # 根据公式计算 sigmas
        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
        return sigmas  # 返回计算出的 sigmas


# 定义 LegacyDDPMDiscretization 类,继承自 Discretization
class LegacyDDPMDiscretization(Discretization):
    # 初始化方法,设置线性开始、结束和时间步数的默认值
    def __init__(
        self,
        linear_start=0.00085,
        linear_end=0.0120,
        num_timesteps=1000,
    ):
        super().__init__()  # 调用父类的初始化方法
        self.num_timesteps = num_timesteps  # 设置时间步数
        # 生成 beta 调度并计算对应的 alpha 值
        betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end)
        alphas = 1.0 - betas  # 计算 alpha 值
        # 计算 alpha 值的累积乘积
        self.alphas_cumprod = np.cumprod(alphas, axis=0)
        # 创建一个部分应用的 torch.tensor 函数,指定数据类型
        self.to_torch = partial(torch.tensor, dtype=torch.float32)

    # 实现 get_sigmas 方法
    def get_sigmas(self, n, device="cpu"):
        # 如果 n 小于时间步数,生成均匀间隔的时间步
        if n < self.num_timesteps:
            timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
            alphas_cumprod = self.alphas_cumprod[timesteps]  # 根据时间步获取对应的 alpha 值
        # 如果 n 等于时间步数,直接使用累积 alpha 值
        elif n == self.num_timesteps:
            alphas_cumprod = self.alphas_cumprod
        # 如果 n 超过时间步数,抛出值错误
        else:
            raise ValueError

        # 创建一个部分应用的 torch.tensor 函数,指定数据类型和设备
        to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
        # 计算 sigmas
        sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
        return torch.flip(sigmas, (0,))  # 反转返回 sigmas


# 定义 ZeroSNRDDPMDiscretization 类,继承自 Discretization
class ZeroSNRDDPMDiscretization(Discretization):
    # 初始化方法,设置线性开始、结束、时间步数和其他参数的默认值
    def __init__(
        self,
        linear_start=0.00085,
        linear_end=0.0120,
        num_timesteps=1000,
        shift_scale=1.0,  # 噪声调度参数
        keep_start=False,  # 是否保持起始状态
        post_shift=False,  # 是否在后续进行偏移
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 如果保留起始值且没有后移,则对线性起始值进行缩放
        if keep_start and not post_shift:
            linear_start = linear_start / (shift_scale + (1 - shift_scale) * linear_start)
        # 设置时间步数
        self.num_timesteps = num_timesteps
        # 创建线性调度的 beta 值
        betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end)
        # 计算 alpha 值
        alphas = 1.0 - betas
        # 计算累积的 alpha 值
        self.alphas_cumprod = np.cumprod(alphas, axis=0)
        # 将部分功能固定为转换为 torch 张量
        self.to_torch = partial(torch.tensor, dtype=torch.float32)

        # SNR 偏移处理
        if not post_shift:
            # 调整累积的 alpha 值
            self.alphas_cumprod = self.alphas_cumprod / (shift_scale + (1 - shift_scale) * self.alphas_cumprod)

        # 存储后移状态
        self.post_shift = post_shift
        # 存储偏移缩放因子
        self.shift_scale = shift_scale

    def get_sigmas(self, n, device="cpu", return_idx=False):
        # 如果 n 小于时间步数,则生成等间隔的时间步
        if n < self.num_timesteps:
            timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
            # 取出对应的累积 alpha 值
            alphas_cumprod = self.alphas_cumprod[timesteps]
        # 如果 n 等于时间步数,直接使用全部累积 alpha 值
        elif n == self.num_timesteps:
            alphas_cumprod = self.alphas_cumprod
        # 如果 n 超过时间步数,则抛出错误
        else:
            raise ValueError

        # 将累积 alpha 值转换为 torch 张量
        to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
        alphas_cumprod = to_torch(alphas_cumprod)
        # 计算累积 alpha 值的平方根
        alphas_cumprod_sqrt = alphas_cumprod.sqrt()
        # 备份初始和最终的累积 alpha 值的平方根
        alphas_cumprod_sqrt_0 = alphas_cumprod_sqrt[0].clone()
        alphas_cumprod_sqrt_T = alphas_cumprod_sqrt[-1].clone()

        # 调整平方根的累积 alpha 值
        alphas_cumprod_sqrt -= alphas_cumprod_sqrt_T
        alphas_cumprod_sqrt *= alphas_cumprod_sqrt_0 / (alphas_cumprod_sqrt_0 - alphas_cumprod_sqrt_T)

        # 如果开启后移,则进一步调整平方根的累积 alpha 值
        if self.post_shift:
            alphas_cumprod_sqrt = (
                alphas_cumprod_sqrt**2 / (self.shift_scale + (1 - self.shift_scale) * alphas_cumprod_sqrt**2)
            ) ** 0.5

        # 根据是否返回索引来决定返回值
        if return_idx:
            return torch.flip(alphas_cumprod_sqrt, (0,)), timesteps
        else:
            # 返回反转的平方根 alpha 值
            return torch.flip(alphas_cumprod_sqrt, (0,))  # sqrt(alpha_t): 0 -> 0.99

标签:channels,CogVideoX,nn,conv,self,torch,---,源码,sigma
From: https://www.cnblogs.com/apachecn/p/18494402

相关文章

  • CogView3---CogView-3Plus-微调代码源码解析-五-
    CogView3&CogView-3Plus微调代码源码解析(五).\cogview3-finetune\sat\sgm\modules\encoders\__init__.py请提供需要注释的代码。.\cogview3-finetune\sat\sgm\modules\__init__.py#从当前包的编码器模块导入GeneralConditioner类from.encoders.modulesimportGeneral......
  • CogView3---CogView-3Plus-微调代码源码解析-四-
    CogView3&CogView-3Plus微调代码源码解析(四).\cogview3-finetune\sat\sgm\modules\diffusionmodules\sampling_utils.py#导入数学库以进行数学运算importmath#导入PyTorch库以进行张量操作importtorch#从SciPy库导入积分函数fromscipyimportintegrate#从......
  • CogView3---CogView-3Plus-微调代码源码解析-三-
    CogView3&CogView-3Plus微调代码源码解析(三).\cogview3-finetune\sat\sgm\modules\diffusionmodules\guiders.py#导入logging模块,用于记录日志信息importlogging#从abc模块导入ABC类和abstractmethod装饰器,用于定义抽象基类和抽象方法fromabcimportABC,abst......
  • CogView3---CogView-3Plus-微调代码源码解析-二-
    CogView3&CogView-3Plus微调代码源码解析(二).\cogview3-finetune\sat\sgm\models\__init__.py#从同一模块导入AutoencodingEngine类,用于后续的自动编码器操作from.autoencoderimportAutoencodingEngine#注释文本(可能是无关信息或标识符)#XuDwndGaCFo.\cogview3-fi......
  • selenium单例模式下 docker-chrome 多线程并发代码
    最近需要写爬虫,在解决docker-standalone-chrome发现只能有一个chrome被执行。所以写了这个多线程并发控制类来管理。当模板记录下。#!/usr/bin/envpython3importthreadingimporttracebackfromloguruimportloggerfromseleniumimportwebdriverfromselenium.comm......
  • 基于SpringBoot+Vue的大数据技术的宠物商品信息比价及推荐系统(源码+LW+调试文档+讲解
    在宠物经济日益繁荣的今天,为宠物主人提供一个高效的宠物商品信息比价及推荐系统至关重要。本系统基于SpringBoot+Vue并结合大数据技术,为宠物主人带来全新的购物体验。在设计上,系统广泛收集各类宠物商品的信息,包括价格、品牌、规格、用户评价等。通过大数据分析,对不同......
  • 基于SpringBoot+Vue的大数据高乐健身器材销售数据可视化系统设计与实现(源码+LW+调试
    在健身热潮持续升温的当下,健身器材销售数据的有效管理和分析至关重要。本系统基于SpringBoot+Vue并结合大数据技术,为高乐健身器材的销售管理提供强大的可视化解决方案。在设计上,系统全面收集高乐健身器材的销售数据,包括产品种类、销售数量、销售地区、销售时间等多维......
  • 从零开始实现WEB自动化 - 技术选型及简单实践
    作为程序员的我们,在工作中应该能明显感觉到,技术选型对整个开发周期尤为重要,选择合适的技术可以帮助我们更高效地完成工作,提高开发速度和质量。 本篇主要针对开发WEB自动化的技术实现探索Selenium初探Selenium是一个用于自动化浏览器操作的工具,可以模拟用户在网页上的操作,包......
  • 前端vue-接口的调用和特殊组件的封装
                 ......
  • 【开题报告】基于Springboot+vue中医古方名方信息管理系统(程序+源码+论文) 计算机毕业
    本系统(程序+源码)带文档lw万字以上文末可获取一份本项目的java源码和数据库参考。系统程序文件列表开题报告内容研究背景中医作为中华民族的传统医学,承载着丰富的历史文化底蕴与独特的医疗智慧。在历史的长河中,无数中医先辈通过临床实践,总结出了大量疗效显著的古方名方,这......