1、Related_functions.py
import torch from torchvision import models, transforms from PIL import Image import os import numpy as np import warnings warnings.filterwarnings("ignore", category=Warning) def get_feature(image_dir): vgg_model = models.vgg19(pretrained=True) new_classifier = torch.nn.Sequential(*list(vgg_model.children())[-1][:6]) vgg_model.classifier = new_classifier trans = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) im = Image.open(image_dir).convert('RGB') im = trans(im) im.unsqueeze_(dim=0) vgg_model = vgg_model.eval() y = vgg_model(im).data.numpy().tolist() feature = y[0] return feature def get_img_feature(img_dir): img_feature = get_feature(img_dir) return img_feature def get_Datasets_feature(Datasets_dir): if not Datasets_dir.endswith('/'): Datasets_dir = Datasets_dir + '/' try: os.listdir(Datasets_dir) except: print('请检查数据库的路径') all_feature= [] paths = [] for fi in os.listdir(Datasets_dir): img_paths = Datasets_dir+fi+'/' for fj in os.listdir(img_paths): img_dir = img_paths+fj img_feature = get_img_feature(img_dir) all_feature.append(img_feature) paths.append(img_dir) print('正在提特征的图像是:',img_dir) return all_feature,paths def calEuclideanDistance(x,y): return np.sqrt(sum(pow(a-b,2) for a,b in zip(x,y))) def query_sim_img(query_img_feature,Datasets_features,Datasets_paths,top_num): need_im_det_instance=[] for i in range(Datasets_features.shape[0]): img_feature = Datasets_features[i,:] dist=calEuclideanDistance(query_img_feature,img_feature) need_im_det_instance.append(dist) im_distanc=np.array(need_im_det_instance) y = im_distanc.argsort() similar_img_path=[] for index in y[0:top_num]: similar_img_path.append(Datasets_paths[index]) return similar_img_path
2、img_dataset_feature.py
# -*- coding:utf-8 -*- import os from Related_functions import * if __name__ == '__main__': Datasets_dir = 'D:/My_work/python_code/03_lianxi/C00304/Animals_with_Attributes2/JPEGImages' features,paths = get_Datasets_feature(Datasets_dir) np.save('Datasets_features.npy', features) np.save('Datasets_paths.npy', paths)
标签:检索,paths,Datasets,img,feature,实践,im,图像,dir From: https://www.cnblogs.com/wjjcjj/p/18233077