LLaMA系列用的FFN层现在是SwishGLU,这里Swish是个激活函数,GLU是个线性单元,二者合起来是SwishGLU。
FFN
Transformer中原始的FFN长这样:
\[FFN(x) = ReLU(xW_1+b_1)W_2+b_2 \]两个线性层中间夹了个relu激活函数
写成模型代码就是:
x = up_proj(x)
x = relu(x)
x = down_proj(x)
或者写成一行:
x = down_proj(relu(up_proj(x)))
Swish
是一个激活函数
原论文是:Searching for Activation Functions
其中\(\sigma(x)=\frac{1}{1+e^{-x}}\),就是sigmoid函数,\(\beta\)是超参数,或者可训练参数。
写成代码就是:
y = x * sigmoid(x)
GLU(Gated Linear Units)
门控线性单元
原论文是:Language Modeling with Gated Convolutional Networks
其中\(\odot\)是矩阵element-wise乘积,\(\sigma\)在这里是“任意”激活函数。(虽然原论文说是sigmoid,但是必须认为是任意,不然下面没法解释了)
这里应该如此理解:对线性单元施加了一个门控。左边的\(xW+b\)是原本的线性单元,给他加入带门控的线性单元信息:\(\sigma(xV+c)\)。
写成代码就是:
x = up_proj(x) * act(gate_proj(x))
SwishGLU
其实SwishGLU和GLU没区别,也长这样,只要把激活函数换成swish用的sigmoid即可。
吐槽一句,GLU原论文中本来就说了激活函数是sigmoid,Swish原文说的激活函数也是sigmoid,这样GLU本来就是SwishGLU,二者没什么区别,但是大家的讨论似乎都歪了,GLU反而成为了符合上述形式的,任意激活函数的一个模块的统称
写成代码就是(只替换了sigmoid):
x = up_proj(x) * sigmoid(gate_proj(x))
LLaMA MLP
回顾一下transformer的FFN,只要将down_proj
里面的部分替换为SwishGLU
down_proj(F.sigmoid(gate_proj(x)) * up_proj(x))
也就是把up_proj(x)
变成了:F.sigmoid(gate_proj(x)) * up_proj(x)