首页 > 其他分享 >[cnn]cnn训练MINST数据集demo

[cnn]cnn训练MINST数据集demo

时间:2023-05-25 11:55:06浏览次数:44  
标签:训练 demo MINST 损失 准确率 epoch 60000 测试 cnn

[cnn]cnn训练MINST数据集demo

tips:

在文件路径进入conda

输入

jupyter nbconvert --to markdown test.ipynb

将ipynb文件转化成markdown文件

jupyter nbconvert --to html test.ipynb

jupyter nbconvert --to pdf test.ipynb

(html,pdf文件同理)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
import numpy as np
input_size = 28   #图像尺寸 28*28
num_class = 10    #标签总数
num_epochs = 3    #训练总周期
batch_size = 64    #一个批次多少图片

train_dataset = datasets.MNIST(
  root='data',
  train=True,
  transform=transforms.ToTensor(),
  download=True,
)

test_dataset = datasets.MNIST(
  root='data',
   train=False,
  transform=transforms.ToTensor(),
  download=True,
)

train_loader = torch.utils.data.DataLoader(
  dataset = train_dataset,
  batch_size = batch_size,
  shuffle = True,
)
test_loader = torch.utils.data.DataLoader(
  dataset = test_dataset,
  batch_size = batch_size,
  shuffle = True,
)


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(   #输入为(1,28,28)
            nn.Conv2d(
                in_channels=1,              
                out_channels=16,      #要得到几个特征图      
                kernel_size=5,        #卷积核大小      
                stride=1,             #步长     
                padding=2,                  
            ),                         #输出特征图为(16*28*28)     
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2), #池化(2x2) 输出为(16,14,14)
        )
        self.conv2 = nn.Sequential(          #输入(16,14,14)
            nn.Conv2d(16, 32, 5, 1, 2),     #输出(32,14,14)
            nn.ReLU(),                      
            nn.MaxPool2d(2),                #输出(32,7,7)
        )
        self.out = nn.Linear(32 * 7 * 7, 10) #全连接

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1) #flatten操作 输出为(batch_size,32*7*7)
        output = self.out(x)
        return output, x 
def accuracy(predictions,labels):
  pred = torch.max(predictions.data,1)[1]
  rights = pred.eq(labels.data.view_as(pred)).sum()
  return rights,len(labels)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device
'cuda'
net = CNN().to(device)
criterion = nn.CrossEntropyLoss() #损失函数
#优化器
optimizer = optim.Adam(net.parameters(),lr = 0.001)

for epoch in range(num_epochs+1):
  #保留epoch的结果
  train_rights = []
  for batch_idx,(data,target) in enumerate(train_loader):
    data = data.to(device)
    target = target.to(device)
    net.train()
    output = net(data)[0]
    loss = criterion(output,target)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    right = accuracy(output,target)
    train_rights.append(right)

    if batch_idx %100 ==0:
      net.eval()
      val_rights = []
      for(data,target) in test_loader:
        data = data.to(device)
        target = target.to(device)
        output = net(data)[0]
        right = accuracy(output,target)
        val_rights.append(right)
      #计算准确率
      train_r = (sum([i[0] for i in train_rights]),sum(i[1] for i in train_rights))
      val_r = (sum([i[0] for i in val_rights]),sum(i[1] for i in val_rights))

      print('当前epoch:{}[{}/{}({:.0f}%)]\t损失:{:.2f}\t训练集准确率:{:.2f}%\t测试集准确率:{:.2f}%'.format(
        epoch,
        batch_idx * batch_size,
        len(train_loader.dataset),
        100. * batch_idx / len(train_loader),
        loss.data,
        100. * train_r[0].cpu().numpy() / train_r[1],
        100. * val_r[0].cpu().numpy() / val_r[1]
      )
      )

当前epoch:0[0/60000(0%)]	损失:2.31	训练集准确率:4.69%	测试集准确率:21.01%
当前epoch:0[6400/60000(11%)]	损失:0.51	训练集准确率:75.94%	测试集准确率:91.43%
当前epoch:0[12800/60000(21%)]	损失:0.28	训练集准确率:84.05%	测试集准确率:93.87%
当前epoch:0[19200/60000(32%)]	损失:0.15	训练集准确率:87.77%	测试集准确率:96.42%
当前epoch:0[25600/60000(43%)]	损失:0.08	训练集准确率:89.82%	测试集准确率:97.02%
当前epoch:0[32000/60000(53%)]	损失:0.14	训练集准确率:91.20%	测试集准确率:97.42%
当前epoch:0[38400/60000(64%)]	损失:0.04	训练集准确率:92.13%	测试集准确率:97.59%
当前epoch:0[44800/60000(75%)]	损失:0.08	训练集准确率:92.83%	测试集准确率:97.83%
当前epoch:0[51200/60000(85%)]	损失:0.12	训练集准确率:93.38%	测试集准确率:97.77%
当前epoch:0[57600/60000(96%)]	损失:0.19	训练集准确率:93.81%	测试集准确率:98.24%
当前epoch:1[0/60000(0%)]	损失:0.07	训练集准确率:95.31%	测试集准确率:97.90%
当前epoch:1[6400/60000(11%)]	损失:0.08	训练集准确率:97.96%	测试集准确率:98.27%
当前epoch:1[12800/60000(21%)]	损失:0.10	训练集准确率:97.99%	测试集准确率:98.30%
当前epoch:1[19200/60000(32%)]	损失:0.02	训练集准确率:98.07%	测试集准确率:98.20%
当前epoch:1[25600/60000(43%)]	损失:0.17	训练集准确率:98.09%	测试集准确率:98.40%
当前epoch:1[32000/60000(53%)]	损失:0.12	训练集准确率:98.11%	测试集准确率:98.68%
当前epoch:1[38400/60000(64%)]	损失:0.05	训练集准确率:98.11%	测试集准确率:98.63%
当前epoch:1[44800/60000(75%)]	损失:0.10	训练集准确率:98.14%	测试集准确率:98.70%
当前epoch:1[51200/60000(85%)]	损失:0.04	训练集准确率:98.19%	测试集准确率:98.56%
当前epoch:1[57600/60000(96%)]	损失:0.03	训练集准确率:98.23%	测试集准确率:98.67%
当前epoch:2[0/60000(0%)]	损失:0.06	训练集准确率:98.44%	测试集准确率:98.32%
当前epoch:2[6400/60000(11%)]	损失:0.03	训练集准确率:98.64%	测试集准确率:98.63%
当前epoch:2[12800/60000(21%)]	损失:0.05	训练集准确率:98.70%	测试集准确率:98.62%
当前epoch:2[19200/60000(32%)]	损失:0.01	训练集准确率:98.72%	测试集准确率:98.69%
当前epoch:2[25600/60000(43%)]	损失:0.01	训练集准确率:98.70%	测试集准确率:98.76%
当前epoch:2[32000/60000(53%)]	损失:0.03	训练集准确率:98.70%	测试集准确率:98.76%
当前epoch:2[38400/60000(64%)]	损失:0.07	训练集准确率:98.70%	测试集准确率:98.62%
当前epoch:2[44800/60000(75%)]	损失:0.07	训练集准确率:98.72%	测试集准确率:98.60%
当前epoch:2[51200/60000(85%)]	损失:0.03	训练集准确率:98.71%	测试集准确率:98.99%
当前epoch:2[57600/60000(96%)]	损失:0.05	训练集准确率:98.74%	测试集准确率:98.84%

标签:训练,demo,MINST,损失,准确率,epoch,60000,测试,cnn
From: https://www.cnblogs.com/jinwan/p/17430743.html

相关文章

  • cnn全连接层
    作用根据特征的组合进行分类大大减少特征位置对分类带来的影响减少特征位置对分类带来的影响就是它把特征representation整合到一起,输出为一个值这样做,有一个什么好处?就是大大减少特征位置对分类带来的影响为啥都是两层?但是大部分是两层以上呢这是为啥子呢泰勒公式都......
  • MQTT入门DEMO(Java语言)
    目录快速开始准备下载及安装第一次安装EMQX第一次运行EMQX客户端代码快速开始准备MQTT简介EMQX简介下载及安装第一次安装EMQX版本选择EMQX支持多种操作系统,请选择合适您的版本下载。下载地址:https://www.emqx.io/cn/downloads#broker在MicrosoftWindows下安装目前EMQX......
  • 图像分类基于cnn的戴口罩和不戴口罩的分类任务-详细教程文档(视频同款)
    图像分类基于cnn的戴口罩和不戴口罩的分类任务-详细教程文档(视频同款)......
  • 讯飞开放平台机器翻译(新)golang实现demo
    最近做项目用到翻译功能,对接了一下科大讯飞的翻译api接口,demo如下:packagemainimport( "crypto/hmac" "crypto/sha256" "encoding/base64" "encoding/json" "fmt" "io/ioutil" "net/http" "time" "......
  • 微服务框架SpringCloud微-2-服务拆分及远程调用-demo黑马
    微服务框架SpringCloud微服务架构2服务拆分及远程调用2.1案例Demo2.1.1服务拆分注意事项 这里四个模块,拆成四个服务就行了 单一职责:不同微服务,不要重复开发相同业务【不能像以前那样了】数据独立:不要访问其它微服务的数据库 3.面向服务:将自己的业务暴......
  • C++ 手搓 CNN 卷积神经网络
    代码请自取https://github.com/xoslh/CNN-MNIST-CPP-1卷积神经网络-CNN的基本原理​ 卷积神经网络(ConvolutionalNeuralNetworks,CNNs)是一种深度学习算法,特别适用于图像处理和分析。其设计灵感来源于生物学中视觉皮层的机制,是一种强大的特征提取和分类工具。1.1Layers......
  • 利用卷积神经网络的Text-CNN 文本分类
    访问【WRITE-BUG数字空间】_[内附完整源码和文档]TextCNN是利用卷积神经网络对文本进行分类的算法,由YoonKim在“ConvolutionalNeuralNetworksforSentenceClassification”一文(见参考[1])中提出.TextCNN是利用卷积神经网络对文本进行分类的算法,由YoonKim在“Conv......
  • Field userClient in com.demo.order.service.OrderService required a bean of type'
    在SpringCloud项目中使用Feign进行远程调用遇到的错误。原因是因为UserClient在com.demo.feign.clients包下面,而order-service的@EnableFeignClientd注解却在com.demo.order包下面,这两个不在同一个包下,无法扫描到UserClient。解决方法有两种1.指定Feign应该扫描的包@EnableFeig......
  • Qt+QtWebApp开发笔记(二):http服务器日志系统介绍、添加日志系统至Demo测试
    前言  上一篇使用QtWebApp的基于Qt的轻量级http服务器实现了一个静态网页返回的Demo,网页服务器很重要的就是日志,因为在服务器类上并没有直接返回,所以,本篇先把日志加上。 Demo  下载地址  链接:https://pan.baidu.com/s/1BPVRLS07qk-WPi-txERKbg?pwd=1234......
  • 【小小demo】Springboot + Vue 增删改查
    vue-table-ui该工程提供的是一个简单的Vue+Element-UI的表格,增删改查操作。工程代码在最下面。环境jdk1.8ideamavenspringboot2.1.1.RELEASE示例首页查询新增修改删除官方文档Element-Ui:https://element.eleme.cn/#/zh-CN/component/installationV......