import os
import re
from functools import wraps
import torch
import torch_npu
from torch_npu.utils._error_code import ErrCode, pta_error
from .unsupport_api import unsupported_Tensor_api, unsupported_nn_api
from .collect_env import get_cann_version
cann_pytorch_version_map = {
"6.3.RC2": ["1.8.1.post2", "1.11.0.post1", "2.0.0.rc1"],
"6.3.RC1": ["1.8.1.post1", "1.11.0"],
"6.1.RC1": ["1.8.1.post1", "1.11.0"],
"6.0.1": ["1.8.1", "1.11.0.rc2"],
"6.0.RC1": ["1.8.1", "1.11.0.rc1"]
}
__all__ = []
def _cann_package_check():
if "ASCEND_HOME_PATH" in os.environ:
ascend_home_path = os.environ["ASCEND_HOME_PATH"]
if not os.path.exists(ascend_home_path):
raise Exception(f"ASCEND_HOME_PATH : {ascend_home_path} does not exist. "
"Please run 'source set_env.sh' in the CANN installation path." +
pta_error(ErrCode.NOT_FOUND))
if "ASCEND_OPP_PATH" not in os.environ:
raise Exception(f"ASCEND_OPP_PATH environment variable is not set. "
"Please check whether the opp package has been installed. If exist, please run "
"'source set_env.sh' in the CANN installation path." +
pta_error(ErrCode.NOT_FOUND))
ascend_opp_path = os.environ["ASCEND_OPP_PATH"]
if not os.path.exists(ascend_opp_path):
raise Exception(f"ASCEND_OPP_PATH : {ascend_opp_path} does not exist. "
"Please check whether the opp package has been installed. If exist, please run "
"'source set_env.sh' in the CANN installation path." +
pta_error(ErrCode.NOT_FOUND))
ascend_runtime_path = os.path.join(ascend_home_path, "runtime")
if not os.path.exists(ascend_runtime_path):
raise Exception(f"ASCEND_RUNTIME_PATH : {ascend_runtime_path} does not exist. "
"Please check whether the runtime package has been installed. If exist, please run "
"'source set_env.sh' in the CANN installation path." +
pta_error(ErrCode.NOT_FOUND))
ascend_compiler_path = os.path.join(ascend_home_path, "compiler")
if not os.path.exists(ascend_compiler_path):
raise Exception(f"ASCEND_COMPILER_PATH : {ascend_compiler_path} does not exist. "
"Please check whether the compiler package has been installed. If exist, please run "
"'source set_env.sh' in the CANN installation path." +
pta_error(ErrCode.NOT_FOUND))
cann_version = get_cann_version()
if cann_version in cann_pytorch_version_map and \
torch_npu.__version__ not in cann_pytorch_version_map[cann_version]:
print(f"Warning : CANN package version {cann_version} and PyTorch version {torch_npu.__version__} "
"is not matched, please check the README of the ascend pytorch repo.")
else:
print(f"Warning : ASCEND_HOME_PATH environment variable is not set.")
def _create_wrap_func(check_func):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if check_func(*args, **kwargs):
raise RuntimeError(f"{str(func)} is not supported in npu." + pta_error(ErrCode.NOT_SUPPORT))
return func(*args, **kwargs)
return wrapper
return decorator
def _is_tensor_npu_supported(*args, **kwargs):
return torch.is_tensor(args[0]) and args[0].is_npu
def _is_module_parameters_supported(*args, **kwargs):
module_args = [m for m in args if isinstance(m, torch.nn.Module) and hasattr(m, "_modules")]
module_parameters = [p for _, p in module_args[0].named_parameters()]
return any(p.device is not None and p.device.type == "npu" for p in module_parameters)
def _apply_wrap_func_to_modules(wrap_func, unsupported_modules):
for attr_name, parent_module in unsupported_modules.items():
setattr(parent_module, attr_name, wrap_func(getattr(parent_module, attr_name)))
def _add_intercept_methods():
_apply_wrap_func_to_modules(_create_wrap_func(_is_tensor_npu_supported), unsupported_Tensor_api)
_apply_wrap_func_to_modules(_create_wrap_func(_is_module_parameters_supported), unsupported_nn_api)