import dataclasses

import torch

import torch_npu
from mindspeed.te.pytorch.fp8.constants import TensorKey
from mindspeed.te.pytorch.fp8.recipes.recipe import Recipe, RecipeScaling, BlockDim
from mindspeed.te.pytorch.fp8.tensor import is_fp8_tensor
from mindspeed.te.pytorch.fp8.tensor.float8_block_tensor import Float8BlockTensor
from mindspeed.te.pytorch.fp8.reuse import reuse_or_quantize
from mindspeed.te.pytorch.utils import view_as_n_dim, get_quant_dtype


class Float8BlockRecipe(Recipe):
    left_dim = BlockDim(row_block_size=1, col_block_size=128)
    right_dim = BlockDim(row_block_size=128, col_block_size=128)
    rowwise, colwise = False, True

    quant_dim: dict[tuple[TensorKey, bool], BlockDim] = {
        (TensorKey.inputs, rowwise): right_dim,
        (TensorKey.inputs, colwise): left_dim,
        (TensorKey.weight, rowwise): right_dim,
        (TensorKey.weight, colwise): right_dim,
        (TensorKey.grads, rowwise): left_dim,
        (TensorKey.grads, colwise): left_dim,
    }

    def quantization(self, tensor: torch.Tensor, key, colwise, rowwise):
        if tensor is None:
            return tensor
        if is_fp8_tensor(tensor):
            return tensor
        tensor_2d = view_as_n_dim(tensor)
        col_data, row_data, col_scale, row_scale = None, None, None, None
        quant_tensor = Float8BlockTensor(self.fp8_format_dtype, tensor.shape, tensor.device, tensor.dtype, key=key)

        col_quant_dim = self.quant_dim[(key, self.colwise)]
        row_quant_dim = self.quant_dim[(key, self.rowwise)]

        if colwise:
            col_data, col_scale = self.run_quantizer(
                tensor_2d,
                key,
                torch_npu.npu_dynamic_block_quant,
                reuse_identity=tensor,
                dst_type=self.quant_dtype,
                **col_quant_dim,
            )
        if rowwise:
            row_data, row_scale = self.run_quantizer(
                tensor_2d.T if key == TensorKey.grads else tensor_2d,
                key,
                torch_npu.npu_dynamic_block_quant,
                reuse_identity=tensor,
                dst_type=self.quant_dtype,
                **row_quant_dim,
            )

        quant_tensor.set_col_data(col_data, col_scale, key == TensorKey.weight)
        quant_tensor.set_row_data(row_data, row_scale)
        return quant_tensor


@dataclasses.dataclass
class Float8BlockScaling(RecipeScaling):
    recipe = Float8BlockRecipe


class Float8BlockMatMul(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, weight, need_grad=True):
        qdtype = get_quant_dtype()
        x_mxfp8, x_scale = torch_npu.npu_dynamic_block_quant(view_as_n_dim(x), dst_type=qdtype.x,
                                                             **Float8BlockRecipe.left_dim)
        w_quant, w_scale = reuse_or_quantize(
            weight,
            TensorKey.weight,
            torch_npu.npu_dynamic_block_quant,
            dst_type=qdtype.w,
            **Float8BlockRecipe.right_dim,
        )
        output = torch_npu.npu_quant_matmul(x_mxfp8, w_quant.t(), w_scale.transpose(0, 1), pertoken_scale=x_scale,
                                            output_dtype=x.dtype, group_sizes=[1, 128, 128])
        if len(x.shape) != 2:
            output = output.reshape(*x.shape[:-1], *output.shape[1:])
        if weight.requires_grad:
            output.requires_grad = True
        ctx.save_for_backward(x, weight)
        return output

    @staticmethod
    def backward(ctx, grads: torch.Tensor):
        x, weight = ctx.saved_tensors
        qdtype = get_quant_dtype()
        grads_quant, grads_scale = torch_npu.npu_dynamic_block_quant(
            view_as_n_dim(grads), dst_type=qdtype.grads, **Float8BlockRecipe.left_dim)
        w_quant, w_scale = reuse_or_quantize(
            weight.t(),
            TensorKey.weight,
            torch_npu.npu_dynamic_block_quant,
            dst_type=qdtype.w,
            **Float8BlockRecipe.right_dim,
        )
        dx = torch_npu.npu_quant_matmul(grads_quant, w_quant.t(), w_scale.transpose(0, 1), pertoken_scale=grads_scale,
                                        output_dtype=x.dtype, group_sizes=[1, 128, 128])
        if len(grads.shape) != 2:
            dx = dx.reshape(*grads.shape[:-1], *dx.shape[1:])

        grads_quant, grads_scale = torch_npu.npu_dynamic_block_quant(
            view_as_n_dim(grads).t(), dst_type=qdtype.grads, **Float8BlockRecipe.left_dim)
        x_quant, x_scale = torch_npu.npu_dynamic_block_quant(
            view_as_n_dim(x).t(), dst_type=qdtype.x, **Float8BlockRecipe.right_dim)
        dw = torch_npu.npu_quant_matmul(grads_quant, x_quant.t(), x_scale.transpose(0, 1), pertoken_scale=grads_scale,
                                        output_dtype=x.dtype, group_sizes=[1, 128, 128])
        return dx, dw, None, None, None