首页 > 其他分享 >LSTM结构原理与代码实践

LSTM结构原理与代码实践

时间:2023-08-09 21:36:25浏览次数:42  
标签:word idx 代码 list 实践 len tag LSTM data

近日学习LSTM结构,已有很多博客对LSTM结构进行说明,但某些细节仍然存在模糊情况,为此本文将进行补充与说明,可分以下内容:

一.RNN原理简介与LSTM原理阐释。

  一般来说,RNN的输入和输出都是一个序列,分别记为

有关(序列中的第t个元素我们叫做序列在time_step=t时的取值)。

注:seqin={x1,x2,...,x3}可假设指代有顺序的句子序列长度,如 I Like code ,其中x1=I,x2=Like,x3=code;以此类推seqout指代我们想输出结果。

更直观的理解可看下图:

 

LSTM结构原理与代码实践_激活函数

 

 

LSTM是一种特殊的RNN,主要通过三个门控逻辑实现(遗忘、输入、输出)。它的提出就是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。 

下图是一个LSTM结构示意图,如Xt指代Like单词:

LSTM结构原理与代码实践_词性_02

 

 

 以上可看出输出为yt,ct和ht

LSTM结构原理与代码实践_词性_03

 

LSTM结构原理与代码实践_词性_04

LSTM结构原理与代码实践_词性_05

 求解Ct公式 

LSTM结构原理与代码实践_取值_06

 求解ht

 

 σ函数表示sigmoid函数

更详细解释如下:

LSTM结构原理与代码实践_词性_07

 

 其中, 

LSTM结构原理与代码实践_激活函数_08

 

 

与普通RNN类似,输出 

 

二.LSTM代码如下:

注:主要调用nn.LSTM

 

'''
本程序实现了对单词词性的判断,输入一句话,输出该句话中每个单词的词性。
'''

import torch
import torch.nn.functional as F
from torch import nn, optim

from tqdm import tqdm

training_data = [("The dog ate the apple".split(), ["DET", "NN", "V", "DET", "NN"]),
                 ("Everybody read that book".split(), ["NN", "V", "DET", "NN"])]

def build_data(training_data):
    # 构建数据集
    # 数据转换方法
    word_to_idx = {}
    tag_to_idx = {}
    for context, tag in training_data:
        for word in context:
            if word not in word_to_idx:
                word_to_idx[word] = len(word_to_idx)
        for label in tag:
            if label not in tag_to_idx:
                tag_to_idx[label] = len(tag_to_idx)
    idx_to_tag = {tag_to_idx[tag]: tag for tag in tag_to_idx}

    return word_to_idx,tag_to_idx,idx_to_tag

class LSTMTagger(nn.Module):
    def __init__(self, n_word, n_dim, n_hidden, n_tag):
        super(LSTMTagger, self).__init__()
        self.word_embedding = nn.Embedding(n_word, n_dim)
        self.lstm = nn.LSTM(n_dim, n_hidden, batch_first=True)  # nn.lstm()接受的数据输入是(序列长度,batch,输入维数),
        # 这和我们cnn输入的方式不太一致,所以使用batch_first=True,把输入变成(batch,序列长度,输入维度),本程序的序列长度指的是一句话的单词数目
        # 同时,batch_first=True会改变输出的维度顺序。<br data-filtered="filtered">
        self.linear1 = nn.Linear(n_hidden, n_tag)

    def forward(self, x):            # x是word_list,即单词的索引列表,size为len(x)
        x = self.word_embedding(x)   # embedding之后,x的size为(len(x),n_dim)
        x = x.unsqueeze(0)           # unsqueeze之后,x的size为(1,len(x),n_dim),1在下一行程序的lstm中被当做是batchsize,len(x)被当做序列长度
        x, _ = self.lstm(x)          # lstm的隐藏层输出,x的size为(1,len(x),n_hidden),因为定义lstm网络时用了batch_first=True,所以1在第一维,如果batch_first=False,则len(x)会在第一维
        x = x.squeeze(0)             # squeeze之后,x的size为(len(x),n_hidden),在下一行的linear层中,len(x)被当做是batchsize
        x = self.linear1(x)          # linear层之后,x的size为(len(x),n_tag)
        y = F.log_softmax(x, dim=1)  # 对第1维先进行softmax计算,然后log一下。y的size为(len(x),n_tag)。
        return y

word_to_idx, tag_to_idx, idx_to_tag=build_data(training_data)

def main():
    # 用于训练
    model = LSTMTagger(len(word_to_idx), 100, 128, len(tag_to_idx))  # 模型初始化
    if torch.cuda.is_available():
        model = model.cuda()
    criterion = nn.NLLLoss()
    optimizer = optim.SGD(model.parameters(), lr=1e-2)
    for epoch in tqdm(range(200)):
        running_loss = 0
        for data in training_data:
            sentence, tags = data
            word_list = [word_to_idx[word] for word in sentence]     # word_list是word索引列表
            word_list = torch.LongTensor(word_list)
            tag_list = [tag_to_idx[tag] for tag in tags]             # tag_list是tag索引列表
            tag_list = torch.LongTensor(tag_list)
            if torch.cuda.is_available():
                word_list = word_list.cuda()
                tag_list = tag_list.cuda()
            # forward
            out = model(word_list)
            loss = criterion(out, tag_list)
            running_loss += loss.data.cpu().numpy()
            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print('Epoch: {:<3d} | Loss: {:6.4f}'.format(epoch, running_loss / len(data)))


    # 模型测试
    test_sentence = "Everybody ate the apple"
    print('\n The test sentence is:\n', test_sentence)
    test_sentence = test_sentence.split()
    test_list = [word_to_idx[word] for word in test_sentence]
    test_list = torch.LongTensor(test_list)
    if torch.cuda.is_available():
        test_list = test_list.cuda()

    out = model(test_list)
    _, predict_idx = torch.max(out, 1)  # 1表示找行的最大值。 predict_idx是词性索引,是一个size为([len(test_sentence)]的张量
    predict_tag = [idx_to_tag[idx] for idx in list(predict_idx.cpu().numpy())]
    print('The predict tags are:', predict_tag)


if __name__ == '__main__':
    main()

 

结果如下:

LSTM结构原理与代码实践_取值_09

 

 

 

 

 

 

借鉴内容如下:

https://zhuanlan.zhihu.com/p/32085405

https://zhuanlan.zhihu.com/p/128098497

标签:word,idx,代码,list,实践,len,tag,LSTM,data
From: https://blog.51cto.com/u_16162011/7025557

相关文章

  • 残差网络ResNet(超详细代码解析) :你必须要知道backbone模块成员之一
      本文主要贡献代码模块(文末),在本文中对resnet进行了复现,是一份原始版本模块,里面集成了权重文件pth的载入模块(如函数:init_weights(self,pretrained=None)),layers的冻结模块(如函数:_freeze_stages(self)),更是将其改写成可读性高的代码,若你需要执行该模块,可直接将其代码模块粘......
  • [代码随想录]Day13-二叉树part02
    题目:102.二叉树的层序遍历思路:先把根放进去,然后每次都是左右就可以了。记录一个深度,当len(res)==deepth的时候就说明这个深度还没有实例化,先搞一个再去收集。代码:/***Definitionforabinarytreenode.*typeTreeNodestruct{*Valint*Left*TreeN......
  • MATLAB用改进K-Means(K-均值)聚类算法数据挖掘高校学生的期末考试成绩|附代码数据
    全文链接:http://tecdat.cn/?p=30832最近我们被客户要求撰写关于K-Means(K-均值)聚类算法的研究报告,包括一些图形和统计输出。本文首先阐明了聚类算法的基本概念,介绍了几种比较典型的聚类算法,然后重点阐述了K-均值算法的基本思想,对K-均值算法的优缺点做了分析,回顾了对K-均值改进......
  • Stata广义矩量法GMM面板向量自回归PVAR模型选择、估计、Granger因果检验分析投资、收
    原文链接:http://tecdat.cn/?p=24016原文出处:拓端数据部落公众号摘要最近我们被要求撰写关于广义矩量法GMM的研究报告,包括一些图形和统计输出。面板向量自回归(VAR)模型在应用研究中的应用越来越多。虽然专门用于估计时间序列VAR模型的程序通常作为标准功能包含在大多数统计软件......
  • 架构师必备:商业选型与项目部署实践
    标题:架构师必备:商业选型与项目部署实践引言:作为一名架构师,商业选型和项目部署是你工作中至关重要的两个环节。商业选型涉及到选择合适的技术方案和工具,以满足企业的商业需求和目标。而项目部署则是将这些选型结果实际应用于项目中,确保项目的高效运行和顺利交付。本文将深入探讨商......
  • Thinkphp 5.0.23 远程代码执行漏洞
    漏洞简介ThinkPHP是一款运用极广的PHP开发框架。在ThinkPHP5.0.23以前的版本中,获取method的方法中没有正确处理方法名,导致攻击者可以调用Request类任意方法并构造利用链,从而导致远程代码执行漏洞。漏洞复现开启vulhub靶场环境,确保ThinkPHP正常运行cdvulhub-master/thinkp......
  • 一文理解GIT的代码冲突
    对于GIT,不知道有没有人和我一样,很长时间都是小心翼翼、紧张兮兮,生怕一不小心,自己辛苦写的代码没了。特别是代码冲突,更是难到我无法理解,每次都要求助于百度,跟着人家的教程一步步解决,下一次还是这样。所有的紧张、不自信、不敢用、用不好,都来源于:不理解。只要理解了,你会发现所有......
  • 每个微服务对应一个代码库吗?
    你是把每个微服务放在它自己的git存储库中,还是使用monorepo?如果是后者,您如何在同一个repo中处理多个服务?回答1.我一直为每个服务使用一个repo,但这主要是因为我们在工作中使用maven和GitHub。我发现monorepo的想法很有趣,但我一直无法找到正确的工具,也不想花时间自己动......
  • prompt gating代码探索
    importtorchdefpromptGating(gating,adding,x):'''gating:(num_prefix,dim)adding:(num_prefix,dim)x:(seq_length,batch_size,dim)'''ifgatingisnotNone:gating=gating.unsque......
  • Modbus通信协议实践(1)-通过modbusRTU实现TPC7022kt对电流表数据的读取
    需求:1.昆仑通泰TPC7022kt触摸屏2.安装了MCGSpro的PC一台3.能使用RS485通讯协议的数显直流电流表一个 操作步骤:1.以常规1mm电线和双绞线连接电流表和昆仑通泰触摸屏,网线连接触摸屏和pc。2.根据数显直流表的说明书,设置该表的通讯地址为01,波特率为9600,N81无校验位,8个数据位,1......