首页 > 其他分享 >PyTorch -- RNN 快速实践

PyTorch -- RNN 快速实践

时间:2024-06-20 10:01:00浏览次数:12  
标签:RNN seq -- batch len PyTorch input hidden size

  • RNN Layer torch.nn.RNN(input_size,hidden_size,num_layers,batch_first)

    • input_size: 输入的编码维度
    • hidden_size: 隐含层的维数
    • num_layers: 隐含层的层数
    • batch_first: ·True 指定输入的参数顺序为:
      • x:[batch, seq_len, input_size]
      • h0:[batch, num_layers, hidden_size]
  • RNN 的输入

    • x:[seq_len, batch, input_size]
      • seq_len: 输入的序列长度
      • batch: batch size 批大小
    • h0:[num_layers, batch, hidden_size]
  • RNN 的输出

    • y: [seq_len, batch, hidden_size]

在这里插入图片描述


  • 实战之预测 正弦曲线:以下会以此为例,演示 RNN 预测任务的部署
    在这里插入图片描述
    • 步骤一:确定 RNN Layer 相关参数值并基于此创建 Net

      import numpy as np
      from matplotlib import pyplot as plt
      
      import torch
      import torch.nn as nn
      import torch.optim as optim
      
      
      seq_len     = 50
      batch       = 1
      num_time_steps = seq_len
      
      input_size  = 1
      output_size = input_size
      hidden_size = 10  	
      num_layers = 1  	
      batch_first = True 
      
      class Net(nn.Module):  ## model 定义
      	def __init__(self):
      		super(Net, self).__init__()
      		self.rnn = nn.RNN(
      		input_size=input_size,
      		hidden_size=hidden_size,
      		num_layers=num_layers,
      		batch_first=batch_first)
      		# for p in self.rnn.parameters():
      		# 	nn.init.normal_(p, mean=0.0, std=0.001)
      		self.linear = nn.Linear(hidden_size, output_size)
      
      	def forward(self, x, hidden_prev):
      		out, hidden_prev = self.rnn(x, hidden_prev)
      		# out: [batch, seq_len, hidden_size]
      		out = out.view(-1, hidden_size)  # [batch*seq_len, hidden_size]
      		out = self.linear(out) 			 # [batch*seq_len, output_size]
      		out = out.unsqueeze(dim=0)    # [1, batch*seq_len, output_size]
      		return out, hidden_prev
      
    • 步骤二:确定 训练流程

      lr=0.01
      
      def tarin_RNN():
          model = Net()
          print('model:\n',model)
          criterion = nn.MSELoss()
          optimizer = optim.Adam(model.parameters(), lr)
          hidden_prev = torch.zeros(num_layers, batch, hidden_size)  #初始化h
      
          l = []
          for iter in range(100):  # 训练100次
              start = np.random.randint(10, size=1)[0]  ## 序列起点
              time_steps = np.linspace(start, start+10, num_time_steps)  ## 序列
              data = np.sin(time_steps).reshape(num_time_steps, 1)  ## 序列数据
      
              x = torch.tensor(data[:-1]).float().view(batch, seq_len-1, input_size)
              y = torch.tensor(data[1: ]).float().view(batch, seq_len-1, input_size)  # 目标为预测一个新的点
      
              output, hidden_prev = model(x, hidden_prev)
              hidden_prev = hidden_prev.detach()  ## 最后一层隐藏层的状态要 detach
      
              loss = criterion(output, y)
              model.zero_grad()
              loss.backward()
              optimizer.step()
      
              if iter % 100 == 0:
                  print("Iteration: {} loss {}".format(iter, loss.item()))
                  l.append(loss.item())
          #############################绘制损失函数#################################
          plt.plot(l,'r')
          plt.xlabel('训练次数')
          plt.ylabel('loss')
          plt.title('RNN LOSS')
          plt.savefig('RNN_LOSS.png')
          return hidden_prev,model
      
       hidden_prev,model = tarin_RNN()
      
    • 步骤三:测试训练结果

      start = np.random.randint(3, size=1)[0]  ## 序列起点
      time_steps = np.linspace(start, start+10, num_time_steps)  ## 序列
      data = np.sin(time_steps).reshape(num_time_steps, 1)  ## 序列数据
      x = torch.tensor(data[:-1]).float().view(batch, seq_len-1, input_size)
      y = torch.tensor(data[1: ]).float().view(batch, seq_len-1, input_size)  # 目标为预测一个新的点    
      
      predictions = []  ## 预测结果
      input = x[:,0,:]
      for _ in range(x.shape[1]):
          input = input.view(1, 1, 1)
          pred, hidden_prev = model(input, hidden_prev)
          input = pred  ## 循环获得每个input点输入网络
          predictions.append(pred.detach().numpy()[0])
      x= x.data.numpy()
      y = y.data.numpy( )
      plt.scatter(time_steps[:-1], x.squeeze(), s=90)
      plt.plot(time_steps[:-1], x.squeeze())
      plt.scatter(time_steps[1:],predictions)  ## 黄色为预测
      plt.show()
      

      在这里插入图片描述


【高阶】上述例子比较简单,便于入门以推理到自己的目标任务,实际 RNN 训练可能更有难度,可以添加

  • 对于梯度爆炸的解决:
    for p in model.parameters()"
    	p.grad.nomr()
    	torch.nn.utils.clip_grad_norm_(p, 10)  ## 其中的 norm 后面的_ 表示 in place
    
  • 对于梯度消失的解决:-> LSTM

标签:RNN,seq,--,batch,len,PyTorch,input,hidden,size
From: https://blog.csdn.net/CODE_RabbitV/article/details/139727458

相关文章

  • 【2024最新精简版】ElasticSearch面试篇
    文章目录你们项目中主要使用ElasticSearch干什么什么是正向索引?什么是倒排索引?......
  • 【Power Compiler手册】9.时钟门控(2)
    指定时钟门控延迟在综合过程中,DesignCompiler假设时钟是理想的。理想时钟在时钟网络中不产生任何延迟。这种假设是因为直到时钟树综合之后,实际的时钟网络延迟才为人所知。实际上,时钟并不是理想的,并且通过时钟网络存在非零延迟。对于具有时钟门控的设计,寄存器处的时钟网络延......
  • 【单片机毕业设计选题24017】-基于STM32的禽舍环境监测控制系统(蓝牙版)
    系统功能:系统分为主机端和从机端,主机端主动向从机端发送信息和命令,从机端收到主机端的信息后回复温湿度氨气浓度和光照强度等信息。主要功能模块原理图:电源时钟烧录接口:单片机和按键输入电路:主机部分电路:从机部分电路:资料获取地址主从机部分代码:初......
  • Java毕业设计-基于springboot开发的网上购物商城系统研发-毕业论文(附毕设源代码)
    文章目录前言一、毕设成果演示(源代码在文末)二、毕设摘要展示1、开发说明2、需求/流程分析3、系统功能结构三、系统实现展示1、用户功能模块的实现1.1用户注册界面1.2用户登录界面1.3个人中心界面1.4商品详情界面1.5购物车界面1.6我的订单界面1.7我的地址界面2、管理员......
  • 生物信息学
    HGP的主要人物是人类的DNA测序遗传图谱:连锁分析法物理图谱:指DNA链的限制性酶切片段的排列顺序序列图谱:测序获得基因图谱(genemap):基因图谱是在识别基因组所包含的蛋白质编码序列的基础上绘制的结合有关基因序列、位置及表达模式等信息的图谱。生物信息学与组学作为新兴的交......
  • 唯一工业操作系统!蓝卓supOS入榜中国500最具价值品牌
    6月19日在第21届世界品牌大会上世界品牌实验室(WorldBrandLab)正式发布2024年《中国500最具价值品牌》分析报告蓝卓supOS以131.65亿元的品牌价值成为国内首个入榜工业操作系统国内首个工业操作系统的蜕变之旅蓝卓supOS快速迭代指数增长不断引领海内外多个首创•......
  • FM151A,FM1202和利时备品
     FM151A,FM1202和利时备品。现场总线于DCS系统I/O总线上的集成――通过一个现场总线接口卡挂在DCS的I/O总线上,FM151A,FM1202和利时备品。使得在DCS控制器所看到的现场总线来的信息就如同来自一个传统的DCS设备卡一样。对原有系统影响较小系列现场总线产品可以实现在DCS系统网......
  • #C:比如有如下两个字符串:“hello“、“helhehe“不能使用库函数完成字符串的比较
    #include<stdio.h>#include<string.h>voidtest00()//比如有如下两个字符串:"hello"、"helhehe"不能使用库函数完成字符串的比较{ //时刻需要注意变量i的值  charbuf1[128]="";   printf("请输入第一个字符串buf1:");   //scanf("%s",buf1);......
  • 【计算机毕业设计】208基于微信小程序的二手物品交易平台
    ......
  • 前端:异地登录!!!
    哈喽,大家好!今天来聊一聊前端怎么实现异地登录提示!!!在数字化时代,账户安全是每个用户和开发者都不容忽视的问题。异地登录提示是一种安全措施,用于提醒用户他们的账户可能在不同的位置被访问。这通常涉及到检测登录行为的异常,比如IP地址的变化,并在检测到异常时通知用户。用户登......