import torch
import torch.distributed.launcher.api
import torch_npu
from torch_npu._init.patches.patch_manager import PatchManager
_INTERNAL_REPLACEMENTS = [
(
"_C._distributed_c10d._verify_params_across_processes",
"distributed._verify_params_across_processes",
),
(
"_C._distributed_c10d.ProcessGroup._get_sequence_number_for_group",
"distributed.distributed_c10d._hccl_get_sequence_number_for_group",
),
(
"distributed.distributed_c10d._add_ephemeral_timeout_for_all_pgs",
"distributed.distributed_c10d._hccl_add_ephemeral_timeout_for_all_pgs",
),
]
_PUBLIC_API_ALIASES = [
(
"distributed.batch_isend_irecv",
"distributed.distributed_c10d._batch_isend_irecv",
),
(
"distributed.distributed_c10d.batch_isend_irecv",
"distributed.distributed_c10d._batch_isend_irecv",
),
(
"distributed.gather",
"distributed.distributed_c10d._gather",
),
(
"distributed.distributed_c10d.gather",
"distributed.distributed_c10d._gather",
),
(
"distributed.gather_object",
"distributed.distributed_c10d._gather_object",
),
(
"distributed.distributed_c10d.gather_object",
"distributed.distributed_c10d._gather_object",
),
(
"distributed.is_hccl_available",
"distributed.is_hccl_available",
),
(
"distributed.reinit_process_group",
"distributed.reinit_process_group",
),
]
def _resolve_attr(root, attr_path: str):
obj = root
for part in attr_path.split("."):
obj = getattr(obj, part)
return obj
def _assign_attr(target_root, target_path: str, source_root, source_path: str):
parts = target_path.split(".")
owner = (
_resolve_attr(target_root, ".".join(parts[:-1]))
if len(parts) > 1
else target_root
)
setattr(owner, parts[-1], _resolve_attr(source_root, source_path))
def _apply_internal_replacements(torch, torch_npu):
"""
Replace PyTorch internal distributed implementations with NPU/HCCL versions.
Example:
torch._C._distributed_c10d._verify_params_across_processes
-> torch_npu.distributed._verify_params_across_processes
"""
for target_path, source_path in _INTERNAL_REPLACEMENTS:
_assign_attr(torch, target_path, torch_npu, source_path)
def _apply_public_api_aliases(torch, torch_npu):
"""
Patch torch.distributed public APIs with torch_npu implementations.
Example:
torch.distributed.gather(...) -> torch_npu.distributed.distributed_c10d._gather(...)
"""
for target_path, source_path in _PUBLIC_API_ALIASES:
_assign_attr(torch, target_path, torch_npu, source_path)
def _apply_wrapped_functions(torch, torch_npu):
"""
Wrap PyTorch distributed helpers with torch_npu NPU/HCCL logic.
Example:
torch.distributed.distributed_c10d.rendezvous(...)
-> torch_npu.distributed.distributed_c10d._trigger_rendezvous_decorator(...)
torch.distributed.launcher.api._get_addr_and_port(...)
-> torch_npu.distributed.distributed_c10d._trigger__get_addr_and_port_decorator(...)
"""
torch.distributed.distributed_c10d.rendezvous = (
torch_npu.distributed.distributed_c10d._trigger_rendezvous_decorator(
torch.distributed.distributed_c10d.rendezvous
)
)
torch.distributed.launcher.api._get_addr_and_port = (
torch_npu.distributed.distributed_c10d._trigger__get_addr_and_port_decorator(
torch.distributed.launcher.api._get_addr_and_port
)
)
def _apply_sharded_grad_scaler_patch(torch):
"""
Replace PyTorch FSDP ShardedGradScaler with torch_npu implementation.
Example:
torch.distributed.fsdp.sharded_grad_scaler.ShardedGradScaler
-> torch_npu.npu.amp.sharded_grad_scaler._ShardedGradScaler
"""
from torch.distributed.fsdp import sharded_grad_scaler
from torch_npu.npu.amp.sharded_grad_scaler import _ShardedGradScaler
sharded_grad_scaler.ShardedGradScaler = _ShardedGradScaler
@PatchManager.register_patch("distributed")
def apply_distributed_methods_patch():
"""
Patch PyTorch distributed APIs with torch_npu NPU/HCCL implementations.
Categories:
1. Internal replacements
2. Public API aliases
3. Wrapped functions
"""
_apply_internal_replacements(torch, torch_npu)
_apply_public_api_aliases(torch, torch_npu)
_apply_sharded_grad_scaler_patch(torch)
_apply_wrapped_functions(torch, torch_npu)