首页 > 其他分享 >MNIST 数据集、数据加载

MNIST 数据集、数据加载

时间:2022-12-11 15:46:33浏览次数:34  
标签:set labels transforms images 数据 MNIST 加载

目录

MNIST 数据集

机器学习的入门就是MNIST。

MNIST 数据集来自美国国家标准与技术研究所,是NIST(National Institute of Standards and Technology)的缩小版,训练集(training set)由来自 250 个不同人手写的数字构成,其中 50% 是高中学生,50% 来自人口普查局(the Census Bureau)的工作人员,测试集(test set)也是同样比例的手写数字数据。

获取MNIST
MNIST 数据集可在http://yann.lecun.com/exdb/mnist/获取,图片是以字节的形式进行存储,它包含了四个部分:

Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)
Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)

此数据集中,训练样本:共60000个,其中55000个用于训练,另外5000个用于验证。测试样本:共10000个,验证数据比例相同。

from torchvision.datasets import MNIST
mnist_train = MNIST(root='./MNIST_data', train=True, download=True, transform=transforms.PILToTensor())

数据加载

from torch.utils.data import DataLoader
from torchvision.utils import make_grid
dataloader = DataLoader(dataset=mnist_train, batch_size=2, shuffle=True, num_workers=2)
for (images, labels) in dataloader:
    print(labels)
    image = make_grid(images).permute(1, 2, 0).numpy()
    plt.imshow(image)
    plt.show()
    exit()

其中参数含义:

  1. dataset:提前定义的dataset的实例
  2. batch_size:传入数据的batch的大小,常用128,256等等
  3. shuffle:bool类型,表示是否在每次获取数据的时候提前打乱数据
  4. num_workers:加载数据的线程数

transforms

由于 DataLoader 这个加载器只能加载 tensors, numpy arrays, numbers, dicts or lists

但是 found <class 'PIL.Image.Image'>,所以就很尴尬,我们需要将图片转换一下

transforms 用于图形变换,在使用时我们还可以使用 transforms.Compose将一系列的transforms操作链接起来。

  • torchvision.transforms.Compose([ ts,ts,ts... ])ts为transforms操作
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

大多数情况下我们不会只transforms 一下,所以可以用如下方案

from torchvision import transforms
transforms.Compose(
    [  #文档  https://pytorch.org/vision/stable/transforms.html
        transforms.ToPILImage(),  # 转成PIL图片
        # transforms.Resize(size),  # 缩放
        transforms.ToTensor(),  # 变张量
        transforms.Normalize(mean=(0.1307, ), std=(0.3081, )) ]
)

介绍一个概念:

transforms 处理过后,会把通道移到最前边。比如 MNIST h*w*c 为:28281

tensor处理完,通道数会提前,并且做了轴交换,变为了 c*h*w 为:12828

至于为什么要这么设计?听传言是做矩阵加减乘除以及卷积等运算是需要调用cuda和cudnn的函数的,而这些接口都设成成 chw 格式了

标签:set,labels,transforms,images,数据,MNIST,加载
From: https://www.cnblogs.com/kai-/p/16973742.html

相关文章