# Copyright 2023-2026 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""AKG-MLIR Driver for MindSpore."""
import argparse
import hashlib
import json
import logging
import os
import ctypes
import pathlib
import subprocess
import shutil
import numpy as np

from .utils.cpu_profiling_wrapper import wrap_timer_func
from .utils.dynamic_utils import get_device_shape
from .utils.gen_runtime_code import ProfilingParams, gen_cuda_runtime_code
from .backends.ascend import akg_opt, write_code, bisheng_compile, get_block_dim_from_mlir
from .backends.ascend import benchmark_launch as launch

HOST_SHAPES = "hostShapes"
DEVICE_SHAPES = "deviceShapes"
RUNTIME_VARS = "runtimeVars"
SUPPORT_INFO = "SupportInfo"
TARGET_INFO = "targetInfo"
DYNAMIC = "is_dynamic"
SHA256 = "sha256"
KERNEL_NAME = "kernelName"
STATIC_TILE_IMPL = "StaticTileImpl"

def get_kernel_meta_path():
    """Return the PATH of kernel meta files."""
    kernel_meta_dir = os.getenv("KERNEL_META_DIR", default="akg_kernel_meta")
    return os.path.join(
        os.path.realpath(os.getenv("MS_COMPILER_CACHE_PATH", "")),
        kernel_meta_dir,
    )

def _is_single_op(desc_d):
    """Return the number of desc op is 1."""
    input_lists = desc_d.get("op_desc", [])
    return len(input_lists) <= 1

def generate_unique_hash(input_str):
    """Return the hash of input."""
    unique_hash = hashlib.md5(input_str.encode("utf8")).hexdigest()
    return unique_hash

def deal_input(desc):
    """Deal input dict."""
    for input_desc in desc["input_desc"] if desc.get("input_desc") is not None else []:
        if len(input_desc[0]["shape"]) == 1 and input_desc[0]["shape"][0] == 1 and "value" in input_desc[0]:
            input_desc[0]["value"] = 0

def del_value(desc):
    """Deal op desc."""
    for operation in desc["op_desc"]:
        deal_input(operation)
    desc["op"] = ""

def get_npucompiler_path():
    """Return the path of bishengir-compile."""
    npu_compiler_path = shutil.which("bishengir-compile")
    if npu_compiler_path is None:
        raise EnvironmentError("Couldn't find executable bishengir-compile.")
    return npu_compiler_path

def _compile_lib(kernel_name, file_path="./tmp_files/"):
    """Compile cuda runtime source."""
    so_file = os.path.join(file_path, "gen_func_" + kernel_name + ".so")
    gen_lib_file = os.path.join(file_path, "gen_func_" + kernel_name + ".cu")

    cmd = ["nvcc", "-o", so_file, gen_lib_file, "--shared",
           "-Xcompiler", "-fPIC", "-lcudart", "-lcuda", "-O3"]

    try:
        subprocess.run(cmd, check=True, capture_output=True)
    except subprocess.CalledProcessError as e:
        logging.error("run compile lib failed! cmd:\n %s \nerror message:\n %s", e.cmd, e.stderr)
        raise RuntimeError("nvcc compile failed in converting the case: " + kernel_name + "!\n") from e

def create_executable(kernel_name,
                      input_for_mod,
                      output_indexes,
                      is_dyn_shape):
    """Generate executable files"""
    cur_path = get_kernel_meta_path()
    tmp_file_path = os.path.join(cur_path, "tmp_files")
    tmp_file_name = os.path.join(tmp_file_path, "gen_func_" + kernel_name + ".so")
    fake_output_indices = []
    gen_cuda_runtime_code(kernel_name,
                          input_for_mod,
                          output_indexes,
                          is_dyn_shape,
                          fake_output_indices,
                          path=cur_path)
    try:
        _compile_lib(kernel_name, file_path=tmp_file_path)
    except Exception as e:
        raise RuntimeError("Compile cuda runtime lib fail") from e
    try:
        lib = ctypes.cdll.LoadLibrary(tmp_file_name)
    except Exception as e:
        raise RuntimeError("Load cuda runtime lib fail") from e
    return lib

def transform_data_to_ctypes(data,
                             kernel_name,
                             is_dyn_shape=False,
                             backend="gpu",
                             is_profile_params=False,
                             ):
    """ transform input data to ctypes """
    def get_max_shape_length(shapes):
        max_len = 0
        for shape in shapes:
            max_len = max(max_len, len(shape))
        return max_len

    data_ctypes = []
    if len(data) == 0:
        # dynamic shape info cannot generate inputs while compilation
        return data_ctypes
    shape_arg_list = []
    int_p = ctypes.POINTER(ctypes.c_int)
    device_shape, _, _ = get_device_shape(data, kernel_name, is_dyn_shape and not is_profile_params)

    for data_idx, d in enumerate(data):
        shape_list = [0]
        data_shape = device_shape[data_idx]
        if isinstance(d, int):
            data_ctypes.append(ctypes.c_int(d))
        elif isinstance(d, np.ndarray):
            data_ctypes.append(d.ctypes.data_as(int_p))
            shape_list += list(data_shape)
            # for tensor (m, n, k), strides is [n*k, k, 1]
            stride_list = [1] * len(data_shape)
            for idx, _ in enumerate(data_shape[1:]):
                stride_list[-idx - 2] = stride_list[-idx - 1] * \
                    data_shape[-idx - 1]
            shape_list += stride_list
        else:
            raise TypeError("wrong data to cytpes, current type is '", type(d), "'")
        shape_list += [0] * (1 + 2 * 0 if is_profile_params else get_max_shape_length(device_shape) - len(shape_list))
        shape_arg_list.append(shape_list)
    if is_profile_params or backend == "gpu":
        return data_ctypes
    # pack parameters into an array of pointers
    # static shape: array of pointers of data
    packed_tensors = (int_p * len(data))()
    packed_tensors[:] = [
        ctypes.cast(data_ctype, int_p) for data_ctype in data_ctypes
    ]

    if backend == "cpu" and not is_dyn_shape:
        return [packed_tensors]

    # dynamic shape: array of pointers of data, array of [0, shape, stride] of data
    # tensor_num * [0, shape_list 1,2,...,n, strides 0,1,2,...,n]
    packed_shape_lists = (int_p * len(data))()
    for idx, shape_list in enumerate(shape_arg_list):
        packed_shapes = (int_p * len(shape_list))()
        packed_shapes[:] = [
            ctypes.cast(shape, int_p) for shape in shape_list
        ]
        packed_shape_lists[idx] = ctypes.cast(packed_shapes, int_p)
    return [packed_tensors, packed_shape_lists]


class MlirDriver:
    """class MlirDriver."""

    def __init__(
        self,
        kernel_name: str,
        input_file: str,
        output_dir: str = "",
        akg_tools_dir: str = "",
        llvm_tools_dir: str = "",
        dynamic_shape: bool = False,
        log_level: bool = "INFO",
        dump_ir=False,
        mlir_timing=False,
        repo_path: str = "",
        profiling_trails=0,
        runtime_provider="MindSpore",
        enable_loop_fusion=True,
        arch="Ascend910B1",
        backend="ascend",
    ):
        super().__init__()

        self.kernel_name = kernel_name
        self.input_file = input_file
        self.output_dir = get_kernel_meta_path() if output_dir == "" else output_dir
        self.akg_tools_dir = (
            os.path.dirname(os.path.abspath(__file__))
            if akg_tools_dir == ""
            else akg_tools_dir
        )
        self.llvm_tools_dir = (
            os.path.join(pathlib.Path(__file__).absolute().parent, "../../third-party/llvm-project/build/")
            if llvm_tools_dir == ""
            else llvm_tools_dir
        )
        self.log_level = log_level
        self.target_info = ""
        self.dump_ir = dump_ir
        self.mlir_timing = mlir_timing
        self.repo_path = repo_path
        self.profiling_trails = profiling_trails
        self.runtime_provider = runtime_provider
        self.enable_loop_fusion = enable_loop_fusion
        self.arch = arch
        self.backend = backend
        self.dynamic_shape = dynamic_shape

    def compile(self):
        """
        compile interface of akg-mlir. The input is the OP description json file
        generated by MindSpore GraphKernel Module
        """
        if self.backend == "cpu":
            self.compile_cpu()
        elif self.backend == "gpu":
            self.compile_gpu()
        elif self.backend == "ascend":
            self.compile_ascend()
        else:
            raise RuntimeError("Unsupported backend: " + self.backend + "!\n")

    def run(self, input_for_mod, output_indexes=None):
        """
        launch kernel interface of akg-mlir.
        """
        if self.backend == "cpu":
            self.run_cpu(input_for_mod, output_indexes)
        elif self.backend == "gpu":
            self.run_gpu(input_for_mod, output_indexes)
        elif self.backend == "ascend":
            self.run_ascend(input_for_mod, output_indexes)
        else:
            raise RuntimeError("Unsupported backend: " + self.backend + "!\n")

    def compile_ascend(self):
        """compile ascend kernel of akg_mlir."""
        self._run_mlir_convert()
        self._run_mlir_ascend_pipeline(self.dynamic_shape, self.kernel_name)
        self._run_ascend_generate_binary(self.kernel_name)

    def compile_cpu(self):
        """compile cpu kernel of akg-mlir."""
        self._run_mlir_convert()
        self._run_mlir_cpu_pipeline(self.dynamic_shape, self.kernel_name)
        self._run_mlir_to_llvm(self.kernel_name)
        self._run_cpu_generate_binary(self.kernel_name)

    def compile_gpu(self):
        """compile gpu kernel of akg-mlir."""

        def _build(is_dyn, kernel_name, tiling_mode):
            self._run_mlir_gpu_pipeline(is_dyn, kernel_name, tiling_mode)
            self._run_mlir_gpu_codegen(kernel_name)
            self._run_gpu_translate(kernel_name)
            self._run_ptx_replace(is_dyn, kernel_name)
            self._run_ptx_dump_json(is_dyn, kernel_name)

        def _gen_static_tile_kernel(kernel_name):
            try:
                sub_kernel_name = kernel_name + "_static"
                sub_input_file = os.path.join(self.output_dir, sub_kernel_name + ".info")
                sub_input_file_desc = {}
                with open(self.input_file, "r", encoding="utf-8") as f:
                    sub_input_file_desc = json.loads(f.read())
                    sub_input_file_desc["op"] = sub_kernel_name
                with os.fdopen(os.open(sub_input_file, os.O_WRONLY | os.O_CREAT, 0o755), "w") as f:
                    f.write(json.dumps(sub_input_file_desc))
                self._run_mlir_convert(sub_kernel_name, sub_input_file)
                logging.debug("Start to build %s", sub_kernel_name)
                _build(True, sub_kernel_name, "static")
                logging.debug("Success to build %s", sub_kernel_name)
            except RuntimeError as exc:
                logging.error("Fail to build %s : %s", sub_kernel_name, exc)
                raise RuntimeError(f"Compile error, kernel: {sub_kernel_name} is not generated") from exc

        default_tiling_mode = None
        if self.dynamic_shape and os.environ.get("MLIR_TILING_MODE", "auto") == "both":
            _gen_static_tile_kernel(self.kernel_name)
            default_tiling_mode = "auto"
        try:
            self._run_mlir_convert()
            _build(self.dynamic_shape, self.kernel_name, default_tiling_mode)
        except RuntimeError as exc:
            logging.error("Compile error, kernel: %s", self.kernel_name)
            raise RuntimeError(f"Compile error, kernel: {self.kernel_name} is not generated") from exc

    def run_ascend(self, input_for_mod, output_indexes=None):
        """run kernel for npu"""
        # Run executable and profiling
        # All preprocessing (bf16, output_indexes, device_shape) done inside launch()
        device_id = int(os.environ.get("DEVICE_ID", 0))
        if self.profiling_trails == 0:
            launch(self.output_dir, self.kernel_name, device_id, self.dynamic_shape, *input_for_mod,
                   use_mem_pool=True, output_indexes=output_indexes)
        else:
            akgProfileMgr.ascend_start_profiling(device_id)
            for _ in range(5):
                launch(self.output_dir, self.kernel_name, device_id, self.dynamic_shape, *input_for_mod,
                       use_mem_pool=True, output_indexes=output_indexes)
            akgProfileMgr.ascend_stop_profiling()
            # analysis
            cycle = profiling_analyse(None)
            logging.info('=====Task Duration(us)==============================')
            if cycle != PROF_ERROR_CODE:
                logging.info(cycle)
            else:
                logging.error("OOPS, can't correctly Task Duration!")
            logging.info('=====Task Duration(us)==============================')
            logging.info("%s Task Duration(us) : %s", kernel_name, str(cycle))
        return input_for_mod

    def run_cpu(self, input_for_mod, output_indexes=None):
        """run kernel for npu"""
        input_for_mod_ctypes = transform_data_to_ctypes(
            input_for_mod,
            self.kernel_name,
            self.dynamic_shape,
            "cpu")
        # Profiling
        cur = ctypes.cdll.LoadLibrary(os.path.join(self.output_dir, self.kernel_name + ".so"))
        if self.profiling_trails > 0:
            func = getattr(cur, "main")
            np_timers_ns = np.array([0], dtype=np.int64)
            input_for_mod_ctypes.append(np_timers_ns.ctypes.data_as(ctypes.POINTER(ctypes.c_longlong)))
            func(*input_for_mod_ctypes)
            logging.info("%s : Running %s times, the average execution time is %s ms.",
                         self.kernel_name,
                         self.profiling_trails,
                         np_timers_ns / 1000000 / self.profiling_trails)
        else:
            func = getattr(cur, self.kernel_name)
            # Run executable and compare results
            func(*input_for_mod_ctypes)


    def run_gpu(self, input_for_mod, output_indexes=None):
        """run kernel for gpu"""
        input_for_mod_ctypes = transform_data_to_ctypes(
            input_for_mod,
            self.kernel_name,
            self.dynamic_shape,
            "gpu")
        if self.profiling_trails == 0:
            # Run executable
            lib = create_executable(self.kernel_name, input_for_mod, output_indexes, self.dynamic_shape)
            lib.cuda_runtime_exec(*input_for_mod_ctypes)
        else:
            # Profiling
            prof_params = ProfilingParams(number=10, repeat=self.profiling_trails, min_repeat_ms=0)
            prof_params_ctypes = transform_data_to_ctypes(
                prof_params.get_data(),
                self.kernel_name,
                self.dynamic_shape,
                "gpu",
                is_profile_params=True)
            lib = create_executable(self.kernel_name, input_for_mod, output_indexes, self.dynamic_shape)
            lib.cuda_runtime_profiling(*input_for_mod_ctypes, *prof_params_ctypes)


    def _run_mlir_convert(self, kernel_name=None, input_file=None):
        """convert info to mlir."""
        if kernel_name is None:
            kernel_name = self.kernel_name
        if input_file is None:
            input_file = self.input_file
        out_file = os.path.join(self.output_dir, kernel_name + ".mlir")
        if input_file.endswith(".mlir"):
            shutil.copy(input_file, out_file)
            return
        cmd = [
            os.path.join(self.akg_tools_dir, "bin/mindspore-translate"),
            "-json-to-mindspore",
            input_file,
            "-o",
            out_file,
        ]
        try:
            subprocess.run(cmd, check=True, capture_output=True)
        except subprocess.CalledProcessError as e:
            logging.error("run mlir pipeline failed! cmd:\n %s \nerror message:\n %s", e.cmd, e.stderr)
            raise RuntimeError("mlir pipeline failed in converting the case: " + kernel_name + "!\n") from e

    def _run_mlir_cpu_pipeline(self, dyn_shape, kernel_name):
        """compile mlir use cpu pipeline."""
        input_file = os.path.join(self.output_dir, kernel_name + ".mlir")
        out_file = os.path.join(self.output_dir, kernel_name + "_out.mlir")
        cpu_opt_option = "--cpu-opt"
        if dyn_shape:
            cpu_opt_option += "=dynamic-shape=true"
        else:
            cpu_opt_option += "=cpu-outlining=false outlining-platform=" + self.runtime_provider
        cmd = [os.path.join(self.akg_tools_dir, "bin/akg-opt"), input_file, cpu_opt_option, "-o", out_file]
        if self.dump_ir:
            cmd.append("--mlir-print-ir-after-all")
        if self.mlir_timing:
            cmd.append("--mlir-timing")

        try:
            result = subprocess.run(cmd, check=True, capture_output=True, text=True)
            if self.dump_ir:
                dump_log = os.path.join(self.output_dir, kernel_name + "_dump_cpu.log")
                with os.fdopen(os.open(dump_log, os.O_WRONLY | os.O_CREAT, 0o755), "w") as f:
                    f.write(result.stderr)
        except subprocess.CalledProcessError as e:
            logging.error("run akg-opt failed! cmd:\n %s \nerror message:\n %s", e.cmd, e.stderr)
            raise RuntimeError("mlir pipeline failed in case: " + kernel_name + "!\n") from e

        logging.info("mlir pipeline success")

    def _run_mlir_ascend_pipeline(self, dyn_shape, kernel_name):
        """compile mlir use ascend pipeline."""
        input_file = os.path.join(self.output_dir, kernel_name + ".mlir")
        out_file = os.path.join(self.output_dir, kernel_name + "_opt.mlir")

        dump_ir_path = None
        if self.dump_ir:
            dump_ir_path = os.path.join(self.output_dir, kernel_name + "_opt.log")

        akg_opt(
            input_file=input_file,
            output_file=out_file,
            dyn_shape=dyn_shape,
            enable_loop_fusion=self.enable_loop_fusion,
            arch=self.arch,
            dump_ir=self.dump_ir,
            mlir_timing=self.mlir_timing,
            dump_ir_path=dump_ir_path,
        )

    def _run_ascend_generate_binary(self, kernel_name):
        """compile mlir to binary for ascend."""
        logging.info("bishengir-compile code generator:")
        input_file = os.path.join(self.output_dir, kernel_name + "_opt.mlir")
        so_file = os.path.join(self.output_dir, kernel_name + ".so")
        dump_log = os.path.join(self.output_dir, kernel_name + "_bisheng.log")

        block_dim = get_block_dim_from_mlir(input_file)
        bisheng_compile(
            input_file=input_file,
            output_file=so_file,
            enable_hfusion_compile=not self.enable_loop_fusion,
            block_dim=block_dim,
            dump_ir=self.dump_ir,
            dump_ir_path=dump_log,
        )

    def _run_mlir_to_llvm(self, kernel_name):
        """compile mlir to llvm."""
        logging.info("mlir to llvm:")
        input_file = os.path.join(self.output_dir, kernel_name + "_out.mlir")
        if self.profiling_trails > 0:
            input_file = wrap_timer_func(input_file, self.kernel_name, self.profiling_trails)
        out_file = os.path.join(self.output_dir, kernel_name + ".ll")
        cmd = ["mlir-translate", input_file, "--mlir-to-llvmir", "-o", out_file]
        logging.debug("_run_mlir_to_llvm cmd: %s", cmd)
        try:
            subprocess.run(cmd, check=True, capture_output=True)
        except subprocess.CalledProcessError as e:
            logging.error("run mlir to llvm failed! cmd:\n %s \nerror message:\n %s", e.cmd, e.stderr)
            raise RuntimeError("mlir to llvm failed in case: " + input_file + "!\n") from e
        logging.info("mlir to llvm success")

    def _run_cpu_generate_binary(self, kernel_name):
        """compile mlir to binary for cpu."""
        input_file = os.path.join(self.output_dir, kernel_name + ".ll")
        out_file = os.path.join(self.output_dir, kernel_name + ".s")
        bin_file = os.path.join(self.output_dir, kernel_name + ".so")
        cmd = [
            "llc",
            input_file,
            "-relocation-model=pic",
            "-O3",
            "-o",
            out_file,
        ]
        try:
            subprocess.run(cmd, check=True, capture_output=True)
        except subprocess.CalledProcessError as e:
            logging.error("generate .s failed! cmd:\n %s \nerror message:\n %s", e.cmd, e.stderr)
            raise RuntimeError("generate .s failed in case " + input_file + "!\n") from e

        cmd = [
            "clang++",
            out_file,
            "--rtlib=compiler-rt",
            "-fopenmp",
            "-O3",
            "--shared",
            "-fPIC",
            "-o",
            bin_file,
            "-L",
            os.path.join(self.llvm_tools_dir, "lib/"),
            "-lmlir_c_runner_utils",
        ]
        if self.runtime_provider == "MLIR":
            cmd.extend(["-L", self.akg_tools_dir, "-lmlir_akgParallelLaunch_runtime",
                        f"-Wl,-rpath,{self.akg_tools_dir}"])
        if self.profiling_trails > 0:
            cmd.extend(["-L", os.path.join(self.llvm_tools_dir, "lib/"), "-lmlir_runner_utils",
                        f"-Wl,-rpath,{self.llvm_tools_dir}/lib"])
        try:
            subprocess.run(cmd, check=True, capture_output=True)
        except subprocess.CalledProcessError as e:
            logging.error("generate .so failed! cmd:\n %s \nerror message:\n %s", e.cmd, e.stderr)
            raise RuntimeError("generate .so failed in case " + input_file + "!\n") from e

        logging.info("generate cpu binary .so success")
        title_dict = {}
        # kernel name
        title_dict[KERNEL_NAME] = kernel_name + "_kernel"
        # thread number
        thread_num = "null"
        title_dict["threadNumber"] = thread_num
        lib_file = os.path.join(self.output_dir, kernel_name + ".so")
        # sha256 of files
        lib_sha256 = hashlib.sha256()
        with open(lib_file, "rb") as f:
            lib_sha256.update(f.read())
        lib_hash_str = lib_sha256.hexdigest()
        title_dict[SHA256] = lib_hash_str

        json_file = os.path.join(self.output_dir, kernel_name + ".json")
        write_code(title_dict, json_file)

    def _run_ascend_generate_binary_(self, kernel_name):
        """compile llvm to binary for ascend."""
        input_file = os.path.join(self.output_dir, kernel_name + ".ll")
        out_file = os.path.join(self.output_dir, kernel_name + ".s")
        bin_file = os.path.join(self.output_dir, kernel_name + ".so")
        cmd = [
            "llc",
            input_file,
            "-relocation-model=pic",
            "-O3",
            "-o",
            out_file,
        ]
        logging.debug("_run_ascend_generate_binary step 0 cmd: %s", cmd)
        try:
            subprocess.run(cmd, check=True, capture_output=True)
        except subprocess.CalledProcessError as e:
            logging.error("generate .s failed! cmd:\n %s \nerror message:\n %s", e.cmd, e.stderr)
            raise RuntimeError("generate .s failed in case " + input_file + "!\n") from e

        cmd = [
            "clang++",
            out_file,
            "--rtlib=compiler-rt",
            "-O3",
            "--shared",
            "-fPIC",
            "-o",
            bin_file,
            "-L",
            os.path.join(self.llvm_tools_dir, "lib/"),
            "-lmlir_c_runner_utils",
        ]
        if self.runtime_provider == "MLIR":
            cmd.extend(["-L", os.path.join(self.akg_tools_dir, "lib/"), "-lmlir_akgParallelLaunch_runtime"])
        if self.profiling_trails > 0:
            cmd.extend(["-L", os.path.join(self.llvm_tools_dir, "lib/"), "-lmlir_runner_utils"])
        logging.debug("_run_ascend_generate_binary step 1 cmd: %s", cmd)
        try:
            subprocess.run(cmd, check=True, capture_output=True)
        except subprocess.CalledProcessError as e:
            logging.error("generate .so failed! cmd:\n %s \nerror message:\n %s", e.cmd, e.stderr)
            raise RuntimeError("generate .so failed in case " + input_file + "!\n") from e

        logging.info("generate ascend binary .so success")
        title_dict = {}
        # kernel name
        title_dict[KERNEL_NAME] = kernel_name + "_kernel"
        # thread number
        thread_num = "null"
        title_dict["threadNumber"] = thread_num
        lib_file = os.path.join(self.output_dir, kernel_name + ".so")
        # sha256 of files
        lib_sha256 = hashlib.sha256()
        with open(lib_file, "rb") as f:
            lib_sha256.update(f.read())
        lib_hash_str = lib_sha256.hexdigest()
        title_dict[SHA256] = lib_hash_str

        json_file = os.path.join(self.output_dir, kernel_name + ".json")
        write_code(title_dict, json_file)

    def has_reduce(self):
        """Return if the fused op contain reduction operator."""
        with open(self.input_file, "r", encoding='utf-8') as f:
            desc_d = json.loads(f.read())
            for op in desc_d.get("op_desc"):
                op_name = op.get("name")
                if "reduce" in op_name.lower():
                    return True
        return False

    def _run_mlir_gpu_pipeline(self, dyn_shape, kernel_name, tiling_mode=None):
        """compile mlir use ascend pipeline."""
        input_file = os.path.join(self.output_dir, kernel_name + ".mlir")
        out_file = os.path.join(self.output_dir, kernel_name + "_gpu.mlir")
        opt_pipeline = "--gpu-dyn-opt" if dyn_shape else "--gpu-opt"
        opt_options = ""
        if dyn_shape:
            if tiling_mode is None:
                tiling_mode = os.environ.get("MLIR_TILING_MODE", "auto")
            opt_options += "tiling-mode=" + tiling_mode

        if os.path.exists(self.repo_path):
            opt_options += " global-config-file=" + self.repo_path

        if opt_options != "":
            opt_pipeline += "=" + opt_options

        cmd = [os.path.join(self.akg_tools_dir, "bin/akg-opt"), input_file, opt_pipeline, "-o", out_file]
        if self.dump_ir:
            cmd.append("--mlir-print-ir-after-all")
        if self.mlir_timing:
            cmd.append("--mlir-timing")

        try:
            result = subprocess.run(cmd, check=True, capture_output=True, text=True)
            if self.dump_ir:
                dump_log = os.path.join(self.output_dir, kernel_name + "_dump_gpu.log")
                with os.fdopen(os.open(dump_log, os.O_WRONLY | os.O_CREAT, 0o755), "w") as f:
                    f.write(result.stderr)
        except subprocess.CalledProcessError as e:
            logging.error("mlir gpu pipeline failed! cmd:\n %s \nerror message:\n %s", e.cmd, e.stderr)
            raise RuntimeError("mlir gpu pipeline failed in case: " + kernel_name + "!\n") from e
        logging.info("mlir gpu pipeline success: %s", kernel_name)

    def _run_mlir_gpu_codegen(self, kernel_name):
        """compile mlir to nvvm for gpu."""
        logging.info("gpu_codegen:")
        input_file = os.path.join(self.output_dir, kernel_name + "_gpu.mlir")
        out_file = os.path.join(self.output_dir, kernel_name + "_nvvm.mlir")
        cmd = [os.path.join(self.akg_tools_dir, "bin/akg-opt"), input_file, "--gpu-codegen", "-o", out_file]
        try:
            subprocess.run(cmd, check=True, capture_output=True)
        except subprocess.CalledProcessError as e:
            logging.error("gpu_codegen failed! cmd:\n %s \nerror message:\n %s", e.cmd, e.stderr)
            raise RuntimeError("gpu_codegen failed in case: %s" + kernel_name + "!\n") from e
        logging.info("gpu_codegen success: %s", kernel_name)

    def _run_gpu_translate(self, kernel_name):
        """compile nnvm to ptx for gpu."""
        logging.info("mlir to ptx:")
        input_file = os.path.join(self.output_dir, kernel_name + "_nvvm.mlir")
        out_prefix = os.path.join(self.output_dir, kernel_name + "_init")
        cmd = [
            os.path.join(self.akg_tools_dir, "bin/akg-translate"),
            "-gen-ptx",
            "-arch=sm_70",
            input_file,
            "--kernel-name=" + out_prefix,
        ]
        try:
            subprocess.run(cmd, check=True, capture_output=True)
        except subprocess.CalledProcessError as e:
            logging.error("mlir to ptx failed! cmd:\n %s \nerror message:\n %s", e.cmd, e.stderr)
            raise RuntimeError("mlir to ptx failed in case: " + kernel_name + "!\n") from e
        logging.info("mlir to ptx success: %s", kernel_name)

    def _run_ptx_replace(self, dyn_shape, kernel_name):
        """ptx replacement."""
        logging.info("ptx replacement")
        input_file = os.path.join(self.output_dir, kernel_name + "_init.ptx")
        out_prefix = os.path.join(self.output_dir, kernel_name)
        out_file = out_prefix + ".ptx"
        shape_arg_file = os.path.join(self.output_dir, kernel_name + "_shape_arg.txt")
        cmd = [os.path.join(self.akg_tools_dir, "bin/akg-ptx-replace"), input_file, shape_arg_file, out_file]
        if dyn_shape:
            cmd += ["none", "dynamic_shape"]
        try:
            subprocess.run(cmd, check=True, capture_output=True)
        except subprocess.CalledProcessError as e:
            logging.error("ptx replacement failed! cmd:\n %s \nerror message:\n %s", e.cmd, e.stderr)
            raise RuntimeError("ptx replacement failed in case: " + input_file + "!\n") from e
        logging.info("ptx replacement success: %s", kernel_name)

    def _run_ptx_dump_json(self, dyn_shape, kernel_name):
        """dump ptx meta."""
        title_dict = {}
        json_file = os.path.join(self.output_dir, kernel_name + ".json")
        with open(json_file, "rb") as f:
            params = json.load(f)
            # Skip useless integar lists like ["32", 1] in "Seq"
            for k, v in params.items():
                if not (isinstance(v, list) and len(v) == 2 and isinstance(v[0], str) and v[0].isdigit()):
                    title_dict[k] = v
                elif "Seq" not in k:
                    title_dict[k] = (int(v[0]) - 1) // v[1] + 1

        out_file = os.path.join(self.output_dir, kernel_name + ".ptx")

        if dyn_shape:
            shape_info_json = os.path.join(self.output_dir, kernel_name + "_shape_info.json")
            if not os.path.exists(shape_info_json):
                raise RuntimeError(f"Dynamic shape needs file {shape_info_json} to get the device shape. Otherwise, \
                                     the result may be incorrect.")

            with os.fdopen(os.open(shape_info_json, os.O_RDONLY, 0o755), "rb") as f:
                shape_params = json.load(f)
                title_dict[HOST_SHAPES] = shape_params.get(HOST_SHAPES, [])
                title_dict[DEVICE_SHAPES] = shape_params.get(DEVICE_SHAPES, [])
                title_dict[RUNTIME_VARS] = shape_params.get(RUNTIME_VARS, [])
                title_dict[SUPPORT_INFO] = shape_params.get(SUPPORT_INFO, [])
            title_dict[TARGET_INFO] = self.target_info
            if len(title_dict[RUNTIME_VARS]) > 0:
                self._dump_static_tile_kernel_impl(kernel_name, title_dict, out_file)

        title_dict[KERNEL_NAME] = kernel_name + "_kernel"
        title_dict[DYNAMIC] = dyn_shape
        # sha256 of files
        lib_sha256 = hashlib.sha256()
        with os.fdopen(os.open(out_file, os.O_RDONLY, 0o755), "rb") as f:
            lib_sha256.update(f.read())
        title_dict[SHA256] = lib_sha256.hexdigest()

        write_code(title_dict, json_file)

    def _dump_static_tile_kernel_impl(self, kernel_name, title_dict, out_file):
        """dump ptx."""
        static_json_file = os.path.join(self.output_dir, kernel_name + "_static.json")
        static_ptx_file = os.path.join(self.output_dir, kernel_name + "_static.ptx")
        if not (os.path.exists(static_json_file) and os.path.exists(static_ptx_file)):
            return

        with os.fdopen(os.open(static_json_file, os.O_RDONLY, 0o755), "r") as sf:
            title_dict[STATIC_TILE_IMPL] = json.load(sf)
        static_kernel_str = []
        with os.fdopen(os.open(static_ptx_file, os.O_RDONLY, 0o755), "r") as sf:
            start = False
            for line in sf:
                if ".entry" in line:
                    start = True
                if start:
                    static_kernel_str.append(line)
        with os.fdopen(os.open(out_file, os.O_WRONLY | os.O_CREAT, 0o755), "w") as f:
            for line in static_kernel_str:
                f.write(line)




if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run akg-mlir End to End")
    parser.add_argument("-f", type=str, help="Run single file.")
    parser.add_argument("-o", type=str, help="output dir.", default="")
    parser.add_argument("-akg-tools-dir", type=str, help="akg-mlir tools build dir.", default="")
    parser.add_argument("-llvm-tools-dir", type=str, help="llvm tools build dir", default="")
    parser.add_argument("-bisheng-tools-dir", type=str, help="bisheng cpp tools build dir", default="")
    parser.add_argument("-d", "--dynamic-shape", type=bool, help="Specifies dynamic shape or not", default=False)
    args = parser.parse_args()
    logging.info(args)

    driver = MlirDriver(
        kernel_name='',
        input_file=args.f,
        output_dir=args.o,
        akg_tools_dir=args.akg_tools_dir,
        llvm_tools_dir=args.llvm_tools_dir,
        dynamic_shape=args.dynamic_shape,
    )
    driver.compile()