首页 > 其他分享 >Transformer模型:intra-attention mask实现

Transformer模型:intra-attention mask实现

时间:2024-07-14 14:57:49浏览次数:11  
标签:src Transformer torch max attention mask pos len valid

前言

        这是对Transformer模型Word Embedding、Postion Embedding、Encoder self-attention mask内容的续篇。

视频链接:20、Transformer模型Decoder原理精讲及其PyTorch逐行实现_哔哩哔哩_bilibili

文章链接:Transformer模型:WordEmbedding实现-CSDN博客 

                  Transformer模型:Postion Embedding实现-CSDN博客

                  Transformer模型:Encoder的self-attention mask实现-CSDN博客


正文

        首先介绍一下intra-attention mask,它指的是Decocder中间的部分:因为目标序列样本间的长度都是不一样的,而原序列间样本间的长度也是不一样的,并且一对之间的长度也是不一样的,所以如果目标序列的某个位置跟原序列之间的某个位置有pad的话就说明是无效的,他们之间的关系是无效的。所以需要得到这么一个掩码矩阵。

        先给出公式:Q*K^T,就是Q矩阵乘以K矩阵的转置,shape:[batch_size,src_seg_len,tgt_seg_len]。这里返照前面构造的valid_encoder_pos,同时还要构造一个valid_decoder_pos:

valid_encoder_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max_src_seg_len - L)), 0) for L in src_len]), 2)
valid_decoder_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max_tgt_seg_len - L)), 0) for L in tgt_len]), 2)

tensor([[[1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [0.]],

        [[1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [0.],
         [0.],
         [0.]]])
tensor([[[1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [0.],
         [0.]],

        [[1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [0.]]])

        得到以上的结果是因为之前我们定义的src跟tgt长度分别为11、9和10、11,填充到最大长度为12,并且都扩为3维。

        接着我们要得到三维矩阵,需要将valid_decoder_pos跟valid_encoder_pos相乘,注意encoder_pos要转置,第一维跟第二维,第零维为batch_size,不要操作它。

valid_cross_pos = torch.bmm(valid_decoder_pos,valid_encoder_pos.transpose(1,2))

        得到的结果如下所示,valid_cross_pos也是三维的:batch_size * max_tgt_seg_len * max_src_seg_len,第一个样本的第一行就是第一个目标序列的第一个单词,里面的111111111110说明对应的原序列第一个句子前11个是有效的,最后一个是无效的;第一个样本的第11行全0,说明这里是无效的目标序列。类似的,第二个样本的第一行就是第二个目标序列的第一个单词,对应的是原序列的第二个句子。

 

         接下来的操作跟上一篇的差不多了,先得到一个无效的再转化为bool类型,然后填充无效位置为-1e9之后经过softmax得到目标掩码矩阵的概率分布。

invalid_cross_pos_matrix = 1- valid_cross_pos_matrix
mask_cross_attention = invalid_cross_pos_matrix.to(torch.bool)
mask_score2 = score.masked_fill(mask_cross_attention, -1e9)
prob2 = F.softmax(mask_score2, -1)

 

代码

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

# 句子数
batch_size = 2

# 单词表大小
max_num_src_words = 10
max_num_tgt_words = 10

# 序列的最大长度
max_src_seg_len = 12
max_tgt_seg_len = 12
max_position_len = 12

# 模型的维度
model_dim = 8

# 生成固定长度的序列
src_len = torch.Tensor([11, 9]).to(torch.int32)
tgt_len = torch.Tensor([10, 11]).to(torch.int32)

# 单词索引构成的句子
src_seq = torch.cat(
    [torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L,)), (0, max_src_seg_len - L)), 0) for L in src_len])
tgt_seq = torch.cat(
    [torch.unsqueeze(F.pad(torch.randint(1, max_num_tgt_words, (L,)), (0, max_tgt_seg_len - L)), 0) for L in tgt_len])

# Part1:构造Word Embedding
src_embedding_table = nn.Embedding(max_num_src_words + 1, model_dim)
tgt_embedding_table = nn.Embedding(max_num_tgt_words + 1, model_dim)
src_embedding = src_embedding_table(src_seq)
tgt_embedding = tgt_embedding_table(tgt_seq)

# 构造Pos序列跟i序列
pos_mat = torch.arange(max_position_len).reshape((-1, 1))
i_mat = torch.pow(10000, torch.arange(0, 8, 2) / model_dim)

# Part2:构造Position Embedding
pe_embedding_table = torch.zeros(max_position_len, model_dim)
pe_embedding_table[:, 0::2] = torch.sin(pos_mat / i_mat)
pe_embedding_table[:, 1::2] = torch.cos(pos_mat / i_mat)

pe_embedding = nn.Embedding(max_position_len, model_dim)
pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad=False)

# 构建位置索引
src_pos = torch.cat([torch.unsqueeze(torch.arange(max_position_len), 0) for _ in src_len]).to(torch.int32)
tgt_pos = torch.cat([torch.unsqueeze(torch.arange(max_position_len), 0) for _ in tgt_len]).to(torch.int32)

src_pe_embedding = pe_embedding(src_pos)
tgt_pe_embedding = pe_embedding(tgt_pos)

# Part3:构造encoder self-attention mask
valid_encoder_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max_src_seg_len - L)), 0) for L in src_len]), 2)
valid_encoder_pos_matrix = torch.bmm(valid_encoder_pos, valid_encoder_pos.transpose(1, 2))
invalid_encoder_pos_matrix = 1 - torch.bmm(valid_encoder_pos, valid_encoder_pos.transpose(1, 2))
mask_encoder_self_attention = invalid_encoder_pos_matrix.to(torch.bool)
score = torch.randn(batch_size, max_src_seg_len, max_src_seg_len)
mask_score1 = score.masked_fill(mask_encoder_self_attention, -1e9)
prob1 = F.softmax(mask_score1, -1)

# Part4:构造intra-attention mask
valid_encoder_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max_src_seg_len - L)), 0) for L in src_len]), 2)
valid_decoder_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max_tgt_seg_len - L)), 0) for L in tgt_len]), 2)

valid_cross_pos_matrix = torch.bmm(valid_decoder_pos,valid_encoder_pos.transpose(1,2))
invalid_cross_pos_matrix = 1- valid_cross_pos_matrix
mask_cross_attention = invalid_cross_pos_matrix.to(torch.bool)
mask_score2 = score.masked_fill(mask_cross_attention, -1e9)
prob2 = F.softmax(mask_score2, -1)

        

标签:src,Transformer,torch,max,attention,mask,pos,len,valid
From: https://blog.csdn.net/weixin_62472350/article/details/140408761

相关文章