首页 > 编程语言 >diffusers-源码解析-九-

diffusers-源码解析-九-

时间:2024-10-22 12:36:22浏览次数:1  
标签:dim None dtype self torch diffusers 源码 model 解析

diffusers 源码解析(九)

.\diffusers\models\embeddings_flax.py

# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# 该文件的使用需要遵循 Apache 2.0 许可证
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 不提供任何明示或暗示的担保或条件
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# 查看许可证以了解特定权限和限制
import math  # 导入数学库以进行数学运算

import flax.linen as nn  # 导入Flax库中的神经网络模块
import jax.numpy as jnp  # 导入JAX的numpy模块以进行数值计算


def get_sinusoidal_embeddings(
    timesteps: jnp.ndarray,  # 定义输入参数 timesteps 为一维 JAX 数组
    embedding_dim: int,  # 定义输出嵌入的维度
    freq_shift: float = 1,  # 频率偏移的默认值为1
    min_timescale: float = 1,  # 最小时间尺度的默认值
    max_timescale: float = 1.0e4,  # 最大时间尺度的默认值
    flip_sin_to_cos: bool = False,  # 是否翻转正弦和余弦
    scale: float = 1.0,  # 缩放因子的默认值
) -> jnp.ndarray:  # 函数返回一个 JAX 数组
    """Returns the positional encoding (same as Tensor2Tensor).
    
    返回位置编码,类似于Tensor2Tensor

    Args:
        timesteps: a 1-D Tensor of N indices, one per batch element.
        输入为一维张量,N个索引,每个批次元素一个
        These may be fractional.
        embedding_dim: The number of output channels.
        嵌入的通道数
        min_timescale: The smallest time unit (should probably be 0.0).
        最小时间单位
        max_timescale: The largest time unit.
        最大时间单位
    Returns:
        a Tensor of timing signals [N, num_channels]
        返回时间信号的张量 [N, num_channels]
    """
    assert timesteps.ndim == 1, "Timesteps should be a 1d-array"  # 检查 timesteps 是否为一维数组
    assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"  # 检查嵌入维度是否为偶数
    num_timescales = float(embedding_dim // 2)  # 计算时间尺度的数量
    log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)  # 计算对数时间尺度增量
    inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)  # 计算反时间尺度
    emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)  # 计算嵌入

    # scale embeddings
    scaled_time = scale * emb  # 对嵌入进行缩放

    if flip_sin_to_cos:  # 如果需要翻转正弦和余弦
        signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)  # 拼接余弦和正弦信号
    else:  # 否则
        signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)  # 拼接正弦和余弦信号
    signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])  # 重塑信号的形状
    return signal  # 返回信号


class FlaxTimestepEmbedding(nn.Module):  # 定义时间步嵌入模块
    r"""
    Time step Embedding Module. Learns embeddings for input time steps.
    时间步嵌入模块。学习输入时间步的嵌入

    Args:
        time_embed_dim (`int`, *optional*, defaults to `32`):
                Time step embedding dimension
                时间步嵌入维度
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
                Parameters `dtype`
                参数的数据类型
    """

    time_embed_dim: int = 32  # 设置时间嵌入维度的默认值为32
    dtype: jnp.dtype = jnp.float32  # 设置参数的数据类型的默认值为jnp.float32

    @nn.compact  # 指示该方法为紧凑的神经网络模块
    def __call__(self, temb):  # 定义模块的调用方法,接收输入参数 temb
        temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb)  # 第一个全连接层
        temb = nn.silu(temb)  # 应用Silu激活函数
        temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb)  # 第二个全连接层
        return temb  # 返回处理后的temb


class FlaxTimesteps(nn.Module):  # 定义时间步模块
    r"""
    # 包装类,用于生成正弦时间步嵌入,详细说明见 https://arxiv.org/abs/2006.11239
    
    # 参数:
    #     dim (`int`, *可选*, 默认为 `32`):
    #             时间步嵌入的维度
        dim: int = 32  # 定义时间步嵌入的维度,默认值为 32
        flip_sin_to_cos: bool = False  # 定义是否将正弦值转换为余弦值,默认为 False
        freq_shift: float = 1  # 定义频率偏移量,默认为 1
    
        @nn.compact  # 表示这是一个紧凑模式的神经网络层,适合 JAX 使用
        def __call__(self, timesteps):  # 定义调用方法,接受时间步作为输入
            return get_sinusoidal_embeddings(  # 调用函数生成正弦嵌入
                timesteps,  # 输入的时间步
                embedding_dim=self.dim,  # 嵌入维度设置为实例属性 dim
                flip_sin_to_cos=self.flip_sin_to_cos,  # 设置是否翻转正弦到余弦
                freq_shift=self.freq_shift  # 设置频率偏移量
            )  # 返回生成的正弦嵌入

.\diffusers\models\lora.py

# 版权信息,指明文件的版权所有者和保留权利
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 按照 Apache 许可证第 2.0 版(“许可证”)进行许可;
# 除非遵循许可证,否则不得使用本文件。
# 可在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,
# 否则根据许可证分发的软件在“按原样”基础上提供,
# 不提供任何形式的明示或暗示的担保或条件。
# 请参阅许可证,以获取有关权限和
# 限制的具体条款。

# 重要提示:                                                      #
###################################################################
# ----------------------------------------------------------------#
# 此文件已被弃用,将很快删除                                   #
# (一旦 PEFT 成为 LoRA 的必需依赖项)                          #
# ----------------------------------------------------------------#
###################################################################

from typing import Optional, Tuple, Union  # 导入可选类型、元组和联合类型以用于类型注解

import torch  # 导入 PyTorch 库
import torch.nn.functional as F  # 导入 PyTorch 的功能性神经网络模块
from torch import nn  # 从 PyTorch 导入神经网络模块

from ..utils import deprecate, logging  # 从上级目录导入工具函数 deprecate 和 logging
from ..utils.import_utils import is_transformers_available  # 导入检查 transformers 库可用性的函数


# 如果 transformers 库可用,则导入相关模型
if is_transformers_available():
    from transformers import CLIPTextModel, CLIPTextModelWithProjection  # 导入 CLIP 文本模型及其变体


logger = logging.get_logger(__name__)  # 创建一个记录器实例,用于日志记录,禁用 pylint 的名称检查


def text_encoder_attn_modules(text_encoder):
    attn_modules = []  # 初始化一个空列表,用于存储注意力模块

    # 检查文本编码器是否为 CLIPTextModel 或 CLIPTextModelWithProjection 的实例
    if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
        # 遍历编码器层,收集每一层的自注意力模块
        for i, layer in enumerate(text_encoder.text_model.encoder.layers):
            name = f"text_model.encoder.layers.{i}.self_attn"  # 构造注意力模块的名称
            mod = layer.self_attn  # 获取当前层的自注意力模块
            attn_modules.append((name, mod))  # 将名称和模块元组添加到列表中
    else:
        # 如果文本编码器不是预期的类型,抛出值错误
        raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")

    return attn_modules  # 返回注意力模块的列表


def text_encoder_mlp_modules(text_encoder):
    mlp_modules = []  # 初始化一个空列表,用于存储 MLP 模块

    # 检查文本编码器是否为 CLIPTextModel 或 CLIPTextModelWithProjection 的实例
    if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
        # 遍历编码器层,收集每一层的 MLP 模块
        for i, layer in enumerate(text_encoder.text_model.encoder.layers):
            mlp_mod = layer.mlp  # 获取当前层的 MLP 模块
            name = f"text_model.encoder.layers.{i}.mlp"  # 构造 MLP 模块的名称
            mlp_modules.append((name, mlp_mod))  # 将名称和模块元组添加到列表中
    else:
        # 如果文本编码器不是预期的类型,抛出值错误
        raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}")

    return mlp_modules  # 返回 MLP 模块的列表


def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
    # 遍历文本编码器中的注意力模块
    for _, attn_module in text_encoder_attn_modules(text_encoder):
        # 检查当前注意力模块的查询投影是否为 PatchedLoraProjection 实例
        if isinstance(attn_module.q_proj, PatchedLoraProjection):
            attn_module.q_proj.lora_scale = lora_scale  # 调整查询投影的 Lora 缩放因子
            attn_module.k_proj.lora_scale = lora_scale  # 调整键投影的 Lora 缩放因子
            attn_module.v_proj.lora_scale = lora_scale  # 调整值投影的 Lora 缩放因子
            attn_module.out_proj.lora_scale = lora_scale  # 调整输出投影的 Lora 缩放因子
    # 遍历文本编码器中的 MLP 模块
        for _, mlp_module in text_encoder_mlp_modules(text_encoder):
            # 检查当前模块的 fc1 层是否为 PatchedLoraProjection 类型
            if isinstance(mlp_module.fc1, PatchedLoraProjection):
                # 设置 fc1 层的 lora_scale 属性
                mlp_module.fc1.lora_scale = lora_scale
                # 设置 fc2 层的 lora_scale 属性
                mlp_module.fc2.lora_scale = lora_scale
# 定义一个名为 PatchedLoraProjection 的类,继承自 PyTorch 的 nn.Module
class PatchedLoraProjection(torch.nn.Module):
    # 初始化方法,接受多个参数以设置 LoraProjection
    def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
        # 设置弃用警告信息
        deprecation_message = "Use of `PatchedLoraProjection` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
        # 调用 deprecate 函数记录弃用信息
        deprecate("PatchedLoraProjection", "1.0.0", deprecation_message)

        # 调用父类的初始化方法
        super().__init__()
        # 从 lora 模块导入 LoRALinearLayer 类
        from ..models.lora import LoRALinearLayer

        # 保存传入的常规线性层
        self.regular_linear_layer = regular_linear_layer

        # 获取常规线性层的设备信息
        device = self.regular_linear_layer.weight.device

        # 如果未指定数据类型,则使用常规线性层的权重数据类型
        if dtype is None:
            dtype = self.regular_linear_layer.weight.dtype

        # 创建 LoRALinearLayer 实例
        self.lora_linear_layer = LoRALinearLayer(
            self.regular_linear_layer.in_features,
            self.regular_linear_layer.out_features,
            network_alpha=network_alpha,
            device=device,
            dtype=dtype,
            rank=rank,
        )

        # 保存 LoRA 的缩放因子
        self.lora_scale = lora_scale

    # 重写 PyTorch 的 state_dict 方法以确保仅保存 'regular_linear_layer' 权重
    def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
        # 如果没有 LoRA 线性层,返回常规线性层的状态字典
        if self.lora_linear_layer is None:
            return self.regular_linear_layer.state_dict(
                *args, destination=destination, prefix=prefix, keep_vars=keep_vars
            )

        # 否则调用父类的 state_dict 方法
        return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)

    # 定义一个融合 LoRA 权重的方法
    def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
        # 如果没有 LoRA 线性层,则直接返回
        if self.lora_linear_layer is None:
            return

        # 获取常规线性层的权重数据类型和设备
        dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device

        # 将常规线性层的权重转换为浮点类型
        w_orig = self.regular_linear_layer.weight.data.float()
        # 将 LoRA 层的上权重转换为浮点类型
        w_up = self.lora_linear_layer.up.weight.data.float()
        # 将 LoRA 层的下权重转换为浮点类型
        w_down = self.lora_linear_layer.down.weight.data.float()

        # 如果 network_alpha 不为 None,则调整上权重
        if self.lora_linear_layer.network_alpha is not None:
            w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank

        # 计算融合后的权重
        fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])

        # 如果安全融合并且融合权重中包含 NaN,抛出异常
        if safe_fusing and torch.isnan(fused_weight).any().item():
            raise ValueError(
                "This LoRA weight seems to be broken. "
                f"Encountered NaN values when trying to fuse LoRA weights for {self}."
                "LoRA weights will not be fused."
            )

        # 更新常规线性层的权重数据
        self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype)

        # 将 LoRA 线性层设为 None,表示已经融合
        self.lora_linear_layer = None

        # 将上、下权重矩阵转移到 CPU 以节省内存
        self.w_up = w_up.cpu()
        self.w_down = w_down.cpu()
        # 更新 LoRA 的缩放因子
        self.lora_scale = lora_scale
    # 定义解融合 Lora 的私有方法
    def _unfuse_lora(self):
        # 检查 w_up 和 w_down 属性是否存在且不为 None
        if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
            # 如果任一属性为 None,则直接返回
            return

        # 获取常规线性层的权重数据
        fused_weight = self.regular_linear_layer.weight.data
        # 保存权重的数据类型和设备信息
        dtype, device = fused_weight.dtype, fused_weight.device

        # 将 w_up 转换为目标设备并转为浮点类型
        w_up = self.w_up.to(device=device).float()
        # 将 w_down 转换为目标设备并转为浮点类型
        w_down = self.w_down.to(device).float()

        # 计算未融合的权重,通过从融合权重中减去 Lora 的贡献
        unfused_weight = fused_weight.float() - (self.lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
        # 将未融合的权重赋值回常规线性层
        self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype)

        # 清空 w_up 和 w_down 属性
        self.w_up = None
        self.w_down = None

    # 定义前向传播方法
    def forward(self, input):
        # 如果 lora_scale 为 None,则设置为 1.0
        if self.lora_scale is None:
            self.lora_scale = 1.0
        # 如果 lora_linear_layer 为 None,则直接返回常规线性层的输出
        if self.lora_linear_layer is None:
            return self.regular_linear_layer(input)
        # 返回常规线性层的输出加上 Lora 的贡献
        return self.regular_linear_layer(input) + (self.lora_scale * self.lora_linear_layer(input))
# 定义一个用于 LoRA 的线性层,继承自 nn.Module
class LoRALinearLayer(nn.Module):
    r"""
    A linear layer that is used with LoRA.

    Parameters:
        in_features (`int`):
            Number of input features.
        out_features (`int`):
            Number of output features.
        rank (`int`, `optional`, defaults to 4):
            The rank of the LoRA layer.
        network_alpha (`float`, `optional`, defaults to `None`):
            The value of the network alpha used for stable learning and preventing underflow. This value has the same
            meaning as the `--network_alpha` option in the kohya-ss trainer script. See
            https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
        device (`torch.device`, `optional`, defaults to `None`):
            The device to use for the layer's weights.
        dtype (`torch.dtype`, `optional`, defaults to `None`):
            The dtype to use for the layer's weights.
    """

    # 初始化方法,定义输入输出特征和其他参数
    def __init__(
        self,
        in_features: int,  # 输入特征数量
        out_features: int,  # 输出特征数量
        rank: int = 4,  # LoRA 层的秩,默认为 4
        network_alpha: Optional[float] = None,  # 用于稳定学习的网络 alpha,默认为 None
        device: Optional[Union[torch.device, str]] = None,  # 权重使用的设备,默认为 None
        dtype: Optional[torch.dtype] = None,  # 权重使用的数据类型,默认为 None
    ):
        super().__init__()  # 调用父类的初始化方法

        # 弃用提示消息,提醒用户切换到 PEFT 后端
        deprecation_message = "Use of `LoRALinearLayer` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
        deprecate("LoRALinearLayer", "1.0.0", deprecation_message)  # 记录弃用信息

        # 定义向下线性层,不使用偏置
        self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
        # 定义向上线性层,不使用偏置
        self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
        # 将网络 alpha 值赋给实例变量
        self.network_alpha = network_alpha
        self.rank = rank  # 保存秩
        self.out_features = out_features  # 保存输出特征数量
        self.in_features = in_features  # 保存输入特征数量

        # 使用正态分布初始化向下权重
        nn.init.normal_(self.down.weight, std=1 / rank)
        # 将向上权重初始化为零
        nn.init.zeros_(self.up.weight)

    # 前向传播方法,接受隐藏状态并返回处理后的结果
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        orig_dtype = hidden_states.dtype  # 保存输入数据类型
        dtype = self.down.weight.dtype  # 获取向下层权重的数据类型

        # 通过向下层处理隐藏状态
        down_hidden_states = self.down(hidden_states.to(dtype))
        # 通过向上层处理向下层输出
        up_hidden_states = self.up(down_hidden_states)

        # 如果网络 alpha 不为 None,则调整向上层输出
        if self.network_alpha is not None:
            up_hidden_states *= self.network_alpha / self.rank

        # 返回与原始数据类型相同的输出
        return up_hidden_states.to(orig_dtype)


# 定义一个用于 LoRA 的卷积层,继承自 nn.Module
class LoRAConv2dLayer(nn.Module):
    r"""
    A convolutional layer that is used with LoRA.
    # 参数说明
    Parameters:
        in_features (`int`):  # 输入特征的数量
            Number of input features.  # 输入特征的数量
        out_features (`int`):  # 输出特征的数量
            Number of output features.  # 输出特征的数量
        rank (`int`, `optional`, defaults to 4):  # LoRA 层的秩,默认为 4
            The rank of the LoRA layer.  # LoRA 层的秩
        kernel_size (`int` or `tuple` of two `int`, `optional`, defaults to 1):  # 卷积核的大小,默认为 (1, 1)
            The kernel size of the convolution.  # 卷积核的大小
        stride (`int` or `tuple` of two `int`, `optional`, defaults to 1):  # 卷积的步幅,默认为 (1, 1)
            The stride of the convolution.  # 卷积的步幅
        padding (`int` or `tuple` of two `int` or `str`, `optional`, defaults to 0):  # 卷积的填充方式,默认为 0
            The padding of the convolution.  # 卷积的填充方式
        network_alpha (`float`, `optional`, defaults to `None`):  # 网络 alpha 的值,用于稳定学习,防止下溢
            The value of the network alpha used for stable learning and preventing underflow. This value has the same
            meaning as the `--network_alpha` option in the kohya-ss trainer script. See  # 与 kohya-ss 训练脚本中的 `--network_alpha` 选项含义相同
            https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning  # 参考链接

    # 初始化方法
    def __init__(
        self,
        in_features: int,  # 输入特征数量
        out_features: int,  # 输出特征数量
        rank: int = 4,  # LoRA 层的秩,默认为 4
        kernel_size: Union[int, Tuple[int, int]] = (1, 1),  # 卷积核大小,默认为 (1, 1)
        stride: Union[int, Tuple[int, int]] = (1, 1),  # 卷积步幅,默认为 (1, 1)
        padding: Union[int, Tuple[int, int], str] = 0,  # 卷积填充,默认为 0
        network_alpha: Optional[float] = None,  # 网络 alpha 的值,默认为 None
    ):
        super().__init__()  # 调用父类的初始化方法

        # 弃用警告信息,提示用户切换到 PEFT 后端
        deprecation_message = "Use of `LoRAConv2dLayer` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
        deprecate("LoRAConv2dLayer", "1.0.0", deprecation_message)  # 发出弃用警告

        # 定义下卷积层,输入为 in_features,输出为 rank,使用指定的卷积参数
        self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
        # 根据官方 kohya_ss 训练器,向上卷积层的卷积核大小始终固定
        # # 参考链接: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
        # 定义上卷积层,输入为 rank,输出为 out_features,使用固定的卷积核大小 (1, 1)
        self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False)

        # 保存网络 alpha 值,与训练脚本中的相同含义
        # 参考链接: https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
        self.network_alpha = network_alpha  # 设置网络 alpha 值
        self.rank = rank  # 设置秩

        # 初始化下卷积层的权重为均值为 0,标准差为 1/rank 的正态分布
        nn.init.normal_(self.down.weight, std=1 / rank)
        # 初始化上卷积层的权重为 0
        nn.init.zeros_(self.up.weight)

    # 前向传播方法
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:  # 定义前向传播函数
        orig_dtype = hidden_states.dtype  # 保存输入张量的原始数据类型
        dtype = self.down.weight.dtype  # 获取下卷积层权重的数据类型

        # 将输入的隐状态张量通过下卷积层
        down_hidden_states = self.down(hidden_states.to(dtype))
        # 将下卷积层的输出通过上卷积层
        up_hidden_states = self.up(down_hidden_states)

        # 如果 network_alpha 不为 None,则进行缩放
        if self.network_alpha is not None:
            up_hidden_states *= self.network_alpha / self.rank  # 根据 network_alpha 进行缩放

        # 返回转换回原始数据类型的输出张量
        return up_hidden_states.to(orig_dtype)  # 返回最终输出
# 定义一个可以与 LoRA 兼容的卷积层,继承自 nn.Conv2d
class LoRACompatibleConv(nn.Conv2d):
    """
    A convolutional layer that can be used with LoRA.
    """

    # 初始化方法,接受可变数量的参数,lora_layer 为可选参数,其他参数通过 kwargs 接收
    def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs):
        # 设置弃用消息,提示用户切换到 PEFT 后端
        deprecation_message = "Use of `LoRACompatibleConv` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
        # 调用弃用函数,记录此类的弃用信息
        deprecate("LoRACompatibleConv", "1.0.0", deprecation_message)

        # 调用父类的初始化方法
        super().__init__(*args, **kwargs)
        # 将 lora_layer 赋值给实例变量
        self.lora_layer = lora_layer

    # 设置 lora_layer 的方法
    def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
        # 设置弃用消息,提示用户切换到 PEFT 后端
        deprecation_message = "Use of `set_lora_layer()` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
        # 调用弃用函数,记录此方法的弃用信息
        deprecate("set_lora_layer", "1.0.0", deprecation_message)

        # 将传入的 lora_layer 赋值给实例变量
        self.lora_layer = lora_layer

    # 融合 LoRA 权重的方法
    def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
        # 如果 lora_layer 为 None,直接返回
        if self.lora_layer is None:
            return

        # 获取当前权重的数据类型和设备
        dtype, device = self.weight.data.dtype, self.weight.data.device

        # 将权重转换为浮点型
        w_orig = self.weight.data.float()
        # 获取 lora_layer 的上升和下降权重,并转换为浮点型
        w_up = self.lora_layer.up.weight.data.float()
        w_down = self.lora_layer.down.weight.data.float()

        # 如果 network_alpha 不为 None,调整上升权重
        if self.lora_layer.network_alpha is not None:
            w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank

        # 进行矩阵乘法,融合上升和下降权重
        fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1))
        # 将融合的结果调整为原始权重的形状
        fusion = fusion.reshape((w_orig.shape))
        # 计算最终融合权重
        fused_weight = w_orig + (lora_scale * fusion)

        # 如果安全融合为 True,检查融合权重中是否有 NaN 值
        if safe_fusing and torch.isnan(fused_weight).any().item():
            raise ValueError(
                "This LoRA weight seems to be broken. "
                f"Encountered NaN values when trying to fuse LoRA weights for {self}."
                "LoRA weights will not be fused."
            )

        # 将融合后的权重赋值回实例的权重,保持设备和数据类型
        self.weight.data = fused_weight.to(device=device, dtype=dtype)

        # 融合后可以删除 lora_layer
        self.lora_layer = None

        # 将上升和下降矩阵转移到 CPU,以减少内存占用
        self.w_up = w_up.cpu()
        self.w_down = w_down.cpu()
        # 存储 lora_scale
        self._lora_scale = lora_scale

    # 解融合 LoRA 权重的方法
    def _unfuse_lora(self):
        # 检查 w_up 和 w_down 是否存在
        if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
            return

        # 获取当前融合权重
        fused_weight = self.weight.data
        # 获取当前权重的数据类型和设备
        dtype, device = fused_weight.data.dtype, fused_weight.data.device

        # 将 w_up 和 w_down 转移到正确的设备并转换为浮点型
        self.w_up = self.w_up.to(device=device).float()
        self.w_down = self.w_down.to(device).float()

        # 进行矩阵乘法,重新计算未融合权重
        fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1))
        # 将融合结果调整为融合权重的形状
        fusion = fusion.reshape((fused_weight.shape))
        # 计算最终的未融合权重
        unfused_weight = fused_weight.float() - (self._lora_scale * fusion)
        # 更新实例的权重
        self.weight.data = unfused_weight.to(device=device, dtype=dtype)

        # 清空 w_up 和 w_down
        self.w_up = None
        self.w_down = None
    # 定义前向传播函数,接收隐藏状态和缩放因子,返回张量
    def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
        # 检查填充模式是否不是“零”,若是则进行相应填充
        if self.padding_mode != "zeros":
            # 对隐藏状态进行填充,使用反向填充参数和指定的填充模式
            hidden_states = F.pad(hidden_states, self._reversed_padding_repeated_twice, mode=self.padding_mode)
            # 设置填充为 (0, 0)
            padding = (0, 0)
        else:
            # 使用类中的填充属性
            padding = self.padding
    
        # 进行二维卷积操作,返回卷积结果
        original_outputs = F.conv2d(
            hidden_states, self.weight, self.bias, self.stride, padding, self.dilation, self.groups
        )
    
        # 如果 Lora 层不存在,则返回卷积结果
        if self.lora_layer is None:
            return original_outputs
        else:
            # 否则,将卷积结果与 Lora 层的结果按比例相加并返回
            return original_outputs + (scale * self.lora_layer(hidden_states))
# 定义一个兼容 LoRA 的线性层,继承自 nn.Linear
class LoRACompatibleLinear(nn.Linear):
    """
    A Linear layer that can be used with LoRA.
    """

    # 初始化方法,接收参数并可选传入 LoRA 层
    def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
        # 定义弃用提示信息,建议用户切换到 PEFT 后端
        deprecation_message = "Use of `LoRACompatibleLinear` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
        # 调用弃用函数提示用户
        deprecate("LoRACompatibleLinear", "1.0.0", deprecation_message)

        # 调用父类的初始化方法
        super().__init__(*args, **kwargs)
        # 设置 LoRA 层
        self.lora_layer = lora_layer

    # 设置 LoRA 层的方法
    def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
        # 定义弃用提示信息,建议用户切换到 PEFT 后端
        deprecation_message = "Use of `set_lora_layer()` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
        # 调用弃用函数提示用户
        deprecate("set_lora_layer", "1.0.0", deprecation_message)
        # 设置 LoRA 层
        self.lora_layer = lora_layer

    # 融合 LoRA 权重的方法
    def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
        # 如果没有 LoRA 层,直接返回
        if self.lora_layer is None:
            return

        # 获取权重的数据类型和设备
        dtype, device = self.weight.data.dtype, self.weight.data.device

        # 将原始权重转换为浮点型
        w_orig = self.weight.data.float()
        # 获取 LoRA 层的上权重并转换为浮点型
        w_up = self.lora_layer.up.weight.data.float()
        # 获取 LoRA 层的下权重并转换为浮点型
        w_down = self.lora_layer.down.weight.data.float()

        # 如果网络 alpha 不为 None,则调整上权重
        if self.lora_layer.network_alpha is not None:
            w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank

        # 融合权重的计算
        fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])

        # 如果进行安全融合且融合权重存在 NaN,则抛出错误
        if safe_fusing and torch.isnan(fused_weight).any().item():
            raise ValueError(
                "This LoRA weight seems to be broken. "
                f"Encountered NaN values when trying to fuse LoRA weights for {self}."
                "LoRA weights will not be fused."
            )

        # 更新当前权重为融合后的权重
        self.weight.data = fused_weight.to(device=device, dtype=dtype)

        # 将 LoRA 层设为 None,表示已融合
        self.lora_layer = None

        # 将上权重和下权重移到 CPU,防止内存溢出
        self.w_up = w_up.cpu()
        self.w_down = w_down.cpu()
        # 保存 LoRA 融合的缩放因子
        self._lora_scale = lora_scale

    # 反融合 LoRA 权重的方法
    def _unfuse_lora(self):
        # 如果上权重和下权重不存在,直接返回
        if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
            return

        # 获取当前融合权重
        fused_weight = self.weight.data
        # 获取当前权重的数据类型和设备
        dtype, device = fused_weight.dtype, fused_weight.device

        # 将上权重和下权重移到对应设备并转换为浮点型
        w_up = self.w_up.to(device=device).float()
        w_down = self.w_down.to(device).float()

        # 计算未融合的权重
        unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
        # 更新当前权重为未融合的权重
        self.weight.data = unfused_weight.to(device=device, dtype=dtype)

        # 将上权重和下权重设为 None
        self.w_up = None
        self.w_down = None

    # 前向传播方法
    def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
        # 如果没有 LoRA 层,直接使用父类的前向传播
        if self.lora_layer is None:
            out = super().forward(hidden_states)
            return out
        else:
            # 使用父类的前向传播加上 LoRA 层的输出
            out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
            return out

.\diffusers\models\modeling_flax_pytorch_utils.py

# coding=utf-8  # 指定文件编码为 UTF-8
# Copyright 2024 The HuggingFace Inc. team.  # 版权信息,表明版权所有者

# Licensed under the Apache License, Version 2.0 (the "License");  # 说明该文件根据 Apache 2.0 许可证发布
# you may not use this file except in compliance with the License.  # 说明只能在遵守许可证的情况下使用此文件
# You may obtain a copy of the License at  # 提供获取许可证的地址
#
#     http://www.apache.org/licenses/LICENSE-2.0  # 许可证的具体链接
#
# Unless required by applicable law or agreed to in writing, software  # 免责声明,除非另有规定或书面同意
# distributed under the License is distributed on an "AS IS" BASIS,  # 说明软件是按“现状”提供的
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  # 没有任何形式的明示或暗示的保证
# See the License for the specific language governing permissions and  # 指向许可证以获取具体条款
# limitations under the License.  # 以及使用限制的说明
"""PyTorch - Flax general utilities."""  # 文档字符串,描述该模块的功能

import re  # 导入正则表达式模块

import jax.numpy as jnp  # 导入 JAX 的 NumPy 库,并重命名为 jnp
from flax.traverse_util import flatten_dict, unflatten_dict  # 从 flax 导入字典扁平化和还原的工具
from jax.random import PRNGKey  # 从 jax 导入伪随机数生成器的键

from ..utils import logging  # 从父目录导入 logging 模块

logger = logging.get_logger(__name__)  # 创建一个日志记录器,记录当前模块的信息

def rename_key(key):  # 定义一个函数,用于重命名键
    regex = r"\w+[.]\d+"  # 定义一个正则表达式,匹配包含点号和数字的字符串
    pats = re.findall(regex, key)  # 使用正则表达式查找所有匹配的字符串
    for pat in pats:  # 遍历所有找到的匹配
        key = key.replace(pat, "_".join(pat.split(".")))  # 将匹配的字符串中的点替换为下划线
    return key  # 返回修改后的键

#####################
# PyTorch => Flax #
#####################  # 注释区分 PyTorch 到 Flax 的转换部分

# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69  # 说明该函数的来源链接
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py  # 说明该函数的另一来源链接
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):  # 定义函数,重命名权重并在必要时改变张量形状
    """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""  # 文档字符串,说明函数功能
    # conv norm or layer norm  # 注释,说明即将处理的内容是卷积归一化或层归一化
    renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)  # 将原键的最后一个元素替换为 "scale"

    # rename attention layers  # 注释,说明将重命名注意力层
    if len(pt_tuple_key) > 1:  # 如果元组键的长度大于 1
        for rename_from, rename_to in (  # 遍历重命名映射的元组
            ("to_out_0", "proj_attn"),  # 旧名称到新名称的映射
            ("to_k", "key"),  # 旧名称到新名称的映射
            ("to_v", "value"),  # 旧名称到新名称的映射
            ("to_q", "query"),  # 旧名称到新名称的映射
        ):
            if pt_tuple_key[-2] == rename_from:  # 如果倒数第二个元素匹配旧名称
                weight_name = pt_tuple_key[-1]  # 获取最后一个元素作为权重名称
                weight_name = "kernel" if weight_name == "weight" else weight_name  # 如果权重名称是 "weight",则改为 "kernel"
                renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name)  # 生成新的键
                if renamed_pt_tuple_key in random_flax_state_dict:  # 如果新键存在于状态字典中
                    assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape  # 断言新键的形状与转置的张量形状相同
                    return renamed_pt_tuple_key, pt_tensor.T  # 返回新的键和转置的张量

    if (  # 检查是否满足以下条件
        any("norm" in str_ for str_ in pt_tuple_key)  # 如果键中任何部分包含 "norm"
        and (pt_tuple_key[-1] == "bias")  # 并且最后一个元素是 "bias"
        and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)  # 并且去掉最后一个元素后加 "bias" 的键不在状态字典中
        and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)  # 并且去掉最后一个元素后加 "scale" 的键在状态字典中
    ):
        renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)  # 将键的最后一个元素替换为 "scale"
        return renamed_pt_tuple_key, pt_tensor  # 返回新的键和原张量

    elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:  # 如果最后一个元素是 "weight" 或 "gamma" 并且去掉最后一个元素后加 "scale" 的键在状态字典中
        renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)  # 将键的最后一个元素替换为 "scale"
        return renamed_pt_tuple_key, pt_tensor  # 返回新的键和原张量

    # embedding  # 注释,表明此处将处理嵌入相关的内容
    # 检查元组的最后一个元素是否为 "weight",并且在字典中查找相应的 "embedding" 键
    if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
        # 将元组的最后一个元素替换为 "embedding"
        pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
        # 返回更新后的元组键和张量
        return renamed_pt_tuple_key, pt_tensor

    # 卷积层处理
    # 更新元组的最后一个元素为 "kernel"
    renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
    # 检查元组的最后一个元素是否为 "weight",并且张量的维度是否为 4
    if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
        # 转置张量的维度顺序
        pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
        # 返回更新后的元组键和张量
        return renamed_pt_tuple_key, pt_tensor

    # 线性层处理
    # 更新元组的最后一个元素为 "kernel"
    renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
    # 检查元组的最后一个元素是否为 "weight"
    if pt_tuple_key[-1] == "weight":
        # 转置张量
        pt_tensor = pt_tensor.T
        # 返回更新后的元组键和张量
        return renamed_pt_tuple_key, pt_tensor

    # 旧版 PyTorch 层归一化权重处理
    # 更新元组的最后一个元素为 "weight"
    renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
    # 检查元组的最后一个元素是否为 "gamma"
    if pt_tuple_key[-1] == "gamma":
        # 返回更新后的元组键和张量
        return renamed_pt_tuple_key, pt_tensor

    # 旧版 PyTorch 层归一化偏置处理
    # 更新元组的最后一个元素为 "bias"
    renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
    # 检查元组的最后一个元素是否为 "beta"
    if pt_tuple_key[-1] == "beta":
        # 返回更新后的元组键和张量
        return renamed_pt_tuple_key, pt_tensor

    # 如果没有匹配的条件,则返回原始元组键和张量
    return pt_tuple_key, pt_tensor
# 将 PyTorch 的状态字典转换为 Flax 模型的参数字典
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
    # 步骤 1:将 PyTorch 张量转换为 NumPy 数组
    pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}

    # 步骤 2:由于模型是无状态的,使用随机种子初始化 Flax 参数
    random_flax_params = flax_model.init_weights(PRNGKey(init_key))

    # 将随机生成的 Flax 参数展平为字典形式
    random_flax_state_dict = flatten_dict(random_flax_params)
    # 初始化一个空的 Flax 状态字典
    flax_state_dict = {}

    # 需要修改一些参数名称以匹配 Flax 的命名
    for pt_key, pt_tensor in pt_state_dict.items():
        # 重命名 PyTorch 的键
        renamed_pt_key = rename_key(pt_key)
        # 将重命名后的键分割成元组形式
        pt_tuple_key = tuple(renamed_pt_key.split("."))

        # 正确重命名权重参数并调整形状
        flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)

        # 检查 Flax 键是否在随机生成的状态字典中
        if flax_key in random_flax_state_dict:
            # 如果形状不匹配,抛出错误
            if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
                raise ValueError(
                    f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
                    f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
                )

        # 也将意外的权重添加到字典中,以便引发警告
        flax_state_dict[flax_key] = jnp.asarray(flax_tensor)

    # 返回解压缩后的 Flax 状态字典
    return unflatten_dict(flax_state_dict)

.\diffusers\models\modeling_flax_utils.py

# 指定文件编码为 UTF-8
# coding=utf-8
# 版权声明,表示文件由 HuggingFace Inc. 团队拥有
# Copyright 2024 The HuggingFace Inc. team.
#
# 根据 Apache 2.0 许可证许可本文件,使用时需遵循该许可证
# Licensed under the Apache License, Version 2.0 (the "License");
# 只能在遵循许可证的前提下使用此文件
# you may not use this file except in compliance with the License.
# 可以在此网址获取许可证副本
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面协议另有约定,软件按“原样”提供
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 不提供任何形式的明示或暗示的保证或条件
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 查看许可证以获取特定语言管理权限和限制
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入操作系统模块
import os
# 导入反序列化错误类
from pickle import UnpicklingError
# 导入类型提示所需的 Any, Dict, Union 类型
from typing import Any, Dict, Union

# 导入 JAX 库及其 NumPy 子模块
import jax
import jax.numpy as jnp
# 导入 msgpack 异常
import msgpack.exceptions
# 从 flax 库导入冻结字典及其解冻方法
from flax.core.frozen_dict import FrozenDict, unfreeze
# 从 flax 库导入字节序列化与反序列化方法
from flax.serialization import from_bytes, to_bytes
# 从 flax 库导入字典扁平化与解扁平化方法
from flax.traverse_util import flatten_dict, unflatten_dict
# 从 huggingface_hub 导入创建仓库和下载方法
from huggingface_hub import create_repo, hf_hub_download
# 导入 huggingface_hub 的一些异常类
from huggingface_hub.utils import (
    EntryNotFoundError,
    RepositoryNotFoundError,
    RevisionNotFoundError,
    validate_hf_hub_args,
)
# 导入请求库中的 HTTP 错误类
from requests import HTTPError

# 导入当前包的版本和 PyTorch 可用性检查
from .. import __version__, is_torch_available
# 导入工具函数和常量
from ..utils import (
    CONFIG_NAME,
    FLAX_WEIGHTS_NAME,
    HUGGINGFACE_CO_RESOLVE_ENDPOINT,
    WEIGHTS_NAME,
    PushToHubMixin,
    logging,
)
# 从模型转换工具中导入 PyTorch 状态字典转换为 Flax 的方法
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义 FlaxModelMixin 类,继承自 PushToHubMixin
class FlaxModelMixin(PushToHubMixin):
    r"""
    所有 Flax 模型的基类。

    [`FlaxModelMixin`] 负责存储模型配置,并提供加载、下载和保存模型的方法。

        - **config_name** ([`str`]) -- 调用 [`~FlaxModelMixin.save_pretrained`] 时保存模型的文件名。
    """

    # 配置文件名常量,指定模型配置文件名
    config_name = CONFIG_NAME
    # 自动保存的参数列表
    _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
    # Flax 内部参数列表
    _flax_internal_args = ["name", "parent", "dtype"]

    # 类方法,用于根据配置创建模型实例
    @classmethod
    def _from_config(cls, config, **kwargs):
        """
        模型初始化所需的上下文管理器在这里定义。
        """
        # 返回类的实例,传入配置和其他参数
        return cls(config, **kwargs)
    # 定义一个方法,将给定参数的浮点值转换为指定的数据类型
    def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
        # 帮助方法,用于将给定 PyTree 中的浮点值转换为给定的数据类型
        """
        Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
        """
    
        # 条件转换函数,判断参数类型并执行转换
        # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
        def conditional_cast(param):
            # 检查参数是否为浮点类型的数组
            if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
                # 将数组转换为指定的数据类型
                param = param.astype(dtype)
            # 返回转换后的参数
            return param
    
        # 如果没有提供掩码,则对所有参数应用条件转换
        if mask is None:
            # 使用 jax.tree_map 对参数树中的每个元素应用条件转换
            return jax.tree_map(conditional_cast, params)
    
        # 扁平化参数字典以便处理
        flat_params = flatten_dict(params)
        # 扁平化掩码,并丢弃结构信息
        flat_mask, _ = jax.tree_flatten(mask)
    
        # 遍历掩码和参数的扁平化键
        for masked, key in zip(flat_mask, flat_params.keys()):
            # 如果掩码为真,则执行转换
            if masked:
                param = flat_params[key]
                # 将转换后的参数重新存储回扁平化参数字典中
                flat_params[key] = conditional_cast(param)
    
        # 将扁平化的参数字典转换回原始结构
        return unflatten_dict(flat_params)
    
    # 定义一个方法,将参数转换为 bfloat16 类型
    def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
        # 将浮点参数转换为 jax.numpy.bfloat16,返回新的参数树
        r"""
        Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
        the `params` in place.
    
        This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full
        half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
    
        Arguments:
            params (`Union[Dict, FrozenDict]`):
                A `PyTree` of model parameters.
            mask (`Union[Dict, FrozenDict]`):
                A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
                for params you want to cast, and `False` for those you want to skip.
    
        Examples:
    
        ```python
        >>> from diffusers import FlaxUNet2DConditionModel
    
        >>> # load model
        >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
        >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
        >>> params = model.to_bf16(params)
        >>> # If you don't want to cast certain parameters (for example layer norm bias and scale)
        >>> # then pass the mask as follows
        >>> from flax import traverse_util
    
        >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
        >>> flat_params = traverse_util.flatten_dict(params)
        >>> mask = {
        ...     path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
        ...     for path in flat_params
        ... }
        >>> mask = traverse_util.unflatten_dict(mask)
        >>> params = model.to_bf16(params, mask)
        ```py"""
        # 调用内部方法,将参数转换为 bfloat16 类型
        return self._cast_floating_to(params, jnp.bfloat16, mask)
    # 将模型参数转换为浮点32位格式的方法
    def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
        r""" 
        将浮点数 `params` 转换为 `jax.numpy.float32`。此方法可用于显式将模型参数转换为 fp32 精度。
        返回一个新的 `params` 树,而不在原地转换 `params`。
    
        参数:
            params (`Union[Dict, FrozenDict]`):
                模型参数的 `PyTree`。
            mask (`Union[Dict, FrozenDict]`):
                与 `params` 树具有相同结构的 `PyTree`。叶子应为布尔值。应为要转换的参数设置为 `True`,为要跳过的参数设置为 `False`。
    
        示例:
    
        ```python
        >>> from diffusers import FlaxUNet2DConditionModel
    
        >>> # 从 huggingface.co 下载模型和配置
        >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
        >>> # 默认情况下,模型参数将是 fp32,为了说明此方法的用法,
        >>> # 我们将首先转换为 fp16,然后再转换回 fp32
        >>> params = model.to_f16(params)
        >>> # 现在转换回 fp32
        >>> params = model.to_fp32(params)
        ```py"""
        # 调用私有方法,将参数转换为浮点32格式,传入参数、目标类型和掩码
        return self._cast_floating_to(params, jnp.float32, mask)
    # 定义一个将浮点数参数转换为 float16 的方法,接受参数字典和可选的掩码
    def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
        r"""
        将浮点数 `params` 转换为 `jax.numpy.float16`。该方法返回一个新的 `params` 树,不会在原地转换 `params`。

        此方法可在 GPU 上使用,显式地将模型参数转换为 float16 精度,以进行全半精度训练,或将权重保存为 float16 以便推理,从而节省内存并提高速度。

        参数:
            params (`Union[Dict, FrozenDict]`):
                一个模型参数的 `PyTree`。
            mask (`Union[Dict, FrozenDict]`):
                具有与 `params` 树相同结构的 `PyTree`。叶子节点应为布尔值。对于要转换的参数,应为 `True`,而要跳过的参数应为 `False`。

        示例:

        ```python
        >>> from diffusers import FlaxUNet2DConditionModel

        >>> # 加载模型
        >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
        >>> # 默认情况下,模型参数将为 fp32,转换为 float16
        >>> params = model.to_fp16(params)
        >>> # 如果你不想转换某些参数(例如层归一化的偏差和尺度)
        >>> # 则可以按如下方式传递掩码
        >>> from flax import traverse_util

        >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
        >>> flat_params = traverse_util.flatten_dict(params)
        >>> mask = {
        ...     path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
        ...     for path in flat_params
        ... }
        >>> mask = traverse_util.unflatten_dict(mask)
        >>> params = model.to_fp16(params, mask)
        ```py"""
        # 调用内部方法将参数转换为 float16 类型,传入可选的掩码
        return self._cast_floating_to(params, jnp.float16, mask)

    # 定义一个初始化权重的方法,接受随机数生成器作为参数,返回字典
    def init_weights(self, rng: jax.Array) -> Dict:
        # 抛出未实现的错误,提示此方法需要被实现
        raise NotImplementedError(f"init_weights method has to be implemented for {self}")

    # 定义一个类方法用于从预训练模型加载参数,接受模型名称或路径等参数
    @classmethod
    @validate_hf_hub_args
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Union[str, os.PathLike],
        dtype: jnp.dtype = jnp.float32,
        *model_args,
        **kwargs,
    # 定义一个保存预训练模型的方法,接受保存目录和参数等
    def save_pretrained(
        self,
        save_directory: Union[str, os.PathLike],
        params: Union[Dict, FrozenDict],
        is_main_process: bool = True,
        push_to_hub: bool = False,
        **kwargs,
    ):
        """
        保存模型及其配置文件到指定目录,以便使用
        [`~FlaxModelMixin.from_pretrained`] 类方法重新加载。

        参数:
            save_directory (`str` 或 `os.PathLike`):
                保存模型及其配置文件的目录。如果目录不存在,将会被创建。
            params (`Union[Dict, FrozenDict]`):
                模型参数的 `PyTree`。
            is_main_process (`bool`, *可选*, 默认为 `True`):
                调用此函数的进程是否为主进程。在分布式训练中非常有用,
                需要在所有进程上调用此函数。此时,仅在主进程上将 `is_main_process=True`
                以避免竞争条件。
            push_to_hub (`bool`, *可选*, 默认为 `False`):
                保存模型后是否将其推送到 Hugging Face 模型库。可以使用 `repo_id`
                指定要推送到的库(默认为 `save_directory` 中的名称)。
            kwargs (`Dict[str, Any]`, *可选*):
                额外的关键字参数,将传递给 [`~utils.PushToHubMixin.push_to_hub`] 方法。
        """
        # 检查提供的路径是否为文件,如果是则记录错误并返回
        if os.path.isfile(save_directory):
            logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
            return

        # 如果目录不存在则创建该目录
        os.makedirs(save_directory, exist_ok=True)

        # 如果需要推送到模型库
        if push_to_hub:
            # 从关键字参数中弹出提交信息,如果没有则为 None
            commit_message = kwargs.pop("commit_message", None)
            # 从关键字参数中弹出隐私设置,默认为 False
            private = kwargs.pop("private", False)
            # 从关键字参数中弹出创建 PR 的设置,默认为 False
            create_pr = kwargs.pop("create_pr", False)
            # 从关键字参数中弹出 token,默认为 None
            token = kwargs.pop("token", None)
            # 从关键字参数中弹出 repo_id,默认为 save_directory 的最后一部分
            repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
            # 创建库并获取 repo_id
            repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id

        # 将当前对象赋值给 model_to_save
        model_to_save = self

        # 将模型架构附加到配置中
        # 保存配置
        if is_main_process:
            # 如果是主进程,保存模型配置到指定目录
            model_to_save.save_config(save_directory)

        # 保存模型的输出文件路径
        output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
        # 以二进制写入模式打开模型文件
        with open(output_model_file, "wb") as f:
            # 将模型参数转换为字节
            model_bytes = to_bytes(params)
            # 将字节数据写入文件
            f.write(model_bytes)

        # 记录模型权重保存的路径信息
        logger.info(f"Model weights saved in {output_model_file}")

        # 如果需要推送到模型库
        if push_to_hub:
            # 调用上传文件夹的方法,将模型文件夹推送到模型库
            self._upload_folder(
                save_directory,
                repo_id,
                token=token,
                commit_message=commit_message,
                create_pr=create_pr,
            )

.\diffusers\models\modeling_outputs.py

# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass

# 从上级目录的 utils 模块导入 BaseOutput 类
from ..utils import BaseOutput


# 定义 AutoencoderKLOutput 类,继承自 BaseOutput
@dataclass
class AutoencoderKLOutput(BaseOutput):
    """
    AutoencoderKL 编码方法的输出。

    参数:
        latent_dist (`DiagonalGaussianDistribution`):
            编码器的输出,以 `DiagonalGaussianDistribution` 的均值和对数方差表示。
            `DiagonalGaussianDistribution` 允许从分布中采样潜在变量。
    """

    # 定义 latent_dist 属性,类型为 DiagonalGaussianDistribution
    latent_dist: "DiagonalGaussianDistribution"  # noqa: F821


# 定义 Transformer2DModelOutput 类,继承自 BaseOutput
@dataclass
class Transformer2DModelOutput(BaseOutput):
    """
    [`Transformer2DModel`] 的输出。

    参数:
        sample (`torch.Tensor`,形状为 `(batch_size, num_channels, height, width)` 或 `(batch size, num_vector_embeds - 1, num_latent_pixels)` 如果 [`Transformer2DModel`] 是离散的):
            基于 `encoder_hidden_states` 输入的隐藏状态输出。如果是离散的,则返回无噪声潜在像素的概率分布。
    """

    # 定义 sample 属性,类型为 torch.Tensor
    sample: "torch.Tensor"  # noqa: F821

.\diffusers\models\modeling_pytorch_flax_utils.py

# 指定文件编码为 UTF-8
# coding=utf-8
# 版权所有 2024 The HuggingFace Inc. 团队。
#
# 根据 Apache 许可证版本 2.0("许可证")许可;
# 除非遵守许可证,否则您不得使用此文件。
# 您可以在以下地址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,软件在许可证下以“原样”方式分发,
# 不提供任何形式的保证或条件,无论是明示或暗示的。
# 有关许可证下的特定权限和限制,请参见许可证。
"""PyTorch - Flax 一般实用工具。"""

# 从 pickle 模块导入 UnpicklingError 异常
from pickle import UnpicklingError

# 导入 jax 库及其 numpy 模块
import jax
import jax.numpy as jnp
# 导入 numpy 库
import numpy as np
# 从 flax.serialization 导入 from_bytes 函数
from flax.serialization import from_bytes
# 从 flax.traverse_util 导入 flatten_dict 函数
from flax.traverse_util import flatten_dict

# 从 utils 模块导入 logging
from ..utils import logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

#####################
# Flax => PyTorch #
#####################

# 从指定模型文件加载 Flax 检查点到 PyTorch 模型
# 来源:https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352
def load_flax_checkpoint_in_pytorch_model(pt_model, model_file):
    # 尝试打开模型文件以读取 Flax 状态
    try:
        with open(model_file, "rb") as flax_state_f:
            # 从字节流中反序列化 Flax 状态
            flax_state = from_bytes(None, flax_state_f.read())
    # 捕获反序列化错误
    except UnpicklingError as e:
        try:
            # 以文本模式打开模型文件
            with open(model_file) as f:
                # 检查文件内容是否以 "version" 开头
                if f.read().startswith("version"):
                    # 如果是,抛出 OSError,提示缺少 git-lfs
                    raise OSError(
                        "You seem to have cloned a repository without having git-lfs installed. Please"
                        " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
                        " folder you cloned."
                    )
                else:
                    # 否则,抛出 ValueError
                    raise ValueError from e
        # 捕获 Unicode 解码错误和其他值错误
        except (UnicodeDecodeError, ValueError):
            # 抛出环境错误,提示无法转换文件
            raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")

    # 返回加载的 Flax 权重到 PyTorch 模型
    return load_flax_weights_in_pytorch_model(pt_model, flax_state)

# 从 Flax 状态加载权重到 PyTorch 模型
def load_flax_weights_in_pytorch_model(pt_model, flax_state):
    """将 Flax 检查点加载到 PyTorch 模型中"""

    # 尝试导入 PyTorch
    try:
        import torch  # noqa: F401
    # 捕获导入错误
    except ImportError:
        # 记录错误信息,提示需要安装 PyTorch 和 Flax
        logger.error(
            "Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see"
            " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
            " instructions."
        )
        # 抛出异常
        raise

    # 检查是否存在 bf16 权重
    is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
    # 如果存在 bf16 类型的权重
    if any(is_type_bf16):
        # 如果权重是 bf16 类型,转换为 fp32,因为 torch.from_numpy 无法处理 bf16
        
        # 而且 bf16 在 PyTorch 中尚未完全支持。
        logger.warning(
            "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
            "before loading those in PyTorch model."
        )
        # 使用 tree_map 遍历 flax_state,将 bf16 权重转换为 float32
        flax_state = jax.tree_util.tree_map(
            lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
        )

    # 将基础模型前缀设为空
    pt_model.base_model_prefix = ""

    # 将 flax_state 字典扁平化,使用 "." 作为分隔符
    flax_state_dict = flatten_dict(flax_state, sep=".")
    # 获取 PyTorch 模型的状态字典
    pt_model_dict = pt_model.state_dict()

    # 记录意外和缺失的键
    unexpected_keys = []  # 存储意外键
    missing_keys = set(pt_model_dict.keys())  # 存储缺失键的集合

    # 遍历 flax_state_dict 中的每个键值对
    for flax_key_tuple, flax_tensor in flax_state_dict.items():
        # 将键元组转换为数组形式
        flax_key_tuple_array = flax_key_tuple.split(".")

        # 如果键的最后一个元素是 "kernel" 且张量维度为 4
        if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4:
            # 将最后一个元素替换为 "weight",并调整张量的维度顺序
            flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
            flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
        # 如果键的最后一个元素是 "kernel"
        elif flax_key_tuple_array[-1] == "kernel":
            # 将最后一个元素替换为 "weight",并转置张量
            flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
            flax_tensor = flax_tensor.T
        # 如果键的最后一个元素是 "scale"
        elif flax_key_tuple_array[-1] == "scale":
            # 将最后一个元素替换为 "weight"
            flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]

        # 如果键数组中不包含 "time_embedding"
        if "time_embedding" not in flax_key_tuple_array:
            # 遍历键数组,替换下划线为点
            for i, flax_key_tuple_string in enumerate(flax_key_tuple_array):
                flax_key_tuple_array[i] = (
                    flax_key_tuple_string.replace("_0", ".0")
                    .replace("_1", ".1")
                    .replace("_2", ".2")
                    .replace("_3", ".3")
                    .replace("_4", ".4")
                    .replace("_5", ".5")
                    .replace("_6", ".6")
                    .replace("_7", ".7")
                    .replace("_8", ".8")
                    .replace("_9", ".9")
                )

        # 将键数组重新连接为字符串
        flax_key = ".".join(flax_key_tuple_array)

        # 如果当前键在 PyTorch 模型的字典中
        if flax_key in pt_model_dict:
            # 如果权重形状不匹配,抛出错误
            if flax_tensor.shape != pt_model_dict[flax_key].shape:
                raise ValueError(
                    f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected "
                    f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
                )
            else:
                # 将权重添加到 PyTorch 字典中
                flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
                pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
                # 从缺失键中移除当前键
                missing_keys.remove(flax_key)
        else:
            # 权重不是 PyTorch 模型所期望的
            unexpected_keys.append(flax_key)

    # 将状态字典加载到 PyTorch 模型中
    pt_model.load_state_dict(pt_model_dict)

    # 将缺失键重新转换为列表
    # 将 missing_keys 转换为列表,以便后续处理
    missing_keys = list(missing_keys)

    # 检查 unexpected_keys 的长度,如果大于 0,表示有未使用的权重
    if len(unexpected_keys) > 0:
        # 记录警告信息,提示某些权重未被使用
        logger.warning(
            "Some weights of the Flax model were not used when initializing the PyTorch model"
            f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
            f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
            " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
            f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
            " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
            " FlaxBertForSequenceClassification model)."
        )
    # 检查 missing_keys 的长度,如果大于 0,表示有权重未被初始化
    if len(missing_keys) > 0:
        # 记录警告信息,提示某些权重是新初始化的
        logger.warning(
            f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
            f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
            " use it for predictions and inference."
        )

    # 返回初始化后的 PyTorch 模型
    return pt_model

.\diffusers\models\modeling_utils.py

# coding=utf-8  # 指定文件编码为 UTF-8
# Copyright 2024 The HuggingFace Inc. team.  # HuggingFace Inc. 团队的版权声明
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.  # NVIDIA 的版权声明
#
# Licensed under the Apache License, Version 2.0 (the "License");  # 指定此文件使用 Apache 2.0 许可证
# you may not use this file except in compliance with the License.  # 使用此文件需要遵循许可证的规定
# You may obtain a copy of the License at  # 可以在以下网址获取许可证
#
#     http://www.apache.org/licenses/LICENSE-2.0  # 许可证的具体链接
#
# Unless required by applicable law or agreed to in writing, software  # 除非法律要求或书面同意
# distributed under the License is distributed on an "AS IS" BASIS,  # 否则按 "现状" 基础分发软件
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  # 不提供任何形式的担保或条件
# See the License for the specific language governing permissions and  # 参见许可证了解特定权限和限制
# limitations under the License.  # 以及许可证下的限制

import inspect  # 导入 inspect 模块,用于获取对象的信息
import itertools  # 导入 itertools 模块,提供高效的迭代器
import json  # 导入 json 模块,用于 JSON 数据的解析和生成
import os  # 导入 os 模块,提供与操作系统交互的功能
import re  # 导入 re 模块,提供正则表达式操作
from collections import OrderedDict  # 从 collections 导入有序字典
from functools import partial  # 从 functools 导入部分函数应用工具
from pathlib import Path  # 从 pathlib 导入路径处理工具
from typing import Any, Callable, List, Optional, Tuple, Union  # 导入类型注解支持

import safetensors  # 导入 safetensors 库,处理安全的张量
import torch  # 导入 PyTorch 库
from huggingface_hub import create_repo, split_torch_state_dict_into_shards  # 从 huggingface_hub 导入相关功能
from huggingface_hub.utils import validate_hf_hub_args  # 导入验证 Hugging Face Hub 参数的工具
from torch import Tensor, nn  # 从 torch 导入 Tensor 和神经网络模块

from .. import __version__  # 从父级模块导入当前版本
from ..utils import (  # 从父级模块的 utils 导入多个工具
    CONFIG_NAME,  # 配置文件名常量
    FLAX_WEIGHTS_NAME,  # Flax 权重文件名常量
    SAFE_WEIGHTS_INDEX_NAME,  # 安全权重索引文件名常量
    SAFETENSORS_WEIGHTS_NAME,  # Safetensors 权重文件名常量
    WEIGHTS_INDEX_NAME,  # 权重索引文件名常量
    WEIGHTS_NAME,  # 权重文件名常量
    _add_variant,  # 导入添加变体的工具
    _get_checkpoint_shard_files,  # 导入获取检查点分片文件的工具
    _get_model_file,  # 导入获取模型文件的工具
    deprecate,  # 导入弃用标记的工具
    is_accelerate_available,  # 导入检测加速库可用性的工具
    is_torch_version,  # 导入检测 PyTorch 版本的工具
    logging,  # 导入日志记录工具
)
from ..utils.hub_utils import (  # 从父级模块的 hub_utils 导入多个工具
    PushToHubMixin,  # 导入用于推送到 Hub 的混合类
    load_or_create_model_card,  # 导入加载或创建模型卡的工具
    populate_model_card,  # 导入填充模型卡的工具
)
from .model_loading_utils import (  # 从当前包的 model_loading_utils 导入多个工具
    _determine_device_map,  # 导入确定设备映射的工具
    _fetch_index_file,  # 导入获取索引文件的工具
    _load_state_dict_into_model,  # 导入将状态字典加载到模型中的工具
    load_model_dict_into_meta,  # 导入将模型字典加载到元数据中的工具
    load_state_dict,  # 导入加载状态字典的工具
)

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

_REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}")  # 编译正则表达式,用于匹配分片文件名

if is_torch_version(">=", "1.9.0"):  # 检查当前 PyTorch 版本是否大于等于 1.9.0
    _LOW_CPU_MEM_USAGE_DEFAULT = True  # 设置低 CPU 内存使用默认值为 True
else:  # 如果 PyTorch 版本小于 1.9.0
    _LOW_CPU_MEM_USAGE_DEFAULT = False  # 设置低 CPU 内存使用默认值为 False

if is_accelerate_available():  # 检查加速库是否可用
    import accelerate  # 如果可用,则导入 accelerate 库

def get_parameter_device(parameter: torch.nn.Module) -> torch.device:  # 定义获取模型参数设备的函数
    try:  # 尝试执行以下代码
        parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())  # 合并模型参数和缓冲区
        return next(parameters_and_buffers).device  # 返回第一个参数或缓冲区的设备
    except StopIteration:  # 如果没有参数和缓冲区
        # For torch.nn.DataParallel compatibility in PyTorch 1.5  # 为兼容 PyTorch 1.5 的 DataParallel

        def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:  # 定义查找张量属性的内部函数
            tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]  # 获取模块中所有张量属性
            return tuples  # 返回张量属性的列表

        gen = parameter._named_members(get_members_fn=find_tensor_attributes)  # 获取模型的命名成员生成器
        first_tuple = next(gen)  # 获取生成器中的第一个元组
        return first_tuple[1].device  # 返回第一个张量的设备

def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:  # 定义获取模型参数数据类型的函数
    try:  # 尝试执行以下代码
        params = tuple(parameter.parameters())  # 将模型参数转换为元组
        if len(params) > 0:  # 如果参数数量大于零
            return params[0].dtype  # 返回第一个参数的数据类型

        buffers = tuple(parameter.buffers())  # 将缓冲区转换为元组
        if len(buffers) > 0:  # 如果缓冲区数量大于零
            return buffers[0].dtype  # 返回第一个缓冲区的数据类型
    # 捕获 StopIteration 异常,处理迭代器停止的情况
    except StopIteration:
        # 为了兼容 PyTorch 1.5 中的 torch.nn.DataParallel

        # 定义一个函数,用于查找模块中所有的张量属性,返回属性名和张量的元组列表
        def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
            # 生成一个元组列表,包含模块中所有张量属性的名称和对应的张量
            tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
            # 返回元组列表
            return tuples

        # 使用指定的函数获取模块的命名成员生成器
        gen = parameter._named_members(get_members_fn=find_tensor_attributes)
        # 获取生成器中的第一个元组
        first_tuple = next(gen)
        # 返回第一个张量的 dtype(数据类型)
        return first_tuple[1].dtype
# 定义一个模型混合类,继承自 PyTorch 的 nn.Module 和 PushToHubMixin
class ModelMixin(torch.nn.Module, PushToHubMixin):
    r"""
    所有模型的基类。

    [`ModelMixin`] 负责存储模型配置,并提供加载、下载和保存模型的方法。

        - **config_name** ([`str`]) -- 保存模型时的文件名,调用 [`~models.ModelMixin.save_pretrained`]。
    """

    # 配置名称,作为模型保存时的文件名
    config_name = CONFIG_NAME
    # 自动保存的参数列表
    _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
    # 是否支持梯度检查点
    _supports_gradient_checkpointing = False
    # 加载时忽略的意外键
    _keys_to_ignore_on_load_unexpected = None
    # 不分割的模块
    _no_split_modules = None

    # 初始化方法
    def __init__(self):
        # 调用父类的初始化方法
        super().__init__()

    # 重写 getattr 方法以优雅地弃用直接访问配置属性
    def __getattr__(self, name: str) -> Any:
        """重写 `getattr` 的唯一原因是优雅地弃用直接访问配置属性。
        参见 https://github.com/huggingface/diffusers/pull/3129 需要在这里重写
        __getattr__,以免触发 `torch.nn.Module` 的 __getattr__:
        https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
        """

        # 检查属性是否在内部字典中,并且是否存在于内部字典的属性中
        is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
        # 检查属性是否在当前实例的字典中
        is_attribute = name in self.__dict__

        # 如果属性在配置中且不在实例字典中,显示弃用警告
        if is_in_config and not is_attribute:
            # 构建弃用消息
            deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
            # 调用弃用函数显示警告
            deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
            # 返回内部字典中的属性值
            return self._internal_dict[name]

        # 调用 PyTorch 的原始 __getattr__ 方法
        return super().__getattr__(name)

    # 定义一个只读属性,检查是否启用了梯度检查点
    @property
    def is_gradient_checkpointing(self) -> bool:
        """
        检查该模型是否启用了梯度检查点。
        """
        # 遍历模型中的所有模块,检查是否有启用梯度检查点的模块
        return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())

    # 启用梯度检查点的方法
    def enable_gradient_checkpointing(self) -> None:
        """
        启用当前模型的梯度检查点(在其他框架中可能称为 *激活检查点* 或
        *检查点激活*)。
        """
        # 检查当前模型是否支持梯度检查点
        if not self._supports_gradient_checkpointing:
            # 如果不支持,抛出值错误
            raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
        # 应用设置,启用梯度检查点
        self.apply(partial(self._set_gradient_checkpointing, value=True))

    # 禁用梯度检查点的方法
    def disable_gradient_checkpointing(self) -> None:
        """
        禁用当前模型的梯度检查点(在其他框架中可能称为 *激活检查点* 或
        *检查点激活*)。
        """
        # 检查当前模型是否支持梯度检查点
        if self._supports_gradient_checkpointing:
            # 应用设置,禁用梯度检查点
            self.apply(partial(self._set_gradient_checkpointing, value=False))
    # 定义一个设置 npu flash attention 开关的方法,接收布尔值 valid
    def set_use_npu_flash_attention(self, valid: bool) -> None:
        r""" 
        设置 npu flash attention 的开关。
        """
    
        # 定义一个递归设置 npu flash attention 的内部方法,接收一个模块
        def fn_recursive_set_npu_flash_attention(module: torch.nn.Module):
            # 如果模块有设置 npu flash attention 的方法,则调用它
            if hasattr(module, "set_use_npu_flash_attention"):
                module.set_use_npu_flash_attention(valid)
    
            # 递归遍历模块的所有子模块
            for child in module.children():
                fn_recursive_set_npu_flash_attention(child)
    
        # 遍历当前对象的所有子模块
        for module in self.children():
            # 如果子模块是一个 torch.nn.Module 类型,则调用递归方法
            if isinstance(module, torch.nn.Module):
                fn_recursive_set_npu_flash_attention(module)
    
    # 定义一个启用 npu flash attention 的方法
    def enable_npu_flash_attention(self) -> None:
        r""" 
        启用来自 torch_npu 的 npu flash attention。
        """
        # 调用设置方法,将开关置为 True
        self.set_use_npu_flash_attention(True)
    
    # 定义一个禁用 npu flash attention 的方法
    def disable_npu_flash_attention(self) -> None:
        r""" 
        禁用来自 torch_npu 的 npu flash attention。
        """
        # 调用设置方法,将开关置为 False
        self.set_use_npu_flash_attention(False)
    
    # 定义一个设置内存高效注意力的 xformers 方法,接收布尔值 valid 和可选的注意力操作
    def set_use_memory_efficient_attention_xformers(
        self, valid: bool, attention_op: Optional[Callable] = None
    ) -> None:
        # 递归遍历所有子模块。
        # 任何暴露 set_use_memory_efficient_attention_xformers 方法的子模块都会接收到消息
        def fn_recursive_set_mem_eff(module: torch.nn.Module):
            # 如果模块有设置内存高效注意力的方法,则调用它
            if hasattr(module, "set_use_memory_efficient_attention_xformers"):
                module.set_use_memory_efficient_attention_xformers(valid, attention_op)
    
            # 递归遍历模块的所有子模块
            for child in module.children():
                fn_recursive_set_mem_eff(child)
    
        # 遍历当前对象的所有子模块
        for module in self.children():
            # 如果子模块是一个 torch.nn.Module 类型,则调用递归方法
            if isinstance(module, torch.nn.Module):
                fn_recursive_set_mem_eff(module)
    # 启用来自 xFormers 的内存高效注意力
        def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None:
            # 文档字符串,描述该方法的功能和使用示例
            r"""
            Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
    
            When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
            inference. Speed up during training is not guaranteed.
    
            <Tip warning={true}>
    
            ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
            precedent.
    
            </Tip>
    
            Parameters:
                attention_op (`Callable`, *optional*):
                    Override the default `None` operator for use as `op` argument to the
                    [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
                    function of xFormers.
    
            Examples:
    
            ```py
            >>> import torch
            >>> from diffusers import UNet2DConditionModel
            >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
    
            >>> model = UNet2DConditionModel.from_pretrained(
            ...     "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
            ... )
            >>> model = model.to("cuda")
            >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
            ```py
            """
            # 设置使用 xFormers 的内存高效注意力,传入可选的注意力操作
            self.set_use_memory_efficient_attention_xformers(True, attention_op)
    
        # 禁用来自 xFormers 的内存高效注意力
        def disable_xformers_memory_efficient_attention(self) -> None:
            # 文档字符串,描述该方法的功能
            r"""
            Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
            """
            # 设置不使用 xFormers 的内存高效注意力
            self.set_use_memory_efficient_attention_xformers(False)
    
        # 保存预训练模型的方法
        def save_pretrained(
            self,
            save_directory: Union[str, os.PathLike],
            is_main_process: bool = True,
            save_function: Optional[Callable] = None,
            safe_serialization: bool = True,
            variant: Optional[str] = None,
            max_shard_size: Union[int, str] = "10GB",
            push_to_hub: bool = False,
            **kwargs,
        @classmethod
        # 类方法,加载预训练模型
        @validate_hf_hub_args
        @classmethod
        def _load_pretrained_model(
            cls,
            model,
            state_dict: OrderedDict,
            resolved_archive_file,
            pretrained_model_name_or_path: Union[str, os.PathLike],
            ignore_mismatched_sizes: bool = False,
        @classmethod
        # 获取对象的构造函数签名参数
        def _get_signature_keys(cls, obj):
            # 获取构造函数的参数字典
            parameters = inspect.signature(obj.__init__).parameters
            # 提取必需的参数
            required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
            # 提取可选参数
            optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
            # 计算期望的模块,排除 'self'
            expected_modules = set(required_parameters.keys()) - {"self"}
    
            return expected_modules, optional_parameters
    
        # 从 transformers 的 modeling_utils.py 修改而来
    # 定义一个私有方法,用于获取在使用 device_map 时不应拆分的模块
    def _get_no_split_modules(self, device_map: str):
        """
        获取模型中在使用 device_map 时不应拆分的模块。我们遍历模块以获取底层的 `_no_split_modules`。
    
        参数:
            device_map (`str`):
                设备映射值。选项包括 ["auto", "balanced", "balanced_low_0", "sequential"]
    
        返回:
            `List[str]`: 不应拆分的模块列表
        """
        # 初始化一个集合,用于存储不应拆分的模块
        _no_split_modules = set()
        # 将当前对象添加到待检查的模块列表中
        modules_to_check = [self]
        # 当待检查模块列表不为空时继续循环
        while len(modules_to_check) > 0:
            # 从待检查列表中弹出最后一个模块
            module = modules_to_check.pop(-1)
            # 如果模块不在不应拆分的模块集合中,检查其子模块
            if module.__class__.__name__ not in _no_split_modules:
                # 如果模块是 ModelMixin 的实例
                if isinstance(module, ModelMixin):
                    # 如果模块的 `_no_split_modules` 属性为 None,抛出异常
                    if module._no_split_modules is None:
                        raise ValueError(
                            f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
                            "class needs to implement the `_no_split_modules` attribute."
                        )
                    # 否则,将模块的不应拆分模块添加到集合中
                    else:
                        _no_split_modules = _no_split_modules | set(module._no_split_modules)
                # 将当前模块的所有子模块添加到待检查列表中
                modules_to_check += list(module.children())
        # 返回不应拆分模块的列表
        return list(_no_split_modules)
    
    # 定义一个属性,用于获取模块所在的设备
    @property
    def device(self) -> torch.device:
        """
        `torch.device`: 模块所在的设备(假设所有模块参数在同一设备上)。
        """
        # 调用函数获取当前对象的参数设备
        return get_parameter_device(self)
    
    # 定义一个属性,用于获取模块的数据类型
    @property
    def dtype(self) -> torch.dtype:
        """
        `torch.dtype`: 模块的数据类型(假设所有模块参数具有相同的数据类型)。
        """
        # 调用函数获取当前对象的参数数据类型
        return get_parameter_dtype(self)
    # 定义一个方法,用于获取模块中的参数数量
    def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
        """
        获取模块中(可训练或非嵌入)参数的数量。
    
        参数:
            only_trainable (`bool`, *可选*, 默认为 `False`):
                是否仅返回可训练参数的数量。
            exclude_embeddings (`bool`, *可选*, 默认为 `False`):
                是否仅返回非嵌入参数的数量。
    
        返回:
            `int`: 参数的数量。
    
        示例:
    
        ```py
        from diffusers import UNet2DConditionModel
    
        model_id = "runwayml/stable-diffusion-v1-5"
        unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
        unet.num_parameters(only_trainable=True)
        859520964
        ```py
        """
    
        # 如果排除嵌入参数
        if exclude_embeddings:
            # 获取所有嵌入层的参数名
            embedding_param_names = [
                f"{name}.weight"
                for name, module_type in self.named_modules()
                if isinstance(module_type, torch.nn.Embedding)
            ]
            # 筛选出非嵌入参数
            non_embedding_parameters = [
                parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
            ]
            # 返回所有非嵌入参数的数量(可训练或非可训练)
            return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
        else:
            # 返回所有参数的数量(可训练或非可训练)
            return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
    # 定义一个方法,用于转换过时的注意力块
    def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
        # 初始化一个列表,用于存储过时注意力块的路径
        deprecated_attention_block_paths = []

        # 定义一个递归函数,用于查找过时的注意力块
        def recursive_find_attn_block(name, module):
            # 检查当前模块是否是过时的注意力块
            if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
                # 将找到的模块名称添加到路径列表中
                deprecated_attention_block_paths.append(name)

            # 遍历模块的子模块
            for sub_name, sub_module in module.named_children():
                # 形成完整的子模块名称
                sub_name = sub_name if name == "" else f"{name}.{sub_name}"
                # 递归查找子模块
                recursive_find_attn_block(sub_name, sub_module)

        # 从当前对象开始递归查找过时的注意力块
        recursive_find_attn_block("", self)

        # 注意:需要检查过时参数是否在状态字典中
        # 因为可能加载的是已经转换过的状态字典

        # 遍历所有找到的过时注意力块路径
        for path in deprecated_attention_block_paths:
            # group_norm 路径保持不变

            # 将 query 参数转换为 to_q
            if f"{path}.query.weight" in state_dict:
                state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
            if f"{path}.query.bias" in state_dict:
                state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")

            # 将 key 参数转换为 to_k
            if f"{path}.key.weight" in state_dict:
                state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
            if f"{path}.key.bias" in state_dict:
                state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")

            # 将 value 参数转换为 to_v
            if f"{path}.value.weight" in state_dict:
                state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
            if f"{path}.value.bias" in state_dict:
                state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")

            # 将 proj_attn 参数转换为 to_out.0
            if f"{path}.proj_attn.weight" in state_dict:
                state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
            if f"{path}.proj_attn.bias" in state_dict:
                state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
    # 将当前对象的注意力模块转换为已弃用的注意力块
    def _temp_convert_self_to_deprecated_attention_blocks(self) -> None:
        # 初始化一个列表,用于存储已弃用的注意力块模块
        deprecated_attention_block_modules = []
    
        # 定义递归函数以查找注意力块模块
        def recursive_find_attn_block(module):
            # 检查模块是否为已弃用的注意力块
            if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
                # 将找到的模块添加到列表中
                deprecated_attention_block_modules.append(module)
    
            # 遍历子模块并递归调用
            for sub_module in module.children():
                recursive_find_attn_block(sub_module)
    
        # 从当前对象开始递归查找
        recursive_find_attn_block(self)
    
        # 遍历所有已弃用的注意力块模块
        for module in deprecated_attention_block_modules:
            # 将新属性赋值给相应的旧属性
            module.query = module.to_q
            module.key = module.to_k
            module.value = module.to_v
            module.proj_attn = module.to_out[0]
    
            # 删除旧属性以确保所有权重都加载到新属性中
            del module.to_q
            del module.to_k
            del module.to_v
            del module.to_out
    
    # 将已弃用的注意力块模块恢复为当前对象的注意力模块
    def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None:
        # 初始化一个列表,用于存储已弃用的注意力块模块
        deprecated_attention_block_modules = []
    
        # 定义递归函数以查找注意力块模块
        def recursive_find_attn_block(module) -> None:
            # 检查模块是否为已弃用的注意力块
            if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
                # 将找到的模块添加到列表中
                deprecated_attention_block_modules.append(module)
    
            # 遍历子模块并递归调用
            for sub_module in module.children():
                recursive_find_attn_block(sub_module)
    
        # 从当前对象开始递归查找
        recursive_find_attn_block(self)
    
        # 遍历所有已弃用的注意力块模块
        for module in deprecated_attention_block_modules:
            # 将旧属性赋值给相应的新属性
            module.to_q = module.query
            module.to_k = module.key
            module.to_v = module.value
            module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
    
            # 删除新属性以恢复旧的模块结构
            del module.query
            del module.key
            del module.value
            del module.proj_attn
# 定义一个继承自 ModelMixin 的类,用于处理从旧类到特定管道类的映射
class LegacyModelMixin(ModelMixin):
    r"""
    一个 `ModelMixin` 的子类,用于从旧类(如 `Transformer2DModel`)解析到更具体的管道类(如 `DiTTransformer2DModel`)的类映射。
    """

    @classmethod
    @validate_hf_hub_args
    def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
        # 为了避免依赖导入问题
        from .model_loading_utils import _fetch_remapped_cls_from_config

        # 创建 kwargs 的副本,以避免对后续调用中的关键字参数造成影响
        kwargs_copy = kwargs.copy()

        # 从 kwargs 中提取 cache_dir 参数,若未提供则为 None
        cache_dir = kwargs.pop("cache_dir", None)
        # 从 kwargs 中提取 force_download 参数,默认为 False
        force_download = kwargs.pop("force_download", False)
        # 从 kwargs 中提取 proxies 参数,默认为 None
        proxies = kwargs.pop("proxies", None)
        # 从 kwargs 中提取 local_files_only 参数,默认为 None
        local_files_only = kwargs.pop("local_files_only", None)
        # 从 kwargs 中提取 token 参数,默认为 None
        token = kwargs.pop("token", None)
        # 从 kwargs 中提取 revision 参数,默认为 None
        revision = kwargs.pop("revision", None)
        # 从 kwargs 中提取 subfolder 参数,默认为 None
        subfolder = kwargs.pop("subfolder", None)

        # 如果未提供配置,则将配置路径设置为预训练模型名称或路径
        config_path = pretrained_model_name_or_path

        # 设置用户代理信息
        user_agent = {
            "diffusers": __version__,
            "file_type": "model",
            "framework": "pytorch",
        }

        # 加载配置
        config, _, _ = cls.load_config(
            config_path,
            cache_dir=cache_dir,
            return_unused_kwargs=True,
            return_commit_hash=True,
            force_download=force_download,
            proxies=proxies,
            local_files_only=local_files_only,
            token=token,
            revision=revision,
            subfolder=subfolder,
            user_agent=user_agent,
            **kwargs,
        )
        # 解析类的映射
        remapped_class = _fetch_remapped_cls_from_config(config, cls)

        # 返回映射后的类的 from_pretrained 方法调用
        return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)

.\diffusers\models\model_loading_utils.py

# 指定编码为 UTF-8
# coding=utf-8
# 版权声明,表明此文件的版权归 HuggingFace Inc. 团队所有
# Copyright 2024 The HuggingFace Inc. team.
# 版权声明,表明此文件的版权归 NVIDIA CORPORATION 所有
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# 根据 Apache 许可证第 2.0 版进行许可
# Licensed under the Apache License, Version 2.0 (the "License");
# 使用此文件必须遵守许可证
# you may not use this file except in compliance with the License.
# 可以在此处获取许可证的副本
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,软件在 "AS IS" 基础上分发
# Unless required by applicable law or agreed to in writing, software
# 不提供任何明示或暗示的担保或条件
# distributed under the License is distributed on an "AS IS" BASIS,
# 查看许可证以获取特定权限和限制的详细信息
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入标准库中的 importlib 模块
import importlib
# 导入 inspect 模块,用于检查对象
import inspect
# 导入操作系统模块
import os
# 从 collections 导入 OrderedDict,用于保持字典的顺序
from collections import OrderedDict
# 从 pathlib 导入 Path,处理文件路径
from pathlib import Path
# 导入 List、Optional 和 Union 类型提示
from typing import List, Optional, Union

# 导入 safetensors 模块
import safetensors
# 导入 PyTorch 库
import torch
# 从 huggingface_hub.utils 导入 EntryNotFoundError 异常
from huggingface_hub.utils import EntryNotFoundError

# 从 utils 模块中导入常量和函数
from ..utils import (
    SAFE_WEIGHTS_INDEX_NAME,
    SAFETENSORS_FILE_EXTENSION,
    WEIGHTS_INDEX_NAME,
    _add_variant,
    _get_model_file,
    is_accelerate_available,
    is_torch_version,
    logging,
)

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义类重映射字典,将旧类名映射到新类名
_CLASS_REMAPPING_DICT = {
    "Transformer2DModel": {
        "ada_norm_zero": "DiTTransformer2DModel",
        "ada_norm_single": "PixArtTransformer2DModel",
    }
}

# 如果可用,导入加速库的相关功能
if is_accelerate_available():
    from accelerate import infer_auto_device_map
    from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device

# 根据模型和设备映射确定设备映射
# Adapted from `transformers` (see modeling_utils.py)
def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype):
    # 如果 device_map 是字符串,获取不拆分模块
    if isinstance(device_map, str):
        no_split_modules = model._get_no_split_modules(device_map)
        device_map_kwargs = {"no_split_module_classes": no_split_modules}

        # 如果 device_map 不是 "sequential",计算平衡内存
        if device_map != "sequential":
            max_memory = get_balanced_memory(
                model,
                dtype=torch_dtype,
                low_zero=(device_map == "balanced_low_0"),
                max_memory=max_memory,
                **device_map_kwargs,
            )
        # 否则获取最大内存
        else:
            max_memory = get_max_memory(max_memory)

        # 更新 device_map 参数并推断设备映射
        device_map_kwargs["max_memory"] = max_memory
        device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs)

    # 返回最终的设备映射
    return device_map

# 从配置中获取重映射的类
def _fetch_remapped_cls_from_config(config, old_class):
    # 获取旧类的名称
    previous_class_name = old_class.__name__
    # 根据配置中的 norm_type 查找重映射的类名
    remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name).get(config["norm_type"], None)

    # 详细信息:
    # https://github.com/huggingface/diffusers/pull/7647#discussion_r1621344818
    # 如果 remapped_class_name 存在
        if remapped_class_name:
            # 加载 diffusers 库以导入兼容的原始调度器
            diffusers_library = importlib.import_module(__name__.split(".")[0])
            # 从 diffusers 库中获取 remapped_class_name 指定的类
            remapped_class = getattr(diffusers_library, remapped_class_name)
            # 记录日志,说明类对象正在更改,因之前的类将在未来版本中弃用
            logger.info(
                f"Changing class object to be of `{remapped_class_name}` type from `{previous_class_name}` type."
                f"This is because `{previous_class_name}` is scheduled to be deprecated in a future version. Note that this"
                " DOESN'T affect the final results."
            )
            # 返回映射后的类
            return remapped_class
        else:
            # 如果没有 remapped_class_name,返回旧类
            return old_class
# 定义一个函数,用于加载检查点文件,返回格式化的错误信息(如有)
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
    """
    读取检查点文件,如果出现错误,则返回正确格式的错误信息。
    """
    try:
        # 获取检查点文件名的扩展名
        file_extension = os.path.basename(checkpoint_file).split(".")[-1]
        # 如果文件扩展名是 SAFETENSORS_FILE_EXTENSION,则使用 safetensors 加载文件
        if file_extension == SAFETENSORS_FILE_EXTENSION:
            return safetensors.torch.load_file(checkpoint_file, device="cpu")
        else:
            # 检查 PyTorch 版本,如果大于等于 1.13,则设置 weights_only 参数
            weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
            # 加载检查点文件,并将模型权重映射到 CPU
            return torch.load(
                checkpoint_file,
                map_location="cpu",
                **weights_only_kwarg,
            )
    except Exception as e:
        try:
            # 尝试打开检查点文件
            with open(checkpoint_file) as f:
                # 检查文件是否以 "version" 开头,以确定是否缺少 git-lfs
                if f.read().startswith("version"):
                    raise OSError(
                        "您似乎克隆了一个没有安装 git-lfs 的库。请安装 "
                        "git-lfs 并在克隆的文件夹中运行 `git lfs install` 以及 `git lfs pull`。"
                    )
                else:
                    # 如果文件不存在,抛出 ValueError
                    raise ValueError(
                        f"无法找到加载此预训练模型所需的文件 {checkpoint_file}。请确保已正确保存模型。"
                    ) from e
        except (UnicodeDecodeError, ValueError):
            # 如果读取文件时出现错误,抛出 OSError
            raise OSError(
                f"无法从检查点文件加载权重 '{checkpoint_file}' " f"在 '{checkpoint_file}'。"
            )


# 定义一个函数,将模型状态字典加载到元数据中
def load_model_dict_into_meta(
    model,
    state_dict: OrderedDict,
    device: Optional[Union[str, torch.device]] = None,
    dtype: Optional[Union[str, torch.dtype]] = None,
    model_name_or_path: Optional[str] = None,
) -> List[str]:
    # 如果未提供设备,则默认使用 CPU
    device = device or torch.device("cpu")
    # 如果未提供数据类型,则默认使用 float32
    dtype = dtype or torch.float32

    # 检查 set_module_tensor_to_device 函数是否接受 dtype 参数
    accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())

    # 初始化一个列表以存储意外的键
    unexpected_keys = []
    # 获取模型的空状态字典
    empty_state_dict = model.state_dict()
    # 遍历状态字典中的每个参数名称和对应的参数值
    for param_name, param in state_dict.items():
        # 如果参数名称不在空状态字典中,则记录为意外的键
        if param_name not in empty_state_dict:
            unexpected_keys.append(param_name)
            continue  # 跳过本次循环,继续下一个参数

        # 检查空状态字典中对应参数的形状是否与当前参数的形状匹配
        if empty_state_dict[param_name].shape != param.shape:
            # 如果模型路径存在,则格式化字符串以包含模型路径
            model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
            # 抛出值错误,提示参数形状不匹配,并给出解决方案和参考链接
            raise ValueError(
                f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
            )

        # 如果接受数据类型,则将参数设置到模型的指定设备上,并指定数据类型
        if accepts_dtype:
            set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
        else:
            # 如果不接受数据类型,则仅将参数设置到模型的指定设备上
            set_module_tensor_to_device(model, param_name, device, value=param)
    # 返回意外的键列表
    return unexpected_keys
# 定义一个函数,将状态字典加载到模型中,并返回错误信息列表
def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
    # 如果需要,从 PyTorch 的 state_dict 转换旧格式到新格式
    # 复制 state_dict,以便 _load_from_state_dict 可以对其进行修改
    state_dict = state_dict.copy()
    # 用于存储加载过程中的错误信息
    error_msgs = []

    # PyTorch 的 `_load_from_state_dict` 不会复制模块子孙中的参数
    # 所以我们需要递归地应用这个函数
    def load(module: torch.nn.Module, prefix: str = ""):
        # 准备参数,调用模块的 `_load_from_state_dict` 方法
        args = (state_dict, prefix, {}, True, [], [], error_msgs)
        module._load_from_state_dict(*args)

        # 遍历模块的所有子模块
        for name, child in module._modules.items():
            # 如果子模块存在,递归加载
            if child is not None:
                load(child, prefix + name + ".")

    # 初始调用加载模型
    load(model_to_load)

    # 返回所有错误信息
    return error_msgs


# 定义一个函数,获取索引文件的路径
def _fetch_index_file(
    is_local,
    pretrained_model_name_or_path,
    subfolder,
    use_safetensors,
    cache_dir,
    variant,
    force_download,
    proxies,
    local_files_only,
    token,
    revision,
    user_agent,
    commit_hash,
):
    # 如果是本地文件
    if is_local:
        # 构造索引文件的路径
        index_file = Path(
            pretrained_model_name_or_path,
            subfolder or "",  # 如果子文件夹为空,则使用空字符串
            _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
        )
    else:
        # 构造索引文件在远程仓库中的路径
        index_file_in_repo = Path(
            subfolder or "",  # 如果子文件夹为空,则使用空字符串
            _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
        ).as_posix()  # 转换为 POSIX 路径格式
        try:
            # 获取模型文件的路径
            index_file = _get_model_file(
                pretrained_model_name_or_path,
                weights_name=index_file_in_repo,  # 指定权重文件名
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                local_files_only=local_files_only,
                token=token,
                revision=revision,
                subfolder=None,  # 子文件夹为 None
                user_agent=user_agent,
                commit_hash=commit_hash,
            )
            # 将返回的路径转换为 Path 对象
            index_file = Path(index_file)
        except (EntryNotFoundError, EnvironmentError):
            # 如果找不到文件或发生环境错误,将索引文件设置为 None
            index_file = None

    # 返回索引文件的路径
    return index_file

.\diffusers\models\normalization.py

# 指定文件编码为 UTF-8
# copyright 信息,标识版权所有者及年份
# 许可证声明,指明使用的许可证类型及条件
# 提供许可证的获取链接
# 声明在适用情况下,软件是以“原样”方式分发的,且不提供任何形式的担保或条件
# 引用许可证中关于权限和限制的具体条款

# 导入 numbers 模块,用于处理数值相关的操作
from typing import Dict, Optional, Tuple  # 导入类型提示所需的类型

# 导入 PyTorch 相关模块和功能
import torch
import torch.nn as nn  # 导入神经网络模块
import torch.nn.functional as F  # 导入功能性神经网络操作模块

# 导入工具函数以检查 PyTorch 版本
from ..utils import is_torch_version
# 导入激活函数获取方法
from .activations import get_activation
# 导入嵌入层相关类
from .embeddings import (
    CombinedTimestepLabelEmbeddings,
    PixArtAlphaCombinedTimestepSizeEmbeddings,
)


class AdaLayerNorm(nn.Module):  # 定义自定义的层归一化类,继承自 nn.Module
    r"""  # 文档字符串,描述此类的功能和参数
    Norm layer modified to incorporate timestep embeddings.  # 说明此层归一化是为了支持时间步嵌入

    Parameters:
        embedding_dim (`int`): The size of each embedding vector.  # 嵌入向量的维度
        num_embeddings (`int`, *optional*): The size of the embeddings dictionary.  # 嵌入字典的大小(可选)
        output_dim (`int`, *optional*):  # 输出维度(可选)
        norm_elementwise_affine (`bool`, defaults to `False):  # 是否应用元素级仿射变换(默认 False)
        norm_eps (`bool`, defaults to `False`):  # 归一化时的小常数(默认 1e-5)
        chunk_dim (`int`, defaults to `0`):  # 分块维度(默认 0)
    """

    def __init__(  # 初始化方法,定义类的构造函数
        self,
        embedding_dim: int,  # 嵌入维度
        num_embeddings: Optional[int] = None,  # 嵌入字典的大小(可选)
        output_dim: Optional[int] = None,  # 输出维度(可选)
        norm_elementwise_affine: bool = False,  # 是否应用元素级仿射变换
        norm_eps: float = 1e-5,  # 归一化时的小常数
        chunk_dim: int = 0,  # 分块维度
    ):
        super().__init__()  # 调用父类构造函数

        self.chunk_dim = chunk_dim  # 保存分块维度
        output_dim = output_dim or embedding_dim * 2  # 如果未指定输出维度,则计算输出维度

        if num_embeddings is not None:  # 如果指定了嵌入字典大小
            self.emb = nn.Embedding(num_embeddings, embedding_dim)  # 初始化嵌入层
        else:
            self.emb = None  # 嵌入层为 None

        self.silu = nn.SiLU()  # 初始化 SiLU 激活函数
        self.linear = nn.Linear(embedding_dim, output_dim)  # 初始化线性层
        self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)  # 初始化层归一化

    def forward(  # 定义前向传播方法
        self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None  # 输入张量及可选时间步和嵌入
    ) -> torch.Tensor:  # 返回类型为张量
        if self.emb is not None:  # 如果嵌入层存在
            temb = self.emb(timestep)  # 通过嵌入层计算时间步的嵌入

        temb = self.linear(self.silu(temb))  # 应用激活函数并通过线性层处理嵌入

        if self.chunk_dim == 1:  # 如果分块维度为 1
            # 对于 CogVideoX 的特殊情况,分割嵌入为偏移量和缩放量
            shift, scale = temb.chunk(2, dim=1)  # 按照维度 1 分块
            shift = shift[:, None, :]  # 扩展偏移量维度
            scale = scale[:, None, :]  # 扩展缩放量维度
        else:  # 如果分块维度不是 1
            scale, shift = temb.chunk(2, dim=0)  # 按照维度 0 分块

        x = self.norm(x) * (1 + scale) + shift  # 进行层归一化,并应用缩放和偏移
        return x  # 返回结果


class FP32LayerNorm(nn.LayerNorm):  # 定义 FP32 层归一化类,继承自 nn.LayerNorm
    # 定义前向传播方法,接受输入张量并返回输出张量
        def forward(self, inputs: torch.Tensor) -> torch.Tensor:
            # 保存输入张量的数据类型
            origin_dtype = inputs.dtype
            # 进行层归一化处理,并将结果转换回原始数据类型
            return F.layer_norm(
                # 将输入张量转换为浮点型进行归一化
                inputs.float(),
                # 归一化的形状
                self.normalized_shape,
                # 如果权重存在,将其转换为浮点型;否则为 None
                self.weight.float() if self.weight is not None else None,
                # 如果偏置存在,将其转换为浮点型;否则为 None
                self.bias.float() if self.bias is not None else None,
                # 设置一个小的数值以避免除零
                self.eps,
            ).to(origin_dtype)  # 将归一化后的结果转换回原始数据类型
# 定义自适应层归一化零层的类
class AdaLayerNormZero(nn.Module):
    r"""
    自适应层归一化零层 (adaLN-Zero)。

    参数:
        embedding_dim (`int`): 每个嵌入向量的大小。
        num_embeddings (`int`): 嵌入字典的大小。
    """

    # 初始化方法,接收嵌入维度和可选的嵌入数量及归一化类型
    def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
        # 调用父类初始化方法
        super().__init__()
        # 如果提供了嵌入数量,初始化嵌入层
        if num_embeddings is not None:
            self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
        else:
            # 否则,嵌入层设置为 None
            self.emb = None

        # 初始化 SiLU 激活函数
        self.silu = nn.SiLU()
        # 初始化线性变换层,输出维度为 6 倍的嵌入维度
        self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
        # 根据提供的归一化类型,初始化归一化层
        if norm_type == "layer_norm":
            self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
        elif norm_type == "fp32_layer_norm":
            self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
        else:
            # 如果提供了不支持的归一化类型,抛出错误
            raise ValueError(
                f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
            )

    # 定义前向传播方法
    def forward(
        self,
        x: torch.Tensor,
        timestep: Optional[torch.Tensor] = None,
        class_labels: Optional[torch.LongTensor] = None,
        hidden_dtype: Optional[torch.dtype] = None,
        emb: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # 如果嵌入层不为 None,则计算嵌入
        if self.emb is not None:
            emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
        # 先经过 SiLU 激活函数再经过线性变换
        emb = self.linear(self.silu(emb))
        # 将嵌入切分为 6 个部分
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
        # 对输入 x 应用归一化,并结合缩放和偏移
        x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
        # 返回处理后的 x 及其他信息
        return x, gate_msa, shift_mlp, scale_mlp, gate_mlp


# 定义自适应层归一化零层单一版本的类
class AdaLayerNormZeroSingle(nn.Module):
    r"""
    自适应层归一化零层 (adaLN-Zero)。

    参数:
        embedding_dim (`int`): 每个嵌入向量的大小。
        num_embeddings (`int`): 嵌入字典的大小。
    """

    # 初始化方法,接收嵌入维度和归一化类型
    def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
        # 调用父类初始化方法
        super().__init__()

        # 初始化 SiLU 激活函数
        self.silu = nn.SiLU()
        # 初始化线性变换层,输出维度为 3 倍的嵌入维度
        self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
        # 根据提供的归一化类型,初始化归一化层
        if norm_type == "layer_norm":
            self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
        else:
            # 如果提供了不支持的归一化类型,抛出错误
            raise ValueError(
                f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
            )

    # 定义前向传播方法
    def forward(
        self,
        x: torch.Tensor,
        emb: Optional[torch.Tensor] = None,
    # 定义一个函数的返回类型为五个张量的元组
        ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # 通过线性层和Silu激活函数处理嵌入向量
            emb = self.linear(self.silu(emb))
        # 将处理后的嵌入向量分割成三个部分:shift_msa, scale_msa 和 gate_msa
            shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
        # 对输入x进行归一化,并结合scale和shift进行变换
            x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
        # 返回变换后的x和gate_msa
            return x, gate_msa
# 定义 LuminaRMSNormZero 类,继承自 nn.Module
class LuminaRMSNormZero(nn.Module):
    """
    Norm layer adaptive RMS normalization zero.

    Parameters:
        embedding_dim (`int`): The size of each embedding vector.
    """

    # 初始化方法,设置嵌入维度、正则化参数和元素级偏置
    def __init__(self, embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool):
        # 调用父类构造函数
        super().__init__()
        # 初始化 SiLU 激活函数
        self.silu = nn.SiLU()
        # 初始化线性变换层,输入为 embedding_dim 或 1024 中的较小值,输出为 4 倍的 embedding_dim
        self.linear = nn.Linear(
            min(embedding_dim, 1024),
            4 * embedding_dim,
            bias=True,
        )
        # 初始化 RMSNorm 层
        self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)

    # 前向传播方法
    def forward(
        self,
        x: torch.Tensor,
        emb: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # 对 emb 应用线性变换和 SiLU 激活
        emb = self.linear(self.silu(emb))
        # 将嵌入分块为四个部分
        scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
        # 对输入 x 应用 RMSNorm 并与 scale_msa 相乘
        x = self.norm(x) * (1 + scale_msa[:, None])

        # 返回处理后的 x 以及门控和缩放值
        return x, gate_msa, scale_mlp, gate_mlp


# 定义 AdaLayerNormSingle 类,继承自 nn.Module
class AdaLayerNormSingle(nn.Module):
    r"""
    Norm layer adaptive layer norm single (adaLN-single).

    As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).

    Parameters:
        embedding_dim (`int`): The size of each embedding vector.
        use_additional_conditions (`bool`): To use additional conditions for normalization or not.
    """

    # 初始化方法,设置嵌入维度和是否使用额外条件
    def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
        # 调用父类构造函数
        super().__init__()

        # 初始化 PixArtAlphaCombinedTimestepSizeEmbeddings,用于时间步嵌入
        self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
            embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
        )

        # 初始化 SiLU 激活函数
        self.silu = nn.SiLU()
        # 初始化线性变换层,输出为 6 倍的嵌入维度
        self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)

    # 前向传播方法
    def forward(
        self,
        timestep: torch.Tensor,
        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
        batch_size: Optional[int] = None,
        hidden_dtype: Optional[torch.dtype] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # 嵌入时间步,可能使用额外的条件
        embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
        # 返回线性变换后的嵌入和嵌入结果
        return self.linear(self.silu(embedded_timestep)), embedded_timestep


# 定义 AdaGroupNorm 类,继承自 nn.Module
class AdaGroupNorm(nn.Module):
    r"""
    GroupNorm layer modified to incorporate timestep embeddings.

    Parameters:
        embedding_dim (`int`): The size of each embedding vector.
        num_embeddings (`int`): The size of the embeddings dictionary.
        num_groups (`int`): The number of groups to separate the channels into.
        act_fn (`str`, *optional*, defaults to `None`): The activation function to use.
        eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability.
    """
    # 初始化方法,用于设置类的基本属性
        def __init__(
            # 嵌入向量的维度
            self, embedding_dim: int, 
            # 输出向量的维度
            out_dim: int, 
            # 组的数量
            num_groups: int, 
            # 激活函数名称(可选)
            act_fn: Optional[str] = None, 
            # 防止除零错误的微小值
            eps: float = 1e-5
        ):
            # 调用父类初始化方法
            super().__init__()
            # 设置组的数量
            self.num_groups = num_groups
            # 设置用于数值稳定性的微小值
            self.eps = eps
    
            # 如果没有提供激活函数,则设置为 None
            if act_fn is None:
                self.act = None
            else:
                # 根据激活函数名称获取激活函数
                self.act = get_activation(act_fn)
    
            # 创建一个线性层,将嵌入维度映射到输出维度的两倍
            self.linear = nn.Linear(embedding_dim, out_dim * 2)
    
        # 前向传播方法,定义输入数据的处理方式
        def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
            # 如果存在激活函数,则对嵌入进行激活
            if self.act:
                emb = self.act(emb)
            # 将嵌入传递通过线性层
            emb = self.linear(emb)
            # 扩展嵌入的维度,以适配后续操作
            emb = emb[:, :, None, None]
            # 将嵌入分割为缩放因子和偏移量
            scale, shift = emb.chunk(2, dim=1)
    
            # 对输入数据进行分组归一化
            x = F.group_norm(x, self.num_groups, eps=self.eps)
            # 使用缩放因子和偏移量调整归一化后的数据
            x = x * (1 + scale) + shift
            # 返回处理后的数据
            return x
# 定义一个自定义的神经网络模块,继承自 nn.Module
class AdaLayerNormContinuous(nn.Module):
    # 初始化方法,接受多个参数以配置层的特性
    def __init__(
        self,
        embedding_dim: int,  # 嵌入维度
        conditioning_embedding_dim: int,  # 条件嵌入维度
        # 注释:规范层可以配置缩放和偏移参数有点奇怪,因为输出会被投影的条件嵌入立即缩放和偏移。
        # 注意,AdaLayerNorm 不允许规范层有缩放和偏移参数。
        # 但是这是原始代码中的实现,您应该将 `elementwise_affine` 设置为 False。
        elementwise_affine=True,  # 是否允许元素级的仿射变换
        eps=1e-5,  # 防止除零错误的小值
        bias=True,  # 是否在全连接层中使用偏置
        norm_type="layer_norm",  # 规范化类型
    ):
        super().__init__()  # 调用父类构造函数
        self.silu = nn.SiLU()  # 定义 SiLU 激活函数
        self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)  # 全连接层,输出两倍嵌入维度
        # 根据指定的规范类型初始化规范层
        if norm_type == "layer_norm":
            self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)  # 层规范化
        elif norm_type == "rms_norm":
            self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)  # RMS 规范化
        else:
            raise ValueError(f"unknown norm_type {norm_type}")  # 抛出错误,若规范类型未知

    # 前向传播方法,定义如何计算输出
    def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
        # 将条件嵌入转换为与输入 x 相同的数据类型
        emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))  # 应用激活函数和全连接层
        scale, shift = torch.chunk(emb, 2, dim=1)  # 将输出拆分为缩放和偏移
        # 规范化输入 x,并进行缩放和偏移操作
        x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]  # 返回处理后的输出
        return x  # 返回最终结果


# 定义另一个自定义的神经网络模块,继承自 nn.Module
class LuminaLayerNormContinuous(nn.Module):
    # 初始化方法,接受多个参数以配置层的特性
    def __init__(
        self,
        embedding_dim: int,  # 嵌入维度
        conditioning_embedding_dim: int,  # 条件嵌入维度
        # 注释:规范层可以配置缩放和偏移参数有点奇怪,因为输出会被投影的条件嵌入立即缩放和偏移。
        # 注意,AdaLayerNorm 不允许规范层有缩放和偏移参数。
        # 但是这是原始代码中的实现,您应该将 `elementwise_affine` 设置为 False。
        elementwise_affine=True,  # 是否允许元素级的仿射变换
        eps=1e-5,  # 防止除零错误的小值
        bias=True,  # 是否在全连接层中使用偏置
        norm_type="layer_norm",  # 规范化类型
        out_dim: Optional[int] = None,  # 可选的输出维度
    ):
        super().__init__()  # 调用父类构造函数
        # AdaLN
        self.silu = nn.SiLU()  # 定义 SiLU 激活函数
        self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)  # 全连接层,将条件嵌入映射到嵌入维度
        # 根据指定的规范类型初始化规范层
        if norm_type == "layer_norm":
            self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)  # 层规范化
        else:
            raise ValueError(f"unknown norm_type {norm_type}")  # 抛出错误,若规范类型未知
        # 如果指定了输出维度,则创建第二个全连接层
        if out_dim is not None:
            self.linear_2 = nn.Linear(
                embedding_dim,  # 输入维度为嵌入维度
                out_dim,  # 输出维度
                bias=bias,  # 是否使用偏置
            )

    # 前向传播方法,定义如何计算输出
    def forward(
        self,
        x: torch.Tensor,  # 输入张量
        conditioning_embedding: torch.Tensor,  # 条件嵌入张量
    # 返回一个张量,类型为 torch.Tensor
    ) -> torch.Tensor:
        # 将条件嵌入转换回原始数据类型,以防止其被提升为 float32(用于 hunyuanDiT)
        emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
        # 将嵌入值赋给 scale
        scale = emb
        # 对输入 x 进行规范化,并乘以(1 + scale),同时在新维度上扩展
        x = self.norm(x) * (1 + scale)[:, None, :]
    
        # 如果 linear_2 存在,则对 x 应用 linear_2
        if self.linear_2 is not None:
            x = self.linear_2(x)
    
        # 返回处理后的张量 x
        return x
# 定义一个自定义的层,继承自 nn.Module
class CogVideoXLayerNormZero(nn.Module):
    # 初始化方法,定义该层的参数
    def __init__(
        self,
        conditioning_dim: int,  # 输入的条件维度
        embedding_dim: int,  # 嵌入的维度
        elementwise_affine: bool = True,  # 是否启用逐元素仿射变换
        eps: float = 1e-5,  # 防止除零的一个小常数
        bias: bool = True,  # 是否添加偏置
    ) -> None:
        # 调用父类的初始化方法
        super().__init__()

        # 使用 SiLU 激活函数
        self.silu = nn.SiLU()
        # 线性变换,将条件维度映射到 6 倍的嵌入维度
        self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
        # 归一化层,使用层归一化
        self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)

    # 前向传播方法,定义输入和输出
    def forward(
        self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # 通过线性层处理 temb,并分成 6 个部分
        shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
        # 对隐藏状态进行归一化并应用缩放和平移
        hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
        # 对编码器隐藏状态进行相同处理
        encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
        # 返回处理后的隐藏状态和编码器隐藏状态,以及门控信号
        return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]


# 根据 PyTorch 版本决定是否使用标准 LayerNorm
if is_torch_version(">=", "2.1.0"):
    # 使用标准的 LayerNorm
    LayerNorm = nn.LayerNorm
else:
    # 定义自定义的 LayerNorm 类,兼容旧版本 PyTorch
    # Has optional bias parameter compared to torch layer norm
    # TODO: replace with torch layernorm once min required torch version >= 2.1
    class LayerNorm(nn.Module):
        # 初始化方法
        def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
            # 调用父类的初始化方法
            super().__init__()

            # 设置小常数以避免除零
            self.eps = eps

            # 如果维度是整数,则转为元组
            if isinstance(dim, numbers.Integral):
                dim = (dim,)

            # 保存维度信息
            self.dim = torch.Size(dim)

            # 如果启用逐元素仿射,则初始化权重和偏置
            if elementwise_affine:
                self.weight = nn.Parameter(torch.ones(dim))
                self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
            else:
                self.weight = None
                self.bias = None

        # 前向传播方法
        def forward(self, input):
            # 应用层归一化
            return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)


# 定义 RMSNorm 类,继承自 nn.Module
class RMSNorm(nn.Module):
    # 初始化方法
    def __init__(self, dim, eps: float, elementwise_affine: bool = True):
        # 调用父类的初始化方法
        super().__init__()

        # 设置小常数以避免除零
        self.eps = eps

        # 如果维度是整数,则转为元组
        if isinstance(dim, numbers.Integral):
            dim = (dim,)

        # 保存维度信息
        self.dim = torch.Size(dim)

        # 如果启用逐元素仿射,则初始化权重
        if elementwise_affine:
            self.weight = nn.Parameter(torch.ones(dim))
        else:
            self.weight = None

    # 前向传播方法
    def forward(self, hidden_states):
        # 保存输入数据类型
        input_dtype = hidden_states.dtype
        # 计算输入的方差
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        # 对隐藏状态进行缩放
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)

        # 如果有权重,则进行进一步处理
        if self.weight is not None:
            # 如果需要,将隐藏状态转换为半精度
            if self.weight.dtype in [torch.float16, torch.bfloat16]:
                hidden_states = hidden_states.to(self.weight.dtype)
            # 应用权重
            hidden_states = hidden_states * self.weight
        else:
            # 将隐藏状态转换回原数据类型
            hidden_states = hidden_states.to(input_dtype)

        # 返回处理后的隐藏状态
        return hidden_states
# 定义一个全局响应归一化的类,继承自 nn.Module
class GlobalResponseNorm(nn.Module):
    # 初始化方法,接受一个维度参数 dim
    def __init__(self, dim):
        # 调用父类构造函数
        super().__init__()
        # 初始化可学习参数 gamma,形状为 (1, 1, 1, dim)
        self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
        # 初始化可学习参数 beta,形状为 (1, 1, 1, dim)
        self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))

    # 定义前向传播方法,接受输入 x
    def forward(self, x):
        # 计算输入 x 在 (1, 2) 维度上的 L2 范数,保持维度
        gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
        # 归一化 gx,计算每个样本的均值并防止除以零
        nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
        # 返回归一化后的结果,加上可学习的 gamma 和 beta
        return self.gamma * (x * nx) + self.beta + x

标签:dim,None,dtype,self,torch,diffusers,源码,model,解析
From: https://www.cnblogs.com/apachecn/p/18492370

相关文章

  • diffusers-源码解析-二十一-
    diffusers源码解析(二十一).\diffusers\pipelines\controlnet\pipeline_controlnet.py#版权信息,指明该代码由HuggingFace团队版权所有##根据Apache2.0许可证授权,用户需遵循许可证规定使用该文件#许可证可以在以下网址获取##http://www.apache.org/licenses/L......
  • diffusers-源码解析-二十四-
    diffusers源码解析(二十四).\diffusers\pipelines\controlnet_sd3\pipeline_stable_diffusion_3_controlnet.py#版权声明,指出版权所有者及相关信息#Copyright2024StabilityAI,TheHuggingFaceTeamandTheInstantXTeam.Allrightsreserved.##按照Apache2.0许可......
  • diffusers-源码解析-二十三-
    diffusers源码解析(二十三).\diffusers\pipelines\controlnet\pipeline_controlnet_sd_xl_img2img.py#版权所有2024HuggingFace团队。保留所有权利。##根据Apache许可证第2.0版(“许可证”)许可;#除非遵守许可证,否则您不得使用此文件。#您可以在以下网址获得许可证副......
  • diffusers-源码解析-二十六-
    diffusers源码解析(二十六).\diffusers\pipelines\deepfloyd_if\pipeline_if_inpainting_superresolution.py#导入html模块,用于处理HTML文本importhtml#导入inspect模块,用于获取对象的信息importinspect#导入re模块,用于正则表达式匹配importre#导入urllib.......
  • diffusers-源码解析-二十九-
    diffusers源码解析(二十九).\diffusers\pipelines\deprecated\stable_diffusion_variants\pipeline_stable_diffusion_model_editing.py#版权信息,声明版权和许可协议#Copyright2024TIMEAuthorsandTheHuggingFaceTeam.Allrightsreserved."#根据ApacheLicense2.0......
  • diffusers-源码解析-十一-
    diffusers源码解析(十一).\diffusers\models\transformers\hunyuan_transformer_2d.py#版权所有2024HunyuanDiT作者,QixunWang和HuggingFace团队。保留所有权利。##根据Apache许可证第2.0版("许可证")进行许可;#除非符合许可证,否则您不得使用此文件。#您可以在以......
  • diffusers-源码解析-十五-
    diffusers源码解析(十五).\diffusers\models\unets\unet_3d_condition.py#版权声明,声明此代码的版权信息和所有权#Copyright2024AlibabaDAMO-VILABandTheHuggingFaceTeam.Allrightsreserved.#版权声明,声明此代码的版权信息和所有权#Copyright2024TheModelSco......
  • diffusers-源码解析-十四-
    diffusers源码解析(十四).\diffusers\models\unets\unet_2d_blocks_flax.py#版权声明,说明该文件的版权信息及相关许可协议#Copyright2024TheHuggingFaceTeam.Allrightsreserved.##许可信息,使用ApacheLicense2.0许可#LicensedundertheApacheLicense,Versi......
  • diffusers-源码解析-十三-
    diffusers源码解析(十三).\diffusers\models\unets\unet_2d.py#版权声明,表示该代码由HuggingFace团队所有##根据Apache2.0许可证进行许可;#除非遵循许可证,否则不得使用此文件。#可以在以下地址获取许可证的副本:##http://www.apache.org/licenses/LICENSE-2.......
  • diffusers-源码解析-五-
    diffusers源码解析(五).\diffusers\models\autoencoders\autoencoder_asym_kl.py#版权声明,标识该文件的所有权和使用条款#Copyright2024TheHuggingFaceTeam.Allrightsreserved.##根据Apache许可证第2.0版(“许可证”)进行授权;#除非遵循许可证,否则您不得使用此文......