import dataclasses

import torch
import torch_npu

from mindspeed.te.pytorch.fp8.constants import TensorKey, FormatEnum
from mindspeed.te.pytorch.fp8.recipes.recipe import Recipe, RecipeScaling
from mindspeed.te.pytorch.fp8.tensor import is_fp8_tensor, Float8Tensor
from mindspeed.te.pytorch.fp8.reuse import reuse_or_quantize
from mindspeed.te.pytorch.utils import view_as_n_dim, get_quant_dtype


class CurrentScalingRecipe(Recipe):
    def quantization(self, tensor, key, colwise, rowwise):
        if tensor is None:
            return tensor
        if is_fp8_tensor(tensor):  # if dtype is fp8
            return tensor
        dst_type_max = self.fp8_format.max if self.quant_dtype == torch_npu.hifloat8 else 0
        quant_tensor, scale = self.run_quantizer(
            tensor,
            key,
            torch_npu.npu_dynamic_quant,
            dst_type=self.quant_dtype,
            quant_mode='pertensor',
            dst_type_max=dst_type_max,
        )
        return Float8Tensor(quant_tensor, self.quant_dtype, scale, dtype=tensor.dtype)


@dataclasses.dataclass
class Float8CurrentScaling(RecipeScaling):
    recipe = CurrentScalingRecipe


class TensorwiseMatMul(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weight, need_grad=True):
        qdtype = get_quant_dtype()
        dst_type_max = FormatEnum.HIF8_15.value.max if qdtype.x == torch_npu.hifloat8 else 0
        x_quant, x_scale = torch_npu.npu_dynamic_quant(
            view_as_n_dim(x), dst_type=qdtype.x, quant_mode='pertensor', dst_type_max=dst_type_max
        )
        w_quant, w_scale = reuse_or_quantize(
            weight,
            TensorKey.weight,
            torch_npu.npu_dynamic_quant,
            dst_type=qdtype.w,
            quant_mode='pertensor',
            dst_type_max=dst_type_max,
        )

        output = torch_npu.npu_quant_matmul(
            x_quant, w_quant.t(), w_scale, pertoken_scale=x_scale, output_dtype=x.dtype, **qdtype.mm_kwargs
        )
        if len(x.shape) != 2:
            output = output.reshape(*x.shape[:-1], *output.shape[1:])
        if weight.requires_grad:
            ctx.save_for_backward(x, weight)
        ctx.x_quant, ctx.x_scale, ctx.w_quant, ctx.w_scale = x_quant, x_scale, w_quant, w_scale
        ctx.output_dtype = x.dtype
        return output

    @staticmethod
    def backward(ctx, grads: torch.Tensor):
        qdtype = get_quant_dtype()
        dst_type_max = FormatEnum.HIF8_224.value.max if qdtype.x == torch_npu.hifloat8 else 0
        grads_quant, grads_scale = torch_npu.npu_dynamic_quant(
            view_as_n_dim(grads),
            dst_type=qdtype.grads,
            quant_mode='pertensor',
            dst_type_max=dst_type_max,
        )
        x_quant, x_scale, w_quant, w_scale = ctx.x_quant, ctx.x_scale, ctx.w_quant, ctx.w_scale
        dx = torch_npu.npu_quant_matmul(
            grads_quant, w_quant, w_scale, pertoken_scale=grads_scale, output_dtype=ctx.output_dtype, **qdtype.mm_kwargs
        )
        if len(grads.shape) != 2:
            dx = dx.reshape(*grads.shape[:-1], *dx.shape[1:])
        dw = torch_npu.npu_quant_matmul(
            grads_quant.T,
            x_quant,
            x_scale,
            pertoken_scale=grads_scale,
            output_dtype=ctx.output_dtype,
            **qdtype.mm_kwargs,
        )
        return dx, dw, None, None, None