title: Transformer源码
date: 2022-10-09 18:30:35
mathjax: true
tags:
- Encoder
- Decoder
- Transformer
Transformer代码(源码Pytorch版本)从零解读(Pytorch版本)_哔哩哔哩_bilibili
执行流程
1. 构建语料库
sentences = ['ich mochte ein bier P', 'S i want a beer', 'i want a beer E']
src_vocab = {'P': 0, 'ich': 1, 'mochte': 2, 'ein': 3, 'bier': 4}
tgt_vocab = {'P': 0, 'i': 1, 'want': 2, 'a': 3, 'beer': 4, 'S': 5, 'E': 6}
2. 生成批量数据
def make_batch(sentences):
input_batch = [[src_vocab[n] for n in sentences[0].split()]]
output_batch = [[tgt_vocab[n] for n in sentences[1].split()]]
target_batch = [[tgt_vocab[n] for n in sentences[2].split()]]
return torch.LongTensor(input_batch), torch.LongTensor(output_batch), torch.LongTensor(target_batch)
3. Encoder部分
3.1 Embedding
为什么输出是(1,5,512)?
根据官网API可以指导,输出是由输入shape和dim组成
3.2 PostionalEncoding(位置编码)
enc_outputs = self.pos_emb(enc_outputs.transpose(0, 1)).transpose(0, 1)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
3.3 将那些pad标记出来
因为要将单词转换矩阵来处理方便些,前提是长度都是一样,长度不够就去填充,长度多了就砍掉
enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs)
def get_attn_pad_mask(seq_q, seq_k):
batch_size, len_q = seq_q.size()
batch_size, len_k = seq_k.size()
pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)
return pad_attn_mask.expand(batch_size, len_q, len_k)
3.4 多头注意力
# Encoder
enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)
# EncoderLayer
self.enc_self_attn = MultiHeadAttention()
enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)
class MultiHeadAttention(nn.Module):
def __init__(self):
super(MultiHeadAttention, self).__init__()
self.W_Q = nn.Linear(d_model, d_k * n_heads)
self.W_K = nn.Linear(d_model, d_k * n_heads)
self.W_V = nn.Linear(d_model, d_v * n_heads)
self.linear = nn.Linear(n_heads * d_v, d_model)
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, Q, K, V, attn_mask):
residual, batch_size = Q, Q.size(0)
q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)
k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)
v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)
attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v)
output = self.linear(context)
return self.layer_norm(output + residual), attn
class ScaledDotProductAttention(nn.Module):
def __init__(self):
super(ScaledDotProductAttention, self).__init__()
def forward(self, Q, K, V, attn_mask):
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)
scores.masked_fill_(attn_mask, -1e9)
attn = nn.Softmax(dim=-1)(scores)
context = torch.matmul(attn, V)
return context, attn
- repeat
- contiguous
contiguous一般与transpose,permute,view搭配使用:使用transpose或permute进行维度变换后,调用contiguous,然后方可使用view对维度进行变形。
transpose、permute等维度变换操作后,tensor在内存中不再是连续存储的,而view操作要求tensor的内存连续存储,所以需要contiguous来返回一个contiguous copy;
Pytorch之contiguous函数 - 知乎 (zhihu.com)
3.5 FFN
这里的FNN使用的是卷积
enc_outputs = self.pos_ffn(enc_outputs)
class PoswiseFeedForwardNet(nn.Module):
def __init__(self):
super(PoswiseFeedForwardNet, self).__init__()
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, inputs):
residual = inputs
output = nn.ReLU()(self.conv1(inputs.transpose(1, 2)))
output = self.conv2(output).transpose(1, 2)
return self.layer_norm(output + residual)
3.6 Encoder代码
class EncoderLayer(nn.Module):
def __init__(self):
super(EncoderLayer, self).__init__()
self.enc_self_attn = MultiHeadAttention()
self.pos_ffn = PoswiseFeedForwardNet()
def forward(self, enc_inputs, enc_self_attn_mask):
enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)
enc_outputs = self.pos_ffn(enc_outputs)
return enc_outputs, attn
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.src_emb = nn.Embedding(src_vocab_size, d_model)
self.pos_emb = PositionalEncoding(d_model)
# N× Encoder
self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
def forward(self, enc_inputs):
enc_outputs = self.src_emb(enc_inputs)
enc_outputs = self.pos_emb(enc_outputs.transpose(0, 1)).transpose(0, 1)
enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs)
enc_self_attns = []
for layer in self.layers:
enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)
# 把attn放入到一个list中去
enc_self_attns.append(enc_self_attn)
return enc_outputs, enc_self_attns
enc_outputs, enc_self_attns = self.encoder(enc_inputs)
4. Decoder部分
4.1 Embedding
self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
dec_outputs = self.tgt_emb(dec_inputs)
4.2 PositionalEncoding
self.pos_emb = PositionalEncoding(d_model)
dec_outputs = self.pos_emb(dec_outputs.transpose(0, 1)).transpose(0, 1)
4.3 标记dec_inputs的pad
dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs)
def get_attn_pad_mask(seq_q, seq_k):
batch_size, len_q = seq_q.size()
batch_size, len_k = seq_k.size()
pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)
return pad_attn_mask.expand(batch_size, len_q, len_k)
4.4 生成上三角
dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs)
def get_attn_subsequent_mask(seq):
attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
subsequence_mask = np.triu(np.ones(attn_shape), k=1)
subsequence_mask = torch.from_numpy(subsequence_mask).byte()
return subsequence_mask
- triu
4.5 两矩阵相加
dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)
4.6 标记enc_inputs的pad
dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs)
4.7 Masked Multi-Head attention
dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)
self.dec_self_attn = MultiHeadAttention()
dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
class MultiHeadAttention(nn.Module):
def __init__(self):
super(MultiHeadAttention, self).__init__()
self.W_Q = nn.Linear(d_model, d_k * n_heads)
self.W_K = nn.Linear(d_model, d_k * n_heads)
self.W_V = nn.Linear(d_model, d_v * n_heads)
self.linear = nn.Linear(n_heads * d_v, d_model)
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, Q, K, V, attn_mask):
residual, batch_size = Q, Q.size(0)
q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)
k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)
v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)
attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v)
output = self.linear(context)
return self.layer_norm(output + residual), attn
class ScaledDotProductAttention(nn.Module):
def __init__(self):
super(ScaledDotProductAttention, self).__init__()
def forward(self, Q, K, V, attn_mask):
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)
scores.masked_fill_(attn_mask, -1e9)
attn = nn.Softmax(dim=-1)(scores)
context = torch.matmul(attn, V)
return context, attn
4.8 多头注意力机制
dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
class MultiHeadAttention(nn.Module):
def __init__(self):
super(MultiHeadAttention, self).__init__()
self.W_Q = nn.Linear(d_model, d_k * n_heads)
self.W_K = nn.Linear(d_model, d_k * n_heads)
self.W_V = nn.Linear(d_model, d_v * n_heads)
self.linear = nn.Linear(n_heads * d_v, d_model)
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, Q, K, V, attn_mask):
residual, batch_size = Q, Q.size(0)
q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)
k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)
v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)
attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v)
output = self.linear(context)
return self.layer_norm(output + residual), attn
class ScaledDotProductAttention(nn.Module):
def __init__(self):
super(ScaledDotProductAttention, self).__init__()
def forward(self, Q, K, V, attn_mask):
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)
scores.masked_fill_(attn_mask, -1e9)
attn = nn.Softmax(dim=-1)(scores)
context = torch.matmul(attn, V)
return context, attn
4.9 FNN
dec_outputs = self.pos_ffn(dec_outputs)
class PoswiseFeedForwardNet(nn.Module):
def __init__(self):
super(PoswiseFeedForwardNet, self).__init__()
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, inputs):
residual = inputs # inputs : [batch_size, len_q, d_model]
output = nn.ReLU()(self.conv1(inputs.transpose(1, 2)))
output = self.conv2(output).transpose(1, 2)
return self.layer_norm(output + residual)
5. 线性层
self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False)
dec_logits = self.projection(dec_outputs)
dec_logits.view(-1, dec_logits.size(-1))
6. 整体流程图
7. 简化版
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
import math
sentences = ['ich mochte ein bier P', 'S i want a beer', 'i want a beer E']
src_vocab = {'P': 0, 'ich': 1, 'mochte': 2, 'ein': 3, 'bier': 4}
src_vocab_size = len(src_vocab)
tgt_vocab = {'P': 0, 'i': 1, 'want': 2, 'a': 3, 'beer': 4, 'S': 5, 'E': 6}
tgt_vocab_size = len(tgt_vocab)
idx_to_vocab = list(tgt_vocab.keys())
print(idx_to_vocab)
enc_inputs = np.array(sentences)[0]
print(enc_inputs)
enc_inputs = torch.LongTensor([[src_vocab[item] for item in enc_inputs.split(' ') ]])
enc_inputs_len = enc_inputs.shape[1]
print(enc_inputs_len)
dec_inputs = np.array(sentences)[1]
dec_inputs = torch.LongTensor([[tgt_vocab[item] for item in dec_inputs.split(' ') ]])
dec_inputs_len = dec_inputs.shape[1]
tgt_outputs = np.array(sentences)[2]
tgt_outputs = torch.LongTensor([[tgt_vocab[item] for item in tgt_outputs.split(' ') ]])
tgt_outputs_len = tgt_outputs.shape[1]
print(tgt_outputs)
class FFN(nn.Module):
def __init__(self):
super(FFN,self).__init__()
self.module = nn.Sequential(
nn.Linear(d_model,d_ff),
nn.ReLU(),
nn.Linear(d_ff,d_model)
)
self.layerNorm = nn.LayerNorm(d_model)
def forward(self,x):
outputs = self.module(x)
return self.layerNorm(outputs+x)
def get_attn_pad_mask(q_k,q_v):
batch_size,len_q = q_k.shape
batch_size,len_k = q_v.shape
q = q_k.data.eq(0)
return q.expand(batch_size,len_q,len_k)
class PositionalEncoding(nn.Module):
def __init__(self,d_model,max_len=500):
super(PositionalEncoding,self).__init__()
pe = torch.randn((max_len,d_model))
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe',pe)
def forward(self,x):
return x + self.pe[:,:x.size(1),:]
class ScaleDotProduction(nn.Module):
def __init__(self):
super(ScaleDotProduction,self).__init__()
def forward(self,Q,K,V,attn):
scores = torch.matmul(Q,K.transpose(-1,-2)) / np.sqrt(d_k)
scores = scores.masked_fill_(attn,-1e9)
return scores
class MultiHeadAttention(nn.Module):
def __init__(self):
super(MultiHeadAttention,self).__init__()
self.W_Q = nn.Linear(d_model,heads*d_k)
self.W_K = nn.Linear(d_model,heads*d_k)
self.W_V = nn.Linear(d_model,heads*d_v)
self.attention = ScaleDotProduction()
self.linear = nn.Linear(heads*d_v,d_model)
self.layerNorm = nn.LayerNorm(d_model)
def forward(self,Q,K,V,attn=1):
residual = Q
q_s = self.W_Q(Q).view(batch_size,-1,heads,d_k).transpose(1,2)
k_s = self.W_K(K).view(batch_size,-1,heads,d_k).transpose(1,2)
v_s = self.W_V(V).view(batch_size,-1,heads,d_k).transpose(1,2)
scores = self.attention(q_s,k_s,v_s,attn)
attn = nn.Softmax(dim=-1)(scores)
contexts = torch.matmul(attn,v_s).transpose(1,2).contiguous().view(batch_size,enc_inputs_len,heads*d_v)
contexts = self.linear(contexts)
return self.layerNorm(residual+contexts),attn
class EncoderLayer(nn.Module):
def __init__(self):
super(EncoderLayer,self).__init__()
self.mul = MultiHeadAttention()
self.ffn = FFN()
def forward(self,enc_inputs,enc_attn_mask):
enc_outputs,attn = self.mul(enc_inputs,enc_inputs,enc_inputs,enc_attn_mask)
enc_outputs = self.ffn(enc_outputs)
return enc_outputs,attn
class Encoder(nn.Module):
def __init__(self):
super(Encoder,self).__init__()
self.emb = nn.Embedding(src_vocab_size,d_model)
self.pe = PositionalEncoding(d_model)
self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
def forward(self,enc_inputs):
enc_outputs = self.emb(enc_inputs)
enc_outputs = self.pe(enc_outputs)
enc_self_attn_mask = get_attn_pad_mask(enc_inputs,enc_inputs)
enc_self_attns = []
for layer in self.layers:
enc_outputs,enc_self_attn = layer(enc_outputs,enc_self_attn_mask)
enc_self_attns.append(enc_self_attn)
return enc_outputs,enc_self_attns
def get_attn_subsequence_mask(seq):
attn_shape = [seq.size(0),seq.size(1),seq.size(1)]
subsequence_mask = np.triu(np.ones(attn_shape),k=1)
return subsequence_mask
class DecoderLayer(nn.Module):
def __init__(self):
super(DecoderLayer,self).__init__()
self.mul = MultiHeadAttention()
self.ffn = FFN()
def forward(self,dec_inputs,enc_outputs,dec_self_attn_mask,dec_enc_attn_mask):
dec_outputs,dec_self_attn = self.mul(dec_inputs,dec_inputs,dec_inputs,dec_self_attn_mask)
dec_outputs,dec_enc_attn = self.mul(dec_outputs,enc_outputs,enc_outputs,dec_enc_attn_mask)
dec_outputs = self.ffn(dec_outputs)
return dec_outputs,dec_self_attn,dec_enc_attn
class Decoder(nn.Module):
def __init__(self):
super(Decoder,self).__init__()
self.emb = nn.Embedding(tgt_vocab_size,d_model)
self.pe = PositionalEncoding(d_model)
self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])
def forward(self,dec_inputs,enc_inputs,enc_outputs):
dec_outputs = self.emb(dec_inputs)
dec_outputs = self.pe(dec_outputs)
dec_self_attns = []
dec_enc_attns = []
dec_self_attn = get_attn_pad_mask(dec_inputs,dec_inputs)
triu = get_attn_subsequence_mask(dec_inputs)
dec_self_attn_mask = torch.gt((dec_self_attn + triu),0)
dec_enc_attn_mask = get_attn_pad_mask(dec_inputs,enc_inputs)
for layer in self.layers:
dec_outputs,dec_self_attn,dec_enc_attn = layer(dec_outputs,enc_outputs,dec_self_attn_mask,dec_enc_attn_mask)
dec_self_attns.append(dec_self_attn)
dec_enc_attns.append(dec_enc_attn)
return dec_outputs,dec_self_attns,dec_enc_attns
class Transformer(nn.Module):
def __init__(self):
super(Transformer,self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
self.projection = nn.Linear(d_model,tgt_vocab_size,bias=False)
def forward(self,enc_inputs,dec_inputs):
enc_outputs,enc_self_attns = self.encoder(enc_inputs)
dec_outputs,dec_self_attns,dec_enc_attns = self.decoder(dec_inputs,enc_inputs,enc_outputs)
dec_logits = self.projection(dec_outputs)
return dec_logits.view(-1,dec_logits.size(-1)),enc_self_attns,dec_self_attns,dec_enc_attns
def make_batch(sentences):
input_batch = [[src_vocab[n] for n in sentences[0].split()]]
output_batch = [[tgt_vocab[n] for n in sentences[1].split()]]
target_batch = [[tgt_vocab[n] for n in sentences[2].split()]]
return torch.LongTensor(input_batch), torch.LongTensor(output_batch), torch.LongTensor(target_batch)
enc_inputs, dec_inputs, target_batch = make_batch(sentences)
batch_size = 1
d_model = 512
heads = 8
d_k = d_v = 64
n_layers = 6
d_ff = 2048
epochs = 30
model = Transformer()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=0.001)
for epoch in range(epochs):
optimizer.zero_grad()
pro_logits,enc_self_attns,dec_self_attns,dec_enc_attns = model(enc_inputs,dec_inputs)
loss = criterion(pro_logits,tgt_outputs.contiguous().view(-1))
print(loss)
loss.backward()
optimizer.step()
outputs = torch.argmax(pro_logits,-1)
tgt_output = [idx_to_vocab[key] for key in outputs]
str = ' '
print(str.join(tgt_output))
标签:inputs,Transformer,enc,outputs,self,源码,attn,dec
From: https://www.cnblogs.com/bzwww/p/16805783.html