学会看内置数据集的官方文档:https://pytorch.org/vision/stable/generated/torchvision.datasets.CIFAR10.html#torchvision.datasets.CIFAR10
示例代码:
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
#ToTensor
tensor_trans = transforms.ToTensor()
train_set = torchvision.datasets.CIFAR10(root=r'D:\ai-learning\pytorch\cifar10', train=True, transform=tensor_trans, download=True)
test_set = torchvision.datasets.CIFAR10(root=r'D:\ai-learning\pytorch\cifar10', train=False, transform=tensor_trans, download=True)
#部分可视化
writer = SummaryWriter("logs")
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
writer.close()
以CIFAR10数据集为例,其常用的参数:
root=下载路径(如果没下载过,且设置download=True,则会下载至此路径;如果此路径中有已下载的数据,会校验)
train=True(提取此数据集中的训练集)train=False(提取此数据集中的测试集)
transforms=使用什么transforms方法(如果一个方法,可以直接transforms=torchvision.transforms.ToTensor
;如果多个方法,先Compose)
download=True下载数据集至root路径,如果已有,则不再下载。建议常年True
*如果下载慢,可在help文档里查看此数据集的URL,复制至迅雷中下载,下载后把压缩文件复制至root目录中。运行代码时,会自动检测到下载好的数据集并校验、解压
看官方文档时,关注:
1、数据集的内容,比如CIFAR10:10类,每类6k。训练集50k,测试集10k。大小32323
2、数据集的数据类型,比如CIFAR10就是PIL——要transform为Tensor类型
3、有哪些参数
4、看getitem返回什么内容,比如CIFAR10返回img, target,DataLoader后即为imgs, targets
—————————————————————————————————————
DataLoader
导入:from torch.utils.data import DataLoader
常用参数:
dataset=load什么数据集
batch_size=每次抽几张出来(默认是随机抽出)
shuffle=True(每个epoch是否洗牌)
num_works=0不开并行
drop_last=如果总数量除以batch_size除不尽,余数是否扔掉
DataLoader之后:
load的数据按照batch_size打包,用for循环提取每个batch的数据:for data in test_loader
要去文档里看一下DataLoader的getitem返回哪些内容,比如CIFAR10返回的就是imgs, targets,所以imgs, targets = data
之后会把imgs送入神经网络
示例代码:
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
test_data = torchvision.datasets.CIFAR10(r'./cifar10', train=False, transform=torchvision.transforms.ToTensor(), download=True)
test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)
# img, target = test_data[0]
writer = SummaryWriter("logs")
for epoch in range(2):
step = 0
for data in test_loader:
imgs, targets = data
writer.add_images("Epoch: {}".format(epoch), imgs, step) # 可视化
step = step + 1
writer.close()
标签:15,torchvision,CIFAR10,data,Dataloader,transforms,test,True
From: https://www.cnblogs.com/xjl-ultrasound/p/18339035