#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
#-------------------------------------------------------------------
# -----------------------------------------------------------------------------------------------------------
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# -----------------------------------------------------------------------------------------------------------

import os
import subprocess
import shutil
import tempfile
import platform
import json
import time

from tbe.common.buildcfg import get_current_build_config
from tbe.tikcpp.compile_op import CommonUtility, AscendCLogLevel
from asc_op_compile_base.common.platform.platform_info import get_soc_spec
import tbe.common.utils.log as logger
# Python3 lib pyautofuse.so
from .pyautofuse import Schedule, CodeGen, ascir
from .ascbc_kernel_compile import ascbc_kernel_compile, camel_to_snake
from .compile_adapter import get_pgo_env_flag, get_pgo_topn
from tbe.tikcpp.get_op_tiling import TilingInfo, _change_param_name_to_name, gen_static_shape_v2
from tbe.common.utils.op_tiling import do_op_tiling
from tbe.common.context import get_context
from tbe.tikcpp import OpInfo

PYF_PATH = os.path.dirname(os.path.realpath(__file__))
ASCEND_PATH = os.path.join(PYF_PATH, "..", "..", "..")
timestamp_list = []
CV_COMMON_MIX_WHITE_LIST = [
    4096, 4097, 4112, 4113, 4160, 4161, 4176, 4177,
    8192, 8208, 8256, 8272,
    1048576, 1048577, 1048592, 1048593, 1048640, 1048641, 1048656, 1048657,
    1114113, 1114129, 1114177, 1114193, 1179649, 1179665, 1179713, 1179729,
    2097152, 2097153, 2097168, 2097169, 2097216, 2097217, 2097232, 2097233,
    2101248, 2101249, 2101264, 2101265, 2101312, 2101313, 2101328, 2101329,
    2162689, 2162705, 2162753, 2162769, 2228225, 2228241, 2228289, 2228305
]

CV_COMMON_BMM_MIX_WHITE_LIST = [
    513,
    8192,
    2097409, 2097425, 2097473, 2097489
]


def generate_cmake_lists(asc_graph_name, kernel_name, host_build_dir, is_last_compile, is_static_shape, is_cube=False):
    source = f"################ {kernel_name}.so ################\n"

    machine = platform.machine()
    source += "cmake_minimum_required(VERSION 3.16.0)\n"
    source += "project(asc_codegen)\n"
    source += f"set(CMAKE_CXX_STANDARD 17)\n"
    source += f"link_directories({ASCEND_PATH}/{machine}-linux/lib64/\n"
    source += f")\n\n"

    source += f"file(GLOB ALL_CPP_SRCS\n"
    source += f"    *tiling_func*.cpp\n"
    source += f"    *infershape*.cpp\n"
    if not is_static_shape and is_cube:
        source += f"    cube_kernel_tiling_wrapper.cpp\n"
    if is_last_compile:
        source += f"    *get_kernel*.cpp\n"
    source += ")\n\n"
    source += f"add_library({kernel_name} SHARED ${{ALL_CPP_SRCS}})\n\n"

    source += f"target_compile_options({kernel_name} PRIVATE\n"
    if is_static_shape:
        source += ("    -O0 -fno-common -Werror -Wextra -Wfloat-equal -fvisibility=default -DLOG_CPP"
                   " -ffile-prefix-map=${CMAKE_CURRENT_SOURCE_DIR}/=\n")
    else:
        source += ("    -O2 -fno-common -Werror -Wextra -Wfloat-equal -fvisibility=default -DLOG_CPP"
                   " -ffile-prefix-map=${CMAKE_CURRENT_SOURCE_DIR}/=\n")

    source += ")\n\n"
    source += f"message(STATUS \"Using environment variable ASCEND_INSTALL_PATH: {ASCEND_PATH}\")\n"
    source += (f"target_link_libraries({kernel_name} PRIVATE c_sec ascendalog platform error_manager"
               f" tiling_api graph_base register)\n")
    source += f"target_include_directories({kernel_name} PRIVATE\n"
    source += f"    {ASCEND_PATH}/include\n"
    source += f"    {ASCEND_PATH}/include/base\n"
    source += f"    {ASCEND_PATH}/include/experiment\n"
    source += f"    {ASCEND_PATH}/{machine}-linux/include\n"
    source += f"    {ASCEND_PATH}/{machine}-linux/pkg_inc/base\n"
    source += f"    {ASCEND_PATH}/{machine}-linux/ascendc/include/highlevel_api/tiling/platform\n"
    source += f"    {ASCEND_PATH}/{machine}-linux/ascendc/include/highlevel_api\n"
    if is_cube:
        source += f"    {ASCEND_PATH}/opp/built-in/op_impl/ai_core/tbe/impl/ascendc/mat_mul_v3\n"
        source += f"    {ASCEND_PATH}/opp/built-in/op_impl/ai_core/tbe/impl/ops_nn/ascendc/mat_mul_v3\n"
        source += f"    {ASCEND_PATH}/opp/built-in/op_impl/ai_core/tbe/impl/ascendc/conv2d_v2\n"
        source += f"    {ASCEND_PATH}/opp/built-in/op_impl/ai_core/tbe/impl/ops_nn/ascendc/conv2d_v2\n"
        source += f"    {ASCEND_PATH}/opp/built-in/op_impl/ai_core/tbe/impl/ops_nn/ascendc/common\n"
    source += "    )\n\n"

    with open(host_build_dir + "/CMakeLists.txt", "w") as cmake_file:
        cmake_file.write(source)


def generate_file(dst_dir, file_name, text):
    os.makedirs(dst_dir, exist_ok=True)
    file_path = os.path.join(dst_dir, file_name)
    with open(file_path, "w") as file:
        file.write(text)


def ascbc_host_compile(graph_name, kernel_name, host_build_dir, is_last_compile, is_static_shape, is_cube=False):
    generate_cmake_lists(graph_name, kernel_name, host_build_dir, is_last_compile, is_static_shape, is_cube)
    ori_directory = os.getcwd()
    # 切换到临时工作目录
    os.chdir(host_build_dir)

    cmake_command = ["cmake", "-S", ".", "-B", "./", "-DCMAKE_C_COMPILER=gcc", "-DCMAKE_CXX_COMPILER=g++"]

    # 运行CMake
    cmake_ret = subprocess.run(cmake_command, capture_output=True, text=True)
    if cmake_ret.returncode != 0:
        os.chdir(ori_directory)
        CommonUtility.print_compile_log("", f"cmake fail: {cmake_ret.stderr}", AscendCLogLevel.LOG_ERROR)
        error_msg = f"execute cmake failed with return code {cmake_ret.returncode}\n"
        error_msg += f"Standard Output:\n{cmake_ret.stdout}\n"
        error_msg += f"Standard Error:\n{cmake_ret.stderr}\n"
        raise Exception(error_msg)

    # 定义 CMake 编译命令
    make_command = ["make", "-C", "./", '-j']
    # 编译生成so
    make_ret = subprocess.run(make_command, capture_output=True, text=True)
    if make_ret.returncode != 0:
        os.chdir(ori_directory)
        CommonUtility.print_compile_log("", f"make fail: {make_ret.stderr}", AscendCLogLevel.LOG_ERROR)
        error_msg = f"execute cmake failed with return code {make_ret.returncode}\n"
        error_msg += f"Standard Output:\n{make_ret.stdout}\n"
        error_msg += f"Standard Error:\n{make_ret.stderr}\n"
        raise Exception(error_msg)
    # 回退到原始工作目录
    os.chdir(ori_directory)


def gen_valid_name(t_name):
    result = []
    last_was_underscore = False

    for c in t_name:
        if c.isalnum():
            result.append(c)
            last_was_underscore = False
        else:
            if not last_was_underscore:
                result.append('_')
                last_was_underscore = True

    ret_name = ''.join(result)

    # 删除开头的下划线
    if ret_name and ret_name[0] == '_':
        ret_name = ret_name[1:]

    # 如果以数字开头,添加前缀
    if ret_name and ret_name[0].isdigit():
        ret_name = "t_" + ret_name

    return ret_name


def is_valid_path(path):
    """
    判断路径是否合法。
    
    :param path: 路径字符串
    :return: 如果路径合法返回 True,否则返回 False
    """
    try:
        # 尝试获取路径的绝对路径
        os.path.abspath(path)
        return True
    except (TypeError, ValueError):
        # 如果路径无效,会抛出异常
        return False


def check_keys_in_dict(d, keys):
    """
    检查字典中是否缺少指定的键。如果所有指定的键都存在,返回True,
    
    :param d: 字典
    :param keys: 需要检查的键列表
    """
    # 检查每个键是否在字典中
    for key in keys:
        if key not in d:
            return False  # 如果任何一个键不存在,直接返回false

    # 如果所有键都存在
    return True


def modify_json_file(json_file, host_so):
    """
    修改json文件中的binFileName/binFileSuffix
    """
    with open(json_file, 'r') as f:
        data = json.load(f)
    kernenl_file_name = data['binFileName']
    lib_file_name = "lib" + kernenl_file_name
    lib_file_suffix = ".so"
    data['binFileName'] = lib_file_name
    data['binFileSuffix'] = lib_file_suffix
    import ctypes
    # 加载共享库
    lib = ctypes.CDLL(host_so)

    # 定义函数参数和返回类型
    lib.GetTilingDataSize.argtypes = []
    lib.GetTilingDataSize.restype = ctypes.c_size_t

    # 调用函数
    data['opParaSize'] = int(lib.GetTilingDataSize())
    CommonUtility.print_compile_log("", f"{kernenl_file_name} tiling size: {data['opParaSize']}",
                                    AscendCLogLevel.LOG_INFO)
    # 自动融合 workspace 默认只有一个 
    data['workspace']['num'] = int(1)
    data['workspace']['size'] = list([-1])
    data['workspace']['type'] = list([0])
    with open(json_file, 'w') as f:
        json.dump(data, f, indent=4)


def timestamp_set(is_start_stamp, graph_name, stage, need_report=False):
    """
    设置时间戳列表
    """
    timestamp = time.time_ns()
    if is_start_stamp:
        timestamp_list.append(timestamp)
    else:
        start = timestamp_list[-1]
        cost_time = (timestamp - start)
        timestamp_list[-1] = cost_time
        ascir.utils.duration_record([stage, graph_name], int(start), int(cost_time))
    if need_report:
        ascir.utils.report_durations()

def ascendc_clean(temp_dir):
    src_directory = os.getcwd()
    keep_dirs = {'host', 'device'}
    # 遍历目录中的所有条目
    for entry in os.listdir(temp_dir):
        entry_path = os.path.join(temp_dir, entry)
        # 检查是否是文件或目录
        if os.path.isfile(entry_path):
            # 如果是文件,直接删除
            os.remove(entry_path)
            CommonUtility.print_compile_log("", f"delete file: {entry_path}", AscendCLogLevel.LOG_INFO)
        else:
            # 如果是目录,检查是否需要保留
            if entry not in keep_dirs:
                shutil.rmtree(entry_path)
                CommonUtility.print_compile_log("", f"delete dir: {entry_path}", AscendCLogLevel.LOG_INFO)
    os.chdir(src_directory)


def static_shape_kernel_proc(kernel_src, temp_dir, kernel_type, use_cv_common=None):
    import re
    base_device_files = os.path.basename(kernel_src)
    if use_cv_common and use_cv_common[0]:
        kernel_file = os.path.join(temp_dir, "device", "cv_common", base_device_files)
    else:
        kernel_file = os.path.join(temp_dir, "device", base_device_files)

    with open(kernel_file, 'r', encoding='utf-8') as file:
        code = file.read()

    # 使用正则表达式进行替换
    code = re.sub(r'GET_TILING_DATA\(t, gm_tiling_data\);', 'const AutofuseTilingData t;', code)
    code = re.sub(r'GM_ADDR workspace, AutofuseTilingData& t', 'GM_ADDR workspace, const AutofuseTilingData& t', code)
    if kernel_type is not None:
        code = re.sub(r'KERNEL_TASK_TYPE_DEFAULT\(.*\)', f'KERNEL_TASK_TYPE_DEFAULT({kernel_type})', code)

    # 将修改后的内容写回文件
    with open(kernel_file, 'w', encoding='utf-8') as file:
        file.write(code)


def static_shape_compile(kernel_name, temp_dir, graph_name, tiling_key_list=None, kernel_type_list=None,
                         use_cv_common=None, is_cube=False):
    if use_cv_common and use_cv_common[0]:
        host_build_dir = os.path.join(temp_dir, "host", "cv_common")
    else:
        host_build_dir = os.path.join(temp_dir, "host")
    ascbc_host_compile(graph_name, kernel_name, host_build_dir, False, True, is_cube)

    kernel_src = graph_name + "_op_kernel.cpp"
    host_so = os.path.join(host_build_dir, f"lib{kernel_name}.so")
    pgo_config_path = os.path.abspath(
        os.path.join(temp_dir, "..", "..", "pgo", f"{graph_name}_config.txt")
    )

    import ctypes
    lib = ctypes.CDLL(host_so)

    CommonUtility.print_compile_log("", f"static shape compile", AscendCLogLevel.LOG_INFO)
    ascendc_clean(temp_dir)
    result = ""
    if hasattr(lib, 'GenConstTilingData'):
        lib.GenConstTilingData.argtypes = []
        lib.GenConstTilingData.restype = ctypes.c_char_p
        result = lib.GenConstTilingData(ctypes.c_char_p(pgo_config_path.encode('utf-8')),
                                        ctypes.c_int(
                                            int(get_soc_spec('vector_core_cnt'))),
                                        ctypes.c_int(int(get_soc_spec('ub_size'))))

    const_tiling_data = result.decode('utf-8')
    if hasattr(lib, 'GetCVUBFusionStageSizeName'):
        stage_size_name = get_cv_ub_fusion_stage_size_name(kernel_name=kernel_name, temp_dir=temp_dir,
                                                           graph_name=graph_name)
        const_tiling_data = const_tiling_data + f"\n#define STAGE_SIZE_NAME " + stage_size_name + "\n"

    if use_cv_common and use_cv_common[0]:
        tiling_data = os.path.join(temp_dir, "device", "cv_common", "autofuse_tiling_data.h")
        tiling_data_bak = os.path.join(temp_dir, "device", "cv_common", "autofuse_tiling_data_bak.h")
    else:
        tiling_data = os.path.join(temp_dir, "device", "autofuse_tiling_data.h")
        tiling_data_bak = os.path.join(temp_dir, "device", "autofuse_tiling_data_bak.h")

    # 备份原tilingdata文件
    shutil.copy(tiling_data, tiling_data_bak)
    with open(tiling_data, "w") as file:
        file.write(const_tiling_data)

    # 静态shape可以在编译期间获取tiling_key
    tiling_key = None
    if hasattr(lib, 'GetTilingKeyForStatic'):
        lib.GetTilingKeyForStatic.argtypes = []
        lib.GetTilingKeyForStatic.restype = ctypes.c_int64  # int64_t对应c_int64
        tiling_key = lib.GetTilingKeyForStatic()
    if (tiling_key_list is not None and tiling_key is not None):
        tiling_key_list[0] = tiling_key
    # 静态shape可以在编译期间获取kernel_type
    kernel_type = None
    if hasattr(lib, 'GetTilingKeyKernelTypeForStatic'):
        lib.GetTilingKeyKernelTypeForStatic.argtypes = []
        lib.GetTilingKeyKernelTypeForStatic.restype = ctypes.c_char_p
        kernel_type_result = lib.GetTilingKeyKernelTypeForStatic()
        kernel_type = kernel_type_result.decode('utf-8')
    if (kernel_type_list is not None and kernel_type is not None):
        kernel_type_list[0] = kernel_type
    # 修改kernel文件
    static_shape_kernel_proc(kernel_src, temp_dir, kernel_type, use_cv_common=use_cv_common)


def dynamic_shape_compile(kernel_name, temp_dir, graph_name, use_cv_common=None, is_cube=False):
    if use_cv_common and use_cv_common[0]:
        host_build_dir = os.path.join(temp_dir, "host", "cv_common")
    else:
        host_build_dir = os.path.join(temp_dir, "host")
    ascbc_host_compile(graph_name, kernel_name, host_build_dir, False, True, is_cube)

    host_so = os.path.join(host_build_dir, f"lib{kernel_name}.so")


    import ctypes
    lib = ctypes.CDLL(host_so)

    CommonUtility.print_compile_log("", f"dynamic shape compile", AscendCLogLevel.LOG_INFO)
    ascendc_clean(temp_dir)

    const_tiling_data = ""
    if hasattr(lib, 'GetCVUBFusionStageSizeName'):
        stage_size_name = get_cv_ub_fusion_stage_size_name(kernel_name=kernel_name, temp_dir=temp_dir,
                                                            graph_name=graph_name)
        const_tiling_data = const_tiling_data + f"\n#define STAGE_SIZE_NAME " + stage_size_name + "\n"

    if use_cv_common and use_cv_common[0]:
        tiling_data = os.path.join(temp_dir, "device", "cv_common", "autofuse_tiling_data.h")
        tiling_data_bak = os.path.join(temp_dir, "device", "cv_common", "autofuse_tiling_data_bak.h")
    else:
        tiling_data = os.path.join(temp_dir, "device", "autofuse_tiling_data.h")
        tiling_data_bak = os.path.join(temp_dir, "device", "autofuse_tiling_data_bak.h")

    # 备份原tilingdata文件
    shutil.copy(tiling_data, tiling_data_bak)
    with open(tiling_data, "w") as file:
        file.write(const_tiling_data)


def static_shape_cv_compile(kernel_name, temp_dir, graph_name):
    host_build_dir = os.path.join(temp_dir, "host")
    ascbc_host_compile(graph_name, kernel_name, host_build_dir, False, True, True)
    host_so = os.path.join(host_build_dir, f"lib{kernel_name}.so")
    pgo_config_path = os.path.abspath(
        os.path.join(temp_dir, "..", "..", "pgo", f"{graph_name}_config.txt")
    )

    import ctypes
    lib = ctypes.CDLL(host_so)
    CommonUtility.print_compile_log("", f"static shape cv compile", AscendCLogLevel.LOG_INFO)
    ascendc_clean(temp_dir)
    result = -1
    if hasattr(lib, 'GenCVFusionTilingKey'):
        result = lib.GenCVFusionTilingKey(ctypes.c_char_p(pgo_config_path.encode('utf-8')),
                                          ctypes.c_int(int(get_soc_spec('vector_core_cnt'))),
                                          ctypes.c_int(int(get_soc_spec('ub_size'))))
    return result


def get_cv_ub_fusion_stage_size_name(kernel_name, temp_dir, graph_name):
    host_build_dir = os.path.join(temp_dir, "host")
    host_so = os.path.join(host_build_dir, f"lib{kernel_name}.so")
    pgo_config_path = os.path.abspath(
        os.path.join(temp_dir, "..", "..", "pgo", f"{graph_name}_config.txt")
    )

    import ctypes
    lib = ctypes.CDLL(host_so)
    CommonUtility.print_compile_log("", f"get cv ub fusion stage size", AscendCLogLevel.LOG_INFO)
    ascendc_clean(temp_dir)
    lib.GetCVUBFusionStageSizeName.argtypes = []
    lib.GetCVUBFusionStageSizeName.restype = ctypes.c_char_p
    result = lib.GetCVUBFusionStageSizeName()
    stage_size_name = result.decode('utf-8')
    return stage_size_name


def static_shape_cv_common_compile(kernel_name, temp_dir, graph_name):
    host_build_dir = os.path.join(temp_dir, "host", "cv_common")
    ascbc_host_compile(graph_name, kernel_name, host_build_dir, False, True, True)
    host_so = os.path.join(host_build_dir, f"lib{kernel_name}.so")
    pgo_config_path = os.path.abspath(
        os.path.join(temp_dir, "..", "..", "pgo", f"{graph_name}_config.txt")
    )

    import ctypes
    lib = ctypes.CDLL(host_so)
    CommonUtility.print_compile_log("", f"static shape cv compile", AscendCLogLevel.LOG_INFO)
    ascendc_clean(temp_dir)
    vec_block_dim = -1
    wss = -1
    if hasattr(lib, 'GenTilingDataValueBlockDimAndWss'):
        # 创建变量接收返回值
        workspace_size = ctypes.c_uint32()
        block_dim = ctypes.c_uint32()
        ret = lib.GenTilingDataValueBlockDimAndWss(ctypes.c_char_p(pgo_config_path.encode('utf-8')),
                                                   ctypes.c_uint32(int(get_soc_spec('vector_core_cnt'))),
                                                   ctypes.c_uint32(int(get_soc_spec('ub_size'))),
                                                   ctypes.byref(workspace_size),
                                                   ctypes.byref(block_dim))

        if ret != -1:
            vec_block_dim = block_dim.value
            wss = workspace_size.value
    return vec_block_dim, wss


def is_static_compile(params, tiling_func_srcs):
    import re
    force_unknown = params.get('force_unknown')
    if force_unknown:
        return False;

    static_shape = False;
    pattern = r'AutofuseIsStaticShape\s*\(\s*\)\s*\{[^{}]*return\s+(true|false)\s*;'
    for tiling_func_src in tiling_func_srcs.values():
        match = re.search(pattern, tiling_func_src, re.IGNORECASE)
        if match:
            return_value = match.group(1).lower()
            static_shape = (return_value == "true")
    return static_shape


def get_pgo_dir(temp_dir):
    # 获取pgo缓存路径
    pgo_dir = os.path.abspath(os.path.join(temp_dir, "..", "..", "pgo"))
    return pgo_dir


# 处理tiling数据字典
def _process_tiling_data_dict(tiling_dict, kernel_build_dir, host_build_dir):
    for key, tiling_data_src in tiling_dict.items():
        if key == "common":
            cv_common_dir = os.path.join(kernel_build_dir, "cv_common")
            generate_file(cv_common_dir, "autofuse_tiling_data.h", tiling_data_src)
            host_cv_common_dir = os.path.join(host_build_dir, "cv_common")
            generate_file(host_cv_common_dir, "autofuse_tiling_data.h", tiling_data_src)
        else:
            generate_file(kernel_build_dir, "autofuse_tiling_data.h", tiling_data_src)
            generate_file(host_build_dir, "autofuse_tiling_data.h", tiling_data_src)


# 处理kernel字典
def _process_kernel_dict(kernel_dict, graph_name, kernel_build_dir):
    ret_op_kernel_src = None
    for key, op_kernel_src in kernel_dict.items():
        if key == "common":
            cv_common_dir = os.path.join(kernel_build_dir, "cv_common")
            generate_file(cv_common_dir, f"{graph_name}_op_kernel.cpp", op_kernel_src)
        else:
            generate_file(kernel_build_dir, f"{graph_name}_op_kernel.cpp", op_kernel_src)
            ret_op_kernel_src = op_kernel_src
    return ret_op_kernel_src


# 处理tiling函数和infershape文件
def _process_tiling_funcs_and_infershape(tiling_func_srcs, graph_name, host_build_dir, infershape_src):
    ret_tiling_func_srcs = None
    common_template_processed = False

    for template_key, template_dict in tiling_func_srcs.items():
        if template_key == "common":
            template_dir = os.path.join(host_build_dir, "cv_common")
            # 生成cv_common目录下的infershape文件
            generate_file(template_dir, graph_name + "_infershape.cpp", infershape_src)
            common_template_processed = True
            ret_tiling_func_srcs = template_dict
        else:
            template_dir = host_build_dir
            ret_tiling_func_srcs = template_dict

        for key, value in template_dict.items():
            if key == "TilingHead":
                generate_file(template_dir, "autofuse_tiling_func_common.h", value)
            elif key == "CubeKernelTilingWrapperHpp":
                generate_file(template_dir, "cube_kernel_tiling_wrapper.h", value)
            elif key == "CubeKernelTilingWrapperCpp":
                generate_file(template_dir, "cube_kernel_tiling_wrapper.cpp", value)
            elif "TilingData" not in key:
                generate_file(template_dir, graph_name + "_tiling_func_" + key + ".cpp", value)

    # 确保始终生成host_build_dir目录下的infershape文件
    if not common_template_processed or ret_tiling_func_srcs is not None:
        generate_file(host_build_dir, graph_name + "_infershape.cpp", infershape_src)

    return ret_tiling_func_srcs


def generate_device_and_host_code(graph_name, temp_dir, params, code_gen):
    # 生成device代码
    timestamp_set(True, graph_name, "GenDevice")
    schedule_results = params['schedule_results']
    vector_core_num = "0" if params['vector_core_num'] is None else str(params['vector_core_num'])
    tiling_dict, kernel_dict = code_gen.device_code_generator(schedule_results)
    kernel_build_dir = os.path.join(temp_dir, "device")
    pgo_dir = get_pgo_dir(temp_dir)
    host_build_dir = os.path.join(temp_dir, "host")

    # 处理tiling数据字典
    _process_tiling_data_dict(tiling_dict, kernel_build_dir, host_build_dir)

    # 处理kernel字典,获取返回的op_kernel_src
    ret_op_kernel_src = _process_kernel_dict(kernel_dict, graph_name, kernel_build_dir)
    timestamp_set(False, graph_name, "GenDevice")

    # 生成host代码
    if not check_keys_in_dict(params, ['output_symbol_shape']):
        CommonUtility.print_compile_log("", f"output_symbol_shape is not exist", AscendCLogLevel.LOG_ERROR)
        raise Exception("An error occurred autofuse compile for check extra_params")
    CommonUtility.print_compile_log("", f"params output_symbol_shape : {params['output_symbol_shape']}",
                                    AscendCLogLevel.LOG_INFO)
    timestamp_set(True, graph_name, "GenHost")
    shape_info = params.get('symbol_source_info')
    output_symbol = json.loads(params.get('output_symbol_shape'))
    tiling_func_srcs, infershape_src = \
        code_gen.host_code_generator(
            schedule_results,
            shape_info,
            output_symbol,
            pgo_dir,
            vector_core_num)

    # 处理tiling函数和infershape文件
    ret_tiling_func_srcs = _process_tiling_funcs_and_infershape(tiling_func_srcs, graph_name, host_build_dir,
                                                                infershape_src)

    timestamp_set(False, graph_name, "GenHost")

    return ret_op_kernel_src, ret_tiling_func_srcs


def pgo_kernel_compile(*args, temp_dir, params, op_kernel_src, code_gen):
    """PGO依赖的动态so编译"""
    schedule_results = params['schedule_results']
    kernel_name = args[-1]
    input_num = schedule_results.get_input_num()
    output_num = schedule_results.get_output_num()
    kernel_build_dir = os.path.join(temp_dir, "device")
    host_build_dir = os.path.join(temp_dir, "host")
    use_list_tensor_desc = op_kernel_src.find('kernel_operator_list_tensor_intf.h') > 0
    enable_parallel_compile = op_kernel_src.rfind('void fake_tiling_ids()') > 0
    pgo_dir = get_pgo_dir(temp_dir)
    graph_name = gen_valid_name(schedule_results.get_name())
    graph_name = camel_to_snake(graph_name)

    kernel_file_path = None
    json_file_path = None
    target_lib_path = None

    try:
        kernel_file, json_file = ascbc_kernel_compile(args, graph_name=graph_name, kernel_name=kernel_name,
            input_num=input_num, output_num=output_num, temp_build_dir=kernel_build_dir,
            impl_mode=params.get('impl_mode'), use_list_tensor_desc=use_list_tensor_desc,
            enable_parallel_compile=enable_parallel_compile)
        kernel_file_path = os.path.join(pgo_dir, os.path.basename(kernel_file))
        json_file_path = os.path.join(pgo_dir, os.path.basename(json_file))
        shutil.copy(kernel_file, kernel_file_path)
        shutil.copy(json_file, json_file_path)
    except Exception as e:
        msg = f"pgo kernel compile for {kernel_name} failed: {str(e)}. skip pgo tune."
        CommonUtility.print_compile_log("", msg, AscendCLogLevel.LOG_ERROR)
        return False, []
    get_kernel_src = code_gen.get_kernel_and_json_generator(kernel_file, json_file)
    generate_file(host_build_dir, graph_name + "_get_kernel.cpp", get_kernel_src)
    ascbc_host_compile(graph_name, kernel_name, host_build_dir, True, False)
    source_lib_path = os.path.join(host_build_dir, f"lib{kernel_name}.so")
    target_lib_path = os.path.join(pgo_dir, f"lib{graph_name}.so")
    shutil.copy(source_lib_path, target_lib_path)
    return True, [kernel_file_path, json_file_path, target_lib_path]


def pgo_cleanup_kernel_and_json(pgo_temp_files, config_path=None):
    """清理PGO流程中生成的其他中间文件"""
    if config_path is not None:
        if os.path.exists(config_path):
            os.remove(config_path)

    from autofuse.compile_adapter import get_debug_flag
    compile_debug = get_debug_flag()
    if compile_debug:
        return
    if not pgo_temp_files:
        return
    for item in pgo_temp_files:
        if not item:
            continue
        try:
            if os.path.exists(item):
                os.remove(item)
                logger.info("[PGO] cleanup file: %s", item)
        except Exception as e:
            logger.warn("[PGO] cleanup file failed: %s, err: %s", item, str(e))


def check_dir_permissions(path):
    import stat
    # 检查是否为软链接
    if path.is_symlink():
        raise Exception(f"Warning: {path} is a symbolic link. Please configure --replace_kernel to a real directory.")

    # 检查目录是否存在
    if not path.exists():
        raise FileNotFoundError(
            f"Directory {path} does not exist. "
            f"Please configure --replace_kernel to an existing directory via AUTOFUSE_DFX_FLAGS."
        )

    # 检查读写执行权限
    mode = path.stat().st_mode
    if not (mode & stat.S_IRUSR):
        raise PermissionError(
            f"No read permission for {path}. "
            f"Please check the directory permissions configured by --replace_kernel in AUTOFUSE_DFX_FLAGS."
        )
    if not (mode & stat.S_IWUSR):
        raise PermissionError(
            f"No write permission for {path}. "
            f"Please check the directory permissions configured by --replace_kernel in AUTOFUSE_DFX_FLAGS."
        )
    if not (mode & stat.S_IXUSR):
        raise PermissionError(
            f"No execute permission for {path}. "
            f"Please check the directory permissions configured by --replace_kernel in AUTOFUSE_DFX_FLAGS."
        )


def get_replace_kernel_root():
    import re
    from pathlib import Path
    autofuse_dfx_flags_env = os.getenv("AUTOFUSE_DFX_FLAGS", "")
    if not autofuse_dfx_flags_env or "replace_kernel=" not in autofuse_dfx_flags_env:
        return None
    pattern = r'replace_kernel=([^";]+)'
    match = re.search(pattern, autofuse_dfx_flags_env)
    if not match:
        logger.info("match env replace_kernel failed. AUTOFUSE_DFX_FLAGS is %s: ", autofuse_dfx_flags_env)
        return None
    replace_kernel_path = match.group(1)
    kernel_path = Path(replace_kernel_path).expanduser().absolute()
    check_dir_permissions(kernel_path)
    return kernel_path


def replace_device_kernel(replace_root, kernel_build_dir, graph_name):
    src_kernel_path = os.path.join(str(replace_root), f"{graph_name}_op_kernel.cpp")
    if not os.path.exists(src_kernel_path):
        return
    dst_kernel_path = os.path.join(kernel_build_dir, f"{graph_name}_op_kernel.cpp")
    shutil.copy(src_kernel_path, dst_kernel_path)
    logger.info("replace kernel: %s to %s", src_kernel_path, dst_kernel_path)


def find_host_replace_source_dir(replace_root, graph_name):
    from pathlib import Path
    matched_dirs = sorted({candidate.parent for candidate in Path(replace_root).rglob(f"{graph_name}*_tiling_func*.cpp")
                           if candidate.is_file()})
    if len(matched_dirs) > 1:
        raise RuntimeError(f"Multiple host replace source dirs matched for {graph_name}: "
                           f"{[str(path) for path in matched_dirs]}")
    return matched_dirs[0] if matched_dirs else None


def cleanup_host_tiling_files(host_build_dir, graph_name):
    from pathlib import Path
    removed_files = []
    for old_file in Path(host_build_dir).glob(f"{graph_name}*_tiling_func*.cpp"):
        old_file.unlink()
        removed_files.append(old_file.name)
    return removed_files


def replace_host_files(replace_root, host_build_dir, graph_name):
    from pathlib import Path
    source_dir = find_host_replace_source_dir(replace_root, graph_name)
    if source_dir is None:
        return
    logger.info("replace host source dir: %s", source_dir)
    logger.info("replace host target dir: %s", host_build_dir)
    host_build_path = Path(host_build_dir)
    if source_dir.resolve() == host_build_path.resolve():
        logger.info("skip host replacement because source dir equals target dir")
        return
    removed_files = cleanup_host_tiling_files(host_build_dir, graph_name)
    logger.info("cleanup stale host tiling files: %s", removed_files)
    copied_files = []
    for src in source_dir.glob(f"{graph_name}*_tiling_func*.cpp"):
        dst = host_build_path / src.name
        if src.resolve() == dst.resolve():
            continue
        shutil.copy(str(src), str(dst))
        copied_files.append(src.name)
    for header_name in ["autofuse_tiling_data.h", "autofuse_tiling_func_common.h", "autofuse_cube_tiling_data.h"]:
        header_path = source_dir / header_name
        if header_path.exists():
            dst = host_build_path / header_name
            if header_path.resolve() == dst.resolve():
                continue
            shutil.copy(str(header_path), str(dst))
            copied_files.append(header_name)
    logger.info("copied host files: %s", copied_files)


def replace_kernel(kernel_build_dir, graph_name):
    replace_root = get_replace_kernel_root()
    if replace_root is None:
        return
    replace_device_kernel(replace_root, kernel_build_dir, graph_name)


def get_host_build_dir(temp_dir, use_cv_common):
    if use_cv_common and use_cv_common[0]:
        return os.path.join(temp_dir, "host", "cv_common")
    return os.path.join(temp_dir, "host")


def get_device_build_dir(temp_dir, use_cv_common):
    if use_cv_common and use_cv_common[0]:
        return os.path.join(temp_dir, "device", "cv_common")
    return os.path.join(temp_dir, "device")


def replace_host_files_if_needed(replace_root, temp_dir, use_cv_common, graph_name):
    if replace_root is None:
        return
    host_build_dir = get_host_build_dir(temp_dir, use_cv_common)
    replace_host_files(replace_root, host_build_dir, graph_name)


def generate_pgo_code(params, code_gen, pgo_dir, host_build_dir):
    """生成PGO代码"""
    schedule_results = params['schedule_results']
    graph_name = gen_valid_name(schedule_results.get_name())
    graph_name = camel_to_snake(graph_name)

    pgo_src = code_gen.pgo_code_generator(schedule_results, pgo_dir)
    generate_file(host_build_dir, graph_name + "_pgo.cpp", pgo_src)


def pgo_get_mspti_config():
    """获取PGO依赖的mspti相关配置"""
    mspti_dir = os.path.join(ASCEND_PATH, "tools", "mspti")
    mspti_lib64_dir = os.path.join(mspti_dir, "lib64")
    mspti_so = os.path.join(mspti_lib64_dir, "libmspti.so")
    libprof_common_so = os.path.join(mspti_lib64_dir, "libprof_common.so")

    if not os.path.exists(mspti_so):
        return None

    has_prof_common = os.path.exists(libprof_common_so)
    preload_so_paths = [libprof_common_so, mspti_so] if has_prof_common else [mspti_so]

    link_flags = [f"-L{mspti_lib64_dir}", "-lmspti"]
    if has_prof_common:
        link_flags.append("-lprof_common")

    return mspti_dir, preload_so_paths, link_flags


def build_pgo_compile_command(source_file, output_file, mspti_dir, mspti_link_flags):
    """构建PGO编译命令"""
    machine = platform.machine()
    base_cmd = ["g++", "-std=c++17", "-O2", "-fPIC"]
    includes = [
        f"-I{ASCEND_PATH}/include/",
        f"-I{ASCEND_PATH}/include/base",
        f"-I{ASCEND_PATH}/include/experiment/runtime",
        f"-I{ASCEND_PATH}/include/experiment/msprof",
        f"-I{ASCEND_PATH}/{machine}-linux/include/toolchain",
        f"-I{ASCEND_PATH}/{machine}-linux/pkg_inc/base",
        f"-I{mspti_dir}/include"
    ]
    libs = [
        f"-L{ASCEND_PATH}/lib64", "-lascendcl", "-lruntime", "-lunified_dlog",
        "-lc_sec",
        *mspti_link_flags,
        "-ldl"
    ]
    cmd = base_cmd + [source_file, "-o", output_file] + includes + libs
    return cmd


def pgo_compile(graph_name, host_build_dir, mspti_dir, mspti_link_flags):
    """编译PGO代码"""
    pgo_source_file = os.path.join(host_build_dir, f"{graph_name}_pgo.cpp")
    output_file = os.path.join(host_build_dir, "pgo")

    cmd = build_pgo_compile_command(pgo_source_file, output_file, mspti_dir, mspti_link_flags)
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        msg = "cmd = " + ' '.join(cmd) + "\n"
        msg += f"stdout: {result.stdout}\n"
        msg += f"stderr: {result.stderr}\n"
        CommonUtility.print_compile_log("", f"pgo compile fail: {msg}", AscendCLogLevel.LOG_ERROR)
        return None
    return output_file


def pgo_program_exec(temp_dir, exec_type=0, *, pgo_exec_params, mspti_so_list):
    """执行PGO, 0表示首次动态调优,1表示二次调优"""
    host_build_dir = os.path.join(temp_dir, "host")
    pgo_file = os.path.join(host_build_dir, "pgo")
    device_id = pgo_exec_params.get("device_id")
    aiv_num = pgo_exec_params.get("aiv_num")
    ub_size = pgo_exec_params.get("ub_size")
    kernel_name = pgo_exec_params.get("kernel_name")
    try:
        graph_so = f"{ASCEND_PATH}/lib64/libgraph_af.so"
        ld_preload_items = [graph_so] + mspti_so_list
        existing_ld_preload = os.environ.get("LD_PRELOAD")
        if existing_ld_preload:
            ld_preload_items.append(existing_ld_preload)
        ld_preload = ":".join([p for p in ld_preload_items if p])
        result = subprocess.run(
            [
                pgo_file,
                str(exec_type),
                str(device_id),
                str(aiv_num),
                str(ub_size),
                str(kernel_name),
            ],
            text=True,
            env={
                **dict(os.environ),
                "LD_PRELOAD": ld_preload,
            },
            timeout=1800 # 暂定,设置超时时间为1800秒,防止大解集case提前退出
        )
        if result.returncode != 0:
            msg = f"pgo exec fail, return code: {result.returncode}\n"
            # 执行失败时,回退原始解,不打断模型正常执行流程
            CommonUtility.print_compile_log("", msg, AscendCLogLevel.LOG_ERROR)
            return False
        return True
    except Exception as e:
        CommonUtility.print_compile_log("", f"pgo exec fail: {str(e)}", AscendCLogLevel.LOG_ERROR)
        return False


def get_pgo_exec_params(*args, params):
    """提取PGO执行所需的基础参数。

    Returns:
        dict: {"device_id": str, "aiv_num": str, "ub_size": str, "kernel_name": str}
    """
    device_id = params.get('device_id')
    device_id = "0" if device_id is None or not str(device_id).isdigit() else str(device_id)

    soc_vector_core_cnt = int(get_soc_spec('vector_core_cnt'))
    param_vector_core_num = soc_vector_core_cnt if params['vector_core_num'] is None else int(params['vector_core_num'])
    aiv_num = str(min(param_vector_core_num, soc_vector_core_cnt))
    ub_size = str(get_soc_spec('ub_size'))
    kernel_name = args[-1]

    return {
        "device_id": device_id,
        "aiv_num": aiv_num,
        "ub_size": ub_size,
        "kernel_name": kernel_name,
    }


def pgo_second_optimization(*args, temp_dir, params, op_kernel_src, mspti_so_list):
    """静态tiling的二次调优"""
    from autofuse.compile_adapter import pgo_get_top_result, pgo_write_config, pgo_generate_config

    pgo_exec_params = get_pgo_exec_params(*args, params=params)
    kernel_name = pgo_exec_params["kernel_name"]
    schedule_results = params['schedule_results']
    graph_name = camel_to_snake(gen_valid_name(schedule_results.get_name()))
    input_num = schedule_results.get_input_num()
    output_num = schedule_results.get_output_num()
    use_list_tensor_desc = op_kernel_src.find('kernel_operator_list_tensor_intf.h') > 0
    enable_parallel_compile = op_kernel_src.rfind('void fake_tiling_ids()') > 0
    kernel_build_dir = os.path.join(temp_dir, "device")
    pgo_dir = get_pgo_dir(temp_dir)
    config_path = os.path.join(pgo_dir, f"{graph_name}_config.txt")
    search_path = os.path.join(pgo_dir, f"{graph_name}_search.txt")
    top_n = get_pgo_topn()

    top_lines, origin_line, top_n = pgo_get_top_result(search_path, top_n)
    if top_lines is None or origin_line is None or top_n is None:
        return False

    def comile_and_exec(config_line):
        tiling_key_list, kernel_type_list = [-1], ["KERNEL_TYPE_AIV_ONLY"]
        pgo_write_config(config_path, config_line)
        static_shape_compile(kernel_name=kernel_name, temp_dir=temp_dir, graph_name=graph_name,
                             tiling_key_list=tiling_key_list, kernel_type_list=kernel_type_list)
        try:
            kernel_file, json_file = ascbc_kernel_compile(args, graph_name=graph_name, kernel_name=kernel_name,
                input_num=input_num, output_num=output_num, temp_build_dir=kernel_build_dir,
                impl_mode=params.get('impl_mode'), use_list_tensor_desc=use_list_tensor_desc,
                enable_parallel_compile=enable_parallel_compile, tiling_key=tiling_key_list[0],
                kernel_type=kernel_type_list[0], is_cube=False)
            shutil.copy(kernel_file, os.path.join(pgo_dir, os.path.basename(kernel_file)))
            shutil.copy(json_file, os.path.join(pgo_dir, os.path.basename(json_file)))
        except Exception as e:
            logger.error("[pgo] kernel compile for %s failed: %s. revert pgo tune.", kernel_name, str(e))
            return False

        result = pgo_program_exec(temp_dir=temp_dir, exec_type=1, mspti_so_list=mspti_so_list,
                      pgo_exec_params=pgo_exec_params)
        return result

    for line in top_lines:
        result = comile_and_exec(line)
        if result is False:
            return False

    result = comile_and_exec(origin_line)
    if result is False:
        return False

    pgo_generate_config(search_path, config_path, top_n)
    return True


def asc_pgo_exec(*args, temp_dir, params, op_kernel_src, code_gen):
    """执行PGO"""
    schedule_results = params['schedule_results']
    graph_name = camel_to_snake(gen_valid_name(schedule_results.get_name()))
    logger.info(f"[PGO] Start PGO tuning for graph: {graph_name}")
    mspti_cfg = pgo_get_mspti_config()
    if mspti_cfg is None:
        logger.warn("[PGO] libmspti.so not installed, skip pgo tuning")
        return
    mspti_dir, mspti_so_list, mspti_link_flags = mspti_cfg

    pgo_dir = get_pgo_dir(temp_dir)
    os.makedirs(pgo_dir, exist_ok=True)
    host_build_dir = os.path.join(temp_dir, "host")

    pgo_temp_files = []

    config_path = os.path.join(pgo_dir, f"{graph_name}_config.txt")
    if os.path.exists(config_path):
        logger.info(f"[PGO] {config_path} exist, skip pgo tuning")
        return

    timestamp_set(True, graph_name, "CompileKernelForPGO")
    result, pgo_temp_files = pgo_kernel_compile(*args, temp_dir=temp_dir, params=params, op_kernel_src=op_kernel_src,
                                                code_gen=code_gen)
    timestamp_set(False, graph_name, "CompileKernelForPGO")
    if result is False:
        pgo_cleanup_kernel_and_json(pgo_temp_files)
        return

    timestamp_set(True, graph_name, "GenerateForPGO")
    generate_pgo_code(params=params, code_gen=code_gen, pgo_dir=pgo_dir, host_build_dir=host_build_dir)
    timestamp_set(False, graph_name, "GenerateForPGO")

    timestamp_set(True, graph_name, "CompilePGO")
    output = pgo_compile(graph_name=graph_name, host_build_dir=host_build_dir,
                         mspti_dir=mspti_dir, mspti_link_flags=mspti_link_flags)
    timestamp_set(False, graph_name, "CompilePGO")
    if output is None:
        pgo_cleanup_kernel_and_json(pgo_temp_files)
        return

    timestamp_set(True, graph_name, "RunForPGO")
    pgo_exec_params = get_pgo_exec_params(*args, params=params)
    result = pgo_program_exec(temp_dir=temp_dir, exec_type=0, mspti_so_list=mspti_so_list,
                              pgo_exec_params=pgo_exec_params)
    timestamp_set(False, graph_name, "RunForPGO")
    if result is False:
        pgo_cleanup_kernel_and_json(pgo_temp_files, config_path=config_path)
        return

    timestamp_set(True, graph_name, "SecondOptimizeForPGO")
    result = pgo_second_optimization(*args, temp_dir=temp_dir, params=params, op_kernel_src=op_kernel_src,
                                     mspti_so_list=mspti_so_list)
    timestamp_set(False, graph_name, "SecondOptimizeForPGO")
    if result is False:
        pgo_cleanup_kernel_and_json(pgo_temp_files, config_path=config_path)
        return

    pgo_cleanup_kernel_and_json(pgo_temp_files)


def _build_mm_args(args_list, input_num, mm_attr1, mm_attr2):
    _inputs_ = []
    _origin_inputs_ = []
    _origin_outputs_ = []
    _m = 0
    _n = 0
    logger.info("CV fusion op, matmul input(%s)", input_num)
    # 遍历args中的每个元素
    for i, arg in enumerate(args_list[:input_num]):
        if i >= input_num:
            continue
        msg = "Processing input " + str(i) + ":" + str(arg)
        logger.info("CV fusion op, matmul input info: %s", msg)
        CommonUtility.print_compile_log("", msg, AscendCLogLevel.LOG_INFO)
        _origin_inputs_.append(arg)
        if arg is not None:
            if isinstance(arg, (list, tuple)) and len(arg) != 0:
                _inputs_.append(arg[0])
            else:
                _inputs_.append(arg)
        else:
            _inputs_.append(arg)
        _inputs_[-1]["param_name"] = "input" + str(i)
        shape = _inputs_[-1]["shape"]
        _inputs_[-1]["shape"] = shape
        _inputs_[-1]["ori_shape"] = shape
        if i == 0:
            write_shape = list(shape)
            if mm_attr1['value']: # a矩阵是否有transpose
                _m = shape[-1]
            else:
                _m = shape[-2]
        if i == 1:
            if mm_attr2['value']: # b矩阵是否有transpose
                _n = shape[-2]
            else:
                _n = shape[-1]

    msg = "Processing output " + ":" + str(args_list[-2])
    logger.info("CV fusion op, matmul output info: %s", msg)
    CommonUtility.print_compile_log("", msg, AscendCLogLevel.LOG_INFO)
    _origin_outputs_.append(args_list[-2])
    _origin_outputs_[-1]["param_name"] = "output0"
    write_shape[-1] = _n
    write_shape[-2] = _m
    _origin_outputs_[-1]["shape"] = tuple(write_shape)
    _origin_outputs_[-1]["dtype"] = _inputs_[-1]["dtype"]
    _origin_outputs_[-1]["ori_shape"] = tuple(write_shape)
    logger.info("CV fusion op, m=%s, n=%s new matmul output info: %s", _m, _n, str(_origin_outputs_[-1]))
    return _origin_inputs_, _origin_outputs_, _inputs_


def build_conv_args(args_list, input_num, data_format):
    _inputs_ = []
    _origin_inputs_ = []
    _origin_outputs_ = []
    logger.info("CV fusion op, conv input(%s)", input_num)

    format_shape_transfer_map = {
        ("HWCN", "NCHW"): [3, 2, 0, 1],
        ("NCHW", "HWCN"): [2, 3, 1, 0],
        ("NHWC", "NCHW"): [0, 3, 1, 2],
        ("NCHW", "NHWC"): [0, 2, 3, 1],
    }

    def transfer_shape_between_formats(shape, old_format, new_format):
        """格式转换时对应的 shape 维度重排,参考 TransferShapeBetweenHwcnNchw"""
        if len(shape) != 4:
            return shape
        if old_format == new_format:
            return shape
        transfer_key = (old_format.upper(), new_format.upper())
        if transfer_key in format_shape_transfer_map:
            indices = format_shape_transfer_map[transfer_key]
            return tuple(shape[i] for i in indices)
        return shape

    # 遍历args中的每个元素
    for i, input_arg in enumerate(args_list[:input_num]):
        if i >= input_num:
            continue
        op_msg = "Processing input " + str(i) + ":" + str(input_arg)
        logger.info("CV fusion op, conv input info: %s", op_msg)
        CommonUtility.print_compile_log("", op_msg, AscendCLogLevel.LOG_INFO)
        _origin_inputs_.append(input_arg)
        if input_arg is not None:
            if isinstance(input_arg, (list, tuple)) and len(input_arg) != 0:
                _inputs_.append(input_arg[0])
            else:
                _inputs_.append(input_arg)
        else:
            _inputs_.append(input_arg)
        _inputs_[-1]["param_name"] = "input" + str(i)
        shape = _inputs_[-1]["shape"]
        src_format = _inputs_[-1]["format"]
        # 如果输入格式与目标格式不匹配,需要进行转换
        if src_format.upper() != data_format.upper():
            # 根据 format 转换规则重排 shape 维度
            new_shape = transfer_shape_between_formats(shape, src_format, data_format)
            _inputs_[-1]["shape"] = new_shape
            _inputs_[-1]["ori_shape"] = new_shape
        else:
            _inputs_[-1]["shape"] = shape
            _inputs_[-1]["ori_shape"] = shape
        _inputs_[-1]["ori_format"] = data_format
        _inputs_[-1]["format"] = data_format

    op_msg = "Processing output " + ":" + str(args_list[-2])
    logger.info("CV fusion op, conv output info: %s", op_msg)
    CommonUtility.print_compile_log("", op_msg, AscendCLogLevel.LOG_INFO)
    _origin_outputs_.append(args_list[-2])
    _origin_outputs_[-1]["param_name"] = "output0"
    _origin_outputs_[-1]["shape"] = args_list[-2]["shape"]
    _origin_outputs_[-1]["dtype"] = _inputs_[-1]["dtype"]
    _origin_outputs_[-1]["ori_shape"] = args_list[-2]["shape"]
    _origin_outputs_[-1]["ori_format"] = data_format
    _origin_outputs_[-1]["format"] = data_format
    logger.info("CV fusion op, new conv output info: %s", str(_origin_outputs_[-1]))
    return _origin_inputs_, _origin_outputs_, _inputs_


def _process_tiling_info(is_batch, compile_info, origin_inputs, origin_outputs, attrs, tiling_key_list):
    tiling_info = TilingInfo()
    op_type = "BatchMatMulV3" if is_batch else "MatMulV3"
    tiling_data_type = f"{op_type}TilingData"

    run_info = do_op_tiling(op_type, compile_info, origin_inputs, origin_outputs, None, None, attrs)
    tiling_info.tiling_data = run_info["tiling_data"]
    tiling_info.tiling_key = run_info['tiling_key']
    tiling_key_list[0] = tiling_info.tiling_key
    cube_block_dim = run_info['block_dim']
    tiling_info.file_content = gen_static_shape_v2(op_type, tiling_data_type, run_info["tiling_data"])

    return tiling_info, cube_block_dim


def is_matmul_relu_fixpip(tiling_info, cube_info):
    has_relu = cube_info[4]
    tiling_key_transpose_and_full_load_mask = 0xF10F0  # transpose/全载模式/基础或者切k模板
    cube_tiling_key_fixpip = tiling_info.tiling_key & ~tiling_key_transpose_and_full_load_mask
    return has_relu and (cube_tiling_key_fixpip == 1)


def template_decider(kernel_name, temp_dir, graph_name, tiling_info, cube_info, is_conv=False):
    _, is_batch, cube_block_dim, use_cv_common, has_relu = cube_info[:5]
    tiling_key = static_shape_cv_compile(kernel_name=kernel_name, temp_dir=temp_dir,
                                         graph_name=graph_name)
    logger.info("CV fusion op, get vector tilingkey(%s)", tiling_key)
    use_cv_common = use_cv_common or [False]
    tiling_key_transpose_mask = 0xF0
    cube_tiling_key_ub = tiling_info.tiling_key & ~tiling_key_transpose_mask
    if (is_matmul_relu_fixpip(tiling_info, cube_info) and not is_conv):
        logger.info("CV fusion op, entering fixpip fusion mode.")
        tiling_info.file_content += "\n#define CV_UB_NO_DB 1\n" # 防止编译问题
    elif ((tiling_key == -1 or cube_tiling_key_ub != 1) and not is_conv):
        is_in_mix_white_list = (is_batch and tiling_info.tiling_key in CV_COMMON_BMM_MIX_WHITE_LIST) or (
                not is_batch and tiling_info.tiling_key in CV_COMMON_MIX_WHITE_LIST)
        if is_in_mix_white_list:
            tiling_info.file_content += "\n#define CV_SAFETY_FUSION_MIX_MODE 1\n"
        logger.info("CV fusion op, entering safety fusion mode. vector_tiling_key=%s, is_batch=%s, cube_tiling_key=%s",
                    tiling_key, is_batch, tiling_info.tiling_key)
        tiling_info.file_content += "\n#define CV_SAFETY_FUSION 1\n"
        use_cv_common[0] = True
        vec_block_dim, wss = static_shape_cv_common_compile(kernel_name=kernel_name, temp_dir=temp_dir,
                                                            graph_name=graph_name)
        logger.info("CV fusion op, CV_AIC_NUM=[%s] CV_AIV_NUM=[%s] CV_VEC_WSS=[%s]", str(cube_block_dim),
                    str(vec_block_dim), str(wss))
        for name, value in [("CV_AIC_NUM", cube_block_dim), ("CV_AIV_NUM", vec_block_dim), ("CV_VEC_WSS", wss)]:
            if value >= 0:
                tiling_info.file_content += f"\n#define {name} {value}\n"
    else:
        logger.info("CV fusion op, entering UB fusion mode, cube_tiling_key=%s, vector tilingkey=%s.",
                    tiling_info.tiling_key, tiling_key)
        tiling_info.file_content += "\n#define CV_UB_FUSION 1\n"
        if (tiling_key == 0):
            logger.info("CV fusion op, entering UB fusion mode with no db.")
            tiling_info.file_content += "\n#define CV_UB_NO_DB 1\n"
        else:  # tiling_key 为1表示UB复用循环模板(非全载模板)
            logger.info("CV fusion op, entering UB fusion mode with db.")
            tiling_info.file_content += "\n#define CV_UB_DB 1\n"


def map_dtype_to_string(dtype):
    dtype_map = {
        "bfloat16": "bfloat16_t",
        "float16": "half",
        "float32": "float"
    }

    return dtype_map.get(dtype.lower(), dtype)


def create_matmul_tiling_data(kernel_name, temp_dir, graph_name, tiling_info, cube_info):
    cube_output_type_size, is_batch, _, _, has_relu, origin_inputs, origin_outputs = cube_info[:7]

    # 根据is_batch设置结构体名称和数据访问路径
    struct_name = "BatchMatMulV3BasicTilingData" if is_batch else "MatMulV3BasicTilingData"
    data_prefix = "tmpTilingData->matMulTilingData" if is_batch else "(*tmpTilingData)"

    # 生成host端tiling数据
    host_tiling_data = f"""
#include "arch35/mat_mul_tiling_data.h"
const int32_t cube_output_type_size = {cube_output_type_size};
GET_TILING_DATA_PTR_WITH_STRUCT({struct_name}, tmpTilingData, tmpTilingGM);
"""

    # 构建class_body,使用data_prefix处理不同的数据访问路径
    class_body = f"const int32_t ub_align_value = 32 / cube_output_type_size;\n"
    class_body += f"const int32_t basen_align = ({data_prefix}.baseN + ub_align_value - 1) " \
                  f"/ ub_align_value * ub_align_value;\n"
    class_body += f"const int32_t basen_basem_align = ({data_prefix}.baseM * basen_align) / 2 + basen_align;\n"
    if (is_matmul_relu_fixpip(tiling_info, cube_info)):
        class_body += f"#define CV_RELU_FIXPIP_MODE 1\n"

    # 写入host端文件
    host_tiling_content = tiling_info.file_content + host_tiling_data + class_body
    generate_file(os.path.join(temp_dir, "host"), "autofuse_cube_tiling_data.h", host_tiling_content)
    generate_file(os.path.join(temp_dir, "host", "cv_common"), "autofuse_cube_tiling_data.h", host_tiling_content)

    # 生成决定走哪一个vector模板的宏
    template_decider(kernel_name, temp_dir, graph_name, tiling_info, cube_info)

    # 写入device端文件
    device_tiling_data = f"""\n#include "arch35/mat_mul_tiling_data.h"
#define IS_ENABLE_RELU {str(has_relu).lower()}
#define OP_TYPE_RELU_VALUE {5 if has_relu else 0}UL // 自动融合新增
#define DTYPE_X1 {map_dtype_to_string(origin_inputs[-1]["dtype"])}
#define DTYPE_X2 {map_dtype_to_string(origin_inputs[-1]["dtype"])}
#define DTYPE_Y {map_dtype_to_string(origin_outputs[-1]["dtype"])}
#define DTYPE_BIAS {map_dtype_to_string(origin_outputs[-1]["dtype"])}
"""

    tiling_info_undef = f"""
#ifndef __UNDEF_MATMULV3_HEADER__
#define __UNDEF_MATMULV3_HEADER__
#undef GET_TILING_DATA_PTR_WITH_STRUCT
#undef COPY_TILING_WITH_STRUCT
#undef COPY_TILING_WITH_ARRAY
#undef GET_TILING_DATA
#undef GET_TILING_DATA_MEMBER
#undef GET_TILING_DATA_WITH_STRUCT
#undef __tiling_data_ptr__
#endif
"""

    tiling_info.file_content += device_tiling_data
    device_tiling_content = tiling_info_undef
    device_tiling_content += tiling_info.file_content
    generate_file(os.path.join(temp_dir, "device"), "autofuse_cube_tiling_data.h", device_tiling_content)
    generate_file(os.path.join(temp_dir, "device", "cv_common"), "autofuse_cube_tiling_data.h",
                  device_tiling_content)


def create_matmul_dynamic_tiling_data(temp_dir, cube_info):
    _, is_batch, _, has_relu, origin_inputs, origin_outputs = cube_info[:7]
    # 生成host端tiling数据
    struct_name = "BatchMatMulV3BasicTilingData" if is_batch else "MatMulV3BasicTilingData"
    head_tiling_data = f"""
#ifndef DYNAMIC_MM_TILING_DATA
#define DYNAMIC_MM_TILING_DATA
#include "arch35/mat_mul_tiling_data.h"
#include "autofuse_tiling_data.h"
struct CVTilingData {{
    uint32_t fusion_mode; // 0:ub; 1:safety
    uint32_t ub_mode; // 0:no db; 1:db
    uint32_t cv_aic_num;
    uint32_t cv_aiv_num;
    uint32_t cv_vec_wss;
    uint32_t mix_mode;
}};
struct CVAutofuseTilingData {{
    AutofuseTilingData tiling_data;
    CVTilingData  cv_tiling_data;
    {struct_name} matmul_tiling_data;
}};
"""

    # 写入host端文件
    host_tiling_content = head_tiling_data
    host_tiling_content += "#endif"
    generate_file(os.path.join(temp_dir, "host"), "autofuse_cube_tiling_data.h", host_tiling_content)
    generate_file(os.path.join(temp_dir, "host", "cv_common"), "autofuse_cube_tiling_data.h", host_tiling_content)

    # 写入device端文件
    device_tiling_data = head_tiling_data
    device_tiling_data += f"""
#define CV_UB_FUSION 1
#define CV_UB_NO_DB 1
#define IS_ENABLE_RELU {str(has_relu).lower()}
#define OP_TYPE_RELU_VALUE {5 if has_relu else 0}UL // 自动融合新增
#define DTYPE_X1 {map_dtype_to_string(origin_inputs[-1]["dtype"])}
#define DTYPE_X2 {map_dtype_to_string(origin_inputs[-1]["dtype"])}
#define DTYPE_Y {map_dtype_to_string(origin_outputs[-1]["dtype"])}
#define DTYPE_BIAS {map_dtype_to_string(origin_outputs[-1]["dtype"])}
#endif
"""

    device_tiling_content = device_tiling_data
    generate_file(os.path.join(temp_dir, "device"), "autofuse_cube_tiling_data.h", device_tiling_content)
    generate_file(os.path.join(temp_dir, "device", "cv_common"), "autofuse_cube_tiling_data.h",
                  device_tiling_content)


def create_conv_tiling_data(kernel_name, temp_dir, graph_name, tiling_info, cube_info, cube_attributes):
    cube_output_type_size, _, _, is_bias, is_offset_w, origin_inputs, origin_outputs = cube_info[:7]
    struct_name = "Conv2DTilingData"
    data_prefix = "(*tmpTilingData).conv2dApiTiling"

    host_tiling_data = f"""
// conv2d
#include "arch35/conv2d_v2_tiling_def.h"
#include <algorithm>
const int32_t cube_output_type_size = {cube_output_type_size};
GET_TILING_DATA_PTR_WITH_STRUCT({struct_name}, tmpTilingData, tmpTilingGM);
"""

    class_body = f"const int32_t ub_align_value = 32 / cube_output_type_size;\n"
    class_body += f"const int32_t basen_align = ({data_prefix}.hoL0 + ub_align_value - 1) " \
                  f"/ ub_align_value * ub_align_value;\n"
    npu_arch = cube_attributes.get("npu_arch", "5102")
    vec_num = 1
    if npu_arch != "5102":
        vec_num = 2
    class_body += f"const int32_t basen_basem_align = std::max({data_prefix}.nL0 / {vec_num}, 16u) * basen_align;\n"

    host_tiling_content = tiling_info.file_content + host_tiling_data + class_body
    generate_file(os.path.join(temp_dir, "host"), "autofuse_cube_tiling_data.h", host_tiling_content)
    generate_file(os.path.join(temp_dir, "host", "cv_common"), "autofuse_cube_tiling_data.h", host_tiling_content)

    # 生成决定走哪一个vector模板的宏
    template_decider(kernel_name, temp_dir, graph_name, tiling_info, cube_info, True)

    device_tiling_data = f"""\n#include "arch35/conv2d_v2_tiling_def.h"
#define IS_ENABLE_BIAS {str(is_bias).lower()}
#define IS_ENABLE_OFFSET_W {str(is_offset_w).lower()}
#define DTYPE_X1 {map_dtype_to_string(origin_inputs[0]["dtype"])}
#define DTYPE_X2 {map_dtype_to_string(origin_inputs[1]["dtype"])}
#define DTYPE_Y {map_dtype_to_string(origin_outputs[0]["dtype"])}
#define DTYPE_BIAS {map_dtype_to_string(origin_outputs[0]["dtype"])}
"""

    tiling_data_undef = f"""
#ifndef __UNDEF_CONV2D_HEADER__
#define __UNDEF_CONV2D_HEADER__
#undef GET_TILING_DATA_PTR_WITH_STRUCT
#undef COPY_TILING_WITH_STRUCT
#undef COPY_TILING_WITH_ARRAY
#undef GET_TILING_DATA
#undef GET_TILING_DATA_MEMBER
#undef GET_TILING_DATA_WITH_STRUCT
#undef __tiling_data_ptr__
#endif
"""

    tiling_info.file_content += device_tiling_data
    device_tiling_content = tiling_data_undef
    device_tiling_content += tiling_info.file_content
    generate_file(os.path.join(temp_dir, "device"), "autofuse_cube_tiling_data.h", device_tiling_content)
    generate_file(os.path.join(temp_dir, "device", "cv_common"), "autofuse_cube_tiling_data.h",
                  device_tiling_content)


def ascbc_conv_kernel_tiling_pro(
    *args,
    temp_dir,
    graph_name,
    kernel_name,
    input_num,
    output_num,
    use_list_tensor_desc,
    cube_attrs,
    tiling_key_list,
    use_cv_common=None
):
    graph_name = camel_to_snake(graph_name)
    args_list = args[0]
    if use_list_tensor_desc:
        inputs = args_list[:input_num]
        outputs = args_list[input_num: input_num + output_num]
        args_list = [inputs, outputs]
        input_num = 1
        output_num = 1

    if cube_attrs is None:
        logger.error("kernel_name=[%s] can't find cube attrs", kernel_name)
        return
    cube_attributes = cube_attrs.get("cube_attributes", {})
    if not cube_attributes:
        logger.error("kernel_name=[%s] can't find cube attributes", kernel_name)
        return

    is_bias = cube_attributes.get("is_bias", False)
    is_offset_w = cube_attributes.get("is_offset_w", False)
    enable_hf32 = cube_attributes.get("enable_hf32", False)
    offset_x = cube_attributes.get("offset_x", 0)
    groups = cube_attributes.get("groups", 1)
    pad_mode = cube_attributes.get("pad_mode", "SPECIFIC")
    data_format = cube_attributes.get("data_format", "NCHW")
    strides = cube_attributes.get("strides", [1, 1])
    pads = cube_attributes.get("pads", [0, 0, 0, 0])
    dilations = cube_attributes.get("dilations", [1, 1])

    conv_attr1 = {"name": "strides", "dtype": "list_int", "value": strides}
    conv_attr2 = {"name": "pads", "dtype": "list_int", "value": pads}
    conv_attr3 = {"name": "dilations", "dtype": "list_int", "value": dilations}
    conv_attr4 = {"name": "groups", "dtype": "int", "value": groups}
    conv_attr5 = {"name": "data_format", "dtype": "str", "value": data_format}
    conv_attr6 = {"name": "offset_x", "dtype": "int", "value": offset_x}
    conv_attr7 = {"name": "pad_mode", "dtype": "str", "value": pad_mode}
    conv_attr8 = {"name": "enable_hf32", "dtype": "bool", "value": enable_hf32}
    conv_attr9 = {"name": "is_bias", "dtype": "bool", "value": is_bias}
    conv_attr10 = {"name": "is_offset_w", "dtype": "bool", "value": is_offset_w}

    conv_input_num = cube_attributes.get("input_num", 0)
    _origin_inputs_, _origin_outputs_, _inputs_ = \
        build_conv_args(args_list, input_num=conv_input_num, data_format=data_format)

    attrs = [conv_attr1, conv_attr2, conv_attr3, conv_attr4, conv_attr5, conv_attr6,
             conv_attr7, conv_attr8, conv_attr9, conv_attr10]
    tiling_info = TilingInfo()
    context = get_context()
    _change_param_name_to_name(_inputs_)
    _change_param_name_to_name(_origin_inputs_)
    compile_info = context.get_compile_info()
    tiling_config = {"name": "ascendc_op_para_size", "dtype": "int", "value": 2 * 1024 * 1024}
    attrs.append(tiling_config)

    op_type = "Conv2DV2"
    tiling_data_type = "Conv2DTilingData"

    run_info = do_op_tiling(op_type, compile_info, _origin_inputs_, _origin_outputs_, None, None, attrs)
    tiling_info.tiling_data = run_info["tiling_data"]
    tiling_info.tiling_key = run_info['tiling_key']
    tiling_key_list[0] = tiling_info.tiling_key
    cube_block_dim = run_info['block_dim']
    tiling_info.file_content = gen_static_shape_v2(op_type, tiling_data_type, run_info["tiling_data"])

    logger.info("kernel_name=[%s], cube tiling_key[%s], tiling_data=[%s], tiling_file_context=[%s]", kernel_name, str(
        tiling_info.tiling_key), str(tiling_info.tiling_data), str(tiling_info.file_content))
    cube_output_type_size = cube_attributes.get("type_size", 4)
    cube_info = [cube_output_type_size, cube_block_dim, use_cv_common, is_bias, is_offset_w,
                 _origin_inputs_, _origin_outputs_]
    create_conv_tiling_data(kernel_name, temp_dir, graph_name, tiling_info, cube_info, cube_attributes)


def ascbc_matmul_kernel_tiling_pro(
    *args,
    temp_dir,
    graph_name,
    kernel_name,
    input_num,
    output_num,
    use_list_tensor_desc,
    cube_attrs,
    tiling_key_list,
    use_cv_common=None
):
    graph_name = camel_to_snake(graph_name)
    args_list = args[0]
    if use_list_tensor_desc:
        inputs = args_list[:input_num]
        outputs = args_list[input_num: input_num + output_num]
        args_list = [inputs, outputs]
        input_num = 1
        output_num = 1

    if cube_attrs is None:
        logger.error("kernel_name=[%s] can't find cube attrs", kernel_name)
        return
    cube_attributes = cube_attrs.get("cube_attributes", {})
    if not cube_attributes:
        logger.error("kernel_name=[%s] can't find cube attributes", kernel_name)
        return

    is_batch = cube_attributes.get("is_batch", False)
    has_relu = cube_attributes.get("has_relu", False)

    mm_attr1 = {"name":"transpose_x1", "dtype":"bool", "value":cube_attributes.get("transpose_x1", False)}
    mm_attr2 = {"name":"transpose_x2", "dtype":"bool", "value":cube_attributes.get("transpose_x2", False)}
    mm_attr3 = {"name":"offset_x", "dtype":"int", "value":cube_attributes.get("offset_x", 0)}
    if is_batch:
        mm_attr4 = {"name":"enable_hf32", "dtype":"bool", "value":cube_attributes.get("enable_hf32", False)}
    else:
        mm_attr4 = {"name":"opImplMode", "dtype":"int", "value":cube_attributes.get("enable_hf32", 0x1)}

    mm_input_num = cube_attributes.get("input_num", 0)
    _origin_inputs_, _origin_outputs_, _inputs_ = \
        _build_mm_args(args_list, input_num=mm_input_num, mm_attr1=mm_attr1, mm_attr2=mm_attr2)

    attrs = [mm_attr1, mm_attr2, mm_attr3, mm_attr4]
    tiling_info = TilingInfo()
    context = get_context()
    _change_param_name_to_name(_inputs_)
    _change_param_name_to_name(_origin_inputs_)
    compile_info = context.get_compile_info()
    tiling_config = {"name":"ascendc_op_para_size", "dtype":"int", "value":2 * 1024 * 1024}
    attrs.append(tiling_config)
    is_batch = cube_attributes.get("is_batch", False)
    tiling_info, cube_block_dim = _process_tiling_info(is_batch, compile_info, _origin_inputs_, _origin_outputs_, attrs,
                                                       tiling_key_list)

    logger.info("kernel_name=[%s], cube tiling_key[%s], tiling_data=[%s], tiling_file_context=[%s]", kernel_name, str(
        tiling_info.tiling_key), str(tiling_info.tiling_data), str(tiling_info.file_content))
    cube_output_type_size = cube_attributes.get("type_size", 4)
    cube_info = [cube_output_type_size, is_batch, cube_block_dim, use_cv_common, has_relu, _origin_inputs_,
                 _origin_outputs_]
    create_matmul_tiling_data(kernel_name, temp_dir, graph_name, tiling_info, cube_info)


def ascbc_matmul_kernel_dynamic_tiling_pro(
    *args,
    temp_dir,
    graph_name,
    kernel_name,
    input_num,
    output_num,
    use_list_tensor_desc,
    cube_attrs,
    use_cv_common=None
):
    graph_name = camel_to_snake(graph_name)
    args_list = args[0]
    if use_list_tensor_desc:
        inputs = args_list[:input_num]
        outputs = args_list[input_num: input_num + output_num]
        args_list = [inputs, outputs]
        input_num = 1
        output_num = 1

    if cube_attrs is None:
        logger.error("kernel_name=[%s] can't find cube attrs", kernel_name)
        return
    cube_attributes = cube_attrs.get("cube_attributes", {})
    if not cube_attributes:
        logger.error("kernel_name=[%s] can't find cube attributes", kernel_name)
        return

    is_batch = cube_attributes.get("is_batch", False)
    has_relu = cube_attributes.get("has_relu", False)

    mm_attr1 = {"name":"transpose_x1", "dtype":"bool", "value":cube_attributes.get("transpose_x1", False)}
    mm_attr2 = {"name":"transpose_x2", "dtype":"bool", "value":cube_attributes.get("transpose_x2", False)}

    mm_input_num = cube_attributes.get("input_num", 0)
    _origin_inputs_, _origin_outputs_, _ = \
        _build_mm_args(args_list, input_num=mm_input_num, mm_attr1=mm_attr1, mm_attr2=mm_attr2)

    cube_output_type_size = cube_attributes.get("type_size", 4)
    cube_info = [cube_output_type_size, is_batch, use_cv_common, has_relu, _origin_inputs_,
                 _origin_outputs_]
    create_matmul_dynamic_tiling_data(temp_dir, cube_info)


def asc_graph_compile_post(
    host_build_dir,
    code_gen,
    graph_name,
    kernel_file,
    json_file,
    kernel_name,
    static_compile_flag,
    is_cube=False
):
    # 生成get_kernel.cpp代码
    timestamp_set(True, graph_name, "GenGetKernel")
    get_kernel_src = code_gen.get_kernel_and_json_generator(kernel_file, json_file)
    generate_file(host_build_dir, graph_name + "_get_kernel.cpp", get_kernel_src)
    timestamp_set(False, graph_name, "GenGetKernel")

    #重新编译host so
    timestamp_set(True, graph_name, "CompileSecondHost")
    ascbc_host_compile(graph_name, kernel_name, host_build_dir, True, static_compile_flag, is_cube)
    timestamp_set(False, graph_name, "CompileSecondHost", True)

    #拷贝so和json
    host_so = os.path.join(host_build_dir, f"lib{kernel_name}.so")
    kernel_meta_dir = get_current_build_config("kernel_meta_parent_dir")
    shutil.copy(host_so, os.path.join(kernel_meta_dir, "kernel_meta"))
    modify_json_file(json_file, host_so)
    if os.path.exists(kernel_file):
        os.remove(kernel_file)


def get_graph_basic_info(params, args) -> tuple:
    schedule_results = params['schedule_results']
    graph_name = camel_to_snake(gen_valid_name(schedule_results.get_name()))
    input_num = schedule_results.get_input_num()
    output_num = schedule_results.get_output_num()
    is_cube = schedule_results.is_cube_type()
    cube_attrs = schedule_results.get_cube_attributes()
    is_conv = schedule_results.is_conv_type()
    kernel_name = args[-1]
    vector_core_num = params['vector_core_num']
    msg = (
        f"graph_name:{graph_name}, input_num:{input_num}, output_num:{output_num},"
        f"kernel_name:{kernel_name}, vector_core_num:{vector_core_num}, is_cube:{is_cube}"
    )
    CommonUtility.print_compile_log("", msg, AscendCLogLevel.LOG_DEBUG)
    return graph_name, input_num, output_num, is_cube, is_conv, cube_attrs


def create_compile_dirs(temp_dir) -> tuple:
    """创建device/host/pgo目录"""
    kernel_build_dir = os.path.join(temp_dir, "device")
    host_build_dir = os.path.join(temp_dir, "host")
    for dir_path in [kernel_build_dir, host_build_dir]:
        os.makedirs(dir_path, exist_ok=True)
    return kernel_build_dir, host_build_dir


def copy_so_and_modify_json(host_build_dir, kernel_name, json_file):
    host_so = os.path.join(host_build_dir, f"lib{kernel_name}.so")
    kernel_meta_dir = get_current_build_config("kernel_meta_parent_dir")
    shutil.copy(host_so, os.path.join(kernel_meta_dir, "kernel_meta"))
    modify_json_file(json_file, host_so)


def asc_graph_compile(*args, temp_dir, params):
    """入口为asc graph场景"""
    graph_name, input_num, output_num, is_cube, is_conv, cube_attrs = get_graph_basic_info(params, args)
    tiling_key_list, kernel_type_list = [-1], ["KERNEL_TYPE_AIV_ONLY"]
    _, host_build_dir = create_compile_dirs(temp_dir)
    kernel_name = args[-1]
    code_gen = CodeGen()

    op_kernel_src, tiling_func_srcs = generate_device_and_host_code(
        graph_name=graph_name, temp_dir=temp_dir, params=params, code_gen=code_gen)
    static_compile_flag = is_static_compile(params, tiling_func_srcs)
    use_list_tensor_desc = op_kernel_src.find('kernel_operator_list_tensor_intf.h') > 0
    enable_parallel_compile = op_kernel_src.rfind('void fake_tiling_ids()') > 0
    use_cv_common = [False]
    replace_root = get_replace_kernel_root()

    if is_cube:
        if (static_compile_flag):
            if (is_conv):
                ascbc_conv_kernel_tiling_pro(args, temp_dir=temp_dir, graph_name=graph_name, kernel_name=kernel_name,
                                             input_num=input_num, output_num=output_num,
                                             use_list_tensor_desc=use_list_tensor_desc, cube_attrs=cube_attrs,
                                             tiling_key_list=tiling_key_list, use_cv_common=use_cv_common)
            else:
                ascbc_matmul_kernel_tiling_pro(args, temp_dir=temp_dir, graph_name=graph_name, kernel_name=kernel_name,
                                               input_num=input_num, output_num=output_num,
                                               use_list_tensor_desc=use_list_tensor_desc, cube_attrs=cube_attrs,
                                               tiling_key_list=tiling_key_list, use_cv_common=use_cv_common)
            replace_host_files_if_needed(
                replace_root, temp_dir, use_cv_common, graph_name)
            static_shape_compile(kernel_name=kernel_name, temp_dir=temp_dir, graph_name=graph_name,
                                 tiling_key_list=tiling_key_list, kernel_type_list=kernel_type_list,
                                 use_cv_common=use_cv_common, is_cube=is_cube)
        else:
            if (not is_conv):
                replace_host_files_if_needed(
                    replace_root, temp_dir, use_cv_common, graph_name)
                ascbc_matmul_kernel_dynamic_tiling_pro(args, temp_dir=temp_dir, graph_name=graph_name,
                                                       kernel_name=kernel_name, input_num=input_num,
                                                       output_num=output_num,
                                                       use_list_tensor_desc=use_list_tensor_desc,
                                                       cube_attrs=cube_attrs, use_cv_common=use_cv_common)
                dynamic_shape_compile(kernel_name=kernel_name, temp_dir=temp_dir, graph_name=graph_name,
                                      use_cv_common=use_cv_common, is_cube=is_cube)

    elif static_compile_flag:
        pgo_env = get_pgo_env_flag()
        if pgo_env and not enable_parallel_compile and not use_list_tensor_desc:
            asc_pgo_exec(*args, temp_dir=temp_dir, params=params, op_kernel_src=op_kernel_src, code_gen=CodeGen())
        replace_host_files_if_needed(replace_root, temp_dir, use_cv_common, graph_name)
        timestamp_set(True, graph_name, "CompileHost")
        static_shape_compile(kernel_name=kernel_name, temp_dir=temp_dir, graph_name=graph_name,
                             tiling_key_list=tiling_key_list, kernel_type_list=kernel_type_list)
        timestamp_set(False, graph_name, "CompileHost")

    kernel_build_dir = get_device_build_dir(temp_dir, use_cv_common)
    replace_kernel(kernel_build_dir=kernel_build_dir, graph_name=graph_name)
    timestamp_set(True, graph_name, "CompileDevice")
    kernel_file, json_file = ascbc_kernel_compile(args, graph_name=graph_name, kernel_name=kernel_name,
                                                  input_num=input_num, output_num=output_num,
                                                  temp_build_dir=kernel_build_dir, impl_mode=params.get('impl_mode'),
                                                  use_list_tensor_desc=use_list_tensor_desc,
                                                  enable_parallel_compile=enable_parallel_compile,
                                                  tiling_key=tiling_key_list[0], kernel_type=kernel_type_list[0],
                                                  is_cube=is_cube, is_conv=is_conv)
    timestamp_set(False, graph_name, "CompileDevice")
    host_build_dir = get_host_build_dir(temp_dir, use_cv_common)
    asc_graph_compile_post(host_build_dir, code_gen, graph_name, kernel_file,
                           json_file, kernel_name, static_compile_flag, is_cube)


def compute_graph_compile(*args, temp_dir, params, vector_core_num, device_id):
    """入口为compute graph场景"""
    timestamp_set(True, "start", "Schedule")
    compute_graph = ascir.utils.deserialize("compute_graph", params['compute_graph'])
    logger.info("[AUTOFUSE_BACKEND]compile start graph[%s], build dir[%s], kernel_name[%s], "
                "symbol_source_info[%s], output_symbol_shape[%s]",
                compute_graph.get_name(), temp_dir, args[-1], params.get('symbol_source_info', "None"),
                params.get('output_symbol_shape', "[]"))
    value = params.get('symbol_source_info')
    CommonUtility.print_compile_log("", f"symbol_source_info is {type(value)}, {value}", AscendCLogLevel.LOG_DEBUG)
    if value is None or value == 'null':
        CommonUtility.print_compile_log("", "maybe a static shape", AscendCLogLevel.LOG_WARNING)
        symbol_source_info = None
    else:
        symbol_source_info = ascir.utils.deserialize("symbol_source_info", value)
    #Do schedule
    sched = Schedule()
    schedule_results = sched.scheduleV2(compute_graph)
    timestamp_set(False, gen_valid_name(compute_graph.get_name()), "Schedule")
    #构造ascgraph参数
    asc_param = {}
    asc_param['schedule_results'] = schedule_results
    asc_param['symbol_source_info'] = symbol_source_info
    asc_param['output_symbol_shape'] = params.get('output_symbol_shape', "[]")
    asc_param['vector_core_num'] = vector_core_num
    asc_param['device_id'] = device_id
    asc_graph_compile(*args, temp_dir=temp_dir, params=asc_param)


def asc_codegen_compile(*args, **kwargs):
    """自动融合算子注册到tefusion的统一编译入口,编译完成后,将编译后的`so`和`json`写入`kernel meta`目录
    **args**
    * inputs:包含输入shape、dtype
    * outputs:包含输出shape、dtype
    * attrs:对于自动融合节点,不存在原型属性,因此该字段不存在
    * kernel_name: kernel name用于生成so名称 确保在最后一个参数中

    **op_info.extra_params**
    * fused_graph: compute graph序列化后的字符串 <string>
    * symbol_source_info: shape的符号化信息,存储了如`s0`是第0个输入的第2个维度信息 <dict>
    * output_symbol_shape: infer shape所需符号化表达 <list<str>>
    """

    def asc_codegen_compile_with_tmpdir(*args, temp_dir, **kwargs):
        import tbe.common.context.op_context as op_context
        context = op_context.get_context()
        op_info = context.get_op_info()
        extra_params = op_info[0].extra_params
        vector_core_num = op_context.get_context().get_addition("_op_vectorcore_num")
        device_id = op_context.get_context().get_addition("device_id")

        # 设置 platform 到 context
        npu_arch = get_soc_spec('NpuArch')
        if npu_arch:
            ascir.utils.set_platform(npu_arch, int(get_soc_spec('vector_core_cnt')), int(get_soc_spec('ub_size')))
            CommonUtility.print_compile_log("", f"Set platform from get_soc_spec: {npu_arch}",
                                            AscendCLogLevel.LOG_DEBUG)

        CommonUtility.print_compile_log("", f"params type: {type(extra_params)}, params: {extra_params}",
                                            AscendCLogLevel.LOG_DEBUG)
        #反序列化graph和symbol_source_info
        if not check_keys_in_dict(extra_params, ['compute_graph', 'symbol_source_info']):
            CommonUtility.print_compile_log("", f"compute_graph and symbol_source_info do not exist",
                                            AscendCLogLevel.LOG_ERROR)
            raise Exception("An error occurred autofuse compile for check extra_params")
        compute_graph_compile(*args,
                              temp_dir=temp_dir,
                              params=extra_params,
                              vector_core_num=vector_core_num,
                              device_id=device_id)


    kernel_meta_dir = get_current_build_config("kernel_meta_parent_dir")
    if not os.path.isabs(kernel_meta_dir):
        kernel_meta_dir = os.path.abspath(kernel_meta_dir)

    if not is_valid_path(kernel_meta_dir):
        CommonUtility.print_compile_log("", f"invalid kernel meta path : `{kernel_meta_dir}' ",
                                        AscendCLogLevel.LOG_ERROR)
        raise Exception("An error occurred autofuse compile for check kernel_meta path")

    if not os.path.exists(kernel_meta_dir):
        CommonUtility.print_compile_log("", f"kernel meta parent dir is not exist : `{kernel_meta_dir}' ",
                                        AscendCLogLevel.LOG_ERROR)
        raise Exception("An error occurred autofuse compile for check kernel_meta path")

    kernel_name = args[-1]
    from autofuse.compile_adapter import get_debug_flag
    compile_debug = get_debug_flag()
    if not compile_debug:
        with tempfile.TemporaryDirectory(prefix=kernel_name, dir=kernel_meta_dir) as temp_dir:
            msg = f"kernel name:{kernel_name}, kernel_meta_dir:{kernel_meta_dir}, tmp:{temp_dir}"
            CommonUtility.print_compile_log("", msg, AscendCLogLevel.LOG_INFO)
            asc_codegen_compile_with_tmpdir(*args, temp_dir=temp_dir, **kwargs)
    else:
        os.environ['ASCEND_OP_COMPILE_SAVE_KERNEL_META'] = '1'
        temp_dir = tempfile.mkdtemp(prefix=kernel_name, dir=kernel_meta_dir)
        msg = f"kernel name:{kernel_name}, kernel_meta_dir:{kernel_meta_dir}, tmp:{temp_dir}"
        CommonUtility.print_compile_log("", msg, AscendCLogLevel.LOG_INFO)
        asc_codegen_compile_with_tmpdir(*args, temp_dir=temp_dir, **kwargs)
        os.environ.pop('ASCEND_OP_COMPILE_SAVE_KERNEL_META')

    CommonUtility.print_compile_log("", "asc_codegen_compile finish", AscendCLogLevel.LOG_INFO)