首页 > 其他分享 >训练BlipForConditionalGeneration

训练BlipForConditionalGeneration

时间:2024-09-07 15:53:48浏览次数:11  
标签:训练 BlipForConditionalGeneration text image ids path input pixel


from transformers import BlipForConditionalGeneration, BlipProcessor, AutoTokenizer, AdamW
from PIL import Image
from datasets import load_dataset

processor = BlipProcessor.from_pretrained("huggingface.co/Salesforce/blip-image-captioning-base")
bertTokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")

input_dataset = load_dataset(path="data", data_files="data.csv")

image_path_list = [Image.open(img_path) for img_path in input_dataset["train"]["image_path"]]

image_inputs = processor(image_path_list, return_tensors="pt")
text_inputs = bertTokenizer(input_dataset["train"]["caption"],
                                max_length=128,
                                padding="max_length",
                                truncation=True,
                                add_special_tokens=False,  
                                return_tensors="pt",
                                return_token_type_ids=False)

pixel_values = image_inputs["pixel_values"]
text_ids = text_inputs["input_ids"]
attention_mask = text_inputs["attention_mask"]

# 从零训练用 BlipForConditionalGeneration.from_config()
model = BlipForConditionalGeneration.from_pretrained("huggingface.co/Salesforce/blip-image-captioning-base")

learning_rate = 5e-5
epochs = 30
optimizer = AdamW(model.parameters(), lr=learning_rate)

model.train()
for epoch in range(epochs):
        outputs = model.forward(pixel_values=pixel_values,
                                input_ids=text_ids,
                                attention_mask=attention_mask,
                                labels=text_ids)

        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        print(f"Epoch {epoch + 1}/{epochs} completed. Loss: {loss.item()}")

model.eval()

output_batch = model.generate(pixel_values=pixel_values, max_length=128)
for i in range(0, output_batch.shape[0]):
    caption = bertTokenizer.decode(output_batch[i], skip_special_tokens=True)
    print(caption)

注意输入的text_ids数据开头要有一个[PAD]/0数据,因为labels后续处理会把input_ids右移1位

标签:训练,BlipForConditionalGeneration,text,image,ids,path,input,pixel
From: https://blog.51cto.com/guotong1988/11945182

相关文章

  • P4649 [IOI2007] training 训练路径
    P4649[IOI2007]training训练路径题意:原题地址给你一棵\(n\)个节点的树,上面还有\(m-(n-1)\)条非树边,每条非树边有一个代价\(c_i\),要求你删掉若干条非树边使得之后的这棵树满足不存在任意一个长度为偶数的简单环。保证每个节点度数\(\le10\)。trick:如果树上不存在偶环......
  • 【保姆级教程】使用 PyTorch 自定义卷积神经网络(CNN) 实现图像分类、训练验证、预测全
    《博主简介》小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。......
  • 利用AI大语言模型和Langchain开发智能车算法训练知识库(上篇)
    今天小李哥将介绍亚马逊云科技的JupyterNotebook机器学习托管服务AmazonSageMaker上,通过AI大语言模型、向量知识库和LangChainAgent,创建用于AI智能车模型训练的RAG问答知识库。整个项目的架构图如下:本系列共分为上下两篇。在上篇内容中,我将分享该知识库的GitHub项目开源代......
  • 深入浅出孪生神经网络,高效训练模型
    大家好,在深度学习领域,神经网络几乎能处理各种任务,但通常需要依赖于海量数据来达到最佳效果。然而,对于像面部识别和签名验证这类任务,我们不可能总是有大量的数据可用。由此产生了一种新型的神经网络架构,称为孪生网络。孪生神经网络能够基于少量数据实现精准预测,本文将介绍孪生......
  • 代码随想录算法训练营第五十天 | 98. 所有可达路径
    目录98.所有可达路径思路图的存储邻接矩阵         邻接表深度优先搜索1.确认递归函数,参数2.确认终止条件3.处理目前搜索节点出发的路径方法一:邻接矩阵写法方法二:邻接表写法98.所有可达路径题目链接:卡码网题目链接(ACM模式)文章讲解:代码随想录 ......
  • 基于springboot武警警官学院训练场地管理系统的计算机毕设
    摘要随着互联网趋势的到来,各行各业都在考虑利用互联网将自己推广出去,最好方式就是建立自己的互联网系统,并对其进行维护和管理。在现实运用中,应用软件的工作规则和开发步骤,采用java技术建设训练场地管理系统。本毕业设计主要实现集人性化、高效率、便捷等优点于一身的训练场地管理......
  • 代码随想录算法训练营第十天| 232.用栈实现队列 、 225. 用队列实现栈 、20. 有效的括
    学习文章链接:代码随想录文章目录一、232.用栈实现队列二、225.用队列实现栈三、20.有效的括号四、1047.删除字符串中的所有相邻重复项一、232.用栈实现队列题目链接:232.用栈实现队列栈的操作:stack<int>s;s.empty();//如果栈为空则返回true,......
  • 腾讯:基于对话的LLM角色扮演训练框架
    ......
  • 【代码随想录训练营第42期 Day51打卡 - 岛屿问题 - 卡码网 99. 岛屿数量 100. 岛屿的
    目录一、做题心得二、题目与题解题目一:99.岛屿数量题目链接题解1:DFS 题解2:BFS 题目二:100.岛屿的最大面积题目链接题解:DFS 三、小结一、做题心得今天打卡的是经典的岛屿问题:分别从两个方向进行探讨--深搜(DFS)与广搜(BFS)。作为这两大基本搜索最经典的例题,今天......
  • BEVFormer复现(使用docker搭建训练环境)
    文章目录一、使用docker创建环境1.1创建容器1.2在容器中安装常用的包1.3安装miniconda1.4安装Pytorch二、环境配置2.1下载源码2.2安装mmcv-full2.3安装mmdet和mmseg2.4从源码安装mmdet3d2.5安装Detectron2和Timm2.6下载预训练模型三、数据准备3.1下载数据集......