首页 > 其他分享 >【Transformer】损失函数-交叉熵损失

【Transformer】损失函数-交叉熵损失

时间:2024-11-14 20:14:44浏览次数:3  
标签:真实 Transformer 交叉 标签 模型 损失 输出 序列

目录

在《Attention Is All You Need》论文中的 Transformer 模型主要用于机器翻译任务。对于这样的序列生成任务(如翻译、文本生成等),模型的损失函数通常是交叉熵损失函数(Cross-Entropy Loss)。这是一个用于分类任务的常用损失函数,适合语言模型生成时对词的预测。下面我将详细介绍这个损失函数在 Transformer 模型中的输入和输出。

1. 交叉熵损失的定义:

交叉熵损失用于衡量模型的预测分布与真实分布之间的差异。它的定义是:

Cross-Entropy Loss = − ∑ i = 1 N y i log ⁡ ( p i ) \text{Cross-Entropy Loss} = -\sum_{i=1}^{N} y_i \log(p_i) Cross-Entropy Loss=−i=1∑N​yi​log(pi​)

其中:

  • ( N N N) 是目标词汇的总数;
  • ( y i y_i yi​ ) 是目标序列中每个词的真实标签;
  • ( p i p_i pi​ ) 是模型对每个词的预测概率。

对于每个时间步,交叉熵损失会根据模型的预测概率与真实标签的匹配程度计算损失值,并将所有时间步的损失求和,得出序列的整体损失。

2.输入:模型的输出分布和真实标签

在 Transformer 模型中,交叉熵损失的输入包括:

  • 模型的输出分布:每个时间步的预测概率分布 ( p i p_i pi​ );
  • 真实标签(目标序列):目标序列的真实标签(即翻译后的词)。

详细解释输入

  1. 模型的输出分布

    • 输入是每个时间步的概率分布:Transformer 在每一个时间步都会预测一个词(例如,句子生成中的下一个词)。模型的输出经过最后一个线性层和 softmax 激活函数后,得到一个向量 ( p ),表示每个词的概率分布。
    • 词汇表大小的概率向量:在机器翻译任务中,每个时间步的输出是一个概率分布向量 ( p ) 的输出,它的大小与词汇表(vocabulary)的大小相同。例如,如果词汇表有 10,000 个词,则模型在每个时间步输出一个大小为 10,000 的向量,每个值表示该位置的词的概率。
  2. 真实标签(目标序列)

    • 目标序列的编码:真实标签通常是目标语言中的单词索引或词向量。
    • 目标序列中的每个词位置对应一个真实标签:例如,对于翻译任务中的句子“我喜欢学习”,对应的目标标签就是其正确翻译“i like learning”。

3. 输出:损失值

交叉熵损失函数计算的最终输出是一个标量值,即整个序列的损失。该值表示模型的预测与真实标签之间的差异,具体来说:

  • 损失的计算:交叉熵损失会计算模型对每个词的预测概率与真实标签之间的负对数相似度,再对整个序列取平均或求和。
  • 目标:损失越小,说明模型的预测分布与真实分布越接近。因此,在训练过程中,优化器会尝试最小化交叉熵损失,以提高模型对目标序列的预测准确性。

4. 详细的步骤

在 Transformer 模型的训练过程中,交叉熵损失的计算流程大致如下:

  1. 输入序列通过编码器:源语言句子经过编码器,编码成上下文表示。

  2. 目标序列经过解码器并生成预测:解码器在每个时间步利用编码器的输出和已经生成的词,预测下一个词。模型在输出层会生成一个向量,通过 softmax 转化为一个概率分布 ( p )。

  3. 计算交叉熵损失

    • 对于每一个时间步,将预测的概率分布 ( p ) 与真实标签 ( y ) 进行对比。
    • 根据交叉熵公式计算损失,即将正确的标签位置的概率 ( p ) 取负对数(因为越接近 1,负对数越小),并对每个时间步的损失求和。
  4. 输出损失值:整个序列的损失值会反馈给优化器,用于更新模型参数,逐渐提高模型对目标序列的预测准确性。

5.举例说明

例如将"I love you" 翻译成 "我爱你"的翻译任务,并假设单词表为[‘我’,‘爱’,‘你’,‘end’]

假设第一次输出的logits的概率分布为[0.8,0.1,0.05,0.05],称为预测序列
(易知预测值为概率最高的分量,即’我’)

而真实值也是’我’,我们用独热编码[1,0,0,0]表示,称为真实序列

在计算交叉熵时,就是将预测序列和真实序列代入公式。
而由于真实值是独热编码的,所以在这个例子中,我们只用计算第一个分量的位置,即
-log(0.8)

总结

在 Transformer 模型中,交叉熵损失函数的输入是模型的输出概率分布和目标序列的真实标签,它通过计算预测分布和真实分布之间的差异来得出一个损失值。这个损失值帮助优化器调整模型参数,以提升模型的预测能力,使其在训练过程中更准确地生成目标序列。

标签:真实,Transformer,交叉,标签,模型,损失,输出,序列
From: https://blog.csdn.net/2301_82023330/article/details/143697086

相关文章

  • 详细介绍Transformer!
     ......
  • 深度学习 PyTorch 中的 logits 和交叉熵损失函数
    在深度学习中,理解损失函数是训练模型的关键一步。在分类任务中,交叉熵损失函数是最常用的损失函数之一。本文将详细解释PyTorch中的logits、交叉熵损失函数的工作原理,并展示如何调整张量的形状以确保计算正确的损失。什么是logits?logits是模型输出的未归一化预测值,通常......
  • Transformer加载预训练模型实践
    以使用google-bert/bert-base-chinese模型为例下载预训练模型官方站点:https://www.huggingface.co/(如果无法访问,使用镜像站点)镜像站点:https://hf-mirror.com/搜索框内搜索自己需要的模型,点击Filesandversions, 一般下载config.json、pytorch_model.bin、tokenizer.json、t......
  • TransFormer--注意力机制:多头注意力
    TransFormer--注意力机制:多头注意力多头注意力是指我们可以使用多个注意力头,而不是只用一个。也就是说,我们可以应用在上一篇中学习的计算注意力矩阵Z的方法,来求得多个注意力矩阵。我们通过一个例子来理解多头注意力层的作用。以Alliswell这句话为例,假设我们需要计算w......
  • golang交叉编译
    交叉编译需要linux环境windows安装编译器aptinstallgcc-mingw-w64编译指令windows: CGO_ENABLED=1\ GOOS=windows\ GOARCH=amd64\ CC=x86_64-w64-mingw32-gcc\ gobuild\ -buildmode=c-shared\ -ldflags\ "-s-w\ -X'${ProjectName}/vers......
  • 《VATT: Transformers for Multimodal Self-Supervised Learning from Raw Video, Aud
    文章汉化系列目录文章目录文章汉化系列目录摘要1引言2相关工作2.1Vision中的Transformer2.2自监督学习3方法3.1标记化与位置编码3.1.1DropToken3.2Transformer架构3.3公共空间投影3.4多模态对比学习4实验4.1实验设置4.2结果4.2.1视频动作识别的微调4.2......
  • Transformers显存优化策略
    (原创)Transformers显存优化简易策略(本教程目标:4G显存也能跑BERT-Large)......
  • PoliFormer:使用 Transformers 扩展策略在线 RL,打造熟练导航员
    24年6月来自西雅图AI2的论文“PoliFormer:ScalingOn-PolicyRLwithTransformersResultsinMasterfulNavigators”,获得CoRL‘24最佳论文之一。POLIFORMER(策略Transformer),这是一个仅限RGB的室内导航智体,通过大规模强化学习进行端到端训练,尽管纯粹在模拟中训练,但它......
  • 【神经网络组件】Transformer Encoder
    【神经网络组件】TransformerEncoder目录【神经网络组件】TransformerEncoder1.seq2seq模型2.为什么只需要TransformerEncoder3.TransformerEncoder的结构1.seq2seq模型什么是sequence:sequence指由多个向量组成的序列。例如,有三个向量:\(\mathbf{a}=[1,0,0]^T,\math......
  • 【优化求解】蚁群算法ACO求解经济损失的航班延误恢复优化问题(目标函数:航班延误成本最
    ......