首页 > 其他分享 >Pytorch基于MNIST数据集简单实现手写数字识别

Pytorch基于MNIST数据集简单实现手写数字识别

时间:2022-11-19 14:11:40浏览次数:64  
标签:nn torch Pytorch test import 手写 data net MNIST

"""
模型训练代码
"""
import torch
import torchvision.datasets
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader
import cv2

# 一般我们定义我们的模型有两种方式
# 方式一
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.Sigmoid(),
    nn.Linear(84, 10)
)


train_data = torchvision.datasets.MNIST(root='MNIST',
                                        train=True,
                                        transform=torchvision.transforms.ToTensor(),
                                        download=True
                                        )

test_data = torchvision.datasets.MNIST(root='MNIST',
                                       train=False,
                                       transform=torchvision.transforms.ToTensor(),
                                       download=True
                                       )

train_loader = DataLoader(train_data, batch_size=100, shuffle=True)
test_loader = DataLoader(test_data, batch_size=100, shuffle=True)


# images, lables = next(iter(train_loader))
# img = torchvision.utils.make_grid(images, nrow=10)  # 把若干图像拼接成一张图像
# img = img.numpy().transpose(1, 2, 0)
# cv2.imshow('img', img)
# cv2.waitKey(0)


# for data in train_loader:
#     imgs, target = data
#     # print(imgs.shape)
#     # print(target.shape)
#     print(target)
#     # print(data[0].shape)  # (100 , 1, 28, 28) 100个皮偏高 1个通道 28 * 28 的图像
#     break


loss = nn.CrossEntropyLoss()  # 损失函数

optim = torch.optim.Adam(net.parameters(), lr=0.001)  # 优化器

num_epochs = 20


for epoch in range(num_epochs):
    sum_loss = 0.0
    for data in train_loader:
        imgs, targets = data
        outputs = net(imgs)
        result_loss = loss(outputs, targets)
        optim.zero_grad()  # 梯度清零
        result_loss.backward()
        optim.step()  # 进行优化
        sum_loss = sum_loss + result_loss
    print(f'epoch:{epoch + 1},训练误差 :{sum_loss/len(train_data)}')


# 测试
net.eval()
test_acc = 0
for data in test_loader:
    imgs, targets = data
    outputs = net(imgs)
    _, id = torch.max(outputs.data, 1)  # 1表示维度 返回概率最大的索引
    test_acc += torch.sum(id == targets.data)
print("测试误差:%.3f" %((test_acc * 100) / len(test_data)))

# 模型的保存
torch.save(net.state_dict(), "net_parameters.pth")

"""
模型测试代码
"""
import torch
import torchvision.datasets
from torch import nn
from d2l import torch as d2l
from torchvision import transforms
from torch.utils.data import DataLoader
import cv2

net = nn.Sequential(
    # 1是输入通道数,6是输出通道数,这里它的卷积核是6个5*5的
    nn.Conv2d(1, 6, kernel_size=5, padding=2),nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.Sigmoid(),
    nn.Linear(84, 10)
)

net.load_state_dict(torch.load(r"D:\PycharmProjects\pytorch_study\easy_test\net_parameters.pth"))

test_data = torchvision.datasets.MNIST(root='MNIST',
                                       train=False,
                                       transform=torchvision.transforms.ToTensor(),
                                       download=True
                                       )

test_loader = DataLoader(test_data, batch_size=1, shuffle=True)

for data in test_loader:
    imgs, targets = data
    output = net(imgs)
    print(torch.topk(output, 1)[1].squeeze(0))
    img = imgs.numpy().reshape((28, 28))
    cv2.imshow('img', img)
    cv2.waitKey(0)
    break

标签:nn,torch,Pytorch,test,import,手写,data,net,MNIST
From: https://www.cnblogs.com/Sheldon2/p/16906029.html

相关文章