首页 > 其他分享 >Pytorch笔记|小土堆|P14-15|torchvision数据集使用、Dataloader使用

Pytorch笔记|小土堆|P14-15|torchvision数据集使用、Dataloader使用

时间:2024-08-03 11:41:15浏览次数:4  
标签:15 torchvision CIFAR10 data Dataloader transforms test True

学会看内置数据集的官方文档:https://pytorch.org/vision/stable/generated/torchvision.datasets.CIFAR10.html#torchvision.datasets.CIFAR10

示例代码

import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

#ToTensor
tensor_trans = transforms.ToTensor()

train_set = torchvision.datasets.CIFAR10(root=r'D:\ai-learning\pytorch\cifar10', train=True, transform=tensor_trans, download=True)
test_set = torchvision.datasets.CIFAR10(root=r'D:\ai-learning\pytorch\cifar10', train=False, transform=tensor_trans, download=True)

#部分可视化
writer = SummaryWriter("logs")
for i in range(10):
    img, target = test_set[i]
    writer.add_image("test_set", img, i)
writer.close()

以CIFAR10数据集为例,其常用的参数
root=下载路径(如果没下载过,且设置download=True,则会下载至此路径;如果此路径中有已下载的数据,会校验)
train=True(提取此数据集中的训练集)train=False(提取此数据集中的测试集)
transforms=使用什么transforms方法(如果一个方法,可以直接transforms=torchvision.transforms.ToTensor;如果多个方法,先Compose)
download=True下载数据集至root路径,如果已有,则不再下载。建议常年True
*如果下载慢,可在help文档里查看此数据集的URL,复制至迅雷中下载,下载后把压缩文件复制至root目录中。运行代码时,会自动检测到下载好的数据集并校验、解压

看官方文档时,关注
1、数据集的内容,比如CIFAR10:10类,每类6k。训练集50k,测试集10k。大小32323
2、数据集的数据类型,比如CIFAR10就是PIL——要transform为Tensor类型
3、有哪些参数
4、看getitem返回什么内容,比如CIFAR10返回img, target,DataLoader后即为imgs, targets
—————————————————————————————————————
DataLoader
导入from torch.utils.data import DataLoader

常用参数
dataset=load什么数据集
batch_size=每次抽几张出来(默认是随机抽出)
shuffle=True(每个epoch是否洗牌)
num_works=0不开并行
drop_last=如果总数量除以batch_size除不尽,余数是否扔掉

DataLoader之后
load的数据按照batch_size打包,用for循环提取每个batch的数据:for data in test_loader
要去文档里看一下DataLoader的getitem返回哪些内容,比如CIFAR10返回的就是imgs, targets,所以imgs, targets = data
之后会把imgs送入神经网络

示例代码

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

test_data = torchvision.datasets.CIFAR10(r'./cifar10', train=False, transform=torchvision.transforms.ToTensor(), download=True)

test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)

# img, target = test_data[0]

writer = SummaryWriter("logs")

for epoch in range(2):
    step = 0
    for data in test_loader:
        imgs, targets = data
        writer.add_images("Epoch: {}".format(epoch), imgs, step) # 可视化
        step = step + 1
writer.close()

标签:15,torchvision,CIFAR10,data,Dataloader,transforms,test,True
From: https://www.cnblogs.com/xjl-ultrasound/p/18339035

相关文章

  • 谷粒商城实战笔记-115-全文检索-ElasticSearch-进阶-bool复合查询
    文章目录1,must2,mustnot3,should1,must{"query":{"bool":{"must":[{"match":{"gender":"M"}},{"matc......
  • springboot演唱会门票管理系统-计算机毕业设计源码15070
    基于微信小程序的演唱会门票管理系统摘 要本文介绍的是基于Spring Boot开发的演唱会门票管理系统。该系统旨在为用户提供一个便捷、高效的平台,以实现演唱会门票的购买和退票功能。随着社交媒体和互联网的普及,传统的演唱会门票购买方式存在过程繁琐、数据统计困难等问题。......
  • 159.336 application for Android
    159.336Assignment1Due14thAugust2024ForthisassignmentyouneedtowriteasimpledialerapplicationforAndroidtomakephonecalls.ThedialermusthavethefollowingUIelements:Anumberdisplaytoshowthephonenumberwhichwillbecalled.A......
  • LeetCode 热题 HOT 100 (015/100)【宇宙最简单版】
    【栈】No.0155最小栈【中等】......
  • torch.utils.data.Dataset 和 torch.utils.data.DataLoader
    torch.utils.data是PyTorch中用于数据加载和预处理的模块。通常结合使用其中的Dataset和DataLoader两个类来加载和处理数据。Datasettorch.utils.data.Dataset是一个抽象类,用于表示数据集。需要用户自己实现两个方法:__len__和__getitem__。__len__方法返回数据集的大小,__getit......
  • vs2015卸载和安装
    vs2015卸载和安装0.摘要可能对大家有帮助的地方: a.vs2015卸载和安装的流程; b.安装时的error:“teamexplorerformicrosoftvisualstudio2015update3ctp1error”解决方式; c.vs2015社区版的下载地址;如果这三点不能解决你遇到的问题,就没必要往下看了。1.卸载......
  • PCIe学习笔记(15)
    设备就绪状态(DeviceReadinessStatus,DRS)消息(DeviceReadinessStatus(DRS)是PCIe规范中引入的一种机制,旨在改进设备初始化和就绪状态的检测与报告。在以往的PCIe版本中,系统通常依赖于固定的超时机制来判断设备是否已经成功初始化并准备好进行数据传输。然而,这种方法存......
  • springboot+vue前后端分离项目-项目搭建15-集成JWT token权限验证
    1.对之前的代码改造,之前将user存储到sessionStorage,改成存储到localStorage,全局搜索修改 之前Result.code等于0代表success,改成200代表success,vue文件全局搜索修改一、前端部分1.改造request.js,登录时将user已经存储到localStorage里,这里将user获取到,将user里的token放到......
  • LeetCode 152 乘积最大子数组
    题目描述给你一个整数数组nums,请你找出数组中乘积最大的非空连续子数组(该子数组中至少包含一个数字),并返回该子数组所对应的乘积。测试用例的答案是一个32位整数。思路这一题用普通的连续子数组思路求解时有一个问题:子问题的最优解不一定是总体的最优局部解。也就是不满足最优......
  • 题解:CF1537E2 Erase and Extend (Hard Version)
    CF1537E2EraseandExtend题解分析通过观察题目,可以证明结果一定是由多次前缀复制得来的。题目要求你进行删和复制的操作,与其交替着操作,不如直接先删到最优的前缀再进行复制。现在就是要找最优的前缀。从头一位一位往后遍历。用\(l\)来存储目前最优前缀的长度,第\(i\)位......