"""
模型训练代码
"""
import torch
import torchvision.datasets
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader
import cv2
# 一般我们定义我们的模型有两种方式
# 方式一
net = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
nn.Linear(120, 84), nn.Sigmoid(),
nn.Linear(84, 10)
)
train_data = torchvision.datasets.MNIST(root='MNIST',
train=True,
transform=torchvision.transforms.ToTensor(),
download=True
)
test_data = torchvision.datasets.MNIST(root='MNIST',
train=False,
transform=torchvision.transforms.ToTensor(),
download=True
)
train_loader = DataLoader(train_data, batch_size=100, shuffle=True)
test_loader = DataLoader(test_data, batch_size=100, shuffle=True)
# images, lables = next(iter(train_loader))
# img = torchvision.utils.make_grid(images, nrow=10) # 把若干图像拼接成一张图像
# img = img.numpy().transpose(1, 2, 0)
# cv2.imshow('img', img)
# cv2.waitKey(0)
# for data in train_loader:
# imgs, target = data
# # print(imgs.shape)
# # print(target.shape)
# print(target)
# # print(data[0].shape) # (100 , 1, 28, 28) 100个皮偏高 1个通道 28 * 28 的图像
# break
loss = nn.CrossEntropyLoss() # 损失函数
optim = torch.optim.Adam(net.parameters(), lr=0.001) # 优化器
num_epochs = 20
for epoch in range(num_epochs):
sum_loss = 0.0
for data in train_loader:
imgs, targets = data
outputs = net(imgs)
result_loss = loss(outputs, targets)
optim.zero_grad() # 梯度清零
result_loss.backward()
optim.step() # 进行优化
sum_loss = sum_loss + result_loss
print(f'epoch:{epoch + 1},训练误差 :{sum_loss/len(train_data)}')
# 测试
net.eval()
test_acc = 0
for data in test_loader:
imgs, targets = data
outputs = net(imgs)
_, id = torch.max(outputs.data, 1) # 1表示维度 返回概率最大的索引
test_acc += torch.sum(id == targets.data)
print("测试误差:%.3f" %((test_acc * 100) / len(test_data)))
# 模型的保存
torch.save(net.state_dict(), "net_parameters.pth")
"""
模型测试代码
"""
import torch
import torchvision.datasets
from torch import nn
from d2l import torch as d2l
from torchvision import transforms
from torch.utils.data import DataLoader
import cv2
net = nn.Sequential(
# 1是输入通道数,6是输出通道数,这里它的卷积核是6个5*5的
nn.Conv2d(1, 6, kernel_size=5, padding=2),nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
nn.Linear(120, 84), nn.Sigmoid(),
nn.Linear(84, 10)
)
net.load_state_dict(torch.load(r"D:\PycharmProjects\pytorch_study\easy_test\net_parameters.pth"))
test_data = torchvision.datasets.MNIST(root='MNIST',
train=False,
transform=torchvision.transforms.ToTensor(),
download=True
)
test_loader = DataLoader(test_data, batch_size=1, shuffle=True)
for data in test_loader:
imgs, targets = data
output = net(imgs)
print(torch.topk(output, 1)[1].squeeze(0))
img = imgs.numpy().reshape((28, 28))
cv2.imshow('img', img)
cv2.waitKey(0)
break
标签:nn,torch,Pytorch,test,import,手写,data,net,MNIST
From: https://www.cnblogs.com/Sheldon2/p/16906029.html