导包
#导入包
import torch
from torch import nn
import torch.nn.functional as f
import math
TokenEmbedding
#首先定义token embadding
from torch import Tensor
"""
将输入词汇表的索引转换成指定维度的Embedding
"""
class TokenEmbedding(nn.Embedding):
def __init__(self,vocab_size,d_model):
"""
初始化TokenEmbedding类。
参数:
vocab_size (int): 词汇表的大小。
d_model (int): Embedding的维度。
注意:
此类自动将索引为1的词汇视为填充词,并将其嵌入向量初始化为全零。
如果你不希望这样,可以手动设置padding_idx参数,或者将其设置为None。
"""
super(TokenEmbedding,self).__init__(vocab_size,d_model,padding_idx=1)
PositionalEmbedding
class PositionalEmbedding(nn.Module):
def __init__(self,d_model,max_len,device):
"""
初始化位置矩阵
"""
super(PositionalEmbedding,self).__init__()
#初始化0矩阵
self.encoding = torch.zeros(max_len,d_model,device=device)
#位置编码不需要优化,就不需要梯度更新
self.encoding.requires_grad = False
#定义pos,生成位置索引
pos = torch.arange(0,max_len)
pos = pos.to(device)
#类型转换为浮点型便于计算,在进行维度拓展为二维张量,利用广播机制自动对其
pos = pos.float().unsqueeze(dim=1)
#根据公式计算
frequencies_indices = torch.arange(0, d_model, step=2, device=device).float()
frequencies = 1.0/torch.pow(10000.0,frequencies_indices//d_model).unsqueeze(dim=0)
self.encoding[:,0::2] = torch.sin(pos*frequencies)
self.encoding[:,1::2] = torch.cos(pos*frequencies)
def forward(self,x):
#获取批量大小和序列长度
batch_size,seq_len = x.size()
return self.encoding[:seq_len,:]
TransformerEmbedding
class TransformerEmbedding(nn.Module):
def __init__(self,vocab_size,d_model,max_len,drop_prob,device):
super(TransformerEmbedding,self).__init__()
self.tok_emb = TokenEmbedding(vocab_size-vocab_size,d_model=d_model)
self.pos_emb = PositionalEmbedding(d_model=d_model,max_len=max_len,device=device)
self.drop_out=nn.Dropout(p=drop_prob)
def forward(self,x):
tok_emb = self.tok_emb(x)
pos_emb = self.pos_emb(x)
return self.drop_out(tok_emb+pos_emb)