首页 > 编程语言 >CogView3---CogView-3Plus-微调代码源码解析-四-

CogView3---CogView-3Plus-微调代码源码解析-四-

时间:2024-10-23 09:22:40浏览次数:1  
标签:__ return 3Plus CogView self torch 源码 key def

CogView3 & CogView-3Plus 微调代码源码解析(四)

.\cogview3-finetune\sat\sgm\modules\diffusionmodules\sampling_utils.py

# 导入数学库以进行数学运算
import math
# 导入 PyTorch 库以进行张量操作
import torch
# 从 SciPy 库导入积分函数
from scipy import integrate

# 从上级目录的 util 模块导入 append_dims 函数
from ...util import append_dims

# 定义一个不进行动态阈值处理的类
class NoDynamicThresholding:
    # 定义类的调用方法
    def __call__(self, uncond, cond, scale, **kwargs):
        # 返回无条件和有条件的加权差值
        return uncond + scale * (cond - uncond)

# 定义一个重新缩放阈值处理的类
class RescaleThresholding:
    # 初始化类,设置 phi 的默认值
    def __init__(self, phi=0.7):
        # 保存 phi 参数
        self.phi = phi

    # 定义类的调用方法
    def __call__(self, uncond, cond, scale, **kwargs):
        # 计算去噪后的配置
        denoised_cfg = uncond + scale * (cond - uncond)
        # 计算条件和去噪配置的标准差
        sigma_pos, sigma_cfg = cond.std(), denoised_cfg.std()
        # 计算缩放因子
        factor = self.phi * sigma_pos / sigma_cfg + (1 - self.phi)
        # 根据因子调整去噪结果
        denoised_final = denoised_cfg * factor
        # 返回最终去噪结果
        return denoised_final

# 定义一个动态阈值处理的类
class DynamicThresholding:
    # 定义可选模式的列表
    Modes = ["Constant", "Linear Up", "Linear Down", "Half Cosine Up", "Half Cosine Down", "Power Up", "Power Down", "Cosine Down","Cosine Up"]
    # 初始化类并设置参数
    def __init__(self, interpret_mode, 
                 scale_min = 3,
                 mimic_interpret_mode = 'Constant',
                 mimic_scale = 3, 
                 mimic_scale_min = 3, 
                 threshold_percentile = 1.0,
                 phi = 1.0,
                 separate_feature_channels = True,
                 measure = 'AD',
                 scaling_startpoint = 'ZERO',
                 ):
        # 验证解释模式是否在可选模式中
        assert interpret_mode in self.Modes
        # 验证模仿解释模式是否在可选模式中
        assert mimic_interpret_mode in self.Modes
        # 验证测量方法是否合法
        assert measure in ['AD', 'STD']
        # 验证缩放起点是否合法
        assert scaling_startpoint in ['ZERO', 'MEAN']
        # 保存各种初始化参数
        self.mode = interpret_mode
        self.mimic_mode = mimic_interpret_mode
        self.scale_min = scale_min
        self.mimic_scale = mimic_scale
        self.mimic_scale_min = mimic_scale_min
        self.threshold_percentile = threshold_percentile
        self.phi = phi
        self.separate_feature_channels = separate_feature_channels
        self.measure = measure
        self.scaling_startpoint = scaling_startpoint
    
    # 定义解释缩放的方法
    def interpret_scale(self, mode, scale, scale_min, step, num_steps):
        """
        num_steps = 50
        step from 0 to 50.
        """
        # 将缩放值减去最小缩放值
        scale -= scale_min
        # 计算当前步骤的比例
        frac = step / num_steps
        # 根据模式调整缩放值
        if mode == 'Constant':
            pass
        elif mode == "Linear Up":
            scale *= frac
        elif mode == "Linear Down":
            scale *= 1.0 - frac
        elif mode == "Half Cosine Up":
            scale *= 1.0 - math.cos(frac)
        elif mode == "Half Cosine Down":
            scale *= math.cos(frac)
        elif mode == "Cosine Down":
            scale *= math.cos(frac * 1.5707)
        elif mode == "Cosine Up":
            scale *= 1.0 - math.cos(frac * 1.5707)
        elif mode == "Power Up":
            scale *= math.pow(frac, 2.0)
        elif mode == "Power Down":
            scale *= 1.0 - math.pow(frac, 2.0)
        # 将调整后的缩放值加回最小缩放值
        scale += scale_min
        # 返回最终的缩放值
        return scale
    # 定义调用方法,接受无条件和条件输入,以及缩放和步骤参数
    def __call__(self, uncond, cond, scale, step, num_steps):
        # 根据当前模式解释缩放参数,计算 cfg_scale
        cfg_scale = self.interpret_scale(self.mode, scale, self.scale_min, step, num_steps)
        # 根据模拟模式解释缩放参数,计算 mimic_cfg_scale
        mimic_cfg_scale = self.interpret_scale(self.mimic_mode, self.mimic_scale, self.mimic_scale_min, step, num_steps)
    
        # 计算 x,作为无条件输入和条件输入之间的线性插值
        x = uncond + cfg_scale*(cond - uncond)
        # 计算 mimic_x,作为无条件输入和条件输入之间的线性插值,使用 mimic_cfg_scale
        mimic_x = uncond + mimic_cfg_scale*(cond - uncond)  
    
        # 将 x 展平,以便于后续操作
        x_flattened = x.flatten(2)
        # 将 mimic_x 展平,以便于后续操作
        mimic_x_flattened = mimic_x.flatten(2)
        
        # 根据缩放起始点的选择,计算均值并中心化
        if self.scaling_startpoint == 'MEAN':
            # 计算 x 的均值,保留维度
            x_means = x_flattened.mean(dim=2, keepdim = True)
            # 计算 mimic_x 的均值,保留维度
            mimic_x_means = mimic_x_flattened.mean(dim=2, keepdim = True)
            # 通过均值中心化 x
            x_centered = x_flattened - x_means
            # 通过均值中心化 mimic_x
            mimic_x_centered = mimic_x_flattened - mimic_x_means
        else:
            # 如果不使用均值中心化,直接赋值
            x_centered = x_flattened
            mimic_x_centered = mimic_x_flattened
                
        # 根据是否分开特征通道的选项,计算尺度参考
        if self.separate_feature_channels:
            # 如果测量方式为绝对差异
            if self.measure == 'AD':
                # 计算 x 的绝对差异的分位数作为尺度参考
                x_scaleref = torch.quantile(x_centered.abs(), self.threshold_percentile, dim=2, keepdim = True)
                # 计算 mimic_x 的绝对差异的最大值作为尺度参考
                mimic_x_scaleref = mimic_x_centered.abs().max(dim=2, keepdim = True).values
            # 如果测量方式为标准差
            elif self.measure == 'STD':
                # 计算 x 的标准差作为尺度参考
                x_scaleref = x_centered.std(dim=2, keepdim = True)
                # 计算 mimic_x 的标准差作为尺度参考
                mimic_x_scaleref = mimic_x_centered.std(dim=2, keepdim = True)
        else:
            # 如果不分开特征通道
            if self.measure == 'AD':
                # 计算 x 的绝对差异的分位数作为尺度参考
                x_scaleref = torch.quantile(x_centered.abs(), self.threshold_percentile)
                # 计算 mimic_x 的绝对差异的最大值作为尺度参考
                mimic_x_scaleref = mimic_x_centered.abs().max()
            # 如果测量方式为标准差
            elif self.measure == 'STD':
                # 计算 x 的标准差作为尺度参考
                x_scaleref = x_centered.std()
                # 计算 mimic_x 的标准差作为尺度参考
                mimic_x_scaleref = mimic_x_centered.std()
            
        # 根据测量方式调整 x 的尺度
        if self.measure == 'AD':
            # 计算 x_scaleref 和 mimic_x_scaleref 的最大值
            max_scaleref = torch.maximum(x_scaleref, mimic_x_scaleref)
            # 限制 x_centered 的值在 [-max_scaleref, max_scaleref] 范围内
            x_clamped = x_centered.clamp(-max_scaleref, max_scaleref)
            # 重新归一化 x
            x_renormed = x_clamped * (mimic_x_scaleref / max_scaleref)
        elif self.measure == 'STD':
            # 重新归一化 x
            x_renormed = x_centered * (mimic_x_scaleref / x_scaleref)
        
        # 根据缩放起始点选择调整 x_dyn 的值
        if self.scaling_startpoint == 'MEAN':
            # 将均值与重新归一化的结果相加
            x_dyn = x_means + x_renormed
        else:
            # 直接使用重新归一化的结果
            x_dyn = x_renormed
        # 反展平 x_dyn,恢复原始形状
        x_dyn = x_dyn.unflatten(2, x.shape[2:])
        # 返回加权和结果
        return self.phi*x_dyn + (1-self.phi)*x
# 定义一个线性多步系数函数,接受阶数、时间点和索引等参数
def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
    # 检查阶数是否超出当前步骤的限制,超出则抛出错误
    if order - 1 > i:
        raise ValueError(f"Order {order} too high for step {i}")

    # 定义内部函数,用于计算特定tau下的乘积
    def fn(tau):
        prod = 1.0  # 初始化乘积为1
        # 遍历从0到阶数的每个k
        for k in range(order):
            # 跳过与j相等的k
            if j == k:
                continue
            # 计算乘积,基于给定的tau和时间点
            prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
        return prod  # 返回计算得到的乘积

    # 使用数值积分计算fn在时间点[i]到[i+1]之间的积分,返回积分值
    return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0]


# 定义一个函数,用于计算祖先步骤的sigma值
def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
    # 如果eta为0,则返回sigma_to和0.0
    if not eta:
        return sigma_to, 0.0
    # 计算sigma_up,限制为sigma_to与特定表达式的较小值
    sigma_up = torch.minimum(
        sigma_to,
        eta
        * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
    )
    # 计算sigma_down,基于sigma_to和sigma_up
    sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
    # 返回计算得到的sigma_down和sigma_up
    return sigma_down, sigma_up


# 定义一个函数,将去噪后的图像与输入图像进行处理,返回标准化结果
def to_d(x, sigma, denoised):
    # 计算去噪后的结果与输入的差异,并除以sigma的扩展维度
    return (x - denoised) / append_dims(sigma, x.ndim)


# 定义一个函数,将sigma转换为负对数形式
def to_neg_log_sigma(sigma):
    # 计算sigma的对数并取负值
    return sigma.log().neg()


# 定义一个函数,将负对数sigma转换为sigma
def to_sigma(neg_log_sigma):
    # 取负值并计算指数,得到sigma
    return neg_log_sigma.neg().exp()

.\cogview3-finetune\sat\sgm\modules\diffusionmodules\sigma_sampling.py

# 导入 PyTorch 库
import torch

# 从相对路径导入 default 和 instantiate_from_config 函数
from ...util import default, instantiate_from_config


# 定义 EDMSampling 类
class EDMSampling:
    # 初始化方法,设置均值和标准差
    def __init__(self, p_mean=-1.2, p_std=1.2):
        # 保存均值到实例变量
        self.p_mean = p_mean
        # 保存标准差到实例变量
        self.p_std = p_std

    # 定义调用方法,允许类实例像函数一样被调用
    def __call__(self, n_samples, rand=None):
        # 计算对数标准差,根据随机数生成 log_sigma
        log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,)))
        # 返回 log_sigma 的指数值
        return log_sigma.exp()


# 定义 DiscreteSampling 类
class DiscreteSampling:
    # 初始化方法,设置离散化配置和其他参数
    def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, low_bound=0, up_bound=1):
        # 保存索引数量到实例变量
        self.num_idx = num_idx
        # 根据配置实例化 sigma 对象
        self.sigmas = instantiate_from_config(discretization_config)(
            num_idx, do_append_zero=do_append_zero, flip=flip
        )
        # 计算并保存下界和上界
        self.low_bound = int(low_bound * num_idx)
        self.up_bound = int(up_bound * num_idx)
        # 打印采样范围
        print(f'sigma sampling from {self.low_bound} to {self.up_bound}')

    # 将索引转换为 sigma 值
    def idx_to_sigma(self, idx):
        # 返回对应索引的 sigma
        return self.sigmas[idx]

    # 定义调用方法
    def __call__(self, n_samples, rand=None, return_idx=False):
        # 生成随机索引,如果没有提供随机数则使用默认随机数
        idx = default(
            rand,
            torch.randint(self.low_bound, self.up_bound, (n_samples,)),
        )
        # 根据 return_idx 参数决定返回 sigma 值或索引
        if return_idx:
            return self.idx_to_sigma(idx), idx
        else:
            return self.idx_to_sigma(idx)

# 定义 PartialDiscreteSampling 类
class PartialDiscreteSampling:
    # 初始化方法,设置完整和部分索引数量
    def __init__(self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True):
        # 保存总索引数量到实例变量
        self.total_num_idx = total_num_idx
        # 保存部分索引数量到实例变量
        self.partial_num_idx = partial_num_idx
        # 根据配置实例化 sigma 对象
        self.sigmas = instantiate_from_config(discretization_config)(
            total_num_idx, do_append_zero=do_append_zero, flip=flip
        )

    # 将索引转换为 sigma 值
    def idx_to_sigma(self, idx):
        # 返回对应索引的 sigma
        return self.sigmas[idx]

    # 定义调用方法
    def __call__(self, n_samples, rand=None):
        # 生成随机索引,根据部分索引数量限制随机范围
        idx = default(
            rand,
            # torch.randint(self.total_num_idx-self.partial_num_idx, self.total_num_idx, (n_samples,)),
            torch.randint(0, self.partial_num_idx, (n_samples,)),
        )
        # 返回对应的 sigma 值
        return self.idx_to_sigma(idx)

.\cogview3-finetune\sat\sgm\modules\diffusionmodules\util.py

"""
引用自
https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
和
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
和
https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py

感谢!
"""

# 导入数学库以执行数学运算
import math
# 从 typing 模块导入可选类型
from typing import Optional

# 导入 PyTorch 库
import torch
# 导入 PyTorch 的神经网络模块
import torch.nn as nn
# 从 einops 库导入 rearrange 和 repeat 函数,用于张量操作
from einops import rearrange, repeat


# 创建一个 beta 调度函数
def make_beta_schedule(
    schedule,  # 调度类型
    n_timestep,  # 时间步数
    linear_start=1e-4,  # 线性调度的起始值
    linear_end=2e-2,  # 线性调度的结束值
):
    # 如果调度类型为线性
    if schedule == "linear":
        # 生成从起始值到结束值的线性空间并平方
        betas = (
            torch.linspace(
                linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
            )
            ** 2
        )
    # 返回生成的 beta 值作为 NumPy 数组
    return betas.numpy()


# 将张量中的值提取到指定形状的输出张量中
def extract_into_tensor(a, t, x_shape):
    # 解构 t 的形状,获取批大小
    b, *_ = t.shape
    # 从 a 中根据 t 的索引提取值
    out = a.gather(-1, t)
    # 将输出调整为指定的形状
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))


# 混合检查点功能的定义
def mixed_checkpoint(func, inputs: dict, params, flag):
    """
    评估函数而不缓存中间激活,从而减少内存使用,但在反向传播中增加计算量。
    与原始的检查点函数不同,它也可以处理非张量输入。
    :param func: 要评估的函数。
    :param inputs: 传递给 `func` 的参数字典。
    :param params: `func` 依赖但不明确作为参数传递的参数序列。
    :param flag: 如果为 False,则禁用梯度检查点。
    """
    # 如果标志为真
    if flag:
        # 获取所有张量键
        tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)]
        # 获取所有张量输入
        tensor_inputs = [
            inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)
        ]
        # 获取所有非张量键
        non_tensor_keys = [
            key for key in inputs if not isinstance(inputs[key], torch.Tensor)
        ]
        # 获取所有非张量输入
        non_tensor_inputs = [
            inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)
        ]
        # 将所有输入和参数组合成元组
        args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params)
        # 调用混合检查点功能并返回结果
        return MixedCheckpointFunction.apply(
            func,
            len(tensor_inputs),
            len(non_tensor_inputs),
            tensor_keys,
            non_tensor_keys,
            *args,
        )
    else:
        # 直接调用函数并返回结果
        return func(**inputs)


# 定义混合检查点功能的类
class MixedCheckpointFunction(torch.autograd.Function):
    @staticmethod
    # 定义前向传播方法
    def forward(
        ctx,  # 上下文对象
        run_function,  # 要运行的函数
        length_tensors,  # 张量的长度
        length_non_tensors,  # 非张量的长度
        tensor_keys,  # 张量的键
        non_tensor_keys,  # 非张量的键
        *args,  # 其他参数
    ):
        # 设置结束张量的数量
        ctx.end_tensors = length_tensors
        # 设置结束非张量的数量
        ctx.end_non_tensors = length_tensors + length_non_tensors
        # 初始化 GPU 自动混合精度参数
        ctx.gpu_autocast_kwargs = {
            "enabled": torch.is_autocast_enabled(),  # 检查自动混合精度是否启用
            "dtype": torch.get_autocast_gpu_dtype(),  # 获取当前自动混合精度数据类型
            "cache_enabled": torch.is_autocast_cache_enabled(),  # 检查缓存是否启用
        }
        # 确保张量键和非张量键的数量与传入长度一致
        assert (
            len(tensor_keys) == length_tensors
            and len(non_tensor_keys) == length_non_tensors
        )

        # 将输入张量映射到字典,键为 tensor_keys,值为相应的 args
        ctx.input_tensors = {
            key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))
        }
        # 将输入非张量映射到字典,键为 non_tensor_keys,值为相应的 args
        ctx.input_non_tensors = {
            key: val
            for (key, val) in zip(
                non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors])
            )
        }
        # 保存运行函数
        ctx.run_function = run_function
        # 获取剩余输入参数
        ctx.input_params = list(args[ctx.end_non_tensors :])

        # 在不计算梯度的上下文中运行
        with torch.no_grad():
            # 调用运行函数并传入输入张量和非张量
            output_tensors = ctx.run_function(
                **ctx.input_tensors, **ctx.input_non_tensors
            )
        # 返回输出张量
        return output_tensors

    @staticmethod
    def backward(ctx, *output_grads):
        # 将输入张量中的所有张量设为需要梯度
        ctx.input_tensors = {
            key: ctx.input_tensors[key].detach().requires_grad_(True)
            for key in ctx.input_tensors
        }

        # 启用梯度计算并设置自动混合精度上下文
        with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
            # 创建输入张量的浅拷贝以避免原地修改
            shallow_copies = {
                key: ctx.input_tensors[key].view_as(ctx.input_tensors[key])
                for key in ctx.input_tensors
            }
            # shallow_copies.update(additional_args)
            # 调用运行函数并传入浅拷贝和非张量输入
            output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors)
        # 计算输出张量相对于输入张量和参数的梯度
        input_grads = torch.autograd.grad(
            output_tensors,
            list(ctx.input_tensors.values()) + ctx.input_params,
            output_grads,
            allow_unused=True,
        )
        # 删除上下文中的输入张量和参数
        del ctx.input_tensors
        del ctx.input_params
        del output_tensors
        # 返回梯度和填充的 None 值
        return (
            (None, None, None, None, None)
            + input_grads[: ctx.end_tensors]
            + (None,) * (ctx.end_non_tensors - ctx.end_tensors)
            + input_grads[ctx.end_tensors :]
        )
# 定义一个检查点函数,用于评估给定的函数,降低内存使用,同时增加计算开销
def checkpoint(func, inputs, params, flag):
    """
    评估一个函数,避免缓存中间激活,减少内存使用,但在反向传播时增加计算量。
    :param func: 要评估的函数。
    :param inputs: 传递给 `func` 的参数序列。
    :param params: `func` 依赖的参数序列,但并不显式作为参数接受。
    :param flag: 如果为 False,禁用梯度检查点。
    """
    # 如果 flag 为 True,启用梯度检查点
    if flag:
        # 将输入参数和依赖参数组合成一个元组
        args = tuple(inputs) + tuple(params)
        # 应用检查点函数,返回计算结果
        return CheckpointFunction.apply(func, len(inputs), *args)
    else:
        # 如果 flag 为 False,直接调用 func 函数
        return func(*inputs)


# 定义检查点函数的类,继承自 torch.autograd.Function
class CheckpointFunction(torch.autograd.Function):
    # 定义前向传播的静态方法
    @staticmethod
    def forward(ctx, run_function, length, *args):
        # 将运行的函数保存到上下文中
        ctx.run_function = run_function
        # 保存输入张量(前 length 个参数)
        ctx.input_tensors = list(args[:length])
        # 保存输入参数(后面的参数)
        ctx.input_params = list(args[length:])
        # 获取当前的 GPU 自动混合精度设置
        ctx.gpu_autocast_kwargs = {
            "enabled": torch.is_autocast_enabled(),
            "dtype": torch.get_autocast_gpu_dtype(),
            "cache_enabled": torch.is_autocast_cache_enabled(),
        }
        # 在不计算梯度的情况下运行函数
        with torch.no_grad():
            output_tensors = ctx.run_function(*ctx.input_tensors)
        # 返回输出张量
        return output_tensors

    # 定义反向传播的静态方法
    @staticmethod
    def backward(ctx, *output_grads):
        # 将输入张量分离,并设置为需要梯度
        ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
        # 启用梯度计算并设置自动混合精度
        with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
            # 解决一个 bug,确保运行函数的第一个操作不会修改分离的张量存储
            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
            # 使用浅拷贝运行函数以获取输出张量
            output_tensors = ctx.run_function(*shallow_copies)
        # 计算输入梯度
        input_grads = torch.autograd.grad(
            output_tensors,
            ctx.input_tensors + ctx.input_params,
            output_grads,
            allow_unused=True,
        )
        # 删除上下文中的输入张量
        del ctx.input_tensors
        # 删除上下文中的输入参数
        del ctx.input_params
        # 删除输出张量
        del output_tensors
        # 返回 None 和输入梯度
        return (None, None) + input_grads


# 定义时间步嵌入的函数
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, dtype=torch.float32):
    """
    创建正弦时间步嵌入。
    :param timesteps: 一个 1-D 张量,包含每个批次元素的 N 个索引。
                      这些索引可以是小数。
    :param dim: 输出的维度。
    :param max_period: 控制嵌入的最小频率。
    :return: 一个形状为 [N x dim] 的位置嵌入张量。
    """
    # 如果不只是重复时序
        if not repeat_only:
            # 计算半个维度,用于频率计算
            half = dim // 2
            # 生成频率数组,基于最大周期和半个维度
            freqs = torch.exp(
                -math.log(max_period)  # 计算最大周期的对数
                * torch.arange(start=0, end=half, dtype=torch.float32)  # 生成从0到half的浮点数数组
                / half  # 归一化,使频率在0到1之间
            ).to(device=timesteps.device)  # 将频率数组移动到与timesteps相同的设备
            # 计算时序与频率的乘积,准备嵌入向量
            args = timesteps[:, None].float() * freqs[None]
            # 通过计算余弦和正弦生成嵌入向量
            embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
            # 如果维度是奇数,添加额外的零向量
            if dim % 2:
                embedding = torch.cat(
                    [embedding, torch.zeros_like(embedding[:, :1])], dim=-1  # 在最后一维追加零向量
                )
        # 如果是重复时序,生成与时序相同的嵌入
        else:
            embedding = repeat(timesteps, "b -> b d", d=dim)  # 根据时序生成重复的嵌入向量
        # 返回嵌入向量并设置数据类型
        return embedding.to(dtype)  # 将嵌入向量转换为指定的数据类型
# 将模块的参数归零并返回该模块
def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    # 遍历模块的所有参数
    for p in module.parameters():
        # 分离参数并将其值归零
        p.detach().zero_()
    # 返回修改后的模块
    return module


# 对模块的参数进行缩放并返回该模块
def scale_module(module, scale):
    """
    Scale the parameters of a module and return it.
    """
    # 遍历模块的所有参数
    for p in module.parameters():
        # 分离参数并按比例缩放
        p.detach().mul_(scale)
    # 返回修改后的模块
    return module


# 计算张量中所有非批次维度的平均值
def mean_flat(tensor):
    """
    Take the mean over all non-batch dimensions.
    """
    # 计算并返回张量在非批次维度上的平均值
    return tensor.mean(dim=list(range(1, len(tensor.shape))))


# 创建标准化层
def normalization(channels):
    """
    Make a standard normalization layer.
    :param channels: number of input channels.
    :return: an nn.Module for normalization.
    """
    # 返回一个具有指定输入通道数的GroupNorm32标准化层
    return GroupNorm32(32, channels)


# SiLU激活函数类,兼容PyTorch 1.5
class SiLU(nn.Module):
    # 定义前向传播方法
    def forward(self, x):
        # 返回输入与其Sigmoid值的乘积
        return x * torch.sigmoid(x)


# 自定义GroupNorm类,继承自nn.GroupNorm
class GroupNorm32(nn.GroupNorm):
    # 定义前向传播方法
    def forward(self, x):
        # 调用父类的前向方法并返回与输入相同的数据类型
        return super().forward(x).type(x.dtype)


# 创建1D、2D或3D卷积模块
def conv_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D convolution module.
    """
    # 根据维度选择对应的卷积层
    if dims == 1:
        return nn.Conv1d(*args, **kwargs)
    elif dims == 2:
        return nn.Conv2d(*args, **kwargs)
    elif dims == 3:
        return nn.Conv3d(*args, **kwargs)
    # 如果维度不支持,抛出错误
    raise ValueError(f"unsupported dimensions: {dims}")


# 创建线性模块
def linear(*args, **kwargs):
    """
    Create a linear module.
    """
    # 返回一个线性层
    return nn.Linear(*args, **kwargs)


# 创建1D、2D或3D平均池化模块
def avg_pool_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D average pooling module.
    """
    # 根据维度选择对应的平均池化层
    if dims == 1:
        return nn.AvgPool1d(*args, **kwargs)
    elif dims == 2:
        return nn.AvgPool2d(*args, **kwargs)
    elif dims == 3:
        return nn.AvgPool3d(*args, **kwargs)
    # 如果维度不支持,抛出错误
    raise ValueError(f"unsupported dimensions: {dims}")


# AlphaBlender类,用于合并不同策略
class AlphaBlender(nn.Module):
    # 支持的合并策略
    strategies = ["learned", "fixed", "learned_with_images"]

    # 初始化方法
    def __init__(
        self,
        alpha: float,
        merge_strategy: str = "learned_with_images",
        rearrange_pattern: str = "b t -> (b t) 1 1",
    ):
        super().__init__()  # 调用父类初始化
        # 保存合并策略和重排模式
        self.merge_strategy = merge_strategy
        self.rearrange_pattern = rearrange_pattern

        # 确保合并策略是支持的选项之一
        assert (
            merge_strategy in self.strategies
        ), f"merge_strategy needs to be in {self.strategies}"

        # 根据合并策略注册混合因子
        if self.merge_strategy == "fixed":
            # 注册固定混合因子
            self.register_buffer("mix_factor", torch.Tensor([alpha]))
        elif (
            self.merge_strategy == "learned"
            or self.merge_strategy == "learned_with_images"
        ):
            # 注册可学习的混合因子
            self.register_parameter(
                "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
            )
        else:
            # 抛出不支持的合并策略错误
            raise ValueError(f"unknown merge strategy {self.merge_strategy}")
    # 定义获取 alpha 值的函数,接受一个图像指示器作为输入,返回一个张量
    def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
        # 根据合并策略选择 alpha 值
        if self.merge_strategy == "fixed":
            # 如果策略是固定,则 alpha 等于混合因子
            alpha = self.mix_factor
        elif self.merge_strategy == "learned":
            # 如果策略是学习的,则通过 sigmoid 函数计算 alpha
            alpha = torch.sigmoid(self.mix_factor)
        elif self.merge_strategy == "learned_with_images":
            # 如果策略是基于图像学习的,确保提供了图像指示器
            assert image_only_indicator is not None, "need image_only_indicator ..."
            # 根据图像指示器的布尔值,决定 alpha 的值
            alpha = torch.where(
                image_only_indicator.bool(),
                # 如果为真,则 alpha 为全 1 的张量
                torch.ones(1, 1, device=image_only_indicator.device),
                # 否则通过 sigmoid 函数处理混合因子,并调整维度
                rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
            )
            # 根据给定的重排模式调整 alpha 的形状
            alpha = rearrange(alpha, self.rearrange_pattern)
        else:
            # 如果没有匹配的合并策略,抛出未实现的错误
            raise NotImplementedError
        # 返回计算得到的 alpha 值
        return alpha
    
    # 定义前向传播的函数,接受空间和时间的输入张量,返回一个张量
    def forward(
        self,
        x_spatial: torch.Tensor,
        x_temporal: torch.Tensor,
        image_only_indicator: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # 调用 get_alpha 函数获取 alpha 值
        alpha = self.get_alpha(image_only_indicator)
        # 根据 alpha 值和输入张量进行加权求和
        x = (
            # x_spatial 乘以 alpha,转换为相同的数据类型
            alpha.to(x_spatial.dtype) * x_spatial
            + 
            # (1.0 - alpha) 乘以 x_temporal,转换为相同的数据类型
            (1.0 - alpha).to(x_spatial.dtype) * x_temporal
        )
        # 返回加权后的结果张量
        return x

.\cogview3-finetune\sat\sgm\modules\diffusionmodules\wrappers.py

# 导入 PyTorch 库及其神经网络模块
import torch
import torch.nn as nn
# 导入版本控制模块
from packaging import version

# 定义一个字符串,表示 OpenAI Wrapper 的路径
OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"


# 定义一个身份包装器类,继承自 nn.Module
class IdentityWrapper(nn.Module):
    # 初始化方法,接受扩散模型、是否编译模型的标志和数据类型
    def __init__(self, diffusion_model, compile_model: bool = False, dtype: torch.dtype = torch.float32):
        # 调用父类的初始化方法
        super().__init__()
        # 根据 PyTorch 版本决定是否编译模型
        compile = (
            torch.compile
            if (version.parse(torch.__version__) >= version.parse("2.0.0"))  # 检查 PyTorch 版本
            and compile_model  # 仅当 compile_model 为 True 时
            else lambda x: x  # 否则返回原始模型
        )
        # 对扩散模型进行编译
        self.diffusion_model = compile(diffusion_model)
        # 保存数据类型
        self.dtype = dtype

    # 前向传播方法,接受任意数量的位置和关键字参数
    def forward(self, *args, **kwargs):
        # 调用扩散模型并返回结果
        return self.diffusion_model(*args, **kwargs)


# 定义 OpenAI Wrapper 类,继承自 IdentityWrapper
class OpenAIWrapper(IdentityWrapper):
    # 重写前向传播方法,接受输入张量、时间步、上下文字典和其他参数
    def forward(
        self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
    ) -> torch.Tensor:
        # 将上下文字典中的每个值转换为指定的数据类型
        for key in c:
            c[key] = c[key].to(self.dtype)

        # 检查输入张量的形状是否为 3 维
        if len(x.shape) == 3:
            # 在最后一个维度拼接上下文中的 "concat" 数据
            x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=-1)
        else:
            # 在第一个维度拼接上下文中的 "concat" 数据
            x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)

        # 调用扩散模型进行前向传播,并返回结果
        return self.diffusion_model(
            x,  # 输入张量
            timesteps=t,  # 时间步
            context=c.get("crossattn", None),  # 上下文中的交叉注意力
            y=c.get("vector", None),  # 上下文中的向量
            **kwargs,  # 其他参数
        )

.\cogview3-finetune\sat\sgm\modules\diffusionmodules\__init__.py

# 从当前包导入去噪声器类
from .denoiser import Denoiser
# 从当前包导入离散化类
from .discretizer import Discretization
# 从当前包导入标准扩散损失类
from .loss import StandardDiffusionLoss
# 从当前包导入解码器、编码器和模型类
from .model import Decoder, Encoder, Model
# 从当前包导入 UNet 模型类
from .openaimodel import UNetModel
# 从当前包导入基础扩散采样器类
from .sampling import BaseDiffusionSampler
# 从当前包导入 OpenAI 封装器类
from .wrappers import OpenAIWrapper

.\cogview3-finetune\sat\sgm\modules\distributions\distributions.py

# 导入 NumPy 库,通常用于数值计算
import numpy as np
# 导入 PyTorch 库,主要用于深度学习计算
import torch


# 定义抽象分布类,继承自基础类
class AbstractDistribution:
    # 抽象方法,样本生成
    def sample(self):
        # 抛出未实现错误,子类需实现此方法
        raise NotImplementedError()

    # 抽象方法,返回分布的众数
    def mode(self):
        # 抛出未实现错误,子类需实现此方法
        raise NotImplementedError()


# 定义 Dirac 分布类,继承自抽象分布类
class DiracDistribution(AbstractDistribution):
    # 初始化方法,接收一个值作为分布的值
    def __init__(self, value):
        # 将传入的值存储为实例变量
        self.value = value

    # 实现样本生成方法
    def sample(self):
        # 返回 Dirac 分布的值
        return self.value

    # 实现众数方法
    def mode(self):
        # 返回 Dirac 分布的值
        return self.value


# 定义对角高斯分布类
class DiagonalGaussianDistribution(object):
    # 初始化方法,接收参数和一个决定是否确定性的标志
    def __init__(self, parameters, deterministic=False):
        # 将参数存储为实例变量
        self.parameters = parameters
        # 将参数分割成均值和对数方差
        self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
        # 限制对数方差在 -30 到 20 之间
        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
        # 存储确定性标志
        self.deterministic = deterministic
        # 计算标准差
        self.std = torch.exp(0.5 * self.logvar)
        # 计算方差
        self.var = torch.exp(self.logvar)
        # 如果是确定性,方差和标准差设为零
        if self.deterministic:
            self.var = self.std = torch.zeros_like(self.mean).to(
                device=self.parameters.device
            )

    # 实现样本生成方法
    def sample(self):
        # x = self.mean + self.std * torch.randn(self.mean.shape).to(
        #     device=self.parameters.device
        # )
        # 生成样本,遵循均值和标准差的分布
        x = self.mean + self.std * torch.randn_like(self.mean).to(
            device=self.parameters.device
        )
        # 返回生成的样本
        return x

    # 实现 KL 散度计算方法
    def kl(self, other=None):
        # 如果是确定性,返回 0.0 的张量
        if self.deterministic:
            return torch.Tensor([0.0])
        else:
            # 如果没有其他分布,计算与标准正态分布的 KL 散度
            if other is None:
                return 0.5 * torch.sum(
                    torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
                    dim=[1, 2, 3],
                )
            else:
                # 计算与另一分布的 KL 散度
                return 0.5 * torch.sum(
                    torch.pow(self.mean - other.mean, 2) / other.var
                    + self.var / other.var
                    - 1.0
                    - self.logvar
                    + other.logvar,
                    dim=[1, 2, 3],
                )

    # 实现负对数似然计算方法
    def nll(self, sample, dims=[1, 2, 3]):
        # 如果是确定性,返回 0.0 的张量
        if self.deterministic:
            return torch.Tensor([0.0])
        # 计算 2π 的对数值
        logtwopi = np.log(2.0 * np.pi)
        # 计算负对数似然并返回
        return 0.5 * torch.sum(
            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
            dim=dims,
        )

    # 实现众数计算方法
    def mode(self):
        # 返回均值作为众数
        return self.mean


# 定义计算两个高斯分布之间 KL 散度的函数
def normal_kl(mean1, logvar1, mean2, logvar2):
    """
    来源: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
    计算两个高斯分布之间的 KL 散度。
    形状会自动广播,因此批次可以与标量等进行比较,
    适用于其他用例。
    """
    # 初始化张量变量
    tensor = None
    # 遍历四个输入对象,找到第一个张量
    for obj in (mean1, logvar1, mean2, logvar2):
        if isinstance(obj, torch.Tensor):
            tensor = obj
            break
    # 确保至少一个输入是张量
    assert tensor is not None, "at least one argument must be a Tensor"

    # 强制方差为张量类型。广播帮助将标量转换为
    # 张量,但对 torch.exp() 不起作用。
    # 将 logvar1 和 logvar2 进行处理,确保它们都是张量类型
        logvar1, logvar2 = [
            # 如果 x 是张量,保持不变;否则,将 x 转换为张量并移动到指定设备
            x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
            for x in (logvar1, logvar2)  # 遍历 logvar1 和 logvar2
        ]
    
        # 计算并返回一个值,公式由多个部分组成
        return 0.5 * (
            # -1.0 表示计算的偏移量
            -1.0
            + logvar2  # 加上 logvar2 的值
            - logvar1  # 减去 logvar1 的值
            + torch.exp(logvar1 - logvar2)  # 加上 logvar1 和 logvar2 之差的指数
            + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)  # 加上均值差的平方乘以 logvar2 的负指数
        )

.\cogview3-finetune\sat\sgm\modules\distributions\__init__.py

# 代码段为空,无法进行注释

.\cogview3-finetune\sat\sgm\modules\ema.py

# 导入 PyTorch 库
import torch
# 从 PyTorch 导入神经网络模块
from torch import nn


# 定义一个继承自 nn.Module 的类 LitEma
class LitEma(nn.Module):
    # 初始化函数,接收模型、衰减率和是否使用更新计数
    def __init__(self, model, decay=0.9999, use_num_upates=True):
        # 调用父类初始化
        super().__init__()
        # 检查衰减率是否在 0 到 1 之间
        if decay < 0.0 or decay > 1.0:
            raise ValueError("Decay must be between 0 and 1")

        # 初始化模型参数名到阴影参数名的映射
        self.m_name2s_name = {}
        # 注册衰减率的缓冲区
        self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
        # 根据是否使用更新计数注册更新次数的缓冲区
        self.register_buffer(
            "num_updates",
            torch.tensor(0, dtype=torch.int)
            if use_num_upates
            else torch.tensor(-1, dtype=torch.int),
        )

        # 遍历模型的命名参数
        for name, p in model.named_parameters():
            # 如果参数需要梯度更新
            if p.requires_grad:
                # 移除参数名中的 '.' 字符
                s_name = name.replace(".", "")
                # 更新模型参数名到阴影参数名的映射
                self.m_name2s_name.update({name: s_name})
                # 注册参数的缓冲区
                self.register_buffer(s_name, p.clone().detach().data)

        # 初始化收集的参数列表
        self.collected_params = []

    # 重置更新计数的方法
    def reset_num_updates(self):
        # 删除 num_updates 的缓冲区
        del self.num_updates
        # 注册更新计数的缓冲区,初始化为 0
        self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))

    # 前向传播方法,接收模型作为输入
    def forward(self, model):
        # 获取当前衰减率
        decay = self.decay

        # 如果更新计数有效
        if self.num_updates >= 0:
            # 增加更新计数
            self.num_updates += 1
            # 计算新的衰减率
            decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))

        # 计算 1 减去衰减率
        one_minus_decay = 1.0 - decay

        # 在不跟踪梯度的情况下执行操作
        with torch.no_grad():
            # 获取模型的命名参数字典
            m_param = dict(model.named_parameters())
            # 获取阴影参数的命名缓冲区字典
            shadow_params = dict(self.named_buffers())

            # 遍历模型参数字典
            for key in m_param:
                # 如果参数需要梯度更新
                if m_param[key].requires_grad:
                    # 获取对应的阴影参数名
                    sname = self.m_name2s_name[key]
                    # 将阴影参数转换为与模型参数相同的数据类型
                    shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
                    # 更新阴影参数的值
                    shadow_params[sname].sub_(
                        one_minus_decay * (shadow_params[sname] - m_param[key])
                    )
                else:
                    # 确保此参数不在映射中
                    assert not key in self.m_name2s_name

    # 将阴影参数复制到模型参数的方法
    def copy_to(self, model):
        # 获取模型的命名参数字典
        m_param = dict(model.named_parameters())
        # 获取阴影参数的命名缓冲区字典
        shadow_params = dict(self.named_buffers())
        # 遍历模型参数字典
        for key in m_param:
            # 如果参数需要梯度更新
            if m_param[key].requires_grad:
                # 复制阴影参数的数据到模型参数
                m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
            else:
                # 确保此参数不在映射中
                assert not key in self.m_name2s_name

    # 存储当前参数以供稍后恢复的方法
    def store(self, parameters):
        """
        保存当前参数以供稍后恢复。
        参数:
          parameters: 可迭代的 `torch.nn.Parameter`;需要临时存储的参数。
        """
        # 克隆参数并存储在收集的参数列表中
        self.collected_params = [param.clone() for param in parameters]
    # 定义一个恢复方法,用于恢复存储的参数
    def restore(self, parameters):
        """
        恢复通过 `store` 方法存储的参数。
        这对于在不影响原始优化过程的情况下使用 EMA 参数验证模型很有用。
        在调用 `copy_to` 方法之前存储参数。
        验证(或保存模型)后,使用此方法恢复先前的参数。
        Args:
          parameters: 可迭代的 `torch.nn.Parameter`;需要用存储的参数更新的参数。
        """
        # 遍历已收集的参数和输入参数,成对处理
        for c_param, param in zip(self.collected_params, parameters):
            # 将已收集参数的数据复制到输入参数的数据中
            param.data.copy_(c_param.data)

.\cogview3-finetune\sat\sgm\modules\encoders\modules.py

# 导入数学库,用于数学计算
import math
# 从上下文管理库导入 nullcontext,用于创建一个不执行任何操作的上下文管理器
from contextlib import nullcontext
# 从 functools 导入 partial,用于创建偏函数
from functools import partial
# 从 typing 导入各种类型注解,用于类型检查
from typing import Dict, List, Optional, Tuple, Union

# 导入 kornia 库,用于计算机视觉的操作
import kornia
# 导入 numpy 库,用于数组和数值计算
import numpy as np
# 导入 open_clip 库,用于处理 CLIP 模型
import open_clip
# 导入 PyTorch 库,深度学习框架
import torch
# 导入 PyTorch 的分布式模块
import torch.distributed
# 导入 PyTorch 的神经网络模块
import torch.nn as nn
# 从 einops 导入 rearrange 和 repeat,用于重排和复制张量
from einops import rearrange, repeat
# 从 omegaconf 导入 ListConfig,用于处理配置文件
from omegaconf import ListConfig
# 从 torch.utils.checkpoint 导入 checkpoint,用于节省内存的检查点机制
from torch.utils.checkpoint import checkpoint
# 从 transformers 导入各种模型和分词器
from transformers import (
    ByT5Tokenizer,  # 导入 ByT5 的分词器
    CLIPTextModel,  # 导入 CLIP 的文本模型
    CLIPTokenizer,  # 导入 CLIP 的分词器
    T5EncoderModel,  # 导入 T5 的编码器模型
    T5Tokenizer,  # 导入 T5 的分词器
    AutoModel,  # 导入自动模型类,用于加载预训练模型
    AutoTokenizer  # 导入自动分词器类,用于加载预训练分词器
)

# 从模块中导入正则化器、编码器和时间步等工具
from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer
from ...modules.diffusionmodules.model import Encoder
from ...modules.diffusionmodules.openaimodel import Timestep
from ...modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
from ...modules.distributions.distributions import DiagonalGaussianDistribution
from ...util import (
    append_dims,  # 导入函数,用于向张量追加维度
    autocast,  # 导入函数,用于自动混合精度训练
    count_params,  # 导入函数,用于计算模型参数数量
    default,  # 导入函数,用于获取默认值
    disabled_train,  # 导入函数,用于禁用训练模式
    expand_dims_like,  # 导入函数,用于扩展张量维度以匹配另一个张量
    instantiate_from_config,  # 导入函数,从配置实例化对象
)


# 定义一个抽象的嵌入模型类,继承自 nn.Module
class AbstractEmbModel(nn.Module):
    # 初始化方法
    def __init__(self):
        super().__init__()  # 调用父类构造函数
        self._is_trainable = None  # 初始化可训练标志
        self._ucg_rate = None  # 初始化 UCG 率
        self._input_key = None  # 初始化输入键

    # 定义 is_trainable 属性的 getter 方法
    @property
    def is_trainable(self) -> bool:
        return self._is_trainable  # 返回可训练标志

    # 定义 ucg_rate 属性的 getter 方法
    @property
    def ucg_rate(self) -> Union[float, torch.Tensor]:
        return self._ucg_rate  # 返回 UCG 率

    # 定义 input_key 属性的 getter 方法
    @property
    def input_key(self) -> str:
        return self._input_key  # 返回输入键

    # 定义 is_trainable 属性的 setter 方法
    @is_trainable.setter
    def is_trainable(self, value: bool):
        self._is_trainable = value  # 设置可训练标志

    # 定义 ucg_rate 属性的 setter 方法
    @ucg_rate.setter
    def ucg_rate(self, value: Union[float, torch.Tensor]):
        self._ucg_rate = value  # 设置 UCG 率

    # 定义 input_key 属性的 setter 方法
    @input_key.setter
    def input_key(self, value: str):
        self._input_key = value  # 设置输入键

    # 定义 is_trainable 属性的 deleter 方法
    @is_trainable.deleter
    def is_trainable(self):
        del self._is_trainable  # 删除可训练标志

    # 定义 ucg_rate 属性的 deleter 方法
    @ucg_rate.deleter
    def ucg_rate(self):
        del self._ucg_rate  # 删除 UCG 率

    # 定义 input_key 属性的 deleter 方法
    @input_key.deleter
    def input_key(self):
        del self._input_key  # 删除输入键


# 定义通用条件器类,继承自 nn.Module
class GeneralConditioner(nn.Module):
    # 输出维度到键的映射
    OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
    # 键到拼接维度的映射
    KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
    # 初始化函数,接收嵌入模型配置及其他参数
        def __init__(self, emb_models: Union[List, ListConfig], cor_embs=[], cor_p=[]):
            # 调用父类的初始化方法
            super().__init__()
            # 存储嵌入模型的列表
            embedders = []
            # 遍历嵌入模型配置
            for n, embconfig in enumerate(emb_models):
                # 从配置中实例化嵌入模型
                embedder = instantiate_from_config(embconfig)
                # 确保嵌入模型继承自 AbstractEmbModel
                assert isinstance(
                    embedder, AbstractEmbModel
                ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel"
                # 获取是否可训练的标志,默认为 False
                embedder.is_trainable = embconfig.get("is_trainable", False)
                # 获取 UCG 比率,默认为 0.0
                embedder.ucg_rate = embconfig.get("ucg_rate", 0.0)
                # 如果不可训练,禁用训练
                if not embedder.is_trainable:
                    embedder.train = disabled_train
                    # 将模型参数的梯度要求设为 False
                    for param in embedder.parameters():
                        param.requires_grad = False
                    # 将模型设置为评估模式
                    embedder.eval()
                # print(
                #     f"Initialized embedder #{n}: {embedder.__class__.__name__} "
                #     f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
                # )
    
                # 检查是否有单一输入键
                if "input_key" in embconfig:
                    embedder.input_key = embconfig["input_key"]
                # 检查是否有多个输入键
                elif "input_keys" in embconfig:
                    embedder.input_keys = embconfig["input_keys"]
                # 如果没有输入键,则引发异常
                else:
                    raise KeyError(
                        f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}"
                    )
    
                # 获取遗留 UCG 值
                embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None)
                # 如果遗留 UCG 值存在,初始化随机状态
                if embedder.legacy_ucg_val is not None:
                    embedder.ucg_prng = np.random.RandomState()
    
                # 将嵌入模型添加到列表中
                embedders.append(embedder)
            # 将嵌入模型列表存储为模块列表
            self.embedders = nn.ModuleList(embedders)
    
            # 如果存在条件嵌入,确保条件概率长度匹配
            if len(cor_embs) > 0:
                assert len(cor_p) == 2**len(cor_embs)
            # 存储条件嵌入和概率
            self.cor_embs = cor_embs
            self.cor_p = cor_p
    
        # 根据嵌入模型和批量数据获取 UCG 值(可能)
        def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict:
            # 确保遗留 UCG 值存在
            assert embedder.legacy_ucg_val is not None
            # 获取 UCG 比率
            p = embedder.ucg_rate
            # 获取遗留 UCG 值
            val = embedder.legacy_ucg_val
            # 遍历批量数据的输入键
            for i in range(len(batch[embedder.input_key])):
                # 根据概率选择是否替换值
                if embedder.ucg_prng.choice(2, p=[1 - p, p]):
                    batch[embedder.input_key][i] = val
            # 返回更新后的批量数据
            return batch
        
        # 根据嵌入模型和条件获取 UCG 值(必定)
        def surely_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict, cond_or_not) -> Dict:
            # 确保遗留 UCG 值存在
            assert embedder.legacy_ucg_val is not None
            # 获取遗留 UCG 值
            val = embedder.legacy_ucg_val
            # 遍历批量数据的输入键
            for i in range(len(batch[embedder.input_key])):
                # 如果条件满足,替换值
                if cond_or_not[i]:
                    batch[embedder.input_key][i] = val
            # 返回更新后的批量数据
            return batch
    # 定义获取单个嵌入的方法
    def get_single_embedding(self, embedder, batch, output, cond_or_not: Optional[np.ndarray] = None, force_zero_embeddings: Optional[List] = None):
        # 根据嵌入器是否可训练选择上下文管理器
        embedding_context = nullcontext if embedder.is_trainable else torch.no_grad
        # 使用选定的上下文管理器
        with embedding_context():
            # 检查嵌入器是否有输入键,并且输入键不为 None
            if hasattr(embedder, "input_key") and (embedder.input_key is not None):
                # 如果嵌入器的 legacy_ucg_val 不为 None
                if embedder.legacy_ucg_val is not None:
                    # 如果条件不为 None
                    if cond_or_not is None:
                        # 可能获取 ucg_val 的值
                        batch = self.possibly_get_ucg_val(embedder, batch)
                    else:
                        # 确定获取 ucg_val 的值
                        batch = self.surely_get_ucg_val(embedder, batch, cond_or_not)
                # 从批次中获取嵌入输出
                emb_out = embedder(batch[embedder.input_key])
            # 如果嵌入器有多个输入键
            elif hasattr(embedder, "input_keys"):
                # 从批次中解包并获取嵌入输出
                emb_out = embedder(*[batch[k] for k in embedder.input_keys])
        # 确保嵌入输出是张量、列表或元组
        assert isinstance(
            emb_out, (torch.Tensor, list, tuple)
        ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}"
        # 如果嵌入输出不是列表或元组,将其转为列表
        if not isinstance(emb_out, (list, tuple)):
            emb_out = [emb_out]    
        # 遍历嵌入输出
        for emb in emb_out:
            # 获取输出键
            out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
            # 如果嵌入器的 ucg_rate 大于 0 且 legacy_ucg_val 为 None
            if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
                # 如果条件不为 None
                if cond_or_not is None:
                    # 扩展嵌入维度并应用伯努利分布
                    emb = (
                        expand_dims_like(
                            torch.bernoulli(
                                (1.0 - embedder.ucg_rate)
                                * torch.ones(emb.shape[0], device=emb.device)
                            ),
                            emb,
                        )
                        * emb
                    )
                else:
                    # 根据条件扩展嵌入维度
                    emb = (
                        expand_dims_like(
                            torch.tensor(1-cond_or_not, dtype=emb.dtype, device=emb.device),
                            emb,
                        )
                        * emb
                    )
            # 如果嵌入器有输入键且在强制零嵌入列表中
            if (
                hasattr(embedder, "input_key")
                and embedder.input_key in force_zero_embeddings
            ):
                # 将嵌入设置为全零
                emb = torch.zeros_like(emb)
            # 如果输出中已有该键
            if out_key in output:
                # 将新的嵌入与已有输出拼接
                output[out_key] = torch.cat(
                    (output[out_key], emb), self.KEY2CATDIM[out_key]
                )
            else:
                # 否则,直接赋值新的嵌入
                output[out_key] = emb
        # 返回更新后的输出
        return output
    
    # 定义前向传播的方法
    def forward(
        self, batch: Dict, force_zero_embeddings: Optional[List] = None
    ) -> Dict:  # 定义函数的返回类型为字典
        output = dict()  # 初始化一个空字典用于存储输出结果
        if force_zero_embeddings is None:  # 检查是否提供强制零嵌入参数
            force_zero_embeddings = []  # 如果没有,初始化为空列表

        if len(self.cor_embs) > 0:  # 如果相关嵌入存在
            batch_size = len(batch[list(batch.keys())[0]])  # 获取批次中第一个键的大小
            rand_idx = np.random.choice(len(self.cor_p), size=(batch_size,), p=self.cor_p)  # 根据相关概率随机选择索引
            for emb_idx in self.cor_embs:  # 遍历相关嵌入索引
                cond_or_not = rand_idx % 2  # 计算条件标志(0或1)
                rand_idx //= 2  # 更新随机索引
                embedder = self.embedders[emb_idx]  # 获取对应的嵌入器
                output = self.get_single_embedding(self.embedders[emb_idx], batch, output=output, cond_or_not=cond_or_not, force_zero_embeddings=force_zero_embeddings)  # 获取单个嵌入并更新输出

        for i, embedder in enumerate(self.embedders):  # 遍历所有嵌入器及其索引
            if i in self.cor_embs:  # 如果索引在相关嵌入中,则跳过
                continue  # 继续下一个循环
            output = self.get_single_embedding(embedder, batch, output=output, force_zero_embeddings=force_zero_embeddings)  # 获取单个嵌入并更新输出
        return output  # 返回最终的输出字典

    def get_unconditional_conditioning(  # 定义获取无条件调节的函数
        self, batch_c, batch_uc=None, force_uc_zero_embeddings=None  # 输入批次及可选参数
    ):
        if force_uc_zero_embeddings is None:  # 检查强制无条件零嵌入参数
            force_uc_zero_embeddings = []  # 如果没有,初始化为空列表
        ucg_rates = list()  # 初始化列表用于存储原有的无条件生成率
        for embedder in self.embedders:  # 遍历所有嵌入器
            ucg_rates.append(embedder.ucg_rate)  # 保存当前的无条件生成率
            embedder.ucg_rate = 0.0  # 将无条件生成率设置为0

        cor_embs = self.cor_embs  # 保存当前相关嵌入
        cor_p = self.cor_p  # 保存当前相关概率
        self.cor_embs = []  # 清空相关嵌入
        self.cor_p = []  # 清空相关概率

        c = self(batch_c)  # 计算输入批次的输出
        uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings)  # 计算无条件输出

        for embedder, rate in zip(self.embedders, ucg_rates):  # 恢复每个嵌入器的无条件生成率
            embedder.ucg_rate = rate  # 将原有的无条件生成率重新赋值
        self.cor_embs = cor_embs  # 恢复相关嵌入
        self.cor_p = cor_p  # 恢复相关概率

        return c, uc  # 返回有条件和无条件的输出
# 定义一个名为 InceptionV3 的类,继承自 nn.Module
class InceptionV3(nn.Module):
    """对 https://github.com/mseitzer/pytorch-fid 的 Inception 
    端口进行包装,并在末尾增加一个 squeeze 操作"""

    # 初始化函数,接受 normalize_input 参数和其他可选参数
    def __init__(self, normalize_input=False, **kwargs):
        # 调用父类的初始化函数
        super().__init__()
        # 从 pytorch_fid 导入 inception 模块
        from pytorch_fid import inception

        # 设置输入调整标志为 True
        kwargs["resize_input"] = True
        # 创建 InceptionV3 模型实例,并传入参数
        self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs)

    # 前向传播函数,接受输入张量 inp
    def forward(self, inp):
        # 对输入进行尺寸调整(已注释)
        # inp = kornia.geometry.resize(inp, (299, 299),
        #                              interpolation='bicubic',
        #                              align_corners=False,
        #                              antialias=True)
        # 将输入值限制在 -1 到 1 之间(已注释)
        # inp = inp.clamp(min=-1, max=1)

        # 使用模型对输入进行处理,获得输出
        outp = self.model(inp)

        # 如果输出只有一个元素,去掉维度并返回
        if len(outp) == 1:
            return outp[0].squeeze()

        # 返回原始输出
        return outp


# 定义一个名为 IdentityEncoder 的类,继承自 AbstractEmbModel
class IdentityEncoder(AbstractEmbModel):
    # 编码函数,直接返回输入
    def encode(self, x):
        return x

    # 前向传播函数,直接返回输入
    def forward(self, x):
        return x


# 定义一个名为 ClassEmbedder 的类,继承自 AbstractEmbModel
class ClassEmbedder(AbstractEmbModel):
    # 初始化函数,接受嵌入维度、类数和是否添加序列维度的参数
    def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False):
        # 调用父类的初始化函数
        super().__init__()
        # 创建嵌入层,映射类到嵌入维度
        self.embedding = nn.Embedding(n_classes, embed_dim)
        # 保存类数和是否添加序列维度的标志
        self.n_classes = n_classes
        self.add_sequence_dim = add_sequence_dim

    # 前向传播函数,接受类的输入
    def forward(self, c):
        # 获取类的嵌入表示
        c = self.embedding(c)
        # 如果需要,添加序列维度
        if self.add_sequence_dim:
            c = c[:, None, :]
        # 返回嵌入表示
        return c

    # 获取无条件的条件信息
    def get_unconditional_conditioning(self, bs, device="cuda"):
        uc_class = (
            self.n_classes - 1
        )  # 1000 类 --> 0 ... 999,额外的类用于无条件生成
        # 创建一个全为 uc_class 的张量
        uc = torch.ones((bs,), device=device) * uc_class
        # 将类信息包装成字典
        uc = {self.key: uc.long()}
        # 返回字典
        return uc


# 定义一个名为 ClassEmbedderForMultiCond 的类,继承自 ClassEmbedder
class ClassEmbedderForMultiCond(ClassEmbedder):
    # 前向传播函数,接受批量数据、键和是否禁用丢弃的标志
    def forward(self, batch, key=None, disable_dropout=False):
        # 将输出初始化为输入批次
        out = batch
        # 如果未提供键,使用默认键
        key = default(key, self.key)
        # 检查批次中的值是否为列表
        islist = isinstance(batch[key], list)
        # 如果是列表,则取第一个元素
        if islist:
            batch[key] = batch[key][0]
        # 调用父类的前向传播
        c_out = super().forward(batch, key, disable_dropout)
        # 根据是否为列表,更新输出
        out[key] = [c_out] if islist else c_out
        # 返回更新后的输出
        return out


# 定义一个名为 FrozenT5Embedder 的类,继承自 AbstractEmbModel
class FrozenT5Embedder(AbstractEmbModel):
    """使用 T5 转换器编码器进行文本处理"""

    # 初始化函数,接受模型目录、设备、最大长度、是否冻结和缓存目录的参数
    def __init__(
        self,
        model_dir="google/t5-v1_1-xxl",
        device="cuda",
        max_length=77,
        freeze=True,
        cache_dir=None,
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 如果模型目录不是默认的,加载相应的分词器和模型
        if model_dir is not "google/t5-v1_1-xxl":
            self.tokenizer = T5Tokenizer.from_pretrained(model_dir)
            self.transformer = T5EncoderModel.from_pretrained(model_dir)
        else:
            # 否则,使用缓存目录加载分词器和模型
            self.tokenizer = T5Tokenizer.from_pretrained(model_dir, cache_dir=cache_dir)
            self.transformer = T5EncoderModel.from_pretrained(model_dir, cache_dir=cache_dir)
        # 保存设备信息
        self.device = device
        # 保存最大长度
        self.max_length = max_length
        # 如果需要冻结,调用冻结函数
        if freeze:
            self.freeze()
    # 定义冻结模型参数的方法
        def freeze(self):
            # 将转换器设置为评估模式,以禁用训练时的行为(如 Dropout)
            self.transformer = self.transformer.eval()
    
            # 遍历所有模型参数
            for param in self.parameters():
                # 禁用参数的梯度计算,以减少内存使用和提高推理速度
                param.requires_grad = False
    
        # @autocast  # 可选装饰器,用于自动混合精度
        def forward(self, text):
            # 使用分词器对输入文本进行编码,返回编码后的批次信息
            batch_encoding = self.tokenizer(
                text,
                # 截断超出最大长度的文本
                truncation=True,
                # 设置最大长度
                max_length=self.max_length,
                # 返回编码后的文本长度
                return_length=True,
                # 不返回溢出的令牌
                return_overflowing_tokens=False,
                # 填充到最大长度
                padding="max_length",
                # 返回 PyTorch 张量
                return_tensors="pt",
            )
            # 将输入ID移动到指定的设备(CPU或GPU)
            tokens = batch_encoding["input_ids"].to(self.device)
            # 使用上下文管理器禁用混合精度计算
            with torch.autocast("cuda", enabled=False):
                # 将令牌输入到转换器中,获取输出
                outputs = self.transformer(input_ids=tokens)
            # 获取最后一个隐藏状态,作为编码的表示
            z = outputs.last_hidden_state
            # 返回编码结果
            return z
    
        # 定义编码文本的方法
        def encode(self, text):
            # 调用当前对象的 forward 方法进行编码
            return self(text)
# 定义一个名为 FrozenByT5Embedder 的类,继承自 AbstractEmbModel
class FrozenByT5Embedder(AbstractEmbModel):
    """
    使用 ByT5 转换器编码器处理文本,具备字符意识。
    """

    # 初始化方法,设置模型的版本、设备、最大长度和是否冻结参数
    def __init__(
        self, version="google/byt5-base", device="cuda", max_length=77, freeze=True
    ):  # 其他可用版本为 google/t5-v1_1-xl 和 google/t5-v1_1-xxl
        # 调用父类构造函数
        super().__init__()
        # 加载预训练的 ByT5 分词器
        self.tokenizer = ByT5Tokenizer.from_pretrained(version)
        # 加载预训练的 T5 编码器模型
        self.transformer = T5EncoderModel.from_pretrained(version)
        # 设置设备类型(如 CUDA)
        self.device = device
        # 设置输入文本的最大长度
        self.max_length = max_length
        # 如果需要冻结参数,则调用冻结方法
        if freeze:
            self.freeze()

    # 冻结模型的参数,以避免训练时更新
    def freeze(self):
        # 将变换器设置为评估模式
        self.transformer = self.transformer.eval()
        # 遍历所有参数并设置为不可更新
        for param in self.parameters():
            param.requires_grad = False

    # 定义前向传播方法
    def forward(self, text):
        # 对输入文本进行编码,返回批次编码
        batch_encoding = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            return_length=True,
            return_overflowing_tokens=False,
            padding="max_length",
            return_tensors="pt",
        )
        # 将输入的 token 移动到指定设备
        tokens = batch_encoding["input_ids"].to(self.device)
        # 在不启用自动混合精度的情况下进行前向传播
        with torch.autocast("cuda", enabled=False):
            # 获取模型的输出
            outputs = self.transformer(input_ids=tokens)
        # 取出最后一层的隐藏状态
        z = outputs.last_hidden_state
        # 返回最后的隐藏状态
        return z

    # 定义编码方法,直接调用前向传播
    def encode(self, text):
        return self(text)


# 定义一个名为 FrozenCLIPEmbedder 的类,继承自 AbstractEmbModel
class FrozenCLIPEmbedder(AbstractEmbModel):
    """使用 CLIP 转换器编码器处理文本(来自 huggingface)"""

    # 定义可用的层类型
    LAYERS = ["last", "pooled", "hidden"]

    # 初始化方法,设置模型的版本、设备、最大长度、冻结状态和层类型
    def __init__(
        self,
        version="openai/clip-vit-large-patch14",
        device="cuda",
        max_length=77,
        freeze=True,
        layer="last",
        layer_idx=None,
        always_return_pooled=False,
    ):  # clip-vit-base-patch32
        # 调用父类构造函数
        super().__init__()
        # 确保层类型在可用层中
        assert layer in self.LAYERS
        # 加载预训练的 CLIP 分词器
        self.tokenizer = CLIPTokenizer.from_pretrained(version)
        # 加载预训练的 CLIP 文本模型
        self.transformer = CLIPTextModel.from_pretrained(version)
        # 设置设备类型(如 CUDA)
        self.device = device
        # 设置输入文本的最大长度
        self.max_length = max_length
        # 如果需要冻结参数,则调用冻结方法
        if freeze:
            self.freeze()
        # 设置所用层的类型
        self.layer = layer
        # 设置所用层的索引
        self.layer_idx = layer_idx
        # 设置是否总是返回池化结果
        self.return_pooled = always_return_pooled
        # 如果层为隐藏层,确保层索引有效
        if layer == "hidden":
            assert layer_idx is not None
            assert 0 <= abs(layer_idx) <= 12

    # 冻结模型的参数,以避免训练时更新
    def freeze(self):
        # 将变换器设置为评估模式
        self.transformer = self.transformer.eval()
        # 遍历所有参数并设置为不可更新
        for param in self.parameters():
            param.requires_grad = False

    # 这里缺少方法体,可能是注释或未完成代码
    @autocast
    # 定义前向传播函数,接收文本输入
    def forward(self, text):
        # 对输入文本进行编码,生成批量编码,设置各种参数以控制编码行为
        batch_encoding = self.tokenizer(
            text,
            truncation=True,  # 超出最大长度时截断文本
            max_length=self.max_length,  # 最大长度限制
            return_length=True,  # 返回编码后每个文本的长度
            return_overflowing_tokens=False,  # 不返回溢出的标记
            padding="max_length",  # 填充到最大长度
            return_tensors="pt",  # 返回 PyTorch 张量格式
        )
        # 获取编码后的输入标记,并将其移动到指定设备上
        tokens = batch_encoding["input_ids"].to(self.device)
        # 使用 transformer 模型进行前向传播,获取输出
        outputs = self.transformer(
            input_ids=tokens, output_hidden_states=self.layer == "hidden"  # 根据条件决定是否返回隐藏状态
        )
        # 根据层级选择相应的输出
        if self.layer == "last":
            # 选择最后一层的隐藏状态
            z = outputs.last_hidden_state
        elif self.layer == "pooled":
            # 选择池化后的输出,并增加一个维度
            z = outputs.pooler_output[:, None, :]
        else:
            # 选择指定索引的隐藏状态
            z = outputs.hidden_states[self.layer_idx]
        # 根据是否需要池化输出,返回相应结果
        if self.return_pooled:
            return z, outputs.pooler_output  # 返回输出和池化结果
        return z  # 返回仅隐藏状态

    # 定义编码函数,简化调用前向传播
    def encode(self, text):
        # 调用前向传播函数并返回结果
        return self(text)
# 定义一个名为 FrozenOpenCLIPEmbedder2 的类,继承自 AbstractEmbModel
class FrozenOpenCLIPEmbedder2(AbstractEmbModel):
    """
    使用 OpenCLIP 变换器编码器进行文本处理
    """

    # 定义可用的层名称
    LAYERS = ["pooled", "last", "penultimate"]

    # 初始化方法,设置模型的基本参数
    def __init__(
        self,
        arch="ViT-H-14",  # 模型架构
        version="laion2b_s32b_b79k",  # 版本信息
        device="cuda",  # 设备类型
        max_length=77,  # 最大输入长度
        freeze=True,  # 是否冻结模型参数
        layer="last",  # 选择的层
        always_return_pooled=False,  # 是否始终返回池化结果
        legacy=True,  # 是否使用遗留模式
    ):
        super().__init__()  # 调用父类构造函数
        assert layer in self.LAYERS  # 确保指定的层有效
        # 创建模型和转换器,并将其移动到 CPU
        model, _, _ = open_clip.create_model_and_transforms(
            arch,
            device=torch.device("cpu"),
        )
        del model.visual  # 删除视觉部分
        self.model = model  # 保存模型

        self.device = device  # 设置设备
        self.max_length = max_length  # 设置最大长度
        self.return_pooled = always_return_pooled  # 设置是否返回池化
        if freeze:  # 如果需要冻结模型
            self.freeze()  # 调用冻结方法
        self.layer = layer  # 设置层
        # 根据选择的层更新层索引
        if self.layer == "last":
            self.layer_idx = 0
        elif self.layer == "penultimate":
            self.layer_idx = 1
        else:
            raise NotImplementedError()  # 不支持的层选择
        self.legacy = legacy  # 设置遗留模式

    # 冻结模型参数的方法
    def freeze(self):
        self.model = self.model.eval()  # 设置模型为评估模式
        for param in self.parameters():  # 遍历所有参数
            param.requires_grad = False  # 禁止梯度更新

    # 前向传播的方法,处理输入文本
    @autocast
    def forward(self, text):
        tokens = open_clip.tokenize(text)  # 将文本转换为标记
        z = self.encode_with_transformer(tokens.to(self.device))  # 编码处理
        # 根据条件返回不同结果
        if not self.return_pooled and self.legacy:
            return z
        if self.return_pooled:
            assert not self.legacy  # 确保不在遗留模式下
            return z[self.layer], z["pooled"]  # 返回选定层和池化结果
        return z[self.layer]  # 返回选定层结果

    # 使用变换器进行编码的方法
    def encode_with_transformer(self, text):
        x = self.model.token_embedding(text)  # 获取标记嵌入
        x = x + self.model.positional_embedding  # 加入位置嵌入
        x = x.permute(1, 0, 2)  # 转换维度顺序
        x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)  # 通过变换器前向传播
        if self.legacy:  # 如果在遗留模式
            x = x[self.layer]  # 获取选定层结果
            x = self.model.ln_final(x)  # 最终归一化
            return x  # 返回结果
        else:
            # x 为字典,将保持为字典
            o = x["last"]  # 获取最后一层输出
            o = self.model.ln_final(o)  # 最终归一化
            pooled = self.pool(o, text)  # 进行池化处理
            x["pooled"] = pooled  # 将池化结果存入字典
            return x  # 返回字典

    # 池化处理的方法
    def pool(self, x, text):
        # 从 eot 嵌入中获取特征(eot_token 为每个序列中的最大值)
        x = (
            x[torch.arange(x.shape[0]), text.argmax(dim=-1)]  # 获取 eot 嵌入
            @ self.model.text_projection  # 应用文本投影
        )
        return x  # 返回池化结果
    # 定义文本转换器的前向传播方法,接受输入张量 x 和可选的注意力掩码
        def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
            # 创建一个空字典以存储输出
            outputs = {}
            # 遍历转换器的每个残差块
            for i, r in enumerate(self.model.transformer.resblocks):
                # 如果是最后一个残差块,将输入张量进行维度变换并存储
                if i == len(self.model.transformer.resblocks) - 1:
                    outputs["penultimate"] = x.permute(1, 0, 2)  # 将维度从 LND 转换为 NLD
                # 如果启用了梯度检查点并且不是脚本模式
                if (
                    self.model.transformer.grad_checkpointing
                    and not torch.jit.is_scripting()
                ):
                    # 使用检查点技术进行前向传播以节省内存
                    x = checkpoint(r, x, attn_mask)
                else:
                    # 正常执行残差块的前向传播
                    x = r(x, attn_mask=attn_mask)
            # 将最后输出张量的维度进行转换并存储
            outputs["last"] = x.permute(1, 0, 2)  # 将维度从 LND 转换为 NLD
            # 返回包含倒数第二层和最后一层输出的字典
            return outputs
    
        # 定义编码方法,接受文本输入
        def encode(self, text):
            # 调用文本转换器的前向传播方法并返回结果
            return self(text)
# 定义一个名为 FrozenOpenCLIPEmbedder 的类,继承自 AbstractEmbModel
class FrozenOpenCLIPEmbedder(AbstractEmbModel):
    # 定义一个类属性 LAYERS,包含模型中可用的层
    LAYERS = [
        # "pooled",  # 注释掉的层选项
        "last",  # 最后一个层
        "penultimate",  # 倒数第二个层
    ]

    # 初始化方法,用于设置实例的基本参数
    def __init__(
        self,
        arch="ViT-H-14",  # 模型架构
        version="laion2b_s32b_b79k",  # 预训练模型版本
        device="cuda",  # 设备类型,默认为 GPU
        max_length=77,  # 最大输入文本长度
        freeze=True,  # 是否冻结模型参数
        layer="last",  # 选择使用的层
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 确保选择的层在可用层中
        assert layer in self.LAYERS
        # 创建模型及其转换,使用指定的架构和设备
        model, _, _ = open_clip.create_model_and_transforms(
            arch, device=torch.device("cpu"), pretrained=version
        )
        # 删除视觉部分以冻结模型
        del model.visual
        # 将模型赋值给实例变量
        self.model = model

        # 设置设备属性
        self.device = device
        # 设置最大长度属性
        self.max_length = max_length
        # 如果需要冻结,则调用冻结方法
        if freeze:
            self.freeze()
        # 设置所选层
        self.layer = layer
        # 根据所选层设置层索引
        if self.layer == "last":
            self.layer_idx = 0
        elif self.layer == "penultimate":
            self.layer_idx = 1
        else:
            # 如果层不在可用选项中,则抛出异常
            raise NotImplementedError()

    # 冻结模型参数的方法
    def freeze(self):
        # 将模型设置为评估模式
        self.model = self.model.eval()
        # 将所有参数的 requires_grad 属性设置为 False,停止梯度计算
        for param in self.parameters():
            param.requires_grad = False

    # 前向传播方法,接受文本输入
    def forward(self, text):
        # 对文本进行分词处理
        tokens = open_clip.tokenize(text)
        # 使用变换器进行编码,并将结果传送到设备
        z = self.encode_with_transformer(tokens.to(self.device))
        # 返回编码结果
        return z

    # 使用变换器编码文本的方法
    def encode_with_transformer(self, text):
        # 获取文本的嵌入表示,形状为 [batch_size, n_ctx, d_model]
        x = self.model.token_embedding(text)  
        # 加上位置嵌入
        x = x + self.model.positional_embedding
        # 重新排列维度,将形状从 NLD 转换为 LND
        x = x.permute(1, 0, 2)  
        # 执行变换器前向传播
        x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
        # 重新排列维度回到 NLD
        x = x.permute(1, 0, 2)  
        # 通过最终层归一化
        x = self.model.ln_final(x)
        # 返回处理后的结果
        return x

    # 执行文本变换器前向传播的方法
    def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
        # 遍历变换器的残差块
        for i, r in enumerate(self.model.transformer.resblocks):
            # 如果达到所需层索引则停止
            if i == len(self.model.transformer.resblocks) - self.layer_idx:
                break
            # 检查是否使用梯度检查点
            if (
                self.model.transformer.grad_checkpointing
                and not torch.jit.is_scripting()
            ):
                # 使用检查点方式更新输入
                x = checkpoint(r, x, attn_mask)
            else:
                # 否则直接通过残差块处理输入
                x = r(x, attn_mask=attn_mask)
        # 返回变换后的结果
        return x

    # 编码文本的简化方法,直接调用前向传播
    def encode(self, text):
        return self(text)


# 定义一个名为 FrozenOpenCLIPImageEmbedder 的类,继承自 AbstractEmbModel
class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
    """
    使用 OpenCLIP 视觉变换器编码器处理图像
    """

    # 初始化方法,用于设置实例的基本参数
    def __init__(
        self,
        arch="ViT-H-14",  # 模型架构
        version="laion2b_s32b_b79k",  # 预训练模型版本
        device="cuda",  # 设备类型,默认为 GPU
        max_length=77,  # 最大输入图像长度
        freeze=True,  # 是否冻结模型参数
        antialias=True,  # 是否使用抗锯齿
        ucg_rate=0.0,  # 用户定义的裁剪率
        unsqueeze_dim=False,  # 是否增加维度
        repeat_to_max_len=False,  # 是否重复到最大长度
        num_image_crops=0,  # 图像裁剪数量
        output_tokens=False,  # 是否输出 tokens
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 创建模型和转换器,使用指定的架构、设备和预训练版本
        model, _, _ = open_clip.create_model_and_transforms(
            arch,
            device=torch.device("cpu"),  # 使用 CPU 作为计算设备
            pretrained=version,  # 指定预训练版本
        )
        # 删除模型中的变换器部分
        del model.transformer
        # 将创建的模型赋值给实例变量
        self.model = model
        # 设置最大图像裁剪数量
        self.max_crops = num_image_crops
        # 检查是否需要填充到最大长度
        self.pad_to_max_len = self.max_crops > 0
        # 检查是否需要重复到最大长度
        self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len)
        # 设置设备类型
        self.device = device
        # 设置最大长度
        self.max_length = max_length
        # 如果需要冻结模型参数,则调用冻结方法
        if freeze:
            self.freeze()

        # 设置抗锯齿参数
        self.antialias = antialias

        # 注册均值张量作为缓冲区,设置为非持久性
        self.register_buffer(
            "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
        )
        # 注册标准差张量作为缓冲区,设置为非持久性
        self.register_buffer(
            "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
        )
        # 设置 UCG 速率
        self.ucg_rate = ucg_rate
        # 设置需要扩展的维度
        self.unsqueeze_dim = unsqueeze_dim
        # 存储的批次初始化为 None
        self.stored_batch = None
        # 设置视觉模型的输出标记
        self.model.visual.output_tokens = output_tokens
        # 保存输出标记的状态
        self.output_tokens = output_tokens

    def preprocess(self, x):
        # 将输入归一化到 [0,1] 范围
        x = kornia.geometry.resize(
            x,
            (224, 224),  # 将图像大小调整为 224x224
            interpolation="bicubic",  # 使用双三次插值
            align_corners=True,  # 对齐角点
            antialias=self.antialias,  # 使用抗锯齿
        )
        # 将图像数据从 [-1,1] 范围转换到 [0,1]
        x = (x + 1.0) / 2.0
        # 根据 CLIP 模型的均值和标准差重新归一化图像
        x = kornia.enhance.normalize(x, self.mean, self.std)
        # 返回处理后的图像
        return x

    def freeze(self):
        # 将模型设置为评估模式
        self.model = self.model.eval()
        # 禁用所有参数的梯度计算
        for param in self.parameters():
            param.requires_grad = False

    @autocast  # 启用自动混合精度计算
    # 前向传播方法,处理输入图像并返回特征或标记
        def forward(self, image, no_dropout=False):
            # 使用视觉变换器对输入图像进行编码,得到特征 z
            z = self.encode_with_vision_transformer(image)
            # 初始化 tokens 为 None
            tokens = None
            # 如果输出标记为真,分离特征和标记
            if self.output_tokens:
                z, tokens = z[0], z[1]
            # 将特征 z 转换为与图像相同的数据类型
            z = z.to(image.dtype)
            # 如果 ucg_rate 大于 0,且没有进行 dropout,且没有最大裁剪
            if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
                # 根据 Bernoulli 分布随机丢弃特征
                z = (
                    torch.bernoulli(
                        (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
                    )[:, None]
                    * z
                )
                # 如果 tokens 不为 None,应用相同的丢弃逻辑
                if tokens is not None:
                    tokens = (
                        expand_dims_like(
                            torch.bernoulli(
                                (1.0 - self.ucg_rate)
                                * torch.ones(tokens.shape[0], device=tokens.device)
                            ),
                            tokens,
                        )
                        * tokens
                    )
            # 如果需要扩展维度,将特征 z 变为三维
            if self.unsqueeze_dim:
                z = z[:, None, :]
            # 如果输出标记为真,检查标记和特征的重复与填充条件
            if self.output_tokens:
                assert not self.repeat_to_max_len
                assert not self.pad_to_max_len
                # 返回标记和特征
                return tokens, z
            # 如果需要重复到最大长度
            if self.repeat_to_max_len:
                # 将二维特征扩展为三维
                if z.dim() == 2:
                    z_ = z[:, None, :]
                else:
                    z_ = z
                # 返回重复的特征
                return repeat(z_, "b 1 d -> b n d", n=self.max_length), z
            # 如果需要填充到最大长度
            elif self.pad_to_max_len:
                # 确保特征是三维的
                assert z.dim() == 3
                # 在特征后面填充零
                z_pad = torch.cat(
                    (
                        z,
                        torch.zeros(
                            z.shape[0],
                            self.max_length - z.shape[1],
                            z.shape[2],
                            device=z.device,
                        ),
                    ),
                    1,
                )
                # 返回填充后的特征和第一个时间步的特征
                return z_pad, z_pad[:, 0, ...]
            # 默认返回特征 z
            return z
    # 使用视觉变换器对图像进行编码
    def encode_with_vision_transformer(self, img):
        # 如果最大裁剪数大于0,则对图像进行裁剪预处理
        # if self.max_crops > 0:
        #    img = self.preprocess_by_cropping(img)
        # 检查图像维度是否为5
        if img.dim() == 5:
            # 确保最大裁剪数与图像的第二维度匹配
            assert self.max_crops == img.shape[1]
            # 重排图像维度,将其从 (b n) c h w 变为 (b n) c h w
            img = rearrange(img, "b n c h w -> (b n) c h w")
        # 对图像进行预处理
        img = self.preprocess(img)
        # 如果不需要输出tokens
        if not self.output_tokens:
            # 确保模型不输出tokens
            assert not self.model.visual.output_tokens
            # 将图像传入模型进行处理
            x = self.model.visual(img)
            tokens = None
        else:
            # 确保模型输出tokens
            assert self.model.visual.output_tokens
            # 将图像传入模型并获取输出和tokens
            x, tokens = self.model.visual(img)
        # 如果最大裁剪数大于0
        if self.max_crops > 0:
            # 重排输出,将其从 (b n) d 变为 b n d
            x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
            # 在序列轴上进行drop out,控制一定比例的输出
            x = (
                torch.bernoulli(
                    (1.0 - self.ucg_rate)
                    * torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
                )
                * x
            )
            # 如果tokens不为None
            if tokens is not None:
                # 重排tokens,将其从 (b n) t d 变为 b t (n d)
                tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
                # 输出实验性提示信息
                print(
                    f"You are running very experimental token-concat in {self.__class__.__name__}. "
                    f"Check what you are doing, and then remove this message."
                )
        # 如果需要输出tokens,则返回
        if self.output_tokens:
            return x, tokens
        # 返回处理后的图像
        return x
    
    # 对输入文本进行编码
    def encode(self, text):
        # 调用自身对文本进行处理
        return self(text)
# 定义一个继承自 AbstractEmbModel 的类,名为 FrozenCLIPT5Encoder
class FrozenCLIPT5Encoder(AbstractEmbModel):
    # 构造函数,初始化模型的参数
    def __init__(
        self,
        clip_version="openai/clip-vit-large-patch14",  # CLIP 模型的版本
        t5_version="google/t5-v1_1-xl",  # T5 模型的版本
        device="cuda",  # 指定使用的设备
        clip_max_length=77,  # CLIP 模型的最大输入长度
        t5_max_length=77,  # T5 模型的最大输入长度
    ):
        super().__init__()  # 调用父类的构造函数
        # 创建 CLIP 嵌入模型实例
        self.clip_encoder = FrozenCLIPEmbedder(
            clip_version, device, max_length=clip_max_length
        )
        # 创建 T5 嵌入模型实例
        self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
        # 打印 CLIP 和 T5 模型的参数数量
        print(
            f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
            f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
        )

    # 定义编码函数,调用前向传播
    def encode(self, text):
        return self(text)

    # 定义前向传播函数
    def forward(self, text):
        # 使用 CLIP 编码器对文本进行编码
        clip_z = self.clip_encoder.encode(text)
        # 使用 T5 编码器对文本进行编码
        t5_z = self.t5_encoder.encode(text)
        # 返回 CLIP 和 T5 的编码结果
        return [clip_z, t5_z]


# 定义一个继承自 nn.Module 的类,名为 SpatialRescaler
class SpatialRescaler(nn.Module):
    # 构造函数,初始化空间重缩放器的参数
    def __init__(
        self,
        n_stages=1,  # 重缩放的阶段数
        method="bilinear",  # 插值方法
        multiplier=0.5,  # 缩放因子
        in_channels=3,  # 输入通道数
        out_channels=None,  # 输出通道数
        bias=False,  # 是否使用偏置
        wrap_video=False,  # 是否处理视频数据
        kernel_size=1,  # 卷积核大小
        remap_output=False,  # 是否重映射输出通道
    ):
        super().__init__()  # 调用父类的构造函数
        self.n_stages = n_stages  # 保存阶段数
        assert self.n_stages >= 0  # 确保阶段数非负
        # 验证插值方法是否在支持的范围内
        assert method in [
            "nearest",
            "linear",
            "bilinear",
            "trilinear",
            "bicubic",
            "area",
        ]
        self.multiplier = multiplier  # 保存缩放因子
        # 创建部分应用的插值函数
        self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
        # 判断是否需要重映射输出通道
        self.remap_output = out_channels is not None or remap_output
        # 如果需要重映射输出通道,创建卷积层
        if self.remap_output:
            print(
                f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
            )
            self.channel_mapper = nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                bias=bias,
                padding=kernel_size // 2,
            )
        self.wrap_video = wrap_video  # 保存是否处理视频的标志

    # 定义前向传播函数
    def forward(self, x):
        # 如果处理视频数据且输入是五维张量,进行维度调整
        if self.wrap_video and x.ndim == 5:
            B, C, T, H, W = x.shape  # 解包维度
            x = rearrange(x, "b c t h w -> b t c h w")  # 调整维度顺序
            x = rearrange(x, "b t c h w -> (b t) c h w")  # 合并批次和时间维度

        # 进行指定阶段的重缩放操作
        for stage in range(self.n_stages):
            x = self.interpolator(x, scale_factor=self.multiplier)

        # 如果处理视频数据,恢复维度顺序
        if self.wrap_video:
            x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T, c=C)  # 恢复维度
            x = rearrange(x, "b t c h w -> b c t h w")  # 再次调整维度
        # 如果需要重映射输出,应用卷积层
        if self.remap_output:
            x = self.channel_mapper(x)
        return x  # 返回处理后的张量

    # 定义编码函数,调用前向传播
    def encode(self, x):
        return self(x)


# 定义一个继承自 nn.Module 的类,名为 LowScaleEncoder
class LowScaleEncoder(nn.Module):
    # 构造函数,初始化低缩放编码器的参数
    def __init__(
        self,
        model_config,  # 模型配置
        linear_start,  # 线性起始值
        linear_end,  # 线性结束值
        timesteps=1000,  # 时间步数
        max_noise_level=250,  # 最大噪声水平
        output_size=64,  # 输出大小
        scale_factor=1.0,  # 缩放因子
    # 定义一个类,继承自父类
    def __init__(self, max_noise_level, model_config, timesteps, linear_start, linear_end, output_size, scale_factor):
        # 调用父类的初始化方法
        super().__init__()
        # 设置最大噪声级别
        self.max_noise_level = max_noise_level
        # 根据配置实例化模型
        self.model = instantiate_from_config(model_config)
        # 注册一个调度表,用于控制噪声的变化
        self.augmentation_schedule = self.register_schedule(
            timesteps=timesteps, linear_start=linear_start, linear_end=linear_end
        )
        # 设置输出大小
        self.out_size = output_size
        # 设置缩放因子
        self.scale_factor = scale_factor
    
    # 注册一个调度表,用于控制噪声的变化
    def register_schedule(self, beta_schedule, timesteps, linear_start, linear_end, cosine_s):
        # 根据给定的参数生成 beta 调度表
        betas = make_beta_schedule(
            beta_schedule,
            timesteps,
            linear_start=linear_start,
            linear_end=linear_end,
            cosine_s=cosine_s,
        )
        # 根据 betas 计算 alphas
        alphas = 1.0 - betas
        # 计算 alphas 的累积乘积
        alphas_cumprod = np.cumprod(alphas, axis=0)
        # 计算 alphas 的累积乘积的前一个值
        alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
    
        # 获取 betas 的形状
        (timesteps,) = betas.shape
        # 将 timesteps 转换为整数
        self.num_timesteps = int(timesteps)
        # 设置线性起始值
        self.linear_start = linear_start
        # 设置线性结束值
        self.linear_end = linear_end
        # 判断 alphas_cumprod 的形状是否与 num_timesteps 相同
        assert (
            alphas_cumprod.shape[0] == self.num_timesteps
        ), "alphas have to be defined for each timestep"
    
        # 创建一个偏函数,用于将数组转换为 torch.tensor
        to_torch = partial(torch.tensor, dtype=torch.float32)
    
        # 注册缓冲区,存储 betas
        self.register_buffer("betas", to_torch(betas))
        # 注册缓冲区,存储 alphas_cumprod
        self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
        # 注册缓冲区,存储 alphas_cumprod_prev
        self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
    
        # 计算扩散 q(x_t | x_{t-1}) 和其他参数
        # 注册缓冲区,存储 sqrt_alphas_cumprod
        self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
        # 注册缓冲区,存储 sqrt_one_minus_alphas_cumprod
        self.register_buffer(
            "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
        )
        # 注册缓冲区,存储 log_one_minus_alphas_cumprod
        self.register_buffer(
            "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
        )
        # 注册缓冲区,存储 sqrt_recip_alphas_cumprod
        self.register_buffer(
            "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
        )
        # 注册缓冲区,存储 sqrt_recipm1_alphas_cumprod
        self.register_buffer(
            "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
        )
    
    # 从初始值 x_start 和时间步 t 生成噪声样本
    def q_sample(self, x_start, t, noise):
        # 如果没有传入噪声,则生成一个与 x_start 形状相同的随机噪声
        noise = default(noise, lambda: torch.randn_like(x_start))
        # 根据噪声和调度表生成噪声样本
        return (
            extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
            + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
            * noise
        )
    # 定义前向传播函数,接收输入 x
    def forward(self, x):
        # 使用模型对输入 x 进行编码,得到潜在表示 z
        z = self.model.encode(x)
        # 检查 z 是否为对角高斯分布类型
        if isinstance(z, DiagonalGaussianDistribution):
            # 从高斯分布中采样,更新 z
            z = z.sample()
        # 将 z 乘以缩放因子,调整其大小
        z = z * self.scale_factor
        # 随机生成噪声水平,范围从 0 到 max_noise_level,形状与批大小相同
        noise_level = torch.randint(
            0, self.max_noise_level, (x.shape[0],), device=x.device
        ).long()
        # 对 z 应用 q_sample 函数,根据噪声水平生成样本
        z = self.q_sample(z, noise_level)
        # 如果指定了输出大小,则调整 z 的尺寸
        if self.out_size is not None:
            z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest")
        # z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)  # 注释掉的代码:可能用于调整 z 的形状
        # 返回处理后的 z 和噪声水平
        return z, noise_level

    # 定义解码函数,接收潜在表示 z
    def decode(self, z):
        # 将 z 除以缩放因子,恢复其原始尺度
        z = z / self.scale_factor
        # 使用模型对 z 进行解码,返回解码后的结果
        return self.model.decode(z)
# 定义一个多维时间步嵌入模型类,继承自抽象嵌入模型
class ConcatTimestepEmbedderND(AbstractEmbModel):
    """嵌入每个维度并独立拼接它们"""

    # 初始化方法,接受输出维度参数
    def __init__(self, outdim):
        # 调用父类的初始化方法
        super().__init__()
        # 创建时间步嵌入对象
        self.timestep = Timestep(outdim)
        # 保存输出维度
        self.outdim = outdim

    # 前向传播方法,处理输入数据
    def forward(self, x):
        # 如果输入是1维,则增加一个维度
        if x.ndim == 1:
            x = x[:, None]
        # 确保输入为2维
        assert len(x.shape) == 2
        # 获取批大小和维度数量
        b, dims = x.shape[0], x.shape[1]
        # 重排输入数据为一维
        x = rearrange(x, "b d -> (b d)")
        # 获取时间步嵌入
        emb = self.timestep(x)
        # 重排嵌入为批大小和输出维度格式
        emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
        # 返回最终的嵌入
        return emb


# 定义一个高斯编码器类,继承自编码器和抽象嵌入模型
class GaussianEncoder(Encoder, AbstractEmbModel):
    # 初始化方法,接受权重和是否扁平化输出参数
    def __init__(
        self, weight: float = 1.0, flatten_output: bool = True, *args, **kwargs
    ):
        # 调用父类的初始化方法
        super().__init__(*args, **kwargs)
        # 创建对角高斯正则化器
        self.posterior = DiagonalGaussianRegularizer()
        # 保存权重
        self.weight = weight
        # 保存是否扁平化输出标志
        self.flatten_output = flatten_output

    # 前向传播方法,处理输入数据
    def forward(self, x) -> Tuple[Dict, torch.Tensor]:
        # 调用父类的前向传播,获取潜变量
        z = super().forward(x)
        # 通过正则化器处理潜变量
        z, log = self.posterior(z)
        # 记录损失和权重
        log["loss"] = log["kl_loss"]
        log["weight"] = self.weight
        # 如果需要,扁平化输出
        if self.flatten_output:
            z = rearrange(z, "b c h w -> b (h w ) c")
        # 返回日志和潜变量
        return log, z


# 定义一个冻结的 OpenCLIP 图像预测嵌入模型类
class FrozenOpenCLIPImagePredictionEmbedder(AbstractEmbModel):
    # 初始化方法,接受配置、条件帧数和副本数
    def __init__(
        self,
        open_clip_embedding_config: Dict,
        n_cond_frames: int,
        n_copies: int,
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 保存条件帧数
        self.n_cond_frames = n_cond_frames
        # 保存副本数
        self.n_copies = n_copies
        # 实例化 OpenCLIP 嵌入对象
        self.open_clip = instantiate_from_config(open_clip_embedding_config)

    # 前向传播方法,处理视频输入
    def forward(self, vid):
        # 通过 OpenCLIP 嵌入视频数据
        vid = self.open_clip(vid)
        # 重排视频数据为批大小和时间步格式
        vid = rearrange(vid, "(b t) d -> b t d", t=self.n_cond_frames)
        # 重复视频数据以匹配副本数
        vid = repeat(vid, "b t d -> (b s) t d", s=self.n_copies)

        # 返回处理后的视频数据
        return vid


# 定义一个原始图像嵌入模型类,继承自抽象嵌入模型
class RawImageEmbedder(AbstractEmbModel):
    """
    在 Instructpix2pix 中将原始图像作为条件
    """
    
    # 前向传播方法,直接返回输入图像
    def forward(self, image):
        return image

标签:__,return,3Plus,CogView,self,torch,源码,key,def
From: https://www.cnblogs.com/apachecn/p/18494399

相关文章

  • CogView3---CogView-3Plus-微调代码源码解析-三-
    CogView3&CogView-3Plus微调代码源码解析(三).\cogview3-finetune\sat\sgm\modules\diffusionmodules\guiders.py#导入logging模块,用于记录日志信息importlogging#从abc模块导入ABC类和abstractmethod装饰器,用于定义抽象基类和抽象方法fromabcimportABC,abst......
  • CogView3---CogView-3Plus-微调代码源码解析-二-
    CogView3&CogView-3Plus微调代码源码解析(二).\cogview3-finetune\sat\sgm\models\__init__.py#从同一模块导入AutoencodingEngine类,用于后续的自动编码器操作from.autoencoderimportAutoencodingEngine#注释文本(可能是无关信息或标识符)#XuDwndGaCFo.\cogview3-fi......
  • 基于SpringBoot+Vue的大数据技术的宠物商品信息比价及推荐系统(源码+LW+调试文档+讲解
    在宠物经济日益繁荣的今天,为宠物主人提供一个高效的宠物商品信息比价及推荐系统至关重要。本系统基于SpringBoot+Vue并结合大数据技术,为宠物主人带来全新的购物体验。在设计上,系统广泛收集各类宠物商品的信息,包括价格、品牌、规格、用户评价等。通过大数据分析,对不同......
  • 基于SpringBoot+Vue的大数据高乐健身器材销售数据可视化系统设计与实现(源码+LW+调试
    在健身热潮持续升温的当下,健身器材销售数据的有效管理和分析至关重要。本系统基于SpringBoot+Vue并结合大数据技术,为高乐健身器材的销售管理提供强大的可视化解决方案。在设计上,系统全面收集高乐健身器材的销售数据,包括产品种类、销售数量、销售地区、销售时间等多维......
  • 【开题报告】基于Springboot+vue中医古方名方信息管理系统(程序+源码+论文) 计算机毕业
    本系统(程序+源码)带文档lw万字以上文末可获取一份本项目的java源码和数据库参考。系统程序文件列表开题报告内容研究背景中医作为中华民族的传统医学,承载着丰富的历史文化底蕴与独特的医疗智慧。在历史的长河中,无数中医先辈通过临床实践,总结出了大量疗效显著的古方名方,这......
  • Springboot+vue社区智慧医疗服务管理系统的设计与实现 毕业设计程序源码98275
    目 录摘要1绪论1.1研究背景1.2研究意义1.3论文结构与章节安排2 社区智慧医疗服务管理系统分析2.1可行性分析2.2系统流程分析2.2.1数据增加流程2.2.2数据修改流程2.2.3数据删除流程2.3系统功能分析2.3.1功能性分析2.4系统用例分析......
  • flask影响电影票房因素的数据分析及可视化系统 毕业设计程序源码19201
    摘 要现在电影行业飞速发展,传统影响电影票房因素的数据分析及可视化方式己经逐渐跟不上时代变化的速度。在计算机行业发达的今天,希望利用现代爬虫技术的优势,提高数据分析及可视化效率及效果。本系统采用的是 Python 语言,使用 PyCharm 这一款开发工具,综合运用了 Tkinte......
  • 基于大数据 Python+Vue 电影票房爬取可视化系统(源码+LW+部署讲解+数据库+ppt)
    !!!!!!!!!会持续一直更新下去有问必答一键收藏关注不迷路源码获取:https://pan.baidu.com/s/1aRpOv3f2sdtVYOogQjb8jg?pwd=jf1d提取码:jf1d!!!!!!!!!项目介绍在快速发展的社会中,娱乐领域也在不断进步。为了提高数据分析的效率和观众的观影体验,越来越多的影视公司和电影院选择利用互联网......
  • 开源图像超分ECBSR项目源码分析
    相关介绍项目GitHub地址:https://github.com/xindongzhang/ECBSR项目相关论文:https://www4.comp.polyu.edu.hk/~cslzhang/paper/MM21_ECBSR.pdf(也可以点这里下载)论文解读:Edge-orientedConvolutionBlockforReal-timeSuperResolutiononMobileDevicesWindows环境训练......
  • Java毕设项目案例实战II 基于移动平台的远程在线诊疗系统(开发文档+数据库+源码)
    目录一、前言二、技术介绍三、系统实现四、论文参考五、核心代码六、源码获取全栈码农以及毕业设计实战开发,CSDN平台Java领域新星创作者,专注于大学生项目实战开发、讲解和毕业答疑辅导。获取源码联系方式请查看文末一、前言在当今数字化时代,医疗行业正经历着前所未......