首页 > 其他分享 >23、完整模型训练的步骤

23、完整模型训练的步骤

时间:2023-02-26 15:33:42浏览次数:32  
标签:loss nn 23 步骤 模型 train 64 data size

1、以CIFAR10为例子

 1 '''以CIFAR10为例子训练完整的模型步骤'''
 2 import torch
 3 import torchvision
 4 from torch import nn
 5 from torch.nn import Conv2d
 6 from torch.utils.data import DataLoader
 7 
 8 '''1、准备数据'''
 9 train_data=torchvision.datasets.CIFAR10(root='../../dataset/CIFAR10',train=True,transform=torchvision.transforms.ToTensor(),
10                                         download=True)
11 test_data=torchvision.datasets.CIFAR10(root='../../dataset/CIFAR10',train=False,transform=torchvision.transforms.ToTensor(),
12                                         download=True)
13 #查看数据集的大小
14 train_data_size=len(train_data)
15 test_data_size=len(test_data)
16 print('训练数据集的长度为:{}'.format(train_data_size))
17 print('测试数据集的长度为:{}'.format(test_data_size))
18 
19 '''2、利用 DataLoader 加载数据'''
20 train_dataloader=DataLoader(train_data,batch_size=64)
21 test_dataloader=DataLoader(test_data,batch_size=64)
22 
23 '''3、搭建神经网络模型,习惯上会把模型单独放在一个文件,然后使用的时候进行引入'''
24 class class_net(nn.Module):
25     def __init__(self):
26         super().__init__()
27         self.modle=nn.Sequential(
28             nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5,stride=1, padding=2),
29             nn.MaxPool2d(kernel_size=2),
30             nn.Conv2d(32,32,5,1,2),
31             nn.MaxPool2d(kernel_size=2),
32             nn.Conv2d(32,64,5,1,2),
33             nn.MaxPool2d(kernel_size=2),
34             nn.Flatten(),
35             nn.Linear(in_features=64*4*4,out_features=64),
36             nn.Linear(64,10)
37         )
38 
39     def forward(self,x):
40         x=self.modle(x)
41         return x
42 #创建网络模型
43 tuidui=class_net()
44 
45 '''4、创建损失函数'''
46 loss_fn=nn.CrossEntropyLoss()
47 '''5、创建优化器'''
48 #习惯把学习速率单独提出来,方便修改,两种写法
49 # learn_rate=0.01
50 learn_rate=1e-2     #1e-2=1*(10)^(-2)=1/100=0.01
51 optimizer=torch.optim.SGD(tuidui.parameters(),lr=learn_rate)
52 
53 '''6、训练网络模型'''
54 #设置训练网络的一些参数
55 
56 total_train_step=0  #记录训练的次数
57 total_test_step=0   #记录测试的次数
58 epoch=10    #训练的轮数
59 for i in range(epoch):
60     print('--------第{}轮训练开始----------'.format(i+1))
61     #训练步骤开始
62     for data in train_dataloader:
63         imgs,targets=data
64         outputs=tuidui(imgs)
65         loss=loss_fn(outputs,targets)
66 
67         #进行优化的第一步是梯度清零
68         optimizer.zero_grad()
69         #利用损失来求每一个参数节点的梯度
70         loss.backward()
71         #进行优化
72         optimizer.step()
73         #更新训练次数
74         total_train_step+=1
75         print('训练次数:{},loss:{}'.format(total_train_step,loss.item()))
76         #loss.item(),使用item会让tensor类型的数据直接变成数字,数值型

通常会把网络模型单独写成一个文件,然后使用的时候直接引入,但是这里我的在引入的时候一直报错。同目录下的导入报错。

 

标签:loss,nn,23,步骤,模型,train,64,data,size
From: https://www.cnblogs.com/ar-boke/p/17156788.html

相关文章

  • 电脑提示msvcp120.dll丢失解决 步骤
    电脑提示找不到msvcp120.dll怎么办?详细安装修复教程电脑提示msvcp120.dll丢失解决步骤打开电脑下载msvcp120.dll在浏览器后在顶部输入【​​dll修复程序.site​​】按下电......
  • 软件工程(3)--原型模型
    前言这是基于我所学习的软件工程课程总结的第三篇文章。原型模型又称原型化模型、快速原型模型书上对于(快速)原型模型的描述是:快速原型是快速建立起的程序,它所能完成的功能往......
  • 【jeecg-boot项目开发crm】:day07JeecgBoot-零基础入门视频-05代码生成(树模型和一对多
    代码生成(树模型和一对多模型,一对多三套模型)树模型生成流程图前期工作:先将页面搭建好页面中代码生成一对多生成流程图只能选主表将生成代码中的vue拷贝到前端目录下接下来......
  • 奶牛大学(2023寒假每日一题 6)
    FarmerJohn计划为奶牛们新开办一所大学!有每头奶牛最多愿意支付FarmerJohn可以设定所有奶牛入学需要支付的学费。如果这笔学费大于一头奶牛愿意支付的最高金额,那么这头......
  • 回收站清空了怎么恢复?2023年怎么使用Easyrecovery恢复误删的数据
    我们在使用电脑时,删除的文件都会先临时放在回收站。回收站里的垃圾文件越多,电脑也会越卡顿,很多人就会清理删除下电脑回收站中的文件。但是有时会出现后续还需要这些文件的情......
  • 2023 年 CCF 春季测试赛模拟赛 - 2 题解
    T1约数和标准解法\(n=a_1^{b_1}\timesa_2^{b_2}\dotsa_k^{b_k}\)那么根据算术基本定理的推广,约数个数和约数和都是可以快速计算得到约数和sum\(sum=(a_1^0......
  • 23_2_26关于pycharm的调试
    pycharm的调试:http://www.360doc.com/content/22/1120/20/37289152_1056826955.shtml 1.添加断点:单击代码行号后面的位置2.进入调试模式:点击“甲壳虫”(似乎已经成了所有I......
  • 2023、2、25-26学习总结
    工程目录:bean:通过bean传递servlet中的数据给daopackageBean;publicclassbean{privateStringwords;privateintid;publicStringgetWo......
  • 2023.2.26【模板】扩展Lucas定理
    2023.2.26【模板】扩展Lucas定理题目概述求\(\binom{n}{m}mod\)\(p\)的值,不保证\(p\)为质数算法流程(扩展和普通算法毫无关系)由于\(p\)不是质数,我们考虑[SDOI201......
  • 《分布式技术原理与算法解析》学习笔记Day23
    分布式数据复制我们在进行分布式数据存储设计时,通常会考虑对数据进行备份,以提高数据的可用性和可靠性,“数据复制技术”就是实现数据备份的关键技术。什么是数据复制技术?......