目录
在《Attention Is All You Need》论文中的 Transformer 模型主要用于机器翻译任务。对于这样的序列生成任务(如翻译、文本生成等),模型的损失函数通常是交叉熵损失函数(Cross-Entropy Loss)。这是一个用于分类任务的常用损失函数,适合语言模型生成时对词的预测。下面我将详细介绍这个损失函数在 Transformer 模型中的输入和输出。
1. 交叉熵损失的定义:
交叉熵损失用于衡量模型的预测分布与真实分布之间的差异。它的定义是:
Cross-Entropy Loss = − ∑ i = 1 N y i log ( p i ) \text{Cross-Entropy Loss} = -\sum_{i=1}^{N} y_i \log(p_i) Cross-Entropy Loss=−i=1∑Nyilog(pi)
其中:
- ( N N N) 是目标词汇的总数;
- ( y i y_i yi ) 是目标序列中每个词的真实标签;
- ( p i p_i pi ) 是模型对每个词的预测概率。
对于每个时间步,交叉熵损失会根据模型的预测概率与真实标签的匹配程度计算损失值,并将所有时间步的损失求和,得出序列的整体损失。
2.输入:模型的输出分布和真实标签
在 Transformer 模型中,交叉熵损失的输入包括:
- 模型的输出分布:每个时间步的预测概率分布 ( p i p_i pi );
- 真实标签(目标序列):目标序列的真实标签(即翻译后的词)。
详细解释输入
-
模型的输出分布
- 输入是每个时间步的概率分布:Transformer 在每一个时间步都会预测一个词(例如,句子生成中的下一个词)。模型的输出经过最后一个线性层和 softmax 激活函数后,得到一个向量 ( p ),表示每个词的概率分布。
- 词汇表大小的概率向量:在机器翻译任务中,每个时间步的输出是一个概率分布向量 ( p ) 的输出,它的大小与词汇表(vocabulary)的大小相同。例如,如果词汇表有 10,000 个词,则模型在每个时间步输出一个大小为 10,000 的向量,每个值表示该位置的词的概率。
-
真实标签(目标序列)
- 目标序列的编码:真实标签通常是目标语言中的单词索引或词向量。
- 目标序列中的每个词位置对应一个真实标签:例如,对于翻译任务中的句子“我喜欢学习”,对应的目标标签就是其正确翻译“i like learning”。
3. 输出:损失值
交叉熵损失函数计算的最终输出是一个标量值,即整个序列的损失。该值表示模型的预测与真实标签之间的差异,具体来说:
- 损失的计算:交叉熵损失会计算模型对每个词的预测概率与真实标签之间的负对数相似度,再对整个序列取平均或求和。
- 目标:损失越小,说明模型的预测分布与真实分布越接近。因此,在训练过程中,优化器会尝试最小化交叉熵损失,以提高模型对目标序列的预测准确性。
4. 详细的步骤
在 Transformer 模型的训练过程中,交叉熵损失的计算流程大致如下:
-
输入序列通过编码器:源语言句子经过编码器,编码成上下文表示。
-
目标序列经过解码器并生成预测:解码器在每个时间步利用编码器的输出和已经生成的词,预测下一个词。模型在输出层会生成一个向量,通过 softmax 转化为一个概率分布 ( p )。
-
计算交叉熵损失:
- 对于每一个时间步,将预测的概率分布 ( p ) 与真实标签 ( y ) 进行对比。
- 根据交叉熵公式计算损失,即将正确的标签位置的概率 ( p ) 取负对数(因为越接近 1,负对数越小),并对每个时间步的损失求和。
-
输出损失值:整个序列的损失值会反馈给优化器,用于更新模型参数,逐渐提高模型对目标序列的预测准确性。
5.举例说明
例如将"I love you" 翻译成 "我爱你"的翻译任务,并假设单词表为[‘我’,‘爱’,‘你’,‘end’]
假设第一次输出的logits的概率分布为[0.8,0.1,0.05,0.05],称为预测序列。
(易知预测值为概率最高的分量,即’我’)
而真实值也是’我’,我们用独热编码[1,0,0,0]表示,称为真实序列。
在计算交叉熵时,就是将预测序列和真实序列代入公式。
而由于真实值是独热编码的,所以在这个例子中,我们只用计算第一个分量的位置,即
-log(0.8)
总结
在 Transformer 模型中,交叉熵损失函数的输入是模型的输出概率分布和目标序列的真实标签,它通过计算预测分布和真实分布之间的差异来得出一个损失值。这个损失值帮助优化器调整模型参数,以提升模型的预测能力,使其在训练过程中更准确地生成目标序列。
标签:真实,Transformer,交叉,标签,模型,损失,输出,序列 From: https://blog.csdn.net/2301_82023330/article/details/143697086