1 CIFAR10
cifar10 的官方网址:<http://www.cs.toronto.edu/~kriz/cifar.html>
是由32\*32像素的60000张图片组成的数据集,50000张图片用于训练,10000张图片用于测试,其中有10个类别,每个类别有6000张图片,
分类之间彼此独立,不会重叠,因此是一个单标签多分类的问题。
2 读取CIFAR10数据
首先在对图像做预处理,将图像的RGB均值都设置为0.5。
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
CIFAR10(root="./data", train=True, download=True, transform=transform)
- 使用 torchvision 下载CIFAR10的图像,因为之前本地没有CIFAR10的图像内容,第一次使用时需要进行下载,download=True
- root 图像所在的位置,如果之前没有图像,这里也是图像的下载位置
- train 是否用于训练,True 为训练,False 为测试
- transform 图像数组
DataLoader(trainset, batch\_size=4, shuffle=True, num\_workers=2)
对图像进行训练
- trainset 原始图像
- batch\_size 每批次数据量,这里是每个批次4张图像
- shuffle 使用乱序,因为是进行训练,这里是True,如果是用于测试,可以设置为 False
- num\_workers 线程数,加快运行速度,因 pytorch 在 windows 平台有bug,最好设置为0
trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=4, shuffle=True, num_workers=0)
testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=4, shuffle=False, num_workers=0)
3 将图像展示
将训练图像获取到后,展示到前端界面上。
对一个批次的图像转换为迭代器,并将其尺寸缩小一半后,显示在界面上。
# 获取训练图片
dataiter = iter(trainloader)
images, labels = dataiter.__next__()
def imgshow(img):
img = img / 2 + 0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# 显示图像
imgshow(torchvision.utils.make_grid(images))
4 打印出训练图像标签
将图像标签(或者叫类别)定义在一个元组中,使用迭代器,将上一步中展示出的图像对应的标签打印到控制台。
classes = (
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck"
)
print(" ".join('%5s' % classes[labels[l]] for l in range(4)))
5 运行
5.1 首先下载图像到本地
5.2 随机展示一组图像
5.3 打印该组图像所属标签
关闭上一步展示出的图像,程序继续执行,将上一步展示的图片所属标签,打印到控制台中。