首页 > 其他分享 >02-Resnet18 图像分类

02-Resnet18 图像分类

时间:2023-01-10 16:12:25浏览次数:44  
标签:02 Resnet18 nn self train 图像 test ResNet18 channel

 

图1 Resnet的残差块

 

 

 图2 Resnet18 网络架构

Cifar10 数据集的Resnet10的框架实现(Pytorch):

 1 import torch
 2 from torch import nn
 3 
 4 # ResNet18_BasicBlock-残差单元
 5 class ResNet18_BasicBlock(nn.Module):
 6     def __init__(self, input_channel, output_channel, stride, use_conv1_1):
 7         super(ResNet18_BasicBlock, self).__init__()
 8 
 9         # 第一层卷积
10         self.conv1 = nn.Conv2d(input_channel, output_channel, kernel_size=3, stride=stride, padding=1)
11         # 第二层卷积
12         self.conv2 = nn.Conv2d(output_channel, output_channel, kernel_size=3, stride=1, padding=1)
13 
14         # 1*1卷积核,在不改变图片尺寸的情况下给通道升维
15         self.extra = nn.Sequential(
16             nn.Conv2d(input_channel, output_channel, kernel_size=1, stride=stride, padding=0),
17             nn.BatchNorm2d(output_channel)
18         )
19 
20         self.use_conv1_1 = use_conv1_1
21 
22         self.bn = nn.BatchNorm2d(output_channel)
23         self.relu = nn.ReLU(inplace=True)
24 
25     def forward(self, x):
26         out = self.bn(self.conv1(x))
27         out = self.relu(out)
28 
29         out = self.bn(self.conv2(out))
30 
31         # 残差连接-(B,C,H,W)维度一致才能进行残差连接
32         if self.use_conv1_1:
33             out = self.extra(x) + out
34 
35         out = self.relu(out)
36         return out
37 
38 
39 # 构建 ResNet18 网络模型
40 class ResNet18(nn.Module):
41     def __init__(self):
42         super(ResNet18, self).__init__()
43 
44         self.conv1 = nn.Sequential(
45             nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=1),
46             nn.BatchNorm2d(64)
47         )
48         self.block1_1 = ResNet18_BasicBlock(input_channel=64, output_channel=64, stride=1, use_conv1_1=False)
49         self.block1_2 = ResNet18_BasicBlock(input_channel=64, output_channel=64, stride=1, use_conv1_1=False)
50 
51         self.block2_1 = ResNet18_BasicBlock(input_channel=64, output_channel=128, stride=2, use_conv1_1=True)
52         self.block2_2 = ResNet18_BasicBlock(input_channel=128, output_channel=128, stride=1, use_conv1_1=False)
53 
54         self.block3_1 = ResNet18_BasicBlock(input_channel=128, output_channel=256, stride=2, use_conv1_1=True)
55         self.block3_2 = ResNet18_BasicBlock(input_channel=256, output_channel=256, stride=1, use_conv1_1=False)
56 
57         self.block4_1 = ResNet18_BasicBlock(input_channel=256, output_channel=512, stride=2, use_conv1_1=True)
58         self.block4_2 = ResNet18_BasicBlock(input_channel=512, output_channel=512, stride=1, use_conv1_1=False)
59 
60         self.FC_layer = nn.Linear(512 * 1 * 1, 10)
61 
62         self.adaptive_avg_pool2d = nn.AdaptiveAvgPool2d((1,1))
63         self.relu = nn.ReLU(inplace=True)
64 
65     def forward(self, x):
66 
67         x = self.relu(self.conv1(x))
68 
69         # ResNet18-网络模型
70         x = self.block1_1(x)
71         x = self.block1_2(x)
72         x = self.block2_1(x)
73         x = self.block2_2(x)
74         x = self.block3_1(x)
75         x = self.block3_2(x)
76         x = self.block4_1(x)
77         x = self.block4_2(x)
78         
79         # 平均值池化
80         x = self.adaptive_avg_pool2d(x)
81         
82         # 数据平坦化处理,为接下来的全连接层做准备
83         x = x.view(x.size(0), -1)
84         x = self.FC_layer(x)
85 
86         return x
View Code

classfyNet_main.py

  1 import torch
  2 from torch.utils.data import DataLoader
  3 from torch import nn, optim
  4 from torchvision import datasets, transforms
  5 
  6 from matplotlib import pyplot as plt
  7 
  8 
  9 import time
 10 
 11 from Lenet5 import Lenet5_new
 12 from Resnet18 import ResNet18
 13 
 14 def main():
 15     
 16     print("Load datasets...")
 17     
 18     # transforms.RandomHorizontalFlip(p=0.5)---以0.5的概率对图片做水平横向翻转
 19     # transforms.ToTensor()---shape从(H,W,C)->(C,H,W), 每个像素点从(0-255)映射到(0-1):直接除以255
 20     # transforms.Normalize---先将输入归一化到(0,1),像素点通过"(x-mean)/std",将每个元素分布到(-1,1)
 21     transform_train = transforms.Compose([
 22                         transforms.RandomHorizontalFlip(p=0.5),
 23                         transforms.ToTensor(),
 24                         transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
 25                     ])
 26 
 27     transform_test = transforms.Compose([
 28                         transforms.ToTensor(),
 29                         transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
 30                     ])
 31     
 32     # 内置函数下载数据集
 33     train_dataset = datasets.CIFAR10(root="./data/Cifar10/", train=True, 
 34                                      transform = transform_train,
 35                                      download=True)
 36     test_dataset = datasets.CIFAR10(root = "./data/Cifar10/", 
 37                                     train = False, 
 38                                     transform = transform_test,
 39                                     download=True)
 40     
 41     print(len(train_dataset), len(test_dataset))
 42     
 43     Batch_size = 64
 44     train_loader = DataLoader(train_dataset, batch_size=Batch_size,  shuffle = True)
 45     test_loader = DataLoader(test_dataset, batch_size = Batch_size, shuffle = False)
 46     
 47     # 设置CUDA
 48     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 49 
 50     # 初始化模型
 51     # 直接更换模型就行,其他无需操作
 52     # model = Lenet5_new().to(device)
 53     model = ResNet18().to(device)
 54     
 55     # 构造损失函数和优化器
 56     criterion = nn.CrossEntropyLoss() # 多分类softmax构造损失
 57     opt = optim.SGD(model.parameters(), lr=0.001, momentum=0.8, weight_decay=0.001)
 58     
 59     # 动态更新学习率 ------每隔step_size : lr = lr * gamma
 60     schedule = optim.lr_scheduler.StepLR(opt, step_size=10, gamma=0.6, last_epoch=-1)
 61     
 62     # 开始训练
 63     print("Start Train...")
 64 
 65     epochs = 100
 66    
 67     loss_list = []
 68     train_acc_list =[]
 69     test_acc_list = []
 70     epochs_list = []
 71     
 72     for epoch in range(0, epochs):
 73          
 74         start = time.time()
 75         
 76         model.train()
 77         
 78         running_loss = 0.0
 79         batch_num = 0
 80         
 81         for i, (inputs, labels) in enumerate(train_loader):
 82             
 83             inputs, labels = inputs.to(device), labels.to(device)
 84             
 85             # 将数据送入模型训练
 86             outputs = model(inputs)
 87             # 计算损失
 88             loss = criterion(outputs, labels).to(device)
 89             
 90             # 重置梯度
 91             opt.zero_grad()
 92             # 计算梯度,反向传播
 93             loss.backward()
 94             # 根据反向传播的梯度值优化更新参数
 95             opt.step()
 96             
 97             # 100个batch的 loss 之和
 98             running_loss += loss.item()
 99             # loss_list.append(loss.item())
100             batch_num+=1
101             
102             
103         epochs_list.append(epoch)
104             
105         # 每一轮结束输出一下当前的学习率 lr
106         lr_1 = opt.param_groups[0]['lr']
107         print("learn_rate:%.15f" % lr_1)
108         schedule.step()
109         
110         end = time.time()
111         print('epoch = %d/100, batch_num = %d, loss = %.6f, time = %.3f' % (epoch+1, batch_num, running_loss/batch_num, end-start))
112         running_loss=0.0    
113         
114         # 每个epoch训练结束,都进行一次测试验证
115         model.eval()
116         train_correct = 0.0
117         train_total = 0
118 
119         test_correct = 0.0
120         test_total = 0
121         
122          # 训练模式不需要反向传播更新梯度
123         with torch.no_grad():
124             
125             # print("=======================train=======================")
126             for inputs, labels in train_loader:
127                 inputs, labels = inputs.to(device), labels.to(device)
128                 outputs = model(inputs)
129 
130                 pred = outputs.argmax(dim=1)  # 返回每一行中最大值元素索引
131                 train_total += inputs.size(0)
132                 train_correct += torch.eq(pred, labels).sum().item()
133           
134             
135             # print("=======================test=======================")
136             for inputs, labels in test_loader:
137                 inputs, labels = inputs.to(device), labels.to(device)
138                 outputs = model(inputs)
139 
140                 pred = outputs.argmax(dim=1)  # 返回每一行中最大值元素索引
141                 test_total += inputs.size(0)
142                 test_correct += torch.eq(pred, labels).sum().item()
143 
144             print("train_total = %d, Accuracy = %.5f %%,  test_total= %d, Accuracy = %.5f %%" %(train_total, 100 * train_correct / train_total, test_total, 100 * test_correct / test_total))    
145 
146             train_acc_list.append(100 * train_correct / train_total)
147             test_acc_list.append(100 * test_correct / test_total)
148 
149         # print("Accuracy of the network on the 10000 test images:%.5f %%" % (100 * test_correct / test_total))
150         # print("===============================================")
151 
152     fig = plt.figure(figsize=(4, 4))
153     
154     plt.plot(epochs_list, train_acc_list, label='train_acc_list')
155     plt.plot(epochs_list, test_acc_list, label='test_acc_list')
156     plt.legend()
157     plt.title("train_test_acc")
158     plt.savefig('resnet18_cc_epoch_{:04d}.png'.format(epochs))
159     plt.close()
160     
161 if __name__ == "__main__":
162     
163     main()
View Code

 

标签:02,Resnet18,nn,self,train,图像,test,ResNet18,channel
From: https://www.cnblogs.com/zhaopengpeng/p/17040590.html

相关文章