首页 > 其他分享 >04 Transformer 中的位置编码的 Pytorch 实现

04 Transformer 中的位置编码的 Pytorch 实现

时间:2022-12-13 13:33:52浏览次数:33  
标签:dim Transformer 04 emb max self torch Pytorch pe

1:10 点赞

16:00


04 Transformer 中的位置编码的 Pytorch 实现_ss


我爱你

你爱我

04 Transformer 中的位置编码的 Pytorch 实现_ss_02



1401

class PositionalEncoding(nn.Module):

def __init__(self, dim, dropout, max_len=5000):
super(PositionalEncoding, self).__init__()

if dim % 2 != 0:
raise ValueError("Cannot use sin/cos positional encoding with "
"odd dim (got dim={:d})".format(dim))

"""
构建位置编码pe
pe公式为:
PE(pos,2i/2i+1) = sin/cos(pos/10000^{2i/d_{model}})
"""
pe = torch.zeros(max_len, dim) # max_len 是解码器生成句子的最长的长度,假设是 10
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *
-(math.log(10000.0) / dim)))


pe[:, 0::2] = torch.sin(position.float() * div_term)
pe[:, 1::2] = torch.cos(position.float() * div_term)
pe = pe.unsqueeze(1)
self.register_buffer('pe', pe)
self.drop_out = nn.Dropout(p=dropout)
self.dim = dim

def forward(self, emb, step=None):

emb = emb * math.sqrt(self.dim)

if step is None:
emb = emb + self.pe[:emb.size(0)]
else:
emb = emb + self.pe[step]
emb = self.drop_out(emb)
return emb



标签:dim,Transformer,04,emb,max,self,torch,Pytorch,pe
From: https://blog.51cto.com/u_13804357/5933856

相关文章