首页 > 其他分享 >深度学习--LSTM网络、使用方法、实战情感分类问题

深度学习--LSTM网络、使用方法、实战情感分类问题

时间:2023-04-26 14:25:01浏览次数:44  
标签:实战 acc rnn -- torch LSTM data hidden

深度学习--LSTM网络、使用方法、实战情感分类问题

1.LSTM基础

长短期记忆网络(Long Short-Term Memory,简称LSTM),是RNN的一种,为了解决RNN存在长期依赖问题而设计出来的。

LSTM的基本结构:

网络图

2.LSTM的具体说明

LSTM与RNN的结构相比,在参数更新的过程中,增加了三个门,由左到右分别是遗忘门(也称记忆门)、输入门、输出门。

图片来源:

https://www.elecfans.com/d/672083.html

1.点乘操作决定多少信息可以传送过去,当为0时,不传送;当为1时,全部传送。

2.1 遗忘门

对于输入xt和ht-1,遗忘门会输出一个值域为[0, 1]的数字,放进Ct−1中。当为0时,全部删除;当为1时,全部保留。

遗忘门

2.2 输入门

对于对于输入xt和ht-1,输入门会选择信息的去留,并且通过tanh激活函数更新临时Ct

输入门

通过遗忘门和输入门输出累加,更新最终的Ct

更新Ct

2.3输出门

通过Ct和输出门,更新memory

输出门

3.PyTorch的LSTM使用方法

  1. __ init __(input _ size, hidden_size,num _layers)

  2. LSTM.foward():

​ out,[ht,ct] = lstm(x,[ht-1,ct-1])

​ x:[一句话单词数,batch几句话,表示的维度]

​ h/c:[层数,batch,记忆(参数)的维度]

​ out:[一句话单词数,batch,参数的维度]

import torch
import torch.nn as nn

lstm = nn.LSTM(input_size = 100,hidden_size = 20,num_layers = 4)
print(lstm)
#LSTM(100, 20, num_layers=4)

x = torch.randn(10,3,100)
out,(h,c)=lstm(x)

print(out.shape,h.shape,c.shape)
#torch.Size([10, 3, 20]) torch.Size([4, 3, 20]) torch.Size([4, 3, 20])

单层使用方法:

cell = nn.LSTMCell(input_size = 100,hidden_size=20)

x = torch.randn(10,3,100)
h = torch.zeros(3,20)
c = torch.zeros(3,20)

for xt in x:
    h,c = cell(xt,[h,c])
    
print(h.shape,c.shape)

#torch.Size([3, 20]) torch.Size([3, 20])

LSTM实战--情感分类问题

Google CoLab环境,需要魔法。

import torch
from torch import nn, optim
from torchtext import data, datasets

print('GPU:', torch.cuda.is_available())

torch.manual_seed(123)

TEXT = data.Field(tokenize='spacy')
LABEL = data.LabelField(dtype=torch.float)
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)

print('len of train data:', len(train_data))
print('len of test data:', len(test_data))

print(train_data.examples[15].text)
print(train_data.examples[15].label)

# word2vec, glove
TEXT.build_vocab(train_data, max_size=10000, vectors='glove.6B.100d')
LABEL.build_vocab(train_data)


batchsz = 30
device = torch.device('cuda')
train_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, test_data),
    batch_size = batchsz,
    device=device
)

class RNN(nn.Module):
    
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        """
        """
        super(RNN, self).__init__()
        
        # [0-10001] => [100]
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        # [100] => [256]
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=2, 
                           bidirectional=True, dropout=0.5)
        # [256*2] => [1]
        self.fc = nn.Linear(hidden_dim*2, 1)
        self.dropout = nn.Dropout(0.5)
        
        
    def forward(self, x):
        """
        x: [seq_len, b] vs [b, 3, 28, 28]
        """
        # [seq, b, 1] => [seq, b, 100]
        embedding = self.dropout(self.embedding(x))
        
        # output: [seq, b, hid_dim*2]
        # hidden/h: [num_layers*2, b, hid_dim]
        # cell/c: [num_layers*2, b, hid_di]
        output, (hidden, cell) = self.rnn(embedding)
        
        # [num_layers*2, b, hid_dim] => 2 of [b, hid_dim] => [b, hid_dim*2]
        hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
        
        # [b, hid_dim*2] => [b, 1]
        hidden = self.dropout(hidden)
        out = self.fc(hidden)
        
        return out

rnn = RNN(len(TEXT.vocab), 100, 256)

pretrained_embedding = TEXT.vocab.vectors
print('pretrained_embedding:', pretrained_embedding.shape)
rnn.embedding.weight.data.copy_(pretrained_embedding)
print('embedding layer inited.')

optimizer = optim.Adam(rnn.parameters(), lr=1e-3)
criteon = nn.BCEWithLogitsLoss().to(device)
rnn.to(device)

import numpy as np

def binary_acc(preds, y):
    """
    get accuracy
    """
    preds = torch.round(torch.sigmoid(preds))
    correct = torch.eq(preds, y).float()
    acc = correct.sum() / len(correct)
    return acc

def train(rnn, iterator, optimizer, criteon):
    
    avg_acc = []
    rnn.train()
    
    for i, batch in enumerate(iterator):
        
        # [seq, b] => [b, 1] => [b]
        pred = rnn(batch.text).squeeze(1)
        # 
        loss = criteon(pred, batch.label)
        acc = binary_acc(pred, batch.label).item()
        avg_acc.append(acc)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if i%10 == 0:
            print(i, acc)
        
    avg_acc = np.array(avg_acc).mean()
    print('avg acc:', avg_acc)
    
    
def eval(rnn, iterator, criteon):
    
    avg_acc = []
    
    rnn.eval()
    
    with torch.no_grad():
        for batch in iterator:

            # [b, 1] => [b]
            pred = rnn(batch.text).squeeze(1)

            #
            loss = criteon(pred, batch.label)

            acc = binary_acc(pred, batch.label).item()
            avg_acc.append(acc)
        
    avg_acc = np.array(avg_acc).mean()
    
    print('>>test:', avg_acc)

for epoch in range(10):
    
    eval(rnn, test_iterator, criteon)
    train(rnn, train_iterator, optimizer, criteon)

标签:实战,acc,rnn,--,torch,LSTM,data,hidden
From: https://www.cnblogs.com/ssl-study/p/17355727.html

相关文章

  • 为什么AutoGPT是AI领域的一件大事
    开发人员正在构建自动化ChatGPT提示的方法,鼓励该工具执行自主连接任务,这将减轻用户在使用它时遇到的一些限制。例如,开发人员ToranBruceRichards在GitHub上推出了他的开源应用程序Auto-GPT,这是一个流行的基于Web的平台,开发人员可以在其中存储代码,与他人合作并跟踪代码更改。它......
  • 访问页面中嵌入的表格
       如上图,点击ExporttoExcel就会下载一个Exce文件,但是当我们查看元素时,,并没有excel的url。查看网络的文档时,也没有excel的url这是我们清空网络的页面,重新点击页面的ExporttoExcel按钮,就会出现三个响应文件,并下载了一个excel文件。逐个分析,    如果我们......
  • 将scss文件转换成css文件
    将scss文件转换成css文件npmisass使用命令转译scss或sass文件sass.\index.scss.\index.css监听scss变化更新css文件sass--watch.\demo\page\index\index.scss.\demo\page\index\index.scss......
  • Linux扩大虚拟机系统磁盘空间
    Linux扩大虚拟机系统磁盘空间一、基本步骤1.虚拟机保持关闭状态,设置->磁盘->拓展->最大磁盘大小 设成30G2.创建新分区3.格式化分区4.挂载分区(创建新分区后,需要挂载才能使用)5.解挂分区(解挂后,数据会保留,重新挂载,数据依旧存在)6.删除分区(删除后,数据不存在) 二、创建......
  • 在线直播源码,自定义AlertDialog设置宽高并去掉默认的边框
    在线直播源码,自定义AlertDialog设置宽高并去掉默认的边框1、先写一个自定义的AlertDialog。 packagecom.phone.common_library.dialog; importandroid.annotation.SuppressLint;importandroid.content.Context;importandroid.content.DialogInterface;importandroid.vie......
  • Natasha V5.2.2.1 稳定版正式发布.
    DotNetCore.Natasha.CSharpv5.2.2.1使用NMSTemplate接管CI的部分功能.取消SourceLink.GitHub的继承性.优化几处内存占用问题.增加隐式using配置文件以支持隐式using引用.当项目开启<ImplicitUsings>enable</ImplicitUsings>时,自动生效.增加初始化PE信息判......
  • 安卓手机使用什么便签?
    随着国产安卓手机的崛起,现在越来越多的消费者在更换手机的时候会选择小米、OPPO、vivo、荣耀等国产安卓手机。不过在使用安卓手机的过程中,有一些用户提出了更高的使用需求,例如想要在手机上随手添加文字、图片、视频记事,把待办的事情记录下来并设置提醒时间等。其实使用一款支持记......
  • centos7 go语言环境安装
    要在CentOS7上安装Go环境,可以按照以下步骤进行操作:1.访问Go官网下载页面(https://golang.google.cn/dl/),并选择适合自己系统的版本和包。2.下载完成后,使用以下命令将下载的tar包解压到/usr/local目录:```sudotar-C/usr/local-xzfgo$VERSION.$OS-$ARCH.tar......
  • qt 代码设置layout中的控件的比例,以水平布局为例
    voidDisplayPathFilename::mainlayout(){m_hboxlayout->addWidget(m_filenamelabel);m_filenamelabel->setText("配置文件:");m_hboxlayout->addWidget(m_filenamelineedit);m_hboxlayout->addWidget(m_displaypathlabel);m_dis......
  • 上班族如何安排时间提高工作效率?
    对于上班族来说,合理安排时间可以兼顾生活和工作,不仅能够减少加班次数,还可以提高工作效率,减少工作中的负面情绪。但是有不少小伙伴表示,自己不知道如何安排时间从而提高工作效率,这应该怎么办呢?其实上班族想要做好时间安排,以下几点很重要:制定清晰的工作计划、将工作计划的轻重缓急程......