首页 > 编程语言 >torch distributed源码解析torch/distributed/distributed_c10d.py 之cleanup()

torch distributed源码解析torch/distributed/distributed_c10d.py 之cleanup()

时间:2024-12-20 22:29:27浏览次数:9  
标签:group barrier ranks torch distributed 源码 pg timeout device

/mambaforge/envs/Plant/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py

def cleanup():
    try:
        dist.barrier()
        dist.destroy_process_group()
    except:
        pass

GPT-4o

Poe

以下是代码逐行的中文解释:


cleanup 函数

python

运行复制

def cleanup():
    try:
        dist.barrier()
        dist.destroy_process_group()
    except:
        pass
  1. def cleanup()
    定义了一个名为 cleanup 的函数,用来清理分布式进程组的资源。

  2. try:
    尝试执行清理操作,进入 try 块。

  3. dist.barrier()
    调用 PyTorch 分布式通信中的 barrier 方法,确保所有进程同步到这一点。如果某些进程未到达 barrier,其他进程会等待。

  4. dist.destroy_process_group()
    销毁当前的分布式进程组,释放资源。

  5. except:
    如果在 try 块中发生任何异常(例如进程组未初始化),捕获异常并忽略。


barrier 函数

python

运行复制

@_exception_logger
def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None):
  1. @_exception_logger
    使用装饰器 _exception_logger 来捕获函数内部的异常并可能记录日志。

  2. def barrier(...)
    定义了一个 barrier 函数,用于同步所有进程。当所有进程调用该函数时,进程才会继续执行。

  3. group=GroupMember.WORLD
    参数 group 指定了进程组,默认为整个世界组(GroupMember.WORLD)。

  4. async_op=False
    参数 async_op 控制是否为异步操作,默认为同步操作(False)。

  5. device_ids=None
    参数 device_ids 允许指定设备(例如 GPU),默认为 None


函数内部逻辑

python

运行复制

if _rank_not_in_group(group):
    _warn_not_in_group("barrier")
    return
  1. _rank_not_in_group(group)
    检查当前进程是否在指定的进程组中。

  2. _warn_not_in_group("barrier")
    如果当前进程不在组内,发出警告,提示该进程不属于指定的组。

  3. return
    如果当前进程不在组内,直接返回,结束函数。


python

运行复制

opts = BarrierOptions()
opts.device = _get_pg_default_device(group)
  1. opts = BarrierOptions()
    创建 BarrierOptions 对象,用于设置 barrier 的配置。

  2. opts.device = _get_pg_default_device(group)
    获取进程组的默认设备信息,并将其设置到 opts.device


python

运行复制

if device_ids is not None:
    if isinstance(device_ids, list):
        opts.device_ids = device_ids
    else:
        raise TypeError(
            "Invalid function argument: device_ids type should be List[int]"
        )
  1. if device_ids is not None:
    如果用户指定了 device_ids 参数。

  2. isinstance(device_ids, list)
    检查 device_ids 是否是一个列表。

  3. opts.device_ids = device_ids
    如果是列表类型,将其赋值给 opts.device_ids

  4. raise TypeError(...)
    如果 device_ids 不是列表,抛出类型错误异常。


python

运行复制

if group is None:
    default_pg = _get_default_group()
    work = default_pg.barrier(opts=opts)
else:
    work = group.barrier(opts=opts)
  1. if group is None:
    如果没有指定进程组,使用默认进程组。

  2. default_pg = _get_default_group()
    获取默认的进程组。

  3. work = default_pg.barrier(opts=opts)
    调用默认进程组的 barrier 方法,传入配置选项。

  4. work = group.barrier(opts=opts)
    如果指定了进程组,调用该进程组的 barrier 方法。


python

运行复制

if async_op:
    return work
else:
    work.wait()
  1. if async_op:
    如果启用了异步操作。

  2. return work
    返回异步操作的句柄 work

  3. work.wait()
    如果是同步操作,等待所有进程完成。


monitored_barrier 函数

函数签名

python

运行复制

def monitored_barrier(group=GroupMember.WORLD, timeout=None, wait_all_ranks=False):
  1. 定义了一个名为 monitored_barrier 的函数,用于同步所有进程,并支持超时机制和故障检测功能。

  2. group=GroupMember.WORLD
    指定进程组,默认为整个世界组。

  3. timeout=None
    设置超时时间,默认为 None

  4. wait_all_ranks=False
    是否等待所有进程完成,默认为 False


函数内部逻辑

python

运行复制

if _rank_not_in_group(group):
    _warn_not_in_group("monitored_barrier")
    return
  1. 检查当前进程是否在指定组中。如果不在组中,发出警告并返回。

python

运行复制

if get_backend(group) != Backend.GLOO:
    raise ValueError("monitored_barrier is only implemented for GLOO backend.")
  1. 检查通信后端是否为 GLOO。如果不是,则抛出异常,因为 monitored_barrier 仅支持 GLOO

python

运行复制

if timeout is None:
    timeout = _get_default_timeout(get_backend(group))
elif isinstance(timeout, float):
    warnings.warn(...)
    timeout = timedelta(seconds=timeout)
  1. 如果未指定超时时间,使用默认超时时间。

  2. 如果超时时间是浮点数,将其转换为 timedelta 类型。


python

运行复制

group_to_use = _get_default_group() if group is None else group
return group_to_use.monitored_barrier(timeout, wait_all_ranks=wait_all_ranks)
  1. 确定使用的进程组,如果未指定,使用默认组。

  2. 调用进程组的 monitored_barrier 方法,传入超时和等待选项。


_create_process_group_wrapper 函数

函数签名

python

运行复制

def _create_process_group_wrapper(
    wrapped_pg: torch._C._distributed_c10d.Backend,
    store_prefix: str,
    store: Store,
    rank: int,
    world_size: int,
    timeout: timedelta = default_pg_timeout,
):
  1. 定义了一个函数,用于创建一个包装的进程组。

  2. wrapped_pg
    被包装的后端(如 GLOONCCL)。

  3. store_prefix
    用于存储的前缀。

  4. store
    存储对象,用于分布式通信。

  5. rank
    当前进程的 rank。

  6. world_size
    总的进程数。

  7. timeout
    超时时间,默认为 default_pg_timeout


函数内部逻辑

python

运行复制

prefix = f"{PG_WRAPPER_STORE_PREFIX}:{store_prefix}"
store = PrefixStore(prefix, store)
helper_pg = ProcessGroupGloo(store, rank, world_size, timeout=timeout)
wrapped_pg = _ProcessGroupWrapper(wrapped_pg, helper_pg)
return wrapped_pg
  1. prefix
    生成带有前缀的存储键。

  2. store = PrefixStore(...)
    创建一个带有前缀的存储对象。

  3. helper_pg = ProcessGroupGloo(...)
    使用 GLOO 后端创建辅助进程组。

  4. wrapped_pg = _ProcessGroupWrapper(...)
    将原始进程组和辅助进程组包装起来。

  5. return wrapped_pg
    返回包装后的进程组。


_hash_ranks 函数

python

运行复制

def _hash_ranks(ranks: List[int]):
    return hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest()
  1. 定义一个函数,用于对进程的 rank 列表进行哈希。

  2. "_".join(map(str, ranks))
    rank 列表转换为字符串,并用下划线连接。

  3. hashlib.sha1(...).hexdigest()
    使用 SHA-1 算法对字符串计算哈希值,并返回十六进制字符串。


_process_group_color 函数

python

运行复制

def _process_group_color(ranks: List[int]) -> int:
    return int(_hash_ranks(ranks), 16) % (sys.maxsize >> 1)
  1. 定义一个函数,基于 rank 列表计算进程组的颜色(整数值)。

  2. _hash_ranks(ranks)
    调用前面的 _hash_ranks 函数,获取哈希值。

  3. int(..., 16)
    将哈希值从十六进制字符串转换为整数。

  4. % (sys.maxsize >> 1)
    取模,确保返回的整数是非负值。

@_exception_logger
def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None):
    """
    Synchronize all processes.

    This collective blocks processes until the whole group enters this function,
    if async_op is False, or if async work handle is called on wait().

    Args:
        group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.
        async_op (bool, optional): Whether this op should be an async op
        device_ids ([int], optional): List of device/GPU ids.

    Returns:
        Async work handle, if async_op is set to True.
        None, if not async_op or if not part of the group

    .. note:: `ProcessGroupNCCL` now relies on stream synchronization instead of
              device synchronization to block the CPU. Thus, please do not assume that
              `barrier()` would perform a device synchronization.
    """
    if _rank_not_in_group(group):
        _warn_not_in_group("barrier")
        return

    opts = BarrierOptions()
    opts.device = _get_pg_default_device(group)
    if device_ids is not None:
        if isinstance(device_ids, list):
            opts.device_ids = device_ids
        else:
            raise TypeError(
                "Invalid function argument: device_ids type should be List[int]"
            )

    if group is None:
        default_pg = _get_default_group()
        work = default_pg.barrier(opts=opts)
    else:
        work = group.barrier(opts=opts)

    if async_op:
        return work
    else:
        work.wait()


def monitored_barrier(group=GroupMember.WORLD, timeout=None, wait_all_ranks=False):
    """
    Synchronize processes similar to ``torch.distributed.barrier``, but consider a configurable timeout.

    It is able to report ranks that did not pass this barrier within the provided timeout.
    Specifically, for non-zero ranks, will block until a send/recv is processed from rank 0.
    Rank 0 will block until all send /recv from other ranks are processed, and will report
    failures for ranks that failed to respond in time. Note that if one rank does not reach the
    monitored_barrier (for example due to a hang), all other ranks would fail in monitored_barrier.

    This collective will block all processes/ranks in the group, until the
    whole group exits the function successfully, making it useful for debugging
    and synchronizing. However, it can have a performance impact and should only
    be used for debugging or scenarios that require full synchronization points
    on the host-side. For debugging purposes, this barrier can be inserted
    before the application's collective calls to check if any ranks are
    desynchronized.

    .. note:: Note that this collective is only supported with the GLOO backend.

    Args:
        group (ProcessGroup, optional): The process group to work on. If
            ``None``, the default process group will be used.
        timeout (datetime.timedelta, optional): Timeout for monitored_barrier.
            If ``None``, the default process group timeout will be used.
        wait_all_ranks (bool, optional): Whether to collect all failed ranks or
            not. By default, this is ``False`` and ``monitored_barrier`` on rank 0
            will throw on the first failed rank it encounters in order to fail
            fast. By setting ``wait_all_ranks=True`` ``monitored_barrier`` will
            collect all failed ranks and throw an error containing information
            about all failed ranks.

    Returns:
        ``None``.

    Example::
        >>> # xdoctest: +SKIP("need process group init")
        >>> # Note: Process group initialization omitted on each rank.
        >>> import torch.distributed as dist
        >>> if dist.get_rank() != 1:
        >>>     dist.monitored_barrier() # Raises exception indicating that
        >>> # rank 1 did not call into monitored_barrier.
        >>> # Example with wait_all_ranks=True
        >>> if dist.get_rank() == 0:
        >>>     dist.monitored_barrier(wait_all_ranks=True) # Raises exception
        >>> # indicating that ranks 1, 2, ... world_size - 1 did not call into
        >>> # monitored_barrier.
    """
    # Need to call rank not in group before using the group, otherwise
    # "Invalid process group" error is raised.
    if _rank_not_in_group(group):
        _warn_not_in_group("monitored_barrier")
        return

    if get_backend(group) != Backend.GLOO:
        raise ValueError("monitored_barrier is only implemented for GLOO backend.")

    if timeout is None:
        timeout = _get_default_timeout(get_backend(group))
    elif isinstance(timeout, float):
        # TODO(whc) aparently some existing test case for monitored_barrier passes in a timeout in float format?
        warnings.warn(
            "Please specify timeout arg as a timedelta. "
            f"Converting current value of {timeout} assuming it represents seconds",
        )
        timeout = timedelta(seconds=timeout)

    _check_valid_timeout(timeout)

    group_to_use = _get_default_group() if group is None else group
    return group_to_use.monitored_barrier(timeout, wait_all_ranks=wait_all_ranks)


def _create_process_group_wrapper(
    wrapped_pg: torch._C._distributed_c10d.Backend,
    store_prefix: str,
    store: Store,
    rank: int,
    world_size: int,
    timeout: timedelta = default_pg_timeout,
):
    # (whc) this appears to be just for the gloo backend? if so, `default_pg_timeout` is appropriate...

    # Create a separate prefix store for the helper process group.
    prefix = f"{PG_WRAPPER_STORE_PREFIX}:{store_prefix}"
    store = PrefixStore(prefix, store)
    helper_pg = ProcessGroupGloo(store, rank, world_size, timeout=timeout)
    # Wrap the underlying pg with ProcessGroupWrapper.
    wrapped_pg = _ProcessGroupWrapper(wrapped_pg, helper_pg)
    return wrapped_pg

# helper function for deterministically hashing a list of ranks
def _hash_ranks(ranks: List[int]):
    return hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest()

# Takes a list of ranks and computes an integer color
def _process_group_color(ranks: List[int]) -> int:
    # Convert our hash to an int, but avoid negative numbers by shifting a bit.
    return int(_hash_ranks(ranks), 16) % (sys.maxsize >> 1)

标签:group,barrier,ranks,torch,distributed,源码,pg,timeout,device
From: https://blog.csdn.net/sinat_37574187/article/details/144510945

相关文章

  • 【开题报告+论文+源码】基于SpringBoot+Vue的“振记小食店”线上点餐系统
    项目背景与意义近年来,随着互联网的迅猛发展和人们生活水平的提高,线上订餐系统逐渐成为了人们点餐的首选方式。这种系统不仅提供了方便快捷的订餐方式,还为餐厅和顾客之间的交流提供了更多的可能性。该系统的功能包括查看菜品、线上点餐、结账、订餐、桌台预定以及订单评价,为顾......
  • 【含文档+PPT+源码】基于SpringBoot的校园论坛系统的设计与实现
    项目背景与意义随着互联网的快速发展,人们获取信息的方式也发生了巨大的变化。特别是在领域,爱好者们希望能够随时随地获取最新的新闻,了解赛事的情况,以及与其他爱好者交流互动。因此,校园论坛的出现成为了满足人们需求的重要途径。校园论坛是指通过互联网为用户提供各种相关信息......
  • 基于SpringBoot+Vue的课程答疑管理系统设计与实现毕设(文档+源码)
    目录一、项目介绍二、开发环境三、功能介绍四、核心代码五、效果图六、源码获取:         大家好呀,我是一个混迹在java圈的码农。今天要和大家分享的是一款基于SpringBoot+Vue的课程答疑管理系统,项目源码请点击文章末尾联系我哦~目前有各类成品毕设JavaWeb......
  • 基于SpringBoot+Vue的科研项目验收管理系统设计与实现毕设(文档+源码)
    目录一、项目介绍二、开发环境三、功能介绍四、核心代码五、效果图六、源码获取:         大家好呀,我是一个混迹在java圈的码农。今天要和大家分享的是一款基于SpringBoot+Vue的科研项目验收管理系统,项目源码请点击文章末尾联系我哦~目前有各类成品毕设JavaW......
  • flask框架驾校预约管理系统毕设源码+论文
    本系统(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。系统程序文件列表开题报告内容一、选题背景在当今社会,随着汽车保有量的不断增加,驾校的规模和学员数量也在迅速增长。驾校的管理变得日益复杂,传统的管理方式难以满足高效运营的需......
  • flask框架驾照考试知识管理平台毕设源码+论文
    本系统(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。系统程序文件列表开题报告内容驾照考试知识管理平台毕业设计相关内容一、选题背景关于驾照考试知识管理平台的研究,现有研究主要以驾照考试的理论知识教学和传统管理方式为主,专......
  • flask框架健身房信息管理系统毕设源码+论文
    本系统(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。系统程序文件列表开题报告内容一、选题背景关于健身房信息管理系统的研究,现有研究主要以传统管理模式为主,专门针对信息化、系统化的健身房信息管理系统的研究较少。在国内外,部分......
  • flask框架监狱管理系统毕设源码+论文
    本系统(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。系统程序文件列表开题报告内容一、选题背景关于监狱管理系统的研究,现有研究主要以监狱的整体管理模式和传统人工管理方式为主。专门针对构建信息化监狱管理系统,整合服刑人员、民......
  • flask框架监狱罪犯信息管理系统毕设源码+论文
    本系统(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。系统程序文件列表开题报告内容一、选题背景关于监狱罪犯信息管理系统的研究,现有研究主要以监狱整体管理方面为主,专门针对罪犯信息管理系统细致功能及流程优化的研究较少。在国外......
  • flask框架健身管理系统毕设源码+论文
    本系统(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。系统程序文件列表开题报告内容一、选题背景关于健身管理系统的研究,现有研究主要以通用的管理系统开发为主,专门针对健身领域特定功能,如用户、健身教练、健身课程、课程报名、预约......