首页 > 编程语言 >llama的rope源码阅读

llama的rope源码阅读

时间:2023-12-21 17:33:10浏览次数:38  
标签:dim unsqueeze torch ids rope 源码 llama position hidden

关键代码的理解:

class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size)) # 是需要学习的参数
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype) # 再变回原始分布, 所以这些norm算法,训练完之后的推理是不影响整体的,只是可以让训练bp传导时候加速模型收敛!所以非常实用.
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`):
            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
            used to pass offsetted position ids when working with a KV-cache.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin) # 最核心的代码: q: 30,128  rotate_half(q): 30, 128   #这个对比二维公式,非常显然,.
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

二维公式.

image

到此源码解读完毕.

标签:dim,unsqueeze,torch,ids,rope,源码,llama,position,hidden
From: https://www.cnblogs.com/zhangbo2008/p/17919474.html

相关文章

  • 开发医疗陪诊系统源码:搭建安全高效的医患互动平台
    本文将深入探讨开发医疗陪诊系统的源码,以及如何搭建一个安全高效的医患互动平台。一、引言医疗陪诊系统旨在通过技术手段,缩短患者与医生之间的距离,提供更快速、便捷的医疗服务。二、技术选型2.1前端技术在搭建医疗陪诊系统的前端时,我们可以选择使用现代化的前端框架,如Vue.js或React......
  • HydroOJ 从入门到入土(9)源码简易修改记录——卍解!
    随着OJ的使用越来越深入,本强迫症总会觉得一些细节有时候不那么符合自己的习惯,但是想改又无处下手,最终还是走上了修改源码的邪路.目录0.重要1.超级管理员查看自测代码2.超级管理员隐身查看比赛/作业题目3.超级管理员隐身查看比赛题目列表4.关掉客观题的多选题部......
  • centos7上源码安装postgresql 13.6
    1环境描述操作系统:Centos7.6postgresql:13.6安装方式:源码安装2创建用户#groupadd-g2000pgsql#useradd-u2000-gpgsqlpgsql3目录规划#mkdir-p/postgresql/{pgdata,archive,scripts,backup,pg13,soft,pg_log}#chown-Rpgsql:pgsql/postgresql#......
  • 智慧医疗APP开发指南:医疗陪诊系统源码实战
    本文将深入探讨智慧医疗APP开发中的医疗陪诊系统,重点介绍实际应用中的源码实战经验。 一、技术选型在开发医疗陪诊系统时,合理的技术选型是确保系统高效运行的关键。可以选择采用前后端分离的架构,使用Vue.js作为前端框架,SpringBoot作为后端框架,通过RESTfulAPI进行通信。二、实战经......
  • 记录 | ubuntu源码编译安装/更新boost版本
    一、卸载当前的版本1、查看当前安装的boost版本dpkg-S/usr/include/boost/version.hpp通过上面的命令,你就可以发现boost的版本了,查看结果可能如下:libboost1.54-dev:/usr/include/boost/version.hpp2、删除当前安装的boostsudoapt-getautoremovelibboost1.54-dev这样就可以删......
  • 记录 | ubuntu源码编译安装faiss
    ubuntu源码编译安装faiss#安装依赖aptupdateaptinstallbuild-essentiallibopenblas-devliblapack-devlibopencv-dev#clonegitclonehttps://github.com/facebookresearch/faiss.gitcdfaiss./configuremake-j32makeinstall使用示例:#include<faiss/IndexF......
  • 记录 | ubuntu源码编译ccls
    ubuntu源码编译ccls#clone代码gitclone--depth=1--recursivehttps://github.com/MaskRay/ccls#安装libclang-15sudoapt-getinstallclanglibclang-15-dev#编译cmake-H.-BRelease-DCMAKE_BUILD_TYPE=Release\-DCMAKE_PREFIX_PATH=/usr/lib/llvm-15\......
  • Spring Boot原理分析 | SpringApplication、Yaml、Properties
    ......
  • Java序列化和反序列化 Serializable BeanUtils.copyProperties赋值属性方法
    Java序列化和反序列化SerializableBeanUtils.copyProperties赋值属性方法packagecom.example.core.mydemo.java;importcom.example.core.mydemo.json2.GsonUtils;importorg.springframework.beans.BeanUtils;importjava.io.*;/***Java序列化和反序列化Serializ......
  • NS-3源码学习(七)追踪和Probe
    追踪框架和WiFi的STA接入AP时使用的ProbeRequest帧、ProbeResponse帧没有关系。追踪NS-3的追踪框架主要用于追踪一个对象当中某个属性的变更、或者某个事件的发生。NS-3初始规定了一些追踪源,一般在model的GetTypeId()方法中定义了这些追踪源(和这个model的属性),我们可以使用两种......