"""generate json desc for graph kernel ops"""
import json
import json.decoder as jd
import traceback
from mindspore import log as logger
import mindspore._extends.graph_kernel.expanders as expanders
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException
def create_expander(expand_info):
"""Create an expander according to op name"""
def call_func(func, arg):
return func(arg)
op_name = str(expand_info['name'])
if not hasattr(expanders, op_name):
raise GraphKernelUnsupportedException("Generator do not support op: {}".format(op_name))
expander = getattr(expanders, op_name)
return call_func(expander, expand_info)
def extract_expand_info(kernel_info):
"""Convert the json into a more friendly format"""
input_desc = []
if 'input_desc' in kernel_info and kernel_info['input_desc']:
for desc in kernel_info['input_desc']:
input_desc += desc
attrs = {}
if 'attr' in kernel_info and kernel_info['attr']:
for attr in kernel_info["attr"]:
attrs[attr["name"]] = attr["value"]
expand_info = {
"name": kernel_info["name"],
"input_desc": input_desc,
"output_desc": kernel_info["output_desc"],
"attr": attrs,
"process": kernel_info["process"],
}
return expand_info
def get_op_expander(json_str: str):
"""get op expander by json info"""
try:
kernel_info = json.loads(json_str)
expand_info = extract_expand_info(kernel_info)
expander = create_expander(expand_info)
graph = expander.run()
desc = graph.dump()
return json.dumps(desc)
except jd.JSONDecodeError:
logger.error("Failed to generate graph kernel op")
logger.error(traceback.format_exc())
return None
except GraphKernelUnsupportedException as e:
logger.info(e.message)
return ""