首页 > 其他分享 >LSTM使用MNIST手写数字识别实战的代码和心得

LSTM使用MNIST手写数字识别实战的代码和心得

时间:2024-02-22 17:23:59浏览次数:30  
标签:hidden 100 self 128 LSTM out 手写 MNIST size

RNN的架构除了RNN类中的模型不同,其他的构架与CNN类似,如果还没有阅读过CNN文章的可以点击下方链接进入:
CNN使用MNIST手写数字识别实战的代码和心得
LSTM(Long Short-Term Memory长短时记忆网络)虽然在MNIST手写数字识别方面不擅长,但是也可以进行使用,效果比CNN略显逊色

对LSTM使用MNIST手写数字识别的思路图
LSTM进行MNIST手写数字识别
LSTM是在RNN的主线基础上增加了支线,增加了三个门,输入门,输出门和忘记门。
避免了可能因为加权问题,使程序忘记之前的内容,梯度弥散或者梯度爆炸。
batch_size在这里选取的是100,选择了一个隐藏层和128的神经元,对LSTM结构进行部署,
MNIST长宽为28,选取一行28作为一份数据传入input_size,RNN是按照时间序列进行传值,batch_size为100,也就是在每次传入的数据为(128,28)
进入隐藏层后,out结果张量的shape为([100, 28, 128])
在out[:, -1, :]时间序列中取得最后一次的输出,得到([100, 128])
再进入全连接层后将hidden_size的128变为所需要的输出的10种图片的维度([100, 10])

对超参数的定义
#定义超参数
input_size = 28
time_step = 28# 时间序列
Layers = 1# 隐藏单元的个数
hidden_size = 128# 每个隐藏单元中神经元个数
classes = 10
batch_size = 100
EPOCHS = 10
learning_rate = 0.01 #学习率
RNN对于数据的读取有别于CNN,按照时间来读取,在这里可以将input_size看作是图片的长,而time_step看作宽的长度。

    #Long Short-Term Memory(长短时记忆网络)
    class RNN(nn.Module):
        def __init__(self, input_size, hidden_size, Layers, classes):
            super(RNN, self).__init__()
            self.Layers = Layers
            self.hidden_size = hidden_size
            self.lstm = nn.LSTM(input_size, hidden_size, Layers, batch_first=True)
            self.fc = nn.Linear(hidden_size, classes)
        def forward(self, x):
            # 设置初始隐藏状态和单元格状态
            h0 = torch.zeros(self.Layers, x.size(0), self.hidden_size).to(device)
            c0 = torch.zeros(self.Layers, x.size(0), self.hidden_size).to(device)
            # out张量的shape(batch_size, time_step, hidden_size)
            out, _ = self.lstm(x, (h0, c0))#torch.Size([100, 28, 128])
            #out[:, -1, :].shape torch.Size([100, 128])
            # 只得到时间顺序点的最后一步
            out = self.fc(out[:, -1, :])#torch.Size([100, 10])
            return out
            ```
运行结果:
```python
    RNN(
      (lstm): LSTM(28, 128, batch_first=True)
      (fc): Linear(in_features=128, out_features=10, bias=True)
    )
    Epoch [10/10],  Loss: 0.0115
    Test Accuracy to test: 98.07 %

标签:hidden,100,self,128,LSTM,out,手写,MNIST,size
From: https://www.cnblogs.com/fly-kiss/p/18027775

相关文章

  • CNN使用MNIST手写数字识别实战的代码和心得
    CNN(ConvolutionalNeuralNetwork)卷积神经网络对于MNIST手写数字识别的实战代码和心得首先是对代码结构思路进行思路图展示,如下:参数和原理剖析:因为MNIST图片为长和宽相同的28像素,为黑白两色,所以图片的高度为1,为灰度通道。在传入的时候,我定义的BATCH_SIZE为512,所以具体的......
  • LSTM 策略应用在量化交易领域的一点猜想
      LSTM(LongShortTermMemory),对于NLP(自然语言处理)和连续拍照的处理时,有额外的优势.在交易领域,最多的是应用于预判未来走势.  在自然语言处理时,将语句分为一个个单词,并预判下一个词汇.   同理:在K线图中,最简单的模式是以OHLCV,即一个Bar被当作一......
  • 每日笔记-LSTM
    今天,搞了一段代码,但没有达到应有的效果importtorchimporttorch.nnasnnimportnumpyasnp#设置随机种子以便结果可重复torch.manual_seed(42)#定义一个更复杂的LSTM模型classComplexLSTMModel(nn.Module):def__init__(self,input_size,hidden_size,ou......
  • Python用GAN生成对抗性神经网络判别模型拟合多维数组、分类识别手写数字图像可视化
    全文链接:https://tecdat.cn/?p=33566原文出处:拓端数据部落公众号生成对抗网络(GAN)是一种神经网络,可以生成类似于人类产生的材料,如图像、音乐、语音或文本。最近我们被客户要求撰写关于GAN生成对抗性神经网络的研究报告,包括一些图形和统计输出。近年来,GAN一直是研究的热门话题。F......
  • 晚上调代码时写对拍程序之——为了不手写平衡树而乱搞的可支持随机访问、快速插入、快
    前言由于需要一个可支持随机访问、快速插入、快速删除的数据结构,但是我除了平衡树实在是想不到别的东西了,于是就乱搞出了一个这样的东西——abstract数组。但是,这玩意好像码量和平衡树差不多......不过!我认为她还是有优点的:相比起平衡树,她应该更不容易出锅?总之,不管怎么样,还是......
  • Java集合篇之逐渐被遗忘的Stack,手写一个栈你会吗?
    正月初九,开工大吉!2024年,更上一层楼!写在开头其实在List的继承关系中,除了ArrayList和LinkedList之外,还有另外一个集合类stack(栈),它继承自vector,线程安全,先进后出,随着Java并发编程的发展,它在很多应用场景下被逐渐替代,成为了Java的遗落之类。不过,stack在数据结构中仍有一席之地,因此,......
  • 手写Promise
    目录参考资料Promises介绍文档Promises/A+规范Promises的一种实现方式github上2.6k+star的一个Promise实现方式手写Promise测试运行执行结果参考资料Promises介绍文档Promises/A+规范Promises的一种实现方式github上2.6k+star的一个Promise实现方式手写......
  • 手写实现cni插件
    k8sv1.19.0mycni配置文件cat>>/etc/cni/net.d/mycni.json<<EOF{"cniVersion":"0.2.0","name":"mycni","type":"mycni"}EOFtype对应/opt/cni/bin目录下二进制文件。mycni代码并编译mkdir/run/n......
  • pytorch MNIST数据集手写数字识别
    MNIST包括6万张28x28的训练样本,1万张测试样本,很多教程都会对它”下手”几乎成为一个“典范”,可以说它就是计算机视觉里面的HelloWorld。所以我们这里也会使用MNIST来进行实战。importtorchimporttorch.nnasnnimporttorch.nn.functionalasFimporttorch.optimasopt......
  • 【Python】基于动态残差学习的堆叠式LSTM模型和传统BP在股票预测中的应用
    1.前言本论文探讨了长短时记忆网络(LSTM)和反向传播神经网络(BP)在股票价格预测中的应用。首先,我们介绍了LSTM和BP在时间序列预测中的基本原理和应用背景。通过对比分析两者的优缺点,我们选择了LSTM作为基础模型,因其能够有效处理时间序列数据中的长期依赖关系,在基础LSTM模型的基础上,......