torch.utils.data.TensorDataset 这个类可以初始化数据集
例子:
import torch from torch.utils import data # torch.utils.data.dataset 类的使用 x = torch.arange(12, dtype=torch.float32).reshape(6, 2) y = torch.arange(6, dtype=torch.float32).reshape(6, 1) # 初始化数据集,需要两个参数,x是特征,y是标签 torch_dataset = data.TensorDataset(x, y) # 使用data.DataLoader 导入数据集,得到可迭代对象 train_iter = data.DataLoader( dataset = torch_dataset, # 数据集 batch_size = 2, # 批量大小 shuffle=True, # 是否打乱 num_workers=2, # 读取线程 ) # 读取数据 for i in train_iter: for y in i: print(y) print('----------') 标签:TensorDataset,数据,utils,torch,dataset,data From: https://www.cnblogs.com/xinbigworld/p/17029661.html