在阅读文献 CGMD 的代码时,看到一个不理解的损失。
一般的损失都是两者的差异,比如预测的 label 和 实际的 label ,两个梯度差,两个分类器的输出概率差等等。这里只有一个输入,然后看起来像是熵的损失。询问师兄后,说这里是一个 互信息最大化 的损失。
Introduction
MIM - mutual information maximization 互信息最大化。MIM 用来发现最能支持最终预测的特征,它能够维持模型预测分布的相对平衡;也可以扩大目标域上类之间的边界。用于学习域不变特征和任务区分性特征。
Method - MIM
模型对于目标域的预测可能不合理的偏向于某一个类别;而且许多目标域上的样本可能会距离在源域上学习到的决策边界太近。
MIM 的公式是,X 是输入 , Y 是预测值。最大化互信息可以分成两部分:最大化H(X),最小化 H(Y|X)。前者避免了 模型预测太偏向于某一类,后者增加了 置信度,扩大了类之间的边界 。
MIM 目标函数形式化如下:
ps !! 就因为没看到下面这句话,看代码卡了一下午:
参考文献《Rethinking Distributional Matching Based Domain Adaptation》
,作者使用 minibatch 中 $p_θ(y|x)$ 的均值,代替 $p_θ(y)$ ;则公式前者变为, =
q(y)
是 $p_θ(y|x)$的均值
这就是为什么这前后两个部分代码看起来几乎一样,一开始寻思也没前者也没乘积啊
Code
先看看 互信息最大化 论文的代码。
logits
是预测的 y (未经 softmax
)
则 -y_entropy
对应公式前者,经 softmax
后变成概率,再取平均,再过 log
,最后求和。ps:负号是因为在计算时也加了负号;
condi_entropy
对应公式后者,经 softmax
后,再求和取平均
OK,那么回到最初的起点。
这个互信息最大化就显而易见了。
Analysis
信息量的定义是 :如果一个事件发生的概率 $p(x)$,则该事件的信息量为 $−log p(x)$ 。将这个事件的所有可能性罗列出来,就可以求得该事件信息量的期望,信息量的期望便是熵:$H(X)=−∑_xp(xi)log{p(xi)}$
条件熵:知道X后Y还剩多少信息量$H(Y|X)$。或者知道Y后,X还剩多少信息量$H(X|Y)$
互信息:知道X,给Y的信息量带来多少损失
我对互信息最大化的理解是,知道 特征 X 后,给预测的 Y 的信息量带来越大损失越好。换句话说,就是知道 X 了,Y 也基本知道了,没有什么悬念了。这样一来,不就符合我们的预测任务了吗?
说的专业一点,就是让生成的特征 X 尽可能的是任务特异性的,也就是尽可能的 对预测 Y 作用很大。