首页 > 其他分享 >CLIP模型代码

CLIP模型代码

时间:2023-10-01 18:23:11浏览次数:49  
标签:features CLIP 模型 text 代码 torch device model image

近期看到了一篇用CLIP在我这个方向应用的文章,所以玩了一下CLIP,感觉效果还是很好的。

 

首先,github上的zero-shot代码

import os
import clip
import torch
from torchvision.datasets import CIFAR100

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)

# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)

# Calculate features
with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")

这里稍微介绍一下,模型的model.encode_xxx方法是用来计算特征的,这个和前向传播没什么差别,唯一不同的是需要多一些处理操作,上面的代码主要做的事情就是预测图片属于100类中的哪一类,找出了top-5的结果。

代码1:

with torch.no_grad():
        logits_per_image, logits_per_text = model(image_input, text_inputs)
     prob = logits_per_image.softmax(-1)

代码2:

with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
prob = (100.0 * image_features @ text_features.T).softmax(dim=-1)

代码1和代码2做的是一样的事情,都可以得到最后的预测,而且prob最后都是一样的,可以尝试一些。

 

第二个就是

我用cifar100测试集测试了一下VIT-B/32这个CLIP模型的zero-shot性能,最后是得到了61%的准确度,当然CLIP论文的作者在论文末尾也说了,未必对于所有目前流行的数据集都是完全zero-shot,不过它这个性能其实还是很不错的,虽然用0.4billion图片训练有点欺负人的意思。

import os
import clip
import torch
from torchvision.datasets import CIFAR100
from tqdm import tqdm

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"use device : {device}")
model, preprocess = clip.load('ViT-B/32', device)
# Download the dataset and the train=False that is mean we will download or load the test set
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)

text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)
accuracy = 0
for item in tqdm(cifar100):
    image, class_id = item
    image_input = preprocess(image).unsqueeze(0).to(device)
    with torch.no_grad():
        logits_per_image, logits_per_text = model(image_input, text_inputs)
    temp_ans = logits_per_image.argmax().item()
    if temp_ans == class_id:
        accuracy += 1

accuracy/=10000
print(accuracy)

 

 


标签:features,CLIP,模型,text,代码,torch,device,model,image
From: https://www.cnblogs.com/XY-Transfomerer/p/17739087.html

相关文章

  • adoc转换html+UPF低功耗仿真例子+python转换C代码+readmemh的@使用
    adoc转换htmladoc这种格式是很多riscv文档使用的格式,该格式可以生成pdf,生成html。生成html的好处是,选中和翻译方便,复制粘贴方便。首先是gem软件要安装,这个软件似乎是ruby相关的(RubyGemsisapackagemanagerfortheRubyprogramminglanguagethatprovidesastandardform......
  • 手把手教你在Ubuntu上部署中文LLAMA-2大模型
     一、前言 llama2作为目前最优秀的的开源大模型,相较于chatGPT,llama2占用的资源更少,推理过程更快,本文将借助llama.cpp工具在ubuntu(x86\ARM64)平台上搭建纯CPU运行的中文LLAMA2中文模型。二、准备工作 1、一个Ubuntu环境(本教程基于Ubuntu20LTS版操作) 2、确保你的环境可......
  • 使用 Gradle:将项目代码导入 IntelliJ
    1.将项目导入IntelliJ打开IntelliJ,如果还打开了其他程序,请关闭它们,再次进入欢迎屏幕。这次,不选择“创建新项目”,而是选择导入项目(ImportProject)。点击导入项目(ImportProject)后,会弹出一个窗口,提示你从某个文件夹导入项目。转到保存ud282-master的文件......
  • React 18 useEffect 代码执行两次的问题
    https://github.com/zjy4fun/notes/issues/62 React18提出的新特性“并发渲染”,为了防止组件重复挂载的问题,React在开发模式&&严格模式下,useEffect会执行两次(模拟组件挂载和组件卸载,让问题提早暴露),但是线上模式不会。开发模式下,可以通过设置标志位防止useEffect执行多......
  • 全新注意力算法PagedAttention:LLM吞吐量提高2-4倍,模型越大效果越好
    前言 吞吐量上不去有可能是内存背锅!无需修改模型架构,减少内存浪费就能提高吞吐量!本文转载自新智元仅用于学术分享,若侵权请联系删除欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读、CV招聘信息。CV各大方向专栏与各个部署框架最全教程整理......
  • VCS代码保护+SOC中的复位电路+verdi生成部分原理图+verdi查看delta cycle+自定义的原
    VCS代码保护在新思公司的一些vip的实现中,一些代码进行了加密,导致无法查看源码,加密的方法也是使用新思的工具VCS。在编译的命令行添加+protect选项,在代码前后加上编译指示,则生成对应的加密vp、svp文件,中间的部分被加密。https://blog.csdn.net/woodhorse007/article/details/524......
  • 【8.0】Fastapi响应模型
    【一】自定义响应模型【1】定义视图函数fromfastapiimportAPIRouterfrompydanticimportBaseModel,EmailStrfromtypingimportOptionalapp04=APIRouter()###响应模型#定义基本类classUserBase(BaseModel):#定义字段username:用户名类型为str:......
  • Python代码转换成C++
    Python和C++是两种不同的编程语言,但它们都有各自的优势和适用场景。在某些情况下,我们可能需要将Python代码转换成C++代码,以获得更高的执行效率或更好的性能。本文将从多个方面介绍如何将Python代码转换为C++代码。一、代码结构Python和C++在代码结构上存在一些差异。Python是一种解......
  • 六种模型含义
    四种物料类型(宽料,窄料,厚料,小卷料)四种冷却方式(自然冷却,1号风机,2号风机,1号和2号风机)物料温度随时间而下降受到分类变量(物料类型,冷却方式)的影响拟合规律预测不同物料类型和冷却方式下的物料温度随时间的变化规律生成数据python机器学习六种方法解释每个......
  • 【机器学习 | 数据预处理】 提升模型性能,优化特征表达:数据标准化和归一化的数值处理技
    ......