import os
import sys
import types
import pathlib
from itertools import chain
from functools import wraps, lru_cache
import torch
import torch_npu
API_LIST = [
'_npu_matmul_add_fp32',
'_npu_quant_rms_norm',
'_npu_group_topk',
'_npu_paged_attention',
'_npu_paged_attention_mla',
'_npu_paged_attention_quant',
'_npu_quantize_per_tensor',
'_npu_reshape_and_cache',
'_npu_reshape_and_cache_siso',
'_npu_rotary_embedding',
'_npu_flash_attention',
'_npu_flash_attention_unpad',
'_npu_paged_attention_splitfuse',
'_npu_flash_attention_qlens',
'_npu_paged_attention_get_workspace',
]
ATB_API_LIST = [
'npu_paged_cache_load',
'npu_multi_head_latent_attention',
'_npu_multi_head_latent_attention_get_workspace',
'_npu_paged_attention_v2',
'_npu_paged_attention_v2_get_workspace',
'_npu_flash_attention_v2',
'_npu_flash_attention_prefix_v2',
'npu_fused_add_topk_div',
'npu_ring_mla',
'npu_self_attention_prefix_encoder',
'npu_mla_preprocess',
]
ATB_MODULE_NAME = 'atb'
ATB_MODULE = types.ModuleType(ATB_MODULE_NAME)
def _add_atb_module():
setattr(torch_npu, ATB_MODULE_NAME, ATB_MODULE)
sys.modules[f'torch_npu.{ATB_MODULE_NAME}'] = ATB_MODULE
_add_atb_module()
NNAL_EX = None
GLOBAL_E = None
try:
npu_path = pathlib.Path(__file__).parents[2]
atb_so_path = os.path.join(npu_path, 'lib', 'libop_plugin_atb.so')
from torch_npu.utils._path_manager import PathManager
PathManager.check_directory_path_readable(atb_so_path)
torch.ops.load_library(atb_so_path)
import torch_npu.op_plugin.atb._atb_meta_registrations
except OSError as e:
nnal_strerror = ""
if "libatb.so" in str(e):
nnal_strerror = "Please check that the nnal package is installed. "\
"Please run 'source set_env.sh' in the NNAL installation path."
if "undefined symbol" in str(e):
nnal_strerror = "Please check the version of the NNAL package. "\
"An undefined symbol was found, "\
"which may be caused by a version mismatch between NNAL and torch_npu."
NNAL_EX = OSError(e.errno, nnal_strerror)
NNAL_EX.__traceback__ = e.__traceback__
GLOBAL_E = e
@lru_cache(None)
def _register_atb_extensions():
global NNAL_EX, GLOBAL_E
if NNAL_EX is not None:
raise NNAL_EX from GLOBAL_E
_patch_atb_ops()
from torch_npu.op_plugin.atb._atb_api_docs import _add_torch_npu_atb_api_docstr
if not torch.compiler.is_compiling():
_add_torch_npu_atb_api_docstr()
def lazy_load_atb_so(api_func):
@wraps(api_func)
def wrapper(*args, **kwargs):
_register_atb_extensions()
return api_func(*args, **kwargs)
return wrapper
def create_lazy_atb_function(api_name):
@lazy_load_atb_so
def generated_function(*args, **kwargs):
return getattr(torch.ops.atb, api_name)(*args, **kwargs)
generated_function.__name__ = api_name
return generated_function
def generate_atb_lazy_function():
for api_name in chain(API_LIST, ATB_API_LIST):
globals()[api_name] = create_lazy_atb_function(api_name)
generate_atb_lazy_function()
def _patch_atb_ops():
for api_name in API_LIST:
setattr(torch_npu, api_name, getattr(torch.ops.atb, api_name))
for api_name in ATB_API_LIST:
setattr(ATB_MODULE, api_name, getattr(torch.ops.atb, api_name))
def _patch_atb_and_loadso():
for api_name in API_LIST:
func = globals().get(api_name)
setattr(torch_npu, api_name, func)
for api_name in ATB_API_LIST:
func = globals().get(api_name)
setattr(ATB_MODULE, api_name, func)