注意力实现:
import math import torch from torch import nn import matplotlib.pyplot as plt from d2l import torch as d2l def sequence_mask(X, valid_len, value=0): """在序列中屏蔽不相关的项""" max_len = X.size(1) mask = torch.arange((max_len), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None] X[~mask] = value return X def masked_softmax(x, valid_lens): if valid_lens is None: return nn.functional.softmax(x, dim=-1) else: shape = x.shape if valid_lens.dim() == 1: valid_lens = torch.repeat_interleave(valid_lens, shape[1]) else: valid_lens = valid_lens.reshape(-1) # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0 x = sequence_mask(x.reshape(-1, shape[-1]), valid_lens, value=-1e6) return nn.functional.softmax(x.reshape(shape), dim=-1) x = torch.ones(2, 3, 4) print(masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))) # 加性注意力: class AdditiveAttention(nn.Module): """加性注意力""" def __init__(self, key_size, query_size, num_hidden, dropout, **kwargs): super(AdditiveAttention, self).__init__(**kwargs) self.w_k = nn.Linear(key_size, num_hidden, bias=False) self.w_q = nn.Linear(query_size, num_hidden, bias=False) self.w_v = nn.Linear(num_hidden, 1, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, queries, keys, values, valid_lens): queries, keys = self.w_q(queries), self.w_k(keys) features = queries.unsqueeze(2) + keys.unsqueeze(1) features = torch.tanh(features) scores = self.w_v(features).squeeze(-1) self.attention_weights = masked_softmax(scores, valid_lens) return torch.bmm(self.dropout(self.attention_weights), values) queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2)) print("queries:") print(queries) print("keys:") print(keys) values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1) print("values:") print(values) valid_lens = torch.tensor([2, 6]) attention = AdditiveAttention(key_size=2, query_size=20, num_hidden=8, dropout=0.1) attention.eval() print(attention(queries, keys, values, valid_lens)) d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)), xlabel='Keys', ylabel='Queries') plt.show() # 点积模型: class DotProductAttention(nn.Module): def __init__(self, dropout, **kwargs): super(DotProductAttention, self).__init__(**kwargs) self.dropout = nn.Dropout(dropout) # queries的形状:(batch_size,查询的个数,d) # keys的形状:(batch_size,“键-值”对的个数,d) # values的形状:(batch_size,“键-值”对的个数,值的维度) # valid_lens的形状:(batch_size,)或者(batch_size,查询的个数) def forward(self, queries, keys, values, valid_lens=None): d = queries.shape[-1] scores = torch.bmm(queries, keys.transpose(1, 2) / math.sqrt(d)) self.attention_weights = masked_softmax(scores, valid_lens) return torch.bmm(self.dropout(self.attention_weights), values) queries = torch.normal(0, 1, (2, 1, 2)) attention = DotProductAttention(dropout=0.5) attention.eval() print(attention(queries, keys, values, valid_lens)) d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)), xlabel='Keys', ylabel='Queries') plt.show()
seq2seq经过action优化之后:
import torch from torch import nn from d2l import torch as d2l import matplotlib.pyplot as plt # @save class AttentionDecoder(d2l.Decoder): """带有注意力机制解码器的基本接口""" def __init__(self, **kwargs): super(AttentionDecoder, self).__init__(**kwargs) @property def attention_weights(self): raise NotImplementedError class Seq2SeqAttentionDecoder(AttentionDecoder): def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs): super(Seq2SeqAttentionDecoder, self).__init__(**kwargs) self.attention = d2l.AdditiveAttention(num_hiddens, num_hiddens, num_hiddens, dropout) self.embedding = nn.Embedding(vocab_size, embed_size) self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout) self.dense = nn.Linear(num_hiddens, vocab_size) def init_state(self, enc_outputs, enc_valid_lens, *args): # outputs的形状为(batch_size,num_steps,num_hidden). # hidden_state的形状为(num_layers,batch_size,num_hidden) outputs, hidden_state = enc_outputs return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens) def forward(self, x, state): # enc_outputs的形状为(batch_size,num_steps,num_hidden). # hidden_state的形状为(num_layers,batch_size, # num_hidden) enc_outputs, hidden_state, enc_valid_lens = state # 输出X的形状为(num_steps,batch_size,embed_size) x = self.embedding(x).permute(1, 0, 2) outputs, self._attention_weights = [], [] for x_ in x: query = torch.unsqueeze(hidden_state[-1], dim=1) # query的形状为(batch_size,1,num_hidden) context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens) # context的形状为(batch_size,1,num_hidden) x_ = torch.cat((context, torch.unsqueeze(x_, dim=1)), dim=-1) # 将x变形为(1,batch_size,embed_size+num_hidden) out, hidden_state = self.rnn(x_.permute(1, 0, 2), hidden_state) outputs.append(out) self._attention_weights.append(self.attention.attention_weights) # 全连接层变换后,outputs的形状为 # (num_steps,batch_size,vocab_size) outputs = self.dense(torch.cat(outputs, dim=0)) return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens] @property def attention_weights(self): return self._attention_weights encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) encoder.eval() decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) decoder.eval() x = torch.zeros((4, 7), dtype=torch.long) state = decoder.init_state(encoder(x), None) output, state = decoder(x, state) print(output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape) embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1 batch_size, num_steps = 64, 10 lr, num_epochs, device = 0.005, 250, d2l.try_gpu() train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps) encoder = d2l.Seq2SeqEncoder( len(src_vocab), embed_size, num_hiddens, num_layers, dropout) decoder = Seq2SeqAttentionDecoder( len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout) net = d2l.EncoderDecoder(encoder, decoder) print(d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)) plt.show() engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .'] fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .'] for eng, fra in zip(engs, fras): translation, dec_attention_weight_seq = d2l.predict_seq2seq( net, eng, src_vocab, tgt_vocab, num_steps, device, True) print(f'{eng} => {translation}, ', f'bleu {d2l.bleu(translation, fra, k=2):.3f}')
标签:attention,torch,self,Attention,num,valid,size From: https://www.cnblogs.com/o-Sakurajimamai-o/p/17720475.html