首页 > 其他分享 >代码层面上学习Gemma模型

代码层面上学习Gemma模型

时间:2024-04-23 22:12:21浏览次数:18  
标签:层面 模型 Gemma attention mask self GemmaForCausalLM GemmaModel 代码

总览

本文留下调试 Gemma 模型的记录。很乱,但我想不出更好的组织方式了。

gemma-2b 模型被封装在 GemmaForCausalLM 类中,这个类继承于 GemmaPreTrainedModel

而模型的本体是 GemmaModel 类(这个对象实例包含在 GemmaForCausalLM 实例中)。 也继承于 GemmaPreTrainedModel

GemmaPreTrainedModel 继承于 transformers.PreTrainedModel,重写了 _init_weights() 负责 nn.Linearnn.Embedding 的权重初始化。额外增加了 _setup_cache()_reset_cache() 两个方法,处理缓存 past_key_value

GemmaModel 中有个 self.layers 对象,由 18 层的 GemmaDecoderLayer 构成。

总之,

  • 继承关系:GemmaForCausalLM -> GemmaModel -> GemmaPreTrainedModel -> PreTrainedModel
  • 自注意力的 18 层所在位置:GemmaModel 实例内

GemmaForCausalLM 步骤

PreTrainedModel 类继承了 GenerationMixin,使得 GemmaForCausalLM 拥有 generate() 方法。用 model.generate() 进行文本生成,这就进入了 GemmaForCausalLM 步骤。

在进入该步骤之前,tokenizer 使用左填充的方式将多序列填充到相同长度。

首先使用 self.modelGemmaModel)让 input_ids 经过一系列的自注意力机制,输出 hidden_states

使用 self.lm_headnn.Linear)映射到 256000 维度,得到 logits。转换到 float32。

GemmaModel 步骤

这是 Gemma 的核心。

经过 Embedding,转换为嵌入后,乘上 hidden_size**0.5 进行标准化。

接下来是 18 层 GemmaDecoderLayer

  • 存储一个 residual,开始自注意力机制
    • self.input_layernormGemmaRMSNorm),首先 \(x·\frac{1}{\sqrt{x^2+\epsilon}}\) 对每个 embedding 归一化,然后使用可训练权重做乘法。全程在 float32 下运算
    • self.self_attnGemmaSdpaAttention),自注意力。是对 F.scaled_dot_product_attention() 的封装。使用 MultiQueryAttention,完成注意力后进行线性变换
    • 使用 residual 进行残差连接
  • 存储一个 residual,开始全连接层
    • self.post_attention_layernormGemmaRMSNorm),再次归一化
    • self.mlpGemmaMLP),多层感知器。从 2048 维到 16348 维映射出 \(x_1\) 和 \(x_2\),进行 \(\text{gelu}(x_1)·x_2\),再映射回 2048 维
    • 使用 residual 进行残差连接
  • 重复 18 次

最后再经过 self.normGemmaRMSNorm)。

各步骤中值得注意的地方

GemmaDecoderLayer 之前

attention_mask,避免对填充标记执行注意力操作。多用于 batch_size 大于 1、对输入序列进行 padding 的情况,避免模型对 padding 施加 attention(因为无意义)。

causal_mask 是用 _update_causal_mask()attention_mask 转换而得。具体来说是从 [0,1] 转换为 [0, -inf](通过 torch.finfo(dtype).min 获得负无穷)。

causal_maskattention_mask 在代码中的角色很混乱。函数嵌套过程中两者会相互转换。

_update_causal_mask() 中,还会使用 AttentionMaskConverter._unmask_unattended() 取消填充部分的 mask,以适应 SDPA 的节省内存注意力方法。

GemmaSdpaAttention 之中

使用到了 past_key_valuetransformers.DynamicCache 类型) 存储每一层的 kv 中间结果,总共存储 18 对 key_statesvalue_states。这是 KV Cache 机制,能够显著提升速度。

经过调试可得知,从生成第二个新词开始,输入到模型的 hidden_states 就只有一个 token 的长度了。多亏了 KV Cache,不需要额外计算前面已有词的 Key 和 Value。

本节参考:

本文参考

谷歌的 Gemma 开源模型和代码,以及 HuggingFace 的 Transformers。

标签:层面,模型,Gemma,attention,mask,self,GemmaForCausalLM,GemmaModel,代码
From: https://www.cnblogs.com/chirp/p/18153877

相关文章

  • 【专题STM32F03】FreeRTOS 队列queue传递结构体,野火例程代码简单修改。
    /************************************************************************@filemain.c*@authorfire*@versionV1.0*@date2018-xx-xx*@briefFreeRTOSV9.0.0+STM32消息队列******************************************************......
  • 如何从架构层面降低公有云多可用区同时故障的概率
    阿里云和腾讯云都曾出现过因一个组件故障而导致所有可用区同时瘫痪的情况。本文将探讨如何从架构设计的角度减小故障域,在故障发生时最小化业务损失,并以Sealos的稳定性实践为例,分享经验教训。抛弃主从,拥抱点对点架构从腾讯云故障报告中可以看出来多可用区一起挂基本都是因为一......
  • Android Studio 蓝牙 示例代码(转)
    原文:https://blog.csdn.net/qq_40511184/article/details/122698077因为androidstudio升级,下面代码中的startactivityresult函数有变化,不能使用,需要更换为publicActivityResultLauncher<Intent>register;ActivityResultLauncher<Intent>startBlueTooth=registerForActi......
  • 35天【代码随想录算法训练营34期】第八章 贪心算法 part04 ( ● 860.柠檬水找零 ● 4
    860.柠檬水找零classSolution:deflemonadeChange(self,bills:List[int])->bool:amt_five=0amt_ten=0amt_twenty=0foriinbills:ifi==5:amt_five+=1elifi==10:......
  • 快刀斩乱麻,DevOps让代码评审也自动起来
    在Dr.MichaelaGreiler的  HowCodeReviewsatMicrosoft一文中提到,微软有140000名员工,其中44%员工是工程师。这意味着,有超过6000名的工程师同时在同一个代码库上开发Office、VisualStudio、Windows等产品。想要确保不同子团队开发的代码能完美协作,并不是一件易事。 那么,如......
  • 贝叶斯分位数回归、lasso和自适应lasso贝叶斯分位数回归分析免疫球蛋白、前列腺癌数据
    原文链接:http://tecdat.cn/?p=22702最近我们被客户要求撰写关于贝叶斯分位数回归的研究报告,包括一些图形和统计输出。贝叶斯回归分位数在最近的文献中受到广泛关注,本文实现了贝叶斯系数估计和回归分位数(RQ)中的变量选择,带有lasso和自适应lasso惩罚的贝叶斯摘要还包括总结结果、......
  • R语言用GARCH模型波动率建模和预测、回测风险价值 (VaR)分析股市收益率时间序列|附代
    原文链接:http://tecdat.cn/?p=26897最近我们被客户要求撰写关于GARCH的研究报告,包括一些图形和统计输出。风险价值(VaR)是金融风险管理中使用最广泛的市场风险度量,也被投资组合经理等从业者用来解释未来市场风险风险价值(VaR)VaR可以定义为资产在给定时间段内以概率θ......
  • mybatis-plus 代码生成器步骤
    mybatis-plus代码生成器步骤:1.添加依赖到pom.xml<dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-generator</artifactId><version>3.4.1</version></dependency>&l......
  • 《渣男代码历险记》第五章:设计一个算法删除单链表L(有头结点)中的一个最小值结点
    为了删除单链表L中的一个最小值结点,我们可以遍历链表,找到最小值结点及其前驱结点,然后修改前驱结点的指针,使其指向最小值结点的下一个结点。以下是算法的解析和代码实现:初始化两个指针pre和cur,分别指向头结点和头结点的下一个结点。初始化一个变量min_val,用于存储当前最小值,将其......
  • 《渣男代码历险记》第四 双指针怪
    已知一个带有表头结点的单链表,结点结构为:data next假设该链表只给出了头指针head。在不改变链表的前提下,请设计一个尽可能高效的算法,查找链表中倒数第k(k为正整数)个位置上的结点。若查找成功,算法输出该结点的data值,并返回1;否则,只返回0。要求:(1)描述算法的基本设......