# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2026, Huawei Technologies Co., Ltd. All rights reserved.
#
# See LICENSE for license information.
from typing import Optional

import torch
import torch_npu

from ..constants import USAGE_WITH_TRANS
from ..quantization import FP8GlobalStateManager
from ..quantized_tensor import QuantizedTensor
from ..tensor.grouped_tensor import GroupedTensor
from ..utils import get_quant_dtype, view_as_n_dim


def _align_group_list_device(group_list, device):
    """Keep grouped GEMM split metadata on the same device as input tensors."""
    if isinstance(group_list, torch.Tensor) and group_list.device != device:
        return group_list.to(device=device)
    return group_list


class MXFP8MatMul(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor, weight: torch.Tensor, need_grad: bool = True):
        recipe = FP8GlobalStateManager.get_fp8_recipe()
        qdtype = get_quant_dtype(recipe.fp8_format)
        dst_type = qdtype.fwd
        ctx.qdtype = qdtype
        x_2d = view_as_n_dim(x)
        ctx.output_dtype = x.dtype
        if need_grad:
            x_quant, x_scale, ctx.x, ctx.x_scale = torch_npu.npu_dynamic_mx_quant_with_dual_axis(
                x_2d, dst_type=dst_type
            )
            w_quant, w_scale, ctx.w, ctx.w_scale = torch_npu.npu_dynamic_mx_quant_with_dual_axis(
                weight, dst_type=dst_type
            )
        else:
            x_quant, x_scale = torch_npu.npu_dynamic_mx_quant(x_2d, axis=-1, dst_type=dst_type)
            w_quant, w_scale = torch_npu.npu_dynamic_mx_quant(weight, axis=-1, dst_type=dst_type)
            ctx.save_for_backward(x, weight)
        output = torch_npu.npu_quant_matmul(
            x_quant,
            w_quant.t(),
            w_scale.transpose(0, 1),
            pertoken_scale=x_scale,
            output_dtype=x.dtype,
            scale_dtype=torch_npu.float8_e8m0fnu,
            pertoken_scale_dtype=torch_npu.float8_e8m0fnu,
            group_sizes=[1, 1, 32],
        )
        if len(x.shape) != 2:
            output = output.reshape(*x.shape[:-1], *output.shape[1:])
        if weight.requires_grad:
            output.requires_grad = True
        return output

    @staticmethod
    def backward(ctx, grads: torch.Tensor):
        qdtype = ctx.qdtype
        grads_dx, grads_dx_scale, grads_dw, grads_dw_scale = (
            torch_npu.npu_dynamic_mx_quant_with_dual_axis(view_as_n_dim(grads), dst_type=qdtype.bwd)
        )

        if hasattr(ctx, "x"):
            x_quant, x_scale, w_quant, w_scale = ctx.x, ctx.x_scale, ctx.w, ctx.w_scale
        else:
            x, weight = ctx.saved_tensors
            w_quant, w_scale = torch_npu.npu_dynamic_mx_quant(weight, axis=-2, dst_type=qdtype.fwd)
            x_quant, x_scale = torch_npu.npu_dynamic_mx_quant(
                view_as_n_dim(x), axis=-2, dst_type=qdtype.fwd
            )

        dx = torch_npu.npu_quant_matmul(
            grads_dx,
            w_quant,
            w_scale,
            pertoken_scale=grads_dx_scale,
            output_dtype=ctx.output_dtype,
            scale_dtype=torch_npu.float8_e8m0fnu,
            pertoken_scale_dtype=torch_npu.float8_e8m0fnu,
            group_sizes=[1, 1, 32],
        )
        if len(grads.shape) != 2:
            dx = dx.reshape(*grads.shape[:-1], *dx.shape[1:])

        dw = torch_npu.npu_quant_matmul(
            grads_dw.t(),
            x_quant,
            x_scale,
            pertoken_scale=grads_dw_scale.transpose(0, 1),
            output_dtype=ctx.output_dtype,
            scale_dtype=torch_npu.float8_e8m0fnu,
            pertoken_scale_dtype=torch_npu.float8_e8m0fnu,
            group_sizes=[1, 1, 32],
        )
        return dx, dw, None, None, None


def general_gemm(A, B, usage_a, usage_b, out_dtype, bias=None):
    if isinstance(A, QuantizedTensor):
        out = A.matmul(B, usage_a, usage_b, out_dtype)
        return out if bias is None else out + bias

    if usage_a in USAGE_WITH_TRANS:
        A = A.t()
    if usage_b in USAGE_WITH_TRANS:
        B = B.t()

    out = torch.matmul(A, B)
    return out if bias is None else out + bias


def general_gemm_add(main_grad, A, B, usage_a, usage_b, out_dtype):
    # 检查total_input/grad_output的shape是否有维度为0
    if not all(A.shape) or not all(B.shape):
        return
    if isinstance(A, QuantizedTensor):
        A.matmul_add(main_grad, B, usage_a, usage_b, out_dtype)
        return

    if usage_a in USAGE_WITH_TRANS:
        A = A.t()
    if usage_b in USAGE_WITH_TRANS:
        B = B.t()
    main_grad.addmm_(A, B)


def matmul_add(main_grad, inp, grad_out):
    # 检查total_input/grad_output的shape是否有维度为0
    if not all(inp.shape) or not all(grad_out.shape):
        return
    main_grad.addmm_(grad_out.t(), inp)
    # TODO @Muu 看下这里怎么添加下a3的适配
    # else:
    #     matmul_add_ops = matmul_add_op_builder.load()
    #     matmul_add_ops.npu_matmul_add_fp32(grad_out, inp, grad)


def general_grouped_gemm(
    B,
    A,
    group_split: torch.Tensor,
    out: Optional[torch.Tensor] = None,
    layout="TN",
    use_bias: bool = False,
    biases=None,
    group_type: int = 0,
    group_list_type: int = 1,
    split_item: int = 3,
    out_dtype: torch.dtype = torch.bfloat16,
):
    if isinstance(A, GroupedTensor) and isinstance(B, GroupedTensor):
        if out is not None:
            general_grouped_gemm_add_(out, B, A, group_split, layout, group_list_type)
            return None
        else:
            return general_grouped_gemm_for_grouped_tensor(
                B,
                A,
                group_split,
                layout,
                use_bias,
                biases,
                group_type,
                group_list_type,
                split_item,
                out_dtype,
            )

    if isinstance(B, torch.Tensor):
        B = [B]

    if layout[0] == "T":
        B = [b.transpose(-1, -2) for b in B]
    if layout[1] == "T":
        if not A.is_contiguous():
            A = A.contiguous()
        A = A.t()

    if A.shape[-1] == 0 and B[0].shape[0] == 0:
        return torch.zeros(
            len(group_split),
            *A.shape[:-1],
            *B[0].shape[1:],
            device=A.device,
            dtype=out_dtype,
        )

    out = torch_npu.npu_grouped_matmul(
        [A],
        B,
        group_list=group_split,
        output_dtype=out_dtype,
        group_type=group_type,
        group_list_type=group_list_type,
        split_item=split_item,
    )[0]

    if use_bias and biases is not None:
        out = out_add_biases(out, biases, group_split)
    return out


def general_grouped_gemm_for_grouped_tensor(
    B: GroupedTensor,
    A: GroupedTensor,
    group_split: torch.Tensor,
    layout="TN",
    use_bias: bool = False,
    biases=None,
    group_type: int = 0,
    group_list_type: int = 1,
    split_item: int = 3,
    out_dtype: torch.dtype = torch.bfloat16,
):
    # 暂时只支持单单单模式

    if not isinstance(A, GroupedTensor) or not isinstance(B, GroupedTensor):
        raise TypeError(
            "general_grouped_gemm_for_grouped_tensor expects GroupedTensor inputs, "
            f"but got {type(A)}(A) with {type(B)}(B)"
        )

    if (A.quantizer is None) != (B.quantizer is None):
        raise RuntimeError("Mixed dense and quantized GroupedTensor GEMM is not supported")

    a_data, a_scale = A.get_data(f"L{layout[1]}")
    b_data, b_scale = B.get_data(f"R{layout[0]}")

    is_low_precision = A.quantizer is not None and B.quantizer is not None
    if (
        (not is_low_precision or not A.quantizer._get_compatible_recipe().mxfp4())
        and a_data.shape[-1] == 0
        and b_data.shape[0] == 0
    ):
        return torch.zeros(
            len(group_split),
            *a_data.shape[:-1],
            *b_data.shape[1:],
            device=A.device,
            dtype=out_dtype,
        )

    group_split = _align_group_list_device(group_split, a_data.device)
    gmm_kwargs = {}
    if is_low_precision:
        gmm_kwargs["scale"] = [b_scale]
        gmm_kwargs["per_token_scale"] = [a_scale]
        recipe = A.quantizer._get_compatible_recipe()
        if recipe.mxfp8():
            gmm_kwargs["scale_dtype"] = torch_npu.float8_e8m0fnu
            gmm_kwargs["per_token_scale_dtype"] = torch_npu.float8_e8m0fnu
        elif recipe.mxfp4():
            gmm_kwargs["scale_dtype"] = torch_npu.float8_e8m0fnu
            gmm_kwargs["per_token_scale_dtype"] = torch_npu.float8_e8m0fnu
            if not (A.quantizer.with_rht and B.quantizer.with_rht):
                gmm_kwargs["x_dtype"] = torch_npu.float4_e2m1fn_x2
                gmm_kwargs["weight_dtype"] = torch_npu.float4_e2m1fn_x2
        if A.quantizer.dtype == torch_npu.hifloat8:
            gmm_kwargs["x_dtype"] = torch_npu.hifloat8
            gmm_kwargs["weight_dtype"] = torch_npu.hifloat8

    out = torch_npu.npu_grouped_matmul(
        [a_data],
        [b_data],
        group_list=group_split,
        output_dtype=out_dtype,
        group_type=group_type,
        group_list_type=group_list_type,
        split_item=split_item,
        **gmm_kwargs,
    )[0]

    if use_bias and biases is not None:
        out = out_add_biases(out, biases, group_split)
    return out


def general_grouped_gemm_add_(
    out: torch.Tensor,
    B: GroupedTensor,
    A: GroupedTensor,
    group_split: torch.Tensor,
    layout="TN",
    group_list_type: int = 1,
):
    if not isinstance(A, GroupedTensor) or not isinstance(B, GroupedTensor):
        raise TypeError("general_grouped_gemm_add_ expects GroupedTensor inputs")

    if (A.quantizer is None) != (B.quantizer is None):
        raise RuntimeError("Mixed dense and quantized GroupedTensor GEMM is not supported")

    a_data, a_scale = A.get_data(f"L{layout[1]}")
    b_data, b_scale = B.get_data(f"R{layout[0]}")

    if a_data.shape[-1] == 0 and b_data.shape[0] == 0:
        return
    group_split = _align_group_list_device(group_split, a_data.device)
    is_low_precision = A.quantizer is not None and B.quantizer is not None
    g_size = len(group_split)

    if not is_low_precision:
        torch_npu.npu_grouped_matmul_add_(
            out.view(g_size, b_data.shape[-1], -1), b_data, a_data, group_split
        )
    elif A.quantizer._get_compatible_recipe().mxfp8():
        torch_npu.npu_add_quant_gmm_(
            out.view(g_size, b_data.shape[-1], -1),
            b_data,
            a_data,
            a_scale,
            x1_scale=b_scale,
            group_list_type=group_list_type,
            group_list=group_split,
            x1_scale_dtype=torch_npu.float8_e8m0fnu,
            x2_scale_dtype=torch_npu.float8_e8m0fnu,
        )
    else:
        raise RuntimeError(
            "Currently, general_grouped_gemm_add_ only supports BF16 and MXFP8 formats"
        )


def out_add_biases(out, biases, group_split):
    if isinstance(group_split, torch.Tensor):
        group_split = group_split.tolist()
    out = list(torch.split(out, group_split))
    for i, bias in enumerate(biases):
        out[i] = out[i] + bias
    return torch.cat(out, dim=0)