首页 > 其他分享 >深度学习—ResNet_CIFAR100代码

深度学习—ResNet_CIFAR100代码

时间:2023-04-06 11:22:43浏览次数:45  
标签:ch nn 代码 ResNet stride train CIFAR100 self out

 

 

 

 

  1 '''
  2 参考资料: PyTorch官方文档
  3 '''
  4 
  5 # 导入所需的包
  6 import torch
  7 import wandb
  8 import torch.nn as nn
  9 from torchvision import transforms
 10 from torchvision.datasets import CIFAR100
 11 from torch.utils.data import DataLoader
 12 
 13 # 使用Compose容器组合定义图像预处理方式
 14 transf = transforms.Compose([
 15     # 将给定图片转为shape为(C, H, W)的tensor
 16     transforms.ToTensor()
 17 
 18 ])
 19 # 数据准备
 20 train_set = CIFAR100(
 21     # 数据集的地址
 22     root="./",
 23     # 是否为训练集,True为训练集
 24     train=True,
 25     # 使用数据预处理
 26     transform=transf,
 27     # 是否需要下载, True为需要下载
 28     download=True
 29 )
 30 test_set = CIFAR100(
 31     root="./",
 32     train=False,
 33     transform=transf,
 34     download=True
 35 )
 36 # 定义数据加载器
 37 train_loader = DataLoader(
 38     # 需要加载的数据
 39     train_set,
 40     # 定义batch大小
 41     batch_size=16,
 42     # 是否打乱顺序,True为打乱顺序
 43     shuffle=True
 44 )
 45 test_loader = DataLoader(
 46     test_set,
 47     batch_size=16,
 48     shuffle=False
 49 )
 50 
 51 # 基础的残差模块
 52 class BasicBlock(nn.Module):
 53     expansion = 1
 54     def __init__(self, ch_in, ch_out, stride=1):
 55         super(BasicBlock, self).__init__()
 56         self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1, bias=False)
 57         self.bn1 = nn.BatchNorm2d(ch_out)
 58         # inplace为True,将计算得到的值直接覆盖之前的值,可以节省时间和内存
 59         self.relu = nn.ReLU(inplace=True)
 60         self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=False)
 61         self.bn2 = nn.BatchNorm2d(ch_out)
 62         self.downsample = None
 63         if ch_out != ch_in:
 64             # 如果输入通道数和输出通道数不相同,使用1×1的卷积改变通道数
 65             self.downsample = nn.Sequential(
 66                 nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=2, bias=False),
 67                 nn.BatchNorm2d(ch_out)
 68             )
 69 
 70     def forward(self,x):
 71         identity = x
 72         out = self.bn1(self.conv1(x))
 73         out = self.relu(out)
 74         out = self.bn2(self.conv2(out))
 75         if self.downsample != None:
 76             identity = self.downsample(x)
 77         out += identity
 78         relu = nn.ReLU()
 79         out = relu(out)
 80         return out
 81 
 82 # 改进型的残差模块
 83 class Bottleneck(nn.Module):
 84     expansion = 4  #扩展,即通道数为之前的4倍
 85     def __init__(self, ch_in, ch_out, stride=1):
 86         super(Bottleneck, self).__init__()
 87         self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, bias=False)
 88         self.bn1 = nn.BatchNorm2d(ch_out)
 89         self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=stride, padding=1, bias=False)
 90         self.bn2 = nn.BatchNorm2d(ch_out)
 91         self.conv3 = nn.Conv2d(ch_out, ch_out * self.expansion, kernel_size=1, stride=1, bias=False)
 92         self.bn3 = nn.BatchNorm2d(ch_out * self.expansion)
 93         self.relu = nn.ReLU(inplace=True)
 94         self.downsample = None
 95         if ch_in != ch_out * self.expansion:
 96             self.downsample = nn.Sequential(
 97                 nn.Conv2d(ch_in, ch_out * self.expansion, kernel_size=1, stride=stride, bias=False),
 98                 nn.BatchNorm2d(ch_out * self.expansion)
 99             )
100 
101     def forward(self, x):
102         identity = x
103         out = self.bn1(self.conv1(x))
104         out = self.relu(out)
105         out = self.bn2(self.conv2(out))
106         out = self.relu(out)
107         out = self.bn3(self.conv3(out))
108         if self.downsample is not None:
109             identity = self.downsample(x)
110         out += identity
111         relu = nn.ReLU()
112         out = relu(out)
113         return out
114 
115 # 实现ResNet网络
116 class ResNet(nn.Module):
117     # 初始化;block:残差块结构;layers:残差块层数;num_classes:输出层神经元即分类数
118     def __init__(self, block, layers, num_classes=1000):
119         super(ResNet, self).__init__()
120         # 改变后的通道数
121         self.channel = 64
122         # 第一个卷积层
123         self.conv1 = nn.Conv2d(3, self.channel, kernel_size=7, stride=2, padding=3, bias=False)
124         self.bn1 = nn.BatchNorm2d(self.channel)
125         self.relu = nn.ReLU(inplace=True)
126         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
127         # 残差网络的四个残差块堆
128         self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
129         self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
130         self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
131         self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
132         self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
133         # 全连接层,也是输出层
134         self.fc = nn.Linear(512 * block.expansion, num_classes)
135 
136 
137 
138     # 用于堆叠残差块
139     def _make_layer(self, block, ch_out, blocks, stride=1):
140         layers = []
141         layers.append(block(self.channel, ch_out, stride))
142         self.channel = ch_out * block.expansion
143         for _ in range(1, blocks):
144             layers.append(block(self.channel, ch_out))
145         return nn.Sequential(*layers)
146 
147     def forward(self, x):
148         x = self.conv1(x)
149         x = self.bn1(x)
150         x = self.relu(x)
151         x = self.maxpool(x)
152         x = self.layer1(x)
153         x = self.layer2(x)
154         x = self.layer3(x)
155         x = self.layer4(x)
156         x = self.avgpool(x)
157         x = x.view(x.size(0), -1)
158         x = self.fc(x)
159         return x
160 
161 # ResNet18生成方法
162 def resnet18(num_classes=1000):
163     model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
164     return model
165 
166 # ResNet50生成方法
167 def resnet50(num_classes=1000):
168     model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes)
169     return model
170 
171 
172 # 定义网络的预训练
173 def train(net, train_loader, test_loader, device, l_r = 0.0002, num_epochs=50):
174     # 使用wandb跟踪训练过程
175     #experiment = wandb.init(project='ResNet18', resume='allow', anonymous='must')
176     experiment = wandb.init(project='ResNet50', resume='allow', anonymous='must')
177     # 定义损失函数
178     criterion = nn.CrossEntropyLoss()
179     # 定义优化器
180     optimizer = torch.optim.Adam(net.parameters(), lr=l_r)
181     # 将网络移动到指定设备
182     net = net.to(device)
183     # 正式开始训练
184     for epoch in range(num_epochs):
185         # 保存一个Epoch的损失
186         train_loss = 0
187         # 计算准确度
188         test_corrects = 0
189         # 设置模型为训练模式
190         net.train()
191         for step, (imgs, labels) in enumerate(train_loader):
192             # 训练使用的数据移动到指定设备
193             imgs = imgs.to(device)
194             labels = labels.to(device)
195             output = net(imgs)
196             # 计算损失
197             loss = criterion(output, labels)
198             # 将梯度清零
199             optimizer.zero_grad()
200             # 将损失进行后向传播
201             loss.backward()
202             # 更新网络参数
203             optimizer.step()
204             train_loss += loss.item()
205             pre_lab = torch.argmax(output, 1)
206             train_batch_corrects = (torch.sum(pre_lab == labels.data).double() / imgs.size(0))
207             if step % 100 == 0:
208                 print("train {} {}/{} loss: {} acc: {}".format(epoch, step, len(train_loader), loss.item(),
209                                                               train_batch_corrects.item()))
210 
211         # 设置模型为验证模式
212         net.eval()
213         for step, (imgs, labels) in enumerate(test_loader):
214             imgs = imgs.to(device)
215             labels = labels.to(device)
216             output = net(imgs)
217             loss = criterion(output, labels)
218             pre_lab = torch.argmax(output, 1)
219             test_batch_corrects = (torch.sum(pre_lab == labels.data).double() / imgs.size(0))
220             test_corrects += test_batch_corrects.item()
221             if step % 100 == 0:
222                 print("val {} {}/{} loss: {} acc: {}".format(epoch, step, len(test_loader), loss.item(), test_batch_corrects.item()))
223 
224         # 一个Epoch结束时,使用wandb保存需要可视化的数据
225         experiment.log({
226             'epoch':epoch,
227             'train loss': train_loss / len(train_loader),
228             'test acc': test_corrects / len(test_loader),
229         })
230         print('Epoch: {}/{}'.format(epoch, num_epochs-1))
231         print('{} Train Loss:{:.4f}'.format(epoch, train_loss / len(train_loader)))
232         print('{} Test Acc:{:.4f}'.format(epoch, test_corrects / len(test_loader)))
233         # 保存此Epoch训练的网络的参数
234         torch.save(net.state_dict(), './ResNet18.pth')
235 
236 if __name__ == "__main__":
237     # 定义训练使用的设备
238     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
239     # 使用自定义的resnet18()方法实现ResNet18网络
240     net = resnet50(num_classes = 100)
241     train(net, train_loader, test_loader, device, l_r=0.00003, num_epochs=10)

 

标签:ch,nn,代码,ResNet,stride,train,CIFAR100,self,out
From: https://www.cnblogs.com/jevonChao/p/17292213.html

相关文章

  • 架构师日记-如何写的一手好代码
    作者:京东零售刘慧卿一前言在日常工作中,我经常听到部分同学抱怨代码质量问题,潜台词是:“除了自己的代码,其他人写的都是垃圾,得送到绞刑架上,重构!”。今天就来聊一聊,如何写的一手好代码。要回答这个问题之前,得先弄清楚一个问题,好代码的标准是什么?易阅读,可扩展,高内聚,低耦合,编程范式,设计......
  • 全网最详细中英文ChatGPT-GPT-4示例文档-会议笔记文档智能转摘要从0到1快速入门——官
    目录Introduce简介setting设置Prompt提示Sampleresponse回复样本APIrequest接口请求python接口请求示例node.js接口请求示例curl命令示例json格式示例其它资料下载ChatGPT是目前最先进的AI聊天机器人,它能够理解图片和文字,生成流畅和有趣的回答。如果你想跟上AI时代的潮流......
  • 网络对抗实验四-恶意代码分析
    Exp4恶意代码分析实验基础实验目标1.监控自己系统的运行状态,看有没有可疑的程序在运行。2.分析一个恶意软件,就分析Exp2或Exp3中生成后门软件;分析工具尽量使用原生指令或sysinternals,systracer套件。3.假定将来工作中你觉得自己的主机有问题,就可以用实验中的这个思路,先整个......
  • VS2012、VS2013、VS2015、VS2019 代码自动注释插件【2】
    Git代码自动注释工具源码地址 VS2010、VS2012、VS2013的代码自动注释插件。安装该插件后,可以在VS的菜单中显示“注释”主菜单,可以给类、函数、成员添加标准的注释,与Doxygen配合使用,可以直接生成项目的注释文档。【插件下载】高版本的VS,可以下载源码后,自行编译使用。【插件安装】......
  • PHP 文件加密Zend Guard Loader 学习和使用(如何安装ioncube扩展对PHP代码加密)
    一、大体流程图二、PHP项目文件加密 下表列出了Zend产品中的PHP版本及其内部API版本和Zend产品版本。如何加密请往后看三、如何使用第一步:确认当前环境AmaiPhalcon前,请确认您具备以下两个条件,如果您的环境不满足此条件,建议您对系统环境进行重新配置。条件1:PHP版本在5.5.X以上(......
  • nohup python app.py 1>log.log 2>&1 & 这句话代码咋解释呀,不太明白
    nohuppythonapp.py1>log.log2>&1&这句话代码咋解释呀,不太明白 GPT给的答案 克隆ChatGpt功能nohuppythonapp.py1>log.log2>&1&这句话代码咋解释呀,不太明白  这个命令可以分成几部分:-`nohup`:意思是不挂断,即使终端关闭或者用户退出登录,进程也将继续运行。-......
  • 深度学习经典网络模型汇总——LeNet、AlexNet、ZFNet、VGGNet、GoogleNet、ResNet【对
    文章目录LeNetAlexNetZFNetVGGNetGoogleNetResNet先来看一下我们要讲述哪些经典的网络模型,如下:LeNet:最早用于手写数字识别的CNN网络AlexNet:2012年ILSVRC比赛冠军,比LeNet层数更深,这是一个历史性突破。ZFNet:2013年ILSVRC比赛效果较好,和AlexNet类似。VGGNet:2014年ILSVRC比赛分类......
  • Android之一个简单计算器源代码
    通过Android4.0 网格布局GridLayout来实现一个简单的计算器界面布局 源码如下(欢迎大家指导批评) packagecom.android.xiong.gridlayoutTest;importjava.math.BigDecimal;importjava.util.regex.Pattern;importcom.android.xiong.gridlayoutTest.R.id;......
  • 基于pytorch搭建ResNet神经网络用于花类识别
    文章目录基于pytorch搭建ResNet神经网络用于花类识别写在前面ResNet网络模型搭建✨✨✨训练结果展示小结基于pytorch搭建ResNet神经网络用于花类识别写在前面【当然这是要在你对这部分网络结构的理论有充分的了解之后】另一方面,我觉得这部分真的得你自己切切实实的钻研,自己一步步的......
  • VSCode中使用Git/git 代码管理
    1.在一个目录下clone项目:[email protected]:hemoumou-debug/libjpeg-turbo.git问题1:解决:需要从known_hosts文件中删除旧的github.comRSA密钥,然后将新的RSA密钥添加到文件中。您可以按照以下步骤操作:在文本编辑器中打开您的known_hosts文件。在Windows......