322694cb创建于 2025年3月17日历史提交
__all__ = ["NPUTensorPipeRpcBackendOptions"]

from typing import Dict, List, Optional, Union

import torch
from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase
from torch._C import _get_privateuse1_backend_name
import torch.distributed.rpc.constants as rpc_constants

from torch_npu.utils._error_code import ErrCode, dist_error

DeviceType = Union[int, str, torch.device]


def _to_device(device: DeviceType) -> torch.device:
    device = torch.device(device)
    if device.type != _get_privateuse1_backend_name():
        raise ValueError(
            "`set_devices` expect a list of "
            f"{_get_privateuse1_backend_name()} devices, but got "
            f"device type {device.type}." + dist_error(ErrCode.VALUE)
        )
    return device


def _to_device_map(
    device_map: Dict[DeviceType, DeviceType]
) -> Dict[torch.device, torch.device]:
    full_device_map: Dict[torch.device, torch.device] = {}
    reverse_map: Dict[torch.device, torch.device] = {}
    for k, v in device_map.items():
        k, v = torch.device(k), torch.device(v)
        if v in reverse_map:
            raise ValueError(
                "`device_map` only supports 1-to-1 mapping, "
                f"trying to map {k} and {reverse_map[v]} to {v}" + dist_error(ErrCode.VALUE)
            )
        full_device_map[k] = v
        reverse_map[v] = k
    return full_device_map


def _to_device_list(devices: List[DeviceType]) -> List[torch.device]:
    return list(map(_to_device, devices))


class NPUTensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
    def __init__(
        self,
        *,
        num_worker_threads: int = rpc_constants.DEFAULT_NUM_WORKER_THREADS,
        rpc_timeout: float = rpc_constants.DEFAULT_RPC_TIMEOUT_SEC,
        init_method: str = rpc_constants.DEFAULT_INIT_METHOD,
        device_maps: Optional[Dict[str, Dict[DeviceType, DeviceType]]] = None,
        devices: Optional[List[DeviceType]] = None,
        _transports: Optional[List] = None,
        _channels: Optional[List] = None,
    ):
        full_device_maps = (
            {}
            if device_maps is None
            else {k: _to_device_map(v) for k, v in device_maps.items()}
        )
        full_device_list = [] if devices is None else _to_device_list(devices)
        super().__init__(
            num_worker_threads,
            _transports,
            _channels,
            rpc_timeout,
            init_method,
            full_device_maps,
            full_device_list,
        )

    def set_device_map(self, to: str, device_map: Dict[DeviceType, DeviceType]):
        full_device_map = _to_device_map(device_map)
        curr_device_maps = super().device_maps

        if to in curr_device_maps:
            for k, v in full_device_map.items():
                if k in curr_device_maps[to] and v != curr_device_maps[to][k]:
                    raise ValueError(
                        "`set_device_map` only supports 1-to-1 mapping, trying"
                        f" to map {k} to {v} and {curr_device_maps[to][k]}" + dist_error(ErrCode.VALUE)
                    )

        super()._set_device_map(to, full_device_map)

    def set_devices(self, devices: List[DeviceType]):
        self.devices = _to_device_list(devices)