"""
gen tiling head file for dynamic shape
"""
import ctypes
import json
import os
import stat
import struct
import math
import sys
import subprocess
import re
from pathlib import Path
from collections import namedtuple
import asc_op_compile_base
from asc_op_compile_base.common.platform import get_soc_spec
from asc_op_compile_base.asc_op_compiler.op_tiling import do_op_tiling, _ASCEND_OPP_PATH_ENV, _ASCEND_OPP_PATH_DEFAULT, \
_BUILTIN_TILING_PATH, _CUSTOM_TILING_PATH_DEFAULT, so_arch_path2
from asc_op_compile_base.common.error_mgr import raise_tbe_python_err, TBE_DEFAULT_PYTHON_ERROR_CODE
import asc_op_compile_base.asc_op_compiler.get_op_tiling as tiling_help
OpInfo = namedtuple('OpInfo', ['kernel_name', 'op_type', 'inputs', 'outputs', 'attrs', 'impl_mode', 'origin_inputs',\
'origin_outputs', 'param_type_dynamic', 'mc2_ctx', 'param_type_list', 'init_value_list',\
'output_shape_depend_on_compute'])
OpInfo.__new__.__defaults__ = (None, None, None, None, None, None, None, None, None, None, None, None, None)
DEFAULT_TILING_KEY_VALUE = 0
_ASCEND_CUSTOM_OPP_PATH_ENV = "ASCEND_CUSTOM_OPP_PATH"
_TILING_SO_PATH = "op_impl/ai_core/tbe/op_tiling/liboptiling.so"
opp_dir = os.environ.get(_ASCEND_OPP_PATH_ENV, _ASCEND_OPP_PATH_DEFAULT)
config_default = os.path.join(opp_dir, "vendors", "config.ini")
def get_default_optiling_pathlist():
vendor_list = []
full_path_list = []
if not os.path.exists(config_default):
return full_path_list
with open(config_default) as f:
vdr_lst = f.readline().split("=")[1].split(",")
for vdr in vdr_lst:
_vendor = vdr.strip()
if _vendor not in vendor_list:
vendor_list.append(_vendor)
full_path = os.path.join(opp_dir, "vendors", _vendor)
full_path_list.append(full_path)
return full_path_list
def get_custom_opp_pathlist():
vendor_list = []
custom_opp_dir = os.environ.get(_ASCEND_CUSTOM_OPP_PATH_ENV)
if custom_opp_dir is None:
return vendor_list
_path_lst = str(custom_opp_dir).split(":")
for _path in _path_lst:
local_path = _path.strip()
if len(local_path) != 0 and local_path not in vendor_list:
vendor_list.append(local_path)
return vendor_list
def load_lib_v2(hight_priority_path:str = None):
opp_path = Path(os.environ.get(_ASCEND_OPP_PATH_ENV, _ASCEND_OPP_PATH_DEFAULT))
builtin_optiling_lib_path = opp_path.joinpath(_BUILTIN_TILING_PATH)
libregister = ctypes.CDLL("libregister.so")
builtin_optiling_lib_path2 = opp_path.joinpath(so_arch_path2)
default_lst = get_default_optiling_pathlist()
if hight_priority_path is not None:
try:
custom_opp_so_path = hight_priority_path
lib_optiling = ctypes.CDLL(custom_opp_so_path)
custom_opp_so_path_str = str(custom_opp_so_path)
lib_optiling.TbeLoadSoAndSaveToRegistry(custom_opp_so_path_str.encode('utf_8'))
except OSError as e:
raise e
return libregister
custom_opp_list = get_custom_opp_pathlist()
join_list = custom_opp_list + default_lst
for _path in join_list:
try:
custom_opp_so_path = os.path.join(_path, _TILING_SO_PATH)
lib_optiling = ctypes.CDLL(custom_opp_so_path)
custom_opp_so_path_str = str(custom_opp_so_path)
lib_optiling.TbeLoadSoAndSaveToRegistry(custom_opp_so_path_str.encode('utf_8'))
except OSError:
pass
try:
if os.path.exists(builtin_optiling_lib_path2):
lib_optiling_builtin = ctypes.CDLL(builtin_optiling_lib_path2)
builtin_optiling_lib_path2_str = str(builtin_optiling_lib_path2)
lib_optiling_builtin.TbeLoadSoAndSaveToRegistry(builtin_optiling_lib_path2_str.encode('utf_8'))
except AttributeError:
pass
return libregister
def get_default_tiling_struct(opname: str):
default_tiling_struct = ""
old_path = os.path.abspath(os.path.dirname(os.path.abspath(__file__)) + "/../ops/built-in/tbe/impl/ascendc")
old_path = os.path.join(old_path, opname)
new_base_path = (os.path.dirname(os.path.abspath(__file__)) + "/../ops/")
pattern = f"*/{opname}"
new_paths = list(Path(new_base_path).glob(pattern))
asc_file_path = old_path
if len(new_paths) > 0:
asc_file_path = str(new_paths[0])
print(f"asc_file_path:{asc_file_path}")
command = ["grep", "-rnwE", "REGISTER_TILING_DEFAULT|BroadcastSch", asc_file_path]
print("command:", " ".join(command))
try:
result = subprocess.run(command, text=True, capture_output=True, check=True)
default_tiling_struct = result.stdout
default_tiling_struct = default_tiling_struct.split(";")[0]
default_tiling_struct = re.sub(r".*REGISTER_TILING_DEFAULT\(", '', default_tiling_struct)
default_tiling_struct = default_tiling_struct.replace(')', '')
default_tiling_struct = default_tiling_struct.strip()
print("default_tiling_struct:", default_tiling_struct)
if result.returncode == 0:
default_tiling_struct = "TestUtDefaultTilingStruct"
except subprocess.CalledProcessError as e:
print(e)
return default_tiling_struct
if __name__ == "__main__":
if len(sys.argv) <= 2:
raise RuntimeError("arguments must be greater than 2.")
op_type = sys.argv[1]
op_name = sys.argv[2]
file_name = op_name + "_tiling_data.h"
so_path = None
if len(sys.argv) == 4:
so_path = sys.argv[3]
tiling_key_list = []
if len(sys.argv) == 5:
tiling_key_list = sys.argv[4].split(",")
print("tiling_key_list:", tiling_key_list)
tiling_help.load_lib = lambda: load_lib_v2(so_path)
op_info = OpInfo()
tiling_info = tiling_help.TilingInfo()
op_info_dict = op_info._asdict()
op_info_dict["op_type"] = op_type
op_info_dict["inputs"] = [{"shape": [-1]}]
op_info_dict["origin_inputs"] = [{"shape": [-1]}]
op_info_dict["outputs"] = [{"shape": [-1]}]
op_info_dict["attrs"] = []
op_info2 = OpInfo(**op_info_dict)
with asc_op_compile_base.common.context.op_context.OpContext("dynamic"):
tiling_struct = get_default_tiling_struct(op_name)
if tiling_struct:
tiling_info.file_content = tiling_help.gen_dynamic_shape_v2(op_name, tiling_struct)
else:
tiling_info = tiling_help.get_tiling_info(op_info2, tiling_key_list)
if not tiling_info.file_content:
print("ERROR gen tiling head file failed.", flush=True)
with open(file_name, "w") as file:
print(tiling_info.file_content, file=file)