首页 > 编程语言 > Bert Pytorch 源码分析:二、注意力层

Bert Pytorch 源码分析:二、注意力层

时间:2023-06-25 18:45:49浏览次数:54  
标签:Bert ML self value Pytorch 源码 key query ES

# 注意力机制的具体模块
# 兼容单头和多头
class Attention(nn.Module):
    """
    Compute 'Scaled Dot Product Attention
    """

	# QKV 尺寸都是 BS * ML * ES
	# (或者多头情况下是 BS * HC * ML * HS,最后两维之外的维度不重要)
	# 从输入计算 QKV 的过程可以统一处理,不必放到每个头里面
    def forward(self, query, key, value, mask=None, dropout=None):
		# 将每个批量的 Q 和 K.T 做矩阵乘法,再除以√ES,
		# 得到相关性矩阵 S,尺寸为 BS * ML * ML
        scores = torch.matmul(query, key.transpose(-2, -1)) \
                 / math.sqrt(query.size(-1))

		# 如果存在掩码则使用它
		# 将 scores 的 mask == 0 的位置上的元素改为 -1e9
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # 将 S 转换到概率空间,同时对其最后一维归一化
        p_attn = F.softmax(scores, dim=-1)

		# 如果存在 dropout 则使用
        if dropout is not None:
            p_attn = dropout(p_attn)

		# 最后将 S 与 V 相乘得到输出
        return torch.matmul(p_attn, value), p_attn
		
# 多头注意力就是包含很多(HC)个头,但是每个头的尺寸(HS)变为原来的 1/HC
# 把 qkv 切成小段分给每个头做运算,将结果拼起来作为整个层的输出
class MultiHeadedAttention(nn.Module):
    """
    Take in model size and number of heads.
    """

	# h 是头数(HC)
	# d_model 是嵌入向量大小(ES)
    def __init__(self, h, d_model, dropout=0.1):
        super().__init__()
		# 判断 ES 是否能被 HC 整除,以便结果能拼接回去
        assert d_model % h == 0

		# d_k 是每个头的大小 HS = ES // HC
        self.d_k = d_model // h
        self.h = h

		# 创建输入转换为QKV的权重矩阵,Wq, Wk, Wv,尺寸均为 ES * ES
        self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
		# 输出应该还乘一个权重矩阵,Wo,尺寸也是 ES * ES
        self.output_linear = nn.Linear(d_model, d_model)
		# 创建执行注意力机制的具体模块
        self.attention = Attention()
		# 创建 droput 层
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
		# 获取批量大小(BS)
        batch_size = query.size(0)

       
		'''
        query, key, value = [
			l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
		    for l, x in zip(self.linear_layers, (query, key, value))
		]
		'''
		# 将 QKV 的每个与其相应权重矩阵 Wq, Wk, Wv 相乘
		lq, lk, lv = self.linear_layers
		query, key, value = lq(query), lk(key), lv(value) 
		
		# 然后将他们转型为 BS * ML * HC * HS
		# 也就是将最后一个维度按头部数量分割成小的向量
		query, key, value = [
			x.view(batch_size, -1, self.h, self.d_k)
			for x in (query, key, value)
		]
		
		# 然后交换 1 和 2 维,变成 BS * HC * ML  * HS
		# 这样每个头的 QKV 是内存连续的,便于矩阵相乘
		query, key, value = [
			x.transpose(1, 2)
			for x in (query, key, value)
		]

        # 对每个头应用注意力机制,输出尺寸不变
        x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)

        # 交换 1 和 2 维恢复原状,然后把每个头的输出相连接,尺寸变为 BS * ML * ES
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)

		# 执行最后的矩阵相乘
        return self.output_linear(x)

缩写表

  • BS:批量大小,即一批数据中样本大小,训练集和测试集可能不同,那就是TBS和VBS
  • ES:嵌入大小,嵌入向量空间的维数,也是注意力层的隐藏单元数量,GPT 中一般是 768
  • ML:输入序列最大长度,一般是512或者1024,不够需要用<pad>填充
  • HC:头部的数量,需要能够整除ES,因为每个头的输出拼接起来才是层的输出
  • HS:头部大小,等于ES // HC
  • VS:词汇表大小,也就是词的种类数量

尺寸备注

  • 嵌入层的矩阵尺寸应该是VS * ES
  • 注意力层的输入尺寸是BS * ML * ES
  • 输出以及 Q K V 和输入形状相同
  • 每个头的 QKV 尺寸为BS * ML * HS
  • 权重矩阵尺寸为ES * ES
  • 相关矩阵 S 尺寸为BS * ML * ML

标签:Bert,ML,self,value,Pytorch,源码,key,query,ES
From: https://www.cnblogs.com/apachecn/p/17503687.html

相关文章

  • 【源码阅读】2. Catalog和Database
     Catalog创建|KW_CREATEKW_CATALOGopt_if_not_exists:ifNotExistsident:catalogNameopt_properties:properties{:RESULT=newCreateCatalogStmt(ifNotExists,catalogName,null,properties);:}|KW_CREATEKW_CATALOGopt_if_not_......
  • 【源码阅读】1. 配置、VARIABLE与用户PROPERTY
     配置初始化在FE启动时:● Config类ConfField注解标记的静态属性反射出Field存储到内存confFields,作为一个可读取和修改的属性列表(真正的值存储在Config类的静态属性中,反射出Field并存储到confFields只是一个读取和修改指针而已)● 读取配置文件,根据配置文件内容,设置Confi......
  • Bert PyTorch 源码分析:一、嵌入层
    #标记嵌入就是最普通的嵌入层#接受单词ID输出单词向量#直接转发给了`nn.Embedding`classTokenEmbedding(nn.Embedding):def__init__(self,vocab_size,embed_size=512):super().__init__(vocab_size,embed_size,padding_idx=0) #片段嵌入实际上是......
  • 谁与争锋!手机直播源码知识分享之主播PK功能
    今天我要分享的知识与PK有关,PK是指某些人分成几方进行对决、对抗,直到分出胜负。PK的方式有很多,在现实生活中,人们可以通过智力、力量等进行PK,方式可以是搏斗、扳手腕、现场智力问答等;而在网络中,人们可以通过游戏、网络智力问答的方式进行PK。我今天要讲的这个功能也是网络中的PK,这个......
  • 2.nacos-client源码及查看
    nacos-client.2.2.1-RC.SDK查看源码官网JAVASDK链接主要内容<dependency><groupId>com.alibaba.nacos</groupId><artifactId>nacos-client</artifactId><version>${version}</version></dependency>问题:1.获取配置api是获取快照......
  • k8s驱逐篇(7)-kube-controller-manager驱逐-taintManager源码分析
    概述taintManager的主要功能为:当某个node被打上NoExecute污点后,其上面的pod如果不能容忍该污点,则taintManager将会驱逐这些pod,而新建的pod也需要容忍该污点才能调度到该node上;通过kcm启动参数--enable-taint-manager来确定是否启动taintManager,true时启动(启动参数默认值为true);k......
  • spring源码笔记
    Bean创建流程获取对象的BeanDefinition通过反射创建空对象填充属性调用init方法  Bean创建关键方法(按顺序)getBeandoGetBeancreateBeandoCreateBeancreateBeanInstancepopulateBean  解决循环依赖:三级缓存循环依赖原因单例,每个类只有一个对象。A引用B,B又......
  • SPI的插件化设计-->JDK的SPI(ServiceLoader)实现拓展、实现Dubbo的SPI(ExtensionLoade
    (目录)1.什么是SPI?SPI的全称是ServiceProviderInterface,直译过来就是"服务提供接口",为了降低耦合,实现在模块装配的时候动态指定具体实现类的一种服务发现机制。动态地为接口寻找服务实现。它的核心来自于ServiceLoader这个类。javaSPI应用场景很广泛,在Java底层和一些......
  • 基于springboot+vue的漫画之家管理系统,附源码+数据库+论文+PPT,适合课程设计、毕业设计
    1、项目介绍随着信息技术和网络技术的飞速发展,人类已进入全新信息化时代,传统管理技术已无法高效,便捷地管理信息。为了迎合时代需求,优化管理效率,各种各样的管理系统应运而生,各行各业相继进入信息管理时代,“漫画之家”系统就是信息时代变革中的产物之一。任何系统都要遵循系统设计......
  • TVM 源码阅读PASS — VectorizeLoop
    本文地址:https://www.cnblogs.com/wanger-sjtu/p/17501119.htmlVectorizeLoop这个PASS就是对标记为ForKind::kVectorized的For循环做向量化处理,并对For循环中的语句涉及到的变量,替换为Ramp,以便于在Codegen的过程中生成相关的向量化运算的指令。VectorizeLoop这个PASS的入口函数......