首页 > 其他分享 >注意力机制【5】Scaled Dot-Product Attention 和 mask

注意力机制【5】Scaled Dot-Product Attention 和 mask

时间:2022-09-30 08:33:13浏览次数:48  
标签:dim Product sentence lengths self Attention mask hidden

Scaled Dot-Product Attention 

在实际应用中,经常会用到 Attention 机制,其中最常用的是 Scaled Dot-Product Attention,它是通过计算query和key之间的点积 来作为 之间的相似度。

  • Scaled 指的是 Q和K计算得到的相似度 再经过了一定的量化,具体就是 除以 根号下K_dim;
  • Dot-Product 指的是 Q和K之间 通过计算点积作为相似度;
  • Mask 可选择性 目的是将 padding的部分 填充负无穷,这样算softmax的时候这里就attention为0,从而避免padding带来的影响.

mask

上代码吧

import torch
import torch.nn as nn
import pandas as pd
import torch.nn.functional as F


class Attention_Layer(nn.Module):
    # 用来实现mask-attention layer
    def __init__(self, input_size, hidden_dim):
        super(Attention_Layer, self).__init__()

        self.hidden_dim = hidden_dim
        self.Q_linear = nn.Linear(input_size, hidden_dim, bias=False)
        self.K_linear = nn.Linear(input_size, hidden_dim, bias=False)
        self.V_linear = nn.Linear(input_size, hidden_dim, bias=False)

    def forward(self, inputs, lens):
        size = inputs.size()        # [b h w]   h代表词总量,w代表每个词的编码长度
        # 计算生成QKV矩阵
        Q = self.Q_linear(inputs)   # [b h hidden_dim]
        K = self.K_linear(inputs).permute(0, 2, 1)  # # [b hidden_dim h]
        V = self.V_linear(inputs)   # [b h hidden_dim]

        # 还要计算生成mask矩阵
        max_len = max(lens)  # 最大的句子长度,生成mask矩阵
        sentence_lengths = torch.Tensor(lens)  # 代表每个句子的长度
        print(sentence_lengths)                     # tensor([ 7., 10.,  4.])
        print(sentence_lengths.max().item())        # 10.0
        print(torch.arange(sentence_lengths.max().item()))          # tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
        print(torch.arange(sentence_lengths.max().item())[None, :]) # tensor([[0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]])
        print(sentence_lengths[:, None])
        # tensor([[7.],
        #         [10.],
        #         [4.]])
        mask = torch.arange(sentence_lengths.max().item())[None, :] < sentence_lengths[:, None]
        print(mask) # <前每一行的所有值 分别与 <后每一列的值 进行比较
        # tensor([[True, True, True, True, True, True, True, False, False, False],
        #         [True, True, True, True, True, True, True, True, True, True],
        #         [True, True, True, True, False, False, False, False, False, False]])
        mask = mask.unsqueeze(dim=1)  # [batch_size, 1, max_len]
        mask = mask.expand(size[0], max_len, max_len)  # [batch_size, max_len, max_len]

        padding_num = torch.ones_like(mask)     # 全1
        padding_num = -2 ** 31 * padding_num.float()    # 全无穷小
        # qk=[b h hidden_dim]*[b hidden_dim h]=[b h h] 代表每句话的每个词 和 其他词的关系
        alpha = torch.matmul(Q, K)
        # mask True 区域的 alpha 值 置为 无穷小
        alpha = torch.where(mask, alpha, padding_num)   # 用法:满足条件,返回x,否则返回y
        pd.DataFrame(alpha[0].data).to_csv('mask.csv')
        alpha = F.softmax(alpha, dim=2)
        pd.DataFrame(alpha[0].data).to_csv('softmax.csv')
        # softmax*v=[b h h]*[b h hidden_dim]=[b h hidden_dim]
        out = torch.matmul(alpha, V)
        return out


if __name__ == '__main__':
    input_size = 100
    hidden_size = 8
    input = torch.rand(3, 10, input_size)
    att_L = Attention_Layer(input_size, hidden_size)
    lens = [7, 10, 4]  # 一个batch文本的真实长度

    att_out = att_L(input, lens) 

看看中间结果就明白了

 

 

参考资料:

标签:dim,Product,sentence,lengths,self,Attention,mask,hidden
From: https://www.cnblogs.com/yanshw/p/16741156.html

相关文章

  • face_masker
    安装opencvpipinstall-ihttps://pypi.tuna.tsinghua.edu.cn/simpleopencv-python测试代码importnumpyasnpimportcv2ascvcap=cv.VideoCapture(0)ifnot......
  • linux umask值的设置
    一、linux系统默认umask为022[root@localhost~]#umask0022查看umask值对应的权限[root@localhost~]#umask-Su=rwx,g=rx,o=rx022对应的目录权限是:777-022=7......
  • com.ibatis.sqlmap.client.SqlMapException: There is no statement named saveNewPr
    经常发生这种问题,其实是写代码不严谨造成的。忘记将相应的sqlMap文件名称和路径在sqlMapConfig(sql-map-config.xml)配置文件中进行配置。  在文件中加入新写的dao层xml......
  • Attention Is All You Need transformer开山之作论文精读 笔记
    参考资料1、https://www.bilibili.com/video/BV1pu411o7BE/?spm_id_from=333.337.search-card.all.click&vd_source=920f8a63e92d345556c1e229d6ce363f李沐老师讲解trans......
  • mask和RectMask2D区别
    1.Mask遮罩的大小与形状依赖于Graphic,而RectMask2D只需要依赖RectTransform2.Mask支持圆形或其他形状遮罩,而RectMask2D只支持矩形3.Mask会增加drawcall4、mask的性......
  • self-attention为什么要除以根号d_k
    参考文章:https://blog.csdn.net/tailonh/article/details/120544719正如上文所说,原因之一在于:1、首先要除以一个数,防止输入softmax的值过大,导致偏导数趋近于0;2、选......
  • ABC 269 C - Submask(dfs+位运算)
    C-Submask(dfs+位运算)题目大意:给定一个十进制的数字,让我们求出它的二进制下的1可以改变时候的数字SampleInput111SampleOutput10123891011Thebi......
  • Fastformer: Additive Attention Can Be All You Need
    创新点:本文根据transformer模型进行改进,提出了一个高效的模型,模型复杂度呈线性。主要改进了注意力机制,出发点在于降低了注意力矩阵的重要程度,该方法采用一个(1*T)一维向量......
  • Clickhouse执行报错(Double-distributed IN/JOIN subqueries is denied (distributed_p
    错误示例:DB::Exception:Double-distributedIN/JOINsubqueriesisdenied(distributed_product_mode='deny').Youmayrewritequerytouselocaltablesinsubq......
  • ECCV 2022 | k-means Mask Transformer
    前言 目前,大多数现有的基于transformer的视觉模型只是借用了自然语言处理的思想,忽略了语言和图像之间的关键差异,特别是空间扁平像素特征的巨大序列长度。这阻碍了在像素特......