论文标题
《Confusion Graph: Detecting Confusion Communities in Large Scale Image Classification》
混淆图:在大规模图像分类中检测混淆社区
作者
Ruochun Jin、Yong Dou、Yueqing Wang 和 Xin Niu
来自国防科技大学并行和分布式处理国家实验室,和上一篇是姊妹篇。
初读
摘要
-
问题描述:
对于基于深度卷积神经网络(CNN)的图像分类模型,我们观察到在视觉相似性高的类别之间发生的混淆要比视觉差异大的类别之间的混淆强烈得多。
-
方法描述:
在这些不平衡的混淆中,类别可以被组织成社区(community),这类似于社交网络中的人群。基于此,我们提出了一种名为“混淆图(confusion graph)”的基于图的工具,用于量化这些混淆,并进一步揭示数据库内部的社区结构。
-
作用描述:
- 利用这种社区结构,我们可以诊断模型的弱点,并使用专门的专家子网络来提高分类准确率,这与其他最先进的技术相当。
- 利用这些社区信息,我们还可以利用预训练模型自动在大规模数据库中识别错误标注的图像。使用我们的方法,研究人员只需手动检查大约 \(3\%\) 的 ILSVRC2012 分类数据库,就能定位几乎所有的错误标注样本。
再读
Section 1 Introduction
-
深度学习的挑战:
在过去的几年里,研究人员见证了图像分类领域的巨大飞跃,尤其是在深度卷积神经网络取得显著成功以及像 ImageNet 这样的大规模图像数据集出现之后。尽管最先进的模型的分类准确率已经超过了人类,但仍存在两个关键挑战:
-
首先,改进现有的基于深度CNN的模型极其困难。
-
尽管已经提出了实用的优化方法,但很少有成熟的理论被提出来指导模型设计或优化。鉴于这种情况,大多数努力倾向于采用实验方法来改进现有模型,而这种方法的一个重要部分是如何诊断和理解模型的弱点。
-
为了解决这个问题,专注于那些对模型错误最负有责任的孤立样本,已经开发了几种诊断方法和可视化工具。然而,根据我们的观察,实际上,导致大多数误预测的是具有高视觉相似性的类别之间的混淆,这些混淆很难通过这些样本级别的诊断方法发现。
- 例如,通过先前的方法,可能会发现特定的“母鸡”或“公鸡”图像是模型失败的原因。但仅凭这种样本级别的分析,人们很难注意到,实际上是“母鸡”类别和“公鸡”类别之间的混淆导致了大多数误预测。相反,如果揭示了类别之间的混淆,我们就能有效地定位与错误相关的特定样本。
因此,为了检测模型的弱点,进而支持模型的改进,量化和理解不同类别之间的混淆至关重要。
-
-
第二个挑战是,由众包构建的大规模图像数据库中不可避免地存在错误标注问题,这会对监督分类器产生严重的负面影响。这种标注噪声是不可避免的,主要有两个原因。
-
首先,随着更细粒度的类别被添加,正确标注图像需要特定领域的专业知识,例如鸟类学。然而,参与众包的大多数人专业知识有限,这增加了错误标注的概率。
-
其次,随着数据库规模的迅速扩大,通过人工检查来识别所有错误标注的样本变得极为费力,这使得几乎不可能消除所有标注噪声。尽管研究人员已经注意到这个问题的严重性,但遗憾的是,目前提出的自动检测大规模数据库中错误标注图像的方法还很少。
-
-
混淆图:
- 通过分析深度 CNN 模型的输出,我们发现具有相似视觉特征(如形状、颜色、纹理和背景)的类别之间的混淆比视觉相似性低的类别之间的混淆要强烈得多。这种现象类似于社交网络中人们的关系,朋友之间的关系比陌生人之间的关系更紧密。
- 基于这种类比,我们提出了一种基于图的工具,名为“混淆图(confusion graph)”,通过累积每张测试图像的顶部预测来量化不同类别之间的混淆。
- 然后,我们应用社区检测算法来揭示图中预期的社区结构,其中同一社区内的类别具有高视觉相似性,而来自不同社区的类别在视觉上是不同的。
-
混淆图的应用:
混淆图至少有两个应用。
- 首先,它可以作为一个诊断工具来检测给定模型的弱点。
- 图中的每个社区都代表一个弱点,如果克服了这些弱点,模型的整体性能就可以得到提升。
- 为了说明这一点,我们选择了十个三类社区,并设计了专门的层来克服每个弱点。对于基于 AlexNet 和基于 VGG-verydeep-16 的模型,top-1 错误率的平均降低分别为 \(1.49\%\) 和 \(3.45\%\),这与其他最先进的方法相当。
- 其次,我们利用预训练模型以及社区信息来自动识别图像数据库中的错误标注样本。
- 在随机污染的 Oxford102 花卉数据集上进行评估,其中 \(15\%\) 的图像被错误标注,我们的方法可以检测到大约 \(89\%\) 的错误标签,准确率为\(72\%\)。
- 在检测 ILSVRC2012 分类验证集中的错误标注时,使用我们的方法,研究人员只需手动检查大约 \(3\%\) 的整个数据库,就能定位几乎所有错误标注的样本,这显著减少了劳动工作量。据我们所知,很少有类似的工作被报道。
- 首先,它可以作为一个诊断工具来检测给定模型的弱点。
-
本文贡献:
本文主要有以下两个贡献:
- 我们观察到,深度 CNN 模型的大多数错误是由于具有高视觉相似性的类别之间的混淆造成的,并且可以根据它们的视觉混淆将图像类别划分为社区。
- 我们开发了一个基于图的工具,名为“混淆图”,用于量化类别之间的混淆。我们进一步利用图中的社区结构来诊断模型的弱点,并在大规模数据集中自动识别错误标注的图像。
Section 2 Related Work
-
错误诊断方面:
据我们所知,很少有诊断方法被提出来理解图像分类模型的错误。最相关的工作是[Kabra等人,2015],该研究通过检查有影响力的邻居来定位对错误负有最大责任的特定样本。然而,他们的方法无法确定哪些类别是导致模型误预测的原因,而我们认为这是预测失误更根本的原因。此外,已经提出了各种可视化方法来可视化特征表示,这些方法在理解模型失败方面也是有帮助的。
-
抗标签噪声方面:
为了提高模型对标注噪声的鲁棒性,已经提出了几种方法,通过自动识别和降低错误标注样本的权重。然而,很少有方法被应用于图像数据库中标注噪声的识别。[Stokes等人,2016]在一些简单的图像数据库(如只包含10个数字类别的MNIST)中验证了他们的方法。在大规模数据库(如ImageNet)中错误标注图像识别的性能仍然是未知的。
Section 3 Confusion Graph and Communities Inside
混淆图和内部社区
3.1 Definition of the Confusion Graph
混淆图的定义
定义1:给定一个具有 \(N\) 个类别的分类任务,一个模型 \(M\) 和一个数据集 \(T\),分类的混淆图 \(G=(V,E)\) 由一组顶点 \(V=\{v_1,\dots,v_N\}\) 和无自环的无向边 \(E\) 组成。
- 每个顶点 \(v\in V\) 代表分类中的一个类别。
- 边 \(e_{i,j}\in E\) 表示模型 \(M\) 可能会将类别 \(i\) 与类别 \(j\) 混淆。
- 边 \(e_{i,j}\) 的权重 \(w_{i,j}\)(详见第3.2节)量化了 \(M\) 将类别 \(i\) 误判为类别 \(j\) 或将类别 \(j\) 误判为类别 \(i\) 的可能性。边的权重越大,混淆的可能性越高。
3.2 Establish a Confusion Graph
建立混淆图
-
算法思想:
-
参数说明:
给定一个模型 \(M\),一个包含 \(N\) 个类别的测试数据集 \(T\),每个类别有 \(n\) 个单标签样本,以及一个整数参数 \(\tau\),
-
算法1 通过将每个测试样本的前 \(\tau\) 个预测映射到一个无向图来建立相应的混淆图 \(G\)。
-
算法的主要思想:
- 首先对每个测试样本的前 \(\tau\) 个分类得分进行归一化(normalizing),这些得分中隐藏了混淆信息,
- 然后将每个归一化得分累积到连接标签类别和预测类别的边的权重上。
- 例如,假设我们将一张“猫”的图片输入模型,得到类别“猫”以 \(0.5\) 的得分作为最高预测,“狗”以 \(0.2\) 的得分作为第二预测,“鹿”以 \(0.1\) 的得分作为第三预测。
- 然后,算法将这三个得分分别归一化为 \(0.625\)、\(0.25\) 和 \(0.125\)。最终,忽略自环(self loops),算法将 \(0.25\) 累积到“猫”和“狗”之间的边的权重上,并将 \(0.125\) 加到连接“猫”和“鹿”的边的权重上。
-
-
函数详解:
具体来说,在算法1中,
-
函数 “TestOneSample” 将一张图像t输入模型 \(M\),输出包括预测的类别和得分,这些分别保存在 \(R.c\) 和 \(R.s\) 中。
-
函数 “ScoreNormalization” 接收一个得分数组作为输入,并通过类似 softmax 的公式1 对每个值进行归一化。通过归一化前 \(\tau\) 个得分,每个测试样本在构建图G中边的权重方面具有相同的贡献。
\[topR[i].s=\frac{e^{topR[i].s}}{\sum^\tau_{j=1}e^{topR[j].s}} \]
-
-
实例说明:
-
作为说明说明目,我们使用算法1并将ω设置为 5 评估了预训练的 LeNet 模型在 CIFAR10 验证集上的性能,并得到了名为 LeNet-CIFAR10 的混淆图。
-
本文中的所有实验都是使用 Matlab2014a 和 MatConvNet 完成的,预训练模型是从 MatConvNet 网站下载的。
-
如图2(a)所示,图中的每个顶点代表数据集中的一个类别,每条边的权重量化了两个类别之间的混淆程度。
- 例如,图中连接“狗”和“猫”的最强边表明模型很可能会将狗和猫混淆。相反,连接“鹿”和“卡车”的细小链接意味着模型很少将鹿与卡车混淆。
-
3.3 Detect Communities in a Confusion Graph
在混淆图中检测社区
-
设计思想:
- 在大规模图像数据库(如 ImageNet)中,不同类别之间的混淆关系非常复杂,我们无法简单地使用孤立的边来分析其中的多对多关系。
- 因此,受社交网络中的社区结构启发,该结构揭示了人们的组织方式,我们应用社区检测算法来探索混淆图中的社区,并利用每个社区的模块性作为衡量每个社区紧凑性的指标。
- 在混淆图中,社区的意义是双重的:
- 一方面,从模型的角度来看,大多数错误源于社区内部的混淆,这表明每个社区都可以被视为模型的一个弱点。
- 另一方面,从数据库的角度来看,属于同一社区的类别对模型来说是一个难题,因为大多数错误发生在这些社区内部。
-
算法细节:
-
我们使用快速社区检测算法的第一次迭代来发现混淆图中的细粒度社区。
-
由于混淆图中边的权重分布极其不平衡,为了突出主要的混淆关系,在应用算法之前,我们删除了大多数微小的边。
- 具体来说,我们首先根据边的权重对边进行排序。
- 然后,我们使用 \(p\) 位百分位数(p-th percentile,\(0 < p < 100\))作为切割点来过滤掉微小的边。任何权重小于 \(p\) 位百分位数的边都会被删除。
-
我们还引入了[Blondel等人,2008]中的模块性(modularity)概念来衡量社区内部链接的密度与社区间链接的密度。这个值使我们能够比较不同社区的紧凑性(在第4.1节中使用)。
-
给定某个社区划分,我们可以通过公式2 来计算第 \(k\) 个社区的模块性 \(Q_k\):
\[Q_k=\frac{1}{2m}\sum_{i,j}(w_{i,j}-\frac{s_is_j}{2m}) \delta(c_i,c_j)\theta(c_i,c_j,k) \]-
参数字典:
-
\(s_i = \sum_{j}w_{i,j}\) 是顶点 \(i\) 连接的边的权重之和。
-
\(c_i\) 是顶点 \(i\) 所属的社区。
-
\(\delta(u,v)\) 是一个指示函数,如果 \(u=v\) 则等于 \(1\),否则等于 \(0\)。
-
\(\theta(c_i,c_j,k)\) 也是一个指示函数,如果 \(c_i\) 或 \(c_j\) 是第 \(k\) 个社区,则等于 \(1\),否则等于 \(0\)。
-
\(m=\frac{1}{2}\sum_{i,j}w_{i,j}\) 是所有边权重总和的一半。
-
-
-
-
实例说明:
-
CIFAR10:
当我们将 \(p\) 设置为 50 时,我们得到了 LeNet-CIFAR10 的社区划分(如图2(b)所示),其中同一社区的类别被赋予相同的颜色。
- "卡车"、"汽车"、"船"和"飞机"属于同一个社区,因为它们都是有外壳的人造交通工具,这使它们与其余类别区分开来。类似地,基于视觉特征,动物类别可以进一步分为三个社区。
-
CIFAR100:
-
选择动因:
为了解释为什么某些类别会聚集在一起,我们选择了 CIFAR100 作为研究对象,因为它的复杂性介于 CIFAR10 和 ImageNet 之间。这种适中的复杂性为调查提供了多样化的社区,并且手动分析社区结构的工作量是可以接受的。
-
类别结构设计理念:
在 CIFAR100 中,有 100 个细粒度类别,这些类别进一步分为 20 个粗粒度的超类,每个超类包含 5 个子类。这种类别结构的设计基于这样一个理念:同一超类内的类别相似度较高,因此比属于不同超类的类别更难区分。
-
参数设置:
通过训练一个基于 LeNet 的模型,其 top-5 错误率为 \(22.5\%\),并将 \(\tau\) 设置为 5,我们得到了使用验证集得到的混淆图。为了便于说明,我们隐藏了微小的边,并将 \(p\) 设置为 95,从而进一步揭示了内部的社区结构(如图3所示)。
-
结果分析:
通过仔细比较,我们发现混淆图中的社区结构与 CIFAR100 的原始类别结构之间存在显著差异。CIFAR100 中不同超类的一些类别在混淆图中形成社区的主要原因总结如下:
-
类似的形状:例如,“蛇” 属于 “爬行动物” 超类,而 “蠕虫” 属于 “非昆虫无脊椎动物” 超类,这两个类别因为都有长身体而在同一个混淆社区中。
-
类似的背景或环境:一个很好的例子是最大的社区,包括 “水獭”、“海豹”、“鲸鱼”、“乌龟”、“海豚”、“水族馆鱼”、“鳐鱼”、“鲨鱼”、“鲑鱼” 和 “比目鱼”。尽管这些生物来自不同的超类,但它们都生活在水里或水附近。这些相似的水生背景将它们结合成一个社区。
-
类似的纹理或颜色:“森林”、“柳树”、“松树”、“枫树”、“橡树” 和 “棕榈树” 等类别构成一个紧密的社区,因为这些植物的叶子颜色和纹理非常相似。
-
共同出现:尽管“床”、“桌子”、“椅子”、“电视”、“沙发”、“衣柜” 和 “键盘” 在视觉上相似性不大,但这些类别形成一个社区,因为这些家具和电子设备通常一起出现在客厅或卧室的图片中。
这些发现表明,混淆图揭示了基于视觉特征和上下文信息的类别关系,这可能与原始的类别结构有所不同。这种分析有助于我们理解模型在分类时可能遇到的困难,并为改进模型提供了新的视角。
-
-
-
ILSVRC2012:
-
参数设置:
-
将 \(\tau\) 设置为 5,我们得到了使用 ILSVRC2012 验证集评估的 AlexNet 的混淆图,命名为 AN-ILSVRC2012。
-
然后我们对 AN-ILSVRC2012 进行了类似的分析,观察到具有高视觉相似性的类别聚集成社区的现象。将 \(p\) 设置为 92,我们得到了一个记录了 AN-ILSVRC2012 内部143个社区的社区列表 \(L\)。这些社区的规模从1到24不等(详见图6),并且在图4 中展示了 10 个每个包含 3 个类别的社区作为示例。
-
-
从这些示例中,我们可以看到,基于 CIFAR100 总结的原因与 ImageNet 数据库是兼容的。这进一步表明,最先进的 CNN 模型的大多数分类错误源于在视觉特征上有微小差异的类别。
-
-
Section 4 Applications of the Confusion Graph
混淆图的应用
4.1 Detect Class-scale Weaknesses of the Model
检测模型的类别级弱点
-
关键点:模块性(modularity)
- 混淆图的第一个应用是诊断相应分类模型的弱点,因为大多数错误发生在每个社区内部,而不同社区之间的混淆较少。
- 每个社区可以被视为一个弱点,如果能克服这些弱点,预测准确率就可以得到提高。
- 此外,基于我们的实验,我们发现模块性值较高的社区在提高模型分类性能方面更有潜力。
-
实验验证:
-
实验设置:
-
为了证明基于图的诊断的优势以及模块性值的效果,我们从第 3.3 节获得的社区列表 \(L\) 中选择了五个模块性值最高的 3 类社区和五个模块性值最低的社区。
-
对于每个选定的社区,我们训练一个基于 AlexNet 的专家子网络(ES),如图5所示。每个 ES 包含三个全连接层,前向预测过程可以分为两个阶段。
- 首先,通过路径 ①,图像由原始的 AlexNet 进行分类。如果没有任何一个 top-3 个预测属于 ES 的社区,整个过程就结束了。否则进入路径 ②
- 通过路径 ②,CNN 部分提取的特征将直接发送到相应的 ES,ES 的输出将替换第一阶段的 top-3 个预测。通过将随机初始化的 ES 级联到 AlexNet 的预训练卷积层上,我们使用 ILSVRC2012 训练集中相应 3 类的图像来训练每个 ES。
-
-
实验结果:
- 我们使用 ILSVRC2012 验证集中的图像测试了每个改进模型。每次测试使用了 ES 专门针对的3个类别的 150 张图像。使用 top-1 错误率作为性能指标,结果如表1所示。
- 在相同的参数设置下,我们还基于 VGG-verydeep-16 进行了类似的实验,构建了 VGG-verydeep-16 的混淆图,检测内部的社区,并最终使用 ES 来克服每个弱点。
-
-
结果分析:
- 如表1所示,在基于 AlexNet 的实验中,所有错误率都有所下降,平均下降约为 \(1.49\%\)。
- 此外,模块性排名前 5 的社区的平均 top-1 错误率下降幅度大于模块性排名最低 5 的社区,分别为 \(2.15\%\) 和 \(0.84\%\)。在基于 VGG 的实验中也可以观察到类似结果,分别为 \(5.52\%\) 和 \(1.38\%\)。
- 这表明,在相同的优化方法下,从更紧凑的社区(即模块性更高的社区)可以获得更多的改进。因此,通过在混淆图中检测社区,克服每个社区所代表的弱点,并专注于模块性高的社区,我们可以有效地降低模型的整体错误率。
-
相关工作对比:
- 与我们的方法相比,类似的优化结果在[Yan等人,2015]中也有报道,他们通过减少 top-1 错误率 \(1.11\%\) 来优化。具体来说,他们使用基于混淆矩阵的谱聚类将细粒度类别聚类成粗粒度类别。
- 然而,他们的方法不能清晰地显示哪个粗粒度类别在准确性提升方面具有更大的潜力,而且谱聚类对参数选择敏感,这比我们基于图的方法的鲁棒性要差。
4.2 Identify Mislabeled Images in the Database
在数据库中识别错误标注的图像
-
错误标注定义:
错标图像是指标签完全不相关的图像。如果图像中的任何对象都被正确标注,则该图像不属于误标注图像。这一定义与[Deng 等人,2009]中的定义一致。利用算法 2,我们可以使用预先训练好的模型来自动检测误标注图像。
-
错误标签筛选标准:
我们的方法使用两个标准来筛选可疑的错误标注样本。
- 首先,如果大多数前 \(\mu\) 个预测类别与标签类别不在同一社区,那么这张图像可能被错误标注。
- 其次,如果前 \(\mu\) 个预测得分的均方根(RMS)高于平均值,那么这个预测结果是可信的。
基于这两个标准,大多数错误标注可以被识别,因为错误标注的图像通常有一个可信的预测结果,但其大多数前预测类别并不与标签类别在同一社区。算法 2 中的社区信息至关重要,因为只有拥有社区列表,我们才能利用最先进的模型的人类水平的 top-5 错误率来识别标注错误。
-
参数 \(\mu\) 的设置:
- 参数 \(\mu\) 控制着检测的精确度和召回率之间的权衡,高 \(\mu\) 值导致高精确度低召回率,而低 \(\mu\) 值导致低精确度高召回率。
- 在实践中,我们通过手动检查算法2的输出来迭代纠正错误标注的图像。在自动检测的迭代过程中,参数 \(\mu\) 从5逐渐降低到2,迭代过程在输出中没有真正的错误标注图像时结束。
-
实验验证:
我们设计了两个实验来验证我们的方法。
-
实验一:小规模数据集上的实验
-
实验设置:
-
首先,为了展示我们方法的高精确度和召回率,我们将算法 2 应用于清理随机污染的 Oxford 102 花卉数据集。原始数据集包含102个细粒度花卉类别,所有图像都由专家正确标注。
-
我们首先使用干净的数据集训练分类模型,并获得了四个模型,其 top-1 预测错误率分别为 \(40\%\)、\(30\%\)、\(20\%\) 和 \(10\%\)。
-
然后我们使用包含 1020 张图像的验证集(即每个类别 10 张图像)来构建混淆图并获取每个模型的社区列表。
-
为了模拟某些图像被错误标注的情况,我们随机选择验证集中的 \(3\%\)、\(5\%\) 和 \(10\%\) 的图像,并用一个随机类别错误标注它们。
-
-
实验结果:
通过我们的迭代方法,我们利用这四个模型在污染的验证集中识别错误标注的样本,结果如表2所示,
- 图片注解:
- 其中 “PM” 表示错误标注的百分比,
- “ER” 表示模型的错误率,
- “NM” 表示错误标注样本的数量,
- “NMD” 表示检测到的错误标注样本的数量,
- “NTMD” 表示检测结果中真正的错误标注样本的数量。
- “精确度” 是 “NTMD” 与 “NMD” 的比率,
- “召回率”是 “NTMD” 与 “NM” 的比率。
- 图片注解:
-
结果分析:
- 如表2所示,我们的方法能够识别大约 \(92\%\) 的错误标注图像,相应的精确度约为 \(80\%\)。
- 此外,结果表明,我们方法的 “精确度” 和 “召回率” 与模型的错误率呈负相关,这意味着如果我们使用分类错误率较低的模型,我们可以更准确地找到更多的错误标注样本。
-
-
实验二:大规模数据集上的实验
-
实验设置:
- 实验二为了研究在处理大规模数据库时的性能,在我们的第二个实验中,我们使用预训练的 VGG-verydeep-16 来检测 ILSVRC2012 分类验证集中的错误标签。
- 相应的混淆图和社区列表在第4.1节中获得。在我们的实验中,我们首先使用算法2来检测可疑样本。然后我们手动检查每个自动检测到的样本,并确认真正错误标签的数量。
- 参数 \(\mu\) 从 0 变化到 5,
-
实验结果:
结果如表3所示,
-
图片注解:
其中 “TMP” 是真实错误标注的百分比,“DMP”是检测到的错误标注的百分比,它们分别是“NTMD”和“NMD”与数据集中所有图像数量(在这个实验中为50,000)的比率。图7展示了一些我们的自动检测结果作为示例。
-
-
结果分析:
基于我们的第二个实验,我们可以得出三个主要结论。
- 首先,根据[Deng等人,2009],ImageNet的标注准确率为\(99.7\%\)。然而,通过手动检查 VGGverydeep-16 检测到的可疑样本,我们在 50,000 张图像中确认了 383 个真实的错误标签,这表明之前报告的 ImageNet 的准确率可能高于实际情况。
- 其次,如果 ImageNet 的准确率真的如报告的那样是 \(99.7\%\),使用我们的方法,研究人员只需手动检查 ImageNet 数据库的大约 \(3\%\),就可以找出几乎所有的错误标注样本,这显著减少了劳动工作量。
- 第三,我们观察到大多数标注错误发生在特殊的动植物类别中,如 “鞭尾蜥蜴” 或 “长臂猿”,这证明了在标注包含多种类别的大规模图像数据库时,专业知识是必要的。
-
-