import os
import sys
import time
import warnings
import importlib
from torch._dynamo import register_backend as _register_backend
from torch._dynamo.backends.registry import _BACKENDS
from torch.library import Library, impl
from torch_npu.utils._error_code import ErrCode, pta_error
from torch_npu.utils.utils import _should_print_warning
from .trace_rule import _patch_npu_trace_rules
_global_npu_backend = {}
__all__ = []
class _TorchairImportError(Exception):
def __init__(self):
super().__init__(self)
self.err_info = (
"\nAn error occured when import `torchair` and the above is the specific error message. \n"
"This error message was generated when import torchair, but throwed asynchronously here. \n"
"Please check the error message above. \n") + pta_error(ErrCode.INTERNAL)
def __str__(self):
return self.err_info
class _LazyException:
def __init__(self, e):
self._info = _TorchairImportError()
self._e = e
def __getattr__(self, name):
raise self._info from self._e
def __call__(self, *args, **kwargs):
raise self._info from self._e
def _eager_npu_backend(gm, *args, **kwargs):
return gm
def _get_global_npu_backend(name, config=None):
global _global_npu_backend
if name in _global_npu_backend.keys():
return _global_npu_backend[name]
if 'torchair' not in sys.modules:
raise AssertionError("Could not find module torchair. "
"Please check if torchair is removed from sys.modules." + pta_error(ErrCode.NOT_FOUND))
import torchair
_global_npu_backend[name] = torchair.get_npu_backend(compiler_config=config)
return _global_npu_backend[name]
class _LazyTorchair:
def __init__(self):
self._torchair = None
self._exception = None
self._allowed_list = ["__spec__", "__path__"]
def __getattr__(self, name):
if self._exception is not None:
return self._exception()
if self._torchair is not None:
return getattr(self._torchair, name)
if name not in self._allowed_list:
raise AttributeError(f"Try to get torchair's attr `{name}` before torchair is initialized."
+ self._pta_error_code())
try:
from . import torchair
except Exception as e:
self._exception = _LazyException(e)
raise
self._torchair = torchair
return getattr(torchair, name)
def _pta_error_code(self):
error_msg = "\n[ERROR] {time} (PID:{pid}, Device:-1, RankID:-1) ERR00005 PTA internal error"
return error_msg.format(
time=time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()),
pid=os.getpid())
def _get_default_backend(name):
if not os.path.exists(os.path.join(os.path.dirname(__file__), 'torchair')):
if _should_print_warning():
warnings.warn(
"Register eager implementation for the 'npu' backend of dynamo, "
"as torch_npu was not compiled with torchair.")
return _eager_npu_backend
def _lazy_exec(*args, **kwargs):
return _get_global_npu_backend(name)(*args, **kwargs)
sys.modules['torchair'] = _LazyTorchair()
return _lazy_exec
def _get_npugraph_ex_backend():
def _exec(*args, **kwargs):
import torchair
config = torchair.CompilerConfig()
config.mode = "reduce-overhead"
return _get_global_npu_backend("npugraph_ex", config)(*args, **kwargs)
return _exec
_global_backend = _get_default_backend(name="npu")
_global_npugraph_ex_backend = _get_npugraph_ex_backend()
def _register_npu_backend(backend, name="npu"):
if name in _BACKENDS.keys():
del _BACKENDS[name]
_register_backend(backend, name)
_register_npu_backend(_global_backend)
_register_npu_backend(_global_npugraph_ex_backend, "npugraph_ex")