从这篇文章开始,有三AI-NLP专栏就要进入深度学习了。本文会介绍自然语言处理早期标志性的特征提取工具-循环神经网络(RNN)。首先,会介绍RNN提出的由来;然后,详细介绍RNN的模型结构,前向传播和反向传播的过程;最后,讨论RNN的特点及其优劣势。
作者&编辑 | 小Dream哥
完整的NLP深度学习介绍,应该从反向传播(BP)开始,进而介绍深度神经网络(DNN),卷积神经网络(CNN)也是必不可少的内容。鉴于有三AI已经发布了大量的CV相关的文章,其中必有相关的介绍。所以,在NLP专栏就暂不介绍相关的内容了。如果有需要的同学,可以留言提出来。
1 引言:RNN
对于一些序列输入的信息,例如语音、语言等,不同时刻之间的输入存在相互的影响,需要一种模型能够“记忆”历史输入的信息,进而对整个序列进行完整的特征提取和表征。
循环神经网络(RNN)就是面对这样的需求提出来的,它能够“记忆”序列输入的历史信息,从而能够较好的对整个序列进行语义建模。
目前,RNN及其变种在NLP领域有着广泛的应用。语音识别、对话系统、机器翻译、情感分析等等领域,在产业界,RNN及其变种都是最主要的特征提取工具。
关于RNN的特性,这里先不做太多理论上的说明,等介绍完其结构、前向传播和反向传播后我们再来讨论。
基于篇幅的限制,本文会先介绍最基本的RNN模型结构和原理,LSTM会在下一篇文章中做详细的介绍。
2 RNN的结构
如上图所示,是RNN的结构图。相较于CNN繁杂的卷积运算过程和复杂的网络层次,RNN的模型结构看上去相当的简洁。同样的,RNN模型的结构也分为输入层(Input Layer)、隐藏层(Hidden Layer)和输出层(Output Layer)。图中的箭头表示数据的流动,需要注意的是在隐藏层,有一个回流的箭头,这是这个箭头的作用,使得RNN具有了“记忆”的能力。
这样看,同学们可能还无法看清楚数据在RNN模型内到底是如何流动的。我们将RNN模型的单元按时间展开,如下图所示:
可以看到,不同时刻的数据x_t与上一时刻的状态s_(t-1),从输入层输入,经过一系列运算(激活函数)之后,得到该时刻的状态s_t,s_t再经过矩阵运算得到该时刻的输出o_t,同时t时刻的状态s_t会传给下一时刻的输入层。
通过这种方式,任意时刻的序列输入都会包含前面所有时刻的状态信息,就实现了“记忆”的目的,实际就是一种残差的结构。
需要注意的是,这里所有的RNN结构单元是权重共享的,用大白话说,就是只有一个RNN单元。
下面我们来详细看看数据的流动过程,也就是RNN的正向传播与反向传播过程。
3 RNN的正向传播
RNN的正向传播过程,就是通过输入数据x_t,求该时刻的RNN单元状态(Cell State)s_t以及输出o_t的过程。
我们先来看s_t
U和W是权重参数,f是激活函数,激活函数有sigmoid、relu以及tanh等。
o_t的计算过程为:
V是权重参数,g是输出函数,因为通常是预测类别,所以一般是softmax。
4 RNN的反向传播
下面我们基于RNN的正向传播过程来介绍下RNN的反向传播过程。RNN的反向传播与DNN的反向传播的基本理论是一致的。差别在于,因为RNN是序列的输入,因此其反向传播是基于时间的,叫BPTT(Back PropagationThrough Time)。
与DNN一致,反向传播的过程其实就是更新参数U,W,V的过程。知道反向传播的同学应该知道,更新,W,V其实就是求梯度。
用L_t表示t时刻的模型损失,则输入完一个序列后的总损失值为:
我们先来看参数V的更新,根据偏导公式,
损失函数通常为交叉熵,因此,
再来看看W和U的更新,像DNN的反向传播一样,我们引入一个中间变量,暂称之误差delta,t时刻的误差delta_t:
我们的目标是要得到一个递推公式,用delta_(t+1)来表示delta_t,注意这里激活函数用的是tanh函数。
最后时刻的误差可以表示为:
这样就可以通过delta_T一步一步得到所有时刻的误差。
那么,怎么通过误差得到W和U的梯度呢?
罗列了一大堆的公式,肯定有同学看花了眼。公式推导有不明白的地方,没有关系,我们暂且先放下,后面再慢慢的思考,最重要的是理解反向传播时,梯度更新的思想和技巧。下面我带着大家总结一下这个过程,相信你能获益匪浅。
1.正向传播,求得所有时刻的x_t,o_t,s_t
2. 根据梯度公式,求V的梯度
3. 求得T时刻的误差delta_T
4.根据误差的递推公式,求得所有时刻的误差delta_1,delta_2,...,delta_T
5. 根据梯度公式,和上述误差值求得W的梯度
6. 根据梯度公式,和上述误差值求得U的梯度
7. 更新权重参数
总结
上文详细讲述了RNN的模型结构及其正向和反向传播过程。
RNN虽然理论上可以很漂亮的解决序列数据的训练,但是它也像DNN一样有梯度消失的问题,当序列很长的时候问题尤其严重。虽然同选择合适的激活函数等方法能够一定程度的减轻该问题。但人们往往更青睐于使用RNN的变种。
因此,上面的RNN模型一般都没有直接应用的领域。在语音识别,对话系统以及机器翻译等NLP领域实际应用比较广泛的是基于RNN模型的变种。