#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ----------------------------------------------------------------------------
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ----------------------------------------------------------------------------

import os
import tempfile
import argparse
from typing import List
from autofuse import ascendc_compile
import re


def str2bool(v):
    v_lower = v.lower()
    if v_lower in ['true', '1', 'yes', 'y']:
        return True
    elif v_lower in ['false', '0', 'no', 'n']:
        return False
    else:
        raise ValueError(f"Invalid boolean value: '{v}'")


def camel_to_snake(camel_str):
    # 使用正则表达式匹配大写字母
    s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', camel_str)
    # 使用正则表达式匹配小写字母后跟大写字母的情况
    return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()


def gen_valid_name(t_name):
    result = []
    last_was_underscore = False

    for c in t_name:
        if c.isalnum():
            result.append(c)
            last_was_underscore = False
        else:
            if not last_was_underscore:
                result.append('_')
                last_was_underscore = True

    ret_name = ''.join(result)

    # 删除开头的下划线
    if ret_name and ret_name[0] == '_':
        ret_name = ret_name[1:]

    # 如果以数字开头,添加前缀
    if ret_name and ret_name[0].isdigit():
        ret_name = "t_" + ret_name

    return ret_name


def parse_compile_args(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument('--graph_name', default='autofuse', type=str, help='Graph name.')
    parser.add_argument('--output_file', required=True, type=str, help='Destination directory.')
    parser.add_argument('--output_path', default='', type=str, help='Output directory.')
    parser.add_argument('--force_unknown', default=False, type=str2bool, help='force unknown shape.')
    parser.add_argument('--config_file', default='', type=str, help='PGO tiling config file after turning.')
    parser.add_argument('--soc_version', default='Ascend910B', type=str, help='chip soc version.')
    parser.add_argument('--compile_options', default='', type=str, help='Compile options of tiling and kernel.')
    return parser.parse_args(argv)


def generate_file(dst_dir, file_name, text):
    os.makedirs(dst_dir, exist_ok=True)
    file_path = os.path.join(dst_dir, file_name)
    with open(file_path, "w") as file:
        file.write(text)


def parse_env_flags(env_name):
    result = {}
    flags = os.getenv(env_name)
    if not flags:
        return result
    params = flags.split(';')
    for param in params:
        if '=' in param:
            key_part, value_part = param.split('=', 1)
            key = key_part.lstrip('-')
            result[key] = value_part
    return result


def get_dfx_env_result():
    return parse_env_flags('AUTOFUSE_DFX_FLAGS')


def get_debug_flag():
    dfx_dict = get_dfx_env_result()
    return dfx_dict.get('codegen_compile_debug', "false").lower() == 'true'


def get_pgo_topn():
    default_topn = 5
    dfx_dict = get_dfx_env_result()
    topn_str = dfx_dict.get('autofuse_pgo_topn', str(default_topn))
    try:
        topn = int(topn_str)
        if topn < 0:
            return default_topn
        return topn
    except ValueError:
        return default_topn


def get_pgo_env_flag():
    result = parse_env_flags('AUTOFUSE_FLAGS')
    return result.get('autofuse_enable_pgo', "false").lower() == 'true'


def prepare_compile_context(argv, stage, tiling_repr):
    args = parse_compile_args(argv)
    args.stage = stage
    args.tiling_repr = tiling_repr
    if stage == 'host':
        args.compile_options = (args.compile_options + " -D_GLIBCXX_USE_CXX11_ABI=0").strip()

    args.graph_name = camel_to_snake(gen_valid_name(args.graph_name))
    auto_cleanup = not args.output_path and not get_debug_flag()

    if auto_cleanup:
        temp_dir_ctx = tempfile.TemporaryDirectory()
        args.temp_dir = temp_dir_ctx.name
        return args, temp_dir_ctx, True
    args.temp_dir = args.output_path if args.output_path else tempfile.mkdtemp()
    return args, None, False


def execute_compile(sources, args):
    tiling_def_file = "autofuse_tiling_data.h"
    base_host_file = args.graph_name + "_tiling_func.cpp"
    base_device_file = args.graph_name + "_op_kernel.cpp"
    if args.stage in ['all', 'host']:
        host_file_path = os.path.join(args.temp_dir, "host")
        generate_file(host_file_path, tiling_def_file, sources['tiling_struct_code'])
        generate_file(host_file_path, base_host_file, sources['host_impl_code'])
        args.host_files = os.path.join(host_file_path, base_host_file)
    if args.stage in ['all', 'device']:
        device_file_path = os.path.join(args.temp_dir, "device")
        generate_file(device_file_path, tiling_def_file, sources['tiling_struct_code'])
        generate_file(device_file_path, base_device_file, sources['kernel_impl_code'])
        args.device_files = os.path.join(device_file_path, base_device_file)

    ascendc_compile.main(args)
    return args.temp_dir


def compile_core(sources, argv: List[str], stage='all', tiling_repr=None):
    args, temp_dir_ctx, auto_cleanup = prepare_compile_context(argv, stage, tiling_repr)
    if not auto_cleanup:
        return execute_compile(sources, args)
    with temp_dir_ctx:
        return execute_compile(sources, args)


def jit_compile(tiling_def, host_tiling, op_kernel, argv: List[str]):
    return compile_core({
        'tiling_struct_code': tiling_def,
        'host_impl_code': host_tiling,
        'kernel_impl_code': op_kernel
    }, argv)


def host_compile(tiling_def_code, tiling_impl_code, argv: List[str]):
    return compile_core({
        'tiling_struct_code': tiling_def_code,
        'host_impl_code': tiling_impl_code,
        'kernel_impl_code': None
    }, argv, 'host')


def kernel_compile(tiling_def_code, kernel_impl_code, argv: List[str], *, tiling_repr=None):
    return compile_core({
        'tiling_struct_code': tiling_def_code,
        'host_impl_code': None,
        'kernel_impl_code': kernel_impl_code
    }, argv, 'device', tiling_repr)


def extract_time(line):
    try:
        time_str = line.split('#')[-1].strip()
        if time_str == '1.79769e+308': # 采样失败的返回值
            return float('inf')
        return float(time_str)
    except (ValueError, IndexError):
        return float('inf')


def pgo_get_top_result(search_path, top_n=5):
    with open(search_path, 'r') as file:
        lines = [line.strip() for line in file if line.strip()]

    if not lines:
        return None, None, None

    origin_line = lines[-1]
    solution_set_line = lines[:-1]

    sorted_lines = sorted(solution_set_line, key=extract_time)
    if top_n == 0 or top_n > len(sorted_lines):
        top_lines = sorted_lines
    else:
        top_lines = sorted_lines[:top_n]

    return top_lines, origin_line, top_n


def pgo_write_config(config_path, tiling_data, is_last_result=False):
    # 写入配置文件
    # 只有调优结束后,才写1标记内存复用,否则写0强制每次读文件
    with open(config_path, 'w') as file:
        if is_last_result:
            file.write(f'1\n')
        else:
            file.write(f'0\n')
        file.write(f"{tiling_data}\n")
        file.flush()


def pgo_generate_config(search_path, config_path, topn=5):
    with open(search_path, 'r') as file:
        lines = [line.strip() for line in file if line.strip()]

    target_lines = lines[-(topn + 1):]

    result = min(target_lines, key=extract_time)
    if extract_time(result) == float('inf'):
        result = lines[-1]
    pgo_write_config(config_path, result, is_last_result=True)