RoPE
流程总结 & RoPE介绍
通过绝对位置编码起到相对位置编码的效果,寻找一个函数\(f, g\),使得 \(<f_{q}\left(x_{m}, m\right), f_{k}\left(x_{n}, n\right)>=g\left(x_{m}, x_{n}, m-n\right)\)
这里的RoPE是在计算Attention的过程中引入的。
首先对于一个token的嵌入\(x\),通过\(W_Q, W_K\)映射得到\(q, k\)向量
通过该token的位置计算旋转角度,
将\(q, k\)向量两两维度进行分组,分别应用旋转变换
最终使用旋转之后的\(q, v\)进行attention计算。
通过应用token绝对位置信息以及旋转操作,最终内积的计算结果就等于\(x_m, x_n\)以及相对位置的计算结果。
非LLaMA实现
旋转位置编码的生成
def sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, device):
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(-1)
ids = torch.arange(0, output_dim // 2, dtype=torch.float)
theta = torch.pow(10000, -2 * ids / output_dim)
# 公式里的 mθ
embeddings = position * theta
embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape))))
embeddings = torch.reshape(embeddings, (batch_size, nums_head, max_len, output_dim))
embeddings = embeddings.to(device)
return embeddings
(q, k): bs, head, max_len, output_dim
torch.Size([8, 12, 10, 32])
1. 根据 bs, num_head, max_len, output_dim 计算位置编码矩阵
① 根据max_len生成绝对位置编码 position,torch.Size([10, 1])
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(-1) # 结果 tensor([[0.], [1.], [2.], [3.], [4.], [5.], [6.], [7.], [8.], [9.]])
② 计算θ,torch.Size([16])
ids = torch.arange(0, output_dim // 2, dtype=torch.float) theta = torch.pow(10000, -2 * ids / output_dim) # 结果 ids = tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.]) theta = tensor([1.0000e+00, 5.6234e-01, 3.1623e-01, 1.7783e-01, 1.0000e-01, 5.6234e-02, 3.1623e-02, 1.7783e-02, 1.0000e-02, 5.6234e-03, 3.1623e-03, 1.7783e-03, 1.0000e-03, 5.6234e-04, 3.1623e-04, 1.7783e-04])
③ 计算公式中对应的 \(m\theta\),torch.Size([10, 16])
embeddings = position * theta # 结果 tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [1.0000e+00, 5.6234e-01, 3.1623e-01, 1.7783e-01, 1.0000e-01, 5.6234e-02, 3.1623e-02, 1.7783e-02, 1.0000e-02, 5.6234e-03, 3.1623e-03, 1.7783e-03, 1.0000e-03, 5.6234e-04, 3.1623e-04, 1.7783e-04], [2.0000e+00, 1.1247e+00, 6.3246e-01, 3.5566e-01, 2.0000e-01, 1.1247e-01, 6.3246e-02, 3.5566e-02, 2.0000e-02, 1.1247e-02, 6.3246e-03, 3.5566e-03, 2.0000e-03, 1.1247e-03, 6.3246e-04, 3.5566e-04], [3.0000e+00, 1.6870e+00, 9.4868e-01, 5.3348e-01, 3.0000e-01, 1.6870e-01, 9.4868e-02, 5.3348e-02, 3.0000e-02, 1.6870e-02, 9.4868e-03, 5.3348e-03, 3.0000e-03, 1.6870e-03, 9.4868e-04, 5.3348e-04], [4.0000e+00, 2.2494e+00, 1.2649e+00, 7.1131e-01, 4.0000e-01, 2.2494e-01, 1.2649e-01, 7.1131e-02, 4.0000e-02, 2.2494e-02, 1.2649e-02, 7.1131e-03, 4.0000e-03, 2.2494e-03, 1.2649e-03, 7.1131e-04], [5.0000e+00, 2.8117e+00, 1.5811e+00, 8.8914e-01, 5.0000e-01, 2.8117e-01, 1.5811e-01, 8.8914e-02, 5.0000e-02, 2.8117e-02, 1.5811e-02, 8.8914e-03, 5.0000e-03, 2.8117e-03, 1.5811e-03, 8.8914e-04], [6.0000e+00, 3.3740e+00, 1.8974e+00, 1.0670e+00, 6.0000e-01, 3.3740e-01, 1.8974e-01, 1.0670e-01, 6.0000e-02, 3.3740e-02, 1.8974e-02, 1.0670e-02, 6.0000e-03, 3.3740e-03, 1.8974e-03, 1.0670e-03], [7.0000e+00, 3.9364e+00, 2.2136e+00, 1.2448e+00, 7.0000e-01, 3.9364e-01, 2.2136e-01, 1.2448e-01, 7.0000e-02, 3.9364e-02, 2.2136e-02, 1.2448e-02, 7.0000e-03, 3.9364e-03, 2.2136e-03, 1.2448e-03], [8.0000e+00, 4.4987e+00, 2.5298e+00, 1.4226e+00, 8.0000e-01, 4.4987e-01, 2.5298e-01, 1.4226e-01, 8.0000e-02, 4.4987e-02, 2.5298e-02, 1.4226e-02, 8.0000e-03, 4.4987e-03, 2.5298e-03, 1.4226e-03], [9.0000e+00, 5.0611e+00, 2.8460e+00, 1.6005e+00, 9.0000e-01, 5.0611e-01, 2.8460e-01, 1.6005e-01, 9.0000e-02, 5.0611e-02, 2.8460e-02, 1.6005e-02, 9.0000e-03, 5.0611e-03, 2.8460e-03, 1.6005e-03]])
④ 分别计算sin(mθ)和cos(mθ),并进行合并,torch.Size([10, 16, 2])
embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) # 中间结果 # sin(mθ) torch.Size([10, 16]) tensor([[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [ 8.4147e-01, 5.3317e-01, 3.1098e-01, 1.7689e-01, 9.9833e-02, 5.6204e-02, 3.1618e-02, 1.7782e-02, 9.9998e-03, 5.6234e-03, 3.1623e-03, 1.7783e-03, 1.0000e-03, 5.6234e-04, 3.1623e-04, 1.7783e-04], ... ]) # cos(mθ) torch.Size([10, 16]) tensor([[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], [ 0.5403, 0.8460, 0.9504, 0.9842, 0.9950, 0.9984, 0.9995, 0.9998, 0.9999, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], ... ]) # 结果 torch.Size([10, 16, 2]) # torch.stack([sin(mθ), cos(mθ)], dim=-1) # 一个sin一个cos组成 tensor([[[ 0.0000e+00, 1.0000e+00], [ 0.0000e+00, 1.0000e+00], [ 0.0000e+00, 1.0000e+00], [ 0.0000e+00, 1.0000e+00], [ 0.0000e+00, 1.0000e+00], [ 0.0000e+00, 1.0000e+00], [ 0.0000e+00, 1.0000e+00], [ 0.0000e+00, 1.0000e+00], [ 0.0000e+00, 1.0000e+00], [ 0.0000e+00, 1.0000e+00], [ 0.0000e+00, 1.0000e+00], [ 0.0000e+00, 1.0000e+00], [ 0.0000e+00, 1.0000e+00], [ 0.0000e+00, 1.0000e+00], [ 0.0000e+00, 1.0000e+00], [ 0.0000e+00, 1.0000e+00]], [[ 8.4147e-01, 5.4030e-01], [ 5.3317e-01, 8.4601e-01], [ 3.1098e-01, 9.5042e-01], [ 1.7689e-01, 9.8423e-01], [ 9.9833e-02, 9.9500e-01], [ 5.6204e-02, 9.9842e-01], [ 3.1618e-02, 9.9950e-01], [ 1.7782e-02, 9.9984e-01], [ 9.9998e-03, 9.9995e-01], [ 5.6234e-03, 9.9998e-01], [ 3.1623e-03, 9.9999e-01], [ 1.7783e-03, 1.0000e+00], [ 1.0000e-03, 1.0000e+00], [ 5.6234e-04, 1.0000e+00], [ 3.1623e-04, 1.0000e+00], [ 1.7783e-04, 1.0000e+00]], ... ])
⑤ 在bs维度重复,其他维度都是1不重复,torch.Size([8, 12, 10, 16, 2])
embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape)))) # 这里每个head应该也重复了
⑥ reshape,将sin和cos进行合并,torch.Size([8, 12, 10, 32])
# 以下为一个batch,一个head上的PE # 一个sin一个cos进行排列 tensor([[ 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00], [ 8.4147e-01, 5.4030e-01, 5.3317e-01, 8.4601e-01, 3.1098e-01, 9.5042e-01, 1.7689e-01, 9.8423e-01, 9.9833e-02, 9.9500e-01, 5.6204e-02, 9.9842e-01, 3.1618e-02, 9.9950e-01, 1.7782e-02, 9.9984e-01, 9.9998e-03, 9.9995e-01, 5.6234e-03, 9.9998e-01, 3.1623e-03, 9.9999e-01, 1.7783e-03, 1.0000e+00, 1.0000e-03, 1.0000e+00, 5.6234e-04, 1.0000e+00, 3.1623e-04, 1.0000e+00, 1.7783e-04, 1.0000e+00], ... ])
应用旋转变换
def RoPE(q, k):
# q,k: (bs, head, max_len, output_dim)
batch_size = q.shape[0]
nums_head = q.shape[1]
max_len = q.shape[2]
output_dim = q.shape[-1]
# (bs, head, max_len, output_dim)
# 计算位置编码的矩阵
pos_emb = sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device)
# cos_pos,sin_pos: (bs, head, max_len, output_dim)
# 看rope公式可知,相邻cos,sin之间是相同的,所以复制一遍。如(1,2,3)变成(1,1,2,2,3,3)
cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) # 将奇数列信息抽取出来也就是cos 拿出来并复制
sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) # 将偶数列信息抽取出来也就是sin 拿出来并复制
# q,k: (bs, head, max_len, output_dim)
# 奇数位置为负,偶数位置不变
q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)
q2 = q2.reshape(q.shape) # reshape后就是正负交替了
# 更新qw, *对应位置相乘
q = q * cos_pos + q2 * sin_pos
k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1)
k2 = k2.reshape(k.shape)
# 更新kw, *对应位置相乘
k = k * cos_pos + k2 * sin_pos
return q, k
标签:02,编码,00,01,03,1.0000,RoPE,0.0000,梳理 From: https://www.cnblogs.com/mudou/p/18307600在上述pos_emb返回的基础之上
① 将cos和sin分别从pos_emb中取出来,并将每个元素复制一份,torch.Size([8, 12, 10, 32])
cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) # 步长为2取出cos和sin,但是复制了一份,所以形状还是 torch.Size([8, 12, 10, 32]) # cos结果,可见相邻的两个cos值是复制的相同的。 tensor([[[[ 1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000], [ 0.5403, 0.5403, 0.8460, ..., 1.0000, 1.0000, 1.0000], [-0.4161, -0.4161, 0.4315, ..., 1.0000, 1.0000, 1.0000], ..., [ 0.7539, 0.7539, -0.7004, ..., 1.0000, 1.0000, 1.0000], [-0.1455, -0.1455, -0.2120, ..., 1.0000, 1.0000, 1.0000], [-0.9111, -0.9111, 0.3417, ..., 1.0000, 1.0000, 1.0000]], ... ]])
\(\left[\begin{array}{ccccccc} \cos m \theta_{0} & -\sin m \theta_{0} & 0 & 0 & \cdots & 0 & 0 \\ \sin m \theta_{0} & \cos m \theta_{0} & 0 & 0 & \cdots & 0 & 0 \\ 0 & 0 & \cos m \theta_{1} & -\sin m \theta_{1} & \cdots & 0 & 0 \\ 0 & 0 & \sin m \theta_{1} & \cos m \theta_{1} & \cdots & 0 & 0 \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & 0 & 0 & \cdots & \cos m \theta_{d / 2-1} & -\sin m \theta_{d / 2-1} \\ 0 & 0 & 0 & 0 & \cdots & \sin m \theta_{d / 2-1} & \cos m \theta_{d / 2-1} \end{array}\right]\left[\begin{array}{c} q_{0} \\ q_{1} \\ q_{2} \\ q_{3} \\ \vdots \\ q_{d-2} \\ q_{d-1} \end{array}\right]=\left[\begin{array}{c} q_{0} \\ q_{1} \\ q_{2} \\ q_{3} \\ \vdots \\ q_{d-2} \\ q_{d-1} \end{array}\right] \otimes\left[\begin{array}{c} \cos m \theta_{0} \\ \cos m \theta_{0} \\ \cos m \theta_{1} \\ \cos m \theta_{1} \\ \vdots \\ \cos m \theta_{d / 2-1} \\ \cos m \theta_{d / 2-1} \end{array}\right]+\left[\begin{array}{c} -q_{1} \\ q_{0} \\ -q_{3} \\ q_{2} \\ \vdots \\ -q_{d-1} \\ q_{d-2} \end{array}\right] \otimes\left[\begin{array}{c} \sin m \theta_{0} \\ \sin m \theta_{0} \\ \sin m \theta_{1} \\ \sin m \theta_{1} \\ \vdots \\ \sin m \theta_{d / 2-1} \\ \sin m \theta_{d / 2-1} \end{array}\right]\)
② 对\(q\)进行重构,构建正负号交替,便于后续计算。
# 将奇数位置与偶数位置调换位置,并且奇数位置添加符号 q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1) q2 = q2.reshape(q.shape) # 结果 q [ q0, q1, q2, q3, q4, q5, ..., q30, q31] q2 [-1q, q0, -q3, q2, -q5, q4, ..., -q31, q30] # 可以对照上述公式,为了方便计算,进行这样构造
更新\(k\)同理
③ 更新\(q\),也就是进行旋转变换,对应位置相乘
q = q * cos_pos + q2 * sin_pos # torch.Size([8, 12, 10, 32]) # cos_pos torch.Size([8, 12, 10, 32]) # cos_pos排列 第一个token位置 [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.] 第二个token位置 [0.5403, 0.5403, 0.8460, 0.8460, 0.9504, 0.9504, 0.9842, 0.9842, 0.9950, 0.9950, 0.9984, 0.9984, 0.9995, 0.9995, 0.9998, 0.9998, 0.9999, 0.9999, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000] 第三个token位置 [-0.4161, -0.4161, 0.4315, 0.4315, 0.8066, 0.8066, 0.9374, 0.9374, 0.9801, 0.9801, 0.9937, 0.9937, 0.9980, 0.9980, 0.9994, 0.9994, 0.9998, 0.9998, 0.9999, 0.9999, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]
更新\(k\)同理