首页 > 其他分享 >(6-3-03)CLIP模型训练与微调(3)训练模型+模型微调+调试运行

(6-3-03)CLIP模型训练与微调(3)训练模型+模型微调+调试运行

时间:2024-09-26 21:19:54浏览次数:3  
标签:argparse 微调 训练 模型 args parser hparams

6.3.4  训练模型

文件train.py是训练 CLIP 模型的主程序,首先根据命令行参数指定的模型名称加载相应的配置文件,然后创建一个 CLIPWrapper 模型实例,并根据命令行参数初始化数据模块。接着,使用 PyTorch Lightning 的 Trainer 对象进行训练。

import yaml
from argparse import ArgumentParser
from pytorch_lightning import Trainer
from data.text_image_dm import TextImageDataModule
from models import CLIPWrapper


def main(hparams):
    config_dir = 'models/configs/ViT.yaml' if 'ViT' in hparams.model_name else 'models/configs/RN.yaml'
    with open(config_dir) as fin:
        config = yaml.safe_load(fin)[hparams.model_name]

    if hparams.minibatch_size < 1:
        hparams.minibatch_size = hparams.batch_size

    model = CLIPWrapper(hparams.model_name, config, hparams.minibatch_size)
    del hparams.model_name
    dm = TextImageDataModule.from_argparse_args(hparams)
    trainer = Trainer.from_argparse_args(hparams, precision=16, max_epochs=32)
    trainer.fit(model, dm)


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--model_name', type=str, required=True)
    parser.add_argument('--minibatch_size', type=int, default=0)
    parser = TextImageDataModule.add_argparse_args(parser)
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    main(args)

对上述代码的具体说明如下所示:

  1. 加载模型配置文件:根据模型名称确定加载 ViT.yaml 还是 RN.yaml 配置文件。
  2. 创建模型实例:使用 CLIPWrapper 类创建模型实例,传入模型名称、配置和最小批次大小。
  3. 初始化数据模块:使用 TextImageDataModule.from_argparse_args 根据命令行参数初始化数据模块。
  4. 设置训练器参数:使用 Trainer.from_argparse_args 根据命令行参数设置训练器,包括精度和最大训练周期。
  5. 开始训练:使用 trainer.fit 方法开始训练模型。

6.3.5  模型微调

文件train_finetune.py用于微调 CLIP 模型的主程序 train_finetune.py,首先加载预训练的图像编码器(ResNet-50)、文本编码器(DECLUTR-SCI-BASE)以及相应的 tokenizer,然后创建了一个 CustomCLIPWrapper 模型实例进行微调训练。

import torch
from argparse import ArgumentParser
from pytorch_lightning import Trainer
from data.text_image_dm import TextImageDataModule
from models import CustomCLIPWrapper
from torchvision.models import resnet50
from transformers import AutoTokenizer, AutoModel

def main(hparams):
    img_encoder = resnet50(pretrained=True)
    img_encoder.fc = torch.nn.Linear(2048, 768)
    tokenizer = AutoTokenizer.from_pretrained("johngiorgi/declutr-sci-base")
    txt_encoder = AutoModel.from_pretrained("johngiorgi/declutr-sci-base")
    if hparams.minibatch_size < 1:
        hparams.minibatch_size = hparams.batch_size
    model = CustomCLIPWrapper(img_encoder, txt_encoder, hparams.minibatch_size, avg_word_embs=True)
    dm = TextImageDataModule.from_argparse_args(hparams, custom_tokenizer=tokenizer)
    trainer = Trainer.from_argparse_args(hparams, precision=16, max_epochs=32)
    trainer.fit(model, dm)


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--minibatch_size', type=int, default=0)
    parser = TextImageDataModule.add_argparse_args(parser)
    parser = Trainer.add_argparse_args(parser)
    args = parser.parse_args()
    main(args)

对上述代码的具体说明如下所示:

  1. 加载预训练模型和 tokenizer:加载预训练的 ResNet-50 图像编码器和 DECLUTR-SCI-BASE 文本编码器,以及相应的 tokenizer。
  2. 修改图像编码器:将 ResNet-50 的全连接层替换为一个线性层,将输出维度调整为 768。
  3. 创建模型实例:使用 CustomCLIPWrapper 类创建模型实例,传入图像编码器、文本编码器和其他参数,如最小批次大小。
  4. 初始化数据模块:使用 TextImageDataModule.from_argparse_args 根据命令行参数初始化数据模块,同时传入自定义的 tokenizer。
  5. 设置训练器参数:使用 Trainer.from_argparse_args 根据命令行参数设置训练器,包括精度和最大训练周期。
  6. 开始微调训练:使用 trainer.fit 方法开始微调训练模型。

6.3.6  调试运行

根据自己的需要,大家可以按照如下三种方式训练文生图模型CLIP。

1. 全新训练

在训练文生图模型CLIP时可以直接使用项目中的配置信息,只需提供一个训练目录或自己的数据集即可。在训练时需要指定模型名称,并告诉训练文件夹和批量大小,所有可能的模型都可以在models/config目录下的yaml文件中找到。例如运行命令如下:

python train.py --model_name RN50 --folder data_dir --batchsize 512

2. 微调训练

为了更高效地进行CLIP训练,可以使用类CustomCLIPWrapper,这个类用于微调预训练的图像和语言模型,这样可以大大提高性能效率。要使用这个功能,只需修改train_finetune.py文件,传入一个图像编码器和Hugging Face文本编码器。

img_encoder = resnet50(pretrained=True)
img_encoder.fc = torch.nn.Linear(2048, 768)

tokenizer = AutoTokenizer.from_pretrained("johngiorgi/declutr-sci-base")
txt_encoder = AutoModel.from_pretrained("johngiorgi/declutr-sci-base")

model = CustomCLIPWrapper(img_encoder, txt_encoder, hparams.minibatch_size, avg_word_embs=True)

具体的命令行参数与之前一样,只是去掉了 --model_name 标志:

python train_finetune.py --folder data_dir --batchsize 512

3. 使用自己的DataModule进行训练

此时需要每个图像对具有相同的stem名称(即coco_img1.png和coco_img1.txt),你只需在运行时指定文件夹即可。任何子文件夹结构都将被忽略,这意味着foo/bar/image1.jpg将始终找到它的myster/folder/image1.txt,只要它们共享一个共同的父文件夹。所有图像后缀都可以使用,唯一的期望是标题由\n分隔。

4. 使用自己的数据进行训练

如果你有不同的训练需求,可以插入自己的DataLoader。首先注释掉项目中的DataModule,并将你自己的DataModule插入到 trainer.fit(model, your_data) 中,然后编辑train.py脚本以满足你的需求。唯一的期望是返回元组的第一项是图像批次,第二项是文本批次。

标签:argparse,微调,训练,模型,args,parser,hparams
From: https://blog.csdn.net/asd343442/article/details/142484505

相关文章

  • 风速预测(三)EMD-LSTM-Attention模型
    往期精彩内容:时序预测:LSTM、ARIMA、Holt-Winters、SARIMA模型的分析与比较全是干货|数据集、学习资料、建模资源分享!拒绝信息泄露!VMD滚动分解+Informer-BiLSTM并行预测模型-CSDN博客风速预测(一)数据集介绍和预处理_风速数据在哪里下载-CSDN博客风速预测(二)基于Pytorch......
  • 手把手教你建【货币】一题的网络流模型
    现在已知如下问题,并告诉你这题可以用网络流来解决,你该怎么做,该怎么建出网络流的模型?一些前提:显然可以发现绝不可能走横向向左的边,但可能走竖向向上的边(如下图)那么图其实就是这样的:问从\(s\)到\(t\)的最小花费如果没有那\(m\)条限制,我们直接跑最短路就行了,加上这些限制......
  • 如何让智能客服像真人一样对话?容联七陌揭秘:多Agent大模型
    科技云报到原创。经历了多年的“答非所问”、“一问三不知”,很多人已经厌倦了所谓的“智能客服”。哪怕是技术已经非常成熟、可以模拟真人发音的外呼机器人,也会因为“机感”重而被用户迅速挂机或转向人工客服。智能客服似乎遇到了一道坎,在理解用户、和用户对话方面,始终无法实现真正......
  • MiniMax、商汤科技、面壁智能、西湖心辰、声网都来了!RTE 大会「实时互动和大模型」专
       当大模型进化到实时多模态,将诞生什么样的新场景和玩法? VoiceAI实现human-like的最后一步是什么? AI视频爆炸增长,新一代编解码技术将面临何种挑战? 所有AIInfra都在探寻规格和性能的最佳平衡,如何构建高可用的云边端协同架构? AI加持下,空间计算和新......
  • 从“可用”到“好用”,百度智能云如何做大模型的“超级工厂”?
    如果说,过去两三年大模型处于造锤子阶段,那么今年,更多的则是考验钉钉子的能力,面对各类业务场景大模型是否能够有的放矢、一击必中,为千行百业深度赋能。当前市场上,已经有200多把这样的锤子在疯狂找钉子。但从实际应用来看,大模型在文生文、文生图以及扮演初级的工作助理等方面还算合格,......
  • 豆包通用模型Pro:字节跳动的AI革新,引领多模态交互新纪元
    在人工智能技术的快速发展浪潮中,字节跳动凭借其最新的豆包通用模型Pro,再次站在了技术创新的前沿。豆包通用模型Pro不仅在技术上取得了显著的突破,更在实际应用中展现了其强大的多模态交互能力,为内容创作和用户交互提供了全新的解决方案。技术突破:豆包通用模型Pro的核心优势豆包通用......
  • Windows如何本地部署llamafile并运行千问7b大模型无需安装运行环境或依赖库
    文章目录前言1.下载llamafile2.下载大语言模型3.运行大语言模型4.安装Cpolar工具5.配置远程访问地址6.远程访问对话界面7.固定远程访问地址前言本文主要介绍在Windows系统电脑如何利用llamafile结合cpolar内网穿透工具,实现随时随地远程访问本地大语言模型的......
  • 代码随想录算法训练营第一天| 704. 二分查找、27. 移除元素、977.有序数组的平方。
    704.二分查找总结:防止int溢出:classSolution{public:intsearch(vector<int>&nums,inttarget){intleft=0;intright=nums.size()-1;while(left<=right){intmiddle=(left+right)/2;//intmid=(right-left)/......
  • 袋鼠云数据资产平台:数据模型标准化建表重构升级
    想要建立一个良好的数据模型,设计时需要优先考虑数据的关系,避免出现数据冗余和不一致的问题,减少数据维护的难度。正是基于这样的需求,袋鼠云数据资产平台中的数据模型提供了一种建表的能力,可以对表名、字段名等信息进行约束,并且支持批量解析模式(根据中文名批量解析字段)与建表语句模式......
  • 深入理解并发原子性、可见性、有序性与JMM内存模型
    1.并发三大特性并发编程Bug的源头:原子性、可见性和有序性问题1.1原子性一个或多个操作,要么全部执行且在执行过程中不被任何因素打断,要么全部不执行。在Java中,对基本数据类型的变量的读取和赋值操作是原子性操作(64位处理器)。不采取任何的原子性保障措施的自增操作并不是......