近期看到了一篇用CLIP在我这个方向应用的文章,所以玩了一下CLIP,感觉效果还是很好的。
首先,github上的zero-shot代码
import os import clip import torch from torchvision.datasets import CIFAR100 # Load the model device = "cuda" if torch.cuda.is_available() else "cpu" model, preprocess = clip.load('ViT-B/32', device) # Download the dataset cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False) # Prepare the inputs image, class_id = cifar100[3637] image_input = preprocess(image).unsqueeze(0).to(device) text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device) # Calculate features with torch.no_grad(): image_features = model.encode_image(image_input) text_features = model.encode_text(text_inputs) # Pick the top 5 most similar labels for the image image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) values, indices = similarity[0].topk(5) # Print the result print("\nTop predictions:\n") for value, index in zip(values, indices): print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")
这里稍微介绍一下,模型的model.encode_xxx方法是用来计算特征的,这个和前向传播没什么差别,唯一不同的是需要多一些处理操作,上面的代码主要做的事情就是预测图片属于100类中的哪一类,找出了top-5的结果。
代码1:
with torch.no_grad(): logits_per_image, logits_per_text = model(image_input, text_inputs)
prob = logits_per_image.softmax(-1)
代码2:
with torch.no_grad(): image_features = model.encode_image(image_input) text_features = model.encode_text(text_inputs) # Pick the top 5 most similar labels for the image image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) prob = (100.0 * image_features @ text_features.T).softmax(dim=-1)
代码1和代码2做的是一样的事情,都可以得到最后的预测,而且prob最后都是一样的,可以尝试一些。
第二个就是
我用cifar100测试集测试了一下VIT-B/32这个CLIP模型的zero-shot性能,最后是得到了61%的准确度,当然CLIP论文的作者在论文末尾也说了,未必对于所有目前流行的数据集都是完全zero-shot,不过它这个性能其实还是很不错的,虽然用0.4billion图片训练有点欺负人的意思。
import os import clip import torch from torchvision.datasets import CIFAR100 from tqdm import tqdm # Load the model device = "cuda" if torch.cuda.is_available() else "cpu" print(f"use device : {device}") model, preprocess = clip.load('ViT-B/32', device) # Download the dataset and the train=False that is mean we will download or load the test set cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False) text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device) accuracy = 0 for item in tqdm(cifar100): image, class_id = item image_input = preprocess(image).unsqueeze(0).to(device) with torch.no_grad(): logits_per_image, logits_per_text = model(image_input, text_inputs) temp_ans = logits_per_image.argmax().item() if temp_ans == class_id: accuracy += 1 accuracy/=10000 print(accuracy)
标签:features,CLIP,模型,text,代码,torch,device,model,image From: https://www.cnblogs.com/XY-Transfomerer/p/17739087.html