PyTorch 有两种基础数据类型: torch.utils.data.DataLoader
和 torch.utils.data.Dataset
. Dataset,它们存储着样本和对应的标记。
Dataset是样本数据集,DataLoader对Dataset进行封装,方便加载、遍历和分批等。
import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import ToTensor
PyTorch 提供了不同用途的数据集,比如: TorchText, TorchVision, and TorchAudio. 在本教程中,我们使用TorchVision。
torchvision.datasets
模块包含了各种视觉数据集, 比如 CIFAR, COCO (完整列表)。 本教程我们使用FashionMNIST数据集。 每个视觉数据集包含2个参数:transform
和 target_transform,可以分别用来修改样本和标记。
# 从开放机构下载训练数据集 training_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor(), ) # 下载测试数据集 test_data = datasets.FashionMNIST( root="data", train=False, download=True, transform=ToTensor(), )
输出:
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz 0%| | 0/26421880 [00:00<?, ?it/s] 0%| | 65536/26421880 [00:00<01:12, 365718.31it/s] 1%| | 229376/26421880 [00:00<00:38, 685682.68it/s] 3%|3 | 884736/26421880 [00:00<00:10, 2498938.52it/s] 7%|7 | 1933312/26421880 [00:00<00:05, 4141475.37it/s] 19%|#8 | 4915200/26421880 [00:00<00:01, 10854978.12it/s] 26%|##5 | 6782976/26421880 [00:00<00:01, 11037400.65it/s] 37%|###7 | 9797632/26421880 [00:01<00:01, 15568756.79it/s] 44%|####4 | 11730944/26421880 [00:01<00:01, 14184748.16it/s] 55%|#####5 | 14647296/26421880 [00:01<00:00, 17510568.70it/s] 63%|######3 | 16777216/26421880 [00:01<00:00, 15834704.91it/s] 75%|#######4 | 19693568/26421880 [00:01<00:00, 18759775.35it/s] 83%|########2 | 21889024/26421880 [00:01<00:00, 16780435.96it/s] 94%|#########3| 24772608/26421880 [00:01<00:00, 19391805.01it/s] 100%|##########| 26421880/26421880 [00:01<00:00, 13914460.04it/s] Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz 0%| | 0/29515 [00:00<?, ?it/s] 100%|##########| 29515/29515 [00:00<00:00, 326673.50it/s] Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz 0%| | 0/4422102 [00:00<?, ?it/s] 1%|1 | 65536/4422102 [00:00<00:12, 362354.20it/s] 5%|5 | 229376/4422102 [00:00<00:06, 684627.79it/s] 21%|## | 917504/4422102 [00:00<00:01, 2626211.85it/s] 44%|####3 | 1933312/4422102 [00:00<00:00, 4103892.12it/s] 100%|##########| 4422102/4422102 [00:00<00:00, 6109664.51it/s] Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz 0%| | 0/5148 [00:00<?, ?it/s] 100%|##########| 5148/5148 [00:00<00:00, 61868988.52it/s] Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
把Dataset
作为参数传递给DataLoader
。这样就可以把数据集封装起来,实现自动分批,取样,打乱和多处理器协同加载。在这里,我们定义每批大小为65,这样一来,分批遍历dataloader的时候,就能在循环中每次取到64组特征和标记。
batch_size = 64 # Create data loaders. train_dataloader = DataLoader(training_data, batch_size=batch_size) test_dataloader = DataLoader(test_data, batch_size=batch_size) for X, y in test_dataloader: print(f"Shape of X [N, C, H, W]: {X.shape}") print(f"Shape of y: {y.shape} {y.dtype}") break
输出:
Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28]) Shape of y: torch.Size([64]) torch.int64
想了解更多请移步 从TyTorch加载数据
标签:DataLoader,入门,torch,PyTorch,train,import,size,data,加载 From: https://www.cnblogs.com/conveniencable/p/17506234.html