首页 > 其他分享 >PyTorch 分布式使用方式及代码解析

PyTorch 分布式使用方式及代码解析

时间:2024-07-13 10:00:22浏览次数:16  
标签:torch distributed PyTorch train device model 解析 分布式

一、PyTorch分布式 DP与DDP

1.1 PyTorch分布式支持

数据并行
 
模型并行
​​​​​​

1.2 PyTorch分布式调用-DP 

1.3 PyTorch分布式调用-DDP 

1.4 PyTorch分布式-通信后端 

gloo:具有各种原语的集体通信库,用
于多机训练。Facebook为在Linux上
运行而构建的
NCCL(Nvidia Collective multi-GPU
Communication Library) : 是一个实现
多GPU的collective communication通信
(all-gather, reduce, broadcast)库,
Nvidia做了很多优化,以在PCIe、
NVLink、InfiniBand上实现较高的通信
速度
分布式后端适用场景:
GPU分布式训练——nccl
CPU分布式训练——gloo

分布式实现中的核心通信模块为 torch.distributed,该模块支持三种后端
• nccl
• gloo
• mpi
import torch.distributed as dist
...
dist.init_process_group(backend= args.dist_backend, init_method= args.dist_url, world_size= args.world_size,
rank= args.rank)
• backend (str or Backend) – 后端使用。根据构建时配置,有效值包括mpi、gloo、nccl。
• init_method (str, optional) – 指定如何初始化进程组的URL。 tcp、file、master方式
• world_size (int, optional) – 参与作业的进程数。
• rank (int, optional) – 当前流程的排名。
• timeout (timedelta, optional) – 针对进程组执行的操作超时,默认值等于30分钟,这仅适用于gloo后端。
• group_name (str, optional, deprecated) – 团队名字。
torch.distributed.is_nccl_available()
torch.distributed.is_mpi_available()

 二、实际代码解析

2.1 PyTorch分布式调用-DP

单进程控制多 GPU

# main.py
import torch
import torch.distributed as dist
gpus = [0, 1, 2, 3]
torch.cuda.set_device('cuda:{}'.format(gpus[0]))
train_dataset = ...
train_loader =
torch.utils.data.DataLoader(train_dataset,
batch_size=...)
model = ...
model = nn.DataParallel(model.to(device),
device_ids=gpus, output_device=gpus[0])
optimizer = optim.SGD(model.parameters())

主程序段:

for epoch in range(100):
for batch_idx, (data, target) in enumerate(train_loader):
# 这里要 images/target.cuda()
images = images.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
...
output = model(images)
loss = criterion(output, target)
...
optimizer.zero_grad()
loss.backward()
optimizer.step()
• model.cuda() 模型和数据均需要load进gpu
• device_ids 参与训练得gpu有哪些
• output_device 用于汇总梯度的gpu是哪个
batchsize设置成n倍的单卡batchsize

2.2 PyTorch分布式调用-DDP

Ø 在使用 distributed 包的任何其他函数之前,需要使用 init_process_group 初始化进程组,同时
初始化 distributed 包;
Ø 如果需要进行小组内集体通信,用 new_group 创建子分组;
Ø 创建分布式并行(DistributedDataParallel)模型 DDP(model, device_ids=device_ids);
Ø 为数据集创建 Sampler;
Ø 使用启动工具 torch.distributed.launch 在每个主机上执行一次脚本,开始训练;

# main.py
import torch.distributed as dist
#进程同步
dist.init_process_group(backend='nccl’)
#把模型 数据加载到当前进程gpu
torch.cuda.set_device(args.local_rank)
#数据切分
train_dataset = ...
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=..., sampler=train_sampler)
model = ...
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
optimizer = torch.optim.SGD(model.parameters())
主程序段

2.3 DP与DDP区别

Ø DP
• 使用简单,一行代码即可;
• 不支持多机多卡分布式并行;
• 负载不均衡
Ø DDP
• 负载均衡;
• 支持多机多卡并行;
• 相较于DP,需要传输的数据量小,效率更高;
结论:综上,推荐调用DDP完成PyTorch的分布式支持。

标签:torch,distributed,PyTorch,train,device,model,解析,分布式
From: https://blog.csdn.net/qq_27815483/article/details/140377050

相关文章

  • PyTorch自学笔记——深度学习基础(1)
    PyTorch自学笔记。学习教程为ZerotoMasteryLearnPyTorchforDeepLearning,对应视频教程为https://www.youtube.com/watch?v=Z_ikDlimN6A概念(Whatisdeeplearning)机器学习(Machinelearning,ML)定义将事物(数据)转化为数字,并找出数字中的模式Machinelearning......
  • yolov8_pytorch目标检测和图像分割深度学习模型
    yolov8论文无模型结构yolov8是一种单阶段目标检测算法,该算法在YOLOV5的基础上添加了一些新的改进思路,使其速度与精度都得到了极大的性能提升。算法原理YOLOv8算法通过将图像划分为不同大小的网格,预测每个网格中的目标类别和边界框,利用特征金字塔结构和自适应的模型缩放......
  • 单体、分布式、微服务、Serverless及新兴部署模式分析
    在数字化时代,软件架构的选择对于企业的技术战略至关重要。从单体架构到Serverless,再到服务网格和服务化模型,每一种架构模式都反映了特定时期内技术发展和业务需求的特点。本文将对这些架构模式的优缺点进行讨论,供大家参考。部署方式的不断演进单体架构(MonolithicArchitect......
  • 2024年06月CCF-GESP编程能力等级认证C++编程三级真题解析
    本文收录于专栏《C++等级认证CCF-GESP真题解析》,专栏总目录:点这里。订阅后可阅读专栏内所有文章。一、单选题(每题2分,共30分)第1题小杨父母带他到某培训机构给他报名参加CCF组织的GESP认证考试的第1级,那他可以选择的认证语言有()种。A.1B.2C.3D.4答案:C第2......
  • 2023CSP真题+答案+解析
    一、 单项选择题(共15题,每题2分,共计30分:每题有且仅有一个正确选项)1. 在C++中,下面哪个关键字用于声明一个变量,其值不能被修改?()。A. unsigned B. const C. static D. mutable答案:B在C++中,关键字const用于声明一个变量,表示其值是常量,不能被修改。一旦用con......
  • 深入解析香橙派 AIpro开发板:功能、性能与应用场景全面测评
    文章目录引言香橙派AIpro开发板介绍到手第一感觉开发板正面开发板背面性能应用场景移植操作系统香橙派AIpro开发板支持哪些操作系统?烧写操作系统到SD卡中启动开发板的步骤查看系统提供的事例程序体验——开发的简洁性视频播放展示ffmpeg简介ffmpeg播放视频安装ffmpeg......
  • 主流json解析框架示例
    主流json解析框架示例jackson、gson、fastjson/fastjson2三种主流json解析框架对比●性能:在性能方面,Fastjson通常被认为是最快的JSON解析库,其次是Jackson和Gson,json-lib的性能相对较低。●API和功能:Jackson提供了非常灵活、强大的API,支持各种高级功能,例如树模型、数据绑定、......
  • JVM参数系列解析
    -XX:+UseCompressedOopsJavaSE6U23开始,JVM会默认开启压缩指针。JVM之压缩指针(CompressedOops)-XX:+DisableExplicitGC强制禁用手动gcJava虚拟机System.gc()解析CMS系列-XX:+UseParNewGC-XX:+UseParNewGC是一个与Java虚拟机(JVM)垃圾回收策略相关的命令行选项,......
  • 在 PostgreSQL 里如何实现数据的分布式查询的负载均衡?
    文章目录在PostgreSQL中实现数据分布式查询的负载均衡在PostgreSQL中实现数据分布式查询的负载均衡在当今数字化时代,数据量呈爆炸式增长,对于大规模数据处理的需求也日益迫切。在PostgreSQL中实现数据的分布式查询负载均衡成为了提升系统性能和可用性的关键......
  • 计算机网络 ARP协议(地址解析协议)
            ARP(AddressResolutionProtocol)是一种用于解析网络层地址(如IPv4地址)和数据链路层地址(如MAC地址)之间对应关系的协议。它主要用于在局域网(LAN)中根据目标设备的IP地址获取其对应的MAC地址,以便在数据链路层进行准确的数据传输。ARP协议工作原理:ARP请求和响应......