多头自注意力机制计算示例
多头自注意力机制计算示例
1. 输入序列和权重矩阵
假设输入序列 X 如下:
X |
---|
[1, 0, 1, 0] |
[0, 1, 0, 1] |
[1, 1, 1, 1] |
我们有两个头,分别对应的权重矩阵如下:
头 1
WQ(1) | WK(1) | WV(1) |
---|---|---|
[1, 0] | [1, 0] | [1, 0] |
[0, 1] | [0, 1] | [0, 1] |
[1, 0] | [1, 0] | [1, 0] |
[0, 1] | [0, 1] | [0, 1] |
头 2
WQ(2) | WK(2) | WV(2) |
---|---|---|
[0, 1] | [0, 1] | [0, 1] |
[1, 0] | [1, 0] | [1, 0] |
[0, 1] | [0, 1] | [0, 1] |
[1, 0] | [1, 0] | [1, 0] |
2. 计算每个头的 Q、K、V
头 1
计算 Q1:
Q1 |
---|
[2, 0] |
[0, 2] |
[2, 2] |
计算 K1:
K1 |
---|
[2, 0] |
[0, 2] |
[2, 2] |
计算 V1:
V1 |
---|
[2, 0] |
[0, 2] |
[2, 2] |
头 2
计算 Q2:
Q2 |
---|
[0, 1] |
[1, 0] |
[2, 2] |
计算 K2:
K2 |
---|
[0, 1] |
[1, 0] |
[2, 2] |
计算 V2:
V2 |
---|
[0, 1] |
[1, 0] |
[2, 2] |
3. 计算每个头的自注意力
头 1
计算点积 Q1 K1T:
Q1 K1T |
---|
[4, 0, 4] |
[0, 4, 4] |
[4, 4, 8] |
缩放点积:
缩放点积 |
---|
[2.83, 0, 2.83] |
[0, 2.83, 2.83] |
[2.83, 2.83, 5.66] |
应用 softmax:
softmax |
---|
[0.5, 0, 0.5] |
[0, 0.5, 0.5] |
[0.25, 0.25, 0.5] |
计算注意力输出:
注意力输出 |
---|
[2, 1] |
[1, 2] |
[1.5, 2] |
头 2
计算点积 Q2 K2T:
Q2 K2T |
---|
[1, 0, 2] |
[0, 1, 2] |
[2, 2, 8] |
缩放点积:
缩放点积 |
---|
[0.71, 0, 1.41] |
[0, 0.71, 1.41] |
[1.41, 1.41, 5.66] |
应用 softmax:
softmax |
---|
[0.41, 0.15, 0.44] |
[0.15, 0.41, 0.44] |
[0.25, 0.25, 0.5] |
计算注意力输出:
注意力输出 |
---|
[0.88, 1.29] |
[1.29, 0.88] |
[1.50, 1.50] |
4. 合并头的输出
将所有头的输出连接起来:
Concat |
---|
[2, 1, 0.88, 1.29] |
[1, 2, 1.29, 0.88] |
[1.5, 2, 1.5, 1.5] |
5. 最终线性变换
假设线性变换矩阵 WO 为:
WO |
---|
[0.5, 0.5, 0.5, 0.5] |
[0.5, 0.5, 0.5, 0.5] |
计算线性变换输出:
Output |
---|
2.585, 2.585 |
2.585, 2.585 |
3.25, 3.25 |