import torch
from torch.distributed.distributed_c10d import _pg_map
import torch_npu
import torch_npu._C
from torch_npu.utils._error_code import ErrCode, pta_error, _except_handler
WATCHDOG_STATUS_RUN = 1
WATCHDOG_STATUS_STOP = 2
def check_npu_storage_is_safe(storage_obj):
if isinstance(storage_obj, (torch.storage.TypedStorage, torch.storage.UntypedStorage)):
return torch_npu._C._check_npu_data_ptr(storage_obj)
else:
raise RuntimeError(f"param type should be TypedStorage or UntypedStorage, could not be {type(storage_obj)}" + pta_error(ErrCode.TYPE))
def check_npu_tensor_is_safe(tensor_obj):
if isinstance(tensor_obj, torch.Tensor):
return check_npu_storage_is_safe(tensor_obj.untyped_storage())
else:
raise RuntimeError(f"param type should be Tensor, could not be {type(tensor_obj)}" + pta_error(ErrCode.TYPE))
def mark_all_npu_tensor_unsafe(device: int):
return torch_npu._C._mark_all_npu_data_ptr_unsafe(device)
def update_npu_storage_to_safe(storage_obj):
if isinstance(storage_obj, (torch.storage.TypedStorage, torch.storage.UntypedStorage)):
return torch_npu._C._update_npu_data_ptr(storage_obj)
else:
raise RuntimeError(f"param type should be TypedStorage or UntypedStorage, could not be {type(storage_obj)}" + pta_error(ErrCode.TYPE))
def update_npu_tensor_to_safe(tensor_obj):
if isinstance(tensor_obj, torch.Tensor):
return update_npu_storage_to_safe(tensor_obj.untyped_storage())
else:
raise RuntimeError(f"param type should be Tensor, could not be {type(tensor_obj)}" + pta_error(ErrCode.TYPE))
def set_npu_tensor_unsafe_check_flag(flag: bool) -> None:
return torch_npu._C._set_npu_data_unsafe_flag(flag)
def get_npu_tensor_unsafe_check_flag() -> bool:
return torch_npu._C._get_npu_data_unsafe_flag()
def _recovery_all_npu_stream(device: int) -> None:
return torch_npu._C._recovery_all_npu_stream(device)
def restart_device(device_id: int, rebuild_all_resources: int = False):
torch_npu.npu._lazy_init()
if rebuild_all_resources:
mark_all_npu_tensor_unsafe(device_id)
set_npu_tensor_unsafe_check_flag(True)
_recovery_all_npu_stream(device_id)
torch_npu._C._npu_restart_device(device_id)
_except_handler.set_force_stop_exception(False)
npu_device = torch.device('npu')
for pg in _pg_map:
if (npu_device in pg._device_types):
pg._get_backend(npu_device).clear_workmeta_list()
pg._get_backend(npu_device).set_watchdog_status(WATCHDOG_STATUS_RUN)
def stop_device(device_id):
torch_npu.npu._lazy_init()
result = torch_npu._C._npu_stopDevice(device_id)
_except_handler.set_force_stop_exception(True)
npu_device = torch.device('npu')
for pg in _pg_map:
if (npu_device in pg._device_types):
pg._get_backend(npu_device).set_watchdog_status(WATCHDOG_STATUS_STOP)
return result