首页 > 编程语言 >CogView3---CogView-3Plus-微调代码源码解析-三-

CogView3---CogView-3Plus-微调代码源码解析-三-

时间:2024-10-23 09:22:19浏览次数:1  
标签:None num 3Plus CogView self channels 源码 sigma out

CogView3 & CogView-3Plus 微调代码源码解析(三)

.\cogview3-finetune\sat\sgm\modules\diffusionmodules\guiders.py

# 导入 logging 模块,用于记录日志信息
import logging
# 从 abc 模块导入 ABC 类和 abstractmethod 装饰器,用于定义抽象基类和抽象方法
from abc import ABC, abstractmethod
# 导入类型注解,方便在函数签名中定义复杂数据结构
from typing import Dict, List, Optional, Tuple, Union
# 从 functools 模块导入 partial 函数,用于部分应用函数
from functools import partial
# 导入数学模块,提供数学函数
import math

# 导入 PyTorch 库,提供张量计算功能
import torch
# 从 einops 模块导入 rearrange 和 repeat 函数,用于张量重排和重复
from einops import rearrange, repeat

# 从上层模块导入工具函数,提供一些默认值和实例化配置的功能
from ...util import append_dims, default, instantiate_from_config

# 定义一个抽象基类 Guider,继承自 ABC
class Guider(ABC):
    # 定义一个抽象方法 __call__,接受一个张量和一个浮点数,返回一个张量
    @abstractmethod
    def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
        pass

    # 定义准备输入的方法,接受多个参数并返回一个元组
    def prepare_inputs(
        self, x: torch.Tensor, s: float, c: Dict, uc: Dict
    ) -> Tuple[torch.Tensor, float, Dict]:
        pass


# 定义一个类 VanillaCFG,表示基本的条件生成模型
class VanillaCFG:
    """
    implements parallelized CFG
    """

    # 初始化方法,接受比例和动态阈值配置
    def __init__(self, scale, dyn_thresh_config=None):
        # 定义一个 lambda 函数,根据 sigma 返回 scale,保持独立于步数
        scale_schedule = lambda scale, sigma: scale  # independent of step
        # 使用 partial 固定 scale 参数,创建 scale_schedule 方法
        self.scale_schedule = partial(scale_schedule, scale)
        # 实例化动态阈值对象,如果没有提供配置则使用默认配置
        self.dyn_thresh = instantiate_from_config(
            default(
                dyn_thresh_config,
                {
                    "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
                },
            )
        )

    # 定义 __call__ 方法,使该类可以被调用,接受多个参数
    def __call__(self, x, sigma, step = None, num_steps = None, **kwargs):
        # 将输入张量 x 拆分为两个部分 x_u 和 x_c
        x_u, x_c = x.chunk(2)
        # 根据 sigma 计算 scale_value
        scale_value = self.scale_schedule(sigma)
        # 使用动态阈值处理函数进行预测,返回预测结果
        x_pred = self.dyn_thresh(x_u, x_c, scale_value, step=step, num_steps=num_steps)
        return x_pred

    # 定义准备输入的方法,接受多个参数并返回一个元组
    def prepare_inputs(self, x, s, c, uc):
        # 初始化输出字典
        c_out = dict()

        # 遍历条件字典 c 的键
        for k in c:
            # 如果键是特定值,则将 uc 和 c 中的对应张量拼接
            if k in ["vector", "crossattn", "concat"]:
                c_out[k] = torch.cat((uc[k], c[k]), 0)
            # 否则确保两个字典中对应的值相等,并直接赋值
            else:
                assert c[k] == uc[k]
                c_out[k] = c[k]
        # 返回拼接后的张量和条件字典
        return torch.cat([x] * 2), torch.cat([s] * 2), c_out


# 定义一个类 IdentityGuider,实现一个恒等引导器
class IdentityGuider:
    # 定义 __call__ 方法,直接返回输入张量
    def __call__(self, x, sigma, **kwargs):
        return x

    # 定义准备输入的方法,返回输入和条件字典
    def prepare_inputs(self, x, s, c, uc):
        # 初始化输出字典
        c_out = dict()

        # 遍历条件字典 c 的键
        for k in c:
            # 直接将条件字典 c 的值赋给输出字典
            c_out[k] = c[k]

        # 返回输入张量和条件字典
        return x, s, c_out


# 定义一个类 LinearPredictionGuider,继承自 Guider
class LinearPredictionGuider(Guider):
    # 初始化方法,接受多个参数
    def __init__(
        self,
        max_scale: float,
        num_frames: int,
        min_scale: float = 1.0,
        additional_cond_keys: Optional[Union[List[str], str]] = None,
    ):
        # 初始化最小和最大比例
        self.min_scale = min_scale
        self.max_scale = max_scale
        # 计算比例的线性变化,生成 num_frames 个值
        self.num_frames = num_frames
        self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0)

        # 确保 additional_cond_keys 是一个列表,如果是字符串则转换为列表
        additional_cond_keys = default(additional_cond_keys, [])
        if isinstance(additional_cond_keys, str):
            additional_cond_keys = [additional_cond_keys]
        # 保存附加条件键
        self.additional_cond_keys = additional_cond_keys
    # 定义可调用对象的方法,接收输入张量 x 和 sigma,以及其他参数 kwargs,返回一个张量
    def __call__(self, x: torch.Tensor, sigma: torch.Tensor, **kwargs) -> torch.Tensor:
        # 将输入张量 x 拆分为两部分:x_u 和 x_c
        x_u, x_c = x.chunk(2)
    
        # 重排 x_u 的维度,使其形状为 (批量大小 b, 帧数 t, ...),t 由 num_frames 指定
        x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames)
        # 重排 x_c 的维度,使其形状为 (批量大小 b, 帧数 t, ...),t 由 num_frames 指定
        x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames)
        # 复制 scale 张量的维度,使其形状为 (批量大小 b, 帧数 t)
        scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0])
        # 将 scale 的维度扩展到与 x_u 的维度一致,并移动到 x_u 的设备上
        scale = append_dims(scale, x_u.ndim).to(x_u.device)
        # 将 scale 转换为与 x_u 相同的数据类型
        scale = scale.to(x_u.dtype)
    
        # 返回经过计算的结果,重排为 (批量大小 b * 帧数 t, ...)
        return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...")
    
    # 定义准备输入的函数,接收输入张量 x 和 s,以及条件字典 c 和 uc,返回一个元组
    def prepare_inputs(
        self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict
    ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
        # 初始化一个空字典 c_out 用于存放处理后的条件
        c_out = dict()
    
        # 遍历条件字典 c 的每一个键 k
        for k in c:
            # 如果 k 是指定的条件键之一,进行拼接
            if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys:
                # 将 uc[k] 和 c[k] 沿第0维拼接,并存入 c_out
                c_out[k] = torch.cat((uc[k], c[k]), 0)
            else:
                # 确保 c[k] 与 uc[k] 相等
                assert c[k] == uc[k]
                # 将 c[k] 直接存入 c_out
                c_out[k] = c[k]
        # 返回拼接后的 x 和 s 以及处理后的条件字典 c_out
        return torch.cat([x] * 2), torch.cat([s] * 2), c_out

.\cogview3-finetune\sat\sgm\modules\diffusionmodules\loss.py

# 导入所需的标准库和类型提示
import os
import copy
from typing import List, Optional, Union

# 导入 NumPy 和 PyTorch 库
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# 导入 OmegaConf 中的 ListConfig
from omegaconf import ListConfig

# 从自定义模块中导入所需的函数和类
from ...util import append_dims, instantiate_from_config
from ...modules.autoencoding.lpips.loss.lpips import LPIPS
from ...modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
from ...util import get_obj_from_str, default
from ...modules.diffusionmodules.discretizer import generate_roughly_equally_spaced_steps, sub_generate_roughly_equally_spaced_steps


# 定义标准扩散损失类,继承自 nn.Module
class StandardDiffusionLoss(nn.Module):
    # 初始化方法,设置损失类型和噪声级别等参数
    def __init__(
        self,
        sigma_sampler_config,
        type="l2",
        offset_noise_level=0.0,
        batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
    ):
        super().__init__()

        # 确保损失类型有效
        assert type in ["l2", "l1", "lpips"]

        # 根据配置实例化 sigma 采样器
        self.sigma_sampler = instantiate_from_config(sigma_sampler_config)

        # 保存损失类型和噪声级别
        self.type = type
        self.offset_noise_level = offset_noise_level

        # 如果损失类型为 lpips,则初始化 lpips 模块
        if type == "lpips":
            self.lpips = LPIPS().eval()

        # 如果没有提供 batch2model_keys,则设置为空列表
        if not batch2model_keys:
            batch2model_keys = []

        # 如果 batch2model_keys 是字符串,则转为列表
        if isinstance(batch2model_keys, str):
            batch2model_keys = [batch2model_keys]

        # 将 batch2model_keys 转为集合以便于后续处理
        self.batch2model_keys = set(batch2model_keys)

    # 定义调用方法,计算损失
    def __call__(self, network, denoiser, conditioner, input, batch):
        # 使用条件器处理输入批次
        cond = conditioner(batch)
        # 从批次中提取附加模型输入
        additional_model_inputs = {
            key: batch[key] for key in self.batch2model_keys.intersection(batch)
        }

        # 生成 sigma 值
        sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
        # 生成与输入相同形状的随机噪声
        noise = torch.randn_like(input)
        # 如果设置了噪声级别,调整噪声
        if self.offset_noise_level > 0.0:
            noise = noise + append_dims(
                torch.randn(input.shape[0]).to(input.device), input.ndim
            ) * self.offset_noise_level
            # 确保噪声数据类型与输入一致
            noise = noise.to(input.dtype)
        # 将输入与噪声和 sigma 结合,生成有噪声的输入
        noised_input = input.float() + noise * append_dims(sigmas, input.ndim)
        # 使用去噪网络处理有噪声的输入
        model_output = denoiser(
            network, noised_input, sigmas, cond, **additional_model_inputs
        )
        # 将去噪网络的权重调整为与输入相同的维度
        w = append_dims(denoiser.w(sigmas), input.ndim)
        # 返回损失值
        return self.get_loss(model_output, input, w)

    # 定义计算损失的方法
    def get_loss(self, model_output, target, w):
        # 根据损失类型计算 l2 损失
        if self.type == "l2":
            return torch.mean(
                (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
            )
        # 根据损失类型计算 l1 损失
        elif self.type == "l1":
            return torch.mean(
                (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
            )
        # 根据损失类型计算 lpips 损失
        elif self.type == "lpips":
            loss = self.lpips(model_output, target).reshape(-1)
            return loss


# 定义线性中继扩散损失类,继承自 StandardDiffusionLoss
class LinearRelayDiffusionLoss(StandardDiffusionLoss):
    # 初始化方法,设置相关参数
    def __init__(
        self,
        sigma_sampler_config,
        type="l2",
        offset_noise_level=0.0,
        partial_num_steps=500,
        blurring_schedule='linear',
        batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
    ):
        # 调用父类构造函数,初始化基本参数
        super().__init__(
            sigma_sampler_config,  # sigma 采样器的配置
            type=type,  # 类型参数
            offset_noise_level=offset_noise_level,  # 偏移噪声水平
            batch2model_keys=batch2model_keys,  # 批次到模型的键映射
        )

        # 设置模糊调度参数
        self.blurring_schedule = blurring_schedule
        # 设置部分步骤数量
        self.partial_num_steps = partial_num_steps

    
    def __call__(self, network, denoiser, conditioner, input, batch):
        # 使用调节器处理批次数据,生成条件
        cond = conditioner(batch)
        # 生成额外的模型输入,筛选出与模型键对应的批次数据
        additional_model_inputs = {
            key: batch[key] for key in self.batch2model_keys.intersection(batch)
        }
        # 从批次中获取低分辨率输入
        lr_input = batch["lr_input"]

        # 生成随机整数,用于选择部分步骤
        rand = torch.randint(0, self.partial_num_steps, (input.shape[0],))
        # 从 sigma 采样器生成 sigma 值,并转换为输入数据类型和设备
        sigmas = self.sigma_sampler(input.shape[0], rand).to(input.dtype).to(input.device)
        # 生成与输入形状相同的随机噪声
        noise = torch.randn_like(input)
        # 如果偏移噪声水平大于0,则添加额外噪声
        if self.offset_noise_level > 0.0:
            # 生成额外随机噪声并调整其维度,乘以偏移噪声水平
            noise = noise + append_dims(
                torch.randn(input.shape[0]).to(input.device), input.ndim
            ) * self.offset_noise_level
            # 转换噪声为输入数据类型
            noise = noise.to(input.dtype)
        # 调整 rand 的维度并转换为输入数据类型和设备
        rand = append_dims(rand, input.ndim).to(input.dtype).to(input.device)
        # 根据模糊调度的不同方式计算模糊输入
        if self.blurring_schedule == 'linear':
            # 线性模糊处理
            blurred_input = input * (1 - rand / self.partial_num_steps) + lr_input * (rand / self.partial_num_steps)
        elif self.blurring_schedule == 'sigma':
            # 使用 sigma 最大值进行模糊处理
            max_sigmas = self.sigma_sampler(input.shape[0], torch.ones(input.shape[0])*self.partial_num_steps).to(input.dtype).to(input.device)
            blurred_input = input * (1 - sigmas / max_sigmas) + lr_input * (sigmas / max_sigmas)
        elif self.blurring_schedule == 'exp':
            # 指数模糊处理
            rand_blurring = (1 - torch.exp(-(torch.sin((rand+1) / self.partial_num_steps * torch.pi / 2)**4))) / (1 - torch.exp(-torch.ones_like(rand)))
            blurred_input = input * (1 - rand_blurring) + lr_input * rand_blurring
        else:
            # 如果模糊调度不被支持,抛出未实现错误
            raise NotImplementedError
        # 将噪声添加到模糊输入中
        noised_input = blurred_input + noise * append_dims(sigmas, input.ndim)
        # 调用去噪声器处理模糊输入,获取模型输出
        model_output = denoiser(
            network, noised_input, sigmas, cond, **additional_model_inputs
        )
        # 调整去噪声器权重的维度
        w = append_dims(denoiser.w(sigmas), input.ndim)
        # 返回模型输出的损失值
        return self.get_loss(model_output, input, w)
# 定义一个名为 ZeroSNRDiffusionLoss 的类,继承自 StandardDiffusionLoss
class ZeroSNRDiffusionLoss(StandardDiffusionLoss):

    # 重载调用方法,接受网络、去噪器、条件、输入和批次作为参数
    def __call__(self, network, denoiser, conditioner, input, batch):
        # 使用条件生成器处理批次,得到条件变量
        cond = conditioner(batch)
        # 从批次中提取与模型键相交的额外输入
        additional_model_inputs = {
            key: batch[key] for key in self.batch2model_keys.intersection(batch)
        }

        # 生成累积的 alpha 值并获取索引
        alphas_cumprod_sqrt, idx = self.sigma_sampler(input.shape[0], return_idx=True)
        # 将 alpha 值移动到输入的设备上
        alphas_cumprod_sqrt = alphas_cumprod_sqrt.to(input.device)
        # 将索引移动到输入的数据类型和设备上
        idx = idx.to(input.dtype).to(input.device)
        # 将索引添加到额外模型输入中
        additional_model_inputs['idx'] = idx

        # 生成与输入形状相同的随机噪声
        noise = torch.randn_like(input)
        # 如果偏移噪声水平大于零,则添加额外噪声
        if self.offset_noise_level > 0.0:
            noise = noise + append_dims(
                # 生成随机噪声并调整维度,乘以偏移噪声水平
                torch.randn(input.shape[0]).to(input.device), input.ndim
            ) * self.offset_noise_level

        # 计算加入噪声的输入
        noised_input = input.float() * append_dims(alphas_cumprod_sqrt, input.ndim) + noise * append_dims((1-alphas_cumprod_sqrt**2)**0.5, input.ndim)
        # 使用去噪器处理带噪声的输入
        model_output = denoiser(
            network, noised_input, alphas_cumprod_sqrt, cond, **additional_model_inputs
        )
        # 计算 v-pred 权重
        w = append_dims(1/(1-alphas_cumprod_sqrt**2), input.ndim) 
        # 返回损失值
        return self.get_loss(model_output, input, w)
    
    # 定义一个获取损失的函数
    def get_loss(self, model_output, target, w):
        # 如果损失类型为 L2,计算 L2 损失
        if self.type == "l2":
            return torch.mean(
                # 计算每个样本的 L2 损失并调整维度
                (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
            )
        # 如果损失类型为 L1,计算 L1 损失
        elif self.type == "l1":
            return torch.mean(
                # 计算每个样本的 L1 损失并调整维度
                (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
            )
        # 如果损失类型为 LPIPS,计算 LPIPS 损失
        elif self.type == "lpips":
            loss = self.lpips(model_output, target).reshape(-1)
            return loss

.\cogview3-finetune\sat\sgm\modules\diffusionmodules\model.py

# pytorch_diffusion + derived encoder decoder
# 导入数学库
import math
# 导入类型注解相关
from typing import Any, Callable, Optional

# 导入 numpy 库
import numpy as np
# 导入 pytorch 库
import torch
# 导入 pytorch 神经网络模块
import torch.nn as nn
# 导入 rearrange 函数以处理张量重排列
from einops import rearrange
# 导入版本管理库
from packaging import version

# 尝试导入 xformers 模块
try:
    import xformers
    import xformers.ops

    # 如果成功导入,设置标志为 True
    XFORMERS_IS_AVAILABLE = True
except:
    # 如果导入失败,设置标志为 False,并打印提示信息
    XFORMERS_IS_AVAILABLE = False
    print("no module 'xformers'. Processing without...")

# 从其他模块导入 LinearAttention 和 MemoryEfficientCrossAttention
from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention


def get_timestep_embedding(timesteps, embedding_dim):
    """
    此函数与 Denoising Diffusion Probabilistic Models 中的实现相匹配:
    来自 Fairseq。
    构建正弦嵌入。
    此实现与 tensor2tensor 中的实现相匹配,但与 "Attention Is All You Need" 第 3.5 节中的描述略有不同。
    """
    # 确保时间步长是一维的
    assert len(timesteps.shape) == 1

    # 计算嵌入维度的一半
    half_dim = embedding_dim // 2
    # 计算嵌入因子的对数
    emb = math.log(10000) / (half_dim - 1)
    # 计算并生成指数衰减的嵌入
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
    # 将嵌入移动到与时间步相同的设备上
    emb = emb.to(device=timesteps.device)
    # 扩展时间步并与嵌入相乘
    emb = timesteps.float()[:, None] * emb[None, :]
    # 将正弦和余弦嵌入拼接在一起
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    # 如果嵌入维度是奇数,则进行零填充
    if embedding_dim % 2 == 1:  # zero pad
        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
    # 返回最终的嵌入
    return emb


def nonlinearity(x):
    # 使用 swish 激活函数
    return x * torch.sigmoid(x)


def Normalize(in_channels, num_groups=32):
    # 返回一个 GroupNorm 归一化层
    return torch.nn.GroupNorm(
        num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
    )


class Upsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        # 初始化 Upsample 类
        super().__init__()
        # 记录是否使用卷积
        self.with_conv = with_conv
        # 如果使用卷积,则定义卷积层
        if self.with_conv:
            self.conv = torch.nn.Conv2d(
                in_channels, in_channels, kernel_size=3, stride=1, padding=1
            )

    def forward(self, x):
        # 使用最近邻插值将输入张量上采样
        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
        # 如果使用卷积,则应用卷积层
        if self.with_conv:
            x = self.conv(x)
        # 返回处理后的张量
        return x


class Downsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        # 初始化 Downsample 类
        super().__init__()
        # 记录是否使用卷积
        self.with_conv = with_conv
        # 如果使用卷积,则定义卷积层
        if self.with_conv:
            # 因为 pytorch 卷积不支持不对称填充,需手动处理
            self.conv = torch.nn.Conv2d(
                in_channels, in_channels, kernel_size=3, stride=2, padding=0
            )

    def forward(self, x):
        # 如果使用卷积,先进行填充再应用卷积层
        if self.with_conv:
            pad = (0, 1, 0, 1)
            x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
            x = self.conv(x)
        # 否则使用平均池化进行下采样
        else:
            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
        # 返回处理后的张量
        return x


class ResnetBlock(nn.Module):
    def __init__(
        self,
        *,
        in_channels,
        out_channels=None,
        conv_shortcut=False,
        dropout,
        temb_channels=512,
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 保存输入通道数
        self.in_channels = in_channels
        # 如果未指定输出通道数,则设置为输入通道数
        out_channels = in_channels if out_channels is None else out_channels
        # 保存输出通道数
        self.out_channels = out_channels
        # 保存是否使用卷积捷径的标志
        self.use_conv_shortcut = conv_shortcut

        # 初始化输入通道数的归一化层
        self.norm1 = Normalize(in_channels)
        # 定义第一层卷积,输入输出通道及卷积核参数
        self.conv1 = torch.nn.Conv2d(
            in_channels, out_channels, kernel_size=3, stride=1, padding=1
        )
        # 如果有时间嵌入通道,则定义时间嵌入投影层
        if temb_channels > 0:
            self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
        # 初始化输出通道数的归一化层
        self.norm2 = Normalize(out_channels)
        # 定义 dropout 层
        self.dropout = torch.nn.Dropout(dropout)
        # 定义第二层卷积,输入输出通道及卷积核参数
        self.conv2 = torch.nn.Conv2d(
            out_channels, out_channels, kernel_size=3, stride=1, padding=1
        )
        # 如果输入和输出通道数不相同
        if self.in_channels != self.out_channels:
            # 如果使用卷积捷径,则定义卷积捷径层
            if self.use_conv_shortcut:
                self.conv_shortcut = torch.nn.Conv2d(
                    in_channels, out_channels, kernel_size=3, stride=1, padding=1
                )
            # 否则定义 1x1 卷积捷径层
            else:
                self.nin_shortcut = torch.nn.Conv2d(
                    in_channels, out_channels, kernel_size=1, stride=1, padding=0
                )

    # 前向传播函数
    def forward(self, x, temb):
        # 将输入赋值给 h 变量
        h = x
        # 对 h 进行归一化
        h = self.norm1(h)
        # 应用非线性激活函数
        h = nonlinearity(h)
        # 通过第一层卷积处理 h
        h = self.conv1(h)

        # 如果时间嵌入不为 None
        if temb is not None:
            # 将时间嵌入通过非线性激活函数处理后投影到输出通道,并与 h 相加
            h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]

        # 对 h 进行第二次归一化
        h = self.norm2(h)
        # 应用非线性激活函数
        h = nonlinearity(h)
        # 通过 dropout 层处理 h
        h = self.dropout(h)
        # 通过第二层卷积处理 h
        h = self.conv2(h)

        # 如果输入和输出通道数不相同
        if self.in_channels != self.out_channels:
            # 如果使用卷积捷径,则通过卷积捷径层处理 x
            if self.use_conv_shortcut:
                x = self.conv_shortcut(x)
            # 否则通过 1x1 卷积捷径层处理 x
            else:
                x = self.nin_shortcut(x)

        # 返回 x 和 h 的相加结果
        return x + h
# 定义 LinAttnBlock 类,继承自 LinearAttention
class LinAttnBlock(LinearAttention):
    """to match AttnBlock usage"""  # 文档字符串,说明该类用于匹配 AttnBlock 的使用方式

    # 初始化方法,接受输入通道数
    def __init__(self, in_channels):
        # 调用父类的初始化方法,设置维度和头数
        super().__init__(dim=in_channels, heads=1, dim_head=in_channels)


# 定义 AttnBlock 类,继承自 nn.Module
class AttnBlock(nn.Module):
    # 初始化方法,接受输入通道数
    def __init__(self, in_channels):
        # 调用父类的初始化方法
        super().__init__()
        # 保存输入通道数
        self.in_channels = in_channels

        # 初始化归一化层
        self.norm = Normalize(in_channels)
        # 初始化查询卷积层
        self.q = torch.nn.Conv2d(
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        # 初始化键卷积层
        self.k = torch.nn.Conv2d(
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        # 初始化值卷积层
        self.v = torch.nn.Conv2d(
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        # 初始化输出投影卷积层
        self.proj_out = torch.nn.Conv2d(
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )

    # 定义注意力计算方法
    def attention(self, h_: torch.Tensor) -> torch.Tensor:
        # 对输入进行归一化
        h_ = self.norm(h_)
        # 计算查询、键和值
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # 获取查询的形状参数
        b, c, h, w = q.shape
        # 重新排列查询、键和值的形状
        q, k, v = map(
            lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)
        )
        # 计算缩放的点积注意力
        h_ = torch.nn.functional.scaled_dot_product_attention(
            q, k, v
        )  # scale is dim ** -0.5 per default
        # 计算注意力

        # 返回重新排列后的注意力结果
        return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)

    # 定义前向传播方法
    def forward(self, x, **kwargs):
        # 将输入赋值给 h_
        h_ = x
        # 计算注意力
        h_ = self.attention(h_)
        # 应用输出投影
        h_ = self.proj_out(h_)
        # 返回输入与注意力结果的和
        return x + h_


# 定义 MemoryEfficientAttnBlock 类,继承自 nn.Module
class MemoryEfficientAttnBlock(nn.Module):
    """
    Uses xformers efficient implementation,
    see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
    Note: this is a single-head self-attention operation
    """  # 文档字符串,说明该类使用 xformers 高效实现的单头自注意力

    # 初始化方法,接受输入通道数
    def __init__(self, in_channels):
        # 调用父类的初始化方法
        super().__init__()
        # 保存输入通道数
        self.in_channels = in_channels

        # 初始化归一化层
        self.norm = Normalize(in_channels)
        # 初始化查询卷积层
        self.q = torch.nn.Conv2d(
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        # 初始化键卷积层
        self.k = torch.nn.Conv2d(
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        # 初始化值卷积层
        self.v = torch.nn.Conv2d(
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        # 初始化输出投影卷积层
        self.proj_out = torch.nn.Conv2d(
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        # 初始化注意力操作,类型为可选的任意类型
        self.attention_op: Optional[Any] = None
    # 定义注意力机制的函数,输入为一个张量,输出也是一个张量
        def attention(self, h_: torch.Tensor) -> torch.Tensor:
            # 先对输入进行归一化处理
            h_ = self.norm(h_)
            # 通过线性变换生成查询张量
            q = self.q(h_)
            # 通过线性变换生成键张量
            k = self.k(h_)
            # 通过线性变换生成值张量
            v = self.v(h_)
    
            # 计算注意力
            # 获取查询张量的形状信息
            B, C, H, W = q.shape
            # 调整张量形状,将其从四维转为二维
            q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
    
            # 对查询、键、值进行维度调整以便计算注意力
            q, k, v = map(
                lambda t: t.unsqueeze(3)  # 在最后增加一个维度
                .reshape(B, t.shape[1], 1, C)  # 调整形状
                .permute(0, 2, 1, 3)  # 变换维度顺序
                .reshape(B * 1, t.shape[1], C)  # 重新调整形状
                .contiguous(),  # 保证内存连续性
                (q, k, v),
            )
            # 使用内存高效的注意力操作
            out = xformers.ops.memory_efficient_attention(
                q, k, v, attn_bias=None, op=self.attention_op
            )
    
            # 调整输出张量的形状
            out = (
                out.unsqueeze(0)  # 增加一个维度
                .reshape(B, 1, out.shape[1], C)  # 调整形状
                .permute(0, 2, 1, 3)  # 变换维度顺序
                .reshape(B, out.shape[1], C)  # 重新调整形状
            )
            # 将输出张量的形状恢复为原来的格式
            return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
    
        # 定义前向传播函数
        def forward(self, x, **kwargs):
            # 输入数据赋值给 h_
            h_ = x
            # 通过注意力机制处理 h_
            h_ = self.attention(h_)
            # 通过输出投影处理 h_
            h_ = self.proj_out(h_)
            # 返回输入和处理后的 h_ 的和
            return x + h_
# 定义一个内存高效的交叉注意力包装类,继承自 MemoryEfficientCrossAttention
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
    # 前向传播方法,接受输入张量和可选的上下文、掩码
    def forward(self, x, context=None, mask=None, **unused_kwargs):
        # 解包输入张量的维度:批量大小、通道数、高度和宽度
        b, c, h, w = x.shape
        # 重新排列输入张量的维度,将 (b, c, h, w) 转换为 (b, h*w, c)
        x = rearrange(x, "b c h w -> b (h w) c")
        # 调用父类的 forward 方法,处理重新排列后的输入
        out = super().forward(x, context=context, mask=mask)
        # 将输出张量的维度重新排列回 (b, c, h, w)
        out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
        # 返回输入与输出的和,进行残差连接
        return x + out


# 定义一个生成注意力模块的函数
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
    # 检查传入的注意力类型是否在支持的类型列表中
    assert attn_type in [
        "vanilla",
        "vanilla-xformers",
        "memory-efficient-cross-attn",
        "linear",
        "none",
    ], f"attn_type {attn_type} unknown"
    # 检查 PyTorch 版本,并且如果类型不是 "none",则验证是否可用 xformers
    if (
        version.parse(torch.__version__) < version.parse("2.0.0")
        and attn_type != "none"
    ):
        assert XFORMERS_IS_AVAILABLE, (
            f"We do not support vanilla attention in {torch.__version__} anymore, "
            f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
        )
        # 将注意力类型设置为 "vanilla-xformers"
        attn_type = "vanilla-xformers"
    # 根据注意力类型生成相应的注意力块
    if attn_type == "vanilla":
        # 验证注意力参数不为 None
        assert attn_kwargs is None
        # 返回标准的注意力块
        return AttnBlock(in_channels)
    elif attn_type == "vanilla-xformers":
        # 返回内存高效的注意力块
        return MemoryEfficientAttnBlock(in_channels)
    elif attn_type == "memory-efficient-cross-attn":
        # 设置查询维度为输入通道数
        attn_kwargs["query_dim"] = in_channels
        # 返回内存高效的交叉注意力包装类
        return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
    elif attn_type == "none":
        # 返回一个身份映射层,不改变输入
        return nn.Identity(in_channels)
    else:
        # 返回线性注意力块
        return LinAttnBlock(in_channels)


# 定义一个模型类,继承自 nn.Module
class Model(nn.Module):
    # 初始化方法,接受多个参数进行模型构建
    def __init__(
        self,
        *,
        ch,
        out_ch,
        ch_mult=(1, 2, 4, 8),
        num_res_blocks,
        attn_resolutions,
        dropout=0.0,
        resamp_with_conv=True,
        in_channels,
        resolution,
        use_timestep=True,
        use_linear_attn=False,
        attn_type="vanilla",
    # 定义前向传播方法,接受输入 x、时间步 t 和上下文 context
        def forward(self, x, t=None, context=None):
            # 确保输入 x 的高度和宽度与设定的分辨率相等(被注释掉)
            # assert x.shape[2] == x.shape[3] == self.resolution
            # 如果上下文不为 None,沿通道维度连接输入 x 和上下文
            if context is not None:
                # 假设上下文对齐,沿通道轴拼接
                x = torch.cat((x, context), dim=1)
            # 如果使用时间步,进行时间步嵌入
            if self.use_timestep:
                # 确保时间步 t 不为 None
                assert t is not None
                # 获取时间步嵌入
                temb = get_timestep_embedding(t, self.ch)
                # 通过第一层密集层处理时间步嵌入
                temb = self.temb.dense[0](temb)
                # 应用非线性变换
                temb = nonlinearity(temb)
                # 通过第二层密集层处理
                temb = self.temb.dense[1](temb)
            else:
                # 如果不使用时间步,设置时间步嵌入为 None
                temb = None
    
            # 下采样
            hs = [self.conv_in(x)]  # 初始卷积层的输出
            for i_level in range(self.num_resolutions):
                for i_block in range(self.num_res_blocks):
                    # 通过当前下采样层和时间步嵌入处理前一层输出
                    h = self.down[i_level].block[i_block](hs[-1], temb)
                    # 如果存在注意力层,则对输出进行注意力处理
                    if len(self.down[i_level].attn) > 0:
                        h = self.down[i_level].attn[i_block](h)
                    # 将处理后的输出添加到列表
                    hs.append(h)
                # 如果不是最后一层分辨率,进行下采样
                if i_level != self.num_resolutions - 1:
                    hs.append(self.down[i_level].downsample(hs[-1]))
    
            # 中间处理
            h = hs[-1]  # 获取最后一层的输出
            h = self.mid.block_1(h, temb)  # 通过中间块处理
            h = self.mid.attn_1(h)  # 通过中间注意力层处理
            h = self.mid.block_2(h, temb)  # 再次通过中间块处理
    
            # 上采样
            for i_level in reversed(range(self.num_resolutions)):
                for i_block in range(self.num_res_blocks + 1):
                    # 拼接上层输出和当前层的输出,然后通过上采样块处理
                    h = self.up[i_level].block[i_block](
                        torch.cat([h, hs.pop()], dim=1), temb
                    )
                    # 如果存在注意力层,则对输出进行注意力处理
                    if len(self.up[i_level].attn) > 0:
                        h = self.up[i_level].attn[i_block](h)
                # 如果不是第一层分辨率,进行上采样
                if i_level != 0:
                    h = self.up[i_level].upsample(h)
    
            # 结束处理
            h = self.norm_out(h)  # 最后的归一化处理
            h = nonlinearity(h)  # 应用非线性变换
            h = self.conv_out(h)  # 通过输出卷积层处理
            return h  # 返回最终输出
    
        # 获取最后一层的卷积权重
        def get_last_layer(self):
            return self.conv_out.weight  # 返回输出卷积层的权重
# 定义一个编码器类,继承自 nn.Module
class Encoder(nn.Module):
    # 初始化方法,接收多个参数用于配置编码器
    def __init__(
        self,
        *,
        ch,
        out_ch,
        ch_mult=(1, 2, 4, 8),
        num_res_blocks,
        attn_resolutions,
        dropout=0.0,
        resamp_with_conv=True,
        in_channels,
        resolution,
        z_channels,
        double_z=True,
        use_linear_attn=False,
        attn_type="vanilla",
        mid_attn=True,
        **ignore_kwargs,
    ):
        # 调用父类构造方法
        super().__init__()
        # 如果使用线性注意力,设置注意力类型为线性
        if use_linear_attn:
            attn_type = "linear"
        # 保存输入参数以供后续使用
        self.ch = ch
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels
        self.attn_resolutions = attn_resolutions
        self.mid_attn = mid_attn

        # 下采样
        # 定义输入卷积层
        self.conv_in = torch.nn.Conv2d(
            in_channels, self.ch, kernel_size=3, stride=1, padding=1
        )

        # 当前分辨率初始化
        curr_res = resolution
        # 定义输入通道的倍率
        in_ch_mult = (1,) + tuple(ch_mult)
        self.in_ch_mult = in_ch_mult
        # 初始化下采样模块列表
        self.down = nn.ModuleList()
        # 遍历每个分辨率层级
        for i_level in range(self.num_resolutions):
            # 初始化块和注意力模块列表
            block = nn.ModuleList()
            attn = nn.ModuleList()
            # 输入和输出通道数计算
            block_in = ch * in_ch_mult[i_level]
            block_out = ch * ch_mult[i_level]
            # 遍历每个残差块
            for i_block in range(self.num_res_blocks):
                # 添加残差块到块列表中
                block.append(
                    ResnetBlock(
                        in_channels=block_in,
                        out_channels=block_out,
                        temb_channels=self.temb_ch,
                        dropout=dropout,
                    )
                )
                # 更新输入通道数为当前块的输出通道数
                block_in = block_out
                # 如果当前分辨率在注意力分辨率列表中,添加注意力模块
                if curr_res in attn_resolutions:
                    attn.append(make_attn(block_in, attn_type=attn_type))
            # 创建下采样模块
            down = nn.Module()
            down.block = block
            down.attn = attn
            # 如果不是最后一个分辨率,添加下采样层
            if i_level != self.num_resolutions - 1:
                down.downsample = Downsample(block_in, resamp_with_conv)
                # 更新当前分辨率为一半
                curr_res = curr_res // 2
            # 将下采样模块添加到列表中
            self.down.append(down)

        # 中间层
        self.mid = nn.Module()
        # 添加第一个残差块
        self.mid.block_1 = ResnetBlock(
            in_channels=block_in,
            out_channels=block_in,
            temb_channels=self.temb_ch,
            dropout=dropout,
        )
        # 如果使用中间注意力,添加注意力模块
        if mid_attn:
            self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
        # 添加第二个残差块
        self.mid.block_2 = ResnetBlock(
            in_channels=block_in,
            out_channels=block_in,
            temb_channels=self.temb_ch,
            dropout=dropout,
        )

        # 结束层
        # 定义归一化层
        self.norm_out = Normalize(block_in)
        # 定义输出卷积层,根据是否双 z 通道设置输出通道数
        self.conv_out = torch.nn.Conv2d(
            block_in,
            2 * z_channels if double_z else z_channels,
            kernel_size=3,
            stride=1,
            padding=1,
        )
    # 定义前向传播方法,接受输入数据 x
    def forward(self, x):
        # 时间步嵌入初始化为 None
        temb = None

        # 下采样过程
        # 对输入 x 进行卷积操作,生成初始特征图 hs
        hs = [self.conv_in(x)]
        # 遍历每个分辨率层
        for i_level in range(self.num_resolutions):
            # 遍历当前分辨率层中的每个残差块
            for i_block in range(self.num_res_blocks):
                # 使用当前层的残差块处理上一个层的输出和时间步嵌入
                h = self.down[i_level].block[i_block](hs[-1], temb)
                # 如果当前层有注意力机制,则应用注意力
                if len(self.down[i_level].attn) > 0:
                    h = self.down[i_level].attn[i_block](h)
                # 将当前层的输出添加到特征图列表中
                hs.append(h)
            # 如果当前层不是最后一个分辨率层,则进行下采样
            if i_level != self.num_resolutions - 1:
                hs.append(self.down[i_level].downsample(hs[-1]))

        # 中间处理阶段
        h = hs[-1]  # 获取最后一层的输出
        # 通过中间块1处理输入
        h = self.mid.block_1(h, temb)
        # 如果中间层有注意力机制,则应用注意力
        if self.mid_attn:
            h = self.mid.attn_1(h)
        # 通过中间块2处理输出
        h = self.mid.block_2(h, temb)

        # 最终处理阶段
        h = self.norm_out(h)  # 应用输出归一化
        h = nonlinearity(h)   # 应用非线性激活函数
        h = self.conv_out(h)  # 通过输出卷积生成最终结果
        return h  # 返回最终输出
# 定义一个解码器类,继承自 PyTorch 的 nn.Module
class Decoder(nn.Module):
    # 初始化方法,定义解码器的参数
    def __init__(
        self,
        *,
        ch,  # 输入通道数
        out_ch,  # 输出通道数
        ch_mult=(1, 2, 4, 8),  # 通道数的倍增因子
        num_res_blocks,  # 残差块的数量
        attn_resolutions,  # 注意力机制应用的分辨率
        dropout=0.0,  # dropout 比例,默认值为 0
        resamp_with_conv=True,  # 是否使用卷积进行上采样
        in_channels,  # 输入的通道数
        resolution,  # 输入的分辨率
        z_channels,  # 潜在变量的通道数
        give_pre_end=False,  # 是否在前面给予额外的结束标志
        tanh_out=False,  # 输出是否经过 tanh 激活
        use_linear_attn=False,  # 是否使用线性注意力机制
        attn_type="vanilla",  # 注意力类型,默认为“vanilla”
        mid_attn=True,  # 是否在中间层使用注意力
        **ignorekwargs,  # 其他忽略的参数,采用关键字参数形式
    ):
        # 初始化父类
        super().__init__()
        # 如果使用线性注意力机制,设置注意力类型为线性
        if use_linear_attn:
            attn_type = "linear"
        # 设置通道数
        self.ch = ch
        # 初始化时间嵌入通道数为0
        self.temb_ch = 0
        # 计算分辨率数量
        self.num_resolutions = len(ch_mult)
        # 设置残差块数量
        self.num_res_blocks = num_res_blocks
        # 设置输入分辨率
        self.resolution = resolution
        # 设置输入通道数
        self.in_channels = in_channels
        # 设置是否给出前置结束标志
        self.give_pre_end = give_pre_end
        # 设置激活函数输出
        self.tanh_out = tanh_out
        # 设置注意力分辨率
        self.attn_resolutions = attn_resolutions
        # 设置中间注意力
        self.mid_attn = mid_attn

        # 计算输入通道倍数、块输入通道和当前最低分辨率
        in_ch_mult = (1,) + tuple(ch_mult)
        # 计算当前块的输入通道数
        block_in = ch * ch_mult[self.num_resolutions - 1]
        # 计算当前分辨率
        curr_res = resolution // 2 ** (self.num_resolutions - 1)
        # 设置潜在变量的形状
        self.z_shape = (1, z_channels, curr_res, curr_res)
        # print(
        #     "Working with z of shape {} = {} dimensions.".format(
        #         self.z_shape, np.prod(self.z_shape)
        #     )
        # )

        # 创建注意力和残差块类
        make_attn_cls = self._make_attn()
        make_resblock_cls = self._make_resblock()
        make_conv_cls = self._make_conv()
        # 将潜在变量映射到块输入通道
        self.conv_in = torch.nn.Conv2d(
            z_channels, block_in, kernel_size=3, stride=1, padding=1
        )

        # 中间层
        self.mid = nn.Module()
        # 创建第一个残差块
        self.mid.block_1 = make_resblock_cls(
            in_channels=block_in,
            out_channels=block_in,
            temb_channels=self.temb_ch,
            dropout=dropout,
        )
        # 如果启用中间注意力,创建注意力层
        if mid_attn:
            self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
        # 创建第二个残差块
        self.mid.block_2 = make_resblock_cls(
            in_channels=block_in,
            out_channels=block_in,
            temb_channels=self.temb_ch,
            dropout=dropout,
        )

        # 上采样层
        self.up = nn.ModuleList()
        # 从高到低遍历每个分辨率级别
        for i_level in reversed(range(self.num_resolutions)):
            block = nn.ModuleList()  # 残差块列表
            attn = nn.ModuleList()   # 注意力层列表
            # 计算当前块的输出通道数
            block_out = ch * ch_mult[i_level]
            # 创建每个残差块
            for i_block in range(self.num_res_blocks + 1):
                block.append(
                    make_resblock_cls(
                        in_channels=block_in,
                        out_channels=block_out,
                        temb_channels=self.temb_ch,
                        dropout=dropout,
                    )
                )
                # 更新块输入通道
                block_in = block_out
                # 如果当前分辨率在注意力分辨率中,添加注意力层
                if curr_res in attn_resolutions:
                    attn.append(make_attn_cls(block_in, attn_type=attn_type))
            up = nn.Module()  # 上采样模块
            up.block = block  # 添加残差块
            up.attn = attn   # 添加注意力层
            # 如果不是最低分辨率,添加上采样层
            if i_level != 0:
                up.upsample = Upsample(block_in, resamp_with_conv)
                # 更新当前分辨率
                curr_res = curr_res * 2
            # 将上采样模块插入列表的开头
            self.up.insert(0, up)  # prepend to get consistent order

        # 结束层
        # 创建归一化层
        self.norm_out = Normalize(block_in)
        # 创建输出卷积层
        self.conv_out = make_conv_cls(
            block_in, out_ch, kernel_size=3, stride=1, padding=1
        )
    # 定义一个私有方法,用于返回注意力机制的构造函数
    def _make_attn(self) -> Callable:
        return make_attn

    # 定义一个私有方法,用于返回残差块的构造函数
    def _make_resblock(self) -> Callable:
        return ResnetBlock

    # 定义一个私有方法,用于返回二维卷积层的构造函数
    def _make_conv(self) -> Callable:
        return torch.nn.Conv2d

    # 获取最后一层的权重
    def get_last_layer(self, **kwargs):
        return self.conv_out.weight

    # 前向传播方法,接收输入 z 和可选参数
    def forward(self, z, **kwargs):
        # 确保输入 z 的形状与预期相同(被注释掉的检查)
        # assert z.shape[1:] == self.z_shape[1:]
        # 记录输入 z 的形状
        self.last_z_shape = z.shape

        # 初始化时间步嵌入
        temb = None

        # 将输入 z 传入卷积层
        h = self.conv_in(z)

        # 中间处理
        h = self.mid.block_1(h, temb, **kwargs)  # 通过第一块中间块处理
        if self.mid_attn:  # 如果启用了中间注意力
            h = self.mid.attn_1(h, **kwargs)  # 应用中间注意力层
        h = self.mid.block_2(h, temb, **kwargs)  # 通过第二块中间块处理

        # 上采样过程
        for i_level in reversed(range(self.num_resolutions)):  # 从最高分辨率到最低分辨率
            for i_block in range(self.num_res_blocks + 1):  # 遍历每个残差块
                h = self.up[i_level].block[i_block](h, temb, **kwargs)  # 通过上采样块处理
                if len(self.up[i_level].attn) > 0:  # 如果存在注意力层
                    h = self.up[i_level].attn[i_block](h, **kwargs)  # 应用注意力层
            if i_level != 0:  # 如果不是最低分辨率
                h = self.up[i_level].upsample(h)  # 执行上采样

        # 结束处理
        if self.give_pre_end:  # 如果启用了预处理结束返回
            return h

        h = self.norm_out(h)  # 对输出进行归一化
        h = nonlinearity(h)  # 应用非线性激活函数
        h = self.conv_out(h, **kwargs)  # 通过最终卷积层处理
        if self.tanh_out:  # 如果启用了 Tanh 输出
            h = torch.tanh(h)  # 应用 Tanh 激活函数
        return h  # 返回最终输出

.\cogview3-finetune\sat\sgm\modules\diffusionmodules\openaimodel.py

# 导入操作系统模块,用于处理文件和目录操作
import os
# 导入数学模块,提供数学函数和常量
import math
# 从 abc 模块导入抽象方法装饰器,用于定义抽象基类
from abc import abstractmethod
# 从 functools 模块导入 partial 函数,用于偏函数应用
from functools import partial
# 从 typing 模块导入类型注解,用于类型提示
from typing import Iterable, List, Optional, Tuple, Union

# 导入 numpy 库,通常用于数值计算
import numpy as np
# 导入 torch 库,通常用于深度学习
import torch as th
# 导入 PyTorch 的神经网络模块
import torch.nn as nn
# 导入 PyTorch 的功能模块,提供激活函数等
import torch.nn.functional as F
# 从 einops 导入 rearrange 函数,用于重排张量
from einops import rearrange

# 导入自定义模块中的 SpatialTransformer 类
from ...modules.attention import SpatialTransformer
# 导入自定义模块中的实用函数
from ...modules.diffusionmodules.util import (
    avg_pool_nd,  # 平均池化函数
    checkpoint,   # 检查点函数
    conv_nd,      # 卷积函数
    linear,       # 线性变换函数
    normalization, # 归一化函数
    timestep_embedding, # 时间步嵌入函数
    zero_module,  # 零模块函数
)

# 导入自定义模块中的实用函数
from ...util import default, exists

# 定义一个空的占位函数,用于将模块转换为半精度浮点数
# dummy replace
def convert_module_to_f16(x):
    pass

# 定义一个空的占位函数,用于将模块转换为单精度浮点数
def convert_module_to_f32(x):
    pass


# 定义一个用于注意力池化的类,继承自 nn.Module
## go
class AttentionPool2d(nn.Module):
    """
    从 CLIP 中改编: https://github.com/openai/CLIP/blob/main/clip/model.py
    """

    # 初始化方法,设置各类参数
    def __init__(
        self,
        spacial_dim: int,  # 空间维度
        embed_dim: int,    # 嵌入维度
        num_heads_channels: int,  # 头通道数量
        output_dim: int = None,  # 输出维度(可选)
    ):
        # 调用父类初始化方法
        super().__init__()
        # 定义位置嵌入参数,初始化为正态分布
        self.positional_embedding = nn.Parameter(
            th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
        )
        # 定义查询、键、值的卷积投影
        self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
        # 定义输出的卷积投影
        self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
        # 计算头的数量
        self.num_heads = embed_dim // num_heads_channels
        # 初始化注意力机制
        self.attention = QKVAttention(self.num_heads)

    # 前向传播方法
    def forward(self, x):
        # 获取输入的批次大小和通道数
        b, c, *_spatial = x.shape
        # 将输入重塑为 (批次, 通道, 高*宽) 的形状
        x = x.reshape(b, c, -1)  # NC(HW)
        # 在最后一维上添加均值作为额外的特征
        x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)  # NC(HW+1)
        # 将位置嵌入加到输入上
        x = x + self.positional_embedding[None, :, :].to(x.dtype)  # NC(HW+1)
        # 对输入进行查询、键、值投影
        x = self.qkv_proj(x)
        # 应用注意力机制
        x = self.attention(x)
        # 对结果进行输出投影
        x = self.c_proj(x)
        # 返回第一个通道的结果
        return x[:, :, 0]


# 定义一个时间步模块的基类,继承自 nn.Module
class TimestepBlock(nn.Module):
    """
    任何模块的 forward() 方法接受时间步嵌入作为第二个参数。
    """

    # 定义抽象的前向传播方法
    @abstractmethod
    def forward(self, x, emb):
        """
        将模块应用于 `x`,并给定 `emb` 时间步嵌入。
        """


# 定义一个时间步嵌入的顺序模块,继承自 nn.Sequential 和 TimestepBlock
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    """
    一个顺序模块,将时间步嵌入作为额外输入传递给支持的子模块。
    """

    # 重写前向传播方法
    def forward(
        self,
        x: th.Tensor,  # 输入张量
        emb: th.Tensor,  # 时间步嵌入张量
        context: Optional[th.Tensor] = None,  # 上下文张量(可选)
    ):
        # 遍历所有子模块
        for layer in self:
            module = layer

            # 如果子模块是 TimestepBlock,则使用时间步嵌入进行计算
            if isinstance(module, TimestepBlock):
                x = layer(x, emb)
            # 如果子模块是 SpatialTransformer,则使用上下文进行计算
            elif isinstance(module, SpatialTransformer):
                x = layer(x, context)
            # 否则,仅使用输入进行计算
            else:
                x = layer(x)
        # 返回最终的输出
        return x


# 定义一个上采样模块,继承自 nn.Module
class Upsample(nn.Module):
    """
    一个可选卷积的上采样层。
    :param channels: 输入和输出的通道数。
    :param use_conv: 布尔值,确定是否应用卷积。
    :param dims: 确定信号是 1D、2D 还是 3D。如果是 3D,则在内两个维度上进行上采样。
    """
    # 初始化方法,设置类的基本属性
        def __init__(
            self, channels, use_conv, dims=2, out_channels=None, padding=1, third_up=False
        ):
            # 调用父类初始化方法
            super().__init__()
            # 保存输入的通道数
            self.channels = channels
            # 如果没有指定输出通道数,则默认与输入通道数相同
            self.out_channels = out_channels or channels
            # 保存是否使用卷积的标志
            self.use_conv = use_conv
            # 保存维度信息
            self.dims = dims
            # 保存是否进行第三层上采样的标志
            self.third_up = third_up
            # 如果使用卷积,初始化卷积层
            if use_conv:
                self.conv = conv_nd(
                    dims, self.channels, self.out_channels, 3, padding=padding
                )
    
    # 前向传播方法,定义输入如何通过网络进行处理
        def forward(self, x):
            # 确保输入的通道数与初始化时指定的通道数一致
            assert x.shape[1] == self.channels
            # 如果输入为三维数据
            if self.dims == 3:
                # 根据是否需要第三层上采样确定时间因子
                t_factor = 1 if not self.third_up else 2
                # 对输入进行上采样
                x = F.interpolate(
                    x,
                    (t_factor * x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
                    mode="nearest",
                )
            else:
                # 对输入进行上采样,比例因子为2
                x = F.interpolate(x, scale_factor=2, mode="nearest")
            # 如果使用卷积,则将输入通过卷积层处理
            if self.use_conv:
                x = self.conv(x)
            # 返回处理后的输出
            return x
# 定义一个转置上采样的类,继承自 nn.Module
class TransposedUpsample(nn.Module):
    "Learned 2x upsampling without padding"  # 文档字符串,描述该类的功能

    # 初始化方法,设置输入通道、输出通道和卷积核大小
    def __init__(self, channels, out_channels=None, ks=5):
        super().__init__()  # 调用父类的初始化方法
        self.channels = channels  # 保存输入通道数量
        self.out_channels = out_channels or channels  # 如果没有指定输出通道,则与输入通道相同

        # 定义一个转置卷积层,用于上采样
        self.up = nn.ConvTranspose2d(
            self.channels, self.out_channels, kernel_size=ks, stride=2
        )

    # 前向传播方法,执行上采样操作
    def forward(self, x):
        return self.up(x)  # 返回上采样后的结果


# 定义一个下采样层的类,继承自 nn.Module
class Downsample(nn.Module):
    """
    A downsampling layer with an optional convolution.
    :param channels: channels in the inputs and outputs.
    :param use_conv: a bool determining if a convolution is applied.
    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
                 downsampling occurs in the inner-two dimensions.
    """

    # 初始化方法,设置输入通道、是否使用卷积、维度等参数
    def __init__(
        self, channels, use_conv, dims=2, out_channels=None, padding=1, third_down=False
    ):
        super().__init__()  # 调用父类的初始化方法
        self.channels = channels  # 保存输入通道数量
        self.out_channels = out_channels or channels  # 如果没有指定输出通道,则与输入通道相同
        self.use_conv = use_conv  # 保存是否使用卷积的标志
        self.dims = dims  # 保存信号的维度
        stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))  # 确定步幅
        if use_conv:  # 如果使用卷积
            # print(f"Building a Downsample layer with {dims} dims.")  # 打印信息,表示正在构建下采样层
            # print(
            #     f"  --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
            #     f"kernel-size: 3, stride: {stride}, padding: {padding}"
            # )  # 打印卷积层的设置参数
            # if dims == 3:
            #     print(f"  --> Downsampling third axis (time): {third_down}")  # 打印是否在第三维进行下采样
            # 定义卷积操作
            self.op = conv_nd(
                dims,
                self.channels,
                self.out_channels,
                3,
                stride=stride,
                padding=padding,
            )
        else:  # 如果不使用卷积
            assert self.channels == self.out_channels  # 确保输入通道与输出通道相同
            # 定义平均池化操作
            self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)

    # 前向传播方法,执行下采样操作
    def forward(self, x):
        assert x.shape[1] == self.channels  # 确保输入的通道数匹配
        return self.op(x)  # 返回下采样后的结果


# 定义一个残差块的类,继承自 TimestepBlock
class ResBlock(TimestepBlock):
    """
    A residual block that can optionally change the number of channels.
    :param channels: the number of input channels.
    :param emb_channels: the number of timestep embedding channels.
    :param dropout: the rate of dropout.
    :param out_channels: if specified, the number of out channels.
    :param use_conv: if True and out_channels is specified, use a spatial
        convolution instead of a smaller 1x1 convolution to change the
        channels in the skip connection.
    :param dims: determines if the signal is 1D, 2D, or 3D.
    :param use_checkpoint: if True, use gradient checkpointing on this module.
    :param up: if True, use this block for upsampling.
    :param down: if True, use this block for downsampling.
    """
    # 初始化方法,用于创建类的实例
    def __init__(
        self,
        channels,  # 输入通道数
        emb_channels,  # 嵌入通道数
        dropout,  # 丢弃率
        out_channels=None,  # 输出通道数,默认为 None
        use_conv=False,  # 是否使用卷积
        use_scale_shift_norm=False,  # 是否使用缩放位移归一化
        dims=2,  # 数据维度,默认为 2
        use_checkpoint=False,  # 是否使用检查点
        up=False,  # 是否进行上采样
        down=False,  # 是否进行下采样
        kernel_size=3,  # 卷积核大小,默认为 3
        exchange_temb_dims=False,  # 是否交换时间嵌入维度
        skip_t_emb=False,  # 是否跳过时间嵌入
    ):
        # 调用父类初始化方法
        super().__init__()
        # 设置输入通道数
        self.channels = channels
        # 设置嵌入通道数
        self.emb_channels = emb_channels
        # 设置丢弃率
        self.dropout = dropout
        # 设置输出通道数,如果未提供则默认与输入通道数相同
        self.out_channels = out_channels or channels
        # 设置是否使用卷积
        self.use_conv = use_conv
        # 设置是否使用检查点
        self.use_checkpoint = use_checkpoint
        # 设置是否使用缩放位移归一化
        self.use_scale_shift_norm = use_scale_shift_norm
        # 设置是否交换时间嵌入维度
        self.exchange_temb_dims = exchange_temb_dims

        # 如果卷积核大小是可迭代的,计算每个维度的填充大小
        if isinstance(kernel_size, Iterable):
            padding = [k // 2 for k in kernel_size]
        else:
            # 否则直接计算单个卷积核的填充大小
            padding = kernel_size // 2

        # 创建输入层的序列,包括归一化、激活函数和卷积操作
        self.in_layers = nn.Sequential(
            normalization(channels),  # 归一化
            nn.SiLU(),  # SiLU 激活函数
            conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),  # 卷积层
        )

        # 判断是否进行上采样或下采样
        self.updown = up or down

        # 如果进行上采样,初始化上采样层
        if up:
            self.h_upd = Upsample(channels, False, dims)  # 上采样层
            self.x_upd = Upsample(channels, False, dims)  # 上采样层
        # 如果进行下采样,初始化下采样层
        elif down:
            self.h_upd = Downsample(channels, False, dims)  # 下采样层
            self.x_upd = Downsample(channels, False, dims)  # 下采样层
        # 否则使用身份映射
        else:
            self.h_upd = self.x_upd = nn.Identity()  # 身份映射层

        # 设置是否跳过时间嵌入
        self.skip_t_emb = skip_t_emb
        # 根据是否使用缩放位移归一化计算嵌入输出通道数
        self.emb_out_channels = (
            2 * self.out_channels if use_scale_shift_norm else self.out_channels
        )
        # 如果跳过时间嵌入,输出警告并设置嵌入层为 None
        if self.skip_t_emb:
            print(f"Skipping timestep embedding in {self.__class__.__name__}")  # 警告信息
            assert not self.use_scale_shift_norm  # 确保不使用缩放位移归一化
            self.emb_layers = None  # 嵌入层设置为 None
            self.exchange_temb_dims = False  # 不交换时间嵌入维度
        # 否则创建嵌入层的序列
        else:
            self.emb_layers = nn.Sequential(
                nn.SiLU(),  # SiLU 激活函数
                linear(
                    emb_channels,  # 嵌入通道数
                    self.emb_out_channels,  # 嵌入输出通道数
                ),
            )

        # 创建输出层的序列,包括归一化、激活函数、丢弃层和卷积层
        self.out_layers = nn.Sequential(
            normalization(self.out_channels),  # 归一化
            nn.SiLU(),  # SiLU 激活函数
            nn.Dropout(p=dropout),  # 丢弃层
            zero_module(
                conv_nd(
                    dims,  # 数据维度
                    self.out_channels,  # 输出通道数
                    self.out_channels,  # 输出通道数
                    kernel_size,  # 卷积核大小
                    padding=padding,  # 填充
                )
            ),  # 卷积层
        )

        # 根据输入和输出通道数设置跳过连接
        if self.out_channels == channels:
            self.skip_connection = nn.Identity()  # 身份映射层
        elif use_conv:
            self.skip_connection = conv_nd(
                dims, channels, self.out_channels, kernel_size, padding=padding  # 卷积层
            )
        else:
            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)  # 卷积层,卷积核大小为 1
    # 定义前向传播函数,接受输入张量和时间步嵌入
    def forward(self, x, emb):
        """
        Apply the block to a Tensor, conditioned on a timestep embedding.
        :param x: an [N x C x ...] Tensor of features.
        :param emb: an [N x emb_channels] Tensor of timestep embeddings.
        :return: an [N x C x ...] Tensor of outputs.
        """
        # 调用检查点函数以保存中间计算结果,减少内存使用
        return checkpoint(
            self._forward, (x, emb), self.parameters(), self.use_checkpoint
        )

    # 定义实际的前向传播逻辑
    def _forward(self, x, emb):
        # 如果设置了 updown,则进行上采样和下采样
        if self.updown:
            # 分离输入层的最后一层和其他层
            in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
            # 通过其他输入层处理输入 x
            h = in_rest(x)
            # 更新隐藏状态
            h = self.h_upd(h)
            # 更新输入 x
            x = self.x_upd(x)
            # 通过卷积层处理隐藏状态
            h = in_conv(h)
        else:
            # 直接通过输入层处理输入 x
            h = self.in_layers(x)

        # 如果跳过时间嵌入,则初始化嵌入输出为零张量
        if self.skip_t_emb:
            emb_out = th.zeros_like(h)
        else:
            # 通过嵌入层处理时间嵌入,确保数据类型与 h 一致
            emb_out = self.emb_layers(emb).type(h.dtype)
        # 扩展 emb_out 的形状以匹配 h 的形状
        while len(emb_out.shape) < len(h.shape):
            emb_out = emb_out[..., None]
        # 如果使用缩放和偏移规范化
        if self.use_scale_shift_norm:
            # 分离输出层中的规范化层和其他层
            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
            # 将嵌入输出分割为缩放和偏移
            scale, shift = th.chunk(emb_out, 2, dim=1)
            # 对隐藏状态进行规范化并应用缩放和偏移
            h = out_norm(h) * (1 + scale) + shift
            # 通过剩余的输出层处理隐藏状态
            h = out_rest(h)
        else:
            # 如果交换时间嵌入的维度
            if self.exchange_temb_dims:
                # 重新排列嵌入输出的维度
                emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
            # 将嵌入输出与隐藏状态相加
            h = h + emb_out
            # 通过输出层处理隐藏状态
            h = self.out_layers(h)
        # 返回输入 x 与处理后的隐藏状态的跳跃连接
        return self.skip_connection(x) + h
# 定义一个注意力模块,允许空间位置相互关注
class AttentionBlock(nn.Module):
    """
    An attention block that allows spatial positions to attend to each other.
    Originally ported from here, but adapted to the N-d case.
    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
    """

    # 初始化方法,定义模块的基本参数
    def __init__(
        self,
        channels,  # 输入通道数
        num_heads=1,  # 注意力头的数量,默认为1
        num_head_channels=-1,  # 每个头的通道数,默认为-1
        use_checkpoint=False,  # 是否使用检查点
        use_new_attention_order=False,  # 是否使用新的注意力顺序
    ):
        # 调用父类初始化方法
        super().__init__()
        self.channels = channels  # 保存输入通道数
        # 判断 num_head_channels 是否为 -1
        if num_head_channels == -1:
            self.num_heads = num_heads  # 如果为 -1,直接使用 num_heads
        else:
            # 断言通道数可以被 num_head_channels 整除
            assert (
                channels % num_head_channels == 0
            ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
            self.num_heads = channels // num_head_channels  # 计算头的数量
        self.use_checkpoint = use_checkpoint  # 保存检查点标志
        self.norm = normalization(channels)  # 初始化归一化层
        self.qkv = conv_nd(1, channels, channels * 3, 1)  # 创建卷积层用于计算 q, k, v
        # 根据是否使用新注意力顺序选择相应的注意力类
        if use_new_attention_order:
            # 在分割头之前分割 qkv
            self.attention = QKVAttention(self.num_heads)
        else:
            # 在分割 qkv 之前分割头
            self.attention = QKVAttentionLegacy(self.num_heads)

        # 初始化输出投影层
        self.proj_out = zero_module(conv_nd(1, channels, channels, 1))

    # 前向传播方法
    def forward(self, x, **kwargs):
        # TODO 添加跨帧注意力并使用混合检查点
        # 使用检查点机制来调用内部前向传播函数
        return checkpoint(
            self._forward, (x,), self.parameters(), True
        )  # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
        # return pt_checkpoint(self._forward, x)  # pytorch

    # 内部前向传播方法
    def _forward(self, x):
        b, c, *spatial = x.shape  # 解包输入张量的形状
        x = x.reshape(b, c, -1)  # 将输入张量重塑为 (batch_size, channels, spatial_dim)
        qkv = self.qkv(self.norm(x))  # 计算 q, k, v
        h = self.attention(qkv)  # 应用注意力机制
        h = self.proj_out(h)  # 对注意力结果进行投影
        return (x + h).reshape(b, c, *spatial)  # 返回重塑后的结果

# 计算注意力操作的 FLOPS
def count_flops_attn(model, _x, y):
    """
    A counter for the `thop` package to count the operations in an
    attention operation.
    Meant to be used like:
        macs, params = thop.profile(
            model,
            inputs=(inputs, timestamps),
            custom_ops={QKVAttention: QKVAttention.count_flops},
        )
    """
    b, c, *spatial = y[0].shape  # 解包输入张量的形状
    num_spatial = int(np.prod(spatial))  # 计算空间维度的总数
    # 进行两个矩阵乘法,具有相同数量的操作。
    # 第一个计算权重矩阵,第二个计算值向量的组合。
    matmul_ops = 2 * b * (num_spatial**2) * c  # 计算矩阵乘法的操作数
    model.total_ops += th.DoubleTensor([matmul_ops])  # 将操作数累加到模型的总操作数中

# 旧版 QKV 注意力模块
class QKVAttentionLegacy(nn.Module):
    """
    A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
    """

    # 初始化方法,设置注意力头的数量
    def __init__(self, n_heads):
        super().__init__()  # 调用父类初始化方法
        self.n_heads = n_heads  # 保存注意力头的数量
    # 定义前向传播方法,接收 QKV 张量
    def forward(self, qkv):
        """
        应用 QKV 注意力机制。
        :param qkv: 一个形状为 [N x (H * 3 * C) x T] 的张量,包含 Q、K 和 V。
        :return: 一个形状为 [N x (H * C) x T] 的张量,经过注意力处理后输出。
        """
        # 获取输入张量的批量大小、宽度和长度
        bs, width, length = qkv.shape
        # 确保宽度可以被 (3 * n_heads) 整除,以分割 Q、K 和 V
        assert width % (3 * self.n_heads) == 0
        # 计算每个头的通道数
        ch = width // (3 * self.n_heads)
        # 将 qkv 张量重塑并分割成 Q、K 和 V 三个部分
        q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
        # 计算缩放因子,用于稳定性
        scale = 1 / math.sqrt(math.sqrt(ch))
        # 使用爱因斯坦求和约定计算注意力权重,乘以缩放因子
        weight = th.einsum(
            "bct,bcs->bts", q * scale, k * scale
        )  # 使用 f16 比后续除法更稳定
        # 对权重进行 softmax 归一化,并保持原始数据类型
        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
        # 根据权重和 V 计算输出张量
        a = th.einsum("bts,bcs->bct", weight, v)
        # 将输出张量重塑为原始批量大小和通道数
        return a.reshape(bs, -1, length)

    # 定义静态方法以计算模型的浮点运算数
    @staticmethod
    def count_flops(model, _x, y):
        # 调用辅助函数计算注意力层的浮点运算数
        return count_flops_attn(model, _x, y)
# 定义一个名为 QKVAttention 的类,继承自 nn.Module
class QKVAttention(nn.Module):
    """
    A module which performs QKV attention and splits in a different order.
    """

    # 初始化方法,接收注意力头的数量
    def __init__(self, n_heads):
        super().__init__()  # 调用父类的初始化方法
        self.n_heads = n_heads  # 保存注意力头的数量

    # 前向传播方法,接收 qkv 张量并执行注意力计算
    def forward(self, qkv):
        """
        Apply QKV attention.
        :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
        :return: an [N x (H * C) x T] tensor after attention.
        """
        bs, width, length = qkv.shape  # 解包 qkv 张量的维度
        assert width % (3 * self.n_heads) == 0  # 确保宽度能够被注意力头数量整除
        ch = width // (3 * self.n_heads)  # 计算每个头的通道数
        q, k, v = qkv.chunk(3, dim=1)  # 将 qkv 张量分成 Q, K, V 三部分
        scale = 1 / math.sqrt(math.sqrt(ch))  # 计算缩放因子
        weight = th.einsum(
            "bct,bcs->bts",  # 定义爱因斯坦求和约定,计算权重
            (q * scale).view(bs * self.n_heads, ch, length),  # 缩放后的 Q 重塑形状
            (k * scale).view(bs * self.n_heads, ch, length),  # 缩放后的 K 重塑形状
        )  # More stable with f16 than dividing afterwards
        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)  # 计算权重的 softmax,确保其和为 1
        a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))  # 计算最终的注意力输出
        return a.reshape(bs, -1, length)  # 将输出重塑回原始批量形状

    @staticmethod
    # 计算 FLOPs 的静态方法
    def count_flops(model, _x, y):
        return count_flops_attn(model, _x, y)  # 调用函数计算注意力层的 FLOPs


# 定义一个名为 Timestep 的类,继承自 nn.Module
class Timestep(nn.Module):
    def __init__(self, dim):
        super().__init__()  # 调用父类的初始化方法
        self.dim = dim  # 保存时间步的维度

    # 前向传播方法,接收时间步张量
    def forward(self, t):
        return timestep_embedding(t, self.dim)  # 调用时间步嵌入函数


# 定义一个字典,将字符串类型映射到对应的 PyTorch 数据类型
str_to_dtype = {
    "fp32": th.float32,  # fp32 对应 float32
    "fp16": th.float16,  # fp16 对应 float16
    "bf16": th.bfloat16   # bf16 对应 bfloat16
}

# 定义一个名为 UNetModel 的类,继承自 nn.Module
class UNetModel(nn.Module):
    """
    The full UNet model with attention and timestep embedding.
    :param in_channels: channels in the input Tensor.
    :param model_channels: base channel count for the model.
    :param out_channels: channels in the output Tensor.
    :param num_res_blocks: number of residual blocks per downsample.
    :param attention_resolutions: a collection of downsample rates at which
        attention will take place. May be a set, list, or tuple.
        For example, if this contains 4, then at 4x downsampling, attention
        will be used.
    :param dropout: the dropout probability.
    :param channel_mult: channel multiplier for each level of the UNet.
    :param conv_resample: if True, use learned convolutions for upsampling and
        downsampling.
    :param dims: determines if the signal is 1D, 2D, or 3D.
    :param num_classes: if specified (as an int), then this model will be
        class-conditional with `num_classes` classes.
    :param use_checkpoint: use gradient checkpointing to reduce memory usage.
    :param num_heads: the number of attention heads in each attention layer.
    :param num_heads_channels: if specified, ignore num_heads and instead use
                               a fixed channel width per attention head.
    :param num_heads_upsample: works with num_heads to set a different number
                               of heads for upsampling. Deprecated.
    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
    """
    # 参数 resblock_updown:是否在上采样/下采样过程中使用残差块
    # 参数 use_new_attention_order:是否使用不同的注意力模式以提高效率
    """

    # 初始化方法
    def __init__(
        # 输入通道数
        self,
        in_channels,
        # 模型通道数
        model_channels,
        # 输出通道数
        out_channels,
        # 残差块的数量
        num_res_blocks,
        # 注意力分辨率
        attention_resolutions,
        # dropout 比例,默认为 0
        dropout=0,
        # 通道的倍增因子,默认值为 (1, 2, 4, 8)
        channel_mult=(1, 2, 4, 8),
        # 是否使用卷积重采样,默认为 True
        conv_resample=True,
        # 数据维度,默认为 2
        dims=2,
        # 类别数,默认为 None
        num_classes=None,
        # 是否使用检查点,默认为 False
        use_checkpoint=False,
        # 是否使用 fp16 精度,默认为 False
        use_fp16=False,
        # 注意力头数,默认为 -1
        num_heads=-1,
        # 每个头的通道数,默认为 -1
        num_head_channels=-1,
        # 上采样时的头数,默认为 -1
        num_heads_upsample=-1,
        # 是否使用尺度偏移归一化,默认为 False
        use_scale_shift_norm=False,
        # 是否使用残差块进行上采样/下采样,默认为 False
        resblock_updown=False,
        # 是否使用新的注意力顺序,默认为 False
        use_new_attention_order=False,
        # 是否使用空间变换器,支持自定义变换器
        use_spatial_transformer=False,  # custom transformer support
        # 变换器的深度,默认为 1
        transformer_depth=1,  # custom transformer support
        # 上下文维度,默认为 None
        context_dim=None,  # custom transformer support
        # 嵌入数,默认为 None
        n_embed=None,  # custom support for prediction of discrete ids into codebook of first stage vq model
        # 是否使用传统模式,默认为 True
        legacy=True,
        # 是否禁用自注意力,默认为 None
        disable_self_attentions=None,
        # 注意力块的数量,默认为 None
        num_attention_blocks=None,
        # 是否禁用中间自注意力,默认为 False
        disable_middle_self_attn=False,
        # 是否在变换器中使用线性输入,默认为 False
        use_linear_in_transformer=False,
        # 空间变换器的注意力类型,默认为 "softmax"
        spatial_transformer_attn_type="softmax",
        # 输入通道数,默认为 None
        adm_in_channels=None,
        # 是否使用 Fairscale 检查点,默认为 False
        use_fairscale_checkpoint=False,
        # 是否将计算卸载到 CPU,默认为 False
        offload_to_cpu=False,
        # 中间变换器的深度,默认为 None
        transformer_depth_middle=None,
        # 配置条件嵌入维度,默认为 None
        cfg_cond_embed_dim=None,
        # 数据类型,默认为 "fp32"
        dtype="fp32",
    # 将模型的主体转换为 float16
    def convert_to_fp16(self):
        """
        将模型的主体转换为 float16。
        """
        # 对输入块应用转换模块,将其转换为 float16
        self.input_blocks.apply(convert_module_to_f16)
        # 对中间块应用转换模块,将其转换为 float16
        self.middle_block.apply(convert_module_to_f16)
        # 对输出块应用转换模块,将其转换为 float16
        self.output_blocks.apply(convert_module_to_f16)

    # 将模型的主体转换为 float32
    def convert_to_fp32(self):
        """
        将模型的主体转换为 float32。
        """
        # 对输入块应用转换模块,将其转换为 float32
        self.input_blocks.apply(convert_module_to_f32)
        # 对中间块应用转换模块,将其转换为 float32
        self.middle_block.apply(convert_module_to_f32)
        # 对输出块应用转换模块,将其转换为 float32
        self.output_blocks.apply(convert_module_to_f32)
    # 定义前向传播函数,接收输入数据和其他参数
    def forward(self, x, timesteps=None, context=None, y=None, scale_emb=None, **kwargs):
        """
        应用模型于输入批次。
        :param x: 输入张量,形状为 [N x C x ...]。
        :param timesteps: 一维时间步批次。
        :param context: 通过 crossattn 插入的条件信息。
        :param y: 标签张量,形状为 [N],如果是类条件。
        :return: 输出张量,形状为 [N x C x ...]。
        """
        # 如果输入数据类型不匹配,则转换为模型所需的数据类型
        if x.dtype != self.dtype:
            x = x.to(self.dtype)
    
        # 确保 y 的存在性与类数设置一致
        assert (y is not None) == (
            self.num_classes is not None
        ), "must specify y if and only if the model is class-conditional"
        # 初始化存储中间结果的列表
        hs = []
    
        # 生成时间步嵌入
        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype)
        # 如果提供了缩放嵌入,则进行相应处理
        if scale_emb is not None:
            assert hasattr(self, "w_proj"), "w_proj not found in the model"
            t_emb = t_emb + self.w_proj(scale_emb.to(self.dtype))
        # 通过时间嵌入生成最终嵌入
        emb = self.time_embed(t_emb)
    
        # 如果模型是类条件,则将标签嵌入加入到最终嵌入中
        if self.num_classes is not None:
            assert y.shape[0] == x.shape[0]
            emb = emb + self.label_emb(y)
    
        # 将输入数据赋值给 h
        # h = x.type(self.dtype)
        h = x
        # 通过输入模块处理 h,并保存中间结果
        for module in self.input_blocks:
            h = module(h, emb, context)
            hs.append(h)
        # 通过中间模块进一步处理 h
        h = self.middle_block(h, emb, context)
        # 通过输出模块处理 h,并逐层合并中间结果
        for module in self.output_blocks:
            h = th.cat([h, hs.pop()], dim=1)
            h = module(h, emb, context)
        # 将 h 转换回原输入数据类型
        h = h.type(x.dtype)
        # 检查是否支持预测码本 ID
        if self.predict_codebook_ids:
            assert False, "not supported anymore. what the f*** are you doing?"
        else:
            # 返回最终输出结果
            return self.out(h)

.\cogview3-finetune\sat\sgm\modules\diffusionmodules\sampling.py

# 部分代码移植自 https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
"""
    Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
"""

# 从 typing 模块导入字典和联合类型
from typing import Dict, Union

# 导入 PyTorch 库
import torch
# 从 omegaconf 模块导入配置相关的类
from omegaconf import ListConfig, OmegaConf
# 导入 tqdm 库用于显示进度条
from tqdm import tqdm

# 从相对路径模块导入采样相关的工具函数
from ...modules.diffusionmodules.sampling_utils import (
    get_ancestral_step,  # 获取祖先步骤
    linear_multistep_coeff,  # 线性多步骤系数
    to_d,  # 转换为 d
    to_neg_log_sigma,  # 转换为负对数sigma
    to_sigma,  # 转换为 sigma
)
# 从相对路径模块导入离散化工具
from ...modules.diffusionmodules.discretizer import generate_roughly_equally_spaced_steps
# 从相对路径模块导入工具函数
from ...util import append_dims, default, instantiate_from_config

# 定义默认引导器配置
DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}

# 定义用于生成引导嵌入的函数
def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
    """
    参考文献: https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298

    Args:
        timesteps (`torch.Tensor`):
            在这些时间步生成嵌入向量
        embedding_dim (`int`, *可选*, 默认为 512):
            生成的嵌入的维度
        dtype:
            生成嵌入的数据类型

    Returns:
        `torch.FloatTensor`: 形状为 `(len(timesteps), embedding_dim)` 的嵌入向量
    """
    # 确保输入张量是一个一维张量
    assert len(w.shape) == 1
    # 将输入乘以 1000.0
    w = w * 1000.0

    # 计算嵌入维度的一半
    half_dim = embedding_dim // 2
    # 计算基础嵌入的系数
    emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
    # 生成嵌入基础,转换为指数形式并调整为目标设备和数据类型
    emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb).to(w.device).to(w.dtype)
    # 生成最终的嵌入向量
    emb = w.to(dtype)[:, None] * emb[None, :]
    # 将正弦和余弦值连接在一起
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    # 如果嵌入维度为奇数,进行零填充
    if embedding_dim % 2 == 1:  # zero pad
        emb = torch.nn.functional.pad(emb, (0, 1))
    # 确保生成的嵌入形状与预期一致
    assert emb.shape == (w.shape[0], embedding_dim)
    # 返回生成的嵌入向量
    return emb

# 定义基础扩散采样器类
class BaseDiffusionSampler:
    # 初始化采样器
    def __init__(
        self,
        discretization_config: Union[Dict, ListConfig, OmegaConf],  # 离散化配置
        num_steps: Union[int, None] = None,  # 采样步数,默认为 None
        guider_config: Union[Dict, ListConfig, OmegaConf, None] = None,  # 引导器配置,默认为 None
        cfg_cond_scale: Union[int, None] = None,  # 条件缩放参数,默认为 None
        cfg_cond_embed_dim: Union[int, None] = 256,  # 条件嵌入维度,默认为 256
        verbose: bool = False,  # 是否显示详细信息
        device: str = "cuda",  # 设备类型,默认为 CUDA
    ):
        # 设置采样步数
        self.num_steps = num_steps
        # 实例化离散化配置
        self.discretization = instantiate_from_config(discretization_config)
        # 实例化引导器配置
        self.guider = instantiate_from_config(
            default(
                guider_config,
                DEFAULT_GUIDER,
            )
        )

        # 设置条件参数
        self.cfg_cond_scale = cfg_cond_scale
        self.cfg_cond_embed_dim = cfg_cond_embed_dim
        
        # 设置详细模式和设备
        self.verbose = verbose
        self.device = device

    # 准备采样循环的函数
    def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
        # 生成 sigma 值
        sigmas = self.discretization(
            self.num_steps if num_steps is None else num_steps, device=self.device
        )
        # 默认使用条件
        uc = default(uc, cond)

        # 根据 sigma 计算 x 的调整
        x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
        # 获取 sigma 的数量
        num_sigmas = len(sigmas)

        # 创建新的一维张量 s_in,初始值为 1
        s_in = x.new_ones([x.shape[0]]).float()

        # 返回调整后的 x 和其他参数
        return x, s_in, sigmas, num_sigmas, cond, uc
    # 定义去噪函数,接受输入x、去噪器denoiser、噪声水平sigma、条件cond和无条件uc
        def denoise(self, x, denoiser, sigma, cond, uc):
            # 检查条件缩放系数是否不为None
            if self.cfg_cond_scale is not None:
                # 获取输入批次的大小
                batch_size = x.shape[0]
                # 创建与批次大小相同的全1张量,并乘以条件缩放系数,生成缩放嵌入
                scale_emb = guidance_scale_embedding(torch.ones(batch_size, device=x.device) * self.cfg_cond_scale, embedding_dim=self.cfg_cond_embed_dim, dtype=x.dtype)
                # 使用去噪器处理输入,传入缩放嵌入
                denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), scale_emb=scale_emb)
            else:
                # 若无条件缩放系数,直接使用去噪器处理输入
                denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc))
            # 对去噪后的结果进行进一步引导处理
            denoised = self.guider(denoised, sigma)
            # 返回最终去噪结果
            return denoised
    
        # 定义生成sigma的函数,接受sigma数量num_sigmas
        def get_sigma_gen(self, num_sigmas):
            # 创建一个范围生成器,从0到num_sigmas-1
            sigma_generator = range(num_sigmas - 1)
            # 如果启用了详细输出
            if self.verbose:
                # 打印分隔线和采样设置信息
                print("#" * 30, " Sampling setting ", "#" * 30)
                print(f"Sampler: {self.__class__.__name__}")
                print(f"Discretization: {self.discretization.__class__.__name__}")
                print(f"Guider: {self.guider.__class__.__name__}")
                # 使用tqdm包装生成器以显示进度条
                sigma_generator = tqdm(
                    sigma_generator,
                    total=num_sigmas,
                    desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps",
                )
            # 返回sigma生成器
            return sigma_generator
# 定义一个单步扩散采样器类,继承自基本扩散采样器
class SingleStepDiffusionSampler(BaseDiffusionSampler):
    # 定义采样步骤方法,未实现
    def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs):
        # 抛出未实现错误,表明该方法需在子类中实现
        raise NotImplementedError

    # 定义欧拉步骤方法,用于计算下一个状态
    def euler_step(self, x, d, dt):
        # 返回更新后的状态,基于当前状态、导数和时间增量
        return x + dt * d


# 定义 EDM 采样器类,继承自单步扩散采样器
class EDMSampler(SingleStepDiffusionSampler):
    # 初始化 EDM 采样器的参数
    def __init__(
        self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
    ):
        # 调用父类的初始化方法
        super().__init__(*args, **kwargs)

        # 设置采样器的参数
        self.s_churn = s_churn  # 变化率
        self.s_tmin = s_tmin    # 最小时间
        self.s_tmax = s_tmax    # 最大时间
        self.s_noise = s_noise  # 噪声强度

    # 定义采样步骤方法
    def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0):
        # 计算调整后的 sigma 值
        sigma_hat = sigma * (gamma + 1.0)
        # 如果 gamma 大于 0,加入噪声
        if gamma > 0:
            # 生成与 x 形状相同的随机噪声
            eps = torch.randn_like(x) * self.s_noise
            # 更新 x 的值,加入噪声
            x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5

        # 去噪,得到去噪后的结果
        denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)
        # 计算导数
        d = to_d(x, sigma_hat, denoised)
        # 计算时间增量
        dt = append_dims(next_sigma - sigma_hat, x.ndim)

        # 执行欧拉步骤,更新 x
        euler_step = self.euler_step(x, d, dt)
        # 进行可能的修正步骤,得到最终的 x
        x = self.possible_correction_step(
            euler_step, x, d, dt, next_sigma, denoiser, cond, uc
        )
        # 返回更新后的 x
        return x

    # 定义调用方法
    def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
        # 准备采样循环所需的参数
        x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
            x, cond, uc, num_steps
        )

        # 遍历 sigma 值
        for i in self.get_sigma_gen(num_sigmas):
            # 计算 gamma 值
            gamma = (
                min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
                if self.s_tmin <= sigmas[i] <= self.s_tmax
                else 0.0
            )
            # 执行采样步骤,更新 x
            x = self.sampler_step(
                s_in * sigmas[i],
                s_in * sigmas[i + 1],
                denoiser,
                x,
                cond,
                uc,
                gamma,
            )

        # 返回最终的 x
        return x


# 定义 DDIM 采样器类,继承自单步扩散采样器
class DDIMSampler(SingleStepDiffusionSampler):
    # 初始化 DDIM 采样器的参数
    def __init__(
        self, s_noise=0.1, *args, **kwargs
    ):
        # 调用父类的初始化方法
        super().__init__(*args, **kwargs)

        # 设置噪声强度
        self.s_noise = s_noise

    # 定义采样步骤方法
    def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0):

        # 去噪,得到去噪后的结果
        denoised = self.denoise(x, denoiser, sigma, cond, uc)
        # 计算导数
        d = to_d(x, sigma, denoised)
        # 计算时间增量
        dt = append_dims(next_sigma * (1 - s_noise**2)**0.5 - sigma, x.ndim)

        # 计算欧拉步骤,加入噪声
        euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x)

        # 进行可能的修正步骤,得到最终的 x
        x = self.possible_correction_step(
            euler_step, x, d, dt, next_sigma, denoiser, cond, uc
        )
        # 返回更新后的 x
        return x
    # 定义一个可调用的类方法,接收去噪器、输入数据、条件及其他参数
    def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
        # 准备采样循环,返回处理后的数据和相关参数
        x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
            x, cond, uc, num_steps
        )
    
        # 遍历生成的 sigma 值
        for i in self.get_sigma_gen(num_sigmas):
            # 执行采样步骤,更新输入数据 x
            x = self.sampler_step(
                s_in * sigmas[i],    # 当前 sigma 乘以输入信号
                s_in * sigmas[i + 1],# 下一个 sigma 乘以输入信号
                denoiser,            # 传递去噪器
                x,                   # 当前数据
                cond,                # 条件信息
                uc,                  # 可选的额外条件
                self.s_noise,        # 传递噪声信息
            )
    
        # 返回最终处理后的数据
        return x
# 定义一个继承自 SingleStepDiffusionSampler 的类 AncestralSampler
class AncestralSampler(SingleStepDiffusionSampler):
    # 初始化方法,设定默认参数 eta 和 s_noise
    def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs):
        # 调用父类的初始化方法
        super().__init__(*args, **kwargs)

        # 设置 eta 属性
        self.eta = eta
        # 设置 s_noise 属性
        self.s_noise = s_noise
        # 定义噪声采样器,生成与输入形状相同的随机噪声
        self.noise_sampler = lambda x: torch.randn_like(x)

    # 定义 ancestral_euler_step 方法,用于执行欧拉步长
    def ancestral_euler_step(self, x, denoised, sigma, sigma_down):
        # 计算偏导数 d
        d = to_d(x, sigma, denoised)
        # 将 sigma_down 和 sigma 的差值扩展到 x 的维度
        dt = append_dims(sigma_down - sigma, x.ndim)

        # 返回欧拉步长的结果
        return self.euler_step(x, d, dt)

    # 定义 ancestral_step 方法,执行采样步骤
    def ancestral_step(self, x, sigma, next_sigma, sigma_up):
        # 根据条件选择更新 x 的值
        x = torch.where(
            append_dims(next_sigma, x.ndim) > 0.0,  # 检查 next_sigma 是否大于 0
            x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim),  # 更新 x 的值
            x,  # 保持原值
        )
        # 返回更新后的 x
        return x

    # 定义调用方法,使得类可以被调用
    def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
        # 准备采样循环,获取必要的输入
        x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
            x, cond, uc, num_steps
        )

        # 遍历 sigma 生成器,进行采样步骤
        for i in self.get_sigma_gen(num_sigmas):
            x = self.sampler_step(
                s_in * sigmas[i],  # 当前 sigma 值
                s_in * sigmas[i + 1],  # 下一个 sigma 值
                denoiser,  # 去噪器
                x,  # 当前 x 值
                cond,  # 条件
                uc,  # 额外条件
            )

        # 返回最终的 x 值
        return x


# 定义一个继承自 BaseDiffusionSampler 的类 LinearMultistepSampler
class LinearMultistepSampler(BaseDiffusionSampler):
    # 初始化方法,设定默认的 order 参数
    def __init__(
        self,
        order=4,
        *args,
        **kwargs,
    ):
        # 调用父类的初始化方法
        super().__init__(*args, **kwargs)

        # 设置 order 属性
        self.order = order

    # 定义调用方法,使得类可以被调用
    def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
        # 准备采样循环,获取必要的输入
        x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
            x, cond, uc, num_steps
        )

        # 初始化一个列表 ds 用于存储导数
        ds = []
        # 将 sigmas 从 GPU 移到 CPU,并转换为 numpy 数组
        sigmas_cpu = sigmas.detach().cpu().numpy()
        # 遍历 sigma 生成器
        for i in self.get_sigma_gen(num_sigmas):
            # 计算当前的 sigma
            sigma = s_in * sigmas[i]
            # 使用去噪器处理当前输入
            denoised = denoiser(
                *self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs
            )
            # 使用引导函数对去噪结果进行处理
            denoised = self.guider(denoised, sigma)
            # 计算导数 d
            d = to_d(x, sigma, denoised)
            # 将导数添加到列表 ds
            ds.append(d)
            # 如果 ds 的长度超过 order,移除最早的元素
            if len(ds) > self.order:
                ds.pop(0)
            # 计算当前的阶数
            cur_order = min(i + 1, self.order)
            # 计算当前阶数的线性多步系数
            coeffs = [
                linear_multistep_coeff(cur_order, sigmas_cpu, i, j)
                for j in range(cur_order)
            ]
            # 更新 x 值
            x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))

        # 返回最终的 x 值
        return x


# 定义一个继承自 EDMSampler 的类 EulerEDMSampler
class EulerEDMSampler(EDMSampler):
    # 定义可能的校正步骤方法
    def possible_correction_step(
        self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
    ):
        # 返回 euler_step,表示不进行额外的校正
        return euler_step


# 定义一个继承自 EDMSampler 的类 HeunEDMSampler
class HeunEDMSampler(EDMSampler):
    # 定义可能的校正步骤方法
    def possible_correction_step(
        self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
    ):
    ):
        # 如果下一个噪声水平的总和小于一个非常小的阈值
        if torch.sum(next_sigma) < 1e-14:
            # 如果所有噪声水平为0,保存网络评估的结果
            return euler_step
        else:
            # 使用去噪器对当前步进行去噪处理
            denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc)
            # 将去噪后的结果转换为新数据
            d_new = to_d(euler_step, next_sigma, denoised)
            # 计算当前数据与新数据的平均值
            d_prime = (d + d_new) / 2.0

            # 如果噪声水平不为0,则应用修正
            x = torch.where(
                # 检查噪声水平是否大于0,决定是否修正
                append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step
            )
            # 返回修正后的结果
            return x
# 定义一个 Euler 祖先采样器类,继承自 AncestralSampler
class EulerAncestralSampler(AncestralSampler):
    # 定义采样步骤的方法,接受多个参数
    def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc):
        # 获取下一个采样步的 sigma 值
        sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
        # 使用去噪器对当前输入进行去噪
        denoised = self.denoise(x, denoiser, sigma, cond, uc)
        # 使用 Euler 方法更新 x 的值
        x = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
        # 应用祖先步骤更新 x 的值
        x = self.ancestral_step(x, sigma, next_sigma, sigma_up)

        # 返回更新后的 x
        return x


# 定义一个 DPMPP2S 祖先采样器类,继承自 AncestralSampler
class DPMPP2SAncestralSampler(AncestralSampler):
    # 获取变量的方法,计算相关参数
    def get_variables(self, sigma, sigma_down):
        # 将 sigma 和 sigma_down 转换为负对数形式
        t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)]
        # 计算时间间隔 h
        h = t_next - t
        # 计算 s 值
        s = t + 0.5 * h
        # 返回计算的参数
        return h, s, t, t_next

    # 获取乘法因子的方法
    def get_mult(self, h, s, t, t_next):
        # 计算各个乘法因子
        mult1 = to_sigma(s) / to_sigma(t)
        mult2 = (-0.5 * h).expm1()
        mult3 = to_sigma(t_next) / to_sigma(t)
        mult4 = (-h).expm1()

        # 返回所有乘法因子
        return mult1, mult2, mult3, mult4

    # 采样步骤的方法,执行多个计算步骤
    def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs):
        # 获取下一个采样步的 sigma 值
        sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
        # 对输入进行去噪
        denoised = self.denoise(x, denoiser, sigma, cond, uc)
        # 使用 Euler 方法更新 x 的值
        x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down)

        # 检查 sigma_down 是否接近于零
        if torch.sum(sigma_down) < 1e-14:
            # 如果噪声级别为 0,则保存网络评估
            x = x_euler
        else:
            # 获取变量 h, s, t, t_next
            h, s, t, t_next = self.get_variables(sigma, sigma_down)
            # 获取乘法因子,并调整维度
            mult = [
                append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)
            ]

            # 更新 x 的值
            x2 = mult[0] * x - mult[1] * denoised
            # 对 x2 进行去噪
            denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
            # 计算最终的 x 值
            x_dpmpp2s = mult[2] * x - mult[3] * denoised2

            # 如果噪声级别不为 0,则应用校正
            x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler)

        # 最终应用祖先步骤更新 x
        x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
        # 返回更新后的 x
        return x


# 定义一个 DPMPP2M 采样器类,继承自 BaseDiffusionSampler
class DPMPP2MSampler(BaseDiffusionSampler):
    # 获取变量的方法,计算相关参数
    def get_variables(self, sigma, next_sigma, previous_sigma=None):
        # 将 sigma 和 next_sigma 转换为负对数形式
        t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
        # 计算时间间隔 h
        h = t_next - t

        # 如果提供了 previous_sigma,则进行额外计算
        if previous_sigma is not None:
            h_last = t - to_neg_log_sigma(previous_sigma)
            r = h_last / h
            return h, r, t, t_next
        else:
            # 如果没有提供,则返回 h 和 t 值
            return h, None, t, t_next

    # 获取乘法因子的方法
    def get_mult(self, h, r, t, t_next, previous_sigma):
        # 计算基础乘法因子
        mult1 = to_sigma(t_next) / to_sigma(t)
        mult2 = (-h).expm1()

        # 如果提供了 previous_sigma,则计算额外的乘法因子
        if previous_sigma is not None:
            mult3 = 1 + 1 / (2 * r)
            mult4 = 1 / (2 * r)
            return mult1, mult2, mult3, mult4
        else:
            # 返回基本的乘法因子
            return mult1, mult2

    # 采样步骤的方法,执行多个计算步骤
    def sampler_step(
        self,
        old_denoised,
        previous_sigma,
        sigma,
        next_sigma,
        denoiser,
        x,
        cond,
        uc=None,
    ):
        # 使用去噪器对输入数据进行去噪,返回去噪后的结果
        denoised = self.denoise(x, denoiser, sigma, cond, uc)

        # 获取当前和下一个噪声级别相关的变量
        h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
        # 计算多重系数,扩展维度以匹配输入数据的维度
        mult = [
            append_dims(mult, x.ndim)
            for mult in self.get_mult(h, r, t, t_next, previous_sigma)
        ]

        # 计算标准化后的输出
        x_standard = mult[0] * x - mult[1] * denoised
        # 检查之前的去噪结果是否存在或下一噪声级别是否接近零
        if old_denoised is None or torch.sum(next_sigma) < 1e-14:
            # 如果噪声级别为零或处于第一步,返回标准化结果和去噪结果
            return x_standard, denoised
        else:
            # 计算去噪后的数据修正值
            denoised_d = mult[2] * denoised - mult[3] * old_denoised
            # 计算高级输出
            x_advanced = mult[0] * x - mult[1] * denoised_d

            # 如果噪声级别不为零且不是第一步,应用修正
            x = torch.where(
                append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
            )

        # 返回最终输出和去噪结果
        return x, denoised

    def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
        # 准备采样循环,包括输入数据和条件信息的处理
        x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
            x, cond, uc, num_steps
        )

        old_denoised = None
        # 遍历噪声级别生成器
        for i in self.get_sigma_gen(num_sigmas):
            # 在每个步骤中执行采样,更新去噪结果
            x, old_denoised = self.sampler_step(
                old_denoised,
                None if i == 0 else s_in * sigmas[i - 1],
                s_in * sigmas[i],
                s_in * sigmas[i + 1],
                denoiser,
                x,
                cond,
                uc=uc,
            )

        # 返回最终的去噪结果
        return x
# 定义一个将输入信号传递到去噪器的函数
def relay_to_d(x, sigma, denoised, image, step, total_step):
    # 计算模糊度的变化量
    blurring_d = (denoised - image) / total_step
    # 根据模糊度和当前步长更新去噪图像
    blurring_denoised = image + blurring_d * step
    # 计算当前信号与去噪信号的差异,标准化为 sigma 的维度
    d = (x - blurring_denoised) / append_dims(sigma, x.ndim)
    # 返回计算得到的差异和模糊度变化
    return d, blurring_d
    

# 定义一个线性中继EDM采样器,继承自EulerEDMSampler
class LinearRelayEDMSampler(EulerEDMSampler):
    # 初始化函数,设定部分步数
    def __init__(self, partial_num_steps=20, *args, **kwargs):
        # 调用父类初始化方法
        super().__init__(*args, **kwargs)
        # 设置部分步数
        self.partial_num_steps = partial_num_steps

    # 定义采样调用方法
    def __call__(self, denoiser, image, randn, cond, uc=None, num_steps=None):
        # 克隆随机数以保持不变
        randn_unit = randn.clone()
        # 准备采样循环,获取相关参数
        randn, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
            randn, cond, uc, num_steps
        )
        # 初始化 x 为 None
        x = None

        # 遍历生成的 sigma 值
        for i in self.get_sigma_gen(num_sigmas):
            # 如果当前步数小于总步数减去部分步数,继续下一次循环
            if i < self.num_steps - self.partial_num_steps:
                continue
            # 如果 x 还未初始化,则根据图像和随机数计算初始值
            if x is None:
                x = image + randn_unit * append_dims(s_in * sigmas[i], len(randn_unit.shape))

            # 计算 gamma 值,控制采样过程中的噪声
            gamma = (
                min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
                if self.s_tmin <= sigmas[i] <= self.s_tmax
                else 0.0
            )
            # 进行一次采样步骤
            x = self.sampler_step(
                s_in * sigmas[i],
                s_in * sigmas[i + 1],
                denoiser,
                x,
                cond,
                uc,
                gamma,
                step=i - self.num_steps + self.partial_num_steps,
                image=image,
                index=self.num_steps - i,
            )

        # 返回最终的图像
        return x

    # 定义欧拉步骤的计算方法
    def euler_step(self, x, d, dt, blurring_d):
        # 更新 x 的值
        return x + dt * d + blurring_d

    # 定义采样步骤的计算方法
    def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0, step=None, image=None, index=None):
        # 计算 sigma_hat,考虑 gamma 的影响
        sigma_hat = sigma * (gamma + 1.0)
        # 如果 gamma 大于 0,添加噪声
        if gamma > 0:
            eps = torch.randn_like(x) * self.s_noise
            x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5

        # 使用去噪器去噪当前图像
        denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)
        # 计算 beta_t,控制去噪过程
        beta_t = next_sigma / sigma_hat * index / self.partial_num_steps - (index - 1) / self.partial_num_steps
        # 更新 x 的值,结合去噪结果
        x = x * append_dims(next_sigma / sigma_hat, x.ndim) + denoised * append_dims(1 - next_sigma / sigma_hat + beta_t, x.ndim) - image * append_dims(beta_t, x.ndim)
        # 返回更新后的图像
        return x
    

# 定义零信噪比DDIM采样器,继承自SingleStepDiffusionSampler
class ZeroSNRDDIMSampler(SingleStepDiffusionSampler):
    # 初始化函数,设定是否使用条件生成
    def __init__(
        self,
        do_cfg=True,
        *args,
        **kwargs,
    ):
        # 调用父类初始化方法
        super().__init__(*args, **kwargs)
        # 设置条件生成标志
        self.do_cfg = do_cfg

    # 准备采样循环的参数
    def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
        # 计算累积的 alpha 值,并获取对应的索引
        alpha_cumprod_sqrt, indices = self.discretization(
            self.num_steps if num_steps is None else num_steps, device=self.device, return_idx=True
        )
        # 如果 uc 为 None,则使用 cond
        uc = default(uc, cond)

        # 获取 sigma 的数量
        num_sigmas = len(alpha_cumprod_sqrt)

        # 初始化 s_in 为全 1 向量
        s_in = x.new_ones([x.shape[0]])

        # 返回准备好的参数
        return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, indices
    # 定义去噪函数,接受输入数据和其他参数
        def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, i=None, idx=None):
            # 初始化额外的模型输入字典
            additional_model_inputs = {}
            # 如果启用 CFG,准备包含索引的输入
            if self.do_cfg:
                additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * idx] * 2)
            # 否则只准备单个索引输入
            else:
                additional_model_inputs['idx'] = torch.cat([x.new_ones([x.shape[0]]) * idx])
            # 使用去噪器处理准备好的输入和额外参数,得到去噪后的结果
            denoised = denoiser(*self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs)
            # 使用引导器进一步处理去噪后的结果
            denoised = self.guider(denoised, alpha_cumprod_sqrt, step=i, num_steps=self.num_steps)
            # 返回去噪后的结果
            return denoised
    
    # 定义采样步骤函数,执行去噪和更新过程
        def sampler_step(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, denoiser, x, cond, uc=None, i=None, idx=None, return_denoised=False):
            # 调用去噪函数,并转换结果为浮点型
            denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, i, idx).to(torch.float32)
            # 如果达到最后一步,返回去噪结果
            if i == self.num_steps - 1:
                if return_denoised:
                    return denoised, denoised
                return denoised
    
            # 计算当前步骤的 a_t 值
            a_t = ((1-next_alpha_cumprod_sqrt**2)/(1-alpha_cumprod_sqrt**2))**0.5
            # 计算当前步骤的 b_t 值
            b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t
    
            # 更新 x 的值,结合去噪后的结果
            x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised
            # 根据需要返回去噪结果
            if return_denoised:
                return x, denoised
            return x
    
    # 定义可调用函数,用于处理采样和去噪流程
        def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
            # 准备采样循环所需的输入数据
            x, s_in, alpha_cumprod_sqrts, num_sigmas, cond, uc, indices = self.prepare_sampling_loop(
                x, cond, uc, num_steps
            )
    
            # 根据 sigma 生成器逐步执行采样
            for i in self.get_sigma_gen(num_sigmas):
                x = self.sampler_step(
                    s_in * alpha_cumprod_sqrts[i],
                    s_in * alpha_cumprod_sqrts[i + 1],
                    denoiser,
                    x,
                    cond,
                    uc,
                    i=i,
                    idx=indices[self.num_steps-i-1],
                )
    
            # 返回最终的结果
            return x

标签:None,num,3Plus,CogView,self,channels,源码,sigma,out
From: https://www.cnblogs.com/apachecn/p/18494398

相关文章