在这里先把代码放上来
import torch import time import numpy as np import torchvision from torch.utils import data from torchvision import transforms from d2l import torch as d2l d2l.use_svg_display() #利用svg显示图片 import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" #定义一个计时器 class Timer: #记录多次运行时间 def __init__(self): self.times = [] self.start() def start(self): #启动计时器 self.tik = time.time() def stop(self): #停止计时器并将时间记录在列表中 self.times.append(time.time() - self.tik) return self.times[-1] #-1代表列表中的最后一个元素的索引 def avg(self): #返回平均时间 return sum(self.times) / len(self.times) def sum(self): #返回时间总和 return sum(self.times) def cumsum(self): #返回累计时间 return np.array(self.times).cumsum().tolist() #通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式 #并除以255使得所有像素的数值均在0到1之间 trans = transforms.ToTensor() mnist_train = torchvision.datasets.FashionMNIST( root="../data", train=True, transform=trans, download=True) mnist_test = torchvision.datasets.FashionMNIST( root="../data", train=False, transform=trans, download=True) print(len(mnist_train)) #训练集60000张图像 print(len(mnist_test)) #测试集10000张图像 #每个图像的高度和宽度都为28像素.数据集由灰度图像组成,其通道数为1 print(mnist_train[0][0].shape) def get_fashion_mnist_labels(labels): #返回Fashion-MNIST数据集的文本标签 text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] return [text_labels[int(i)] for i in labels] #现在我们创建一个函数来可视化这些样本 def show_images(imgs, num_rows, num_cols, titles = None, scale = 1.5): #绘制图像列表 figsize = (num_cols * scale, num_rows * scale) _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize) axes = axes.flatten() for i, (ax, img) in enumerate(zip(axes, imgs)): if torch.is_tensor(img): #图像张量 ax.imshow(img.numpy()) else: # PIL图像 ax.imshow(img) ax.axes.get_xaxis().set_visible(False) ax.axes.get_yaxis().set_visible(False) if titles: ax.set_title(titles[i]) return axes x, y = next(iter(data.DataLoader(mnist_train, batch_size=18))) show_images(x.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y)); batch_size = 256 def get_dataloader_workers(): #使用4个进程来读取数据 return 4 #if __name__=="__main__": train_iter = data.DataLoader(mnist_train, batch_size, shuffle = True, num_workers = get_dataloader_workers()) #我们看一下读取训练数据所需的时间 timer = d2l.Timer() for x, y in train_iter: continue print(f'{timer.stop():.2f} sec')
标签:灰灰,return,self,train,报错,mnist,import,多线程,def From: https://www.cnblogs.com/fighting-huihui/p/17476337.html