首页 > 其他分享 >pytorch nn.LSTM模块参数详解

pytorch nn.LSTM模块参数详解

时间:2023-08-28 14:13:03浏览次数:51  
标签:10 nn 16 torch pytorch LSTM 100 size

nn.LSTM模块参数

input_size :输入的维度

hidden_size:h的维度

num_layers:堆叠LSTM的层数,默认值为1

bias:偏置 ,默认值:True

batch_first: 如果是True,则input为(batch, seq, input_size)。默认值为:False(seq_len, batch, input_size)

bidirectional :是否双向传播,默认值为False

 

输入

(input_size, hideen_size)

以训练句子为例子,假如每个词是100维的向量,每个句子含有24个单词,一次训练10个句子。那么batch_size=10,seq=24,input_size=100。(seq指的是句子的长度,input_size作为一个的输入) ,所以在设置LSTM网络的过程中input_size=100。由于seq的长度是24,那么这个LSTM结构会循环24次最后输出预设的结果。如下图所示。

预设的hidden_size,这个hideen_size主要是下面LSTM公式中的各个W和b的维度设置,以g_{t}为例子,假设hideen_size为16,则W_{ig}为16*100,x_{t}为100*1,W_{hg}为16*16,h_{t-1}为16*1。

 

输出

output:(seq_len, batch, num_directions * hidden_size

h_n:(num_layers * num_directions, batch, hidden_size)

c_n :(num_layers * num_directions, batch, hidden_size
注:num_directions 表示单向、双向

单向

import torch.nn as nn
import torch
x = torch.rand(10,24,100)
lstm = nn.LSTM(100,16,num_layers=2)
output,(h,c) = lstm(x)
print(output.size())
print(h.size())
print(c.size())
 
output:
torch.Size([24, 10, 16])
torch.Size([2, 10, 16])
torch.Size([2, 10, 16])

双向

import torch.nn as nn
import torch
x = torch.rand(10,24,100)
lstm = nn.LSTM(100,16,bidirectional=True)
output,(h,c) = lstm(x)
print(output.size())
print(h.size())
print(c.size())
 
output:
torch.Size([24, 10, 32])
torch.Size([2, 10, 16])
torch.Size([2, 10, 16])

使用h0、c0

import torch.nn as nn
import torch
x = torch.rand(24,10,100) #seq,batch,input_size
h0 = torch.rand(1,10,16)# num_layers*num_directions, batch, hidden_size
c0 = torch.rand(1,10,16)
lstm = nn.LSTM(100,16)
output,(h,c) = lstm(x,(h0,c0))

 

标签:10,nn,16,torch,pytorch,LSTM,100,size
From: https://www.cnblogs.com/pass-ion/p/17662130.html

相关文章

  • Innodb引擎中B+树一般有几层?能容纳多少数据量?
    1、页在MySQL中InnoDB存储引擎的最小存储单元是页(大小默认是16k,可通过参数设置)。页可用于存放B+树叶节点数据,也可用于存放B+树非叶节点的“键+指针”(也就是路径节点)。在查找数据时一次页的查找代表一次IO,一般B+树高大约为1~3层,所以通过主键索引查询通常只需要1~3次IO......
  • Netty 的 ChannelOption.SO_BACKLOG 知识点整理
    Netty的ChannelOption.SO_BACKLOG知识点整理 一个基于Netty的应用,在压力测试时,Socket请求数量一多,就发送失败,监测JVM内存大小比较稳定,猜测可能是ChannelOption.SO_BACKLOG这个配置导致的,设置的值是128。调整为1024后,连接失败的次数确实减少了一些,那么这个配置到......
  • 学习笔记:DSTAGNN中ST块的代码分析
    DSTAGNN模型可以看我上一个博客学习笔记:DSTAGNN:DynamicSpatial-TemporalAwareGraphNeuralNetworkforTrafficFlowForecasting这篇博客主要写了我对代码中ST块部分的阅读。写这篇模型的初衷,是这篇论文结构图和语言描述不太一致,再加上我想要学习怎么写一个时空预测的代......
  • DWR的注释(annotations)使用及反向调用(Reverse Ajax)
    先说说注释语法,省掉dwr.xml。(自从用了java5之后,现在越看一堆堆的配置文件越烦,越来越喜欢注释方式来的直接简单了)  首先下载最新的稳定版本的dwr.jar文件放到你的工程中。(还有需要其它的吗?不需要了,dwr就是这么简单)然后在web.xml中添加如下一段<!--DWRServlet--><servle......
  • 带你上手基于Pytorch和Transformers的中文NLP训练框架
    本文分享自华为云社区《全套解决方案:基于pytorch、transformers的中文NLP训练框架,支持大模型训练和文本生成,快速上手,海量训练数据》,作者:汀丶。1.简介目标:基于pytorch、transformers做中文领域的nlp开箱即用的训练框架,提供全套的训练、微调模型(包括大模型、文本转向量、文本生......
  • 简单的将pytorch模型部署到onnx
    1.创建一个pytorch模型这里我用的U2Net,直接加载好训练出的权重model=U2Net(class_nums=4)model.load_state_dict(torch.load(checkpoint_path))2.将pytorch模型转成onnx格式x=torcg.randn(1,3,512,512)withtorch.no_grad():torch.onnx.export(......
  • MySQLSTMT函数详解及使用方法(mysql_stmt())
    MySQL_STMT函数详解及使用方法 MySQL_STMT是MySQL提供的一个CAPI,用于执行预处理语句(Preparedstatements)。相比于直接执行SQL,预处理语句具有更高的运行效率和更好的安全性。本文将详细介绍MySQL_STMT函数的使用方法。 1.创建预处理语句 使用MySQL_STMT,需要先创建一个预......
  • Netty源码学习3——Channel ,ChannelHandler,ChannelPipeline
    系列文章目录和关于我零丶引入在Netty源码学习2——NioEventLoop的执行中,我们学习了NioEventLoop是如何进行事件循环以及如何修复NIO空轮询的bug的,但是没有深入了解IO事件在netty中是如何被处理的,下面我们以服务端demo代码为例子,看下和IO事件处理密切的Channel如上在编写nett......
  • 解决:docker 443: connect: network is unreachable
    1、配置镜像加速器您可以通过修改daemon配置文件/etc/docker/daemon.json来使用加速器sudomkdir-p/etc/dockersudotee/etc/docker/daemon.json<<-'EOF'{"registry-mirrors":["https://liadaibh.mirror.aliyuncs.com"]}EOFsudosystemctldaemon-......
  • NC19857 最后的晚餐(dinner)
    题目链接题目题目描述​**YZ(已被和谐)的食堂实在是太挤辣!所以Apojacsleam现在想邀请他的一些好友去校外吃一顿饭,并在某酒店包下了一桌饭。​当Apojacsleam和他的同学们来到酒店之后,他才发现了这些同学们其实是N对cp,由于要保护广大单身狗的弱小心灵(FF!),所以他不想让任意一......