首页 > 其他分享 >collate_fn的应用教程

collate_fn的应用教程

时间:2023-04-11 16:55:33浏览次数:45  
标签:__ 教程 batch dataset collate data fn

作用

collate_fn:即用于collate的function,用于整理数据的函数。
说到整理数据,你当然要会用数据,即会用数据制作工具torch.utils.data.Dataset,虽然我们今天谈的是torch.utils.data.DataLoader。
collate_fn笼统的说就是用于整理数据,通常我们不需要使用,其应用的情形是:各个数据长度不一样的情况,比如第一张图片大小是2828,第二张是5050,这样的话就如果不自己写collate_fn,而使用默认的,就会报错。

基础

dataset

我们必须先看看torch.utils.data.Dataset如何使用,以一个例子为例:

import torch.utils.data as Data
class mydataset(Data.Dataset):
    def __init__(self,train_inputs,train_targets):#必须有
        super(mydataset,self).__init__()
        self.inputs=train_inputs
        self.targets=train_targets
        
    def __getitem__(self, index):#必须重写
        return self.inputs[index],self.targets[index]
        
    def __len__(self):#必须重写
        return len(self.targets)
#构造训练数据
datax=torch.randn(4,3)#构造4个输入
datay=torch.empty(4).random_(2)#构造4个标签
#制作dataset
dataset=mydataset(datax,datay)

下面,可以对dataset进行一系列操作,这些操作返回的结果和你之前那个class的三个函数定义都息息相关。我想说,那三个函数非常自由,你想怎么定义就怎么定义,上述只是一种常见的而已,你可以定制一个特色的。

len(dataset)#调用了你上面定义的def __len__()那个函数
#4

上面的输出结果和你的定义有关,比如你完全可以把def getitem()改成:

def __getitem__(self, index):
    return self.inputs[index]#不输出标签

那么,

dataset[0]#此时当然变化。
#tensor([-1.1426, -1.3239,  1.8372])

dataloader

torch.utils.data.DataLoader

dataloader=Data.DataLoader(dataset,batch_size=2)

一共有4条数据,batch_size=2,所以一共有2个batch。

collate_fn如果你不指定,会调用pytorch内部的,也就是说这个函数是一定会调用的,而且调用这个函数时pytorch会往这个函数里面传入一个参数batch。

def my_collate(batch):
    return xxx

这个batch是什么?这个东西和你定义的dataset, batch_size息息相关。batch是一个列表[x, ... , x],长度就是batch_size,里面每一个元素是dataset的某一个元素,即dataset[i]。

在我们的例子中,由于我们没有对dataloader设置需要打乱数据即shuffle=True,那么第1个batch就是前两个数据,如下:

print(datax)
print(datay)
batch=[dataset[0],dataset[1]]  # 所以才说和你dataset中get_item的定义有关。
print(batch)

image
对,你没有看错,上述代码展示的batch就会传入到pytorch默认的collate_fn中,然后经过默认的处理,输出如下:

it=iter(dataloader)
nex=next(it)#我们展示第一个batch经过collate_fn之后的输出结果
print(nex)

image
其实,上面就是我们常用的,经典的输出结果,即输入和标签是分开的,第一项是输入tensor,第二项是标签tensor,输入的维度变成了(batch_size,input_size)。

但是我们乍一看,将第一个batch变成上述输出结果很容易呀,我们也会!我们下面就来自己写一个collate_fn实现这个功能。

# a simple custom collate function, just to show the idea
# `batch` is a list of tuple where first element is input tensor and the second element is corresponding label
def my_collate(batch):
    inputs=[data[0].tolist() for data in batch]
    target = torch.tensor([data[1] for data in batch])
    return [data, target]
dataloader=Data.DataLoader(dataset,batch_size=2,collate_fn=my_collate)
print(datax)
print(datay)

image

it=iter(dataloader)
nex=next(it)
print(nex)

image

这不就和默认的collate_fn的输出结果一样了嘛!无非就是默认的还把输入变成了tensor,标签变成了tensor,我上面是列表,我改就是了嘛!如下:

def my_collate(batch):
    inputs=[data[0].tolist() for data in batch]
    inputs=torch.tensor(inputs)
    target =[data[1].tolist() for data in batch]
    target=torch.tensor(target)
    return [inputs, target]
dataloader=Data.DataLoader(dataset,batch_size=2,collate_fn=my_collate)
it=iter(dataloader)
nex=next(it)
print(nex)

image
给大家的一个经验就是,一般dataset是不会报错的,而是根据dataset制作dataloader的时候容易报错,因为默认collate_fn把dataset的类型限制得比较死。

应用情形

假设我们还是4个输入,但是维度不固定的。

a=[[1,2],[3,4,5],[1],[3,4,9]]
b=[1,0,0,1]
dataset=mydataset(a,b)
dataloader=Data.DataLoader(dataset,batch_size=2)
it=iter(dataloader)
nex=next(it)
nex

使用默认的collate_fn,直接报错,要求相同维度。
image
这个时候,我们可以使用自己的collate_fn,避免报错。

不过话说回来,我个人感受是:
在这里避免报错好像也没有什么用,因为大多数的神经网络都是定长输入的,而且很多的操作也要求相同维度才能相加或相乘,所以:这里不报错,后面还是报错。如果后面解决这个问题的方法是:在不足维度上进行补0操作,那么我们为什么不在建立dataset之前先补好呢?所以,collate_fn这个东西的应用场景还是有限的。

https://www.jb51.net/article/237011.htm
https://mp.weixin.qq.com/s/Uc2LYM6tIOY8KyxB7aQrOw

标签:__,教程,batch,dataset,collate,data,fn
From: https://www.cnblogs.com/edkong/p/16256620.html

相关文章

  • EasyCVR平台基于GB28181协议的语音对讲配置操作教程
    EasyCVR基于云边端协同,具有强大的数据接入、处理及分发能力,平台可支持海量视频的轻量化接入与汇聚管理,可提供视频监控直播、视频轮播、视频录像、云存储、回放与检索、智能告警、服务器集群、语音对讲、云台控制、电子地图、平台级联等功能。其中,语音对讲功能在视频监控场景中具有......
  • SketchUp曲面建模教程
    推荐:将NSDT场景编辑器加入你的3D工具链其他系列工具:NSDT简石数字孪生教程适用品牌型号:华硕N550JK系统版本:Windows10 专业版软件版本:SketchUp2021大伙儿建模的时候总会遇到曲面建模,那么SketchUp如何曲面建模?是不是一定要用插件才能完成曲面呢?下面就来分享两种曲面案建模的......
  • Raspberry Pi GPIO 图解教程 All In One
    RaspberryPiGPIO图解教程AllInOneRaspberryPi&GPIOGPIO图解GPIOhttps://www.raspberrypi.com/documentation/computers/os.html#gpio-and-the-40-pin-header$pinouthttps://pinout.xyzGPIO(GeneralPurposeIO)SPI(SerialPeripheralInterface)I......
  • AI 绘画 API 超详细使用教程 - 附微信小程序接入代码
    写在前面【AI绘画/AI图像生成】已成为现下炙手可热的话题,AI大模型训练的成本高昂,算法研究时间周期较长,对于大多数人来说,自研一套算法模型还是非常困难的,因此AI绘画API就应运而生,直接调用AI绘画API就能轻松将先进的图文AI融入到我们的产品中,使用门槛是非常低的。 本......
  • Android DataStore Proto框架存储接入AndroidStudio教程详解与使用
    一、介绍        通过前面的文字,我们已掌握了DataStore的存储,但是留下一个尾巴,那就是Proto的接入。Proto是什么?Protobuf,类似于json和xml,是一种序列化结构数据机制,可以用于数据通讯等场景,相对于xml而言更小,相对于json而言解析更快,支持多语言官网:LanguageGuide(proto3)|......
  • Keras中文教程
    http://www.likuli.com/doc/keras/15211966310968.html关于深度学习由于Keras是为深度学习设计的工具,这里只列举深度学习中的一些基本概念。请确保对下面的概念有一定理解:有监督学习,无监督学习,分类,聚类,回归神经元模型,多层感知器,BP算法目标函数(损失函数),激活函数,梯度下降法全......
  • Windows 系统上如何安装 Python 环境(详细教程)
    Windows系统上如何安装Python环境(详细教程)目前,Python有两个版本,一个是2.x版,一个是3.x版,这两个版本是不兼容的。由于2.x版官方只维护到2020年,所以以3.x版作为示例,但是2.x版与3.x版安装方法及环境变量配置的方法是一模一样的,所以请放心。下载Python安装包进入Python官网www.......
  • 学习笔记395—Windows10 Docker安装详细教程
    思维导航前言DockerDesktop是什么?DokcerDesktop下载启用Hyper-V以在Windows10上创建虚拟机安装DockerDesktop配置阿里云镜像加速地址WindowsPowerShell查看Docker版本验证Docker桌面版可以正常使用通过启用WSL2安装DockerDocker学习系列文章前言:在上......
  • Linux环境下nginx安装详细教程,一步步装上nginx
    本人安装Nginx环境为:CentOS7.9 下载安装包下载Nginx安装包Linux版:Nginx官网下载:https://nginx.org/en/download.html下载Stableversion(即稳定版) 上传安装包将压缩包放入系统: 解压:tar-zxvfnginx-1.22.1.tar.gz解压成功: 编译安装执行./configure配置命令:这里提示./config......
  • 定义一个基类Base,有两个公有成员函数fn1,fn2,私有派生出Derived类,如何通过Derived类
    定义一个基类Base,有两个公有成员函数fn1,fn2,私有派生出Derived类,如何通过Derived类的对象调用基类的函数fn1。#include<bits/stdc++.h>usingnamespacestd;classBase{public: intfn1(){return0;} intfn2(){return0;}};classDerived:privateBase{publi......