首页 > 其他分享 >02-Lenet5 图像分类网络

02-Lenet5 图像分类网络

时间:2023-01-10 16:00:36浏览次数:44  
标签:02 Lenet5 nn self list train 图像 test size

 

 图1 Lenet5 手写字符分类网络架构

Cifar10 数据集的Lenet5的框架实现(Pytorch):

 1 import torch 
 2 from torch import nn, optim
 3 import torch.nn.functional as F
 4 
 5 class Lenet5(nn.Module):
 6     
 7     def __init__(self):
 8         super(Lenet5, self).__init__()
 9         
10         self.conv1 = nn.Conv2d(3, 6, kernel_size=5)
11         self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
12         self.pooling = nn.AvgPool2d(kernel_size=2, stride=2)
13         
14         self.l1 = nn.Linear(400, 120)
15         self.l2 = nn.Linear(120, 84)
16         self.l3 = nn.Linear(84, 10)
17         
18     def forward(self, x):
19         
20         # x: [64, 3, 32, 32]
21         batch_size = x.size(0)
22         
23         x = self.pooling(self.conv1(x))
24         x = self.pooling(self.conv2(x))
25         x = x.view(batch_size, -1)
26         x = F.relu(self.l1(x))
27         x = F.relu(self.l2(x))
28         
29         return self.l3(x)
30 
31     
32 class Lenet5_new(nn.Module):
33     
34     def __init__(self):
35         super(Lenet5_new, self).__init__()
36         
37         self.conv_unit = nn.Sequential(
38                         nn.Conv2d(3, 6, kernel_size=5),
39                         nn.AvgPool2d(kernel_size=2, stride=2),
40                         nn.Conv2d(6, 16, kernel_size=5),
41                         nn.AvgPool2d(kernel_size=2, stride=2)
42                     )
43 
44         self.classfy = nn.Sequential(
45                         nn.Linear(400, 120),
46                         nn.ReLU(),
47                         nn.Linear(120, 84),
48                         nn.ReLU(),
49                         nn.Linear(84, 10)
50                     )
51 
52     def forward(self, x):
53         
54         # x: [64, 3, 32, 32]
55         batch_size = x.size(0)
56         
57         x = self.conv_unit(x) 
58         
59         # print(x.shape)
60         
61         x = x.view(batch_size, -1)
62         
63         output = self.classfy(x)
64         
65         return output
66         
67         
68 model = Lenet5_new()
69 x = torch.rand(64, 3, 32, 32)
70 print(model(x).shape)    
View Code

classfyNet_main.py

  1 import torch
  2 from torch.utils.data import DataLoader
  3 from torch import nn, optim
  4 from torchvision import datasets, transforms
  5 
  6 from matplotlib import pyplot as plt
  7 
  8 
  9 import time
 10 
 11 from Lenet5 import Lenet5_new
 12 from Resnet18 import ResNet18
 13 
 14 def main():
 15     
 16     print("Load datasets...")
 17     
 18     # transforms.RandomHorizontalFlip(p=0.5)---以0.5的概率对图片做水平横向翻转
 19     # transforms.ToTensor()---shape从(H,W,C)->(C,H,W), 每个像素点从(0-255)映射到(0-1):直接除以255
 20     # transforms.Normalize---先将输入归一化到(0,1),像素点通过"(x-mean)/std",将每个元素分布到(-1,1)
 21     transform_train = transforms.Compose([
 22                         transforms.RandomHorizontalFlip(p=0.5),
 23                         transforms.ToTensor(),
 24                         transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
 25                     ])
 26 
 27     transform_test = transforms.Compose([
 28                         transforms.ToTensor(),
 29                         transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
 30                     ])
 31     
 32     # 内置函数下载数据集
 33     train_dataset = datasets.CIFAR10(root="./data/Cifar10/", train=True, 
 34                                      transform = transform_train,
 35                                      download=True)
 36     test_dataset = datasets.CIFAR10(root = "./data/Cifar10/", 
 37                                     train = False, 
 38                                     transform = transform_test,
 39                                     download=True)
 40     
 41     print(len(train_dataset), len(test_dataset))
 42     
 43     Batch_size = 64
 44     train_loader = DataLoader(train_dataset, batch_size=Batch_size,  shuffle = True)
 45     test_loader = DataLoader(test_dataset, batch_size = Batch_size, shuffle = False)
 46     
 47     # 设置CUDA
 48     device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
 49 
 50     # 初始化模型
 51     # 直接更换模型就行,其他无需操作
 52     # model = Lenet5_new().to(device)
 53     model = ResNet18().to(device)
 54     
 55     # 构造损失函数和优化器
 56     criterion = nn.CrossEntropyLoss() # 多分类softmax构造损失
 57     opt = optim.SGD(model.parameters(), lr=0.001, momentum=0.8, weight_decay=0.001)
 58     
 59     # 动态更新学习率 ------每隔step_size : lr = lr * gamma
 60     schedule = optim.lr_scheduler.StepLR(opt, step_size=10, gamma=0.6, last_epoch=-1)
 61     
 62     # 开始训练
 63     print("Start Train...")
 64 
 65     epochs = 100
 66    
 67     loss_list = []
 68     train_acc_list =[]
 69     test_acc_list = []
 70     epochs_list = []
 71     
 72     for epoch in range(0, epochs):
 73          
 74         start = time.time()
 75         
 76         model.train()
 77         
 78         running_loss = 0.0
 79         batch_num = 0
 80         
 81         for i, (inputs, labels) in enumerate(train_loader):
 82             
 83             inputs, labels = inputs.to(device), labels.to(device)
 84             
 85             # 将数据送入模型训练
 86             outputs = model(inputs)
 87             # 计算损失
 88             loss = criterion(outputs, labels).to(device)
 89             
 90             # 重置梯度
 91             opt.zero_grad()
 92             # 计算梯度,反向传播
 93             loss.backward()
 94             # 根据反向传播的梯度值优化更新参数
 95             opt.step()
 96             
 97             # 100个batch的 loss 之和
 98             running_loss += loss.item()
 99             # loss_list.append(loss.item())
100             batch_num+=1
101             
102             
103         epochs_list.append(epoch)
104             
105         # 每一轮结束输出一下当前的学习率 lr
106         lr_1 = opt.param_groups[0]['lr']
107         print("learn_rate:%.15f" % lr_1)
108         schedule.step()
109         
110         end = time.time()
111         print('epoch = %d/100, batch_num = %d, loss = %.6f, time = %.3f' % (epoch+1, batch_num, running_loss/batch_num, end-start))
112         running_loss=0.0    
113         
114         # 每个epoch训练结束,都进行一次测试验证
115         model.eval()
116         train_correct = 0.0
117         train_total = 0
118 
119         test_correct = 0.0
120         test_total = 0
121         
122          # 训练模式不需要反向传播更新梯度
123         with torch.no_grad():
124             
125             # print("=======================train=======================")
126             for inputs, labels in train_loader:
127                 inputs, labels = inputs.to(device), labels.to(device)
128                 outputs = model(inputs)
129 
130                 pred = outputs.argmax(dim=1)  # 返回每一行中最大值元素索引
131                 train_total += inputs.size(0)
132                 train_correct += torch.eq(pred, labels).sum().item()
133           
134             
135             # print("=======================test=======================")
136             for inputs, labels in test_loader:
137                 inputs, labels = inputs.to(device), labels.to(device)
138                 outputs = model(inputs)
139 
140                 pred = outputs.argmax(dim=1)  # 返回每一行中最大值元素索引
141                 test_total += inputs.size(0)
142                 test_correct += torch.eq(pred, labels).sum().item()
143 
144             print("train_total = %d, Accuracy = %.5f %%,  test_total= %d, Accuracy = %.5f %%" %(train_total, 100 * train_correct / train_total, test_total, 100 * test_correct / test_total))    
145 
146             train_acc_list.append(100 * train_correct / train_total)
147             test_acc_list.append(100 * test_correct / test_total)
148 
149         # print("Accuracy of the network on the 10000 test images:%.5f %%" % (100 * test_correct / test_total))
150         # print("===============================================")
151 
152     fig = plt.figure(figsize=(4, 4))
153     
154     plt.plot(epochs_list, train_acc_list, label='train_acc_list')
155     plt.plot(epochs_list, test_acc_list, label='test_acc_list')
156     plt.legend()
157     plt.title("train_test_acc")
158     plt.savefig('resnet18_cc_epoch_{:04d}.png'.format(epochs))
159     plt.close()
160     
161 if __name__ == "__main__":
162     
163     main()
View Code

 

标签:02,Lenet5,nn,self,list,train,图像,test,size
From: https://www.cnblogs.com/zhaopengpeng/p/17040553.html

相关文章