首页 > 其他分享 >RNN详解(12)

RNN详解(12)

时间:2023-01-01 19:34:30浏览次数:68  
标签:12 RNN 导数 梯度 详解 序列 时刻 函数


本文部分参考和摘录了以下文章,在此由衷感谢以下作者的分享!
​​​https://zhuanlan.zhihu.com/p/28054589​​​

​​​https://zhuanlan.zhihu.com/p/28687529​​​

​​​https://zhuanlan.zhihu.com/p/26892413​​​
​​​https://zhuanlan.zhihu.com/p/21462488?refer=intelligentunit​


RNN(Recurrent Neural Network)是一类用于处理序列数据的神经网络。首先我们要明确什么是序列数据,摘取​​百度百科​​词条:时间序列数据是指在不同时间点上收集到的数据,这类数据反映了某一事物、现象等随时间的变化状态或程度。这是时间序列数据的定义,当然这里也可以不是时间,比如文字序列,但总归序列数据有一个特点——后面的数据跟前面的数据有关系。

RNN的结构及变体

我们从基础的神经网络中知道,神经网络包含输入层、隐层、输出层,通过激活函数控制输出,层与层之间通过权值连接。激活函数是事先确定好的,那么神经网络模型通过训练“学“到的东西就蕴含在“权值“中。

基础的神经网络只在层与层之间建立了权连接,RNN最大的不同之处就是在层之间的神经元之间也建立的权连接。如图。

RNN详解(12)_激活函数

这是一个标准的RNN结构图,图中每个箭头代表做一次变换,也就是说箭头连接带有权值。左侧是折叠起来的样子,右侧是展开的样子,左侧中h旁边的箭头代表此结构中的“循环“体现在隐层。
在展开结构中我们可以观察到,在标准的RNN结构中,隐层的神经元之间也是带有权值的。也就是说,随着序列的不断推进,前面的隐层将会影响后面的隐层。图中O代表输出,y代表样本给出的确定值,L代表损失函数,我们可以看到,“损失“也是随着序列的推荐而不断积累的。
除上述特点之外,标准RNN的还有以下特点:
1、权值共享,图中的W全是相同的,U和V也一样。
2、每一个输入值都只与它本身的那条路线建立权连接,不会和别的神经元连接。

以上是RNN的标准结构,然而在实际中这一种结构并不能解决所有问题,例如我们输入为一串文字,输出为分类类别,那么输出就不需要一个序列,只需要单个输出。如图。

RNN详解(12)_Pytorch_02


同样的,我们有时候还需要单输入但是输出为序列的情况。那么就可以使用如下结构:

RNN详解(12)_RNN_03

还有一种结构是输入虽是序列,但不随着序列变化,就可以使用如下结构:

RNN详解(12)_理论详讲_04

原始的N vs N RNN要求序列等长,然而我们遇到的大部分问题序列都是不等长的,如机器翻译中,源语言和目标语言的句子往往并没有相同的长度。

下面我们来介绍RNN最重要的一个变种:N vs M。这种结构又叫Encoder-Decoder模型,也可以称之为Seq2Seq模型。

RNN详解(12)_Pytorch_05


从名字就能看出,这个结构的原理是先编码后解码。左侧的RNN用来编码得到c,拿到c后再用右侧的RNN进行解码。得到c有多种方式,最简单的方法就是把Encoder的最后一个隐状态赋值给c,还可以对最后的隐状态做一个变换得到c,也可以对所有的隐状态做变换。

RNN详解(12)_激活函数_06

除了以上这些结构以外RNN还有很多种结构,用于应对不同的需求和解决不同的问题。还想继续了解可以看一下下面这个博客,里面又介绍了几种不同的结构。但相同的是循环神经网络除了拥有神经网络都有的一些共性元素之外,它总要在一个地方体现出“循环“,而根据“循环“体现方式的不同和输入输出的变化就形成了多种RNN结构。


标准RNN的前向输出流程

上面介绍了RNN有很多变种,但其数学推导过程其实都是​​大同小异​​。这里就介绍一下标准结构的RNN的前向传播过程。

RNN详解(12)_激活函数

再来介绍一下各个符号的含义:x是输入,h是隐层单元,o为输出,L为损失函数,y为训练集的标签。这些元素右上角带的t代表t时刻的状态,其中需要注意的是,因策单元h在t时刻的表现不仅由此刻的输入决定,还受t时刻之前时刻的影响。V、W、U是权值,同一类型的权连接权值相同。

有了上面的理解,前向传播算法其实非常简单,对于t时刻:

h(t)=ϕ(Ux(t)+Wh(t−1)+b)h(t)=ϕ(Ux(t)+Wh(t−1)+b)

为激活函数,一般来说会选择tanh函数,b为偏置。

t时刻的输出就更为简单:

o(t)=Vh(t)+co(t)=Vh(t)+c

最终模型的预测输出为:

yˆ(t)=σ(o(t))y^(t)=σ(o(t))

其中σσ为激活函数,通常RNN用于分类,故这里一般用softmax函数。

RNN的训练方法——BPTT

BPTT(back-propagation through time)算法是常用的训练RNN的方法,其实本质还是BP算法,只不过RNN处理时间序列数据,所以要基于时间反向传播,故叫随时间反向传播。BPTT的中心思想和BP算法相同,沿着需要优化的参数的负梯度方向不断寻找更优的点直至收敛。综上所述,BPTT算法本质还是BP算法,BP算法本质还是梯度下降法,那么求各个参数的梯度便成了此算法的核心。

RNN详解(12)_激活函数

再次拿出这个结构图观察,需要寻优的参数有三个,分别是U、V、W。与BP算法不同的是,其中W和U两个参数的寻优过程需要追溯之前的历史数据,参数V相对简单只需关注目前,那么我们就来先求解参数V的偏导数。

∂L(t)∂V=∂L(t)∂o(t)⋅∂o(t)∂V∂L(t)∂V=∂L(t)∂o(t)⋅∂o(t)∂V



这个式子看起来简单但是求解起来很容易出错,因为其中嵌套着激活函数函数,是复合函数的求道过程。

RNN的损失也是会随着时间累加的,所以不能只求t时刻的偏导。

L=∑t=1nL(t)L=∑t=1nL(t)

W和U的偏导的求解由于需要涉及到历史数据,其偏导求起来相对复杂,我们先假设只有三个时刻,那么在第三个时刻 L对W的偏导数为:

∂L(3)∂W=∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂W+∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂h(2)∂h(2)∂W+∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂h(2)∂h(2)∂h(1)∂h(1)∂W∂L(3)∂W=∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂W+∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂h(2)∂h(2)∂W+∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂h(2)∂h(2)∂h(1)∂h(1)∂W

可以观察到,在某个时刻的对W或是U的偏导数,需要追溯这个时刻之前所有时刻的信息,这还仅仅是一个时刻的偏导数,上面说过损失也是会累加的,那么整个损失函数对W和U的偏导数将会非常繁琐。虽然如此但好在规律还是有迹可循,我们根据上面两个式子可以写出L在t时刻对W和U偏导数的通式:

∂L(t)∂W=∑k=0t∂L(t)∂o(t)∂o(t)∂h(t)(∏j=k+1t∂h(j)∂h(j−1))∂h(k)∂W∂L(t)∂W=∑k=0t∂L(t)∂o(t)∂o(t)∂h(t)(∏j=k+1t∂h(j)∂h(j−1))∂h(k)∂W



整体的偏导公式就是将其按时刻再一一加起来。

前面说过激活函数是嵌套在里面的,如果我们把激活函数放进去,拿出中间累乘的那部分:

∏j=k+1t∂hj∂hj−1=∏j=k+1ttanh′⋅Ws∏j=k+1t∂hj∂hj−1=∏j=k+1ttanh′⋅Ws

我们会发现累乘会导致激活函数导数的累乘,进而会导致“梯度消失“和“梯度爆炸“现象的发生。

至于为什么,我们先来看看这两个激活函数的图像。

这是sigmoid函数的函数图和导数图。


这是tanh函数的函数图和导数图。


它们二者是何其的相似,都把输出压缩在了一个范围之内。他们的导数图像也非常相近,我们可以从中观察到,sigmoid函数的导数范围是(0,0.25],tach函数的导数范围是(0,1],他们的导数最大都不大于1。

这就会导致一个问题,在上面式子累乘的过程中,如果取sigmoid函数作为激活函数的话,那么必然是一堆小数在做乘法,结果就是越乘越小。随着时间序列的不断深入,小数的累乘就会导致梯度越来越小直到接近于0,这就是“梯度消失“现象。其实RNN的时间序列与深层神经网络很像,在较为深层的神经网络中使用sigmoid函数做激活函数也会导致反向传播时梯度消失,梯度消失就意味消失那一层的参数再也不更新,那么那一层隐层就变成了单纯的映射层,毫无意义了,所以在深层神经网络中,有时候多加神经元数量可能会比多家深度好。

你可能会提出异议,RNN明明与深层神经网络不同,RNN的参数都是共享的,而且某时刻的梯度是此时刻和之前时刻的累加,即使传不到最深处那浅层也是有梯度的。这当然是对的,但如果我们根据有限层的梯度来更新更多层的共享的参数一定会出现问题的,因为将有限的信息来作为寻优根据必定不会找到所有信息的最优解。

之前说过我们多用tanh函数作为激活函数,那tanh函数的导数最大也才1啊,而且又不可能所有值都取到1,那相当于还是一堆小数在累乘,还是会出现“梯度消失“,那为什么还要用它做激活函数呢?原因是tanh函数相对于sigmoid函数来说梯度较大,收敛速度更快且引起梯度消失更慢。

还有一个原因是sigmoid函数还有一个缺点,Sigmoid函数输出不是零中心对称。sigmoid的输出均大于0,这就使得输出不是0均值,称为偏移现象,这将导致后一层的神经元将上一层输出的非0均值的信号作为输入。关于原点对称的输入和中心对称的输出,网络会收敛地更好。

RNN的特点本来就是能“​​追根溯源​​​“利用历史数据,现在告诉我可利用的历史数据竟然是有限的,这就令人非常难受,解决“梯度消失“是非常必要的。解决“梯度消失“的方法主要有:
1、选取更好的激活函数
2、改变传播结构

关于第一点,一般选用ReLU函数作为激活函数,ReLU函数的图像为:


ReLU函数的左侧导数为0,右侧导数恒为1,这就避免了“梯度消失“的发生。但恒为1的导数容易导致“梯度爆炸“,但设定合适的阈值可以解决这个问题。还有一点就是如果左侧横为0的导数有可能导致把神经元学死,不过设置合适的步长(学习旅)也可以有效避免这个问题的发生。

关于第二点,LSTM结构可以解决这个问题。

总结一下,sigmoid函数的缺点:
1、导数值范围为(0,0.25],反向传播时会导致“梯度消失“。tanh函数导数值范围更大,相对好一点。
2、sigmoid函数不是0中心对称,tanh函数是,可以使网络收敛的更好。


LSTM

下面来了解一下LSTM(long short-term memory)。长短期记忆网络是RNN的一种变体,RNN由于梯度消失的原因只能有短期记忆,LSTM网络通过精妙的门控制将短期记忆与长期记忆结合起来,并且一定程度上解决了梯度消失的问题。
由于已经存在了一篇写得非常好的博客,我在这里就直接转载过来,再在其中夹杂点自己的理解。原文连接如下。

作者:朱小虎Neil 链接:​​https://www.jianshu.com/p/9dc9f41f0b29​​ 來源:简书

在此感谢原作者!

长期依赖(Long-Term Dependencies)问题

RNN 的关键点之一就是他们可以用来连接先前的信息到当前的任务上,例如使用过去的视频段来推测对当前段的理解。如果 RNN 可以做到这个,他们就变得非常有用。但是真的可以么?答案是,还有很多依赖因素。

有时候,我们仅仅需要知道先前的信息来执行当前的任务。例如,我们有一个语言模型用来基于先前的词来预测下一个词。如果我们试着预测 “the clouds are in the sky” 最后的词,我们并不需要任何其他的上下文 —— 因此下一个词很显然就应该是 sky。在这样的场景中,相关的信息和预测的词位置之间的间隔是非常小的,RNN 可以学会使用先前的信息。

RNN详解(12)_Pytorch_09

不太长的相关信息和位置间隔不太长的相关信息和位置间隔



但是同样会有一些更加复杂的场景。假设我们试着去预测“I grew up in France… I speak fluent French”最后的词。当前的信息建议下一个词可能是一种语言的名字,但是如果我们需要弄清楚是什么语言,我们是需要先前提到的离当前位置很远的 France 的上下文的。这说明相关信息和当前预测位置之间的间隔就肯定变得相当的大。

不幸的是,在这个间隔不断增大时,RNN 会丧失学习到连接如此远的信息的能力。

RNN详解(12)_理论详讲_10

相当长的相关信息和

了解更多关于《计算机视觉与图形学》相关知识,请关注公众号:计算机视觉与图形学实战


标签:12,RNN,导数,梯度,详解,序列,时刻,函数
From: https://blog.51cto.com/u_15717531/5983285

相关文章

  • 详解前端缓存,解决前端换包之后环境中仍会出现旧版效果
    前端项目修改了很多东西:比如bug啊,样式啊。当你把前端项目打包之后满心欢喜的在Nginx(测试环境)换上它,然后在Jira上修改bug状态@测试人员复测。然后测试人员开始找你ba......
  • 好题分享、心路历程(力扣1225)
    【题目介绍】该题为力扣1225,名为报告系统状态的连续日期。【题型分类】属于连续专题。官网标为困难题。【思路分享】这里的连续属于时间连续,采用row_number()、subd......
  • SVN中trunk,branches,tags用法详解
    Subversion有一个很标准的目录结构,是这样的。比如项目是proj,svn地址为svn://proj/,那么标准的svn布局是svn://proj/|+-trunk+-branches+-tags这是一个标准的布局,trunk为主开......
  • mongodb的aggregate聚合操作详解
    ################################### 在工作中会经常遇到一些mongodb的聚合操作,特此总结下。mongo存储的可以是复杂类型,比如数组、对象等mysql不善于处理的文档型结构,并且......
  • 力扣---1262. 可被三整除的最大和
    给你一个整数数组 nums,请你找出并返回能被三整除的元素最大和。示例1:输入:nums=[3,6,5,1,8]输出:18解释:选出数字3,6,1和8,它们的和是18(可被3整除的最大和)。示例2......
  • Django——全局配置settings详解
    Django设置文件包含你所有的Django安装配置。这个文件一般在你的项目文件夹里。比如我们创建了一个名为mysite的项目,那么这个配置文件setting.py就在项目里的mysite文件夹......
  • fix协议介绍12-取消订单被拒(OrderCacelReject)
    FIX.5.0SP2MessageOrderCancelReject [type'9']<OrdCxlRej>Theordercancelrejectmessageisissuedbythebrokeruponreceiptofacancelrequestorcancel......
  • Gitlab CI 配置文件 .gitlab-ci.yaml 详解
    转载:GitlabCI配置文件.gitlab-ci.yaml详解(上)-腾讯云开发者社区-腾讯云(tencent.com)本文档用于描述.gitlab-ci.yml语法,.gitlab-ci.yml文件被用来管理项目的......
  • KMP字符串模式匹配详解
    KMP字符串模式匹配详解KMP字符串模式匹配通俗点说就是一种在一个字符串中定位另一个串的高效算法。简单匹配算法的时间复杂度为O(m*n);KMP匹配算法。可以证明它的时间复杂度......
  • 122FPS、51.8mAP 超轻量关键点检测算法PP-TinyPose来啦!
    精准的人机交互任务,如手势控制、智能健身、体感游戏等,背后的核心技术是什么?那必须是关键点检测!还有智慧城市、智慧安防等领域的打架斗殴、司机/工人违规操作等异常行为识别,......