首页 > 编程语言 >Bert Pytorch 源码分析:四、编解码器

Bert Pytorch 源码分析:四、编解码器

时间:2023-06-26 15:33:25浏览次数:47  
标签:vocab Bert self param Pytorch 源码 __ hidden size

# Bert 编码器模块
# 由一个嵌入层和 NL 个 TF 层组成
class BERT(nn.Module):
    """
    BERT model : Bidirectional Encoder Representations from Transformers.
    """

    def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
        """
        :param vocab_size: vocab_size of total words
        :param hidden: BERT model hidden size
        :param n_layers: numbers of Transformer blocks(layers)
        :param attn_heads: number of attention heads
        :param dropout: dropout rate
        """

        super().__init__()
		# 嵌入大小 ES
        self.hidden = hidden
		# TF 层数 NL
        self.n_layers = n_layers
		# 头部数量 HC
        self.attn_heads = attn_heads

        # FFN 层中的隐藏单元数量,记为 FF,一般是 ES 的四倍
        self.feed_forward_hidden = hidden * 4

        # 嵌入层,嵌入矩阵尺寸 VS * ES
        self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden)

        # NL 个 TF 层
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)])

    def forward(self, x, segment_info):
        # 为`<pad>`(ID = 0)设置掩码
		# 尺寸为 BS * 1 * ML * ML,以便与相似性矩阵 S 匹配
		# 在每个 BS 的 ML * ML 矩阵中,`<pad>`标记对应的行为 1,其余为零
        mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)

        # 单词 ID 传入嵌入层得到词向量
        x = self.embedding(x, segment_info)

        # 依次传入每个 TF 层,得到编码器输出
        for transformer in self.transformer_blocks:
            x = transformer.forward(x, mask)

        return x

# 解码器结构根据具体任务而定
# 任务一般有三种:(1)序列分类,(2)标记分类,(3)序列生成
# 但一般都是全连接的

# 用于下个句子判断的解码器
# 序列分类任务,输入两个句子,输出一个标签,1表示是相邻句子,0表示不是
class NextSentencePrediction(nn.Module):
    """
    2-class classification model : is_next, is_not_next
    """

    def __init__(self, hidden):
        """
        :param hidden: BERT model output size
        """
        super().__init__()
		# 将向量压缩到两维, 尺寸为 ES * 2
        self.linear = nn.Linear(hidden, 2)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
		# 输入 -> 取第一个向量 -> LL -> softmax -> 输出
		# 输出相邻句子和非相邻句子的概率
        return self.softmax(self.linear(x[:, 0]))

# 用于完型填空的解码器
# 序列生成任务,输入是带有`<mask>`的句子,输出是完整句子
class MaskedLanguageModel(nn.Module):
    """
    predicting origin token from masked input sequence
    n-class classification problem, n-class = vocab_size
    """

    def __init__(self, hidden, vocab_size):
        """
        :param hidden: output size of BERT model
        :param vocab_size: total vocab size
        """
        super().__init__()
		# 将输入压缩到词汇表大小
        self.linear = nn.Linear(hidden, vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
		# 输入 -> LL -> softmax -> 输出
		# 输出序列中每个词是词汇表中每个词的概率
        return self.softmax(self.linear(x))

标签:vocab,Bert,self,param,Pytorch,源码,__,hidden,size
From: https://www.cnblogs.com/apachecn/p/17505736.html

相关文章

  • 基于STM32单片机的差分升级和增量升级算法源码,这些源码可以在不同平台上进行移植
    基于STM32单片机的差分升级和增量升级算法源码,这些源码可以在不同平台上进行移植。此外,IAP升级和OTA升级技术,这些技术在物联网和车联网领域中得到广泛应用。原创文章,转载请说明出处,资料来源:http://imgcs.cn/5c/653978935134.html提取的知识点和领域范围:1.单片机(STM32):单片机是一......
  • Pytorch | 输入的形状为[seq_len, batch_size, d_model]和 [batch_size, seq_len, d_m
    首先导入依赖的torch包。importtorch我们设:seq_len(序列的最大长度):5batch_size(批量大小):2d_model(每个单词被映射为的向量的维度):10heads(多头注意力机制的头数):5d_k(每个头的特征数):21、输入形状为:[seq_len,batch_size,d_model]input_tensor=torch.randn(5,2,10)inp......
  • 粮油MES质量追溯平台源码,实现一物一码,全程追溯
    粮油生产质量追溯系统源码 MES质量追溯平台源码,实现一物一码,全程追溯,正向追踪,逆向溯源,自主研发,拥有自主知识产权。技术架构:springboot+mybatis+easyui+mysql。粮油生产质量追溯系统可广泛用于粮油生产加工领域。实现种植主体、种植基地、生产计划、压榨、精炼、包装、销售、物料......
  • 语音厅源码实用功能屏幕的转换
     在我们日常生活中,我们会利用电子设备去放松、释放压力,像是利用手机去看电影、看电视剧等,今天我们要分享的知识就与这个释放压力的方式有关,那是什么哪?我们都知道现在市面上的大部分手机都是长方形的,所以在我们看手机上的内容大部分都是竖着的,那我们如果去看电影、电视剧时,则也会......
  • C# 实现 Linux 视频聊天、远程桌面(源码,支持信创国产化环境,银河麒麟,统信UOS)
        园子里的有朋友在下载并了解了《C#实现Linux视频会议(源码,支持信创环境,银河麒麟,统信UOS)》中提供的源码后,留言给我说,这个视频会议有点复杂了,代码比较多,看得有些费劲。问我能不能整个简单点的Demo,只要有视频聊天和远程桌面的功能就可以。于是,我就又写了一个Demo来供大......
  • Pytorch | view()函数的使用
    函数简介Pytorch中的view函数主要用于Tensor维度的重构,即返回一个有相同数据但不同维度的Tensor。根据上面的描述可知,view函数的操作对象应该是Tensor类型。如果不是Tensor类型,可以通过tensor=torch.tensor(data)来转换。普通用法(手动调整size)view(参数a,参数b,…),其中,总......
  • jQuery源码浅谈系列---$.attr()
    attr()   1、attr(name)     取得第一个匹配元素的属性值。如果元素没有相应的属性,则返回undefined。  2、attr(properties)     将一个"名/值"形式的对象设置为所有匹配元素的属性。    注:要设置class属性,必须用'className'作为属性名。     举例:......
  • springboot+vue基于Web的社区医院管理服务系统,附源码+数据库+论文+PPT,适合课程设计、
    1、项目介绍在Internet高速发展的今天,我们生活的各个领域都涉及到计算机的应用,其中包括社区医院管理服务系统的网络应用,在外国线上管理系统已经是很普遍的方式,不过国内的管理系统可能还处于起步阶段。社区医院管理服务系统具有社区医院信息管理功能的选择。社区医院管理服务系统......
  • 【源码阅读】其他
     Export语法文件export_stmt::=KW_EXPORTKW_TABLEbase_table_ref:tblRefwhere_clause:whereExprKW_TOSTRING_LITERAL:pathopt_properties:propertiesopt_broker:broker{:RESULT=newExportStmt(tblRef,whereExpr,path,pr......
  • 【源码阅读】5. Broker Load 导入任务的执行流程
    load_stmt::=KW_LOADKW_LABELjob_label:labelLPARENdata_desc_list:dataDescListRPARENopt_broker:brokeropt_properties:properties{:RESULT=newLoadStmt(label,dataDescList,broker,properties);:}|KW_LOADKW_LAB......