#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ----------------------------------------------------------------------------
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# The code snippet comes from Huawei's open-source Ascend project.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.
# Licensed under the Apache License, Version 2.0 (the "License");
# You may obtain a copy of the License at
# 
# http://www.apache.org/licenses/LICENSE-2.0
# ----------------------------------------------------------------------------

import json
import os
import stat
import sys


def parse_ini_files(ini_files):
    """
    parse ini files to json
    Parameters:
    ----------------
    ini_files:input file list
    return:ops_info
    ----------------
    """
    tbe_ops_info = {}
    for ini_file in ini_files:
        if not os.path.exists(ini_file):
            print("ini file {} not exists!".format(ini_file))
            continue
        parse_ini_to_obj(ini_file, tbe_ops_info)
    return tbe_ops_info


def parse_ini_to_obj(ini_file, tbe_ops_info):
    """
    parse ini file to json obj
    Parameters:
    ----------------
    ini_file:ini file path
    tbe_ops_info:ops_info
    ----------------
    """
    with open(ini_file) as ini_file:
        lines = ini_file.readlines()
        op = {}
        op_name = ""
        for line in lines:
            line = line.rstrip()
            if not line:
                continue
            if line.startswith("["):
                op_name = line[1:-1]
                op = {}
                tbe_ops_info[op_name] = op
            else:
                key1 = line[:line.index("=")].strip()
                key2 = line[line.index("=")+1:].strip()
                key1_0, key1_1 = key1.split(".")
                if not key1_0 in op:
                    op[key1_0] = {}
                if key1_1 in op[key1_0]:
                    raise RuntimeError("Op:" + op_name + " " + key1_0 + " " + key1_1 + " is repeated!")
                op[key1_0][key1_1] = key2


def check_op_info(tbe_ops):
    """
    Check info info
    """
    print("\n\n==============check valid for ops info start==============")
    required_op_input_info_keys = ["name"]
    required_op_output_info_keys = ["name"]
    param_type_valid_value = ["dynamic", "optional", "required"]
    infer_shape_subtype_valid_value = ["1", "2", "3", "4"]
    ops_flag_valid_value = ["OPS_FLAG_OPEN", "OPS_FLAG_CLOSE"]
    is_valid = True

    def is_aicpu_op(op):
        op_info = op.get("opInfo", {})
        return op_info.get("engine") == "DNN_VM_AICPU"

    def check_param_type(op_key, op_info_key, op_io_info, aicpu_mode):
        if "paramType" not in op_io_info:
            if not aicpu_mode:
                print("op: " + op_key + " " + op_info_key + " missing: paramType")
                return False
            return True
        if op_io_info["paramType"] not in param_type_valid_value:
            print("op: " + op_key + " " + op_info_key +
                  " paramType not valid, valid key:[dynamic, optional, required]")
            return False
        return True

    def check_aicpu_extend_cfg(op_key, op):
        op_info = op.get("opInfo", {})
        valid = True
        subtype = op_info.get("subTypeOfInferShape")
        if subtype is not None and subtype not in infer_shape_subtype_valid_value:
            print("op: " + op_key +
                  " opInfo.subTypeOfInferShape not valid, valid key:[1, 2, 3, 4]")
            valid = False

        ops_flag = op_info.get("opsFlag")
        if ops_flag is not None and ops_flag not in ops_flag_valid_value:
            print("op: " + op_key +
                  " opInfo.opsFlag not valid, valid key:[OPS_FLAG_OPEN, OPS_FLAG_CLOSE]")
            valid = False

        workspace_size = op_info.get("workspaceSize")
        if workspace_size is not None:
            if not workspace_size.isdigit():
                print("op: " + op_key + " opInfo.workspaceSize not valid, should be integer in [100, 500]")
                valid = False
            else:
                value = int(workspace_size)
                if value < 100 or value > 500:
                    print("op: " + op_key + " opInfo.workspaceSize out of range, expected [100, 500]")
                    valid = False

        kernel_so = op_info.get("kernelSo")
        if kernel_so is not None and not kernel_so.endswith(".so"):
            print("op: " + op_key + " opInfo.kernelSo not valid, should end with .so")
            valid = False
        return valid

    for op_key in tbe_ops:
        op = tbe_ops[op_key]
        aicpu_mode = is_aicpu_op(op)

        if aicpu_mode and not check_aicpu_extend_cfg(op_key, op):
            is_valid = False

        for op_info_key in op:
            if op_info_key.startswith("input"):
                op_input_info = op[op_info_key]
                missing_keys = []
                for required_op_input_info_key in required_op_input_info_keys:
                    if required_op_input_info_key not in op_input_info:
                        missing_keys.append(required_op_input_info_key)
                if len(missing_keys) > 0:
                    print("op: " + op_key + " " + op_info_key + " missing: " + ",".join(missing_keys))
                    is_valid = False
                if not check_param_type(op_key, op_info_key, op_input_info, aicpu_mode):
                    is_valid = False
            if op_info_key.startswith("output"):
                op_input_info = op[op_info_key]
                missing_keys = []
                for required_op_input_info_key in required_op_output_info_keys:
                    if required_op_input_info_key not in op_input_info:
                        missing_keys.append(required_op_input_info_key)
                if len(missing_keys) > 0:
                    print("op: " + op_key + " " + op_info_key + " missing: " + ",".join(missing_keys))
                    is_valid = False
                if not check_param_type(op_key, op_info_key, op_input_info, aicpu_mode):
                    is_valid = False
    print("==============check valid for ops info end================\n\n")
    return is_valid


def write_json_file(tbe_ops_info, json_file_path):
    """
    Save info to json file
    Parameters:
    ----------------
    tbe_ops_info: ops_info
    json_file_path: json file path
    ----------------
    """
    json_file_real_path = os.path.realpath(json_file_path)
    with open(json_file_real_path, "w") as f:
        # Only the owner and group have rights
        os.chmod(json_file_real_path, stat.S_IWGRP + stat.S_IWUSR + stat.S_IRGRP + stat.S_IRUSR)
        json.dump(tbe_ops_info, f, sort_keys=True, indent=4, separators=(',', ':'))
    print("Compile op info cfg successfully.")


def parse_ini_to_json(ini_file_paths, outfile_path):
    """
    parse ini files to json file
    Parameters:
    ----------------
    ini_file_paths: list of ini file path
    outfile_path: output file path
    ----------------
    """
    tbe_ops_info = parse_ini_files(ini_file_paths)
    if not check_op_info(tbe_ops_info):
        print("Compile op info cfg failed.")
        return False
    write_json_file(tbe_ops_info, outfile_path)
    return True


if __name__ == '__main__':
    args = sys.argv

    output_file_path = "tbe_ops_info.json"
    ini_file_path_list = []

    for arg in args:
        if arg.endswith("ini"):
            ini_file_path_list.append(arg)
            output_file_path = arg.replace(".ini", ".json")
        if arg.endswith("json"):
            output_file_path = arg

    if len(ini_file_path_list) == 0:
        ini_file_path_list.append("tbe_ops_info.ini")

    if not parse_ini_to_json(ini_file_path_list, output_file_path):
        sys.exit(1)
    sys.exit(0)