首先“心算”三位整数乘法不管对人类还是对模型来说都不简单的。如果使用CoT的方式就类似于“笔算”,如果使用编程的方式就类似于人拿着计算器算。我将问题更精确一点地表述为“模型如何在心算多位整数乘法上接近或超过人的水平?”
这个问题困扰了我很久,简单乘法是推理能力的一种体现,如果可以解决,那大模型的整体性能应该还能大幅提升。于是我做了很多实验来寻找GPT连简单乘法都难以解决的技术层面原因。直到我借助CRNN里的CTCLoss来解码不定长序列,而不是通过TransformerDecoder以自回归的方式生成,这个问题才得到了很好的解决。下面我从技术层面上分析下可能的原因以及如何解决。
首先让我们看看GPT是如何训练的。 假设有足够多的样本,形式如下:
56*123=6888
122*222=27084
777*512=397824
612*324=198288
243*753=182979
896*267=239232
368*12=4416
......
GPT的原理是每次只看前n个字符,预测后续的1个字符 以56*123=6888为例,每一次模型迭代,等同于10次小的迭代的相加:
1、 输入5,预测6
2、 输入56,预测*
3、 输入56*,预测1
4、 输入56*1,预测2
5、 输入56*12,预测3
6、 输入56*123,预测=
7、 输入56*123=,预测6
8、 输入56*123=6,预测8
9、 输入56*123=68,预测8
10、输入56*123=688,预测8
细心的读者应该很快能发现,迭代1~迭代5是几乎无效的,因为对于等号前面的数字只能预测出是个0~9的字符,而没法预测出具体的值。实验表明在模型迭代过程中,屏蔽掉迭代1~迭代5,可以加快模型的收敛速度。不过虽然加快了收敛速度,但模型仍很难达到较高精度,特别是预测结果中靠近中间的数字,例如896*267,结果“239232”中间的“9”、“2”两个数字更难预测。显然还有更深层的原因我们没有发现。
那是不是因为GPT每一次小的迭代“只看前n个字符,预测后续的1个字符”导致的?也就是说每一次小的迭代都只告诉模型答案中的一个字符,而不告诉模型完整的答案。可是如果在每一次小的迭代中不告诉模型完整的答案,那么即使神经网络本身的推理能力再强,每一次小的跌代也很难学习到足够多的信息。打个比方,一个数学老师这么教他的学生:
同学们,今天我教大家2位数乘法,
11*11的结果中第一位是1
21*11的结果中第一位是2
21*21的结果中第一位是4
21*32的结果中第一位是6
32*68的结果中第一位是2
......
请同学们好好学习,学会了第一位我再教你们第二位,学会了第二位我再教你们第三位...
试想一下如果老师真这么教,那学生更多的是依靠记忆能力,而不是推理能力。有读者可能会问在迭代10中完整的答案不是出现了吗?可惜的是神经网络只会学习输入与目标之间的映射关系,而不会从输入中学到任何信息。
所以我先假设“Transformer的推理能力是足够的,之所以学不会是因为老师教的方式不对”,那接下来该如何验证该假设为真呢?由于我对CRNN比较熟悉,CRNN可以预测不定长序列,并且每一次迭代,都会告诉模型完整的答案。于是我将CRNN模型进行了一定的改造来学习3位数乘法,模型大致结构如下: (B, T) – nn.Embedding --> (B, T, C) – TransformerEncoder层 --> (B, T, C) – MLP层 --> (B, T, C o u t C_{out} Cout)
训练样本对的形式如下:
--------56*123= 6888
-------122*222= 27084
-------777*512= 397824
-------612*324= 198288
-------243*753= 182979
-------896*267= 239232
--------368*12= 4416
......
使用CTCLoss进行梯度更新。经过单卡20分钟左右的训练,训练损失基本接近0,测试准确率约99.994%。
为了进一步验证假设,我设计了对比实验:保持Transformer层不变,将模型改成多分类模型,共10个类别代表数字0~9。训练样本对的形式如下:
--------56*123= 6
-------56*123=6 8
------56*123=68 8
-----56*123=688 8
-------122*222= 2
------122*222=2 7
-----122*222=27 0
----122*222=270 8
---122*222=2708 4
......
原始的一条样本被拆成了多条子样本,打包在同一个batch里,模拟GPT的训练方式。使用交叉熵损失函数进行梯度更新。实验结果:模型收敛很慢,训练了很长时间,测试准确率都不是很高。
结论:Transformer的推理能力是足够强大的,基于Transformer的GPT作为一个学生推理能力也是足够的,但由于老师的教学方法不对,导致连3位整数乘法都很难学会。
补充一下,既然改用CTCLoss后可以学会连GPT都很难学会的3位整数乘法,那是否意味着CTCLoss可以广泛应用于大语言模型呢?很遗憾,答案暂时是否定的,因为GPT可以更好地解决模糊性问题,比如同一个问题,有多种回答都是正确的,这种情况不适合CTCLoss。
2024.04.12更新
这几天关注到Cohere公司开源的大模型Command R+,我简单测试了一下,发现Command R+对三位/四位整数乘法的精度相当高(不使用CoT,编程等任何辅助方式)。这让我意识到GPT也是能够将简单乘法训练到几乎100%测试准确率的,我之前训练不出来是因为自己对TransformerDecoder的认识不够深入。
受 Command R+ 的启发,我重新设计了基于TransformerDecoder Only的GPT模型方案,通过合成大量的3位整数乘法训练样本,使用CrossEntropyLoss损失函数进行梯度更新,单卡训练约5个小时后,训练损失几乎为0,测试准确率约99.991%。随机生成10万个样本进行测试,错了5个:
698 * 716 != 509768, 499768(expected)
949 * 959 != 900091, 910091(expected)
616 * 13 != 7008, 8008(expected)
95 * 63 != 6085, 5985(expected)
237 * 38 != 8006, 9006(expected)
结论:使用足够多的样本,训练足够多的轮数,GPT也是可以在3位整数乘法上达到99.99%以上准确率的,相对于TransformerEncoder+CTCLoss方案,TransformerDecoder+CrossEntropyLoss方案收敛更慢,需要数倍的训练时间。但现实情况是,大模型的训练成本太高,一般也就训练一轮,自然在简单乘法上达不到很高的准确率。
标签:迭代,模型,56,123,LLM,GPT,ChatGPT,算错,乘法 From: https://blog.csdn.net/weixin_42885381/article/details/143226979