首页 > 其他分享 >GaLore Memory-Efficient LLM Training by Gradient Low-Rank Projection

GaLore Memory-Efficient LLM Training by Gradient Low-Rank Projection

时间:2024-08-27 16:06:37浏览次数:13  
标签:mathbb Training Projection Efficient GaLore times rho tilde hat

目录

Zhao J., Zhang Z., Chen B., Wang Z., Anandkumar A. and Tian Y. GaLore: Memory-efficient llm training by gradient low-rank projection. ICML, 2024.

本文提出了一种优化器中高效的缓存策略.

符号说明

  • \(W_t \in \mathbb{R}^{m \times n}\), 参数;
  • \(\varphi_t\), 损失函数;
  • \(G_t = -\nabla_W \varphi_t (W_t) \in \mathbb{R}^{m \times n}\);

GaLore

  • 一般的优化器更新可以归结为:

    \[W_{t+1} = W_t - \eta \tilde{G}_t, \]

    其中 \(\tilde{G}_t = \rho_t\) 是对梯度 \(G_t\) 进行的一个处理, 在 Adam 中涉及两种动量:

    \[M_t = \beta_1 M_{t-1} + (1 - \beta_1) G_t, \\ V_t = \beta_2 V_{t-1} + (1 - \beta_2) G_t^2, \\ \rho_t(G_t) = M_t / \sqrt{V_t + \epsilon}. \]

  • 像 Adam 这种带 momentum 的, 我们需要缓存 2x 模型大小的量用于更新, 这是非常恐怖的消耗.

  • 作者通过理论分析发现, \(G_t\) 随着梯度更新会逐渐趋于低秩, 本文建议一种 gradient low-rank projection (GaLore) 的方式更新:

    \[W_{t+1} = W_t - \eta \tilde{G}_t, \quad \tilde{G}_t = P_t \:\rho_t (P_t^T G_t Q_t) \: Q_t^T, \]

    其中 \(P_t \in \mathbb{R}^{m \times r}, Q_t \in \mathbb{R}^{n \times r}, r \ll m, n\).

  • 即 梯度转移到低秩空间 -> 在低秩空间中完成 \(\rho_t\) -> 恢复到原空间. 于是在整个训练过程中, 我们只需要缓存这些投影矩阵即可. 如下是 Adam 的一个例子 (只用了一半的投影):

  • 收敛性是容易理解的, 每一步更新都相当于:

    \[\varphi_t(\hat{W}_t), \quad \hat{W}_t = \text{stop-gradient}(W_t) + P \tilde{W}_t Q^T, \quad \tilde{W}_t \in \mathbb{R}^{r \times r}. \]

  • \[\nabla_{\tilde{W}_t} \varphi_t = P^T G_t Q, \]

    此时便有:

    \[\hat{W}_{t+1} = \hat{W}_t + P \Delta \tilde{W} Q^T = \hat{W}_t - \eta P \: \rho_t (P^T G_t Q) Q^T. \]

标签:mathbb,Training,Projection,Efficient,GaLore,times,rho,tilde,hat
From: https://www.cnblogs.com/MTandHJ/p/18382905

相关文章