PyTorch 手写字符识别
我们使用MNIST数据集对建立的卷积神经网络进行了训练,并加载测试集进行测试,最终的识别精度达到了99%。
但是官方和网上的测试流程只是演示最终的测试结果,没有很直观的告诉我们怎么在项目中使用他。我们学习机器学习和人工智能的目的不是跑一个官网的演示程序。
这就好比,hello world 我已经会跑了,可是我怎么用到项目中呢?
所以接下来会用一张照片或者自己手写一个数字,输入到训练好的网络中进行识别,来验证结果是不是正确。
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.autograd import Variable
from torchvision.transforms import transforms
from PIL import Image, ImageOps
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from matplotlib import pyplot as plt
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, 5)
self.conv2 = nn.Conv2d(10, 20, 3)
self.fc1 = nn.Linear(20 * 10 * 10, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
input_size = x.size(0)
x = self.conv1(x)
x = F.relu(x)
x = F.max_pool2d(x, 2, 2)
x = self.conv2(x)
x = F.relu(x)
x = x.view(input_size, -1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
标签:__,10,字符识别,nn,self,torch,PyTorch,import,手写
From: https://blog.csdn.net/u010604770/article/details/143771113