首页 > 其他分享 >门控循环单元GRU

门控循环单元GRU

时间:2024-08-03 17:28:33浏览次数:10  
标签:vocab GRU batch 门控 token num 隐藏 单元 size

目录

一、GRU提出的背景:

1.RNN存在的问题:

循环神经网络讲解文章

由于RNN的隐藏状态ht用于记录每个句子之前的所有序列信息,而对于长序列问题来说ht会记录太多序列信息导致序列时序特征区分度很差(最前面的序列特征因为进行了太多轮迭代往往不太好从ht中提取),并且RNN默认当前时间步的token单词和该句子的隐藏状态ht中所有序列信息都有同等的相关度,因此一些比较靠前但与当前时间步输入的token相关性高的序列特征在ht中可能就不太被重视,而一些比较靠后但与当前时间步输入的token相关性低的序列特征在ht中被过于关注。

2.GRU的思想:

GRU的提出就是为了解决RNN默认句子中所有token之间的相关性相等问题。
GRU的思想是对于每个时间步的输入token,使用门的控制将隐藏状态ht中与当前token相关性高的序列信息拿来参与计算,而ht中与当前token相关性低的序列信息作为噪音不参与计算。

  • 对于需要关注的序列信息,使用更新门来提高关注度
  • 对于需要遗忘的序列信息,使用遗忘门来降低关注度

二、更新门和重置门:

GRU提出更新门和重置门的思想来改变隐藏状态ht中不同序列信息的关注度。
在这里插入图片描述
更新门和重置门可以分别看做一个全连接层的隐藏层,这样的话上图就等价于两个并排的隐藏层,其中:

  • 每个隐藏层都接收之前时间步的隐藏状态Ht-1和当前时间步的输入token或token集合(batch_size>1)。
  • 更新门和重置门有各自的可学习权重参数和偏置值,公式含义类似传统RNN
  • Rt 和 Zt 都是根据过去的隐藏状态 Ht-1 和当前输入 Xt 计算得到的 [0,1] 之间的量(激活函数)。

三、GRU网络架构:

1.更新门和重置门如何发挥作用:

重置门对过去t个时间步的序列信息(Ht-1)进行选择,更新门对当前一个时间步的序列信息(Xt)进行选择。具体原理如下:

1.1候选隐藏状态H~t:

候选隐藏状态既保留了之前的隐藏状态Ht-1,又保留了当前一个时间步的序列信息Xt。
在这里插入图片描述
因为Rt是一个[0,1] 之间的量,所以Rt×Ht-1是对之前的隐藏状态Ht-1进行一次选择:Rt 在某个位置的值越趋近于0,则表示Ht-1这个位置的序列信息越倾向于被丢弃,反之保留。

综上,重置门的作用是对过去的序列信息Ht-1进行选择,Ht-1中哪些序列信息对H~T是有用的,应该被保存下来,而哪些序列信息是不重要的,应该被遗忘。

1.2隐藏状态Ht:

在这里插入图片描述
因为Zt是一个[0,1] 之间的量,如果Zt全为0,则当前隐藏状态Ht为当前候选隐藏状态,该候选隐藏状态不仅保留了之前的序列信息,还保留了当前时间步batch的序列信息;如果Zt全为1,则当前隐藏状态Ht为上一个时间步的隐藏状态。

综上,更新门的作用是决定当前一个时间步的序列信息是否保留,如果Zt全为0,则说明当前时间步token的序列信息是有用的(候选隐藏状态包含之前的序列信息和当前一个时间步的序列信息),保留下来加入到隐藏状态Ht中;如果Zt全为1,则说明当前时间步batch的序列信息是没有用的,丢弃当前token的序列信息,直接使用上一个时间步的隐藏状态Ht-1作为当前的隐藏状态Ht。(Ht-1仅包含之前的序列信息,不包含当前一个时间步的序列信息)

2.GRU:

GRU网络架构如下,可以看做是三个隐藏层并排的架构。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

四、训练过程举例******:

以下文预测问题为例,一次epoch训练过程如下。
1.对整个文本进行数据预处理,获得数据字典,这里假设字典中有vocab_size条字典序,这样就转换成了一个vocab_size分类的序列问题。
2.将每个单词token值使用独热编码转换成1×vocab_size的一维向量,作为特征,表示各分类上的概率。
3.每轮epoch输入格式为batch_num×batch_size×num_steps×vocab_size,其中batch_num表示该轮压迫训练多少个batch,batch_size表示每个batch中有多少个句子序列,每个句子有num_steps个单词token,即该batch要训练多少个时间步,即循环time_step次传统神经网络,每个单词为一个一维向量,表示在字典序上的概率。每次训练一个batch,每个时间步t使用该batch中所有batch_size个序列的第t个token集合Xt进行训练(num_steps=t的token),batch尺寸为batch_size×num_steps×vocab_size,Xt尺寸为batch_size×vocab_size
4.隐藏层参数Whh维度为num_hiddens×num_hiddens,表示隐藏层关于序列信息(隐藏状态)的权重矩阵;Whx维度为vocab_size×num_hiddens,表示隐藏层关于输入特征的权重矩阵;参数bh维度为1×num_hiddens
5.三个并行隐藏层各自的参数Whh、Whz、Whr维度计算为num_hiddens×num_hiddens,表示隐藏层关于序列信息(隐藏状态)的权重矩阵;三个并行隐藏层各自的参数Wxh、Wxz、Wxr维度计算为vocab_size×num_hiddens,表示隐藏层关于输入特征的权重矩阵;参数bh、bz、br维度计算为1×num_hiddens。这里由于三个隐藏层输出维度相同,所以隐藏内的神经元数目都是相同的=num_hiddens。
6.对于第一个batch,训练过程如下:
6.1.初始化0时刻序列信息(隐藏层输出,隐藏状态)h0,尺寸为(batch_size,神经元个数num_hiddens)
6.2.t1时间步num_steps=1,取该batch所有序列样本的第一个token组成x0,尺寸batch_size×vocab_size,每个vocab一维向量并行放入神经网络学习,首先x0中每个token和ho同时进入更新门隐藏层和重置门隐藏层,重置门隐藏层输出R1=sigmoid(Whr×h0+Wxr×x0+br)、更新门隐藏层输出Z1=sigmoid(Whz×h0+Wxz×x0+bz),两个隐藏层分别用来筛选过去和当前的序列信息,输出维度均为batch_size×num_hiddens。
6.3.重置门输出R1、隐藏状态h0和x0中每个token进入候选隐藏状态隐藏层,使用重置门对过去的序列信息进行筛选,计算出候选隐藏状态H~1。
6.4.更新门输出Z1、隐藏状态h0和候选隐藏状态H~1联合计算,使用更新门对当前的序列信息进行筛选,计算出当前时间步的隐藏状态h1,隐藏层输出维度batch_size×num_hiddens,h1作为t1时间步的输出层输入、t2时间步的隐藏层输入序列信息(隐藏状态)。
6.5.此时两个操作并行执行:t1时间步的输出层计算、t2时间步的隐藏层计算。
6.5.1首先h1作为t1时间步的输出层输入,输出层有vocab_size个神经元,会执行多分类预测,可学习参数为Woh(num_hiddens×vocab_size)和bo(1×vocab_size),每个token输出维度1×vocab_size,输出层输出维度batch_size×vocab_size,表示各个token在各个分类上的预测。
6.5.2其次,t2时间步num_steps=2,取batch中num_steps=2的token集合为x1,维度为batch_size×vocab_size,并行将每个token一维向量放入神经网络学习,隐藏层输出h2=sigmoid(Whh×h1+Whx×x1+bh),每个token输出维度1×num_hiddens,隐藏层输出维度batch_size×num_hiddens,h2作为t2时间步的输出层输入、t3时间步的隐藏层输入序列信息。
6.6.如此反复每个时间步取一个数据点token集合进行训练,并更新隐藏层输出ht作为下一个时间步的输入,直到完成所有num_steps个时间步的训练任务,整个batch就训练完成了。
6.7.对于每个时间步上的预测batch_size×vocab_size,num_steps个时间步上总的预测为(num_steps×batch_size,vocab_size),这是该batch的训练总输出。
6.8.使用损失函数计算batch中各个句子中每个token的概率损失,并取均值。
6.9.反向传播算法计算各个参数关于损失函数的梯度。
6.10.梯度裁剪修改梯度。
6.11.梯度下降算法修改参数值。
7.该batch训练完成。进行下一个batch训练,初始化隐藏状态h0…。

五、预测过程举例******:

背景定义同训练过程,模型的预测过程如下。
1.输入prefix长度的前缀,来预测接下来num_preds个token。
2.首先还是将prefix转换成字典序并进行独热编码,尺寸为1×prefix×vocab_size,其中prefix=num_steps。
3.加载模型,初始化时序信息h0。
4.batch_size为1,在每个时间步上对句子长度每个token一维向量依次作为模型一个时间步的输入,输入维度1×vocab_size,总共计算prefix个时间步,循环计算prefix个时间步后的时序信息hp,hp尺寸为1×num_hiddens(batch_size=1)。
5.将prefix最后一个token和hp作为模型输入,来预测num_preds个token的第一个token,输出预测结果pred1和时序信息hp1,然后将pred1和hp1作为输入预测pred2和hp2(即使用预测值来预测下一个预测值),直到预测num_preds个预测值。(等价于batch=1,num_steps=num_preds的训练过程)
6.将预测值使用字典转为字符串输出。

六、底层源码:

代码中num_hiddens表示隐藏层神经元个数,由于重置门、更新门的输出维度相同,所以重置门和更新门两个隐藏层的神经元个数也是一样的=num_hiddens。

import torch
from torch import nn
from d2l import torch as d2l

# 数据预处理,获取datalodaer和字典
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

# 初始化可学习参数
def get_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device) * 0.01

    def three():
        return (normal(
            (num_inputs, num_hiddens)), normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xz, W_hz, b_z = three()
    W_xr, W_hr, b_r = three()
    W_xh, W_hh, b_h = three()
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

# 初始化隐藏状态
def init_gru_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),)

# 定义门控循环单元模型
def gru(inputs, state, params):
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    outputs = []
    for X in inputs:
        Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)
        R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)
        H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)
        H = Z * H + (1 - Z) * H_tilda
        Y = H @ W_hq + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H,)

# 训练
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,
                            init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

七、Pytorch版代码:

num_inputs = vocab_size
# 调用pytorch构建网络结构
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

标签:vocab,GRU,batch,门控,token,num,隐藏,单元,size
From: https://blog.csdn.net/m0_53881899/article/details/140844993

相关文章

  • excel 中如何将指定的空白单元格填充为指定的内容
     001、测试表格  002、选中,按F5 a、选定位条件 b、 c、直接输入一个测试文本 d、ctrl+enter 。 ......
  • 在 VSCode 上,使用 #%% 单元格时,shift-enter 不再执行交互窗口中的单元格并移至下一个
    在VSCode上,当使用#%%单元格时,shift-enter不再执行交互窗口中的单元格并移至下一个单元格。它曾经工作了很多年,但现在在某些代码行上按Shift键输入会在标题为“PythonREPL”的新窗口中引发错误。我仍然可以使用Control-Enter来执行交互窗口中的单元格,但这不会将光标移动......
  • 使用 GRUB2 管理双系统
    最近给自己的老笔记本换了一块大硬盘,顺便装了Windows和Ubuntu两个操作系统。记录一下安装过程。安装Ubuntu下载UbuntuDesktop镜像文件。你可以在官网中使用标准下载;或者在镜像源列表中就近下载,比如清华源是很不错的选择;或者使用BitTorrent下载。使用镜像......
  • 软件测试之解构单元测试
    软件单元测试是对软件中的最小可测试单元进行检查和验证的过程。这些单元可以是函数、方法、类实例,或者是任何具有明确功能、规格定义和接口定义的程序代码模块。单元测试是软件开发过程中的最低级别的测试活动,它确保软件的独立单元在与程序的其他部分相隔离的情况下能够正确工......
  • C++竞赛初阶L1-05-第四单元-判断语句(第19课)100003: 最大数输出
    题目内容输入三个整数,输出最大的数。输入格式输入为一行,包含三个整数,数与数之间以一个空格分开。输出格式输出一行,包含一个整数,即最大的整数。样例1输入102056样例1输出56程序代码输出:#include<bits/stdc++.h>usingnamespacestd;intmain(){ inta,b,c......
  • 前端如何设置表格边框样式和单元格间距?
    前端如何设置表格边框样式和单元格间距?引言表格的基本概念基本结构示例一:基本表格样式CSS说明示例二:交替行颜色CSS说明示例三:固定表头CSS说明示例四:设置单元格间距HTMLCSS说明示例五:响应式表格CSSHTML说明实际工作中的使用技巧技巧一:优化单元格内的内容CSS技巧二:使......
  • 修改anolist grub entry
    之前一直用ubuntu,切换到centos上感觉诸多不适宜。ubuntu切换kernel非常方便,只要grub-update即可,centos/anolist上比较麻烦,记一下。首先是把编好的kernel放到/boot下面,一般直接makeinstall即可;然后grub2-mkconfig-o/boot/grub/grub.cfg,这将会生成新的grubentry。如果要将新添......
  • 门控循环单元(GRU)预测模型及其Python和MATLAB实现
    ##一、背景循环神经网络(RNN)是处理序列数据的一类神经网络,尤其适用于时间序列预测、自然语言处理等领域。然而,传统的RNN在长序列数据的训练中面临梯度消失和爆炸的问题,导致模型对长期依赖的学习能力不足。为了解决这一问题,研究人员提出了多种改进的RNN结构,其中包括长短期记忆......
  • 单元电路(串联阻抗、并联导纳、无耗传输线)的基本网络参量(Z矩阵、Y矩阵、A矩阵、S矩
          PDF文件下载链接如下:单元电路(串联阻抗、并联导纳、无耗传输线段、无耗传输线接头)的矩阵(Z矩阵、Y矩阵、S矩阵、A矩阵、T矩阵)推导过程资源-CSDN文库https://download.csdn.net/download/lu2289504634/89583021单元电路的网络参量,可以直接根据未归一化网络参量的......
  • 使用JUnit 5进行Java单元测试的高级技术
    使用JUnit5进行Java单元测试的高级技术大家好,我是微赚淘客系统3.0的小编,是个冬天不穿秋裤,天冷也要风度的程序猿!今天我们来探讨如何使用JUnit5进行Java单元测试的高级技术。JUnit5是Java测试框架JUnit的最新版本,它引入了许多新功能和改进,使得编写和执行测试更加方便和灵活......