/mambaforge/envs/Plant/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py
def cleanup():
try:
dist.barrier()
dist.destroy_process_group()
except:
pass
Poe
以下是代码逐行的中文解释:
cleanup
函数
python
运行复制
def cleanup():
try:
dist.barrier()
dist.destroy_process_group()
except:
pass
-
def cleanup()
定义了一个名为cleanup
的函数,用来清理分布式进程组的资源。 -
try:
尝试执行清理操作,进入try
块。 -
dist.barrier()
调用 PyTorch 分布式通信中的barrier
方法,确保所有进程同步到这一点。如果某些进程未到达barrier
,其他进程会等待。 -
dist.destroy_process_group()
销毁当前的分布式进程组,释放资源。 -
except:
如果在try
块中发生任何异常(例如进程组未初始化),捕获异常并忽略。
barrier
函数
python
运行复制
@_exception_logger
def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None):
-
@_exception_logger
使用装饰器_exception_logger
来捕获函数内部的异常并可能记录日志。 -
def barrier(...)
定义了一个barrier
函数,用于同步所有进程。当所有进程调用该函数时,进程才会继续执行。 -
group=GroupMember.WORLD
参数group
指定了进程组,默认为整个世界组(GroupMember.WORLD
)。 -
async_op=False
参数async_op
控制是否为异步操作,默认为同步操作(False
)。 -
device_ids=None
参数device_ids
允许指定设备(例如 GPU),默认为None
。
函数内部逻辑
python
运行复制
if _rank_not_in_group(group):
_warn_not_in_group("barrier")
return
-
_rank_not_in_group(group)
检查当前进程是否在指定的进程组中。 -
_warn_not_in_group("barrier")
如果当前进程不在组内,发出警告,提示该进程不属于指定的组。 -
return
如果当前进程不在组内,直接返回,结束函数。
python
运行复制
opts = BarrierOptions()
opts.device = _get_pg_default_device(group)
-
opts = BarrierOptions()
创建BarrierOptions
对象,用于设置barrier
的配置。 -
opts.device = _get_pg_default_device(group)
获取进程组的默认设备信息,并将其设置到opts.device
。
python
运行复制
if device_ids is not None:
if isinstance(device_ids, list):
opts.device_ids = device_ids
else:
raise TypeError(
"Invalid function argument: device_ids type should be List[int]"
)
-
if device_ids is not None:
如果用户指定了device_ids
参数。 -
isinstance(device_ids, list)
检查device_ids
是否是一个列表。 -
opts.device_ids = device_ids
如果是列表类型,将其赋值给opts.device_ids
。 -
raise TypeError(...)
如果device_ids
不是列表,抛出类型错误异常。
python
运行复制
if group is None:
default_pg = _get_default_group()
work = default_pg.barrier(opts=opts)
else:
work = group.barrier(opts=opts)
-
if group is None:
如果没有指定进程组,使用默认进程组。 -
default_pg = _get_default_group()
获取默认的进程组。 -
work = default_pg.barrier(opts=opts)
调用默认进程组的barrier
方法,传入配置选项。 -
work = group.barrier(opts=opts)
如果指定了进程组,调用该进程组的barrier
方法。
python
运行复制
if async_op:
return work
else:
work.wait()
-
if async_op:
如果启用了异步操作。 -
return work
返回异步操作的句柄work
。 -
work.wait()
如果是同步操作,等待所有进程完成。
monitored_barrier
函数
函数签名
python
运行复制
def monitored_barrier(group=GroupMember.WORLD, timeout=None, wait_all_ranks=False):
-
定义了一个名为
monitored_barrier
的函数,用于同步所有进程,并支持超时机制和故障检测功能。 -
group=GroupMember.WORLD
指定进程组,默认为整个世界组。 -
timeout=None
设置超时时间,默认为None
。 -
wait_all_ranks=False
是否等待所有进程完成,默认为False
。
函数内部逻辑
python
运行复制
if _rank_not_in_group(group):
_warn_not_in_group("monitored_barrier")
return
- 检查当前进程是否在指定组中。如果不在组中,发出警告并返回。
python
运行复制
if get_backend(group) != Backend.GLOO:
raise ValueError("monitored_barrier is only implemented for GLOO backend.")
- 检查通信后端是否为
GLOO
。如果不是,则抛出异常,因为monitored_barrier
仅支持GLOO
。
python
运行复制
if timeout is None:
timeout = _get_default_timeout(get_backend(group))
elif isinstance(timeout, float):
warnings.warn(...)
timeout = timedelta(seconds=timeout)
-
如果未指定超时时间,使用默认超时时间。
-
如果超时时间是浮点数,将其转换为
timedelta
类型。
python
运行复制
group_to_use = _get_default_group() if group is None else group
return group_to_use.monitored_barrier(timeout, wait_all_ranks=wait_all_ranks)
-
确定使用的进程组,如果未指定,使用默认组。
-
调用进程组的
monitored_barrier
方法,传入超时和等待选项。
_create_process_group_wrapper
函数
函数签名
python
运行复制
def _create_process_group_wrapper(
wrapped_pg: torch._C._distributed_c10d.Backend,
store_prefix: str,
store: Store,
rank: int,
world_size: int,
timeout: timedelta = default_pg_timeout,
):
-
定义了一个函数,用于创建一个包装的进程组。
-
wrapped_pg
被包装的后端(如GLOO
或NCCL
)。 -
store_prefix
用于存储的前缀。 -
store
存储对象,用于分布式通信。 -
rank
当前进程的 rank。 -
world_size
总的进程数。 -
timeout
超时时间,默认为default_pg_timeout
。
函数内部逻辑
python
运行复制
prefix = f"{PG_WRAPPER_STORE_PREFIX}:{store_prefix}"
store = PrefixStore(prefix, store)
helper_pg = ProcessGroupGloo(store, rank, world_size, timeout=timeout)
wrapped_pg = _ProcessGroupWrapper(wrapped_pg, helper_pg)
return wrapped_pg
-
prefix
生成带有前缀的存储键。 -
store = PrefixStore(...)
创建一个带有前缀的存储对象。 -
helper_pg = ProcessGroupGloo(...)
使用GLOO
后端创建辅助进程组。 -
wrapped_pg = _ProcessGroupWrapper(...)
将原始进程组和辅助进程组包装起来。 -
return wrapped_pg
返回包装后的进程组。
_hash_ranks
函数
python
运行复制
def _hash_ranks(ranks: List[int]):
return hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest()
-
定义一个函数,用于对进程的
rank
列表进行哈希。 -
"_".join(map(str, ranks))
将rank
列表转换为字符串,并用下划线连接。 -
hashlib.sha1(...).hexdigest()
使用 SHA-1 算法对字符串计算哈希值,并返回十六进制字符串。
_process_group_color
函数
python
运行复制
def _process_group_color(ranks: List[int]) -> int:
return int(_hash_ranks(ranks), 16) % (sys.maxsize >> 1)
-
定义一个函数,基于
rank
列表计算进程组的颜色(整数值)。 -
_hash_ranks(ranks)
调用前面的_hash_ranks
函数,获取哈希值。 -
int(..., 16)
将哈希值从十六进制字符串转换为整数。 -
% (sys.maxsize >> 1)
取模,确保返回的整数是非负值。
@_exception_logger
def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None):
"""
Synchronize all processes.
This collective blocks processes until the whole group enters this function,
if async_op is False, or if async work handle is called on wait().
Args:
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
async_op (bool, optional): Whether this op should be an async op
device_ids ([int], optional): List of device/GPU ids.
Returns:
Async work handle, if async_op is set to True.
None, if not async_op or if not part of the group
.. note:: `ProcessGroupNCCL` now relies on stream synchronization instead of
device synchronization to block the CPU. Thus, please do not assume that
`barrier()` would perform a device synchronization.
"""
if _rank_not_in_group(group):
_warn_not_in_group("barrier")
return
opts = BarrierOptions()
opts.device = _get_pg_default_device(group)
if device_ids is not None:
if isinstance(device_ids, list):
opts.device_ids = device_ids
else:
raise TypeError(
"Invalid function argument: device_ids type should be List[int]"
)
if group is None:
default_pg = _get_default_group()
work = default_pg.barrier(opts=opts)
else:
work = group.barrier(opts=opts)
if async_op:
return work
else:
work.wait()
def monitored_barrier(group=GroupMember.WORLD, timeout=None, wait_all_ranks=False):
"""
Synchronize processes similar to ``torch.distributed.barrier``, but consider a configurable timeout.
It is able to report ranks that did not pass this barrier within the provided timeout.
Specifically, for non-zero ranks, will block until a send/recv is processed from rank 0.
Rank 0 will block until all send /recv from other ranks are processed, and will report
failures for ranks that failed to respond in time. Note that if one rank does not reach the
monitored_barrier (for example due to a hang), all other ranks would fail in monitored_barrier.
This collective will block all processes/ranks in the group, until the
whole group exits the function successfully, making it useful for debugging
and synchronizing. However, it can have a performance impact and should only
be used for debugging or scenarios that require full synchronization points
on the host-side. For debugging purposes, this barrier can be inserted
before the application's collective calls to check if any ranks are
desynchronized.
.. note:: Note that this collective is only supported with the GLOO backend.
Args:
group (ProcessGroup, optional): The process group to work on. If
``None``, the default process group will be used.
timeout (datetime.timedelta, optional): Timeout for monitored_barrier.
If ``None``, the default process group timeout will be used.
wait_all_ranks (bool, optional): Whether to collect all failed ranks or
not. By default, this is ``False`` and ``monitored_barrier`` on rank 0
will throw on the first failed rank it encounters in order to fail
fast. By setting ``wait_all_ranks=True`` ``monitored_barrier`` will
collect all failed ranks and throw an error containing information
about all failed ranks.
Returns:
``None``.
Example::
>>> # xdoctest: +SKIP("need process group init")
>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> if dist.get_rank() != 1:
>>> dist.monitored_barrier() # Raises exception indicating that
>>> # rank 1 did not call into monitored_barrier.
>>> # Example with wait_all_ranks=True
>>> if dist.get_rank() == 0:
>>> dist.monitored_barrier(wait_all_ranks=True) # Raises exception
>>> # indicating that ranks 1, 2, ... world_size - 1 did not call into
>>> # monitored_barrier.
"""
# Need to call rank not in group before using the group, otherwise
# "Invalid process group" error is raised.
if _rank_not_in_group(group):
_warn_not_in_group("monitored_barrier")
return
if get_backend(group) != Backend.GLOO:
raise ValueError("monitored_barrier is only implemented for GLOO backend.")
if timeout is None:
timeout = _get_default_timeout(get_backend(group))
elif isinstance(timeout, float):
# TODO(whc) aparently some existing test case for monitored_barrier passes in a timeout in float format?
warnings.warn(
"Please specify timeout arg as a timedelta. "
f"Converting current value of {timeout} assuming it represents seconds",
)
timeout = timedelta(seconds=timeout)
_check_valid_timeout(timeout)
group_to_use = _get_default_group() if group is None else group
return group_to_use.monitored_barrier(timeout, wait_all_ranks=wait_all_ranks)
def _create_process_group_wrapper(
wrapped_pg: torch._C._distributed_c10d.Backend,
store_prefix: str,
store: Store,
rank: int,
world_size: int,
timeout: timedelta = default_pg_timeout,
):
# (whc) this appears to be just for the gloo backend? if so, `default_pg_timeout` is appropriate...
# Create a separate prefix store for the helper process group.
prefix = f"{PG_WRAPPER_STORE_PREFIX}:{store_prefix}"
store = PrefixStore(prefix, store)
helper_pg = ProcessGroupGloo(store, rank, world_size, timeout=timeout)
# Wrap the underlying pg with ProcessGroupWrapper.
wrapped_pg = _ProcessGroupWrapper(wrapped_pg, helper_pg)
return wrapped_pg
# helper function for deterministically hashing a list of ranks
def _hash_ranks(ranks: List[int]):
return hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest()
# Takes a list of ranks and computes an integer color
def _process_group_color(ranks: List[int]) -> int:
# Convert our hash to an int, but avoid negative numbers by shifting a bit.
return int(_hash_ranks(ranks), 16) % (sys.maxsize >> 1)
标签:group,barrier,ranks,torch,distributed,源码,pg,timeout,device
From: https://blog.csdn.net/sinat_37574187/article/details/144510945