Improving Long Text Understanding with Knowledge Distilled From Summarization Model
利用从摘要模型中提炼的知识提高长文本理解能力
paper: https://arxiv.org/abs/2405.04955
github:
本文做的是一个利用抽象摘要的能力,去提升下游长文本任务的能力,具体来说就是,利用抽象摘要模型中的注意力权重作为软标签,去训练一个要点检测器,即摘要文本中最重要的几个词,最后将文本的权重应用到下游任务中,与下游向量进行整合。其实可以理解为是为了找到文本中的关键词,然后让这些关键词更加突出,其实就是监督注意力机制的思路,只不过这个监督能力由要点检测器提供。
文章目录~
- 1.背景动机
- 2.Model
- 3.下游任务
- 4.原文阅读
1.背景动机
介绍摘要模型的总结能力对于长文本下游任务的帮助:
与短文本不同,长文本包含许多噪声词。抽象摘要旨在将源文本压缩并改写成简短版本,同时保留其主要信息。摘要模型学会了关注要点相关部分,而忽略无关部分。要点检测能力可以通过让模型意识到长文本的重要部分来提高对长文本的理解。
介绍运用摘要模型的问题:
利用摘要模型的要点检测能力,将提炼出的要点信息整合到下游模型中,以增强其长文本理解能力。然而,这仍然存在两个挑战:
- 首先,从大型摘要模型中提取每个训练样本的要点信息非常耗时。
- 其次,摘要模型在每个解码步骤中都会产生最重要信息,而长文本理解模型产生的是单一表征。
介绍本文的方法:
为了解决这些难题,本文提出了 “要点检测器”(Gist Detector),用于将摘要模型中的要点信息转移到下游的长文本理解模型中。
首先训练 Gist Detector 从摘要模型中捕捉要点信息,然后提供要点感知表征作为辅助,以增强长文本理解模型。
利用知识提炼机制来训练要点检测器,其中采用编码器-解码器架构的摘要模型是教师模型,而采用较少层编码器的要点检测器是学生模型。学生模型以教师模型产生的所有解码步骤的平均注意力分布作为 "软目标 "进行训练。由于要点检测器是非自回归模型,而且比摘要模型小得多,因此提取要点的过程可以大大提高效率。
2.Model
1.介绍具体的方法流程:
在训练过程中,利用知识提炼机制将摘要检测能力从训练有素的摘要模型(教师模型)转移到 Gist Detector(学生模型)。然后,我们将要点检测器提取的要点信息整合到下游模型中。
2.Gist Detector Architecture,要点检测器:
Gist Detector 是一个encoder架构,它从摘要模型中学习源序列中每个词的重要性权重,并为下游模型生成这些信息。具体来说,
首先将输入的 { x 1 , . . . , x N } \{x_{1},...,x_{N}\} {x1,...,xN} 映射为嵌入的 { e 1 , . . . , e N } \{\mathbf{e}_{1},...,\mathbf{e}_{N}\} {e1,...,eN},然后输入一个4层transformer encoder,得到表示 H = { h 1 , . . . , h N } \mathbf{H}=\{\mathbf{h}_{1},...,\mathbf{h}_{N}\} H={h1,...,hN}。然后,应用双层 MLP 和软最大函数来生成输入文本的概率分布 p = { p 1 , . . . , p N } \mathbf{p}=\{p_{1},...,p_{N}\} p={p1,...,pN},从而揭示每个词在源序列中的重要性。
3.Training with knowledge distillation,要点检测器的训练,即蒸馏过程:
用从抽象摘要模型(教师模型)中提取的突出信息来训练要点检测器(学生模型)。从解码过程中提取的注意力分布揭示了源文本的重要信息,将教师的注意力分布作为软目标。学生模型通过学习来重现每个训练样本的注意力分布。具体来说,
软目标
q
=
{
q
1
,
.
.
.
,
q
N
}
\mathbf{q}=\{q_{1},...,q_{N}\}
q={q1,...,qN} 是根据所有解码步骤中注意力分布的几何平均值计算得出的:
q
n
=
∑
t
a
n
,
t
T
(1)
q_{n}=\frac{\sum_{t}a_{n,t}}{T}\tag{1}
qn=T∑tan,t(1)
其中
T
T
T 是解码步骤总数。最后,优化目标是学生模型的预测概率分布
p
\mathbf{p}
p 与教师模型的软目标
q
\mathbf{q}
q 之间的交叉熵:
L K D = − ∑ ( x , y ) ∑ n = 1 N q n l o g ( p n ) ) (2) L_{KD}=-\sum_{(x,y)}\sum_{n=1}^{N}q_{n}log(p_{n}))\tag{2} LKD=−(x,y)∑n=1∑Nqnlog(pn))(2)
4.Integration of salient information,重要信息的整合:
从训练好的要点检测器中提取了重要信息,并通过融合模块将其整合到下游模型中。具体来说,
对于输入的每个长文本
{
x
1
,
.
.
.
,
x
N
}
\{x_{1},...,x_{N}\}
{x1,...,xN},Gist Detector 会生成输入文本的概率分布
p
=
{
p
1
,
.
.
.
,
p
N
}
\mathbf{p}=\{p_{1},...,p_{N}\}
p={p1,...,pN} ,从而揭示每个单词的重要性权重。长文档理解模型的上下文表示为
c
=
∑
n
N
s
n
\mathbf{c}=\sum_{n}^{N}\mathbf{s_{n}}
c=∑nNsn,将上下文表示
c
\mathbf{c}
c与重要性权重
p
\mathbf{p}
p融合为:
c
′
=
(
1
−
λ
)
c
+
λ
∑
t
p
t
s
t
(3)
\mathbf{c}^{\prime}=(1-\lambda)\mathbf{c}+\lambda\sum_{t}p_{t}\mathbf{s}_{t}\tag{3}
c′=(1−λ)c+λt∑ptst(3)
其中
λ
∈
[
0
,
1
]
\lambda\in[0,1]
λ∈[0,1] 是一个可调整的超参数。作为预测输入文本中每个单词得分的下游模型,如抽取式QA模型,将预测得分
{
r
1
,
.
.
.
,
r
N
}
\{r_{1},...,r_{N}\}
{r1,...,rN} 与重要性权重
{
p
1
,
.
.
.
,
p
N
}
\{p_{1},...,p_{N}\}
{p1,...,pN} 融合:
r
t
′
=
(
1
−
λ
′
)
r
t
+
λ
p
t
(4)
r_{t}^{\prime}=(1-\lambda^{\prime})r_{t}+\lambda p_{t}\tag{4}
rt′=(1−λ′)rt+λpt(4)
3.下游任务
1.Document Classification,文档分类:
将 BiLSTM 模型作为文档分类任务的基准模型,它将上下文表示向量输入到MLP 来预测标签。使用300d 的 GloVe 对单词嵌入进行初始化。BiLSTM 的隐藏大小设置为 256。BiLSTM 和 MLP 的层数均设置为 2d。将 Adam 作为优化器,设置为 lr = 0.001 0.001 0.001、 β 1 \beta_{1} β1 = 0.9 0.9 0.9、 β 2 \beta_{2} β2 = 0.999 0.999 0.999、 0.35 0.35 0.35 dropout,并训练 6 6 6 epochs。 λ \lambda λ 设置为 0.5$,同时将 BiLSTM 模型与Gist Detector 集成。
2.Distantly Supervised Open-Domain QA,问答:
使用 OpenQA 模型作为远距离监督开放域问题解答任务的基准模型,该模型应用选择器过滤段落,然后应用精确阅读器提取潜在答案,最后汇总这些结果以预测最终答案。将通道选择器Gist Detector 结合起来,并将 λ \lambda λ 设置为 0.5。将 c ′ \mathbf{c}^{\prime} c′输入一个线性函数,然后与问题向量相乘,得出过滤段落的得分,并将其与OpenQA选择器得出的原始得分相加,预测最终段落得分。对于读者来说,直接将答案跨度的预测得分与 Gist Detector 生成的概率分布 p \mathbf{p} p 相加来生成最终得分,其中 λ ′ \lambda^{\prime} λ′ 设置为 0.2$。
3.Text Style Transfer,文本风格转移:
选择交叉对齐自动编码器(Cross-aligned AE)和逆向正则化自动编码器(Adversarially Regularized Autoencoder,ARAE)作为基准模型。将 content 向量与Gist 检测器相结合,并将 λ 设为 0.5。
4.原文阅读
Abstract
对于自然语言处理来说,理解长文本非常重要,但也极具挑战性。一篇长文章或文档通常包含许多与要点无关的冗余词,有时甚至可以被视为噪音。随着抽象总结技术的不断进步,我们提出了 “要点检测器”(Gist Detector)来利用总结模型的要点检测能力,并将提取的要点整合到下游模型中,以增强其对长篇文本的理解能力。具体来说,要点检测器首先学习从摘要模型中提炼出的要点检测知识,然后生成要点感知表征以增强下游模型。我们在三个不同的任务中评估了我们的方法:长文档分类、远程监督开放领域问题解答和非并行文本风格转移。实验结果表明,我们的方法可以显著提高基线模型在所有任务中的性能。
1 Introduction
介绍摘要模型的总结能力对于长文本下游任务的帮助:
最近,深度学习发展迅速。基于transformer的模型在众多 NLP 任务中非常普遍,但由于输入文本长度的复杂性,在处理长文本时存在困难。与短文本不同,长文本本质上包含许多与主旨无关的噪声词。虽然近期的研究取得了可喜的成果,但很少有研究注意衡量文本的每一部分是突出的还是可忽略的。抽象概括是一项经典的 NLP 任务,旨在将源文本压缩并改写成简短版本,同时保留其主要信息。有了这个优化目标,训练有素的摘要模型就有可能检测出长文本的要点。图 1 显示了 _CNN/Daily Mail_数据集中的一个示例,其中蓝色阴影强度代表从训练有素的摘要模型中提取的重要性权重。我们可以看到,摘要模型学会了关注要点相关部分,而忽略无关部分。直观地说,要点检测能力可以通过让模型意识到长文本的突出部分来提高对长文本的理解。
介绍运用摘要模型的问题:
本文提出利用摘要模型的要点检测能力,将提炼出的要点信息整合到下游模型中,以增强其长文本理解能力。然而,这仍然存在两个挑战:首先,从大型摘要模型中提取每个训练样本的要点信息非常耗时。其次,摘要模型在每个解码步骤中都会产生最重要信息,而长文本理解模型产生的是单一表征。
介绍本文的方法:
为了解决这些难题,我们提出了 “要点检测器”(Gist Detector),用于将摘要模型中的要点信息转移到下游的长文本理解模型中。具体来说,我们首先训练 Gist Detector 从摘要模型中重现要点信息,然后提供要点感知表征作为辅助,以增强长文本理解模型。我们利用知识提炼机制来训练要点检测器,其中采用编码器-解码器架构的摘要模型是教师模型,而采用较少层编码器的要点检测器是学生模型。学生模型以教师模型产生的所有解码步骤的平均注意力分布作为 "软目标 "进行训练。由于要点检测器是非自回归模型,而且比摘要模型小得多,因此提取要点的过程可以大大提高效率。然后,我们通过融合模块将经过提炼的要点检测器提取的要点信息整合到下游模型中,从而有效增强了这些模型的长文本理解能力。
为了评估我们的方法的有效性,我们在三个任务上进行了广泛的实验:长文档分类、远距离监督开放域问题解答(DS-QA)和非并行文本风格转移。实验结果表明,我们的方法有效地增强了不同基线模型的长文本理解能力,因此在所有下游任务中都取得了显著的性能提升。
2 Methodology
介绍具体的方法流程:
在本文中,我们提出了要点检测器(Gist Detector),利用摘要模型的要点检测能力,将要点信息转移到下游的长文本理解模型中。我们首先介绍要点检测器的架构。在训练过程中,我们利用知识提炼机制将摘要检测能力从训练有素的摘要模型(教师模型)转移到 Gist Detector(学生模型)。然后,我们将要点检测器提取的要点信息整合到下游模型中。更小的模型规模和非自进式架构减少了耗时问题,而生成的单一要点感知表示法克服了不匹配问题。
2.1.Gist Detector Architecture
要点检测器:
如图 2 中部所示,Gist Detector 是一个encoder架构,它从摘要模型中学习源序列中每个词的重要性权重,并为下游模型生成这些信息。Gist Detector 有多种可能的网络架构。我们用几个 Transformer 编码器层实现了我们的要旨检测器,并证明了经过简单提炼的要点检测器可以成功地使长文档理解模型受益。
具体来说,首先将输入的 { x 1 , . . . , x N } \{x_{1},...,x_{N}\} {x1,...,xN} 映射为嵌入的 { e 1 , . . . , e N } \{\mathbf{e}_{1},...,\mathbf{e}_{N}\} {e1,...,eN},然后输入一个四层transformer encoder,得到表示 H = { h 1 , . . . , h N } \mathbf{H}=\{\mathbf{h}_{1},...,\mathbf{h}_{N}\} H={h1,...,hN}。然后,应用双层 MLP 和软最大函数来生成输入文本的概率分布 p = { p 1 , . . . , p N } \mathbf{p}=\{p_{1},...,p_{N}\} p={p1,...,pN},从而揭示每个词在源序列中的重要性。
2.2.Training with knowledge distillation
要点检测器的训练,即蒸馏过程:
我们利用知识提炼机制,用从抽象摘要模型(教师模型)中提取的突出信息来训练要点检测器(学生模型)。与典型的知识提炼不同,我们假定从解码过程中提取的注意力分布揭示了源文本的突出信息,并将教师的注意力分布作为软目标。学生模型通过学习来重现每个训练样本的注意力分布。
具体来说,软目标 m a t h b f q = { q 1 , . . . , q N } n = 1 mathbf{q}=\{q_{1},...,q_{N}\}\ _{n=1} mathbfq={q1,...,qN} n=1 是根据所有解码步骤中注意力分布的几何平均值计算得出的:
q
n
=
∑
t
a
n
,
t
T
(1)
q_{n}=\frac{\sum_{t}a_{n,t}}{T}\tag{1}
qn=T∑tan,t(1)
其中
T
T
T 是解码步骤总数。最后,优化目标是学生模型的预测概率分布
p
\mathbf{p}
p 与教师模型的软目标
q
\mathbf{q}
q 之间的交叉熵:
L K D = − ∑ ( x , y ) ∑ n = 1 N q n l o g ( p n ) ) (2) L_{KD}=-\sum_{(x,y)}\sum_{n=1}^{N}q_{n}log(p_{n}))\tag{2} LKD=−(x,y)∑n=1∑Nqnlog(pn))(2)
2.3.Integration of salient information
重要信息的整合:
为了增强下游模型对长文档的理解能力,我们从训练有素的要点检测器中提取了突出信息,并通过融合模块将其整合到下游模型中。
具体来说,对于输入的每个长文本
{
x
1
,
.
.
.
,
x
N
}
\{x_{1},...,x_{N}\}
{x1,...,xN},Gist Detector 会生成输入文本的概率分布
p
=
{
p
1
,
.
.
.
,
p
N
}
\mathbf{p}=\{p_{1},...,p_{N}\}
p={p1,...,pN} 在输入文本上,从而揭示每个单词的重要性权重。鉴于长文档理解模型的上下文表示为
c
=
∑
n
N
s
n
\mathbf{c}=\sum_{n}^{N}\mathbf{s_{n}}
c=∑nNsn,我们将上下文表示
c
\mathbf{c}
c与重要性权重
p
\mathbf{p}
p融合为:
c
′
=
(
1
−
λ
)
c
+
λ
∑
t
p
t
s
t
(3)
\mathbf{c}^{\prime}=(1-\lambda)\mathbf{c}+\lambda\sum_{t}p_{t}\mathbf{s}_{t}\tag{3}
c′=(1−λ)c+λt∑ptst(3)
其中
λ
∈
[
0
,
1
]
\lambda\in[0,1]
λ∈[0,1] 是一个可调整的超参数。作为预测输入文本中每个单词得分的下游模型,如抽取式QA模型,我们将预测得分
{
r
1
,
.
.
.
,
r
N
}
\{r_{1},...,r_{N}\}
{r1,...,rN} 与重要性权重
{
p
1
,
.
.
.
,
p
N
}
\{p_{1},...,p_{N}\}
{p1,...,pN} 融合:
r
t
′
=
(
1
−
λ
′
)
r
t
+
λ
p
t
(4)
r_{t}^{\prime}=(1-\lambda^{\prime})r_{t}+\lambda p_{t}\tag{4}
rt′=(1−λ′)rt+λpt(4)
请注意,我们使用重要性权重而非上下文表示作为突出信息,因为它包含的参数要少得多,并能减轻特定领域信息的影响。
3 Experiments
3.1.Distillation
首先,我们在 CNN/Daily Mail 上使用基于 Transformer 的编码器-解码器架构训练了 8 个抽象摘要模型,作为教师模型。教师模型的平均 ROUGE F1 分数分别为 ROUGE-1、ROUGE-2 和 ROUGE-L 的 38.6、16.3 和 35.4。我们采用相同的设置,并使用 [26] 提供的脚本对 CNN/Daily Mail 数据集进行预处理。我们为 CNN 使用宽度为 5 的 100 维过滤器来捕捉字符嵌入。我们选择 300d GloVe 预训练字嵌入,并在编码器和解码器之间共享相同的字嵌入权重。Transformer 的隐藏大小为 512。我们使用 Adam 优化器 [27],学习率为 0.0004,β1 = 0.9,β2 = 0.999。辍学率和批量大小分别设置为 0.35 和 16。为了避免梯度爆炸问题,我们采用了梯度法削波,最大梯度法为 2.0。
然后,我们利用知识提炼机制,使用基于 Transformer 的编码器架构训练 Gist Detector。我们使用 100d GloVe 进行单词嵌入,使用 50d 进行字符嵌入,Transformer 编码器的隐藏大小为 256。我们采用与教师模型相同的优化设置。
3.2.Integration into Downstream Tasks
最后,我们将训练好的 "要点检测器 "中的突出信息转移到三个长文本理解任务的下游模型中:文档分类、远距离监督开放域问题解答(DS-QA)和非并行文本风格转移。
3.2.1 Document Classification:
我们将 BiLSTM 模型作为文档分类任务的基准模型,它将前向和后向传递的最终状态值作为上下文表示向量,然后将其输入 MLP 来预测标签。我们使用300d 的 GloVe 对单词嵌入进行初始化。BiLSTM 的隐藏大小设置为 256。BiLSTM 和 MLP 的层数均设置为 2d。我们将 Adam 作为优化器,设置为 lr = 0.001 0.001 0.001、 β 1 \beta_{1} β1 = 0.9 0.9 0.9、 β 2 \beta_{2} β2 = 0.999 0.999 0.999、 0.35 0.35 0.35 dropout,并训练 6 6 6 epochs。SS 2.3 中的 λ \lambda λ 设置为 0.5$,同时将 BiLSTM 模型与我们的 Gist Detector 集成。
3.2.2 Distantly Supervised Open-Domain QA:
我们使用 OpenQA 模型[28]作为远距离监督开放域问题解答任务的基准模型,该模型应用选择器过滤段落,然后应用精确阅读器提取潜在答案,最后汇总这些结果以预测最终答案。我们在两个高质量的数据集_TriviaQA_(开放领域设置)[29]和_SearchQA[30]_上评估了我们的方法,两个指标包括ExactMatch(EM)和F1分数。我们保留了与 OpenQA 相同的超参数设置和训练设置,一些重要细节如下。我们将通道选择器与 SS 2.3 中介绍的 Gist Detector 结合起来,并将 λ \lambda λ 设置为 0.5。我们将 c ′ \mathbf{c}^{\prime} c′输入一个线性函数,然后与问题向量相乘,得出过滤段落的得分,并将其与OpenQA选择器得出的原始得分相加,预测最终段落得分。对于读者来说,我们直接将答案跨度的预测得分与 SS 2.3 中介绍的由 Gist Detector 生成的概率分布 p \mathbf{p} p 相加来生成最终得分,其中 λ ′ \lambda^{\prime} λ′ 设置为 0.2$。
3.2.3 Text Style Transfer:
至于非并行文本风格转换任务,该模型旨在将文本要点压缩成固定大小的向量,与纯粹的风格信息分离。我们选择交叉对齐自动编码器(Cross-aligned AE)[31] 和逆向正则化自动编码器(Adversarially Regularized Autoencoder,ARAE)[32] 作为基准模型。我们沿用了[31]的设置,但仍保留了长度在 70 到 150 之间的视图,而不是长度不超过 15 的视图,最终分别从亚马逊和 Yelp 的评论中获得了 350K 和 280K 非并行数据。我们保留了与交叉对齐 AE 和 ARAE 相同的超参数设置和训练设置。我们将 con- tent 向量与第 2.3 节中介绍的 Gist 检测器相结合,并将 λ 设为 0.5。为了对模型进行评估,我们使用了 4 个自动指标:(i) Acc:由预先训练好的分类器测量的将样式成功转换为目标样式的准确率。根据文献 [31],我们使用 TextCNN 模型作为分类器,其在亚马逊和 Yelp 上的准确率分别为 94.2% 和 95.7%。(ii) 余弦:我们沿用 [33] 的设置,用余弦相似度来衡量内容保存情况。(iii) 实体:我们使用名词实体的比例来衡量源文本和生成文本之间的内容一致性。(iv) PPL:通过在相应数据集上预先训练的语言模型来衡量生成文本的流畅性。
4 Results and Analysis
4.1.Results on Document Classification
我们在 FDU-MTL 数据集[34]上对我们的方法进行了 16 个领域的评估。如表 1 所示,使用我们的 Gist Detector 后,基线 BiLSTM 模型在所有 16 个领域的性能都得到了显著提高,并以 88.2 的总体准确率超过了之前的方法(ASP-MTL [34]、S-LSTM [35]、Meta-MTL [36])。一项消融研究表明,如果我们使用带有随机初始参数的 Gist Detector,整体性能会下降 3.6。这表明,Gist Detector 的附加参数和从摘要模型中提炼出的要点检测能力都有助于提高性能。
4.2.Results on DS-QA
我们在 TriviaQA(开放领域设置)[29] 和 SearchQA [30] 数据集上使用 ExactMatch (EM) 和 F1 分数指标对我们的方法进行了评估。如表 2 所示,有了Gist Detector,基线 OpenQA 模型在这两个数据集上的表现要好得多。一项消融研究表明,将突出信息整合到选择器和阅读器中可获得最佳性能。表 3 显示了我们方法的最佳选择性能。我们发现,有了 Gist Detector,选择器能更精确地过滤段落,因此我们的质量保证系统能在更少的段落中汇总信息,更快地预测答案。
4.3.Results on Text Style Transfer
我们在亚马逊和 Yelp 文本风格转移数据集[31]上进一步评估了我们的方法。表 4 的自动评估结果表明,有了我们的要点检测器,基线模型 ARAE[32] 可以实现更高的传输准确率、更好的内容保留、更好的名词实体保留以及更高的流畅度。这表明,Gist Detector 可以帮助模型从长文本中检测并压缩更多重要信息。此外,我们还进行了人工评估,以进一步评价文体转换模型的质量。我们随机抽取了 1000 个例子(500/500 正/负),让人判断文本是否转换为目标文体,并对内容相关性(0 - 5,5 为最相关)和流畅性(0 - 5,5 为最流畅)进行评价。如表 5 所示,在所有评价指标上,Gist Detector 都能显著提高基线模型的性能。
5.CONCLUSION
在本文中,我们提出了 “要点检测器”(Gist Detector),通过知识提炼机制从摘要模型中学习要旨检测能力。我们将经过提炼的要点检测器检测到的要带你信息整合到不同的下游模型中,以增强它们对长文档的理解能力。实验结果表明,在需要理解长文本的不同任务中,我们的方法显著提高了所有基线模型的性能。未来的工作将包括寻找更好的策略,将我们的要旨检测器集成到更多任务中,并处理更长的序列。
标签:...,mathbf,Knowledge,模型,摘要,抽象,要点,文本,Gist From: https://blog.csdn.net/weixin_44362044/article/details/140912777