import copy
import ast
import operator
import re
from enum import Enum, auto
from dataclasses import dataclass
from typing import List, Any
from tbe.common.context import get_context
from .log_utils import AscendCLogLevel
from .ascendc_common_utility import CommonUtility
from .get_op_tiling import OpInfo
TILING_DECLARE_MAP = []
TILING_SELECT_MAP = []
ONE_BYTE_BIT = 8
UI_RANGE = 0
UI_LIST = 1
UI_MIX = 2
ASCENDC_TPL_INPUT_BIAS = 1e8
ASCENDC_TPL_OUTPUT_BIAS = 9e8
ASCENDC_TPL_DATATYPE_MAX = 42
ASCENDC_COMPILE_DATATYPE_MAP = {
"DT_FLOAT": ["float32", 0],
"DT_FLOAT16": ["float16", 1],
"DT_INT8": ["int8", 2],
"DT_INT32": ["int32", 3],
"DT_UINT8": ["uint8", 4],
"DT_INT16": ["int16", 6],
"DT_UINT16": ["uint16", 7],
"DT_UINT32": ["uint32", 8],
"DT_INT64": ["int64", 9],
"DT_UINT64": ["uint64", 10],
"DT_DOUBLE": ["double", 11],
"DT_BOOL": ["bool", 12],
"DT_COMPLEX64": ["complex64", 16],
"DT_BF16": ["bfloat16", 27],
"DT_INT4": ["int4", 29],
"DT_UINT1": ["uint1", 30],
"DT_INT2": ["int2", 31],
"DT_COMPLEX32": ["complex32", 33],
"DT_HIFLOAT8": ["hifloat8", 34],
"DT_FLOAT8_E5M2": ["float8_e5m2", 35],
"DT_FLOAT8_E4M3FN": ["float8_e4m3fn", 36],
"DT_FLOAT4_E2M1": ["float4_e2m1", 40],
"DT_FLOAT4_E1M2": ["float4_e1m2", 41],
"DT_MAX": ["unknown", ASCENDC_TPL_DATATYPE_MAX],
}
ASCENDC_COMPILE_DATATYPE_REVERT_MAP = {v[-1]: v[0] for v in ASCENDC_COMPILE_DATATYPE_MAP.values()}
ASCENDC_KERNEL_TEMPLATE_INPUT_DATATYPE_MAP = {
"float": 0,
"half": 1,
"int8_t": 2,
"int32_t": 3,
"uint8_t": 4,
"int16_t": 6,
"uint16_t": 7,
"uint32_t": 8,
"int64_t": 9,
"uint64_t": 10,
"double": 11,
"bool": 12,
"complex64": 16,
"bfloat16_t": 27,
"int4b_t": 29,
"hifloat8_t": 34,
"fp8_e5m2_t": 35,
"fp8_e4m3fn_t": 36,
"fp4x2_e2m1_t": 40,
"fp4x2_e1m2_t": 41
}
ASCENDC_KERNEL_TEMPLATE_INPUT_KERNEL_TYPE_MAP = {
"ASCENDC_TPL_AIV_ONLY": 0,
"ASCENDC_TPL_AIC_ONLY": 1,
"ASCENDC_TPL_MIX_AIV_1_0": 4,
"ASCENDC_TPL_MIX_AIC_1_0": 5,
"ASCENDC_TPL_MIX_AIC_1_1": 6,
"ASCENDC_TPL_MIX_AIC_1_2": 7,
"ASCENDC_TPL_AICORE": 8,
"ASCENDC_TPL_VECTORCORE": 9,
"ASCENDC_TPL_MIX_AICORE": 10,
"ASCENDC_TPL_MIX_VECTOR_CORE": 11,
"ASCENDC_TPL_MAX": 12
}
ASCENDC_TPL_DATAFORMAT_MAX = 55
ASCENDC_COMPILE_DATAFORMAT_MAP = {
"FORMAT_NCHW": ["NCHW", 0],
"FORMAT_NHWC": ["NHWC", 1],
"FORMAT_ND": ["ND", 2],
"FORMAT_NC1HWC0": ["NC1HWC0", 3],
"FORMAT_FRACTAL_Z": ["FRACTAL_Z", 4],
"FORMAT_NC1C0HWPAD": ["NC1C0HWPAD", 5],
"FORMAT_NHWC1C0": ["NHWC1C0", 6],
"FORMAT_FSR_NCHW": ["FSR_NCHW", 7],
"FORMAT_FRACTAL_DECONV": ["FRACTAL_DECONV", 8],
"FORMAT_C1HWNC0": ["C1HWNC0", 9],
"FORMAT_FRACTAL_DECONV_TRANSPOSE": ["FRACTAL_DECONV_TRANSPOSE", 10],
"FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS": ["FRACTAL_DECONV_SP_STRIDE_TRANS", 11],
"FORMAT_NC1HWC0_C04": ["NC1HWC0_C04", 12],
"FORMAT_FRACTAL_Z_C04": ["FRACTAL_Z_C04", 13],
"FORMAT_CHWN": ["CHWN", 14],
"FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS": ["FRACTAL_DECONV_SP_STRIDE8_TRANS", 15],
"FORMAT_HWCN": ["HWCN", 16],
"FORMAT_NC1KHKWHWC0": ["NC1KHKWHWC0", 17],
"FORMAT_BN_WEIGHT": ["BN_WEIGHT", 18],
"FORMAT_FILTER_HWCK": ["FILTER_HWCK", 19],
"FORMAT_HASHTABLE_LOOKUP_LOOKUPS": ["HASHTABLE_LOOKUP_LOOKUPS", 20],
"FORMAT_HASHTABLE_LOOKUP_KEYS": ["HASHTABLE_LOOKUP_KEYS", 21],
"FORMAT_HASHTABLE_LOOKUP_VALUE": ["HASHTABLE_LOOKUP_VALUE", 22],
"FORMAT_HASHTABLE_LOOKUP_OUTPUT": ["HASHTABLE_LOOKUP_OUTPUT", 23],
"FORMAT_HASHTABLE_LOOKUP_HITS": ["HASHTABLE_LOOKUP_HITS", 24],
"FORMAT_C1HWNCoC0": ["C1HWNCoC0", 25],
"FORMAT_MD": ["MD", 26],
"FORMAT_NDHWC": ["NDHWC", 27],
"FORMAT_FRACTAL_ZZ": ["FRACTAL_ZZ", 28],
"FORMAT_FRACTAL_NZ": ["FRACTAL_NZ", 29],
"FORMAT_NCDHW": ["NCDHW", 30],
"FORMAT_DHWCN": ["DHWCN", 31],
"FORMAT_NDC1HWC0": ["NDC1HWC0", 32],
"FORMAT_FRACTAL_Z_3D": ["FRACTAL_Z_3D", 33],
"FORMAT_CN": ["CN", 34],
"FORMAT_NC": ["NC", 35],
"FORMAT_DHWNC": ["DHWNC", 36],
"FORMAT_FRACTAL_Z_3D_TRANSPOSE": ["FRACTAL_Z_3D_TRANSPOSE", 37],
"FORMAT_FRACTAL_ZN_LSTM": ["FRACTAL_ZN_LSTM", 38],
"FORMAT_FRACTAL_Z_G": ["FRACTAL_Z_G", 39],
"FORMAT_RESERVED": ["RESERVED", 40],
"FORMAT_ALL": ["ALL", 41],
"FORMAT_NULL": ["NULL", 42],
"FORMAT_ND_RNN_BIAS": ["ND_RNN_BIAS", 43],
"FORMAT_FRACTAL_ZN_RNN": ["FRACTAL_ZN_RNN", 44],
"FORMAT_NYUV": ["NYUV", 45],
"FORMAT_NYUV_A": ["NYUV_A", 46],
"FORMAT_NCL": ["NCL", 47],
"FORMAT_FRACTAL_Z_WINO": ["FRACTAL_Z_WINO", 48],
"FORMAT_C1HWC0": ["C1HWC0", 49],
"FORMAT_FRACTAL_NZ_C0_16": ["FRACTAL_NZ_C0_16", 50],
"FORMAT_FRACTAL_NZ_C0_32": ["FRACTAL_NZ_C0_32", 51],
"FORMAT_FRACTAL_NZ_C0_2": ["FRACTAL_NZ_C0_2", 52],
"FORMAT_FRACTAL_NZ_C0_4": ["FRACTAL_NZ_C0_4", 53],
"FORMAT_FRACTAL_NZ_C0_8": ["FRACTAL_NZ_C0_8", 54],
"FORMAT_MAX": ["unknown", ASCENDC_TPL_DATAFORMAT_MAX],
}
ASCENDC_COMPILE_DATAFORMAT_REVERT_MAP = {v[-1]: v[0] for v in ASCENDC_COMPILE_DATAFORMAT_MAP.values()}
SAFE_EXPR_BIN_OPS = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.FloorDiv: operator.floordiv,
ast.Div: operator.floordiv,
ast.Mod: operator.mod,
ast.LShift: operator.lshift,
ast.RShift: operator.rshift,
ast.BitOr: operator.or_,
ast.BitAnd: operator.and_,
ast.BitXor: operator.xor,
}
SAFE_EXPR_UNARY_OPS = {
ast.UAdd: operator.pos,
ast.USub: operator.neg,
ast.Invert: operator.invert,
}
def safe_eval_int_expr(expr):
def _eval(node):
if isinstance(node, ast.Expression):
return _eval(node.body)
if isinstance(node, ast.Constant) and isinstance(node.value, int):
return node.value
if isinstance(node, ast.BinOp) and type(node.op) in SAFE_EXPR_BIN_OPS:
return SAFE_EXPR_BIN_OPS[type(node.op)](_eval(node.left), _eval(node.right))
if isinstance(node, ast.UnaryOp) and type(node.op) in SAFE_EXPR_UNARY_OPS:
return SAFE_EXPR_UNARY_OPS[type(node.op)](_eval(node.operand))
raise ValueError("unsupported integer expression")
return _eval(ast.parse(expr, mode="eval"))
def extract_num(s):
numbers = re.findall(r'\d+', s)
return [int(num) for num in numbers]
def extract_str(s):
strs = re.findall(r'\w+', s)
return strs
def safe_parse_value(s):
try:
return int(s)
except Exception:
pass
try:
result = int(safe_eval_int_expr(s))
return result
except Exception:
pass
try:
if "DT_" in s:
if s[:5] != "C_DT_":
CommonUtility.print_compile_log(
"",
f"datatype only support C_DT_*, but now is [{s}]",
AscendCLogLevel.LOG_ERROR,
)
datatype = s.split("::")[-1]
if datatype[:2] == "C_":
datatype = datatype[2:]
result = ASCENDC_COMPILE_DATATYPE_MAP[datatype][-1]
return result
except Exception:
pass
try:
if "FORMAT_" in s:
if s[:9] != "C_FORMAT_":
CommonUtility.print_compile_log(
"",
f"format only support C_FORMAT_*, but now is [{s}]",
AscendCLogLevel.LOG_ERROR,
)
dataformat = s.split("::")[-1]
if dataformat[:2] == "C_":
dataformat = dataformat[2:]
result = ASCENDC_COMPILE_DATAFORMAT_MAP[dataformat][-1]
return result
except Exception:
pass
CommonUtility.print_compile_log(
"",
"can't convert str : [{}], it is invalid value.".format(s),
AscendCLogLevel.LOG_ERROR,
)
return -1
def extract_expr(s):
match = re.search(r"{([^}]*)}", s)
content = match.group(1).split(",")
content = [x.strip() for x in content if x.strip()]
results = [safe_parse_value(item) for item in content if safe_parse_value(item) >= 0]
return results
class TilingParamType(Enum):
TPL_DTYPE = auto()
TPL_DATATYPE = auto()
TPL_FORMAT = auto()
TPL_DATAFORMAT = auto()
TPL_UINT = auto()
TPL_BOOL = auto()
TPL_TILING_STRUCT = auto()
TPL_KERNEL_TYPE = auto()
TPL_DETERMINISTIC = auto()
TPL_SHARED_KERNEL_TYPE = auto()
TPL_NONE = auto()
@dataclass
class TilingTemplateParams:
name: str
param_type: TilingParamType
values: List[Any]
bit_width: int = 0
macro_type: str = ""
def __post_init__(self):
if self.param_type == TilingParamType.TPL_UINT:
if len(self.values) < 2:
raise RuntimeError("Length of ASCENDC_TPL_UINT_{} {} is too short.".format(self.macro_type, self.name))
if self.bit_width == 0:
self.bit_width = self.values[0]
uint_declare_flag = self.values[1]
vals = self.values[2:]
else:
uint_declare_flag = self.values[0]
vals = self.values[1:]
if uint_declare_flag == UI_LIST:
self.values = vals
self.__check_valid()
return
if uint_declare_flag != UI_RANGE and uint_declare_flag != UI_MIX:
raise RuntimeError("Cannot find flag in ASCENDC_TPL_UINT_{} {}! Value should be in"
" [ASCENDC_TPL_UI_RANGE, ASCENDC_TPL_UI_LIST, ASCENDC_TPL_UI_MIX].".format(self.macro_type, self.name))
range_num = vals[0]
extend_num = []
for i in range(range_num):
if 1 + i * 2 >= len(vals) - 1:
raise RuntimeError("UI_RANGE declare parse failed, "
"because the announced length of ASCENDC_TPL_UINT_{} {}"
" is greater than actual values' length.".format(self.macro_type, self.name))
extend_num.extend(range(vals[1 + 2 * i], vals[2 + i * 2] + 1))
if uint_declare_flag == UI_MIX:
extend_num.extend(vals[1 + range_num * 2:])
self.values = extend_num
self.__check_valid()
def __check_valid(self):
if not self.values:
raise RuntimeError("values of ASCENDC_TPL_{}_{} {} is empty!".format(self.get_param_type_str(),
self.macro_type, self.name))
if len(set(self.values)) != len(self.values):
raise RuntimeError("There is duplicated number in ASCENDC_TPL_{}_{} {}!"
" Duplicated List: {}.".format(self.macro_type, self.get_param_type_str(), self.name, self.values))
if self.param_type == TilingParamType.TPL_BOOL:
if not set(self.values).issubset({0, 1}):
raise RuntimeError("There is invalid number in ASCENDC_TPL_BOOL_{} {}!"
" Value should only be in [0, 1].".format(self.macro_type, self.name))
elif self.param_type == TilingParamType.TPL_DETERMINISTIC:
if not set(self.values).issubset({'true', 'false'}):
raise RuntimeError("There is invalid number in ASCENDC_TPL_DETERMINISTIC_SEL!"
"Value should only be in [0, 1, true, false].")
else:
self.check_bit_width_valid()
def check_bit_width_valid(self):
bit_width_check_mode = "value" if self.param_type != TilingParamType.TPL_UINT else "index"
if self.bit_width <= 0:
raise RuntimeError("Bit width in ASCENDC_TPL_{}_{} {}"
" cannot be less than zero!".format(self.get_param_type_str(), self.macro_type, self.name))
max_encode_num = 2 ** self.bit_width
if bit_width_check_mode == "value":
if max_encode_num < max(list(filter(lambda x: x < ASCENDC_TPL_INPUT_BIAS, self.values))):
raise RuntimeError("Bit width:{} in ASCENDC_TPL_{}_{} {} is not enough to represent all values: {}!"
" Please make sure 2^bitWidth is greater than or equal to the [max value of values]."
.format(self.bit_width, self.get_param_type_str(), self.macro_type, self.name, self.values))
elif bit_width_check_mode == "index":
if max_encode_num < len(list(filter(lambda x: x < ASCENDC_TPL_INPUT_BIAS, self.values))):
raise RuntimeError("Bit width:{} in ASCENDC_TPL_{}_{} {} is not enough to represent all values: {}!"
" Please make sure 2^bitWidth is greater than or equal to the [number of values]."
.format(self.bit_width, self.get_param_type_str(), self.macro_type, self.name, self.values))
def get_encodes(self):
encodes = dict()
if self.param_type == TilingParamType.TPL_UINT:
for i, value in enumerate(self.values):
encodes[bin(i)[2:].rjust(self.bit_width, "0")] = value
elif (
self.param_type == TilingParamType.TPL_DTYPE or
self.param_type == TilingParamType.TPL_FORMAT
):
for value in self.values:
encodes[bin(value)[2:].rjust(self.bit_width, "0")] = value
elif (self.param_type == TilingParamType.TPL_DATATYPE):
for value in self.values:
if value >= ASCENDC_TPL_DATATYPE_MAX:
continue
encodes[bin(value)[2:].rjust(self.bit_width, "0")] = value
elif (self.param_type == TilingParamType.TPL_DATAFORMAT):
for value in self.values:
if value >= ASCENDC_TPL_DATAFORMAT_MAX:
continue
encodes[bin(value)[2:].rjust(self.bit_width, "0")] = value
elif self.param_type == TilingParamType.TPL_SHARED_KERNEL_TYPE:
for value in self.values:
encodes[bin(value)[2:].rjust(self.bit_width, '0')] = value
else:
encodes['0'] = 0
encodes['1'] = 1
return encodes
def get_param_type_str(self) -> str:
name_dict = {
TilingParamType.TPL_DTYPE: "DTYPE",
TilingParamType.TPL_DATATYPE: "DATATYPE",
TilingParamType.TPL_BOOL: "BOOL",
TilingParamType.TPL_FORMAT: "FORMAT",
TilingParamType.TPL_DATAFORMAT: "DATAFORMAT",
TilingParamType.TPL_UINT: "UINT",
TilingParamType.TPL_DETERMINISTIC: "DETERMINISTIC",
TilingParamType.TPL_NONE: "",
TilingParamType.TPL_SHARED_KERNEL_TYPE: "SHARED_KERNEL_TYPE",
}
return name_dict.get(self.param_type, "")
def extract_template_tiling_info(tiling_declare_str: str, tiling_select_str: str) -> None:
global TILING_DECLARE_MAP
global TILING_SELECT_MAP
declare_param_list = tiling_declare_str.split("@@")
select_param_list = tiling_select_str.split("@@")
TILING_DECLARE_MAP = extract_template_tiling_params(declare_param_list)[0]
bit_map = {}
for declare_param in TILING_DECLARE_MAP:
bit_map[declare_param.name] = declare_param.bit_width
TILING_SELECT_MAP = extract_template_tiling_params(select_param_list, bit_map)
name_order = [declare_param.name for declare_param in TILING_DECLARE_MAP]
for tiling_param_list in TILING_SELECT_MAP:
check_valid_select_param(tiling_param_list, name_order)
for select_params in TILING_SELECT_MAP:
select_params.sort(key=lambda x: name_order.index(x.name) if x.name in name_order else -1)
def check_valid_select_param(tiling_param_list: List[TilingTemplateParams], name_order: List[str]) -> None:
tiling_param_name_list = [tiling_param.name for tiling_param in tiling_param_list if
tiling_param.param_type != TilingParamType.TPL_TILING_STRUCT]
for tiling_declare_param in TILING_DECLARE_MAP:
if tiling_declare_param.name not in tiling_param_name_list:
raise RuntimeError("There is missing marco define: {} in ASCENDC_TPL_{}_SEL.".format(
tiling_declare_param.name, tiling_declare_param.get_param_type_str()))
total_bit_length = 0
has_shared_kernel = any(p.param_type == TilingParamType.TPL_SHARED_KERNEL_TYPE for p in tiling_param_list)
has_kernel = any(p.param_type == TilingParamType.TPL_KERNEL_TYPE for p in tiling_param_list)
if has_shared_kernel and has_kernel:
raise RuntimeError("There exists both TPL_KERNEL_TYPE and SHARED_KERNEL_TYPE")
for tiling_param in tiling_param_list:
if tiling_param.param_type == TilingParamType.TPL_TILING_STRUCT or \
tiling_param.param_type == TilingParamType.TPL_KERNEL_TYPE or \
tiling_param.param_type == TilingParamType.TPL_DETERMINISTIC:
continue
name = tiling_param.name
total_bit_length += tiling_param.bit_width
if total_bit_length > 64:
raise RuntimeError(f"name:{tiling_param.name}, type:{tiling_param.get_param_type_str()}, \
Total bit width cannot be greater than 64!")
tpl_type = tiling_param.param_type
vals = tiling_param.values
matched_param = None
for param in TILING_DECLARE_MAP:
if param.name == name:
matched_param = param
break
if matched_param is None:
raise RuntimeError(
"Cannot find ASCENDC_TPL_{}_SEL name: {} in ASCENDC_TPL_{}_DECL.".format(
tiling_param.get_param_type_str(), name, tiling_param.get_param_type_str()))
if tpl_type != matched_param.param_type:
if (
tpl_type == TilingParamType.TPL_FORMAT
and matched_param.param_type == TilingParamType.TPL_DATAFORMAT
):
raise RuntimeError(
"[Error] Not support both used custom FORMAT and native FORMAT, "
"please check {} in ASCENDC_TPL_FORMAT_DECL.".format(name, tiling_param.get_param_type_str())
)
else:
raise RuntimeError(
"{} has different type in ASCENDC_TPL_{}_SEL and ASCENDC_TPL_{}_DECL.".format(
name, tiling_param.get_param_type_str(), matched_param.get_param_type_str()
)
)
if not set(vals).issubset(set(matched_param.values)):
raise RuntimeError(
"Cannot find value {} from ASCENDC_TPL_{}_SEL {} in ASCENDC_TPL_{}_DECL,"
" please check your macro define.".format(set(vals) - set(matched_param.values),
tiling_param.get_param_type_str(),
name, tiling_param.get_param_type_str()))
def remove_prefix(text, prefix):
if text.startswith(prefix):
return text[len(prefix):]
return text
def build_format_class(tiling_param_str: str, name: str, macro_type: str):
sparse_values_old = extract_num(tiling_param_str)
sparse_values_new = extract_expr(tiling_param_str)
if sparse_values_old == sparse_values_new:
return TilingTemplateParams(name, TilingParamType.TPL_FORMAT, sparse_values_old, 8, macro_type)
else:
return TilingTemplateParams(name, TilingParamType.TPL_DATAFORMAT, sparse_values_new, 8, macro_type)
def extract_template_tiling_params(tiling_param_list: List[str], bit_map: dict = None) -> list:
res = []
tiling_param = []
for i, _ in enumerate(tiling_param_list):
sub_str = tiling_param_list[i]
if not sub_str:
continue
if sub_str == '{' and tiling_param:
res.append(tiling_param)
tiling_param = []
macro_type = (
"SEL"
if sub_str.startswith("ASCENDC_TPL_DTYPE_SEL_")
or sub_str.startswith("ASCENDC_TPL_DATATYPE_SEL_")
or sub_str.startswith("ASCENDC_TPL_FORMAT_SEL_")
or sub_str.startswith("ASCENDC_TPL_DATAFORMAT_SEL_")
or sub_str.startswith("ASCENDC_TPL_UINT_SEL_")
or sub_str.startswith("ASCENDC_TPL_BOOL_SEL_")
or sub_str.startswith("ASCENDC_TPL_TILING_STRUCT_SEL_")
or sub_str.startswith("ASCENDC_TPL_KERNEL_TYPE_SEL")
or sub_str.startswith("ASCENDC_TPL_DETERMINISTIC_SEL")
or sub_str.startswith("ASCENDC_TPL_SHARED_KERNEL_TYPE_SEL")
else "DECL"
)
if sub_str.startswith("ASCENDC_TPL_DTYPE"):
dtype_list = extract_num(tiling_param_list[i + 1])
name = remove_prefix(sub_str, 'ASCENDC_TPL_DTYPE_{}_'.format(macro_type))
tiling_param.append(TilingTemplateParams(
name, TilingParamType.TPL_DTYPE, dtype_list, 8, macro_type))
elif sub_str.startswith("ASCENDC_TPL_DATATYPE"):
dtype_list = extract_expr(tiling_param_list[i + 1])
name = remove_prefix(sub_str, 'ASCENDC_TPL_DATATYPE_{}_'.format(macro_type))
tiling_param.append(TilingTemplateParams(
name, TilingParamType.TPL_DATATYPE, dtype_list, 8, macro_type))
elif sub_str.startswith("ASCENDC_TPL_FORMAT"):
name = remove_prefix(sub_str, 'ASCENDC_TPL_FORMAT_{}_'.format(macro_type))
tiling_param.append(build_format_class(tiling_param_list[i + 1], name, macro_type))
elif sub_str.startswith("ASCENDC_TPL_KERNEL_TYPE_SEL"):
format_list = extract_num(tiling_param_list[i + 1])
name = "KERNEL_TYPE"
tiling_param.append(TilingTemplateParams(
name, TilingParamType.TPL_KERNEL_TYPE, format_list, 8, macro_type))
elif sub_str.startswith("ASCENDC_TPL_DETERMINISTIC_SEL"):
cur_deter_flag = extract_str(tiling_param_list[i + 1])
if len(cur_deter_flag) != 1:
raise RuntimeError("ASCENDC_TPL_DETERMINISTIC_SEL can only one value can be specified!")
if cur_deter_flag[0] == "1":
cur_deter_flag[0] = "true"
elif cur_deter_flag[0] == "0":
cur_deter_flag[0] = "false"
tiling_param.append(TilingTemplateParams("DETERMINISTIC", \
TilingParamType.TPL_DETERMINISTIC, cur_deter_flag, 1, macro_type))
elif sub_str.startswith("ASCENDC_TPL_UINT"):
uint_list = extract_num(tiling_param_list[i + 1])
name = remove_prefix(sub_str, 'ASCENDC_TPL_UINT_{}_'.format(macro_type))
if macro_type == "SEL":
tiling_param.append(TilingTemplateParams(name,
TilingParamType.TPL_UINT, uint_list, bit_map[name], macro_type))
else:
tiling_param.append(TilingTemplateParams(name, TilingParamType.TPL_UINT, uint_list, 0, macro_type))
elif "ASCENDC_TPL_BOOL" in sub_str:
bool_list = extract_num(tiling_param_list[i + 1])
name = remove_prefix(sub_str, 'ASCENDC_TPL_BOOL_{}_'.format(macro_type))
tiling_param.append(TilingTemplateParams(name, TilingParamType.TPL_BOOL, bool_list, 1, macro_type))
elif "ASCENDC_TPL_SHARED_KERNEL_TYPE" in sub_str:
shared_kernel_type_list = extract_num(tiling_param_list[i + 1])
name = remove_prefix(sub_str, 'ASCENDC_TPL_SHARED_KERNEL_TYPE_{}_'.format(macro_type))
tiling_param.append(TilingTemplateParams(name, TilingParamType.TPL_SHARED_KERNEL_TYPE, \
shared_kernel_type_list, 8, macro_type))
elif "ASCENDC_TPL_KERNEL_TYPE" in sub_str:
if macro_type == "DECL":
shared_kernel_type_list = extract_num(tiling_param_list[i + 1])
name = remove_prefix(sub_str, 'ASCENDC_TPL_KERNEL_TYPE_{}_'.format(macro_type))
tiling_param.append(TilingTemplateParams(name, TilingParamType.TPL_SHARED_KERNEL_TYPE, \
shared_kernel_type_list, 8, macro_type))
elif "ASCENDC_TPL_TILING_STRUCT" in sub_str:
name = remove_prefix(sub_str, 'ASCENDC_TPL_TILING_STRUCT_{}_'.format(macro_type))
tiling_param.append(TilingTemplateParams(name, TilingParamType.TPL_TILING_STRUCT, [-1], 1))
if tiling_param:
res.append(tiling_param)
return res
GROUP_ID = 0
ID_LIST = []
group_map = {}
kernel_compile_dict = {}
def get_concated_tiling_key(template_param_list: List[TilingTemplateParams], \
index: int, tiling_key: str, tiling_args: List[int], encode_book: dict, data: dict, kernel_index: int) -> dict:
global ID_LIST
global GROUP_ID
global group_map
global kernel_compile_dict
result = dict()
if index == len(template_param_list):
data.update({"paramArgs": tiling_args})
if kernel_index != -1:
group_key = tuple(tiling_args[: kernel_index] + tiling_args[kernel_index + 1:])
if group_key in group_map:
group_id = group_map[group_key]
else:
GROUP_ID += 1
group_map[group_key] = GROUP_ID
group_id = GROUP_ID
data["groupId"] = group_id
ID_LIST.append(group_id)
result[int(tiling_key, 2)] = copy.copy(data)
return result
template_param = template_param_list[index]
for val in template_param.values:
if kernel_compile_dict and template_param.name in kernel_compile_dict \
and str(val) not in kernel_compile_dict[template_param.name]:
continue
if template_param.param_type == TilingParamType.TPL_TILING_STRUCT:
data["tilingStruct"] = template_param.name
encode_ = ""
tiling_args_temp = tiling_args
elif template_param.param_type == TilingParamType.TPL_KERNEL_TYPE:
data["kernelType"] = template_param.values[0]
encode_ = ""
tiling_args_temp = tiling_args
elif template_param.param_type == TilingParamType.TPL_DETERMINISTIC:
data['deterministic'] = template_param.values[0]
encode_ = ""
tiling_args_temp = tiling_args
elif template_param.param_type == TilingParamType.TPL_DATATYPE:
data["dtypeParams"].append(ASCENDC_COMPILE_DATATYPE_REVERT_MAP.get(val, "unknown"))
encode_ = encode_book[template_param.name][val]
tiling_args_temp = tiling_args + ["TypeFromId<{}>::type".format(str(val))]
elif template_param.param_type == TilingParamType.TPL_DATAFORMAT:
data["formatParams"].append(ASCENDC_COMPILE_DATAFORMAT_REVERT_MAP.get(val, "unknown"))
encode_ = encode_book[template_param.name][val]
tiling_args_temp = tiling_args + [str(val)]
else:
if template_param.param_type == TilingParamType.TPL_SHARED_KERNEL_TYPE:
data["kernelType"] = val
encode_ = encode_book[template_param.name][val]
tiling_args_temp = tiling_args + [str(val)]
result.update(
get_concated_tiling_key(
template_param_list, index + 1, encode_ + tiling_key, tiling_args_temp, encode_book, data, kernel_index
)
)
return result
def extract_decl_param_options(op_info: OpInfo, option_name="dtype"):
if op_info is None:
return [], []
support_option_dict = {
"dtype": TilingParamType.TPL_DATATYPE,
"format": TilingParamType.TPL_DATAFORMAT,
}
decl_input_indexes = []
decl_output_indexes = []
deck_select_indexes = []
checked_param_type = support_option_dict[option_name]
for declare_param in TILING_DECLARE_MAP:
if declare_param.param_type != checked_param_type:
continue
deck_select_indexes.append(False)
for val in declare_param.values:
if val >= ASCENDC_TPL_INPUT_BIAS and val < ASCENDC_TPL_OUTPUT_BIAS:
decl_input_indexes.append(int(val % ASCENDC_TPL_INPUT_BIAS))
deck_select_indexes[-1] = True
break
elif val >= ASCENDC_TPL_OUTPUT_BIAS:
decl_output_indexes.append(int(val % ASCENDC_TPL_OUTPUT_BIAS))
deck_select_indexes[-1] = True
break
def _check_indexes(indexes: List[int], value_list, desc=""):
if len(indexes) != 0 and (max(indexes) >= len(value_list) or min(indexes) < 0):
raise ValueError(
f"exist invalid {desc} index : {indexes}, please check it."
)
if len(indexes) != len(set(indexes)):
raise ValueError(f"exist duplicate {desc} index : {indexes}")
_check_indexes(decl_input_indexes, op_info.inputs, "input")
_check_indexes(decl_output_indexes, op_info.outputs, "output")
decl_input_options = [
op_info.inputs[idx][option_name] for idx in decl_input_indexes
]
decl_output_options = [
op_info.outputs[idx][option_name] for idx in decl_output_indexes
]
decl_output_indexes = [int(x + len(op_info.inputs)) for x in decl_output_indexes]
return decl_input_options + decl_output_options, deck_select_indexes
def check_kernel_template_input(kernel_compile_dict_input: dict):
tiling_map_dict = {item.name: item for item in TILING_DECLARE_MAP}
for key, values in kernel_compile_dict_input.items():
if key not in tiling_map_dict:
raise RuntimeError('kernel-template-input key {} is invalid!'.format(key))
valid_values = tiling_map_dict[key].values
valid_values_str = set(str(v) for v in valid_values)
for value in values:
if value not in valid_values_str:
raise RuntimeError('kernel-template-input key: {} value: {} is invalid!'.format(key, value))
def decode_tiling(tiling_key: int = None) -> dict:
global kernel_compile_dict
kernel_template_input = get_context().get_addition("kernel-template-input")
if kernel_template_input and "=" not in kernel_template_input:
raise RuntimeError('Invalid kernel-template-input format! Please use key value pairs!')
if kernel_template_input:
if kernel_template_input.startswith("'") and kernel_template_input.endswith("'"):
kernel_template_input = kernel_template_input[1:-1]
pairs = kernel_template_input.split(';')
for pair in pairs:
if pair.count("=") > 1:
raise RuntimeError('kernel-template-input value has invalid format!')
if '=' in pair:
key, value = pair.split('=', 1)
value_list = value.split(',')
format_value_list = []
for v in value_list:
if v in ASCENDC_KERNEL_TEMPLATE_INPUT_DATATYPE_MAP.keys():
format_value_list.append(str(ASCENDC_KERNEL_TEMPLATE_INPUT_DATATYPE_MAP[v]))
elif v == 'false':
format_value_list.append('0')
elif v == 'true':
format_value_list.append('1')
elif v in ASCENDC_KERNEL_TEMPLATE_INPUT_KERNEL_TYPE_MAP.keys():
format_value_list.append(str(ASCENDC_KERNEL_TEMPLATE_INPUT_KERNEL_TYPE_MAP[v]))
else:
format_value_list.append(v)
kernel_compile_dict[key] = format_value_list
check_kernel_template_input(kernel_compile_dict)
encode_book = dict()
for param in TILING_DECLARE_MAP:
encode_book[param.name] = param.get_encodes()
reversed_encode_book = {}
CommonUtility.print_compile_log("", "[Template Tiling] Encode Book is {}.".format(encode_book), \
AscendCLogLevel.LOG_INFO)
for k, v in encode_book.items():
reversed_book = dict()
for bin_encodes, actual_val in v.items():
reversed_book[actual_val] = bin_encodes
reversed_encode_book[k] = reversed_book
decode_results = dict()
for tiling_template_select_param in TILING_SELECT_MAP:
kernel_index = next((i for i, param in enumerate(tiling_template_select_param) if \
param.param_type == TilingParamType.TPL_SHARED_KERNEL_TYPE), -1)
decode_result = get_concated_tiling_key(
tiling_template_select_param, 0, "", [], reversed_encode_book, \
{"dtypeParams": [], "formatParams": []}, kernel_index
)
length_before = len(decode_result) + len(decode_results)
decode_results.update(decode_result)
if length_before != len(decode_results):
raise RuntimeError("ASCENDC_TPL_SELECT has duplicated definitions!")
if tiling_key is None:
return decode_results
if tiling_key not in decode_results.keys():
CommonUtility.print_compile_log("", "Cannot find any matched template tiling keys.", \
AscendCLogLevel.LOG_WARNING)
return {}
else:
return {tiling_key: decode_results[tiling_key]}