目录
概
原作者的博客已经讲得非常到位了: [here] and [there].
RoPE
-
RoPE 是一种相对位置编码, 特点是它可以像绝对位置编码一样, 在 embedding 上操作后再进行 attention 的运算, 而不限定于在 score 矩阵上操作.
-
具体的, 假设 \(\bm{x}_m, \bm{x}_n\) 为位置 \(m, n\) 上的两个 embedding, 令:
\[\bm{z}_m := \mathbf{R}_m \mathbf{W}_q \bm{x}_m, \\ \bm{z}_n := \mathbf{R}_n \mathbf{W}_q \bm{x}_n, \\ \]则
\[\bm{z}_{m}^T \bm{z}_n \]就是吸收了相对位置信息 \((m-n)\) 的 score.
-
其中 \(\mathbf{R} \in \mathbb{R}^{d \times d}\) 是旋转矩阵, 它作用在向量是相当于对两个两个维度地进行旋转. \(\theta_i = 10000^{-2i/d}\) 和最普通的 Sinusoidal 编码保持一致.
-
\(\mathbf{R}\bm{x}\) 有一种更加高效的方式:
-
下面是 LLaMA 中的实现方式:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# x: B, S, H, D
# freqs_cis: S, D // 2
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # (B, S, H, D // 2)
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # (B, S, H, D // 2)
freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # (1, S, 1, 1, D // 2)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
标签:cis,xk,bm,torch,freqs,RoPE,xq
From: https://www.cnblogs.com/MTandHJ/p/17577879.html