1. 前言
最近在复现MCNN时发现一个问题,ShanghaiTech数据集图片的尺寸不一,转换为tensor后的shape形状不一致,无法直接进行多batch_size的数据加载。经过查找资料,有人提到可以定义dataloader的collate_fn函数,在加载时将数据裁剪为最小的图片尺寸,以便于堆叠成多个batch_size。
2. 代码
2.1 数据集的定义
dataset.py
import scipy.io as sio from torch.utils.data import DataLoader, Dataset import numpy as np import torch import os import cv2 from PIL import Image import torchvision class myDatasets(Dataset): def __init__(self,img_path, ann_path, down_sample=False,transform=None): self.pre_img_path = img_path self.pre_ann_path = ann_path # 图像的文件名是 IMG_15.jpg 则 标签是 GT_IMG_15.mat # 因此不需要listdir标签路径 self.img_names = os.listdir(img_path) self.transform=transform self.down_sample = down_sample def __len__(self): return len(self.img_names) def __getitem__(self, index): img_name = self.img_names[index] mat_name = 'GT_' + img_name.replace('jpg','mat') img = Image.open(os.path.join(self.pre_img_path,img_name)).convert('L') img = np.array(img).astype(np.float32) # print(F"{h=},{w=}") if self.transform != None: img=self.transform(img) # img.permute(0,2,1) # totensor会自动进行维度的转换,所以这里是不必要的 h,w = img.shape[1],img.shape[2] anno = sio.loadmat(self.pre_ann_path + mat_name) xy = anno['image_info'][0][0][0][0][0] # N,2的坐标数组 density_map = self.get_density((h,w), xy).astype(np.float32) # 密度图 density_map = torch.from_numpy(density_map) return img,density_map def get_density(self,img_shape, points): if self.down_sample: h, w = img_shape[0]//4, img_shape[1]//4 else: h, w = img_shape[0], img_shape[1] # 进行下采样 # 密度图 初始化全0 labels = np.zeros(shape=(h,w)) for loc in points: f_sz = 15 # 滤波器尺寸 预设为15 也是邻域的尺寸 sigma = 4.0 # sigma参数 H = self.fspecial(f_sz, f_sz , sigma) # 高斯核矩阵 if self.down_sample: x = min(max(0,abs(int(loc[0]/4))),int(w)) # 头部坐标 y = min(max(0,abs(int(loc[1]/4))),int(h)) else: x = min(max(0,abs(int(loc[0]))),int(w)) # 头部坐标 y = min(max(0,abs(int(loc[1]))),int(h)) if x > w or y > h: continue x1 = x - f_sz/2 ; y1 = y - f_sz/2 x2 = x + f_sz/2 ; y2 = y + f_sz/2 dfx1 = 0; dfy1 = 0; dfx2 = 0; dfy2 = 0 change_H = False if x1 < 0: dfx1 = abs(x1);x1 = 0 ;change_H = True if y1 < 0: dfy1 = abs(y1); y1 = 0 ; change_H = True if x2 > w: dfx2 = x2-w ; x2 =w-1 ; change_H =True if y2 > h: dfy2 = y2 -h ; y2 = h-1 ; change_H =True x1h = 1 + dfx1 ; y1h = 1 + dfy1 x2h = f_sz - dfx2 ; y2h = f_sz - dfy2 if change_H : H = self.fspecial(int(y2h-y1h+1), int(x2h-x1h+1),sigma) labels[int(y1):int(y2), int(x1):int(x2)] = labels[int(y1):int(y2), int(x1):int(x2)] + H return labels def fspecial(self,ksize_x=5, ksize_y = 5, sigma=4): kx = cv2.getGaussianKernel(ksize_x, sigma) ky = cv2.getGaussianKernel(ksize_y, sigma) return np.multiply(kx,np.transpose(ky))View Code
2.2 使用
demo.py
from config import get_args from model import MCNN from dataset import myDatasets import torchvision from torch.utils.data import DataLoader import torch from torch import nn import time from utils import get_mse_mae,show import os import numpy as np import matplotlib.pyplot as plt from debug_utils import ModelVerbose import random import cv2 args = get_args() if args.dataset == 'ShanghaiTechA': if os.name == 'nt': # for windows train_imgs_path = args.dataset_path + r'\train_data\images\\' train_labels_path = args.dataset_path+r'\train_data\ground-truth\\' test_imgs_path = args.dataset_path+r'\test_data\images\\' test_labels_path = args.dataset_path+r'\test_data\ground-truth\\' else: # for linux train_imgs_path = os.path.join(args.dataset_path,'train_data/images/') train_labels_path = os.path.join(args.dataset_path,'train_data/ground-truth/') test_imgs_path = os.path.join(args.dataset_path,'test_data/images/') test_labels_path = os.path.join(args.dataset_path,'test_data/ground-truth/') # print(F"{train_imgs_path=}\n{train_labels_path=}\n{test_imgs_path=}\n{test_labels_path=}") else: raise Exception(F'Dataset {args.dataset} Not Implement') # 数据集 transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor() # torchvision.transforms.Resize((768,1024)), # torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def get_min_size(batch): min_ht, min_wd = (float('inf'),float('inf')) for img in batch: c,h,w = img.shape if h<min_ht: min_ht = h if w<min_wd: min_wd = w return min_ht,min_wd def random_crop_img(img,size): c,h,w = img.shape h_start = random.randint(0,h-size[0]) h_end = h_start + size[0] w_start = random.randint(0,w - size[1]) w_end = w_start + size[1] return img[:,h_start:h_end,w_start:w_end] def random_crop_dtmap(dt_map,size): h,w = dt_map.shape h_start = random.randint(0,h-size[0]) h_end = h_start + size[0] w_start = random.randint(0,w - size[1]) w_end = w_start + size[1] return dt_map[h_start:h_end,w_start:w_end] def random_crop(img,dt_map,size): c,h,w = img.shape h_start = random.randint(0,h-size[0]) h_end = h_start + size[0] w_start = random.randint(0,w - size[1]) w_end = w_start + size[1] return img[:,h_start:h_end,w_start:w_end],dt_map[h_start:h_end,w_start:w_end] def c_f(batch): # 这里接收到的data 是[(img_1_768_1024,target_192,256)] # 1. 分别找到img target的最大h w # 2. 新建数组(h,w) transposed = list(zip(*batch)) imgs, dens = [transposed[0],transposed[1]] error_msg = "batch must contain tensors; found {}" if isinstance(imgs[0],torch.Tensor) and isinstance(dens[0],torch.Tensor): min_h, min_w = get_min_size(imgs) cropped_imgs = [] cropped_dens = [] for i in range(len(batch)): # _img = random_crop_img(imgs[i],(min_h,min_w)) # 下采样 # _dtmap = random_crop_dtmap(dens[i],(min_h//4,min_w//4)) # _dtmap = random_crop_dtmap(dens[i],(min_h,min_w)) _img,_dtmap = random_crop(imgs[i],dens[i],(min_h,min_w)) cropped_imgs.append(_img) cropped_dens.append(_dtmap) cropped_imgs = torch.stack(cropped_imgs) cropped_dens = torch.stack(cropped_dens) return [cropped_imgs,cropped_dens] raise TypeError((error_msg.format(type(batch[0])))) train_datasets = myDatasets(train_imgs_path, train_labels_path,down_sample=False,transform=transform) train_data_loader = DataLoader(train_datasets, batch_size=args.batch_size,collate_fn=c_f) test_datasets = myDatasets(test_imgs_path, test_labels_path,down_sample=True,transform=transform) test_data_loader = DataLoader(test_datasets, batch_size=args.batch_size) def color_map(img,color='gray'): # labels是一个二维数组,是密度图 max_pixel = np.max(img) min_pixel = np.min(img) delta = max_pixel - min_pixel labels_int = ((img-min_pixel)/delta*255) # 以下操作是为了反转jet的颜色,不然就会出现数值高的反而是蓝色,数值低的是红色,不像热力图了 labels_int = labels_int * (-1) labels_int = labels_int + 255 labels_int = labels_int.astype(np.uint8) if color == 'jet': return cv2.applyColorMap(labels_int,cv2.COLORMAP_JET) else: img_ = img[::,::] img_ = cv2.cvtColor(img_,cv2.COLOR_GRAY2RGB) return img_ for i,(imgs,targets) in enumerate(train_data_loader): # img.shape: (1,1,768,1024) # targets.shape: (1,192,256) for j in range(args.batch_size): img = imgs[j][0].numpy() dtmap = targets[j].numpy() # dtmap = cv2.resize(dtmap,img.shape[::-1]) img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB) img = img.astype(np.uint8) dtmap = color_map(dtmap,'jet') visual_img = cv2.addWeighted(img,0.5,dtmap,0.5,0) plt.imshow(visual_img) plt.show() if i>1: breakView Code
2.3 配置
config.py
import argparse def get_args(): parser = argparse.ArgumentParser(description='MCNN') parser.add_argument('--dataset',type=str,default='ShanghaiTechA') parser.add_argument('--dataset_path',type=str,default=r"C:\Users\ocean\Downloads\datasets\ShanghaiTech\part_A\\") parser.add_argument('--save_path',type=str,default='./save_file/') parser.add_argument('--print_freq',type=int,default=1) parser.add_argument('--device',type=str,default='cuda') parser.add_argument('--epochs',type=int,default=600) parser.add_argument('--batch_size',type=int,default=4) parser.add_argument('--lr',type=float,default=1e-5) parser.add_argument('--optimizer',type=str,default='Adam') args = parser.parse_args() # for jupyer notbook # args = parser.parse_know_args()[0] return argsView Code
3. 总结
其中比较值得说道时collate_fn函数c_f(),它的代码如下所示
def c_f(batch): transposed = list(zip(*batch)) imgs, dens = [transposed[0],transposed[1]] error_msg = "batch must contain tensors; found {}" if isinstance(imgs[0],torch.Tensor) and isinstance(dens[0],torch.Tensor): min_h, min_w = get_min_size(imgs) cropped_imgs = [] cropped_dens = [] for i in range(len(batch)): _img,_dtmap = random_crop(imgs[i],dens[i],(min_h,min_w)) cropped_imgs.append(_img) cropped_dens.append(_dtmap) cropped_imgs = torch.stack(cropped_imgs) cropped_dens = torch.stack(cropped_dens) return [cropped_imgs,cropped_dens] # 这里不用列表包起来应该也行 raise TypeError((error_msg.format(type(batch[0]))))
这里传入的参数batch是一个list,其长度是batch_size。它的每一个元素代表了一个数据集单元,即自定义数据集类中__getitem__方法return的值。由于我们的__getitem__方法return了img和density_map两个数据,所以batch的每一个数据单元其实是一个元组(img, density_map)。
list(zip(*batch))所做的事情是把batch中的imgs和density_maps分别拿出来各自成为一个列表,方便下一步的处理。
在处理最后还要将列表中的元素堆叠成tensor返回
标签:args,集时,img,int,self,dataloader,Pytorch,import,path From: https://www.cnblogs.com/x-ocean/p/16878864.html