首页 > 其他分享 >Generative AI 新世界 | 文生图领域动手实践:预训练模型的微调

Generative AI 新世界 | 文生图领域动手实践:预训练模型的微调

时间:2023-10-10 21:35:17浏览次数:42  
标签:训练 文生 模型 微调 AI train Generative model image

在上期文章,我们探讨了预训练模型的部署和推理,包括运行环境准备、角色权限配置、支持的主要推理参数、图像的压缩输出、提示工程 (Prompt Engineering)、反向提示 (Negative Prompting) 等内容。

亚马逊云科技开发者社区为开发者们提供全球的开发技术资源。这里有技术文档、开发案例、技术专栏、培训视频、活动与竞赛等。帮助中国开发者对接世界最前沿技术,观点,和项目,并将中国优秀开发者或技术推荐给全球云社区。如果你还没有关注/收藏,看到这里请一定不要匆匆划过,点这里让它成为你的技术宝库!

本期文章,我们将探讨如何在自定义数据集上来微调(fine-tuned)模型,该模型可以针对任何图像数据集进行微调。即使你手上只有几张自定义的图像提供做训练,模型也能输出比较理想的结果。

首先,让我们通过一篇论文的概括解读,来了解这种文生图模型的微调 (fine-tuned),背后的工作原理和理论基础知识。

DreamBooth 论文概述

这种文生图模型的微调(fine-tuned)理论基础来自于 DreamBooth 论文,如下图所示:

image.png

DreamBooth: Fine Tuning Text-to-Image Diffusion Models for Subject-DrivenGeneration

https://arxiv.org/pdf/2208.12242.pdf?trk=cndc-detail

在论文的开头,作者提出一个挑战性的问题:

虽然当时的文生图模型已经可以根据给定的 **prompt **生成高质量的图片,但是这些模型并不能模仿给定参考图片中的物体要素,在不同情景中来生成新的图片。

举个例子。

我家里有一只叫做“小花”的可爱加菲猫,如下图:

image.png

我想让加菲猫“小花”带上一顶礼帽,如下图:

image.png

或者带上一副很酷炫的墨镜,如下图:

image.png

甚至想象下她刷牙的魔幻景象,如下图:

image.png

事实上,上面的这些加菲猫“小花”的照片(戴礼帽、戴墨镜、刷牙),都是大模型使用 DreamBooth 做微调后生成的。很有趣吧?在文末会提供生成这些魔幻照片的全部代码

我们先看下 DreamBooth 论文阐述的背后原理。

DreamBooth 论文提出一个新颖的方法:将输入图片中的物体与一个特殊标识符绑定在一起,即用这个特殊标记符来表示输入图片中的物体。因此论文为微调模型设计了一种 prompt 格式:a [identifier] [class noun],即将所有输入图片的 prompt 都设置成这种形式,其中 identifier 是一个与输入图片中物体相关联的特殊标记符,class noun 是对物体的类别描述。

这里之所以在 prompt 中加入类别,是因为想利用预训练模型中关于该类别物品的先验知识,并将先验知识与特殊标记符相关信息进行融合,这样就可以在不同场景下生成不同姿势的目标物体。

简单来说就是:不要学了新的知识,就忘了旧的知识

论文提出的方法,大致如下图所示,即仅仅通过 3 到 5 张图片去微调文生图模型,使得模型能将输入图片中特定的物品和 prompt 中的特殊标记符关联起来了。

image.png

Source: https://dreambooth.github.io\?trk=cndc-detail

关于特殊标记符的选择,论文提出通过在词表中选择罕见词来作为特殊标记符,这样避免了预训练模型对特殊标记符有很强烈的先验知识。

DreamBooth 论文提出一个新的微调方法:通过预先生成的一些图像,来保留先验损失权重;以此来解决过拟合与语言漂移问题。用模型自己生成的样本来监督模型,以便在 few-shot(小样本)微调开始后保留先验知识,如以下论文中提供的解释图所示:

image.png

Source: https://dreambooth.github.io/?trk=cndc-detail

给定大约 3-5 张拍摄对象的图像,我们分两步微调文本到图像的扩散:

  1. 使用输入图像与包含唯一标识符和主题所属类名称(例如:“A photo of a [T] dog”)的文本提示配对;同时,我们应用特定于类的预先保存损失,它利用了模型之前的语义通过在文本提示中注入类名,来鼓励它生成属于受试者类的各种实例提示(例如:“A photo of a dog”)。
  2. 使用从我们的输入图像集中拍摄的低分辨率和高分辨率图像,对超分辨率组件进行微调,这使我们能够保持对拍摄对象小细节的高保真度。

引入了先验损失的 loss 公式,如下所示:

image.png

通过这种 DreamBooth 方法,使得:输入训练集 + 提示词 [v] dog,然后还有用模型本身自己生成的 dog 图像,训练完成后得到了一个特殊标记符:[v]。通过这个特殊标记符 [v],就把这次训练的 dog 和其他本身学过的 dog 分开了。

最后得到惊艳的结果,比如给一只小熊带上太阳镜,如下图所示:

image.png

Source: https://dreambooth.github.io/?trk=cndc-detail

接下来,我们将完整用代码演示,如何给我家的加菲猫“小花”带上眼镜和礼帽。

Fine-tune 预训练模型在自有数据集上的微调

我们使用 Amazon SageMaker Studio 来实现在自有数据上的模型微调。

我首先将为我家的加菲猫“小花”拍摄几张照片,然后用这几张照片来微调模型;完成模型微调后,我们将使用 “a picture of Garfield cat with glasses” 这样的提示词,来直接为我家的加菲猫“小花”带上眼镜。

1 实例和环境准备

这个 Notebook 在带有 Python 3(Data Science)内核的 SageMaker Studio 中,使用 ml.t3.medium 实例上进行了测试。要对数据集的模型进行微调,您需要在账户中提供 ml.g4dn.2xlarge 实例类型。

完整的示例代码,可参考以下 GitHub 文档链接,从 “Fine-tune the pre-trained model on a custom dataset” 这一部分开始阅读代码:

https://github.com/aws/studio-lab-examples/blob/main/generative-deep-learning/stable-diffusion-finetune/Amazon_JumpStart_Text_To_Image.ipynb?trk=cndc-detail

你存放自定义照片的 s3 路径,应该看起来像这样:s3://bucket_name/input_directory/

请注意,后面的“/”为必填项。

以下是训练数据的示例格式:

input_directory
    |---instance_image_1.png
    |---instance_image_2.png
    |---instance_image_3.png
    |---instance_image_4.png
    |---instance_image_5.png
    |---dataset_info.json
    |---class_data_dir
        |---class_image_1.png
        |---class_image_2.png
        |---class_image_3.png
        |---class_image_4.png

 

预先保存、实例提示和类提示(Prior preservation, instance prompt and class prompt):预先保存是一种使用我们正在尝试训练的同一个类的其他图像的技术。例如,如果训练数据由特定狗的图像组成,并事先保存,则我们会合并普通犬的类别图像。它试图通过在为特定狗训练时显示不同狗的图像来避免过度拟合。类提示中缺少表示实例提示中存在的特定狗的标签。

例如,实例提示可能是 “A photo of a Garfield cat”,类提示可能是 “A photo of a cat”。

您可以通过将超参数设置为 _prior_preservation = True 来启用预先保存。

以下为使用我家加菲猫“小花”的照片的 dataset_info.json 的文件示例:

$ cat dataset_info.json
{
  "instance_prompt": "A photo of a Garfield cat",
  "class_prompt": "A photo of a cat"
}

  

以下是我为了微调模型,而拍摄的我家加菲猫“小花”的照片。我只用了下面这六张照片,就实现了模型的微调。

image.png

我存放照片(即为微调模型提供的自定义训练图片)的 S3 桶参考路径如下:s3://sagemaker-us-east-1-xxxxxxxxxxxx/haowen-datasets/cat_finetuning/

其中 “sagemaker-us-east-1-xxxxxxxxxxxx” 需要更新为你自己定义的桶名。

最终完成微调后,模型存放的 S3 桶参考路径如下:s3://sagemaker-us-east-1-xxxxxxxxxxxx/jumpstart-example-sd-training/output

其中 “sagemaker-us-east-1-xxxxxxxxxxxx” 需要更新为你自己定义的桶名。

2 检索训练数据的 Artifacts

在这里,我们检索训练 docker 容器、训练算法源和预先训练的基础模型。

请注意,model_version= “*” 获取的是最新的模型版本号。以下代码选择了 Stable Diffusion V2.1 Base 的文生图大模型。

# Select a model 
train_model_id, train_model_version, train_scope = (
    "model-txt2img-stabilityai-stable-diffusion-v2-1-base",
    "*",
    "training",
)

  

以下代码选择了微调模型的实例是 ml.g4dn.2xlarge:

training_instance_type = "ml.g4dn.2xlarge"

  

以下代码获取 Docker Image:

# Retrieve the docker image
train_image_uri = image_uris.retrieve(
    region=None,
    framework=None,  # automatically inferred from model_id
    model_id=train_model_id,
    model_version=train_model_version,
    image_scope=train_scope,
    instance_type=training_instance_type,
)

  

以下代码获取训练脚本:

# Retrieve the training script. This contains all the necessary files including data processing, model training etc.
train_source_uri = script_uris.retrieve(
    model_id=train_model_id, model_version=train_model_version, script_scope=train_scope
)

  

以下代码获取预训练模型的 tarball 包,用于之后的微调工作:

# Retrieve the pre-trained model tarball to further fine-tune
train_model_uri = model_uris.retrieve(
    model_id=train_model_id, model_version=train_model_version, model_scope=train_scope
)
 

3 设置训练参数

现在我们已经完成了所有需要的设置,我们已经准备好微调 Stable Diffusion 模型了。首先,让我们创建一个 sageMaker.estimator.Estimator 对象。该 Estimator 将启动训练作业。

模型的微调训练需要设置两种参数。

第一组参数是训练作业的参数。其中包括:

  1. 训练数据路径,这是存储输入数据的 S3  路径。即之前我们准备的 “s3://sagemaker-us-east-1-xxxxxxxxxxxx/haowen-datasets/cat_finetuning/” 这个路径;
  2.  输出路径,这是存储微调模型训练的输出 s3 路径。即之前我们准备的“s3://sagemaker-us-east-1-xxxxxxxxxxxx/jumpstart-example-sd-training/output” 这个路径;
  3. 训练实例类型,这表示运行模型微调训练的机器类型。我们在上面定义了训练实例类型,以获取正确的 train_image_uri。

第二组参数是特定于算法的训练超参数。对于算法特定的超参数,我们首先获取算法接受的训练超参数的 python 字典及其默认值,然后可以将其改写为自定义值。示例代码如下所示:

from sagemaker import hyperparameters

# Retrieve the default hyper-parameters for fine-tuning the model
hyperparameters = hyperparameters.retrieve_default(
    model_id=train_model_id, model_version=train_model_version
)

# [Optional] Override default hyperparameters with custom values
hyperparameters["max_steps"] = "400"
print(hyperparameters)

4 启动模型微调训练

我们首先使用所有必需的 assets 创建 estimator 对象,然后启动训练作业。

from sagemaker.estimator import Estimator
from sagemaker.utils import name_from_base
from sagemaker.tuner import HyperparameterTuner

training_job_name = name_from_base(f"jumpstart-example-{train_model_id}-transfer-learning")

# Create SageMaker Estimator instance
sd_estimator = Estimator(
    role=aws_role,
    image_uri=train_image_uri,
    source_dir=train_source_uri,
    model_uri=train_model_uri,
    entry_point="transfer_learning.py",  # Entry-point file in source_dir and present in train_source_uri.
    instance_count=1,
    instance_type=training_instance_type,
    max_run=360000,
    hyperparameters=hyperparameters,
    output_path=s3_output_location,
    base_job_name=training_job_name,
)

# Launch a SageMaker Training job by passing s3 path of the training data
sd_estimator.fit({"training": training_dataset_s3_path}, logs=True)

  

模型训练开始后,如果观察 SageMaker 的控制台,会发现:

  1. 训练任务的状态,从 “InProgress” 逐渐变成 “Completed”;
  2. 超参调优的状态,从 “InProgress” 逐渐变成 “Completed”。

如下图所示:

image.png

image.png

image.png

经过那六张照片作为新的输入数据,微调后的模型重新训练完成后,就可以进入以下的模型部署阶段了。

5 微调后模型的部署

我们将遵循上一篇中介绍的模型部署的相同步骤,在训练好的模型上运行推理。我们首先检索用于部署端点的 jumpstart 工件。但是,我们部署的是经过微调的 sd_estimator 估算器,而不是上一篇中使用的 base_predictor 估算器。

inference_instance_type = "ml.g4dn.2xlarge"

# Retrieve the inference docker container uri
deploy_image_uri = image_uris.retrieve(
    region=None,
    framework=None,  # automatically inferred from model_id
    image_scope="inference",
    model_id=train_model_id,
    model_version=train_model_version,
    instance_type=inference_instance_type,
)
# Retrieve the inference script uri. This includes scripts for model loading, inference handling etc.
deploy_source_uri = script_uris.retrieve(
    model_id=train_model_id, model_version=train_model_version, script_scope="inference"
)

endpoint_name = name_from_base(f"jumpstart-example-FT-{train_model_id}-")

# Use the estimator from the previous step to deploy to a SageMaker endpoint
finetuned_predictor = sd_estimator.deploy(
    initial_instance_count=1,
    instance_type=inference_instance_type,
    entry_point="inference.py",  # entry point file in source_dir and present in deploy_source_uri
    image_uri=deploy_image_uri,
    source_dir=deploy_source_uri,
    endpoint_name=endpoint_name,
)

  

在等待新模型部署的过程中,可以回到 SageMaker 的控制台,在 Endpoints 项中刷新检查模型部署的情况。当 Status 从 “Creating” 变成 “Completed”,就表示微调后的新模型已经部署完成可以开始进行推理了。如下图所示:

image.png

6 微调后模型的推理

下面进入激动人心的时刻,我们在微调后的模型上进行推理。

我输入的提示词是:“a photo of a Garfield cat with a hat”(一只带帽子的加菲猫)。

text = " a photo of a Garfield cat with a hat"
query_response = query(finetuned_predictor, text)
img, prmpt = parse_response(query_response)
display_img_and_prompt(img, prmpt)

  

模型的魔幻输出如下图所示。我们成功地给加菲猫“小花”带上礼帽了!

image.png

接着我们给加菲猫“小花”带上眼镜,我输入的提示词是:“a picture of Garfield cat with glasses”:

text = " a picture of Garfield cat with glasses"
query_response = query(finetuned_predictor, text)
img, prmpt = parse_response(query_response)
display_img_and_prompt(img, prmpt)

  

模型的输出如下:

image.png

最后让加菲猫“小花”像人类一样去刷牙,我输入的提示词是:“a picture of Garfield cat brushing her teeth”:

text = " a picture of Garfield cat brushing her teeth"
query_response = query(finetuned_predictor, text)
img, prmpt = parse_response(query_response)
display_img_and_prompt(img, prmpt)

  

image.png

神奇吧?加菲猫“小花”会自己刷牙了!

7 计算资源删除和清理

和以前一样,实验完成后别忘记清除相关的 endpoint 资源,以避免产生不必要的费用:

# Delete the SageMaker endpoint
finetuned_predictor.delete_model()
finetuned_predictor.delete_endpoint()

  

总结

本文我们学习了如何使用 Amazon SageMaker JumpStart 方便地微调文生图的 Stable Diffusion 模型。

Amazon SageMaker JumpStart 为预训练的模型提供了微调功能,本文的例子中,你只需使用六张训练图像即可根据自己的用例调整模型。这在创建个性化艺术品、独特的徽标、企业的 LOGO、或者其他需要自定义设计的场景时非常有用。

下一期的文章,我们将重新回到文本生成的大模型场景,探讨如何在 Amazon SageMaker JumpStart 上部署当今炙手可热的开源大语言模型。我们将以 Falcon 40B 开源大模型为例,逐行代码轻松部署高达 400 亿参数的这个大型语言模型。敬请期待。

请持续关注 Build On Cloud 专栏,了解更多面向开发者的技术分享和云开发动态!

 作者 黄浩文

亚马逊云科技资深开发者布道师,专注于 AI/ML、Data Science 等。拥有 20 多年电信、移动互联网以及云计算等行业架构设计、技术及创业管理等丰富经验,曾就职于 Microsoft、Sun Microsystems、中国电信等企业,专注为游戏、电商、媒体和广告等企业客户提供 AI/ML、数据分析和企业数字化转型等解决方案咨询服务。

 

文章来源:https://dev.amazoncloud.cn/column/article/64cb87265306fa4a7fa3a3c9?sc_medium=regulartraffic&sc_campaign=crossplatform&sc_channel=bokey

标签:训练,文生,模型,微调,AI,train,Generative,model,image
From: https://www.cnblogs.com/AmazonwebService/p/17755794.html

相关文章

  • ACK 云原生 AI 套件:云原生 AI 工程化落地最优路径
    作者:胡玉瑜(稚柳)前言在过去几年中,人工智能技术取得了突飞猛进的发展,涵盖了机器学习、深度学习和神经网络等关键技术的重大突破,这使得人工智能在各个领域都得到广泛应用,对各行各业产生了深远的影响。特别值得一提的是,近年来,ChatGPT的快速发展,使得人工智能技术在自然语言处理和......
  • SAP ABAP 域(domain)固定值读取方法
    1SELECTSINGLEVALPOS2FROMDD07V3INTO@DATA(GT_DD07V)4WHEREDOMNAME='ZSTUTYPE'ANDVALPOS=@P_ZSTUTYP."域名和值5IFSY-SUBRC<>0.6MESSAGETEXT-134TYPE'S'DISPLAYLIKE'E......
  • Qt学习随笔-3、QMainWindow
       1 QMainWindow   1.1 菜单栏最多只能有一个      1.1.1 创建菜单栏,通过QMainWindow类的menubar()函数获取主窗口菜单栏指针            QMenuBar*bar=MenuBar();      1.1.2 setMenuBar(bar);  ......
  • ansible报 MODULE FAILURE
    在使用ansibles 批量连接新升级的欧拉系统时候,报MODULEFAILURE原因:ansibles 默认的python 名字叫python,需要使用python3,;而欧拉的python链接到的是python2 解决办法: 先将python重命名, 再执行:ln -s /usr/bin/python3 /usr/bin/python将python链接到pyth......
  • Failed to find "GL/gl.h" in
     001、问题:Failedtofind"GL/gl.h"in 002、解决方法[root@pc1cmake-3.27.7-build]#yuminstallmesa-lib* 。 参考:https://www.jianshu.com/p/5eeb3dd51c08 ......
  • System.NotSupportedException:“无法显式设置 SplitterPanel 的高度。改在 SplitCont
    System.NotSupportedException:“无法显式设置SplitterPanel的高度。改在SplitContainer上设置SplitterDistance。”这个错误信息是在使用SplitContainer控件时出现的。它表明您尝试显式设置SplitterPanel的高度,但这是不支持的操作,应该在SplitContainer上设置Splitte......
  • 浅述安防视频可视化场景中TSINGSEE青犀AI智能化应用的分析
    随着社会的不断发展和安防需求的不断提升,安防视频可视化场景已经成为人们关注的焦点。而随着人工智能、大数据等技术的不断发展,智能化应用在安防视频可视化场景中的应用也越来越多。本文将分析安防视频可视化场景中的智能化应用,主要包括以下方面:背景介绍、智能化应用分析、关键技术......
  • 浅析森林烟火AI检测算法的应用及场景使用说明
    一、方案背景现有的森林防火监测系统落后,以人工地面巡护、瞭望塔高点巡查为主,存在巡护范围有限、巡护效率低等问题,建立健全的森林防火风险预警体系,实现对森林、林场等场景的全天候智能自动监测、火情预警,及时发现森林火灾并辅助决策,是当前林业管理的重要任务。二、方案概述旭帆......
  • DDD(Domain-Driven Design,领域驱动设计)
    DDD(Domain-DrivenDesign,领域驱动设计)是一种软件开发方法论,它注重对业务领域的深入理解,并将领域模型作为软件设计的核心。在DDD中,领域模型是通过对业务领域的分析和抽象而得到的,它是对业务领域中的概念、规则、行为等的描述。领域模型的设计是DDD中的一个重要环节,它需要开发团队......
  • 在hadoop虚拟机里面使用hadoop jar运行打包文件,出现Exception in thread "main" org.a
    问题描述更改了JDK版本之后,再次运行又出现了这个错误:问题解决经过查阅相关资料,发现是自己定义的hdfs的路径不太对,本来写的是这样的:然后自己确实不记得配置环境时配置的是多少,就看了看这个文件core.site.xml:catcore-site.xml然后看到这里:使用的端口号是8020,改成跟环境......