#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# ----------------------------------------------------------------------------------------------------------
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ----------------------------------------------------------------------------------------------------------

import os
import sys
import glob
import shutil
import argparse
import subprocess
from asc_op_compile_base.common.utils.log_utils import LogUtil, AscendCLogLevel


def run_command(args, **others):
    try:
        subprocess.run(args, check=True, **others)
    except subprocess.CalledProcessError as e:
        LogUtil.print_compile_log("", f"Command failed: {e}!", AscendCLogLevel.LOG_ERROR, LogUtil.Option.NON_SOC)


def args_parse():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-f", "--object-files-dir", help="Output files from host object targets, generated by cmake."
    )

    parser.add_argument(
        "-k", "--kernel-libs-dir", default=[], help="Output files from kernel libs."
    )

    parser.add_argument(
        "-t", "--tmp-obj-dir", help="Temporary dir for combining object files to static library."
    )

    parser.add_argument(
        "-o", "--output-file", help="Output static library of current customize operator project."
    )

    parser.add_argument(
        "-r", "--remove-tmp-files", default="1", help="Whether to remove temporary files."
    )

    parser.add_argument(
        "-p", "--package-name", default="", help="package name."
    )
    return parser.parse_args()


def get_object_list(obj_file, ori_object_list):
    with open(obj_file) as f:
        cur_obj_target = os.path.basename(obj_file)
        cur_obj_target = cur_obj_target[:cur_obj_target.find(".")]

        cur_obj_list = []
        for line in f:
            cur_obj_path = line.strip("\n \r")
            if not os.path.exists(cur_obj_path):
                raise RuntimeError(f"object file {cur_obj_path} doesn't exist.")
            cur_obj_list.append(cur_obj_path)
        ori_object_list[cur_obj_target] = cur_obj_list


def collect_object_from_files(object_files_dir, tmp_obj_dir):
    object_files_parts = glob.glob(object_files_dir + "/*.txt")
    if len(object_files_parts) == 0:
        raise RuntimeError("-o/--object-files is empty, please check and reset.")

    if os.path.isfile(tmp_obj_dir):
        raise RuntimeError("-t/--tmp-obj-dir is an existing file, is must be a directory, please check and reset.")

    if os.path.exists(tmp_obj_dir):
        shutil.rmtree(tmp_obj_dir)

    os.makedirs(tmp_obj_dir, exist_ok=True)

    ori_object_list = {}
    for obj_file in object_files_parts:
        get_object_list(obj_file, ori_object_list)

    if len(ori_object_list) == 0:
        raise RuntimeError("object parsed from file is empty.")

    dst_object_list = []
    for target, obj_list in ori_object_list.items():
        for obj in obj_list:
            file_name = target + "_" + os.path.basename(obj)
            dst_file = os.path.join(tmp_obj_dir, file_name)
            shutil.copyfile(obj, dst_file)
            dst_object_list.append(file_name)

    return dst_object_list


def unpack_kernel_library(kernel_libs_dir, objects, tmp_obj_dir):

    kernel_libs_parts = []
    opregistry_pattern = kernel_libs_dir + f"/**/libopregistry.a"
    opregistry_file_path = glob.glob(opregistry_pattern, recursive=True)
    if len(opregistry_file_path) != 1:
        raise RuntimeError(f"libopregistry.a not exist in {kernel_libs_dir}, please check and reset.")
    kernel_libs_parts.append(opregistry_file_path[0])

    kernels_pattern = kernel_libs_dir + f"/**/libkernels.a"
    kernels_file_path = glob.glob(kernels_pattern, recursive=True)
    if len(kernels_file_path) != 1:
        raise RuntimeError(f"libkernels.a not exist in {kernel_libs_dir}, please check and reset.")

    kernel_libs_parts.append(kernels_file_path[0])

    for lib in kernel_libs_parts:
        output = subprocess.check_output(["ar", "-t", lib])
        objects += [x.strip() for x in output.decode("utf-8").split("\n") if x.strip() != ""]
        run_command(["ar", "x", lib], cwd=tmp_obj_dir)


def pack_static_library(output_file, objects, tmp_obj_dir, package_name):
    output_abs_file = os.path.abspath(output_file)
    if os.path.exists(output_abs_file):
        os.remove(output_abs_file)

    output_dir_path = os.path.dirname(output_abs_file)
    output_obj_file = os.path.join(output_dir_path, package_name + "_ascendc_final.o")

    step_size = 30
    tmp_files = []
    for index in range(0, len(objects), step_size):
        tmp_file = os.path.join(output_dir_path, package_name + f"_ascendc_tmp{index}.o")
        tmp_files.append(tmp_file)
        run_command(["ld", "-r", "-o", tmp_file] + objects[index: index + step_size], cwd=tmp_obj_dir)

    run_command(["ld", "-r", "-o", output_obj_file] + tmp_files[:], cwd=tmp_obj_dir)


def remove_temporary_files(tmp_obj_dir):
    if os.path.exists(tmp_obj_dir):
        shutil.rmtree(tmp_obj_dir)


def main():
    try:
        args = args_parse()

        objects = collect_object_from_files(args.object_files_dir, args.tmp_obj_dir)
        unpack_kernel_library(args.kernel_libs_dir, objects, args.tmp_obj_dir)
        pack_static_library(args.output_file, objects, args.tmp_obj_dir, args.package_name)

        if args.remove_tmp_files == "1":
            remove_temporary_files(args.tmp_obj_dir)
    except Exception as e:
        raise (e)


if __name__ == "__main__":
    main()