首页 > 其他分享 >Pytorch | `torch.multiprocessing.spawn` 函数的使用

Pytorch | `torch.multiprocessing.spawn` 函数的使用

时间:2023-06-26 19:22:16浏览次数:52  
标签:spawn args 函数 torch Pytorch 进程 local multiprocessing

torch.multiprocessing.spawn 是 PyTorch 中用于启动多进程的函数,可以用于分布式训练等场景。其函数签名如下:

torch.multiprocessing.spawn(
    fn,
    args=(),
    nprocs=1,
    join=True,
    daemon=False,
    start_method='spawn',
)

参数:

  • fn (function) –函数被称为派生进程的入口点。必须在模块的顶层定义此函数,以便对其进行pickle和派生。这是多进程强加的要求。该函数称为fn(i,*args),其中i是进程索引,args是传递的参数元组。
  • args (tuple) – 传递给 fn 的参数.
  • nprocs (int) – 派生的进程数.
  • join (bool) – 执行一个阻塞的join对于所有进程.
  • daemon (bool) – 派生进程守护进程标志。如果设置为True,将创建守护进程.

其中,fn 是要在子进程中运行的函数,args 是传递给该函数的参数,nprocs 是要启动的进程数。当 nprocs 大于 1 时,会创建多个子进程,并在每个子进程中调用 fn 函数,每个子进程都会使用不同的进程 ID 进行标识。当 nprocs 等于 1 时,会在当前进程中直接调用 fn 函数,而不会创建新的子进程。

在上面提到的代码中,torch.multiprocessing.spawn 函数的具体调用方式如下:

torch.multiprocessing.spawn(process_fn, args=(parsed_args,), nprocs=world_size)

其中,process_fn 是要在子进程中运行的函数,args 是传递给该函数的参数,nprocs 是要启动的进程数,即推断出的 GPU 数量。这里的 process_fn 函数应该是在其他地方定义的,用于执行具体的训练任务。在多进程编程中,每个子进程都会调用该函数来执行训练任务。

需要注意的是,torch.multiprocessing.spawn 函数会自动将数据分布到各个进程中,并在所有进程执行完成后进行同步,以确保各个进程之间的数据一致性。同时,该函数还支持多种进程间通信方式,如共享内存(Shared Memory)、管道(Pipe)等,可以根据具体的需求进行选择。

给予process_fn 函数如下:

def process_fn(rank, args):
    local_args = copy.copy(args)
    local_args.local_rank = rank
    main(local_args)

其中,rank 参数是当前子进程的 ID,由 torch.multiprocessing.spawn 函数自动分配。而 args 参数是在调用 torch.multiprocessing.spawn 函数时传递的,其值为 (parsed_args,),表示 args 是一个元组,其中包含了一个元素 parsed_args

process_fn 函数内部,会先使用 copy.copy 函数复制一份 args 参数,将其赋值给 local_args 变量。然后将当前子进程的 ID 赋值给 local_args.local_rank,再调用 main(local_args) 函数进行具体的训练任务。

由于 main 函数需要的参数是一个 args 对象,因此在 process_fn 函数中需要将 args 参数解包,并将其值赋值给 local_args 变量。然后再将 local_args 变量传递给 main 函数进行训练任务。在多进程编程中,由于各个子进程之间是相互独立的,因此需要将训练任务拆分成多个子任务来分配给各个子进程执行,以实现并行化加速训练的效果。

例子:

import utils.multiprocessing as mpu
 
  if cfg.NUM_GPUS > 1:
            torch.multiprocessing.spawn(
                mpu.run,
                nprocs=cfg.NUM_GPUS,
                args=(
                    cfg.NUM_GPUS,
                    train,
                    cfg.DIST_INIT_METHOD,
                    cfg.SHARD_ID,
                    cfg.NUM_SHARDS,
                    cfg.DIST_BACKEND,
                    cfg,
                ),
                daemon=False,
            )

上面这段函数使用了torch.multiprocessing.spawn方法,传入的参数fnmpu.run,也就是utils.multiprocessing.run函数。然后又传入7个参数,传到multiprocessing.py文件中的run()方法,我们来找一下这个函数。

multiprocessing.py

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
 
"""Multiprocessing helpers."""
 
import torch
 
 
def run(
    local_rank, num_proc, func, init_method, shard_id, num_shards, backend, cfg
):
    """
    Runs a function from a child process.
    Args:
        local_rank (int): rank of the current process on the current machine.
        num_proc (int): number of processes per machine.
        func (function): function to execute on each of the process.
        init_method (string): method to initialize the distributed training.
            TCP initialization: equiring a network address reachable from all
            processes followed by the port.
            Shared file-system initialization: makes use of a file system that
            is shared and visible from all machines. The URL should start with
            file:// and contain a path to a non-existent file on a shared file
            system.
        shard_id (int): the rank of the current machine.
        num_shards (int): number of overall machines for the distributed
            training job.
        backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are
            supports, each with different capabilities. Details can be found
            here:
            https://pytorch.org/docs/stable/distributed.html
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
    """
    # Initialize the process group.
    world_size = num_proc * num_shards
    rank = shard_id * num_proc + local_rank
 
    try:
        torch.distributed.init_process_group(
            backend=backend,
            init_method=init_method,
            world_size=world_size,
            rank=rank,
        )
    except Exception as e:
        raise e
 
    torch.cuda.set_device(local_rank)
    func(cfg)

我们找到了这个函数的run()方法,但是这个方法需要传八个参数,我们从torch.multiprocessing.spawn方法传进来的只有七个。

所以要注意,run()函数中的第一个参数local_rank是当前子进程的 ID,由 torch.multiprocessing.spawn 函数自动分配。然后会自动将数据分布到各个进程中,并在所有进程执行完成后进行同步,以确保各个进程之间的数据一致性。

参考:

https://github.com/sangho-vision/wds_example/blob/850fdff046e4b84215722d291ffad8c024062607/run.py

标签:spawn,args,函数,torch,Pytorch,进程,local,multiprocessing
From: https://www.cnblogs.com/zhangxuegold/p/17506535.html

相关文章

  • PyTorch 从入门到放弃 —— 加载数据
    PyTorch有两种基础数据类型: torch.utils.data.DataLoader 和 torch.utils.data.Dataset. Dataset,它们存储着样本和对应的标记。 Dataset是样本数据集,DataLoader对Dataset进行封装,方便加载、遍历和分批等。importtorchfromtorchimportnnfromtorch.utils.dataimport......
  • Bert Pytorch 源码分析:四、编解码器
    #Bert编码器模块#由一个嵌入层和NL个TF层组成classBERT(nn.Module):"""BERTmodel:BidirectionalEncoderRepresentationsfromTransformers."""def__init__(self,vocab_size,hidden=768,n_layers=12,attn_heads=12,d......
  • Pytorch | 输入的形状为[seq_len, batch_size, d_model]和 [batch_size, seq_len, d_m
    首先导入依赖的torch包。importtorch我们设:seq_len(序列的最大长度):5batch_size(批量大小):2d_model(每个单词被映射为的向量的维度):10heads(多头注意力机制的头数):5d_k(每个头的特征数):21、输入形状为:[seq_len,batch_size,d_model]input_tensor=torch.randn(5,2,10)inp......
  • Pytorch | view()函数的使用
    函数简介Pytorch中的view函数主要用于Tensor维度的重构,即返回一个有相同数据但不同维度的Tensor。根据上面的描述可知,view函数的操作对象应该是Tensor类型。如果不是Tensor类型,可以通过tensor=torch.tensor(data)来转换。普通用法(手动调整size)view(参数a,参数b,…),其中,总......
  • Bert Pytorch 源码分析:二、注意力层
    #注意力机制的具体模块#兼容单头和多头classAttention(nn.Module):"""Compute'ScaledDotProductAttention""" #QKV尺寸都是BS*ML*ES #(或者多头情况下是BS*HC*ML*HS,最后两维之外的维度不重要) #从输入计算QKV的过程可以统一处理,不必......
  • Bert PyTorch 源码分析:一、嵌入层
    #标记嵌入就是最普通的嵌入层#接受单词ID输出单词向量#直接转发给了`nn.Embedding`classTokenEmbedding(nn.Embedding):def__init__(self,vocab_size,embed_size=512):super().__init__(vocab_size,embed_size,padding_idx=0) #片段嵌入实际上是......
  • Yann Lecun-纽约大学-深度学习(PyTorch)
    课程介绍    本课程涉及深度学习和表示学习的最新技术,重点是有监督和无监督的深度学习,嵌入方法,度量学习,卷积和递归网络,并应用于计算机视觉,自然语言理解和语音识别。前提条件包括:DS-GA1001数据科学入门或研究生水平的机器学习课程。  sphq: https://mp.weixin.qq.com/s?__b......
  • 历史最全GAN模型PyTorch代码实现整理分享
        如果你是第一次接触AE自编码器和GAN生成对抗网络,那这将会是一个非常有用且效率的学习资源。所有的内容使用PyTorch编写,编写格式清晰,非常适合PyTorch新手作为学习资源。本项目的所有模型目前都是基于MNIST数据库进行图片生成。MNIST数据集是一个比较小,一个光CPU就能跑起来的......
  • pytorch 使用多GPU训练模型测试出现:TypeError: forward() missing 1 required positio
    转载:https://blog.csdn.net/lingyunxianhe/article/details/119454778?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522168718901716800227455818%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fblog.%2522%257D&request_id=16871890171680022745......
  • Pytorch中查看GPU信息
    本文摘自:知乎 用Pytorch中查看GPU信息1.返回当前设备索引torch.cuda.current_device()2.返回GPU的数量torch.cuda.device_count()3.返回gpu名字,设备索引默认从0开始torch.cuda.get_device_name(0)4.cuda是否可用torch.cuda.is_available()......