首页 > 其他分享 >Pytorch笔记:dataloader的collate_fn参数在加载数据集时的作用

Pytorch笔记:dataloader的collate_fn参数在加载数据集时的作用

时间:2022-11-11 10:34:56浏览次数:46  
标签:args 集时 img int self dataloader Pytorch import path

1. 前言

最近在复现MCNN时发现一个问题,ShanghaiTech数据集图片的尺寸不一,转换为tensor后的shape形状不一致,无法直接进行多batch_size的数据加载。经过查找资料,有人提到可以定义dataloader的collate_fn函数,在加载时将数据裁剪为最小的图片尺寸,以便于堆叠成多个batch_size。

2. 代码

2.1 数据集的定义

dataset.py

import scipy.io as sio
from torch.utils.data import DataLoader, Dataset
import numpy as np
import torch
import os
import cv2
from PIL import Image
import torchvision

class myDatasets(Dataset):
    def __init__(self,img_path, ann_path, down_sample=False,transform=None):
        self.pre_img_path = img_path
        self.pre_ann_path = ann_path
        # 图像的文件名是 IMG_15.jpg 则 标签是 GT_IMG_15.mat
        # 因此不需要listdir标签路径
        self.img_names = os.listdir(img_path)
        self.transform=transform
        self.down_sample = down_sample

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, index):
        img_name = self.img_names[index]
        mat_name = 'GT_' + img_name.replace('jpg','mat')

        img = Image.open(os.path.join(self.pre_img_path,img_name)).convert('L')
        img = np.array(img).astype(np.float32)
        
        # print(F"{h=},{w=}")
        if self.transform != None:
            img=self.transform(img)
        # img.permute(0,2,1) # totensor会自动进行维度的转换,所以这里是不必要的

        h,w = img.shape[1],img.shape[2]

        anno = sio.loadmat(self.pre_ann_path + mat_name)
        xy = anno['image_info'][0][0][0][0][0]  # N,2的坐标数组
        density_map = self.get_density((h,w), xy).astype(np.float32) # 密度图
        density_map = torch.from_numpy(density_map)

        return img,density_map


    def get_density(self,img_shape, points):
        if self.down_sample:
            h, w  = img_shape[0]//4, img_shape[1]//4
        else:
            h, w  = img_shape[0], img_shape[1]
        # 进行下采样
        # 密度图 初始化全0
        labels = np.zeros(shape=(h,w))
        for loc in points:
            f_sz = 15  # 滤波器尺寸 预设为15 也是邻域的尺寸
            sigma = 4.0  # sigma参数
            H = self.fspecial(f_sz, f_sz , sigma)  # 高斯核矩阵
            if self.down_sample:
                x = min(max(0,abs(int(loc[0]/4))),int(w))  # 头部坐标
                y = min(max(0,abs(int(loc[1]/4))),int(h))
            else:
                x = min(max(0,abs(int(loc[0]))),int(w))  # 头部坐标
                y = min(max(0,abs(int(loc[1]))),int(h))
            if x > w or y > h:
                continue
            x1 = x - f_sz/2 ; y1 = y - f_sz/2
            x2 = x + f_sz/2 ; y2 = y + f_sz/2
            dfx1 = 0; dfy1 = 0; dfx2 = 0; dfy2 = 0

            change_H = False
            if x1 < 0:
                dfx1 = abs(x1);x1 = 0 ;change_H = True
            if y1 < 0:
                dfy1 = abs(y1); y1 = 0 ; change_H = True
            if x2 > w:
                dfx2 = x2-w ; x2 =w-1 ; change_H =True
            if y2 > h:
                dfy2 = y2 -h ; y2 = h-1 ; change_H =True
            x1h =  1 + dfx1 ; y1h =  1 + dfy1
            x2h = f_sz - dfx2 ; y2h = f_sz - dfy2
            if change_H :
                H = self.fspecial(int(y2h-y1h+1), int(x2h-x1h+1),sigma)
            labels[int(y1):int(y2), int(x1):int(x2)] = labels[int(y1):int(y2), int(x1):int(x2)] + H
        return labels

    def fspecial(self,ksize_x=5, ksize_y = 5, sigma=4):
        kx = cv2.getGaussianKernel(ksize_x, sigma)
        ky = cv2.getGaussianKernel(ksize_y, sigma)
        return np.multiply(kx,np.transpose(ky))
View Code

2.2 使用

demo.py

from config import get_args
from model import MCNN
from dataset import myDatasets
import torchvision
from torch.utils.data import DataLoader
import torch
from torch import nn
import time
from utils import get_mse_mae,show
import os
import numpy as np
import matplotlib.pyplot as plt
from debug_utils import ModelVerbose
import random
import cv2

args = get_args()

if args.dataset == 'ShanghaiTechA':
    if os.name == 'nt':
        # for windows
        train_imgs_path = args.dataset_path + r'\train_data\images\\'
        train_labels_path = args.dataset_path+r'\train_data\ground-truth\\'
        test_imgs_path = args.dataset_path+r'\test_data\images\\'
        test_labels_path = args.dataset_path+r'\test_data\ground-truth\\'
    else:
        # for linux
        train_imgs_path = os.path.join(args.dataset_path,'train_data/images/')
        train_labels_path = os.path.join(args.dataset_path,'train_data/ground-truth/')
        test_imgs_path = os.path.join(args.dataset_path,'test_data/images/')
        test_labels_path = os.path.join(args.dataset_path,'test_data/ground-truth/')
    # print(F"{train_imgs_path=}\n{train_labels_path=}\n{test_imgs_path=}\n{test_labels_path=}")
else:
    raise Exception(F'Dataset {args.dataset} Not Implement')

# 数据集
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
    # torchvision.transforms.Resize((768,1024)),
    # torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def get_min_size(batch):
    min_ht, min_wd = (float('inf'),float('inf'))
    for img in batch:
        c,h,w = img.shape
        if h<min_ht:
            min_ht = h
        if w<min_wd:
            min_wd = w
    return min_ht,min_wd

def random_crop_img(img,size):
    c,h,w = img.shape
    h_start = random.randint(0,h-size[0])
    h_end = h_start + size[0]
    w_start = random.randint(0,w - size[1])
    w_end = w_start + size[1]

    
    return img[:,h_start:h_end,w_start:w_end]

def random_crop_dtmap(dt_map,size):
    h,w = dt_map.shape
    h_start = random.randint(0,h-size[0])
    h_end = h_start + size[0]
    w_start = random.randint(0,w - size[1])
    w_end = w_start + size[1]
    return dt_map[h_start:h_end,w_start:w_end]

def random_crop(img,dt_map,size):
    c,h,w = img.shape
    h_start = random.randint(0,h-size[0])
    h_end = h_start + size[0]
    w_start = random.randint(0,w - size[1])
    w_end = w_start + size[1]
    return img[:,h_start:h_end,w_start:w_end],dt_map[h_start:h_end,w_start:w_end]



def c_f(batch):
    # 这里接收到的data 是[(img_1_768_1024,target_192,256)]
    # 1. 分别找到img target的最大h w
    # 2. 新建数组(h,w)
    transposed = list(zip(*batch))
    imgs, dens = [transposed[0],transposed[1]]
    error_msg = "batch must contain tensors; found {}"
    if isinstance(imgs[0],torch.Tensor) and isinstance(dens[0],torch.Tensor):
        min_h, min_w = get_min_size(imgs)
        cropped_imgs = []
        cropped_dens = []
        for i in range(len(batch)):
            # _img = random_crop_img(imgs[i],(min_h,min_w))
            # 下采样
            # _dtmap = random_crop_dtmap(dens[i],(min_h//4,min_w//4))
            # _dtmap = random_crop_dtmap(dens[i],(min_h,min_w))
            _img,_dtmap = random_crop(imgs[i],dens[i],(min_h,min_w))
            cropped_imgs.append(_img)
            cropped_dens.append(_dtmap)
        cropped_imgs = torch.stack(cropped_imgs)
        cropped_dens = torch.stack(cropped_dens)
        return [cropped_imgs,cropped_dens]
    raise TypeError((error_msg.format(type(batch[0]))))

train_datasets = myDatasets(train_imgs_path, train_labels_path,down_sample=False,transform=transform)
train_data_loader = DataLoader(train_datasets, batch_size=args.batch_size,collate_fn=c_f)
test_datasets = myDatasets(test_imgs_path, test_labels_path,down_sample=True,transform=transform)
test_data_loader = DataLoader(test_datasets, batch_size=args.batch_size)

def color_map(img,color='gray'):
    # labels是一个二维数组,是密度图
    max_pixel = np.max(img)
    min_pixel = np.min(img)
    delta = max_pixel - min_pixel
    labels_int = ((img-min_pixel)/delta*255)
    # 以下操作是为了反转jet的颜色,不然就会出现数值高的反而是蓝色,数值低的是红色,不像热力图了
    labels_int = labels_int * (-1)
    labels_int = labels_int + 255
    labels_int = labels_int.astype(np.uint8)
    if color == 'jet':
        return cv2.applyColorMap(labels_int,cv2.COLORMAP_JET)
    else:
        img_ = img[::,::]
        img_ = cv2.cvtColor(img_,cv2.COLOR_GRAY2RGB)
        return img_

for i,(imgs,targets) in enumerate(train_data_loader):
    # img.shape:        (1,1,768,1024)
    # targets.shape:    (1,192,256)
    for j in range(args.batch_size):
        img = imgs[j][0].numpy()
        dtmap = targets[j].numpy()
        # dtmap = cv2.resize(dtmap,img.shape[::-1])
        img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)
        img = img.astype(np.uint8)
        dtmap = color_map(dtmap,'jet')
        visual_img = cv2.addWeighted(img,0.5,dtmap,0.5,0)
        plt.imshow(visual_img)
        plt.show()
    if i>1:
        break
    
View Code

2.3 配置

config.py

import argparse


def get_args():
    parser = argparse.ArgumentParser(description='MCNN')

    parser.add_argument('--dataset',type=str,default='ShanghaiTechA')

    parser.add_argument('--dataset_path',type=str,default=r"C:\Users\ocean\Downloads\datasets\ShanghaiTech\part_A\\")

    parser.add_argument('--save_path',type=str,default='./save_file/')

    parser.add_argument('--print_freq',type=int,default=1)

    parser.add_argument('--device',type=str,default='cuda')

    parser.add_argument('--epochs',type=int,default=600)

    parser.add_argument('--batch_size',type=int,default=4)

    parser.add_argument('--lr',type=float,default=1e-5)

    parser.add_argument('--optimizer',type=str,default='Adam')

    args = parser.parse_args()
    # for jupyer notbook
    # args = parser.parse_know_args()[0]
    return args
View Code

 3. 总结

其中比较值得说道时collate_fn函数c_f(),它的代码如下所示

def c_f(batch):
    transposed = list(zip(*batch))
    imgs, dens = [transposed[0],transposed[1]]
    error_msg = "batch must contain tensors; found {}"
    if isinstance(imgs[0],torch.Tensor) and isinstance(dens[0],torch.Tensor):
        min_h, min_w = get_min_size(imgs)
        cropped_imgs = []
        cropped_dens = []
        for i in range(len(batch)):
            _img,_dtmap = random_crop(imgs[i],dens[i],(min_h,min_w))
            cropped_imgs.append(_img)
            cropped_dens.append(_dtmap)
        cropped_imgs = torch.stack(cropped_imgs)
        cropped_dens = torch.stack(cropped_dens)
        return [cropped_imgs,cropped_dens] # 这里不用列表包起来应该也行
    raise TypeError((error_msg.format(type(batch[0]))))

这里传入的参数batch是一个list,其长度是batch_size。它的每一个元素代表了一个数据集单元,即自定义数据集类中__getitem__方法return的值。由于我们的__getitem__方法return了img和density_map两个数据,所以batch的每一个数据单元其实是一个元组(img, density_map)。

list(zip(*batch))所做的事情是把batch中的imgs和density_maps分别拿出来各自成为一个列表,方便下一步的处理。

在处理最后还要将列表中的元素堆叠成tensor返回

标签:args,集时,img,int,self,dataloader,Pytorch,import,path
From: https://www.cnblogs.com/x-ocean/p/16878864.html

相关文章

  • conda 虚拟环境安装pytorch & d2l包
    conda虚拟环境安装pytorch1、首先,conda终端添加清华镜像源,可以加快安装速度。2、确认电脑匹配的CUDA型号,(例如,9.2)3、新建一个虚拟环境,在终端运行condacreate-nXXXp......
  • PyTorch中F.cross_entropy()函数
    对PyTorch中F.cross_entropy()的理解PyTorch提供了求交叉熵的两个常用函数:一个是F.cross_entropy(),另一个是F.nll_entropy(),是对F.cross_entropy(input,target)中参数targ......
  • pytorch张量索引
    一、pytorch返回最值索引1官方文档资料1.1torch.argmax()介绍 返回最大值的索引下标函数:torch.argmax(input,dim,keepdim=False)→LongTensor返回值:Retur......
  • pytorch tensor 张量常用方法介绍
    1. view()函数PyTorch 中的view()函数相当于numpy中的resize()函数,都是用来重构(或者调整)张量维度的,用法稍有不同。>>>importtorch>>>re=torch.tensor([1,......
  • pytorch TensorDataset和DataLoader区别
    TensorDatasetTensorDataset可以用来对tensor进行打包,就好像python中的zip功能。该类通过每一个tensor的第一个维度进行索引。因此,该类中的tensor第一维度必须......
  • pytorch入门
    初衷:看不懂论文开源代码参考:B站小土堆(土堆yyds~)   PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】_哔哩哔哩_bilibili 1.环境配置参考:(39条消息)win10......
  • 一文读懂 PyTorch 版本 SRGAN训练和测试【CVPR 2017】
    ????声明:作为全网AI领域干货最多的博主之一,❤️不负光阴不负卿❤️????深度学习:#超分重建、一文读懂????超分重建经典网络SRGAN详尽教程????最近更新:2022年2月28......
  • 使用PyTorch实现简单的AlphaZero的算法(1):背景和介绍
    在本文中,我们将在PyTorch中为ChainReaction[2]游戏从头开始实现DeepMind的AlphaZero[1]。为了使AlphaZero的学习过程更有效,我们还将使用一个相对较新的改进,称为“Playout......
  • PyTorch实现非极大值抑制(NMS)
    NMS即nonmaximumsuppression即非极大抑制,顾名思义就是抑制不是极大值的元素,搜索局部的极大值。在最近几年常见的物体检测算法(包括rcnn、sppnet、fast-rcnn、faster-rcnn......
  • Pytorch中模型调用
    注意:RNN、LSTM的batch_first参数,对于不同的网络层,输入的维度虽然不同,但是通常输入的第一个维度都是batch_size,比如torch.nn.Linear的输入(batch_size,in_features),torch.nn......