首页 > 其他分享 >dreambooth代码阅读

dreambooth代码阅读

时间:2024-08-13 18:48:44浏览次数:19  
标签:prompt dreambooth text 代码 args encoder 阅读 model class

网上dreambooth大部分只是对论文讲解,但代码讲解不是找不到就是收费,没办法,自己硬读,记录一下。
水平不高,学机器学习不久,可能有错,欢迎指正,仅做参考。

Dreambooth 流程简单来说是1,通过在现有的Diffusion模型增加一个你要的token,变成一个新的模型,比如你给特定一只sys狗的照片训练,你新生成的模型就有dog 的token 和 sys dog 的token。2,这时你就能用dog token 生成一只普通的狗,用sys dog 生成sys狗。前一部分是微调模型,后一部分是生成图片。
这里介绍很好,比项目的readme更细一点https://huggingface.co/docs/diffusers/main/en/training/dreambooth
环境搭好后
accelerate config 设置训练使用模式,特别是分布式训练,我对这个不太懂,停留在设好能跑就行

然后会让运行一个以下代码格式的脚本,主要是运行train_dreambooth.py和设置参数,
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export INSTANCE_DIR="./dog"
export OUTPUT_DIR="path_to_saved_model"

accelerate launch train_dreambooth.py
--pretrained_model_name_or_path=$MODEL_NAME
--instance_data_dir=$INSTANCE_DIR
--output_dir=$OUTPUT_DIR
--instance_prompt="a photo of sks dog"
--resolution=512
--train_batch_size=1
--gradient_accumulation_steps=1
--learning_rate=5e-6
--lr_scheduler="constant"
--lr_warmup_steps=0
--max_train_steps=400
--push_to_hub

重点在train_dreambooth.py的代码
1400行代码先不管翻到最下面
if name == "main":
args = parse_args()
main(args)
就两行,先进行parse_args(),再main(),至少结构是清晰的
先看parse_args()函数
主要作用就是将脚本或者命令行的传参进行解析,保存到args中

  def parse_args(input_args=None):
  ##主要是将上面脚本的传参进行解析,
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
  #使用argparse 模块创建了一个解析命令行参数的对象,该对象名为 parser。
    parser.add_argument(#设置传入参数格式,名字,格式,帮助等
        "--pretrained_model_name_or_path",
        type=str,
        default=None,
        required=True,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--revision",
        type=str,
        default=None,
        required=False,
        help="Revision of pretrained model identifier from huggingface.co/models.",
    )

......后面都是这样的,不细看用到再说
然后是一些环境配置和传参的报错

if input_args is not None:#不为空,使用提供的参数来解析命令行参数;
    args = parser.parse_args(input_args)
else:#为空,使用默认的方式来解析命令行参数。
    args = parser.parse_args()

env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
#从环境变量 LOCAL_RANK 中获取一个值并将其转换为整数类型。不存在为-1
#local_rank解释建议参考https://blog.csdn.net/shenjianhua005/article/details/127318594
if env_local_rank != -1 and env_local_rank != args.local_rank:
    args.local_rank = env_local_rank
#赋值

if args.with_prior_preservation:
    #传入参数有with_prior reservation loss 是否使用先置保全损失
    #没理解错的话是用来保证微调过程中dog 的token的语义不会偏移,就是依旧是普通的狗,不是特定的sys狗
    if args.class_data_dir is None:#报错
        #--class_data_dir:包含生成的类样本图像的文件夹的路径
        raise ValueError("You must specify a data directory for class images.")
    if args.class_prompt is None:
        #--class_prompt:描述生成的样本图像类别的文本提示
        raise ValueError("You must specify prompt for class images.")
else:
    #警告
    # logger is not available yet
    if args.class_data_dir is not None:
        warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
    if args.class_prompt is not None:
        warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
#  有必要提一嘴--instance_prompt="a photo of sks dog" \  包含示例图像的特殊词的文本提示
#   --class_prompt="a photo of dog" \  描述生成的样本图像类别的文本提示
#使用with_prior reservation loss 先置保全损失 需要class_prompt参数,不使用则不需要
#在我的理解,使用with_prior reservation 模型要生成普通的多样的狗的图片,便于保证dog token不会偏移
#class_data_dir是用来存放算法过程中生成的普通的狗的图片,往后看就知道了,是这样的我不会回来改



if args.train_text_encoder and args.pre_compute_text_embeddings:
    raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`")
#-train_text_encoder 是否也要训练文本编码器
# --pre_compute_text_embeddings  Whether or not to pre-compute text embeddings. 是否预先计算文本嵌入。不是很懂,先不管。
return args #返回处理完的参数 400行代码大部分是参数设置,松一口气,后面也这么简单就好

后面就开始看main函数
先是如果要将模型上传到库的报错和设置

  if args.report_to == "wandb" and args.hub_token is not None:
      #--report_to是 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
      #    ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
      #将结果和日志报告集成到的平台。支持的平台有"tensorboard"(默认)、"wandb"和"comet_ml"。使用"all"来报告到所有集成平台。
      #--hub_token 是"The token to use to push to the Model Hub.
      raise ValueError(
          "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
          " Please use `huggingface-cli login` to authenticate with the Hub."
      )
  #报错,由于存在泄露令牌的安全风险,您不能同时使用 --report_to=wandb 和 --hub_token。请使用 huggingface-cli login 通过 Hub 进行身份验证。
  #不太懂,我应该不需要上传,应该不重要

  logging_dir = Path(args.output_dir, args.logging_dir)
  #连接output_dir和“TensorBoard日志目录,有默认值,所以参数不影响执行  --output_dir:保存训练好的模型,
  #logging_dir所以是生成output_dir上传TensorBoard的地址

再是一些参数设置,有些我不太懂

accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
#配置项目的各种参数,以便在后续的代码中使用这些配置。
accelerator = Accelerator(
    gradient_accumulation_steps=args.gradient_accumulation_steps,#用于指定梯度累积的步数。
    mixed_precision=args.mixed_precision,#mixed_precision: 用于指定是否启用混合精度训练。
    log_with=args.report_to,#og_with: 用于指定日志的输出方式。
    project_config=accelerator_project_config,#project_config: 用于指定项目的配置信息。
)# 实例化Accelerator类,有关accelerator的可以看这个https://blog.csdn.net/qq_56591814/article/details/134200839      


# Disable AMP for MPS.
if torch.backends.mps.is_available():
    accelerator.native_amp = False
#检查了是否torch后端支持MPS(Multi-Process Service 不懂 跳过

if args.report_to == "wandb":#简单的警告
    if not is_wandb_available():
        raise ImportError("Make sure to install wandb if you want to use it for logging during training.")

# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
 #目前,使用 accelerate.accumulate 在训练两个模型时不支持梯度累积。
  #这个功能很快会在 accelerate 中启用。目前,当训练两个模型时,我们不允许梯度累积。
  #   TODO (patil-suraj): 当 accelerate 中允许在训练两个模型时进行梯度累积时,请移除这个检查
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
    raise ValueError(
        "Gradient accumulation is not supported when training the text encoder in distributed training. "
        "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
    )

日志设置

# Make one log on every process with the configuration for debugging.
#在每个进程上使用配置进行调试时进行一次日志记录。 具体不想看了gpt一搜就有
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
    transformers.utils.logging.set_verbosity_warning()
    diffusers.utils.logging.set_verbosity_info()
else:
    transformers.utils.logging.set_verbosity_error()
    diffusers.utils.logging.set_verbosity_error()

随机数seed

# If passed along, set the training seed now.
if args.seed is not None:
    set_seed(args.seed)
#设置accelerate.utils 随机数的种子,我觉得可以参考https://blog.csdn.net/qq_40206371/article/details/139466522

使用先验损失的话生成普通dog图片防止dog语义偏移

# Generate class images if prior preservation is enabled.
#使用了with_prior_preservation 生成生成类别图像。
if args.with_prior_preservation:
    class_images_dir = Path(args.class_data_dir)#dog
    if not class_images_dir.exists():#如果 class_images_dir 不存在,则创建该目录。
        class_images_dir.mkdir(parents=True)
    cur_class_images = len(list(class_images_dir.iterdir()))#cur_class_images 变量存储了 class_images_dir 目录中文件的数量。

    if cur_class_images < args.num_class_images:
        #加载模型
        torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
        if args.prior_generation_precision == "fp32":
            torch_dtype = torch.float32
        elif args.prior_generation_precision == "fp16":
            torch_dtype = torch.float16
        elif args.prior_generation_precision == "bf16":
            torch_dtype = torch.bfloat16
        pipeline = DiffusionPipeline.from_pretrained(
            args.pretrained_model_name_or_path,
            torch_dtype=torch_dtype,
            safety_checker=None,
            revision=args.revision,
            variant=args.variant,
        )
        pipeline.set_progress_bar_config(disable=True)
        # 用于禁用流水线的进度条

        num_new_images = args.num_class_images - cur_class_images#还需要生成的图片数量
        logger.info(f"Number of class images to sample: {num_new_images}.")#日志

        sample_dataset = PromptDataset(args.class_prompt, num_new_images)#生成数据集
        sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)#加载数据集

        sample_dataloader = accelerator.prepare(sample_dataloader)
        #使用了加速器(如GPU或TPU)来准备数据加载器sample_dataloader
        pipeline.to(accelerator.device)#模型或数据处理管道移动到加速设备

        for example in tqdm(
            sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
        ):#迭代和进度条
            images = pipeline(example["prompt"]).images#生成图片

            for i, image in enumerate(images):#进行hash命名保存
                hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
                image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
                image.save(image_filename)

        del pipeline
        if torch.cuda.is_available():
            torch.cuda.empty_cache()#有gpu清空缓存
#生成dog图像结束

上传仓库设置

# Handle the repository creation
#对上传的仓库的设置,不了解,没细看
if accelerator.is_main_process:
    if args.output_dir is not None:
        os.makedirs(args.output_dir, exist_ok=True)

    if args.push_to_hub:
        repo_id = create_repo(
            repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
        ).repo_id

加载tokener

# Load the tokenizer
if args.tokenizer_name:#tokener名和model名不一样的话
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
elif args.pretrained_model_name_or_path:
    tokenizer = AutoTokenizer.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="tokenizer",
        revision=args.revision,
        use_fast=False,
    )

加载scheduler and models,unet

# import correct text encoder class
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
#作者自己写的一个返回模型类别的函数,用于后面使用用class.xxx函数

# Load scheduler and models
#逆向扩散的Scheduler和encoder
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)

if model_has_vae(args):#自定义函数判断是否有vae 有则初始化自动编码器
    vae = AutoencoderKL.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
    )
else:
    vae = None
#加载语义分割模型
unet = UNet2DConditionModel.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)

hook的设置,说实话不太懂

def unwrap_model(model):
    model = accelerator.unwrap_model(model)#解包模型,具体还是看https://blog.csdn.net/qq_56591814/article/details/134200839
    model = model._orig_mod if is_compiled_module(model) else model#检查模块是否是使用 torch.compile 进行编译的
    return model

# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
#创建自定义保存和加载hooks,以便使 accelerator.save_state(...) 方法以一种良好的格式序列化
def save_model_hook(models, weights, output_dir):
    if accelerator.is_main_process:
        for model in models:
            sub_dir = "unet" if isinstance(model, type(unwrap_model(unet))) else "text_encoder"#不同情况目录名
            model.save_pretrained(os.path.join(output_dir, sub_dir))#模型保存到目录

            # make sure to pop weight so that corresponding model is not saved again
            #确保弹出权重,这样相应的模型就不会再次保存。
            weights.pop()

def load_model_hook(models, input_dir):
    while len(models) > 0:
        # pop models so that they are not loaded again
        model = models.pop()

        #使用isinstance检查模型类型以确定应该使用哪种方式加载模型。如果模型类型是text_encoder,
        # 则使用transformers风格加载模型,并更新模型的配置(config);如果模型类型不是text_encoder,则使用UNet2DConditionModel风格加载模型,并注册配置信息。
        if isinstance(model, type(unwrap_model(text_encoder))):
            # load transformers style into model
            load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
            model.config = load_model.config
        else:
            # load diffusers style into model
            load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
            model.register_to_config(**load_model.config)

        model.load_state_dict(load_model.state_dict())
        #将已加载的load_model的状态字典(包含模型的参数)应用到当前的model中。
        del load_model

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
#设置hook来获得运行时中间变量,里面的函数也应该是配置,对这个不太懂,有兴趣可以去accelerator库详细看看,先跳过

一些设置和报错

if vae is not None:
    vae.requires_grad_(False)#是否跟踪梯度

if not args.train_text_encoder:
    text_encoder.requires_grad_(False)

# xformers配置警告,md,上次就是这个b库和pytorch 版本不适配,搞死我了
#还有这东西不能放前面吗,每次跑那么久你才告诉我有问题
if args.enable_xformers_memory_efficient_attention:
    if is_xformers_available():
        import xformers

        xformers_version = version.parse(xformers.__version__)
        if xformers_version == version.parse("0.0.16"):
            logger.warning(
                "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
            )
        unet.enable_xformers_memory_efficient_attention()
    else:
        raise ValueError("xformers is not available. Make sure it is installed correctly")

if args.gradient_checkpointing:#是否使用梯度检查点来节省内存,但以减慢反向传播的速度。
    unet.enable_gradient_checkpointing()
    if args.train_text_encoder:
        text_encoder.gradient_checkpointing_enable()

# Check that all trainable models are in full precision
low_precision_error_string = (
    "Please make sure to always have all model weights in full float32 precision when starting training - even if"
    " doing mixed precision training. copy of the weights should still be float32."
)
if args.gradient_checkpointing:#是否使用梯度检查点来节省内存,但以减慢反向传播的速度。
    unet.enable_gradient_checkpointing()
    if args.train_text_encoder:
        text_encoder.gradient_checkpointing_enable()

# Check that all trainable models are in full precision
low_precision_error_string = (
    "Please make sure to always have all model weights in full float32 precision when starting training - even if"
    " doing mixed precision training. copy of the weights should still be float32."
)

if unwrap_model(unet).dtype != torch.float32:
    raise ValueError(f"Unet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}")

if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
    raise ValueError(
        f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}"
    )
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
    torch.backends.cuda.matmul.allow_tf32 = True

if args.scale_lr:#学习率
    args.learning_rate = (
        args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
    )

# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
    try:
        import bitsandbytes as bnb
    except ImportError:
        raise ImportError(
            "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
        )

    optimizer_class = bnb.optim.AdamW8bit
else:
    optimizer_class = torch.optim.AdamW

创建优化器

# Optimizer creation
params_to_optimize = (
    itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
)
#根据条件args.train_text_encoder的值来选择性地将神经网络模型的参数传递给优化器。如果args.train_text_encoder为真,
# 那么params_to_optimize将使用itertools.chain(unet.parameters(), text_encoder.parameters())生成的参数组合。
# 否则,params_to_optimize将仅仅使用unet.parameters()。
optimizer = optimizer_class(
    params_to_optimize,
    lr=args.learning_rate,
    betas=(args.adam_beta1, args.adam_beta2),
    weight_decay=args.adam_weight_decay,
    eps=args.adam_epsilon,
)#优化器对象

是否预先计算文本嵌入

if args.pre_compute_text_embeddings:#是否预先计算文本嵌入。      
   #https://segmentfault.com/a/1190000044075300

    def compute_text_embeddings(prompt):
        with torch.no_grad():#不会追踪梯度。
            text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length)
            #将输入的文本prompt编码,
            prompt_embeds = encode_prompt(
                text_encoder,
                text_inputs.input_ids,
                text_inputs.attention_mask,
                text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
                #是否使用attention_mask进行text_encoder https://developer.baidu.com/article/details/3248780
            )#使用自定义函数,对输入的文本进行编码,得到编码后向量

        return prompt_embeds

    pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)#调用刚刚函数返回sys dog编码后向量
    validation_prompt_negative_prompt_embeds = compute_text_embeddings("")

    if args.validation_prompt is not None:#validation_prompt在验证过程中用于确认模型是否在学习的prompt
        validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt)
    else:
        validation_prompt_encoder_hidden_states = None

    if args.class_prompt is not None:
        pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt)
    else:
        pre_computed_class_prompt_encoder_hidden_states = None

    text_encoder = None
    tokenizer = None

    gc.collect()#强制进行垃圾回收
    torch.cuda.empty_cache()
else:#不预先计算文本嵌入
    pre_computed_encoder_hidden_states = None
    validation_prompt_encoder_hidden_states = None
    validation_prompt_negative_prompt_embeds = None
    pre_computed_class_prompt_encoder_hidden_states = None

dataset和datasetloder

# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
    instance_data_root=args.instance_data_dir,
    instance_prompt=args.instance_prompt,
    class_data_root=args.class_data_dir if args.with_prior_preservation else None,
    class_prompt=args.class_prompt,
    class_num=args.num_class_images,
    tokenizer=tokenizer,
    size=args.resolution,
    center_crop=args.center_crop,
    encoder_hidden_states=pre_computed_encoder_hidden_states,
    class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
    tokenizer_max_length=args.tokenizer_max_length,
)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.train_batch_size,
    shuffle=True,
    collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
    num_workers=args.dataloader_num_workers,
)

其他配置

# Scheduler and math around the number of training steps.
#训练step设置
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    overrode_max_train_steps = True

lr_scheduler = get_scheduler(
    args.lr_scheduler,
    optimizer=optimizer,
    num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
    num_training_steps=args.max_train_steps * accelerator.num_processes,
    num_cycles=args.lr_num_cycles,
    power=args.lr_power,
)

# Prepare everything with our `accelerator`.
if args.train_text_encoder:
    unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        unet, text_encoder, optimizer, train_dataloader, lr_scheduler
    )
else:
    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        unet, optimizer, train_dataloader, lr_scheduler
    )

# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
#对于混合精度训练,我们将所有不可训练的权重(例如,VAE、非 LORA 文本编码器和非 LORA UNet)转换为半精度,因为这些权重仅用于推理,保持全精度权重是不必要的。
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
    weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
    weight_dtype = torch.bfloat16

# Move vae and text_encoder to device and cast to weight_dtype
if vae is not None:
    vae.to(accelerator.device, dtype=weight_dtype)

if not args.train_text_encoder and text_encoder is not None:
    text_encoder.to(accelerator.device, dtype=weight_dtype)

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
#训练step,不太懂,有空可以详细看看
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
    tracker_config = vars(copy.deepcopy(args))
    tracker_config.pop("validation_images")
    accelerator.init_trackers("dreambooth", config=tracker_config)
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

logger.info("***** Running training *****")
logger.info(f"  Num examples = {len(train_dataset)}")
logger.info(f"  Num batches each epoch = {len(train_dataloader)}")
logger.info(f"  Num Epochs = {args.num_train_epochs}")
logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f"  Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0

# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:#是否应该从以前的检查点开始训练。
    if args.resume_from_checkpoint != "latest":
        path = os.path.basename(args.resume_from_checkpoint)#不是latest提前文件名
    else:
        # Get the most recent checkpoint
        dirs = os.listdir(args.output_dir)#获取指定目录中的文件和目录列表
        dirs = [d for d in dirs if d.startswith("checkpoint")]#只保留以"checkpoint"开头的名称
        dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))#进行排序,排序的标准是按照"-"分割后的第二部分进行整数排序。
        path = dirs[-1] if len(dirs) > 0 else None#最后选择最大的文件夹作为路径

    if path is None:
        accelerator.print(
            f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
        )
        args.resume_from_checkpoint = None
        initial_global_step = 0
    else:
        accelerator.print(f"Resuming from checkpoint {path}")
        accelerator.load_state(os.path.join(args.output_dir, path))
        global_step = int(path.split("-")[1])
        #这行代码通过破折号("-")对字符串变量 path 进行分割,并取分割后的第二部分(即索引为 1 的部分),然后将其转换为整数类型
        #如果 path 是一个形如 'model-100.pth' 的文件路径,那么这行代码将提取出全局步骤数 100

        initial_global_step = global_step
        first_epoch = global_step // num_update_steps_per_epoch
else:#否应该从以前的检查点开始训练。
    initial_global_step = 0

progress_bar = tqdm(#创建了一个迭代器进度条。
    range(0, args.max_train_steps),
    initial=initial_global_step,
    desc="Steps",
    # Only show the progress bar once on each machine.
    disable=not accelerator.is_local_main_process,
)

训练

# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:#是否应该从以前的检查点开始训练。
    if args.resume_from_checkpoint != "latest":
        path = os.path.basename(args.resume_from_checkpoint)#不是latest提前文件名
    else:
        # Get the most recent checkpoint
        dirs = os.listdir(args.output_dir)#获取指定目录中的文件和目录列表
        dirs = [d for d in dirs if d.startswith("checkpoint")]#只保留以"checkpoint"开头的名称
        dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))#进行排序,排序的标准是按照"-"分割后的第二部分进行整数排序。
        path = dirs[-1] if len(dirs) > 0 else None#最后选择最大的文件夹作为路径

    if path is None:
        accelerator.print(
            f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
        )
        args.resume_from_checkpoint = None
        initial_global_step = 0
    else:
        accelerator.print(f"Resuming from checkpoint {path}")
        accelerator.load_state(os.path.join(args.output_dir, path))
        global_step = int(path.split("-")[1])
        #这行代码通过破折号("-")对字符串变量 path 进行分割,并取分割后的第二部分(即索引为 1 的部分),然后将其转换为整数类型
        #如果 path 是一个形如 'model-100.pth' 的文件路径,那么这行代码将提取出全局步骤数 100

        initial_global_step = global_step
        first_epoch = global_step // num_update_steps_per_epoch
else:#否应该从以前的检查点开始训练。
    initial_global_step = 0

progress_bar = tqdm(#创建了一个迭代器进度条。
    range(0, args.max_train_steps),
    initial=initial_global_step,
    desc="Steps",
    # Only show the progress bar once on each machine.
    disable=not accelerator.is_local_main_process,
)

for epoch in range(first_epoch, args.num_train_epochs):#训练轮次和开始轮次
    unet.train()
    if args.train_text_encoder:
        text_encoder.train()
    #使用 unet.train() 和 text_encoder.train() 将模型设置为训练模式。
    for step, batch in enumerate(train_dataloader):
        #加载sys
        with accelerator.accumulate(unet):
            pixel_values = batch["pixel_values"].to(dtype=weight_dtype)

            if vae is not None:
                # Convert images to latent space
                #在图像上编码并将其转换为 latent space。
                model_input = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
                model_input = model_input * vae.config.scaling_factor
            else:
                model_input = pixel_values

            # Sample noise that we'll add to the model input
            #添加noise
            if args.offset_noise:
                noise = torch.randn_like(model_input) + 0.1 * torch.randn(
                    model_input.shape[0], model_input.shape[1], 1, 1, device=model_input.device
                )
            else:
                noise = torch.randn_like(model_input)
            bsz, channels, height, width = model_input.shape#图片值

            # Sample a random timestep for each image
            #为每个图像随机采样一个时间步, 就是循环神经网络认为每个输入数据与前多少个陆续输入的数据有联系。
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
            )
            timesteps = timesteps.long()
            #这段代码首先使用PyTorch的torch.randint函数生成一个大小为bsz的随机整数张量timesteps,
            # 其中每个元素的取值范围在0到noise_scheduler.config.num_train_timesteps之间。这个过程在model_input.device所指定的设备上进行。
            #接着,将生成的随机整数张量timesteps转换为long类型。

            # Add noise to the model input according to the noise magnitude at each timestep
            #添加noise,是向前传播过程
            # (this is the forward diffusion process)
            noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)

            # Get the text embedding for conditioning
            #为条件获取文本嵌入,就是将文本变为向量
            if args.pre_compute_text_embeddings:
                encoder_hidden_states = batch["input_ids"]
            else:
                encoder_hidden_states = encode_prompt(
                    text_encoder,
                    batch["input_ids"],
                    batch["attention_mask"],
                    text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
                )

            if unwrap_model(unet).config.in_channels == channels * 2:#channels是两倍,noise也要拼接成两倍
                noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)

            if args.class_labels_conditioning == "timesteps":
                class_labels = timesteps
            else:
                class_labels = None

            # Predict the noise residual预测噪音残差
            model_pred = unet(
                noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels, return_dict=False
            )[0]

            if model_pred.shape[1] == 6:
                model_pred, _ = torch.chunk(model_pred, 2, dim=1)

            # Get the target for loss depending on the prediction type
            #根据预测类型获取用于损失计算的目标值
            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(model_input, noise, timesteps)
            else:
                raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

            if args.with_prior_preservation:
                # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
                #将噪音和模型预测分成两部分,并分别对每部分计算损失
                model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
                target, target_prior = torch.chunk(target, 2, dim=0)
                # Compute prior loss计算先验损失
                prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")

            # Compute instance loss
            if args.snr_gamma is None:#“在重新平衡损失时使用的信噪比加权 gamma。推荐值为 5.0。”
                loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
            else:
                # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
                # Since we predict the noise instead of x_0, the original formulation is slightly changed.
                # This is discussed in Section 4.2 of the same paper.
                #计算loss weight 作者进行了更改
                snr = compute_snr(noise_scheduler, timesteps)#计算了信噪比(SNR)
                base_weight = (
                    torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
                )

                if noise_scheduler.config.prediction_type == "v_prediction":
                    # Velocity objective needs to be floored to an SNR weight of one.
                    mse_loss_weights = base_weight + 1
                else:
                    # Epsilon and sample both use the same loss weights.
                    mse_loss_weights = base_weight
                loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
                loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
                loss = loss.mean()

            if args.with_prior_preservation:
                # Add the prior loss to the instance loss.
                loss = loss + args.prior_loss_weight * prior_loss

            accelerator.backward(loss)#对loss反向传播
            if accelerator.sync_gradients:
            #首先检查加速器是否需要同步梯度。如果需要,它将选择需要进行梯度裁剪的参数,
            # 将所有参数整合并进行梯度裁剪操作,确保梯度大小不会超过指定的max_grad_norm。
                params_to_clip = (
                    itertools.chain(unet.parameters(), text_encoder.parameters())
                    if args.train_text_encoder
                    else unet.parameters()
                )
                accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
            optimizer.step()
            #对模型参数进行一步优化器更新,根据计算得到的梯度大小调整参数
            lr_scheduler.step()
            #新优化器中的学习率。
            optimizer.zero_grad(set_to_none=args.set_grads_to_none)#梯度归零

        # Checks if the accelerator has performed an optimization step behind the scenes
        #检查加速器是否在后台执行了优化步骤
        if accelerator.sync_gradients:#检查加速器是否需要同步梯度。
            progress_bar.update(1)#通过进度条更新显示进度。
            global_step += 1#全局步数加一。

            if accelerator.is_main_process:#检查当前进程是否为主进程,以确保以下的操作只由主进程执行,避免重复操作。
                if global_step % args.checkpointing_steps == 0:#判断是否到了保存检查点的步数。如果是,就执行保存检查点的操作。
                    # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
                    if args.checkpoints_total_limit is not None:
                        checkpoints = os.listdir(args.output_dir)
                        checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
                        checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))

                        # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
                        if len(checkpoints) >= args.checkpoints_total_limit:#检查是否超过了保存检查点的总限制。会删除旧的检查点,以确保不超过阈值个数。
                            num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
                            removing_checkpoints = checkpoints[0:num_to_remove]

                            logger.info(
                                f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
                            )
                            logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")

                            for removing_checkpoint in removing_checkpoints:
                                removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
                                shutil.rmtree(removing_checkpoint)

                    save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                    accelerator.save_state(save_path)
                    logger.info(f"Saved state to {save_path}")
                    #保存当前模型状态为检查点,并记录保存路径

                images = []

                if args.validation_prompt is not None and global_step % args.validation_steps == 0:
                    #检查是否设置了验证提示,并判断是否到了执行验证的步数
                    images = log_validation(
                        unwrap_model(text_encoder) if text_encoder is not None else text_encoder,
                        tokenizer,
                        unwrap_model(unet),
                        vae,
                        args,
                        accelerator,
                        weight_dtype,
                        global_step,
                        validation_prompt_encoder_hidden_states,
                        validation_prompt_negative_prompt_embeds,
                    )##执行模型的验证,并记录验证生成的图像。

        logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
        progress_bar.set_postfix(**logs)#将损失值和学习率作为进度条的附加信息显示出来,
        accelerator.log(logs, step=global_step)

        if global_step >= args.max_train_steps:#如果全局步数(global_step)超过设定的最大训练步数(args.max_train_steps),则跳出循环,即结束训练过程。
            break

create the pipeline和上传

# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()#确保所有进程都已经执行完毕,等待所有的线程进行同步。
if accelerator.is_main_process:
    #创建一个pipeline,并根据训练得到的模型配置其参数,最后将pipeline保存起来。
    pipeline_args = {}

    if text_encoder is not None:
        pipeline_args["text_encoder"] = unwrap_model(text_encoder)

    if args.skip_save_text_encoder:
        pipeline_args["text_encoder"] = None

    pipeline = DiffusionPipeline.from_pretrained(
        args.pretrained_model_name_or_path,
        unet=unwrap_model(unet),
        revision=args.revision,
        variant=args.variant,
        **pipeline_args,
    )

    # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
    scheduler_args = {}
   #创建一个空字典scheduler_args用于存储调度器参数。
    if "variance_type" in pipeline.scheduler.config:
        variance_type = pipeline.scheduler.config.variance_type

        if variance_type in ["learned", "learned_range"]:
            variance_type = "fixed_small"

        scheduler_args["variance_type"] = variance_type

    pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
    pipeline.save_pretrained(args.output_dir)
    #最后将训练好的pipeline保存到args.output_dir中

    if args.push_to_hub:
        save_model_card(
            repo_id,
            images=images,
            base_model=args.pretrained_model_name_or_path,
            train_text_encoder=args.train_text_encoder,
            prompt=args.instance_prompt,
            repo_folder=args.output_dir,
            pipeline=pipeline,
        )
        upload_folder(
            repo_id=repo_id,
            folder_path=args.output_dir,
            commit_message="End of training",
            ignore_patterns=["step_*", "epoch_*"],
        )

accelerator.end_training()

后面真的懒得详细写了
附上自定义函数

  def save_model_card(
      repo_id: str,
      images: list = None,
      base_model: str = None,
      train_text_encoder=False,
      prompt: str = None,
      repo_folder: str = None,
      pipeline: DiffusionPipeline = None,
  ):
      img_str = ""
      if images is not None:
          for i, image in enumerate(images):
              image.save(os.path.join(repo_folder, f"image_{i}.png"))
              img_str += f"![img_{i}](./image_{i}.png)\n"

      model_description = f"""
  # DreamBooth - {repo_id}

  This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/).
  You can find some example images in the following. \n
  {img_str}

  DreamBooth for the text encoder was enabled: {train_text_encoder}.
  """
      model_card = load_or_create_model_card(
          repo_id_or_path=repo_id,
          from_training=True,
          license="creativeml-openrail-m",
          base_model=base_model,
          prompt=prompt,
          model_description=model_description,
          inference=True,
      )

      tags = ["text-to-image", "dreambooth", "diffusers-training"]
      if isinstance(pipeline, StableDiffusionPipeline):
          tags.extend(["stable-diffusion", "stable-diffusion-diffusers"])
      else:
          tags.extend(["if", "if-diffusers"])
      model_card = populate_model_card(model_card, tags=tags)

      model_card.save(os.path.join(repo_folder, "README.md"))


  def log_validation(
      text_encoder,
      tokenizer,
      unet,
      vae,
      args,
      accelerator,
      weight_dtype,
      global_step,
      prompt_embeds,
      negative_prompt_embeds,
  ):#执行模型的验证,并记录验证生成的图像。
      logger.info(
          f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
          f" {args.validation_prompt}."
      )

      pipeline_args = {}

      if vae is not None:
          pipeline_args["vae"] = vae

      # create pipeline (note: unet and vae are loaded again in float32)
      pipeline = DiffusionPipeline.from_pretrained(
          args.pretrained_model_name_or_path,
          tokenizer=tokenizer,
          text_encoder=text_encoder,
          unet=unet,
          revision=args.revision,
          variant=args.variant,
          torch_dtype=weight_dtype,
          **pipeline_args,
      )

      # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
      scheduler_args = {}

      if "variance_type" in pipeline.scheduler.config:
          variance_type = pipeline.scheduler.config.variance_type

          if variance_type in ["learned", "learned_range"]:
              variance_type = "fixed_small"

          scheduler_args["variance_type"] = variance_type

      module = importlib.import_module("diffusers")
      scheduler_class = getattr(module, args.validation_scheduler)
      pipeline.scheduler = scheduler_class.from_config(pipeline.scheduler.config, **scheduler_args)
      pipeline = pipeline.to(accelerator.device)
      pipeline.set_progress_bar_config(disable=True)

      if args.pre_compute_text_embeddings:
          pipeline_args = {
              "prompt_embeds": prompt_embeds,
              "negative_prompt_embeds": negative_prompt_embeds,
          }
      else:
          pipeline_args = {"prompt": args.validation_prompt}

      # run inference
      generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
      images = []
      if args.validation_images is None:
          for _ in range(args.num_validation_images):
              with torch.autocast("cuda"):
                  image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0]
              images.append(image)
      else:
          for image in args.validation_images:
              image = Image.open(image)
              image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
              images.append(image)

      for tracker in accelerator.trackers:
          if tracker.name == "tensorboard":
              np_images = np.stack([np.asarray(img) for img in images])
              tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC")
          if tracker.name == "wandb":
              tracker.log(
                  {
                      "validation": [
                          wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
                      ]
                  }
              )

      del pipeline
      torch.cuda.empty_cache()

      return images


  def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):#返回模型类别
      text_encoder_config = PretrainedConfig.from_pretrained(
          pretrained_model_name_or_path,
          subfolder="text_encoder",
          revision=revision,
      )
      model_class = text_encoder_config.architectures[0]

      if model_class == "CLIPTextModel":
          from transformers import CLIPTextModel

          return CLIPTextModel
      elif model_class == "RobertaSeriesModelWithTransformation":
          from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation

          return RobertaSeriesModelWithTransformation
      elif model_class == "T5EncoderModel":
          from transformers import T5EncoderModel

          return T5EncoderModel
      else:
          raise ValueError(f"{model_class} is not supported.")

还有

class DreamBoothDataset(Dataset):

  #A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
  #It pre-processes the images and the tokenizes prompts.
  
  #一个数据集,用于为调整模型进行细化的实例和类别图像准备提示。它对图像进行预处理并对提示进行标记化处理
  def __init__(
      self,
      instance_data_root,
      instance_prompt,
      tokenizer,
      class_data_root=None,
      class_prompt=None,
      class_num=None,
      size=512,
      center_crop=False,
      encoder_hidden_states=None,
      class_prompt_encoder_hidden_states=None,
      tokenizer_max_length=None,
  ):
      self.size = size
      self.center_crop = center_crop
      self.tokenizer = tokenizer
      self.encoder_hidden_states = encoder_hidden_states
      self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
      self.tokenizer_max_length = tokenizer_max_length

      self.instance_data_root = Path(instance_data_root)
      if not self.instance_data_root.exists():
          raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.")

      self.instance_images_path = list(Path(instance_data_root).iterdir())
      self.num_instance_images = len(self.instance_images_path)
      self.instance_prompt = instance_prompt
      self._length = self.num_instance_images

      if class_data_root is not None:
          self.class_data_root = Path(class_data_root)
          self.class_data_root.mkdir(parents=True, exist_ok=True)
          self.class_images_path = list(self.class_data_root.iterdir())
          if class_num is not None:
              self.num_class_images = min(len(self.class_images_path), class_num)
          else:
              self.num_class_images = len(self.class_images_path)
          self._length = max(self.num_class_images, self.num_instance_images)
          self.class_prompt = class_prompt
      else:
          self.class_data_root = None

      self.image_transforms = transforms.Compose(
          [
              transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
              transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
              transforms.ToTensor(),
              transforms.Normalize([0.5], [0.5]),
          ]
      )

  def __len__(self):
      return self._length

  def __getitem__(self, index):
      example = {}
      instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
      instance_image = exif_transpose(instance_image)

      if not instance_image.mode == "RGB":
          instance_image = instance_image.convert("RGB")
      example["instance_images"] = self.image_transforms(instance_image)

      if self.encoder_hidden_states is not None:
          example["instance_prompt_ids"] = self.encoder_hidden_states
      else:
          text_inputs = tokenize_prompt(
              self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
          )
          example["instance_prompt_ids"] = text_inputs.input_ids
          example["instance_attention_mask"] = text_inputs.attention_mask

      if self.class_data_root:
          class_image = Image.open(self.class_images_path[index % self.num_class_images])
          class_image = exif_transpose(class_image)

          if not class_image.mode == "RGB":
              class_image = class_image.convert("RGB")
          example["class_images"] = self.image_transforms(class_image)

          if self.class_prompt_encoder_hidden_states is not None:
              example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
          else:
              class_text_inputs = tokenize_prompt(
                  self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
              )
              example["class_prompt_ids"] = class_text_inputs.input_ids
              example["class_attention_mask"] = class_text_inputs.attention_mask

      return example


def collate_fn(examples, with_prior_preservation=False):
    #返回一个批处理字典对象,其中包含输入id、像素值以及(如果存在的话)注意力掩码。
    has_attention_mask = "instance_attention_mask" in examples[0]

    input_ids = [example["instance_prompt_ids"] for example in examples]
    pixel_values = [example["instance_images"] for example in examples]

    if has_attention_mask:
        attention_mask = [example["instance_attention_mask"] for example in examples]

    # Concat class and instance examples for prior preservation.
    # We do this to avoid doing two forward passes.
    if with_prior_preservation:
        input_ids += [example["class_prompt_ids"] for example in examples]
        pixel_values += [example["class_images"] for example in examples]

        if has_attention_mask:
            attention_mask += [example["class_attention_mask"] for example in examples]

    pixel_values = torch.stack(pixel_values)
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

    input_ids = torch.cat(input_ids, dim=0)

    batch = {
        "input_ids": input_ids,
        "pixel_values": pixel_values,
    }

    if has_attention_mask:
        attention_mask = torch.cat(attention_mask, dim=0)
        batch["attention_mask"] = attention_mask

    return batch


class PromptDataset(Dataset):
    """A simple dataset to prepare the prompts to generate class images on multiple GPUs."""

    def __init__(self, prompt, num_samples):
        self.prompt = prompt
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index):
        example = {}
        example["prompt"] = self.prompt
        example["index"] = index
        return example


def model_has_vae(args):
    #字面意思,是否有vae
    config_file_name = Path("vae", AutoencoderKL.config_name).as_posix()
    if os.path.isdir(args.pretrained_model_name_or_path):
        config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name)
        return os.path.isfile(config_file_name)
    else:
        files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings
        return any(file.rfilename == config_file_name for file in files_in_repo)


def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
    #进行编码可以看看这个https://www.cnblogs.com/carolsun/p/16903276.html
    if tokenizer_max_length is not None:
        max_length = tokenizer_max_length
    else:
        max_length = tokenizer.model_max_length

    text_inputs = tokenizer(
        prompt,
        truncation=True,
        padding="max_length",
        max_length=max_length,
        return_tensors="pt",
    )

    return text_inputs


def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):
    #对文本进行编码
    text_input_ids = input_ids.to(text_encoder.device)

    if text_encoder_use_attention_mask:
        attention_mask = attention_mask.to(text_encoder.device)
    else:
        attention_mask = None

    prompt_embeds = text_encoder(
        text_input_ids,
        attention_mask=attention_mask,
        return_dict=False,
    )
    promp  t_embeds = prompt_embeds[0]
    #提取向量
    return prompt_embeds

标签:prompt,dreambooth,text,代码,args,encoder,阅读,model,class
From: https://www.cnblogs.com/ltlearnweb/p/18355797

相关文章

  • 【原创】【深入浅出系列】之代码可读性
    这是“深入浅出系列”文章的第一篇,主要记录和分享程序设计的一些思想和方法论,如果读者觉得所有受用,还请“一键三连”,这是对我最大的鼓励。一、老生常谈,到底啥是可读性一句话:见名知其义。有人说好的代码必然有清晰完整的注释,我不否认;也有人说代码即注释,是代码简洁之道的最高境......
  • 代码随想录Day14
    226.翻转二叉树给你一棵二叉树的根节点root,翻转这棵二叉树,并返回其根节点。示例1:输入:root=[4,2,7,1,3,6,9]输出:[4,7,2,9,6,3,1]示例2:输入:root=[2,1,3]输出:[2,3,1]示例3:输入:root=[]输出:[]提示:树中节点数目范围在[0,100]内-100<=Node.val<=100......
  • 【SpringBoot+Vue】基于混合推荐算法的小说在线阅读平台
    【1】系统介绍随着互联网技术的发展和普及,网络文学已经成为人们日常生活中不可或缺的一部分。网络小说因其便捷的获取方式、丰富的题材选择以及个性化的阅读体验而受到广大读者的喜爱。然而,在海量的小说资源中,如何为每位读者提供高质量、个性化的阅读推荐,成为了在线阅读平......
  • 【笔记】从0开始的代码审计
    【笔记】从0开始的代码审计代码审计思路敏感函数回溯参数调用过程首先特别关注程序敏感函数点,如:SQL语句拼合处、call_user_func、eval、unserialize、HTTP_CLIENT_IP等然后回溯参数调用过程查看是否全部过滤或者过滤不全,如:程序可能开启magic_quotes_gpc(转义大部分符号),但是部......
  • word中插入代码块
    一、使用word原生功能......
  • WebSockets:原理、握手及代码实现
    1.WebSockets原理WebSockets是HTML5标准的一部分,旨在为Web应用提供全双工通信能力。与传统的HTTP请求不同,WebSockets连接一旦建立,就可以在客户端和服务器之间自由传输数据,而不再需要每次通信都重新建立连接。工作流程:建立连接:客户端通过HTTP协议发起WebSocket握......
  • 京东旋转验证码识别代码
    京东旋转验证码样例如下:现在京东更新了很多新图片,我们再次进行了大量数据标记,完成了这款验证码的更新。现在正确率可以达到95%左右。下边是这款验证码的识别代码:importbase64importrequestsimportdatetimeimportnumpyasnpfromioimportBytesIOfromPILimpo......
  • 代码随想录算法训练营第 42 天 |LeetCode 188.买卖股票的最佳时机IV LeetCode309.最佳
    代码随想录算法训练营Day42代码随想录算法训练营第42天|LeetCode188.买卖股票的最佳时机IVLeetCode309.最佳买卖股票时机含冷冻期LeetCode714.买卖股票的最佳时机含手续费目录代码随想录算法训练营前言LeetCode188.买卖股票的最佳时机IVLeetCode309.最佳买卖......
  • 摘要生成—通过摘要风格控制摘要的生成/抽取,原文阅读与理解:GEMINI: Controlling The S
    GEMINI:ControllingTheSentence-LevelSummaryStyleinAbstractiveTextSummarizationGEMINI:在抽象文本摘要中控制句子级摘要风格paper:https://arxiv.org/abs/2304.03548github:https://github.com/baoguangsheng/gemini本文介绍了一种自适应摘要抽取/生成方......