fd227bb4创建于 2021年4月26日历史提交
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""generate json desc for graph kernel ops"""
import json
import json.decoder as jd
import traceback
from mindspore import log as logger
import mindspore._extends.graph_kernel.expanders as expanders
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException


def create_expander(expand_info):
    """Create an expander according to op name"""
    def call_func(func, arg):
        return func(arg)
    op_name = str(expand_info['name'])
    if not hasattr(expanders, op_name):
        raise GraphKernelUnsupportedException("Generator do not support op: {}".format(op_name))
    expander = getattr(expanders, op_name)
    return call_func(expander, expand_info)


def extract_expand_info(kernel_info):
    """Convert the json into a more friendly format"""
    input_desc = []
    if 'input_desc' in kernel_info and kernel_info['input_desc']:
        for desc in kernel_info['input_desc']:
            input_desc += desc
    attrs = {}
    if 'attr' in kernel_info and kernel_info['attr']:
        for attr in kernel_info["attr"]:
            attrs[attr["name"]] = attr["value"]
    expand_info = {
        "name": kernel_info["name"],
        "input_desc": input_desc,
        "output_desc": kernel_info["output_desc"],
        "attr": attrs,
        "process": kernel_info["process"],
    }
    return expand_info


def get_op_expander(json_str: str):
    """get op expander by json info"""
    try:
        kernel_info = json.loads(json_str)
        expand_info = extract_expand_info(kernel_info)

        expander = create_expander(expand_info)
        graph = expander.run()

        # dump graph to json desc.
        desc = graph.dump()
        return json.dumps(desc)

    except jd.JSONDecodeError:
        logger.error("Failed to generate graph kernel op")
        logger.error(traceback.format_exc())
        return None
    except GraphKernelUnsupportedException as e:
        logger.info(e.message)
        return ""