首页 > 其他分享 >Scaled Dot-Product Attention 的公式中为什么要除以 $\sqrt{d_k}$?

Scaled Dot-Product Attention 的公式中为什么要除以 $\sqrt{d_k}$?

时间:2024-10-22 17:59:33浏览次数:5  
标签:Product right mathbf Attention sqrt Var mathrm left

Scaled Dot-Product Attention 的公式中为什么要除以 \(\sqrt{d_k}\)?

在学习 Scaled Dot-Product Attention 的过程中,遇到了如下公式

\[ \mathrm{Attention} (\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \mathrm{softmax} \left( \dfrac{\mathbf{Q} \mathbf{K}}{\sqrt{d_k}} \right) \mathbf{V} \]

不禁产生疑问,其中的 \(\sqrt{d_k}\) 为什么是这个数,而不是 \(d_k\) 或者其它的什么值呢?

Attention Is All You Need 中有一段解释

We suspect that for large values of \(d_k\), the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients. To counteract this effect, we scale the dot products by \(\sqrt{d_k}\).

这说明,两个向量的点积可能很大,导致 softmax 函数的梯度太小,因此需要除以一个因子,但是为什么是 \(\sqrt{d_k}\) 呢?

文章中的一行注释提及到

To illustrate why the dot products get large, assume that the components of \(\mathbf{q}\) and \(\mathbf{k}\) are independent random variables with mean \(0\) and variance \(1\). Then their dot product, $\mathbf{q} \cdot \mathbf{k} = \sum_{i=1}^{d_k} q_i k_i $ has mean \(0\) and variance \(d_k\).

本期,我们将基于上文的思路进行完整的推导,以证明 \(\sqrt{d_k}\) 的在其中的作用.

基本假设

假设独立随机变量 \(U_1 ,\ U_2 ,\ \dots ,\ U_{d_k}\) 和独立随机变量 \(V_1 ,\ V_2 ,\ \dots ,\ V_{d_k}\) 分别服从期望为 \(0\),方差为 \(1\) 的分布,即

\[E \left(U_i \right) = 0 ,\ \mathrm{Var} \left(U_i \right) = 1 \]

\[E \left(V_i \right) = 0 ,\ \mathrm{Var} \left(V_i \right) = 1 \]

其中 \(i = 1, 2, \dots ,\ d_k\),\(d_k\) 是个常数.

计算 $U_i V_i $ 的方差

由随机变量方差的定义可得 $ U_i V_i $ 的方差为

\[\begin{align*} \mathrm{Var} \left( U_i V_i \right) &= E \left[ \left( U_i V_i - E \left( U_i V_i \right) \right)^2\right] \\ &= E \left[ \left(U_i V_i \right)^2 - 2U_i V_i E \left( U_i V_i \right) + E^2 \left( U_i V_i \right)\right] \\ &= E \left[ \left( U_i V_i \right)^2 \right] - 2 E \left[ U_i V_i E \left( U_i V_i \right) \right] + E^2 \left(U_i V_i\right) \\ &= E \left( U_i^2 V_i^2 \right) - 2 E \left( U_i V_i \right) E \left( U_i V_i \right) + E^2 \left(U_i V_i\right) \\ &= E \left( U_i^2 V_i^2 \right) - E^2 \left( U_i V_i \right) \end{align*} \]

因为 \(U_i\) 和 \(V_i\) 是独立的随机变量,所以

\[E \left( U_i V_i \right) = E \left( U_i \right) E \left( V_i \right) \]

从而

\[\begin{align*} \mathrm{Var} \left( U_i V_i \right) &= E\left(U_i^2\right) E\left(V_i^2\right) - \left(E\left(U_i\right) E\left(V_i\right) \right)^2 \\ &= E\left(U_i^2\right) E\left(V_i^2\right) - E^2\left(U_i\right) E^2\left(V_i\right) \end{align*} \]

又因为 \(E(U_i) = E(V_i) = 0\),所以

\[\mathrm{Var} \left( U_i V_i \right) = E(U_i^2) E(V_i^2) \]

计算 \(E(U_i^2)\)

因为

\[ E \left( U_i \right) = 0 \]

\[\mathrm{Var} \left( U_i \right) = 1 \]

\[\mathrm{Var} \left( U_i \right) = E \left( U_i^2 \right) - E^2 \left( U_i \right) \]

所以

\[E(U_i^2) = 1 \]

同理,

\[E(V_i^2) = 1 \]

计算 \(\mathbf{q} \mathbf{k}\) 的方差

如果 \(\mathbf{q} = \left[U_1, U_2, \cdots, U_{d_k} \right]^T\),\(\mathbf{k} = \left[V_1, V_2, \cdots, V_{d_k} \right]^T\),那么

\[\mathbf{q} \mathbf{k} = \sum_{i=1}^{d_k} U_i V_i \]

\(\mathbf{q} \mathbf{k}\) 的方差

\[\begin{align*} \mathrm{Var}\left( \mathbf{q} \mathbf{k} \right) &= \mathrm{Var}\left( \sum_{i=1}^{d_k} U_i V_i \right) \\ &= \sum_{i=1}^{d_k} \mathrm{Var} \left( U_i V_i \right) \\ &= \sum_{i=1}^{d_k} E \left(U_i^2\right) E \left(V_i^2\right) \\ &= \sum_{i=1}^{d_k} 1 \cdot 1 \\ &= d_k \end{align*} \]

到这里就可以解释为什么在最后要除以 \(\sqrt{d_k}\),因为

\[\begin{align*} \mathrm{Var}\left( \dfrac{\mathbf{q} \mathbf{k} }{\sqrt{d_k}} \right) &= \dfrac{\mathrm{Var}\left( \mathbf{q} \mathbf{k} \right)}{d_k} \\ &= \dfrac{d_k}{d_k} \\ &= 1 \end{align*} \]

可见这个因子的目的是让 \(\mathbf{q} \mathbf{k}\) 的分布也归一化到期望为 \(0\),方差为 \(1\) 的分布中,增强机器学习的稳定性.

参考文献/资料

标签:Product,right,mathbf,Attention,sqrt,Var,mathrm,left
From: https://www.cnblogs.com/AkagawaTsurunaki/p/18493441

相关文章

  • 一文读懂什么是数据即产品(Data as a Product,DaaP)
    企业每天都要产生并消费大量数据,但如果这些数据一直保持在原始格式,就很难真正应用起来。因此,为了充分发挥数据的最大潜力,必须改变组织内部处理数据的方式。“数据即产品”(DaaP)就是这样一种思维方式转变的代表,即将原始数据转化为高质量的信息产品。这种转变不仅会改变企业的数据战......
  • YOLOv8改进:引入LSKAttention大核注意力机制,助力目标检测性能极限提升【YOLOv8】
    本专栏专为AI视觉领域的爱好者和从业者打造。涵盖分类、检测、分割、追踪等多项技术,带你从入门到精通!后续更有实战项目,助你轻松应对面试挑战!立即订阅,开启你的YOLOv8之旅!专栏订阅地址:https://blog.csdn.net/mrdeam/category_12804295.html文章目录YOLOv8改进:引入LSKAtte......
  • 【论文阅读】【IEEE TGARS】RRNet: Relational Reasoning Network WithParallel Multi
    引言任务:光学遥感显著目标检测-关系推理论文地址:RRNet:RelationalReasoningNetworkWithParallelMultiscaleAttentionforSalientObjectDetectioninOpticalRemoteSensingImages|IEEEJournals&Magazine|IEEEXplore代码地址:rmcong/RRNet_TGRS2021(g......
  • P5048 [Ynoi2019 模拟赛] Yuno loves sqrt technology III
    Sol蒲公英题意基本相同,但是注意到空间限制62.5MB,显然不能用蒲公英的做法。考虑先把整块的答案算出来,然后把小块的部分补上去,显然大块可以预处理,小块可以直接暴力查询是否越界。代码很简单。Code#include<iostream>#include<iomanip>#include<cstdio>#include<vector>......
  • YOLO11-pose关键点检测:可变形双级路由注意力(DBRA),魔改动态稀疏注意力的双层路由方法BiL
    ......
  • FlashAttention逐代解析与公式推导
    StandardAttention标准Attention计算可以简化为:\[O=softmax(QK^T)V\tag{1}\]此处忽略了AttentionMask和维度归一化因子\(1/\sqrt{d}\)。公式(1)的标准计算方式是分解成三步:\[S=QK^T\tag{2}\]\[P=softmax(S)\tag{3}\]\[O=PV\tag{4}\]但这样做的问题在于,假设\(......
  • YOLOv8改进 - 注意力篇 - 引入ShuffleAttention注意力机制
    一、本文介绍作为入门性篇章,这里介绍了ShuffleAttention注意力在YOLOv8中的使用。包含ShuffleAttention原理分析,ShuffleAttention的代码、ShuffleAttention的使用方法、以及添加以后的yaml文件及运行记录。二、ShuffleAttention原理分析ShuffleAttention官方论文地址:文章Sh......
  • Production Tracking是什么 ?
    【大家好,我是唐Sun,唐Sun的唐,唐Sun的Sun。一站式数智工厂解决方案服务商】ProductionTracking,即生产跟踪,是对生产过程进行全面、实时监控和记录的一种管理手段。它涵盖了从原材料采购、生产计划制定、生产工序执行,到产品最终完成的整个生产流程。通过各种技术手段,如传感器、......
  • DuoAttention: 高效的长上下文大语言模型推理方法
    在人工智能的日新月异中,长上下文大语言模型(LLMs)如同一颗闪亮的明星,吸引着研究人员的目光。然而,部署这些模型并非易事,尤其在处理长上下文时,面临着计算和内存的巨大挑战。在这一背景下,“DuoAttention”的理念应运而生,旨在通过高效的长上下文推理方法,缓解这些问题。......
  • YOLOv11改进策略【Conv和Transformer】| CVPR-2024 Single-Head Self-Attention 单头
    一、本文介绍本文记录的是利用单头自注意力SHSA改进YOLOv11检测模型,详细说明了优化原因,注意事项等。传统的自注意力机制虽能提升性能,但计算量大,内存访问成本高,而SHSA从根本上避免了多注意力头机制带来的计算冗余。并且改进后的模型在相同计算预算下,能够堆叠更多宽度更大的......