首页 > 其他分享 >Pytorch torch.utils.data.DataLoader 用法详细介绍

Pytorch torch.utils.data.DataLoader 用法详细介绍

时间:2024-04-03 17:58:33浏览次数:23  
标签:torch 批次 数据 utils DataLoader num optional workers 加载

文章目录


1. 介绍

torch.utils.data.DataLoader 是 PyTorch 提供的一个用于数据加载的工具类,用于批量加载数据并为模型提供输入。它可以将数据集包装成一个可迭代的对象,方便地进行数据加载和批处理操作。Pytorch DataLoader 的详细官方介绍看这里

2. 参数详解

  • dataset (Dataset) – 加载的数据集

  • batch_size (int, optional) – 每一次处理加载多少数据

  • shuffle (bool, optional) – True 表示每次 epoch 都要重新打乱数据,默认 False

  • sampler (Sampler or Iterable, optional) – 定义采样的策略。如果定义了此参数,那么 shuffle 参数必须为 False

  • batch_sampler (Sampler or Iterable, optional) – 同 sample 一样,但每次返回数据的索引。与 batch_sizeshufflesampledrop_last 参数互斥

  • num_workers (int, optional) – 指定用于数据加载的子进程数,可以加快数据加载速度。默认0,表示用主进程加载

  • collate_fn (Callable, optional) – 批处理函数,用于将多个样本合并成一个批次,例如将多个张量拼接在一起,构建 mini-batch。当使用 map-style 数据集进行批量加载时使用。

  • pin_memory (bool, optional) – True 表示在返回张量之前将张量复制到 CUDA 固定的内存中,加快 GPU 传输速度

  • drop_last (bool, optional) – True 表示可删除最后一个不完整的批次。默认 False,如果数据集的大小不能被批次大小整除,则最后一个批次会更小。

  • timeout (numeric, optional) – 非负数,worker 收集批次数据的超时时间,默认0

  • worker_init_fn (Callable, optional) – 如果非None,则在种子设定之后和数据加载之前,将以worker id([0,num_workers-1]中的int)作为输入对每个 worker 子进程调用此函数。(默认值:None)

  • multiprocessing_context (str or multiprocessing.context.BaseContext, optional) – 如果为None,则将使用操作系统的默认多处理上下文。(默认值:None)

  • generator (torch.Generator, optional) – 如果非None,则RandomSampler 将使用此RNG来生成随机索引,并进行多进程处理以为 workers 生成 base_seed。(默认值:None)

  • prefetch_factor (int, optional, keyword-only arg) – 每个 worker 预先装载的批次数。2 表示在所有工作线程中总共预取2*num_workers批次。(默认值取决于为num_workers设置的值。如果num_workers=0的值,则默认为None。否则,如果num_workers>0的值,默认为2)

  • persistent_workers (bool, optional) – True 表示不会在数据集使用一次后关闭工作进程。这允许保持 worker 实例处于活动状态。(默认值:False)

  • pin_memory_device (str, optional) – 如果 pin_memory 为 True,该参数表示 pin_memory 所指向的设备

3. 用法

使用 DataLoader 进行迭代

import torch
from torch.utils.data import Dataset, DataLoader
# 假设有自定义数据集类 MyDataset
class MyDataset(Dataset):
    # 实现 __init__, __len__, 和 __getitem__ 方法...

# 实例化数据集
dataset = MyDataset(data_source)

# 创建 DataLoader
dataloader = DataLoader(dataset,
                       batch_size=64,  # 设置批次大小
                       shuffle=True,   # 是否随机打乱数据
                       num_workers=4,  # 启用4个工作进程加载数据
                       drop_last=True  # 丢弃最后一个不足批次大小的数据
                      )

# 迭代数据加载器进行训练
for epoch in range(num_epochs):
    for inputs, labels in dataloader:
        # 训练模型...
        outputs = model(inputs)
        loss = compute_loss(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

在迭代过程中,loader 会自动从数据集中加载数据,并将其组织成批次。每次迭代返回一个批次的数据,其中 batch_data 是一个包含输入数据和标签的元组或列表。

4. 参考

https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

标签:torch,批次,数据,utils,DataLoader,num,optional,workers,加载
From: https://blog.csdn.net/qq_36803941/article/details/137352449

相关文章

  • 大模型中常用的注意力机制GQA详解以及Pytorch代码实现
    分组查询注意力(GroupedQueryAttention)是一种在大型语言模型中的多查询注意力(MQA)和多头注意力(MHA)之间进行插值的方法,它的目标是在保持MQA速度的同时实现MHA的质量。这篇文章中,我们将解释GQA的思想以及如何将其转化为代码。GQA是在论文GQA:TrainingGeneraliz......
  • PyTorch学习(5):并行训练模型权重的本地化与加载
    1.并行训练与非并行训练        在训练深度神经网络时,我们一般会采用CPU或GPU来完成。得益于开源传统,许多算法都提供了完整的开源代码工程,便于学习和使用。随着GPU的普及,GPGPU已经占据了大部分的训练场景。        我们在这里仅以GPU训练场景做一些说明。......
  • pytorch | torchvision.transforms.CenterCrop
    torchvision.transforms.CenterCrop==>从图像中心裁剪图片transforms.CenterCroptorchvision.transforms.CenterCrop(size)功能:从图像中心裁剪图片size:所需裁剪的图片尺寸transforms.CenterCrop(196)的效果如下:(也可以写成transforms.CenterCrop((196,196)))如果裁剪......
  • 3. dataset、dataloader
    dataset数据集dataloader数据加载器1.AI训练时的需求有一个数据集文件来,里面有100w的样本和标签训练时,通常希望,一次在100w中随机抓取batch个样本,拿去训练如果全部抓取完毕,则重新打乱后,再来一次2.dataset,数据集作用:储存数据集的信息self.xxx获取数据集长度__len_......
  • pytorch在Mac上实现像cuda一样的加速
    1.参考:https://developer.apple.com/metal/pytorch/2.具体实现:2.1RequirementsMacM芯片或者AMD的GPUmacOS12.3orlaterPython3.7orlaterXcodecommand-linetools: xcode-select--install2.2准备anac......
  • Pytorch - Dataloader
    BasicallytheDataLoaderworkswiththeDatasetobject.SotousetheDataLoaderyouneedtogetyourdataintothisDatasetwrapper.Todothisyouonlyneedtoimplementtwomagicmethods:__getitem__and__len__.The__getitem__takesanindexandretu......
  • 【PyTorch 实战2:UNet 分类模型】10min揭秘 UNet 分割网络如何工作以及pytorch代码实现
    UNet网络详解及PyTorch实现一、UNet网络原理  U-Net,自2015年诞生以来,便以其卓越的性能在生物医学图像分割领域崭露头角。作为FCN的一种变体,U-Net凭借其Encoder-Decoder的精巧结构,不仅在医学图像分析中大放异彩,更在卫星图像分割、工业瑕疵检测等多个领域展现出强大的应用......
  • 使用镜像安装cuda12.1版本pytorch
    1.添加通道condaconfig--addchannelshttps://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/condaconfig--addchannelscondaconfig--addchannelshttps://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/condaconfig--addchannelshttps://mirrors.bfs......
  • 故障诊断模型 | 基于LSTM长短期记忆神经网络的滚动轴承故障诊断(Pytorch)
    概述LSTM(LongShort-TermMemory)是一种常用的循环神经网络(RNN),在时间序列数据处理任务中表现优秀,可用于滚动轴承故障诊断。滚动轴承故障通常会导致振动信号的变化,这些振动信号可以被视为时间序列数据。LSTM能够捕捉时间序列之间的依赖关系,从而对滚动轴承的故障进行诊断。......
  • CUDA与Pytorch安装
    cuda和pytorch是使用python进行深度学习常会需要的工具,其中pytorch是深度学习的框架之一,cuda是利用GPU进行运算的工具。cuda的安装cuda是英伟达公司开发的利用显卡进行深度学习的工具。显卡的GPU比CPU的运算能力要强,在深度学习时算力十分重要,直接决定了我们训练模型的速度,所以......