图1 Lenet5 手写字符分类网络架构
Cifar10 数据集的Lenet5的框架实现(Pytorch):
1 import torch 2 from torch import nn, optim 3 import torch.nn.functional as F 4 5 class Lenet5(nn.Module): 6 7 def __init__(self): 8 super(Lenet5, self).__init__() 9 10 self.conv1 = nn.Conv2d(3, 6, kernel_size=5) 11 self.conv2 = nn.Conv2d(6, 16, kernel_size=5) 12 self.pooling = nn.AvgPool2d(kernel_size=2, stride=2) 13 14 self.l1 = nn.Linear(400, 120) 15 self.l2 = nn.Linear(120, 84) 16 self.l3 = nn.Linear(84, 10) 17 18 def forward(self, x): 19 20 # x: [64, 3, 32, 32] 21 batch_size = x.size(0) 22 23 x = self.pooling(self.conv1(x)) 24 x = self.pooling(self.conv2(x)) 25 x = x.view(batch_size, -1) 26 x = F.relu(self.l1(x)) 27 x = F.relu(self.l2(x)) 28 29 return self.l3(x) 30 31 32 class Lenet5_new(nn.Module): 33 34 def __init__(self): 35 super(Lenet5_new, self).__init__() 36 37 self.conv_unit = nn.Sequential( 38 nn.Conv2d(3, 6, kernel_size=5), 39 nn.AvgPool2d(kernel_size=2, stride=2), 40 nn.Conv2d(6, 16, kernel_size=5), 41 nn.AvgPool2d(kernel_size=2, stride=2) 42 ) 43 44 self.classfy = nn.Sequential( 45 nn.Linear(400, 120), 46 nn.ReLU(), 47 nn.Linear(120, 84), 48 nn.ReLU(), 49 nn.Linear(84, 10) 50 ) 51 52 def forward(self, x): 53 54 # x: [64, 3, 32, 32] 55 batch_size = x.size(0) 56 57 x = self.conv_unit(x) 58 59 # print(x.shape) 60 61 x = x.view(batch_size, -1) 62 63 output = self.classfy(x) 64 65 return output 66 67 68 model = Lenet5_new() 69 x = torch.rand(64, 3, 32, 32) 70 print(model(x).shape)View Code
classfyNet_main.py
1 import torch 2 from torch.utils.data import DataLoader 3 from torch import nn, optim 4 from torchvision import datasets, transforms 5 6 from matplotlib import pyplot as plt 7 8 9 import time 10 11 from Lenet5 import Lenet5_new 12 from Resnet18 import ResNet18 13 14 def main(): 15 16 print("Load datasets...") 17 18 # transforms.RandomHorizontalFlip(p=0.5)---以0.5的概率对图片做水平横向翻转 19 # transforms.ToTensor()---shape从(H,W,C)->(C,H,W), 每个像素点从(0-255)映射到(0-1):直接除以255 20 # transforms.Normalize---先将输入归一化到(0,1),像素点通过"(x-mean)/std",将每个元素分布到(-1,1) 21 transform_train = transforms.Compose([ 22 transforms.RandomHorizontalFlip(p=0.5), 23 transforms.ToTensor(), 24 transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 25 ]) 26 27 transform_test = transforms.Compose([ 28 transforms.ToTensor(), 29 transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 30 ]) 31 32 # 内置函数下载数据集 33 train_dataset = datasets.CIFAR10(root="./data/Cifar10/", train=True, 34 transform = transform_train, 35 download=True) 36 test_dataset = datasets.CIFAR10(root = "./data/Cifar10/", 37 train = False, 38 transform = transform_test, 39 download=True) 40 41 print(len(train_dataset), len(test_dataset)) 42 43 Batch_size = 64 44 train_loader = DataLoader(train_dataset, batch_size=Batch_size, shuffle = True) 45 test_loader = DataLoader(test_dataset, batch_size = Batch_size, shuffle = False) 46 47 # 设置CUDA 48 device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") 49 50 # 初始化模型 51 # 直接更换模型就行,其他无需操作 52 # model = Lenet5_new().to(device) 53 model = ResNet18().to(device) 54 55 # 构造损失函数和优化器 56 criterion = nn.CrossEntropyLoss() # 多分类softmax构造损失 57 opt = optim.SGD(model.parameters(), lr=0.001, momentum=0.8, weight_decay=0.001) 58 59 # 动态更新学习率 ------每隔step_size : lr = lr * gamma 60 schedule = optim.lr_scheduler.StepLR(opt, step_size=10, gamma=0.6, last_epoch=-1) 61 62 # 开始训练 63 print("Start Train...") 64 65 epochs = 100 66 67 loss_list = [] 68 train_acc_list =[] 69 test_acc_list = [] 70 epochs_list = [] 71 72 for epoch in range(0, epochs): 73 74 start = time.time() 75 76 model.train() 77 78 running_loss = 0.0 79 batch_num = 0 80 81 for i, (inputs, labels) in enumerate(train_loader): 82 83 inputs, labels = inputs.to(device), labels.to(device) 84 85 # 将数据送入模型训练 86 outputs = model(inputs) 87 # 计算损失 88 loss = criterion(outputs, labels).to(device) 89 90 # 重置梯度 91 opt.zero_grad() 92 # 计算梯度,反向传播 93 loss.backward() 94 # 根据反向传播的梯度值优化更新参数 95 opt.step() 96 97 # 100个batch的 loss 之和 98 running_loss += loss.item() 99 # loss_list.append(loss.item()) 100 batch_num+=1 101 102 103 epochs_list.append(epoch) 104 105 # 每一轮结束输出一下当前的学习率 lr 106 lr_1 = opt.param_groups[0]['lr'] 107 print("learn_rate:%.15f" % lr_1) 108 schedule.step() 109 110 end = time.time() 111 print('epoch = %d/100, batch_num = %d, loss = %.6f, time = %.3f' % (epoch+1, batch_num, running_loss/batch_num, end-start)) 112 running_loss=0.0 113 114 # 每个epoch训练结束,都进行一次测试验证 115 model.eval() 116 train_correct = 0.0 117 train_total = 0 118 119 test_correct = 0.0 120 test_total = 0 121 122 # 训练模式不需要反向传播更新梯度 123 with torch.no_grad(): 124 125 # print("=======================train=======================") 126 for inputs, labels in train_loader: 127 inputs, labels = inputs.to(device), labels.to(device) 128 outputs = model(inputs) 129 130 pred = outputs.argmax(dim=1) # 返回每一行中最大值元素索引 131 train_total += inputs.size(0) 132 train_correct += torch.eq(pred, labels).sum().item() 133 134 135 # print("=======================test=======================") 136 for inputs, labels in test_loader: 137 inputs, labels = inputs.to(device), labels.to(device) 138 outputs = model(inputs) 139 140 pred = outputs.argmax(dim=1) # 返回每一行中最大值元素索引 141 test_total += inputs.size(0) 142 test_correct += torch.eq(pred, labels).sum().item() 143 144 print("train_total = %d, Accuracy = %.5f %%, test_total= %d, Accuracy = %.5f %%" %(train_total, 100 * train_correct / train_total, test_total, 100 * test_correct / test_total)) 145 146 train_acc_list.append(100 * train_correct / train_total) 147 test_acc_list.append(100 * test_correct / test_total) 148 149 # print("Accuracy of the network on the 10000 test images:%.5f %%" % (100 * test_correct / test_total)) 150 # print("===============================================") 151 152 fig = plt.figure(figsize=(4, 4)) 153 154 plt.plot(epochs_list, train_acc_list, label='train_acc_list') 155 plt.plot(epochs_list, test_acc_list, label='test_acc_list') 156 plt.legend() 157 plt.title("train_test_acc") 158 plt.savefig('resnet18_cc_epoch_{:04d}.png'.format(epochs)) 159 plt.close() 160 161 if __name__ == "__main__": 162 163 main()View Code
标签:02,Lenet5,nn,self,list,train,图像,test,size From: https://www.cnblogs.com/zhaopengpeng/p/17040553.html