import logging
import torch
from torch.distributed.distributed_c10d import _pg_map
import torch_npu
import torch_npu._C
from torch_npu.utils._error_code import ErrCode, pta_error, _except_handler
WATCHDOG_STATUS_RUN = 1
WATCHDOG_STATUS_STOP = 2
logger = logging.getLogger("torch_npu.recovery")
def check_npu_storage_is_safe(storage_obj):
if isinstance(storage_obj, (torch.storage.TypedStorage, torch.storage.UntypedStorage)):
return torch_npu._C._check_npu_data_ptr(storage_obj)
else:
raise RuntimeError(f"param type should be TypedStorage or UntypedStorage, could not be {type(storage_obj)}" + pta_error(ErrCode.TYPE))
def check_npu_tensor_is_safe(tensor_obj):
if isinstance(tensor_obj, torch.Tensor):
return check_npu_storage_is_safe(tensor_obj.untyped_storage())
else:
raise RuntimeError(f"param type should be Tensor, could not be {type(tensor_obj)}" + pta_error(ErrCode.TYPE))
def mark_all_npu_tensor_unsafe(device: int):
return torch_npu._C._mark_all_npu_data_ptr_unsafe(device)
def update_npu_storage_to_safe(storage_obj):
if isinstance(storage_obj, (torch.storage.TypedStorage, torch.storage.UntypedStorage)):
return torch_npu._C._update_npu_data_ptr(storage_obj)
else:
raise RuntimeError(f"param type should be TypedStorage or UntypedStorage, could not be {type(storage_obj)}" + pta_error(ErrCode.TYPE))
def update_npu_tensor_to_safe(tensor_obj):
if isinstance(tensor_obj, torch.Tensor):
return update_npu_storage_to_safe(tensor_obj.untyped_storage())
else:
raise RuntimeError(f"param type should be Tensor, could not be {type(tensor_obj)}" + pta_error(ErrCode.TYPE))
def set_npu_tensor_unsafe_check_flag(flag: bool) -> None:
return torch_npu._C._set_npu_data_unsafe_flag(flag)
def get_npu_tensor_unsafe_check_flag() -> bool:
return torch_npu._C._get_npu_data_unsafe_flag()
def _recovery_all_npu_stream(device: int) -> None:
return torch_npu._C._recovery_all_npu_stream(device)
def restart_device(
device_id: int,
rebuild_all_resources: bool = False,
disable_tensor_unsafe_check: bool = False,
):
logger.info(
"restart device start, device_id=%s, rebuild_all_resources=%s, "
"disable_tensor_unsafe_check=%s",
device_id,
rebuild_all_resources,
disable_tensor_unsafe_check,
)
torch_npu.npu._lazy_init()
if rebuild_all_resources:
if not disable_tensor_unsafe_check:
mark_all_npu_tensor_unsafe(device_id)
set_npu_tensor_unsafe_check_flag(True)
_recovery_all_npu_stream(device_id)
torch_npu._C._npu_restart_device(device_id)
_except_handler.set_force_stop_exception(False)
npu_device = torch.device('npu')
for pg in _pg_map:
if (npu_device in pg._device_types):
pg._get_backend(npu_device).clear_workmeta_list()
pg._get_backend(npu_device).set_watchdog_status(WATCHDOG_STATUS_RUN)
logger.info(f"set watchdog status to run, device_id={device_id}, group={pg}")
logger.info(f"restart device end, device_id={device_id}")
def stop_device(device_id):
logger.info(f"stop device start, device_id={device_id}")
torch_npu.npu._lazy_init()
result = torch_npu._C._npu_stopDevice(device_id)
_except_handler.set_force_stop_exception(True)
npu_device = torch.device('npu')
for pg in _pg_map:
if (npu_device in pg._device_types):
pg._get_backend(npu_device).set_watchdog_status(WATCHDOG_STATUS_STOP)
logger.info(f"set watchdog status to stop, device_id={device_id}, group={pg}")
logger.info(f"stop device end, device_id={device_id}, result={result}")
return result