使用MNIST测试各类CNN网络性能,在此记录,以便按需选择网络。
除了第一个CNN为自己搭的以外,其余模型使用Pytorch官方模型,这些模型提出时是在ImageNet上进行测试,在此补充在MNIST上的测试。
另外时间有限,每种模型只跑一次得出测试数据,实验结果仅供参考
各种参数:
训练集60000、测试集10000,使用GPU训练
GPU信息
NVIDIA GeForce RTX 4060 Laptop GPU
驱动程序版本: 31.0.15.5222
驱动程序日期: 2024/4/11 星期四
DirectX 版本: 12 (FL 12.1)
物理位置: PCI 总线 1、设备 0、功能 0
利用率 0%
专用 GPU 内存 0.3/8.0 GB
共享 GPU 内存 0.0/7.8 GB
GPU 内存 0.3/15.8 GB
epoch:10
batch size:16
learning rate:0.001
优化器:AdamW
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
损失函数:CrossEntropyLoss()
loss_criterion = torch.nn.CrossEntropyLoss() #交叉熵损失函数
0.CNN(自己随便搭的)
import torch
import torch.nn as nn
class CNN(nn.Module):
def __init__(self, in_channels,out_channels,padding,linear_input_size,out_features):
super(CNN,self).__init__()
# 输入通道数
self.in_channels = in_channels
# 全连接层输入大小
self.linear_input_size = linear_input_size
# 第一层卷积输出通道数
self.out_channels = out_channels
# 全连接层输出类别数
self.out_features = out_features
# 边缘填充0圈数
self.padding=padding
# 卷积层1
self.conv1 = nn.Sequential(
# 卷积层,padding:前向计算时在输入特征图周围添加0的圈数
nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=5, stride=2, padding=self.padding),
# 归一化层
nn.BatchNorm2d(self.out_channels),
# 池化层
nn.MaxPool2d(kernel_size=5, stride=2)
)
# 卷积层2
self.conv2 = nn.Sequential(
# 卷积层
nn.Conv2d(in_channels=self.out_channels, out_channels=self.out_channels*2, kernel_size=4, stride=2, padding=self.padding),
# 归一化层
nn.BatchNorm2d(self.out_channels*2),
# 池化层
nn.MaxPool2d(kernel_size=4, stride=2)
)
# 卷积层3
self.conv3 = nn.Sequential(
# 卷积层
nn.Conv2d(in_channels=self.out_channels*2, out_channels=self.out_channels*4, kernel_size=3, stride=1, padding=self.padding),
# 归一化层
nn.BatchNorm2d(self.out_channels*4),
# 池化层
nn.MaxPool2d(kernel_size=3, stride=1)
)
# 展平层
self.flatten = torch.nn.Flatten()
# 激活层1
self.gelu = nn.GELU()
# 全连接层1
self.fc1 = nn.Linear(self.linear_input_size, self.out_features)
def forward(self,x):
# 多层卷积
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
# 展平层
x=self.flatten(x)
# 激活层
x = self.gelu(x)
# 全连接层1
x = self.fc1(x)
return x
# 定义网络参数
# 输入图像通道数
in_channels = 3
# 第一层卷积输出通道数
out_channels=32
# 边缘填充0圈数
padding_size=5
# 全连接层输入大小
linear_input_size=8192
# 全连接层输出类别数
out_features = 10
# 实例化cnn网络
CNN_model = CNN(in_channels,out_channels,padding_size,linear_input_size,out_features)
1、CNN网络结果
CNN_model = CNN(in_channels,out_channels,padding_size,linear_input_size,out_features)
训练用时:0:04:25
准确率:99.02%
保存模型大小:761.36KB
模型提出时间:202406(由本人)
2、ResNet34
VisionNN_model=torchvision.models.resnet34(pretrained=False,num_classes=out_features)
训练用时:0:21:17
准确率:98.72%
保存模型大小:81.37MB
模型提出时间:2015
3、MobileNet_v3_small
VisionNN_model=torchvision.models.mobilenet_v3_small(pretrained=False,num_classes=out_features)
训练用时:0:12:52
准确率:98.35%
保存模型大小:5.99MB
模型提出时间:2017
4、ConvNeXt_small
VisionNN_model=torchvision.models.convnext_small(pretrained=False,num_classes=out_features)
由于此网络不支持MNIST的3通道28x28尺寸的图像,使用以下这一行代码将图像尺寸调整为32x32
torchvision.transforms.Resize((32, 32))
训练用时:0:32:57(图像扩大到原来的1.14倍,时长仅供参考)
准确率:98.69%
保存模型大小:188.87MB
模型提出时间:2020
5、AlexNet
VisionNN_model=torchvision.models.AlexNet(num_classes=out_features)
进行图像尺寸调整,28x28 -> 65x65
torchvision.transforms.Resize((65, 65)),
训练用时:0:18:06(图像扩大到原来的2.32倍,时长仅供参考)
准确率:97.58%
保存模型大小:217.62MB
模型提出时间:2012
6、EfficientNet_v2_s
VisionNN_model=torchvision.models.efficientnet_v2_s(num_classes=out_features)
训练用时:0:42:55
准确率:95.96%
保存模型大小:77.99MB
模型提出时间:2019
7、MNASNet0_5
VisionNN_model=torchvision.models.mnasnet0_5(num_classes=out_features)
训练用时:0:11:48
准确率:93.55%
保存模型大小:3.82MB
模型提出时间:2018
结论:
当准确率差距较小的情况下,训练速度就成了脱颖而出的关键
本次没有更改Pytorch官方模型的代码,在实际生产中还需要测试不同的架构思想结合起来开发适合的网络模型
model | time used | aac | model size | published | 备注 |
---|---|---|---|---|---|
CNN | 0:04:25 | 99.02% | 761.36KB | 2024 | |
MNASNet0_5 | 0:11:48 | 93.55% | 3.82MB | 2018 | |
MobileNet_v3_small | 0:12:52 | 98.35% | 5.99MB | 2017 | |
AlexNet | 0:18:06 | 97.58% | 217.62MB | 2012 | 输入图像尺寸65x65 |
ResNet34 | 0:21:17 | 98.72% | 81.37MB | 2015 | |
ConvNeXt_small | 0:32:57 | 98.69% | 188.87MB | 2020 | 输入图像尺寸32x32 |
EfficientNet_v2_s | 0:42:55 | 95.96% | 77.99MB | 2019 |