首页 > 其他分享 >自注意力self-attention理解(qkv计算、代码)

自注意力self-attention理解(qkv计算、代码)

时间:2025-01-07 20:31:48浏览次数:3  
标签:dim self attention 注意力 num key qkv

1.自注意力的个人理解

     self-attention中的核心便是qkv的计算,首先是将输入向量分别乘上三个可学习的的矩阵得到Query(查询)、Key(键)、Value(值);再将q和k点乘达到全局建模的作用,将qk结果进行softmax得到Attention分数;最后将Attention和v相乘这个操作我的理解是:可以把Value这个矩阵理解成一个”筛子“(可以讨论讨论),如果attention-scores较小则对应位置与v相乘的结果也小,反之则越大,最终得到的结果就是特征图中会更加关注有目标的位置而弱化背景的位置(现在很多论文经常会提到”门控机制“,我认为都可以把这个理解成”筛子“)。

2.qkv计算

    在自注意力中涉及到最核心的就是qkv的计算来实现全局建模。主要从以下几点来展开认识:

    2.1.qkv获取

         自注意力无论是在NLP还是CV领域都取得不错的效果,其输入都是将输入转换为张量    Tensor。例如,在NLP中对于每个单词我们需要创建qkv三个向量,这些向量通过将嵌入向量        乘以训练过程中可学习的三个矩阵来创建:

         在图中,将x1乘以可学习的矩阵W_{q}会生成q1,最终为输入的句子中每个的那次创建一  个"query"、”key"、“value"投影。

    2.2.计算注意力分数Attention-score

           步骤一:注意力分数通俗理解就是相关性,就比如在直观上的理解近处的词和远处的词相关性比较低,其对应的注意力分数也就比较低,反之,分数则越高。

             步骤二:在上图中,我们分别为:Thinking和Machines生成了qkv的张量,首先Thinking的查询q和自身的k1相乘,然后再和“Machines”的k2相乘,我们从图中最下面一行可以看到 q1*k1=112,q2*k2=96。一个词的相关性肯定是个自身最相关的,这也符合我们的直觉;

             步骤三:为了训练梯度稳定,我们要将步骤二中得到的qk分数除以\sqrt{d_{k}}(词向量维度的根号)

             步骤四:将每个值向量v乘以softmax分数,对于这个步骤我的理解就是起到了“筛子”的作用,保持我们想要关注的单词的值完整,并弱化不相关的单词(或者理解成相关性弱的单词);

             步骤五:将加权值向量相加,对于第一个单词两说会产生自注意力层的输出。

        2.3自注意力矩阵的计算

                 首先我们需要通过三个可学习的权重矩阵(Wq,Wk,Wv)来生成我们需要的Query        Key、Value矩阵:

                      在图中X的每一行对应输入句子中的一个单词。

3.代码实现

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, attn_ratio=0.5):
        """Initializes multi-head attention module with query, key, and value convolutions and positional encoding."""
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.key_dim = int(self.head_dim * attn_ratio)
        self.scale = self.key_dim**-0.5
        nh_kd = self.key_dim * num_heads
        h = dim + nh_kd * 2
        self.qkv = Conv(dim, h, 1, act=False)    #用于计算查键和值的卷积层  输入通道为dim 输出通道为h
        self.proj = Conv(dim, dim, 1, act=False)   #用于线性映射的卷积层 用于将注意力后的输出投影回原始维度
        self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)  #基于组卷积实现的位置编码

    def forward(self, x):
        """
        Forward pass of the Attention module.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            (torch.Tensor): The output tensor after self-attention.
        """
        B, C, H, W = x.shape
        N = H * W
        qkv = self.qkv(x)
        q, k, v = qkv.view(B, self.num_heads, self.key_dim * 2 + self.head_dim, N).split(
            [self.key_dim, self.key_dim, self.head_dim], dim=2
        )

        attn = (q.transpose(-2, -1) @ k) * self.scale
        attn = attn.softmax(dim=-1)
        x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W))
        x = self.proj(x)
        return x

4.论文链接

http://Vaswani, A. "Attention is all you need." Advances in Neural Information Processing Systems (2017).

标签:dim,self,attention,注意力,num,key,qkv
From: https://blog.csdn.net/m0_59959542/article/details/144991163

相关文章