首页 > 编程语言 >LLM大模型: llama源码要点解读(二)

LLM大模型: llama源码要点解读(二)

时间:2024-06-12 23:10:22浏览次数:12  
标签:key LLM self attention value states 源码 llama query

  1、attention机制:这算是transformer架构最大的创新点了!利用attention机制,找到token之间的相似度(或则说距离),根据相似度调整token本身的embedding值,本质就是根据token的context调整自身的embedding值这个思路非常符合人脑对语言和语义的理解!比如”苹果“这个词,如果只看这一个token,没有任何context,根本无法分辨是水果还是手机,此时该token的embedding大概率是没法用的!但是有了context就不一样了,比如:

  • 苹果、香蕉、梨、菠萝这些水果中,我最喜欢的是菠萝!  很明显,这里的苹果是水果,V向量中水果相关的维度值会比较大,其他维度值会较小!
  • 苹果的内存128G,屏幕耐磨不易坏;电池续航8小时,上班时间不用充电!   很明显,这里的苹果是电子产品,V向量中电子产品相关的维度值会比较大,其他维度值会较小!

       具体做法也不复杂:Q和K相乘得到权重值,用权重值乘以V向量(V向量才是token最终的embedding值)!这里打个岔:目前大模型微调最主流的就是Lora了,思路是用low rank矩阵旁路主矩阵,通过旁路矩阵吸收和承载新数据的信息!理论上讲:transformer架构中只要有矩阵的地方,都可以旁路low rank矩阵;但如果真这么做,计算量会很大(算力足够的土豪当然可以每个原矩阵都旁路啦)!为了保证效果,可以优先考虑Q和K矩阵,通过这两个矩阵提取新微调数据的context信息,然后调整token原有的V值

       在llama的attention代码中哟三个大:LlamaAttention、LlamaFlashAttention2、LlamaSdpaAttention,LlamaAttention是基础类,实现了transformer的attention机制,另外两个是基于LlamaAttention做了改进。、

     (1)FlashAttention2:参考官网:https://github.com/Dao-AILab/flash-attention  ; 从名字就能看出来主要改进点是flash(Fast and Memory-Efficient Exact Attention with IO-Awareness)! 核心的改进代码如下:

  •    _upad_input 方法:对输入 tensor 进行 unpad 操作来去除填充 token,减少attention计算时内积的计算量,并生成相应的索引数组和 cumulated sequence lengths;
     #去掉padding的token,节约attention时的算力和时间
    def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
        batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
        #使用索引操作 index_first_axis 来避免不必要的内存复制,减少IO
        key_layer = index_first_axis(
            key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
        ) # 重新排列,允许对输入进行分块处理,减少内存占用
        value_layer = index_first_axis(
            value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
        )# 重新排列,允许对输入进行分块处理,减少内存占用
        if query_length == kv_seq_len:
            query_layer = index_first_axis(
                query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
            )# 重新排列,允许对输入进行分块处理,减少内存占用
            cu_seqlens_q = cu_seqlens_k
            max_seqlen_in_batch_q = max_seqlen_in_batch_k
            indices_q = indices_k
        elif query_length == 1:
            max_seqlen_in_batch_q = 1
            cu_seqlens_q = torch.arange(
                batch_size + 1, dtype=torch.int32, device=query_layer.device
            )  # There is a memcpy here, that is very bad.
            indices_q = cu_seqlens_q[:-1]
            query_layer = query_layer.squeeze(1)
        else:
            # The -q_len: slice assumes left padding.
            attention_mask = attention_mask[:, -query_length:]
            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)

        return (
            query_layer,
            key_layer,
            value_layer,
            indices_q,
            (cu_seqlens_q, cu_seqlens_k),
            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
        )
  • _flash_attention_forward:核心还是调用了flash_attn包的flash_attn_varlen_func方法:
    def _flash_attention_forward(
        self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
    ):
        """   先去掉padding 的token再计算attention,然后再pad回去还原
        Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
        first unpad the input, then computes the attention scores and pad the final attention scores.

        Args:
            query_states (`torch.Tensor`):
                Input query states to be passed to Flash Attention API
            key_states (`torch.Tensor`):
                Input key states to be passed to Flash Attention API
            value_states (`torch.Tensor`):
                Input value states to be passed to Flash Attention API
            attention_mask (`torch.Tensor`):
                The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
                position of padding tokens and 1 for the position of non-padding tokens.
            dropout (`float`):
                Attention dropout
            softmax_scale (`float`, *optional*):
                The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
        """
        if not self._flash_attn_uses_top_left_mask:
            causal = self.is_causal
        else:
            # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
            causal = self.is_causal and query_length != 1

        # Contains at least one padding token in the sequence
        if attention_mask is not None:
            batch_size = query_states.shape[0]
            query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( #先去掉padding
                query_states, key_states, value_states, attention_mask, query_length
            )

            cu_seqlens_q, cu_seqlens_k = cu_seq_lens
            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
            #自动处理 softmax 的计算,并在计算过程中忽略不必要的填充 token
            attn_output_unpad = flash_attn_varlen_func(
                query_states,
                key_states,
                value_states,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=max_seqlen_in_batch_q,
                max_seqlen_k=max_seqlen_in_batch_k,
                dropout_p=dropout,
                softmax_scale=softmax_scale,
                causal=causal,
            )
            #attention计算完成后,重新pad
            attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
        else:
            attn_output = flash_attn_func(
                query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
            )

        return attn_output

  flash_attn_varlen_func核心的优化思路:

  • 索引和序列长度累加器:使用索引数组和累加的序列长度数组来表示变长序列。这些数组允许函数仅对有效的非填充(non-padded)部分进行计算,从而避免了填充 token 的冗余计算
  • 逐步计算(Chunk-Based Computation):将输入序列分割成较小的块,每次只对一个块进行计算。这种方法不仅减少了内存使用,还能更好地利用缓存,提高计算效率
  • 并行计算:充分利用 GPU 的并行计算能力,通过优化的内核实现并行执行矩阵乘法和 softmax 操作。这些优化内核能够最大限度地利用 GPU 的计算资源,减少计算时间
  • 优化的 softmax 计算:通过自定义的内核实现高效的 softmax 计算,避免了数值不稳定性(减去向量的最大值避免underflow和overflow),并且能更快地完成 softmax 操作。

   2、5个model,从名字就能猜出来这5个model的作用:

        

        既然功能不同,代码肯定不同,每种model的核心代码如下:

  • LlamaModel:实现了基本的transformer架构,但并不包括实现特定功能的lm_head!
    class LlamaModel(LlamaPreTrainedModel):
        def __init__(self, config):
            super().__init__(config)
            self.transformer = Transformer(config)
            self.init_weights()
    
        def forward(self, input_ids, attention_mask=None):
            outputs = self.transformer(input_ids, attention_mask=attention_mask)
            return outputs
  • LlamaForCausalLM:使用的自回归auto regression,下一个token就是target,采用cross enctropy的loss计算;
class LlamaForCausalLM(LlamaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.llama = LlamaModel(config)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.init_weights()

    def forward(self, input_ids, attention_mask=None, labels=None):
        transformer_outputs = self.llama(input_ids, attention_mask=attention_mask)
        hidden_states = transformer_outputs[0]
        logits = self.lm_head(hidden_states)
        
        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
        
        return (loss, logits) if loss is not None else logits
  • LlamaForSequenceClassification:用于序列级别分类任务。
class LlamaForSequenceClassification(LlamaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.llama = LlamaModel(config)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.init_weights()

    def forward(self, input_ids, attention_mask=None, labels=None):
        transformer_outputs = self.llama(input_ids, attention_mask=attention_mask)
        hidden_states = transformer_outputs[0]
        logits = self.classifier(hidden_states[:, 0, :])  # Use the CLS token hidden state
        
        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        
        return (loss, logits) if loss is not None else logits
  • LlamaForQuestionAnswering:用于问答任务。添加了一个问答头,包括两个线性层,用于预测答案的起始位置和结束位置。
class LlamaForQuestionAnswering(LlamaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.llama = LlamaModel(config)
        self.qa_outputs = nn.Linear(config.hidden_size, 2)
        self.init_weights()

    def forward(self, input_ids, attention_mask=None, start_positions=None, end_positions=None):
        transformer_outputs = self.llama(input_ids, attention_mask=attention_mask)
        hidden_states = transformer_outputs[0]
        logits = self.qa_outputs(hidden_states)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
        
        loss = None
        if start_positions is not None and end_positions is not None:
            loss_fct = CrossEntropyLoss()
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            loss = (start_loss + end_loss) / 2
        
        return (loss, start_logits, end_logits) if loss is not None else (start_logits, end_logits)
  • LlamaForTokenClassification:用于标记级别分类任务。加了一个标记分类头,用于每个输入标记生成类别标签
    class LlamaForTokenClassification(LlamaPreTrainedModel):
        def __init__(self, config):
            super().__init__(config)
            self.num_labels = config.num_labels
            self.llama = LlamaModel(config)
            self.classifier = nn.Linear(config.hidden_size, config.num_labels)
            self.init_weights()
    
        def forward(self, input_ids, attention_mask=None, labels=None):
            transformer_outputs = self.llama(input_ids, attention_mask=attention_mask)
            hidden_states = transformer_outputs[0]
            logits = self.classifier(hidden_states)
            
            loss = None
            if labels is not None:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            
            return (loss, logits) if loss is not None else logits

    以上各种model的作用不同,但实现的原理没有本质区别:先把hidden_states通过矩阵相乘的方式做线性变换,映射到新的空间,然后再在新的空间做各种操作

   3、旋转位置编码 

       (1)NLP任务中,token的位置是非常重要的,比如”我打小明“和”小明打我“,这两句话的token完全相同,但整体的语义截然相反!传统的RNN和LSTM因为是顺序处理每个词,所以能很好地理解词的位置,但transformer架构就不一样了:为了提升效率,底层计算的时候是并行计算的,所以embedding需要包含位置信息,避免语义出错!此外,token位置信息应该是相对位置信息,不能用绝对位置信息,同样举例:”我打小明“  和  ”昨天下午在公司,我把小明打了一顿“ ,这两句话的语义是一样的,主题都是我打小明,但第二句话中关键token的绝对位置信息完全不同,如果embedding中用绝对位置信息,可能会影响最终的语义理解,需要用相对位置信息!最终的理想结果应该是:我 打 小明  这个三个token的qk内积在第一个短句的结果,和在第二个长句的结果应该是接近甚至一样的,就说明绝对位置不影响整体的语义信息啦(这也是个取其精华、去其糟粕的过程)!所以终极问题来了:这个相对位置编码应该怎么求?

       (2)详细的推导过程参考:https://wmathor.com/index.php/archives/1542/   直接上结论,以二维为例,如下:

         

        其中m就是token在sequence中的绝对位置,seta是频率超参,计算时人为指定的,不是学习得来!位置m的token和位置n的token的q、k相乘就变成了:

        

       R是旋转矩阵。原始的q和k内积之前,先通过旋转矩阵换个位置,分别把自己的绝对位置m、n信息融入,结果中的R(m-n)不就把两个token的相对位置信息包括进去了么

  上面的embedding是2位的,但实际embedding肯定不止二维,咋办了?因为内积满足线性叠加性,因此任意偶数维的 RoPE,我们都可以表示为二维情形的拼接,如下:

        

   每两个维度分成一组,每组单独计算,q的每个维度都能包含位置信息啦!计算时只剩最后1个问题了:seta是超参数,不是学习的来的,而是计算时人为指定的,那么这个seta该怎么设置才合理?

      (3)”我打小明“、”昨天下午在公司,我把小明打了一顿“、”我昨天下午在公司把小明狠狠地按在地上打了一顿“,这3句话的语义是一样的,都是我打了小明,核心token就3个:我  小明   打;理论上讲,3个token的q和k应该比较接近才合理!但实际情况是:这3个token在3句话中位置(不论绝对位置,还是相对位置)差异都较大,怎么能在这3个句子中都能正确地反映相对位置依赖信息了?

  以 llama为例,假设模型的隐藏层维度是d=1024,那么可以选择的 seta 值会覆盖从较低频率到较高频率的范围,如下:

     

     这意味着在 1024 维空间中,会有 512 对不同频率的正弦和余弦函数,分别应用在不同的维度对上;具体的数值列举如下:

    

      有没有发现一个规律啊:seta的值随着维度增加越来越小了!seta值越小,频率越低,波长越长,正余弦在较长的序列范围内变化较慢,在整个序列中可能只有几次完整的波动,这种缓慢的变化允许长波长函数能够平滑地跨越整个序列,捕捉到远距离的位置关系,保持一定的稳定性。因此,低频(长波长)的编码在捕捉到句子整体的结构和主旨上比较有效;反之:seta值越大,频率越高,波长越短,正弦余弦函数在整个序列中完成多次波动,能更好地提取局部语义信息

      上面的描述可能有点抽象,这里举个具体的例子说明:假设我们有一个长度为 L=100 的序列,我们选取低频和高频维度的正弦函数来进行位置编码。

      参数设定

  • 序列长度 L=100
  • 向量维度 d=128
  • 低频维度 i1=0
  • 高频维度 i2​=64

    

      

       

  • 低频编码(例如 i1=0)在整个序列中变化缓慢,只有几次完整波动。比如从位置 m=0到 m=99,正弦值从0变化到接近=0.9。这种变化速度使得编码在远距离位置上仍然保持较高的相关性。例如,位置 m=0 和 m=50 的编码值分别为0和-0.262,尽管不完全相同,但它们的变化较慢,仍然保持一定的相关性。因此,低频编码能够捕捉到整个序列范围内的远距离依赖关系,例如一个句子的开始和结束之间的关系。
  • 反之,高频编码从位置 m=0 到 m=99,正弦值从0变化到接近0.0099。这样的编码变化速度非常快,在短距离内编码值就会大幅度变化

     至于实现:llama在attention之前就要计算旋转位置编码,然后才计算attention的值!

class LlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        if self.config.pretraining_tp > 1:
            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
            query_slices = self.q_proj.weight.split(
                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
            )
            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
            query_states = torch.cat(query_states, dim=-1)

            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
            key_states = torch.cat(key_states, dim=-1)

            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
            value_states = torch.cat(value_states, dim=-1)

        else:
            query_states = self.q_proj(hidden_states)
            key_states = self.k_proj(hidden_states)
            value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        cos, sin = self.rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) #LlamaAttention的forward中在计算attention之前先计算旋转位置编码

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(bsz, q_len, -1)

        if self.config.pretraining_tp > 1:
            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
        else:
            attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

 

     

 

       

  

 

标签:key,LLM,self,attention,value,states,源码,llama,query
From: https://www.cnblogs.com/theseventhson/p/18241641

相关文章