首页 > 其他分享 >联合学习MOON——无需共享原始数据,通过模型对比联合学习实现准确的图像分类

联合学习MOON——无需共享原始数据,通过模型对比联合学习实现准确的图像分类

时间:2024-04-08 12:58:56浏览次数:22  
标签:数据 模型 MOON 学习 图像 原始数据 客户端

1. 概述

联合学习(Federated Learning)是一种分布式的机器学习方法,它允许多个参与者协作训练一个共享的模型,同时保持各自数据的隐私性。这种方法特别适用于那些涉及敏感数据的场景,如医疗、金融和个人设备等。
在传统的中心化机器学习方法中,所有的数据需要被收集到一个中心服务器上,然后在这个服务器上进行模型的训练。然而,这种方法在处理隐私敏感的数据时存在较大的风险,因为数据一旦被泄露,可能会导致严重的隐私问题。
联合学习通过在每个客户端本地进行模型的训练和更新,然后只将模型的更新(如权重更新)发送到服务器,而不是原始数据,从而解决了这个问题。服务器聚合这些更新,然后分发新的全局模型给所有客户端。这个过程会重复进行,直到达到预定的性能目标或者收敛。
这种方法的优点包括:

  1. 隐私保护:由于原始数据不需要离开客户端,因此可以显著降低数据泄露的风险。
  2. 通信效率:只有模型的更新需要在网络中传输,这比传输大量的原始数据要节省带宽。
  3. 灵活性:联合学习可以应用于各种设备和环境中,包括移动设备和边缘计算设备。
  4. 可扩展性:联合学习可以轻松地扩展到成千上万的客户端,而不需要一个强大的中心服务器。
  5. 模型性能:通过结合多个客户端的数据和知识,可以训练出更强大的模型,尤其是在处理非独立同分布(non-IID)数据时。

然而,联合学习也面临一些挑战,比如如何设计有效的聚合算法、如何处理不同客户端的非独立同分布数据、如何确保系统的鲁棒性和安全性等。

总的来说,联合学习是一种非常有前景的机器学习方法,它能够在保护隐私的同时,充分利用分布式数据的优势,推动了人工智能在多个领域的应用和发展。

2. 联合学习

传统的机器学习方法通常需要将数据集中存储在中央服务器上进行模型的训练和学习。这种方法虽然在某些情况下能够带来较好的模型性能,但它存在一些显著的问题,尤其是在数据隐私和安全性方面。

首先,当原始数据需要在多个客户端或设备之间传输并汇总到中央服务器时,数据在传输过程中可能被截获或窃取,导致敏感信息的泄露。这对于那些处理个人身份信息、健康记录、金融交易记录等敏感数据的行业来说,是一个严重的安全隐患。

其次,一旦所有数据都集中在中央服务器上,这个服务器本身就成为了一个高价值的目标,容易受到黑客攻击。如果服务器的安全措施不够完善,一旦遭受攻击,可能会导致大规模的数据泄露,给企业和用户带来重大损失。

随着社会对个人隐私保护意识的提高,越来越多的企业和组织开始寻求更加安全的数据处理方法。在这种背景下,联合学习作为一种新型的机器学习方法应运而生,它能够在不共享原始数据的前提下,实现分布式数据上的模型训练和学习。

联合学习通过在本地设备上训练模型,并将模型更新(如权重变化)发送到中央服务器进行聚合,而不是传输原始数据。这样既保护了用户的隐私,又减少了数据在传输过程中被截获的风险。此外,由于攻击者无法通过攻击中央服务器来获取大量的原始数据,这大大降低了数据泄露的可能性。

3.算法类型

联合学习可以根据数据的共同点来区分为横向联合学习和纵向联合学习两种模式。

横向联合学习(Horizontal Federated Learning)适用于数据样本特征空间相同,但样本空间不同的情况。在这种模式下,参与联合学习的各个客户端拥有相似或相同的特征集,但它们的样本集合是不同的。例如,不同的医院可能收集了具有相同特征(如年龄、性别、病史等)的患者数据,但每个医院的患者群体是独立的。横向联合学习允许这些医院合作,共同训练一个模型,同时保护各自患者的隐私数据。

纵向联合学习(Vertical Federated Learning)则是针对特征空间和样本空间都不重叠的数据集。在这种情况下,不同的客户端拥有不同的特征集,但它们可能有共同的目标变量或标签。例如,医院和金融机构可能都拥有关于某位客户的信息,但医院拥有的是客户的医疗记录和健康信息,而金融机构掌握的是客户的财务和交易信息。虽然两者的数据特征集不同,但它们可能都关注客户的信用评分这一共同目标。纵向联合学习使得这两方可以在不直接共享数据的情况下,共同训练一个预测客户信用评分的模型。

MOON是一个纵向的联合学习算法。
源码地址:https://github.com/QinbinLi/MOON.git
论文地址:https://arxiv.org/pdf/2103.16257.pdf

算法实现步骤

横向联合学习是一种有效的机器学习模型训练方法,它允许多个组织或客户端在保护数据隐私的前提下共同训练和改进模型。以下是横向联合学习的基本算法步骤:
步骤 0:初始化全局模型

  • 在中央服务器上初始化一个全局模型,其参数设置为随机状态。这个模型将作为联合学习的基础。

步骤 1:分发全局模型

  • 将当前的全局模型复制并分发给参与横向联合学习的每个客户端。每个客户端都将拥有一个全局模型的副本。

步骤 2:本地模型训练

  • 每个客户端使用自己的本地数据对分发来的全局模型进行训练。由于每个客户端的数据是独立的,这一步是在本地完成的,不涉及数据共享。
  • 在训练过程中,客户端会根据本地数据调整模型的参数,以最小化预测误差。

步骤 3:上传更新权重

  • 训练完成后,客户端将更新后的模型参数(通常是权重的变化量,而不是原始权重)发送回中央服务器。这个过程确保了原始数据不会离开客户端。

步骤 4:聚合更新

  • 中央服务器收集所有客户端发送的更新权重,并通过聚合算法(如联邦平均算法Federated Averaging, FedAvg)来更新全局模型。这个聚合过程考虑了所有客户端的贡献,以确保全局模型能够反映所有参与方的数据特征。

多轮迭代

  • 从步骤 1 到步骤 4 的过程被称为一轮联合学习。为了提高全局模型的准确性和泛化能力,这个过程需要重复多轮。每一轮结束后,全局模型都会得到更新和改进。


在联合学习中,全局模型的构建通常依赖于从各个客户端收集的模型更新,这些更新通常是通过最小化损失函数来获得的。然而,当客户端数据分布不均衡时,传统的联合学习方法可能会遇到准确性下降的问题。这是因为少数类样本可能会在全局模型中被忽视,导致模型性能在这些类别上不佳。

为了解决这一问题,研究人员提出了多种改进方法。其中一种方法是在损失函数中引入修正项,以便更好地处理数据不平衡问题。这些修正项可以根据数据集的特性进行调整,从而提高模型对少数类的识别能力。

MOON(Model-Contrastive Federated Learning)就是这样一种改进方法。MOON的核心思想是利用模型表示之间的相似性来纠正单个客户端的本地训练。具体来说,MOON通过在模型层面进行对比学习,即在全局模型和本地模型之间引入对比损失(model-contrastive loss)。这种方法不仅考虑了根据标签产生的损失(如交叉熵损失),还通过模型对比损失来减少当前更新的模型产生的特征与全局模型产生的特征的距离,并增大当前模型产生的特征和上一轮模型产生的特征的距离25。

4.图像分类中使用联合学习

联合学习是一种用于表格数据分析和文本数据分类等领域的技术,目前正被应用于图像分类领域。然而,与表格或文本数据相比,图像数据更加复杂多样,而且这些数据在客户端之间往往不平衡。因此,当联合学习应用于图像分类时,全局模型和局部模型可能会出现分歧,从而导致准确率不高。

5.对比学习

对比学习作为一种无监督学习方法,在图像分类领域中正变得越来越重要。它的基本思想是通过比较图像之间的相似性和差异性来学习有用的特征表示,而不需要显式的标签信息。这种方法特别适用于那些标签难以获得或成本高昂的场景。

在对比学习中,模型的目标是使得同一类别的图像对在特征空间中更相似,而不同类别的图像对更不相似。为了实现这一目标,通常会定义一个对比损失函数,如InfoNCE损失,它鼓励模型拉近正样本对之间的距离,同时推远负样本对之间的距离。

特征向量是通过深度学习模型,特别是卷积神经网络(CNN)来提取的。CNN是一种强大的图像处理模型,它通过卷积层来提取图像的局部特征,并通过池化层(汇集层)来降低特征的空间维度,从而减少计算量并提高模型的泛化能力。CNN的设计使得它对图像的平移、缩放和旋转等变换保持不变性,这在图像识别任务中是非常重要的。

通过使用CNN作为编码器,对比学习可以有效地从原始图像中提取有用的特征表示。这些特征随后可以用于各种下游任务,包括图像分类、目标检测和语义分割等。著名的CNN模型如AlexNet和GoogleNet等在图像识别领域取得了突破性的成果,它们的成功也推动了深度学习在计算机视觉领域的广泛应用。

对比学习是主要的无监督学习方法之一,其中包括自监督学习。无监督学习有多种方法,如基于生成任务的方法、基于判别回归任务的方法和基于比较任务的方法,但归类为判别和回归任务以及比较任务的方法有时被称为自监督学习方法。这种方法可以从数据本身创建教师信号(标签)。这种方法可用于使用自动编码器生成图像,以及在自然语言处理中学习嵌入式单词表示法。自监督学习具有预学习和微调相结合的结构,是一个有潜力应用于广泛任务的领域,同时提高了学习的普遍性。它有望应用于语音识别和自动驾驶等多个领域。

对比学习(SimCLR)机制

在没有标签指导的情况下,学习图像之间的语义距离是对比学习的核心任务。SimCLR模型通过对比度学习的方法,有效地解决了这一问题。以下是SimCLR模型的关键要点:

  1. 正负样本对:SimCLR模型通过对比学习框架,为每个图像生成两个不同的增强视图作为正样本对,同时从数据集中随机选择其他图像作为负样本。模型的目标是使得正样本对在特征空间中更接近,而负样本则更远离。

  2. 数据增强:SimCLR使用一系列数据增强技术来创建图像的变体,这些技术包括随机裁剪、颜色失真、高斯模糊等。这些增强技术增加了数据的多样性,帮助模型学习到更鲁棒的特征表示。

  3. 特征提取:SimCLR使用卷积神经网络(CNN)作为编码器,将原始图像和增强后的图像转换成高维特征向量。这些特征向量应该捕捉到图像的关键视觉信息,并在特征空间中保持图像的语义相似性。

  4. 对比损失函数:SimCLR采用了一种特殊的对比损失函数,即InfoNCE损失,它通过最大化正样本对之间的相似度(通过最小化它们之间的距离)和最小化负样本对之间的相似度(通过最大化它们之间的距离)来训练模型。

  5. 归一化温度:SimCLR引入了温度参数(T)来调整相似度的尺度。通过这种方式,模型可以学习到更平滑的特征表示,有助于区分细微的图像差异。

  6. 预训练和微调:SimCLR模型可以首先在大规模的未标注数据集上进行预训练,学习通用的特征表示。之后,可以在有标注的数据集上进行微调,以适应特定的下游任务。

通过这些机制,SimCLR能够有效地学习图像的特征表示,这些表示可以用于各种视觉任务,如图像分类、目标检测等。SimCLR的成功展示了对比学习在无监督学习领域的巨大潜力,特别是在处理高维视觉数据时的有效性。其损失函数如下式所示:

这里,sim是余弦相似度的函数,即特征向量在方向上的相似程度。分子表明,学习效果越好,源自同一图像的特征向量的相似性就越大,损失函数就越小。分母计算的是特定单一图像 xk 与除 xk 以外所有图像的余弦相似度,它表明学习的越好,从不同图像中提取的特征向量的相似度就越小,损失函数就越小。

6. MOON

MOON(Model-Contrastive Federated Learning)机制是一种在联合学习框架下进行模型训练的方法,它特别关注于如何提高模型在面对非独立同分布(non-IID)数据时的性能。在MOON中,通过在客户端和全局模型之间进行特征对比学习,可以有效地整合来自不同客户端的数据特征,从而提高全局模型的泛化能力。

在MOON机制中,涉及的三个关键组件包括:

  1. 客户端一轮前的模型:这是在客户端进行一轮训练之前的模型状态。它包含了客户端数据的特征表示和学习到的知识。

  2. 全局模型:这是由中央服务器维护的模型,它聚合了所有客户端模型的更新,代表了所有客户端数据的全局知识。

  3. 当前的本地模型:这是客户端在当前轮次中更新后的模型。它通过在本地数据上训练得到,并准备上传更新到全局模型。

在这三种模型中,通常会使用卷积神经网络(CNN)作为编码器来提取图像的特征。CNN的红色层通过卷积和池化操作,从图像中提取出有用的视觉特征。随后,这些特征被送入蓝色层,即多层感知器层,它是一种全连接层,用于进一步处理和学习特征的高级表示。最终,这些特征被转换成一定维度的向量,并用于预测分类结果的概率分布,即黄色层所示。

多层感知器层是一种强大的神经网络模型,它由多个全连接层组成,能够学习复杂的非线性映射。这种层的优势在于其简单性和能力,可以从原始数据中学习到复杂的模式和结构。在MOON机制中,多层感知器层不仅用于学习特征表示,还用于在客户端和全局模型之间进行对比学习,从而提高模型对于不同数据分布的适应性。

MOON 损失函数的定义

如前所述,在用于图像分类的传统联盟学习中,图像分类是通过在客户端拥有的图像之间引入对比学习(即通过单一机器学习模型传递客户端内的图像并比较其输出)来实现的。相比之下,MOON 方法采用不同的方法得出损失函数,并将其添加到传统的损失函数中。

具体的损失函数如下

μ 是预先指定的超参数。基于单一模型的对比度训练所产生的损失函数会被添加到基于单一图像的对比度训练所产生的损失函数中。

等式的后半部分显示,对于给定的某幅图像,通过客户端模型获得的输出结果与该客户端一轮之前的模型输出结果进行对比,以进行学习。

由于这可能是模型中最重要的部分,我们将通过概念对比图来重申这一点。

左图显示的是 SimCLR 方法,来自一个客户端的图像通过一个机器学习模型来得出图像的相似度。右图显示的是 MOON 方法,该方法将一张图像通过多个不同的机器学习模型运行,并比较它们的输出结果。从图中可以看出,虽然构成相似,但我们比较的是不同的概念。

MOON 模型

联盟学习中的全局模型是将每个客户的平均损失函数加权值与所有客户的平均损失函数加权值相加,并学习如何使损失函数最小化。每个客户的平均损失函数是该客户所有图像对损失函数的平均值。

MOON 的模型更新算法如图所示。其中,T 是通信总数,N 是客户总数,E 是局部历元数,η 是联盟学习中的学习率,τ 是预先指定的超参数。

7.实验结果

为了证实 MOON 在图像分类方面的准确性,我们将图像分类的准确性与 Fed Average、Fed Prox 和其他联盟学习方法等现有方法进行了比较。我们使用了三个图像数据集**:CIFAR-10、CIFAR-100 和 Tiny-ImageNet**。这三个数据集都是计算机视觉领域的基准自然图像数据集。

Res-Net50 被用作图像分类的基础。Res-Net 是一种专门用于图像分类的机器学习模型。该模型引入了一种称为 "跳过连接"(skip connections)的机制,通过一种称为 “跳过连接”(skip connections)的过程,即跳过各层,解决了 “梯度损失”(de-gradation)问题,即随着层数加深,准确率下降,难以优化函数。这一概念已被引入各种深度学习模型中。

假设局部历时(虚拟客户机数量)为 10,实验次数为 3,得出的平均值和标准偏差结果如下。

实验表明,在这两个图像数据集上,MOON 的图像分类准确率均高于现有方法。

8.总结

联合学习是一种在数据分布的情况下,既能以低成本进行机器学习,又能保护隐私的方法,引入这种方法不仅能保护隐私,还能降低更新模型时向中央服务器发送数据的通信成本。这种方法的引入不仅有望保护隐私,还能降低更新模型时向中央服务器发送数据的通信成本。

标签:数据,模型,MOON,学习,图像,原始数据,客户端
From: https://blog.csdn.net/matt45m/article/details/137497134

相关文章

  • 前端学习<四>JavaScript基础——11-流程控制语句:选择结构(if和switch)
    代码块用{}包围起来的代码,就是代码块。在ES5语法中,代码块,只具有分组的作用,没有其他的用途。代码块中的内容,在外部是完全可见的。举例: {   vara=2;   alert('qianguyihao');   console.log('千古壹号'); } ​ console.log('a='+a);打印结......
  • 前端学习<四>JavaScript基础——10-运算符
    我们在前面讲过变量,本文讲一下运算符和表达式。运算符的定义和分类运算符的定义运算符:也叫操作符,是一种符号。通过运算符可以对一个或多个值进行运算,并获取运算结果。表达式:数字、运算符、变量的组合(组成的式子)。表达式最终都会有一个运算结果,我们将这个结果称为表达式的......
  • PCB学习记录-----入门&基础知识
    一、搭建环境1.下载嘉立创EDA 软件下载-嘉立创EDA(lceda.cn)选专业版在线编辑:嘉立创EDA(专业版)-V2.1.45(lceda.cn)官方教程:立创EDA专业版-使用教程(lceda.cn)2.新建工程文件-新建-项目,右键Board1可以重命名,原理图右键新增图页右侧图纸尺寸可自定义调整图纸......
  • 吴恩达2022机器学习专项课程(一) 5.5 特征缩放1 & 5.6 特征缩放2
    问题预览/关键词什么是特征缩放?作用是什么?特征尺度和参数w权重的关系是?算法为什么要调节w权重?不进行特征缩放对梯度下降的影响?有特征缩放对梯度下降的影响?实现特征缩放的三种方法是?如何实现最大值缩放?如何实现均值归一化?如何实现Z-score标准化?判断缩放成功的标准是?什么情况......
  • 从零开始的深度学习项目(PyTorch识别人群行为)
    PyTorch识别人群行为系统环境介绍环境版本Python3.11.5pandas2.0.3numpy1.24.3torch2.1.2+cu121注意:2.1.2+cu121这样的版本号通常用于描述TensorFlow等深度学习框架的版本信息,其中:2.1.2是TensorFlow的主要版本号,表示主要的功能和接口的变化。cu121表示该Tenso......
  • 深度学习-卷积神经网络--什么是manifold embedding--66
    目录参考:流形假设(ManifoldHypothesis)在介绍流形学习(Manifoldlearning)之前,首先需要理解一个假设,就是流形假设(ManifoldHypothesis)。这个假设认为,高维数据很多都是低维流形嵌入(embedding)于高维空间当中,比如说三维空间里的各种平面或者曲面,虽然这些平面或者曲面处于三......
  • JAVA语言学习-Day5
    集合Java中的集合是工具类,可以存储任意数量的具有共同属性的对象应用场景无法预测存储数据的数据同时存储具有一对一关系的数据需要进行数据的增删数据重复问题体系结构Collection:List、Queue、SetMap:HashMapList有序且可重复,ArrayList、LinkedList......
  • 学习笔记445—白盒测试用例设计方法(语句覆盖、判定覆盖、条件覆盖、判定/条件覆盖、组
    白盒测试用例设计方法(语句覆盖、判定覆盖、条件覆盖、判定/条件覆盖、组合覆盖、路径覆盖、基本路径覆盖语句覆盖:每条语句至少执行一次。判定覆盖:每个判定的所有可能结果至少出现一次。(又称“分支覆盖”)条件覆盖:每个条件的所有可能结果至少执行一次。判定/条件覆盖:一个判定中的每......
  • 【学习笔记】基础数据结构:猫树
    猫树是线段树的一个特殊版本,猫树不再支持修改操作,类似\(\text{ST}\)表猫树支持高速区间查询,每次查询都只需要进行\(1\)次合并操作,设单次合并操作的复杂度为\(O(k)\),建立猫树的复杂度是\(O(kn\logn)\)的,而查询的复杂度是\(O(k)\)的一般单次查询的复杂度是\(O(1)\),所......
  • HectorSlam学习
    #HectorSlam##一、安装gitclonehttps://github.com/tu-darmstadt-ros-pkg/hector_slam克隆到工作空间,编译无误后即可+注意:gitbranch检查分支+注意:gitcheckout-b分支名切换到分支+注意:gitbranch-a查看所有分支+注意:gitcheckout分支名切换到分支+注意:gitbranch......