# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.


import functools
import hashlib
import os
import re
import shutil
import subprocess
import sysconfig
from pathlib import Path
import logging
from platform import python_version
from triton.backends.ascend.backend_register import backend_strategy_registry

import pybind11

backend_policy = None


def get_backend_func(name, *args, **kwargs):
    global backend_policy
    if backend_policy is None:
        try:
            import torch
            import torch_npu
            backend_policy = "torch_npu"
        except ImportError:
            backend_policy = "mindspore"
        print("the backend policy is {}".format(backend_policy))
    return backend_strategy_registry.execute_func(backend_policy, name, *args, **kwargs)


def get_logger(logger_name, logger_level_str):
    '''
    '''
    logging_level_mapping = {
        "DEBUG": logging.DEBUG,
        "INFO": logging.INFO,
        "WARNING": logging.WARNING,
        "ERROR": logging.ERROR,
        "CRITICAL": logging.CRITICAL,
    }
    logger = logging.getLogger(logger_name)
    logger.propagate = False
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(formatter)
    logger.setLevel(logging_level_mapping.get(logger_level_str.upper(), "INFO"))
    logger.addHandler(console_handler)
    return logger


def downgrade_llir(llir):
    llir = _downgrade_mem_attrs(llir)
    llir = _downgrade_stacksaverestore_intrinsics(llir)
    return llir


def _downgrade_mem_attrs(llir: str):
    memory_pattern = r"memory\([^()]*\)"

    def replace_mem_attr(m):
        attrs = m[0][7:-1].split(",")
        if len(attrs) == 0:
            return "readnone"
        loc_map = {"argmem": 1, "inaccessiblemem": 2, "other": 4}
        loc_attr = 0
        rw_map = {"readwrite": 3, "write": 2, "read": 1, "none": 0}
        rw_attr = 0
        for attr_pair in attrs:
            pair = attr_pair.split(":")
            assert len(pair) <= 2
            if len(pair) == 1:
                rw = rw_map[pair[0].strip()]
                loc = loc_map["other"]  # all location
            else:
                rw = rw_map[pair[1].strip()]
                loc_str = pair[0].strip()
                if loc_str == "argmem" or loc_str == "inaccessiblemem":
                    loc = loc_map[loc_str]
                else:
                    loc = loc_map["other"]
            if rw > 0:
                loc_attr = loc_attr | loc
                rw_attr = rw_attr | rw
        rev_rw_map = {0: "readnone", 1: "readonly", 2: "writeonly"}
        if rw_attr in rev_rw_map:
            rw_attr_str = rev_rw_map[rw_attr]
        else:
            rw_attr_str = ""
        rev_loc_map = {
            1: "argmemonly",
            2: "inaccessiblememonly",
            3: "inaccessiblemem_or_argmemonly",
        }
        if loc_attr in rev_loc_map:
            loc_attr_str = rev_loc_map[loc_attr]
        else:
            loc_attr_str = ""
        return rw_attr_str + " " + loc_attr_str

    return re.sub(memory_pattern, replace_mem_attr, llir)


def _downgrade_stacksaverestore_intrinsics(llir: str):
    llir = re.sub(r"llvm\.stacksave\.\w+", "llvm.stacksave", llir)
    llir = re.sub(r"llvm\.stackrestore\.\w+", "llvm.stackrestore", llir)
    return llir


def _get_triton_adapter_opt_path() -> str:
    path = os.path.dirname(__file__)
    path = os.path.join(path, "triton-adapter-opt")
    return path


def _get_mlir_path(path: str, *paths) -> str:
    root_path = os.getenv("MLIR_ROOT", "")
    if root_path == "":
        raise EnvironmentError("MLIR_ROOT is not set.")
    return os.path.join(root_path, path, *paths)


def _get_llvm_path(path: str, *paths) -> str:
    root_path = os.getenv("LLVM_ROOT", "")
    if root_path == "":
        raise EnvironmentError("LLVM_ROOT is not set.")
    return os.path.join(root_path, path, *paths)


def _get_npucompiler_path() -> str:
    npu_compiler_path = shutil.which("bishengir-compile")
    if npu_compiler_path is None:
        npu_compiler_root = os.getenv("TRITON_NPU_COMPILER_PATH", "")
        if npu_compiler_root is None:
            raise EnvironmentError(
                "Couldn't find executable bishengir-compile or TRITON_NPU_COMPILER_PATH."
            )
        npu_compiler_path = os.path.join(npu_compiler_root, "npuc")
    return npu_compiler_path


def _get_bisheng_path() -> str:
    bisheng_path = shutil.which("bisheng")
    if bisheng_path is None:
        npu_compiler_root = os.getenv("TRITON_NPU_COMPILER_PATH", "")
        if npu_compiler_root is None:
            raise EnvironmentError(
                "Couldn't find executable bisheng or TRITON_NPU_COMPILER_PATH"
            )
        bisheng_path = os.path.join(npu_compiler_root, "ccec")
    return bisheng_path


def _is_valid_bishengir_path(path: str) -> bool:
    if not path or not isinstance(path, str):
        return False
    if os.path.basename(path) != "bishengir-compile":
        return False
    if not os.path.isfile(path) or not os.access(path, os.X_OK):
        return False
    return True


# grep bishengir-compile's option limit-auto-multi-buffer-buffer to check
# if bishengir-compile is a newer version which does not generate kernel_reloc.o
# any more.
def _check_bishengir_api_change() -> bool:
    bishengir_path = _get_npucompiler_path()
    if not _is_valid_bishengir_path(bishengir_path):
        print(f"ERROR: Invalid bishengir path format: {bishengir_path}")
        return False
    try:
        result = subprocess.run(
            [bishengir_path, "--help"],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
        )
        if result.returncode == 0 and 'limit-auto-multi-buffer-buffer' in result.stdout:
            # bishengir-compile is newer version
            return True
        else:
            # bishengir-compile is older version
            return False
    except Exception as e:
        print(f"ERROR: {e}")
        return False


def _check_bishengir_is_regbased() -> bool:
    bishengir_path = _get_npucompiler_path()
    if not _is_valid_bishengir_path(bishengir_path):
        print(f"ERROR: Invalid bishengir path format: {bishengir_path}")
        return False
    try:
        result = subprocess.run(
            [bishengir_path, "--help"],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
        )
        if result.returncode == 0 and 'reg-based' in result.stdout:
            # bishengir-compile is regbased version
            return True
        else:
            # bishengir-compile is membased version
            return False
    except Exception as e:
        print(f"ERROR: {e}")
        return False


@functools.lru_cache(None)
def _get_ascend_path() -> Path:
    path = os.getenv("ASCEND_HOME_PATH", "")
    if path == "":
        raise EnvironmentError(
            "ASCEND_HOME_PATH is not set, source <ascend-toolkit>/set_env.sh first"
        )
    return Path(path)


def _is_ascend_sanitizer_enabled() -> bool:
    return os.getenv("TRITON_ENABLE_SANITIZER", "false").lower() in ("true", "1")


def _is_debug_line_info_disabled() -> bool:
    return os.getenv("TRITON_DISABLE_LINE_INFO", "true").lower() in ("true", "1")


def _is_auto_map_parallel_blocks_enabled() -> bool:
    return os.getenv("TRITON_ALL_BLOCKS_PARALLEL", "false").lower() in ("true", "1")


def _enable_unpublished_feature() -> bool:
    return os.getenv("ENABLE_UNPUBLISHED_FEATURE", "false").lower() in ("true", "1")

def _enable_print_ub_bits() -> bool:
    return os.getenv("ENABLE_PRINT_UB_BITS", "false").lower() in ("true", "1")

def _get_cxx():
    cxx = os.environ.get("CC")
    if cxx is None:
        clangxx = shutil.which("clang++")
        gxx = shutil.which("g++")
        cxx = clangxx if clangxx is not None else gxx
        if cxx is None:
            raise RuntimeError("Failed to find C++ compiler")
    return cxx

def _get_cxx_precompiled(header_path):
    cc_cmd = []
    cxx = os.environ.get("CC")
    if cxx is None:
        clangxx = shutil.which("clang++")
        gxx = shutil.which("g++")
        if clangxx is not None:
            cc_cmd += [clangxx, "-include", header_path]
        elif gxx is not None:
            cc_cmd += [gxx]
        else:
            raise RuntimeError("Failed to find C++ compiler")
    else:
        cc_cmd += [cxx]
    return cc_cmd

def _precompile_npu_hash(header_src):
    import sys
    cxx = _get_cxx()
    py_version = sys.version
    asc_path = _get_ascend_path().name
    version_txt = [header_src, cxx, py_version, asc_path]
    version_txt += get_backend_func("version_hash")
    hash_txt = hashlib.sha256("_".join(version_txt).encode("utf-8")).hexdigest()
    return hash_txt

def _precompile_npu_ext(header_path):
    src_dir = os.path.dirname(header_path)
    gch_path = os.path.join(src_dir, "precompiled.h.gch")
    cxx = _get_cxx()

    cc_cmd = [cxx, "-x", "c++-header", header_path]
    # disable all warnings
    cc_cmd += [f"-w"]
    # find the python library
    if hasattr(sysconfig, "get_default_scheme"):
        scheme = sysconfig.get_default_scheme()
    else:
        scheme = sysconfig._get_default_scheme()
    # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install
    # path changes to include 'local'. This change is required to use triton with system-wide python.
    if scheme == "posix_local":
        scheme = "posix_prefix"
    py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
    cc_cmd += [f"-I{py_include_dir}"]
    # device_print.h
    cc_cmd += [f"-I{os.path.dirname(os.path.realpath(__file__))}"]
    # find the ascend library
    asc_path = _get_ascend_path()

    rt_path = os.path.join(asc_path, "include/experiment/runtime/runtime/rt.h")
    if not os.path.exists(rt_path):
        cc_cmd += [
            f"-I{os.path.join(asc_path, 'pkg_inc')}",
            f"-I{os.path.join(asc_path, 'pkg_inc/profiling')}",
        ]

    cc_cmd += [
        f"-I{os.path.join(asc_path, 'include')}",
        f"-I{os.path.join(asc_path, 'include/experiment')}",
        f"-I{os.path.join(asc_path, 'include/experiment/msprof')}",
        f"-I{pybind11.get_include()}",
    ]

    cc_cmd += get_backend_func("get_cc_cmd", build_pch=True)

    cc_cmd += ["-std=c++17", "-shared", "-fPIC", "-o", gch_path]

    ret = subprocess.check_call(cc_cmd)

    if ret != 0:
        print(f"Unable to precompile header file, ret is: {ret}")

    return header_path

def _build_npu_ext(obj_name: str, header_path, src_path, *, kernel_launcher="torch", precompile=False) -> str:
    suffix = sysconfig.get_config_var("EXT_SUFFIX")
    src_dir = os.path.dirname(src_path)
    so_path = os.path.join(src_dir, f"{obj_name}{suffix}")
    if precompile:
        cc_cmd = _get_cxx_precompiled(header_path)
        cc_cmd += [src_path]
    else:
        cxx = _get_cxx()
        cc_cmd = [cxx, src_path]
    # disable all warnings
    cc_cmd += [f"-w"]
    # find the python library
    if hasattr(sysconfig, "get_default_scheme"):
        scheme = sysconfig.get_default_scheme()
    else:
        scheme = sysconfig._get_default_scheme()
    # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install
    # path changes to include 'local'. This change is required to use triton with system-wide python.
    if scheme == "posix_local":
        scheme = "posix_prefix"
    py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
    cc_cmd += [f"-I{py_include_dir}"]
    # device_print.h
    cc_cmd += [f"-I{os.path.dirname(os.path.realpath(__file__))}"]
    # find the ascend library
    asc_path = _get_ascend_path()
    if header_path is not None:
        cc_cmd += [f"-I{os.path.dirname(header_path)}"]

    rt_path = os.path.join(asc_path, "include/experiment/runtime/runtime/rt.h")
    if not os.path.exists(rt_path):
        cc_cmd += [
            f"-I{os.path.join(asc_path, 'pkg_inc')}",
            f"-I{os.path.join(asc_path, 'pkg_inc/profiling')}",
        ]

    cc_cmd += [
        f"-I{os.path.join(asc_path, 'include')}",
        f"-I{os.path.join(asc_path, 'include/experiment')}",
        f"-I{os.path.join(asc_path, 'include/experiment/msprof')}",
        f"-I{pybind11.get_include()}",
        f"-L{os.path.join(asc_path, 'lib64')}",
        "-lruntime",
        "-lascendcl",
    ]
    # FIXME: check why this condition works wrong in parall scene
    # if kernel_launcher == "torch":
    cc_cmd += get_backend_func("get_cc_cmd", build_pch=False)

    cc_cmd += ["-std=c++17", "-shared", "-fPIC", "-Winvalid-pch", "-o", so_path]

    result = subprocess.run(cc_cmd, capture_output=True, text=True)

    if result.returncode == 0:
        return so_path
    else:
        if "precompiled.h.gch" in result.stderr:
            # only for clang++, when precompile invalid, fallback to normal compile
            return _build_npu_ext(obj_name, header_path, src_path, precompile=False)
        else:
            raise RuntimeError(f"Failed to compile {src_path}, error: {result.stderr}")


def _get_kernel_target(metadata: dict):
    if "target" not in metadata:
        raise Exception("No target provided!")
    sub_target = metadata["target"].arch
    assert isinstance(sub_target, str)
    if sub_target.startswith("Ascend910B"):
        mix_mode = metadata["mix_mode"]
        if mix_mode.lower().strip("_").startswith("aiv"):
            return "ascend_910b_vec", "c220-vec", "aiv"
        elif mix_mode.lower().strip("_").startswith("aic"):
            return "ascend_910b_cube", "c220-cube", "aic"
        else:
            return "ascend_910b", "c220", "mix"
    elif sub_target.startswith("Ascend910"):
        return "ascend_910", "c100", "mix"
    else:
        raise NotImplementedError(f"NPU subtarget {sub_target} not supported yet")


def _check_cxx11_abi():
    return get_backend_func("cxx_abi")


def convert_sigtype_to_int(sigty: str):
    MAP_SIGTYPE_TO_INT = {
        # Boolean
        "i1": 12,  # BOOL
        # Integer types
        "i8": 2,  # INT8
        "i16": 6,  # INT16
        "i32": 3,  # INT32
        "i64": 9,  # INT64
        # Unsigned integer types
        "u8": 4,  # UINT8
        "u16": 7,  # UINT16
        "u32": 8,  # UINT32
        "u64": 10,  # UINT64
        # Floating point types
        "fp16": 1,  # FLOAT16
        "bf16": 27,  # DT_BF16
        "fp32": 0,  # FLOAT
        "fp64": 11,  # DOUBLE
        "fp8e5": 35,  # FLOAT8_E5M2
        "fp8e4nv": 36,  # FLOAT8_E4M3FN
    }
    if sigty not in MAP_SIGTYPE_TO_INT:
        raise ValueError(f"Unsupported data type: {sigty}")

    return MAP_SIGTYPE_TO_INT[sigty]


def convert_dtype_to_numpy(dtype):
    return get_backend_func("type_convert")[dtype]


def _check_bishengir_able_save_ir() -> bool:
    bishengir_path = _get_npucompiler_path()
    if not _is_valid_bishengir_path(bishengir_path):
        print(f"ERROR: Invalid bishengir path format: {bishengir_path}")
        return False
    try:
        result = subprocess.run(
            [bishengir_path, "--help"],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
        )
        if result.returncode == 0 and 'save-linked-ir' in result.stdout:
            return True
        else:
            return False
    except Exception as e:
        print(f"ERROR: {e}")
        return False

def get_ascend_arch_from_env():
    arch = os.getenv("TRITON_ASCEND_ARCH", "")
    if arch == "":
        # User does not set arch by ENV. Thus directly return.
        return arch
    # User sets arch by ENV. Thus we need to check if it is supported.
    valid_arch_list = [
        "Ascend910B1",
        "Ascend910B2",
        "Ascend910B3",
        "Ascend910B4",
        "Ascend910_9362",
        "Ascend910_9372",
        "Ascend910_9381",
        "Ascend910_9382",
        "Ascend910_9391",
        "Ascend910_9392",
        "Ascend310B1",
        "Ascend310B2",
        "Ascend310B3",
        "Ascend310B4",
        "Ascend910_9579",
        "Ascend910_9581",
        "Ascend910_9589",
        "Ascend910_9599",
    ]
    is_valid = arch in valid_arch_list
    if not is_valid:
        valid_arch_str = ", ".join(valid_arch_list)
        raise ValueError(f"TRITON_ASCEND_ARCH = {arch} is invalid!"
                         f"Candidates are [{valid_arch_str}]")
    return arch


def is_ffts_supported(arch: str):
    '''
    Cases:
    - empty str: User does not specify arch, thus it runs on 910B/910D both of which support ffts. Return True.
    - Ascend310B4: 310B4 does not support ffts. Return False.
    - Other arch: 910B/910D supports ffts. Return True.
    '''
    if arch in ["Ascend910A", "Ascend310B4"]:
        return False
    return True


def force_disable_ffts():
    '''
    '''
    disable_ffts = os.getenv("TRITON_DISABLE_FFTS", "false").lower() in ("true", "1")
    return disable_ffts