__all__ = [
"is_initialized",
"init",
"set_dump",
"synchronize",
"device_count",
"can_device_access_peer",
"set_device",
"current_device",
"get_device_name",
"get_device_properties",
"get_device_capability",
"mem_get_info",
"is_available",
"device",
"device_of",
"StreamContext",
"stream",
"set_stream",
"current_stream",
"default_stream",
"set_sync_debug_mode",
"get_sync_debug_mode",
"init_dump",
"utilization",
"finalize_dump",
"manual_seed",
"manual_seed_all",
"seed",
"seed_all",
"initial_seed",
"caching_allocator_alloc",
"caching_allocator_delete",
"get_per_process_memory_fraction",
"set_per_process_memory_fraction",
"empty_cache",
"empty_virt_addr_cache",
"memory_stats",
"memory_stats_as_nested_dict",
"reset_accumulated_memory_stats",
"reset_peak_memory_stats",
"reset_max_memory_allocated",
"reset_max_memory_cached",
"memory_allocated",
"max_memory_allocated",
"memory_reserved",
"max_memory_reserved",
"memory_cached",
"max_memory_cached",
"memory_snapshot",
"memory_summary",
"MemPool",
"MemPoolContext",
"use_mem_pool",
"get_allocator_backend",
"NPUPluggableAllocator",
"change_current_allocator",
"Stream",
"Event",
"set_option",
"CubeMathType",
"set_aoe",
"set_compile_mode",
"set_mm_bmm_format_nd",
"get_mm_bmm_format_nd",
"get_npu_overflow_flag",
"clear_npu_overflow_flag",
"stress_detect",
"get_rng_state",
"set_rng_state",
"get_rng_state_all",
"set_rng_state_all",
"is_jit_compile_false",
"get_amp_supported_dtype",
"is_autocast_enabled",
"is_bf16_supported",
"set_autocast_enabled",
"get_autocast_dtype",
"set_autocast_dtype",
"BoolStorage",
"ByteStorage",
"ShortStorage",
"LongStorage",
"IntStorage",
"HalfStorage",
"CharStorage",
"DoubleStorage",
"FloatStorage",
"BoolTensor",
"ByteTensor",
"CharTensor",
"DoubleTensor",
"FloatTensor",
"HalfTensor",
"IntTensor",
"LongTensor",
"ShortTensor",
"BFloat16Tensor",
"BFloat16Storage",
"current_blas_handle",
"stop_device",
"restart_device",
"check_uce_in_memory",
"config",
"matmul",
"conv",
"enable_deterministic_with_backward",
"disable_deterministic_with_backward",
"mstx",
"SyncLaunchStream",
"NPUGraph",
"graph",
"graph_pool_handle",
"is_current_stream_capturing",
"make_graphed_callables",
"ExternalEvent",
"graph_task_group_begin",
"graph_task_group_end",
"graph_task_update_begin",
"graph_task_update_end",
"set_device_limit",
"get_device_limit",
"set_stream_limit",
"reset_stream_limit",
"get_stream_limit",
"ipc_collect",
"obfuscation_initialize",
"obfuscation_finalize",
"obfuscation_calculate",
"set_op_timeout_ms",
"host_empty_cache",
"host_memory_stats",
"host_memory_stats_as_nested_dict",
"reset_accumulated_host_memory_stats",
"reset_peak_host_memory_stats",
"set_deterministic_level",
"use_compatible_impl",
"are_compatible_impl_enabled",
"NpuGraphOpHandler",
"register_npu_graph_handler",
"super_kernel_scope_begin",
"super_kernel_scope_end",
]
from typing import Tuple, Union, List, cast, Optional
from multiprocessing.util import register_after_fork as _register_after_fork
import traceback
import threading
import os
import re
import torch
from torch.storage import _LegacyStorage, _warn_typed_storage_removal
from torch._utils import classproperty
from torch_npu._init.common.warning_utils import _should_print_warning
import torch_npu
from torch_npu.utils._error_code import ErrCode, pta_error, prof_error
from .utils import (obfuscation_initialize, obfuscation_calculate, obfuscation_finalize,
synchronize, set_device, current_device, _get_device_index,
device, device_of, StreamContext, stream, set_stream, current_stream, default_stream, set_sync_debug_mode,
get_sync_debug_mode, init_dump, current_blas_handle, is_bf16_supported,
finalize_dump, set_dump, get_npu_overflow_flag, clear_npu_overflow_flag,
check_uce_in_memory, stress_detect, _get_uce_addr, ipc_collect, set_op_timeout_ms)
from ._recovery import restart_device, stop_device
from .streams import Stream, Event, SyncLaunchStream, ExternalEvent
from .mstx import mstx
from .npu_config import *
from .autocast_utils import *
from .backends import *
from ._backends import *
from .deterministic import enable_deterministic_with_backward, disable_deterministic_with_backward
from . import npugraph_ex
from .graphs import (
NPUGraph,
graph,
graph_pool_handle,
is_current_stream_capturing,
make_graphed_callables,
graph_task_group_begin,
graph_task_group_end,
graph_task_update_begin,
graph_task_update_end,
super_kernel_scope_begin,
super_kernel_scope_end,
)
from ._npugraph_handlers import (
NpuGraphOpHandler,
register_npu_graph_handler,
)
config = npu_config._npuConfig()
matmul = npu_config._allowHF32Matmul()
conv = npu_config._allowHF32Conv()
CubeMathType = npu_config._CubeMathType
default_generators: Tuple[torch._C.Generator] = ()
_is_internal_in_bad_fork = False
_initialized = False
_tls = threading.local()
_initialization_lock = threading.Lock()
_queued_calls = []
_original_pid = False
def _is_in_bad_fork():
return _is_internal_in_bad_fork
def is_initialized():
r"""Returns whether PyTorch's NPU state has been initialized."""
return _initialized and not _is_internal_in_bad_fork
def _lazy_call(cb):
if _initialized:
cb()
else:
_queued_calls.append((cb, traceback.format_stack()))
class _DeferredNpuCallError(Exception):
pass
def init():
r"""Initialize PyTorch's NPU state. You may need to call
this explicitly if you are interacting with PyTorch via
its C API, as Python bindings for NPU functionality will not
be until this initialization takes place. Ordinary users
should not need this, as all of PyTorch's NPU methods
automatically initialize NPU state on-demand.
Does nothing if the NPU state is already initialized.
"""
torch_npu.npu._lazy_init()
def _lazy_init():
def _queue_call(queued_calls):
for queued_call, orig_traceback in queued_calls:
try:
queued_call()
except Exception as e:
msg = (f"NPU call failed lazily at initialization with error: {str(e)}\n\n"
f"NPU call was originally invoked at:\n\n{orig_traceback}")
raise DeferredNpuCallError(msg + pta_error(ErrCode.INTERNAL)) from e
global _initialized, _original_pid, _queued_calls
if _initialized or hasattr(_tls, 'is_initializing'):
return
with _initialization_lock:
if _initialized:
return
if _is_internal_in_bad_fork:
raise RuntimeError(
"Cannot re-initialize NPU in forked subprocess. To use NPU with "
"multiprocessing, you must use the 'spawn' start method")
torch_npu._C._npu_init()
_original_pid = os.getpid()
_tls.is_initializing = True
try:
_queue_call(_queued_calls)
finally:
delattr(_tls, 'is_initializing')
_initialized = True
def _after_fork(arg):
global _initialized, _is_internal_in_bad_fork
if _initialized and _original_pid != os.getpid():
_initialized = False
_is_internal_in_bad_fork = True
torch_npu._C._npu_set_run_yet_variable_to_false()
_register_after_fork(_after_fork, _after_fork)
def _get_device(device: Union[int, str, torch.device]) -> torch.device:
r"""Return the torch.device type object from the passed in device.
Args:
device (torch.device or int): selected device.
"""
if isinstance(device, str):
return torch.device(device)
elif isinstance(device, int):
return torch.device('npu', device)
return device
def _get_generator(device: torch.device) -> torch._C.Generator:
r"""Return the NPU Generator object for the given device.
Args:
device (torch.device): selected device.
"""
idx = device.index
if idx is None:
idx = current_device()
return torch.npu.default_generators[idx]
def _set_rng_state_offset(offset: int, device: Union[int, str, torch.device] = 'npu') -> None:
r"""Sets the random number generator state offset of the specified NPU.
Args:
offset (int): The desired offset
device (torch.device or int, optional): The device to set the RNG state.
Default: ``'npu'`` (i.e., ``torch.device('npu')``, the current NPU device).
"""
final_device = _get_device(device)
def cb():
default_generator = _get_generator(final_device)
default_generator.set_offset(offset)
_lazy_call(cb)
def _get_rng_state_offset(device: Union[int, str, torch.device] = 'npu') -> int:
r"""Returns the random number generator state offset of the specified NPU.
Args:
device (torch.device or int, optional): The device to return the RNG state offset of.
Default: ``'npu'`` (i.e., ``torch.device('npu')``, the current NPU device).
.. warning::
This function eagerly initializes NPU.
"""
_lazy_init()
final_device = _get_device(device)
default_generator = _get_generator(final_device)
return default_generator.get_offset()
def is_available():
if (not hasattr(torch_npu._C, '_npu_setDevice')):
return False
return device_count() > 0
def _parse_visible_devices() -> Union[List[int], List[str]]:
r"""Parse ASCEND_RT_VISIBLE_DEVICES environment variable."""
var = os.getenv("ASCEND_RT_VISIBLE_DEVICES")
if var is None:
return list(range(64))
rc: List[int] = []
if not var:
return rc
for idx, c in enumerate(var):
if not (c.isdigit() or c == ","):
break
if idx + 1 == len(var):
idx += 1
for elem in var[:idx].split(","):
if not elem:
return rc
x = int(elem)
rc.append(x)
return rc
def _raw_device_count_ascend_hal() -> int:
r"""Return number of devices as reported by ascend_hal or negative value if ascend_hal discovery/initialization failed."""
from ctypes import byref, c_int, CDLL
ascend_hal_h = CDLL("libascend_hal.so")
dev_count = c_int(-1)
rc = ascend_hal_h.drvGetDevNum(byref(dev_count))
if rc != 0:
warnings.warn("Can't get ascend_hal device count")
return -1
del ascend_hal_h
return dev_count.value
def _device_count_ascend_hal() -> int:
r"""Return number of devices as reported by ascend_hal taking ASCEND_RT_VISIBLE_DEVICES into account.
Negative value is returned if ascend_hal discovery or initialization has failed.
"""
visible_devices = _parse_visible_devices()
if not visible_devices:
return 0
try:
raw_cnt = _raw_device_count_ascend_hal()
if raw_cnt <= 0:
return raw_cnt
for idx, val in enumerate(visible_devices):
if idx > 0 and val <= visible_devices[idx - 1]:
return 0
if cast(int, val) >= raw_cnt:
return idx
except OSError:
return -1
except AttributeError:
return -1
return len(visible_devices)
_cached_device_count: Optional[int] = None
def device_count() -> int:
r"""Return the number of NPUs available."""
global _cached_device_count
if _cached_device_count is not None:
return _cached_device_count
ascend_hal_count = _device_count_ascend_hal()
r = torch_npu._C._npu_getDeviceCount() if ascend_hal_count < 0 else ascend_hal_count
if _initialized:
_cached_device_count = r
return r
def _launch_host_func(op_stream, fn, user_data):
torch_npu._C._launch_host_func(op_stream, fn, user_data)
def _launch_host_func_pending(op_stream, fn, user_data):
torch_npu._C._launch_host_func_pending(op_stream, fn, user_data)
def _subscribe_report(op_stream):
torch_npu._C._subscribe_report(op_stream)
def _unsubscribe_report(op_stream):
torch_npu._C._unsubscribe_report(op_stream)
def can_device_access_peer(device_id, peer_device_id):
r"""Checks if peer access between two devices is possible.
"""
device_id = _get_device_index(device_id, optional=True)
peer_device_id = _get_device_index(peer_device_id, optional=True)
if device_id < 0 or device_id >= device_count():
raise AssertionError("Invalid devide id" + pta_error(ErrCode.VALUE))
if peer_device_id < 0 or peer_device_id >= device_count():
raise AssertionError("Invalid peer devide id" + pta_error(ErrCode.VALUE))
return torch_npu._C._npu_canDeviceAccessPeer(device_id, peer_device_id)
def get_device_name(device_name=None):
device_id = _get_device_index(device_name, optional=True)
if device_id < 0 or device_id >= device_count():
raise AssertionError("Invalid device id" + pta_error(ErrCode.VALUE))
return torch_npu._C._npu_getDeviceName()
def get_device_properties(device_name=None):
device_id = _get_device_index(device_name, optional=True)
if device_id < 0 or device_id >= device_count():
raise AssertionError("Invalid device id" + pta_error(ErrCode.VALUE))
torch_npu.npu._lazy_init()
return torch_npu._C._npu_getDeviceProperties(device_id)
def mem_get_info(device=None):
if device is None:
device = torch_npu.npu.current_device()
device_id = _get_device_index(device, optional=True)
if device_id < 0 or device_id >= device_count():
raise AssertionError("Invalid device id" + pta_error(ErrCode.VALUE))
torch_npu.npu._lazy_init()
device_prop = torch_npu._C._npu_getDeviceMemories(device_id)
return device_prop.free_memory, device_prop.total_memory
_cached_device_capability = None
_cached_device_capability_env = None
def get_device_capability(device=None):
r"""Query the minor and major data of device.
This function can be configured via the TORCH_NPU_DEVICE_CAPABILITY environment variable.
The format should be "major.minor", e.g., "9.0" or "8.0".
.. note::
The return value of get_device_capability is only for compatibility with PyTorch and does not represent the actual capability of the NPU device.
Args:
device (torch.device or int, optional): The device parameter has no practical meaning.
Returns:
tuple(int, int): the device capability of the device. Returns the tuple(major, minor) configured via
TORCH_NPU_DEVICE_CAPABILITY, or None if TORCH_NPU_DEVICE_CAPABILITY not configured.
Example:
>>> print(torch_npu.npu.get_device_capability())
None
>>> os.environ['TORCH_NPU_DEVICE_CAPABILITY'] = '8.0'
>>> print(torch_npu.npu.get_device_capability())
(8, 0)
>>> print(torch_npu.npu.get_device_capability(0))
(8, 0)
"""
global _cached_device_capability, _cached_device_capability_env
capability_env = os.getenv("TORCH_NPU_DEVICE_CAPABILITY")
warning_str = "The return value of get_device_capability is only for compatibility with PyTorch and does not represent the actual capability of the NPU device."
if not capability_env:
warnings.warn(f"You can set the device capability via the environment variable TORCH_NPU_DEVICE_CAPABILITY. {warning_str}")
return None
if _cached_device_capability_env == capability_env and _cached_device_capability is not None:
return _cached_device_capability
pattern = r'^(\d+)\.(\d+)$'
match = re.match(pattern, capability_env)
if not match:
warnings.warn(
f"torch.npu.get_device_capability can't parse TORCH_NPU_DEVICE_CAPABILITY environment variable: {capability_env}. "
"Expected format is 'major.minor' where major and minor are non-negative integers (e.g., '8.0')."
)
return None
warnings.warn(warning_str)
major_str, minor_str = match.groups()
major = int(major_str)
minor = int(minor_str)
_cached_device_capability = (major, minor)
_cached_device_capability_env = capability_env
return _cached_device_capability
def utilization(device=None):
r"""Query the comprehensive utilization rate of device
"""
device_id = _get_device_index(device, optional=True)
if device_id < 0 or device_id >= device_count():
raise AssertionError("Invalid device id" + pta_error(ErrCode.VALUE))
torch_npu.npu._lazy_init()
return torch_npu._C._npu_getDeviceUtilizationRate(device_id)
def _aclnn_reselect_static_kernel():
torch_npu.npu._lazy_init()
torch_npu._C._aclnn_reselect_static_kernel()
from .random import *
from .memory import *
@staticmethod
def _lazy_new(cls, *args, **kwargs):
_lazy_init()
return super(_NPUBase, cls).__new__(cls, *args, **kwargs)
def _comm_switch_nic(ranks, useBackup):
torch_npu.npu.synchronize()
return torch_npu.distributed.distributed_c10d._comm_switch_nic(ranks, useBackup)
def set_deterministic_level(level):
warnings.warn("After using 'torch_npu.npu.set_deterministic_level', "
"please do not use 'torch.use_deterministic_algorithms' anymore, "
"as it may cause unknown errors.")
if level == 0 and torch.are_deterministic_algorithms_enabled():
warnings.warn("The current configuration value of 'torch_npu.npu.set_deterministic_level' "
"conflicts with 'torch.use_deterministic_algorithms'. "
"'torch.use_deterministic_algorithms' has been configured to 'False'")
torch.use_deterministic_algorithms(False)
elif level >= 1 and not torch.are_deterministic_algorithms_enabled():
warnings.warn("The current configuration value of 'torch_npu.npu.set_deterministic_level' "
"conflicts with 'torch.use_deterministic_algorithms'. "
"'torch.use_deterministic_algorithms' has been configured to 'True'")
torch.use_deterministic_algorithms(True)
torch_npu._C._npu_set_deterministic_level(level)
def _get_deterministic_level():
level = torch_npu._C._npu_get_deterministic_level()
if level == 0 and torch.are_deterministic_algorithms_enabled():
level = 1
torch_npu.npu.set_deterministic_level(level)
return level
if level >= 1 and not torch.are_deterministic_algorithms_enabled():
level = 0
torch_npu.npu.set_deterministic_level(level)
return level
return level
def use_compatible_impl(is_enable):
option = {"TORCH_NPU_USE_COMPATIBLE_IMPL": "1" if is_enable else "0"}
torch_npu._C._npu_setOption(option)
def are_compatible_impl_enabled():
compatible_value = torch_npu._C._npu_getOption("TORCH_NPU_USE_COMPATIBLE_IMPL")
return compatible_value is not None and compatible_value.decode() == "1"
class _NPUBase:
is_npu = True
is_sparse = False
def type(self, *args, **kwargs):
with device(self.get_device()):
return super().type(*args, **kwargs)
__new__ = _lazy_new
class _NPULegacyStorage(_LegacyStorage):
@classmethod
def from_buffer(cls, *args, **kwargs):
_warn_typed_storage_removal()
raise RuntimeError('from_buffer: Not available for NPU storage' + pta_error(ErrCode.UNAVAIL))
@classmethod
def _new_with_weak_ptr(cls, *args, **kwargs):
raise RuntimeError('_new_with_weak_ptr: Not available for NPU storage' + pta_error(ErrCode.UNAVAIL))
@classmethod
def _new_shared_filename(cls, manager, obj, size, *, device=None, dtype=None):
raise RuntimeError('_new_shared_filename: Not available for NPU storage' + pta_error(ErrCode.UNAVAIL))
class ByteStorage(_NPULegacyStorage):
@classproperty
def dtype(self):
if _should_print_warning():
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.uint8
class DoubleStorage(_NPULegacyStorage):
@classproperty
def dtype(self):
if _should_print_warning():
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.double
class FloatStorage(_NPULegacyStorage):
@classproperty
def dtype(self):
if _should_print_warning():
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.float
class HalfStorage(_NPULegacyStorage):
@classproperty
def dtype(self):
if _should_print_warning():
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.half
class LongStorage(_NPULegacyStorage):
@classproperty
def dtype(self):
if _should_print_warning():
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.long
class IntStorage(_NPULegacyStorage):
@classproperty
def dtype(self):
if _should_print_warning():
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.int
class ShortStorage(_NPULegacyStorage):
@classproperty
def dtype(self):
if _should_print_warning():
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.short
class CharStorage(_NPULegacyStorage):
@classproperty
def dtype(self):
if _should_print_warning():
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.int8
class BoolStorage(_NPULegacyStorage):
@classproperty
def dtype(self):
if _should_print_warning():
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.bool
class BFloat16Storage(_NPULegacyStorage):
@classproperty
def dtype(self):
if _should_print_warning():
_warn_typed_storage_removal()
return self._dtype
@classproperty
def _dtype(self):
return torch.bfloat16
del _LegacyStorage
del _NPULegacyStorage
torch._storage_classes.add(DoubleStorage)
torch._storage_classes.add(FloatStorage)
torch._storage_classes.add(LongStorage)
torch._storage_classes.add(IntStorage)
torch._storage_classes.add(ShortStorage)
torch._storage_classes.add(CharStorage)
torch._storage_classes.add(ByteStorage)
torch._storage_classes.add(HalfStorage)
torch._storage_classes.add(BoolStorage)
torch._storage_classes.add(BFloat16Storage)