目录
概
本文提出了一种优化器中高效的缓存策略.
符号说明
- \(W_t \in \mathbb{R}^{m \times n}\), 参数;
- \(\varphi_t\), 损失函数;
- \(G_t = -\nabla_W \varphi_t (W_t) \in \mathbb{R}^{m \times n}\);
GaLore
-
一般的优化器更新可以归结为:
\[W_{t+1} = W_t - \eta \tilde{G}_t, \]其中 \(\tilde{G}_t = \rho_t\) 是对梯度 \(G_t\) 进行的一个处理, 在 Adam 中涉及两种动量:
\[M_t = \beta_1 M_{t-1} + (1 - \beta_1) G_t, \\ V_t = \beta_2 V_{t-1} + (1 - \beta_2) G_t^2, \\ \rho_t(G_t) = M_t / \sqrt{V_t + \epsilon}. \] -
像 Adam 这种带 momentum 的, 我们需要缓存 2x 模型大小的量用于更新, 这是非常恐怖的消耗.
-
作者通过理论分析发现, \(G_t\) 随着梯度更新会逐渐趋于低秩, 本文建议一种 gradient low-rank projection (GaLore) 的方式更新:
\[W_{t+1} = W_t - \eta \tilde{G}_t, \quad \tilde{G}_t = P_t \:\rho_t (P_t^T G_t Q_t) \: Q_t^T, \]其中 \(P_t \in \mathbb{R}^{m \times r}, Q_t \in \mathbb{R}^{n \times r}, r \ll m, n\).
-
即 梯度转移到低秩空间 -> 在低秩空间中完成 \(\rho_t\) -> 恢复到原空间. 于是在整个训练过程中, 我们只需要缓存这些投影矩阵即可. 如下是 Adam 的一个例子 (只用了一半的投影):
-
收敛性是容易理解的, 每一步更新都相当于:
\[\varphi_t(\hat{W}_t), \quad \hat{W}_t = \text{stop-gradient}(W_t) + P \tilde{W}_t Q^T, \quad \tilde{W}_t \in \mathbb{R}^{r \times r}. \] -
则
\[\nabla_{\tilde{W}_t} \varphi_t = P^T G_t Q, \]此时便有:
\[\hat{W}_{t+1} = \hat{W}_t + P \Delta \tilde{W} Q^T = \hat{W}_t - \eta P \: \rho_t (P^T G_t Q) Q^T. \]