首页 > 编程语言 >Albert 源码解析:分组复用

Albert 源码解析:分组复用

时间:2023-08-21 18:33:23浏览次数:46  
标签:layer num self Albert 复用 states 源码 group hidden

class AlbertGroup(nn.Module):
    def __init__(self, config):
        super(AlbertGroup, self).__init__()
        self.inner_group_num = config.inner_group_num
        self.inner_group = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])

    def forward(self, hidden_states, attention_mask, head_mask):
        layer_attentions = ()
        layer_hidden_states = ()
        for inner_group_idx in range(self.inner_group_num): # [1]
            layer_module = self.inner_group[inner_group_idx]
            layer_outputs = layer_module(hidden_states, attention_mask, head_mask)
            hidden_states = layer_outputs[0]
            layer_attentions = layer_attentions + (layer_outputs[1],)
            layer_hidden_states = layer_hidden_states + (hidden_states,)
        return (layer_hidden_states, layer_attentions)

class AlbertTransformer(nn.Module):
    def __init__(self, config):
        super(AlbertTransformer, self).__init__()
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
        self.num_hidden_layers = config.num_hidden_layers
        self.num_hidden_groups = config.num_hidden_groups
        self.group = nn.ModuleList([AlbertGroup(config) for _ in range(config.num_hidden_groups)])

    def forward(self, hidden_states, attention_mask, head_mask):
        all_hidden_states = ()
        all_attentions = ()
        for layer_idx in range(self.num_hidden_layers):
            if self.output_hidden_states and layer_idx == 0:
                all_hidden_states = all_hidden_states + (hidden_states,)
			# [2]
            group_idx = int(layer_idx / self.num_hidden_layers * self.num_hidden_groups)
            layer_module = self.group[group_idx]
            layer_outputs = layer_module(hidden_states, attention_mask, head_mask[layer_idx])
            hidden_states = layer_outputs[0][-1]
            if self.output_attentions:
                all_attentions = all_attentions + layer_outputs[1]
            if self.output_hidden_states:
                all_hidden_states = all_hidden_states + layer_outputs[0]
        outputs = (hidden_states,)
        if self.output_hidden_states:
            outputs = outputs + (all_hidden_states,)
        if self.output_attentions:
            outputs = outputs + (all_attentions,)
        return outputs  # last-layer hidden state, (all hidden states), (all attentions)

config.inner_group_num是组内TFBlock数量,这里记为 GS,num_hidden_layers是总的 TFBlock 层数,记为 LC,num_hidden_groups是分组数量,记为 GC。

可以看出来AlbertGroup含有 GS 个AlbertLayer(也就是TFBlock),逻辑就是依次调用它里面的AlbertLayer([1])。 GS 应该等于LC // GC,但是源码里面三个值都能自由设定,没有校验。

然后AlbertTransformer含有 GC 个AlbertGroup,这些AlbertGroup都是重复使用的。在正向传播期间,程序遍历每一层,根据层序号得到分组序号,然后调用整个分组,也就是把隐藏向量传入层里面所有的组([2])。

直观演示,假设:

LC = 12
GC = 3
GS = LC // GC = 4

那么第一轮迭代,Li = 0, Gi = 0,隐藏向量以此传入G0L0, G0L1, G0L2, G0L3

之后的三轮迭代,Li = 1, 2, 3,还是Gi = 0,隐藏向量以相同方式传入G0L0, G0L1, G0L2, G0L3三次。

之后,Li = 4, 5, 6, 7Gi = 1,隐藏向量传入G1L0, G1L1, G1L2, G1L3四次。

之后,Li = 8, 9, 10, 11Gi = 2,隐藏向量传入G2L0, G2L1, G2L2, G2L3四次。

整体的模块调用路径是这样:

G0L0, G0L1, G0L2, G0L3
G0L0, G0L1, G0L2, G0L3
G0L0, G0L1, G0L2, G0L3
G0L0, G0L1, G0L2, G0L3
G1L0, G1L1, G1L2, G1L3
G1L0, G1L1, G1L2, G1L3
G1L0, G1L1, G1L2, G1L3
G1L0, G1L1, G1L2, G1L3
G2L0, G2L1, G2L2, G2L3
G2L0, G2L1, G2L2, G2L3
G2L0, G2L1, G2L2, G2L3
G2L0, G2L1, G2L2, G2L3

也就是层数是 12 没错,组数是 3 没错,但是每个组被复用了 4 次。

关于复用来看,有三个关键参数,第一个是每个组的容量,也就是 GS,第二个是一共有多少组,也就是GC,第三个是每个组复用多少次,实际上等于LC // GC,但这里面没有任何一个参数直接设置这个,你只能设置 LC。估计是为了和前代保持一致,但非常非常不好用。

假设我们把这个参数开放出来,叫做 GR(分组副本),那么并让 LC = GS * GR * GC,一切就合理了。我们完全可以将重复的层看作新的一层,只不过参数是和其他层共享的。

而且这套设置同时兼容跨层复用和相邻层的复用:

L0, L0, L1, L1, ..., LN, LN

对于相邻层复用,我们只需要把GS设成 1,GR设成 2,GC设成 N。

还有一种是跨层复用:

L0, L1, L2, ..., LN, L0, L1, L2, ... LN

我们只需要把GC设成 1,GS设成 N,然后GR设成 2 。

标签:layer,num,self,Albert,复用,states,源码,group,hidden
From: https://www.cnblogs.com/apachecn/p/17646772.html

相关文章

  • EventBus源码再分析
    一、概述EventBus是一个开源的用于Android和Java上的一个:订阅--->发布事件总线。优点:1.只要是在一个JVM内,就可以实现通信2.小巧灵活、不占内存3.解耦,切换线程灵活4.库小,不占内存缺点:1.注册和反注册时一对,如果忘记了......
  • ASP.NET版LIMS系统源码 实验室信息管理系统
    实验室信息管理系统(LaboratoryInformationManagementSystem)简称LIMS系统,是指通过计算机对实验室的各种信息进行管理的计算机软、硬件系统,并将实验室的设备各种信息通过计算机网络连接起来,采用科学的管理思想和先进的数据库技术,实现以实验室为核心,集检验业务管理、检测资源管理、......
  • RocketMQ源码(四):RocketMQ生产者发送消息流程
    RocketMQ通过Producer发送消息,以同步方式发送普通消息为例,分析发送消息的整体流程。Producer的示例代码如下:1importorg.apache.rocketmq.client.producer.DefaultMQProducer;2importorg.apache.rocketmq.client.producer.SendResult;3importorg.apache.rocketmq.......
  • (一)Dubbo源码解析:增强SPI
    〇、前言在Dubbo的架构设计中,如何可以通过“类插拔”的方式,对其功能进行灵活的扩展或者削弱,那么,SPI起到了极其关键的作用。本篇文章作为分析Dubbo源码的第一篇文章,我们先暂时放下“服务注册发布流程”、“服务启动流程”、“请求处理流程”……这些功能代码的探索,我们先从最基本的......
  • 逻辑清晰,详解社交源码Android开发SDK
    前篇我们讲解了有关如何在IOS平台开发集成SDK,那么今天来给大家简单讲解下如何在社交源码Android客户端上开发集成。1.获取SDK:从提供SDK的第三方开发者或公司获得SDK的相关文件和文档。2.导入SDK文件:将SDK的库文件(.jar或.aar格式)拷贝到Android项目的libs文件夹中。3.配置权限:检查并......
  • app直播源码,读取多行文本、读取文件分割多行文本
    app直播源码,读取多行文本、读取文件分割多行文本读取文本 publicfunctiondaoru(){/* *逐行读取TXT文件  */     $rep=str_replace("\n",',',"TD92069E76EC27CA8B66B631CB49A9C6TD5A22D898050393C2F8D5C29C854F1B");    $cont=explode(',',$re......
  • 直播系统源码,实现上滑加载分页(触底加载)
    直播系统源码,实现上滑加载分页(触底加载) //依据分类查询图书  publicfunctionquery_book_by_classid(){    $token=input('token');    $class_id=input('class_id');    $page=input('page');//起始行    $per_page=input('per_page');//......
  • 机房收费管理系统-计算机毕业设计源码+LW文档
    【摘要】作为计算机机房管理的必要组成部分,计算机机房管理系统有助于机房资源的合理分配、统一管理和设备利用率的提高,从而有力地保证了机房的管理质量。现代化、信息化和自动化是计算机机房的发展方向,它们旨在实现无人或少人值守的开放式管理,并减轻管理员的压力。通过自动计费和合......
  • 实验室信息管理系统(LIMS)源码,采用灵活的架构开发,支持多种应用程序和技术
    实验室信息管理系统(LIMS)是指帮助实验室组织和管理实验数据的计算机软件系统,它将实验室操作有机地组织在一起,以满足实验室工作流程的所有要求。它能以不同的方式支持实验室的工作,从简单的过程(如样品采集和入库)到复杂的流程(如教据报告和实验结果分析),完全改变实验室的工作流程,使......
  • StoneDB 源码解读系列|Tianmu 引擎工具类模块源码详解(一)
    StoneDB源码解读系列文章正式开启,预计以周更的形式跟大家见面,请多多支持~本篇源码解读内容已进行直播分享,可在视频号观看直播回放,也可点击阅读原文跳转至B站观看回放视频。PPT内容可在社区论坛中查看下载:https://forum.stonedb.io/t/topic/89各个工具类属于Tianmu引擎......