概述
长短期记忆 LSTM(Long Short Term Memory),该类型的神经网络可以利用上输入数据的时序信息。对于需要处理与顺序或时间强相关数据的领域(自然语言处理、天气预测等)相当合适。
GRU(Gate Recurrent Unit)可以视为 LSTM 的简化版本。运算量更小,却能达到 LSTM 相当的性能。
介绍 LSTM 之前,要先了解什么是 RNN。
RNN
递归神经网络 RNN(Recurssion Neural Network),通过让网络能接收上一时刻的网络输出达成处理时序数据的目标。
通常来说,网络通过输入 \(x\) 可以得到输出 \(y\)。而 RNN 的思路是将 \(i-1\) 时刻的输出 \(y\) 视为 “状态” \(h^{i-1}\),用为 \(i\) 时刻的网络输入。
如此,网络的输入有两个:\(x^i\),和上一个时刻的输出 \(h^{i-1}\)。网络的输出仍为一个,并且可以作为下一个时刻的网络输入 \(h^i\)。
LSTM
RNN 有很多缺点(遗忘、梯度爆炸与梯度消失),现在更多使用 LSTM。
LSTM 引入了单元状态(cell state)的概念。网络的输入现在有 \(x\)、隐藏态(hidden state)\(h^{i-1}\)、单元状态 \(c^{i-1}\)。。
单元状态 \(c^{i}\) 变化很慢,通常是 \(c^{t-1}\) 的基础上加一些数值。而 \(h^i\) 对于不同节点有很大区别。
LSTM 具体细节
在 \(i\) 时刻,网络先将本次输入 \(x^t\) 和上一隐藏态 \(h^{i-1}\) 拼接,经由四个不同的矩阵(矩阵参数可学习)做乘法,获得用途各异的四个状态 \(z^f\)、\(z^i\)、\(z^o\) 和 \(z\)。
\[\begin{aligned} &z^f=\text{sigmoid}(W^f\cdot \text{concatenate}(x^t,h^{t-1}))\\ &z^i=\text{sigmoid}(W^i\cdot \text{concatenate}(x^t,h^{t-1}))\\ &z^o=\text{sigmoid}(W^o\cdot \text{concatenate}(x^t,h^{t-1}))\\ &z=\text{tanh}(W\cdot \text{concatenate}(x^t,h^{t-1})) \end{aligned} \]一次运算的步骤如下:
-
从 \(i-1\) 时刻传来的单元状态 \(c^{i-1}\) 首先与 \(z^f\) 相乘,用于代表记忆的遗忘(forget)
-
\(z^i\) 与 \(z\) 相乘,代表对记忆进行选择,哪些记忆需要记录(information)。上一步经过遗忘处理的 \(c^{i-1}\) 与需要记录的记忆进行加运算,完成记录。
此步完成后,\(c^{i-1}\) 化身为 \(c^{i}\) 作为下一步的单元状态输入
- 最后用 \(z^o\) 控制输出(output)。\(z^o\) 与上一步的 \(c^{i-1}\) 的 \(\text{tanh}\) 结果相乘,获得本时刻的输出 \(y^i\),并作为下一步的隐藏态输入 \(h^t\)
GRU
GRU 可以实现与 LSTM 相当的性能,且运算量更低。
GRU 具体细节
GRU 没有单元状态 \(c^{i}\)。网络接收两个输入:当前输入 \(x^i\)、上一隐藏状态 \(h^{i-1}\)。两个输入经过两个不同的矩阵(矩阵参数可学习)做乘法,获得两个门控(gate):
\[\begin{aligned} &r=\text{sigmoid}(W^r\cdot \text{concatenate}(x^t,h^{t-1}))\\ &z=\text{sigmoid}(W^z\cdot \text{concatenate}(x^t,h^{t-1})) \end{aligned} \]\(r\) 为重置门控(reset gate),\(z\) 为更新门控(update gate)。
一次运算的步骤如下:
- 从 \(i-1\) 时刻传来隐藏状态 \(h^{i-1}\) 与 \(r\) 相乘,获得 \({h^{i-1}}'\)。这一步代表有选择性地保留记忆(遗忘)
- \({h^{t-1}}'\) 与输入 \(x^i\) 拼接,再乘一个参数可学习的矩阵,取 \(\text{tanh}\) 获得 \(h'\)。这一步让 \(h'\) 记忆了当前时刻的状态(记录)
- 用 \((1-z)\) 乘上 \(h^{i-1}\),用 \(z\) 乘上 \(h'\),将两者的和视为当前的隐藏状态 \(h^i\)。可见 \(h^i\) 结合了以前的记忆与现在的状态,代表记忆的更新
现在,将 \(h^i\) 视为下一时刻的输入,即完成了一次运算。
参考来源
- Mark,“LSTM - 长短期记忆递归神经网络”,https://zhuanlan.zhihu.com/p/123857569
- 陈诚,“人人都能看懂的LSTM”,https://zhuanlan.zhihu.com/p/32085405
- 陈诚,“人人都能看懂的GRU”,https://zhuanlan.zhihu.com/p/32481747