首页 > 系统相关 >windows下使用pytorch进行单机多卡分布式训练

windows下使用pytorch进行单机多卡分布式训练

时间:2023-04-02 17:33:16浏览次数:55  
标签:windows distributed 多卡 rank ids -- pytorch local size

现在有四张卡,但是部署在windows10系统上,想尝试下在windows上使用单机多卡进行分布式训练,网上找了一圈硬是没找到相关的文章。以下是踩坑过程。

首先,pytorch的版本必须是大于1.7,这里使用的环境是:

pytorch==1.12+cu11.6
四张4090显卡
python==3.7.6

使用nn.DataParallel进行分布式训练

这一种方式较为简单:
首先我们要定义好使用的GPU的编号,GPU按顺序依次为0,1,2,3。gpu_ids可以通过命令行的形式传入:

gpu_ids = args.gpu_ids.split(',')
gpu_ids = [int(i) for i in gpu_ids]
torch.cuda.set_device('cuda:{}'.format(gpu_ids[0]))

创建模型后用nn.DataParallel进行处理,

 model.cuda()
 r_model = nn.DataParallel(model, device_ids=gpu_ids, output_device=gpu_ids[0])

对,没错,只需要这么两步就行了。需要注意的是保存模型后进行加载时,需要先用nn.DataParallel进行处理,再加载权重,不然参数名没对齐会报错。

checkpoint = torch.load(checkpoint_path)
model.cuda()
r_model = nn.DataParallel(model, device_ids=gpu_ids, output_device=gpu_ids[0])
r_model.load_state_dict(checkpoint['state_dict'])

如果不使用分布式加载模型,你需要对权重进行映射:

new_start_dict = {}
for k, v in checkpoint['state_dict'].items():
    new_start_dict["module." + k] = v
model.load_state_dict(new_start_dict)

使用Distributed进行分布式训练

首先了解一下概念:
node:主机数,单机多卡就一个主机,也就是1。
rank:当前进程的序号,用于进程之间的通讯,rank=0的主机为master节点。
local_rank:当前进程对应的GPU编号。
world_size:总的进程数。
在windows中,我们需要在py文件里面使用:

import os
os.environ["CUDA_VISIBLE_DEVICES]='0,1,3'

来指定使用的显卡。
假设现在我们使用上面的三张显卡,运行时显卡会重新按照0-N进行编号,有:

[38664] rank = 1, world_size = 3, n = 1, device_ids = [1]
[76032] rank = 0, world_size = 3, n = 1, device_ids = [0]
[23208] rank = 2, world_size = 3, n = 1, device_ids = [2]

也就是进程0使用第1张显卡,进行1使用第2张显卡,进程2使用第三张显卡。
有了上述的基本知识,再看看具体的实现。

使用torch.distributed.launch启动

使用torch.distributed.launch启动时,我们必须要在args里面添加一个local_rank参数,也就是:
parser.add_argument("--local_rank", type=int, default=0)
1、初始化:

import torch.distributed as dist

env_dict = {
        key: os.environ[key]
        for key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE")
}
current_work_dir = os.getcwd()
    init_method = f"file:///{os.path.join(current_work_dir, 'ddp_example')}"
dist.init_process_group(backend="gloo", init_method=init_method, rank=int(env_dict["RANK"]),
                                world_size=int(env_dict["WORLD_SIZE"]))

这里需要重点注意,这种启动方式在环境变量中会存在RANK和WORLD_SIZE,我们可以拿来用。backend必须指定为gloo,init_method必须是file:///,而且每次运行完一次,下一次再运行前都必须删除生成的ddp_example,不然会一直卡住。
2、构建模型并封装
local_rank会自己绑定值,不再是我们--local_rank指定的。

 model.cuda(args.local_rank)
 r_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=device_ids)

3、构建数据集加载器并封装

  train_dataset = dataset(file_path='data/{}/{}'.format(args.data_name, train_file))
  train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
  train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size,
                              collate_fn=collate.collate_fn, num_workers=4, sampler=train_sampler)

4、计算损失函数
我们把每一个GPU上的loss进行汇聚后计算。

def loss_reduce(self, loss):
        rt = loss.clone()
        dist.all_reduce(rt, op=dist.ReduceOp.SUM)
        rt /= self.args.local_world_size
        return rt

loss = self.criterion(outputs, labels)
torch.distributed.barrier()
loss = self.loss_reduce(loss)

注意打印相关信息和保存模型的时候我们通常只需要在local_rank=0时打印。同时,在需要将张量转换到GPU上时,我们需要指定使用的GPU,通过local_rank指定就行,即data.cuda(args.local_rank),保证数据在对应的GPU上进行处理。
5、启动
在windows下需要把换行符去掉,且只变为一行。

python -m torch.distributed.launch \
--nnode=1 \
--node_rank=0 \
--nproc_per_node=3 \
main_distributed.py \
--local_world_size=3 \
--bert_dir="../model_hub/chinese-bert-wwm-ext/" \
--data_dir="./data/cnews/" \
--data_name="cnews" \
--log_dir="./logs/" \
--output_dir="./checkpoints/" \
--num_tags=10 \
--seed=123 \
--max_seq_len=512 \
--lr=3e-5 \
--train_batch_size=64 \
--train_epochs=1 \
--eval_batch_size=64 \
--do_train \
--do_predict \
--do_test

nproc_per_node、local_world_size和GPU的数目保持一致。

使用torch.multiprocessing启动

使用torch.multiprocessing启动和使用torch.distributed.launch启动大体上是差不多的,有一些地方需要注意。

mp.spawn(main_worker, nprocs=args.nprocs, args=(args,))

main_worker是我们的主运行函数,dist.init_process_group要放在这里面,而且第一个参数必须为local_rank。即main_worker(local_rank, args)
nprocs是进程数,也就是使用的GPU数目。
args按顺序传入main_worker真正使用的参数。
其余的就差不多。
启动指令:

python main_mp_distributed.py \
--local_world_size=4 \
--bert_dir="../model_hub/chinese-bert-wwm-ext/" \
--data_dir="./data/cnews/" \
--data_name="cnews" \
--log_dir="./logs/" \
--output_dir="./checkpoints/" \
--num_tags=10 \
--seed=123 \
--max_seq_len=512 \
--lr=3e-5 \
--train_batch_size=64 \
--train_epochs=1 \
--eval_batch_size=64 \
--do_train \
--do_predict \
--do_test

最后需要说明的,假设我们设置的batch_size=64,那么实际上的batch_size = int(batch_size / GPU数目)。
附上完整的基于bert的中文文本分类单机多卡训练代码:https://github.com/taishan1994/pytorch_bert_chinese_text_classification

参考

https://github.com/tczhangzhi/pytorch-distributed
https://murphypei.github.io/blog/2020/09/pytorch-distributed
https://pytorch.org/docs/master/distributed.html?highlight=all_gather#torch.distributed.all_gather
https://github.com/lesliejackson/PyTorch-Distributed-Training
https://github.com/pytorch/examples/blob/ddp-tutorial-code/distributed/ddp/example.py
996黄金一代:[原创][深度][PyTorch] DDP系列第一篇:入门教程
「新生手册」:PyTorch分布式训练 - 知乎 (zhihu.com)

标签:windows,distributed,多卡,rank,ids,--,pytorch,local,size
From: https://www.cnblogs.com/xiximayou/p/17280859.html

相关文章

  • mingw64 + nvim + coc.nvim + nvim-dap : C++ windows - 01
    目标用做C++编译器尽量不要扩展其它功能python是避免不了,所以才安装的。1.1下载安装https://mirror.tuna.tsinghua.edu.cn/msys2/distrib/msys2-x86_64-latest.exe安装路径:C:\gnu\msys641.2mingw64.exe使用这个:C:\gnu\msys64\mingw64.exe1.3安装程序:pac......
  • mingw64 + nvim + coc.nvim + nvim-dap : C++ windows - 02
    2.1命令行工具https://github.com/sharkdp/fdhttps://github.com/junegunn/fzfhttps://github.com/BurntSushi/ripgrephttps://github.com/tree-sitter/tree-sitterC:\gnu\cli\fd.exeC:\gnu\cli\fzf.exeC:\gnu\cli\rg.exeC:\gnu\cli\tree-sitter.exe添加到path......
  • mingw64 + nvim + coc.nvim + nvim-dap : C++ windows - 03
    3.1cclshttps://github.com/MaskRay/ccls/wiki/Build在msys中:pacman-S--noconfirmmingw-w64-x86_64-clangmingw-w64-x86_64-clang-tools-extramingw64/mingw-w64-x86_64-pollymingw-w64-x86_64-cmakemingw-w64-x86_64-jqmingw-w64-x86_64-ninjamingw-w64-x86_64-n......
  • mingw64 + nvim + coc.nvim + nvim-dap : C++ windows - 04
    4.1init.vim将init.vim放置到下面:echostdpath('config')~\AppData\Local\nvim4.2plughttps://github.com/junegunn/vim-plug将plug.vim放置到下面~\AppData\Local\nvim-data\site\autoload4.3:PlugInstall4.3.1网络问题gitproxy4.3.2process4.3......
  • mingw64 + nvim + coc.nvim + nvim-dap : C++ windows - 05
    PSC:\Users\dev\Desktop\cpp>cd.\build\PSC:\Users\dev\Desktop\cpp\build>cmake..-DCMAKE_BUILD_TYPE=Debug--Buildingfor:Ninja--TheCcompileridentificationisGNU12.2.0--TheCXXcompileridentificationisGNU12.2.0--Detect......
  • Windows窗口对象
         ......
  • 【THM】Windows Fundamentals 2(Windows基础知识2)-学习
    本文相关的TryHackMe实验房间链接:https://tryhackme.com/room/windowsfundamentals2x0x本文介绍:本文所涉及的内容是Windows基础模块的第2部分,了解有关系统配置、UAC设置、资源监控、Windows注册表等更多信息。简介在WindowsFundamentals1中,我们已经介绍了Windows的桌面......
  • Windows 环境以 CPU 运行 stable diffusion
    前言 stable-diffusion-webui要求的Python版本是3.10.6,本机还是几年前装的3.10.0,为了避免处理更多幺蛾子,直接升级到3.10.6,还好之前就是3.10,可以直接升级。还有一个好处就是不用安装conda或者miniconda,Python虚拟环境直接就是3.10.6。其实3.10其他小版本的环......
  • go创建文件的软链接,不支持windows
    funcmain(){err:=os.Symlink("/data/da","/home/go/da")iferr!=nil{fmt.Println(err)//即使是错误也不退出}err=os.Remove("/home/go/da")iferr!=nil{fmt.Println("/home/go/dafileisremove",......
  • Thinkpad T14升级Windows11ver22h2失败问题解决小记
    背景手头的ThinkPad在近一年的时间里每次升级Windows11的22h2版本每次都会报错,具体有以下几种情况:更新过程中无问题,重启后黑屏更新过程中会卡在26%左右,然后蓝屏报KENERAL_CHECK_FAIL,接着便自动重启进入修复程序在WindowsUpdate更新中报错0xC1900101在上述错误出现后,再次更......