__all__ = []
import torch
import torch.distributed as dist
import torch.distributed.rpc as rpc
from torch._C import _get_privateuse1_backend_name
from torch.distributed.rpc import api
from torch.distributed.rpc import constants as rpc_constants
import torch_npu._C
from torch_npu.utils._error_code import ErrCode, dist_error
def _get_device_count_info():
device_count = dict()
custom_backend_name = _get_privateuse1_backend_name()
if hasattr(torch, custom_backend_name):
custom_device_count_func = torch.utils.backend_registration._get_custom_mod_func("device_count")
custom_device_count = custom_device_count_func() if custom_device_count_func else 0
device_count[custom_backend_name] = custom_device_count
return device_count
def _init_device_state(custom_backend_name):
if getattr(torch, custom_backend_name).is_available():
getattr(torch, custom_backend_name).init()
def _tensorpipe_validate_devices(devices, device_count):
return all(
d.type == "cpu" or (0 <= d.index < device_count.get(d.type, 0))
for d in devices
)
def _validate_device_maps(
all_names, all_device_counts, all_device_maps, all_devices, is_static_group=True
):
for node in all_names:
devices = all_devices[node]
if len(set(devices)) != len(devices):
raise ValueError(
f"Node {node} has duplicated devices\n"
f"devices = {devices}" + dist_error(ErrCode.VALUE)
)
if not _tensorpipe_validate_devices(devices, all_device_counts[node]):
raise ValueError(
f"Node {node} has devices with invalid indices\n"
f"devices = {devices}\n"
f"device count = {all_device_counts[node]}" + dist_error(ErrCode.VALUE)
)
for source_node in all_names:
if is_static_group and not set(all_device_maps[source_node].keys()).issubset(all_names):
raise ValueError(
f"Node {source_node} has invalid target node names in its device maps\n"
f"device maps = {all_device_maps[source_node].keys()}\n"
f"node names = {all_names}" + dist_error(ErrCode.VALUE)
)
for target_node, map_ in all_device_maps[source_node].items():
if len(set(map_.values())) != len(map_):
raise ValueError(
f"Node {source_node} has duplicated target devices "
f"in its device map for {target_node}\n"
f"device map = {map_}" + dist_error(ErrCode.VALUE)
)
if all_devices[source_node]:
if not set(map_.keys()).issubset(all_devices[source_node]):
raise ValueError(
f"Node {source_node} has unexpected source devices "
f"in its device map for {target_node}\n"
f"device map = {map_}\n"
f"devices = {all_devices[source_node]}" + dist_error(ErrCode.VALUE)
)
elif not _tensorpipe_validate_devices(
map_.keys(), all_device_counts[source_node]
):
raise ValueError(
f"Node {source_node} has source devices with invalid indices "
f"in its device map for {target_node}\n"
f"device map = {map_}\n"
f"device count = {all_device_counts[source_node]}" + dist_error(ErrCode.VALUE)
)
if all_devices.get(target_node, []):
if not set(map_.values()).issubset(all_devices[target_node]):
raise ValueError(
f"Node {source_node} has unexpected target devices "
f"in its device map for {target_node}\n"
f"device map = {map_}\n"
f"devices = {all_devices[target_node]}" + dist_error(ErrCode.VALUE)
)
elif target_node in all_device_counts and not _tensorpipe_validate_devices(
map_.values(), all_device_counts[target_node]
):
raise ValueError(
f"Node {source_node} has target devices with invalid indices "
f"in its device map for {target_node}\n"
f"device map = {map_}\n"
f"device count = {all_device_counts[target_node]}" + dist_error(ErrCode.VALUE)
)
def _get_device_infos():
from torch_npu._C._distributed_rpc import TensorPipeAgent
agent = cast(TensorPipeAgent, api._get_current_rpc_agent())
opts = agent._get_backend_options()
device_count = _get_device_count_info()
if opts.devices:
_init_device_state(opts.devices[0].type)
return device_count, opts.device_maps, opts.devices
def _tensorpipe_exchange_and_check_all_device_maps(
my_name, my_device_count, my_device_maps, my_devices, group
):
gathered: List[Tuple[
str, int, Dict[str, Dict[torch.device, torch.device]], List[torch.device]
]] = [("", 0, {}, []) for _ in range(group.size())]
dist.all_gather_object(
gathered, (my_name, my_device_count, my_device_maps, my_devices), group
)
all_names = [name for name, _, _, _ in gathered]
all_device_counts = {name: count for name, count, _, _ in gathered}
all_device_maps = {name: map_ for name, _, map_, _ in gathered}
all_devices = {name: devices for name, _, _, devices in gathered}
_validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices)
reverse_device_maps = rpc.backend_registry._create_reverse_mapping(my_name, all_names, all_device_maps)
my_devices = rpc.backend_registry._create_device_list(my_devices, my_device_maps, reverse_device_maps)
return reverse_device_maps, my_devices
def _set_devices_and_reverse_device_map(agent):
from torch_npu._C._distributed_rpc import TensorPipeAgent
agent = cast(TensorPipeAgent, agent)
my_worker_info = agent.get_worker_info()
my_name = my_worker_info.name
all_worker_infos = agent.get_worker_infos()
all_device_counts, all_device_maps, all_devices, all_names = {}, {}, {}, []
for worker_info in all_worker_infos:
worker_name = worker_info.name
if worker_name != my_name:
device_count, device_map, devices = api.rpc_sync(worker_name, _get_device_infos)
else:
opts = agent._get_backend_options()
device_map, devices = opts.device_maps, opts.devices
device_count = _get_device_count_info()
all_device_counts[worker_name] = device_count
all_device_maps[worker_name] = device_map
all_devices[worker_name] = devices
all_names.append(worker_name)
_validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices, is_static_group=False)
reverse_device_maps = rpc.backend_registry._create_reverse_mapping(my_name, all_names, all_device_maps)
for worker_name in all_names:
all_devices[worker_name] = rpc.backend_registry._create_device_list(all_devices[worker_name], all_device_maps[worker_name], reverse_device_maps)
api.rpc_sync(worker_name, _update_group_membership,
args=(my_worker_info, all_devices[worker_name], reverse_device_maps, True))
def _backend_type_repr(self):
return "BackendType." + self.name
def _construct_rpc_backend_options(
backend,
rpc_timeout=rpc_constants.DEFAULT_RPC_TIMEOUT_SEC,
init_method=rpc_constants.DEFAULT_INIT_METHOD,
**kwargs
):
return backend.value.construct_rpc_backend_options_handler(
rpc_timeout, init_method, **kwargs
)
def _init_backend(backend, *args, **kwargs):
return backend.value.init_backend_handler(*args, **kwargs)
def _npu_tensorpipe_construct_rpc_backend_options_handler(
rpc_timeout,
init_method,
num_worker_threads=rpc.constants.DEFAULT_NUM_WORKER_THREADS,
_transports=None,
_channels=None,
**kwargs
):
from .options import NPUTensorPipeRpcBackendOptions
return NPUTensorPipeRpcBackendOptions(
rpc_timeout=rpc_timeout,
init_method=init_method,
num_worker_threads=num_worker_threads,
_transports=_transports,
_channels=_channels,
)
def _npu_tensorpipe_init_backend_handler(
store, name, rank, world_size, rpc_backend_options
):
from torch_npu._C._distributed_rpc import TensorPipeAgent
from .options import NPUTensorPipeRpcBackendOptions
if not isinstance(store, dist.Store):
raise TypeError(f"`store` must be a c10d::Store. {store}" + dist_error(ErrCode.TYPE))
if not isinstance(
rpc_backend_options, NPUTensorPipeRpcBackendOptions
):
raise TypeError(
f"`rpc_backend_options` must be a `NPUTensorPipeRpcBackendOptions`. {rpc_backend_options}" +
dist_error(ErrCode.TYPE)
)
device_count = _get_device_count_info()
is_static_group = True if world_size else False
if is_static_group:
group = rpc.backend_registry._init_process_group(store, rank, world_size)
reverse_device_maps, devices = _tensorpipe_exchange_and_check_all_device_maps(
name,
device_count,
rpc_backend_options.device_maps,
rpc_backend_options.devices,
group,
)
if devices:
_init_device_state(devices[0].type)
agent = TensorPipeAgent(
store,
name,
rank,
world_size,
rpc_backend_options,
reverse_device_maps,
devices,
)
api._init_rpc_states(agent)
api._all_gather(None, timeout=rpc_backend_options.rpc_timeout)
group.barrier().wait()
return agent
else:
with _group_membership_management(store, name, True):
agent = TensorPipeAgent(
store,
name,
rank,
world_size,
rpc_backend_options,
{},
[],
)
api._init_rpc_states(agent)
try:
_set_devices_and_reverse_device_map(agent)
pass
except Exception as e:
api.shutdown()
e.msg += dist_error(ErrCode.INTERNAL)
raise
return agent
def _faulty_tensorpipe_construct_rpc_backend_options_handler(
rpc_timeout,
init_method,
num_worker_threads,
messages_to_fail,
messages_to_delay,
num_fail_sends=0,
**kwargs,
):
from torch_npu._C._distributed_rpc import FaultyTensorPipeRpcBackendOptions
return FaultyTensorPipeRpcBackendOptions(
num_worker_threads=num_worker_threads,
rpc_timeout=rpc_timeout,
init_method=init_method,
messages_to_fail=messages_to_fail,
messages_to_delay=messages_to_delay,
num_fail_sends=num_fail_sends,
)
def _faulty_tensorpipe_init_backend_handler(
store, name, rank, world_size, rpc_backend_options
):
from torch_npu._C._distributed_rpc import FaultyTensorPipeAgent, FaultyTensorPipeRpcBackendOptions
if not isinstance(store, dist.Store):
raise TypeError(f"`store` must be a c10d::Store. {store}" + dist_error(ErrCode.TYPE))
if not isinstance(rpc_backend_options, FaultyTensorPipeRpcBackendOptions):
raise TypeError(
f"`rpc_backend_options` must be a `FaultyTensorPipeRpcBackendOptions`. {rpc_backend_options}" +
dist_error(ErrCode.TYPE)
)
_init_device_state(_get_privateuse1_backend_name())
agent = FaultyTensorPipeAgent(
store,
name,
rank,
world_size,
rpc_backend_options,
{},
[],
)
api._init_rpc_states(agent)
return agent
def _rpc_backend_registry():
if hasattr(torch_npu._C, "_rpc_npu_init"):
torch_npu._C._rpc_npu_init()
rpc.backend_registry.register_backend(
"NPU_TENSORPIPE",
_npu_tensorpipe_construct_rpc_backend_options_handler,
_npu_tensorpipe_init_backend_handler,
)
rpc.backend_registry.register_backend(
"NPU_FAULTY_TENSORPIPE",
_faulty_tensorpipe_construct_rpc_backend_options_handler,
_faulty_tensorpipe_init_backend_handler,
)
import torch.distributed.rpc as _rpc_module
_rpc_module.BackendType = rpc.backend_registry.BackendType