__all__ = ["is_hccl_available", "reinit_process_group"]
import logging
import os
import warnings
from datetime import timedelta
from functools import wraps
import torch
import torch.distributed as dist
import torch.distributed.distributed_c10d as dist_c10d
from torch._C._distributed_c10d import (
_DistributedBackendOptions,
_register_process_group,
FakeProcessGroup,
PrefixStore,
)
from torch.distributed.distributed_c10d import (
_check_p2p_op_list,
_check_single_tensor,
_check_tensor_list,
_check_valid_timeout,
_coalescing_manager,
_create_process_group_wrapper,
_ensure_all_tensors_same_dtype,
_find_pg_by_ranks_and_tag,
_get_default_group,
_get_group_size,
_get_object_coll_device,
_get_split_source,
_GLOO_AVAILABLE,
_object_to_tensor,
_process_group_color,
_rank_not_in_group,
_tensor_to_object,
_unregister_all_process_groups,
_unregister_process_group,
_update_default_pg,
_validate_output_list_for_rank,
_warn_not_in_group,
all_gather,
Backend,
BackendConfig,
DebugLevel,
GatherOptions,
get_debug_level,
get_group_rank,
get_rank,
get_world_size,
GroupMember,
is_gloo_available,
is_initialized,
is_mpi_available,
is_nccl_available,
is_ucc_available,
is_xccl_available,
isend,
ProcessGroup,
ReduceScatterOptions,
)
from torch.distributed.elastic.rendezvous import RendezvousParameters
from torch_npu import npu
from torch_npu.utils import get_cann_version
from torch_npu.utils._error_code import dist_error, ErrCode
if is_mpi_available():
from torch.distributed.distributed_c10d import ProcessGroupMPI
if is_nccl_available():
from torch.distributed.distributed_c10d import ProcessGroupNCCL
if is_gloo_available():
from torch.distributed.distributed_c10d import ProcessGroupGloo
if is_ucc_available():
from torch.distributed.distributed_c10d import ProcessGroupUCC
if is_xccl_available():
from torch.distributed.distributed_c10d import ProcessGroupXCCL
logger = logging.getLogger("torch.distributed")
origin_get_sequence_number_for_group = ProcessGroup._get_sequence_number_for_group
cann_version = get_cann_version("CANN")
npu_device_name = None
def _batch_isend_irecv(p2p_op_list):
group = p2p_op_list[0].group
device = p2p_op_list[0].tensor.device
is_multi_pg = True
def peer_kwarg(op):
key = "group_dst" if op.op is isend else "group_src"
return {key: op.group_peer}
if device.type == "cuda":
with _coalescing_manager(group, device, async_ops=True) as cm:
for p2p_op in p2p_op_list:
p2p_op.op(p2p_op.tensor, p2p_op.peer, p2p_op.group, p2p_op.tag)
return cm.works
elif device.type == "npu":
global npu_device_name
npu_device_name = npu.get_device_name()
if group is None:
group = _get_default_group()
is_multi_pg = False
_group = group._get_backend(device)
is_supported_device_name = (
npu_device_name >= "Ascend910B1" and npu_device_name <= "Ascend910B4_1"
) or (
npu_device_name >= "Ascend910_9362" and npu_device_name <= "Ascend910_9391"
)
is_fake_backend = isinstance(_group, FakeProcessGroup)
if (
(isinstance(group, ProcessGroup)
and cann_version >= "9.0.0"
and is_supported_device_name) or is_fake_backend
):
_check_p2p_op_list(p2p_op_list)
with _coalescing_manager(group=group, device=device, async_ops=True) as cm:
for p2p_op in p2p_op_list:
p2p_op.op(
p2p_op.tensor,
group=p2p_op.group,
tag=p2p_op.tag,
**peer_kwarg(p2p_op),
)
return cm.works
else:
op_type = []
tensors = []
remote_rank_list = []
for p2p_op in p2p_op_list:
if p2p_op.tensor.device.type != "npu":
deviceType = p2p_op.tensor.device.type
raise RuntimeError(
f"No backend type associated with device type {deviceType}"
+ dist_error(ErrCode.PARAM)
)
op_type.append(p2p_op.op.__name__)
tensors.append(p2p_op.tensor)
rank_for_op = (
get_group_rank(group, p2p_op.peer) if is_multi_pg else p2p_op.peer
)
remote_rank_list.append(rank_for_op)
return [_group.batch_isend_irecv(op_type, tensors, remote_rank_list)]
else:
reqs = []
for p2p_op in p2p_op_list:
work = p2p_op.op(p2p_op.tensor, p2p_op.peer, p2p_op.group, p2p_op.tag)
if work:
reqs.append(work)
return reqs
def _gather(tensor, gather_list=None, dst=0, group=None, async_op=False):
"""
Gathers a list of tensors in a single process.
Args:
tensor (Tensor): Input tensor.
gather_list (list[Tensor], optional): List of appropriately-sized
tensors to use for gathered data (default is None, must be specified
on the destination rank)
dst (int, optional): Destination rank (default is 0)
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
async_op (bool, optional): Whether this op should be an async op
Returns:
Async work handle, if async_op is set to True.
None, if not async_op or if not part of the group
Note:
Npu replaces gather with all_gather in default mode, uses gather with group send/recv in compatibility mode
"""
_check_single_tensor(tensor, "tensor")
if gather_list:
_check_tensor_list(gather_list, "gather_list")
else:
gather_list = []
_ensure_all_tensors_same_dtype(tensor, gather_list)
if _rank_not_in_group(group):
_warn_not_in_group("gather")
return None
my_rank = get_rank()
_validate_output_list_for_rank(my_rank, dst, gather_list)
group_size = _get_group_size(group)
recv_size_list = (
[None for _ in range(group_size)]
if my_rank != dst
else [tensor.size() for tensor in gather_list]
)
input_tensors = [tensor]
opts = GatherOptions()
opts.rootRank = dst
use_compatible_impl = False
if tensor.device.type == "npu":
use_compatible_impl = npu.are_compatible_impl_enabled()
if group is None or group is GroupMember.WORLD:
default_pg = _get_default_group()
if tensor.device.type == "npu":
if use_compatible_impl:
output_tensors = [gather_list] if my_rank == dst else []
_group = default_pg._get_backend(torch.device("npu"))
work = _group.gather(output_tensors, input_tensors, opts)
else:
if my_rank == dst:
warnings.warn(
"HCCL doesn't support gather at the moment. Implemented with allgather instead."
)
dist.broadcast_object_list(recv_size_list, dst, group)
if not gather_list:
gather_list = [
torch.empty(tensor_size, dtype=tensor.dtype).npu()
for tensor_size in recv_size_list
]
output_tensors = [gather_list]
_group = default_pg._get_backend(torch.device("npu"))
work = _group.allgather(output_tensors, input_tensors)
else:
output_tensors = [gather_list] if dst == my_rank else []
default_pg = _get_default_group()
work = default_pg.gather(output_tensors, input_tensors, opts)
else:
if tensor.device.type == "npu":
if use_compatible_impl:
output_tensors = [gather_list] if my_rank == dst else []
_group = group._get_backend(torch.device("npu"))
work = _group.gather(output_tensors, input_tensors, opts)
else:
if my_rank == dst:
warnings.warn(
"HCCL doesn't support gather at the moment. Implemented with allgather instead."
)
dist.broadcast_object_list(recv_size_list, dst, group)
if not gather_list:
gather_list = [
torch.empty(tensor_size, dtype=tensor.dtype).npu()
for tensor_size in recv_size_list
]
output_tensors = [gather_list]
_group = group._get_backend(torch.device("npu"))
work = _group.allgather(output_tensors, input_tensors)
else:
group_dst_rank = get_group_rank(group, dst)
output_tensors = [gather_list] if dst == my_rank else []
opts.rootRank = group_dst_rank
work = group.gather(output_tensors, input_tensors, opts)
if async_op:
return work
else:
work.wait()
return None
def _gather_object(obj, object_gather_list=None, dst=0, group=None):
"""
Note:
Avoid gather_object to use gather func defined in origin distributed_c10d.
"""
if _rank_not_in_group(group):
_warn_not_in_group("gather_object")
return
my_rank = get_rank()
_validate_output_list_for_rank(my_rank, dst, object_gather_list)
current_device = _get_object_coll_device(group)
input_tensor, local_size = _object_to_tensor(obj, current_device, group)
group_size = get_world_size(group=group)
object_sizes_tensor = torch.zeros(
group_size, dtype=torch.long, device=current_device
)
object_size_list = [
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
]
all_gather(object_size_list, local_size, group=group)
max_object_size = int(max(object_size_list).item())
input_tensor.resize_(max_object_size)
if my_rank == dst:
coalesced_output_tensor = torch.empty(
max_object_size * group_size, dtype=torch.uint8, device=current_device
)
output_tensors = [
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
for i in range(group_size)
]
_gather(
input_tensor,
gather_list=output_tensors if my_rank == dst else None,
dst=dst,
group=group,
)
if my_rank != dst:
return
for i, tensor in enumerate(output_tensors):
tensor = tensor.type(torch.uint8)
tensor_size = object_size_list[i]
object_gather_list[i] = _tensor_to_object(tensor, tensor_size, group)
def is_hccl_available():
return "hccl" in Backend.backend_list
def _clear_pg_cache_in_torch(group: ProcessGroup):
if dist_c10d._world.pg_map.get(group) is not None:
del dist_c10d._world.pg_map[group]
if dist_c10d._world.pg_names.get(group) is not None:
del dist_c10d._world.pg_names[group]
if dist_c10d._world.pg_group_ranks.get(group) is not None:
del dist_c10d._world.pg_group_ranks[group]
if dist_c10d._world.pg_backend_config.get(group) is not None:
del dist_c10d._world.pg_backend_config[group]
if dist_c10d._world.pg_to_tag.get(group) is not None:
del dist_c10d._world.pg_to_tag[group]
tags_list = [
key for key, value in dist_c10d._world.tags_to_pg.items() if group in value
]
if len(tags_list) > 0:
for tag in tags_list:
del dist_c10d._world.tags_to_pg[tag]
_unregister_process_group(group.group_name)
def reinit_process_group(group=None, rebuild_link=True):
device_id = torch.npu.current_device()
logger.info(
"reinit process group, group=%s, rebuild link=%s, device=%s",
group,
rebuild_link,
device_id,
)
if group is None:
group = dist_c10d._world.default_pg
if not rebuild_link:
npu_device = torch.device("npu")
for pg in dist_c10d._pg_map:
if npu_device in pg._device_types:
pg._get_backend(npu_device).resume_hccl_comm(device_id)
logger.info("resume hccl comm end, device_id=%s", device_id)
return None
else:
backend = dist_c10d.Backend(dist_c10d._world.pg_map[group][0])
if "hccl" in backend:
logger.info(
"reinit hccl comm start, group=%s, device_id=%s", group, device_id
)
group._get_backend(torch.device("npu"))._delete_tcpstore_key()
group._get_backend(torch.device("npu")).abort_hccl_comm("reinit")
logger.info("reinit hccl comm end, group=%s, device_id=%s", group, device_id)
return group
def _comm_switch_nic(ranks, useBackup):
nRanks = len(ranks)
npu_device = torch.device("npu")
rankid = int(os.environ["RANK"])
result = True
for pg in dist_c10d._pg_map:
if npu_device in pg._device_types:
presult = pg._get_backend(npu_device)._set_switch_nic_comm(
rankid, nRanks, ranks, useBackup
)
if not presult:
result = False
return result
def _reduce_scatter_tensor_uneven(
output,
input,
input_split_sizes=None,
op=dist.ReduceOp.SUM,
group=None,
async_op=False,
):
if _rank_not_in_group(group):
_warn_not_in_group("reduce_scatter_tensor_uneven")
return None
if output.device.type != "npu" or input.device.type != "npu":
warnings.warn("Support for Tensors is limited to those of type npu")
return None
if group is None or group is GroupMember.WORLD:
group = _get_default_group()
group = group._get_backend(torch.device("npu"))
opts = ReduceScatterOptions()
opts.reduceOp = op
input_split_sizes = [] if input_split_sizes is None else input_split_sizes
work = group.reduce_scatter_tensor_uneven(output, input, input_split_sizes, opts)
if async_op:
return work
else:
work.wait()
return None
def _all_gather_into_tensor_uneven(
output, input, output_split_sizes=None, group=None, async_op=False
):
if _rank_not_in_group(group):
_warn_not_in_group("all_gather_into_tensor_uneven")
return None
if output.device.type != "npu" or input.device.type != "npu":
warnings.warn("Support for Tensors is limited to those of type npu")
return None
if group is None or group is GroupMember.WORLD:
group = _get_default_group()
group = group._get_backend(torch.device("npu"))
output_split_sizes = [] if output_split_sizes is None else output_split_sizes
work = group.all_gather_into_tensor_uneven(output, input, output_split_sizes)
if async_op:
return work
else:
work.wait()
return None
def _trigger__get_addr_and_port_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if (
len(args) > 0
and isinstance(args[0], RendezvousParameters)
and args[0].backend == "parallel"
):
args[0].backend = "static"
master_addr, master_port = func(*args, **kwargs)
args[0].backend = "parallel"
return master_addr, master_port
else:
return func(*args, **kwargs)
return wrapper
def _trigger_rendezvous_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
use_parallel = os.getenv("TORCH_NPU_USE_PARALLEL_TCPSTORE", "False")
if use_parallel == "True":
if len(args) > 0 and args[0] == "env://":
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
master_port = os.getenv("MASTER_PORT", "29500")
args = (f"parallel://{master_addr}:{master_port}",) + args[1:]
logger.info(
"torch_npu_run change the rendezvous url from env:// to %s", args[0]
)
return func(*args, **kwargs)
return wrapper
def _destructor_process_group():
world = dist_c10d._world
for pg_to_shutdown in sorted(
world.pg_names, key=lambda x: world.pg_names[x], reverse=True
):
try:
pg_to_shutdown.shutdown()
except Exception:
logger.warning(
"Failed to shutdown process group during NPU exit.", exc_info=True
)
_update_default_pg(None)
world.pg_map.clear()
world.pg_names.clear()
world.pg_group_ranks.clear()
world.pg_backend_config.clear()
world.pg_to_tag.clear()
world.tags_to_pg.clear()
world.pg_coalesce_state.clear()
_unregister_all_process_groups()
world.group_count = 0
def _hccl_get_sequence_number_for_group(self):
backend = torch.distributed.get_backend_config(self)
if backend == "hccl" or backend == "npu:hccl":
return self._get_backend(torch.device("npu"))._get_sequence_number_for_group()
else:
return origin_get_sequence_number_for_group(self)
def _patched_new_process_group_helper(
group_size,
group_rank,
global_ranks_in_group,
backend,
store,
group_name,
backend_options=None,
timeout=None,
pg_tag=None,
device_id=None,
group_desc=None,
):
"""
Create a new distributed process group.
This function must be called by ALL processes in the global group, even if
the calling process is not part of the newly created group. In that case,
this function returns GroupMember.NON_GROUP_MEMBER.
This function is called with ``global_ranks_in_group == []`` for the default group.
"""
if group_name in dist_c10d._world.pg_names.values():
raise ValueError(
"The specified group name has already been "
"created, please use a different group name"
)
if device_id is not None and (device_id.index is None or device_id.type == "cpu"):
raise ValueError(
"init_process_group device_id parameter must be an accelerator with an index"
)
_check_valid_timeout(timeout)
if pg_tag not in [None, ""]:
existing_group = _find_pg_by_ranks_and_tag(pg_tag, global_ranks_in_group)
if existing_group:
_, prefix_store = dist_c10d._world.pg_map[existing_group]
return existing_group, prefix_store
group_desc = "undefined" if group_desc is None else group_desc
is_default_group = len(global_ranks_in_group) == 0
if is_initialized() and _get_default_group().bound_device_id:
split_from = _get_split_source(_get_default_group())
else:
split_from = None
if not is_default_group:
global_rank = _get_default_group().rank()
if global_rank not in global_ranks_in_group:
if split_from:
split_from.perform_nocolor_split(_get_default_group().bound_device_id)
return GroupMember.NON_GROUP_MEMBER, None
prefix_store = PrefixStore(f"{group_name}/", store)
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
backend_config = BackendConfig(backend)
if "," not in str(backend) and ":" not in str(backend):
if backend not in Backend.backend_type_map:
raise ValueError(f"Unknown backend type {backend}")
if backend == Backend.UNDEFINED:
if Backend.NCCL in backend_config.get_device_backend_map().values():
pg._set_default_backend(ProcessGroup.BackendType.NCCL)
elif Backend.HCCL in backend_config.get_device_backend_map().values():
pg._set_default_backend(ProcessGroup.BackendType.CUSTOM)
else:
pg._set_default_backend(ProcessGroup.BackendType.GLOO)
else:
pg._set_default_backend(Backend.backend_type_map[backend])
else:
if Backend.NCCL in backend_config.device_backend_map.values():
pg._set_default_backend(ProcessGroup.BackendType.NCCL)
elif Backend._plugins.keys():
custom_backend = next(iter(Backend._plugins.keys()))
if custom_backend in backend_config.device_backend_map.values():
pg._set_default_backend(ProcessGroup.BackendType.CUSTOM)
else:
pg._set_default_backend(ProcessGroup.BackendType.GLOO)
if device_id:
pg.bound_device_id = device_id
backend_class: torch._C._distributed_c10d.Backend
for device, backend_str in backend_config.get_device_backend_map().items():
backend_prefix_store = PrefixStore(f"{device}/", prefix_store)
if backend_str == Backend.MPI:
if not is_mpi_available():
raise RuntimeError(
"Distributed package doesn't have MPI built in."
" MPI is only included if you build PyTorch from"
" source on a host that has MPI installed."
)
backend_class = ProcessGroupMPI.create(global_ranks_in_group)
backend_type = ProcessGroup.BackendType.MPI
if not backend_class:
return GroupMember.NON_GROUP_MEMBER, None
if pg.rank() == -1 and pg.size() == -1:
pg = ProcessGroup(
backend_prefix_store,
backend_class.rank(),
backend_class.size(),
)
pg._set_default_backend(backend_type)
elif backend_str == Backend.GLOO:
backend_class = ProcessGroupGloo(
backend_prefix_store, group_rank, group_size, timeout=timeout
)
backend_type = ProcessGroup.BackendType.GLOO
elif backend_str == Backend.NCCL:
if not is_nccl_available():
raise RuntimeError("Distributed package doesn't have NCCL built in")
if backend_options is not None:
if not isinstance(backend_options, ProcessGroupNCCL.Options):
raise AssertionError(
"Expected backend_options argument to be of type ProcessGroupNCCL.Options"
)
if backend_options._timeout != timeout:
warnings.warn(
"backend_options._timeout was specified, "
"but timeout kwarg has a default value that will always override it. "
)
else:
backend_options = ProcessGroupNCCL.Options()
backend_options.is_high_priority_stream = False
backend_options._timeout = timeout
if split_from:
backend_options.split_from = split_from
backend_options.split_color = _process_group_color(
global_ranks_in_group
)
backend_options.global_ranks_in_group = global_ranks_in_group
backend_options.group_name = group_name
backend_class = ProcessGroupNCCL(
backend_prefix_store, group_rank, group_size, backend_options
)
backend_type = ProcessGroup.BackendType.NCCL
elif backend_str == Backend.UCC and is_ucc_available():
backend_class = ProcessGroupUCC(
backend_prefix_store, group_rank, group_size, timeout=timeout
)
backend_type = ProcessGroup.BackendType.UCC
elif backend_str == Backend.XCCL:
if not is_xccl_available():
raise RuntimeError("Distributed package doesn't have XCCL built in")
backend_class = ProcessGroupXCCL(
backend_prefix_store, group_rank, group_size
)
backend_type = ProcessGroup.BackendType.XCCL
else:
if backend_str.upper() not in Backend._plugins:
raise ValueError(f"Unknown c10d backend type {backend_str.upper()}")
backend_plugin = Backend._plugins[backend_str.upper()]
creator_fn = backend_plugin.creator_fn
extended_api = backend_plugin.extended_api
backend_type = ProcessGroup.BackendType.CUSTOM
if not extended_api:
backend_class = creator_fn(
backend_prefix_store, group_rank, group_size, timeout
)
else:
dist_backend_opts = _DistributedBackendOptions()
dist_backend_opts.store = backend_prefix_store
dist_backend_opts.group_rank = group_rank
dist_backend_opts.group_size = group_size
dist_backend_opts.timeout = timeout
dist_backend_opts.group_id = group_name
dist_backend_opts.global_ranks_in_group = global_ranks_in_group
backend_class = creator_fn(dist_backend_opts, backend_options)
if backend_str == Backend.GLOO:
if not isinstance(backend_class, ProcessGroupGloo):
raise AssertionError("Expected backend_class to be ProcessGroupGloo")
backend_class._set_sequence_number_for_group()
elif backend_str == Backend.NCCL:
if not isinstance(backend_class, ProcessGroupNCCL):
raise AssertionError("Expected backend_class to be ProcessGroupNCCL")
backend_class._set_sequence_number_for_group()
if issubclass(type(backend_class), ProcessGroup):
pg = backend_class
break
if (
backend_str in [Backend.GLOO, Backend.NCCL, Backend.UCC]
or backend_str.upper() in Backend._plugins
):
if get_debug_level() == DebugLevel.DETAIL:
if not _GLOO_AVAILABLE:
logger.info(
"""TORCH_DISTRIBUTED_DEBUG was set to DETAIL, but
GLOO is not available. Build with Gloo to
create a wrapper process group in debug mode
to aid collective desynchronization debugging."""
)
else:
backend_class = _create_process_group_wrapper(
wrapped_pg=backend_class,
store_prefix=group_name,
store=backend_prefix_store,
rank=group_rank,
world_size=group_size,
timeout=timeout,
)
if len(set(backend_config.get_device_backend_map().values())) == 1:
for device in backend_config.get_device_backend_map():
pg._register_backend(torch.device(device), backend_type, backend_class)
break
pg._register_backend(torch.device(device), backend_type, backend_class)
if group_name is None:
raise AssertionError("group_name should not be None")
if group_desc is None:
raise AssertionError("group_desc should not be None")
pg._set_group_name(group_name)
pg._set_group_desc(group_desc)
if device_id and pg._get_backend(device_id).supports_splitting:
eager_backend = pg._get_backend(device_id)
eager_backend.eager_connect_single_device(device_id)
dist_c10d._world.pg_map[pg] = (backend, prefix_store)
dist_c10d._world.pg_names[pg] = group_name
_register_process_group(group_name, pg)
dist_c10d._world.pg_backend_config[pg] = str(backend_config)
if pg_tag in [None, ""]:
pg_tag = f"ptd:{group_name}"
dist_c10d._world.tags_to_pg.setdefault("", []).append(pg)
else:
pg_tag = f"user:{pg_tag}"
dist_c10d._world.tags_to_pg.setdefault(pg_tag, []).append(pg)
dist_c10d._world.pg_to_tag[pg] = pg_tag
return pg, prefix_store
torch.distributed.distributed_c10d._new_process_group_helper = (
_patched_new_process_group_helper
)
def _hccl_add_ephemeral_timeout_for_all_pgs(timeout: timedelta) -> None:
"""
This API adds an ephemeral timeout extension for all PGs locally
on one rank. The timeout gets reset when the first collective issued
after API called finished.
NOTE: We only support to set timeout for hccl backends for now.
NOTE: While this feature provides flexibility in specific scenarios,
it introduces statefulness
to timeout setting. Therefore, it is advisable to use this API sparingly
and consider alternative approaches, such as directly setting the timeout
or utilizing a barrier collective (one can set any timeout to the barrier),
whenever feasible.
Args:
timeout (timedelta): The delta of timeout to extend.
Returns:
None.
"""
if not is_hccl_available():
return
try:
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
except ImportError:
return
for pg in dist_c10d._world.pg_map:
devices = pg._device_types
if torch.device("npu") in devices:
backend = pg._get_backend(torch.device("npu"))
if isinstance(backend, ProcessGroupHCCL) and hasattr(
backend, "_add_ephemeral_timeout"
):
backend._add_ephemeral_timeout(timeout)