首页 > 其他分享 >自注意力机学习

自注意力机学习

时间:2024-06-07 23:33:02浏览次数:20  
标签:self Value 学习 Key Query embed 注意力 size

自注意力机制的核心概念

1. Query, Key 和 Value
  • Query(查询向量):可以看作是你当前在关注的输入项。假设你正在阅读一段文字,这就像你当前在读的句子。

  • Key(键向量):表示其他所有输入项的标识或特征。这就像你在书中已经读过的所有句子的摘要或要点。

  • Value(值向量):是与每个Key相关联的具体信息或内容。就像这些句子带来的详细信息。

现实比喻
想象你在图书馆寻找一本特定的书(Query),书架上有很多书,每本书都有一个书名(Key)。根据书名(Key)匹配你的查询(Query),你从合适的书中获取详细内容(Value)。

2. 点积注意力(Dot-Product Attention)

这是计算Query和Key之间相关性的方式。我们通过计算Query和Key的点积来确定它们的关系强度。

比喻
就像在图书馆,你有一本书的部分标题(Query),你对比书架上所有书的书名(Key),看哪个书名最接近你的标题,然后选出最相关的书(Value)。

3. 缩放(Scaling)

为了防止Query和Key之间的点积结果太大导致数值不稳定,我们将结果除以一个常数——通常是Key向量的维度的平方根。这使得计算更加稳定。

比喻
假设你在测试你的记忆力,如果你直接用高分数衡量,可能会出现极端值。所以你需要调整分数范围,使得评估更合理和稳定。

4. Softmax 归一化

Softmax函数将一组数值转换为概率分布,使得它们的总和为1。这意味着每个单词的注意力权重表示它对当前处理单词的重要性。

比喻
就像你在评分不同的书,Softmax就像把所有的分数转换成百分比,这样你可以看到每本书相对于其他书的重要性。

自注意力机制的工作流程

让我们更详细地看看自注意力机制是如何一步一步工作的:

  1. 生成 Query, Key 和 Value 向量

    我们首先通过线性变换将输入序列的每个单词转换成三个不同的向量:Query, Key 和 Value。

    query = W_q * input
    key = W_k * input
    value = W_v * input
    

    比喻:这是把每个单词变成三个不同的代表,就像给每个单词生成了三个不同的标签,用于不同的目的(查询、匹配和提供信息)。

  2. 计算注意力权重

    通过计算Query和Key的点积,我们得到它们之间的相关性得分。然后,我们将这些得分除以 d k \sqrt{d_k} dk​ ​ 进行缩放,最后应用Softmax函数来得到权重。

    # 计算点积
    scores = query.dot(key.T) / sqrt(d_k)
    # 使用Softmax函数归一化
    attention_weights = softmax(scores)
    

    比喻:这就像你比较当前正在读的句子(Query)和你已经读过的所有句子(Key),然后根据它们的相似程度打分。接着,你将这些分数标准化,使它们总和为1,表示每个句子的重要性百分比。

  3. 加权求和 Value 向量

    我们将Value向量按照注意力权重进行加权求和,这样每个Value对最终输出的贡献由它的重要性决定。

    # 计算加权的Value
    output = sum(attention_weights * value)
    

    比喻:就像你根据每本书的重要性百分比(注意力权重),从每本书中提取一定量的信息(Value),最终形成你对整个图书馆信息的理解。

示例和实际应用

假设你在处理一句话“我喜欢吃苹果,因为苹果很甜”:

  1. Query, Key, Value

    • Query:当前处理的词是“苹果”。
    • Key:句子中的所有单词的表示,如“我”,“喜欢”,“吃”,“苹果”,“因为”,“很”,“甜”。
    • Value:这些单词的具体信息,比如它们的词义或上下文信息。
  2. 点积注意力

    • 你在评估“苹果”和句子中其他词的关系,比如“苹果”与“甜”的关系就很重要,而与“我”关系可能不大。
  3. Softmax 归一化

    • 将关系得分转化为一个概率分布,表示每个单词对当前词“苹果”的重要性。
  4. 加权求和

    • 最后,根据重要性权重,从每个单词中提取信息,生成“苹果”的最终表示,这样“苹果”就包含了它和“甜”的关系。

自注意力机制代码示例

import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, embed_size, bias=False)
        self.keys = nn.Linear(self.head_dim, embed_size, bias=False)
        self.queries = nn.Linear(self.head_dim, embed_size, bias=False)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # 1. 生成 Query, Key 和 Value 向量
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        # 2. 计算注意力权重
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

        # 3. 加权求和 Value 向量
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.embed_size
        )

        out = self.fc_out(out)
        return out

关键概念总结

  1. 自注意力机制:允许模型在处理一个输入时,同时关注到整个输入序列中的所有其他输入。提高了捕捉长距离依赖关系的能力。

  2. Query, Key 和 Value:分别代表当前处理的焦点、其他输入的标识和它们携带的信息。

  3. 点积注意力:通过计算Query和Key的相似性来确定它们之间的关系强度。

  4. 缩放:对点积结果进行调整,防止数值过大导致计算不稳定。

  5. Softmax 归一化:将相似性得分转化为概率分布,表示每个输入的重要性。

通过这些步骤,自注意力机制能够帮助模型在处理每一个输入时同时考虑整个序列,从而更好地理解上下文和词语之间的关系。

标签:self,Value,学习,Key,Query,embed,注意力,size
From: https://blog.csdn.net/pumpkin84514/article/details/139537093

相关文章

  • 【机器学习】TensorFlow 202107090086
    【源代码】importtensorflowastfl2_reg=tf.keras.regularizers.l2(0.1)#设置模型model=tf.keras.models.Sequential([tf.keras.layers.Dense(30,activation='relu',kernel_initializer='he_normal',kernel_regula......
  • 算法学习笔记(23):杜教筛
    杜教筛参考来源:OI-Wiki,网上博客线性筛可以在线性时间求积性函数前缀和,而杜教筛可以用低于线性时间求解积性函数前缀和。我们考虑\(S(n)\)就是积性函数的前缀和,所以我们尝试构造关于\(\largeS(n)\)关于\(\largeS(\lfloor\frac{n}{i}\rfloor)\)的递推式。对于任意......
  • 知识图谱学习记录(一)
    知识图谱学习记录(一)1.什么是知识图谱?知识图谱是一种用于表示知识的图形化结构,它包含了实体(如人物、地点、事件等)以及这些实体之间的关系。它的目的是将信息组织成易于理解和处理的形式,以便计算机程序能够理解和利用这些信息。知识图谱通常由三部分组成:实体(Entities):代表现实世......
  • DVWA靶场学习(一)—— Brute Force
    BruteForce暴力破解其实就是利用不同的账户和密码进行多次尝试。因为用户在设置密码时可能会选用比较容易记忆的口令,因此,可以使用一些保存常用密码的字典或者结合用户的个人信息进行爆破。DVWA安全等级有Low,Medium,High和Impossible四种,随着安全等级的提高,网站的防护等级和攻击......
  • ChatGPT-4o在临床医学日常工作、数据分析与可视化、机器学习建模中的技术
    2022年11月30日,可能将成为一个改变人类历史的日子——美国人工智能开发机构OpenAI推出了聊天机器人ChatGPT-3.5,将人工智能的发展推向了一个新的高度。2023年11月7日,OpenAI首届开发者大会被称为“科技界的春晚”,吸引了全球广大用户的关注,GPT商店更是显现了OpenAI旨在构建AI生态......
  • 滑坡、泥石流等地质灾害风险评价、基于机器学习的滑坡易发性分析技术
    入门篇,ArcGIS软件的快速入门与GIS数据源的获取与理解;方法篇,致灾因子提取方法、灾害危险性因子分析指标体系的建立方法和灾害危险性评价模型构建方法;拓展篇,GIS在灾害重建中的应用方法;高阶篇:Python环境中利用机器学习进行灾害易发性评价模型的建立与优化方法。原文链接:滑坡、泥......
  • Pixi.js学习 (四)鼠标跟随、字符拼接与图片位控
    目录目录目录前言一、鼠标移动跟随1.1获取鼠标坐标1.2 鼠标跟随二、锚点、元素组合2.1锚点2.2 元素组合2.3总结前言为了提高作者的代码编辑水品,作者在使用博客的时候使用的集成工具为HBuilderX。下文所有截图使用此集成工具,读者随意。此系列文......
  • 数据结构学习笔记-佛洛依德算法
    最短路径问题的经典解法-佛洛依德算法问题描述:设计算法求解图的最短路径【算法设计思想】初始化距离矩阵:首先,将解决方案矩阵dist[][]初始化为输入图矩阵graph[][],这个矩阵存储了顶点之间的直接距离或者权值。中间顶点迭代:然后,对每一个顶点作为中间顶点进行迭代。算法通过......
  • 盘点学习Python常犯一些错误,你中了几个
    对于刚入门的Pythonista在学习过程中运行代码是或多或少会遇到一些错误,刚开始可能看起来比较费劲。随着代码量的积累,熟能生巧当遇到一些运行时错误时能够很快的定位问题原题。下面整理了一些常见的17个错误,等你写出的代码不怎么出现这些错误的时候,你的Python功力就上......
  • 分数规划学习笔记
    1.用途分数规划常用于求一个分式的极值,就是给出两个序列\(a_i\)和\(b_i\),使得\(\dfrac{\suma_i\timesw_i}{\sumb_i\timesw_i}\)的值最大或最小,其中\(w\in\{0,1\}\)通常,题目中还会有类似于\(\sumb_i>w\)的限制2.解法通常使用二分,记当前二分的值为mid\[\begin{aligned......