首页 > 编程语言 >CogVideo---CogVideoX-微调代码源码解析-十-

CogVideo---CogVideoX-微调代码源码解析-十-

时间:2024-10-23 09:24:40浏览次数:6  
标签:None CogVideoX self torch --- 源码 return sigma def

CogVideo & CogVideoX 微调代码源码解析(十)

.\cogvideo-finetune\sat\sgm\modules\diffusionmodules\sampling.py

# 部分移植自 https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
"""

# 从类型提示模块导入字典和联合类型
from typing import Dict, Union

# 导入 PyTorch 库
import torch
# 从 OmegaConf 导入列表配置和 OmegaConf 类
from omegaconf import ListConfig, OmegaConf
# 从 tqdm 导入进度条功能
from tqdm import tqdm

# 从自定义模块导入必要的函数
from ...modules.diffusionmodules.sampling_utils import (
    get_ancestral_step,  # 获取祖先步骤的函数
    linear_multistep_coeff,  # 线性多步骤系数的函数
    to_d,  # 转换到 d 的函数
    to_neg_log_sigma,  # 转换为负对数 sigma 的函数
    to_sigma,  # 转换为 sigma 的函数
)
# 从 util 模块导入辅助函数
from ...util import append_dims, default, instantiate_from_config
from ...util import SeededNoise  # 导入带种子的噪声生成器

# 从 guiders 模块导入动态 CFG 类
from .guiders import DynamicCFG

# 默认引导器配置
DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}

# 定义基础扩散采样器类
class BaseDiffusionSampler:
    # 初始化方法,设置相关配置
    def __init__(
        self,
        discretization_config: Union[Dict, ListConfig, OmegaConf],  # 离散化配置
        num_steps: Union[int, None] = None,  # 步骤数
        guider_config: Union[Dict, ListConfig, OmegaConf, None] = None,  # 引导器配置
        verbose: bool = False,  # 是否详细输出
        device: str = "cuda",  # 使用的设备
    ):
        self.num_steps = num_steps  # 保存步骤数
        # 根据配置实例化离散化对象
        self.discretization = instantiate_from_config(discretization_config)
        # 根据配置实例化引导器对象,使用默认值如果未提供
        self.guider = instantiate_from_config(
            default(
                guider_config,
                DEFAULT_GUIDER,
            )
        )
        self.verbose = verbose  # 设置详细输出标志
        self.device = device  # 设置设备

    # 准备采样循环的方法
    def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
        # 计算 sigma 值
        sigmas = self.discretization(self.num_steps if num_steps is None else num_steps, device=self.device)
        # 默认情况下使用条件输入
        uc = default(uc, cond)

        # 对输入进行缩放
        x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
        num_sigmas = len(sigmas)  # 获取 sigma 数量

        # 创建与输入样本数相同的全 1 张量
        s_in = x.new_ones([x.shape[0]]).float()

        # 返回准备好的参数
        return x, s_in, sigmas, num_sigmas, cond, uc

    # 去噪声的方法
    def denoise(self, x, denoiser, sigma, cond, uc):
        # 准备输入并进行去噪声处理
        denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc))
        # 使用引导器进一步处理去噪声结果
        denoised = self.guider(denoised, sigma)
        # 返回去噪声后的结果
        return denoised

    # 获取 sigma 生成器的方法
    def get_sigma_gen(self, num_sigmas):
        sigma_generator = range(num_sigmas - 1)  # 创建 sigma 生成器范围
        if self.verbose:  # 如果启用详细输出
            print("#" * 30, " Sampling setting ", "#" * 30)  # 输出分隔符
            print(f"Sampler: {self.__class__.__name__}")  # 输出采样器类名
            print(f"Discretization: {self.discretization.__class__.__name__}")  # 输出离散化类名
            print(f"Guider: {self.guider.__class__.__name__}")  # 输出引导器类名
            # 包装 sigma 生成器为 tqdm 对象,以便显示进度条
            sigma_generator = tqdm(
                sigma_generator,
                total=num_sigmas,
                desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps",
            )
        # 返回 sigma 生成器
        return sigma_generator

# 定义单步扩散采样器类,继承自基础扩散采样器
class SingleStepDiffusionSampler(BaseDiffusionSampler):
    # 定义采样步骤的抽象方法
    def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs):
        raise NotImplementedError  # 抛出未实现错误

    # 欧拉步骤的方法
    def euler_step(self, x, d, dt):
        # 计算下一个状态
        return x + dt * d

# 定义 EDM 采样器类,继承自单步扩散采样器
class EDMSampler(SingleStepDiffusionSampler):
    # 初始化方法,设置相关参数
    def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs):
        super().__init__(*args, **kwargs)  # 调用父类初始化

        self.s_churn = s_churn  # 设置 churn 参数
        self.s_tmin = s_tmin  # 设置最小时间
        self.s_tmax = s_tmax  # 设置最大时间
        self.s_noise = s_noise  # 设置噪声参数
    # 定义sampler_step函数,用于执行采样步骤
    def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0):
        # 计算sigma_hat,用于计算噪声
        sigma_hat = sigma * (gamma + 1.0)
        # 如果gamma大于0,生成服从标准正态分布的随机数eps,并将其乘以s_noise,再加到x上
        if gamma > 0:
            eps = torch.randn_like(x) * self.s_noise
            x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5

        # 使用denoiser对x进行去噪,得到denoised
        denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)
        # 计算d,用于后续步骤
        d = to_d(x, sigma_hat, denoised)
        # 计算dt,用于后续步骤
        dt = append_dims(next_sigma - sigma_hat, x.ndim)

        # 使用欧拉步骤更新x
        euler_step = self.euler_step(x, d, dt)
        # 使用可能的修正步骤更新x
        x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)
        # 返回更新后的x
        return x

    # 定义__call__函数,用于执行整个采样过程
    def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
        # 准备采样循环所需的变量
        x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)

        # 遍历sigmas,执行采样步骤
        for i in self.get_sigma_gen(num_sigmas):
            # 计算gamma,用于控制噪声的大小
            gamma = (
                min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0
            )
            # 执行采样步骤
            x = self.sampler_step(
                s_in * sigmas[i],
                s_in * sigmas[i + 1],
                denoiser,
                x,
                cond,
                uc,
                gamma,
            )

        # 返回最终的x
        return x
# 定义 DDIMSampler 类,继承自 SingleStepDiffusionSampler 类
class DDIMSampler(SingleStepDiffusionSampler):
    # 初始化方法,接受 s_noise 参数,默认值为 0.1
    def __init__(self, s_noise=0.1, *args, **kwargs):
        # 调用父类的初始化方法
        super().__init__(*args, **kwargs)

        # 设置实例属性 s_noise 为传入的 s_noise 参数值
        self.s_noise = s_noise

    # sampler_step 方法,接受 sigma、next_sigma、denoiser、x、cond、uc 和 s_noise 参数
    def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0):
        # 使用 denoiser 对 x 进行去噪,得到 denoised
        denoised = self.denoise(x, denoiser, sigma, cond, uc)
        # 计算 d,使用 to_d 函数
        d = to_d(x, sigma, denoised)
        # 计算 dt
        dt = append_dims(next_sigma * (1 - s_noise**2) ** 0.5 - sigma, x.ndim)
        # 计算 euler_step
        euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x)
        # 调用 possible_correction_step 方法,得到 x
        x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)
        # 返回 x
        return x

    # __call__ 方法,接受 denoiser、x、cond、uc 和 num_steps 参数
    def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
        # 调用 prepare_sampling_loop 方法,得到 x、s_in、sigmas、num_sigmas、cond 和 uc
        x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)

        # 遍历 sigmas
        for i in self.get_sigma_gen(num_sigmas):
            # 调用 sampler_step 方法,得到 x
            x = self.sampler_step(
                s_in * sigmas[i],
                s_in * sigmas[i + 1],
                denoiser,
                x,
                cond,
                uc,
                self.s_noise,
            )

        # 返回 x
        return x


# 定义 AncestralSampler 类,继承自 SingleStepDiffusionSampler 类
class AncestralSampler(SingleStepDiffusionSampler):
    # 初始化方法,接受 eta 和 s_noise 参数,默认值分别为 1.0 和 1.0
    def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs):
        # 调用父类的初始化方法
        super().__init__(*args, **kwargs)

        # 设置实例属性 eta 为传入的 eta 参数值
        self.eta = eta
        # 设置实例属性 s_noise 为传入的 s_noise 参数值
        self.s_noise = s_noise
        # 设置实例属性 noise_sampler 为一个 lambda 函数,用于生成噪声
        self.noise_sampler = lambda x: torch.randn_like(x)

    # ancestral_euler_step 方法,接受 x、denoised、sigma 和 sigma_down 参数
    def ancestral_euler_step(self, x, denoised, sigma, sigma_down):
        # 计算 d,使用 to_d 函数
        d = to_d(x, sigma, denoised)
        # 计算 dt
        dt = append_dims(sigma_down - sigma, x.ndim)
        # 调用 euler_step 方法,得到结果
        return self.euler_step(x, d, dt)

    # ancestral_step 方法,接受 x、sigma、next_sigma 和 sigma_up 参数
    def ancestral_step(self, x, sigma, next_sigma, sigma_up):
        # 根据条件进行赋值操作
        x = torch.where(
            append_dims(next_sigma, x.ndim) > 0.0,
            x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim),
            x,
        )
        # 返回结果
        return x

    # __call__ 方法,接受 denoiser、x、cond、uc 和 num_steps 参数
    def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
        # 调用 prepare_sampling_loop 方法,得到 x、s_in、sigmas、num_sigmas、cond 和 uc
        x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)

        # 遍历 sigmas
        for i in self.get_sigma_gen(num_sigmas):
            # 调用 sampler_step 方法,得到 x
            x = self.sampler_step(
                s_in * sigmas[i],
                s_in * sigmas[i + 1],
                denoiser,
                x,
                cond,
                uc,
            )

        # 返回 x
        return x


# 定义 LinearMultistepSampler 类,继承自 BaseDiffusionSampler 类
class LinearMultistepSampler(BaseDiffusionSampler):
    # 初始化方法,接受 order 参数,默认值为 4
    def __init__(
        self,
        order=4,
        *args,
        **kwargs,
    ):
        # 调用父类的初始化方法
        super().__init__(*args, **kwargs)

        # 设置实例属性 order 为传入的 order 参数值
        self.order = order
    # 定义可调用方法,用于执行去噪操作
        def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
            # 准备采样循环,返回处理后的输入、sigma、以及条件信息
            x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
    
            # 初始化去噪结果列表
            ds = []
            # 将 sigma 从计算图中分离并转移到 CPU,然后转换为 NumPy 数组
            sigmas_cpu = sigmas.detach().cpu().numpy()
            # 遍历生成的 sigma 值
            for i in self.get_sigma_gen(num_sigmas):
                # 计算当前 sigma
                sigma = s_in * sigmas[i]
                # 使用去噪器处理输入数据,获取去噪结果
                denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs)
                # 进一步处理去噪结果
                denoised = self.guider(denoised, sigma)
                # 将当前输入、sigma 和去噪结果转换为目标格式
                d = to_d(x, sigma, denoised)
                # 将去噪结果添加到结果列表
                ds.append(d)
                # 如果结果列表超过预设顺序,则移除最旧的结果
                if len(ds) > self.order:
                    ds.pop(0)
                # 计算当前顺序
                cur_order = min(i + 1, self.order)
                # 计算线性多步法的系数
                coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
                # 更新输入 x,使用加权和的方式结合去噪结果
                x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
    
            # 返回最终的去噪结果
            return x
# 定义一个基于 Euler 方法的 EDM 采样器类,继承自 EDMSampler
class EulerEDMSampler(EDMSampler):
    # 可能的修正步骤,接受多个参数
    def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):
        # 直接返回 euler_step,没有进行任何修正
        return euler_step


# 定义一个基于 Heun 方法的 EDM 采样器类,继承自 EDMSampler
class HeunEDMSampler(EDMSampler):
    # 可能的修正步骤,接受多个参数
    def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):
        # 检查 next_sigma 的总和是否小于 1e-14
        if torch.sum(next_sigma) < 1e-14:
            # 如果所有噪声水平为 0,则返回 euler_step,避免网络评估
            return euler_step
        else:
            # 使用 denoiser 去噪 euler_step
            denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc)
            # 将 euler_step 和去噪结果转换为新的 d 值
            d_new = to_d(euler_step, next_sigma, denoised)
            # 计算 d 的新值,取 d 和 d_new 的平均
            d_prime = (d + d_new) / 2.0

            # 如果噪声水平不为 0,则应用修正
            x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step)
            # 返回修正后的 x
            return x


# 定义一个基于 Euler 的祖先采样器类,继承自 AncestralSampler
class EulerAncestralSampler(AncestralSampler):
    # 进行采样步骤,接受多个参数
    def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc):
        # 获取祖先步骤的上下界
        sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
        # 去噪 x
        denoised = self.denoise(x, denoiser, sigma, cond, uc)
        # 进行 Euler 采样步骤
        x = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
        # 进行祖先步骤
        x = self.ancestral_step(x, sigma, next_sigma, sigma_up)

        # 返回最终的 x
        return x


# 定义一个基于 DPM++ 的祖先采样器类,继承自 AncestralSampler
class DPMPP2SAncestralSampler(AncestralSampler):
    # 获取变量,接受 sigma 和 sigma_down
    def get_variables(self, sigma, sigma_down):
        # 将 sigma 和 sigma_down 转换为负对数
        t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)]
        # 计算时间差 h
        h = t_next - t
        # 计算 s 值
        s = t + 0.5 * h
        # 返回 h, s, t, t_next
        return h, s, t, t_next

    # 计算多重值,接受 h, s, t, t_next
    def get_mult(self, h, s, t, t_next):
        # 计算多个乘数
        mult1 = to_sigma(s) / to_sigma(t)
        mult2 = (-0.5 * h).expm1()
        mult3 = to_sigma(t_next) / to_sigma(t)
        mult4 = (-h).expm1()

        # 返回所有计算出的乘数
        return mult1, mult2, mult3, mult4

    # 进行采样步骤,接受多个参数
    def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs):
        # 获取祖先步骤的上下界
        sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
        # 去噪 x
        denoised = self.denoise(x, denoiser, sigma, cond, uc)
        # 进行 Euler 采样步骤
        x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down)

        # 检查 sigma_down 的总和是否小于 1e-14
        if torch.sum(sigma_down) < 1e-14:
            # 如果所有噪声水平为 0,则返回 x_euler,避免网络评估
            x = x_euler
        else:
            # 获取变量
            h, s, t, t_next = self.get_variables(sigma, sigma_down)
            # 计算多重值并调整维度
            mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)]

            # 计算新的 x2
            x2 = mult[0] * x - mult[1] * denoised
            # 对 x2 进行去噪
            denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
            # 计算最终的 x_dpmpp2s
            x_dpmpp2s = mult[2] * x - mult[3] * denoised2

            # 如果噪声水平不为 0,则应用修正
            x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler)

        # 进行祖先步骤
        x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
        # 返回最终的 x
        return x


# 定义一个基于 DPM++ 的采样器类,继承自 BaseDiffusionSampler
class DPMPP2MSampler(BaseDiffusionSampler):
    # 定义一个获取变量的函数,接受当前和下一个噪声级别,以及可选的上一个噪声级别
    def get_variables(self, sigma, next_sigma, previous_sigma=None):
        # 将 sigma 和 next_sigma 转换为负对数形式
        t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
        # 计算两个时间点之间的差值
        h = t_next - t

        # 如果上一个噪声级别存在
        if previous_sigma is not None:
            # 计算当前时间与上一个时间的差值
            h_last = t - to_neg_log_sigma(previous_sigma)
            # 计算当前和上一个时间差的比值
            r = h_last / h
            # 返回差值 h、比值 r、当前和下一个时间
            return h, r, t, t_next
        else:
            # 返回差值 h 和当前、下一个时间,但不返回比值 r
            return h, None, t, t_next

    # 定义一个获取乘数的函数,接受多个参数
    def get_mult(self, h, r, t, t_next, previous_sigma):
        # 计算当前和下一个时间的 sigma 乘数
        mult1 = to_sigma(t_next) / to_sigma(t)
        # 计算 h 的负值的指数减一
        mult2 = (-h).expm1()

        # 如果上一个噪声级别存在
        if previous_sigma is not None:
            # 计算与 r 相关的乘数
            mult3 = 1 + 1 / (2 * r)
            mult4 = 1 / (2 * r)
            # 返回所有乘数
            return mult1, mult2, mult3, mult4
        else:
            # 返回前两个乘数
            return mult1, mult2

    # 定义采样步骤函数,接受多个参数
    def sampler_step(
        self,
        old_denoised,
        previous_sigma,
        sigma,
        next_sigma,
        denoiser,
        x,
        cond,
        uc=None,
    ):
        # 使用去噪器对输入 x 进行去噪
        denoised = self.denoise(x, denoiser, sigma, cond, uc)

        # 获取变量 h、r、t、t_next
        h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
        # 获取乘数,并将维度调整以匹配 x 的维度
        mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)]

        # 计算标准化后的 x
        x_standard = mult[0] * x - mult[1] * denoised
        # 如果没有旧的去噪结果或下一个噪声级别接近零
        if old_denoised is None or torch.sum(next_sigma) < 1e-14:
            # 保存网络评估,如果所有噪声级别为 0 或为第一步
            return x_standard, denoised
        else:
            # 计算去噪后的旧结果
            denoised_d = mult[2] * denoised - mult[3] * old_denoised
            # 计算高级 x
            x_advanced = mult[0] * x - mult[1] * denoised_d

            # 如果噪声级别不为 0 且不是第一步,则应用修正
            x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard)

        # 返回最终的 x 和去噪结果
        return x, denoised

    # 定义调用函数,接受去噪器、输入 x、条件和其他可选参数
    def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
        # 准备采样循环的参数,包括对输入 x、条件、噪声级别等的处理
        x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)

        # 初始化旧去噪结果
        old_denoised = None
        # 遍历生成的噪声级别
        for i in self.get_sigma_gen(num_sigmas):
            # 执行采样步骤,更新 x 和旧去噪结果
            x, old_denoised = self.sampler_step(
                old_denoised,
                None if i == 0 else s_in * sigmas[i - 1],
                s_in * sigmas[i],
                s_in * sigmas[i + 1],
                denoiser,
                x,
                cond,
                uc=uc,
            )

        # 返回最终的 x
        return x
# 定义 SDEDPMPP2MSampler 类,继承自 BaseDiffusionSampler
class SDEDPMPP2MSampler(BaseDiffusionSampler):
    # 获取变量 h、r 和时间参数 t、t_next
    def get_variables(self, sigma, next_sigma, previous_sigma=None):
        # 将 sigma 和 next_sigma 转换为负对数形式
        t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
        # 计算 h 为 t_next 和 t 的差值
        h = t_next - t

        # 如果 previous_sigma 不为 None
        if previous_sigma is not None:
            # 计算上一个 sigma 的负对数值
            h_last = t - to_neg_log_sigma(previous_sigma)
            # 计算 r 为 h_last 和 h 的比值
            r = h_last / h
            # 返回 h、r、t 和 t_next
            return h, r, t, t_next
        else:
            # 返回 h 和 None(无 r),以及 t 和 t_next
            return h, None, t, t_next

    # 计算乘数值
    def get_mult(self, h, r, t, t_next, previous_sigma):
        # 计算 mult1 为 t_next 和 t 的 sigma 比值乘以 h 的负指数
        mult1 = to_sigma(t_next) / to_sigma(t) * (-h).exp()
        # 计算 mult2 为 (-2*h) 的 expm1 值
        mult2 = (-2 * h).expm1()

        # 如果 previous_sigma 不为 None
        if previous_sigma is not None:
            # 计算 mult3 为 1 + 1/(2*r)
            mult3 = 1 + 1 / (2 * r)
            # 计算 mult4 为 1/(2*r)
            mult4 = 1 / (2 * r)
            # 返回 mult1、mult2、mult3 和 mult4
            return mult1, mult2, mult3, mult4
        else:
            # 返回 mult1 和 mult2
            return mult1, mult2

    # 执行采样步骤
    def sampler_step(
        self,
        old_denoised,
        previous_sigma,
        sigma,
        next_sigma,
        denoiser,
        x,
        cond,
        uc=None,
    ):
        # 使用 denoiser 对 x 进行去噪处理
        denoised = self.denoise(x, denoiser, sigma, cond, uc)

        # 获取 h、r、t 和 t_next
        h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
        # 计算乘数,并调整维度
        mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)]
        # 计算噪声乘数并调整维度
        mult_noise = append_dims(next_sigma * (1 - (-2 * h).exp()) ** 0.5, x.ndim)

        # 计算标准化后的 x
        x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
        # 如果 old_denoised 为 None 或 next_sigma 的和小于 1e-14
        if old_denoised is None or torch.sum(next_sigma) < 1e-14:
            # 返回标准化后的 x 和去噪后的结果
            return x_standard, denoised
        else:
            # 计算去噪的差异
            denoised_d = mult[2] * denoised - mult[3] * old_denoised
            # 计算高级 x
            x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x)

            # 如果噪声水平不为 0 且不是第一步,应用修正
            x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard)

        # 返回最终的 x 和去噪后的结果
        return x, denoised

    # 调用采样器,执行采样循环
    def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs):
        # 准备采样循环,初始化输入和 sigma
        x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)

        # 初始化 old_denoised 为 None
        old_denoised = None
        # 遍历 sigma 生成器
        for i in self.get_sigma_gen(num_sigmas):
            # 执行采样步骤并更新 x 和 old_denoised
            x, old_denoised = self.sampler_step(
                old_denoised,
                None if i == 0 else s_in * sigmas[i - 1],
                s_in * sigmas[i],
                s_in * sigmas[i + 1],
                denoiser,
                x,
                cond,
                uc=uc,
            )

        # 返回最终的 x
        return x


# 定义 SdeditEDMSampler 类,继承自 EulerEDMSampler
class SdeditEDMSampler(EulerEDMSampler):
    # 初始化函数,设置编辑比例
    def __init__(self, edit_ratio=0.5, *args, **kwargs):
        # 调用父类的初始化函数
        super().__init__(*args, **kwargs)

        # 设置编辑比例
        self.edit_ratio = edit_ratio
    # 定义一个可调用的方法,接受多个参数进行图像去噪
    def __call__(self, denoiser, image, randn, cond, uc=None, num_steps=None, edit_ratio=None):
        # 克隆 randn,创建 randn_unit 用于后续计算
        randn_unit = randn.clone()
        # 准备采样循环,处理 randn、条件、未条件和步骤数
        randn, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(randn, cond, uc, num_steps)
    
        # 如果未指定 num_steps,则使用对象的默认步骤数
        if num_steps is None:
            num_steps = self.num_steps
        # 如果未指定 edit_ratio,则使用对象的默认编辑比例
        if edit_ratio is None:
            edit_ratio = self.edit_ratio
        # 初始化 x 为 None,用于后续存储结果
        x = None
    
        # 遍历 sigma 生成器,获取每个 sigma 的值
        for i in self.get_sigma_gen(num_sigmas):
            # 如果当前步骤比例小于 edit_ratio,则跳过此次循环
            if i / num_steps < edit_ratio:
                continue
            # 如果 x 为 None,则初始化 x 为图像与噪声的组合
            if x is None:
                x = image + randn_unit * append_dims(s_in * sigmas[i], len(randn_unit.shape))
    
            # 计算 gamma 值,依据条件限制调整
            gamma = (
                min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0
            )
            # 进行一次采样步骤,更新 x 的值
            x = self.sampler_step(
                s_in * sigmas[i],     # 当前 sigma 的输入
                s_in * sigmas[i + 1], # 下一个 sigma 的输入
                denoiser,             # 去噪器
                x,                    # 当前图像
                cond,                 # 条件信息
                uc,                   # 未条件信息
                gamma,                # gamma 值
            )
    
        # 返回最终处理后的图像
        return x
# 定义一个名为 VideoDDIMSampler 的类,继承自 BaseDiffusionSampler
class VideoDDIMSampler(BaseDiffusionSampler):
    # 初始化函数,接受固定帧数和 sdedit 标志,及其他参数
    def __init__(self, fixed_frames=0, sdedit=False, **kwargs):
        # 调用父类的初始化函数
        super().__init__(**kwargs)
        # 设置固定帧数
        self.fixed_frames = fixed_frames
        # 设置 sdedit 标志
        self.sdedit = sdedit

    # 准备采样循环,接受输入数据和条件
    def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
        # 进行离散化,计算 alpha 的平方根累积乘积和时间步
        alpha_cumprod_sqrt, timesteps = self.discretization(
            self.num_steps if num_steps is None else num_steps,  # 使用给定的步数或默认步数
            device=self.device,  # 指定设备
            return_idx=True,  # 返回索引
            do_append_zero=False,  # 不追加零
        )
        # 在 alpha_cumprod_sqrt 末尾添加一个值为 1 的新张量
        alpha_cumprod_sqrt = torch.cat([alpha_cumprod_sqrt, alpha_cumprod_sqrt.new_ones([1])])
        # 创建一个新的时间步张量,并在开头添加一个值为 -1 的零张量
        timesteps = torch.cat([torch.tensor(list(timesteps)).new_zeros([1]) - 1, torch.tensor(list(timesteps))])

        # 如果 uc 为空,使用 cond 作为默认值
        uc = default(uc, cond)

        # 计算 alpha_cumprod_sqrt 的元素数量
        num_sigmas = len(alpha_cumprod_sqrt)

        # 创建一个新的张量 s_in,初始值为 1
        s_in = x.new_ones([x.shape[0]])

        # 返回多个变量
        return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps

    # 去噪函数,接受多个输入参数
    def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None, idx=None, scale=None, scale_emb=None):
        # 初始化额外模型输入的字典
        additional_model_inputs = {}

        # 检查 scale 是否为张量且不为 1
        if isinstance(scale, torch.Tensor) == False and scale == 1:
            # 为额外模型输入添加当前时间步的索引
            additional_model_inputs["idx"] = x.new_ones([x.shape[0]]) * timestep
            # 如果 scale_emb 不为 None,添加到额外输入
            if scale_emb is not None:
                additional_model_inputs["scale_emb"] = scale_emb
            # 调用去噪器进行去噪,并转换为 float32 类型
            denoised = denoiser(x, alpha_cumprod_sqrt, cond, **additional_model_inputs).to(torch.float32)
        else:
            # 创建一个新的索引张量,包含当前时间步的重复值
            additional_model_inputs["idx"] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2)
            # 调用去噪器进行去噪,准备输入并转换为 float32 类型
            denoised = denoiser(
                *self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs
            ).to(torch.float32)
            # 如果 guider 是 DynamicCFG 的实例,进行动态调整
            if isinstance(self.guider, DynamicCFG):
                denoised = self.guider(
                    denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, step_index=self.num_steps - timestep, scale=scale
                )
            else:
                # 否则,进行普通的调整
                denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, scale=scale)
        # 返回去噪后的结果
        return denoised

    # 采样步骤函数,接受多个输入参数
    def sampler_step(
        self,
        alpha_cumprod_sqrt,
        next_alpha_cumprod_sqrt,
        denoiser,
        x,
        cond,
        uc=None,
        idx=None,
        timestep=None,
        scale=None,
        scale_emb=None,
    ):
        # 调用 denoise 方法获取去噪结果
        denoised = self.denoise(
            x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb
        ).to(torch.float32)

        # 计算 a_t 和 b_t 值
        a_t = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5
        b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t

        # 更新 x 的值,通过加权当前值和去噪后的值
        x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised
        # 返回更新后的 x
        return x
    # 定义一个可调用的方法,用于处理去噪过程
    def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None):
        # 准备采样循环的输入,返回处理后的数据和相关参数
        x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
            x, cond, uc, num_steps
        )
    
        # 遍历生成的 sigma 值
        for i in self.get_sigma_gen(num_sigmas):
            # 执行采样步骤,更新输入数据
            x = self.sampler_step(
                s_in * alpha_cumprod_sqrt[i],  # 当前 sigma 的缩放输入
                s_in * alpha_cumprod_sqrt[i + 1],  # 下一个 sigma 的缩放输入
                denoiser,  # 去噪器对象
                x,  # 当前输入
                cond,  # 条件输入
                uc,  # 可选的额外条件
                idx=self.num_steps - i,  # 当前步骤索引
                timestep=timesteps[-(i + 1)],  # 当前时间步
                scale=scale,  # 缩放因子
                scale_emb=scale_emb,  # 嵌入的缩放因子
            )
    
        # 返回处理后的结果
        return x
# 定义 VPSDEDPMPP2M 采样器类,继承自 VideoDDIMSampler
class VPSDEDPMPP2MSampler(VideoDDIMSampler):
    # 获取变量,计算多个参数
    def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None):
        # 计算 alpha 的累积乘积
        alpha_cumprod = alpha_cumprod_sqrt**2
        # 计算 lamb 的对数值
        lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log()
        # 计算下一个 alpha 的累积乘积
        next_alpha_cumprod = next_alpha_cumprod_sqrt**2
        # 计算下一个 lamb 的对数值
        lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log()
        # 计算 h 值
        h = lamb_next - lamb

        # 如果存在前一个 alpha 的累积乘积
        if previous_alpha_cumprod_sqrt is not None:
            # 计算前一个 alpha 的累积乘积
            previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2
            # 计算前一个 lamb 的对数值
            lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log()
            # 计算 h_last 值
            h_last = lamb - lamb_previous
            # 计算 r 值
            r = h_last / h
            # 返回 h、r、lamb 和 lamb_next
            return h, r, lamb, lamb_next
        else:
            # 返回 h、None、lamb 和 lamb_next
            return h, None, lamb, lamb_next

    # 计算乘数
    def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt):
        # 计算第一个乘数
        mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 * (-h).exp()
        # 计算第二个乘数
        mult2 = (-2 * h).expm1() * next_alpha_cumprod_sqrt

        # 如果存在前一个 alpha 的累积乘积
        if previous_alpha_cumprod_sqrt is not None:
            # 计算第三个乘数
            mult3 = 1 + 1 / (2 * r)
            # 计算第四个乘数
            mult4 = 1 / (2 * r)
            # 返回所有乘数
            return mult1, mult2, mult3, mult4
        else:
            # 返回前两个乘数
            return mult1, mult2

    # 执行采样步骤
    def sampler_step(
        self,
        old_denoised,
        previous_alpha_cumprod_sqrt,
        alpha_cumprod_sqrt,
        next_alpha_cumprod_sqrt,
        denoiser,
        x,
        cond,
        uc=None,
        idx=None,
        timestep=None,
        scale=None,
        scale_emb=None,
    ):
        # 使用去噪器处理输入,得到去噪后的结果
        denoised = self.denoise(
            x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb
        ).to(torch.float32)
        # 如果索引为 1,返回去噪结果
        if idx == 1:
            return denoised, denoised

        # 获取相关变量
        h, r, lamb, lamb_next = self.get_variables(
            alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
        )
        # 获取乘数
        mult = [
            append_dims(mult, x.ndim)
            for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
        ]
        # 计算噪声乘数
        mult_noise = append_dims((1 - next_alpha_cumprod_sqrt**2) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5, x.ndim)

        # 计算标准化 x
        x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x)
        # 如果 old_denoised 为 None 或者下一个 alpha 的累积乘积小于阈值
        if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14:
            # 返回标准化的 x 和去噪后的结果
            return x_standard, denoised
        else:
            # 计算去噪后的差异
            denoised_d = mult[2] * denoised - mult[3] * old_denoised
            # 计算高级 x
            x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x)

            # 更新 x
            x = x_advanced

        # 返回最终的 x 和去噪结果
        return x, denoised
    # 定义可调用方法,接收去噪器、输入数据、条件、上采样因子及其它参数
    def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None):
        # 准备采样循环所需的输入及参数
        x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
            x, cond, uc, num_steps
        )
    
        # 如果固定帧数大于0,提取前固定帧数的图像
        if self.fixed_frames > 0:
            prefix_frames = x[:, : self.fixed_frames]
        # 初始化去噪后的图像为 None
        old_denoised = None
        # 遍历生成的 sigma 值
        for i in self.get_sigma_gen(num_sigmas):
            # 如果固定帧数大于0,进行处理
            if self.fixed_frames > 0:
                # 如果启用 SD 编辑模式
                if self.sdedit:
                    # 生成与前缀帧同形状的随机噪声
                    rd = torch.randn_like(prefix_frames)
                    # 计算带噪声的前缀帧
                    noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims(
                        s_in * (1 - alpha_cumprod_sqrt[i] ** 2) ** 0.5, len(prefix_frames.shape)
                    )
                    # 将带噪声的前缀帧与剩余帧连接
                    x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames :]], dim=1)
                else:
                    # 直接将前缀帧与剩余帧连接
                    x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1)
            # 执行去噪步骤
            x, old_denoised = self.sampler_step(
                old_denoised,
                None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1],
                s_in * alpha_cumprod_sqrt[i],
                s_in * alpha_cumprod_sqrt[i + 1],
                denoiser,
                x,
                cond,
                uc=uc,
                idx=self.num_steps - i,
                timestep=timesteps[-(i + 1)],
                scale=scale,
                scale_emb=scale_emb,
            )
    
        # 如果固定帧数大于0,重构最终输出
        if self.fixed_frames > 0:
            x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1)
    
        # 返回最终的去噪结果
        return x
# 定义 VPODEDPMPP2MSampler 类,继承自 VideoDDIMSampler
class VPODEDPMPP2MSampler(VideoDDIMSampler):
    # 获取变量,计算当前和下一个 alpha 的平方根
    def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None):
        # 计算 alpha 的平方
        alpha_cumprod = alpha_cumprod_sqrt**2
        # 计算 lambda 值并取其对数
        lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log()
        # 计算下一个 alpha 的平方
        next_alpha_cumprod = next_alpha_cumprod_sqrt**2
        # 计算下一个 lambda 值并取其对数
        lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log()
        # 计算 h 值
        h = lamb_next - lamb

        # 如果提供了上一个 alpha 的平方根
        if previous_alpha_cumprod_sqrt is not None:
            # 计算上一个 alpha 的平方
            previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2
            # 计算上一个 lambda 值并取其对数
            lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log()
            # 计算上一个 h 值
            h_last = lamb - lamb_previous
            # 计算 r 值
            r = h_last / h
            # 返回 h, r, lamb, lamb_next
            return h, r, lamb, lamb_next
        else:
            # 如果没有上一个 alpha,返回 h 和其他计算值
            return h, None, lamb, lamb_next

    # 获取乘数
    def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt):
        # 计算第一个乘数
        mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5
        # 计算第二个乘数
        mult2 = (-h).expm1() * next_alpha_cumprod_sqrt

        # 如果提供了上一个 alpha 的平方根
        if previous_alpha_cumprod_sqrt is not None:
            # 计算第三个乘数
            mult3 = 1 + 1 / (2 * r)
            # 计算第四个乘数
            mult4 = 1 / (2 * r)
            # 返回所有乘数
            return mult1, mult2, mult3, mult4
        else:
            # 返回前两个乘数
            return mult1, mult2

    # 采样步骤
    def sampler_step(
        self,
        old_denoised,
        previous_alpha_cumprod_sqrt,
        alpha_cumprod_sqrt,
        next_alpha_cumprod_sqrt,
        denoiser,
        x,
        cond,
        uc=None,
        idx=None,
        timestep=None,
    ):
        # 使用去噪器对输入 x 进行去噪处理
        denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx).to(torch.float32)
        # 如果索引为 1,返回去噪结果
        if idx == 1:
            return denoised, denoised

        # 获取变量
        h, r, lamb, lamb_next = self.get_variables(
            alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt
        )
        # 获取乘数并调整维度
        mult = [
            append_dims(mult, x.ndim)
            for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt)
        ]

        # 计算标准化的 x
        x_standard = mult[0] * x - mult[1] * denoised
        # 如果没有旧的去噪结果或下一个 alpha 的平方根总和接近 0
        if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14:
            # 返回标准化的 x 和去噪结果
            return x_standard, denoised
        else:
            # 计算去噪后的结果
            denoised_d = mult[2] * denoised - mult[3] * old_denoised
            # 计算高级的 x
            x_advanced = mult[0] * x - mult[1] * denoised_d

            # 更新 x
            x = x_advanced

        # 返回最终的 x 和去噪结果
        return x, denoised
    # 定义可调用对象,接受去噪器、输入数据、条件及其他参数
    def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs):
        # 准备采样循环所需的输入数据和参数
        x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop(
            x, cond, uc, num_steps
        )
    
        # 初始化旧去噪结果为 None
        old_denoised = None
        # 遍历生成的 sigma 值,进行采样步骤
        for i in self.get_sigma_gen(num_sigmas):
            # 执行单步采样,并更新当前输入和旧去噪结果
            x, old_denoised = self.sampler_step(
                old_denoised,
                None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1],  # 第一步不使用旧去噪
                s_in * alpha_cumprod_sqrt[i],  # 当前 sigma 的值
                s_in * alpha_cumprod_sqrt[i + 1],  # 下一个 sigma 的值
                denoiser,  # 去噪器
                x,  # 当前输入数据
                cond,  # 条件输入
                uc=uc,  # 额外条件(可选)
                idx=self.num_steps - i,  # 当前步骤索引
                timestep=timesteps[-(i + 1)],  # 当前时间步
            )
    
        # 返回最终生成的输入数据
        return x

.\cogvideo-finetune\sat\sgm\modules\diffusionmodules\sampling_utils.py

# 导入 PyTorch 库
import torch
# 从 scipy 导入积分功能
from scipy import integrate

# 从上层模块导入 append_dims 函数
from ...util import append_dims
# 从 einops 导入 rearrange 函数,用于重排张量
from einops import rearrange


# 定义无动态阈值类
class NoDynamicThresholding:
    # 定义调用方法,接受无条件输入、条件输入和缩放因子
    def __call__(self, uncond, cond, scale):
        # 如果 scale 是张量,则调整其维度以匹配条件的维度
        scale = append_dims(scale, cond.ndim) if isinstance(scale, torch.Tensor) else scale
        # 返回无条件输入和缩放后的条件与无条件差值之和
        return uncond + scale * (cond - uncond)


# 定义静态阈值类
class StaticThresholding:
    # 定义调用方法,接受无条件输入、条件输入和缩放因子
    def __call__(self, uncond, cond, scale):
        # 计算无条件输入和缩放后的条件与无条件差值之和
        result = uncond + scale * (cond - uncond)
        # 将结果限制在 -1.0 到 1.0 之间
        result = torch.clamp(result, min=-1.0, max=1.0)
        # 返回处理后的结果
        return result


# 定义动态阈值函数
def dynamic_threshold(x, p=0.95):
    # 获取输入张量的维度
    N, T, C, H, W = x.shape
    # 将张量重排为适合计算的形状
    x = rearrange(x, "n t c h w -> n c (t h w)")
    # 计算给定分位数的左侧和右侧阈值
    l, r = x.quantile(q=torch.tensor([1 - p, p], device=x.device), dim=-1, keepdim=True)
    # 计算阈值的最大值
    s = torch.maximum(-l, r)
    # 创建阈值掩码,用于过滤
    threshold_mask = (s > 1).expand(-1, -1, H * W * T)
    # 如果阈值掩码中有任何 True 值,则应用阈值处理
    if threshold_mask.any():
        x = torch.where(threshold_mask, x.clamp(min=-1 * s, max=s), x)
    # 恢复张量的原始形状
    x = rearrange(x, "n c (t h w) -> n t c h w", t=T, h=H, w=W)
    # 返回处理后的张量
    return x


# 定义第二种动态阈值处理函数
def dynamic_thresholding2(x0):
    p = 0.995  # 参考论文“Imagen”中的超参数
    # 保存输入张量的原始数据类型
    origin_dtype = x0.dtype
    # 将输入张量转换为浮点数
    x0 = x0.to(torch.float32)
    # 计算输入张量绝对值的分位数
    s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
    # 将分位数与 1 进行比较并调整维度
    s = append_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), x0.dim())
    # 限制输入张量的值在 -s 和 s 之间
    x0 = torch.clamp(x0, -s, s)  # / s
    # 返回转换为原始数据类型的张量
    return x0.to(origin_dtype)


# 定义潜在动态阈值处理函数
def latent_dynamic_thresholding(x0):
    p = 0.9995  # 参考论文中的超参数
    # 保存输入张量的原始数据类型
    origin_dtype = x0.dtype
    # 将输入张量转换为浮点数
    x0 = x0.to(torch.float32)
    # 计算输入张量绝对值的分位数
    s = torch.quantile(torch.abs(x0), p, dim=2)
    # 调整分位数的维度以匹配输入张量
    s = append_dims(s, x0.dim())
    # 限制输入张量的值在 -s 和 s 之间,并进行归一化处理
    x0 = torch.clamp(x0, -s, s) / s
    # 返回转换为原始数据类型的张量
    return x0.to(origin_dtype)


# 定义第三种动态阈值处理函数
def dynamic_thresholding3(x0):
    p = 0.995  # 参考论文“Imagen”中的超参数
    # 保存输入张量的原始数据类型
    origin_dtype = x0.dtype
    # 将输入张量转换为浮点数
    x0 = x0.to(torch.float32)
    # 计算输入张量绝对值的分位数
    s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
    # 将分位数与 1 进行比较并调整维度
    s = append_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), x0.dim())
    # 限制输入张量的值在 -s 和 s 之间
    x0 = torch.clamp(x0, -s, s)  # / s
    # 返回转换为原始数据类型的张量
    return x0.to(origin_dtype)


# 定义动态阈值类
class DynamicThresholding:
    # 定义调用方法,接受无条件输入、条件输入和缩放因子
    def __call__(self, uncond, cond, scale):
        # 计算无条件输入的均值和标准差
        mean = uncond.mean()
        std = uncond.std()
        # 计算无条件输入和缩放后的条件与无条件差值之和
        result = uncond + scale * (cond - uncond)
        # 计算结果的均值和标准差
        result_mean, result_std = result.mean(), result.std()
        # 标准化结果,使其具有相同的标准差
        result = (result - result_mean) / result_std * std
        # result = dynamic_thresholding3(result)  # 可选的进一步处理
        # 返回处理后的结果
        return result


# 定义动态阈值版本1类
class DynamicThresholdingV1:
    # 初始化时接收缩放因子
    def __init__(self, scale_factor):
        self.scale_factor = scale_factor
    # 定义一个函数,接受三个参数 uncond, cond, scale
    def __call__(self, uncond, cond, scale):
        # 计算结果,根据公式 uncond + scale * (cond - uncond)
        result = uncond + scale * (cond - uncond)
        # 对结果进行反缩放,除以缩放因子
        unscaled_result = result / self.scale_factor
        # 获取结果的形状信息,分别为 Batch size, Time steps, Channels, Height, Width
        B, T, C, H, W = unscaled_result.shape
        # 将结果重新排列成 "b t c h w" 的形式
        flattened = rearrange(unscaled_result, "b t c h w -> b c (t h w)")
        # 计算每个通道的均值,并在第二维度上增加一个维度
        means = flattened.mean(dim=2).unsqueeze(2)
        # 对结果进行重新中心化,减去均值
        recentered = flattened - means
        # 计算每个通道的绝对值的最大值
        magnitudes = recentered.abs().max()
        # 对结果进行归一化,除以最大值
        normalized = recentered / magnitudes
        # 对结果进行动态阈值处理
        thresholded = latent_dynamic_thresholding(normalized)
        # 对结果进行反归一化,乘以最大值
        denormalized = thresholded * magnitudes
        # 对结果进行重新中心化,加上均值
        uncentered = denormalized + means
        # 将结果重新排列成 "b c (t h w)" 的形式
        unflattened = rearrange(uncentered, "b c (t h w) -> b t c h w", t=T, h=H, w=W)
        # 对结果进行缩放,乘以缩放因子
        scaled_result = unflattened * self.scale_factor
        # 返回缩放后的结果
        return scaled_result
# 定义一个动态阈值处理类
class DynamicThresholdingV2:
    # 定义类的调用方法,接受无条件值、条件值和缩放比例作为参数
    def __call__(self, uncond, cond, scale):
        # 获取无条件值的形状信息
        B, T, C, H, W = uncond.shape
        # 计算条件值和无条件值的差值
        diff = cond - uncond
        # 计算最小目标值
        mim_target = uncond + diff * 4.0
        # 计算配置目标值
        cfg_target = uncond + diff * 8.0

        # 将最小目标值展平为二维数组
        mim_flattened = rearrange(mim_target, "b t c h w -> b c (t h w)")
        # 将配置目标值展平为二维数组
        cfg_flattened = rearrange(cfg_target, "b t c h w -> b c (t h w)")
        # 计算最小目标值的均值
        mim_means = mim_flattened.mean(dim=2).unsqueeze(2)
        # 计算配置目标值的均值
        cfg_means = cfg_flattened.mean(dim=2).unsqueeze(2)
        # 计算最小目标值的中心化值
        mim_centered = mim_flattened - mim_means
        # 计算配置目标值的中心化值
        cfg_centered = cfg_flattened - cfg_means

        # 计算最小目标值的标准差
        mim_scaleref = mim_centered.std(dim=2).unsqueeze(2)
        # 计算配置目标值的标准差
        cfg_scaleref = cfg_centered.std(dim=2).unsqueeze(2)

        # 对配置目标值进行重新归一化
        cfg_renormalized = cfg_centered / cfg_scaleref * mim_scaleref

        # 将结果还原为原始形状
        result = cfg_renormalized + cfg_means
        unflattened = rearrange(result, "b c (t h w) -> b t c h w", t=T, h=H, w=W)

        return unflattened


# 定义一个线性多步系数函数
def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
    # 如果阶数大于i,则抛出异常
    if order - 1 > i:
        raise ValueError(f"Order {order} too high for step {i}")

    # 定义一个内部函数,接受tau作为参数
    def fn(tau):
        prod = 1.0
        # 遍历阶数
        for k in range(order):
            # 如果j等于k,则跳过
            if j == k:
                continue
            # 计算乘积
            prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
        return prod

    # 对内部函数进行积分
    return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0]


# 定义一个获取祖先步长的函数
def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
    # 如果eta为假,则返回sigma_to和0.0
    if not eta:
        return sigma_to, 0.0
    # 计算上行步长
    sigma_up = torch.minimum(
        sigma_to,
        eta * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
    )
    # 计算下行步长
    sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
    return sigma_down, sigma_up


# 定义一个转换为d的函数
def to_d(x, sigma, denoised):
    return (x - denoised) / append_dims(sigma, x.ndim)


# 定义一个转换为负对数sigma的函数
def to_neg_log_sigma(sigma):
    return sigma.log().neg()


# 定义一个转换为sigma的函数
def to_sigma(neg_log_sigma):
    return neg_log_sigma.neg().exp()

.\cogvideo-finetune\sat\sgm\modules\diffusionmodules\sigma_sampling.py

import torch
import torch.distributed

from sat import mpu

from ...util import default, instantiate_from_config


class EDMSampling:
    def __init__(self, p_mean=-1.2, p_std=1.2):
        # 初始化函数,设置默认的 p_mean 和 p_std 值
        self.p_mean = p_mean
        self.p_std = p_std

    def __call__(self, n_samples, rand=None):
        # 调用函数,生成 n_samples 个样本
        # 通过随机数生成 log_sigma,其均值为 p_mean,标准差为 p_std
        log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,)))
        # 返回 log_sigma 的指数形式
        return log_sigma.exp()


class DiscreteSampling:
    def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False):
        # 初始化函数,设置离散化配置、索引数量、是否追加零、是否翻转和是否均匀采样
        self.num_idx = num_idx
        # 实例化离散化配置,得到 sigmas
        self.sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip)
        # 获取数据并行的世界大小
        world_size = mpu.get_data_parallel_world_size()
        self.uniform_sampling = uniform_sampling
        if self.uniform_sampling:
            # 如果进行均匀采样
            i = 1
            while True:
                if world_size % i != 0 or num_idx % (world_size // i) != 0:
                    i += 1
                else:
                    # 计算组数
                    self.group_num = world_size // i
                    break

            assert self.group_num > 0
            assert world_size % self.group_num == 0
            # 计算每组的宽度,即每组中的排名数量
            self.group_width = world_size // self.group_num  # the number of rank in one group
            # 计算每个组的 sigma 区间
            self.sigma_interval = self.num_idx // self.group_num

    def idx_to_sigma(self, idx):
        # 根据索引获取 sigma
        return self.sigmas[idx]

    def __call__(self, n_samples, rand=None, return_idx=False):
        # 调用函数,生成 n_samples 个样本
        if self.uniform_sampling:
            # 如果进行均匀采样
            rank = mpu.get_data_parallel_rank()
            # 计算组索引
            group_index = rank // self.group_width
            # 生成索引,范围为 group_index * sigma_interval 到 (group_index + 1) * sigma_interval
            idx = default(
                rand,
                torch.randint(
                    group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,)
                ),
            )
        else:
            # 如果不进行均匀采样
            # 生成索引,范围为 0 到 num_idx
            idx = default(
                rand,
                torch.randint(0, self.num_idx, (n_samples,)),
            )
        if return_idx:
            # 如果返回索引
            # 返回索引对应的 sigma 和索引
            return self.idx_to_sigma(idx), idx
        else:
            # 如果不返回索引
            # 返回索引对应的 sigma
            return self.idx_to_sigma(idx)


class PartialDiscreteSampling:
    def __init__(self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True):
        # 初始化函数,设置离散化配置、总索引数量、部分索引数量、是否追加零和是否翻转
        self.total_num_idx = total_num_idx
        self.partial_num_idx = partial_num_idx
        # 实例化离散化配置,得到 sigmas
        self.sigmas = instantiate_from_config(discretization_config)(
            total_num_idx, do_append_zero=do_append_zero, flip=flip
        )

    def idx_to_sigma(self, idx):
        # 根据索引获取 sigma
        return self.sigmas[idx]

    def __call__(self, n_samples, rand=None):
        # 调用函数,生成 n_samples 个样本
        # 生成索引,范围为 0 到 partial_num_idx
        idx = default(
            rand,
            torch.randint(0, self.partial_num_idx, (n_samples,)),
        )
        # 返回索引对应的 sigma
        return self.idx_to_sigma(idx)

.\cogvideo-finetune\sat\sgm\modules\diffusionmodules\util.py

# 该模块源自多个开源项目,感谢它们的贡献
"""
adopted from
https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
and
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
and
https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py

thanks!
"""

# 导入数学库
import math
# 从 typing 导入可选类型
from typing import Optional

# 导入 PyTorch 库
import torch
import torch.nn as nn
# 从 einops 导入重排和重复函数
from einops import rearrange, repeat


# 创建 beta 调度的函数
def make_beta_schedule(
    schedule,
    n_timestep,
    linear_start=1e-4,
    linear_end=2e-2,
):
    # 如果调度方式为线性,生成 beta 值的序列
    if schedule == "linear":
        betas = torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64) ** 2
    # 返回 numpy 格式的 beta 值
    return betas.numpy()


# 从张量中提取特定的值并调整形状
def extract_into_tensor(a, t, x_shape):
    # 获取 t 的第一个维度的大小
    b, *_ = t.shape
    # 根据索引 t 从 a 中提取数据
    out = a.gather(-1, t)
    # 返回调整形状后的输出
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))


# 混合检查点函数,用于在不缓存中间激活的情况下评估函数
def mixed_checkpoint(func, inputs: dict, params, flag):
    """
    在不缓存中间激活的情况下评估函数,减少内存消耗,但会增加反向传播的计算量。
    该实现允许非张量输入。
    :param func: 要评估的函数。
    :param inputs: 传递给 `func` 的参数字典。
    :param params: func 依赖但不作为参数的参数序列。
    :param flag: 如果为 False,禁用梯度检查点。
    """
    # 如果启用标志,处理张量输入
    if flag:
        # 获取所有张量类型的输入键
        tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)]
        # 获取所有张量类型的输入值
        tensor_inputs = [inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)]
        # 获取所有非张量类型的输入键
        non_tensor_keys = [key for key in inputs if not isinstance(inputs[key], torch.Tensor)]
        # 获取所有非张量类型的输入值
        non_tensor_inputs = [inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)]
        # 构建参数元组
        args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params)
        # 应用混合检查点函数
        return MixedCheckpointFunction.apply(
            func,
            len(tensor_inputs),
            len(non_tensor_inputs),
            tensor_keys,
            non_tensor_keys,
            *args,
        )
    else:
        # 如果禁用标志,直接调用函数
        return func(**inputs)


# 定义混合检查点函数的类
class MixedCheckpointFunction(torch.autograd.Function):
    @staticmethod
    # 定义前向传播方法
    def forward(
        ctx,
        run_function,
        length_tensors,
        length_non_tensors,
        tensor_keys,
        non_tensor_keys,
        *args,
    ):
        # 将长度张量赋值给上下文对象的属性
        ctx.end_tensors = length_tensors
        # 将长度非张量赋值给上下文对象的属性
        ctx.end_non_tensors = length_tensors + length_non_tensors
        # 创建包含 GPU 自动混合精度参数的字典
        ctx.gpu_autocast_kwargs = {
            # 检查自动混合精度是否启用
            "enabled": torch.is_autocast_enabled(),
            # 获取自动混合精度的 GPU 数据类型
            "dtype": torch.get_autocast_gpu_dtype(),
            # 检查自动混合精度缓存是否启用
            "cache_enabled": torch.is_autocast_cache_enabled(),
        }
        # 断言张量键和非张量键的长度与预期相符
        assert len(tensor_keys) == length_tensors and len(non_tensor_keys) == length_non_tensors

        # 创建输入张量字典,将键与对应的值配对
        ctx.input_tensors = {key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))}
        # 创建输入非张量字典,将键与对应的值配对
        ctx.input_non_tensors = {
            key: val for (key, val) in zip(non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors]))
        }
        # 将运行函数赋值给上下文对象的属性
        ctx.run_function = run_function
        # 将输入参数赋值为剩余的参数
        ctx.input_params = list(args[ctx.end_non_tensors :])

        # 在无梯度计算上下文中执行
        with torch.no_grad():
            # 调用运行函数,并传入输入张量和非张量
            output_tensors = ctx.run_function(**ctx.input_tensors, **ctx.input_non_tensors)
        # 返回输出张量
        return output_tensors

    @staticmethod
    def backward(ctx, *output_grads):
        # 创建额外参数的字典(已注释掉)
        # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)}
        # 将输入张量设为不跟踪梯度并要求梯度
        ctx.input_tensors = {key: ctx.input_tensors[key].detach().requires_grad_(True) for key in ctx.input_tensors}

        # 在启用梯度计算和混合精度上下文中执行
        with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
            # 创建输入张量的浅拷贝以避免修改原张量
            shallow_copies = {key: ctx.input_tensors[key].view_as(ctx.input_tensors[key]) for key in ctx.input_tensors}
            # shallow_copies.update(additional_args)  # 更新额外参数(已注释掉)
            # 调用运行函数,并传入浅拷贝和非张量
            output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors)
        # 计算输出张量相对于输入张量和输入参数的梯度
        input_grads = torch.autograd.grad(
            output_tensors,
            list(ctx.input_tensors.values()) + ctx.input_params,
            output_grads,
            allow_unused=True,
        )
        # 删除输入张量
        del ctx.input_tensors
        # 删除输入参数
        del ctx.input_params
        # 删除输出张量
        del output_tensors
        # 返回梯度和占位符
        return (
            (None, None, None, None, None)
            + input_grads[: ctx.end_tensors]
            + (None,) * (ctx.end_non_tensors - ctx.end_tensors)
            + input_grads[ctx.end_tensors :]
        )
# 定义一个检查点函数,允许不缓存中间激活以降低内存使用
def checkpoint(func, inputs, params, flag):
    # 文档字符串,解释函数的用途和参数
    """
    Evaluate a function without caching intermediate activations, allowing for
    reduced memory at the expense of extra compute in the backward pass.
    :param func: the function to evaluate.
    :param inputs: the argument sequence to pass to `func`.
    :param params: a sequence of parameters `func` depends on but does not
                   explicitly take as arguments.
    :param flag: if False, disable gradient checkpointing.
    """
    # 如果标志为真,执行检查点功能
    if flag:
        # 将输入和参数组合为一个元组
        args = tuple(inputs) + tuple(params)
        # 应用检查点功能并返回结果
        return CheckpointFunction.apply(func, len(inputs), *args)
    else:
        # 直接调用函数并返回结果
        return func(*inputs)

# 定义检查点功能类,继承自torch的自动梯度功能
class CheckpointFunction(torch.autograd.Function):
    # 静态方法,用于前向传播
    @staticmethod
    def forward(ctx, run_function, length, *args):
        # 将运行的函数和输入张量保存到上下文中
        ctx.run_function = run_function
        ctx.input_tensors = list(args[:length])
        ctx.input_params = list(args[length:])
        # 保存GPU自动混合精度的相关设置
        ctx.gpu_autocast_kwargs = {
            "enabled": torch.is_autocast_enabled(),
            "dtype": torch.get_autocast_gpu_dtype(),
            "cache_enabled": torch.is_autocast_cache_enabled(),
        }
        # 在无梯度计算模式下执行函数
        with torch.no_grad():
            output_tensors = ctx.run_function(*ctx.input_tensors)
        # 返回输出张量
        return output_tensors

    # 静态方法,用于反向传播
    @staticmethod
    def backward(ctx, *output_grads):
        # 将输入张量标记为需要梯度计算
        ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
        # 启用梯度计算和混合精度模式
        with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
            # 创建输入张量的浅拷贝,以避免存储修改问题
            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
            # 执行反向传播,获取输出张量
            output_tensors = ctx.run_function(*shallow_copies)
        # 计算输入梯度
        input_grads = torch.autograd.grad(
            output_tensors,
            ctx.input_tensors + ctx.input_params,
            output_grads,
            allow_unused=True,
        )
        # 清理上下文中的临时变量
        del ctx.input_tensors
        del ctx.input_params
        del output_tensors
        # 返回None和输入梯度
        return (None, None) + input_grads

# 定义时间步嵌入函数,创建正弦时间步嵌入
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, dtype=torch.float32):
    # 文档字符串,解释函数的用途和参数
    """
    Create sinusoidal timestep embeddings.
    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    # 检查是否只重复使用嵌入
        if not repeat_only:
            # 计算维度的一半
            half = dim // 2
            # 计算频率值,使用指数衰减函数生成频率序列
            freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
                device=timesteps.device
            )
            # 计算时间步长与频率的乘积,形成角度参数
            args = timesteps[:, None].float() * freqs[None]
            # 计算余弦和正弦值并在最后一维合并
            embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
            # 如果维度为奇数,添加一列全零以保持维度一致
            if dim % 2:
                embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        else:
            # 如果只重复,直接生成与时间步长相同的嵌入
            embedding = repeat(timesteps, "b -> b d", d=dim)
        # 将嵌入转换为指定的数据类型并返回
        return embedding.to(dtype)
# 定义一个将模块参数归零的函数,并返回该模块
def zero_module(module):
    """
    将模块的参数置为零,并返回该模块。
    """
    # 遍历模块的所有参数
    for p in module.parameters():
        # 将参数分离并置为零
        p.detach().zero_()
    # 返回处理后的模块
    return module


# 定义一个对模块参数进行缩放的函数,并返回该模块
def scale_module(module, scale):
    """
    将模块的参数进行缩放,并返回该模块。
    """
    # 遍历模块的所有参数
    for p in module.parameters():
        # 将参数分离并进行缩放
        p.detach().mul_(scale)
    # 返回处理后的模块
    return module


# 定义一个计算非批次维度均值的函数
def mean_flat(tensor):
    """
    对所有非批次维度进行求均值。
    """
    # 计算并返回张量的均值,忽略第一维(批次维度)
    return tensor.mean(dim=list(range(1, len(tensor.shape))))


# 定义一个创建标准归一化层的函数
def normalization(channels):
    """
    创建一个标准归一化层。
    :param channels: 输入通道的数量。
    :return: 一个用于归一化的 nn.Module。
    """
    # 返回一个具有指定输入通道数的 GroupNorm32 对象
    return GroupNorm32(32, channels)


# 定义一个 SiLU 激活函数类,支持 PyTorch 1.5
class SiLU(nn.Module):
    # 定义前向传播方法
    def forward(self, x):
        # 返回 x 与其 sigmoid 值的乘积
        return x * torch.sigmoid(x)


# 定义一个扩展自 nn.GroupNorm 的 GroupNorm32 类
class GroupNorm32(nn.GroupNorm):
    # 定义前向传播方法
    def forward(self, x):
        # 调用父类的 forward 方法,并返回与输入相同数据类型的结果
        return super().forward(x).type(x.dtype)


# 定义一个创建卷积模块的函数,支持 1D、2D 和 3D 卷积
def conv_nd(dims, *args, **kwargs):
    """
    创建一个 1D、2D 或 3D 卷积模块。
    """
    # 判断维度是否为 1
    if dims == 1:
        # 返回 1D 卷积模块
        return nn.Conv1d(*args, **kwargs)
    # 判断维度是否为 2
    elif dims == 2:
        # 返回 2D 卷积模块
        return nn.Conv2d(*args, **kwargs)
    # 判断维度是否为 3
    elif dims == 3:
        # 返回 3D 卷积模块
        return nn.Conv3d(*args, **kwargs)
    # 抛出不支持的维度错误
    raise ValueError(f"unsupported dimensions: {dims}")


# 定义一个创建线性模块的函数
def linear(*args, **kwargs):
    """
    创建一个线性模块。
    """
    # 返回一个线性模块
    return nn.Linear(*args, **kwargs)


# 定义一个创建平均池化模块的函数,支持 1D、2D 和 3D 池化
def avg_pool_nd(dims, *args, **kwargs):
    """
    创建一个 1D、2D 或 3D 平均池化模块。
    """
    # 判断维度是否为 1
    if dims == 1:
        # 返回 1D 平均池化模块
        return nn.AvgPool1d(*args, **kwargs)
    # 判断维度是否为 2
    elif dims == 2:
        # 返回 2D 平均池化模块
        return nn.AvgPool2d(*args, **kwargs)
    # 判断维度是否为 3
    elif dims == 3:
        # 返回 3D 平均池化模块
        return nn.AvgPool3d(*args, **kwargs)
    # 抛出不支持的维度错误
    raise ValueError(f"unsupported dimensions: {dims}")


# 定义一个 AlphaBlender 类,用于实现不同的混合策略
class AlphaBlender(nn.Module):
    # 定义支持的混合策略
    strategies = ["learned", "fixed", "learned_with_images"]

    # 初始化方法,设置混合因子和重排模式
    def __init__(
        self,
        alpha: float,
        merge_strategy: str = "learned_with_images",
        rearrange_pattern: str = "b t -> (b t) 1 1",
    ):
        super().__init__()
        # 保存混合策略和重排模式
        self.merge_strategy = merge_strategy
        self.rearrange_pattern = rearrange_pattern

        # 确保混合策略在支持的策略中
        assert merge_strategy in self.strategies, f"merge_strategy needs to be in {self.strategies}"

        # 根据混合策略注册混合因子
        if self.merge_strategy == "fixed":
            # 注册一个固定的混合因子
            self.register_buffer("mix_factor", torch.Tensor([alpha]))
        elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images":
            # 注册一个可学习的混合因子
            self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
        else:
            # 抛出未知混合策略错误
            raise ValueError(f"unknown merge strategy {self.merge_strategy}")
    # 定义获取 alpha 值的函数,输入为图像指示器,返回一个张量
        def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
            # 根据合并策略选择 alpha 值的计算方式
            if self.merge_strategy == "fixed":
                # 如果合并策略为固定值,则直接使用 mix_factor 作为 alpha
                alpha = self.mix_factor
            elif self.merge_strategy == "learned":
                # 如果合并策略为学习得到的值,则对 mix_factor 应用 sigmoid 函数
                alpha = torch.sigmoid(self.mix_factor)
            elif self.merge_strategy == "learned_with_images":
                # 如果合并策略为图像学习,需要确保提供图像指示器
                assert image_only_indicator is not None, "need image_only_indicator ..."
                # 根据图像指示器选择 alpha 值,真值对应 1,假值则使用 mix_factor 的 sigmoid 结果
                alpha = torch.where(
                    image_only_indicator.bool(),  # 将图像指示器转换为布尔值
                    torch.ones(1, 1, device=image_only_indicator.device),  # 为真值创建全 1 的张量
                    rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),  # 为假值应用 sigmoid 并调整维度
                )
                # 根据 rearrange_pattern 重新排列 alpha 的维度
                alpha = rearrange(alpha, self.rearrange_pattern)
            else:
                # 如果合并策略不在已知范围内,抛出未实现错误
                raise NotImplementedError
            # 返回计算得到的 alpha 值
            return alpha
    
    # 定义前向传播函数,输入为空间和时间张量,以及可选的图像指示器
        def forward(
            self,
            x_spatial: torch.Tensor,  # 空间输入张量
            x_temporal: torch.Tensor,  # 时间输入张量
            image_only_indicator: Optional[torch.Tensor] = None,  # 可选的图像指示器
        ) -> torch.Tensor:
            # 调用 get_alpha 函数获取 alpha 值
            alpha = self.get_alpha(image_only_indicator)
            # 计算最终输出张量,结合空间和时间输入,使用 alpha 进行加权
            x = alpha.to(x_spatial.dtype) * x_spatial + (1.0 - alpha).to(x_spatial.dtype) * x_temporal
            # 返回计算得到的输出张量
            return x

.\cogvideo-finetune\sat\sgm\modules\diffusionmodules\wrappers.py

# 导入 PyTorch 库
import torch
# 导入 PyTorch 的神经网络模块
import torch.nn as nn
# 从 packaging 库导入 version 模块以便进行版本比较
from packaging import version

# 定义 OpenAIWrapper 的模块路径
OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"


# 定义身份包装器类,继承自 nn.Module
class IdentityWrapper(nn.Module):
    # 初始化函数,接收扩散模型、是否编译模型的标志和数据类型
    def __init__(self, diffusion_model, compile_model: bool = False, dtype: torch.dtype = torch.float32):
        # 调用父类的初始化函数
        super().__init__()
        # 判断 PyTorch 版本,选择是否编译模型
        compile = (
            torch.compile
            if (version.parse(torch.__version__) >= version.parse("2.0.0")) and compile_model
            else lambda x: x  # 如果不满足条件,返回原始输入
        )
        # 编译扩散模型并赋值给实例变量
        self.diffusion_model = compile(diffusion_model)
        # 设置数据类型
        self.dtype = dtype

    # 前向传播函数,接收任意数量的位置参数和关键字参数
    def forward(self, *args, **kwargs):
        # 调用扩散模型的前向传播,并返回结果
        return self.diffusion_model(*args, **kwargs)


# 定义 OpenAIWrapper 类,继承自 IdentityWrapper
class OpenAIWrapper(IdentityWrapper):
    # 重写前向传播函数,接收输入张量 x、时间步 t、上下文字典 c 以及其他关键字参数
    def forward(self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs) -> torch.Tensor:
        # 将上下文字典中的每个张量转换为指定的数据类型
        for key in c:
            c[key] = c[key].to(self.dtype)

        # 如果输入张量是 4 维,按维度 1 拼接上下文中的 "concat" 张量
        if x.dim() == 4:
            x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
        # 如果输入张量是 5 维,按维度 2 拼接上下文中的 "concat" 张量
        elif x.dim() == 5:
            x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=2)
        # 如果输入张量的维度不符合要求,抛出值错误
        else:
            raise ValueError("Input tensor must be 4D or 5D")

        # 调用扩散模型的前向传播,传入处理后的张量、时间步和上下文等
        return self.diffusion_model(
            x,
            timesteps=t,
            context=c.get("crossattn", None),
            y=c.get("vector", None),
            **kwargs,
        )

.\cogvideo-finetune\sat\sgm\modules\diffusionmodules\__init__.py

# 从当前包的 denoiser 模块导入 Denoiser 类
from .denoiser import Denoiser
# 从当前包的 discretizer 模块导入 Discretization 类
from .discretizer import Discretization
# 从当前包的 model 模块导入 Decoder、Encoder 和 Model 类
from .model import Decoder, Encoder, Model
# 从当前包的 openaimodel 模块导入 UNetModel 类
from .openaimodel import UNetModel
# 从当前包的 sampling 模块导入 BaseDiffusionSampler 类
from .sampling import BaseDiffusionSampler
# 从当前包的 wrappers 模块导入 OpenAIWrapper 类
from .wrappers import OpenAIWrapper

.\cogvideo-finetune\sat\sgm\modules\distributions\distributions.py

# 导入 numpy 库并简化为 np
import numpy as np
# 导入 PyTorch 库
import torch


# 定义抽象类 AbstractDistribution
class AbstractDistribution:
    # 抽象方法,生成样本
    def sample(self):
        raise NotImplementedError()

    # 抽象方法,返回分布的众数
    def mode(self):
        raise NotImplementedError()


# 定义 DiracDistribution 类,继承自 AbstractDistribution
class DiracDistribution(AbstractDistribution):
    # 初始化 DiracDistribution,接受一个值
    def __init__(self, value):
        self.value = value

    # 重写 sample 方法,返回固定值
    def sample(self):
        return self.value

    # 重写 mode 方法,返回固定值
    def mode(self):
        return self.value


# 定义 DiagonalGaussianDistribution 类
class DiagonalGaussianDistribution(object):
    # 初始化,接受参数和一个确定性标志
    def __init__(self, parameters, deterministic=False):
        self.parameters = parameters
        # 将参数分为均值和对数方差
        self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
        # 将对数方差限制在指定范围内
        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
        self.deterministic = deterministic
        # 计算标准差和方差
        self.std = torch.exp(0.5 * self.logvar)
        self.var = torch.exp(self.logvar)
        # 如果是确定性模式,将方差和标准差设为零
        if self.deterministic:
            self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)

    # 生成样本
    def sample(self):
        # 使用均值和标准差生成随机样本
        x = self.mean + self.std * torch.randn_like(self.mean).to(device=self.parameters.device)
        return x

    # 计算 KL 散度
    def kl(self, other=None):
        if self.deterministic:
            return torch.Tensor([0.0])  # 确定性时 KL 散度为 0
        else:
            if other is None:
                # 计算与标准正态分布的 KL 散度
                return 0.5 * torch.sum(
                    torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
                    dim=[1, 2, 3],
                )
            else:
                # 计算两个分布间的 KL 散度
                return 0.5 * torch.sum(
                    torch.pow(self.mean - other.mean, 2) / other.var
                    + self.var / other.var
                    - 1.0
                    - self.logvar
                    + other.logvar,
                    dim=[1, 2, 3],
                )

    # 计算负对数似然
    def nll(self, sample, dims=[1, 2, 3]):
        if self.deterministic:
            return torch.Tensor([0.0])  # 确定性时 NLL 为 0
        logtwopi = np.log(2.0 * np.pi)  # 计算 2π 的对数
        # 计算负对数似然值
        return 0.5 * torch.sum(
            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
            dim=dims,
        )

    # 返回均值
    def mode(self):
        return self.mean


# 定义 normal_kl 函数,计算两个高斯分布之间的 KL 散度
def normal_kl(mean1, logvar1, mean2, logvar2):
    """
    source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
    计算两个高斯分布之间的 KL 散度。
    形状会自动广播,支持批量比较和标量等用例。
    """
    tensor = None
    # 找到第一个是 Tensor 的对象
    for obj in (mean1, logvar1, mean2, logvar2):
        if isinstance(obj, torch.Tensor):
            tensor = obj
            break
    # 确保至少有一个参数是 Tensor
    assert tensor is not None, "at least one argument must be a Tensor"

    # 强制将方差转换为 Tensor
    logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
    # 计算并返回某种损失值,具体公式由多个部分组成
        return 0.5 * (  # 返回整个计算结果的一半
            -1.0 + logvar2 - logvar1 +  # 计算对数方差差值,并减去常数项
            torch.exp(logvar1 - logvar2) +  # 计算指数项,表示方差的相对差异
            ((mean1 - mean2) ** 2) * torch.exp(-logvar2)  # 计算均值差的平方乘以方差的倒数
        )

.\cogvideo-finetune\sat\sgm\modules\distributions\__init__.py

请提供需要注释的代码。

.\cogvideo-finetune\sat\sgm\modules\ema.py

# 导入 PyTorch 库
import torch
# 从 PyTorch 导入神经网络模块
from torch import nn


# 定义一个名为 LitEma 的类,继承自 nn.Module
class LitEma(nn.Module):
    # 初始化函数,接受模型、衰减因子和更新次数的使用标志
    def __init__(self, model, decay=0.9999, use_num_upates=True):
        # 调用父类构造函数
        super().__init__()
        # 检查衰减因子是否在有效范围内
        if decay < 0.0 or decay > 1.0:
            raise ValueError("Decay must be between 0 and 1")

        # 初始化模型参数名称到阴影参数名称的映射字典
        self.m_name2s_name = {}
        # 注册衰减因子为一个缓冲区
        self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
        # 根据是否使用更新次数注册相应的缓冲区
        self.register_buffer(
            "num_updates",
            torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int),
        )

        # 遍历模型的每个命名参数
        for name, p in model.named_parameters():
            # 只处理需要梯度的参数
            if p.requires_grad:
                # 将参数名称中的点替换为字符,以便注册为缓冲区
                s_name = name.replace(".", "")
                # 更新名称映射字典
                self.m_name2s_name.update({name: s_name})
                # 注册参数的副本为缓冲区
                self.register_buffer(s_name, p.clone().detach().data)

        # 初始化存储的参数列表
        self.collected_params = []

    # 重置更新次数的函数
    def reset_num_updates(self):
        # 删除当前的更新次数缓冲区
        del self.num_updates
        # 注册新的更新次数缓冲区,初始值为 0
        self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))

    # 前向传播函数,更新阴影参数
    def forward(self, model):
        # 获取当前的衰减因子
        decay = self.decay

        # 如果更新次数有效,更新衰减因子
        if self.num_updates >= 0:
            self.num_updates += 1
            decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))

        # 计算 1 减去衰减因子
        one_minus_decay = 1.0 - decay

        # 在不计算梯度的情况下执行以下操作
        with torch.no_grad():
            # 获取模型的参数字典
            m_param = dict(model.named_parameters())
            # 获取当前缓冲区中的阴影参数字典
            shadow_params = dict(self.named_buffers())

            # 遍历模型参数
            for key in m_param:
                # 只处理需要梯度的参数
                if m_param[key].requires_grad:
                    # 获取阴影参数名称
                    sname = self.m_name2s_name[key]
                    # 将阴影参数转换为与模型参数相同的类型
                    shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
                    # 更新阴影参数
                    shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
                else:
                    # 确保该参数不在名称映射中
                    assert not key in self.m_name2s_name

    # 将阴影参数复制回模型的函数
    def copy_to(self, model):
        # 获取模型参数和阴影参数的字典
        m_param = dict(model.named_parameters())
        shadow_params = dict(self.named_buffers())
        # 遍历模型参数
        for key in m_param:
            # 只处理需要梯度的参数
            if m_param[key].requires_grad:
                # 将阴影参数的数据复制到模型参数
                m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
            else:
                # 确保该参数不在名称映射中
                assert not key in self.m_name2s_name

    # 存储当前参数以便稍后恢复的函数
    def store(self, parameters):
        """
        保存当前参数以便稍后恢复。
        参数:
          parameters: 可迭代的 `torch.nn.Parameter`;要临时存储的参数。
        """
        # 将当前参数的副本存储在列表中
        self.collected_params = [param.clone() for param in parameters]
    # 定义一个恢复参数的方法
    def restore(self, parameters):
        # 文档字符串,说明此方法的作用及参数
        """
        Restore the parameters stored with the `store` method.
        Useful to validate the model with EMA parameters without affecting the
        original optimization process. Store the parameters before the
        `copy_to` method. After validation (or model saving), use this to
        restore the former parameters.
        Args:
          parameters: Iterable of `torch.nn.Parameter`; the parameters to be
            updated with the stored parameters.
        """
        # 遍历收集的参数和传入的参数,进行一一对应
        for c_param, param in zip(self.collected_params, parameters):
            # 将收集的参数数据复制到当前参数
            param.data.copy_(c_param.data)

.\cogvideo-finetune\sat\sgm\modules\encoders\modules.py

# 导入数学库
import math
# 导入上下文管理器的空上下文
from contextlib import nullcontext
# 导入部分函数的工具
from functools import partial
# 导入类型提示相关的类型
from typing import Dict, List, Optional, Tuple, Union

# 导入 Kornia 图像处理库
import kornia
# 导入 NumPy 库
import numpy as np
# 导入 PyTorch 库
import torch
# 导入 PyTorch 的神经网络模块
import torch.nn as nn
# 导入 einops 库的重排列和重复函数
from einops import rearrange, repeat
# 导入 OmegaConf 库的 ListConfig
from omegaconf import ListConfig
# 导入 PyTorch 的检查点工具
from torch.utils.checkpoint import checkpoint
# 导入 Hugging Face 的 T5 模型和分词器
from transformers import (
    T5EncoderModel,
    T5Tokenizer,
)

# 从自定义工具模块导入多个函数
from ...util import (
    append_dims,
    autocast,
    count_params,
    default,
    disabled_train,
    expand_dims_like,
    instantiate_from_config,
)


# 定义一个抽象的嵌入模型类,继承自 PyTorch 的 nn.Module
class AbstractEmbModel(nn.Module):
    # 初始化方法
    def __init__(self):
        # 调用父类初始化方法
        super().__init__()
        # 初始化是否可训练的标志
        self._is_trainable = None
        # 初始化 UCG 速率
        self._ucg_rate = None
        # 初始化输入键
        self._input_key = None

    # 返回是否可训练的属性
    @property
    def is_trainable(self) -> bool:
        return self._is_trainable

    # 返回 UCG 速率属性,可能是浮点数或张量
    @property
    def ucg_rate(self) -> Union[float, torch.Tensor]:
        return self._ucg_rate

    # 返回输入键属性
    @property
    def input_key(self) -> str:
        return self._input_key

    # 设置是否可训练的属性
    @is_trainable.setter
    def is_trainable(self, value: bool):
        self._is_trainable = value

    # 设置 UCG 速率属性
    @ucg_rate.setter
    def ucg_rate(self, value: Union[float, torch.Tensor]):
        self._ucg_rate = value

    # 设置输入键属性
    @input_key.setter
    def input_key(self, value: str):
        self._input_key = value

    # 删除是否可训练的属性
    @is_trainable.deleter
    def is_trainable(self):
        del self._is_trainable

    # 删除 UCG 速率属性
    @ucg_rate.deleter
    def ucg_rate(self):
        del self._ucg_rate

    # 删除输入键属性
    @input_key.deleter
    def input_key(self):
        del self._input_key


# 定义一个通用条件器类,继承自 PyTorch 的 nn.Module
class GeneralConditioner(nn.Module):
    # 定义输出维度到键的映射
    OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
    # 定义键到拼接维度的映射
    KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
    # 初始化方法,接受嵌入模型配置及相关参数
        def __init__(self, emb_models: Union[List, ListConfig], cor_embs=[], cor_p=[]):
            # 调用父类初始化方法
            super().__init__()
            # 用于存储嵌入模型实例的列表
            embedders = []
            # 遍历每个嵌入模型的配置
            for n, embconfig in enumerate(emb_models):
                # 根据配置实例化嵌入模型
                embedder = instantiate_from_config(embconfig)
                # 确保实例是 AbstractEmbModel 的子类
                assert isinstance(
                    embedder, AbstractEmbModel
                ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel"
                # 设置嵌入模型是否可训练
                embedder.is_trainable = embconfig.get("is_trainable", False)
                # 设置嵌入模型的 ucg_rate 参数
                embedder.ucg_rate = embconfig.get("ucg_rate", 0.0)
                # 如果模型不可训练
                if not embedder.is_trainable:
                    # 禁用训练方法
                    embedder.train = disabled_train
                    # 将模型参数的 requires_grad 属性设为 False
                    for param in embedder.parameters():
                        param.requires_grad = False
                    # 将模型设置为评估模式
                    embedder.eval()
                # 打印嵌入模型的初始化信息
                print(
                    f"Initialized embedder #{n}: {embedder.__class__.__name__} "
                    f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
                )
    
                # 检查配置中是否有 input_key,并赋值给嵌入模型
                if "input_key" in embconfig:
                    embedder.input_key = embconfig["input_key"]
                # 检查配置中是否有 input_keys,并赋值给嵌入模型
                elif "input_keys" in embconfig:
                    embedder.input_keys = embconfig["input_keys"]
                # 如果都没有,抛出 KeyError
                else:
                    raise KeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}")
    
                # 设置嵌入模型的 legacy_ucg_value 参数
                embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None)
                # 如果有 legacy_ucg_val,则初始化随机状态
                if embedder.legacy_ucg_val is not None:
                    embedder.ucg_prng = np.random.RandomState()
    
                # 将嵌入模型添加到列表中
                embedders.append(embedder)
            # 将嵌入模型列表转换为 nn.ModuleList
            self.embedders = nn.ModuleList(embedders)
    
            # 如果有 cor_embs,确保 cor_p 的长度正确
            if len(cor_embs) > 0:
                assert len(cor_p) == 2 ** len(cor_embs)
            # 设置相关嵌入和参数
            self.cor_embs = cor_embs
            self.cor_p = cor_p
    
        # 获取 UCG 值的方法,可能会基于概率进行赋值
        def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict:
            # 确保 legacy_ucg_val 不为 None
            assert embedder.legacy_ucg_val is not None
            # 获取 ucg_rate 参数
            p = embedder.ucg_rate
            # 获取 legacy_ucg_val 值
            val = embedder.legacy_ucg_val
            # 遍历 batch 中的输入数据
            for i in range(len(batch[embedder.input_key])):
                # 根据概率选择是否替换为 legacy_ucg_val
                if embedder.ucg_prng.choice(2, p=[1 - p, p]):
                    batch[embedder.input_key][i] = val
            # 返回更新后的 batch
            return batch
    
        # 获取 UCG 值的方法,基于条件进行赋值
        def surely_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict, cond_or_not) -> Dict:
            # 确保 legacy_ucg_val 不为 None
            assert embedder.legacy_ucg_val is not None
            # 获取 legacy_ucg_val 值
            val = embedder.legacy_ucg_val
            # 遍历 batch 中的输入数据
            for i in range(len(batch[embedder.input_key])):
                # 如果条件为真,则替换为 legacy_ucg_val
                if cond_or_not[i]:
                    batch[embedder.input_key][i] = val
            # 返回更新后的 batch
            return batch
    
        # 获取单个嵌入的方法
        def get_single_embedding(
            self,
            embedder,
            batch,
            output,
            cond_or_not: Optional[np.ndarray] = None,
            force_zero_embeddings: Optional[List] = None,
    ):
        # 根据 embedder 是否可训练选择适当的上下文管理器
        embedding_context = nullcontext if embedder.is_trainable else torch.no_grad
        # 进入上下文管理器以控制梯度计算
        with embedding_context():
            # 检查 embedder 是否有输入键属性并且不为 None
            if hasattr(embedder, "input_key") and (embedder.input_key is not None):
                # 检查 embedder 的遗留 UCG 值是否不为 None
                if embedder.legacy_ucg_val is not None:
                    # 如果条件为 None,获取可能的 UCG 值
                    if cond_or_not is None:
                        batch = self.possibly_get_ucg_val(embedder, batch)
                    # 否则,确保获取 UCG 值
                    else:
                        batch = self.surely_get_ucg_val(embedder, batch, cond_or_not)
                # 使用指定的输入键从 batch 中获取嵌入输出
                emb_out = embedder(batch[embedder.input_key])
            # 检查 embedder 是否有输入键列表
            elif hasattr(embedder, "input_keys"):
                # 解包 batch 中的输入键以获取嵌入输出
                emb_out = embedder(*[batch[k] for k in embedder.input_keys])
        # 确保嵌入输出是张量或序列类型
        assert isinstance(
            emb_out, (torch.Tensor, list, tuple)
        ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}"
        # 如果嵌入输出不是列表或元组,则将其转换为列表
        if not isinstance(emb_out, (list, tuple)):
            emb_out = [emb_out]
        # 遍历嵌入输出中的每个嵌入
        for emb in emb_out:
            # 根据嵌入的维度获取对应的输出键
            out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
            # 如果 UCG 率大于 0 且没有遗留 UCG 值
            if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
                # 如果条件为 None,随机应用 UCG 率
                if cond_or_not is None:
                    emb = (
                        expand_dims_like(
                            torch.bernoulli((1.0 - embedder.ucg_rate) * torch.ones(emb.shape[0], device=emb.device)),
                            emb,
                        )
                        * emb
                    )
                # 否则,基于条件应用 UCG
                else:
                    emb = (
                        expand_dims_like(
                            torch.tensor(1 - cond_or_not, dtype=emb.dtype, device=emb.device),
                            emb,
                        )
                        * emb
                    )
            # 如果输入键在强制零嵌入列表中,将嵌入置为零
            if hasattr(embedder, "input_key") and embedder.input_key in force_zero_embeddings:
                emb = torch.zeros_like(emb)
            # 如果输出中已有该输出键,则拼接嵌入
            if out_key in output:
                output[out_key] = torch.cat((output[out_key], emb), self.KEY2CATDIM[out_key])
            # 否则,直接保存嵌入
            else:
                output[out_key] = emb
        # 返回最终输出字典
        return output
    # 定义一个前向传播函数,接收批次数据和强制零嵌入列表
        def forward(self, batch: Dict, force_zero_embeddings: Optional[List] = None) -> Dict:
            # 初始化输出字典
            output = dict()
            # 如果没有强制零嵌入,初始化为空列表
            if force_zero_embeddings is None:
                force_zero_embeddings = []
    
            # 如果相关嵌入列表不为空
            if len(self.cor_embs) > 0:
                # 获取批次大小
                batch_size = len(batch[list(batch.keys())[0]])
                # 根据相关概率随机选择索引
                rand_idx = np.random.choice(len(self.cor_p), size=(batch_size,), p=self.cor_p)
                # 遍历相关嵌入索引
                for emb_idx in self.cor_embs:
                    # 计算条件是否满足
                    cond_or_not = rand_idx % 2
                    # 更新随机索引
                    rand_idx //= 2
                    # 获取单个嵌入并更新输出字典
                    output = self.get_single_embedding(
                        self.embedders[emb_idx],
                        batch,
                        output=output,
                        cond_or_not=cond_or_not,
                        force_zero_embeddings=force_zero_embeddings,
                    )
    
            # 遍历所有嵌入
            for i, embedder in enumerate(self.embedders):
                # 如果当前索引在相关嵌入列表中,则跳过
                if i in self.cor_embs:
                    continue
                # 获取单个嵌入并更新输出字典
                output = self.get_single_embedding(
                    embedder, batch, output=output, force_zero_embeddings=force_zero_embeddings
                )
            # 返回最终输出字典
            return output
    
        # 定义获取无条件条件的函数
        def get_unconditional_conditioning(self, batch_c, batch_uc=None, force_uc_zero_embeddings=None):
            # 如果没有强制无条件嵌入,初始化为空列表
            if force_uc_zero_embeddings is None:
                force_uc_zero_embeddings = []
            # 初始化无条件生成率列表
            ucg_rates = list()
            # 遍历所有嵌入,保存其生成率并将生成率设置为零
            for embedder in self.embedders:
                ucg_rates.append(embedder.ucg_rate)
                embedder.ucg_rate = 0.0
            # 保存当前相关嵌入和概率
            cor_embs = self.cor_embs
            cor_p = self.cor_p
            # 清空相关嵌入和概率
            self.cor_embs = []
            self.cor_p = []
    
            # 计算条件输出
            c = self(batch_c)
            # 计算无条件输出
            uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings)
    
            # 恢复每个嵌入的生成率
            for embedder, rate in zip(self.embedders, ucg_rates):
                embedder.ucg_rate = rate
            # 恢复相关嵌入和概率
            self.cor_embs = cor_embs
            self.cor_p = cor_p
    
            # 返回条件输出和无条件输出
            return c, uc
# 定义一个名为 FrozenT5Embedder 的类,继承自 AbstractEmbModel
class FrozenT5Embedder(AbstractEmbModel):
    """使用 T5 变换器编码器处理文本"""

    # 初始化方法,设置模型的基本参数
    def __init__(
        self,
        model_dir="google/t5-v1_1-xxl",  # 模型目录,默认是 T5 模型
        device="cuda",                   # 设备设置,默认使用 GPU
        max_length=77,                   # 输入文本的最大长度
        freeze=True,                     # 是否冻结模型参数,默认是冻结
        cache_dir=None,                  # 缓存目录,默认无
    ):
        super().__init__()               # 调用父类的初始化方法
        # 检查模型目录是否为默认 T5 模型
        if model_dir is not "google/t5-v1_1-xxl":
            # 从指定目录加载 tokenizer 和 transformer 模型
            self.tokenizer = T5Tokenizer.from_pretrained(model_dir)
            self.transformer = T5EncoderModel.from_pretrained(model_dir)
        else:
            # 从指定目录加载 tokenizer 和 transformer 模型,同时指定缓存目录
            self.tokenizer = T5Tokenizer.from_pretrained(model_dir, cache_dir=cache_dir)
            self.transformer = T5EncoderModel.from_pretrained(model_dir, cache_dir=cache_dir)
        # 设置设备
        self.device = device
        # 设置最大输入长度
        self.max_length = max_length
        # 如果需要冻结模型参数,调用 freeze 方法
        if freeze:
            self.freeze()

    # 定义冻结方法
    def freeze(self):
        # 将 transformer 设置为评估模式
        self.transformer = self.transformer.eval()

        # 遍历模型参数,设置为不需要梯度更新
        for param in self.parameters():
            param.requires_grad = False

    # @autocast 注解(注释掉的),用于自动混合精度计算
    def forward(self, text):
        # 使用 tokenizer 对输入文本进行编码,返回批处理编码结果
        batch_encoding = self.tokenizer(
            text,
            truncation=True,                # 超出最大长度时进行截断
            max_length=self.max_length,     # 设置最大长度
            return_length=True,             # 返回编码长度
            return_overflowing_tokens=False, # 不返回溢出的 tokens
            padding="max_length",           # 填充到最大长度
            return_tensors="pt",           # 返回 PyTorch 张量
        )
        # 将输入 id 转移到指定设备
        tokens = batch_encoding["input_ids"].to(self.device)
        # 在禁用自动混合精度的上下文中进行前向传播
        with torch.autocast("cuda", enabled=False):
            # 使用 transformer 进行前向传播,获取输出
            outputs = self.transformer(input_ids=tokens)
        # 获取 transformer 输出的最后隐藏状态
        z = outputs.last_hidden_state
        # 返回最后隐藏状态
        return z

    # 定义编码方法,直接调用 forward 方法
    def encode(self, text):
        return self(text)  # 将输入文本传递给 forward 方法进行编码

标签:None,CogVideoX,self,torch,---,源码,return,sigma,def
From: https://www.cnblogs.com/apachecn/p/18494408

相关文章

  • CogVideo---CogVideoX-微调代码源码解析-七-
    CogVideo&CogVideoX微调代码源码解析(七).\cogvideo-finetune\sat\sgm\modules\autoencoding\regularizers\base.py#导入抽象方法和类型注解fromabcimportabstractmethod#导入任意类型和元组类型fromtypingimportAny,Tuple#导入PyTorch和功能模块importtorc......
  • CogVideo---CogVideoX-微调代码源码解析-六-
    CogVideo&CogVideoX微调代码源码解析(六).\cogvideo-finetune\sat\sgm\modules\autoencoding\losses\discriminator_loss.py#导入所需的类型定义fromtypingimportDict,Iterator,List,Optional,Tuple,Union#导入NumPy库importnumpyasnp#导入PyTorch库im......
  • CogVideo---CogVideoX-微调代码源码解析-九-
    CogVideo&CogVideoX微调代码源码解析(九).\cogvideo-finetune\sat\sgm\modules\diffusionmodules\guiders.py#导入logging模块以进行日志记录importlogging#从abc模块导入ABC和abstractmethod以定义抽象基类和抽象方法fromabcimportABC,abstractmethod#从......
  • CogVideo---CogVideoX-微调代码源码解析-二-
    CogVideo&CogVideoX微调代码源码解析(二).\cogvideo-finetune\inference\cli_demo.py#脚本说明:演示如何使用CogVideoX模型通过HuggingFace`diffusers`管道生成视频"""ThisscriptdemonstrateshowtogenerateavideousingtheCogVideoXmodelwiththeHuggingF......
  • CogVideo---CogVideoX-微调代码源码解析-八-
    CogVideo&CogVideoX微调代码源码解析(八).\cogvideo-finetune\sat\sgm\modules\autoencoding\vqvae\movq_enc_3d.py#pytorch_diffusion+derivedencoderdecoderimportmath#导入数学库,提供数学函数importtorch#导入PyTorch库,用于深度学习importtorch.nnasnn......
  • CogView3---CogView-3Plus-微调代码源码解析-五-
    CogView3&CogView-3Plus微调代码源码解析(五).\cogview3-finetune\sat\sgm\modules\encoders\__init__.py请提供需要注释的代码。.\cogview3-finetune\sat\sgm\modules\__init__.py#从当前包的编码器模块导入GeneralConditioner类from.encoders.modulesimportGeneral......
  • CogView3---CogView-3Plus-微调代码源码解析-四-
    CogView3&CogView-3Plus微调代码源码解析(四).\cogview3-finetune\sat\sgm\modules\diffusionmodules\sampling_utils.py#导入数学库以进行数学运算importmath#导入PyTorch库以进行张量操作importtorch#从SciPy库导入积分函数fromscipyimportintegrate#从......
  • CogView3---CogView-3Plus-微调代码源码解析-三-
    CogView3&CogView-3Plus微调代码源码解析(三).\cogview3-finetune\sat\sgm\modules\diffusionmodules\guiders.py#导入logging模块,用于记录日志信息importlogging#从abc模块导入ABC类和abstractmethod装饰器,用于定义抽象基类和抽象方法fromabcimportABC,abst......
  • CogView3---CogView-3Plus-微调代码源码解析-二-
    CogView3&CogView-3Plus微调代码源码解析(二).\cogview3-finetune\sat\sgm\models\__init__.py#从同一模块导入AutoencodingEngine类,用于后续的自动编码器操作from.autoencoderimportAutoencodingEngine#注释文本(可能是无关信息或标识符)#XuDwndGaCFo.\cogview3-fi......
  • selenium单例模式下 docker-chrome 多线程并发代码
    最近需要写爬虫,在解决docker-standalone-chrome发现只能有一个chrome被执行。所以写了这个多线程并发控制类来管理。当模板记录下。#!/usr/bin/envpython3importthreadingimporttracebackfromloguruimportloggerfromseleniumimportwebdriverfromselenium.comm......