首页 > 其他分享 >深度学习—AlexNet_CIFAR100代码

深度学习—AlexNet_CIFAR100代码

时间:2023-04-03 16:55:06浏览次数:45  
标签:nn 代码 train CIFAR100 net True self AlexNet size

 

 

 

 

  1 # 导入所需的包
  2 import torch
  3 #import wandb
  4 import torch.nn as nn
  5 from torchvision import transforms
  6 from torch.utils.data import DataLoader
  7 from torchvision.datasets import CIFAR100
  8 
  9 # 使用Compose容器组合定义图像预处理方式
 10 transf = transforms.Compose([
 11     # 改变图像大小
 12     transforms.Resize(224),
 13     # 将给定图片转为shape为(C, H, W)的tensor
 14     transforms.ToTensor()
 15 ])
 16 # 数据准备
 17 traindata = CIFAR100(
 18     # 数据集的地址
 19     root="./",
 20     # 是否为训练集,True为训练集
 21     train=True,
 22     # 使用数据预处理
 23     transform=transf,
 24     # 是否需要下载, True为需要下载
 25     download=True
 26 )
 27 testdata = CIFAR100(
 28     root="./",
 29     train=False,
 30     transform=transf,
 31     download=True
 32 )
 33 # 定义数据加载器
 34 trainloader = DataLoader(
 35     # 需要加载的数据
 36     traindata,
 37     # 定义batch大小
 38     batch_size=128,
 39     # 是否打乱顺序,True为打乱顺序
 40     shuffle=True
 41 )
 42 testloader = DataLoader(
 43     testdata,
 44     batch_size=128,
 45     shuffle=False
 46 )
 47 
 48 # 定义AlexNet网络
 49 # 此处写代码
 50 from torch import nn
 51 class AlexNet(nn.Module):
 52     # 初始化
 53     def __init__(self):
 54         super(AlexNet, self).__init__()
 55         self.conv1 = nn.Conv2d(in_channels=3,
 56                                out_channels=96,
 57                                kernel_size=11,
 58                                padding=2,
 59                                stride=4)
 60         self.relu1 = nn.ReLU()
 61         self.max_pool1 = nn.MaxPool2d(kernel_size=3,
 62                                       stride=2,
 63                                       padding=0)
 64 
 65         self.conv2 = nn.Conv2d(in_channels=96,
 66                                out_channels=256,
 67                                kernel_size=5,
 68                                padding=2,
 69                                stride=1)
 70         self.relu2 = nn.ReLU()
 71         self.max_pool2 = nn.MaxPool2d(kernel_size=3,
 72                                       stride=2,
 73                                       padding=0)
 74 
 75         self.conv3 = nn.Conv2d(in_channels=256,
 76                                out_channels=384,
 77                                kernel_size=3,
 78                                padding=1,
 79                                stride=1)
 80         self.relu3 = nn.ReLU()
 81 
 82         self.conv4 = nn.Conv2d(in_channels=384,
 83                                out_channels=384,
 84                                kernel_size=3,
 85                                padding=1,
 86                                stride=1)
 87         self.relu4 = nn.ReLU()
 88 
 89         self.conv5 = nn.Conv2d(in_channels=384,
 90                                out_channels=256,
 91                                kernel_size=3,
 92                                padding=1,
 93                                stride=1)
 94         self.relu5 = nn.ReLU()
 95         self.max_pool5 = nn.MaxPool2d(kernel_size=3,
 96                                       stride=2,
 97                                       padding=0)
 98 
 99         self.dropout1=nn.Dropout(0.5)
100         self.linear1 = nn.Linear(in_features=256*6*6,
101                                 out_features=4096,
102                                 bias=True)
103         self.relu6 = nn.ReLU()
104 
105         self.dropout2 = nn.Dropout(0.5)
106         self.linear2 = nn.Linear(in_features=4096,
107                                  out_features=4096,
108                                  bias=True)
109         self.relu7 = nn.ReLU()
110 
111         self.linear3 = nn.Linear(in_features=4096,
112                                  out_features=100,
113                                  bias=True)
114 
115 
116     # 定义前向计算过程
117     def forward(self, x):
118         x = self.conv1(x)
119         x = self.relu1(x)
120         x = self.max_pool1(x)
121 
122         x = self.conv2(x)
123         x = self.relu2(x)
124         x = self.max_pool2(x)
125 
126         x = self.conv3(x)
127         x = self.relu3(x)
128 
129         x = self.conv4(x)
130         x = self.relu4(x)
131 
132         x = self.conv5(x)
133         x = self.relu5(x)
134         x = self.max_pool5(x)
135 
136         # 将特征展平(超级重要!!!)
137         x = x.view(x.shape[0], -1)
138 
139         x = self.dropout1(x)
140         x = self.linear1(x)
141         x = self.relu6(x)
142 
143         x = self.dropout2(x)
144         x = self.linear2(x)
145         x = self.relu7(x)
146 
147 
148         x = self.linear3(x)
149 
150         return x
151 
152 # 定义网络的预训练
153 def train(net, train_loader, test_loader, device, l_r = 0.0002, num_epochs=25,):
154     # 使用wandb跟踪训练过程
155     #experiment = wandb.init(project='AlexNet', resume='allow', anonymous='must')
156     # 定义损失函数
157     criterion = nn.CrossEntropyLoss()
158     # 定义优化器
159     optimizer = torch.optim.Adam(net.parameters(), lr=l_r)
160     # 将网络移动到指定设备
161     net = net.to(device)
162     # 正式开始训练
163     for epoch in range(num_epochs):
164         # 保存一个Epoch的损失
165         train_loss = 0
166         # 计算准确度
167         test_corrects = 0
168         # 设置模型为训练模式
169         net.train()
170         for step, (imgs, labels) in enumerate(train_loader):
171             # 训练使用的数据移动到指定设备
172             imgs = imgs.to(device)
173             labels = labels.to(device)
174             output = net(imgs)
175             # 计算损失
176             loss = criterion(output, labels)
177             # 将梯度清零
178             optimizer.zero_grad()
179             # 将损失进行后向传播
180             loss.backward()
181             # 更新网络参数
182             optimizer.step()
183             train_loss += loss.item()
184         # 设置模型为验证模式
185         net.eval()
186         for step, (imgs, labels) in enumerate(test_loader):
187             imgs = imgs.to(device)
188             labels = labels.to(device)
189             output = net(imgs)
190             pre_lab = torch.argmax(output, 1)
191             corrects = (torch.sum(pre_lab == labels.data).double() / imgs.size(0))
192             test_corrects += corrects.item()
193         #一个Epoch结束时,使用wandb保存需要可视化的数据
194         # experiment.log({
195         #     'epoch':epoch,
196         #     'train loss': train_loss / len(train_loader),
197         #     'test acc': test_corrects / len(test_loader),
198         # })
199         print('Epoch: {}/{}'.format(epoch, num_epochs-1))
200         print('{} Train Loss:{:.4f}'.format(epoch, train_loss / len(train_loader)))
201         print('{} Test Acc:{:.4f}'.format(epoch, test_corrects / len(test_loader)))
202         # 保存此Epoch训练的网络的参数
203         torch.save(net.state_dict(), './net.pth')
204 
205 if __name__ == "__main__":
206     # 定义训练使用的设备
207     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
208     net = AlexNet()
209     train(net, trainloader, testloader, device, l_r=0.0003, num_epochs=10)

 

标签:nn,代码,train,CIFAR100,net,True,self,AlexNet,size
From: https://www.cnblogs.com/jevonChao/p/17266546.html

相关文章

  • 深度学习—VGG16_CIFAR100代码
     1#导入所需的包2importtorch3#importwandb4importtorch.nnasnn5fromtorchvisionimporttransforms6fromtorchvision.datasetsimportCIFAR1007fromtorch.utils.dataimportDataLoader89#使用Compose容器组合定义图像预处理方......
  • Pycharm创建自定义代码片段
    简介PyCharm允许您创建自定义代码片段,也称为代码模板,以提高您的开发效率实现步骤1.添加代码模板打开PyCharm并导航到File->Settings,或者按快捷键ctrl+alt+s打开设置​按照如下序号步骤进行点击,点击“+”按钮以创建新的代码模板,选择LiveTemplate,此处可以看到很多pych......
  • LeaRun低代码开发平台 赋能企业快速落地BI大屏
    在信息化变革的大势下,如何理清错综复杂的业务需求,重构企业数智化新模式,已成为关乎企业发展的“必修课”和行业共识。当前,数字化转型已经进入全面落地阶段,越来越多的中小企业、地方企业和传统企业都开始参与进来,但在转型过程中,往往存在预算有限、技术能力不足等困难。如何将先进的......
  • gitlab推送代码触发jenkins构建
    预期:推送devloop或者master分支的代码,自动执行jenkins发布测试环境首先,jenkins中需要安装如下插件打开一个任务配置,构建触发器中勾选"BuildwhenachangeispushedtoGitLab."并过滤指定分支,这里需要记下GitLabwebhookURL一会儿配置到gitlab上3.gitlab中添......
  • m基于AlexNet神经网络和GEI步态能量图的步态识别算法MATLAB仿真
    1.算法描述        AlexNet是2012年ImageNet竞赛冠军获得者Hinton和他的学生AlexKrizhevsky设计的。也是在那年之后,更多的更深的神经网络被提出,比如优秀的vgg,GoogLeNet。这对于传统的机器学习分类算法而言,已经相当的出色。Alexnet网络模型于2012年提出。它具有更高维......
  • 代码审计系统 Swallow 开发回顾
    做甲方安全建设,SDL是一个离不开的话题,其中就包含代码审计工作,我从最开始使用编辑器自带的查找,到使用fortify工具,再到后来又觉得fortify的扫描太慢影响审计效率,再后来就想着把fortify集成到自己的业务系统中去最近几年安全行业发展的很快,以前少见的组件安全产品也多了起来,......
  • 全网最详细中英文ChatGPT-GPT-4示例文档-复杂函数快速转单行函数从0到1快速入门——官
    目录Introduce简介setting设置Prompt提示Sampleresponse回复样本APIrequest接口请求python接口请求示例node.js接口请求示例curl命令示例json格式示例其它资料下载ChatGPT是目前最先进的AI聊天机器人,它能够理解图片和文字,生成流畅和有趣的回答。如果你想跟上AI时代的潮流......
  • 野火书籍《STM32库开发指南》 第26章LCD代码勘误
     第26章LCD代码,P303代码写错。原来的代码写错,因为是D/CX引脚,高电平(1)意味着数据,低电平(0)意味着命令:#defineFSMC_Addr_ILI9341_CMD((uint32_t))0x60020000#defineFSMC_Addr_ILI9341_DATA((uint32_t))0x60000000因此应该将两个宏定义对换。正确的为:#defineFS......
  • 180122 特征值与特征向量的几何解释与python代码,附matplotlib绘制多边形
    HowtoPlotPolygonsinPythonShapely-ManualShapely-Test3Blue1Brown-线性代数的几何解释DownloadsShapely-WindowsShapely-MacorLinux红色基坐标(竖着看)1001绿色变换矩阵(竖着看)3102蓝色特征向量(竖着看)1−2√202√2黑色变换矩阵(左乘)特征向量(竖着......
  • Java-Day-2(转义字符 + 注释 + 代码规范 + 变量 + 数据类型)
    Java-Day-2常用转义字符代码中只一个\会默认转义(写在“”里)\t:制表位,可以实现对齐功能,可以看作有一个无形表框(上下两行长度相差不大)\n:换行符,仅换代码行的话\\:一个\,想输出"\\"就要输入四个\\'':一个“,字符串里输出双引号\':一个‘\r:一个回车,光标......