import torch
import torch.fx
import torch.utils._pytree as pytree

from . import is_ascend950


aten = torch.ops.aten
prims = torch.ops.prims

DVM_OP_REGISTRY = {}

DVM_SUPPORT_TYPE = [
    torch.bfloat16,
    torch.float16,
    torch.float32,
    torch.int32,
    torch.bool,
]


DVM_SUPPORT_FLOAT_TYPE = [
    torch.bfloat16,
    torch.float16,
    torch.float32,
]

DVM_DTYPE_MAP = {
    torch.bfloat16: "dvm.bfloat16",
    torch.float16: "dvm.float16",
    torch.float32: "dvm.float32",
    torch.int32: "dvm.int32",
    torch.int64: "dvm.int64",
    torch.bool: "dvm.bool_",
}


def to_dvm_dtype(dtype):
    if dtype is None:
        return None
    if isinstance(dtype, torch.dtype):
        if dtype not in DVM_DTYPE_MAP:
            raise NotImplementedError(f"Unsupported dtype for DVM: {dtype}")
        return DVM_DTYPE_MAP[dtype]
    return dtype


def _check_dtype(inputs, supported_dtypes):
    for inp in inputs:
        if not isinstance(inp, torch.fx.Node):
            continue
        if "val" not in inp.meta:
            continue

        for meta in pytree.tree_leaves(inp.meta["val"]):
            if not isinstance(meta, torch._subclasses.FakeTensor):
                continue
            if meta.dtype not in supported_dtypes:
                return False
    return True


def where_rule(node: torch.fx.Node):
    return _check_dtype(node.args[1:], DVM_SUPPORT_FLOAT_TYPE)


def _is_last2_transpose_tensor(t: torch._subclasses.FakeTensor) -> bool:
    if t.dim() < 2:
        return False

    if t.is_contiguous():
        return False

    if not (t.stride(-2) == 1 and t.stride(-1) == t.size(-2)):
        return False

    batch = t.size(-1) * t.size(-2)
    for d in range(t.dim() - 3, -1, -1):
        if t.stride(d) != batch:
            return False
        batch *= t.size(d)

    return True


def mm_rule(node: torch.fx.Node):
    UINT16_MAX = (1 << 16) - 1
    UINT8_MAX = (1 << 8) - 1
    MAX_INNER = UINT16_MAX - UINT8_MAX
    SMALL_OUTPUT_MAX = 256

    def inner_axis_length(t: torch._subclasses.FakeTensor):
        if _is_last2_transpose_tensor(t):
            return t.mT.size(-1)
        return t.size(-1)

    def check(input_node):
        t = input_node.meta["val"]
        if t.dim() > 4 or t.dim() < 2:
            return False
        inner_axis = inner_axis_length(t)
        if not is_ascend950:
            if isinstance(inner_axis, torch.SymInt):
                return False
            if inner_axis > MAX_INNER:
                return False
        return True

    def check_output(output_node):
        t = output_node.meta["val"]
        last_two_dims = t.shape[-2:]
        if all(not isinstance(dim, torch.SymInt) for dim in last_two_dims) and all(
            dim <= SMALL_OUTPUT_MAX for dim in last_two_dims
        ):
            return False
        return True

    def check_k1_fusion(lhs_node, rhs_node):
        lhs_t = lhs_node.meta["val"]
        rhs_t = rhs_node.meta["val"]
        lhs_k = lhs_t.shape[-1]
        rhs_k = rhs_t.shape[-2]
        if isinstance(lhs_k, torch.SymInt) or isinstance(rhs_k, torch.SymInt):
            return True
        if lhs_k == 1 and rhs_k == 1:
            return (not _is_last2_transpose_tensor(lhs_t)) and (
                not _is_last2_transpose_tensor(rhs_t)
            )
        return True

    if node.target in [aten.mm.default, aten.bmm.default]:
        lhs = node.args[0]
        rhs = node.args[1]
    elif node.target is aten.addmm.default:
        lhs = node.args[1]
        rhs = node.args[2]
    else:
        return False
    if node.meta["val"].dtype not in (torch.float16, torch.bfloat16):
        return False

    return (
        check(lhs) and check(rhs) and check_output(node) and check_k1_fusion(lhs, rhs)
    )


class DvmOpInfo:
    def __init__(
        self,
        func,
        input_dtypes=DVM_SUPPORT_FLOAT_TYPE,
        output_dtypes=DVM_SUPPORT_FLOAT_TYPE,
        rule=None,
    ):
        self.func = func
        self.input_dtypes = input_dtypes
        self.output_dtypes = output_dtypes
        self.rule = rule

    def is_supported(self, node: torch.fx.Node):
        inputs = pytree.arg_tree_leaves(*node.args, **node.kwargs)
        return (
            (
                self.input_dtypes is None
                or _check_dtype(inputs, self.input_dtypes)
            )
            and (
                self.output_dtypes is None
                or _check_dtype([node], self.output_dtypes)
            )
            and (self.rule is None or self.rule(node))
        )

    def __iter__(self):
        yield self.func
        yield self.is_supported


def register_dvm_op(
    *ops,
    input_dtypes=DVM_SUPPORT_FLOAT_TYPE,
    output_dtypes=DVM_SUPPORT_FLOAT_TYPE,
    rule=None,
):
    def decorator(func):
        info = DvmOpInfo(
            func,
            input_dtypes=input_dtypes,
            output_dtypes=output_dtypes,
            rule=rule,
        )
        for op in ops:
            DVM_OP_REGISTRY[op] = info
        return func

    return decorator


_DEFAULT_OP_INFO = DvmOpInfo(None)


def common_rule(node: torch.fx.Node):
    return _DEFAULT_OP_INFO.is_supported(node)


def format_shape(shape):
    if isinstance(shape, (int, torch.SymInt)):
        shape = [shape]
    return "[" + ", ".join(map(str, shape)) + "]"


@register_dvm_op(
    aten.add.Tensor,
    aten.add.Scalar,
    input_dtypes=[*DVM_SUPPORT_FLOAT_TYPE, torch.int32],
    output_dtypes=[*DVM_SUPPORT_FLOAT_TYPE, torch.int32],
)
def add(x, y):
    return f"k.add({x}, {y})"


@register_dvm_op(
    aten.sub.Tensor,
    aten.sub.Scalar,
    input_dtypes=[*DVM_SUPPORT_FLOAT_TYPE, torch.int32],
    output_dtypes=[*DVM_SUPPORT_FLOAT_TYPE, torch.int32],
)
def sub(x, y):
    return f"k.sub({x}, {y})"


@register_dvm_op(
    aten.mul.Tensor,
    aten.mul.Scalar,
    input_dtypes=[*DVM_SUPPORT_FLOAT_TYPE, torch.int32],
    output_dtypes=[*DVM_SUPPORT_FLOAT_TYPE, torch.int32],
)
def mul(x, y):
    return f"k.mul({x}, {y})"


@register_dvm_op(aten.div.Tensor, aten.div.Scalar)
def div(x, y):
    return f"k.div({x}, {y})"


@register_dvm_op(aten.pow.Tensor_Tensor, aten.pow.Tensor_Scalar, aten.pow.Scalar)
def pow_op(x, y):
    return f"k.pow({x}, {y})"


@register_dvm_op(
    aten.lt.Tensor,
    aten.lt.Scalar,
    input_dtypes=[*DVM_SUPPORT_FLOAT_TYPE, torch.int32],
    output_dtypes=[torch.bool],
)
def less(x, y):
    return f"k.less({x}, {y})"


@register_dvm_op(
    aten.le.Tensor,
    aten.le.Scalar,
    input_dtypes=[*DVM_SUPPORT_FLOAT_TYPE, torch.int32],
    output_dtypes=[torch.bool],
)
def less_equal(x, y):
    return f"k.less_equal({x}, {y})"


@register_dvm_op(
    aten.gt.Tensor,
    aten.gt.Scalar,
    input_dtypes=[*DVM_SUPPORT_FLOAT_TYPE, torch.int32],
    output_dtypes=[torch.bool],
)
def greater(x, y):
    return f"k.greater({x}, {y})"


@register_dvm_op(
    aten.ge.Tensor,
    aten.ge.Scalar,
    input_dtypes=[*DVM_SUPPORT_FLOAT_TYPE, torch.int32],
    output_dtypes=[torch.bool],
)
def greater_equal(x, y):
    return f"k.greater_equal({x}, {y})"


@register_dvm_op(aten.maximum.default)
def maximum(x, y):
    return f"k.maximum({x}, {y})"


@register_dvm_op(aten.minimum.default)
def minimum(x, y):
    return f"k.minimum({x}, {y})"


@register_dvm_op(aten.clamp_min.default)
def clamp_min(x, min_value):
    return maximum(x, min_value)


@register_dvm_op(aten.clamp_max.default)
def clamp_max(x, max_value):
    return minimum(x, max_value)


@register_dvm_op(
    aten.logical_and.default,
    aten.bitwise_and.Tensor,
    input_dtypes=[torch.bool],
    output_dtypes=[torch.bool],
)
def logical_and(x, y):
    return f"k.logical_and({x}, {y})"


@register_dvm_op(
    aten.logical_or.default,
    aten.bitwise_or.Tensor,
    input_dtypes=[torch.bool],
    output_dtypes=[torch.bool],
)
def logical_or(x, y):
    return f"k.logical_or({x}, {y})"


@register_dvm_op(
    aten.eq.Tensor,
    aten.eq.Scalar,
    input_dtypes=[*DVM_SUPPORT_FLOAT_TYPE, torch.int32],
    output_dtypes=[torch.bool],
)
def equal(x, y):
    return f"k.equal({x}, {y})"


@register_dvm_op(
    aten.ne.Tensor,
    aten.ne.Scalar,
    input_dtypes=[*DVM_SUPPORT_FLOAT_TYPE, torch.int32],
    output_dtypes=[torch.bool],
)
def not_equal(x, y):
    return f"k.not_equal({x}, {y})"


@register_dvm_op(aten.sqrt.default)
def sqrt(x):
    return f"k.sqrt({x})"


@register_dvm_op(aten.rsqrt.default)
def rsqrt(x):
    return div(1, sqrt(x))


@register_dvm_op(aten.abs.default)
def abs_op(x):
    return f"k.abs({x})"


@register_dvm_op(aten.log.default)
def log(x):
    return f"k.log({x})"


@register_dvm_op(aten.exp.default)
def exp(x):
    return f"k.exp({x})"


@register_dvm_op(aten.reciprocal.default)
def reciprocal(x):
    return f"k.reciprocal({x})"


@register_dvm_op(aten.isfinite.default, output_dtypes=[torch.bool])
def is_finite(x):
    return f"k.is_finite({x})"


@register_dvm_op(
    aten.logical_not.default,
    aten.bitwise_not.default,
    input_dtypes=[torch.bool],
    output_dtypes=[torch.bool],
)
def logical_not(x):
    return f"k.logical_not({x})"


@register_dvm_op(aten.round.default, aten.round.decimals)
def round_op(x):
    return f"k.round({x})"


@register_dvm_op(aten.floor.default)
def floor(x):
    return f"k.floor({x})"


@register_dvm_op(aten.ceil.default)
def ceil(x):
    return f"k.ceil({x})"


@register_dvm_op(aten.trunc.default)
def trunc(x):
    return f"k.trunc({x})"


@register_dvm_op(
    aten._to_copy.default,
    prims.convert_element_type.default,
    torch.ops.npu.npu_dtype_cast.default,
    torch.ops.npu.npu_dtype_cast_backward.default,
    torch.ops.npu._npu_dtype_cast.default,
    torch.ops.npu._npu_dtype_cast_backward.default,
    input_dtypes=DVM_SUPPORT_TYPE,
    output_dtypes=DVM_SUPPORT_TYPE,
)
def cast(x, dtype):
    dtype = to_dvm_dtype(dtype)
    return f"k.cast({x}, {dtype})"


@register_dvm_op(
    aten.expand.default,
    input_dtypes=DVM_SUPPORT_TYPE,
    output_dtypes=DVM_SUPPORT_TYPE,
)
def broadcast(x, shape):
    shape = format_shape(shape)
    return f"k.broadcast({x}, {shape})"


@register_dvm_op(
    aten.where.default,
    aten.where.self,
    input_dtypes=None,
    output_dtypes=DVM_SUPPORT_TYPE,
    rule=where_rule,
)
def select(x, y, z):
    return f"k.select({x}, {y}, {z})"


@register_dvm_op(aten.sum.dim_IntList, aten.sum.default)
def reduce_sum(x, dim=None, keepdim=False, dtype=None):
    if dim is None:
        dim = []
    dim = format_shape(dim)
    return f"k.sum({x}, {dim}, {keepdim})"


@register_dvm_op(aten.amax.default)
def reduce_max(x, dim=None, keepdim=False):
    if dim is None:
        dim = []
    dim = format_shape(dim)
    return f"k.max({x}, {dim}, {keepdim})"


@register_dvm_op(aten.amin.default)
def reduce_min(x, dim=None, keepdim=False):
    if dim is None:
        dim = []
    dim = format_shape(dim)
    return f"k.min({x}, {dim}, {keepdim})"


@register_dvm_op(
    aten.view.default,
    aten.reshape.default,
    aten._unsafe_view.default,
    input_dtypes=DVM_SUPPORT_TYPE,
    output_dtypes=DVM_SUPPORT_TYPE,
)
def reshape(x, shape):
    shape = format_shape(shape)
    return f"k.reshape({x}, {shape})"


@register_dvm_op(aten.neg.default)
def neg(x):
    return mul(x, -1)


@register_dvm_op(aten.relu.default)
def relu(x):
    return maximum(x, 0)


def copy(x):
    return f"k.copy({x})"


@register_dvm_op(aten.clone.default)
def clone(x, memory_format=None):
    return copy(x)


@register_dvm_op(
    aten.full.default,
    output_dtypes=DVM_SUPPORT_TYPE,
)
def full(
    size,
    fill_value,
    **kwargs,
):
    size = format_shape(size)
    dtype = to_dvm_dtype(kwargs.get("dtype"))
    return f"k.full({fill_value}, {size}, {dtype})"


@register_dvm_op(aten.mm.default, aten.bmm.default, rule=mm_rule)
def matmul(x, y, trans_a, trans_b):
    return f"k.matmul({x}, {y}, {trans_a}, {trans_b})"


def matmul_bias(bias, x, y, trans_a, trans_b, beta=1, alpha=1):
    return f"k.matmul({x}, {y}, {trans_a}, {trans_b},{bias})"


@register_dvm_op(aten.addmm.default, rule=mm_rule)
def addmm(z, x, y, trans_a, trans_b, use_bias, beta=1, alpha=1):
    if use_bias:
        return matmul_bias(z, x, y, trans_a, trans_b)
    if beta != 1:
        z = mul(z, beta)
    mm = matmul(x, y, trans_a, trans_b)
    if alpha != 1:
        mm = mul(mm, alpha)
    return add(mm, z)


def load(shape, dtype):
    dtype = to_dvm_dtype(dtype)
    return f"k.load({shape}, {dtype})"


def view_load(shape, stride, dtype):
    dtype = to_dvm_dtype(dtype)
    return f"k.view_load({shape}, {stride}, {dtype})"


def store(x, dtype=None):
    if dtype is None:
        return f"k.store({x})"
    dtype = to_dvm_dtype(dtype)
    return f"k.store({x}, {dtype})"