"""
build_opp_kernel_static.py
"""
import concurrent.futures
import glob
import multiprocessing
import sys
import os
import re
import platform
import stat
import json
import argparse
import subprocess
import logging as log
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List
class Const:
x86 = "x86_64"
arm = "aarch64"
def shell_exec(cmd, shell=False):
try:
ps = subprocess.Popen(cmd, shell)
ps.communicate(timeout=180)
except BaseException as e:
log.error(f"shell_exec error: {e}")
sys.exit(1)
def shell_checkout_key_func(symbol_file, key_str):
process = subprocess.Popen(("cat", symbol_file), stdout=subprocess.PIPE)
awk_out = subprocess.check_output(("awk", "{print $8}"), stdin=process.stdout)
process.wait()
cppfilt = subprocess.check_output(("c++filt",), input=awk_out)
if key_str not in cppfilt.decode("utf-8"):
return "".encode("utf-8")
grep_out = subprocess.check_output(("grep", key_str), input=cppfilt)
return grep_out.decode("utf-8")
def to_upper_camel_case(x) -> str:
"""转大驼峰法命名"""
s = re.sub('_([a-zA-Z])', lambda m: (m.group(1).upper()), x.lower())
return s[0].upper() + s[1:]
def generate_symbol(args):
library_file = args.library_file
symbol_file = args.symbol_file
if not os.path.exists(library_file):
raise FileExistsError(f"generate_symbol input library error, file <{library_file}> not exists.")
process = subprocess.Popen(("readelf", "-Ws", library_file), stdout=subprocess.PIPE)
output = process.communicate(timeout=180)[0].decode("utf-8")
with open(symbol_file, "w") as fd:
fd.write(output)
def parser_generate_symbol(subparsers):
generate_symbol_parser = subparsers.add_parser(name='GenerateSymbol', help='Generate Symbol file of input library')
generate_symbol_parser.add_argument('-l', '--library_file', type=str, required=False, dest="library_file",
default="", help="Input the library file")
generate_symbol_parser.add_argument('-s', '--symbol_file', type=str, required=False, dest="symbol_file", default="",
help="The symbol file for output")
generate_symbol_parser.set_defaults(func=generate_symbol)
class CompileOpStaticLib:
def __init__(self, ops_compile_files: Dict, out_path: str, dist_index: int, arch: str):
self.ops_compile_files = ops_compile_files
self.out_path = out_path
self.part_index = dist_index
self.cpu_arch = arch
if self.cpu_arch not in [Const.x86, Const.arm]:
raise Exception(f"CompileOpStaticLib Error, input arch<{arch}> error...")
def compile_link_single(self, file_path, file_o):
(dir_path, file_name) = os.path.split(file_path)
if self.cpu_arch == Const.x86:
shell_exec(["bash", "-c", f"cd {dir_path} && "
f"objcopy --input-target binary --output-target elf64-x86-64 "
f"--binary-architecture i386 "
f"{file_name} {file_o}"], shell=False)
elif self.cpu_arch == Const.arm and platform.machine() != Const.x86:
shell_exec(["bash", "-c", f"cd {dir_path} && "
f"objcopy --input-target binary "
f"--output-target elf64-littleaarch64 --binary-architecture aarch64 "
f"{file_name} {file_o}"], shell=False)
elif self.cpu_arch == Const.arm:
shell_exec(["bash", "-c", f"cd {dir_path} && "
f"aarch64-linux-gnu-objcopy --input-target binary "
f"--output-target elf64-littleaarch64 --binary-architecture aarch64 "
f"{file_name} {file_o}"], shell=False)
def compile_link_o(self, out_path, file_path, is_need_path=True):
file_pre = os.path.basename(file_path).replace('.', '_').replace('-', '_')
path_o_prefix = os.path.join(out_path, f"data_{file_pre}_{self.cpu_arch}.o")
if is_need_path and file_path.name.endswith(".json"):
with open(file_path, 'r', encoding='UTF-8') as json_fd:
json_dict = json.load(json_fd)
soc = str(file_path).split("/binary/")[-1].split("/bin/")[0]
json_dict["filePath"] = os.path.join(soc, str(file_path).split("/bin/")[-1].split("/kernel/")[-1])
if "opp/built-in/" in str(file_path):
json_dict["filePath"] = str(file_path).split("/bin/")[-1].split("/kernel/")[-1].replace("/ops_transformer", "")
file_path = os.path.join(out_path, os.path.basename(file_path))
with open(file_path, 'w', encoding='UTF-8') as new_json_fd:
new_json_fd.write(json.dumps(json_dict, indent=4))
file_path = os.path.realpath(file_path)
self.compile_link_single(file_path, path_o_prefix)
return
def compile_ops_part_o(self, out_path):
path_data_o = os.path.join(out_path, f"data_*{self.cpu_arch}.o")
path_data_o_list = glob.glob(path_data_o)
if not path_data_o_list:
return
(dir_path, ops_name) = os.path.split(out_path)
file_part_o = f"{ops_name}_{self.cpu_arch}_part{self.part_index}.o"
path_part_o = os.path.join(dir_path, file_part_o)
if self.cpu_arch == Const.x86 or (self.cpu_arch == Const.arm and platform.machine() != Const.x86):
shell_exec(["bash", "-c", f"cd {out_path} && "
f"ld -r {path_data_o} -o {path_part_o}"], shell=False)
if self.cpu_arch == Const.arm and platform.machine() == Const.x86:
shell_exec(["bash", "-c", f"cd {out_path} && "
f"aarch64-linux-gnu-ld -r {path_data_o} -o {path_part_o}"], shell=False)
return
def exec_compile(self):
"""
编译算子静态库
:return:
"""
def get_parallel_num() -> int:
"""
获取多线程最大并发数量
"""
num = multiprocessing.cpu_count() * 2
if num == 0:
num = 16
return num
job_num = get_parallel_num()
for op in self.ops_compile_files:
compile_files = self.ops_compile_files[op].kernel_files
json_files = self.ops_compile_files[op].binary_config_files
runtime_kb_files = self.ops_compile_files[op].runtime_kb_files
op_out_path = os.path.join(self.out_path, op)
if not os.path.exists(op_out_path):
os.makedirs(op_out_path, exist_ok=True)
with concurrent.futures.ThreadPoolExecutor(max_workers=job_num) as executor:
for file in compile_files:
executor.submit(self.compile_link_o, op_out_path, file.resolve())
for file in json_files:
executor.submit(self.compile_link_o, op_out_path, file.resolve(), False)
for file in runtime_kb_files:
executor.submit(self.compile_link_o, op_out_path, file.resolve(), False)
with concurrent.futures.ThreadPoolExecutor(max_workers=job_num) as executor:
executor.submit(self.compile_ops_part_o, op_out_path)
return 0
def compile_static_library(args):
index_num = args.index_num
cpu_aarch = args.cpu_aarch
if cpu_aarch not in [Const.x86, Const.arm]:
raise Exception(f"Input cpu_aarch<{cpu_aarch}> Error, Please input parase.")
ops_compile_files = GenOpResourceIni(args.soc_version, args.build_dir, args.jit).analyze_ops_files()
csl = CompileOpStaticLib(ops_compile_files,
os.path.join(args.build_dir, f"bin_tmp/{args.soc_version}"), index_num, cpu_aarch)
ret = csl.exec_compile()
return ret
def parser_compile_static_library(subparsers):
""" 配置静态编译参数及执行信息 """
compile_lib_parser = subparsers.add_parser(name='StaticCompile',
help='Compile static libraries(.a) on distributed server')
compile_lib_parser.add_argument('-s', '--soc_version', type=str, required=True, dest="soc_version",
help="Operator Name, eg: ascend910b, ascend310p")
compile_lib_parser.add_argument('-b', '--build_dir', type=str, required=True, dest="build_dir",
help="Input build dir for this project")
compile_lib_parser.add_argument('-j', '--jit', action='store_true', dest="jit",
help="Compile static libraries(.a) with cann package")
compile_lib_parser.add_argument('-n', '--index_num', type=int, required=True, dest="index_num",
help="Please input distributed compilation idx")
compile_lib_parser.add_argument('-a', '--cpu_aarch', type=str, required=True, dest="cpu_aarch",
help="Please input cpu aarch, eg:x86_64,aarch64")
compile_lib_parser.set_defaults(func=compile_static_library)
@dataclass
class OpResource:
"""算子资源"""
tiling_register: str = field(default=None)
extend_register: str = field(default_factory=list)
infer_shape_register: str = field(default=None)
tuning_bank_key_register: str = field(default=None)
tuning_bank_parse_register: str = field(default=None)
tuning_tiling_helper: str = field(default=None)
binary_config_files: list = field(default_factory=list)
kernel_files: list = field(default_factory=list)
runtime_kb_files: list = field(default_factory=list)
class GenOpResourceIni:
def __init__(self, soc_version: str, build_dir: str, build_with_package: bool):
self._soc_version = soc_version
self._build_dir = Path(build_dir)
opp_path = os.environ.get('ASCEND_OPP_PATH')
if build_with_package and opp_path:
opp_path = Path(opp_path)
self._binary_path = opp_path / "built-in/op_impl/ai_core/tbe/kernel"
self._tuning_basic_path = opp_path / "built-in/data/op"
ops_info = opp_path / "built-in/op_impl/ai_core/tbe/config" / self._soc_version
ops_info = list(ops_info.glob(f"aic-{self._soc_version}-ops-info-transformer.json"))
self._ops_info = ops_info[0] if len(ops_info) != 0 else None
else:
self._binary_path = self._build_dir / "binary" / self._soc_version / "bin"
self._tuning_basic_path = self._build_dir / "tbe/config" / self._soc_version
ops_info = self._build_dir / "custom/op_impl/ai_core/tbe/config" / self._soc_version
ops_info = list(ops_info.glob(f"aic-{self._soc_version}-ops-info*.json"))
self._ops_info = ops_info[0] if len(ops_info) != 0 else None
self._op_resource_path = self._build_dir / "autogen" / self._soc_version / "aclnnop_resource"
self._op_res: Dict[str, OpResource] = defaultdict(OpResource)
self._l0op_list = []
TILING_REG_DECL_FMT = """
namespace {namespace} {{
extern gert::OpImplRegisterV2 {func_name};
}}
"""
EXTEND_REG_DECL_FMT = """
namespace {namespace} {{
extern uint32_t {func_name};
}}
"""
TILING_REG_RES_FUNC_FMT = """
void * {op_type}TilingRegisterResource() {{
return {reference_code};
}}
"""
INFER_SHAPE_REG_DECL_FMT = """
namespace {namespace} {{
extern gert::OpImplRegisterV2 {func_name};
}}
"""
INFER_SHAPE_REG_RES_FUNC_FMT = """
void * {op_type}InferShapeRegisterResource() {{
return {reference_code};
}}
"""
TUNING_REG_DECL_FMT = """
namespace {namespace} {{
class {class_type};
extern {class_type} {func_name};
}}
"""
EXTLEND_REG_RES_FUNC_FMT = """
void * {op_type}ExtendRegisterResource() {{
static std::vector<void *> resource = {{{reference_code}}};
return &resource;
}}
"""
TUNING_REG_RES_FUNC_FMT = """
void * {op_type}TuningRegisterResource() {{
static std::vector<void *> resource = {{{tuning_bank_key}, {tuning_bank_parse}, {tuning_helper}}};
return &resource;
}}
"""
KERNEL_BINARY_RES_FUNC_FMT = """
const OP_BINARY_RES& {op_type}KernelResource() {{
static const OP_BINARY_RES resource = {{
{binary_config_ref_code}
{kernel_files_ref_code}
}};
return resource;
}}
"""
TUNING_KB_BINARY_RES_FUNC_FMT = """
const OP_RUNTIME_KB_RES& {op_type}TuningResource() {{
static const OP_RUNTIME_KB_RES resource = {{
{reference_code}
}};
return resource;
}}
"""
OP_RESOURCE_CPP_FMT = """/******************{op_type}算子的所有资源**********************/
#include "register/op_impl_registry.h"
#include <vector>
#include <tuple>
#include <map>
#include <graph/ascend_string.h>
#include <static_space.h>
using OP_HOST_FUNC_HANDLE = std::vector<void *>;
using OP_RES = std::tuple<const uint8_t *, const uint8_t *>;
using OP_BINARY_RES = std::vector<OP_RES>;
using OP_RUNTIME_KB_RES = std::vector<OP_RES>;
using OP_RESOURCES = std::map<ge::AscendString,
std::tuple<OP_HOST_FUNC_HANDLE, OP_BINARY_RES, OP_RUNTIME_KB_RES>>;
namespace {op_type} {{
auto initializer = StaticSpaceInitializer::GetInstance();;
}}
// 资源声明
// extend resource
{extend_declaration}
// Tiling
{tiling_declaration}
// InferShape
{infer_shape_declaration}
// Tuning
{tuning_bank_key_declaration}
{tuning_bank_parse_declaration}
{tuning_helper_declaration}
// kernel 二进制
{binary_config_declaration}
{kernel_files_declaration}
// kb 二进制
{tuning_kb_declaration}
namespace l0op {{
// 资源函数
// Tiling register resource func
{tiling_reg_func}
// InferShape register resource func
{infer_shape_reg_func}
// Tuning register resource func
{tuning_reg_func}
// kernel resource func
{kernel_resource}
// Tuning resource func
{tuning_kb_resource}
}}
// extend resource func
{extend_reg_func}
"""
@staticmethod
def _extract_op_symbol_pair(symbol_file: str, search_key: str, prefix: str, suffix: str):
symbol_ret = shell_checkout_key_func(symbol_file, search_key)
for symbol in symbol_ret.splitlines():
symbol_name = symbol.split("::")[-1]
if not (symbol_name.startswith(prefix) and symbol_name.endswith(suffix)):
log.warning(f"symbol not satisfied with the format:{prefix}<op_type>{suffix}, skip")
continue
op_type = symbol_name
if prefix:
op_type = op_type[len(prefix):]
if suffix:
op_type = op_type[:-len(suffix)]
yield op_type, symbol
@staticmethod
def _extract_op_symbol_pair_v2(symbol_file: str, search_key: str, prefix: str):
symbol_ret = shell_checkout_key_func(symbol_file, search_key)
for symbol in symbol_ret.splitlines():
symbol_name = symbol.split("::")[-1]
if not (symbol_name.startswith(prefix)):
log.warning(f"symbol not satisfied with the format:{prefix}, skip")
continue
op_type = symbol_name
if prefix:
op_type = op_type[len(prefix):]
op_type = op_type.split("_")[0]
yield op_type, symbol
@staticmethod
def _extract_register_symbol(register_symbol: str):
if not register_symbol:
return "", "", "nullptr"
symbol_data = register_symbol.split("::")
namespace = "::".join(symbol_data[:-1])
if "anonymous" in namespace:
return "", "", "nullptr"
func_name = symbol_data[-1]
reference_code = f"&{register_symbol}"
return namespace, func_name, reference_code
@staticmethod
def _gen_binary_res_code(files):
declaration = ""
reference_code = ""
for binary_file in files:
if ("relocatable" in binary_file.name):
continue
binary_name = binary_file.name.replace(".", "_").replace("-", "_")
declaration += f"""// {binary_file.name}
extern const uint8_t _binary_{binary_name}_start[];
extern const uint8_t _binary_{binary_name}_end[];
"""
reference_code += f"{{_binary_{binary_name}_start, _binary_{binary_name}_end}},\n"
return declaration, reference_code
def gen_ops_ini_files(self):
self.analyze_ops_files()
self._analyze_symbols()
self._analyze_ops_l0op()
if not os.path.exists(self._op_resource_path):
os.makedirs(self._op_resource_path)
for op_type in self._l0op_list:
ini_content = self.generate_op_resouce_ini(op_type)
self._save_op_resource(op_type, ini_content)
for op_type in self._op_res:
if op_type in self._l0op_list:
continue
ini_content = self.generate_op_resouce_ini(op_type)
self._save_op_resource(op_type, ini_content)
def generate_op_resouce_ini(self, op_type: str) -> str:
value_dict = {
"op_type": op_type,
}
value_dict.update(self._gen_register_resouce_code(op_type))
value_dict.update(self._gen_tuning_register_resouce_code(op_type))
value_dict.update(self._gen_binary_resource_code(op_type))
sepical_ops = {"MatMulV2": "MatMul"}
if op_type in sepical_ops:
value_dict["kernel_files_declaration"] = ""
value_dict["kernel_resource"] = f"""
extern const OP_BINARY_RES& {sepical_ops[op_type]}KernelResource();
const OP_BINARY_RES& {op_type}KernelResource() {{
return {sepical_ops[op_type]}KernelResource();
}}
"""
return self.OP_RESOURCE_CPP_FMT.format_map(value_dict)
def analyze_ops_files(self):
if not self._ops_info:
return self._op_res
with open(self._ops_info, "r") as autogen_fd:
ops_info_json = json.load(autogen_fd)
for ops in ops_info_json:
if 'opFile' in ops_info_json[ops]:
json_file = f"{ops_info_json[ops]['opFile']['value']}.json"
else:
o_lists = list(Path(self._binary_path).rglob(f"{self._soc_version}/**/*{ops}*.o"))
if len(o_lists) == 0:
continue
else:
json_file = f"{os.path.basename(os.path.dirname(o_lists[0]))}.json"
json_path = self._binary_path / json_file
if "opp/built-in/" in str(self._binary_path):
json_path = self._binary_path / "config" / self._soc_version / "ops_transformer" / json_file
if not os.path.exists(json_path):
continue
with open(json_path, "r") as op_json_fd:
op_json_content = json.load(op_json_fd)
if "binList" not in op_json_content or len(op_json_content["binList"]) == 0:
continue
bin_json_file = self._binary_path / op_json_content["binList"][0]["binInfo"]["jsonFilePath"].split("/", 1)[1]
if "opp/built-in/" in str(self._binary_path):
bin_json_file = self._binary_path / self._soc_version / "ops_transformer" / op_json_content["binList"][0]["binInfo"]["jsonFilePath"].split("/", 1)[1]
ops_path = os.path.dirname(bin_json_file)
self._op_res[ops].binary_config_files.append(json_path)
self._op_res[ops].kernel_files.extend(sorted(Path(ops_path).iterdir()))
for kb_json in list(Path(self._tuning_basic_path).rglob(f"*_AiCore_*_runtime_kb.json")):
ops = kb_json.name.split("_AiCore_")[-1].split("_runtime_kb")[0]
self._op_res[ops].runtime_kb_files.append(kb_json)
self._op_res[ops].runtime_kb_files.sort(key=lambda p: p.name)
return self._op_res
def _analyze_ops_l0op(self):
opapi_symbol = self._build_dir / "opapi_transformer.txt"
if not os.path.exists(opapi_symbol):
return
for op_type, _ in self._extract_op_symbol_pair(
opapi_symbol, "_kernelName_Be_Defined_Multi_Times__", "", ""
):
self._l0op_list.append(op_type.split("_kernelName_")[0])
self._l0op_list.sort()
def _save_op_resource(self, op_type, res_content):
res_cpp_file = self._op_resource_path / f"{op_type}_op_resource.cpp"
try:
res_cpp_file.unlink()
except FileNotFoundError:
pass
flags = os.O_WRONLY | os.O_CREAT
modes = stat.S_IWUSR | stat.S_IRUSR
with os.fdopen(os.open(res_cpp_file, flags, modes), "w") as fd:
fd.write(res_content)
def _analyze_symbols(self):
ophost_symbol = self._build_dir / "ophost_transformer.txt"
if not os.path.exists(ophost_symbol):
return
for op_type, symbol in self._extract_op_symbol_pair(
ophost_symbol, "op_impl_register_infershape_", "op_impl_register_infershape_", ""
):
self._op_res[op_type].infer_shape_register = symbol
for op_type, symbol in self._extract_op_symbol_pair(
ophost_symbol, "op_impl_register_optiling_", "op_impl_register_optiling_", ""
):
self._op_res[op_type].tiling_register = symbol
for op_type, symbol in self._extract_op_symbol_pair_v2(
ophost_symbol, "op_impl_register_template_", "op_impl_register_template_"
):
self._op_res[op_type].extend_register.append(symbol)
for op_type, symbol in self._extract_op_symbol_pair(
ophost_symbol, "BankKeyRegistryInterf", "g_", "BankKeyRegistryInterf"
):
self._op_res[op_type].tuning_bank_key_register = symbol
for op_type, symbol in self._extract_op_symbol_pair(ophost_symbol, "BankParseInterf", "g_", "BankParseInterf"):
self._op_res[op_type].tuning_bank_parse_register = symbol
for op_type, symbol in self._extract_op_symbol_pair(
ophost_symbol, "g_tuning_tiling_", "g_tuning_tiling_", "Helper"
):
self._op_res[op_type].tuning_tiling_helper = symbol
def _gen_register_resouce_code(self, op_type: str):
""" 注册函数 """
namespace, func_name, reference_code = self._extract_register_symbol(
self._op_res[op_type].tiling_register
)
symbol_map = {
"op_type": op_type,
"namespace": namespace,
"func_name": func_name,
"reference_code": reference_code,
}
tiling_declaration = self.TILING_REG_DECL_FMT.format_map(symbol_map) if func_name else ""
tiling_reg_func = self.TILING_REG_RES_FUNC_FMT.format_map(symbol_map) if func_name else ""
reference_code_list = []
extend_declaration = ""
for symbol in self._op_res[op_type].extend_register:
namespace, func_name, reference_code = self._extract_register_symbol(symbol)
if func_name:
extend_declaration += self.EXTEND_REG_DECL_FMT.format(namespace = namespace, func_name = func_name)
reference_code_list.append(reference_code)
reference_code = ", ".join(reference_code_list)
extend_reg_func = self.EXTLEND_REG_RES_FUNC_FMT.format(op_type = op_type, reference_code = reference_code)
namespace, func_name, reference_code = self._extract_register_symbol(self._op_res[op_type].infer_shape_register)
symbol_map = {
"op_type": op_type,
"namespace": namespace,
"func_name": func_name,
"reference_code": reference_code,
}
infer_shape_declaration = self.INFER_SHAPE_REG_DECL_FMT.format_map(symbol_map) if func_name else ""
infer_shape_reg_func = self.INFER_SHAPE_REG_RES_FUNC_FMT.format_map(symbol_map) if func_name else ""
return {
"tiling_declaration": tiling_declaration,
"infer_shape_declaration": infer_shape_declaration,
"tiling_reg_func": tiling_reg_func,
"infer_shape_reg_func": infer_shape_reg_func,
"extend_reg_func": extend_reg_func,
"extend_declaration": extend_declaration
}
def _gen_tuning_register_resouce_code(self, op_type: str):
""" 知识库注册函数 """
namespace, func_name, tuning_bank_key_ref_code = self._extract_register_symbol(
self._op_res[op_type].tuning_bank_key_register
)
tuning_bank_key_declaration = self.TUNING_REG_DECL_FMT.format(
namespace=namespace,
class_type="OpBankKeyFuncRegistryV2",
func_name=func_name,
) if func_name else ""
namespace, func_name, tuning_bank_parse_ref_code = self._extract_register_symbol(
self._op_res[op_type].tuning_bank_parse_register
)
tuning_bank_parse_declaration = self.TUNING_REG_DECL_FMT.format(
namespace=namespace,
class_type="OpBankKeyFuncRegistryV2",
func_name=func_name,
) if func_name else ""
namespace, func_name, tuning_helper_ref_code = self._extract_register_symbol(
self._op_res[op_type].tuning_tiling_helper
)
tuning_helper_declaration = self.TUNING_REG_DECL_FMT.format(
namespace=namespace,
class_type=f"{op_type}ClassHelper",
func_name=func_name,
) if func_name else ""
tuning_reg_func = self.TUNING_REG_RES_FUNC_FMT.format(
op_type=op_type,
tuning_bank_key=tuning_bank_key_ref_code,
tuning_bank_parse=tuning_bank_parse_ref_code,
tuning_helper=tuning_helper_ref_code,
)
return {
"tuning_bank_key_declaration": tuning_bank_key_declaration,
"tuning_bank_parse_declaration": tuning_bank_parse_declaration,
"tuning_helper_declaration": tuning_helper_declaration,
"tuning_reg_func": tuning_reg_func,
}
def _gen_binary_resource_code(self, op_type: str) -> str:
"""二进制"""
binary_config_declaration, binary_config_ref_code = self._gen_binary_res_code(
self._op_res[op_type].binary_config_files
)
kernel_files_declaration, kernel_files_ref_code = self._gen_binary_res_code(self._op_res[op_type].kernel_files)
kernel_resource = self.KERNEL_BINARY_RES_FUNC_FMT.format(
op_type=op_type,
binary_config_ref_code=binary_config_ref_code,
kernel_files_ref_code=kernel_files_ref_code,
) if kernel_files_ref_code else ""
tuning_kb_declaration, tuning_kb_ref_code = self._gen_binary_res_code(self._op_res[op_type].runtime_kb_files)
tuning_kb_resource = self.TUNING_KB_BINARY_RES_FUNC_FMT.format(
op_type=op_type,
reference_code=tuning_kb_ref_code,
)
return {
"binary_config_declaration": binary_config_declaration,
"kernel_files_declaration": kernel_files_declaration,
"tuning_kb_declaration": tuning_kb_declaration,
"kernel_resource": kernel_resource,
"tuning_kb_resource": tuning_kb_resource,
}
def generate_op_resource_h_file(args):
soc_version: str = args.soc_version
build_dir = args.build_dir
gen_ini = GenOpResourceIni(soc_version, build_dir, args.jit)
gen_ini.gen_ops_ini_files()
return
def parser_generate_op_resource_h_file(subparsers):
gen_resource_ini_parser = subparsers.add_parser(name='GenStaticOpResourceIni',
help='Generate xxx_op_resource.h on consolidation server')
gen_resource_ini_parser.add_argument('-s', '--soc_version', type=str, required=True, dest="soc_version",
help="Operator Name, eg: ascend910b, ascend310p")
gen_resource_ini_parser.add_argument('-b', '--build_dir', type=str, required=True, dest="build_dir",
help="Input build dir for this project")
gen_resource_ini_parser.add_argument('-j', '--jit', action='store_true', dest="jit",
help="Generate xxx_op_resource.h with cann package")
gen_resource_ini_parser.set_defaults(func=generate_op_resource_h_file)
def execute_argus_parse_func():
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(help='Subparsers Commands')
""" 配置静态编译参数及执行信息 """
parser_compile_static_library(subparsers)
""" 配置头文件生成功能参数及执行信息 """
parser_generate_op_resource_h_file(subparsers)
""" 生成指定库的symbol文件 """
parser_generate_symbol(subparsers)
""" 执行函数功能 """
args = parser.parse_args()
args.func(args)
if __name__ == "__main__":
execute_argus_parse_func()
exit(0)