首页 > 编程语言 >CogVideo---CogVideoX-微调代码源码解析-十一-

CogVideo---CogVideoX-微调代码源码解析-十一-

时间:2024-10-23 09:25:29浏览次数:1  
标签:dim None CogVideoX self torch --- 源码 meta context

CogVideo & CogVideoX 微调代码源码解析(十一)

.\cogvideo-finetune\sat\sgm\modules\encoders\__init__.py

请提供需要注释的代码,我将为您添加注释。

.\cogvideo-finetune\sat\sgm\modules\video_attention.py

# 导入 PyTorch 库
import torch

# 从上级模块导入必要的组件
from ..modules.attention import *
from ..modules.diffusionmodules.util import AlphaBlender, linear, timestep_embedding

# 定义一个新的类 TimeMixSequential,继承自 nn.Sequential
class TimeMixSequential(nn.Sequential):
    # 重写 forward 方法
    def forward(self, x, context=None, timesteps=None):
        # 遍历所有层,将输入 x 逐层传递
        for layer in self:
            x = layer(x, context, timesteps)

        # 返回最终的输出
        return x

# 定义 VideoTransformerBlock 类,继承自 nn.Module
class VideoTransformerBlock(nn.Module):
    # 定义注意力模式的字典
    ATTENTION_MODES = {
        "softmax": CrossAttention,  # 使用软最大注意力
        "softmax-xformers": MemoryEfficientCrossAttention,  # 使用内存高效的软最大注意力
    }

    # 初始化方法
    def __init__(
        self,
        dim,  # 输入特征的维度
        n_heads,  # 注意力头的数量
        d_head,  # 每个头的维度
        dropout=0.0,  # dropout 的比率
        context_dim=None,  # 上下文特征的维度
        gated_ff=True,  # 是否使用门控前馈网络
        checkpoint=True,  # 是否启用检查点功能
        timesteps=None,  # 时间步
        ff_in=False,  # 是否使用前馈网络输入
        inner_dim=None,  # 内部特征维度
        attn_mode="softmax",  # 注意力模式
        disable_self_attn=False,  # 是否禁用自注意力
        disable_temporal_crossattention=False,  # 是否禁用时间交叉注意力
        switch_temporal_ca_to_sa=False,  # 是否将时间交叉注意力切换为自注意力
    ):
        # 调用父类构造函数
        super().__init__()

        # 根据指定的注意力模式选择注意力类
        attn_cls = self.ATTENTION_MODES[attn_mode]

        # 确定是否使用前馈网络输入
        self.ff_in = ff_in or inner_dim is not None
        # 如果没有指定内部维度,则将其设为输入维度
        if inner_dim is None:
            inner_dim = dim

        # 确保头的数量乘以每个头的维度等于内部维度
        assert int(n_heads * d_head) == inner_dim

        # 确定是否使用残差连接
        self.is_res = inner_dim == dim

        # 如果使用前馈网络输入,定义输入层的归一化和前馈网络
        if self.ff_in:
            self.norm_in = nn.LayerNorm(dim)  # 定义层归一化
            self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff)  # 定义前馈网络

        # 设置时间步数
        self.timesteps = timesteps
        # 设置是否禁用自注意力
        self.disable_self_attn = disable_self_attn
        # 如果禁用自注意力,使用交叉注意力
        if self.disable_self_attn:
            self.attn1 = attn_cls(
                query_dim=inner_dim,  # 查询的维度
                heads=n_heads,  # 注意力头的数量
                dim_head=d_head,  # 每个头的维度
                context_dim=context_dim,  # 上下文维度
                dropout=dropout,  # dropout 的比率
            )  # 这是一个交叉注意力
        else:
            # 否则,使用自注意力
            self.attn1 = attn_cls(
                query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
            )  # 这是一个自注意力

        # 定义前馈网络
        self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff)

        # 根据设置决定是否禁用时间交叉注意力
        if disable_temporal_crossattention:
            # 如果要切换交叉注意力为自注意力,抛出异常
            if switch_temporal_ca_to_sa:
                raise ValueError
            else:
                self.attn2 = None  # 不使用时间交叉注意力
        else:
            # 定义内部归一化层
            self.norm2 = nn.LayerNorm(inner_dim)
            # 根据设置选择交叉注意力或自注意力
            if switch_temporal_ca_to_sa:
                self.attn2 = attn_cls(
                    query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
                )  # 这是一个自注意力
            else:
                self.attn2 = attn_cls(
                    query_dim=inner_dim,
                    context_dim=context_dim,  # 上下文维度
                    heads=n_heads,  # 注意力头的数量
                    dim_head=d_head,  # 每个头的维度
                    dropout=dropout,  # dropout 的比率
                )  # 如果上下文为空,则为自注意力

        # 定义层归一化
        self.norm1 = nn.LayerNorm(inner_dim)
        self.norm3 = nn.LayerNorm(inner_dim)
        # 保存切换自注意力和交叉注意力的设置
        self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa

        # 保存检查点设置
        self.checkpoint = checkpoint
        # 如果启用检查点,打印信息
        if self.checkpoint:
            print(f"{self.__class__.__name__} is using checkpointing")
    # 定义前向传播方法,接收输入张量 x 和可选的上下文及时间步数
        def forward(self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None) -> torch.Tensor:
            # 如果启用检查点,使用检查点机制来执行前向传播
            if self.checkpoint:
                return checkpoint(self._forward, x, context, timesteps)
            # 否则,直接调用内部前向传播方法
            else:
                return self._forward(x, context, timesteps=timesteps)
    
    # 定义内部前向传播方法,处理输入张量及可选的上下文和时间步数
        def _forward(self, x, context=None, timesteps=None):
            # 确保有时间步数,如果未传入则使用对象属性
            assert self.timesteps or timesteps
            # 确保时间步数一致性,若两者都提供则必须相等
            assert not (self.timesteps and timesteps) or self.timesteps == timesteps
            # 设置时间步数
            timesteps = self.timesteps or timesteps
            # 获取输入张量的批次、序列长度和通道数
            B, S, C = x.shape
            # 调整输入张量的形状以适应后续计算
            x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps)
    
            # 如果启用前馈输入,则进行前馈层处理
            if self.ff_in:
                x_skip = x  # 保存输入以便残差连接
                x = self.ff_in(self.norm_in(x))  # 归一化后通过前馈层
                # 如果使用残差连接,则将输入加回
                if self.is_res:
                    x += x_skip
    
            # 如果禁用自注意力,则仅使用上下文进行处理
            if self.disable_self_attn:
                x = self.attn1(self.norm1(x), context=context) + x  # 计算自注意力并加回原输入
            else:
                x = self.attn1(self.norm1(x)) + x  # 计算自注意力并加回原输入
    
            # 如果存在第二个自注意力层
            if self.attn2 is not None:
                # 根据标志决定是否使用上下文
                if self.switch_temporal_ca_to_sa:
                    x = self.attn2(self.norm2(x)) + x  # 仅计算自注意力
                else:
                    x = self.attn2(self.norm2(x), context=context) + x  # 使用上下文计算自注意力
            x_skip = x  # 保存当前状态以便残差连接
            x = self.ff(self.norm3(x))  # 归一化后通过前馈层
            # 如果使用残差连接,则将之前的状态加回
            if self.is_res:
                x += x_skip
    
            # 调整输出张量的形状为原始格式
            x = rearrange(x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps)
            return x  # 返回处理后的张量
    
    # 定义获取最后一层的权重的方法
        def get_last_layer(self):
            return self.ff.net[-1].weight  # 返回前馈网络最后一层的权重
# 定义数据类型与 PyTorch 数据类型的映射字典
str_to_dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}

# 定义空间视频变换器类,继承自空间变换器
class SpatialVideoTransformer(SpatialTransformer):
    # 初始化方法,接收多个参数以配置变换器
    def __init__(
        self,
        in_channels,  # 输入通道数
        n_heads,      # 注意力头数
        d_head,       # 每个头的维度
        depth=1,      # 网络深度
        dropout=0.0,  # dropout 率
        use_linear=False,  # 是否使用线性层
        context_dim=None,  # 上下文维度
        use_spatial_context=False,  # 是否使用空间上下文
        timesteps=None,  # 时间步数
        merge_strategy: str = "fixed",  # 合并策略
        merge_factor: float = 0.5,  # 合并因子
        time_context_dim=None,  # 时间上下文维度
        ff_in=False,  # 前馈网络输入标志
        checkpoint=False,  # 是否启用检查点
        time_depth=1,  # 时间深度
        attn_mode="softmax",  # 注意力模式
        disable_self_attn=False,  # 是否禁用自注意力
        disable_temporal_crossattention=False,  # 是否禁用时间交叉注意力
        max_time_embed_period: int = 10000,  # 最大时间嵌入周期
        dtype="fp32",  # 数据类型
    ):
        # 调用父类初始化方法,传递部分参数
        super().__init__(
            in_channels,
            n_heads,
            d_head,
            depth=depth,
            dropout=dropout,
            attn_type=attn_mode,
            use_checkpoint=checkpoint,
            context_dim=context_dim,
            use_linear=use_linear,
            disable_self_attn=disable_self_attn,
        )
        # 设置类属性
        self.time_depth = time_depth
        self.depth = depth
        self.max_time_embed_period = max_time_embed_period

        # 初始化时间混合头维度与数量
        time_mix_d_head = d_head
        n_time_mix_heads = n_heads

        # 计算时间混合内部维度
        time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)

        # 计算内部维度
        inner_dim = n_heads * d_head
        # 如果使用空间上下文,更新时间上下文维度
        if use_spatial_context:
            time_context_dim = context_dim

        # 创建时间堆栈,包含多个视频变换器块
        self.time_stack = nn.ModuleList(
            [
                VideoTransformerBlock(
                    inner_dim,  # 内部维度
                    n_time_mix_heads,  # 时间混合头数
                    time_mix_d_head,  # 时间混合头维度
                    dropout=dropout,  # dropout 率
                    context_dim=time_context_dim,  # 时间上下文维度
                    timesteps=timesteps,  # 时间步数
                    checkpoint=checkpoint,  # 检查点标志
                    ff_in=ff_in,  # 前馈网络输入标志
                    inner_dim=time_mix_inner_dim,  # 时间混合内部维度
                    attn_mode=attn_mode,  # 注意力模式
                    disable_self_attn=disable_self_attn,  # 禁用自注意力标志
                    disable_temporal_crossattention=disable_temporal_crossattention,  # 禁用时间交叉注意力标志
                )
                for _ in range(self.depth)  # 创建指定深度的块
            ]
        )

        # 确保时间堆栈与变换器块的数量一致
        assert len(self.time_stack) == len(self.transformer_blocks)

        # 设置类属性
        self.use_spatial_context = use_spatial_context
        self.in_channels = in_channels

        # 计算时间嵌入维度
        time_embed_dim = self.in_channels * 4
        # 定义时间位置嵌入序列
        self.time_pos_embed = nn.Sequential(
            linear(self.in_channels, time_embed_dim),  # 从输入通道到时间嵌入维度的线性变换
            nn.SiLU(),  # 应用 SiLU 激活函数
            linear(time_embed_dim, self.in_channels),  # 从时间嵌入维度回到输入通道的线性变换
        )

        # 初始化时间混合器
        self.time_mixer = AlphaBlender(alpha=merge_factor, merge_strategy=merge_strategy)
        # 设置数据类型
        self.dtype = str_to_dtype[dtype]

    # 前向传播方法,接收输入张量和上下文
    def forward(
        self,
        x: torch.Tensor,  # 输入张量
        context: Optional[torch.Tensor] = None,  # 可选的上下文张量
        time_context: Optional[torch.Tensor] = None,  # 可选的时间上下文张量
        timesteps: Optional[int] = None,  # 可选的时间步数
        image_only_indicator: Optional[torch.Tensor] = None,  # 可选的图像指示器张量
    ) -> torch.Tensor:  # 函数返回一个 torch.Tensor 类型的结果
        _, _, h, w = x.shape  # 解包输入张量 x 的形状,获取高度 h 和宽度 w
        x_in = x  # 保存输入张量 x 的原始值,便于后续操作
        spatial_context = None  # 初始化空间上下文为 None
        if exists(context):  # 如果上下文存在
            spatial_context = context  # 将上下文赋值给空间上下文

        if self.use_spatial_context:  # 如果使用空间上下文
            assert context.ndim == 3, f"n dims of spatial context should be 3 but are {context.ndim}"  # 检查上下文的维度是否为 3

            time_context = context  # 将上下文赋值给时间上下文
            time_context_first_timestep = time_context[::timesteps]  # 获取时间上下文的第一时间步
            time_context = repeat(time_context_first_timestep, "b ... -> (b n) ...", n=h * w)  # 重复时间上下文以适应高度和宽度的大小
        elif time_context is not None and not self.use_spatial_context:  # 如果时间上下文存在且不使用空间上下文
            time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)  # 重复时间上下文以适应高度和宽度的大小
            if time_context.ndim == 2:  # 如果时间上下文的维度为 2
                time_context = rearrange(time_context, "b c -> b 1 c")  # 将时间上下文的维度重新排列

        x = self.norm(x)  # 对输入张量 x 进行归一化处理
        if not self.use_linear:  # 如果不使用线性变换
            x = self.proj_in(x)  # 对 x 进行输入投影
        x = rearrange(x, "b c h w -> b (h w) c")  # 将 x 的维度重新排列为 (batch_size, height*width, channels)
        if self.use_linear:  # 如果使用线性变换
            x = self.proj_in(x)  # 对 x 进行输入投影

        num_frames = torch.arange(timesteps, device=x.device)  # 生成一个包含时间步数的张量 num_frames
        num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)  # 重复 num_frames 以适应 batch 大小
        num_frames = rearrange(num_frames, "b t -> (b t)")  # 将 num_frames 的维度重新排列为一维

        t_emb = timestep_embedding(  # 计算时间步嵌入
            num_frames,  # 输入时间步
            self.in_channels,  # 输入通道数
            repeat_only=False,  # 不仅仅重复
            max_period=self.max_time_embed_period,  # 最大周期
            dtype=self.dtype,  # 数据类型
        )
        emb = self.time_pos_embed(t_emb)  # 通过时间步嵌入计算位置嵌入
        emb = emb[:, None, :]  # 在位置维度上增加一个维度

        for it_, (block, mix_block) in enumerate(zip(self.transformer_blocks, self.time_stack)):  # 遍历 transformer 块和时间堆栈
            x = block(  # 通过 transformer 块处理 x
                x,  # 输入张量
                context=spatial_context,  # 传入空间上下文
            )

            x_mix = x  # 复制处理后的 x 为 x_mix
            x_mix = x_mix + emb  # 将位置嵌入加到 x_mix 上

            x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps)  # 通过混合块处理 x_mix
            x = self.time_mixer(  # 通过时间混合器整合空间和时间特征
                x_spatial=x,  # 输入空间特征
                x_temporal=x_mix,  # 输入时间特征
                image_only_indicator=image_only_indicator,  # 图像标识符
            )
        if self.use_linear:  # 如果使用线性变换
            x = self.proj_out(x)  # 对 x 进行输出投影
        x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)  # 将 x 的维度重新排列为 (batch_size, channels, height, width)
        if not self.use_linear:  # 如果不使用线性变换
            x = self.proj_out(x)  # 对 x 进行输出投影
        out = x + x_in  # 将输出张量与原始输入 x_in 相加
        return out  # 返回最终结果

.\cogvideo-finetune\sat\sgm\modules\__init__.py

# 从相对路径导入 GeneralConditioner 模块
from .encoders.modules import GeneralConditioner

# 定义一个无条件配置的字典
UNCONDITIONAL_CONFIG = {
    # 设置目标为 GeneralConditioner 模块
    "target": "sgm.modules.GeneralConditioner",
    # 定义参数,初始化 emb_models 为一个空列表
    "params": {"emb_models": []},
}

.\cogvideo-finetune\sat\sgm\util.py

# 导入 functools 模块以支持高阶函数功能
import functools
# 导入 importlib 模块以动态导入模块
import importlib
# 导入 os 模块以支持操作系统相关功能
import os
# 从 functools 导入 partial 函数以便于函数部分应用
from functools import partial
# 从 inspect 导入 isfunction 函数用于检查对象是否为函数
from inspect import isfunction

# 导入 fsspec 库以支持文件系统规范
import fsspec
# 导入 numpy 库用于科学计算
import numpy as np
# 导入 torch 库以支持深度学习功能
import torch
# 从 PIL 导入 Image、ImageDraw 和 ImageFont 以支持图像处理
from PIL import Image, ImageDraw, ImageFont
# 从 safetensors.torch 导入 load_file 函数以加载安全张量
from safetensors.torch import load_file as load_safetensors
# 导入 torch.distributed 模块以支持分布式训练
import torch.distributed

# 定义全局变量以存储并行组的上下文
_CONTEXT_PARALLEL_GROUP = None
# 定义全局变量以存储并行组的大小
_CONTEXT_PARALLEL_SIZE = None


# 定义函数以检查上下文并行是否已初始化
def is_context_parallel_initialized():
    # 检查上下文并行组是否为 None
    if _CONTEXT_PARALLEL_GROUP is None:
        return False
    else:
        return True


# 定义函数以设置上下文并行组和大小
def set_context_parallel_group(size, group):
    global _CONTEXT_PARALLEL_GROUP
    global _CONTEXT_PARALLEL_SIZE
    # 设置上下文并行组
    _CONTEXT_PARALLEL_GROUP = group
    # 设置上下文并行大小
    _CONTEXT_PARALLEL_SIZE = size


# 定义函数以初始化上下文并行
def initialize_context_parallel(context_parallel_size):
    global _CONTEXT_PARALLEL_GROUP
    global _CONTEXT_PARALLEL_SIZE

    # 断言上下文并行组未被初始化
    assert _CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized"
    # 设置上下文并行大小
    _CONTEXT_PARALLEL_SIZE = context_parallel_size

    # 获取当前进程的 rank
    rank = torch.distributed.get_rank()
    # 获取全局进程的数量
    world_size = torch.distributed.get_world_size()

    # 按上下文并行大小遍历所有进程以创建新的并行组
    for i in range(0, world_size, context_parallel_size):
        ranks = range(i, i + context_parallel_size)
        # 创建新的分组
        group = torch.distributed.new_group(ranks)
        # 如果当前 rank 在创建的 ranks 中,则设置上下文并行组
        if rank in ranks:
            _CONTEXT_PARALLEL_GROUP = group
            break


# 定义函数以获取当前上下文并行组
def get_context_parallel_group():
    # 断言上下文并行组已初始化
    assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized"

    return _CONTEXT_PARALLEL_GROUP


# 定义函数以获取当前上下文并行的世界大小
def get_context_parallel_world_size():
    # 断言上下文并行大小已初始化
    assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"

    return _CONTEXT_PARALLEL_SIZE


# 定义函数以获取当前上下文并行的 rank
def get_context_parallel_rank():
    # 断言上下文并行大小已初始化
    assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"

    # 获取当前进程的 rank
    rank = torch.distributed.get_rank()
    # 计算当前上下文并行的 rank
    cp_rank = rank % _CONTEXT_PARALLEL_SIZE
    return cp_rank


# 定义函数以获取当前上下文并行组的 rank
def get_context_parallel_group_rank():
    # 断言上下文并行大小已初始化
    assert _CONTEXT_PARALLEL_SIZE is not None, "context parallel size is not initialized"

    # 获取当前进程的 rank
    rank = torch.distributed.get_rank()
    # 计算当前上下文并行组的 rank
    cp_group_rank = rank // _CONTEXT_PARALLEL_SIZE

    return cp_group_rank


# 定义 SafeConv3d 类,继承自 torch.nn.Conv3d
class SafeConv3d(torch.nn.Conv3d):
    # 定义前向传播函数,接收输入数据
    def forward(self, input):
        # 计算输入数据的内存占用(以 GB 为单位),乘以 2 是因为需要考虑反向传播的内存
        memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
        # 如果内存占用超过 2GB,则进行内存优化处理
        if memory_count > 2:
            # kernel_size 取自实例属性,表示卷积核的大小
            kernel_size = self.kernel_size[0]
            # 计算需要将输入分成的部分数量,以控制内存占用
            part_num = int(memory_count / 2) + 1
            # 将输入数据按时间维度进行分块处理
            input_chunks = torch.chunk(input, part_num, dim=2)  # NCTHW
            # 如果卷积核大小大于 1,则需要处理相邻块的拼接
            if kernel_size > 1:
                # 将第一个块保留,后续块与前一个块拼接
                input_chunks = [input_chunks[0]] + [
                    torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
                    for i in range(1, len(input_chunks))
                ]

            # 初始化输出块列表
            output_chunks = []
            # 对每个输入块进行前向传播,并将结果存储到输出块列表中
            for input_chunk in input_chunks:
                output_chunks.append(super(SafeConv3d, self).forward(input_chunk))
            # 将所有输出块在时间维度上拼接成最终输出
            output = torch.cat(output_chunks, dim=2)
            # 返回拼接后的输出
            return output
        else:
            # 如果内存占用不超过 2GB,直接调用父类的前向传播
            return super(SafeConv3d, self).forward(input)
# 禁用训练模式,确保模型的训练/评估模式不会再更改
def disabled_train(self, mode=True):
    """Overwrite model.train with this function to make sure train/eval mode
    does not change anymore."""
    return self  # 返回当前对象,保持状态不变


# 从元组字符串中提取字符串
def get_string_from_tuple(s):
    try:
        # 检查字符串是否以括号开始和结束
        if s[0] == "(" and s[-1] == ")":
            # 将字符串转换为元组
            t = eval(s)
            # 检查 t 的类型是否为元组
            if type(t) == tuple:
                return t[0]  # 返回元组的第一个元素
            else:
                pass  # 如果不是元组则不做处理
    except:
        pass  # 捕获异常,防止程序崩溃
    return s  # 返回原始字符串


# 检查一个数是否是 2 的幂
def is_power_of_two(n):
    """
    chat.openai.com/chat
    Return True if n is a power of 2, otherwise return False.

    The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False.
    The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False.
    If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise.
    Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False.

    """
    if n <= 0:
        return False  # 如果 n 小于或等于 0,返回 False
    return (n & (n - 1)) == 0  # 返回 n 是否是 2 的幂


# 自动混合精度处理函数
def autocast(f, enabled=True):
    def do_autocast(*args, **kwargs):
        with torch.cuda.amp.autocast(
            enabled=enabled,  # 启用或禁用自动混合精度
            dtype=torch.get_autocast_gpu_dtype(),  # 获取 GPU 的自动混合精度数据类型
            cache_enabled=torch.is_autocast_cache_enabled(),  # 检查缓存是否启用
        ):
            return f(*args, **kwargs)  # 调用原始函数并返回结果

    return do_autocast  # 返回自动混合精度处理的封装函数


# 从配置加载部分对象
def load_partial_from_config(config):
    return partial(get_obj_from_str(config["target"]), **config.get("params", dict()))  # 返回部分应用的对象


# 将文本作为图像进行记录
def log_txt_as_img(wh, xc, size=10):
    # wh 是一个包含 (宽度, 高度) 的元组
    # xc 是要绘制的标题列表
    b = len(xc)  # 获取标题列表的长度
    txts = list()  # 初始化文本图像列表
    for bi in range(b):  # 遍历每个标题
        txt = Image.new("RGB", wh, color="white")  # 创建一个白色背景的图像
        draw = ImageDraw.Draw(txt)  # 创建可绘制对象
        font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)  # 加载指定字体
        nc = int(40 * (wh[0] / 256))  # 计算每行最多字符数
        if isinstance(xc[bi], list):  # 如果标题是列表
            text_seq = xc[bi][0]  # 获取第一个元素
        else:
            text_seq = xc[bi]  # 否则直接使用标题
        lines = "\n".join(text_seq[start : start + nc] for start in range(0, len(text_seq), nc))  # 将文本分行

        try:
            draw.text((0, 0), lines, fill="black", font=font)  # 绘制文本
        except UnicodeEncodeError:  # 如果编码错误
            print("Cant encode string for logging. Skipping.")  # 打印错误信息并跳过

        txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0  # 转换图像为数组并归一化
        txts.append(txt)  # 添加到列表中
    txts = np.stack(txts)  # 堆叠所有文本图像
    txts = torch.tensor(txts)  # 转换为 PyTorch 张量
    # 返回包含文本数据的列表或字符串
        return txts
# 定义一个部分类,用于传递参数到初始化方法
def partialclass(cls, *args, **kwargs):
    # 创建一个新类,继承自原始类
    class NewCls(cls):
        # 使用 functools.partialmethod 将原始类的初始化方法与参数绑定
        __init__ = functools.partialmethod(cls.__init__, *args, **kwargs)

    # 返回新创建的类
    return NewCls


# 将给定路径转换为绝对路径
def make_path_absolute(path):
    # 解析路径并获取文件系统和路径
    fs, p = fsspec.core.url_to_fs(path)
    # 如果文件系统协议为文件,则返回绝对路径
    if fs.protocol == "file":
        return os.path.abspath(p)
    # 否则返回原始路径
    return path


# 判断输入是否为地图类型的张量
def ismap(x):
    # 如果输入不是张量,则返回 False
    if not isinstance(x, torch.Tensor):
        return False
    # 检查张量是否为四维且第二维大于3
    return (len(x.shape) == 4) and (x.shape[1] > 3)


# 判断输入是否为图像类型的张量
def isimage(x):
    # 如果输入不是张量,则返回 False
    if not isinstance(x, torch.Tensor):
        return False
    # 检查张量是否为四维且第二维为3或1
    return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)


# 判断输入是否为热图类型的张量
def isheatmap(x):
    # 如果输入不是张量,则返回 False
    if not isinstance(x, torch.Tensor):
        return False
    # 检查张量是否为二维
    return x.ndim == 2


# 判断输入是否为邻接类型的张量
def isneighbors(x):
    # 如果输入不是张量,则返回 False
    if not isinstance(x, torch.Tensor):
        return False
    # 检查张量是否为五维且第三维为3或1
    return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1)


# 检查输入是否存在
def exists(x):
    # 判断输入是否不为 None
    return x is not None


# 扩展张量的维度,使其与目标张量的维度相同
def expand_dims_like(x, y):
    # 当 x 的维度不等于 y 的维度时,持续扩展 x 的维度
    while x.dim() != y.dim():
        x = x.unsqueeze(-1)
    # 返回扩展后的 x
    return x


# 返回存在的值或默认值
def default(val, d):
    # 如果 val 存在,则返回 val
    if exists(val):
        return val
    # 返回默认值,如果 d 是函数则调用它
    return d() if isfunction(d) else d


# 计算张量的平坦均值
def mean_flat(tensor):
    """
    https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
    对所有非批次维度进行均值计算。
    """
    # 计算张量在所有非批次维度上的均值
    return tensor.mean(dim=list(range(1, len(tensor.shape))))


# 计算模型的参数总数
def count_params(model, verbose=False):
    # 计算模型所有参数的总数量
    total_params = sum(p.numel() for p in model.parameters())
    # 如果 verbose 为真,则打印参数数量
    if verbose:
        print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
    # 返回参数总数
    return total_params


# 根据配置实例化对象
def instantiate_from_config(config, **extra_kwargs):
    # 检查配置中是否包含目标键
    if not "target" in config:
        # 如果配置为 "__is_first_stage__",返回 None
        if config == "__is_first_stage__":
            return None
        # 如果配置为 "__is_unconditional__",返回 None
        elif config == "__is_unconditional__":
            return None
        # 抛出缺少目标键的异常
        raise KeyError("Expected key `target` to instantiate.")
    # 根据目标字符串实例化对象,并传递参数
    return get_obj_from_str(config["target"])(**config.get("params", dict()), **extra_kwargs)


# 从字符串获取对象
def get_obj_from_str(string, reload=False, invalidate_cache=True):
    # 分割模块和类名
    module, cls = string.rsplit(".", 1)
    # 如果需要,失效缓存
    if invalidate_cache:
        importlib.invalidate_caches()
    # 如果需要重载模块,导入并重载
    if reload:
        module_imp = importlib.import_module(module)
        importlib.reload(module_imp)
    # 返回模块中的类对象
    return getattr(importlib.import_module(module, package=None), cls)


# 在张量末尾添加零
def append_zero(x):
    # 将零张量与输入张量拼接
    return torch.cat([x, x.new_zeros([1])])


# 添加维度到张量以达到目标维度
def append_dims(x, target_dims):
    """将维度添加到张量末尾,直到其具有目标维度。"""
    # 计算需要添加的维度数量
    dims_to_append = target_dims - x.ndim
    # 如果输入维度大于目标维度,抛出异常
    if dims_to_append < 0:
        raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
    # 返回添加了维度的张量
    return x[(...,) + (None,) * dims_to_append]


# 从配置加载模型
def load_model_from_config(config, ckpt, verbose=True, freeze=True):
    # 打印加载模型的检查点路径
    print(f"Loading model from {ckpt}")
    # 检查检查点文件名是否以 "ckpt" 结尾
    if ckpt.endswith("ckpt"):
        # 从检查点文件加载状态字典到 CPU
        pl_sd = torch.load(ckpt, map_location="cpu")
        # 如果状态字典中包含 "global_step" 键
        if "global_step" in pl_sd:
            # 打印全局步数
            print(f"Global Step: {pl_sd['global_step']}")
        # 提取状态字典
        sd = pl_sd["state_dict"]
    # 检查检查点文件名是否以 "safetensors" 结尾
    elif ckpt.endswith("safetensors"):
        # 从 safetensors 文件加载状态字典
        sd = load_safetensors(ckpt)
    # 如果文件名不符合上述格式,则抛出未实现的错误
    else:
        raise NotImplementedError

    # 根据配置实例化模型
    model = instantiate_from_config(config.model)

    # 加载模型状态字典,允许非严格匹配
    m, u = model.load_state_dict(sd, strict=False)

    # 如果有缺失的键且详细模式开启
    if len(m) > 0 and verbose:
        # 打印缺失的键
        print("missing keys:")
        print(m)
    # 如果有意外的键且详细模式开启
    if len(u) > 0 and verbose:
        # 打印意外的键
        print("unexpected keys:")
        print(u)

    # 如果冻结标志为真
    if freeze:
        # 遍历模型参数
        for param in model.parameters():
            # 设置参数为不需要梯度
            param.requires_grad = False

    # 将模型设置为评估模式
    model.eval()
    # 返回已加载的模型
    return model
# 定义一个获取配置路径的函数,返回字符串类型
def get_configs_path() -> str:
    # 函数文档说明:获取 `configs` 目录的位置
    this_dir = os.path.dirname(__file__)  # 获取当前文件所在目录的路径
    # 创建候选路径,包含当前目录下的 configs 目录和上级目录下的 configs 目录
    candidates = (
        os.path.join(this_dir, "configs"),
        os.path.join(this_dir, "..", "configs"),
    )
    # 遍历每个候选路径
    for candidate in candidates:
        candidate = os.path.abspath(candidate)  # 将候选路径转换为绝对路径
        if os.path.isdir(candidate):  # 检查该路径是否为一个目录
            return candidate  # 如果是,返回该路径
    # 如果没有找到有效的 configs 目录,抛出文件未找到错误
    raise FileNotFoundError(f"Could not find SGM configs in {candidates}")


# 定义一个获取嵌套属性的函数
def get_nested_attribute(obj, attribute_path, depth=None, return_key=False):
    # 函数文档说明:递归获取对象的嵌套属性
    attributes = attribute_path.split(".")  # 根据 '.' 分割属性路径
    if depth is not None and depth > 0:  # 如果指定深度且大于零
        attributes = attributes[:depth]  # 限制属性列表到指定深度
    assert len(attributes) > 0, "At least one attribute should be selected"  # 确保至少有一个属性
    current_attribute = obj  # 初始化当前属性为对象
    current_key = None  # 初始化当前键
    # 遍历每个属性
    for level, attribute in enumerate(attributes):
        current_key = ".".join(attributes[: level + 1])  # 生成当前键的字符串
        try:
            id_ = int(attribute)  # 尝试将属性转换为整数
            current_attribute = current_attribute[id_]  # 使用索引访问属性
        except ValueError:  # 如果转换失败
            current_attribute = getattr(current_attribute, attribute)  # 使用 getattr 获取属性

    # 返回当前属性和当前键的元组,或者只返回当前属性
    return (current_attribute, current_key) if return_key else current_attribute


# 从 math 模块导入平方根函数
from math import sqrt


# 定义一个 SeededNoise 类
class SeededNoise:
    # 初始化方法,接受种子和权重
    def __init__(self, seeds, weights):
        self.seeds = seeds  # 存储种子
        self.weights = weights  # 存储权重
        weight_square_sum = 0  # 初始化权重平方和
        # 遍历每个权重
        for weight in weights:
            weight_square_sum += weight**2  # 计算权重的平方和
        self.weight_square_sum_sqrt = sqrt(weight_square_sum)  # 计算权重平方和的平方根
        self.cnt = 0  # 初始化计数器

    # 定义可调用方法
    def __call__(self, x):
        self.cnt += 1  # 计数器加一
        randn_combined = torch.zeros_like(x)  # 创建与 x 同形状的零张量
        # 遍历种子和权重
        for seed, weight in zip(self.seeds, self.weights):
            randn = np.random.RandomState(seed + self.cnt).randn(*x.shape)  # 生成正态分布随机数
            randn = torch.from_numpy(randn, dtype=x.dtype, device=x.device)  # 将随机数转换为张量
            randn_combined += randn * weight  # 将加权随机数累加到组合中
        randn_combined /= self.weight_square_sum_sqrt  # 将组合随机数归一化
        return randn_combined  # 返回最终的组合随机数

.\cogvideo-finetune\sat\sgm\webds.py

# 导入所需的标准库
import sys  # 系统相关的功能
import io  # 输入输出操作
import os  # 操作系统功能
import re  # 正则表达式操作
import json  # JSON 数据处理
import tarfile  # TAR 文件处理
from functools import partial  # 偏函数应用

# 导入 webdataset 库的相关模块
import webdataset as wds  # webdataset 的主模块
from webdataset import ResampledShards, DataPipeline, tarfile_to_samples  # 导入特定功能
from webdataset.filters import pipelinefilter  # 导入过滤功能
from webdataset.tariterators import url_opener, group_by_keys  # 导入 TAR 迭代器相关功能
from webdataset.handlers import reraise_exception  # 导入异常处理功能
from webdataset.gopen import gopen_schemes, gopen  # 导入打开函数和方案

def pytorch_worker_info(group=None):  # sourcery skip: use-contextlib-suppress
    """返回 PyTorch 和一些分布式环境的节点和工作者信息。"""
    rank = 0  # 初始化节点秩
    world_size = 1  # 初始化世界大小
    worker = 0  # 初始化工作者 ID
    num_workers = 1  # 初始化工作者数量
    try:
        import torch.distributed  # 导入分布式 PyTorch 模块

        # 检查分布式模块是否可用和已初始化
        if torch.distributed.is_available() and torch.distributed.is_initialized():
            group = group or torch.distributed.group.WORLD  # 设置组为 WORLD
            rank = torch.distributed.get_rank(group=group)  # 获取节点秩
            world_size = torch.distributed.get_world_size(group=group)  # 获取世界大小
    except ModuleNotFoundError:
        pass  # 如果未找到模块,则跳过
    try:
        import torch.utils.data  # 导入数据工具模块

        worker_info = torch.utils.data.get_worker_info()  # 获取工作者信息
        if worker_info is not None:  # 如果工作者信息存在
            worker = worker_info.id  # 获取工作者 ID
            num_workers = worker_info.num_workers  # 获取工作者总数
    except ModuleNotFoundError:
        pass  # 如果未找到模块,则跳过

    return rank, world_size, worker, num_workers  # 返回节点信息

def pytorch_worker_seed(group=None):
    """为每个工作者和节点计算唯一且确定性的随机种子。"""
    rank, world_size, worker, num_workers = pytorch_worker_info(group=group)  # 获取工作者信息
    return rank * 1000 + worker  # 计算并返回随机种子

def worker_seed_sat(group=None, seed=0):
    return pytorch_worker_seed(group=group) + seed * 23  # 计算工作者的随机种子并增加偏移

class ConfiguredResampledShards(ResampledShards):
    def __init__(self, urls, seed, nshards=sys.maxsize, deterministic=True):
        from sat.helpers import print_rank0  # 导入打印功能

        try:
            from megatron.core.parallel_state import get_data_parallel_group  # 尝试导入 Megatron 数据并行组

            group = get_data_parallel_group()  # 获取数据并行组
            print_rank0("Using megatron data parallel group.")  # 打印使用的组信息
        except:
            from sat.mpu import get_data_parallel_group  # 导入备用的数据并行组

            try:
                group = get_data_parallel_group()  # 获取数据并行组
                print_rank0("Using sat data parallel group.")  # 打印使用的组信息
            except AssertionError:
                group = None  # 如果没有指定组,则设置为 None
                print_rank0("No data parallel group is specified!")  # 打印警告信息
        worker_seed_sat_this = partial(worker_seed_sat, group=group, seed=seed)  # 创建偏函数
        super().__init__(urls, nshards, worker_seed_sat_this, deterministic)  # 调用父类构造函数

class SimpleDistributedWebDataset(DataPipeline):  # 定义简单的分布式 Web 数据集类
    # 初始化方法,接收路径、处理函数、种子以及可选的洗牌缓冲区大小
    def __init__(self, path, process_fn, seed, *, shuffle_buffer=1000):
        # 如果将 shuffle_buffer 设置为 1,则禁用洗牌,模型并行将会有所不同
        try:
            # 从 sat.mpu 模块导入获取模型并行世界大小的函数
            from sat.mpu import get_model_parallel_world_size
    
            # 检查模型并行世界大小,如果大于 1,则将洗牌缓冲区设置为 1
            if get_model_parallel_world_size() > 1:
                shuffle_buffer = 1
        except Exception:
            # 捕获异常并忽略
            pass
        # 调用父类构造函数,初始化配置的重采样分片及相关参数
        super().__init__(
            ConfiguredResampledShards(path, seed),  # 使用指定路径和种子初始化重采样分片,推荐使用多个分片以避免不均匀
            tarfile_to_samples(),  # 将 tar 文件转换为样本
            wds.shuffle(shuffle_buffer),  # 使用洗牌函数,并传入洗牌缓冲区大小
            process_fn,  # 传入处理函数
        )
# 定义一个迭代器函数,用于遍历 tar 文件,生成文件名和内容的对
def tar_file_iterator_with_meta(
    # 输入的字节流对象,适用于 tarfile
    fileobj, 
    # 元数据文件中不同项的键
    meta_names, 
    # 用于跳过某些键的正则表达式(默认值为 r"__[^/]*__($|/)")
    skip_meta=r"__[^/]*__($|/)", 
    # 文件后缀名(可选)
    suffix=None, 
    # 异常处理的处理程序(默认为 reraise_exception)
    handler=reraise_exception, 
    # 元数据流(可选)
    meta_stream=None
):
    """遍历 tar 文件,返回给定 tar 流的文件名和内容对。

    :param fileobj: 适用于 tarfile 的字节流
    :param meta_names: 元数据文件中不同项的键
    :param skip_meta: 完全跳过的键的正则表达式(默认值 = r"__[^/]*__($|/)")
    """
    # 打开 tar 文件流,以读取模式
    stream = tarfile.open(fileobj=fileobj, mode="r|*")
    # 从文件对象中提取数据目录和文件名
    data_dir, filename = fileobj.name.rsplit("/", 1)
    # 初始化元数据字典,用于存储元数据
    meta_data = {}  # {id: {meta_name: meta_value, meta_name2: meta_value2, ...}}

    # 如果没有提供元数据流
    if meta_stream is None:
        # 生成元数据文件名,使用文件名的前缀加上 ".meta.jsonl"
        meta_file_name = filename.split(".")[0] + ".meta.jsonl"
        # 构建元数据文件的完整路径
        meta_path = os.path.join(data_dir, meta_file_name)
        # 如果元数据文件存在,则打开文件
        if os.path.exists(meta_path):
            meta_stream = open(meta_path, "r")
    else:
        # 如果提供了元数据流,则使用其名称
        meta_file_name = meta_stream.name

    # 如果元数据流存在
    if meta_stream is not None:
        # 遍历元数据流的每一行,记录行号和内容
        for lineno, line in enumerate(meta_stream):
            meta_list = []
            try:
                # 尝试将行内容解析为 JSON 对象
                meta_list.append(json.loads(line))
            except Exception as exn:
                # 导入帮助函数以打印错误
                from sat.helpers import print_rank0

                # 打印解析 JSONL 时的错误信息
                print_rank0(f"Error in loading jsonl {meta_file_name}, lineno {lineno}: {line}", level="DEBUG")
                # 继续下一行
                continue
            # 遍历解析出的每个元数据项
            for item in meta_list:
                # 如果元数据项的键不存在于元数据字典中,则初始化它
                if not item["key"] in meta_data:
                    meta_data[item["key"]] = {}
                # 遍历所有指定的元数据名称
                for meta_name in meta_names:
                    # 如果元数据项中包含该元数据名称,则将其值存储在元数据字典中
                    if meta_name in item:
                        meta_data[item["key"]][meta_name] = item[meta_name]
        # 关闭元数据流
        meta_stream.close()
    # 尝试处理流中的每个项目
        try:
            # 遍历流中的每个 tarinfo 对象
            for tarinfo in stream:
                # 获取当前项目的文件名
                fname = tarinfo.name
                try:
                    # 如果不是常规文件,则跳过
                    if not tarinfo.isreg():
                        continue
                    # 如果文件名为空,则跳过
                    if fname is None:
                        continue
                    # 跳过以双下划线开头和结尾的元数据文件
                    if "/" not in fname and fname.startswith("__") and fname.endswith("__"):
                        # 目前跳过元数据
                        continue
                    # 如果指定了 skip_meta,且文件名匹配,则跳过
                    if skip_meta is not None and re.match(skip_meta, fname):
                        continue
                    # 如果文件是 txt 类型且有后缀,则读取内容并附加后缀
                    if fname.endswith(".txt") and suffix is not None:
                        data = (stream.extractfile(tarinfo).read().decode() + suffix).encode()
                    else:
                        # 否则仅读取文件内容
                        data = stream.extractfile(tarinfo).read()
                    # 创建包含文件名和数据的字典
                    result = dict(fname=fname, data=data)
                    # 生成结果字典
                    yield result
    
                    # 如果文件名以 .id 结尾
                    if fname.endswith(".id"):
                        # 获取文件 ID,去掉扩展名
                        fid = fname.split(".")[0]
                        # 检查文件 ID 是否包含特定字符串并处理
                        if "-$#%@&" in fid:
                            sfid = fid.split("-$#%@&")[0]
                        else:
                            sfid = fid
                        # 从元数据中获取相关数据
                        meta_data_fid = meta_data.get(sfid, {})
                        # 遍历元数据名称
                        for meta_name in meta_names:
                            # 构建元数据文件名
                            meta_fname = fid + "." + meta_name
                            # 获取元数据内容
                            meta = meta_data_fid.get(meta_name, None)
                            # 生成包含元数据的字典
                            yield dict(fname=meta_fname, data=meta)
                    # 清空流的成员列表
                    stream.members = []
                except Exception as exn:
                    # 如果异常有参数,则附加当前文件对象信息
                    if hasattr(exn, "args") and len(exn.args) > 0:
                        exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:]
                    # 处理异常,如果处理成功则继续
                    if handler(exn):
                        continue
                    else:
                        # 否则跳出循环
                        break
        except Exception as exn:
            # 打印外层异常信息
            print(exn)
        # 删除流对象以释放资源
        del stream
# 扩展一个打开的 tar 文件流,并返回包含文件内容的迭代器
def tar_file_expander_with_meta(data, meta_names, handler=reraise_exception):
    """Expand a stream of open tar files into a stream of tar file contents.

    This returns an iterator over (filename, file_contents).
    """
    # 遍历输入数据中的每个源
    for source in data:
        # 从源字典中获取 URL
        url = source["url"]
        try:
            # 确保源是字典类型
            assert isinstance(source, dict)
            # 确保源字典中包含 "stream" 键
            assert "stream" in source
            # 遍历 tar 文件内容生成器
            for sample in tar_file_iterator_with_meta(source["stream"], meta_names, meta_stream=source["meta_stream"]):
                # 确保样本是字典并包含 "data" 和 "fname"
                assert isinstance(sample, dict) and "data" in sample and "fname" in sample
                # 将 URL 添加到样本字典中
                sample["__url__"] = url
                # 生成样本
                yield sample
        except Exception as exn:
            # 追加流和 URL 到异常参数中
            exn.args = exn.args + (source.get("stream"), source.get("url"))
            # 如果处理异常的函数返回真,继续循环
            if handler(exn):
                continue
            else:
                # 否则,退出循环
                break


# 打开 URL 并返回 URL 和流的配对迭代器
def url_opener(
    data,
    handler,
    **kw,
):
    """Open URLs and yield a stream of url+stream pairs.

    Args:
        data: iterator over dict(url=...)
        handler: exception handler.
        kw: keyword arguments for gopen.gopen.

    Yields:
        a stream of url+stream pairs.
    """
    # 遍历输入数据中的每个样本
    for sample in data:
        # 确保样本是字典类型
        assert isinstance(sample, dict), sample
        # 确保样本字典中包含 "url" 键
        assert "url" in sample
        # 从样本中获取 URL
        url = sample["url"]
        try:
            # 打开 URL 并获取流
            stream = gopen(url, **kw)
            # 检查流是否有 meta_stream 属性
            if hasattr(stream, "meta_stream"):
                # 获取 meta_stream,并删除该属性
                meta_stream = stream.meta_stream
                del stream.meta_stream
            else:
                # 如果没有,则设为 None
                meta_stream = None
            # 更新样本字典,包含流和 meta_stream
            sample.update(stream=stream, meta_stream=meta_stream)
            # 生成样本
            yield sample
        except Exception as exn:
            # 追加 URL 到异常参数中
            exn.args = exn.args + (url,)
            # 如果处理异常的函数返回真,继续循环
            if handler(exn):
                continue
            else:
                # 否则,退出循环
                break


# 使用元数据扩展 tar 文件样本
def tarfile_samples_with_meta(src, meta_names, handler=reraise_exception):
    # 使用 URL 打开器获取流
    streams = url_opener(src, handler=handler)
    # 扩展 tar 文件流并获取文件样本
    files = tar_file_expander_with_meta(streams, meta_names, handler)
    # 按键对样本进行分组
    samples = group_by_keys(files, handler=handler)
    # 返回样本
    return samples


# 定义带有元信息文件的分布式 Web 数据集类
class MetaDistributedWebDataset(DataPipeline):
    """WebDataset with meta information files
    Extra Format:
        in webdataset (tar), for each sample there is a '.id';
        for each tar file, there is a '.meta.jsonl' file with the same name;
        The '.meta.jsonl' file contains lines of json objects, each with a 'key' field to match '.id'.
    """

    # 初始化方法,设置数据集参数
    def __init__(
        self, path, process_fn, seed, *, meta_names=[], nshards=sys.maxsize, shuffle_buffer=1000, include_dirs=None
    ):
        # 设置环境变量,控制是否显示种子(注释掉)
        # os.environ['WDS_SHOW_SEED'] = '1'
        # 导入 PyTorch 库
        import torch

        # 检查当前进程是否为主进程
        if torch.distributed.get_rank() == 0:
            # 如果包含的目录不为 None
            if include_dirs is not None:  # /webdatasets/A,/webdatasets/C
                # 初始化其他路径列表
                other_paths = []
                # 将包含的目录字符串按逗号分割
                include_dirs = include_dirs.split(",")
                # 遍历每个包含的目录
                for include_dir in include_dirs:
                    # 如果目录名中包含通配符 "*"
                    if "*" in include_dir:
                        # 分割目录名和数量
                        include_dir, n = include_dir.split("*")
                        n = int(n)  # 转换数量为整数
                    else:
                        n = 1  # 默认数量为 1
                    # 遍历当前目录及其子目录
                    for cur_dir, dirs, files in os.walk(include_dir):
                        # 遍历所有文件
                        for f in files:
                            # 检查文件是否以 "tar" 结尾且文件大小大于 0
                            if f.endswith("tar") and os.path.getsize(os.path.join(cur_dir, f)) > 0:
                                # 将符合条件的文件路径添加到其他路径列表中
                                # other_paths.append(os.path.join(cur_dir,f))
                                other_paths.extend([os.path.join(cur_dir, f)] * n)  # 根据数量扩展列表
                # print(f'Adding dataset paths {",".join(other_paths)}')
                # 从 braceexpand 库导入
                from braceexpand import braceexpand

                # 如果路径字符串不为空
                if len(path) > 0:  # not ""
                    # 扩展路径并与其他路径合并
                    path = list(braceexpand(path)) + other_paths
                else:
                    # 如果路径为空,仅使用其他路径
                    path = other_paths
            # 将路径包装成列表
            path = [path]
        else:
            # 如果不是主进程,将路径设置为 None
            path = [
                None,
            ]
        # 广播路径列表到所有进程
        torch.distributed.broadcast_object_list(path, src=0)
        # 选择第一个路径
        path = path[0]

        # 生成带元数据的 tar 文件样本处理函数
        tarfile_samples = partial(tarfile_samples_with_meta, meta_names=meta_names)
        # 对 tar 文件样本进行管道过滤
        tarfile_to_samples = pipelinefilter(tarfile_samples)

        # 如果模型并行,设置打乱缓冲区大小为 1 以禁用打乱
        try:
            # 从 sat.mpu 模块导入获取模型并行世界大小的函数
            from sat.mpu import get_model_parallel_world_size

            # 检查模型并行世界大小是否大于 1
            if get_model_parallel_world_size() > 1:
                shuffle_buffer = 1  # 设置打乱缓冲区大小
        except Exception:
            pass  # 忽略导入错误

        # 调用父类初始化方法,传入配置好的重采样分片、样本处理管道及其他参数
        super().__init__(
            ConfiguredResampledShards(path, seed, nshards=nshards),
            tarfile_to_samples(),
            wds.shuffle(shuffle_buffer),
            process_fn,
        )
# rclone 支持
from webdataset.gopen import Pipe  # 从 webdataset.gopen 导入 Pipe 类


def gopen_rclone(url, mode="rb", bufsize=1024 * 1024 * 32):
    """使用 `curl` 打开一个 URL。

    :param url: rclone URL,例如 data:bucket1/foo.tar,数据需要被配置。
    :param mode: 文件模式
    :param bufsize: 缓冲区大小
    """
    # 去掉 URL 前缀 "rclone://"
    url = url.replace("rclone://", "")
    # 如果模式以 "r" 开头,准备读取命令
    if mode[0] == "r":
        cmd = f"rclone cat '{url}'"  # 生成 rclone 读取命令
        return Pipe(  # 返回 Pipe 对象用于读取
            cmd,
            mode=mode,  # 设置文件模式
            shell=True,  # 通过 shell 执行命令
            bufsize=bufsize,  # 设置缓冲区大小
            ignore_status=[141, 23],  # 忽略特定的返回状态
        )  # skipcq: BAN-B604
    # 如果模式以 "w" 开头,准备写入命令
    elif mode[0] == "w":
        cmd = f"rclone cp - '{url}'"  # 生成 rclone 写入命令
        return Pipe(  # 返回 Pipe 对象用于写入
            cmd,
            mode=mode,  # 设置文件模式
            shell=True,  # 通过 shell 执行命令
            bufsize=bufsize,  # 设置缓冲区大小
            ignore_status=[141, 26],  # 忽略特定的返回状态
        )  # skipcq: BAN-B604
    else:
        # 如果模式未知,抛出错误
        raise ValueError(f"{mode}: unknown mode")


def gopen_boto3(url, mode="rb", bufsize=8192 * 2):
    """使用 boto3 API 打开一个 URL。

    :param url: boto3 URL,例如 boto3://bucket1/foo.tar,数据需要被配置。
    :param mode: 文件模式
    :param bufsize: 缓冲区大小
    """
    import boto3  # 导入 boto3 库

    # boto3.set_stream_logger('botocore', level='DEBUG')  # 设置日志记录(已注释)
    # 如果 URL 以 "boto3://" 开头,去掉前缀并标记是否需要元数据
    if url.startswith("boto3://"):
        url = url.replace("boto3://", "")
        need_meta = False
    else:
        url = url.replace("metaboto3://", "")  # 去掉 "metaboto3://" 前缀
        need_meta = True  # 需要元数据

    # 从环境变量获取 S3 配置
    endpoint_url = os.environ.get("S3_ENDPOINT_URL", None)  # S3 端点 URL
    access_key = os.environ.get("S3_ACCESS_KEY_ID", None)  # 访问密钥 ID
    secret_key = os.environ.get("S3_SECRET_ACCESS_KEY", None)  # 秘密访问密钥

    # 如果模式以 "r" 开头,准备读取 S3 对象
    if mode[0] == "r":
        # 创建 S3 客户端
        s3_client = boto3.client(
            "s3", endpoint_url=endpoint_url, aws_access_key_id=access_key, aws_secret_access_key=secret_key
        )
        # 分割桶名和对象键
        bucket, key = url.split("/", 1)

        # 如果需要元数据,下载相应的元数据文件
        if need_meta:
            # 下载一个元数据 JSON 文件
            meta_file_key = key.split(".")[0] + ".meta.jsonl"  # 元数据文件名
            meta_stream = io.BytesIO()  # 创建一个字节流用于存储元数据
            s3_client.download_fileobj(bucket, meta_file_key, meta_stream)  # 从 S3 下载元数据
            meta_stream.seek(0)  # 重置字节流位置
            meta_stream.name = meta_file_key  # 设置字节流名称
        else:
            meta_stream = None  # 不需要元数据时设置为 None

        # 获取数据流对象
        response = s3_client.get_object(Bucket=bucket, Key=key)  # 获取 S3 对象
        response["Body"].name = key  # 设置对象的名称(实际未使用)
        response["Body"].meta_stream = meta_stream  # 将元数据流关联到对象
        return response["Body"]  # 返回数据流对象
    else:
        # 如果模式未知,抛出错误
        raise ValueError(f"{mode}: unknown mode")


# 注册 gopen_rclone 和 gopen_boto3 到 gopen_schemes 字典
gopen_schemes["rclone"] = gopen_rclone
gopen_schemes["boto3"] = gopen_boto3
gopen_schemes["metaboto3"] = gopen_boto3

.\cogvideo-finetune\sat\sgm\__init__.py

# 从当前模块导入 AutoencodingEngine 类
from .models import AutoencodingEngine
# 从当前模块导入获取配置路径和根据配置实例化对象的函数
from .util import get_configs_path, instantiate_from_config
# 定义当前模块的版本号
__version__ = "0.1.0"

.\cogvideo-finetune\sat\train_video.py

# 导入操作系统模块
import os
# 导入命令行参数解析模块
import argparse
# 从 functools 模块导入 partial 函数,用于部分应用
from functools import partial
# 导入 NumPy 库
import numpy as np
# 导入 PyTorch 的分布式模块
import torch.distributed
# 导入 OmegaConf,用于配置管理
from omegaconf import OmegaConf
# 导入 imageio,用于读取和写入视频
import imageio

# 导入 PyTorch 库
import torch

# 从 sat 模块导入 mpu
from sat import mpu
# 从 sat.training.deepspeed_training 导入 training_main
from sat.training.deepspeed_training import training_main

# 从 sgm.util 导入工具函数
from sgm.util import get_obj_from_str, isheatmap

# 从 diffusion_video 导入 SATVideoDiffusionEngine 类
from diffusion_video import SATVideoDiffusionEngine
# 从 arguments 导入获取命令行参数的函数
from arguments import get_args

# 从 einops 导入 rearrange 函数,用于重新排列张量
from einops import rearrange

# 尝试导入 wandb 库,用于实验跟踪
try:
    import wandb
# 如果 wandb 未安装,打印警告信息
except ImportError:
    print("warning: wandb not installed")


# 打印调试信息的函数
def print_debug(args, s):
    # 如果启用了调试模式
    if args.debug:
        # 添加当前进程的排名到输出字符串
        s = f"RANK:[{torch.distributed.get_rank()}]:" + s
        # 打印调试信息
        print(s)


# 保存文本列表到指定目录的函数
def save_texts(texts, save_dir, iterations):
    # 构建输出文件的路径
    output_path = os.path.join(save_dir, f"{str(iterations).zfill(8)}")
    # 打开输出文件,设置编码为 UTF-8
    with open(output_path, "w", encoding="utf-8") as f:
        # 遍历文本列表
        for text in texts:
            # 将每个文本写入文件,每个文本后换行
            f.write(text + "\n")


# 将视频批次保存为网格和 MP4 格式的函数
def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5, args=None, key=None):
    # 创建保存路径,如果不存在则创建
    os.makedirs(save_path, exist_ok=True)

    # 遍历视频批次
    for i, vid in enumerate(video_batch):
        gif_frames = []  # 用于存储帧的列表
        # 遍历视频中的每一帧
        for frame in vid:
            # 重新排列帧的维度,从 (c, h, w) 转为 (h, w, c)
            frame = rearrange(frame, "c h w -> h w c")
            # 将帧数据缩放到 [0, 255] 并转换为 uint8 类型
            frame = (255.0 * frame).cpu().numpy().astype(np.uint8)
            # 将处理后的帧添加到列表中
            gif_frames.append(frame)
        # 构建当前视频保存的路径
        now_save_path = os.path.join(save_path, f"{i:06d}.mp4")
        # 使用 imageio 创建视频写入器
        with imageio.get_writer(now_save_path, fps=fps) as writer:
            # 将每一帧写入视频文件
            for frame in gif_frames:
                writer.append_data(frame)
        # 如果 args 存在且 wandb 被启用
        if args is not None and args.wandb:
            # 记录视频到 wandb
            wandb.log(
                {key + f"_video_{i}": wandb.Video(now_save_path, fps=fps, format="mp4")}, step=args.iteration + 1
            )


# 日志视频的函数
def log_video(batch, model, args, only_log_video_latents=False):
    # 获取文本数据
    texts = batch["txt"]
    # 构建保存文本的目录
    text_save_dir = os.path.join(args.save, "video_texts")
    # 创建保存文本目录
    os.makedirs(text_save_dir, exist_ok=True)
    # 保存文本数据到文件
    save_texts(texts, text_save_dir, args.iteration)

    # 准备 GPU 自动混合精度的参数
    gpu_autocast_kwargs = {
        "enabled": torch.is_autocast_enabled(),
        "dtype": torch.get_autocast_gpu_dtype(),
        "cache_enabled": torch.is_autocast_cache_enabled(),
    }
    # 在不计算梯度的上下文中,启用自动混合精度
    with torch.no_grad(), torch.cuda.amp.autocast(**gpu_autocast_kwargs):
        # 使用模型记录视频,是否只记录视频的潜在表示
        videos = model.log_video(batch, only_log_video_latents=only_log_video_latents)
    # 检查当前进程是否是主进程(rank 为 0)
    if torch.distributed.get_rank() == 0:
        # 设置视频保存的根目录
        root = os.path.join(args.save, "video")

        # 如果只记录视频潜在变量
        if only_log_video_latents:
            # 创建潜在变量的子目录
            root = os.path.join(root, "latents")
            # 格式化文件名,包含当前迭代次数
            filename = "{}_gs-{:06}".format("latents", args.iteration)
            # 生成完整的文件路径
            path = os.path.join(root, filename)
            # 创建保存路径的目录(如果不存在的话)
            os.makedirs(os.path.split(path)[0], exist_ok=True)
            # 创建文件夹(如果不存在的话)
            os.makedirs(path, exist_ok=True)
            # 保存潜在变量数据到指定路径
            torch.save(videos["latents"], os.path.join(path, "latent.pt"))
        else:
            # 遍历所有视频数据
            for k in videos:
                # 获取当前视频的帧数
                N = videos[k].shape[0]
                # 如果不是热图,裁剪视频数据
                if not isheatmap(videos[k]):
                    videos[k] = videos[k][:N]
                # 如果视频数据是张量
                if isinstance(videos[k], torch.Tensor):
                    # 将张量分离、转换为浮点数并移动到 CPU
                    videos[k] = videos[k].detach().float().cpu()
                    # 如果不是热图,限制张量值范围
                    if not isheatmap(videos[k]):
                        videos[k] = torch.clamp(videos[k], -1.0, 1.0)

            # 获取当前批次的视频帧数
            num_frames = batch["num_frames"][0]
            # 获取当前批次的帧率
            fps = batch["fps"][0].cpu().item()
            # 如果只记录视频潜在变量
            if only_log_video_latents:
                # 创建潜在变量的子目录
                root = os.path.join(root, "latents")
                # 格式化文件名,包含当前迭代次数
                filename = "{}_gs-{:06}".format("latents", args.iteration)
                # 生成完整的文件路径
                path = os.path.join(root, filename)
                # 创建保存路径的目录(如果不存在的话)
                os.makedirs(os.path.split(path)[0], exist_ok=True)
                # 创建文件夹(如果不存在的话)
                os.makedirs(path, exist_ok=True)
                # 保存潜在变量数据到指定路径
                torch.save(videos["latents"], os.path.join(path, "latents.pt"))
            else:
                # 遍历所有视频数据
                for k in videos:
                    # 将视频数据标准化到 [0, 1] 范围
                    samples = (videos[k] + 1.0) / 2.0
                    # 格式化文件名,包含当前迭代次数
                    filename = "{}_gs-{:06}".format(k, args.iteration)

                    # 生成完整的文件路径
                    path = os.path.join(root, filename)
                    # 创建保存路径的目录(如果不存在的话)
                    os.makedirs(os.path.split(path)[0], exist_ok=True)
                    # 保存视频为网格和 MP4 格式
                    save_video_as_grid_and_mp4(samples, path, num_frames // fps, fps, args, k)
# 定义广播批处理函数
def broad_cast_batch(batch):
    # 获取模型并行世界的大小
    mp_size = mpu.get_model_parallel_world_size()
    # 获取全局进程的排名
    global_rank = torch.distributed.get_rank() // mp_size
    # 计算源进程的索引
    src = global_rank * mp_size

    # 如果批次中存在 mp4 数据,则获取各数据的形状
    if batch["mp4"] is not None:
        broadcast_shape = [batch["mp4"].shape, batch["fps"].shape, batch["num_frames"].shape]
    else:
        # 否则设置为 None
        broadcast_shape = None

    # 准备要广播的对象列表,包括文本和形状信息
    txt = [batch["txt"], broadcast_shape]
    # 广播对象列表到指定源
    torch.distributed.broadcast_object_list(txt, src=src, group=mpu.get_model_parallel_group())
    # 从广播结果中提取文本数据
    batch["txt"] = txt[0]

    # 获取广播后各数据的形状
    mp4_shape = txt[1][0]
    fps_shape = txt[1][1]
    num_frames_shape = txt[1][2]

    # 如果当前模型并行进程不是 0,则初始化对应的张量
    if mpu.get_model_parallel_rank() != 0:
        batch["mp4"] = torch.zeros(mp4_shape, device="cuda")
        batch["fps"] = torch.zeros(fps_shape, device="cuda", dtype=torch.long)
        batch["num_frames"] = torch.zeros(num_frames_shape, device="cuda", dtype=torch.long)

    # 广播 mp4 数据
    torch.distributed.broadcast(batch["mp4"], src=src, group=mpu.get_model_parallel_group())
    # 广播 fps 数据
    torch.distributed.broadcast(batch["fps"], src=src, group=mpu.get_model_parallel_group())
    # 广播 num_frames 数据
    torch.distributed.broadcast(batch["num_frames"], src=src, group=mpu.get_model_parallel_group())
    # 返回处理后的批次数据
    return batch


# 定义前向评估步骤函数
def forward_step_eval(data_iterator, model, args, timers, only_log_video_latents=False, data_class=None):
    # 如果当前模型并行进程是 0,开始数据加载
    if mpu.get_model_parallel_rank() == 0:
        timers("data loader").start()  # 启动计时器
        batch_video = next(data_iterator)  # 获取下一个批次数据
        timers("data loader").stop()  # 停止计时器

        # 如果 mp4 数据的维度为 6,重新调整其形状
        if len(batch_video["mp4"].shape) == 6:
            b, v = batch_video["mp4"].shape[:2]  # 提取批次和视频维度
            batch_video["mp4"] = batch_video["mp4"].view(-1, *batch_video["mp4"].shape[2:])  # 扁平化 mp4 数据
            txt = []  # 初始化文本列表
            # 遍历批次和视频维度,构建文本列表
            for i in range(b):
                for j in range(v):
                    txt.append(batch_video["txt"][j][i])
            batch_video["txt"] = txt  # 更新批次中的文本数据

        # 将批次中的每个张量转移到 GPU
        for key in batch_video:
            if isinstance(batch_video[key], torch.Tensor):
                batch_video[key] = batch_video[key].cuda()  # 转移张量到 CUDA 设备
    else:
        # 如果当前进程不是 0,初始化空的批次数据
        batch_video = {"mp4": None, "fps": None, "num_frames": None, "txt": None}
    # 调用广播函数以同步批次数据
    broad_cast_batch(batch_video)
    # 如果数据并行进程是 0,记录视频数据
    if mpu.get_data_parallel_rank() == 0:
        log_video(batch_video, model, args, only_log_video_latents=only_log_video_latents)

    # 在批次中添加全局步骤信息
    batch_video["global_step"] = args.iteration
    # 进行共享步骤并计算损失
    loss, loss_dict = model.shared_step(batch_video)
    # 将损失字典中的 bfloat16 类型转换为 float32
    for k in loss_dict:
        if loss_dict[k].dtype == torch.bfloat16:
            loss_dict[k] = loss_dict[k].to(torch.float32)  # 转换数据类型
    return loss, loss_dict  # 返回损失和损失字典


# 定义前向步骤函数
def forward_step(data_iterator, model, args, timers, data_class=None):
    # 检查当前模型并行进程的排名是否为0
    if mpu.get_model_parallel_rank() == 0:
        # 启动计时器以记录数据加载时间
        timers("data loader").start()
        # 从数据迭代器中获取下一个批次的数据
        batch = next(data_iterator)
        # 停止计时器
        timers("data loader").stop()
        # 遍历批次中的每个键
        for key in batch:
            # 检查当前键对应的值是否为 PyTorch 张量
            if isinstance(batch[key], torch.Tensor):
                # 将张量移到 GPU 上
                batch[key] = batch[key].cuda()

        # 检查当前进程的分布式排名是否为0
        if torch.distributed.get_rank() == 0:
            # 检查保存目录下是否存在训练配置文件
            if not os.path.exists(os.path.join(args.save, "training_config.yaml")):
                # 加载基础配置文件,并将其存储在列表中
                configs = [OmegaConf.load(cfg) for cfg in args.base]
                # 合并所有基础配置
                config = OmegaConf.merge(*configs)
                # 创建保存目录(如果不存在)
                os.makedirs(args.save, exist_ok=True)
                # 将合并后的配置保存为 YAML 文件
                OmegaConf.save(config=config, f=os.path.join(args.save, "training_config.yaml"))
    else:
        # 如果当前不是模型并行进程0,则创建一个空的批次字典
        batch = {"mp4": None, "fps": None, "num_frames": None, "txt": None}

    # 在批次字典中添加全局步数
    batch["global_step"] = args.iteration

    # 广播批次数据到所有进程
    broad_cast_batch(batch)

    # 执行模型的共享步骤,计算损失和损失字典
    loss, loss_dict = model.shared_step(batch)

    # 返回计算的损失和损失字典
    return loss, loss_dict
# 如果当前脚本作为主程序运行
if __name__ == "__main__":
    # 检查环境变量中是否存在 OMPI_COMM_WORLD_LOCAL_RANK
    if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
        # 将 OMPI_COMM_WORLD_LOCAL_RANK 的值赋给 LOCAL_RANK 环境变量
        os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
        # 将 OMPI_COMM_WORLD_SIZE 的值赋给 WORLD_SIZE 环境变量
        os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
        # 将 OMPI_COMM_WORLD_RANK 的值赋给 RANK 环境变量
        os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]

    # 创建一个不带帮助信息的命令行参数解析器
    py_parser = argparse.ArgumentParser(add_help=False)
    # 解析已知和未知的命令行参数
    known, args_list = py_parser.parse_known_args()
    # 根据解析得到的参数列表获取具体参数
    args = get_args(args_list)
    # 将已知参数和具体参数合并到一个命名空间中
    args = argparse.Namespace(**vars(args), **vars(known))

    # 根据目标字符串获取数据类对象
    data_class = get_obj_from_str(args.data_config["target"])
    # 使用部分应用来创建数据集函数
    create_dataset_function = partial(data_class.create_dataset_function, **args.data_config["params"])

    # 导入 YAML 库
    import yaml

    # 初始化配置列表
    configs = []
    # 遍历基础配置文件列表
    for config in args.base:
        # 以只读模式打开基础配置文件
        with open(config, "r") as f:
            # 安全加载 YAML 文件内容为字典
            base_config = yaml.safe_load(f)
        # 将加载的基础配置添加到配置列表中
        configs.append(base_config)
    # 将加载的配置列表赋给 args.log_config
    args.log_config = configs

    # 调用训练主函数,传入相关参数和函数
    training_main(
        args,
        model_cls=SATVideoDiffusionEngine,
        forward_step_function=partial(forward_step, data_class=data_class),
        forward_step_eval=partial(
            forward_step_eval, data_class=data_class, only_log_video_latents=args.only_log_video_latents
        ),
        create_dataset_function=create_dataset_function,
    )

.\cogvideo-finetune\sat\vae_modules\attention.py

# 导入数学库
import math
# 从 inspect 模块导入 isfunction 函数,用于检查对象是否为函数
from inspect import isfunction
# 导入 Any 和 Optional 类型注解
from typing import Any, Optional

# 导入 PyTorch 库
import torch
# 导入 PyTorch 的功能性操作
import torch.nn.functional as F
# 从 einops 库导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 导入版本控制工具
from packaging import version
# 导入 PyTorch 的神经网络模块
from torch import nn

# 检查 PyTorch 版本是否大于或等于 2.0.0
if version.parse(torch.__version__) >= version.parse("2.0.0"):
    # 设置 SDP_IS_AVAILABLE 为 True,表示 SDP 后端可用
    SDP_IS_AVAILABLE = True
    # 从 CUDA 后端导入 SDPBackend 和 sdp_kernel
    from torch.backends.cuda import SDPBackend, sdp_kernel

    # 定义后端配置映射
    BACKEND_MAP = {
        SDPBackend.MATH: {
            "enable_math": True,  # 启用数学模式
            "enable_flash": False,  # 禁用闪电模式
            "enable_mem_efficient": False,  # 禁用内存高效模式
        },
        SDPBackend.FLASH_ATTENTION: {
            "enable_math": False,  # 禁用数学模式
            "enable_flash": True,  # 启用闪电模式
            "enable_mem_efficient": False,  # 禁用内存高效模式
        },
        SDPBackend.EFFICIENT_ATTENTION: {
            "enable_math": False,  # 禁用数学模式
            "enable_flash": False,  # 禁用闪电模式
            "enable_mem_efficient": True,  # 启用内存高效模式
        },
        None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},  # 默认情况下启用所有模式
    }
else:
    # 从上下文管理模块导入 nullcontext
    from contextlib import nullcontext

    # 设置 SDP_IS_AVAILABLE 为 False,表示 SDP 后端不可用
    SDP_IS_AVAILABLE = False
    # 将 sdp_kernel 设置为 nullcontext
    sdp_kernel = nullcontext
    # 打印警告信息,提示用户升级 PyTorch 版本
    print(
        f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
        f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
    )

# 尝试导入 xformers 和 xformers.ops 模块
try:
    import xformers
    import xformers.ops

    # 如果导入成功,设置 XFORMERS_IS_AVAILABLE 为 True
    XFORMERS_IS_AVAILABLE = True
# 如果导入失败,设置 XFORMERS_IS_AVAILABLE 为 False,并打印提示信息
except:
    XFORMERS_IS_AVAILABLE = False
    print("no module 'xformers'. Processing without...")

# 从 modules.utils 模块导入 checkpoint 函数
from modules.utils import checkpoint


# 定义一个检查值是否存在的函数
def exists(val):
    # 返回值是否不为 None
    return val is not None


# 定义一个去重函数
def uniq(arr):
    # 将数组转换为字典以去重,然后返回字典的键
    return {el: True for el in arr}.keys()


# 定义一个返回默认值的函数
def default(val, d):
    # 如果 val 存在,返回 val
    if exists(val):
        return val
    # 否则返回 d 的结果,如果 d 是函数则调用它
    return d() if isfunction(d) else d


# 定义一个返回最大负值的函数
def max_neg_value(t):
    # 返回指定数据类型的最大负值
    return -torch.finfo(t.dtype).max


# 定义一个初始化张量的函数
def init_(tensor):
    # 获取张量的最后一个维度
    dim = tensor.shape[-1]
    # 计算标准差
    std = 1 / math.sqrt(dim)
    # 在[-std, std]范围内均匀初始化张量
    tensor.uniform_(-std, std)
    # 返回初始化后的张量
    return tensor


# 定义一个前馈神经网络类
class GEGLU(nn.Module):
    # 初始化方法
    def __init__(self, dim_in, dim_out):
        super().__init__()  # 调用父类的初始化方法
        # 创建线性变换层,输出维度为输入维度的两倍
        self.proj = nn.Linear(dim_in, dim_out * 2)

    # 前向传播方法
    def forward(self, x):
        # 将输入 x 通过线性层投影并拆分为两部分
        x, gate = self.proj(x).chunk(2, dim=-1)
        # 返回 x 和经过 GELU 激活的 gate 的乘积
        return x * F.gelu(gate)


# 定义一个前馈神经网络类
class FeedForward(nn.Module):
    # 初始化方法
    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
        super().__init__()  # 调用父类的初始化方法
        # 计算内部维度
        inner_dim = int(dim * mult)
        # 如果 dim_out 为空,设置为 dim
        dim_out = default(dim_out, dim)
        # 根据 glu 参数决定使用哪种输入变换
        project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)

        # 创建包含输入变换、dropout 和线性变换的序列模型
        self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))

    # 前向传播方法
    def forward(self, x):
        # 返回网络的输出
        return self.net(x)


# 定义一个将模块参数归零的函数
def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    # 遍历模块的所有参数
    for p in module.parameters():
        # 断开与计算图的联系并将参数归零
        p.detach().zero_()
    # 返回修改后的模块
    return module


# 定义一个归一化函数
def Normalize(in_channels):
    # 返回 GroupNorm 实例,用于对输入通道进行归一化
    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)


# 定义一个线性注意力类
class LinearAttention(nn.Module):
    # 初始化方法,设置维度、头数和每个头的维度
        def __init__(self, dim, heads=4, dim_head=32):
            # 调用父类的初始化方法
            super().__init__()
            # 存储头数
            self.heads = heads
            # 计算隐藏层的维度,即头数与每个头的维度的乘积
            hidden_dim = dim_head * heads
            # 创建一个卷积层,将输入通道数转换为三倍的隐藏维度,用于生成查询、键和值
            self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
            # 创建一个卷积层,将隐藏维度转换回原始的输入通道数
            self.to_out = nn.Conv2d(hidden_dim, dim, 1)
    
        # 前向传播方法,处理输入数据
        def forward(self, x):
            # 获取输入的批次大小、通道数、高度和宽度
            b, c, h, w = x.shape
            # 通过卷积层生成查询、键和值的组合
            qkv = self.to_qkv(x)
            # 重新排列 qkv 数据,使其分开为 q、k 和 v,并按头数分组
            q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
            # 对键进行 softmax 操作,计算注意力权重
            k = k.softmax(dim=-1)
            # 使用爱因斯坦求和约定计算上下文,即加权后的值
            context = torch.einsum("bhdn,bhen->bhde", k, v)
            # 根据查询和上下文计算输出
            out = torch.einsum("bhde,bhdn->bhen", context, q)
            # 重新排列输出数据,恢复到原始形状
            out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
            # 返回最终的输出结果
            return self.to_out(out)
# 定义一个空间自注意力类,继承自 nn.Module
class SpatialSelfAttention(nn.Module):
    # 初始化方法,接受输入通道数
    def __init__(self, in_channels):
        # 调用父类构造函数
        super().__init__()
        # 保存输入通道数
        self.in_channels = in_channels

        # 创建归一化层
        self.norm = Normalize(in_channels)
        # 创建查询卷积层
        self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        # 创建键卷积层
        self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        # 创建值卷积层
        self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
        # 创建输出投影卷积层
        self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)

    # 前向传播方法
    def forward(self, x):
        # 初始化 h_ 为输入 x
        h_ = x
        # 对输入进行归一化
        h_ = self.norm(h_)
        # 计算查询
        q = self.q(h_)
        # 计算键
        k = self.k(h_)
        # 计算值
        v = self.v(h_)

        # 计算注意力
        b, c, h, w = q.shape  # 获取批量大小、通道数、高度和宽度
        # 重新排列查询以便进行矩阵乘法
        q = rearrange(q, "b c h w -> b (h w) c")
        # 重新排列键以便进行矩阵乘法
        k = rearrange(k, "b c h w -> b c (h w)")
        # 计算注意力权重
        w_ = torch.einsum("bij,bjk->bik", q, k)

        # 缩放注意力权重
        w_ = w_ * (int(c) ** (-0.5))
        # 应用 softmax 得到注意力分布
        w_ = torch.nn.functional.softmax(w_, dim=2)

        # 处理值
        v = rearrange(v, "b c h w -> b c (h w)")  # 重新排列值
        w_ = rearrange(w_, "b i j -> b j i")  # 重新排列注意力权重
        # 根据注意力权重对值进行加权
        h_ = torch.einsum("bij,bjk->bik", v, w_)
        # 重新排列以匹配输出形状
        h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
        # 通过输出卷积层进行投影
        h_ = self.proj_out(h_)

        # 返回原始输入与输出的和
        return x + h_


# 定义一个交叉注意力类,继承自 nn.Module
class CrossAttention(nn.Module):
    # 初始化方法,接受多个参数
    def __init__(
        self,
        query_dim,
        context_dim=None,
        heads=8,
        dim_head=64,
        dropout=0.0,
        backend=None,
    ):
        # 调用父类构造函数
        super().__init__()
        # 计算内部维度
        inner_dim = dim_head * heads
        # 设置上下文维度的默认值
        context_dim = default(context_dim, query_dim)

        # 计算缩放因子
        self.scale = dim_head**-0.5
        # 保存头的数量
        self.heads = heads

        # 创建查询线性层
        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        # 创建键线性层
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        # 创建值线性层
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        # 创建输出线性层及 dropout
        self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
        # 保存后端
        self.backend = backend

    # 前向传播方法
    def forward(
        self,
        x,
        context=None,
        mask=None,
        additional_tokens=None,
        n_times_crossframe_attn_in_self=0,
    ):
        h = self.heads  # 获取头的数量,用于后续的多头注意力机制

        if additional_tokens is not None:  # 检查是否有额外的标记
            # 获取输出序列开头的掩蔽标记数量
            n_tokens_to_mask = additional_tokens.shape[1]  
            # 将额外的标记添加到输入序列前面
            x = torch.cat([additional_tokens, x], dim=1)  

        # 将输入 x 转换为查询向量 q
        q = self.to_q(x)  
        # 默认上下文为 x
        context = default(context, x)  
        # 将上下文转换为键向量 k
        k = self.to_k(context)  
        # 将上下文转换为值向量 v
        v = self.to_v(context)  

        if n_times_crossframe_attn_in_self:  # 检查是否需要进行跨帧注意力机制
            # 按照文献中的方法重新编程跨帧注意力
            assert x.shape[0] % n_times_crossframe_attn_in_self == 0  # 确保输入长度是跨帧次数的整数倍
            n_cp = x.shape[0] // n_times_crossframe_attn_in_self  # 计算每个跨帧的大小
            # 根据跨帧次数重复键向量
            k = repeat(k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp)  
            # 根据跨帧次数重复值向量
            v = repeat(v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp)  

        # 将 q, k, v 重新排列为适合多头注意力的形状
        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))  

        # old
        """
        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale  # 计算查询和键的相似度
        del q, k  # 删除不再需要的 q 和 k

        if exists(mask):  # 检查是否存在掩蔽
            mask = rearrange(mask, 'b ... -> b (...)')  # 重新排列掩蔽形状
            max_neg_value = -torch.finfo(sim.dtype).max  # 获取最大负值,用于掩蔽
            mask = repeat(mask, 'b j -> (b h) () j', h=h)  # 将掩蔽重复到多头
            sim.masked_fill_(~mask, max_neg_value)  # 将不需要的部分填充为最大负值

        # 计算注意力分布
        sim = sim.softmax(dim=-1)  

        out = einsum('b i j, b j d -> b i d', sim, v)  # 计算最终输出
        """
        # new
        with sdp_kernel(**BACKEND_MAP[self.backend]):  # 使用指定后端的 SDP 核心
            # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)  # 调试信息,打印 q, k, v 的形状
            out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)  # 计算缩放点积注意力,默认缩放因子为 dim_head ** -0.5

        del q, k, v  # 删除 q, k, v,释放内存
        out = rearrange(out, "b h n d -> b n (h d)", h=h)  # 将输出重新排列为合适的形状

        if additional_tokens is not None:  # 如果存在额外的标记
            # 移除额外的标记
            out = out[:, n_tokens_to_mask:]  
        return self.to_out(out)  # 将输出转换为最终输出格式并返回
# 定义一个内存高效的交叉注意力模块,继承自 nn.Module
class MemoryEfficientCrossAttention(nn.Module):
    # 初始化方法,设置参数并打印相关信息
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs):
        # 调用父类的初始化方法
        super().__init__()
        # 打印类名及查询维度、上下文维度、头数和维度信息
        print(
            f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
            f"{heads} heads with a dimension of {dim_head}."
        )
        # 计算每个头的内部维度
        inner_dim = dim_head * heads
        # 如果上下文维度为 None,则默认为查询维度
        context_dim = default(context_dim, query_dim)

        # 保存头数和每个头的维度
        self.heads = heads
        self.dim_head = dim_head

        # 定义查询、键、值的线性变换,无偏置项
        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        # 定义最终输出的线性变换和 dropout 层
        self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
        # 初始化注意力操作为 None
        self.attention_op: Optional[Any] = None

    # 前向传播方法,接收输入数据和可选参数
    def forward(
        self,
        x,
        context=None,
        mask=None,
        additional_tokens=None,
        n_times_crossframe_attn_in_self=0,
    ):
        # 检查是否有额外的令牌
        if additional_tokens is not None:
            # 获取输出序列开头被遮蔽的令牌数量
            n_tokens_to_mask = additional_tokens.shape[1]
            # 将额外的令牌与输入张量在列上拼接
            x = torch.cat([additional_tokens, x], dim=1)
        # 将输入张量转换为查询张量
        q = self.to_q(x)
        # 默认上下文为输入张量
        context = default(context, x)
        # 将上下文转换为键张量
        k = self.to_k(context)
        # 将上下文转换为值张量
        v = self.to_v(context)

        # 检查是否需要在自注意力中进行跨帧注意力
        if n_times_crossframe_attn_in_self:
            # 按照文献重新编程跨帧注意力
            assert x.shape[0] % n_times_crossframe_attn_in_self == 0
            # k的形状将被重复n_times_crossframe_attn_in_self次
            k = repeat(
                k[::n_times_crossframe_attn_in_self],
                "b ... -> (b n) ...",
                n=n_times_crossframe_attn_in_self,
            )
            # v的形状将被重复n_times_crossframe_attn_in_self次
            v = repeat(
                v[::n_times_crossframe_attn_in_self],
                "b ... -> (b n) ...",
                n=n_times_crossframe_attn_in_self,
            )

        # 获取查询张量的批次大小和其他维度
        b, _, _ = q.shape
        # 对q, k, v进行处理以匹配注意力机制的要求
        q, k, v = map(
            lambda t: t.unsqueeze(3)
            .reshape(b, t.shape[1], self.heads, self.dim_head)
            .permute(0, 2, 1, 3)
            .reshape(b * self.heads, t.shape[1], self.dim_head)
            .contiguous(),
            (q, k, v),
        )

        # 实际计算注意力机制
        out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)

        # TODO: 将这个直接用于注意力操作作为偏差
        if exists(mask):
            # 如果存在掩码,抛出未实现异常
            raise NotImplementedError
        # 调整输出张量的形状以符合后续处理
        out = (
            out.unsqueeze(0)
            .reshape(b, self.heads, out.shape[1], self.dim_head)
            .permute(0, 2, 1, 3)
            .reshape(b, out.shape[1], self.heads * self.dim_head)
        )
        # 如果有额外的令牌,去除它们
        if additional_tokens is not None:
            out = out[:, n_tokens_to_mask:]
        # 将输出张量转换为最终输出
        return self.to_out(out)
# 基础变换器块,继承自 nn.Module
class BasicTransformerBlock(nn.Module):
    # 定义注意力模式的映射
    ATTENTION_MODES = {
        "softmax": CrossAttention,  # 标准注意力
        "softmax-xformers": MemoryEfficientCrossAttention,  # 节省内存的注意力
    }

    # 初始化方法,接收多个参数
    def __init__(
        self,
        dim,  # 特征维度
        n_heads,  # 注意力头数量
        d_head,  # 每个注意力头的维度
        dropout=0.0,  # dropout比率
        context_dim=None,  # 上下文维度
        gated_ff=True,  # 是否使用门控前馈网络
        checkpoint=True,  # 是否使用检查点
        disable_self_attn=False,  # 是否禁用自注意力
        attn_mode="softmax",  # 注意力模式
        sdp_backend=None,  # 后端配置
    ):
        # 调用父类初始化方法
        super().__init__()
        # 确保指定的注意力模式有效
        assert attn_mode in self.ATTENTION_MODES
        # 如果选择的模式不可用,则回退到默认模式
        if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
            print(
                f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
                f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
            )
            attn_mode = "softmax"
        # 如果选择的是标准模式,但不再支持,则进行适当处理
        elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
            print("We do not support vanilla attention anymore, as it is too expensive. Sorry.")
            # 确保 xformers 可用
            if not XFORMERS_IS_AVAILABLE:
                assert False, "Please install xformers via e.g. 'pip install xformers==0.0.16'"
            else:
                print("Falling back to xformers efficient attention.")
                attn_mode = "softmax-xformers"
        # 根据选择的模式获取注意力类
        attn_cls = self.ATTENTION_MODES[attn_mode]
        # 检查 PyTorch 版本以确保后端有效
        if version.parse(torch.__version__) >= version.parse("2.0.0"):
            assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
        else:
            assert sdp_backend is None
        # 设置自注意力禁用标志
        self.disable_self_attn = disable_self_attn
        # 创建第一个注意力层
        self.attn1 = attn_cls(
            query_dim=dim,  # 查询维度
            heads=n_heads,  # 注意力头数量
            dim_head=d_head,  # 每个头的维度
            dropout=dropout,  # dropout比率
            context_dim=context_dim if self.disable_self_attn else None,  # 上下文维度,若禁用自注意力则为None
            backend=sdp_backend,  # 后端配置
        )  # 如果未禁用自注意力,则为自注意力层
        # 创建前馈网络
        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
        # 创建第二个注意力层
        self.attn2 = attn_cls(
            query_dim=dim,  # 查询维度
            context_dim=context_dim,  # 上下文维度
            heads=n_heads,  # 注意力头数量
            dim_head=d_head,  # 每个头的维度
            dropout=dropout,  # dropout比率
            backend=sdp_backend,  # 后端配置
        )  # 如果上下文为None,则为自注意力层
        # 创建层归一化层
        self.norm1 = nn.LayerNorm(dim)  # 第一层归一化
        self.norm2 = nn.LayerNorm(dim)  # 第二层归一化
        self.norm3 = nn.LayerNorm(dim)  # 第三层归一化
        # 设置检查点标志
        self.checkpoint = checkpoint
        # 如果启用检查点,则打印提示
        if self.checkpoint:
            print(f"{self.__class__.__name__} is using checkpointing")
    # 前向传播方法,接收输入和可选的上下文、附加标记及其他参数
    def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
        # 初始化参数字典,包含输入 x
        kwargs = {"x": x}
    
        # 如果上下文不为 None,将其添加到参数字典
        if context is not None:
            kwargs.update({"context": context})
    
        # 如果附加标记不为 None,将其添加到参数字典
        if additional_tokens is not None:
            kwargs.update({"additional_tokens": additional_tokens})
    
        # 如果跨帧自注意力次数大于 0,将其添加到参数字典
        if n_times_crossframe_attn_in_self:
            kwargs.update({"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self})
    
        # 调用检查点函数,执行前向传播,返回计算结果
        return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
    
    # 定义实际的前向传播逻辑
    def _forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
        # 对输入 x 进行归一化处理并通过第一个注意力层,结合上下文和附加标记
        x = (
            self.attn1(
                self.norm1(x),
                context=context if self.disable_self_attn else None,
                additional_tokens=additional_tokens,
                n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0,
            )
            + x  # 将原始输入加到输出中,形成残差连接
        )
        # 对处理后的 x 进行归一化处理,并通过第二个注意力层
        x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x  # 残差连接
        # 通过前馈层处理,完成归一化后加到输入上
        x = self.ff(self.norm3(x)) + x  # 残差连接
        # 返回最终的输出
        return x
# 定义一个单层基本变换器块,继承自 nn.Module
class BasicTransformerSingleLayerBlock(nn.Module):
    # 定义注意力模式的字典,映射名称到对应的注意力类
    ATTENTION_MODES = {
        "softmax": CrossAttention,  # 标准注意力
        "softmax-xformers": MemoryEfficientCrossAttention,  # 在 A100 上速度不如上面的版本
        # (待办:可能依赖于 head_dim,检查,降级为针对 dim!=[16,32,64,128] 的半优化内核)
    }

    # 初始化函数,定义构造函数的参数
    def __init__(
        self,
        dim,  # 输入特征维度
        n_heads,  # 注意力头的数量
        d_head,  # 每个注意力头的维度
        dropout=0.0,  # dropout 概率
        context_dim=None,  # 上下文维度(可选)
        gated_ff=True,  # 是否使用门控前馈网络
        checkpoint=True,  # 是否使用检查点以节省内存
        attn_mode="softmax",  # 使用的注意力模式
    ):
        # 调用父类的构造函数
        super().__init__()
        # 确保给定的注意力模式在允许的模式中
        assert attn_mode in self.ATTENTION_MODES
        # 根据模式获取对应的注意力类
        attn_cls = self.ATTENTION_MODES[attn_mode]
        # 初始化注意力层
        self.attn1 = attn_cls(
            query_dim=dim,  # 查询维度
            heads=n_heads,  # 头数量
            dim_head=d_head,  # 每个头的维度
            dropout=dropout,  # dropout 概率
            context_dim=context_dim,  # 上下文维度
        )
        # 初始化前馈网络
        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
        # 初始化层归一化
        self.norm1 = nn.LayerNorm(dim)
        # 初始化第二个层归一化
        self.norm2 = nn.LayerNorm(dim)
        # 保存检查点标志
        self.checkpoint = checkpoint

    # 前向传播函数
    def forward(self, x, context=None):
        # 使用检查点机制调用 _forward 方法
        return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)

    # 实际的前向传播逻辑
    def _forward(self, x, context=None):
        # 先应用注意力层,结合残差连接
        x = self.attn1(self.norm1(x), context=context) + x
        # 应用前馈网络,再结合残差连接
        x = self.ff(self.norm2(x)) + x
        # 返回输出
        return x


# 定义用于图像数据的变换器块,继承自 nn.Module
class SpatialTransformer(nn.Module):
    """
    Transformer block for image-like data.
    First, project the input (aka embedding)
    and reshape to b, t, d.
    Then apply standard transformer action.
    Finally, reshape to image
    NEW: use_linear for more efficiency instead of the 1x1 convs
    """

    # 初始化函数,定义构造函数的参数
    def __init__(
        self,
        in_channels,  # 输入通道数
        n_heads,  # 注意力头的数量
        d_head,  # 每个头的维度
        depth=1,  # 堆叠的层数
        dropout=0.0,  # dropout 概率
        context_dim=None,  # 上下文维度(可选)
        disable_self_attn=False,  # 是否禁用自注意力
        use_linear=False,  # 是否使用线性层以提高效率
        attn_type="softmax",  # 使用的注意力类型
        use_checkpoint=True,  # 是否使用检查点以节省内存
        # sdp_backend=SDPBackend.FLASH_ATTENTION  # 备用的 sdp 后端
        sdp_backend=None,  # 指定 sdp 后端(默认为 None)
    ):
        # 调用父类的构造函数,初始化父类属性
        super().__init__()
        # 打印当前类的名称、深度、输入通道数和头的数量
        print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads")
        # 从 omegaconf 库导入 ListConfig 类
        from omegaconf import ListConfig

        # 如果 context_dim 存在且不是列表或 ListConfig 类型,则将其转换为列表
        if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
            context_dim = [context_dim]
        # 如果 context_dim 存在且是列表类型
        if exists(context_dim) and isinstance(context_dim, list):
            # 如果给定的深度与 context_dim 的长度不匹配
            if depth != len(context_dim):
                # 打印警告信息,提示深度与上下文维度不匹配,并重设上下文维度
                print(
                    f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
                    f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
                )
                # 断言所有上下文维度相同,以便自动匹配深度
                assert all(
                    map(lambda x: x == context_dim[0], context_dim)
                ), "need homogenous context_dim to match depth automatically"
                # 将上下文维度设置为深度乘以第一个上下文维度
                context_dim = depth * [context_dim[0]]
        # 如果上下文维度为 None,创建一个包含 None 的列表,长度为深度
        elif context_dim is None:
            context_dim = [None] * depth
        # 设置输入通道数的属性
        self.in_channels = in_channels
        # 计算内部维度,等于头的数量乘以每个头的维度
        inner_dim = n_heads * d_head
        # 创建归一化层
        self.norm = Normalize(in_channels)
        # 根据是否使用线性层选择不同的输入投影层
        if not use_linear:
            # 使用卷积层进行输入投影
            self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
        else:
            # 使用线性层进行输入投影
            self.proj_in = nn.Linear(in_channels, inner_dim)

        # 创建变换器块的模块列表
        self.transformer_blocks = nn.ModuleList(
            [
                # 对于每一层深度,创建基本变换器块
                BasicTransformerBlock(
                    inner_dim,
                    n_heads,
                    d_head,
                    dropout=dropout,
                    context_dim=context_dim[d],
                    disable_self_attn=disable_self_attn,
                    attn_mode=attn_type,
                    checkpoint=use_checkpoint,
                    sdp_backend=sdp_backend,
                )
                for d in range(depth)  # 遍历深度范围
            ]
        )
        # 根据是否使用线性层选择不同的输出投影层
        if not use_linear:
            # 使用零初始化的卷积层作为输出投影层
            self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
        else:
            # 使用零初始化的线性层作为输出投影层
            self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
        # 设置是否使用线性层的属性
        self.use_linear = use_linear
    # 前向传播函数,接受输入 x 和可选的上下文参数
        def forward(self, x, context=None):
            # 如果没有给定上下文,交叉注意力默认使用自注意力
            if not isinstance(context, list):
                context = [context]  # 确保上下文为列表形式
            # 获取输入 x 的形状信息
            b, c, h, w = x.shape
            # 保存原始输入以便后续使用
            x_in = x
            # 对输入 x 进行归一化处理
            x = self.norm(x)
            # 如果不使用线性变换,则进行输入投影
            if not self.use_linear:
                x = self.proj_in(x)
            # 重新排列 x 的形状,以适应后续处理
            x = rearrange(x, "b c h w -> b (h w) c").contiguous()
            # 如果使用线性变换,则进行输入投影
            if self.use_linear:
                x = self.proj_in(x)
            # 遍历每个变换块
            for i, block in enumerate(self.transformer_blocks):
                # 如果不是第一个块且上下文只有一个,则使用相同的上下文
                if i > 0 and len(context) == 1:
                    i = 0  # 每个块使用相同的上下文
                # 将 x 和对应的上下文传入当前块
                x = block(x, context=context[i])
            # 如果使用线性变换,则进行输出投影
            if self.use_linear:
                x = self.proj_out(x)
            # 重新排列 x 的形状,恢复到原始输入的形状
            x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
            # 如果不使用线性变换,则进行输出投影
            if not self.use_linear:
                x = self.proj_out(x)
            # 返回输出与原始输入的和
            return x + x_in

标签:dim,None,CogVideoX,self,torch,---,源码,meta,context
From: https://www.cnblogs.com/apachecn/p/18494411

相关文章

  • CogVideo---CogVideoX-微调代码源码解析-十三-
    CogVideo&CogVideoX微调代码源码解析(十三)VideoCaptionTypically,mostvideodatadoesnotcomewithcorrespondingdescriptivetext,soitisnecessarytoconvertthevideodataintotextualdescriptionstoprovidetheessentialtrainingdatafortext-to-vid......
  • CogVideo---CogVideoX-微调代码源码解析-十二-
    CogVideo&CogVideoX微调代码源码解析(十二).\cogvideo-finetune\sat\vae_modules\autoencoder.py#导入所需的标准库和第三方库importlogging#日志库,用于记录信息importmath#数学库,提供数学函数importre#正则表达式库,用于字符串匹配importrandom#随机数库,......
  • CogVideo---CogVideoX-微调代码源码解析-十-
    CogVideo&CogVideoX微调代码源码解析(十).\cogvideo-finetune\sat\sgm\modules\diffusionmodules\sampling.py#部分移植自https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py"""#从类型提示模块导入字典和联合类型fromtypingimportDict......
  • CogVideo---CogVideoX-微调代码源码解析-七-
    CogVideo&CogVideoX微调代码源码解析(七).\cogvideo-finetune\sat\sgm\modules\autoencoding\regularizers\base.py#导入抽象方法和类型注解fromabcimportabstractmethod#导入任意类型和元组类型fromtypingimportAny,Tuple#导入PyTorch和功能模块importtorc......
  • CogVideo---CogVideoX-微调代码源码解析-六-
    CogVideo&CogVideoX微调代码源码解析(六).\cogvideo-finetune\sat\sgm\modules\autoencoding\losses\discriminator_loss.py#导入所需的类型定义fromtypingimportDict,Iterator,List,Optional,Tuple,Union#导入NumPy库importnumpyasnp#导入PyTorch库im......
  • CogVideo---CogVideoX-微调代码源码解析-九-
    CogVideo&CogVideoX微调代码源码解析(九).\cogvideo-finetune\sat\sgm\modules\diffusionmodules\guiders.py#导入logging模块以进行日志记录importlogging#从abc模块导入ABC和abstractmethod以定义抽象基类和抽象方法fromabcimportABC,abstractmethod#从......
  • CogVideo---CogVideoX-微调代码源码解析-二-
    CogVideo&CogVideoX微调代码源码解析(二).\cogvideo-finetune\inference\cli_demo.py#脚本说明:演示如何使用CogVideoX模型通过HuggingFace`diffusers`管道生成视频"""ThisscriptdemonstrateshowtogenerateavideousingtheCogVideoXmodelwiththeHuggingF......
  • CogVideo---CogVideoX-微调代码源码解析-八-
    CogVideo&CogVideoX微调代码源码解析(八).\cogvideo-finetune\sat\sgm\modules\autoencoding\vqvae\movq_enc_3d.py#pytorch_diffusion+derivedencoderdecoderimportmath#导入数学库,提供数学函数importtorch#导入PyTorch库,用于深度学习importtorch.nnasnn......
  • CogView3---CogView-3Plus-微调代码源码解析-五-
    CogView3&CogView-3Plus微调代码源码解析(五).\cogview3-finetune\sat\sgm\modules\encoders\__init__.py请提供需要注释的代码。.\cogview3-finetune\sat\sgm\modules\__init__.py#从当前包的编码器模块导入GeneralConditioner类from.encoders.modulesimportGeneral......
  • CogView3---CogView-3Plus-微调代码源码解析-四-
    CogView3&CogView-3Plus微调代码源码解析(四).\cogview3-finetune\sat\sgm\modules\diffusionmodules\sampling_utils.py#导入数学库以进行数学运算importmath#导入PyTorch库以进行张量操作importtorch#从SciPy库导入积分函数fromscipyimportintegrate#从......