首页 > 编程语言 >FLUX 源码解析(全)

FLUX 源码解析(全)

时间:2024-09-05 11:47:15浏览次数:12  
标签:Tensor nn self torch channels FLUX 源码 device 解析

.\flux\demo_gr.py

# 导入操作系统相关模块
import os
# 导入时间相关模块
import time
# 从 io 模块导入 BytesIO 类
from io import BytesIO
# 导入 UUID 生成模块
import uuid

# 导入 PyTorch 库
import torch
# 导入 Gradio 库
import gradio as gr
# 导入 NumPy 库
import numpy as np
# 从 einops 模块导入 rearrange 函数
from einops import rearrange
# 从 PIL 库导入 Image 和 ExifTags
from PIL import Image, ExifTags
# 从 transformers 库导入 pipeline 函数
from transformers import pipeline

# 从 flux.cli 模块导入 SamplingOptions 类
from flux.cli import SamplingOptions
# 从 flux.sampling 模块导入多个函数
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
# 从 flux.util 模块导入多个函数
from flux.util import configs, embed_watermark, load_ae, load_clip, load_flow_model, load_t5

# 设置 NSFW (不适宜工作) 图像的分类阈值
NSFW_THRESHOLD = 0.85

# 定义获取模型的函数
def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool):
    # 加载 T5 模型,长度限制根据是否为 schnell 模型决定
    t5 = load_t5(device, max_length=256 if is_schnell else 512)
    # 加载 CLIP 模型
    clip = load_clip(device)
    # 加载流动模型,根据是否卸载来决定使用 CPU 还是设备
    model = load_flow_model(name, device="cpu" if offload else device)
    # 加载自编码器模型,同样根据是否卸载来决定使用 CPU 还是设备
    ae = load_ae(name, device="cpu" if offload else device)
    # 创建 NSFW 分类器管道
    nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
    # 返回加载的模型和分类器
    return model, ae, t5, clip, nsfw_classifier

# 定义 FluxGenerator 类
class FluxGenerator:
    # 类的初始化函数
    def __init__(self, model_name: str, device: str, offload: bool):
        # 将设备字符串转换为 torch.device 对象
        self.device = torch.device(device)
        # 是否卸载的标志
        self.offload = offload
        # 模型名称
        self.model_name = model_name
        # 判断是否为 schnell 模型
        self.is_schnell = model_name == "flux-schnell"
        # 获取模型及相关组件
        self.model, self.ae, self.t5, self.clip, self.nsfw_classifier = get_models(
            model_name,
            device=self.device,
            offload=self.offload,
            is_schnell=self.is_schnell,
        )

    # 使用 torch 的推理模式生成图像
    @torch.inference_mode()
    def generate_image(
        self,
        width,
        height,
        num_steps,
        guidance,
        seed,
        prompt,
        init_image=None,
        image2image_strength=0.0,
        add_sampling_metadata=True,
    # 定义创建演示的函数
def create_demo(model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False):
    # 初始化 FluxGenerator 对象
    generator = FluxGenerator(model_name, device, offload)
    # 判断是否为 schnell 模型
    is_schnell = model_name == "flux-schnell"
    # 创建一个 Gradio 应用的 UI 布局
    with gr.Blocks() as demo:
        # 添加标题 Markdown 文本,显示模型名称
        gr.Markdown(f"# Flux Image Generation Demo - Model: {model_name}")
        
        # 创建一行布局
        with gr.Row():
            # 创建一列布局
            with gr.Column():
                # 创建一个文本框用于输入提示
                prompt = gr.Textbox(label="Prompt", value="a photo of a forest with mist swirling around the tree trunks. The word \"FLUX\" is painted over it in big, red brush strokes with visible texture")
                # 创建一个复选框用于选择是否启用图像到图像转换
                do_img2img = gr.Checkbox(label="Image to Image", value=False, interactive=not is_schnell)
                # 创建一个隐藏的图像输入框
                init_image = gr.Image(label="Input Image", visible=False)
                # 创建一个隐藏的滑块,用于调整图像到图像转换的强度
                image2image_strength = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="Noising strength", visible=False)
                
                # 创建一个可折叠的高级选项区域
                with gr.Accordion("Advanced Options", open=False):
                    # 创建滑块用于设置图像宽度
                    width = gr.Slider(128, 8192, 1360, step=16, label="Width")
                    # 创建滑块用于设置图像高度
                    height = gr.Slider(128, 8192, 768, step=16, label="Height")
                    # 创建滑块用于设置步骤数,根据是否快速模式设置初始值
                    num_steps = gr.Slider(1, 50, 4 if is_schnell else 50, step=1, label="Number of steps")
                    # 创建滑块用于设置指导强度
                    guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Guidance", interactive=not is_schnell)
                    # 创建一个文本框用于输入种子值
                    seed = gr.Textbox(-1, label="Seed (-1 for random)")
                    # 创建一个复选框用于选择是否将采样参数添加到元数据
                    add_sampling_metadata = gr.Checkbox(label="Add sampling parameters to metadata?", value=True)
                
                # 创建一个生成按钮
                generate_btn = gr.Button("Generate")
            
            # 创建另一列布局
            with gr.Column():
                # 创建一个图像框用于显示生成的图像
                output_image = gr.Image(label="Generated Image")
                # 创建一个数字框用于显示使用的种子
                seed_output = gr.Number(label="Used Seed")
                # 创建一个文本框用于显示警告信息
                warning_text = gr.Textbox(label="Warning", visible=False)
                # 创建一个文件框用于下载高分辨率图像
                download_btn = gr.File(label="Download full-resolution")

        # 定义一个函数,用于更新图像到图像转换的可见性
        def update_img2img(do_img2img):
            return {
                init_image: gr.update(visible=do_img2img),
                image2image_strength: gr.update(visible=do_img2img),
            }

        # 当复选框状态变化时,调用更新函数
        do_img2img.change(update_img2img, do_img2img, [init_image, image2image_strength])

        # 设置生成按钮的点击事件,调用生成图像的函数并设置输入和输出
        generate_btn.click(
            fn=generator.generate_image,
            inputs=[width, height, num_steps, guidance, seed, prompt, init_image, image2image_strength, add_sampling_metadata],
            outputs=[output_image, seed_output, download_btn, warning_text],
        )

    # 返回创建的 Gradio 应用布局
    return demo
# 当脚本作为主程序运行时执行以下代码
if __name__ == "__main__":
    # 导入 argparse 模块用于处理命令行参数
    import argparse
    # 创建 ArgumentParser 对象,用于解析命令行参数
    parser = argparse.ArgumentParser(description="Flux")
    # 添加 --name 参数,指定模型名称,默认值为 "flux-schnell",并限制选择范围
    parser.add_argument("--name", type=str, default="flux-schnell", choices=list(configs.keys()), help="Model name")
    # 添加 --device 参数,指定设备,默认值为 "cuda"(如果有 GPU 可用),否则为 "cpu"
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use")
    # 添加 --offload 参数,标志位,指示是否在不使用时将模型移到 CPU
    parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
    # 添加 --share 参数,标志位,指示是否创建一个公共链接以共享演示
    parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
    # 解析命令行参数,并将结果存储在 args 对象中
    args = parser.parse_args()

    # 使用解析出的参数创建 demo 对象
    demo = create_demo(args.name, args.device, args.offload)
    # 启动 demo,是否共享由 --share 参数决定
    demo.launch(share=args.share)

.\flux\demo_st.py

# 导入操作系统相关功能
import os
# 导入正则表达式处理功能
import re
# 导入时间处理功能
import time
# 从 glob 模块导入 iglob,用于生成匹配特定模式的文件路径
from glob import iglob
# 从 io 模块导入 BytesIO,用于处理字节流
from io import BytesIO

# 导入 streamlit 库,用于创建 Web 应用
import streamlit as st
# 导入 PyTorch 库,用于深度学习模型
import torch
# 从 einops 库导入 rearrange,用于张量的重排
from einops import rearrange
# 从 fire 库导入 Fire,用于将命令行参数绑定到函数
from fire import Fire
# 从 PIL 库导入 ExifTags 和 Image,用于图像处理
from PIL import ExifTags, Image
# 从 st_keyup 库导入 st_keyup,用于捕捉键盘事件
from st_keyup import st_keyup
# 从 torchvision 库导入 transforms,用于图像转换
from torchvision import transforms
# 从 transformers 库导入 pipeline,用于各种预训练模型的管道
from transformers import pipeline

# 设置 NSFW 内容的阈值
NSFW_THRESHOLD = 0.85


# 使用 Streamlit 缓存模型加载函数的结果,以提高性能
@st.cache_resource()
def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool):
    # 加载 T5 模型,最大长度取决于是否使用 Schnell 模式
    t5 = load_t5(device, max_length=256 if is_schnell else 512)
    # 加载 CLIP 模型
    clip = load_clip(device)
    # 加载流模型,设备可能是 CPU 或 GPU
    model = load_flow_model(name, device="cpu" if offload else device)
    # 加载自动编码器模型,设备可能是 CPU 或 GPU
    ae = load_ae(name, device="cpu" if offload else device)
    # 加载 NSFW 分类器,用于图像内容检测
    nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
    # 返回模型、自动编码器、T5、CLIP 和 NSFW 分类器
    return model, ae, t5, clip, nsfw_classifier


# 获取用户上传的图像,返回处理后的张量
def get_image() -> torch.Tensor | None:
    # 允许用户上传 JPG、JPEG 或 PNG 格式的图像
    image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
    # 如果没有上传图像,返回 None
    if image is None:
        return None
    # 打开图像文件并转换为 RGB 模式
    image = Image.open(image).convert("RGB")

    # 定义图像转换操作,将图像转为张量,并进行归一化
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Lambda(lambda x: 2.0 * x - 1.0),
        ]
    )
    # 应用转换,将图像处理为张量,并增加一个维度
    img: torch.Tensor = transform(image)
    return img[None, ...]


# 主函数,用于运行应用逻辑
@torch.inference_mode()
def main(
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
    offload: bool = False,
    output_dir: str = "output",
):
    # 根据用户选择的设备创建 PyTorch 设备对象
    torch_device = torch.device(device)
    # 获取配置中的模型名称列表
    names = list(configs.keys())
    # 让用户选择要加载的模型
    name = st.selectbox("Which model to load?", names)
    # 如果未选择模型或未勾选加载模型的复选框,则返回
    if name is None or not st.checkbox("Load model", False):
        return

    # 判断是否使用 Schnell 模式
    is_schnell = name == "flux-schnell"
    # 获取所需的模型和分类器
    model, ae, t5, clip, nsfw_classifier = get_models(
        name,
        device=torch_device,
        offload=offload,
        is_schnell=is_schnell,
    )

    # 判断是否执行图像到图像的转换
    do_img2img = (
        st.checkbox(
            "Image to Image",
            False,
            disabled=is_schnell,
            help="Partially noise an image and denoise again to get variations.\n\nOnly works for flux-dev",
        )
        and not is_schnell
    )
    # 如果需要图像到图像转换
    if do_img2img:
        # 获取用户上传的图像
        init_image = get_image()
        # 如果没有上传图像,显示警告信息
        if init_image is None:
            st.warning("Please add an image to do image to image")
        # 让用户输入噪声强度
        image2image_strength = st.number_input("Noising strength", min_value=0.0, max_value=1.0, value=0.8)
        # 如果上传了图像,显示图像尺寸
        if init_image is not None:
            h, w = init_image.shape[-2:]
            st.write(f"Got image of size {w}x{h} ({h*w/1e6:.2f}MP)")
        # 让用户选择是否调整图像大小
        resize_img = st.checkbox("Resize image", False) or init_image is None
    else:
        # 如果不进行图像到图像转换,初始化图像和图像调整标志
        init_image = None
        resize_img = True
        image2image_strength = 0.0

    # 允许进行打包和转换到潜在空间
    # 根据用户输入的宽度值计算实际宽度,确保宽度为16的倍数
    width = int(
        16 * (st.number_input("Width", min_value=128, value=1360, step=16, disabled=not resize_img) // 16)
    )
    # 根据用户输入的高度值计算实际高度,确保高度为16的倍数
    height = int(
        16 * (st.number_input("Height", min_value=128, value=768, step=16, disabled=not resize_img) // 16)
    )
    # 根据用户输入的步数值设置步数,默认值为4(如果是"schnell"模式),否则为50
    num_steps = int(st.number_input("Number of steps", min_value=1, value=(4 if is_schnell else 50)))
    # 根据用户输入的引导值设置引导参数,默认为3.5,"schnell"模式下禁用此输入
    guidance = float(st.number_input("Guidance", min_value=1.0, value=3.5, disabled=is_schnell))
    # 根据用户输入的种子值设置种子,"schnell"模式下禁用此输入
    seed_str = st.text_input("Seed", disabled=is_schnell)
    # 如果种子值是有效的十进制数,则将其转换为整数;否则,设置种子为None,并显示提示信息
    if seed_str.isdecimal():
        seed = int(seed_str)
    else:
        st.info("No seed set, set to positive integer to enable")
        seed = None
    # 根据用户选择是否保存样本,设置保存样本的选项
    save_samples = st.checkbox("Save samples?", not is_schnell)
    # 根据用户选择是否将采样参数添加到元数据中,设置此选项
    add_sampling_metadata = st.checkbox("Add sampling parameters to metadata?", True)

    # 默认提示文本,用于生成图像
    default_prompt = (
        "a photo of a forest with mist swirling around the tree trunks. The word "
        '"FLUX" is painted over it in big, red brush strokes with visible texture'
    )
    # 获取用户输入的提示文本,默认值为default_prompt,并设置300毫秒的防抖延迟
    prompt = st_keyup("Enter a prompt", value=default_prompt, debounce=300, key="interactive_text")

    # 构造输出文件名的路径,并检查输出目录是否存在
    output_name = os.path.join(output_dir, "img_{idx}.jpg")
    if not os.path.exists(output_dir):
        # 如果输出目录不存在,则创建目录,并初始化索引为0
        os.makedirs(output_dir)
        idx = 0
    else:
        # 如果输出目录存在,获取所有匹配的文件名,并计算下一个可用的索引
        fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
        if len(fns) > 0:
            idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
        else:
            idx = 0

    # 创建一个 PyTorch 随机数生成器对象
    rng = torch.Generator(device="cpu")

    # 如果 session_state 中没有“seed”项,则初始化种子
    if "seed" not in st.session_state:
        st.session_state.seed = rng.seed()

    # 定义增加种子值的函数
    def increment_counter():
        st.session_state.seed += 1

    # 定义减少种子值的函数(种子值不能小于0)
    def decrement_counter():
        if st.session_state.seed > 0:
            st.session_state.seed -= 1

    # 创建一个采样选项对象,用于后续处理
    opts = SamplingOptions(
        prompt=prompt,
        width=width,
        height=height,
        num_steps=num_steps,
        guidance=guidance,
        seed=seed,
    )

    # 如果应用名为“flux-schnell”,则显示带有按钮的列来增加或减少种子值
    if name == "flux-schnell":
        cols = st.columns([5, 1, 1, 5])
        with cols[1]:
            st.button("↩", on_click=increment_counter)
        with cols[2]:
            st.button("↪", on_click=decrement_counter)
    # 获取会话状态中的样本(如果存在),并显示图像及其相关信息
    samples = st.session_state.get("samples", None)
    if samples is not None:
        st.image(samples["img"], caption=samples["prompt"])
        st.download_button(
            "Download full-resolution",
            samples["bytes"],
            file_name="generated.jpg",
            mime="image/jpg",
        )
        st.write(f"Seed: {samples['seed']}")
# 定义应用程序入口函数
def app():
    # 调用 Fire 函数并传入 main 作为参数
    Fire(main)


# 如果脚本是主程序(而不是被导入),则执行 app() 函数
if __name__ == "__main__":
    app()

FLUX.1 [dev] Grid

FLUX.1 [dev] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions.
For more information, please read our blog post.

Key Features

  1. Cutting-edge output quality, second only to our state-of-the-art model FLUX.1 [pro].
  2. Competitive prompt following, matching the performance of closed source alternatives.
  3. Trained using guidance distillation, making FLUX.1 [dev] more efficient.
  4. Open weights to drive new scientific research, and empower artists to develop innovative workflows.
  5. Generated outputs can be used for personal, scientific, and commercial purposes, as described in the flux-1-dev-non-commercial-license.

Usage

We provide a reference implementation of FLUX.1 [dev], as well as sampling code, in a dedicated github repository.
Developers and creatives looking to build on top of FLUX.1 [dev] are encouraged to use this as a starting point.

API Endpoints

The FLUX.1 models are also available via API from the following sources

  1. bfl.ml (currently FLUX.1 [pro])
  2. replicate.com
  3. fal.ai

ComfyUI

FLUX.1 [dev] is also available in Comfy UI for local inference with a node-based workflow.


Limitations

  • This model is not intended or able to provide factual information.
  • As a statistical model this checkpoint might amplify existing societal biases.
  • The model may fail to generate output that matches the prompts.
  • Prompt following is heavily influenced by the prompting-style.

Out-of-Scope Use

The model and its derivatives may not be used

  • In any way that violates any applicable national, federal, state, local or international law or regulation.
  • For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; including but not limited to the solicitation, creation, acquisition, or dissemination of child exploitative content.
  • To generate or disseminate verifiably false information and/or content with the purpose of harming others.
  • To generate or disseminate personal identifiable information that can be used to harm an individual.
  • To harass, abuse, threaten, stalk, or bully individuals or groups of individuals.
  • To create non-consensual nudity or illegal pornographic content.
  • For fully automated decision making that adversely impacts an individual's legal rights or otherwise creates or modifies a binding, enforceable obligation.
  • Generating or facilitating large-scale disinformation campaigns.

License

This model falls under the FLUX.1 [dev] Non-Commercial License.

FLUX.1 [schnell] Grid

FLUX.1 [schnell] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions.
For more information, please read our blog post.

Key Features

  1. Cutting-edge output quality and competitive prompt following, matching the performance of closed source alternatives.
  2. Trained using latent adversarial diffusion distillation, FLUX.1 [schnell] can generate high-quality images in only 1 to 4 steps.
  3. Released under the apache-2.0 licence, the model can be used for personal, scientific, and commercial purposes.

Usage

We provide a reference implementation of FLUX.1 [schnell], as well as sampling code, in a dedicated github repository.
Developers and creatives looking to build on top of FLUX.1 [schnell] are encouraged to use this as a starting point.

API Endpoints

The FLUX.1 models are also available via API from the following sources

  1. bfl.ml (currently FLUX.1 [pro])
  2. replicate.com
  3. fal.ai

ComfyUI

FLUX.1 [schnell] is also available in Comfy UI for local inference with a node-based workflow.


Limitations

  • This model is not intended or able to provide factual information.
  • As a statistical model this checkpoint might amplify existing societal biases.
  • The model may fail to generate output that matches the prompts.
  • Prompt following is heavily influenced by the prompting-style.

Out-of-Scope Use

The model and its derivatives may not be used

  • In any way that violates any applicable national, federal, state, local or international law or regulation.
  • For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; including but not limited to the solicitation, creation, acquisition, or dissemination of child exploitative content.
  • To generate or disseminate verifiably false information and/or content with the purpose of harming others.
  • To generate or disseminate personal identifiable information that can be used to harm an individual.
  • To harass, abuse, threaten, stalk, or bully individuals or groups of individuals.
  • To create non-consensual nudity or illegal pornographic content.
  • For fully automated decision making that adversely impacts an individual's legal rights or otherwise creates or modifies a binding, enforceable obligation.
  • Generating or facilitating large-scale disinformation campaigns.

.\flux\src\flux\api.py

# 导入标准库中的 io 模块,用于处理
    ):
        """
        Manages an image generation request to the API.

        Args:
            prompt: Prompt to sample
            width: Width of the image in pixel
            height: Height of the image in pixel
            name: Name of the model
            num_steps: Number of network evaluations
            prompt_upsampling: Use prompt upsampling
            seed: Fix the generation seed
            validate: Run input validation
            launch: Directly launches request
            api_key: Your API key if not provided by the environment

        Raises:
            ValueError: For invalid input
            ApiException: For errors raised from the API
        """
        # 如果需要验证输入
        if validate:
            # 检查模型名称是否有效
            if name not in ["flux.1-pro"]:
                raise ValueError(f"Invalid model {name}")
            # 检查宽度是否是 32 的倍数
            elif width % 32 != 0:
                raise ValueError(f"width must be divisible by 32, got {width}")
            # 检查宽度是否在合法范围内
            elif not (256 <= width <= 1440):
                raise ValueError(f"width must be between 256 and 1440, got {width}")
            # 检查高度是否是 32 的倍数
            elif height % 32 != 0:
                raise ValueError(f"height must be divisible by 32, got {height}")
            # 检查高度是否在合法范围内
            elif not (256 <= height <= 1440):
                raise ValueError(f"height must be between 256 and 1440, got {height}")
            # 检查步骤数量是否在合法范围内
            elif not (1 <= num_steps <= 50):
                raise ValueError(f"steps must be between 1 and 50, got {num_steps}")

        # 创建请求 JSON 对象,包含所有必需的参数
        self.request_json = {
            "prompt": prompt,
            "width": width,
            "height": height,
            "variant": name,
            "steps": num_steps,
            "prompt_upsampling": prompt_upsampling,
        }
        # 如果指定了种子,将其添加到请求 JSON 中
        if seed is not None:
            self.request_json["seed"] = seed

        # 初始化实例变量
        self.request_id: str | None = None
        self.result: dict | None = None
        self._image_bytes: bytes | None = None
        self._url: str | None = None
        # 如果没有提供 API 密钥,则从环境变量中获取
        if api_key is None:
            self.api_key = os.environ.get("BFL_API_KEY")
        else:
            # 否则使用提供的 API 密钥
            self.api_key = api_key

        # 如果需要立即发起请求
        if launch:
            self.request()

    def request(self):
        """
        Request to generate the image.
        """
        # 如果已经有请求 ID,则不再发起请求
        if self.request_id is not None:
            return
        # 发起 POST 请求以生成图像
        response = requests.post(
            f"{API_ENDPOINT}/v1/image",
            headers={
                "accept": "application/json",
                "x-key": self.api_key,
                "Content-Type": "application/json",
            },
            json=self.request_json,
        )
        # 解析响应为 JSON
        result = response.json()
        # 如果响应状态码不是 200,抛出 API 异常
        if response.status_code != 200:
            raise ApiException(status_code=response.status_code, detail=result.get("detail"))
        # 存储请求 ID
        self.request_id = response.json()["id"]
    # 定义一个方法来等待生成完成并检索响应结果
    def retrieve(self) -> dict:
        """
        等待生成完成并检索响应
        """
        # 如果 request_id 为空,则调用请求方法生成请求 ID
        if self.request_id is None:
            self.request()
        # 循环等待直到结果可用
        while self.result is None:
            # 发送 GET 请求以获取结果
            response = requests.get(
                f"{API_ENDPOINT}/v1/get_result",
                headers={
                    "accept": "application/json",
                    "x-key": self.api_key,
                },
                params={
                    "id": self.request_id,
                },
            )
            # 将响应内容转换为 JSON 格式
            result = response.json()
            # 检查返回结果中是否包含状态字段
            if "status" not in result:
                # 如果没有状态字段,抛出 API 异常
                raise ApiException(status_code=response.status_code, detail=result.get("detail"))
            # 如果状态是“Ready”,则将结果保存到实例变量
            elif result["status"] == "Ready":
                self.result = result["result"]
            # 如果状态是“Pending”,则等待 0.5 秒再重试
            elif result["status"] == "Pending":
                time.sleep(0.5)
            # 如果状态是其他值,抛出 API 异常
            else:
                raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
        # 返回最终结果
        return self.result

    # 定义一个属性方法,返回生成的图像字节
    @property
    def bytes(self) -> bytes:
        """
        生成的图像字节
        """
        # 如果图像字节为空,则从 URL 获取图像数据
        if self._image_bytes is None:
            response = requests.get(self.url)
            # 如果响应状态码是 200,则保存图像字节
            if response.status_code == 200:
                self._image_bytes = response.content
            # 否则抛出 API 异常
            else:
                raise ApiException(status_code=response.status_code)
        # 返回图像字节
        return self._image_bytes

    # 定义一个属性方法,返回图像的公共 URL
    @property
    def url(self) -> str:
        """
        检索图像的公共 URL
        """
        # 如果 URL 为空,则调用 retrieve 方法获取结果并保存 URL
        if self._url is None:
            result = self.retrieve()
            self._url = result["sample"]
        # 返回图像的 URL
        return self._url

    # 定义一个属性方法,返回 PIL 图像对象
    @property
    def image(self) -> Image.Image:
        """
        加载图像为 PIL Image 对象
        """
        return Image.open(io.BytesIO(self.bytes))

    # 定义一个方法来将生成的图像保存到本地路径
    def save(self, path: str):
        """
        将生成的图像保存到本地路径
        """
        # 获取 URL 的文件扩展名
        suffix = Path(self.url).suffix
        # 如果路径没有扩展名,则将扩展名添加到路径中
        if not path.endswith(suffix):
            path = path + suffix
        # 创建保存路径的父目录(如果不存在)
        Path(path).resolve().parent.mkdir(parents=True, exist_ok=True)
        # 将图像字节写入指定路径
        with open(path, "wb") as file:
            file.write(self.bytes)
# 确保只有在直接运行该脚本时才执行以下代码
if __name__ == "__main__":
    # 从 fire 库中导入 Fire 类
    from fire import Fire

    # 使用 Fire 类启动命令行界面,传入 ImageRequest 作为处理对象
    Fire(ImageRequest)

.\flux\src\flux\cli.py

# 导入操作系统相关模块
import os
# 导入正则表达式模块
import re
# 导入时间模块
import time
# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass
# 从 glob 模块导入 iglob 函数,用于文件名模式匹配
from glob import iglob

# 导入 PyTorch 库
import torch
# 从 einops 模块导入 rearrange 函数,用于张量重排
from einops import rearrange
# 从 fire 模块导入 Fire 类,用于命令行接口
from fire import Fire
# 从 PIL 模块导入 ExifTags 和 Image,用于处理图片和元数据
from PIL import ExifTags, Image

# 从 flux.sampling 模块导入采样相关函数
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
# 从 flux.util 模块导入实用工具函数
from flux.util import (configs, embed_watermark, load_ae, load_clip,
                       load_flow_model, load_t5)
# 从 transformers 模块导入 pipeline,用于加载预训练模型
from transformers import pipeline

# 设置 NSFW(不适宜工作)内容的阈值
NSFW_THRESHOLD = 0.85

# 定义一个数据类,用于存储采样选项
@dataclass
class SamplingOptions:
    # 用户提示文本
    prompt: str
    # 图像宽度
    width: int
    # 图像高度
    height: int
    # 生成图像的步骤数量
    num_steps: int
    # 引导强度
    guidance: float
    # 随机种子,可选
    seed: int | None

# 解析用户输入的提示,并根据选项更新 SamplingOptions
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
    # 提示用户输入下一个提示
    user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
    # 使用说明文本
    usage = (
        "Usage: Either write your prompt directly, leave this field empty "
        "to repeat the prompt or write a command starting with a slash:\n"
        "- '/w <width>' will set the width of the generated image\n"
        "- '/h <height>' will set the height of the generated image\n"
        "- '/s <seed>' sets the next seed\n"
        "- '/g <guidance>' sets the guidance (flux-dev only)\n"
        "- '/n <steps>' sets the number of steps\n"
        "- '/q' to quit"
    )
    # 循环读取用户输入,直到输入不以斜杠开头
    while (prompt := input(user_question)).startswith("/"):
        # 处理以 "/w" 开头的命令,设置宽度
        if prompt.startswith("/w"):
            # 如果命令中没有空格,提示无效命令并继续
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            # 解析命令中的宽度值并设置为16的倍数
            _, width = prompt.split()
            options.width = 16 * (int(width) // 16)
            # 打印设置的宽度和高度,以及总像素数
            print(
                f"Setting resolution to {options.width} x {options.height} "
                f"({options.height *options.width/1e6:.2f}MP)"
            )
        # 处理以 "/h" 开头的命令,设置高度
        elif prompt.startswith("/h"):
            # 如果命令中没有空格,提示无效命令并继续
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            # 解析命令中的高度值并设置为16的倍数
            _, height = prompt.split()
            options.height = 16 * (int(height) // 16)
            # 打印设置的宽度和高度,以及总像素数
            print(
                f"Setting resolution to {options.width} x {options.height} "
                f"({options.height *options.width/1e6:.2f}MP)"
            )
        # 处理以 "/g" 开头的命令,设置指导值
        elif prompt.startswith("/g"):
            # 如果命令中没有空格,提示无效命令并继续
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            # 解析命令中的指导值
            _, guidance = prompt.split()
            options.guidance = float(guidance)
            # 打印设置的指导值
            print(f"Setting guidance to {options.guidance}")
        # 处理以 "/s" 开头的命令,设置种子值
        elif prompt.startswith("/s"):
            # 如果命令中没有空格,提示无效命令并继续
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            # 解析命令中的种子值
            _, seed = prompt.split()
            options.seed = int(seed)
            # 打印设置的种子值
            print(f"Setting seed to {options.seed}")
        # 处理以 "/n" 开头的命令,设置步骤数
        elif prompt.startswith("/n"):
            # 如果命令中没有空格,提示无效命令并继续
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            # 解析命令中的步骤数
            _, steps = prompt.split()
            options.num_steps = int(steps)
            # 打印设置的步骤数
            print(f"Setting seed to {options.num_steps}")
        # 处理以 "/q" 开头的命令,退出循环
        elif prompt.startswith("/q"):
            print("Quitting")
            return None
        else:
            # 如果命令不以已知前缀开头,提示无效命令并显示用法
            if not prompt.startswith("/h"):
                print(f"Got invalid command '{prompt}'\n{usage}")
            print(usage)
    # 如果输入不为空,将其设置为提示
    if prompt != "":
        options.prompt = prompt
    # 返回更新后的选项对象
    return options
@torch.inference_mode()
def main(
    name: str = "flux-schnell",
    width: int = 1360,
    height: int = 768,
    seed: int | None = None,
    prompt: str = (
        "a photo of a forest with mist swirling around the tree trunks. The word "
        '"FLUX" is painted over it in big, red brush strokes with visible texture'
    ),
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
    num_steps: int | None = None,
    loop: bool = False,
    guidance: float = 3.5,
    offload: bool = False,
    output_dir: str = "output",
    add_sampling_metadata: bool = True,
):
    """
    Sample the flux model. Either interactively (set `--loop`) or run for a
    single image.

    Args:
        name: Name of the model to load
        height: height of the sample in pixels (should be a multiple of 16)
        width: width of the sample in pixels (should be a multiple of 16)
        seed: Set a seed for sampling
        output_name: where to save the output image, `{idx}` will be replaced
            by the index of the sample
        prompt: Prompt used for sampling
        device: Pytorch device
        num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
        loop: start an interactive session and sample multiple times
        guidance: guidance value used for guidance distillation
        add_sampling_metadata: Add the prompt to the image Exif metadata
    """
    # Initialize an NSFW image classification pipeline with the specified model and device
    nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)

    # Check if the specified model name is valid
    if name not in configs:
        available = ", ".join(configs.keys())
        raise ValueError(f"Got unknown model name: {name}, chose from {available}")

    # Set the PyTorch device based on the provided device string
    torch_device = torch.device(device)
    # Determine the number of sampling steps based on the model name
    if num_steps is None:
        num_steps = 4 if name == "flux-schnell" else 50

    # Adjust height and width to be multiples of 16 for compatibility
    height = 16 * (height // 16)
    width = 16 * (width // 16)

    # Construct the output file path and handle directory and index management
    output_name = os.path.join(output_dir, "img_{idx}.jpg")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        idx = 0
    else:
        fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
        if len(fns) > 0:
            idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
        else:
            idx = 0

    # Initialize components for the sampling process
    t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
    clip = load_clip(torch_device)
    model = load_flow_model(name, device="cpu" if offload else torch_device)
    ae = load_ae(name, device="cpu" if offload else torch_device)

    # Create a random number generator and sampling options
    rng = torch.Generator(device="cpu")
    opts = SamplingOptions(
        prompt=prompt,
        width=width,
        height=height,
        num_steps=num_steps,
        guidance=guidance,
        seed=seed,
    )

    # If loop mode is enabled, adjust the options based on the prompt
    if loop:
        opts = parse_prompt(opts)
    # 当 opts 不为 None 时持续循环
    while opts is not None:
        # 如果 opts 中没有种子,则生成一个新的种子
        if opts.seed is None:
            opts.seed = rng.seed()
        # 打印生成过程的种子和提示
        print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
        # 记录当前时间以计算生成时间
        t0 = time.perf_counter()

        # 准备输入噪声数据
        x = get_noise(
            1,
            opts.height,
            opts.width,
            device=torch_device,
            dtype=torch.bfloat16,
            seed=opts.seed,
        )
        # 将种子置为 None 以防止重复使用
        opts.seed = None
        # 如果需要将模型移至 CPU,清理 CUDA 缓存,并将模型移动到指定设备
        if offload:
            ae = ae.cpu()
            torch.cuda.empty_cache()
            t5, clip = t5.to(torch_device), clip.to(torch_device)
        # 准备输入数据,包括将 T5 和 CLIP 模型的输出、噪声以及提示整理成输入
        inp = prepare(t5, clip, x, prompt=opts.prompt)
        # 获取时间步的调度
        timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))

        # 如果需要将模型移至 CPU,清理 CUDA 缓存,并将模型移动到 GPU
        if offload:
            t5, clip = t5.cpu(), clip.cpu()
            torch.cuda.empty_cache()
            model = model.to(torch_device)

        # 对初始噪声进行去噪处理
        x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)

        # 如果需要将模型移至 CPU,清理 CUDA 缓存,并将自动编码器的解码器移至当前设备
        if offload:
            model.cpu()
            torch.cuda.empty_cache()
            ae.decoder.to(x.device)

        # 将潜在变量解码到像素空间
        x = unpack(x.float(), opts.height, opts.width)
        with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
            x = ae.decode(x)
        # 记录解码处理时间
        t1 = time.perf_counter()

        # 格式化输出文件名
        fn = output_name.format(idx=idx)
        print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
        # 将图像数据带入 PIL 格式并保存
        x = x.clamp(-1, 1)
        x = embed_watermark(x.float())
        x = rearrange(x[0], "c h w -> h w c")

        # 从 numpy 数组创建 PIL 图像对象
        img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
        # 进行 NSFW 内容检测
        nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]
        
        # 如果 NSFW 分数低于阈值,则保存图像及其 EXIF 元数据
        if nsfw_score < NSFW_THRESHOLD:
            exif_data = Image.Exif()
            exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
            exif_data[ExifTags.Base.Make] = "Black Forest Labs"
            exif_data[ExifTags.Base.Model] = name
            if add_sampling_metadata:
                exif_data[ExifTags.Base.ImageDescription] = prompt
            img.save(fn, exif=exif_data, quality=95, subsampling=0)
            # 增加图像索引
            idx += 1
        else:
            print("Your generated image may contain NSFW content.")

        # 如果设置了循环,则解析新的提示并继续,否则退出循环
        if loop:
            print("-" * 80)
            opts = parse_prompt(opts)
        else:
            opts = None
# 定义主函数
def app():
    # 使用 Fire 库将 main 函数作为命令行接口
    Fire(main)


# 检查是否为主模块运行
if __name__ == "__main__":
    # 调用 app 函数
    app()

.\flux\src\flux\math.py

# 导入 PyTorch 库和 einops 的 rearrange 函数
import torch
from einops import rearrange
from torch import Tensor


# 注意力机制函数
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
    # 对 q 和 k 应用相对位置编码
    q, k = apply_rope(q, k, pe)

    # 使用缩放点积注意力计算输出
    x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
    # 重新排列输出张量的维度
    x = rearrange(x, "B H L D -> B L (H D)")

    # 返回处理后的张量
    return x


# 相对位置编码函数
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
    # 确保维度是偶数
    assert dim % 2 == 0
    # 计算尺度因子
    scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
    # 计算 omega 值
    omega = 1.0 / (theta**scale)
    # 通过爱因斯坦求和计算输出
    out = torch.einsum("...n,d->...nd", pos, omega)
    # 创建旋转矩阵
    out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
    # 重新排列旋转矩阵的维度
    out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
    # 转换为 float 类型并返回
    return out.float()


# 应用相对位置编码的辅助函数
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
    # 重新排列 q 和 k 的维度并转换为 float 类型
    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
    # 计算 q 和 k 的编码输出
    xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
    xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
    # 恢复原始维度并返回
    return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

.\flux\src\flux\model.py

# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass

# 导入 PyTorch 和相关模块
import torch
from torch import Tensor, nn

# 从 flux.modules.layers 模块导入特定的类
from flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
                                 MLPEmbedder, SingleStreamBlock,
                                 timestep_embedding)

# 定义包含模型参数的类
@dataclass
class FluxParams:
    # 输入通道数
    in_channels: int
    # 输入向量维度
    vec_in_dim: int
    # 上下文输入维度
    context_in_dim: int
    # 隐藏层大小
    hidden_size: int
    # MLP 比例
    mlp_ratio: float
    # 头数
    num_heads: int
    # 网络深度
    depth: int
    # 单流块的深度
    depth_single_blocks: int
    # 轴维度列表
    axes_dim: list[int]
    # theta 参数
    theta: int
    # 是否使用 QKV 偏置
    qkv_bias: bool
    # 是否使用引导嵌入
    guidance_embed: bool

# 定义 Flux 模型类
class Flux(nn.Module):
    """
    Transformer 模型用于序列上的流匹配。
    """

    # 初始化方法
    def __init__(self, params: FluxParams):
        super().__init__()

        # 保存参数
        self.params = params
        # 输入通道数
        self.in_channels = params.in_channels
        # 输出通道数与输入通道数相同
        self.out_channels = self.in_channels
        # 确保隐藏层大小可以被头数整除
        if params.hidden_size % params.num_heads != 0:
            raise ValueError(
                f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
            )
        # 计算位置编码维度
        pe_dim = params.hidden_size // params.num_heads
        # 确保轴维度总和与位置编码维度匹配
        if sum(params.axes_dim) != pe_dim:
            raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
        # 隐藏层大小
        self.hidden_size = params.hidden_size
        # 头数
        self.num_heads = params.num_heads
        # 初始化位置嵌入层
        self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
        # 初始化图像输入线性层
        self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
        # 初始化时间嵌入层
        self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
        # 初始化向量嵌入层
        self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
        # 初始化引导嵌入层(如果需要的话)
        self.guidance_in = (
            MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
        )
        # 初始化文本输入线性层
        self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)

        # 创建双流块的模块列表
        self.double_blocks = nn.ModuleList(
            [
                DoubleStreamBlock(
                    self.hidden_size,
                    self.num_heads,
                    mlp_ratio=params.mlp_ratio,
                    qkv_bias=params.qkv_bias,
                )
                for _ in range(params.depth)
            ]
        )

        # 创建单流块的模块列表
        self.single_blocks = nn.ModuleList(
            [
                SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
                for _ in range(params.depth_single_blocks)
            ]
        )

        # 初始化最终层
        self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)

    # 前向传播方法
    def forward(
        self,
        img: Tensor,
        img_ids: Tensor,
        txt: Tensor,
        txt_ids: Tensor,
        timesteps: Tensor,
        y: Tensor,
        guidance: Tensor | None = None,
    ) -> Tensor:  # 定义返回类型为 Tensor 的函数
        # 检查 img 和 txt 张量是否都具有 3 个维度
        if img.ndim != 3 or txt.ndim != 3:
            raise ValueError("Input img and txt tensors must have 3 dimensions.")

        # 对输入的 img 张量进行初步处理
        img = self.img_in(img)
        # 计算时间步嵌入向量,并通过 self.time_in 处理
        vec = self.time_in(timestep_embedding(timesteps, 256))
        # 如果启用了指导嵌入,则处理指导嵌入
        if self.params.guidance_embed:
            if guidance is None:
                raise ValueError("Didn't get guidance strength for guidance distilled model.")
            # 将指导嵌入向量添加到 vec 中
            vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
        # 将其他向量添加到 vec 中
        vec = vec + self.vector_in(y)
        # 对 txt 张量进行处理
        txt = self.txt_in(txt)

        # 将 txt_ids 和 img_ids 按维度 1 拼接
        ids = torch.cat((txt_ids, img_ids), dim=1)
        # 计算位置编码
        pe = self.pe_embedder(ids)

        # 对 double_blocks 中的每个块进行处理
        for block in self.double_blocks:
            img, txt = block(img=img, txt=txt, vec=vec, pe=pe)

        # 将 txt 和 img 张量按维度 1 拼接
        img = torch.cat((txt, img), 1)
        # 对 single_blocks 中的每个块进行处理
        for block in self.single_blocks:
            img = block(img, vec=vec, pe=pe)
        # 截取 img 张量,去掉前面的 txt 部分
        img = img[:, txt.shape[1] :, ...]

        # 最终处理 img 张量,返回结果
        img = self.final_layer(img, vec)  # (N, T, patch_size ** 2 * out_channels)
        return img

.\flux\src\flux\modules\autoencoder.py

# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass

# 导入 PyTorch 库
import torch
# 从 einops 模块导入 rearrange 函数
from einops import rearrange
# 从 torch 库导入 Tensor 和 nn 模块
from torch import Tensor, nn


# 定义 AutoEncoder 的参数数据类
@dataclass
class AutoEncoderParams:
    resolution: int  # 图像分辨率
    in_channels: int  # 输入通道数
    ch: int  # 基本通道数
    out_ch: int  # 输出通道数
    ch_mult: list[int]  # 通道数的增减比例
    num_res_blocks: int  # 残差块数量
    z_channels: int  # 潜在通道数
    scale_factor: float  # 缩放因子
    shift_factor: float  # 偏移因子


# 定义 swish 激活函数
def swish(x: Tensor) -> Tensor:
    # 使用 sigmoid 函数调节 x 的激活值
    return x * torch.sigmoid(x)


# 定义注意力块类
class AttnBlock(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        self.in_channels = in_channels

        # 初始化归一化层
        self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)

        # 初始化用于计算注意力的卷积层
        self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)

    # 注意力机制函数
    def attention(self, h_: Tensor) -> Tensor:
        # 归一化输入
        h_ = self.norm(h_)
        # 计算 q, k, v
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # 获取 q, k, v 的维度
        b, c, h, w = q.shape
        # 重排列 q, k, v
        q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
        k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
        v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
        # 应用缩放点积注意力
        h_ = nn.functional.scaled_dot_product_attention(q, k, v)

        # 将输出重排列为原始维度
        return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)

    # 前向传播函数
    def forward(self, x: Tensor) -> Tensor:
        # 添加注意力机制后的输出到原始输入
        return x + self.proj_out(self.attention(x))


# 定义残差块类
class ResnetBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels

        # 初始化归一化层和卷积层
        self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        # 如果输入和输出通道数不同,初始化快捷连接
        if self.in_channels != self.out_channels:
            self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

    # 前向传播函数
    def forward(self, x):
        h = x
        # 通过第一层归一化、激活和卷积
        h = self.norm1(h)
        h = swish(h)
        h = self.conv1(h)

        # 通过第二层归一化、激活和卷积
        h = self.norm2(h)
        h = swish(h)
        h = self.conv2(h)

        # 如果输入和输出通道数不同,应用快捷连接
        if self.in_channels != self.out_channels:
            x = self.nin_shortcut(x)

        # 返回残差连接的结果
        return x + h


# 定义下采样类
class Downsample(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        # 在 torch conv 中没有非对称填充,必须手动处理
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
    # 前向传播函数,接受一个 Tensor 作为输入
        def forward(self, x: Tensor):
            # 定义 padding 的大小,分别是右边 1、下边 1
            pad = (0, 1, 0, 1)
            # 对输入 Tensor 进行 padding,填充值为 0
            x = nn.functional.pad(x, pad, mode="constant", value=0)
            # 将 padding 过的 Tensor 通过卷积层
            x = self.conv(x)
            # 返回卷积后的结果
            return x
# 定义上采样模块,继承自 nn.Module
class Upsample(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        # 创建卷积层,用于对输入特征图进行卷积操作
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x: Tensor):
        # 对输入特征图进行双线性插值上采样,扩大尺寸为原来的2倍
        x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
        # 对上采样后的特征图应用卷积层
        x = self.conv(x)
        # 返回处理后的特征图
        return x


# 定义编码器模块,继承自 nn.Module
class Encoder(nn.Module):
    def __init__(
        self,
        resolution: int,
        in_channels: int,
        ch: int,
        ch_mult: list[int],
        num_res_blocks: int,
        z_channels: int,
    ):
        super().__init__()
        self.ch = ch
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels
        # 输入层卷积,用于初始化特征图
        self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)

        curr_res = resolution
        in_ch_mult = (1,) + tuple(ch_mult)
        self.in_ch_mult = in_ch_mult
        self.down = nn.ModuleList()
        block_in = self.ch
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            # 设置每层的输入和输出通道数
            block_in = ch * in_ch_mult[i_level]
            block_out = ch * ch_mult[i_level]
            for _ in range(self.num_res_blocks):
                # 添加残差块到当前层
                block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
                block_in = block_out
            down = nn.Module()
            down.block = block
            down.attn = attn
            if i_level != self.num_resolutions - 1:
                # 添加下采样层
                down.downsample = Downsample(block_in)
                curr_res = curr_res // 2
            self.down.append(down)

        # 中间层,包括两个残差块和一个注意力块
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
        self.mid.attn_1 = AttnBlock(block_in)
        self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)

        # 输出层,包括归一化和卷积层
        self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
        self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x: Tensor) -> Tensor:
        # 对输入特征图进行下采样
        hs = [self.conv_in(x)]
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = self.down[i_level].block[i_block](hs[-1])
                if len(self.down[i_level].attn) > 0:
                    h = self.down[i_level].attn[i_block](h)
                hs.append(h)
            if i_level != self.num_resolutions - 1:
                hs.append(self.down[i_level].downsample(hs[-1]))

        # 中间处理
        h = hs[-1]
        h = self.mid.block_1(h)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h)
        # 输出处理
        h = self.norm_out(h)
        h = swish(h)
        h = self.conv_out(h)
        # 返回最终处理后的特征图
        return h


# 定义解码器模块,继承自 nn.Module
class Decoder(nn.Module):
    def __init__(
        self,
        ch: int,
        out_ch: int,
        ch_mult: list[int],
        num_res_blocks: int,
        in_channels: int,
        resolution: int,
        z_channels: int,
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 保存输入通道数
        self.ch = ch
        # 保存多分辨率通道数的数量
        self.num_resolutions = len(ch_mult)
        # 保存残差块的数量
        self.num_res_blocks = num_res_blocks
        # 保存图像分辨率
        self.resolution = resolution
        # 保存输入通道数
        self.in_channels = in_channels
        # 计算最终分辨率的缩放因子
        self.ffactor = 2 ** (self.num_resolutions - 1)

        # 计算最低分辨率下的输入通道数和分辨率
        block_in = ch * ch_mult[self.num_resolutions - 1]
        curr_res = resolution // 2 ** (self.num_resolutions - 1)
        # 定义潜在变量 z 的形状
        self.z_shape = (1, z_channels, curr_res, curr_res)

        # z 到 block_in 的卷积层
        self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)

        # 中间层模块
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
        self.mid.attn_1 = AttnBlock(block_in)
        self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)

        # 上采样模块
        self.up = nn.ModuleList()
        for i_level in reversed(range(self.num_resolutions)):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            # 当前分辨率下的输出通道数
            block_out = ch * ch_mult[i_level]
            for _ in range(self.num_res_blocks + 1):
                # 添加残差块
                block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
                block_in = block_out
            up = nn.Module()
            up.block = block
            up.attn = attn
            if i_level != 0:
                # 添加上采样层
                up.upsample = Upsample(block_in)
                curr_res = curr_res * 2
            # 将上采样模块插入列表开头,保持顺序一致
            self.up.insert(0, up)

        # 输出归一化层
        self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
        # 输出卷积层
        self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)

    def forward(self, z: Tensor) -> Tensor:
        # 将 z 传入 conv_in 层
        h = self.conv_in(z)

        # 通过中间层
        h = self.mid.block_1(h)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h)

        # 上采样过程
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks + 1):
                h = self.up[i_level].block[i_block](h)
                if len(self.up[i_level].attn) > 0:
                    h = self.up[i_level].attn[i_block](h)
            if i_level != 0:
                # 上采样
                h = self.up[i_level].upsample(h)

        # 结束层
        h = self.norm_out(h)
        h = swish(h)
        h = self.conv_out(h)
        # 返回最终输出
        return h
# 定义对角高斯分布的神经网络模块
class DiagonalGaussian(nn.Module):
    # 初始化方法,定义是否采样及分块维度
    def __init__(self, sample: bool = True, chunk_dim: int = 1):
        super().__init__()
        # 是否进行采样
        self.sample = sample
        # 进行分块操作的维度
        self.chunk_dim = chunk_dim

    # 前向传播方法
    def forward(self, z: Tensor) -> Tensor:
        # 将输入张量 z 按指定维度 chunk_dim 划分为两个张量 mean 和 logvar
        mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
        if self.sample:
            # 如果需要采样,计算标准差并从标准正态分布中生成随机样本
            std = torch.exp(0.5 * logvar)
            return mean + std * torch.randn_like(mean)
        else:
            # 否则只返回均值
            return mean


# 定义自编码器的神经网络模块
class AutoEncoder(nn.Module):
    # 初始化方法,定义编码器、解码器及高斯分布
    def __init__(self, params: AutoEncoderParams):
        super().__init__()
        # 创建编码器实例,传入相应参数
        self.encoder = Encoder(
            resolution=params.resolution,
            in_channels=params.in_channels,
            ch=params.ch,
            ch_mult=params.ch_mult,
            num_res_blocks=params.num_res_blocks,
            z_channels=params.z_channels,
        )
        # 创建解码器实例,传入相应参数
        self.decoder = Decoder(
            resolution=params.resolution,
            in_channels=params.in_channels,
            ch=params.ch,
            out_ch=params.out_ch,
            ch_mult=params.ch_mult,
            num_res_blocks=params.num_res_blocks,
            z_channels=params.z_channels,
        )
        # 创建对角高斯分布实例
        self.reg = DiagonalGaussian()

        # 设置缩放因子和偏移因子
        self.scale_factor = params.scale_factor
        self.shift_factor = params.shift_factor

    # 编码方法,将输入 x 进行编码并调整缩放和偏移
    def encode(self, x: Tensor) -> Tensor:
        # 通过编码器获取 z,随后通过对角高斯分布进行处理
        z = self.reg(self.encoder(x))
        # 对 z 进行缩放和偏移
        z = self.scale_factor * (z - self.shift_factor)
        return z

    # 解码方法,将 z 解码为输出
    def decode(self, z: Tensor) -> Tensor:
        # 对 z 进行逆操作,恢复到编码前的尺度
        z = z / self.scale_factor + self.shift_factor
        # 使用解码器进行解码
        return self.decoder(z)

    # 前向传播方法,执行编码和解码
    def forward(self, x: Tensor) -> Tensor:
        # 先编码再解码
        return self.decode(self.encode(x))

.\flux\src\flux\modules\conditioner.py

# 从 PyTorch 和 Transformers 库导入必要的模块
from torch import Tensor, nn
from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
                          T5Tokenizer)

# 定义一个用于获取文本嵌入的类 HFEmbedder,继承自 nn.Module
class HFEmbedder(nn.Module):
    # 初始化方法
    def __init__(self, version: str, max_length: int, **hf_kwargs):
        # 调用父类的初始化方法
        super().__init__()
        # 判断是否使用 CLIP 模型,根据版本名进行判断
        self.is_clip = version.startswith("openai")
        # 设置最大长度
        self.max_length = max_length
        # 根据是否使用 CLIP 模型选择输出的键
        self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"

        # 如果使用 CLIP 模型
        if self.is_clip:
            # 从预训练模型加载 tokenizer
            self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
            # 从预训练模型加载 HF 模块
            self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
        else:
            # 如果使用 T5 模型
            # 从预训练模型加载 tokenizer
            self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
            # 从预训练模型加载 HF 模块
            self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)

        # 将模型设置为评估模式,并且不计算梯度
        self.hf_module = self.hf_module.eval().requires_grad_(False)

    # 前向传播方法,处理输入文本并返回嵌入
    def forward(self, text: list[str]) -> Tensor:
        # 使用 tokenizer 对文本进行编码
        batch_encoding = self.tokenizer(
            text,
            truncation=True,  # 对超长文本进行截断
            max_length=self.max_length,  # 设置最大长度
            return_length=False,  # 不返回文本长度
            return_overflowing_tokens=False,  # 不返回溢出的标记
            padding="max_length",  # 填充到最大长度
            return_tensors="pt",  # 返回 PyTorch 张量
        )

        # 使用 HF 模块进行前向传播计算
        outputs = self.hf_module(
            input_ids=batch_encoding["input_ids"].to(self.hf_module.device),  # 将输入 ID 移动到模型所在设备
            attention_mask=None,  # 不使用注意力掩码
            output_hidden_states=False,  # 不返回隐藏状态
        )
        # 返回指定键对应的输出
        return outputs[self.output_key]

.\flux\src\flux\modules\layers.py

# 导入数学库
import math
# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass

# 导入 PyTorch 库
import torch
# 从 einops 库导入 rearrange 函数
from einops import rearrange
# 从 torch 库导入 Tensor 和 nn 模块
from torch import Tensor, nn

# 从 flux.math 模块导入 attention 和 rope 函数
from flux.math import attention, rope


# 定义一个嵌入类,用于处理 N 维数据
class EmbedND(nn.Module):
    def __init__(self, dim: int, theta: int, axes_dim: list[int]):
        super().__init__()
        # 初始化维度、角度和轴维度
        self.dim = dim
        self.theta = theta
        self.axes_dim = axes_dim

    def forward(self, ids: Tensor) -> Tensor:
        # 获取输入 Tensor 的最后一维大小
        n_axes = ids.shape[-1]
        # 对每个轴应用 rope 函数并在-3维上连接
        emb = torch.cat(
            [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
            dim=-3,
        )
        # 在第1维上增加一个维度
        return emb.unsqueeze(1)


# 定义时间步嵌入函数,创建正弦时间步嵌入
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
    """
    创建正弦时间步嵌入。
    :param t: 一维 Tensor,包含每批次元素的索引,可以是小数。
    :param dim: 输出的维度。
    :param max_period: 控制嵌入的最小频率。
    :return: 一个 (N, D) 维的 Tensor,表示位置嵌入。
    """
    # 根据时间因子缩放输入 Tensor
    t = time_factor * t
    # 计算半维度
    half = dim // 2
    # 计算频率
    freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
        t.device
    )
    # 计算嵌入
    args = t[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    # 如果维度是奇数,追加零向量
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    # 如果 t 是浮点类型,将嵌入转换为 t 的类型
    if torch.is_floating_point(t):
        embedding = embedding.to(t)
    return embedding


# 定义一个 MLP 嵌入器类
class MLPEmbedder(nn.Module):
    def __init__(self, in_dim: int, hidden_dim: int):
        super().__init__()
        # 初始化输入层、激活函数和输出层
        self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
        self.silu = nn.SiLU()
        self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)

    def forward(self, x: Tensor) -> Tensor:
        # 执行前向传递,经过输入层、激活函数和输出层
        return self.out_layer(self.silu(self.in_layer(x)))


# 定义 RMSNorm 类
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        # 初始化尺度参数
        self.scale = nn.Parameter(torch.ones(dim))

    def forward(self, x: Tensor):
        # 将输入转换为浮点数
        x_dtype = x.dtype
        x = x.float()
        # 计算均方根归一化
        rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
        # 应用归一化和尺度参数
        return (x * rrms).to(dtype=x_dtype) * self.scale


# 定义 QKNorm 类
class QKNorm(torch.nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        # 初始化查询和键的归一化
        self.query_norm = RMSNorm(dim)
        self.key_norm = RMSNorm(dim)

    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
        # 对查询和键进行归一化
        q = self.query_norm(q)
        k = self.key_norm(k)
        # 返回归一化后的查询、键以及原始值
        return q.to(v), k.to(v)


# 定义自注意力机制类
class SelfAttention(nn.Module):
    def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
        super().__init__()
        # 设置头的数量和每个头的维度
        self.num_heads = num_heads
        head_dim = dim // num_heads

        # 初始化查询、键、值线性变换层
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        # 初始化归一化层
        self.norm = QKNorm(head_dim)
        # 初始化投影层
        self.proj = nn.Linear(dim, dim)
    # 前向传播函数,接受输入张量和位置编码,返回处理后的张量
    def forward(self, x: Tensor, pe: Tensor) -> Tensor:
        # 将输入张量通过 qkv 层,生成查询、键、值的联合表示
        qkv = self.qkv(x)
        # 重新排列 qkv 张量,将其拆分成查询 (q)、键 (k)、值 (v),并根据头数 (num_heads) 分组
        q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
        # 对查询、键和值进行归一化处理
        q, k = self.norm(q, k, v)
        # 计算注意力权重并应用于值,得到加权后的输出
        x = attention(q, k, v, pe=pe)
        # 通过 proj 层将注意力结果映射到输出空间
        x = self.proj(x)
        # 返回最终的输出张量
        return x
# 定义一个包含三个张量的结构体 ModulationOut
@dataclass
class ModulationOut:
    shift: Tensor
    scale: Tensor
    gate: Tensor


# 定义一个继承自 nn.Module 的 Modulation 类
class Modulation(nn.Module):
    # 初始化方法,设置维度和是否双倍
    def __init__(self, dim: int, double: bool):
        super().__init__()
        self.is_double = double  # 存储是否为双倍标志
        self.multiplier = 6 if double else 3  # 根据标志设置 multiplier
        self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)  # 定义线性层

    # 前向传播方法,处理输入张量并返回结果
    def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
        # 应用激活函数后,进行线性变换,并将结果按 multiplier 切分
        out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)

        # 返回切分后的结果,前半部分和后半部分(如果是双倍)
        return (
            ModulationOut(*out[:3]),  # 前三部分
            ModulationOut(*out[3:]) if self.is_double else None,  # 后三部分(如果是双倍)
        )


# 定义一个继承自 nn.Module 的 DoubleStreamBlock 类
class DoubleStreamBlock(nn.Module):
    # 初始化方法,设置隐藏层大小、注意力头数、MLP 比例等
    def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
        super().__init__()

        mlp_hidden_dim = int(hidden_size * mlp_ratio)  # 计算 MLP 隐藏层维度
        self.num_heads = num_heads  # 存储注意力头数
        self.hidden_size = hidden_size  # 存储隐藏层大小
        self.img_mod = Modulation(hidden_size, double=True)  # 定义图像模调模块
        self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)  # 定义图像的第一层归一化
        self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)  # 定义图像的自注意力模块

        self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)  # 定义图像的第二层归一化
        self.img_mlp = nn.Sequential(  # 定义图像的 MLP 网络
            nn.Linear(hidden_size, mlp_hidden_dim, bias=True),  # 第一层线性变换
            nn.GELU(approximate="tanh"),  # 激活函数
            nn.Linear(mlp_hidden_dim, hidden_size, bias=True),  # 第二层线性变换
        )

        self.txt_mod = Modulation(hidden_size, double=True)  # 定义文本模调模块
        self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)  # 定义文本的第一层归一化
        self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)  # 定义文本的自注意力模块

        self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)  # 定义文本的第二层归一化
        self.txt_mlp = nn.Sequential(  # 定义文本的 MLP 网络
            nn.Linear(hidden_size, mlp_hidden_dim, bias=True),  # 第一层线性变换
            nn.GELU(approximate="tanh"),  # 激活函数
            nn.Linear(mlp_hidden_dim, hidden_size, bias=True),  # 第二层线性变换
        )
    # 前向传播函数,处理图像和文本输入,返回更新后的图像和文本
    def forward(self, img: Tensor
# 定义一个 DiT 模块,其中包含并行的线性层以及调整的调制接口
class SingleStreamBlock(nn.Module):
    """
    A DiT block with parallel linear layers as described in
    https://arxiv.org/abs/2302.05442 and adapted modulation interface.
    """

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        qk_scale: float | None = None,
    ):
        super().__init__()
        # 初始化隐藏层维度和注意力头的数量
        self.hidden_dim = hidden_size
        self.num_heads = num_heads
        head_dim = hidden_size // num_heads
        # 计算缩放因子
        self.scale = qk_scale or head_dim**-0.5

        # 计算 MLP 层的隐藏维度
        self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
        # 定义用于 QKV 和 MLP 输入的线性层
        self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
        # 定义用于投影和 MLP 输出的线性层
        self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)

        # 定义归一化层
        self.norm = QKNorm(head_dim)

        # 定义层归一化层
        self.hidden_size = hidden_size
        self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)

        # 定义激活函数和调制层
        self.mlp_act = nn.GELU(approximate="tanh")
        self.modulation = Modulation(hidden_size, double=False)

    def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
        # 通过调制层计算调制因子
        mod, _ = self.modulation(vec)
        # 对输入进行预归一化并应用调制
        x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
        # 将线性层的输出分割为 QKV 和 MLP 输入
        qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)

        # 重新排列 QKV 张量,并进行归一化
        q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
        q, k = self.norm(q, k, v)

        # 计算注意力
        attn = attention(q, k, v, pe=pe)
        # 计算 MLP 流中的激活,拼接结果并通过第二个线性层
        output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
        # 将原始输入与输出加权和相加
        return x + mod.gate * output


# 定义最后一层的网络模块
class LastLayer(nn.Module):
    def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
        super().__init__()
        # 定义最终的层归一化
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        # 定义线性层将隐藏维度映射到最终输出通道
        self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
        # 定义自适应层归一化调制
        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))

    def forward(self, x: Tensor, vec: Tensor) -> Tensor:
        # 通过调制层计算 shift 和 scale
        shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
        # 归一化输入并应用 shift 和 scale
        x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
        # 通过线性层计算最终输出
        x = self.linear(x)
        return x

.\flux\src\flux\sampling.py

# 导入数学库
import math
# 导入 Callable 类型
from typing import Callable

# 导入 PyTorch 库
import torch
# 从 einops 导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 torch 导入 Tensor 类型
from torch import Tensor

# 从 model 模块导入 Flux 类
from .model import Flux
# 从 modules.conditioner 模块导入 HFEmbedder 类
from .modules.conditioner import HFEmbedder


# 生成噪声的函数
def get_noise(
    num_samples: int,  # 生成的样本数量
    height: int,  # 高度
    width: int,  # 宽度
    device: torch.device,  # 计算设备
    dtype: torch.dtype,  # 数据类型
    seed: int,  # 随机种子
):
    return torch.randn(
        num_samples,  # 样本数量
        16,  # 通道数
        # 允许打包的高度和宽度
        2 * math.ceil(height / 16),
        2 * math.ceil(width / 16),
        device=device,  # 指定设备
        dtype=dtype,  # 指定数据类型
        generator=torch.Generator(device=device).manual_seed(seed),  # 使用指定种子初始化随机生成器
    )


# 准备数据的函数
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
    bs, c, h, w = img.shape  # 获取批量大小、通道数、高度和宽度
    if bs == 1 and not isinstance(prompt, str):  # 如果批量大小为1且提示不是字符串
        bs = len(prompt)  # 设置批量大小为提示列表的长度

    # 调整图像形状以适应后续处理
    img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
    if img.shape[0] == 1 and bs > 1:  # 如果批量大小为1且实际批量大于1
        img = repeat(img, "1 ... -> bs ...", bs=bs)  # 复制图像以适应批量大小

    img_ids = torch.zeros(h // 2, w // 2, 3)  # 创建图像ID的零张量
    img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]  # 设置行ID
    img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]  # 设置列ID
    img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)  # 将ID张量重复以适应批量大小

    if isinstance(prompt, str):  # 如果提示是字符串
        prompt = [prompt]  # 将提示转换为列表
    txt = t5(prompt)  # 使用 t5 模型处理文本提示
    if txt.shape[0] == 1 and bs > 1:  # 如果文本的批量大小为1且实际批量大于1
        txt = repeat(txt, "1 ... -> bs ...", bs=bs)  # 复制文本以适应批量大小
    txt_ids = torch.zeros(bs, txt.shape[1], 3)  # 创建文本ID的零张量

    vec = clip(prompt)  # 使用 clip 模型处理文本提示
    if vec.shape[0] == 1 and bs > 1:  # 如果向量的批量大小为1且实际批量大于1
        vec = repeat(vec, "1 ... -> bs ...", bs=bs)  # 复制向量以适应批量大小

    return {
        "img": img,  # 返回处理后的图像
        "img_ids": img_ids.to(img.device),  # 返回图像ID,转移到图像所在设备
        "txt": txt.to(img.device),  # 返回处理后的文本,转移到图像所在设备
        "txt_ids": txt_ids.to(img.device),  # 返回文本ID,转移到图像所在设备
        "vec": vec.to(img.device),  # 返回处理后的向量,转移到图像所在设备
    }


# 计算时间移位的函数
def time_shift(mu: float, sigma: float, t: Tensor):
    return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)  # 计算时间移位值


# 获取线性函数的函数
def get_lin_function(
    x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15  # 默认参数值
) -> Callable[[float], float]:  # 返回一个接受浮点数并返回浮点数的函数
    m = (y2 - y1) / (x2 - x1)  # 计算线性函数的斜率
    b = y1 - m * x1  # 计算线性函数的截距
    return lambda x: m * x + b  # 返回线性函数


# 获取调度时间的函数
def get_schedule(
    num_steps: int,  # 步骤数量
    image_seq_len: int,  # 图像序列长度
    base_shift: float = 0.5,  # 基础偏移量
    max_shift: float = 1.15,  # 最大偏移量
    shift: bool = True,  # 是否应用偏移
) -> list[float]:  # 返回浮点数列表
    # 生成从1到0的时间步长
    timesteps = torch.linspace(1, 0, num_steps + 1)

    # 如果启用了偏移
    if shift:
        # 基于线性估算估计 mu
        mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
        timesteps = time_shift(mu, 1.0, timesteps)  # 应用时间移位

    return timesteps.tolist()  # 返回时间步长的列表


# 去噪函数
def denoise(
    model: Flux,  # 模型
    # 模型输入
    img: Tensor,  # 输入图像
    img_ids: Tensor,  # 图像ID
    txt: Tensor,  # 处理后的文本
    txt_ids: Tensor,  # 文本ID
    vec: Tensor,  # 处理后的向量
    # 采样参数
    timesteps: list[float],  # 时间步长
    guidance: float = 4.0,  # 引导强度
):
    # 为每个图像创建引导向量
    guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
    # 遍历当前时间步和前一个时间步的配对
    for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
        # 创建一个张量 t_vec,其形状与 img 的第一个维度相同,值为 t_curr,数据类型和设备与 img 相同
        t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
        # 使用当前时间步 t_vec 及其他参数调用模型,获得预测结果 pred
        pred = model(
            img=img,
            img_ids=img_ids,
            txt=txt,
            txt_ids=txt_ids,
            y=vec,
            timesteps=t_vec,
            guidance=guidance_vec,
        )

        # 更新 img,增加预测结果 pred 和时间步差 (t_prev - t_curr) 的乘积
        img = img + (t_prev - t_curr) * pred

    # 返回更新后的 img
    return img
# 定义一个函数,用于对 Tensor 进行重排列,调整维度
def unpack(x: Tensor, height: int, width: int) -> Tensor:
    # 使用 rearrange 函数重排列 Tensor 的维度
    return rearrange(
        x,
        # 指定输入维度和输出维度的转换规则
        "b (h w) (c ph pw) -> b c (h ph) (w pw)",
        # 根据输入的 height 和 width 计算重排列后的维度
        h=math.ceil(height / 16),
        w=math.ceil(width / 16),
        ph=2,
        pw=2,
    )

.\flux\src\flux\util.py

# 导入操作系统模块
import os
# 从 dataclasses 模块导入 dataclass 装饰器,用于创建数据类
from dataclasses import dataclass

# 导入 PyTorch 库,用于张量操作和深度学习
import torch
# 从 einops 库导入 rearrange 函数,用于重排列和转换张量
from einops import rearrange
# 从 huggingface_hub 库导入 hf_hub_download 函数,用于下载模型文件
from huggingface_hub import hf_hub_download
# 从 imwatermark 库导入 WatermarkEncoder 类,用于在图像中嵌入水印
from imwatermark import WatermarkEncoder
# 从 safetensors 库导入 load_file 函数,并重命名为 load_sft,用于加载安全张量文件
from safetensors.torch import load_file as load_sft

# 从 flux.model 模块导入 Flux 类和 FluxParams 类,用于模型定义和参数配置
from flux.model import Flux, FluxParams
# 从 flux.modules.autoencoder 模块导入 AutoEncoder 类和 AutoEncoderParams 类,用于自动编码器定义和参数配置
from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
# 从 flux.modules.conditioner 模块导入 HFEmbedder 类,用于条件嵌入
from flux.modules.conditioner import HFEmbedder


# 定义一个数据类 ModelSpec,用于保存模型的各种规格和参数
@dataclass
class ModelSpec:
    # 定义模型参数
    params: FluxParams
    # 定义自动编码器参数
    ae_params: AutoEncoderParams
    # 定义检查点路径(可以为 None)
    ckpt_path: str | None
    # 定义自动编码器路径(可以为 None)
    ae_path: str | None
    # 定义模型仓库 ID(可以为 None)
    repo_id: str | None
    # 定义流文件仓库 ID(可以为 None)
    repo_flow: str | None
    # 定义自动编码器仓库 ID(可以为 None)
    repo_ae: str | None


# 定义配置字典 configs,包含不同模型的规格
configs = {
    # 配置 "flux-dev" 模型的规格
    "flux-dev": ModelSpec(
        # 设置模型仓库 ID
        repo_id="black-forest-labs/FLUX.1-dev",
        # 设置流文件仓库 ID
        repo_flow="flux1-dev.safetensors",
        # 设置自动编码器仓库 ID
        repo_ae="ae.safetensors",
        # 从环境变量获取检查点路径
        ckpt_path=os.getenv("FLUX_DEV"),
        # 设置 Flux 模型参数
        params=FluxParams(
            in_channels=64,
            vec_in_dim=768,
            context_in_dim=4096,
            hidden_size=3072,
            mlp_ratio=4.0,
            num_heads=24,
            depth=19,
            depth_single_blocks=38,
            axes_dim=[16, 56, 56],
            theta=10_000,
            qkv_bias=True,
            guidance_embed=True,
        ),
        # 从环境变量获取自动编码器路径
        ae_path=os.getenv("AE"),
        # 设置自动编码器参数
        ae_params=AutoEncoderParams(
            resolution=256,
            in_channels=3,
            ch=128,
            out_ch=3,
            ch_mult=[1, 2, 4, 4],
            num_res_blocks=2,
            z_channels=16,
            scale_factor=0.3611,
            shift_factor=0.1159,
        ),
    ),
    # 配置 "flux-schnell" 模型的规格
    "flux-schnell": ModelSpec(
        # 设置模型仓库 ID
        repo_id="black-forest-labs/FLUX.1-schnell",
        # 设置流文件仓库 ID
        repo_flow="flux1-schnell.safetensors",
        # 设置自动编码器仓库 ID
        repo_ae="ae.safetensors",
        # 从环境变量获取检查点路径
        ckpt_path=os.getenv("FLUX_SCHNELL"),
        # 设置 Flux 模型参数
        params=FluxParams(
            in_channels=64,
            vec_in_dim=768,
            context_in_dim=4096,
            hidden_size=3072,
            mlp_ratio=4.0,
            num_heads=24,
            depth=19,
            depth_single_blocks=38,
            axes_dim=[16, 56, 56],
            theta=10_000,
            qkv_bias=True,
            guidance_embed=False,
        ),
        # 从环境变量获取自动编码器路径
        ae_path=os.getenv("AE"),
        # 设置自动编码器参数
        ae_params=AutoEncoderParams(
            resolution=256,
            in_channels=3,
            ch=128,
            out_ch=3,
            ch_mult=[1, 2, 4, 4],
            num_res_blocks=2,
            z_channels=16,
            scale_factor=0.3611,
            shift_factor=0.1159,
        ),
    ),
}

# 定义函数 print_load_warning,用于打印加载警告信息
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
    # 如果缺少的键和意外的键都存在,则分别打印它们的数量和列表
    if len(missing) > 0 and len(unexpected) > 0:
        print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
        print("\n" + "-" * 79 + "\n")
        print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
    # 如果只有缺少的键存在,则打印它们的数量和列表
    elif len(missing) > 0:
        print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
    # 如果意外的键数量大于0
        elif len(unexpected) > 0:
            # 打印意外的键数量和它们的列表
            print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
# 定义加载模型的函数,指定模型名称、设备和是否从 HF 下载
def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
    # 打印初始化模型的消息
    print("Init model")
    # 获取配置文件中的检查点路径
    ckpt_path = configs[name].ckpt_path
    # 如果检查点路径为空且需要从 HF 下载
    if (
        ckpt_path is None
        and configs[name].repo_id is not None
        and configs[name].repo_flow is not None
        and hf_download
    ):
        # 从 HF 下载模型文件
        ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)

    # 根据是否有检查点路径选择设备
    with torch.device("meta" if ckpt_path is not None else device):
        # 初始化模型并设置数据类型为 bfloat16
        model = Flux(configs[name].params).to(torch.bfloat16)

    # 如果有检查点路径,加载模型状态
    if ckpt_path is not None:
        print("Loading checkpoint")
        # 加载检查点并转为字符串设备
        sd = load_sft(ckpt_path, device=str(device))
        # 加载状态字典,并检查缺失或意外的参数
        missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
        print_load_warning(missing, unexpected)
    # 返回模型
    return model


# 定义加载 T5 模型的函数,指定设备和最大序列长度
def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
    # 创建 HFEmbedder 对象,使用 T5 模型并设置最大序列长度和数据类型
    return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device)


# 定义加载 CLIP 模型的函数,指定设备
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
    # 创建 HFEmbedder 对象,使用 CLIP 模型并设置最大序列长度和数据类型
    return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)


# 定义加载自动编码器的函数,指定名称、设备和是否从 HF 下载
def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
    # 获取配置文件中的自动编码器路径
    ckpt_path = configs[name].ae_path
    # 如果路径为空且需要从 HF 下载
    if (
        ckpt_path is None
        and configs[name].repo_id is not None
        and configs[name].repo_ae is not None
        and hf_download
    ):
        # 从 HF 下载自动编码器文件
        ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)

    # 打印初始化自动编码器的消息
    print("Init AE")
    # 根据是否有检查点路径选择设备
    with torch.device("meta" if ckpt_path is not None else device):
        # 初始化自动编码器
        ae = AutoEncoder(configs[name].ae_params)

    # 如果有检查点路径,加载自动编码器状态
    if ckpt_path is not None:
        # 加载检查点并转为字符串设备
        sd = load_sft(ckpt_path, device=str(device))
        # 加载状态字典,并检查缺失或意外的参数
        missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
        print_load_warning(missing, unexpected)
    # 返回自动编码器
    return ae


# 定义水印嵌入器类
class WatermarkEmbedder:
    def __init__(self, watermark):
        # 初始化水印和比特位数
        self.watermark = watermark
        self.num_bits = len(WATERMARK_BITS)
        # 初始化水印编码器
        self.encoder = WatermarkEncoder()
        # 设置水印比特数据
        self.encoder.set_watermark("bits", self.watermark)
    # 定义一个可调用对象的 `__call__` 方法,用于给输入图像添加预定义的水印
    def __call__(self, image: torch.Tensor) -> torch.Tensor:
        """
        Adds a predefined watermark to the input image
    
        Args:
            image: ([N,] B, RGB, H, W) in range [-1, 1]
    
        Returns:
            same as input but watermarked
        """
        # 将图像的像素值从范围 [-1, 1] 线性映射到 [0, 1]
        image = 0.5 * image + 0.5
        # 检查图像张量的形状是否是 4 维 (即 batch size 和通道数)
        squeeze = len(image.shape) == 4
        if squeeze:
            # 如果是 4 维,给图像增加一个额外的维度,变成 5 维
            image = image[None, ...]
        # 获取图像的 batch size
        n = image.shape[0]
        # 将图像从 torch 张量转换为 numpy 数组,并调整形状和通道顺序
        image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
        # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
        # watermarking libary expects input as cv2 BGR format
        # 遍历每张图像,为每张图像应用水印编码
        for k in range(image_np.shape[0]):
            image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
        # 将图像从 numpy 数组转换回 torch 张量,恢复原始的形状和设备
        image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
            image.device
        )
        # 将图像的像素值从 [0, 255] 归一化到 [0, 1]
        image = torch.clamp(image / 255, min=0.0, max=1.0)
        if squeeze:
            # 如果之前添加了额外的维度,则将其移除,恢复原始形状
            image = image[0]
        # 将图像的像素值从 [0, 1] 转换回 [-1, 1] 范围
        image = 2 * image - 1
        # 返回处理后的图像
        return image
# 固定的 48 位消息,随机选择的
WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
# bin(x)[2:] 将 x 转换为二进制字符串(去掉前缀 '0b'),然后用 int 将每一位转换为 0 或 1
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
# 使用提取的位创建 WatermarkEmbedder 对象
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)

.\flux\src\flux\__init__.py

# 尝试从当前包的 `_version` 模块导入 `version` 和 `version_tuple`
try:
    from ._version import version as __version__  # type: ignore  # type: ignore 用于忽略类型检查器的警告
    from ._version import version_tuple
# 如果导入失败(模块不存在),则设置默认的版本信息
except ImportError:
    __version__ = "unknown (no version information available)"  # 设置版本号为未知
    version_tuple = (0, 0, "unknown", "noinfo")  # 设置版本元组为未知

# 导入 Path 类以便处理文件路径
from pathlib import Path

# 设置包的名称,将包名中的下划线替换为短横线
PACKAGE = __package__.replace("_", "-")
# 获取当前文件所在目录的路径
PACKAGE_ROOT = Path(__file__).parent

.\flux\src\flux\__main__.py

# 从同一目录下的 cli 模块导入 app 函数
from .cli import app

# 如果当前模块是主程序,则执行 app 函数
if __name__ == "__main__":
    app()

标签:Tensor,nn,self,torch,channels,FLUX,源码,device,解析
From: https://www.cnblogs.com/apachecn/p/18398089

相关文章