torchvision中的数据集使用
1.torchvision介绍
torchvision是pytorch的一个图形库,它服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型,一般包括左侧几个模块。
pytorch官网-Docs-torchvision(左侧修改为0.90版本就可以直接看到datasets)
torchvision.datasets:包含常用的数据集API文档,设置一些参数即可下载和使用这些数据集。
COCO数据集:常用于目标检测、语义分割
MNIST数据集:手写文字数据集(一般为入门数据集)
CIFAR数据集:常用于物体识别
torchvision.io:输入输出模块。
torchvision.models:包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等。
torchvision.ops:提供一些少见的特殊的操作。
torchvision.transforms:常用的图片变换,例如类型转换、裁剪等。
torchvision.utils:其他的一些有用的方法。
2.举例说明
本次以CIFAR10为例进行数据集的使用(观察参数设置):
数据集的使用代码
import torchvision
# 将数据集下载到本地的文件夹中用作训练集和测试集
train_set = torchvision.datasets.CIFAR10(root="./dataset2",train=True,download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset2",train=False,download=True)
print(test_set[0])
print(test_set.classes)
img, target = test_set[0]
print(img)
print(target)
print(train_set.classes[target])
img.show()
dataset和transforms的结合使用:
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transforms = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
# 添加transforms参数可以对数据集进行转换操作
train_set = torchvision.datasets.CIFAR10(root="./dataset2",train=True,transform=dataset_transforms, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset2",train=False,transform=dataset_transforms, download=True)
# print(test_set[0])
writer = SummaryWriter("logs")
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
writer.close()
标签:set,torchvision,print,train,transforms,使用,test,数据
From: https://www.cnblogs.com/yq-ydky/p/17621557.html