"""
"""
import json
from enum import Enum
_HASH_ORDER_KEYS = {
"l1_reuse": ("l1_reuse_hashOrder", "l1_reuse_subgraphCount"),
"cube_merge": ("cube_merge_hashOrder", "cube_merge_subgraphCount"),
"vec_merge": ("vec_merge_hashOrder", "vec_merge_subgraphCount"),
}
def get_sematic(func_index, opmagic, func_data):
if func_index >= len(func_data):
return ""
for call_op in func_data[func_index]["operations"]:
if call_op["opmagic"] == opmagic:
if "semantic_label" in call_op:
return call_op["semantic_label"]["label"]
else:
return ""
return ""
def _extract_hash_order_info_from_op(op):
op_attr = op.get("op_attr", {})
infos = []
for _, (hash_key, count_key) in _HASH_ORDER_KEYS.items():
hash_order = op_attr.get(hash_key)
count = op_attr.get(count_key)
if hash_order is not None and count is not None:
infos.append({"hashOrder": hash_order, "subgraphCount": count})
else:
infos.append(None)
return tuple(infos)
def get_hash_order_info(func_index, func_data):
"""Get HashOrderInfo from program.json by func_index.
Traverse all operations in the function to find valid hashOrder info from op_attr.
Format: op_attr contains "l1_reuse_hashOrder", "l1_reuse_subgraphCount", etc.
Args:
func_index: Index of the function in func_data (should be leaf function for hashOrder)
func_data: List of functions from program.json
Returns:
tuple of dicts: (l1ReuseInfo, cubeMergeInfo, vecMergeInfo)
Each dict contains {"hashOrder": str, "subgraphCount": int} or None if not set
"""
if func_index >= len(func_data):
return None, None, None
func = func_data[func_index]
for op in func["operations"]:
infos = _extract_hash_order_info_from_op(op)
if any(infos):
return tuple(infos)
return None, None, None
def get_hash_order_info_by_opmagic(func_index, opmagic, func_data):
"""Get HashOrderInfo from a specific operation in program.json."""
if func_index >= len(func_data):
return None, None, None
for op in func_data[func_index]["operations"]:
if op.get("opmagic") != opmagic:
continue
return _extract_hash_order_info_from_op(op)
return None, None, None
def get_hash_order_info_for_task(root_index, opmagic, leaf_index, func_data):
"""Get task HashOrderInfo, preferring root CALL op metadata over shared leaf metadata."""
infos = get_hash_order_info_by_opmagic(root_index, opmagic, func_data)
if any(infos):
return infos
return get_hash_order_info(leaf_index, func_data)
def get_hash_order(func_index, func_data):
"""Get hashOrder info from program.json by func_index (backward compatibility)."""
l1_info, cube_info, vec_info = get_hash_order_info(func_index, func_data)
l1_hash = l1_info.get("hashOrder") if l1_info else None
cube_hash = cube_info.get("hashOrder") if cube_info else None
vec_hash = vec_info.get("hashOrder") if vec_info else None
return l1_hash, cube_hash, vec_hash
def get_subgraph_count(func_index, func_data):
"""Get subgraphCount info from program.json by func_index (backward compatibility)."""
l1_info, cube_info, vec_info = get_hash_order_info(func_index, func_data)
l1_count = l1_info.get("subgraphCount") if l1_info else None
cube_count = cube_info.get("subgraphCount") if cube_info else None
vec_count = vec_info.get("subgraphCount") if vec_info else None
return l1_count, cube_count, vec_count
def get_func_magic(func_index, func_data):
"""Get function magic number from program.json by func_index.
Args:
func_index: Index of the function in func_data
func_data: List of functions from program.json
Returns:
int: function magic number, or -1 if not found
"""
if func_index >= len(func_data):
return -1
func = func_data[func_index]
func_magicname = func.get("func_magicname", "")
if func_magicname:
parts = func_magicname.rsplit("_", 1)
if len(parts) == 2:
try:
return int(parts[1])
except ValueError:
pass
return func.get("magic", -1)
def format_hash_order(hash_order, func_magic):
"""Format hashOrder as 'func{magic}_{order}' or return original value.
DEPRECATED: This function is no longer needed as hashOrder is now stored
directly as the full string format or null in the JSON.
Args:
hash_order: The hashOrder value (string, int, or None)
func_magic: The function magic number
Returns:
str or None: Formatted hashOrder string, or None if invalid
"""
if hash_order is None:
return None
if isinstance(hash_order, str):
return hash_order if hash_order else None
if isinstance(hash_order, int) and hash_order >= 0 and func_magic >= 0:
return f"func{func_magic}_{hash_order}"
return None
class DataType(Enum):
INT4 = 0
INT8 = 1
INT16 = 2
INT32 = 3
INT64 = 4
FP8 = 5
FP16 = 6
FP32 = 7
BF16 = 8
HF4 = 9
HF8 = 10
UINT8 = 11
UINT16 = 12
UINT32 = 13
UINT64 = 14
BOOL = 15
DOUBLE = 16
FP8E5M2 = 17
FP8E4M3 = 18
FP8E8M0 = 19
BOTTOM = 20
class MemType(Enum):
UB = 0
L1 = 1
L0A = 2
L0B = 3
L0C = 4
FIX = 5
FIX_QUANT_PRE = 6
FIX_RELU_PRE = 7
FIX_RELU_POST = 8
FIX_QUANT_POST = 9
FIX_ELT_ANTIQ = 10
FIX_MTE2_ANTIQ = 11
BT = 12
L2 = 13
L3 = 14
DEVICE_DDR = 15
HOST1 = 16
FAR1 = 17
FAR2 = 18
WORKSPACE = 19
VECTOR_REG = 20
data_type_dict = {member.value: member.name for member in DataType}
mem_type_dict = {member.value: member.name for member in MemType}
def get_data_type_str(data_type):
if data_type in data_type_dict.keys():
return data_type_dict[data_type]
return "DataType" + str(data_type)
def get_mem_type_str(mem_type):
if mem_type in mem_type_dict.keys():
return mem_type_dict[mem_type]
return "MemType" + str(mem_type)
def convert_rawtensor_to_str(rawtensor):
res = ""
res += "raw@" + str(rawtensor["rawmagic"])
res += " " + get_data_type_str(rawtensor["datatype"])
res += " " + str(rawtensor["rawshape"])
if "symbol" in rawtensor:
res += " " + rawtensor["symbol"]
return res
def convert_operand_to_str(operand):
res = ""
res += "tensor@" + str(operand["magic"])
res += " " + get_mem_type_str(operand["mem_type"]["tobe"])
res += " " + str(operand["shape"])
res += " " + str(operand["offset"])
res += " " + convert_rawtensor_to_str(operand["rawtensor"])
return res
def convert_operands_to_str(operands):
res = []
for operand in operands:
res.append(convert_operand_to_str(operand))
return ", ".join(set(res))
def get_in_out_operand_str(is_inoperand, func_index, opmagic, func_data):
if func_index >= len(func_data):
return ""
for call_op in func_data[func_index]["operations"]:
if call_op["opmagic"] == opmagic:
if is_inoperand:
return convert_operands_data(call_op['ioperands'])
else:
return convert_operands_data(call_op['ooperands'])
return ""
def get_in_out_operands_data(is_inoperand, func_index, opmagic, func_data):
if func_index >= len(func_data):
return []
for call_op in func_data[func_index]["operations"]:
if call_op["opmagic"] == opmagic:
if is_inoperand:
return convert_operands_data(call_op['ioperands'])
else:
return convert_operands_data(call_op['ooperands'])
return []
def convert_operand_data(operand):
res = dict()
tensor_info = dict()
tensor_info['shape'] = operand['shape']
tensor_info['dtype'] = operand['rawtensor']['datatype']
tensor_info['rawmagic'] = operand['rawtensor']['rawmagic']
res[operand['magic']] = tensor_info
return res
def convert_operands_data(operands):
res = []
for operand in operands:
res.append(convert_operand_data(operand))
return res
def get_tensors(func_hash, func_hash_data):
tensors_dict = dict()
if func_hash not in func_hash_data:
return tensors_dict
tensors = func_hash_data[func_hash]['tensors']
for tensor in tensors:
tensors_dict[tensor['magic']] = tensor
return tensors_dict
def get_rawtensors(func_hash, func_hash_data):
rawtensors_dict = dict()
if func_hash not in func_hash_data:
return rawtensors_dict
rawtensors = func_hash_data[func_hash]['rawtensors']
for rawtensor in rawtensors:
rawtensors_dict[rawtensor['rawmagic']] = rawtensor
return rawtensors_dict