首页 > 其他分享 >CNN实现手写数字识别

CNN实现手写数字识别

时间:2024-03-25 09:14:33浏览次数:28  
标签:loss self loader test train CNN 手写 识别 data

全部代码如下:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms

# 超参数
batch_size = 64
epochs = 10
learning_rate = 0.01
momentum = 0.5
log_interval = 10

# 准备数据
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = datasets.MNIST('data', train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)


# 定义模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


model = Net()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)


# 训练模型
def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       0. * batch_idx / len(train_loader), loss.item()))


# 测试模型
def test():
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            test_loss += F.nll_loss(output, target, size_average=False).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        0. * correct / len(test_loader.dataset)))


for epoch in range(1, epochs + 1):
    train(epoch)
    test()

 

标签:loss,self,loader,test,train,CNN,手写,识别,data
From: https://www.cnblogs.com/yhlx125/p/18093647

相关文章

  • 识别线程
    线程标识类型是std::thread::id,可以通过两种方式进行检索。第一种,可以通过调用std::thread对象的成员函数get_id()来直接获取。如果std::thread对象没有与任何执行线程相关联,get_id()将返回std::thread::type默认构造值,这个值表示“没有线程”。第二种,当前......
  • 人脸识别学习
    基于人脸识别及反作弊技术的在线评测系统如何使用摄像头Python调用摄像头Python调用摄像头是通过调用摄像头设备来实现的。具体来说,可以使用Python的OpenCV库来打开摄像头设备,并使用摄像头进行图像捕捉和处理。Python调用摄像头示例代码importcv2#打开摄像头设备#0表......
  • 基于傅里叶描述子和HSV颜色特征的KNN水果类型识别,Matlab实现
           博主简介:专注、专一于Matlab图像处理学习、交流,matlab图像代码代做/项目合作可以联系(QQ:3249726188)       个人主页:Matlab_ImagePro-CSDN博客       原则:代码均由本人编写完成,非中介,提供有偿Matlab算法代码编程服务,不从事不违反涉及学术原则......
  • 小白学视觉 | 7大类卷积神经网络(CNN)创新综述
    本文来源公众号“小白学视觉”,仅用于学术分享,侵权删,干货满满。原文链接:7大类卷积神经网络(CNN)创新综述编者荐语本综述将最近的CNN架构创新分为七个不同的类别,分别基于空间利用、深度、多路径、宽度、特征图利用、通道提升和注意力。转载自丨深度学习这件小事深度卷积......
  • 分类预测 | Matlab实现MTF-CNN-Mutilhead-Attention马尔可夫转移场卷积网络多头注意力
    分类预测|Matlab实现MTF-CNN-Mutilhead-Attention马尔可夫转移场卷积网络多头注意力机制多特征分类预测/故障识别目录分类预测|Matlab实现MTF-CNN-Mutilhead-Attention马尔可夫转移场卷积网络多头注意力机制多特征分类预测/故障识别分类效果基本介绍模型描述程序设......
  • 目标识别与分割
    开始毕设不知道多少天,这个破毕设要学的太多了,真的很烦!!!!如果你的应用需要识别手指指向的具体物体或区域,你可以使用目标检测和图像分割算法。目标检测算法可以识别图像中的物体,并给出它们的位置和类别信息,而图像分割算法可以将图像中的物体分割成不同的区域。这个需要用到深度学习......
  • 【发疯毕设日志day7】hagrid_dataset_512数据集作者论文原文逐句翻译——大疆tello手
    论文原文::::2206.08219.pdf(arxiv.org)https://arxiv.org/pdf/2206.08219.pdf摘要     本文介绍了一个庞大的手势识别数据集——海格(HAndGestrueRecognitionImagedataset),以简历一个手势识别(HGR)系统,专注于与设备的交互管理。这就是为什么所选的18个手势都呗赋予......
  • 毕业设计:基于深度学习的指纹识别系统
    目录前言课题背景和意义实现技术思路一、算法理论基础1.1 深度学习1.2迁移学习1.3指纹识别二、 数据集2.1数据集2.2数据扩充三、实验及结果分析3.1 实验环境搭建3.2 模型训练最后前言  ......
  • 【华为OD机试】真题A卷-垃圾短信识别(JAVA)
    一、题目描述【华为OD机试】真题A卷-垃圾短信识别(JAVA)题目描述:大众对垃圾短信深恶痛绝,希望能对垃圾短信发送者进行识别,为此,很多软件增加了垃圾短信的识别机制。经分析,发现正常用户的短信通常具备交互性,而垃圾短信往往都是大量单向的短信,按照如下规则进行垃圾短信识别: 本......
  • matlab实现神经网络检测手写数字
    一、要求1.计算sigmoid函数的梯度;2.随机初始化网络权重;3.编写网络的代价函数。二、算法介绍神经网络结构:不正则化的神经网络的代价函数:正则化:S型函数求导:反向传播算法:step1:初始化,然后使用前向传播算法计算step2:计算第三层的误差;step3:对于第二层 ;step4:使用......