首页 > 其他分享 >大语言模型 MOE 简明实现指南

大语言模型 MOE 简明实现指南

时间:2024-06-21 16:02:51浏览次数:10  
标签:指南 torch 简明 专家 topk 向量 exp hid MOE

这篇文章中,我简要实现一下大语言模型的 MOE 模块。MOE 模块位于每个GPT层中,位于注意力模块的后面,每个MOE模块包含若干个MLP模块作为专家。这些专家是稀疏的,也就是每次选择部分来调用,并不会调用全部,从而节省宝贵的算力。

首先定义一些常量,通常应该在模型配置文件里面。

bs = 5 # 批量大小
seql = 32 # 序列长度
hid = 128 # 隐藏向量维度
nexp = 5 # 专家总数
topk # 所选的专家数量

模块的输入应该是句子中单词的隐藏向量。为了便于测试我直接取了随机数,正常情况下应该是有意义的值。首先需要转换成二维的,便于计算。

x = torch.randn([bs, seql, hid])
x = x.reshape([-1, hid])
x.shape
# torch.Size([160, 128])

然后我们需要一个门(定义在__init__里面,将每个隐藏向量转换成专家得分,进一步经过 softmax 转换成归一化的得分,表示每个专家对这个向量的结果有多大贡献。注意这里我们为每个向量单独分配专家,可能向量#1分配到了专家#1和#2,而向量#2分配到了专家#3和#4,总之可能是不一样的。

gate = torch.nn.Linear(hid, nexp)
exp_logits = gate(x)
exp_probs = torch.softmax(exp_logits, -1)
exp_probs.shape
# torch.Size([160, 5])

每个专家应该是 MLP(定义在__init__里面),但是为了演示我就直接省略了,大家可以从各个大语言模型的源码里面复制粘贴:

experts = [lambda x: x for _ in range(nexp)]

对每个向量分配到的专家按照贡献度排序,得到每个向量地专家排名exp_topk及其得分sc_topk

exp_topk[i, j]表示第i个词的第j个专家的序号,sc_topk[i, j]表示它的得分。

sc_topk, exp_topk = torch.topk(exp_probs, topk, -1)
sc_topk.shape
# torch.Size([160, 2])
exp_topk.shape
# torch.Size([160, 2])

将专家的得分归一化,因为我们选了两个,总和又不是一了,会对结果的大小有影响:

sc_topk /= sc_topk.sum(-1, keepdim=True)

下面我们创建该层的结果数组,累加每个专家的输出,大小和输入一样:

final_hidden_state = torch.zeros_like(x)

然后我们获取每个专家对应的单词序号,和对应的单词排名。exp_topk == exp_i把等于专家exp_i的位置标注为True其它的为False,然后where获取下标。

hid_idcs是调用专家exp_i的向量序号,hid_ranks是该专家对于对应向量的排名

for exp_i in range(nexp):
    hid_idcs, hid_ranks = torch.where(exp_topk == exp_i)

注意每个专家被调用的次数都可能不一样:

[torch.where(exp_topk == exp_i) for exp_i in range(nexp)]
'''
[tensor([  0,   1,   2,   3,  14,  16,  18,  21,  22,  30,  32,  39,  43,  44,
          45,  52,  55,  58,  66,  67,  72,  77,  78,  80,  83,  87,  89,  90,
          91,  93, 102, 103, 105, 107, 108, 115, 116, 117, 126, 131, 133, 134,
         135, 136, 137, 146, 147, 148, 149, 151, 157, 158]),
 tensor([  6,   8,   9,  11,  18,  19,  20,  23,  26,  27,  28,  31,  34,  35,
          37,  41,  47,  50,  51,  53,  54,  56,  57,  59,  60,  62,  63,  71,
          74,  75,  77,  78,  79,  82,  83,  84,  86,  93,  97,  98, 100, 107,
         109, 110, 111, 113, 114, 118, 120, 123, 124, 126, 127, 128, 129, 130,
         139, 140, 143, 144, 145, 150, 155, 159]),
 tensor([  0,   4,   7,   8,  10,  12,  13,  14,  16,  17,  24,  25,  26,  29,
          32,  33,  34,  36,  40,  41,  46,  47,  49,  50,  53,  58,  64,  65,
          68,  70,  72,  73,  76,  81,  82,  85,  88,  89,  92,  94, 101, 103,
         108, 109, 112, 114, 115, 119, 120, 121, 123, 125, 132, 133, 135, 138,
         139, 140, 141, 142, 145, 146, 147, 150, 152, 153, 155, 156, 158]),
 tensor([  1,   5,   6,   7,   9,  11,  12,  13,  15,  20,  22,  23,  28,  29,
          30,  31,  35,  37,  38,  40,  42,  46,  48,  54,  55,  56,  57,  60,
          61,  62,  64,  65,  67,  69,  70,  71,  73,  74,  79,  80,  81,  84,
          86,  95,  96,  98,  99, 102, 104, 106, 110, 111, 113, 116, 118, 119,
         122, 125, 128, 129, 132, 134, 138, 144, 153, 154, 157, 159]),
 tensor([  2,   3,   4,   5,  10,  15,  17,  19,  21,  24,  25,  27,  33,  36,
          38,  39,  42,  43,  44,  45,  48,  49,  51,  52,  59,  61,  63,  66,
          68,  69,  75,  76,  85,  87,  88,  90,  91,  92,  94,  95,  96,  97,
          99, 100, 101, 104, 105, 106, 112, 117, 121, 122, 124, 127, 130, 131,
         136, 137, 141, 142, 143, 148, 149, 151, 152, 154, 156])]
'''

然后我们把每个专家的向量获取到(x[hid_idcs]),传入该专家experts[exp_i](...)

# for ...
    hidden_state = experts[exp_i](x[hid_idcs])
    hidden_state.shape
    # torch.Size([52, 128])

然后需要乘上专家权重,最后加一维以便权重和上面的向量对齐:

# for ...
    weights = sc_topk[hid_idcs, hid_ranks].unsqueeze(-1)
    weights.shape
    # torch.Size([52, 1])
    hidden_state *= weights

然后将当前专家的输出填回到结果数组中:

# for ...
    final_hidden_state[hid_idcs] += hidden_state

每个专家都计算完之后,将结果数组变形成原始的形状,然后作为整个模块的输出:

final_hidden_state = final_hidden_state.reshape([bs, seql, hid])

标签:指南,torch,简明,专家,topk,向量,exp,hid,MOE
From: https://www.cnblogs.com/apachecn/p/18260686

相关文章

  • Windows系统上更换pip源的详细指南
    Python的包管理工具pip允许用户从Python包索引(PyPI)下载和安装第三方库。然而,默认的PyPI源有时可能因为网络问题或地理位置导致访问速度较慢。更换为更快的源可以显著提高下载和安装Python包的速度。本文将详细介绍如何在Windows系统上更换pip的源。1.理解pip源的重要性......
  • Selenium - 入门指南
    入门指南如果你是Selenium的新手,我们有一些资源帮助你快速入门.Selenium通过使用 WebDriver 支持市场上所有主流浏览器的自动化。Webdriver是一个API和协议,它定义了一个语言中立的接口,用于控制web浏览器的行为。每个浏览器都有一个特定的WebDriver实现,称为驱动程......
  • 「Java开发指南」如何使用Spring注释器实现Spring控制器?(二)
    本教程将引导您使用SpringAnnotator实现Spring控制器,标准Java类被添加到搭建项目中,SpringAnnotatorSpring启用Java类。虽然本教程的重点是Spring控制器,但是SpringAnnotator也可以用于Spring服务、组件和存储库。在本教程中,您将学习如何:创建一个Java类将类配置为Spring控制......
  • 金仓数据库全攻略:简化部署,优化管理的全流程指南
    金仓数据库人大金仓(KINGBASE)是一家拥有20多年数据库领域经验的公司,专注于数据库产品的研发和服务。公司曾参与多项国家级重大课题研究,如"863"计划、电子发展基金、信息安全专项等。其核心产品是金仓数据库管理系统KingbaseES,这是一个大型通用数据库,具有国际先进水平。金仓数据......
  • 【Emacs Verilog mode保姆级的使用指南】
    ......
  • 【Python日志模块全面指南】:记录每一行代码的呼吸,掌握应用程序的脉搏
    文章目录......
  • 【原创】EtherCAT主站IgH解析(二)-- Linux/Windows/RTOS等多操作系统IgH EtherCAT主站
    版权声明:本文为本文为博主原创文章,转载请注明出处。如有问题,欢迎指正。博客地址:https://www.cnblogs.com/wsg1100/前言目前,EtherCAT商用主站有:Acontis、TwinCAT3、KPA、Codesys等,开源EtherCAT主站则主要有两大方案:igh与SOEM,两者设计天差地别,SOEM开源于2008年底1.1.2版本,具备良好......
  • 在SQL中使用explode函数展开数组的详细指南
    目录简介示例1:简单数组展开示例2:展开嵌套数组示例3:与其他函数结合使用处理结构体数组示例:展开包含结构体的数组示例2:展开嵌套结构体数组总结简介在处理SQL中的数组数据时,explode函数非常有用。它可以将数组中的每个元素单独提取出来,便于进一步处理。本文将通过几......
  • AI绘画工具进阶指南
    AI绘画工具进阶指南目录引言高级AI绘画工具概述进阶功能及技术风格迁移的高级应用生成对抗网络(GAN)文本到图像生成进阶使用教程DeepArt高级使用教程DeepDream高级使用教程Artbreeder高级使用教程DALL·E高级使用教程结合AI绘画工具进行创作结论引言在掌握了基础的AI......
  • Kotlin 变量详解:声明、赋值与最佳实践指南
    Kotlin变量变量是用于存储数据值的容器。要创建一个变量,使用var或val,然后使用等号(=)给它赋值:语法var变量名=值val变量名=值示例varname="John"valbirthyear=1975println(name)//打印name的值println(birthyear)//打印birthyear的......