总览
本文留下调试 Gemma 模型的记录。很乱,但我想不出更好的组织方式了。
gemma-2b 模型被封装在 GemmaForCausalLM
类中,这个类继承于 GemmaPreTrainedModel
。
而模型的本体是 GemmaModel
类(这个对象实例包含在 GemmaForCausalLM
实例中)。 也继承于 GemmaPreTrainedModel
。
GemmaPreTrainedModel
继承于 transformers.PreTrainedModel
,重写了 _init_weights()
负责 nn.Linear
和 nn.Embedding
的权重初始化。额外增加了 _setup_cache()
和 _reset_cache()
两个方法,处理缓存 past_key_value
。
GemmaModel
中有个 self.layers
对象,由 18 层的 GemmaDecoderLayer
构成。
总之,
- 继承关系:
GemmaForCausalLM
->GemmaModel
->GemmaPreTrainedModel
->PreTrainedModel
- 自注意力的 18 层所在位置:
GemmaModel
实例内
GemmaForCausalLM
步骤
PreTrainedModel
类继承了 GenerationMixin
,使得 GemmaForCausalLM
拥有 generate()
方法。用 model.generate()
进行文本生成,这就进入了 GemmaForCausalLM
步骤。
在进入该步骤之前,tokenizer 使用左填充的方式将多序列填充到相同长度。
首先使用 self.model
(GemmaModel
)让 input_ids
经过一系列的自注意力机制,输出 hidden_states
。
使用 self.lm_head
(nn.Linear
)映射到 256000 维度,得到 logits
。转换到 float32。
GemmaModel
步骤
这是 Gemma 的核心。
经过 Embedding
,转换为嵌入后,乘上 hidden_size**0.5
进行标准化。
接下来是 18 层 GemmaDecoderLayer
。
- 存储一个
residual
,开始自注意力机制self.input_layernorm
(GemmaRMSNorm
),首先 \(x·\frac{1}{\sqrt{x^2+\epsilon}}\) 对每个 embedding 归一化,然后使用可训练权重做乘法。全程在 float32 下运算self.self_attn
(GemmaSdpaAttention
),自注意力。是对F.scaled_dot_product_attention()
的封装。使用 MultiQueryAttention,完成注意力后进行线性变换- 使用
residual
进行残差连接
- 存储一个
residual
,开始全连接层self.post_attention_layernorm
(GemmaRMSNorm
),再次归一化self.mlp
(GemmaMLP
),多层感知器。从 2048 维到 16348 维映射出 \(x_1\) 和 \(x_2\),进行 \(\text{gelu}(x_1)·x_2\),再映射回 2048 维- 使用
residual
进行残差连接
- 重复 18 次
最后再经过 self.norm
(GemmaRMSNorm
)。
各步骤中值得注意的地方
GemmaDecoderLayer
之前
attention_mask
,避免对填充标记执行注意力操作。多用于 batch_size 大于 1、对输入序列进行 padding 的情况,避免模型对 padding 施加 attention(因为无意义)。
causal_mask
是用 _update_causal_mask()
从 attention_mask
转换而得。具体来说是从 [0,1]
转换为 [0, -inf]
(通过 torch.finfo(dtype).min
获得负无穷)。
causal_mask
与attention_mask
在代码中的角色很混乱。函数嵌套过程中两者会相互转换。
在 _update_causal_mask()
中,还会使用 AttentionMaskConverter._unmask_unattended()
取消填充部分的 mask,以适应 SDPA 的节省内存注意力方法。
GemmaSdpaAttention
之中
使用到了 past_key_value
(transformers.DynamicCache
类型) 存储每一层的 kv 中间结果,总共存储 18 对 key_states
和 value_states
。这是 KV Cache 机制,能够显著提升速度。
经过调试可得知,从生成第二个新词开始,输入到模型的 hidden_states
就只有一个 token 的长度了。多亏了 KV Cache,不需要额外计算前面已有词的 Key 和 Value。
本节参考:
- Young,“大模型推理性能优化之KV Cache解读”,https://zhuanlan.zhihu.com/p/630832593
本文参考
谷歌的 Gemma 开源模型和代码,以及 HuggingFace 的 Transformers。
标签:层面,模型,Gemma,attention,mask,self,GemmaForCausalLM,GemmaModel,代码 From: https://www.cnblogs.com/chirp/p/18153877