首页 > 其他分享 >pytorch——DataLoader

pytorch——DataLoader

时间:2024-04-08 20:00:12浏览次数:40  
标签:False torchvision True DataLoader imgs pytorch test data

DataLoader

1.主要参数

  • datasetDataset) – 要从中加载数据的数据集。
  • batch_sizeint 可选) – 每批要加载的样品数:随即抓取 (默认值:)。1
  • shufflebool 可选) – 设置是否重新洗牌数据 在每个纪元(默认值:False)。
  • num_workersint 可选) – 用于数据的子进程数装载。 默认表示数据将在主进程中加载。 (默认值:0)
  • drop_lastbool 可选) – 设置是否删除最后一个未完成的批次, 如果数据集大小不能被批处理大小整除。如果和数据集的大小不能被批处理大小整除,然后是最后一批 会更小。(默认值:False)

2.图解

pFLsxOJ.png

3.基本使用

import torchvision
from torch.utils.data import DataLoader

#准备数据测试集,测试集已存在,不需要下载
test_data=torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor())
#设置数据集的dataloader,说明如何操作数据集
test_loader=DataLoader(test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)

#测试集中第一张图片及target
img,target=test_data[0]
print(img.shape)
print(target)

#使用dataloader,把随机每batch_size个数据打包
for data in test_loader:
    imgs,targets=data #imgs,targets为tensor类型
    print(imgs.shape)
    print(targets)

4.在tensorboard上显示,drop_last=False

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

#准备数据测试集,测试集已存在,不需要下载
test_data=torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor())
#设置数据集的dataloader,说明如何操作数据集
test_loader=DataLoader(test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=False)

#测试集中第一张图片及target
img,target=test_data[0]
print(img.shape)
print(target)

writer =SummaryWriter('dataloader')
step=0#设置在tensorboard中的步长
#使用dataloader,把随机每batch_size个数据打包
for data in test_loader:
    imgs,targets=data #imgs,targets为tensor类型
    writer.add_images('test_data',imgs,step)  #因为imgs中有多张图片,所以用add_images
    step=step+1

writer.close()

5.drop_last参数选择True和False的不同

drop_last=True代码:

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

#准备数据测试集,测试集已存在,不需要下载
test_data=torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor())
#设置数据集的dataloader,说明如何操作数据集
test_loader=DataLoader(test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=True)

#测试集中第一张图片及target
img,target=test_data[0]
print(img.shape)
print(target)

writer =SummaryWriter('dataloader')
step=0#设置在tensorboard中的步长
#使用dataloader,把随机每batch_size个数据打包
for data in test_loader:
    imgs,targets=data #imgs,targets为tensor类型
    writer.add_images('test_data_drop_last',imgs,step)  #因为imgs中有多张图片,所以用add_images
    step=step+1

writer.close()

对比:上:False,下:True

[pFLyPFx.png

可以看出当设置为False时数据集大小不能被批处理大小整除,即整除后剩余部分,不会被删除;当设置为True时,整除后剩余部分,被删除。

6.shuffle参数选择True和False的不同

shuffle(bool 可选) – 设置是否重新洗牌数据

当设置为True,即设置每次重新洗牌时:

pFLyI1O.png

当设置为False,即设置每次不重新洗牌时:

洗牌时:

[外链图片转存中…(img-3Ck9Gprx-1712577523326)]

当设置为False,即设置每次不重新洗牌时:

pFLy81S.png

标签:False,torchvision,True,DataLoader,imgs,pytorch,test,data
From: https://blog.csdn.net/m0_67855350/article/details/137521875

相关文章

  • 2024.4.8 pytorch框架初上手
    pytorchPyTorch是一个针对深度学习,并且使用GPU和CPU来优化的tensorlibrary(tensor库)中文文档:https://pytorch.org/resources梯度/导数计算#linear.pyimporttorchimportnumpyasnpx=torch.tensor(3,)w=torch.tensor(4.,requires_grad=True)b=t......
  • 从零开始的深度学习项目(PyTorch识别人群行为)
    PyTorch识别人群行为系统环境介绍环境版本Python3.11.5pandas2.0.3numpy1.24.3torch2.1.2+cu121注意:2.1.2+cu121这样的版本号通常用于描述TensorFlow等深度学习框架的版本信息,其中:2.1.2是TensorFlow的主要版本号,表示主要的功能和接口的变化。cu121表示该Tenso......
  • 每天五分钟掌握深度学习框架pytorch:本专栏说明
    专栏大纲专栏计划更新章节在100章左右,之后还会不断更新,都会配备代码实现。以下是专栏大纲部分代码实现代码获取为了方便用户浏览代码,本专栏将代码同步更新到github中,所有用户可以读完专栏内容和代码解析之后,下载对应的代码,跑一跑模型算法,这样会加深自己对算法模型......
  • Pytorch张量的数学运算:向量基础运算
    文章目录一、简单运算二、广播运算1.广播的基本规则2.广播操作的例子三、运算函数参考:与凤行  张量的数学运算是深度学习和科学计算中的基础。张量可以被视为一个多维数组,其在数学和物理学中有广泛的应用。这些运算包括但不限于加法、减法、乘法、除法、内积、......
  • Pytorch实用教程:Pytorch中enumerate(test_loader, start=0)的解释
    文章目录1.Pytorch中的enumerate(test_loader,0)数据加载器`test_loader``enumerate(test_loader,0)`数据解包`inputs,labels=data`总结2.python中enumerate的用法基本用法示例遍历列表使用不同的起始索引在字典上使用为什么使用`enumerate`?1.Pytorch......
  • Pytorch入门实战: 04-猴痘病识别
    ......
  • 0193期通过CNN-pytorch训练识别苹果树叶病害识别-含数据集-含数据集
    代码下载和视频演示地址:0193期通过CNN-pytorch训练识别苹果树叶病害识别-含数据集_哔哩哔哩_bilibili本代码是基于pythonpytorch环境安装的。下载本代码后,有个环境安装的requirement.txt文本数据集介绍,下载本资源后,界面如下:数据集文件夹存放了本次识别的各个类别图片......
  • 最简单知识点PyTorch中的nn.Linear(1, 1)
    一、nn.Linear(1,1)nn.Linear(1,1) 是PyTorch中的一个线性层(全连接层)的定义。nn 是PyTorch的神经网络模块(torch.nn)的常用缩写。nn.Linear(1,1) 的含义如下:第一个参数 1:输入特征的数量。这表示该层接受一个长度为1的向量作为输入。第二个参数 1:输出特征的数量......
  • Yann Lecun-纽约大学-深度学习(PyTorch)
    课程介绍    本课程涉及深度学习和表示学习的最新技术,重点是有监督和无监督的深度学习,嵌入方法,度量学习,卷积和递归网络,并应用于计算机视觉,自然语言理解和语音识别。前提条件包括:DS-GA1001数据科学入门或研究生水平的机器学习课程。     免费获取:YannLecun-纽约......
  • 【保姆级教程附代码】Pytorch (.pth) 到 TensorRT (.plan) 模型转化全流程
    整体流程为:.pth->.onnx->.plan(或.trt,二者等价)需要的工具和包:Docker,Pytorch,ONNX,onnxruntime,TensorRT(trtexec和polygraphy).pth到.onnx这里以SwinIR(https://github.com/JingyunLiang/SwinIR)预训练模型为例init_torch_model()函数主要是对模型初始化,这里是......