首页 > 其他分享 >DA-CLIP关于使用BLIP生成数据集的代码注释

DA-CLIP关于使用BLIP生成数据集的代码注释

时间:2024-03-21 22:01:56浏览次数:30  
标签:CLIP image DA BLIP device model generate size

背景:

BLIP:

DA-CLIP需要的目标:

 为了在混合的退化数据集上训练 DA-CLIP,作者使用引导式视觉语言框架 BLIP 为所有 HQ 图像生成描述。

从HQ图像生成的描述是准确的,不传递退化信息。 然后,我们可以直接将这些干净的标题、LQ 图像和相应的退化类型结合起来,构建图像-文本-退化类型对。 

 

BLIP开源deom

 上BLIP的GitHub开源,readme.md有colab的简易测试代码,直接点开。hugging face的 Web demo 无法使用。

链接:https://colab.research.google.com/github/salesforce/BLIP/blob/main/demo.ipynb

代码

# install requirements
import sys
if 'google.colab' in sys.modules:
    print('Running in Colab.')
    !pip3 install transformers==4.15.0 timm==0.4.12 fairscale==0.4.4
    !git clone https://github.com/salesforce/BLIP
    %cd BLIP

加载测试图。 并进行预处理

from PIL import Image
import requests
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_demo_image(image_size,device):
    img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg' 
    raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')   

    w,h = raw_image.size
    display(raw_image.resize((w//5,h//5)))
    
    transform = transforms.Compose([
        transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        ]) 
    image = transform(raw_image).unsqueeze(0).to(device)   
    return image

 加载模型,进行生成image caption


image_size = 384
image = load_demo_image(image_size=image_size, device=device)

model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'
    
model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')
model.eval()
model = model.to(device)

with torch.no_grad():
    # beam search
    caption = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5) 
    # nucleus sampling
    # caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5) 
    print('caption: '+caption[0])

model.generate 是在使用 BLIP(Bootstrapped Language Image Pretraining)模型进行图像描述(Image Captioning)任务时的一个方法。这个方法接收多个参数来控制生成图像描述的过程。下面是对您提供的代码中 model.generate 方法参数的解释:

  1. image: 这是要生成描述的输入图像。它应该是一个已经加载并转移到指定设备(如GPU)的张量。

  2. sample: 这是一个布尔值,用于选择生成策略。当 sample=False 时,使用贪婪解码(beam search),即每一步都选择最可能的下一个词。当 sample=True 时,使用采样方法,如核采样(nucleus sampling)。

  3. num_beams: 当使用beam search时,这个参数定义了beam的宽度。它影响解码过程中考虑的不同可能性的数量。较大的beam size可能会导致更多样化和流畅的描述,但也会增加计算成本。

  4. max_length: 这个参数设置了生成描述的最大长度(以词为单位)。如果生成的描述在达到最大长度之前结束,它将被截断。

  5. min_length: 这个参数设置了生成描述的最小长度。如果生成的描述在达到最小长度之前结束,解码过程将继续,直到满足最小长度要求

问题

  • 问题1:装依赖的时候时间较长,需要下载1个G多的timm依赖
  • 问题2:版本报错

 修改为transformers==4.16.0。

成功。提示重启会话。全部再次运行。

结果

 想要换成我的测试图片。由于读取的imgurl设置成只能读外部链接的图片。了解到图床这种东西。

简单搜了一下聚合图床 - 免费无限图片上传 (superbed.cn)

上传一张测试图

修改imgurl路径,重新运行。

哈哈挺有意思的。 

DA-CLIP内有相关py文件

Create dataset:

To generate clean captions with BLIP, we use the clip-interrogator tool. Install it with pip install clip-interrogator==0.6.0 and run:

python ../scripts/generate_captions.py

Then you will get daclip_train.csv and daclip_val.csv under the datasets/universal directory

pip install clip-interrogator==0.6.0  

运行../scripts/generate_captions.py

代码解析晚点再写吧。。 

标签:CLIP,image,DA,BLIP,device,model,generate,size
From: https://blog.csdn.net/m0_60350022/article/details/136918987

相关文章

  • cuda 内存模型
    cuda内存模型其实概括来说就是下面两张图双箭头代表可读可写,单箭头代表只读1.localmemory#include<iostream>#include"cuda_runtime.h"#include"device_launch_parameters.h"#defineBLOCK_SIZE256__global__voidtest_kernal(){ intarray[3]; floatvalu......
  • LeetCode刷题记录——day3
    1、https://leetcode.cn/problems/gas-station/submissions/514930619/?envType=study-plan-v2&envId=top-interview-150对于这个问题可以这样来考虑,将数据看作一个环,如果答案唯一,那么就意味着从任意一个节点开始寻找,最后都会得到同一个节点的答案,那么为何不直接从0节点开始呢?其......
  • Python利用Numpy和Pandas实现数据清洗
    利用Numpy和Pandas对数据进行清洗,包括去除重复记录、处理缺失值和异常值,实现代码如下:点击此处下载数据集#coding=utf-8#导入必要的库importpandasaspdimportnumpyasnp#导入数据及输出格式defread_data(filename):data=pd.read_csv(filename)......
  • Debezium vs OGG vs Tapdata:如何实时同步 Oracle 数据到 Kafka 消息队列?
    随着信息时代的蓬勃发展,企业对实时数据处理的需求逐渐成为推动业务创新和发展的重要驱动力。在这个快速变化的环境中,许多企业选择将Oracle数据库同步到Kafka,以满足日益增长的实时数据处理需求。本文将深入探讨这一趋势的背后原因,并通过一个真实的客户案例来强调实时性在业务场......
  • DAX:GROUPBY 嵌套聚合
    GROUPBY函数的作用是根据输入的表进行数据聚合,输入的表可以是表表达式,也就是说,GRUOPBY的参数可以是一个动态查询返回的表,也就是说GROUPBY函数主要用于嵌套聚合的情况。GROUPBY(<table>[,<groupBy_columnName>[,<groupBy_columnName>[,…]]][,<name>,<expression>[,<......
  • 面向报文的UDP(User Datagram Protocol,用户数据报协议)的一个重要特点
    与TCP(TransmissionControlProtocol,传输控制协议)不同,UDP是一种无连接的协议,它不会为数据建立和维护一个持续的连接。因此,UDP的数据传输方式是面向报文的,也就是说,它会把应用层交给它的报文作为一个整体发送出去,不会进行分割或合并。具体来说,当应用层数据交给UDP后,UDP会为其......
  • requests.post传的data如果是直接使用python dict封装,有些服务端接收不了这种数据类型
    平时在自己的php项目里,使用dict方式组装data,然后requests.post,一点问题都没有。但是调了后端一个java的微服务接口,结果就一直报错422: 最后问了一下开发,得到提示“python好像还有个毛病,python的json对象转字符串的时候,转出来的字符串不是标准json字符串,还要做个字符串处理,变成......
  • js——Date()怎么将获取北京时间的日期,精确到时分秒
    JavaScript的Date对象可以获取本地时间,但不直接支持时区。要获取特定时区的时间,你需要进行时区转换。以下是一个函数,用于获取特定时区(如“Asia/Shanghai”,即北京时间)的当前日期和时间,精确到秒。functiongetBeijingTime(){constbeijing=newDate().toLocaleString('en......
  • 本地搭建深度学习训练环境(配置conda环境 cuda pytorch...)
    目录简介Nvidia驱动和cudatoolKit简介首先我们要下载的东西包括:anaconda(虚拟环境管理)pycharm(代码项目编辑器)Nvidia驱动和cudatoolKitpytorch(最好使用wheel)其中,anaconda和pycharm的下载比较简单,这里不在赘述。主要讲解后两个:Nvidia驱动和cudatoolKitNvidia驱动是向......
  • 文献学习-21-DaFoEs:混合数据集以推广微创机器人手术中的视觉状态深度学习力估计
    DaFoEs:MixingDatasetsTowardstheGeneralizationofVision-StateDeep-LearningForceEstimationinMinimallyInvasiveRoboticSurgeryAuthors: MikelDeIturrateReyzabal,GraduateStudentMember,IEEE,MingcongChen,WeiHuang,SebastienOurselin,and......