首页 > 编程语言 >LLM大模型: MOE原理和源码解析

LLM大模型: MOE原理和源码解析

时间:2024-06-15 15:58:17浏览次数:24  
标签:current expert self states 源码 export LLM hidden MOE

  1、古人云:闻道有先后,术业有专攻!每个人的能力范围是有限的,不可能360行,行行都精通!所以搞研究都会选一个细分领域深耕,争取在这个领域做到世界top级别的泰斗!一个团队,内部也都是在各个领域擅长的人组成,比如前端、ui、后端、算法、运维、运营等,大家互相配合,完成既定目标!本人多年前做传统的数据挖掘和机器学习,最常用的就是随机森林random forrest了:树模型不需要事先做归一化预处理,模型本身根据信息增益选择合适的特征分裂;单颗树可能判断错,那就用多棵树一起判断,找到判断结果最多的那个,正确的概率就很大了!说了这么多,想表达的就一个意思:群策群力!如果目标过于复杂,单个个体已经无法达到既定目标,那就把目标拆解,不同的细分目标让不同的专业人士去做,大家群策群力,这就是常说的:专业的事让专业的人去干!截至目前,这个道理同样也适用于大模型: 用户的需求多种多样,单一的大模型很难完全满足客户需求了,那就把单个大模型拆分成多个“小模型”,每个小模型都只用各个细分领域的数据训练,专门用于回答用户在细分领域的问题,这就是所谓的Mixture-of-Experts!google的论文(https://icml.cc/media/icml-2022/Slides/17378.pdf)中有效果对比,如下:同样都是64B参数,分成64个export,每个export只有1B的参数,这样做的效果比GPT3都还要好

       

   这个也可以从我之前做的代码相似度检测的效果来印证:https://www.cnblogs.com/theseventhson/p/18211242  这个GraphCodeBERT是基于bert用代码语料训练的,参数也就1.2亿个,保存模型的bin文件不到500M,是标准的小模型!但是这个小模型使用的数据全是代码,并且代码还提取了AST/DFG作为特征,用于判断两个函数是否语义相似效果非常好!所以模型效果好不好,和大小没太大关系,主要还是训练语料和输入特征是否高质,模型没必要盲目做大!MOE的架构如下:

       

       核心在于每个transformer block的MLP层:之前只有一个神经网络,一般是先升维再降维;现在是把一个大的神经网络拆分成多个小的FFN,多个小的FFN前面有个Gating,用来判断输入数据从那个FFN继续推进(本质就是个路由器,选择合适的分发路径)!和传统的稠密dense model比,MOE这种稀疏sparse model的优势:

  • 推理时只有一小部分的export被激活用于计算,而不是整个网络,节约算力!
  • 每个export各自专注于特定的任务或数据类型,MoE 模型能够更好地处理复杂和多样化的数据
  • 增加export就能扩展模型容量(看着是不是像Lora?在原有线性层的旁边再增加一个旁路),处理新领域的问题和数据,泛化能力比dense model好!

   2、(1)MOE架构也已经实现了,在transformer包的transformers-main\src\transformers\models\mixtral\modeling_mixtral.py这个文件里面。整体的代码结构如下:新增了几个MOE相关的类,其余的结构和llama几乎一样。

            

            既然是用于推理,必然存在于decoder端(encoder端主要用于提取特征向量,没必要用MOE架构),在forward函数中的fully connect模块,attention和norm之后就是MOE啦,如下:

           

      所谓的export:就是个3层的神经网络:

           

    特别说明一下MixtralBlockSparseTop2MLP这里的forward函数:

    def forward(self, hidden_states):
        current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
        current_hidden_states = self.w2(current_hidden_states)
        return current_hidden_states

  同一个hidden_states,经过w1线性转换后激活,然后和w3线性转换后相乘,再通过w2做线性转换,为啥要这么干?

  •  self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) :核心还是特征的非线性组合,目的是为了更好地生成非线性特征,举例如下:

        

  • current_hidden_states = self.w2(current_hidden_states)   再次通过线性变换进入下一个空间,后续所有的操作都在新空间进行,不会和现有空间的操作互相影响

      (2)选择export的forward函数整个流程:

"""将输入数据通过多个export进行处理,并根据动态计算的路由权重将不同输入分配给不同的export"""
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """ """
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        if self.training and self.jitter_noise > 0:
            hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)#对输入数据应用抖动噪声(jitter noise),增加模型的鲁棒性
        hidden_states = hidden_states.view(-1, hidden_dim)#三维变为二维,方便后续处理
        # router_logits: (batch * sequence_length, n_experts)
        # 通过gate计算路由权重得分routing_weights,选择export
        router_logits = self.gate(hidden_states)

        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        # 选择概率最高的 k 个export
        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
        # 归一化权重
        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
        # we cast back to the input dtype
        routing_weights = routing_weights.to(hidden_states.dtype)
        
        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
        )

        # One hot encode the selected experts to create an expert mask
        # this will be used to easily index which expert is going to be sollicitated
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

        # Loop over all available experts in the model and perform the computation on each expert
        for expert_idx in range(self.num_experts):
            expert_layer = self.experts[expert_idx]
            idx, top_x = torch.where(expert_mask[expert_idx])

            # Index the correct hidden states and compute the expert hidden state for
            # the current expert. We need to make sure to multiply the output hidden
            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
            current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

            # However `index_add_` only support torch tensors for indexing so we'll use
            # the `top_x` tensor here.
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
        return final_hidden_states, router_logits

  

参考:

1、https://arxiv.org/pdf/2405.11273

2、https://ar5iv.labs.arxiv.org/html/2402.07871

3、https://icml.cc/media/icml-2022/Slides/17378.pdf    GLaM: Efficient Scaling of Language Models with Mixture-of-Experts

4、https://www.bilibili.com/video/BV1jH4y177DL/?spm_id_from=333.788.recommend_more_video.0&vd_source=241a5bcb1c13e6828e519dd1f78f35b2

5、https://www.bilibili.com/video/BV1Xu4y1K7zn/?spm_id_from=333.788.recommend_more_video.2&vd_source=241a5bcb1c13e6828e519dd1f78f35b2   MOE源码

6、https://www.bilibili.com/video/BV1cy421z7er/?spm_id_from=333.788.recommend_more_video.0&vd_source=241a5bcb1c13e6828e519dd1f78f35b2  

标签:current,expert,self,states,源码,export,LLM,hidden,MOE
From: https://www.cnblogs.com/theseventhson/p/18247463

相关文章

  • ReentrantReadWriteLock:深度解析与源码探险
    1.概述ReentrantReadWriteLock是Java并发包java.util.concurrent.locks中的一个重要类,它提供了可重入的读写锁功能。与传统的互斥锁(如synchronized或ReentrantLock)不同,ReentrantReadWriteLock允许多个线程同时读取共享资源,但在写入时则要求独占锁。这种设计显著提高了在读......
  • ReentrantLock的非公平锁(NonfairSync)深度解析:源码之旅与实战策略
    1.引言在Java并发编程中,ReentrantLock作为一种可重入的互斥锁,提供了比synchronized更强大和灵活的功能。其中,NonfairSync作为ReentrantLock内部非公平锁的实现,其设计理念和源码实现都体现了对性能和公平性的权衡。2.NonfairSync概述非公平锁特性:新到达的线程在......
  • 【计算机毕业设计】基于springboot的大创管理系统【源码+lw+部署文档】
    包含论文源码的压缩包较大,请私信或者加我的绿色小软件获取免责声明:资料部分来源于合法的互联网渠道收集和整理,部分自己学习积累成果,供大家学习参考与交流。收取的费用仅用于收集和整理资料耗费时间的酬劳。本人尊重原创作者或出版方,资料版权归原作者或出版方所有,本人不对所......
  • Translation Agent 源码分析
    吴恩达老师开源了一套AIAgent翻译工作流TranslationAgent。https://github.com/andrewyng/translation-agent/工作流主要分三个步骤:通过指定大语言模型(LLM)进行语言之间的翻译;对翻译结果进行反思,并提出改进建议;再根据这些建议进行优化翻译。很多AI工作都可以用这样的......
  • 球面双站定位c++源码及原理介绍(已知2点经纬高及看向目标的方位、俯仰,求目标的经
    球面双站定位是一个空间几何问题,它用于在给定两个已知站点的经纬度和他们向特定目标看去的方位和俯仰角的情况下,计算目标的经纬度。这个问题可以通过解一个线性方程组来求解。假设两个站点分别是A和B,他们分别看向目标的方位分别是θAθA​和θBθB​,俯仰角分别是ϕAϕA​和ϕBϕB......
  • 开源模型应用落地-Qwen2-7B-Instruct与vllm实现推理加速的正确姿势(十)
    一、前言  目前,大语言模型已升级至Qwen2版本。无论是语言模型还是多模态模型,均在大规模多语言和多模态数据上进行预训练,并通过高质量数据进行后期微调以贴近人类偏好。在本篇学习中,将集成vllm实现模型推理加速,现在,我们赶紧跟上技术发展的脚步,去体验一下新版本模型的推理质......
  • 基于Java+SpringBoot+Vue前后端分离宠物管理系统(源码+万字LW+PPT+部署教程)
    博主介绍:✌全网粉丝10W+csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌一、作品包含源码+数据库+设计文档LW+PPT+全套环境和工具资源+部署教程二、项目......
  • 最新支持ChatGPT3.5/GPT4.0网站源码,AI系统源码,ChatGPT运营网站系统,支持GPTs应用、AI绘
    一、文章前言SparkAi创作系统是基于ChatGPT进行开发的Ai智能问答系统和Midjourney绘画系统,支持OpenAI-GPT全模型+国内AI全模型。支持GPT-4o大模型、文档分析、识图图片理解、GPTs应用、GPT语音对话、联网提问、GPT-4全模型、DALL-E3文生图、GPT4-All联网搜索模型、思维导图、......
  • 美食天下 网页设计 html源码 大作业
    ......
  • AI大佬吴恩达+OpenAI团队编写:面向大模型入门者的 LLM CookBook 汉化版
    粉丝们久等了!!!我又来更LLM大模型的必备读物啦!这次给大家推荐的是AI圈无人不知的吴恩达大佬+OpenAI团队一起编写的大模型入门文档,也就是这本:大型语言模型(LLM)的权威文档<面向开发者的LLM入门PDF>在Github上已经高达56.8kstar了,这含金量啧啧啧朋友们如果有需要这份《LLMC......