1 旋转角度计算
计算公式如下,其中d为词嵌入维度,
\[\theta_j=10000^{-2(j-1)/d},j\in [1,2,\ldots,d/2] \]# 计算词向量元素两两分组之后,每组元素对应的旋转角度
# 维度:[dim / 2]
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
2 计算整个seq的cos_sin矩阵
def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
# 计算词向量元素两两分组之后,每组元素对应的旋转角度
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
# 生成 token 序列索引 t = [0, 1,..., seq_len-1]
t = torch.arange(seq_len, device=freqs.device)
# freqs.shape = [seq_len, dim // 2]
freqs = torch.outer(t, freqs).float()
# torch.polar计算得到每个值的复数向量
# 假设 freqs = [[x, ..., y]]
# 则 freqs_cis = [[cos(x) + sin(x)i, ..., cos(y) + sin(y)i]]
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
3 计算旋转式位置编码
\[\begin{aligned}f_q(x_m,m)&=(W_qx_m)e^{im\theta} \\f_k(x_n,n)&=(W_kx_n)e^{in\theta}\end{aligned} \]公式根据欧拉公式转化后为
标签:dim,float,seq,torch,freqs,Rotary,Embedding,theta,Position From: https://www.cnblogs.com/liangyming/p/17816131.html