引 言
在本文中,我们将介绍并解释基于 Transformer 的大语言模型的每个步骤。
当第一次接触 Transformer 架构时,我被可用于理解它的大量概念和教程所淹没。一些视频或文章假设了自然语言处理(NLP)概念的先验知识,而另一些则太长且难以理解。为了掌握Transformer 架构,我不得不阅读大量文章并观看几个视频,有时由于缺乏基础知识而无法完全理解该主题,我陷入困境,包括:
❝
为什么输入嵌入可以代表一个单词?
什么是矩阵乘法以及形状如何执行?
Softmax 是做什么的?
当模型被训练后,它存储在哪里?
…
由于涉及的所有复杂性,我发现完全掌握这种架构具有挑战性。如果我们循序渐进,从基础开始,逐渐转向更高级的主题,我相信你很快就会和我站在同一起跑线上。
如果你是新手LLM,请务必查看之前介绍的文章,其中解释了LLMs的基础知识和工作原理。
-
LLM大模型基础入门系列之:(一)什么是大语言模型?(https://zhuanlan.zhihu.com/p/704561478)
-
LLM大模型基础入门系列之:(二)大模型如何工作(https://zhuanlan.zhihu.com/p/706164564)
高级概述
Transformer 模型于 2017 年在论文 《Attention is all you need》中首次提出。Transformer 架构旨在训练语言翻译目的模型。然而,OpenAI 的团队发现 transformer 架构是角色预测的关键解决方案。一旦对整个互联网数据进行训练,该模型就有可能理解任何文本的上下文,并连贯地完成任何句子,就像人类一样。
该模型由两部分组成:编码器和解码器。通常,仅编码器体系结构擅长从文本中提取信息以执行分类和回归等任务,而仅解码器模型则**专门用于生成文本。**例如,专注于文本生成的 GPT 属于仅解码器模型的范畴。
GPT 模型仅使用 transformer 架构的解码器部分。
让我们在训练模型时了解架构的关键思想。一张图来说明类似 GPT 的仅解码器转换器架构的训练过程:
-
首先,我们需要一系列输入字符作为训练数据。这些输入被转换为向量嵌入格式。
-
在向量嵌入中添加位置编码,以捕获每个字符在序列中的位置。
-
模型通过一系列计算操作处理这些输入嵌入,最终为给定的输入文本生成可能的下一个字符的概率分布。
-
该模型根据训练数据集中的实际后续特征评估预测结果,并相应地调整概率或“权重”。
-
最后,该模型迭代地完善了这一过程,不断更新其参数以提高未来预测的精度。
让我们深入了解每个步骤的细节。
1、Tokenization 标记化
Tokenization 是 transformer 模型的第一步,该模型:
❝
将输入句子转换为数字表示格式。
Tokenization 是将文本划分为称为 Tokens 的较小单元的过程,这些单元可以是单词、子单词、短语或字符。因为将短语分解成更小的部分有助于模型识别文本的底层结构并更有效地处理它。
❝
Chapter 1: Building Rapport and Capturing
上面这句话可以切成:
❝
Chapter
, ,1
,:
, ,Building
, ,Rap
,port
, ,and
, ,Capturing
它被标记为 10 个数字:
❝
[26072, 220, 16, 25, 17283, 23097, 403, 220, 323, 220, 17013, 220, 1711]
数字 220 用于表示空格字符。有许多方法可以将字符标记为整数。对于示例数据集,我们将使用 tiktoken 库。
为了便于演示,将使用一个小型教科书数据集(来自 Hugging Face ),其中包含 460k 个字符用于训练。
❝
文件大小:450**Kb
词汇量:3,771(表示唯一单词/子单词)
训练数据包含 3,771 个不同字符的词汇量。用于标记教科书数据集的最大数量是 100069
,它被映射到一个字符Clar
。
一旦有了标记化映射,就可以为数据集中每个字符找到相应的整数索引。将利用这些分配的整数索引作为标记,而不是在与模型交互时使用整个单词。
2、Word Embeddings 词嵌入
首先,构建一个包含词汇表中所有字符的查找表。从本质上讲,该表由一个填充了随机初始化数字的矩阵组成。
给定最大标记数是 100069
,并考虑维度为 64(原始论文使用 512 维,表示为 d_model),生成的查找表变成 100,069 × 64 矩阵,这称为标记嵌入查找表。表示如下:
Token Embedding Look-Up Table: 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62 63 0 0.625765 0.025510 0.954514 0.064349 -0.502401 -0.202555 -1.567081 -1.097956 0.235958 -0.239778 ... 0.420812 0.277596 0.778898 1.533269 1.609736 -0.403228 -0.274928 1.473840 0.068826 1.332708 1 -0.497006 0.465756 -0.257259 -1.067259 0.835319 -1.956048 -0.800265 -0.504499 -1.426664 0.905942 ... 0.008287 -0.252325 -0.657626 0.318449 -0.549586 -1.464924 -0.557690 -0.693927 -0.325247 1.243933 2 1.347121 1.690980 -0.124446 -1.682366 1.134614 -0.082384 0.289316 0.835773 0.306655 -0.747233 ... 0.543340 -0.843840 -0.687481 2.138219 0.511412 1.219090 0.097527 -0.978587 -0.432050 -1.493750 3 1.078523 -0.614952 -0.458853 0.567482 0.095883 -1.569957 0.373957 -0.142067 -1.242306 -0.961821 ... -0.882441 0.638720 1.119174 -1.907924 -0.527563 1.080655 -2.215207 0.203201 -1.115814 -1.258691 4 0.814849 -0.064297 1.423653 0.261726 -0.133177 0.211893 1.449790 3.055426 -1.783010 -0.832339 ... 0.665415 0.723436 -1.318454 0.785860 -1.150111 1.313207 -0.334949 0.149743 1.306531 -0.046524 ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... 100064 -0.898191 -1.906910 -0.906910 1.838532 2.121814 -1.654444 0.082778 0.064536 0.345121 0.262247 ... 0.438956 0.163314 0.491996 1.721039 -0.124316 1.228242 0.368963 1.058280 0.406413 -0.326223 100065 1.354992 -1.203096 -2.184551 -1.745679 -0.005853 -0.860506 1.010784 0.355051 -1.489120 -1.936192 ... 1.354665 -1.338872 -0.263905 0.284906 0.202743 -0.487176 -0.421959 0.490739 -1.056457 2.636806 100066 -0.436116 0.450023 -1.381522 0.625508 0.415576 0.628877 -0.595811 -1.074244 -1.512645 -2.027422 ... 0.436522 0.068974 1.305852 0.005790 -0.583766 -0.797004 0.144952 -0.279772 1.522029 -0.629672 100067 0.147102 0.578953 -0.668165 -0.011443 0.236621 0.348374 -0.706088 1.368070 -1.428709 -0.620189 ... 1.130942 -0.739860 -1.546209 -1.475937 -0.145684 -1.744829 0.637790 -1.064455 1.290440 -1.110520 100068 0.415268 -0.345575 0.441546 -0.579085 1.110969 -1.303691 0.143943 -0.714082 -1.426512 1.646982 ... -2.502535 1.409418 0.159812 -0.911323 0.856282 -0.404213 -0.012741 1.333426 0.372255 0.722526 [100,069 rows x 64 columns]
其中每行代表一个字符(按其标记编号索引),每列代表一个维度。
现在,可以将“维度”视为角色的特征或方面。在本例中,指定 64 个维度,这意味着将能够以 64 种不同的方式理解一个角色的文本含义,例如将其分类为名词、动词、形容词等。
假设,现在有一个上下文长度为16的示例训练输入,它是:
❝
“
. By mastering the art of identifying underlying motivations and desires, we equip ourselves with
" “. By mastering the art of identifying underlying motivations and desires, we equip ourselves with
”
现在,通过使用其整数索引来查找嵌入表,从而检索每个标记化字符(或单词)的嵌入向量。因此,我们得到了它们各自的输入嵌入:
[ 627, 1383, 88861, 279, 1989, 315, 25607, 16940, 65931, 323, 32097, 11, 584, 26458, 13520, 449]
在 transformer 架构中,多个输入序列同时并行处理,通常称为多批处理。将batch_size设置为 4。因此,可以一次处理四个随机选择的句子作为输入。
Input Sequence Batch: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 0 627 1383 88861 279 1989 315 25607 16940 65931 323 32097 11 584 26458 13520 449 1 15749 311 9615 3619 872 6444 6 3966 11 10742 11 323 32097 13 3296 22815 2 13189 315 1701 5557 304 6763 374 88861 7528 10758 7526 13 4314 7526 2997 2613 3 323 6376 2867 26470 1603 16661 264 49148 627 18 13 81745 48023 75311 7246 66044 [4 rows x 16 columns]
每行代表一个句子;每列是该句子从第 0 位到第 15 位的字符。
结果,我们现在有了一个矩阵,表示 4 批 16 个字符的输入。该矩阵的形状为 (batch_size, context_length) = [4, 16]。
回顾一下,我们将输入嵌入查找表定义为大小为 100,069 × 64 的矩阵。下一步是获取输入序列矩阵并将其映射到这个嵌入矩阵上,以获得我们的输入嵌入。
在这里,我们将重点分解输入序列矩阵的每一行,从第一行开始。首先,将此开始行从其原始尺寸 (1, context_length) = [1, 16] 重塑为(context_length, 1)= [16, 1] 的新格式。随后,将这个重组后的行覆盖在我们之前建立的嵌入矩阵大小 (vocab_size, d_model) = [100069, 64] 上,从而将匹配的嵌入向量替换为给定上下文窗口中存在的每个字符。生成的输出是形状为 (context_length, d_model) = [16, 64] 的矩阵。
输入序列批处理的第一行:
Input Embedding: 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62 63 0 1.051807 -0.704369 -0.913199 -1.151564 0.582201 -0.898582 0.984299 -0.075260 -0.004821 -0.743642 ... 1.151378 0.119595 0.601200 -0.940352 0.289960 0.579749 0.428623 0.263096 -0.773865 -0.734220 1 -0.293959 -1.278850 -0.050731 0.862562 0.200148 -1.732625 0.374076 -1.128507 0.281203 -1.073113 ... -0.062417 -0.440599 0.800283 0.783043 1.602350 -0.676059 -0.246531 1.005652 -1.018667 0.604092 2 -0.292196 0.109248 -0.131576 -0.700536 0.326451 -1.885801 -0.150834 0.348330 -0.777281 0.986769 ... 0.382480 1.315575 -0.144037 1.280103 1.112829 0.438884 -0.275823 -2.226698 0.108984 0.701881 3 0.427942 0.878749 -0.176951 0.548772 0.226408 -0.070323 -1.865235 1.473364 1.032885 0.696173 ... 1.270187 1.028823 -0.872329 -0.147387 -0.083287 0.142618 -0.375903 -0.101887 0.989520 -0.062560 4 -1.064934 -0.131570 0.514266 -0.759037 0.294044 0.957125 0.976445 -1.477583 -1.376966 -1.171344 ... 0.231112 1.278687 0.254688 0.516287 0.621753 0.219179 1.345463 -0.927867 0.510172 0.656851 5 2.514588 -1.001251 0.391298 -0.845712 0.046932 -0.036732 1.396451 0.934358 -0.876228 -0.024440 ... 0.089804 0.646096 -0.206935 0.187104 -1.288239 -1.068143 0.696718 -0.373597 -0.334495 -0.462218 6 0.498423 -0.349237 -1.061968 -0.093099 1.374657 -0.512061 -1.238927 -1.342982 -1.611635 2.071445 ... 0.025505 0.638072 0.104059 -0.600942 -0.367796 -0.472189 0.843934 0.706170 -1.676522 -0.266379 7 1.684027 -0.651413 -0.768050 0.599159 -0.381595 0.928799 2.188572 1.579998 -0.122685 -1.026440 ... -0.313672 1.276962 -1.142109 -0.145139 1.207923 -0.058557 -0.352806 1.506868 -2.296642 1.378678 8 -0.041210 -0.834533 -1.243622 -0.675754 -1.776586 0.038765 -2.713090 2.423366 -1.711815 0.621387 ... -1.063758 1.525688 -1.762023 0.161098 0.026806 0.462347 0.732975 0.479750 0.942445 -1.050575 9 0.708754 1.058510 0.297560 0.210548 0.460551 1.016141 2.554897 0.254032 0.935956 -0.250423 ... -0.552835 0.084124 0.437348 0.596228 0.512168 0.289721 -0.028321 -0.932675 -0.411235 1.035754 10 -0.584553 1.395676 0.727354 0.641352 0.693481 -2.113973 -0.786199 -0.327758 1.278788 -0.156118 ... 1.204587 -0.131655 -0.595295 -0.433438 -0.863684 3.272247 0.101591 0.619058 -0.982174 -1.174125 11 -0.753828 0.098016 -0.945322 0.708373 -1.493744 0.394732 0.075629 -0.049392 -1.005564 0.356353 ... 2.452891 -0.233571 0.398788 -1.597272 -1.919085 -0.405561 -0.266644 1.237022 1.079494 -2.292414 12 -0.611864 0.006810 1.989711 -0.446170 -0.670108 0.045619 -0.092834 1.226774 -1.407549 -0.096695 ... 1.181310 -0.407162 -0.086341 -0.530628 0.042921 1.369478 0.823999 -0.312957 0.591755 0.516314 13 -0.584553 1.395676 0.727354 0.641352 0.693481 -2.113973 -0.786199 -0.327758 1.278788 -0.156118 ... 1.204587 -0.131655 -0.595295 -0.433438 -0.863684 3.272247 0.101591 0.619058 -0.982174 -1.174125 14 -1.174090 0.096075 -0.749195 0.395859 -0.622460 -1.291126 0.094431 0.680156 -0.480742 0.709318 ... 0.786663 0.237733 1.513797 0.296696 0.069533 -0.236719 1.098030 -0.442940 -0.583177 1.151497 15 0.401740 -0.529587 3.016675 -1.134723 -0.256546 -0.219896 0.637936 2.000511 -0.418684 -0.242720 ... -0.442287 -1.519394 -1.007496 -0.517480 0.307449 -0.316039 -0.880636 -1.424680 -1.901644 1.968463 [16 rows x 64 columns]
矩阵显示映射后的四行之一
对其余的 3 行执行相同的操作,最终我们有 4 组 x [16 行 x 64 列]。
这会导致形状为**(batch_size, context_length, d_model)= [4, 16, 64]** 的输入嵌入矩阵。
从本质上讲,为每个单词提供唯一的嵌入允许模型适应语言的变化并管理具有多种含义或形式的单词。
**让我们继续前进,**理解输入嵌入矩阵作为模型的预期输入格式,即使我们还没有完全掌握起作用的基本数学原理。
3、位置编码
在我看来,位置编码是 transformer 架构中最具挑战性的概念。
总结一下位置编码解决了什么问题:
-
希望每个单词都带有一些关于它在句子中的位置信息。
-
希望模型将看起来彼此接近的单词视为“接近”,将距离较远的单词视为“遥远”。
-
希望位置编码表示模型可以学习的模式。
位置编码描述序列中实体的位置或位置,以便为每个位置分配唯一的表示形式。
位置编码是另一个数字向量,它被添加到每个标记化字符的输入嵌入中。位置编码是正弦波和余弦波,其频率根据标记化字符的位置而变化。
在原始论文中,引入的位置编码计算方法是:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model)) PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
其中 pos
是位置, i
从 0 到 d_model / 2。
d_model
是在训练模型时定义的模型维度(在本例中是 64,在原始论文中使用 512)。
事实上,这个位置编码矩阵只创建一次,并重复用于每个输入序列。
让我们看一下位置编码矩阵:
Position Embedding Look-Up Table: 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62 63 0 0.000000 1.000000 0.000000 1.000000 0.000000 1.000000 0.000000 1.000000 0.000000 1.000000 ... 0.000000 1.000000 0.000000 1.000000 0.000000 1.000000 0.000000 1.000000 0.000000 1.000000 1 0.841471 0.540302 0.681561 0.731761 0.533168 0.846009 0.409309 0.912396 0.310984 0.950415 ... 0.000422 1.000000 0.000316 1.000000 0.000237 1.000000 0.000178 1.000000 0.000133 1.000000 2 0.909297 -0.416147 0.997480 0.070948 0.902131 0.431463 0.746904 0.664932 0.591127 0.806578 ... 0.000843 1.000000 0.000632 1.000000 0.000474 1.000000 0.000356 1.000000 0.000267 1.000000 3 0.141120 -0.989992 0.778273 -0.627927 0.993253 -0.115966 0.953635 0.300967 0.812649 0.582754 ... 0.001265 0.999999 0.000949 1.000000 0.000711 1.000000 0.000533 1.000000 0.000400 1.000000 4 -0.756802 -0.653644 0.141539 -0.989933 0.778472 -0.627680 0.993281 -0.115730 0.953581 0.301137 ... 0.001687 0.999999 0.001265 0.999999 0.000949 1.000000 0.000711 1.000000 0.000533 1.000000 5 -0.958924 0.283662 -0.571127 -0.820862 0.323935 -0.946079 0.858896 -0.512150 0.999947 -0.010342 ... 0.002108 0.999998 0.001581 0.999999 0.001186 0.999999 0.000889 1.000000 0.000667 1.000000 6 -0.279415 0.960170 -0.977396 -0.211416 -0.230368 -0.973104 0.574026 -0.818837 0.947148 -0.320796 ... 0.002530 0.999997 0.001897 0.999998 0.001423 0.999999 0.001067 0.999999 0.000800 1.000000 7 0.656987 0.753902 -0.859313 0.511449 -0.713721 -0.700430 0.188581 -0.982058 0.800422 -0.599437 ... 0.002952 0.999996 0.002214 0.999998 0.001660 0.999999 0.001245 0.999999 0.000933 1.000000 8 0.989358 -0.145500 -0.280228 0.959933 -0.977262 -0.212036 -0.229904 -0.973213 0.574318 -0.818632 ... 0.003374 0.999994 0.002530 0.999997 0.001897 0.999998 0.001423 0.999999 0.001067 0.999999 9 0.412118 -0.911130 0.449194 0.893434 -0.939824 0.341660 -0.608108 -0.793854 0.291259 -0.956644 ... 0.003795 0.999993 0.002846 0.999996 0.002134 0.999998 0.001600 0.999999 0.001200 0.999999 10 -0.544021 -0.839072 0.937633 0.347628 -0.612937 0.790132 -0.879767 -0.475405 -0.020684 -0.999786 ... 0.004217 0.999991 0.003162 0.999995 0.002371 0.999997 0.001778 0.999998 0.001334 0.999999 11 -0.999990 0.004426 0.923052 -0.384674 -0.097276 0.995257 -0.997283 -0.073661 -0.330575 -0.943780 ... 0.004639 0.999989 0.003478 0.999994 0.002609 0.999997 0.001956 0.999998 0.001467 0.999999 12 -0.536573 0.843854 0.413275 -0.910606 0.448343 0.893862 -0.940067 0.340989 -0.607683 -0.794179 ... 0.005060 0.999987 0.003795 0.999993 0.002846 0.999996 0.002134 0.999998 0.001600 0.999999 13 0.420167 0.907447 -0.318216 -0.948018 0.855881 0.517173 -0.718144 0.695895 -0.824528 -0.565821 ... 0.005482 0.999985 0.004111 0.999992 0.003083 0.999995 0.002312 0.999997 0.001734 0.999998 14 0.990607 0.136737 -0.878990 -0.476839 0.999823 -0.018796 -0.370395 0.928874 -0.959605 -0.281349 ... 0.005904 0.999983 0.004427 0.999990 0.003320 0.999995 0.002490 0.999997 0.001867 0.999998 15 0.650288 -0.759688 -0.968206 0.250154 0.835838 -0.548975 0.042249 0.999107 -0.999519 0.031022 ... 0.006325 0.999980 0.004743 0.999989 0.003557 0.999994 0.002667 0.999996 0.002000 0.999998 [16 rows x 64 columns]
谈一谈位置编码技巧。
据了解,位置值是根据它们在序列中的相对位置建立的。此外,由于每个输入句子的上下文长度一致,它使我们能够在各种输入中回收相同的位置编码。因此,必须谨慎地创建序列号,以防止过大的幅度对输入嵌入产生负面影响,确保相邻位置表现出微小的差异,而远处的位置显示出它们之间的较大差异。
使用正弦和余弦向量的组合,该模型可以看到独立于词嵌入的位置编码向量,而不会混淆输入嵌入(语义)信息。很难想象这在神经元网络内部是如何工作的,但它是有效的。
我们可以可视化位置嵌入数字并查看模式。
每条垂直线都是从 0 到 64 的维度;每行代表一个字符。这些值介于 -1 和 1 之间,因为它们来自正弦和余弦函数。颜色越深表示值越接近 -1,颜色越亮表示值越接近 1。绿色表示介于两者之间的值。
让我们回到位置编码矩阵,正如你所看到的,这个位置编码表与输入嵌入表 [4, 16, 64] 中的每个批处理具有相同的形状,它们都是 (context_length, d_model)= [16, 64]。
由于两个具有相同形状的矩阵可以相加,因此可以将位置信息添加到每个输入嵌入行中,以获得最终输入嵌入矩阵。
batch 0: 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62 63 0 1.051807 0.295631 -0.913199 -0.151564 0.582201 0.101418 0.984299 0.924740 -0.004821 0.256358 ... 1.151378 1.119595 0.601200 0.059648 0.289960 1.579749 0.428623 1.263096 -0.773865 0.265780 1 0.547512 -0.738548 0.630830 1.594323 0.733316 -0.886616 0.783385 -0.216111 0.592187 -0.122698 ... -0.061995 0.559401 0.800599 1.783043 1.602587 0.323941 -0.246353 2.005651 -1.018534 1.604092 2 0.617101 -0.306899 0.865904 -0.629588 1.228581 -1.454339 0.596070 1.013263 -0.186154 1.793348 ... 0.383324 2.315575 -0.143404 2.280102 1.113303 1.438884 -0.275467 -1.226698 0.109251 1.701881 3 0.569062 -0.111243 0.601322 -0.079154 1.219661 -0.186289 -0.911600 1.774332 1.845533 1.278927 ... 1.271452 2.028822 -0.871380 0.852612 -0.082575 1.142617 -0.375369 0.898113 0.989920 0.937440 4 -1.821736 -0.785214 0.655805 -1.748969 1.072516 0.329445 1.969725 -1.593312 -0.423386 -0.870206 ... 0.232799 2.278685 0.255953 1.516287 0.622701 1.219178 1.346175 0.072133 0.510705 1.656851 5 1.555663 -0.717588 -0.179829 -1.666574 0.370867 -0.982811 2.255347 0.422208 0.123719 -0.034782 ... 0.091912 1.646094 -0.205354 1.187103 -1.287054 -0.068144 0.697607 0.626403 -0.333828 0.537782 6 0.219007 0.610934 -2.039364 -0.304516 1.144289 -1.485164 -0.664902 -2.161820 -0.664487 1.750649 ... 0.028036 1.638068 0.105957 0.399056 -0.366373 0.527810 0.845001 1.706170 -1.675722 0.733621 7 2.341013 0.102489 -1.627363 1.110608 -1.095316 0.228369 2.377153 0.597940 0.677737 -1.625878 ... -0.310720 2.276958 -1.139895 0.854859 1.209583 0.941441 -0.351562 2.506867 -2.295708 2.378678 8 0.948148 -0.980033 -1.523850 0.284180 -2.753848 -0.173272 -2.942995 1.450153 -1.137498 -0.197246 ... -1.060385 2.525683 -1.759494 1.161095 0.028703 1.462346 0.734397 1.479749 0.943511 -0.050575 9 1.120872 0.147380 0.746753 1.103982 -0.479273 1.357801 1.946789 -0.539822 1.227215 -1.207067 ... -0.549040 1.084117 0.440194 1.596224 0.514303 1.289719 -0.026721 0.067324 -0.410035 2.035753 10 -1.128574 0.556604 1.664986 0.988980 0.080544 -1.323841 -1.665967 -0.803163 1.258105 -1.155904 ... 1.208804 0.868336 -0.592132 0.566557 -0.861313 4.272244 0.103369 1.619057 -0.980840 -0.174126 11 -1.753818 0.102441 -0.022270 0.323699 -1.591020 1.389990 -0.921654 -0.123053 -1.336139 -0.587427 ... 2.457530 0.766419 0.402266 -0.597278 -1.916476 0.594436 -0.264688 2.237020 1.080961 -1.292415 12 -1.148437 0.850664 2.402985 -1.356776 -0.221765 0.939481 -1.032902 1.567763 -2.015232 -0.890874 ... 1.186370 0.592825 -0.082546 0.469365 0.045767 2.369474 0.826133 0.687041 0.593355 1.516313 13 -0.164386 2.303123 0.409138 -0.306666 1.549362 -1.596800 -1.504343 0.368137 0.454260 -0.721938 ... 1.210069 0.868330 -0.591184 0.566554 -0.860601 4.272243 0.103903 1.619056 -0.980440 -0.174127 14 -0.183482 0.232812 -1.628186 -0.080981 0.377364 -1.309922 -0.275964 1.609030 -1.440347 0.427969 ... 0.792566 1.237715 1.518224 1.296686 0.072853 0.763276 1.100520 0.557057 -0.581310 2.151496 15 1.052028 -1.289275 2.048469 -0.884570 0.579293 -0.768871 0.680185 2.999618 -1.418203 -0.211697 ... -0.435962 -0.519414 -1.002752 0.482508 0.311006 0.683955 -0.877969 -0.424683 -1.899643 2.968462 [16 rows x 64 columns] batch 1: 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62 63 0 -0.264236 0.965681 1.909974 -0.338721 -0.554196 0.254583 -0.576111 1.766522 -0.652587 0.455450 ... -1.016426 0.458762 -0.513290 0.618411 0.877229 2.526591 0.614551 0.662366 -1.246907 1.128066 1 1.732205 -0.858178 0.324008 1.022650 -1.172865 0.513133 -0.121611 2.630085 0.072425 2.332296 ... 0.737660 1.988225 2.544661 1.995471 0.447863 3.174428 0.444989 0.860426 2.137797 1.537580 2 -1.348308 -1.080221 1.753394 0.156193 0.440652 1.015287 -0.790644 1.215537 2.037030 0.476560 ... 0.296941 1.100837 -0.153194 1.329375 -0.188958 1.229344 -1.301919 0.938138 -0.860689 -0.860137 3 0.601103 -0.156419 0.850114 -0.324190 -0.311584 -2.232454 -0.903112 0.242687 0.801908 2.502464 ... -0.397007 1.150545 -0.473907 0.318961 -1.970126 1.967961 -0.186831 0.131873 0.947445 -0.281573 4 -1.821736 -0.785214 0.655805 -1.748969 1.072516 0.329445 1.969725 -1.593312 -0.423386 -0.870206 ... 0.232799 2.278685 0.255953 1.516287 0.622701 1.219178 1.346175 0.072133 0.510705 1.656851 5 1.555663 -0.717588 -0.179829 -1.666574 0.370867 -0.982811 2.255347 0.422208 0.123719 -0.034782 ... 0.091912 1.646094 -0.205354 1.187103 -1.287054 -0.068144 0.697607 0.626403 -0.333828 0.537782 6 0.599841 0.943214 -1.397184 -0.607349 -0.333995 -1.222589 -0.731189 -0.997706 1.848611 0.254238 ... 0.340986 1.383113 1.674592 2.229903 -0.157415 0.362868 -0.493762 1.904136 0.027903 1.196017 7 0.072234 1.386670 -0.985962 -1.184486 0.958293 -0.295773 -1.529277 -0.727844 1.510503 1.268154 ... -0.356459 0.382331 0.138104 -0.360916 -0.638448 1.305404 -0.756442 0.299150 0.154600 -0.466154 8 -0.008645 -1.066763 -0.716555 2.148885 -0.709739 -0.137266 0.385401 0.699139 1.907906 -2.357567 ... 0.490190 -1.215412 1.216459 0.659227 -0.282908 -0.912266 0.595569 1.210701 0.737407 0.801672 9 -0.006332 -0.949928 0.192689 3.158421 -1.292153 -0.830248 0.966141 -2.056514 0.042364 1.485927 ... 0.480763 -0.318554 0.005837 3.031636 -0.448117 1.059403 0.598106 0.871427 0.327321 1.090921 10 -1.152681 -0.710162 -0.456591 -0.468090 -0.292566 0.747535 -0.149907 -0.395523 0.170872 -2.372754 ... -1.267461 0.043283 -0.114980 1.083042 -0.288776 1.442318 0.775591 0.728716 -0.576776 -0.727257 11 -0.955986 -0.277475 0.946888 -0.242687 1.257744 0.369994 0.460073 0.728078 -0.165204 -0.761762 ... -0.307983 2.078995 -1.067792 1.805637 0.608968 1.722982 -0.371174 -0.603182 0.285387 1.112932 12 -0.844347 0.883224 1.222388 -0.811387 -0.593557 0.157268 -0.650315 1.289236 -1.472027 -0.447092 ... -0.536433 2.465097 -0.822905 1.272786 0.703664 2.687270 -0.924388 0.596134 -0.367138 0.812242 13 0.776470 1.549248 -0.239693 0.133783 0.767255 1.996130 -0.436228 -0.327975 -0.650743 0.507769 ... -0.821793 1.387792 -1.052105 2.123603 1.421092 2.066746 -0.747766 0.627081 -1.749071 -0.679443 14 1.277579 0.653945 0.045632 -0.409790 0.829708 0.249433 -0.682051 0.601958 -1.932014 -2.077397 ... 0.160611 1.037856 0.656832 0.992817 -0.684056 1.031199 -0.180866 4.579140 -1.123555 0.181580 15 0.356328 -2.038538 -1.018938 1.112716 1.035987 -2.281600 0.416325 -0.129400 -0.718316 -1.042091 ... -0.056092 0.559381 0.805026 1.783032 1.605907 0.323934 -0.243863 2.005648 -1.016667 1.604090 [16 rows x 64 columns] batch 2: 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62 63 0 0.645854 1.291073 -1.588931 1.814376 -0.185270 0.846816 -1.686862 0.982995 -0.973108 1.297203 ... 0.852600 1.533231 0.692729 2.437029 -0.178137 0.493413 0.597484 1.909155 1.257821 2.644325 1 1.732205 -0.858178 0.324008 1.022650 -1.172865 0.513133 -0.121611 2.630085 0.072425 2.332296 ... 0.737660 1.988225 2.544661 1.995471 0.447863 3.174428 0.444989 0.860426 2.137797 1.537580 2 3.298391 -0.363908 0.376535 -0.276692 1.262433 -0.595659 1.694541 0.542514 -0.464756 0.368460 ... -0.169474 1.420809 0.304488 1.689731 -1.128037 -0.024476 -1.356808 2.160992 -2.110703 -0.472404 3 0.626955 -2.988524 0.915578 1.123503 0.635983 0.078006 0.466728 -0.930765 2.189286 1.505499 ... 2.496649 1.691578 0.642664 2.089205 1.926187 1.185045 -0.969952 0.666007 -0.030641 0.667574 4 0.396447 -2.116415 0.384262 -1.632779 0.859029 -0.726599 2.121946 -1.314046 0.744388 -0.227106 ... -1.937352 2.378620 0.029220 1.215336 -0.405487 -0.834419 -1.219825 0.000676 -0.821293 0.340797 5 -2.133014 0.379737 -1.320323 -0.425003 -0.298524 -2.237205 0.953327 0.168006 0.519205 0.698976 ... 0.788771 1.237731 1.515378 1.296695 0.070718 0.763281 1.098920 0.557059 -0.582510 2.151497 6 -0.390918 0.634039 -1.350461 0.032129 0.106428 0.370410 1.292387 0.986316 -0.095396 0.555067 ... -1.792372 -0.357599 0.912276 0.088746 0.866950 0.927208 -0.381643 2.532119 0.464615 -1.044299 7 -0.407947 0.622332 -0.345048 -0.247587 -0.419677 0.256695 1.165026 -2.459640 -0.576545 -1.770781 ... 0.234064 2.278682 0.256901 1.516285 0.623413 1.219177 1.346708 0.072133 0.511105 1.656851 8 3.503946 -1.146751 0.111070 0.114221 -0.930330 -0.248769 1.166547 -0.038856 -0.301910 -0.843072 ... 0.093177 1.646091 -0.204405 1.187101 -1.286342 -0.068145 0.698141 0.626402 -0.333428 0.537781 9 -1.946920 -0.443788 0.560103 3.584257 -0.134643 -1.538940 -1.059084 -0.128679 2.503847 -2.244587 ... -0.643552 1.608934 -0.488734 -0.291253 1.633294 -0.018763 0.696360 -0.657761 0.692395 1.741288 10 0.376520 0.583786 -0.705047 0.855548 0.471473 0.687240 -0.605646 0.463047 1.619052 -1.894214 ... -0.688652 1.974150 -1.399412 2.567682 -0.050040 1.782055 -0.297912 2.366196 -1.888527 0.635260 11 -0.109256 -1.394054 0.565499 -0.093785 -1.803309 0.662382 -1.528203 1.644028 -0.569133 0.438101 ... 0.741877 1.988214 2.547823 1.995465 0.450234 3.174424 0.446768 0.860424 2.139130 1.537579 12 -1.553993 -0.983421 0.392842 -1.473186 1.530387 1.894017 -0.732786 -1.601045 -0.740344 0.245303 ... -0.328828 3.013883 1.178296 1.263333 0.284824 0.791874 2.402131 -0.231270 -1.025411 0.178748 13 -0.757965 1.771306 0.805440 -0.509121 1.212250 0.388750 -0.606959 2.352489 -2.445346 -0.103223 ... 0.425556 1.783019 0.698336 1.871530 2.314023 0.424368 -1.002745 0.983784 -0.090133 0.905337 14 -0.183482 0.232812 -1.628186 -0.080981 0.377364 -1.309922 -0.275964 1.609030 -1.440347 0.427969 ... 0.792566 1.237715 1.518224 1.296686 0.072853 0.763276 1.100520 0.557057 -0.581310 2.151496 15 -0.151101 -0.257150 -0.478131 -1.170082 1.318685 -0.188166 0.146375 2.895475 -0.918949 -0.305261 ... 1.623350 1.656103 -0.600456 1.039260 -1.944202 0.894911 1.409396 1.722673 -0.172070 2.265543 [16 rows x 64 columns] batch 3: 0 1 2 3 4 5 6 7 8 9 ... 54 55 56 57 58 59 60 61 62 63 0 0.377847 -0.380613 1.958640 0.224087 -0.420293 0.915635 -1.077748 1.255988 -0.223147 0.977568 ... -1.290532 1.460963 1.365088 -2.037483 -2.213841 1.039091 -2.129649 0.108403 -0.356996 2.239356 1 0.527961 0.342787 0.096746 0.885016 0.706699 2.873656 0.139732 0.497379 -0.009022 -0.147825 ... -0.409913 0.785146 -0.138166 2.041000 0.277500 1.578947 -1.535113 0.912230 -0.312735 0.540365 2 1.054965 -0.134411 2.155045 -0.188724 0.651576 -0.265663 -0.777263 0.571080 1.508661 1.021718 ... 0.762458 2.297400 -0.624743 -0.979212 2.024008 1.295633 0.208825 0.953138 -2.962624 1.586901 3 -1.032970 -0.893918 0.029077 -0.232068 0.370793 -1.407092 1.048066 0.981123 0.331907 1.292072 ... 0.787928 1.237732 1.514746 1.296695 0.070244 0.763281 1.098564 0.557060 -0.582777 2.151497 4 -0.980037 -1.014605 1.875135 -2.459635 0.486067 -0.941092 1.205490 1.248531 1.801383 0.576983 ... 0.192097 1.784109 -0.201023 0.405095 0.982041 1.927637 0.008535 1.063376 -1.439787 2.967185 5 -0.369996 -1.151058 -0.126222 0.768431 0.107524 -0.481010 2.056029 -0.872815 1.522675 -0.440916 ... 0.246007 -1.032684 0.572565 0.944744 0.790383 -0.034063 -1.704374 -0.053319 1.739537 2.381506 6 -0.555136 -0.284736 -0.162689 -1.542923 -1.619371 -2.014224 0.957231 -0.338164 1.353500 -2.048436 ... 0.180549 -0.598603 0.427175 1.845072 0.924364 -0.013093 -0.054108 -0.082885 -0.719218 0.960552 7 0.548834 1.130444 1.207497 0.565839 -1.814344 -0.111523 0.480270 -1.741823 1.451116 -0.977640 ... 1.692325 -0.708754 -0.747591 1.373189 -0.224415 -0.074035 -0.323435 2.001849 -1.102584 1.644658 8 0.117209 -0.905490 0.272336 0.994848 0.648951 0.354459 -0.731171 -1.641071 -0.966286 -0.837498 ... 0.294006 1.008774 1.376944 2.969555 0.997452 2.076708 0.631358 1.080600 0.075384 1.819302 9 0.557786 -0.629395 1.606758 0.633762 -1.190379 -0.355466 -2.132275 -0.887707 1.208793 -0.741505 ... 0.765410 2.297393 -0.622529 -0.979216 2.025668 1.295631 0.210070 0.953136 -2.961691 1.586900 10 1.107697 -2.050459 1.399869 1.271179 -1.391529 1.103020 -0.910370 -0.398901 -0.803458 -2.081302 ... 1.462017 -0.115730 0.171052 0.594118 0.514388 1.593223 0.064085 -0.029184 -0.044621 1.206415 11 -1.771933 0.469475 0.961730 0.002798 1.386089 0.250342 -0.062900 -0.569053 -2.149857 -0.519952 ... -0.725692 -0.727693 -0.178683 1.675822 -0.401712 1.109331 0.980627 -0.357667 -0.484853 0.208340 12 -1.518213 1.899549 -0.320427 -0.929415 -0.701020 0.727833 -2.764498 0.612756 0.041370 -1.599998 ... -0.136314 1.068995 0.635501 0.765369 0.270007 0.319588 -0.652992 1.322658 1.724227 2.343042 13 0.094923 0.575470 -0.852224 -2.098593 0.998579 0.347285 -0.467688 0.773722 -1.664829 -0.412623 ... -1.274262 0.454381 -1.142107 1.853844 -1.912537 0.544311 0.667555 -1.187468 1.291108 2.275956 14 -0.183482 0.232812 -1.628186 -0.080981 0.377364 -1.309922 -0.275964 1.609030 -1.440347 0.427969 ... 0.792566 1.237715 1.518224 1.296686 0.072853 0.763276 1.100520 0.557057 -0.581310 2.151496 15 2.053710 -2.769740 -0.148796 0.983717 -0.038190 -0.655360 1.826909 -0.332533 -1.036128 -1.001430 ... 0.674310 0.695848 -0.181635 1.051397 -0.884897 1.590696 -1.375117 0.596254 -0.651398 0.797715 [16 rows x 64 columns]
最终输入嵌入将馈送到 Transformer 解码器**模块进行训练。
这个最终结果矩阵称为位置输入嵌入,其形状为 (batch_size,context_length,d_model)= [4,16,64]。
到目前为止,我们已经介绍了模型的输入编码和位置编码部分。接下来,让我们转到 Transformer Block。
Transformer 模块是由三层组成的堆栈:一个 masked 多头注意力机制、两个归一化层和一个前馈网络。
Masked 多头注意力是一组自注意,每个自注意都称为一个头。因此,让我们先来看看自注意力机制。
4、多头注意力概述
Transformers 力量来自一种叫做自注意的东西。通过自注意,模型密切关注输入中最关键的部分。每个部分都称为一个头部。
**这是头的工作原理:**头通过三个独特的层(称为查询 (Q)、键 (K) 和值 (V)) 处理输入来工作。它首先比较 Q 和 K,调整结果,然后使用这些比较创建一组分数,显示重要内容。然后使用这些分数来权衡 V 中的信息,从而更加关注重要部分。头的学习来自于随着时间的推移调整这些 Q、K 和 V 层中的设置。
多头注意力只是由几个单独的头堆叠在一起组成。所有头都接收到完全相同的输入,尽管它们在计算过程中使用自己特定的权重集。处理输入后,所有头的输出被连接起来,然后通过线性层。
下图提供了头部内过程的可视化表示,以及多头注意力模块中的详细信息。
为了进行证明计算,让我们从原始论文“Attention is all you need”中引入公式:
从公式中,首先需要三个矩阵:Q(查询)、K(键)和 V(值)。要计算注意力分数,需要执行以下步骤:
-
将 Q 乘以 K 转置(表示为 K^T)
-
除以 K 维数的平方根
-
应用 softmax 函数
-
乘以 V
准备Q,K,V
计算注意力第一步是获取 Q、K 和 V 矩阵,分别表示查询、键和值。这三个值将用于注意力层来计算注意力概率(权重)。这些是通过将上一步中的位置输入嵌入矩阵(表示为 X)应用于标记为 Wq、Wk 和 Wv 的三个不同的线性层来确定的(所有值都是随机分配且可学习的)。然后将每个线性层的输出拆分为多个头,表示为 num_heads,这里选择 4 个头。
**Wq、Wk、Wv 是三个矩阵,维度为(d_model, d_model)= [64,64]。**所有值都是随机分配的。这在神经网络中称为线性层或可训练参数。可训练参数是模型在训练期间将学习和自我更新的值。
为了获得 Q,K,V 值,在输入嵌入矩阵 X 和三个矩阵 Wq、Wk、Wv 中的每一个之间进行矩阵乘法(它们的初始值是随机分配的)。
-
Q = X*Wq
-
K = X*Wk
-
V = X*Wv
上述函数的计算(矩阵乘法)逻辑:
X 的形状为(batch_size, context_length, d_model)= [4, 16, 64],将其分解为 4 个形状为 [16, 64] 的子矩阵。而 Wq、Wk、Wv 的形状为 (d_model, d_model)= [64, 64]。可以对 4 个 X 的子矩阵中的每一个进行矩阵乘法,以 Wq、Wk、Wv 为单位。
如果回想一下线性代数,则只有当第一个矩阵中的列数等于第二个矩阵中的行数时,才有可能对两个矩阵进行乘法。在本例中,X 中的列数是 64,Wq、Wk、Wv 中的行数也是 64。因此,乘法是可能的。
**矩阵乘法得到 4 个子矩阵的形状 [16, 64],**可以组合表示为(batch_size,context_length, d_model)= [4, 16, 64]。
现在,Q、K、V 矩阵的形状为(batch_size, context_length, d_model)= [4, 16, 64]。接下来,需要将它们拆分为多个头。这就是为什么 transformer 架构将其命名为多头注意力的原因。
切分头只是意味着在d_model的 64 个维度中,将它们切割成多个头部,每个头部包含一定数量的维度。每个头部都将能够学习输入的某些模式或语义。
**假设将 num_heads 也设置为 4。**这意味着将 Q、K、V 形状为 [4, 16, 64] 的矩阵拆分为多个子矩阵。
实际的拆分是通过将 64 的最后一个维度重塑为 16 的 4 个子维度来完成的。
每个 Q、K、V 矩阵从形状 [4, 16, 64] 转换为 [4, 16, 4, 16]。最后两个维度是头部。换句话说,它从以下转变而来:
❝
[batch_size, context_length, d_model]
to:
[batch_size, context_length, num_heads, head_size]
要理解具有相同形状的 Q、K 和 V 矩阵 [4, 16, 4, 16],请考虑以下观点:
在管道中,有四个批次。每批由 16 个 tokens(单词)组成。对于每个token,有 4 个头,每个头编码 16 个维度的语义信息。
计算Q,K注意力
现在我们已经有了 Q、K 和 V 这三个矩阵,让我们开始逐步计算单头注意力。从 transformer 图中,Q 和 K 矩阵首先相乘。
现在,如果丢弃 Q 和 K 矩阵中的 batch_size,只保留最后三个维度,那么现在 Q = K = V = [context_length, num_heads, head_size] = [16, 4, 16]。
我们需要在前两个维度上再做一个转置,**使它们的形状为 Q = K = V = [num_heads, context_length, head_size] = [4 ,16, 16]。**这是因为需要在最后两个维度上进行矩阵乘法运算。
❝
Q * K^T = [4, 16, 16] * [4, 16, 16] = [4, 16, 16]
为什么要这样做?此处的转置是为了促进不同上下文之间的矩阵乘法。用图表解释更直接。最后两个维度表示为 [16, 16],可以可视化如下:
**这个矩阵,其中每行和每列在例句的上下文中代表一个标记(单词)。**矩阵乘法是衡量上下文中每个单词与所有其他单词之间的相似性。该值越高,它们越相似。
让我提出一个注意力得分的头:
[ 0.2712, 0.5608, -0.4975, ..., -0.4172, -0.2944, 0.1899], [-0.0456, 0.3352, -0.2611, ..., 0.0419, 1.0149, 0.2020], [-0.0627, 0.1498, -0.3736, ..., -0.3537, 0.6299, 0.3374], ..., ..., ..., ..., ..., ..., ..., ..., ..., ..., ..., ..., ..., ..., [-0.4166, -0.3364, -0.0458, ..., -0.2498, -0.1401, -0.0726], [ 0.4109, 1.3533, -0.9120, ..., 0.7061, -0.0945, 0.2296], [-0.0602, 0.2428, -0.3014, ..., -0.0209, -0.6606, -0.3170] [16 rows x 16 columns]
这个 16 x 16 矩阵中的数字代表例句 “ . By mastering the art of identifying underlying motivations and desires, we equip ourselves with
” 的注意力分数。
更容易看作一个 plot:
img
横轴**代表 Q 的头之一,纵轴表示 K 的头之一,彩色方块表示上下文中每个令牌和彼此令牌之间的相似性分数。颜色越深,相似度越高。
当然,上面显示的相似之处现在没有多大意义,因为这些只是来自随机分配的值。但是经过训练,相似性分数将是有意义的。
好了,现在把批次维度(batch_size)带回 Q*K 注意力分数。最终结果的形状为 [batch_size, num_heads, context_length, head_size],即 [4, 4, 16, 16]。
Scale
比例部分很简单,只需要将 Q*K^T 注意力分数除以 K 维度的平方根即可。在这里,K 维数等于 Q 的维数,d_model 除以 num_heads:64 / 4 = 16。
然后取 16 的平方根,即 4。并将 Q*K^T 注意力得分除以 4。
这样做的原因是为了防止 Q*K^T 注意力分数过大,这可能会导致 softmax 函数饱和,进而导致梯度消失。
Mask
在仅解码器 transformer 模型中,masked 自注意力本质上充当序列填充。解码器只能查看以前的字符,而不能查看未来的字符。因此**,未来的字符被屏蔽并用于计算注意力权重。**
如果再次可视化,这很容易理解:
空格表示 0 分,被屏蔽了
多头注意力层中 masked 要点是防止解码器“看到未来”。在本例中,解码器只允许看到当前单词和它之前的所有单词。
Softmax
img
softmax 步骤将数字更改为一种特殊的列表,其中整个列表加起来为 1。它增加了高数字并减少了低数字,从而创造了明确的选择。
简而言之,softmax 函数用于将线性层的输出转换为概率分布。
在现代深度学习框架(如 PyTorch****)中,softmax 函数是一个内置函数,使用起来非常简单:
torch.softmax(attention_score, dim=-1)
这行代码会将 softmax 应用于在上一步中计算的所有注意力分数,并产生介于 0 和 1 之间的概率分布。
让我们也提出应用 softmax 后同一头的注意力分数:
[1.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.4059, 0.5941, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.3368, 0.4165, 0.2468, ..., 0.0000, 0.0000, 0.0000], ..., [0.0463, 0.0501, 0.0670, ..., 0.0547, 0.0000, 0.0000], [0.0769, 0.1974, 0.0205, ..., 0.1034, 0.0464, 0.0000], [0.0684, 0.0926, 0.0537, ..., 0.0711, 0.0375, 0.0529]
现在,所有概率分数均为正数,加起来为 1。
计算 V 注意力
img
请记住,V 矩阵还将其拆分为多个头,形状为 (batch_size, num_heads, context_length, head_size) = [4, 4, 16, 16]。
而上一个 softmax 步骤的输出为 (batch_size, num_heads, context_length, head_size) = [4, 4, 16, 16]。
在这里,对两个矩阵的最后两个维度执行另一个矩阵乘法。
softmax_output * V = [4, 4, 16, 16] * [4, 4, 16, 16] = [4, 4, 16, 16]
结果的形状为 [batch_size, num_heads, context_length, head_size] = [4, 4, 16, 16]。我们称此结果为 A。
img
连接和输出
**多头注意力的最后一步是将所有头连接在一起,并将它们穿过线性层。**串联的理想是将来自所有头部的信息组合在一起。因此,需要将 A 矩阵从 [batch_size, num_heads, context_length, head_size] = [4, 4, 16, 16] 重塑为 [batch_size, context_length, num_heads, head_size] = [4, 16, 4, 16]。原因是需要将最后两个维度放在一起 num_heads
head_size
,因此可以很容易地将它们(通过矩阵乘法)组合大小 d_model = 64
。
这可以通过 PyTorch 的内置函数轻松完成:
A = A.transpose(1, 2) # [4, 16, 4, 16] [batch_size, context_length, num_heads, head_size]
接下来,需要将最后两个维度 [num_heads, head_size] = [4, 16] 组合到 [d_model] = [64]。
A = A.reshape(batch_size, -1, d_model) # [4, 16, 64] [batch_size, context_length, d_model]
**正如你所看到的,**经过一系列的计算,结果矩阵 A 现在回到了与输入嵌入矩阵 X 相同的形状,即 [batch_size, context_length, d_model] = [4, 16, 64]。由于此输出结果将作为输入传递到下一层,因此必须保持输入和输出相同的形状。
这个 Wo 被随机分配了形状 [d_model, d_model],并将在训练期间更新。
Output = A* Wo = [4, 16, 64] * [64, 64] = [4, 16, 64]
线性层的输出是单头注意力的输出,表示为输出。
恭喜!现在已经完成了Masked多头注意力部分!让我们开始 transformer block 的其余部分。
img
5、残差连接和层归一化
残差连接(有时称为跳过连接)是允许原始输入 X 绕过一个或多个层的连接。 这只是原始输入 X 和多头注意力层输出的总和。由于它们的形状相同,因此将它们相加很简单。
output = output + X
残差连接后,该过程进入层 归一化 。LayerNorm 是一种对网络中每一层的输出进行规范化的技术。 这是通过减去平均值并除以图层输出的标准差来完成的。此技术用于防止层的输出变得太大或太小,这可能导致网络变得不稳定。
这也是 PyTorch 中使用 nn.LayerNorm
函数的单行代码。
残差连接和层****归一化在“Attention is All You Need”的原始论文中表示 Add & Norm
。
6、前馈网络
一旦有了归一化的注意力权重(概率分数),它将通过一个位置前馈网络进行处理。
前馈网络 (FFN) 由两个线性层组成,它们之间具有 ReLU 激活函数。让我们看看 python 代码是如何实现的:
# Define Feed Forward Network output = nn.Linear(d_model, d_model * 4)(output) output = nn.ReLU()(output) output = nn.Linear(d_model * 4, d_model)(output)
将 ChatGPT 解释上述代码:
❝
输出 = nn.Linear(d_model, d_model * 4)(输出):这将对传入数据应用线性变换,即 y = xA^T + b。输入和输出大小分别为 d_model 和 d_model * 4。此转换增加了输入数据的维度。
输出 = nn.ReLU()(输出):这在元素上应用**整流线性单元 (ReLU) 函数。它被用作激活函数,将非线性引入模型,使其能够学习更复杂的模式。
输出 = nn.Linear(d_model * 4, d_model)(输出):这将应用另一个线性变换,将维数降低到d_model。这种“先扩张后收缩”是神经网络中的常见模式。
作为一个刚接触 机器学习或者LLM的人,可能会被这些解释所迷惑。当第一次遇到这些术语时,我得到了完全相同的感觉。
但不用担心,可以这样理解: 这个前馈网络只是一个标准的神经网络模块,它的输入和输出都是注意力分数。其目的是将注意力分数的维度从 64 扩展到 256,这使得信息更加精细,并使模型能够学习更复杂的知识结构。然后,它将尺寸压缩回 64,使其适用于后续计算。
7、重复第 4 步至第 6 步
**Cool!我们已经完成了第一个 transformer 模块。**现在,我们需要对我们想要的其余 transformer 块重复相同的过程。
引用 HuggingChat 的 AI 回答:
❝
GPT-2 在其最大配置 (GPT-2-XL) 中使用 48 个 transformer 块,而较小的配置具有较少的 transformer 块(GPT-2-Large 为 36 个,GPT-2-Medium 为 24 个,GPT-2-Small 为 12 个)。每个 transformer 模块都包含一个多头自注意力机制,然后是按位置的前馈网络。这些 transformer 模块可帮助模型捕获长程依赖关系并生成连贯的文本。
通过具有多个模块,输出被训练并作为输入 X 传递到下一个模块,因此在迭代后,模型可以学习输入序列中单词之间更复杂的模式和关系。
8、输出概率
在推理过程中, 希望从模型中获取下一个预测的 token,但到目前为止,我们得到的实际上是词汇表中所有 token 的概率分布。在上面的例子中我们的词汇量是 3,771。 因此,为了选择一个最高概率 token,将形成一个矩阵,其模型维度的大小 d_model = 64 乘以 vocab_size = 3,771。这一步在训练上与在推理上没有区别。
将这个线性层之后的输出称为 logits。 logits 是形状为 [batch_size, context_length, vocab_size] = [4, 16, 3771] 的矩阵。
然后使用最终的softmax函数将线性层的logits转换为概率分布。
logits = torch.softmax(logits, dim=-1)
注意:在训练过程中,不需要在这里应用softmax函数,而是使用nn.CrossEntropy 函数,因为它内置了 softmax 行为。
如何查看形状 [4, 16, 3771] 的结果对数? 实际上,经过所有计算,这是一个非常简单的想法:
我们有 4 个批处理管道,每个管道包含该输入序列中的所有 16 个单词,每个单词映射到词汇表中其他每个单词的概率。
如果模型在训练中,更新这些概率参数,如果模型在推理中,只需选择概率最高的一个。那么一切都有意义了。
结论
通常,在一开始完全掌握 Transformer 架构的复杂性可能具有挑战性。就我个人而言,我需要大约一个月的时间才能彻底理解系统中的每个组件。因此,建议查看我们参考页面上列出的其他资源,并从其他杰出的开拓者那里获得灵感。
一旦对理解 Transformer 架构充满信心,建议您继续通过一系列指导步骤实现代码。激动人心的进展等待着您!
随着大模型的持续爆火,各行各业都在开发搭建属于自己企业的私有化大模型,那么势必会需要大量大模型人才,同时也会带来大批量的岗位?“雷军曾说过:站在风口,猪都能飞起来”可以说现在大模型就是当下风口,是一个可以改变自身的机会,就看我们能不能抓住了。
那么,我们该如何学习大模型?
作为一名热心肠的互联网老兵,我决定把宝贵的AI知识分享给大家。 至于能学习到多少就看你的学习毅力和能力了 。我已将重要的AI大模型资料包括AI大模型入门学习思维导图、精品AI大模型学习书籍手册、视频教程、实战学习等录播视频免费分享出来。
一、大模型全套的学习路线
学习大型人工智能模型,如GPT-3、BERT或任何其他先进的神经网络模型,需要系统的方法和持续的努力。既然要系统的学习大模型,那么学习路线是必不可少的,下面的这份路线能帮助你快速梳理知识,形成自己的体系。
L1级别:AI大模型时代的华丽登场
L2级别:AI大模型API应用开发工程
L3级别:大模型应用架构进阶实践
L4级别:大模型微调与私有化部署
一般掌握到第四个级别,市场上大多数岗位都是可以胜任,但要还不是天花板,天花板级别要求更加严格,对于算法和实战是非常苛刻的。建议普通人掌握到L4级别即可。
以上的AI大模型学习路线,不知道为什么发出来就有点糊,高清版可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费
】
二、640套AI大模型报告合集
这套包含640份报告的合集,涵盖了AI大模型的理论研究、技术实现、行业应用等多个方面。无论您是科研人员、工程师,还是对AI大模型感兴趣的爱好者,这套报告合集都将为您提供宝贵的信息和启示。
三、大模型经典PDF籍
随着人工智能技术的飞速发展,AI大模型已经成为了当今科技领域的一大热点。这些大型预训练模型,如GPT-3、BERT、XLNet等,以其强大的语言理解和生成能力,正在改变我们对人工智能的认识。 那以下这些PDF籍就是非常不错的学习资源。
四、AI大模型商业化落地方案
以上的AI大模型学习资料,都已上传至CSDN,需要的小伙伴可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费
】
作为普通人,入局大模型时代需要持续学习和实践,不断提高自己的技能和认知水平,同时也需要有责任感和伦理意识,为人工智能的健康发展贡献力量。
标签:...,Transformer,入门,16,模型,矩阵,64,LLM,model From: https://blog.csdn.net/2401_85779703/article/details/140505669