首页 > 其他分享 >机器学习日志 手写数字识别 pytorch 神经网络

机器学习日志 手写数字识别 pytorch 神经网络

时间:2023-03-05 13:44:52浏览次数:53  
标签:tmp nn self torch 28 神经网络 batch pytorch 日志

我是链接

第一次用pytorch写机器学习,不得不说是真的好用

pytorch的学习可以看这里,看看基本用法就行,个人感觉主要还是要看着实践代码来学习

总结了几个点:

1.loss出现nan

这个让我头疼了好久,主要有两个方面吧:一是学习率可能太高了,可以调低一点试试。二是对于这个数据,黑白值颜色深度是用0255来表示的,让每个颜色深度除以255变成01来表示,结果会好很多,准确率也会高很多。

另外听说训练数据里有nan inf 除以0 也会出现nan

2.训练的时候犯傻了,好几万的数据训练的时候只用了前几百个。我还纳闷为啥准确率那么低(89%左右),后来发现barch挺大,但只有几百个训练了。

把上面说的改了后正确率到达98.2%,后面慢慢来改进

神经网络的模型是看着网上的搭建的,2个卷积层和池化层,2个全连接层

import torch
import torch.nn as nn
import pandas as pd
import numpy as np


class WYJ_CNN(nn.Module):
    def __init__(self):
        super(WYJ_CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16,  # 输入通道1 输出通道16
                               kernel_size=(3, 3),  # 卷积核尺寸
                               stride=(1, 1),  # 卷积核每次移动多少个像素
                               padding=1)  # 原图片边缘加几个空白像素
        # 输入尺寸1×28×28
        # 输出尺寸16×28×28
        self.pool1 = nn.MaxPool2d(kernel_size=2)  #第一次池化,尺寸16×14×14
        self.conv2 = nn.Conv2d(16, 32, 3, 1, 1)  #第二次卷积,尺寸32×14×14
        self.pool2 = nn.MaxPool2d(2)  #第二次池化,尺寸32×7×7
        self.zhankai = nn.Flatten()#展平为张量,尺寸是一维1568(32*7*7)
        self.lin1 = nn.Linear(32 * 7 * 7, 16)#尺寸16
        self.jihuo = nn.ReLU()#激活函数
        self.lin2 = nn.Linear(16, 10)#尺寸10

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.zhankai(x)
        x = self.lin1(x)
        x = self.jihuo(x)
        x = self.lin2(x)
        return x


myTrain = pd.read_csv('train.csv')
vals = np.array(myTrain.values)
labels = torch.from_numpy(myTrain.values[:, 0])

net = WYJ_CNN()
CalcLoss = nn.CrossEntropyLoss()#loss用交叉熵来算
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)#lr是学习率

batch = 128
for cnt in range(10):
    for i in range(len(vals) // batch):
        tmp = vals[i * batch:(i + 1) * batch, 1:] / 255#将0~255的颜色深度转化为0~1的深度
        tmp = tmp.reshape(batch, 1, 28, 28)
        tmp = torch.from_numpy(tmp).float()
        outputs = net(tmp)
        # print(outputs)
        loss = CalcLoss(outputs, labels[i * batch:(i + 1) * batch])

        # loss = loss.requires_grad_()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i * batch % 1000 == 0:
            print("training", i * batch)

torch.save(net, "my_cnn.nn")

# net = torch.load("my_cnn.nn")
myTest = pd.read_csv('test.csv')
ImageId = [];
Label = [];

for i in range(len(myTest)):
    tmp = np.array(myTest.values[i][:]) / 255
    tmp = tmp.reshape(1, 1, 28, 28)
    tmp = torch.from_numpy(tmp).float()
    b = net(tmp)

    ImageId.append(i + 1)
    Label.append(b.argmax().item())

    if i % 1000 == 0:
        print("testing", i)

myAns = pd.DataFrame({'ImageId': ImageId, 'Label': Label})
myAns.to_csv("myAns.csv", index=False, sep=',')

标签:tmp,nn,self,torch,28,神经网络,batch,pytorch,日志
From: https://www.cnblogs.com/wljss/p/17180341.html

相关文章

  • 一种通过nacos动态配置实现多租户的log4j2日志物理隔离的设计
    1、背景1.1、背景旧服务改造为多租户服务后,log4j日志打印在一起不能区分是哪个租户的,日志太多,太杂,不好定位排除问题,排查问题较难。1.2、前提不改动以前的日志代码(工作......
  • Logging日志
    Logging日志什么时候用到日志?可预知的情况,写日志不可预知情况,写日志基础版CRITICAL=50FATAL=CRITICALERROR=40WARNING=30WARN=WARNINGINFO=20DEB......
  • 【java】动态修改日志级别
    背景开发过程中,为了方便问题快速定位,都会在代码中增加相关日志生产环境中,为了减少日志输出量,需要提高日志级别,节约资源。如果能动态修改日志级别,当出现问题时,动态降低......
  • Mysql 中二进制日志的初步认知
    二进制日志二进制日志中以“事件”的形式记录了数据库中数据的变化情况,对于MySQL数据库的灾难恢复起着重要的作用。开启二进制日志可以在​​my.cnf​​​文件或者​​my......
  • 21_Spring_日志框架和测试支持
    ​ spring5框架自带了通用的日志封装,也可以整合自己的日志 1)spring移除了LOG4jConfigListener,官方建议使用log4j2 2)spring5整合log4j2导入log4j2依赖 <......
  • 21_Spring_日志框架和测试支持
     spring5框架自带了通用的日志封装,也可以整合自己的日志 1)spring移除了LOG4jConfigListener,官方建议使用log4j2 2)spring5整合log4j2导入log4j2依赖 <!--log4j2......
  • 21_Spring_日志框架和测试支持
     spring5框架自带了通用的日志封装,也可以整合自己的日志 1)spring移除了LOG4jConfigListener,官方建议使用log4j2 2)spring5整合log4j2导入log4j2依赖 <!--log4j2......
  • 21_Spring_日志框架和测试支持
    ​ spring5框架自带了通用的日志封装,也可以整合自己的日志 1)spring移除了LOG4jConfigListener,官方建议使用log4j2 2)spring5整合log4j2导入log4j2依赖 <......
  • 吴恩达卷积神经网络——人脸识别和神经风格转换
    1.人脸识别人脸验证(FaceVerification)和人脸识别(FaceRecognition)的区别:人脸验证:一般指一个一对一问题,只需要验证输入的人脸图像是否与某个已知的......
  • 安装pytorch报错 ERROR: Could not install packages due to an OSError: [Errno 28]
    windos安装,报错如下  看了不少回答,大概是缓存和内存满了我的C盘只给了70G,然后意外发现只剩下3G多了,先用系统自带的清理工具清理了一下,然后腾讯电脑管家“工具箱”中......