首页 > 其他分享 >torch.utils.data.Dataset 和 torch.utils.data.DataLoader

torch.utils.data.Dataset 和 torch.utils.data.DataLoader

时间:2024-08-02 21:29:39浏览次数:9  
标签:__ data utils torch batch 数据 self

torch.utils.dataPyTorch中用于数据加载和预处理的模块。通常结合使用其中的DatasetDataLoader两个类来加载和处理数据。

Dataset

torch.utils.data.Dataset是一个抽象类,用于表示数据集。

需要用户自己实现两个方法:__len____getitem__

__len__方法返回数据集的大小,__getitem__方法用于根据给定的索引返回一个数据样本。

import torch.utils.data as data

class MyDataset(data.Dataset):
    def __init__(self, data_list):
        self.data_list = data_list

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        return self.data_list[index]

DataLoader

torch.utils.data.Dataset用于表示数据集,torch.utils.data.DataLoader用于加载数据,并对数据进行批量处理和随机化。

dataset

要加载的数据集对象,必须是实现了len()和getitem()方法的对象。

batch_size

每个批次的数据量大小,默认为1。batch size的大小会直接影响到模型的训练速度和效果。如果batch size过大,可能会导致内存不足或者训练速度变慢;如果batch size过小,则可能会降低模型的泛化能力。因此,我们需要根据实际情况来选择合适的batch size

num_workers

num_workers参数用于指定使用多少个进程来加载数据。默认值为0,表示使用主进程加载数据。如果设置为正数,则会使用多个子进程来加载数据,从而提高数据加载的速度。

通过设置num_workers参数为正数来启用多个子进程加载数据,并利用PyTorch的自动混合精度训练(Automatic Mixed Precision, AMP)功能来提高数据加载和处理的速度。

pin_memory

pin_memory参数用于指定是否将数据加载到CUDA主机内存中的固定位置(pinned memory),以提高数据传输效率。默认值为False

collate_fn

collate_fn参数用于指定如何将样本组合成一个批次。默认情况下,DataLoader将每个样本作为一个单独的元素传递给模型,但在某些情况下,需要将样本组合成一个批次,以便一次性对整个批次进行处理。 默认为None,表示使用默认的方式进行组合。

在某些特殊情况下,我们可能需要自定义collate_fn函数来按照特定的方式组合多个数据样本。例如,在处理图像数据时,可能需要将多个图像拼接成一个大的图像作为输入;在处理文本数据时,可能需要将多个文本序列拼接成一个长的文本序列作为输入。通过自定义collate_fn函数,我们可以轻松实现这些需求。

def my_collate_fn(batch):
    # 将样本组合成一个批次
    data = [item[0] for item in batch]
    target = [item[1] for item in batch]
    return [data, target]

my_dataloader = data.DataLoader(my_dataset, batch_size=2, shuffle=True, collate_fn=my_collate_fn)

DataLoader将每个样本作为一个元素传递给my_collate_fn函数,函数将样本组合成一个批次,并返回一个包含数据和目标的列表。

shuffle

是否对数据进行随机洗牌操作,默认为False。通过启用shuffle功能来打乱数据的顺序,可以有效防止模型过拟合。但是需要注意的是,在每个epoch开始时都需要重新打乱数据的顺序,否则会导致模型训练效果不佳。

Sampler

Sampler是一个用于指定数据集采样方式的类,它控制DataLoader如何从数据集中选取样本。PyTorch提供了多种Sampler类,例如RandomSamplerSequentialSampler,分别用于随机采样和顺序采样。如果指定了Sampler,则shuffle参数将被忽略。

from torch.utils.data.sampler import RandomSampler

my_sampler = RandomSampler(my_dataset)
my_dataloader = data.DataLoader(my_dataset, batch_size=2, shuffle=False, sampler=my_sampler)

自定义Sampler

class MySampler(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source
        
    def __iter__(self):
        return iter(range(len(self.data_source)))
    
    def __len__(self):
        return len(self.data_source)

MySampler类继承自torch.utils.data.sampler.Sampler类,实现了__iter____len__方法。MySampler类的构造函数接受一个数据集作为参数,__iter__方法返回一个迭代器,用于遍历数据集中的样本索引,__len__方法返回数据集中样本的数量。

drop_last

如果数据集大小不能被batch size整除,设置为True可以删除最后一个不完整的批次,默认为False。

标签:__,data,utils,torch,batch,数据,self
From: https://www.cnblogs.com/conpi/p/18339626

相关文章

  • Datawhale AI夏令营(AI+生命科学)深度学习-Task3直播笔记
    机器学习lgm上分思路    1、引入新特征(1)对于Task2特征的再刻画        GC含量是siRNA效率中的一个重要且基本的参数,可以作为模型预测的特征。这是因为低GC含量会导致非特异性和较弱的结合,而高GC含量可能会阻碍siRNA双链在解旋酶和RISC复合体作用下的解旋。......
  • pytorch深度学习实践(刘二大人)课堂代码&作业——线性回归
    一、课堂代码1.torch.nn.linear构造linear对象,对象里包含了w和b,即直接利用linear实现wx+b(linear也继承自module,可以自动实现反向传播)2.torch.nn.MSELoss损失函数MSE包含2个参数:size_average(求均值,一般只考虑这个参数)、reduce(求和降维)3.torch.optim.SGDSGD优化器,设置......
  • Pytorch笔记|小土堆|P10-13|transforms
    transforms对图像进行改造最靠谱的办法:根据help文件自行学习transforms包含哪些工具(类)以及如何使用————————————————————————————————————自学一个类时,应关注:1、如何使用各种工具(类)的使用思路:创建对象(实例化)——>传入参数,调用函数(如有__......
  • 如何计算 pandas DataFrame 列中的 NaN 值?
    我想找到数据每列中NaN的数量。可以使用isna()方法加上sum()方法来计算PandasDataFrame列中的NaN值数量。以下是一个示例:importpandasaspd#创建一个示例DataFramedf=pd.DataFrame({'A':[1,2,None,4],'B':[5,Non......
  • Pytorch笔记|小土堆|P7-8|Tensorboard数据可视化
    Tensorboard数据可视化TensorBoard是一个可视化工具,它可以用来展示网络图、张量的指标变化、张量的分布情况等。它通过运行一个本地服务器,来监听6006端口(可更改)。在浏览器发出请求时,分析训练时记录的数据,绘制训练过程中的图像当前环境下安装:pipinstalltensorboardSummaryWrit......
  • 每天五分钟玩转深度学习框架PyTorch:选择函数where和gather
    本文重点如图表所示,这几个方法可以理解为索引函数,有些函数在切片和索引一章进行了简单的介绍,本文将再次进行介绍,温故知新。index_select通过特殊的索引来获取数据index_select,这个这样来理解,第一个参数表示a的第几维度,第二个参数表示获取该维度的哪部分。我们把16,3,28,28看......
  • 【数据科学】Pandas数据库中的Series&DataFrame
    前言前文再续,书接上一回,前两回讲到了Pandas的Series和DataFrame,今天我们使用jupyternotebook来进一步聊聊series和dataframe之间的关系。之前的文章中,我们了解到series和dataframe之间可以相互转换,看完这篇文章,相信你对它们之间的关系会有进一步的了解。正文importdata首......
  • 无敌DataGrip
    jetbrains一款重磅产品——DataGrip这款产品是付费的,不过网上有许多破解补丁和教程,具体自己搜。接下来说说这款产品:     众所周知(从名字中)这是一款数据库工具软件,它跟什么DBeaver之类的不太一样    DBeaver对于大数据库和复杂查询速度较慢,界面对初学者......
  • 排序工具类 - SortUtils
    packagecom.kurumi.util;importorg.springframework.stereotype.Component;importjava.util.Collections;importjava.util.Comparator;importjava.util.List;importjava.util.Map;publicclassSortUtils{/***将list安装sortMap中的传参排......
  • OAF export data from VO in xlsx format
    InthisarticlewearegoingtoseehowtoexportviewobjectinMicrosoftofficeexcelxlsxformatToexportwithxlsxformatfewbasicthingsneededareJarfiles(Listofjari'veusedisshowninbelowscreenshot)ForbetterunderstandingI’lli......