首页 > 其他分享 >CNN --入门MNIST识别

CNN --入门MNIST识别

时间:2023-09-26 17:11:18浏览次数:46  
标签:-- self torch batch CNN import data MNIST size

Smiling & Weeping

              ---- 下次你撑伞低头看水洼,

                就会想起我说雨是神的烟花。

 

简介:主要是看刘二大人的视频讲解:https://www.bilibili.com/video/BV1Y7411d7Ys/?spm_id_from=333.337.search-card.all.click

题目及提交链接:Digit Recognizer | Kaggle

深度学习入门的学习项目,使用CNN(Convolutional Nerual Network)

对于Basic CNN的理解:

  1. 分成两个部分:前一个部分叫做Feature Extraction,后一部分叫做Classification(其中Feature Extraction又可以分为Convolution,Subsampling等)
  2. 其中要求卷积核的通道数量与输入通道数量一致。这种卷积核的总数和输出通道数目的总数一致(详见链接PDF)
  3. 卷积(convolution)后,C(channels),W(width),H(height),其中padding和pooling(小技巧:若要卷积W,H不变,取整kernel_size/2)
  4. 卷积层:保存图像的空间信息
  5. 卷积层要求输入输出是四维张量(B,C,W,H),全连接层的输入输出都是二维张量(B,Input_feature)
  6. 卷积(线性变换),激活函数(非线性变换),池化;这个过程若干次后,view打平,进入全连接层
  1 import torch
  2 import torch.nn.functional as F
  3 import torch.nn as nn
  4 import torch.optim as optim
  5 import torch.autograd as lr_scheduler
  6 from torch.utils.data import DataLoader, Dataset
  7 from torchvision import transforms
  8 from torchvision.utils import make_grid
  9 from torchvision import datasets
 10 from torch.autograd import Variable
 11 from sklearn.model_selection import train_test_split 
 12 import pandas as pd
 13 import numpy as np
 14 import matplotlib.pyplot as plt
 15 
 16 batch_size = 64
 17 transform = transforms.Compose([transforms.ToTensor()])
 18 train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)
 19 train_loader = torch.utils.data.DataLoader(dataset=train_dataset, shuffle=True, batch_size=batch_size)
 20 #同样的方式加载一下测试集
 21 test_dataset = datasets.MNIST(root='../dataset/mnist/',  train=False, download=True, transform=transform)
 22 test_loader = torch.utils.data.DataLoader(dataset=test_dataset, shuffle=False,  batch_size=batch_size)
 23 
 24 # 使用卷积神经网络进行图像特征提取
 25 # (batch, 1, 28, 28) -> (batch, 10, 24, 24) -> 池化 (batch, 10, 12, 12) -> (batch, 20, 8, 8) -> (batch, 20 , 4, 4) -> (batch, 320) -> (batch, 10) 
 26 class Net(torch.nn.Module):
 27     def __init__(self):
 28         super(Net, self).__init__()
 29         self.conv1 = torch.nn.Conv2d(1, 10, kernel_size = 5)
 30         self.conv2 = torch.nn.Conv2d(10, 20, kernel_size = 5)
 31         self.pooling = torch.nn.MaxPool2d(2)
 32         self.fc = torch.nn.Linear(320, 10)
 33         
 34     def forward(self, x):
 35         # Flatten data from (n, 1, 28, 28) to (n, 784)
 36         batch_size = x.size(0)
 37         x = F.relu(self.pooling(self.conv1(x)))
 38         x = F.relu(self.pooling(self.conv2(x)))
 39         x = x.view(batch_size, -1) # Flatten
 40         x = self.fc(x)
 41         return x
 42 
 43 model = Net()
 44 # print(model)
 45 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 46 model.to(device)
 47 criterion = torch.nn.CrossEntropyLoss(size_average=True)
 48 optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
 49 
 50 def train(epoch):
 51     running_loss = 0.0
 52     for batch_idx, data in enumerate(train_loader, 0):
 53         inputs, target = data
 54         inputs, target = inputs.to(device), target.to(device)
 55         optimizer.zero_grad()
 56         
 57         # forward + backward + update
 58         outputs = model(inputs)
 59         # 计算真实值 和 测量值 之间的误差
 60         loss = criterion(outputs, target)
 61         loss.backward()
 62         optimizer.step()
 63         
 64         running_loss += loss.item()
 65         if batch_idx % 300 == 299:
 66             print('[%d, %5d] loss: %3f' % (epoch + 1, batch_idx+1, running_loss / 2000))
 67             running_loss = 0.0 
 68 
 69 def test():
 70     correct = 0
 71     total = 0
 72     with torch.no_grad():
 73         for data in test_loader:
 74             inputs, target = data
 75             inputs, target = inputs.to(device), target.to(device)
 76             outputs = model(inputs)
 77             _, prediction = torch.max(outputs.data, dim=1)
 78             total += target.size(0)
 79             correct += (prediction == target).sum().item()
 80     print('Accuracy on test set: %d %% [%d/%d]' % (100*correct / total, correct, total))
 81     return correct/total
 82 
 83 epoch_list = []
 84 acc_list = []
 85 for epoch in range(10):
 86     train(epoch)
 87     acc = test()
 88     epoch_list.append(epoch)
 89     acc_list.append(acc)
 90     
 91 plt.plot(epoch_list, acc_list)
 92 plt.ylabel("accuracy")
 93 plt.xlabel("epoch")
 94 plt.show()
 95 
 96 class DatasetSubmissionMNIST(torch.utils.data.Dataset):
 97     def __init__(self, file_path, transform=None):
 98         self.data = pd.read_csv(file_path)
 99         self.transform = transform
100         
101     def __len__(self):
102         return len(self.data)
103     
104     def __getitem__(self, index):
105         image = self.data.iloc[index].values.astype(np.uint8).reshape((28, 28, 1))
106 
107         
108         if self.transform is not None:
109             image = self.transform(image)
110             
111         return image
112 
113 transform = transforms.Compose([
114     transforms.ToPILImage(),
115     transforms.ToTensor(),
116     transforms.Normalize(mean=(0.5,), std=(0.5,))
117 ])
118 
119 submissionset = DatasetSubmissionMNIST('/kaggle/input/digit-recognizer/test.csv', transform=transform)
120 submissionloader = torch.utils.data.DataLoader(submissionset, batch_size=batch_size, shuffle=False)
121 
122 submission = [['ImageId', 'Label']]
123 
124 with torch.no_grad():
125     model.eval()
126     image_id = 1
127 
128     for images in submissionloader:
129         images = images.cuda()
130         log_ps = model(images)
131         ps = torch.exp(log_ps)
132         top_p, top_class = ps.topk(1, dim=1)
133         
134         for prediction in top_class:
135             submission.append([image_id, prediction.item()])
136             image_id += 1
137             
138 print(len(submission) - 1)
139 import csv
140 
141 with open('submission.csv', 'w') as submissionFile:
142     writer = csv.writer(submissionFile)
143     writer.writerows(submission)
144     
145 print('Submission Complete!')
146 # summission.to_csv('/kaggle/working/submission.csv', index=False)

就效果来说,也就一般,后面的Advance CNN 会有更高的效率和准确性,大家可以敲一下代码放在自己的编译器上跑一下

对了,这是GPU版本,若用CPU,把所有的device删除就可以,--<-<-<@

文章到此结束,我们下次再见

一束光线,可能会摔碎

                                 但仍旧光芒四射

标签:--,self,torch,batch,CNN,import,data,MNIST,size
From: https://www.cnblogs.com/smiling-weeping-zhr/p/17730663.html

相关文章

  • docker-compose安装Redis
    一、单机版本1、docker-composeversion:'3'services:redis:image:redis:5.0restart:alwaysprivileged:truecontainer_name:redis-javaports:-6379:6379volumes:-/var/docker/server/redis/redis.conf:/etc/redis......
  • (10/1-10/31)10月摸鱼计划,挑战7/14/21天发博文,实体礼品包邮送!
    10月摸鱼计划,来啦!本月继续以【博主任务】形式,让大家自发选择更文任务!任务达标后即可兑奖!且任务间的奖品可同享!【活动时间】发文时间:2023年10月1日—2023年10月31日【活动任务】以下任务福利可同享!!任务一:7天更文任务要求任务链接任务奖品7天发布文章(可以非连续)发文直达>>https://blo......
  • P6344 [CCO2017] Vera 与现代艺术 题解
    在\(V\timesV\)的平面上,\(n\)次修改,每次给定\(x,y,v\),令\(a,b\)为不超过\(x,y\)的最大的\(2\)的整数次幂,则所有\((x+pa,y+qb)(p,q为自然数)\)都加上\(v\),最后有\(m\)次单点询问一个位置的值。\(1\lex,y,V\le10^{18},1\lev,n,m\le2\times10^5\)我们可以......
  • ubunt docker abp 框架 Dockerfile
    #Seehttps://aka.ms/customizecontainertolearnhowtocustomizeyourdebugcontainerandhowVisualStudiousesthisDockerfiletobuildyourimagesforfasterdebugging.FROMmcr.microsoft.com/dotnet/aspnet:7.0ASbase####SQLSERVERTLS版本问题####RUN......
  • GET和POST请求的区别
    HTTP是超文本传输协议,用来定义客户端与服务器数据传输的规范。HTTP服务端默认端口为80,HTTPS默认端口为443,客户端的端口是动态分配的。GET请求和POST请求都是HTTP请求八种(GET、POST、PUT、DELETE、PATCH、HEAD、OPTIONS)方法中的其中一种。1、GETGET请求是一个幂等的请......
  • MongoDB
    MongoDB是一种流行的开源NoSQL数据库管理系统,它专为灵活性、可扩展性和易用性而设计。以下是MongoDB的一些关键特点和概念:1.面向文档的存储:MongoDB是一种面向文档的数据库,意味着它以一种灵活的、半结构化的格式(称为BSON,二进制JSON)来存储数据。每个数据记录都是一个文档,集合中的文......
  • com.qq.weixin.mp.xml.AesException: 签名验证错误
    【已解决】AesException:签名验证错误问题原因:部分语言在url接收时会将+转化为空格导致出错的。这个问题企业微信官方客服两天也没有给个解释,突然就解决了。生气...... ......
  • Jupyter Notebook配置远程服务器
    一、在远程服务器上安装JupyterNotebook首先在服务器端安装Jupyter Notebook并通过配置文件进行相应参数的设置,然后使用本地主机的浏览器远程访问。1.连接远程服务器Win+R输入cmd回车进入命令行 连接远程服务器命令:sshuser名@服务器ip输入密......
  • 取模算术运算符-应用1-奇偶数判断
    C语言中判断一个整数是奇数还是偶数,可以使用取模运算符%。不能直接使用两个整数相除来进行计算,因为直接使用两个整数相除,结果只会保留整数,会舍弃掉小数部分。比如使用C语言计算11/2结果为5,但是11是不能被2整除的,计算结果舍弃掉了小数部分。因此需要使用一个数对2取余,对2取余只......
  • P1105 平台
    贪心枚举,高度高的排在前面,相同高度序号小的排在前面分别遍历左右端点,如果符合条件直接退出,注意俩端点重叠不算在分别用L,R数组记录下标,方便输出点击查看代码#include<bits/stdc++.h>usingnamespacestd;constintN=1e3+10;structnode{ inth,l,r,id;}a[N];bo......