参考自RoPE旋转位置编码深度解析:理论推导、代码实现、长度外推 - 知乎 (zhihu.com)
位置编码: 1.绝对, 直接加到输入中. 2.相对,加在Attn的内积之前, 外推性能强。
ROPE:对Attn的K和V矩阵做ROPE
二维场景:
对于一个二维向量 :
偶数维的可以用拆成若干个2维的向量, 对这些向量分别用ROPE再拼接回原始维度(内积满足线性叠加性)
import torch def apply_rope(input_tensor): bs, seq_len, channels = input_tensor.shape # [bs, 77, 4] # 假设通道为偶数维度,这里c=4 assert channels % 2 == 0 # 对通道进行划分 (c=4 -> 两个二维向量) c1 = input_tensor[:, :, 0::2] # shape [bs, 77, 2] -> (c1, c3) c2 = input_tensor[:, :, 1::2] # shape [bs, 77, 2] -> (c2, c4) # 生成旋转角度 \theta_p positions = torch.arange(seq_len).unsqueeze(1) # [77, 1] theta = positions / (10000 ** (torch.arange(2) / channels)) # 根据RoPE公式计算theta cos_theta = torch.cos(theta).unsqueeze(0) # shape [1, 77, 2] sin_theta = torch.sin(theta).unsqueeze(0) # shape [1, 77, 2] # 对每对通道应用旋转矩阵 c1_prime = c1 * cos_theta - c2 * sin_theta # 旋转后的第一维 c2_prime = c1 * sin_theta + c2 * cos_theta # 旋转后的第二维 # 拼接回原来的通道维度 output_tensor = torch.stack([c1_prime, c2_prime], dim=-1).reshape(bs, seq_len, channels) return output_tensor # 假设输入是 [bs, 77, 4] 的张量 input_tensor = torch.randn(bs, 77, 4) output_tensor = apply_rope(input_tensor)
标签:编码,torch,tensor,位置,旋转,77,bs,theta,input From: https://www.cnblogs.com/alexlord/p/18431000