# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.

from __future__ import annotations

from typing import ClassVar, overload

from .loader import Loader


class ProxyMeta(type):

    def __call__(cls, *args, **kwargs):
        cpp_class = Loader.get_attr(cls.__name__)
        instance = cpp_class.__new__(cpp_class, *args, **kwargs)
        instance.__init__(*args, **kwargs)
        return instance

    def __getattribute__(self, name: str):
        cpp_class = Loader.get_attr(super().__getattribute__("__name__"))
        return getattr(cpp_class, name)


class ProxyBase(metaclass=ProxyMeta):
    pass


class MatmulConfigParams(ProxyBase):

    def __init__(self, mm_config_type: int = ..., enable_l1_cache_ub: bool = ..., schedule_type: ScheduleType = ...,
                 traverse: MatrixTraverse = ..., en_vec_nd2nz: bool = ...) -> None:
        ...


class MatmulApiTilingBase(ProxyBase):

    def __init__(self, *args, **kwargs) -> None:
        ...

    def enable_bias(self, is_bias_in: bool = ...) -> int:
        ...

    def get_base_k(self) -> int:
        ...

    def get_base_m(self) -> int:
        ...

    def get_base_n(self) -> int:
        ...

    def get_tiling(self, tiling: object) -> int:
        ...

    def set_a_layout(self, b: int, s: int, n: int, g: int, d: int) -> int:
        ...

    def set_a_type(self, pos: TPosition, type: CubeFormat, data_type: DataType, is_trans: bool) -> int:
        ...

    def set_b_layout(self, b: int, s: int, n: int, g: int, d: int) -> int:
        ...

    def set_b_type(self, pos: TPosition, type: CubeFormat, data_type: DataType, is_trans: bool) -> int:
        ...

    def set_batch_info_for_normal(self, batch_a: int, batch_b: int, m: int, n: int, k: int) -> int:
        ...

    def set_batch_num(self, batch: int) -> int:
        ...

    def set_bias_type(self, pos: TPosition, type: CubeFormat, data_type: DataType) -> int:
        ...

    def set_buffer_space(self, l1_size: int = ..., l0_c_size: int = ..., ub_size: int = ..., bt_size: int = ...) -> int:
        ...

    def set_c_layout(self, b: int, s: int, n: int, g: int, d: int) -> int:
        ...

    def set_c_type(self, pos: TPosition, type: CubeFormat, data_type: DataType) -> int:
        ...

    def set_dequant_type(self, dequant_type: DequantType) -> int:
        ...

    def set_double_buffer(self, a: bool, b: bool, c: bool, bias: bool, trans_nd2nz: bool = ...,
                          trans_nz2nd: bool = ...) -> int:
        ...

    @overload
    def set_fix_split(self, base_m_in: int = ..., base_n_in: int = ..., base_k_in: int = ...) -> int:
        ...

    def set_mad_type(self, mad_type: MatrixMadType) -> int:
        ...

    @overload
    def set_matmul_config_params(self, mm_config_type_in: int = ..., enable_l1_cache_ub_in: bool = ...,
                                 schedule_type_in: ScheduleType = ..., traverse_in: MatrixTraverse = ...,
                                 en_vec_nd2nz_in: bool = ...) -> None:
        ...

    @overload
    def set_matmul_config_params(self, config_params: MatmulConfigParams) -> None:
        ...

    @overload
    def set_org_shape(self, org_m_in: int, org_n_in: int, org_k_in: int) -> int:
        ...

    @overload
    def set_org_shape(self, org_m_in: int, org_n_in: int, org_ka_in: int, org_kb_in: int) -> int:
        ...

    def set_shape(self, m: int, n: int, k: int) -> int:
        ...

    def set_sparse(self, is_sparse_in: bool = ...) -> int:
        ...

    def set_split_range(self, max_base_m: int = ..., max_base_n: int = ..., max_base_k: int = ...,
                        min_base_m: int = ..., min_base_n: int = ..., min_base_k: int = ...) -> int:
        ...

    def set_traverse(self, traverse: MatrixTraverse) -> int:
        ...


class BatchMatmulTiling(MatmulApiTilingBase):

    def __init__(self, arg0) -> None:
        ...

    def get_core_num(self) -> object:
        ...


class CubeFormat(ProxyBase):
    COLUMN_MAJOR: ClassVar[CubeFormat] = ...
    ND: ClassVar[CubeFormat] = ...
    ND_ALIGN: ClassVar[CubeFormat] = ...
    NN: ClassVar[CubeFormat] = ...
    NZ: ClassVar[CubeFormat] = ...
    ROW_MAJOR: ClassVar[CubeFormat] = ...
    SCALAR: ClassVar[CubeFormat] = ...
    VECTOR: ClassVar[CubeFormat] = ...
    ZN: ClassVar[CubeFormat] = ...
    ZZ: ClassVar[CubeFormat] = ...

    @property
    def name(self) -> str:
        ...

    @property
    def value(self) -> int:
        ...


class DataType(ProxyBase):
    DT_BF16: ClassVar[DataType] = ...
    DT_BFLOAT16: ClassVar[DataType] = ...
    DT_BOOL: ClassVar[DataType] = ...
    DT_COMPLEX128: ClassVar[DataType] = ...
    DT_COMPLEX64: ClassVar[DataType] = ...
    DT_DOUBLE: ClassVar[DataType] = ...
    DT_DUAL: ClassVar[DataType] = ...
    DT_FLOAT: ClassVar[DataType] = ...
    DT_FLOAT16: ClassVar[DataType] = ...
    DT_INT16: ClassVar[DataType] = ...
    DT_INT2: ClassVar[DataType] = ...
    DT_INT32: ClassVar[DataType] = ...
    DT_INT4: ClassVar[DataType] = ...
    DT_INT64: ClassVar[DataType] = ...
    DT_INT8: ClassVar[DataType] = ...
    DT_MAX: ClassVar[DataType] = ...
    DT_QINT16: ClassVar[DataType] = ...
    DT_QINT32: ClassVar[DataType] = ...
    DT_QINT8: ClassVar[DataType] = ...
    DT_QUINT16: ClassVar[DataType] = ...
    DT_QUINT8: ClassVar[DataType] = ...
    DT_STRING: ClassVar[DataType] = ...
    DT_STRING_REF: ClassVar[DataType] = ...
    DT_UINT1: ClassVar[DataType] = ...
    DT_UINT16: ClassVar[DataType] = ...
    DT_UINT32: ClassVar[DataType] = ...
    DT_UINT64: ClassVar[DataType] = ...
    DT_UINT8: ClassVar[DataType] = ...
    DT_UNDEFINED: ClassVar[DataType] = ...
    DT_VARIANT: ClassVar[DataType] = ...

    @property
    def name(self) -> str:
        ...

    @property
    def value(self) -> int:
        ...


class DequantType(ProxyBase):
    SCALAR: ClassVar[DequantType] = ...
    TENSOR: ClassVar[DequantType] = ...

    @property
    def name(self) -> str:
        ...

    @property
    def value(self) -> int:
        ...


class MatmulApiTiling(MatmulApiTilingBase):

    def __init__(self, arg0) -> None:
        ...


class MatrixMadType(ProxyBase):
    HF32: ClassVar[MatrixMadType] = ...
    MXMODE: ClassVar[MatrixMadType] = ...
    NORMAL: ClassVar[MatrixMadType] = ...

    @property
    def name(self) -> str:
        ...

    @property
    def value(self) -> int:
        ...


class MatrixTraverse(ProxyBase):
    FIRSTM: ClassVar[MatrixTraverse] = ...
    FIRSTN: ClassVar[MatrixTraverse] = ...
    NOSET: ClassVar[MatrixTraverse] = ...

    @property
    def name(self) -> str:
        ...

    @property
    def value(self) -> int:
        ...


class MultiCoreMatmulTiling(MatmulApiTilingBase):

    def __init__(self, arg0) -> None:
        ...

    def enable_multi_core_split_k(self, flag: bool) -> None:
        ...

    def get_core_num(self) -> object:
        ...

    def get_single_shape(self) -> object:
        ...

    def set_align_split(self, align_m: int, align_n: int, align_k: int) -> int:
        ...

    def set_dim(self, dim: int) -> int:
        ...

    def set_single_range(self, max_m: int = ..., max_n: int = ..., max_k: int = ..., min_m: int = ..., min_n: int = ...,
                         min_k: int = ...) -> int:
        ...

    def set_single_shape(self, single_m_in: int = ..., single_n_in: int = ..., single_k_in: int = ...) -> int:
        ...


class PlatformAscendC(ProxyBase):

    def __init__(self, *args, **kwargs) -> None:
        ...


class PlatformAscendCManager(ProxyBase):

    def __init__(self, *args, **kwargs) -> None:
        ...

    @overload
    @staticmethod
    def get_instance() -> PlatformAscendC:
        ...

    @overload
    @staticmethod
    def get_instance(soc_version: str) -> PlatformAscendC:
        ...


class ScheduleType(ProxyBase):
    INNER_PRODUCT: ClassVar[ScheduleType] = ...
    OUTER_PRODUCT: ClassVar[ScheduleType] = ...

    @property
    def name(self) -> str:
        ...

    @property
    def value(self) -> int:
        ...


class TPosition(ProxyBase):
    A1: ClassVar[TPosition] = ...
    A2: ClassVar[TPosition] = ...
    B1: ClassVar[TPosition] = ...
    B2: ClassVar[TPosition] = ...
    C1: ClassVar[TPosition] = ...
    C2: ClassVar[TPosition] = ...
    CO1: ClassVar[TPosition] = ...
    CO2: ClassVar[TPosition] = ...
    GM: ClassVar[TPosition] = ...
    LCM: ClassVar[TPosition] = ...
    MAX: ClassVar[TPosition] = ...
    SHM: ClassVar[TPosition] = ...
    SPM: ClassVar[TPosition] = ...
    TSCM: ClassVar[TPosition] = ...
    VECCALC: ClassVar[TPosition] = ...
    VECIN: ClassVar[TPosition] = ...
    VECOUT: ClassVar[TPosition] = ...

    @property
    def name(self) -> str:
        ...

    @property
    def value(self) -> int:
        ...