"""
this file is used to enhance the npu frontend API by set_option or other.
"""

__all__ = ["set_option", "set_aoe",
           "set_compile_mode", "set_mm_bmm_format_nd", "get_mm_bmm_format_nd",
           "is_jit_compile_false", "finalize_dump", "init_dump", "set_dump",
           "set_device_limit", "get_device_limit", "set_stream_limit",
           "reset_stream_limit", "get_stream_limit"]

from logging import exception
from enum import IntEnum, unique
import inspect
import os
import warnings
import torch_npu
import torch_npu._C
from torch_npu.utils._path_manager import PathManager
from torch_npu.utils._error_code import ErrCode, pta_error, prof_error
from .utils import _get_device_index

_option_map = {"ACL_PRECISION_MODE": ["allow_fp32_to_fp16", "must_keep_origin_dtype"],
               "ACL_OP_SELECT_IMPL_MODE": ["high_performance", "high_precision"],
               "ACL_AICORE_NUM": (lambda value: value.isdigit() and 1 <= int(value) <= 32),
               "ACL_OPTYPELIST_FOR_IMPLMODE": None,
               "ACL_OP_DEBUG_LEVEL": ["0", "1", "2", "3", "4"],
               "ACL_DEBUG_DIR": None,
               "ACL_OP_COMPILER_CACHE_MODE": ["disable", "enable", "force"],
               "ACL_OP_COMPILER_CACHE_DIR": None,
               "ACL_OP_DEBUG_OPTION": None}

_deprecated_option_set = {"ACL_OP_SELECT_IMPL_MODE", "ACL_OPTYPELIST_FOR_IMPLMODE"}


class _CubeMathType(IntEnum):
    KEEP_DTYPE = 0
    ALLOW_FP32_DOWN_PRECISION = 1
    USE_FP16 = 2
    USE_HF32 = 3
    FORCE_GRP_ACC_FOR_FP32 = 4 # deprecate, but use as a transition for now
    USE_FP32_ADD = 4


def _check_compile_option(name, value) -> bool:
    if name in _option_map.keys():
        if _option_map[name] is None:
            return True
        if callable(_option_map[name]):
            return _option_map[name](value)
        return value in _option_map[name]
    return True


def set_option(option):
    if not isinstance(option, dict):
        raise TypeError("npu option must be a dict." + pta_error(ErrCode.PARAM))

    if option.get("MM_BMM_ND_ENABLE") == "enable":
        set_mm_bmm_format_nd(True)
    elif option.get("MM_BMM_ND_ENABLE") == "disable":
        set_mm_bmm_format_nd(False)

    for option_name, option_value in option.items():
        if _check_compile_option(option_name, str(option_value)):
            option[option_name] = str(option_value)
        elif callable(_option_map[option_name]):
            raise ValueError(f"value of {option_name} should be in %s "
                             % (inspect.getsource(_option_map[option_name])) + f"but got {option_value}" +
                             pta_error(ErrCode.PARAM))
        else:
            raise ValueError(f"value of {option_name} should be in %s "
                             % (_option_map[option_name]) + f"but got {option_value}" +
                             pta_error(ErrCode.PARAM))

        if option_name in _deprecated_option_set:
            warnings.warn(f"{option_name} will be deprecated in future version. The accuracy or performance "
                          f"may not be the optimal when configuring this option. We do not recommend setting it.")

    torch_npu._C._npu_setOption(option)


def init_dump():
    option = {"mdldumpswitch": "enable"}
    torch_npu._C._npu_setOption(option)


def set_dump(cfg_file):
    if not os.path.exists(cfg_file):
        raise AssertionError("cfg_file %s path does not exists." % (cfg_file) + pta_error(ErrCode.NOT_FOUND))
    cfg_file = os.path.realpath(cfg_file)
    option = {"mdldumpconfigpath": cfg_file}
    torch_npu._C._npu_setOption(option)


def finalize_dump():
    option = {"mdldumpswitch": "disable"}
    torch_npu._C._npu_setOption(option)


def set_compile_mode(jit_compile=False):
    if torch_npu.npu.is_initialized():
        torch_npu.npu.synchronize()
    option = {"jitCompile": "enable" if jit_compile else "disable"}
    torch_npu._C._npu_setOption(option)


def set_aoe(dump_path):
    if not os.path.exists(dump_path):
        try:
            PathManager.make_dir_safety(dump_path)
        except TypeError:
            raise TypeError("Type of dump_path is invalid." + pta_error(ErrCode.TYPE)) from None
        except OSError:
            raise OSError("Value of dump_path is invalid." + pta_error(ErrCode.SYSCALL)) from None
    option = {"autotune": "enable", "autotunegraphdumppath": dump_path}
    torch_npu._C._npu_setOption(option)


"""
This global flag control mm and bmm use ND format to compute, if the flag is True,
we use ND format for mm and bmm in Linear module

useage:
```
option = {}
option["MM_BMM_ND_ENABLE"] = "enable"
torch.npu.set_option(option)
```

Default: False
"""
_MM_BMM_ND_ENABLE = True


def set_mm_bmm_format_nd(is_nd=True):
    global _MM_BMM_ND_ENABLE
    if is_nd:
        _MM_BMM_ND_ENABLE = True
    else:
        _MM_BMM_ND_ENABLE = False


def get_mm_bmm_format_nd():
    return _MM_BMM_ND_ENABLE


def is_jit_compile_false() -> bool:
    torch_npu.npu._lazy_init()
    return torch_npu._C._npu_is_jit_compile_false()


class _npuConfig:
    @classmethod
    def __setattr__(cls, name, value):
        if name == "allow_internal_format":
            option = {"ALLOW_INTERNAL_FORMAT": "enable" if value else "disable"}
            torch_npu._C._npu_setOption(option)


class _allowHF32Matmul:
    @classmethod
    def __setattr__(cls, name, value):
        if name == "allow_hf32":
            if not isinstance(value, bool):
                raise TypeError(
                    "allow_hf32 must be a bool, but got {}{}".format(
                        type(value).__name__, pta_error(ErrCode.TYPE)
                    )
                )
            option = {"ALLOW_MATMUL_HF32": "enable" if value else "disable"}
            torch_npu._C._npu_setOption(option)
        elif name == "cube_math_type":
            if(not isinstance(value, _CubeMathType)):
                raise TypeError(f"value should be one of Enum CubeMathType when setting cube_math_type, but got {type(value)}")
            torch_npu._C._npu_setOption({"CUBE_MATH_TYPE": str(value.value)})

    @classmethod
    def __getattr__(cls, name):
        if name == "allow_hf32":
            hf32_value = torch_npu._C._npu_getOption("ALLOW_MATMUL_HF32")
            return hf32_value is not None and hf32_value.decode() == "enable"
        elif name == "cube_math_type":
            cube_math_type_value = torch_npu._C._npu_getOption("CUBE_MATH_TYPE")
            if cube_math_type_value is not None and len(cube_math_type_value) > 0:
                return _CubeMathType(int(cube_math_type_value))
            # if cube_math_type is not None:
        return None


class _allowHF32Conv:
    @classmethod
    def __setattr__(cls, name, value):
        if name == "allow_hf32":
            if not isinstance(value, bool):
                raise TypeError(
                    "allow_hf32 must be a bool, but got {}{}".format(
                        type(value).__name__, pta_error(ErrCode.TYPE)
                    )
                )
            option = {"ALLOW_CONV_HF32": "enable" if value else "disable"}
            torch_npu._C._npu_setOption(option)

    @classmethod
    def __getattr__(cls, name):
        if name == "allow_hf32":
            hf32_value = torch_npu._C._npu_getOption("ALLOW_CONV_HF32")
            return (hf32_value is None) or (hf32_value.decode() == "") or (hf32_value.decode() == "enable")
        return None


class _call_once_class:
    def __init__(self, func):
        self.func = func
        self.called = False
        self.result = None

    def __call__(self, *args, **kwargs):
        if self.called:
            raise RuntimeError(f"Function '{self.func.__name__}' has already been called, You can only set this interface once.")

        self.called = True
        self.result = self.func(*args, **kwargs)
        return self.result


@_call_once_class
def set_device_limit(device, cube_num=-1, vector_num=-1):
    from torch_npu.npu import device_count
    if isinstance(device, bool) or not isinstance(device, int):
        raise TypeError(
            "device must be an int, but got {}{}".format(
                type(device).__name__, pta_error(ErrCode.TYPE)
            )
        )
    if device < 0 or device >= device_count():
        raise AssertionError("Invalid device id" + pta_error(ErrCode.VALUE))
    torch_npu.npu._lazy_init()
    if cube_num != -1:
        torch_npu._C._npu_set_device_res_limit(device, 0, cube_num)
    if vector_num != -1:
        torch_npu._C._npu_set_device_res_limit(device, 1, vector_num)


def get_device_limit(device):
    from torch_npu.npu import device_count
    if isinstance(device, bool) or not isinstance(device, int):
        raise TypeError(
            "device must be an int, but got {}{}".format(
                type(device).__name__, pta_error(ErrCode.TYPE)
            )
        )
    if device < 0 or device >= device_count():
        raise AssertionError("Invalid device id" + pta_error(ErrCode.VALUE))
    torch_npu.npu._lazy_init()
    return {"cube_core_num": torch_npu._C._npu_get_device_res_limit(device, 0), \
           "vector_core_num": torch_npu._C._npu_get_device_res_limit(device, 1)}


def set_stream_limit(stream, cube_num=-1, vector_num=-1):
    if stream is None:
        raise AssertionError("stream cannot be None" + pta_error(ErrCode.PARAM))
    if not isinstance(stream, torch_npu.npu.Stream):
        raise AssertionError(f"stream should be torch_npu.npu.Stream, could not be {type(stream)}" + pta_error(ErrCode.TYPE))
    torch_npu.npu._lazy_init()
    if cube_num != -1:
        torch_npu._C._npu_set_stream_res_limit(stream_id=stream.stream_id,
                                               device_index=stream.device_index,
                                               device_type=stream.device_type,
                                               type=0,
                                               value=cube_num)
    if vector_num != -1:
        torch_npu._C._npu_set_stream_res_limit(stream_id=stream.stream_id,
                                               device_index=stream.device_index,
                                               device_type=stream.device_type,
                                               type=1,
                                               value=vector_num)


def reset_stream_limit(stream):
    if stream is None:
        raise AssertionError("stream cannot be None" + pta_error(ErrCode.PARAM))
    if not isinstance(stream, torch_npu.npu.Stream):
        raise AssertionError(f"stream should be torch_npu.npu.Stream, could not be {type(stream)}" + pta_error(ErrCode.TYPE))
    torch_npu.npu._lazy_init()
    torch_npu._C._npu_reset_stream_res_limit(stream_id=stream.stream_id,
                                             device_index=stream.device_index,
                                             device_type=stream.device_type)


def get_stream_limit(stream):
    if stream is None:
        raise AssertionError("stream cannot be None" + pta_error(ErrCode.PARAM))
    if not isinstance(stream, torch_npu.npu.Stream):
        raise AssertionError(
            f"stream should be torch_npu.npu.Stream, could not be {type(stream)}" + pta_error(ErrCode.TYPE))
    torch_npu.npu._lazy_init()
    return {"cube_core_num": torch_npu._C._npu_get_stream_res_limit(stream_id=stream.stream_id,
                                                                    device_index=stream.device_index,
                                                                    device_type=stream.device_type,
                                                                    type=0), \
           "vector_core_num": torch_npu._C._npu_get_stream_res_limit(stream_id=stream.stream_id,
                                                                     device_index=stream.device_index,
                                                                     device_type=stream.device_type,
                                                                     type=1)}