torch.multiprocessing.spawn
是 PyTorch 中用于启动多进程的函数,可以用于分布式训练等场景。其函数签名如下:
torch.multiprocessing.spawn(
fn,
args=(),
nprocs=1,
join=True,
daemon=False,
start_method='spawn',
)
参数:
- fn (function) –函数被称为派生进程的入口点。必须在模块的顶层定义此函数,以便对其进行pickle和派生。这是多进程强加的要求。该函数称为
fn(i,*args)
,其中i
是进程索引,args
是传递的参数元组。 - args (tuple) – 传递给
fn
的参数. - nprocs (int) – 派生的进程数.
- join (bool) – 执行一个阻塞的join对于所有进程.
- daemon (bool) – 派生进程守护进程标志。如果设置为True,将创建守护进程.
其中,fn
是要在子进程中运行的函数,args
是传递给该函数的参数,nprocs
是要启动的进程数。当 nprocs
大于 1 时,会创建多个子进程,并在每个子进程中调用 fn
函数,每个子进程都会使用不同的进程 ID 进行标识。当 nprocs
等于 1 时,会在当前进程中直接调用 fn
函数,而不会创建新的子进程。
在上面提到的代码中,torch.multiprocessing.spawn
函数的具体调用方式如下:
torch.multiprocessing.spawn(process_fn, args=(parsed_args,), nprocs=world_size)
其中,process_fn
是要在子进程中运行的函数,args
是传递给该函数的参数,nprocs
是要启动的进程数,即推断出的 GPU 数量。这里的 process_fn
函数应该是在其他地方定义的,用于执行具体的训练任务。在多进程编程中,每个子进程都会调用该函数来执行训练任务。
需要注意的是,torch.multiprocessing.spawn
函数会自动将数据分布到各个进程中,并在所有进程执行完成后进行同步,以确保各个进程之间的数据一致性。同时,该函数还支持多种进程间通信方式,如共享内存(Shared Memory)、管道(Pipe)等,可以根据具体的需求进行选择。
给予process_fn
函数如下:
def process_fn(rank, args):
local_args = copy.copy(args)
local_args.local_rank = rank
main(local_args)
其中,rank
参数是当前子进程的 ID,由 torch.multiprocessing.spawn
函数自动分配。而 args
参数是在调用 torch.multiprocessing.spawn
函数时传递的,其值为 (parsed_args,)
,表示 args
是一个元组,其中包含了一个元素 parsed_args
。
在 process_fn
函数内部,会先使用 copy.copy
函数复制一份 args
参数,将其赋值给 local_args
变量。然后将当前子进程的 ID 赋值给 local_args.local_rank
,再调用 main(local_args)
函数进行具体的训练任务。
由于 main
函数需要的参数是一个 args
对象,因此在 process_fn
函数中需要将 args
参数解包,并将其值赋值给 local_args
变量。然后再将 local_args
变量传递给 main
函数进行训练任务。在多进程编程中,由于各个子进程之间是相互独立的,因此需要将训练任务拆分成多个子任务来分配给各个子进程执行,以实现并行化加速训练的效果。
例子:
import utils.multiprocessing as mpu
if cfg.NUM_GPUS > 1:
torch.multiprocessing.spawn(
mpu.run,
nprocs=cfg.NUM_GPUS,
args=(
cfg.NUM_GPUS,
train,
cfg.DIST_INIT_METHOD,
cfg.SHARD_ID,
cfg.NUM_SHARDS,
cfg.DIST_BACKEND,
cfg,
),
daemon=False,
)
上面这段函数使用了torch.multiprocessing.spawn
方法,传入的参数fn
是mpu.run
,也就是utils.multiprocessing.run
函数。然后又传入7个参数,传到multiprocessing.py文件中的run()
方法,我们来找一下这个函数。
multiprocessing.py
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""Multiprocessing helpers."""
import torch
def run(
local_rank, num_proc, func, init_method, shard_id, num_shards, backend, cfg
):
"""
Runs a function from a child process.
Args:
local_rank (int): rank of the current process on the current machine.
num_proc (int): number of processes per machine.
func (function): function to execute on each of the process.
init_method (string): method to initialize the distributed training.
TCP initialization: equiring a network address reachable from all
processes followed by the port.
Shared file-system initialization: makes use of a file system that
is shared and visible from all machines. The URL should start with
file:// and contain a path to a non-existent file on a shared file
system.
shard_id (int): the rank of the current machine.
num_shards (int): number of overall machines for the distributed
training job.
backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are
supports, each with different capabilities. Details can be found
here:
https://pytorch.org/docs/stable/distributed.html
cfg (CfgNode): configs. Details can be found in
slowfast/config/defaults.py
"""
# Initialize the process group.
world_size = num_proc * num_shards
rank = shard_id * num_proc + local_rank
try:
torch.distributed.init_process_group(
backend=backend,
init_method=init_method,
world_size=world_size,
rank=rank,
)
except Exception as e:
raise e
torch.cuda.set_device(local_rank)
func(cfg)
我们找到了这个函数的run()
方法,但是这个方法需要传八个参数,我们从torch.multiprocessing.spawn
方法传进来的只有七个。
所以要注意,run()
函数中的第一个参数local_rank
是当前子进程的 ID,由 torch.multiprocessing.spawn
函数自动分配。然后会自动将数据分布到各个进程中,并在所有进程执行完成后进行同步,以确保各个进程之间的数据一致性。
参考:
https://github.com/sangho-vision/wds_example/blob/850fdff046e4b84215722d291ffad8c024062607/run.py
标签:spawn,args,函数,torch,Pytorch,进程,local,multiprocessing From: https://www.cnblogs.com/zhangxuegold/p/17506535.html