首页 > 其他分享 >论文解读(FedPCL)《Federated Learning from Pre-Trained Models: A Contrastive Learning Approach》

论文解读(FedPCL)《Federated Learning from Pre-Trained Models: A Contrastive Learning Approach》

时间:2022-10-20 19:46:57浏览次数:79  
标签:Pre right mathbf Contrastive 原型 Learning quad 客户端 left

论文信息

论文标题:Federated Learning from Pre-Trained Models: A Contrastive Learning Approach
论文作者:Yue Tan, Guodong Long, Jie Ma, Lu Liu, Tianyi Zhou, Jing Jiang
论文来源:2022,ICML
论文地址:download 
论文代码:download 

1 Introduction

  联邦学习 Federated Learning (FL) + 原型对比学习 Prototype Contrastive Learning。

  解决问题:通过使用原型对比学习,解决朴素版本联邦学习存在的数据之间异质性问题。

2 Problem Formulation

General FL Framework

  组成:$1$ 个服务器(server), $m$ 个客户端(client)。

  拥有 $m$ 个客户端的 $\text{global objective}$ 是:

    $\underset{\left(w_{1}, w_{2}, \cdots, w_{m}\right)}{\text{min}}   \frac{1}{m} \sum_{i=1}^{m} \frac{\left|D_{i}\right|}{N} L_{i}\left(w_{i} ; D_{i}\right) \quad\quad\quad(1)$

  其中,$L_{i}$ 和 $w_{i}$ 分别为客户端 $i$ 的局部损失函数和模型参数,$D_{i}$ 是第 $i$ 个客户端的私有数据集,$N$ 是所有客户端中的样本总数。

  朴素版本的联邦学习,如 FedAvg :将服务器和所有客户端的模型参数设置为一样 $w=w_{1}=w_{2}=\cdots=w_{m}$ (同步更新)。然而由于数据的异构性,参数同步有时会降低本地数据集的性能。[6,60] 等个性化联邦学习(personalized Federated Learning) 通过对客户端应用各种约束和正则化项解决这个问题,这允许客户端保留不同的模型参数 $w_{i}, i \in[1, m]$,以在本地数据集上获得更高的性能。

The Proposed Lightweight FL Framework

  Lightweight FL Framework 拥有 $m$ 个客户端和一个中心服务器。每个客户端 $i \in[1, m]$ 都拥有 $K$ 个共享和固定的预训练 backbone,以及一个不能彼此共享的私有数据集 $D_{i}$。每个可学习的模型可以由两部分组成:

  (1) Feature Encoder:$r\left(\cdot ; \Phi^{*}\right): \mathbb{R}^{d} \rightarrow \mathbb{R}^{K \times d_{e}}$,包含 $K$ 个预训练的的 backbone,每个 backbone 将 $d$ 维的原始样本  $\mathbf{x}$ 变成 $d_{e}$ 维,最终将 $K$ 个backbone 的表示拼接得到 $r_{\mathbf{x}}$ 。

  (2) Projection Network :将 $h\left(\cdot ; \theta_{i}\right): \mathbb{R}^{K \times d_{e}} \rightarrow \mathbb{R}^{d_{h}}$ , 将 $r_{\mathbf{x}}$ 进行映射,以进行进一步的表示学习。

  正式的定义如下:

    $r\left(\mathbf{x} ; \Phi^{*}\right):=\operatorname{concat}\left(r_{1}\left(\mathbf{x} ; \phi_{1}^{*}\right), \ldots, r_{K}\left(\mathbf{x} ; \phi_{K}^{*}\right)\right) \quad\quad\quad(2)$

    $z(\mathbf{x})=h\left(r_{\mathbf{x}} ; \theta_{i}\right)\quad\quad\quad(3)$

  Global objective 如下:

    $\begin{array}{l}\underset{\theta_{1}, \theta_{2}, \ldots, \theta_{m}}{\text{min}}\quad\sum\limits _{i=1}^{m}  \frac{\left|D_{i}\right|}{N} \mathbb{E}_{(\mathbf{x}, y) \in D_{i}}\left[L_{i}\left(\theta_{i} ; z(\mathbf{x}), y\right)\right]\\\text{s.t.} \quad  z(\mathbf{x})=h\left(r\left(\mathbf{x} ; \Phi^{*}\right) ; \theta_{i}\right)  \quad\text{where}\quad  \Phi^{*}=\left\{\phi_{1}^{*}, \phi_{2}^{*}, \cdots, \phi_{K}^{*}\right\}\end{array}\quad\quad\quad(4)$

3 Federated Prototype-wise Contrastive Learning (FedPCL)

  整体框架如下:

  

Prototype as the Information Carrier

  为捕获类相关的语义信息,本文建议在客户端和服务器之间传输类原型信息,传输类原型的优点:

  首先,原型的形式更紧凑,大大降低了训练过程中所需的通信成本。

  其次,非参数通信允许每个客户端学习一个更多定制的本地模型,而无需与其他客户端同步参数。第三,原型是高级的统计信息,而不是原始的特性,它不会给系统产生额外的隐私问题,并且对基于梯度的攻击[66,67]具有鲁棒性。

  本地类相关特征(class-relevant information),为同一类 $j$ 内的融合表示的平均值组成:

    ${\large C_{i}^{(j)}:=\frac{1}{\left|D_{i, j}\right|} \sum\limits _{(\mathbf{x}, y) \in D_{i, j}} z(\mathbf{x})} \quad\quad\quad(5)$

  其中,$D_{i, j}$ 表示由属于类 $j$ 的所有实例组成的 $D_{i}$ 的子集,$\boldsymbol{C}_{i}$ 表示第 $i$ 个客户端的局部原型集。经上述计算,将每个客户端的本地原型集发送到中央服务器进行知识聚合,中央服务器共享基于其本地数据集在每个特定客户端上提取的本地类相关信息。

    ${\large \bar{C}^{(j)}:=\frac{1}{\left|\mathcal{N}_{j}\right|} \sum\limits _{i \in \mathcal{N}_{j}} \frac{\left|D_{i, j}\right|}{N_{j}} C_{i}^{(j)}} \quad\quad\quad(6)$

  其中,$\mathcal{N}_{j}$ 表示拥有类 $j$ 实例的客户端集,$N_{j}$ 表示在所有客户端上属于类 $j$ 的实例数。全局原型集表示为 $\mathbb{C}= \left\{\bar{C}^{(1)}, \bar{C}^{(2)}, \ldots\right\}$。通过这种聚合机制,全局原型集总结了所有客户端共享的粗粒度的类相关知识,这为表示学习提供了一个高级的视角

  聚合之后,服务器将全局原型集和从所有客户端收集到的完整本地原型集发送回每个客户端。对于在某些非 IID 情况下,每个客户端中只有几个类的情况,我们在服务器中引入了一个原型填充过程,以确保每个本地原型集包含对应于所有类的原型:

    $C_{i}^{(j)}=\left\{\begin{array}{ll}C_{i}^{(j)}, & i \in \mathcal{N}_{j} \\\bar{C}^{(j)}, & i \notin \mathcal{N}_{j}\end{array}\right.\quad\quad\quad(7)$

  基于原型的通信和聚合允许每个客户端拥有一个独特的投影网络,该网络能够以定制的方式融合一般表示。返回的本地原型集可以鼓励从客户相关的角度相互学习,而全局原型集,其中每个元素表示总体数据中的一个类中心,提供了一个从高度总结的与客户无关的角度学习的机会。

Local Training

  在从 server 接收到原型集后,局部训练的主要目标是分别有效地从局部原型和全局原型中提取有用的知识,从而最大限度地有利于局部表示学习。为了实现这一点,我们提出了一个原型式的监督对比损失,它包括两个项,即全局项和局部项。

  为了迫使局部投影网络生成的融合表示 $z (x)$ 更接近其对应的全局类中心,从而提取更多与类相关但与客户端无关的信息,我们将基于全局原型的全局损失项定义为

    $L_{\mathrm{g}}=\sum\limits _{(\mathbf{x}, y) \in D_{i}}-\log \frac{\exp \left(z_{\mathbf{x}} \cdot \bar{C}^{(y)} / \tau\right)}{\sum\limits_{y_{a} \in A(y)} \exp \left(z_{\mathbf{x}} \cdot \bar{C}^{\left(y_{a}\right)} / \tau\right)}\quad\quad\quad(8)$   其中 $z_{\mathbf{x}}$ 表示 $z(\mathbf{x})$,$A(y):=\left\{y_{a} \in[1,|\mathbb{C}|]: y_{a} \neq y\right\}$ 是不同于 $y$ 的标签集。对于从 $D_{i}$ 中采样的特定实例 $x$,我们使用一个内点积来度量融合表示 $z_{\mathbf{x}}$ 和原型之间的相似性。

  除了全局项之外,为了通过潜在空间中的交替客户对比学习将 $z(\mathbf{x}) $ 与每个客户的本地原型对齐,并实现更多的客户间知识共享,我们将基于局部原型的损失项定义为

    ${\large L_{\mathrm{p}}=\sum\limits _{(\mathbf{x}, y) \in D_{i}}-\frac{1}{m} \sum\limits_{p=1}^{m} \log \frac{\exp \left(z_{\mathbf{x}} \cdot C_{p}^{(y)} / \tau\right)}{\sum\limits_{y_{a} \in A(y)} \exp \left(z_{\mathbf{x}} \cdot C_{p}^{\left(y_{a}\right)} / \tau\right)}} \quad\quad\quad(9)$

  对于第 $i$ 个 client,目标函数,形式如下:

    $L\left(\theta_{i} ; z(\mathbf{x}), y, \mathbb{C},\left\{\boldsymbol{C}_{p}\right\}_{p=1}^{m}\right)=L_{\mathrm{g}}\left(\theta_{i} ; z(\mathbf{x}), y, \mathbb{C}\right)+L_{\mathrm{p}}\left(\theta_{i} ; z(\mathbf{x}), y,\left\{\boldsymbol{C}_{p}\right\}_{p=1}^{m}\right)\quad\quad\quad(10)$   在每一轮的 local train 结束时,客户机将他们的本地原型集 $C_{i}$ 上传到服务器上。我们在 Algorithm 1 中详细介绍了 FedPCL 算法。

  

4 Experiments

4.1 Datasets

  包括 3 个 benchmark datasets :

    • Digit-5—— SVHN, USPS, SynthDigits, MNIST-M, MNIST
    • Office-10—— Amazon, Caltech, DSLR, WebCam
    • DomainNet—— Clipart, Info, Painting, Quickdraw, Real, Sketch

4.2 Result

  

标签:Pre,right,mathbf,Contrastive,原型,Learning,quad,客户端,left
From: https://www.cnblogs.com/BlairGrowing/p/16807674.html

相关文章

  • 记录|在WordPress设立HTTPS
    TL;DR:根据HTTPSforWordPress–WordPress.orgForums上的指示操作即可。获取HTTPS证书到Certbot上去,填MyHTTPwebsiteisrunning[Nginx]on[Ubuntu20]......
  • VitePress
    VitePress快速上手本节将帮助你从头开始搭建一个简单的VitePress文档站点。如果你已经有了一个存在的项目并且向在项目中维护文档,你可以从步骤3开始如果你还没有配......
  • 通俗易懂谈强化学习之Q-Learning算法实战
     Datawhale干货 作者:KingJames,伦敦国王大学前言:上篇介绍了什么是强化学习,应大家需求,本篇实战讲解强化学习,所有的实战代码可以自行下载运行。本篇使用强化学习领域经典的P......
  • DeepRec 做了哪些优化?
    前言这段时间参加了天池上的“DeepRecCTR模型性能优化”比赛,通过阅读DeepRec官方文档,可以了解DeepRec做了哪些优化,哪些优化可以迁移借鉴,哪些优化是针对推荐系统的......
  • 论文解读(GGD)《Rethinking and Scaling Up Graph Contrastive Learning: An Extremely
    论文信息论文标题:RethinkingandScalingUpGraphContrastiveLearning:AnExtremelyEfficientApproachwithGroupDiscrimination论文作者:YizhenZheng,ShiruiP......
  • 使用 TensorFlow 和 Express.js 实现AI图像识别
    人工智能是一个相当广泛的领域,以其惊人的能力和复杂的性质吸引着人们的兴趣。自1940年代首次提出使用电子设备重建大脑的可能性以来,这项技术已经取得了长足的进步。尽管人......
  • 【Hive】presto超过17位的数字显示问题
    在HUEPresto查询页面里,如果查询的值是超过17位的bigint,直接显示的结果是有问题的,如下所示: 从图中可以看出,直接显示的id,最后两位是30,但id转换成字符串显示的结果却是41......
  • 使用vitepress构建组件库文档
    vitepress-demoblock  为vitepress添加更专业的Demo演示能力,让您在开发vue组件库或者vue相关文档编写时,可以通过引入vue文件的时候结果显示和代码演示。文档官方文......
  • 为 vitepress 添加更专业的 Demo 演示能力
    vitepress-demoblock  为vitepress添加更专业的Demo演示能力,让您在开发vue组件库或者vue相关文档编写时,可以通过引入vue文件的时候结果显示和代码演示。文档官方文......
  • C# Expression学习笔记(一、表达式与表达式树的基本结构)
    一个美丽的邂逅        昨天心血来潮,想着用了很久的HangFire这个任务调度组件,却从来没有研究过其源码,所以我就想着看一下Hangfire的源码,然后当我看到Hangfire源码......