目录
概
Mamba 系列第三作.
符号说明
- \(u(t) \in \mathbb{R}\), 输入信号;
- \(x(t) \in \mathbb{R}^N\), 中间状态;
- \(y(t) \in \mathbb{R}\), 输出信号
S4
-
在 LSSL 中我们已经阐述了线性系统:
\[x'(t) = A x(t) + Bu(t), \\ y(t) = C x(t) + D u(t) \]在兼顾 RNN, CNN 的优势的可能性, 并且离散化后说明 LSSL 实际上可以改写成卷积的形式, 从而实现高效的并行化:
\[y = \mathcal{K}_L (\bar{A}, \bar{B}, C) * u + Du, \\ \mathcal{K}_L (A, B, C) := (CB, CAB, \ldots, CA^{L-1}B). \] -
现在的问题是, 如果 \(A\) 是固定的, 那么我们实际上只需要计算一次 \(\mathcal{K}_L\) 即可, 但是如果 \(A\) 不是固定的, 那么我们每次就需要付出额外的(相当多的)代价去计算 \(\mathcal{K}_L\), 其主要代价在于 \(A\).
-
假设我们能够通过某个 \(V \in \mathbb{R}^{N \times N}\) 对角化 \(A\), 则我们有:
\[\tilde{x}' = V^{-1} A V \tilde{x} + V^{-1} B u, \\ y = CV \tilde{x}. \]于是 \((V^{-1}AV)^{l}\) 计算起来就会比较方便了.
-
但是问题是, 作者发现 HiPPO 矩阵的 \(V\) 的值的大小规模可以达到 \(2^{4N/3}\), 所以计算的时候会造成严重的数值问题.
-
S4 提出了一种改进方案:
\[A = V(\Lambda - (V^*P) (V^*Q^*))V^*, \]其中
\[P, Q \in \mathbb{R}^{N \times R}, \]为低秩矩阵.
实际上可以证明, 对于所有的 HiPPO matrix, 都可以进行这样的分解. -
既然如此, S4 选择重参数化 \(A\) 为 \((\Lambda \in \mathbb{R}^{N \times 1}, P \in \mathbb{R}^{N \times 1}, Q \in \mathbb{R}^{N \times 1})\), 以及 \(B, C \in \mathbb{R}^{N \times 1}\), 为 5N 的可训练参数.
注: 我看代码的时候, 感觉发现 \(V\) 是没有保留的, 所以直接就是采用 \(V\) 变换后的那个方程了 (我一开始以为会用 HiPPO matrix 的初始的 \(V\) 最后做个转换的, 实际上没有).
注: 作者没有提及 \(\Delta t\) 是否是训练的, 我感觉应该和 LSSL 一样可训练吧.
注: \(R=1\) 不是必须的, 代码里设置了参数可以调节.