import argparse
import glob
import sys
import os
import re
import datetime
import json
from typing import List
import opdesc_parser
import const_var
PYF_PATH = os.path.dirname(os.path.realpath(__file__))
IMPL_HEAD = '''#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
Copyright (c) Huawei Technologies Co., Ltd. {}-{}. All rights reserved.
"""
import os, sys
import ctypes
import json
import shutil
from asc_op_compile_base.common.platform import get_soc_spec
from asc_op_compile_base.common.utils import para_check
from asc_op_compile_base.asc_op_compiler import compile_op, replay_op, check_op_cap, generalize_op_params, get_code_channel, OpInfo
from asc_op_compile_base.common.buildcfg import get_default_build_config
from asc_op_compile_base.common.buildcfg import get_current_build_config
from asc_op_compile_base.common import register as tbe_register
PYF_PATH = os.path.dirname(os.path.realpath(__file__))
__version__ = '2.0.0'
DTYPE_MAP = {{"float32": ["DT_FLOAT", "float"],
"float16": ["DT_FLOAT16", "half"],
"int8": ["DT_INT8", "int8_t"],
"int16": ["DT_INT16", "int16_t"],
"int32": ["DT_INT32", "int32_t"],
"int64": ["DT_INT64", "int64_t"],
"uint1": ["DT_UINT1", "uint1b_t"],
"uint8": ["DT_UINT8", "uint8_t"],
"uint16": ["DT_UINT16", "uint16_t"],
"uint32": ["DT_UINT32", "uint32_t"],
"uint64": ["DT_UINT64", "uint64_t"],
"bool": ["DT_BOOL", "bool"],
"double": ["DT_DOUBLE", "double"],
"dual": ["DT_DUAL", "unknown"],
"dual_sub_int8": ["DT_DUAL_SUB_INT8", "unknown"],
"dual_sub_uint8": ["DT_DUAL_SUB_UINT8", "unknown"],
"string": ["DT_STRING", "unknown"],
"complex32": ["DT_COMPLEX32", "complex32"],
"complex64": ["DT_COMPLEX64", "complex64"],
"complex128": ["DT_COMPLEX128", "unknown"],
"qint8": ["DT_QINT8", "unknown"],
"qint16": ["DT_QINT16", "unknown"],
"qint32": ["DT_QINT32", "unknown"],
"quint8": ["DT_QUINT8", "unknown"],
"quint16": ["DT_QUINT16", "unknown"],
"resource": ["DT_RESOURCE", "unknown"],
"string_ref": ["DT_STRING_REF", "unknown"],
"int4": ["DT_INT4", "int4b_t"],
"bfloat16": ["DT_BF16", "bfloat16_t"],
"float8_e5m2": ["DT_FLOAT8_E5M2", "fp8_e5m2_t"],
"float8_e4m3fn": ["DT_FLOAT8_E4M3FN", "fp8_e4m3fn_t"],
"hifloat8":["DT_HIFLOAT8", "hifloat8_t"],
"float8_e8m0":["DT_FLOAT8_E8M0", "fp8_e8m0_t"],
"float4_e2m1":["DT_FLOAT4_E2M1", "fp4x2_e2m1_t"],
"float4_e1m2":["DT_FLOAT4_E1M2", "fp4x2_e1m2_t"],
"int2": ["DT_INT2", "int2b_t"]}}
def add_dtype_fmt_option_single(x, x_n, is_ref: bool = False):
options = []
x_fmt = x.get("format")
x_dtype = x.get("dtype")
x_n_in_kernel = x_n + '_REF' if is_ref else x_n
options.append("-DDTYPE_{{n}}={{t}}".format(n=x_n_in_kernel, t=DTYPE_MAP.get(x_dtype)[1]))
options.append("-DORIG_DTYPE_{{n}}={{ot}}".format(n=x_n_in_kernel, ot=DTYPE_MAP.get(x_dtype)[0]))
options.append("-DFORMAT_{{n}}=FORMAT_{{f}}".format(n=x_n_in_kernel, f=x_fmt))
return options
def get_dtype_fmt_options(__inputs__, __outputs__):
options = []
input_names = {}
output_names = {}
unique_param_name_set = set()
for idx, x in enumerate(__inputs__):
if x is None:
continue
x_n = input_names[idx].upper()
unique_param_name_set.add(x_n)
options += add_dtype_fmt_option_single(x, x_n)
for idx, x in enumerate(__outputs__):
if x is None:
continue
x_n = output_names[idx].upper()
if x_n in unique_param_name_set:
options += add_dtype_fmt_option_single(x, x_n, True)
else:
options += add_dtype_fmt_option_single(x, x_n)
return options
def load_dso(so_path):
try:
ctypes.CDLL(so_path)
except OSError as error :
print(error)
raise RuntimeError("cannot open %s" %(so_path))
else:
print("load so succ ", so_path)
def get_shortsoc_compile_option(compile_option_list: list, shortsoc:str):
compile_options = []
if shortsoc in compile_option_list:
compile_options.extend(compile_option_list[shortsoc])
if '__ALLSOC__' in compile_option_list:
compile_options.extend(compile_option_list['__ALLSOC__'])
return compile_options
def get_kernel_source(src_file, dir_snake, dir_ex):
src_ex = os.path.join(PYF_PATH, "..", "ascendc", dir_ex, src_file)
if os.path.exists(src_ex):
return src_ex
src = os.environ.get('BUILD_KERNEL_SRC')
if src and os.path.exists(src):
return src
src = os.path.join(PYF_PATH, "..", "ascendc", dir_snake, src_file)
if os.path.exists(src):
return src
src = os.path.join(PYF_PATH, src_file)
if os.path.exists(src):
return src
src = os.path.join(PYF_PATH, "..", "ascendc", dir_snake, dir_snake + ".cpp")
if os.path.exists(src):
return src
src = os.path.join(PYF_PATH, "..", "ascendc", dir_ex, dir_ex + ".cpp")
if os.path.exists(src):
return src
src = os.path.join(PYF_PATH, "..", "ascendc", os.path.splitext(src_file)[0], src_file)
if os.path.exists(src):
return src
return src_ex
'''
IMPL_API = '''
@tbe_register.register_operator("{}", trans_bool_to_s8=False)
@para_check.check_op_params({})
def {}({}, kernel_name="{}", impl_mode=""):
{}
if get_current_build_config("enable_op_prebuild"):
return
__inputs__, __outputs__, __attrs__ = _build_args({})
options = get_dtype_fmt_options(__inputs__, __outputs__)
options += ["-x", "cce"]
bisheng = os.environ.get('BISHENG_REAL_PATH')
if bisheng is None:
bisheng = shutil.which("bisheng")
if bisheng is not None:
bisheng_path = os.path.dirname(bisheng)
tikcpp_path = os.path.realpath(os.path.join(bisheng_path, "..", "..", "tikcpp"))
else:
tikcpp_path = os.path.realpath("/usr/local/Ascend/cann/compiler/tikcpp")
options.append("-I" + tikcpp_path)
options.append("-I" + os.path.join(tikcpp_path, "..", "..", "include"))
options.append("-I" + os.path.join(tikcpp_path, "tikcfw"))
options.append("-I" + os.path.join(tikcpp_path, "tikcfw", "impl"))
options.append("-I" + os.path.join(tikcpp_path, "tikcfw", "interface"))
options.append("-I" + os.path.join(tikcpp_path,
"..", "..", "..", "latest", os.uname().machine+"-linux", "asc", "atcos"))
options.append("-I" + os.path.join(PYF_PATH, "..", "ascendc", "common"))
if impl_mode == "high_performance":
options.append("-DHIGH_PERFORMANCE=1")
elif impl_mode == "high_precision":
options.append("-DHIGH_PRECISION=1")
if get_current_build_config("enable_deterministic_mode") == 1:
options.append("-DDETERMINISTIC_MODE=1")
else:
options.append("-DDETERMINISTIC_MODE=0")
custom_compile_options = {},
custom_all_compile_options = {},
soc_version = get_soc_spec("SOC_VERSION")
soc_short = get_soc_spec("SHORT_SOC_VERSION").lower()
custom_compile_options_soc = get_shortsoc_compile_option(custom_compile_options[0], soc_short)
custom_all_compile_options_soc = get_shortsoc_compile_option(custom_all_compile_options[0], soc_short)
options += custom_all_compile_options_soc
options += custom_compile_options_soc
origin_func_name = "{}"
ascendc_src_dir_ex = "{}"
ascendc_src_dir = "{}"
ascendc_src_file = "{}"
src = get_kernel_source(ascendc_src_file, ascendc_src_dir, ascendc_src_dir_ex)
'''
REPLAY_OP_API = '''
print("start replay Acend C Operator {}, kernel name is {}")
tikreplay_codegen_path = tikcpp_path + "/tikreplaylib/lib"
tikreplay_stub_path = tikcpp_path + "/tikreplaylib/lib/" + soc_version
print("start load libtikreplaylib_codegen.so and libtikreplaylib_stub.so")
codegen_so_path = tikreplay_codegen_path + "/libtikreplaylib_codegen.so"
replaystub_so_path = tikreplay_stub_path + "/libtikreplaylib_stub.so"
if PYF_PATH.endswith("dynamic"):
op_replay_path = os.path.join(PYF_PATH, "..", "..", "op_replay")
else:
op_replay_path = os.path.join(PYF_PATH, "..", "op_replay")
replayapi_so_path = os.path.join(op_replay_path, "libreplay_{}_" + soc_short + ".so")
load_dso(codegen_so_path)
load_dso(replaystub_so_path)
load_dso(replayapi_so_path)
op_type = "{}"
entry_obj = os.path.join(op_replay_path, "{}_entry_" + soc_short + ".o")
code_channel = get_code_channel(src, kernel_name, op_type, options)
op_info = OpInfo(kernel_name = kernel_name, op_type = op_type, inputs = __inputs__, outputs = __outputs__,\\
attrs = __attrs__, impl_mode = impl_mode, param_type_dynamic = {})
res, msg = replay_op(op_info, entry_obj, code_channel, src, options)
if not res:
print("call replay op failed for %s and get into call compile op" %(msg))
compile_op(src, origin_func_name, op_info, options, code_channel, '{}')
'''
COMPILE_OP_API = '''
print("start compile Ascend C operator {}. kernel name is " + kernel_name)
op_type = "{}"
code_channel = get_code_channel(src, kernel_name, op_type, options)
op_info = OpInfo(kernel_name = kernel_name, op_type = op_type, inputs = __inputs__, outputs = __outputs__,\\
attrs = __attrs__, impl_mode = impl_mode, origin_inputs=[{}], origin_outputs = [{}],\\
param_type_dynamic = {}, mc2_ctx = {}, param_type_list = {}, init_value_list = {},\\
output_shape_depend_on_compute = {})
compile_op(src, origin_func_name, op_info, options, code_channel, '{}')
'''
COMPILE_OP_API_BUILT_IN = '''
print("start compile Ascend C operator {}. kernel name is " + kernel_name)
op_type = "{}"
code_channel = get_code_channel(src, kernel_name, op_type, options)
op_info = OpInfo(kernel_name = kernel_name, op_type = op_type, inputs = __inputs__, outputs = __outputs__,\\
attrs = __attrs__, impl_mode = impl_mode, origin_inputs=[{}], origin_outputs = [{}],\\
param_type_dynamic = {}, mc2_ctx = {}, param_type_list = {}, init_value_list = {},\\
output_shape_depend_on_compute = {})
op_compile_option = '{}'
dat_path = os.path.realpath(os.path.join(PYF_PATH, "..", "..", "ascendc_impl.dat"))
if os.path.exists(dat_path):
# dat file exists: built in hidden src file online compiling process. append vfs compile option in compile_op
abs_rel_kernel_src_path = "{}"
extend_options = {{}}
extend_options['opp_kernel_hidden_dat_path'] = dat_path
compile_op(abs_rel_kernel_src_path, origin_func_name, op_info, options, code_channel, op_compile_option,\\
extend_options)
else:
# dat file does not exist, run original compile cmd
compile_op(src, origin_func_name, op_info, options, code_channel, op_compile_option)
'''
SUP_API = '''
def {}({}, impl_mode=""):
__inputs__, __outputs__, __attrs__ = _build_args({})
ret_str = check_op_cap("{}", "{}", __inputs__, __outputs__, __attrs__)
ret_dict = json.loads(ret_str)
err_code = ret_dict.get("ret_code")
sup = "Unknown"
reason = "Unknown reason"
if err_code is not None:
if err_code == 0:
sup = "True"
reason = ""
elif err_code == 1:
sup = "False"
reason = ret_dict.get("reason")
else:
sup = "Unknown"
reason = ret_dict.get("reason")
return sup, reason
'''
CAP_API = '''
def {}({}, impl_mode=""):
__inputs__, __outputs__, __attrs__ = _build_args({})
result = check_op_cap("{}", "{}", __inputs__, __outputs__, __attrs__)
return result.decode("utf-8")
'''
GLZ_API = '''
@tbe_register.register_param_generalization("{}")
def {}_generalization({}, generalize_config=None):
__inputs__, __outputs__, __attrs__ = _build_args({})
ret_str = generalize_op_params("{}", __inputs__, __outputs__, __attrs__, generalize_config)
return [json.loads(ret_str)]
'''
ATTR_DEFAULT = {'bool': 'False', 'int': '0', 'float': '0.0', 'list_int': '[]',
'list_float': '[]', 'list_bool': '[]', 'list_list_int': '[[]]', 'str': ''}
def optype_snake(origin_str):
temp_str = origin_str[0].lower() + origin_str[1:]
new_str = re.sub(r'([A-Z])', r'_\1', temp_str).lower()
return new_str
def optype_snake_ex(s):
snake_case = ""
for i, c in enumerate(s):
if i == 0:
snake_case += c.lower()
elif c.isupper():
if s[i - 1] != '_':
if not s[i - 1].isupper():
snake_case += "_"
elif s[i - 1].isupper() and (i + 1) < len(s) and s[i + 1].islower():
snake_case += "_"
snake_case += c.lower()
else:
snake_case += c
return snake_case
class AdpBuilder(opdesc_parser.OpDesc):
def __init__(self: any, op_type: str):
self.argsdefv = []
self.op_compile_option:str = '{}'
super().__init__(op_type)
def write_adapt(self: any, impl_path, path: str, op_compile_option_all: list = None):
self._build_paradefault()
if os.environ.get('BUILD_BUILTIN_OPP') != '1' and impl_path != "":
if self.op_file.endswith("_apt"):
op_dir = self.op_file.replace("_apt", "")
src_file = os.path.join(impl_path, op_dir, self.op_file + ".cpp")
elif self.op_file.endswith("_910b"):
op_dir = self.op_file.replace("_910b", "")
src_file = os.path.join(impl_path, op_dir, self.op_file + ".cpp")
else:
src_file = os.path.join(impl_path, self.op_file, self.op_file + ".cpp")
if not os.path.exists(src_file):
print(f"[ERROR]: operator: {self.op_file} source file: {src_file} does not found, please check.")
return
out_path = os.path.abspath(path)
if self.dynamic_shape and not out_path.endswith('dynamic'):
out_path = os.path.join(path, 'dynamic')
os.makedirs(out_path, exist_ok=True)
adpfile = os.path.join(out_path, self.op_file + '.py')
self._gen_op_compile_option(op_compile_option_all)
with os.fdopen(os.open(adpfile, const_var.WFLAGS, const_var.WMODES), 'w') as fd:
self._write_head(fd)
self._write_argparse(fd)
self._write_impl(fd, impl_path)
if self.op_chk_support:
self._write_cap('check_supported', fd)
self._write_cap('get_op_support_info', fd)
if self.op_fmt_sel:
self._write_cap('op_select_format', fd)
self._write_cap('get_op_specific_info', fd)
if self.op_range_limit == 'limited' or self.op_range_limit == 'dynamic':
self._write_glz(fd)
def _gen_op_compile_option(self: any, op_compile_option_all: list = None):
if op_compile_option_all is not None:
if self.op_type in op_compile_option_all:
self.op_compile_option = json.dumps(op_compile_option_all[self.op_type])
elif "__all__" in op_compile_option_all:
self.op_compile_option = json.dumps(op_compile_option_all["__all__"])
def _ip_argpack(self: any, default: bool = True) -> list:
args = []
for i in range(len(self.input_name)):
arg = self.input_name[i]
if default and self.argsdefv[i] is not None:
arg += '=' + self.argsdefv[i]
args.append(arg)
return args
def _op_argpack(self: any, default: bool = True) -> list:
args = []
argidx = len(self.input_name)
for i in range(len(self.output_name)):
arg = self.output_name[i]
if default and self.argsdefv[i + argidx] is not None:
arg += '=' + self.argsdefv[i + argidx]
args.append(arg)
return args
def _attr_argpack(self: any, default: bool = True) -> list:
args = []
argidx = len(self.input_name) + len(self.output_name)
for i in range(len(self.attr_list)):
att = self.attr_list[i]
arg = att
if default and self.argsdefv[i + argidx] is not None:
if self.attr_val.get(att).get('type') == 'str':
arg += '="' + self.argsdefv[i + argidx] + '"'
elif self.attr_val.get(att).get('type') == 'bool':
arg += '=' + self.argsdefv[i + argidx].capitalize()
elif self.attr_val.get(att).get('type') == 'list_bool':
arg += '=' + "[" + ", ".join(word.strip().capitalize() \
for word in self.argsdefv[i + argidx].strip('[]').split(',')) + "]"
else:
arg += '=' + self.argsdefv[i + argidx]
args.append(arg)
return args
def _build_paralist(self: any, default: bool = True) -> str:
args = []
args.extend(self._ip_argpack(default))
args.extend(self._op_argpack(default))
args.extend(self._attr_argpack(default))
return ', '.join(args)
def _io_parachk(self: any, types: list, type_name: str) -> list:
chk = []
for iot in types:
if iot == 'optional':
ptype = 'OPTION'
else:
ptype = iot.upper()
chk.append('para_check.{}_{}'.format(ptype, type_name))
return chk
def _attr_parachk(self: any) -> list:
chk = []
for att in self.attr_list:
att_type = self.attr_val.get(att).get('type').upper()
chk.append('para_check.{}_ATTR_{}'.format('OPTION', att_type))
return chk
def _build_parachk(self: any) -> str:
chk = []
chk.extend(self._io_parachk(self.input_type, 'INPUT'))
chk.extend(self._io_parachk(self.output_type, 'OUTPUT'))
chk.extend(self._attr_parachk())
chk.append('para_check.KERNEL_NAME')
return ', '.join(chk)
def _build_virtual(self: any) -> str:
virt_exp = []
for index in range(len(self.input_name)):
if self.input_virt.get(index) is None:
continue
val = []
val.append('"param_name":"{}"'.format(self.input_name[index]))
val.append('"index":{}'.format(index))
val.append('"dtype":"{}"'.format(self.input_dtype[index].split(',')[0]))
val.append('"format":"{}"'.format(self.input_fmt[index].split(',')[0]))
val.append('"ori_format":"{}"'.format(self.input_fmt[index].split(',')[0]))
val.append('"paramType":"optional"')
val.append('"shape":[1]')
val.append('"ori_shape":[1]')
virt_exp.append(' ' + self.input_name[index] + ' = {' + ','.join(val) + '}')
if len(virt_exp) > 0:
return '\n'.join(virt_exp)
else:
return ' # do ascendc build step'
def _build_mc2_ctx(self: any):
if len(self.mc2_ctx) != 0:
return '["' + '", "'.join(self.mc2_ctx) + '"]'
return '[]'
def _build_paradefault(self: any):
optional = False
argtypes = []
argtypes.extend(self.input_type)
argtypes.extend(self.output_type)
in_idx = 0
for atype in argtypes:
if atype == 'optional':
optional = True
if optional:
self.argsdefv.append('None')
else:
self.argsdefv.append(None)
in_idx += 1
for attr in self.attr_list:
atype = self.attr_val.get(attr).get('paramType')
if atype == 'optional':
optional = True
attrval = self.attr_val.get(attr).get('defaultValue')
if attrval is not None:
optional = True
if type == "bool":
attrval = attrval.capitalize()
elif type == "str":
attrval = "\"" + attrval + "\""
self.argsdefv.append(attrval)
continue
if optional:
self.argsdefv.append(ATTR_DEFAULT.get(self.attr_val.get(attr).get('type')))
else:
self.argsdefv.append(None)
def _write_head(self: any, fd: object):
now = datetime.datetime.now()
curr_year = now.year
former_year = curr_year - 1
fd.write(IMPL_HEAD.format(former_year, curr_year, self.input_ori_name, self.output_ori_name))
def _write_argparse(self: any, fd: object):
args = self._build_paralist(False)
fd.write('def _build_args({}):\n'.format(args))
fd.write(' __inputs__ = []\n')
fd.write(' for arg in [{}]:\n'.format(', '.join(self.input_name)))
fd.write(' if arg != None:\n')
fd.write(' if isinstance(arg, (list, tuple)):\n')
fd.write(' if len(arg) == 0:\n')
fd.write(' continue\n')
fd.write(' __inputs__.append(arg[0])\n')
fd.write(' else:\n')
fd.write(' __inputs__.append(arg)\n')
fd.write(' else:\n')
fd.write(' __inputs__.append(arg)\n')
fd.write(' __outputs__ = []\n')
fd.write(' for arg in [{}]:\n'.format(', '.join(self.output_name)))
fd.write(' if arg != None:\n')
fd.write(' if isinstance(arg, (list, tuple)):\n')
fd.write(' if len(arg) == 0:\n')
fd.write(' continue\n')
fd.write(' __outputs__.append(arg[0])\n')
fd.write(' else:\n')
fd.write(' __outputs__.append(arg)\n')
fd.write(' else:\n')
fd.write(' __outputs__.append(arg)\n')
fd.write(' __attrs__ = []\n')
for attr in self.attr_list:
fd.write(' if {} != None:\n'.format(attr))
fd.write(' attr = {}\n')
fd.write(' attr["name"] = "{}"\n'.format(attr))
fd.write(' attr["dtype"] = "{}"\n'.format(self.attr_val.get(attr).get('type')))
fd.write(' attr["value"] = {}\n'.format(attr))
fd.write(' __attrs__.append(attr)\n')
fd.write(' return __inputs__, __outputs__, __attrs__\n')
def _get_kernel_source(self: any, kernel_src_dir, src_file, dir_snake, dir_ex):
src_ex = os.path.join(kernel_src_dir, dir_ex, src_file)
if os.path.exists(src_ex):
return src_ex
src = os.environ.get('BUILD_KERNEL_SRC')
if src and os.path.exists(src):
return src
src = os.path.join(kernel_src_dir, dir_snake, src_file)
if os.path.exists(src):
return src
src = os.path.join(kernel_src_dir, src_file)
if os.path.exists(src):
return src
src = os.path.join(kernel_src_dir, dir_snake, dir_snake + ".cpp")
if os.path.exists(src):
return src
src = os.path.join(kernel_src_dir, dir_ex, dir_ex + ".cpp")
if os.path.exists(src):
return src
src = os.path.join(kernel_src_dir, os.path.splitext(src_file)[0], src_file)
if os.path.exists(src):
return src
return src_ex
def _write_impl(self: any, fd: object, impl_path: str = ""):
argsdef = self._build_paralist()
argsval = self._build_paralist(False)
pchk = self._build_parachk()
if len(self.kern_name) > 0:
kern_name = self.kern_name
else:
kern_name = self.op_intf
src = self.op_file + '.cpp'
virt_exprs = self._build_virtual()
fd.write(IMPL_API.format(self.op_type, pchk, self.op_intf, argsdef, kern_name, virt_exprs, argsval,\
self.custom_compile_options, self.custom_all_compile_options, self.op_intf,\
optype_snake_ex(self.op_type), optype_snake(self.op_type), src))
if self.op_replay_flag:
fd.write(REPLAY_OP_API.format(self.op_type, kern_name, self.op_file, self.op_type, self.op_file,\
self.param_type_dynamic, self.op_compile_option))
else:
if os.environ.get('BUILD_BUILTIN_OPP') == '1':
relative_kernel_src_path = os.path.realpath(self._get_kernel_source(impl_path, src,\
optype_snake(self.op_type), optype_snake_ex(self.op_type)))
abs_rel_kernel_src_path = os.path.join("/", os.path.relpath(relative_kernel_src_path, impl_path))
fd.write(COMPILE_OP_API_BUILT_IN.format(self.op_type, self.op_type, ', '.join(self.input_name),\
', '.join(self.output_name), self.param_type_dynamic, self._build_mc2_ctx(),\
self.input_type + self.output_type, self.output_init_value, self.output_shape_depend_on_compute,\
self.op_compile_option, abs_rel_kernel_src_path))
else:
fd.write(COMPILE_OP_API.format(self.op_type, self.op_type, ', '.join(self.input_name),\
', '.join(self.output_name), self.param_type_dynamic, self._build_mc2_ctx(),\
self.input_type + self.output_type, self.output_init_value, self.output_shape_depend_on_compute,\
self.op_compile_option))
def _write_cap(self: any, cap_name: str, fd: object):
argsdef = self._build_paralist()
argsval = self._build_paralist(False)
if cap_name == 'check_supported':
fd.write(SUP_API.format(cap_name, argsdef, argsval, cap_name, self.op_type))
else:
fd.write(CAP_API.format(cap_name, argsdef, argsval, cap_name, self.op_type))
def _write_glz(self: any, fd: object):
argsdef = self._build_paralist()
argsval = self._build_paralist(False)
fd.write(GLZ_API.format(self.op_type, self.op_intf, argsdef, argsval, self.op_type))
def write_scripts(cfgfile: str, cfgs: dict, dirs: dict, ops: list = None, op_compile_option:list = None):
batch_lists = cfgs.get(const_var.REPLAY_BATCH).split(';')
iterator_lists = cfgs.get(const_var.REPLAY_ITERATE).split(';')
file_map = {}
op_descs = opdesc_parser.get_op_desc(cfgfile, batch_lists, iterator_lists, AdpBuilder,\
ops, dirs.get(const_var.AUTO_GEN_DIR))
for op_desc in op_descs:
op_desc.write_adapt(dirs.get(const_var.CFG_IMPL_DIR), dirs.get(const_var.CFG_OUT_DIR), op_compile_option)
file_map[op_desc.op_type] = op_desc.op_file
return file_map
class OpFileNotExistsError(Exception):
"""File does not exist error."""
def __str__(self) -> str:
return f"File aic-*-ops-info.ini does not exist in directory {super().__str__()}"
def get_ops_info_files(opsinfo_dir: List[str]) -> List[str]:
"""Get all ops info files."""
ops_info_files = []
for _dir in opsinfo_dir:
ops_info_files.extend(glob.glob(f'{_dir}/aic-*-ops-info.ini'))
return sorted(ops_info_files)
def parse_args(argv):
"""Command line parameter parsing"""
parser = argparse.ArgumentParser()
parser.add_argument('argv', nargs='+')
parser.add_argument('--opsinfo-dir', nargs='*', default=None)
return parser.parse_args(argv)
if __name__ == '__main__':
args = parse_args(sys.argv)
if len(args.argv) <= 6:
raise RuntimeError('arguments must greater equal than 6')
rep_cfg = {}
rep_cfg[const_var.REPLAY_BATCH] = args.argv[2]
rep_cfg[const_var.REPLAY_ITERATE] = args.argv[3]
cfg_dir = {}
cfg_dir[const_var.CFG_IMPL_DIR] = args.argv[4]
cfg_dir[const_var.CFG_OUT_DIR] = args.argv[5]
cfg_dir[const_var.AUTO_GEN_DIR] = args.argv[6]
ops_infos = []
if args.opsinfo_dir:
ops_infos.extend(get_ops_info_files(args.opsinfo_dir))
if not ops_infos:
raise OpFileNotExistsError(args.opsinfo_dir)
else:
ops_infos.append(args.argv[1])
for ops_info in ops_infos:
write_scripts(cfgfile=ops_info, cfgs=rep_cfg, dirs=cfg_dir)