首页 > 其他分享 >分布式机器学习(Parameter Server)

分布式机器学习(Parameter Server)

时间:2023-05-27 20:56:33浏览次数:43  
标签:模型 self Server 参数 device 服务器 Parameter 节点 分布式

分布式机器学习中,参数服务器(Parameter Server)用于管理和共享模型参数,其基本思想是将模型参数存储在一个或多个中央服务器上,并通过网络将这些参数共享给参与训练的各个计算节点。每个计算节点可以从参数服务器中获取当前模型参数,并将计算结果返回给参数服务器进行更新。

为了保持模型一致性,通常采用下列两种方法:

  1. 将模型参数保存在一个集中的节点上,当一个计算节点要进行模型训练时,可从集中节点获取参数,进行模型训练,然后将更新后的模型推送回集中节点。由于所有计算节点都从同一个集中节点获取参数,因此可以保证模型一致性。
  2. 每个计算节点都保存模型参数的副本,因此要定期强制同步模型副本,每个计算节点使用自己的训练数据分区来训练本地模型副本。在每个训练迭代后,由于使用不同的输入数据进行训练,存储在不同计算节点上的模型副本可能会有所不同。因此,每一次训练迭代后插入一个全局同步的步骤,这将对不同计算节点上的参数进行平均,以便以完全分布式的方式保证模型的一致性,即All-Reduce范式

PS架构

在该架构中,包含两个角色:parameter server和worker

parameter server将被视为master节点在Master/Worker架构,而worker将充当计算节点负责模型训练

整个系统的工作流程分为4个阶段:

  1. Pull Weights: 所有worker从参数服务器获取权重参数
  2. Push Gradients: 每一个worker使用本地的训练数据训练本地模型,生成本地梯度,之后将梯度上传参数服务器
  3. Aggregate Gradients:收集到所有计算节点发送的梯度后,对梯度进行求和
  4. Model Update:计算出累加梯度,参数服务器使用这个累加梯度来更新位于集中服务器上的模型参数

可见,上述的Pull Weights和Push Gradients涉及到通信,首先对于Pull Weights来说,参数服务器同时向worker发送权重,这是一对多的通信模式,称为fan-out通信模式。假设每个节点(参数服务器和工作节点)的通信带宽都为1。假设在这个数据并行训练作业中有N个工作节点,由于集中式参数服务器需要同时将模型发送给N个工作节点,因此每个工作节点的发送带宽(BW)仅为1/N。另一方面,每个工作节点的接收带宽为1,远大于参数服务器的发送带宽1/N。因此,在拉取权重阶段,参数服务器端存在通信瓶颈。

对于Push Gradients来说,所有的worker并发地发送梯度给参数服务器,称为fan-in通信模式,参数服务器同样存在通信瓶颈。

基于上述讨论,通信瓶颈总是发生在参数服务器端,将通过负载均衡解决这个问题

将模型划分为N个参数服务器,每个参数服务器负责更新1/N的模型参数。实际上是将模型参数分片(sharded model)并存储在多个参数服务器上,可以缓解参数服务器一侧的网络瓶颈问题,使得参数服务器之间的通信负载减少,提高整体的通信效率。

代码实现

定义网络结构:

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        if torch.cuda.is_available():
            device = torch.device("cuda:0")
        else:
            device = torch.device("cpu")

        self.conv1 = nn.Conv2d(1,32,3,1).to(device)
        self.dropout1 = nn.Dropout2d(0.5).to(device)
        self.conv2 = nn.Conv2d(32,64,3,1).to(device)
        self.dropout2 = nn.Dropout2d(0.75).to(device)
        self.fc1 = nn.Linear(9216,128).to(device)
        self.fc2 = nn.Linear(128,20).to(device)
        self.fc3 = nn.Linear(20,10).to(device)

    def forward(self,x):
        x = self.conv1(x)
        x = self.dropout1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.dropout2(x)
        x = F.max_pool2d(x,2)
        x = torch.flatten(x,1)

        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)

        output = F.log_softmax(x,dim=1)

        return output

如上定义了一个简单的CNN

实现参数服务器:

class ParamServer(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = Net()

        if torch.cuda.is_available():
            self.input_device = torch.device("cuda:0")
        else:
            self.input_device = torch.device("cpu")

        self.optimizer = optim.SGD(self.model.parameters(),lr=0.5)

    def get_weights(self):
        return self.model.state_dict()

    def update_model(self,grads):
        for para,grad in zip(self.model.parameters(),grads):
            para.grad = grad

        self.optimizer.step()
        self.optimizer.zero_grad()

get_weights获取权重参数,update_model更新模型,采用SGD优化器

实现worker:

class Worker(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = Net()
        if torch.cuda.is_available():
            self.input_device = torch.device("cuda:0")
        else:
            self.input_device = torch.device("cpu")

    def pull_weights(self,model_params):
        self.model.load_state_dict(model_params)

    def push_gradients(self,batch_idx,data,target):
        data,target = data.to(self.input_device),target.to(self.input_device)
        output = self.model(data)
        data.requires_grad = True
        loss = F.nll_loss(output,target)
        loss.backward()
        grads = []

        for layer in self.parameters():
            grad = layer.grad
            grads.append(grad)

        print(f"batch {batch_idx} training :: loss {loss.item()}")

        return grads

Pull_weights获取模型参数,push_gradients上传梯度

训练

训练数据集为MNIST

import torch
from torchvision import datasets,transforms

from network import Net
from worker import *
from server import *

train_loader = torch.utils.data.DataLoader(datasets.MNIST('./mnist_data', download=True, train=True,
               transform = transforms.Compose([transforms.ToTensor(),
               transforms.Normalize((0.1307,),(0.3081,))])),
               batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('./mnist_data', download=True, train=False,
              transform = transforms.Compose([transforms.ToTensor(),
              transforms.Normalize((0.1307,),(0.3081,))])),
              batch_size=128, shuffle=True)

def main():
    server = ParamServer()
    worker = Worker()

    for batch_idx, (data,target) in enumerate(train_loader):
        params = server.get_weights()
        worker.pull_weights(params)
        grads = worker.push_gradients(batch_idx,data,target)
        server.update_model(grads)

    print("Done Training")

if __name__ == "__main__":
    main()

标签:模型,self,Server,参数,device,服务器,Parameter,节点,分布式
From: https://www.cnblogs.com/N3ptune/p/17437320.html

相关文章

  • 分布式CAP理论
    分布式:一个大业务拆分成多个小业务并部署在不同的服务器上CAP:一个分布式系统最多只能同时满足一致性(Consistency)、可用性(Availability)和分区容错性(Partitiontolerance)这三项中的两项。  网络问题不可避免,P(分区容错性)是一定需要保证的如果此时有节点故障,如果剩余节点正常......
  • 分布式基础之CAP理论&BASE理论
    1.CAP理论1.1)含义C(Consistency一致性)、Availability(可用性)、PartitionTolerance(分区容错性)。1.2)具体意义一致性(Consistency):所有节点访问同一份最新的数据副本可用性(Availability):非故障的节点在合理的时间内返回合理的响应(不是错误或者超时的响应)。分区容错性(Partition......
  • 配置GlusterFS分布式文件系统​
    拓扑图:推荐步骤:在Centos01到Centos04,在每台服务器创建四个分区格式化为XFS文件系统自动设置开机自动挂载在Centos01到Centos04安装glusterFS分布式存储系统创建配置glusterfs群集和创建分布式条带卷、分布式复制卷、分布式卷、条带卷实验步骤:一.在Centos01到Centos04,在每台服务器创......
  • 分布式事务的21种武器 - 6
    在分布式系统中,事务的处理分布在不同组件、服务中,因此分布式事务的ACID保障面临着一些特殊难点。本系列文章介绍了21种分布式事务设计模式,并分析其实现原理和优缺点,在面对具体分布式事务问题时,可以选择合适的模式进行处理。原文:ExploringSolutionsforDistributedTransactio......
  • 小马哥Java分布式架构训练营第一期服务治理-鱼龙潜跃水成文
    小马哥Java分布式架构训练营第一期服务治理download:3w51xuebccom使用Netty和SpringBoot实现仿微信的示例在本文中,我们将使用Netty和SpringBoot框架来创建一个简单的聊天应用程序,类似于微信。这个应用程序将支持多用户聊天和即时消息发送。下面让我们来一步步看看如何实现。第一......
  • ubuntu server 20.4设置使用root登录
    ubuntu@ubuntu:~$sudopasswdrootNewpassword:Retypenewpassword:passwd:passwordupdatedsuccessfullyubuntu@ubuntu:~$suroot#切换到root账户Password:root@ubuntu:/home/ubuntu#使用vim/etc/ssh/sshd_config编辑配置文件找到#PermitRootLoginprohib......
  • LDAPserver相关配置
    [root@schedulershell]#catldapserver.sh#!/bin/bash##LdapServerinstallScript#author:liulingfeng#2023-04-29#--------------------------------------------#1、关闭防火墙sed-i'/SELINUX/s/enforcing/disabled/'/etc/selinux/configsystemctl......
  • 设计模式-观察者模式(Observer)
    一、 观察者(Observer)模式观察者模式又叫做发布-订阅(Publish/Subscribe)模式、模型-视图(Model/View)模式、源-监听器(Source/Listener)模式或从属者(Dependents)模式。观察者模式定义了一种一对多的依赖关系,让多个观察者对象同时监听某一个主题对象。这个主题对象在状态上发生变化时,会通......
  • 十二、集成分布式事务组件Seata
    什么是Seata网址:seata.ioSeata是一款开源的分布式事务解决方案,致力于提供高性能和简单易用的分布式事务服务。Seata将为用户提供了AT、TCC、SAGA和XA事务模式,为用户打造一站式的分布式解决方案。 seata术语TC(TransactionCoordinator)-事务协调者维护全局和分支......
  • 【K8s入门推荐】K8s1.24版本部署全教程,轻松掌握技巧kubeadm丨Kubernetes丨容器编排丨
    通过kubeadm方式极速部署Kubernetes1.24版本前言在Kubernetes的搭建过程中,繁琐的手动操作和复杂的配置往往会成为制约部署效率的关键因素。而使用kubeadm工具可以避免这些问题,大大提高集群的部署效率和部署质量。本文将为大家详细介绍如何使用kubeadm工具快速搭建Kubernetes1.24......