对于高效的ViT架构,近期研究通过对剪枝或融合多余的令牌减少自注意力层的二次计算成本。然而这些研究遇到了由于信息损失而导致的速度-精度平衡问题。本文认为令牌之间的不同关系以最大限度的减少信息损失。本文中提出了一种多标准令牌融合(Multi-criteria Token Fusion),该融合基于多种标准(相似性、信息性和融合令牌尺寸)。
多尺度令牌混合
给定输入令牌集合 X \mathbf{X} X,MCTF目标是融合令牌为输出令牌 X ^ ∈ R ( N − r ) × C \hat{X}\in \mathbb{R}^{(N-r)\times C} X^∈R(N−r)×C,其中 r r r 是融合令牌数量。为了最小化信息损失,首先基于多个标准评估令牌间关系,之后使用双向二分图软匹配算法分组合并令牌。
多标准吸引函数
首先基于多标准定义吸引函数为:
W
(
x
i
,
x
j
)
=
∏
k
=
1
M
(
W
k
(
x
i
,
x
j
)
)
τ
k
W(x_{i},x_{j})=\prod_{k=1}^{M}(W^{k}(x_{i},x_{j}))^{\tau_{k}}
W(xi,xj)=k=1∏M(Wk(xi,xj))τk
W
k
W^{k}
Wk 是由第k标准计算的吸引函数。
τ
k
\tau_{k}
τk 是调整第k标准影响的超参数。两个令牌间越高的吸引分布知识越高的机会进行融合。本文考虑三个标准:相似度,信息量和尺寸。
相似度
这里使用令牌集合间余弦相似度。基于相似度的令牌融合有效地消除了冗余的令牌,然而常常组合富含信息量的令牌。
信息性
为了最小化信息损失,这里引入信息性避免富含信息令牌的融合。为了衡量信息性,测量在自注意力层的平均注意力分数 a a a: a j = 1 N A i j a_{j}=\frac{1}{N}A_{ij} aj=N1Aij, A i j = s o f t m a x ( Q i K j T C ) A_{ij}=softmax(\frac{Q_{i}K_{j}^{T}}{\sqrt{C}}) Aij=softmax(C QiKjT)。当 a i → 0 a_{i}\rightarrow 0 ai→0, x i x_{i} xi 对其他令牌没有影响。基于信息性分数,定义基于信息性的吸引函数为:
W i n f o ( x i , x j ) = 1 a i a j W^{info}(x_{i},x_{j})=\frac{1}{a_{i}a_{j}} Winfo(xi,xj)=aiaj1
当 x i x_{i} xi 和 x j x_{j} xj 都是无信息性的, W i n f o W^{info} Winfo 更高,使得两个令牌倾向于融合。
尺寸
最后的标准是令牌的尺寸,这表示融合令牌的尺寸。尽管令牌没有丢弃但通过融合函数融合,但随着组成令牌数量增加很难保留所有信息。所以偏好较小令牌的融合。初始化令牌 X X X 的尺寸 s s s 为1,跟踪每个令牌的融合令牌数量。这里定义的基于尺寸的吸引函数为:
W s i z e ( x i , x j ) = 1 s i s j W^{size}(x_{i},x_{j})=\frac{1}{s_{i}s_{j}} Wsize(xi,xj)=sisj1
双向二分图软匹配
给定基于多标准的吸引函数,本文MCTF执行松弛的双向二分图软匹配算法。通过松弛一对一对应的约束,可以通过高效的算法获得解。在松弛匹配算法中,令牌集合首先划分为源
X
α
X^{\alpha}
Xα 和目标
X
β
X^{\beta}
Xβ 变量。给定二值化决策变量(
X
α
X^{\alpha}
Xα 与
X
β
X^{\beta}
Xβ 之间边矩阵
E
E
E),二分图软匹配描述为:
E
∗
=
arg
max
E
∑
i
j
w
i
j
′
e
i
j
∑
i
j
e
i
j
=
r
,
∑
j
e
i
j
≤
1
E^{*}=\arg\max_{E}\sum_{ij}w_{ij}^{\prime}e_{ij}\quad \sum_{ij}e_{ij}=r,\quad \sum_{j}e_{ij}\leq 1
E∗=argEmaxij∑wij′eijij∑eij=r,j∑eij≤1
其中
w
i
j
′
=
{
w
i
j
j
≠
arg
max
j
′
w
i
j
′
0
Otherwise
w_{ij}^{\prime}=\left\{ \begin{aligned} w_{ij} & \quad j\neq \arg\max_{j^{\prime}}w_{ij^{\prime}}\\ 0 & \quad \text{Otherwise} \end{aligned} \right.
wij′=⎩
⎨
⎧wij0j=argj′maxwij′Otherwise
e
i
j
e_{ij}
eij 指示i,j令牌
X
α
X^{\alpha}
Xα 和
X
β
X^{\beta}
Xβ 间边,
w
i
j
=
W
(
x
i
α
,
x
j
β
)
w_{ij}=W(x_{i}^{\alpha},x_{j}^{\beta})
wij=W(xiα,xjβ)。优化步骤可以由两个简单的步骤求解:1. 对于每个
i
i
i 寻找最大的边最大化
w
i
j
w_{ij}
wij,2. 选择有最大吸引分数的TOP-r边。基于软匹配结果,按照如下划分令牌:
X
j
α
→
β
=
{
x
i
α
∣
e
i
j
=
1
}
∪
{
x
j
β
}
X_{j}^{\alpha\rightarrow \beta}=\{x_{i}^{\alpha}|e_{ij}=1\}\cup\{x_{j}^{\beta}\}
Xjα→β={xiα∣eij=1}∪{xjβ}
X
j
α
→
β
X_{j}^{\alpha\rightarrow \beta}
Xjα→β 指示与
x
j
β
x_{j}^{\beta}
xjβ 匹配的令牌集合。融合
X
~
\tilde{X}
X~ 的结果描述为:
X ~ = X ~ α ∪ X ~ β , X ~ α = X α − ⋃ i N ′ X i α → β , X ^ β = ⋃ i N ′ { δ ( X i α → β ) } \tilde{X}=\tilde{X}^{\alpha}\cup \tilde{X}^{\beta},\quad \tilde{X}^{\alpha}=X^{\alpha}-\bigcup_{i}^{N^{\prime}}X_{i}^{\alpha\rightarrow \beta}, \quad \hat{X}^{\beta}=\bigcup_{i}^{N^{\prime}}\{\delta(X_{i}^{\alpha\rightarrow \beta})\} X~=X~α∪X~β,X~α=Xα−i⋃N′Xiα→β,X^β=i⋃N′{δ(Xiα→β)}
其中 δ ( X ) = δ ( { x i } i ) = ∑ i a i s i x i ∑ i ′ a a ′ s i ′ \delta(X)=\delta(\{x_{i}\}_{i})=\sum_{i}\frac{a_{i}s_{i}x_{i}}{\sum_{i^{\prime}}a_{a^{\prime}}s_{i}^{\prime}} δ(X)=δ({xi}i)=∑i∑i′aa′si′aisixi 是考虑吸引函数 a a a 和尺寸 s s s 的池化操作。此时目标令牌的数量不能减少。为了解决这个问题,MCTF基于更新的令牌集合执行反方向的匹配 。
使用更新的两组令牌计算成对权重 w ~ = W ( x ~ i β , x ~ j α ) \tilde{w}=W(\tilde{x}_{i}^{\beta},\tilde{x}_{j}^{\alpha}) w~=W(x~iβ,x~jα) 会引入额外的计算成本 O ( N ′ ( N ′ − r ) ) O(N^{\prime}(N^{\prime}-r)) O(N′(N′−r))。为了避免这一计算负荷,本文算法融合前估计吸引函数,即重新使用预计算的权重,因为 X ~ α \tilde{X}^{\alpha} X~α 是 X α \mathbf{X}^{\alpha} Xα 子集。
前一步注意模块的信息性
先前方法使用先前注意力层获得的注意力分数。先前方法使用先前层的注意力 A l A^{l} Al 融合令牌。这一技术允许在连续层注意力是相似的假设下能高效评估。然而本文观察到注意力图存在很大差异,来自前一层的注意力图可能会获得次优的令牌融合。本文提出了前一步注意力(one-step head attention),即在下一层基于注意力衡量令牌信息含量。信息量分数 a a a 基于 A l + 1 A^{l+1} Al+1 计算得到。具体地,当令牌按照 δ ( { x i } i ) \delta(\{x_{i}\}_{i}) δ({xi}i) 融合,对应的前一步注意力分数也按照 KaTeX parse error: Double subscript at position 19: …lta(A_{i}^{l+1}_̲{i}) 在查询和键方向上融合。当融合查询注意力分数时对于 δ \delta δ 使用简单的累加并确保 ∑ j A ~ i j l + 1 = 1 \sum_{j}\tilde{A}_{ij}^{l+1}=1 ∑jA~ijl+1=1。
令牌约简一致性
本文提出一种新的微调架构进一步提升使用MCTF的ViT
f
θ
(
⋅
,
r
)
f_{\theta}(\cdot,r)
fθ(⋅,r) 的性能。观察到每层不同数量的约简令牌可能导致不同的样本表示。通过基于不同
r
r
r 训练Transformer并鼓励它们之间的一致性(令牌约简一致性),本文算法可以获得额外的性能增益。本文算法的目标函数可以描述为:
L
=
L
C
E
(
f
θ
(
x
;
r
)
,
y
)
+
L
C
E
(
f
θ
(
x
;
r
′
)
,
y
)
+
λ
L
M
S
E
(
x
r
c
l
s
,
x
r
′
c
l
s
)
\mathcal{L}=\mathcal{L}_{CE}(f_{\theta}(x;r),y)+\mathcal{L}_{CE}(f_{\theta}(x;r^{\prime}),y)+\lambda \mathcal{L}_{MSE}(x_{r}^{cls},x_{r^{\prime}}^{cls})
L=LCE(fθ(x;r),y)+LCE(fθ(x;r′),y)+λLMSE(xrcls,xr′cls)