首页 > 其他分享 >《PyTorch深度学习实践》-刘二大人 第九讲

《PyTorch深度学习实践》-刘二大人 第九讲

时间:2022-10-23 15:55:22浏览次数:52  
标签:loss 刘二 self torch batch PyTorch train test 第九

课堂练习,课后作业不想做了……

 1 import torch
 2 from torchvision import transforms
 3 from torchvision import datasets
 4 from torch.utils.data import DataLoader
 5 import torch.nn.functional as F
 6 import torch.optim as optim
 7 
 8 # prepare dataset
 9 
10 batch_size = 64
11 transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])  # 归一化,均值和方差
12 
13 train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)
14 train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
15 test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)
16 test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)
17 
18 
19 # design model using class
20 class Net(torch.nn.Module):
21     def __init__(self):
22         super(Net, self).__init__()
23         self.l1 = torch.nn.Linear(784, 512)
24         self.l2 = torch.nn.Linear(512, 256)
25         self.l3 = torch.nn.Linear(256, 128)
26         self.l4 = torch.nn.Linear(128, 64)
27         self.l5 = torch.nn.Linear(64, 10)
28 
29     def forward(self, x):
30         x = x.view(-1, 784)  # -1其实就是自动获取mini_batch
31         x = F.relu(self.l1(x))
32         x = F.relu(self.l2(x))
33         x = F.relu(self.l3(x))
34         x = F.relu(self.l4(x))
35         return self.l5(x)  # 最后一层不做激活,不进行非线性变换
36 model = Net()
37 
38 # construct loss and optimizer
39 criterion = torch.nn.CrossEntropyLoss()
40 optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
41 
42 
43 # training cycle forward, backward, update
44 def train(epoch):
45     running_loss = 0.0
46     for batch_idx, data in enumerate(train_loader, 0):
47         # 获得一个批次的数据和标签
48         inputs, target = data
49         optimizer.zero_grad()
50         # 获得模型预测结果(64, 10)
51         outputs = model(inputs)
52         # 交叉熵代价函数outputs(64,10),target(64)
53         loss = criterion(outputs, target)
54         loss.backward()
55         optimizer.step()
56 
57         running_loss += loss.item()
58         if batch_idx % 300 == 299:
59             print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
60             running_loss = 0.0
61 
62 #名字不能设为test会被识别为程序入口
63 def hehe_11():
64     correct = 0
65     total = 0
66     with torch.no_grad():
67         for data in test_loader:
68             images, labels = data
69             outputs = model(images)
70             _, predicted = torch.max(outputs.data, dim=1)  # dim = 1 列是第0个维度,行是第1个维度
71             total += labels.size(0)
72             correct += (predicted == labels).sum().item()  # 张量之间的比较运算
73     print('accuracy on test set: %d %% ' % (100 * correct / total))
74 
75 
76 if __name__ == '__main__':
77     for epoch in range(10):
78         train(epoch)
79         hehe_11()

结果:

accuracy on test set: 97 %
[9, 300] loss: 0.039
[9, 600] loss: 0.042
[9, 900] loss: 0.040
accuracy on test set: 97 %
[10, 300] loss: 0.033
[10, 600] loss: 0.034
[10, 900] loss: 0.032
accuracy on test set: 97 %

标签:loss,刘二,self,torch,batch,PyTorch,train,test,第九
From: https://www.cnblogs.com/zhouyeqin/p/16818731.html

相关文章

  • PyTorch (1) | PyTorch的安装与简介
    本文已收录于Pytorch系列专栏:​​Pytorch入门与实践​​专栏旨在详解Pytorch,精炼地总结重点,面向入门学习者,掌握Pytorch框架,为数据分析,机器学习及深度学习的代码能力打下......
  • PyTorch 深度学习实践
    1importnumpyasnp2importtorch3importmatplotlib.pyplotasplt4importos5os.environ['KMP_DUPLICATE_LIB_OK']='True'67#1preparedataset......
  • Pytorch学习笔记
    两大强大的工具函数:1.dir(),打开一个包,输出包内含有的其他子类2.help(),帮助文档1help(torch.cuda.is_available)2Helponfunctionis_availableinmoduletorch.cuda......
  • 综合素质 第二章 教育法律法规 第九节 《学生伤害事故处理办法》和《中华人民共和国民
     1.★【学校承担事故责任的情形】因以下情形之一造成学生伤害事故。①学校的校舍,场地,其他公共设施,以及学校提供给学生使用的学具,教育教学和生活设施,设备不符合国家规定......
  • 《PyTorch深度学习实践》-刘二大人 第五讲
    1importtorch23#1preparedataset4#x,y是矩阵,3行1列也就是说总共有3个数据,每个数据只有1个特征5x_data=torch.tensor([[1.0],[2.0],[3.0]])6y_d......
  • 《PyTorch深度学习实践》-刘二大人 第六讲
    1importtorch2importtorch.nn.functionalasF34#1preparedataset5x_data=torch.Tensor([[1.0],[2.0],[3.0]])6y_data=torch.Tensor([[0],[0......
  • 【Spring第九篇】AOP
    文章目录​​AOP核心概念​​​​AOP:切点表达式​​​​AOP:使用切点表达式@annotation​​​​通知分类​​​​获取被增强方法相关信息​​​​【不使用自动注入】AOP方......
  • 《PyTorch深度学习实践》-刘二大人 第三讲
    #梯度下降法frommatplotlibimportpyplotasplt#preparethetrainingsetx_data=[1.0,2.0,3.0]y_data=[2.0,4.0,6.0]#initialguessofweightw=......
  • 《PyTorch深度学习实践》-刘二大人 第二讲
    刘二大人的Pytorch保姆式教程。我觉得算0基础学Pytorch吧,从我现在的基础看就是比较easy的程度,正和我意~课堂练习:importnumpyasnpimportmatplotlib.pyplotasplt......
  • 安装Pytorch
    下面三种需求都是可以尝试的:错误1:AssertionError:TorchnotcompiledwithCUDAenabled错误2:torch.cuda.is_available() 输出false需求3:就是想安装Pytorch 请......