# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
import logging
from typing import Optional, NamedTuple

import torch
import torch_npu

from megatron.training import get_args

from mindspeed.te.pytorch.fp8 import get_matmul_wise_by_tensor_key, MatmulKey
from mindspeed.te.pytorch.fp8.constants import TensorKey
from mindspeed.te.pytorch.fp8.state_manager import FP8GlobalStateManager
from mindspeed.te.pytorch.module_typing import FP8Metadata
from mindspeed.te.pytorch.utils import get_hccl_comm_name, all_gather_along_dim, get_quant_dtype, view_as_n_dim

logger = logging.getLogger(__name__)


class Float8Tensor:
    def __init__(
        self,
        data: torch.Tensor,
        fp8_dtype: torch.dtype,
        fp8_scale: Optional[torch.Tensor] = None,
        dtype: torch.dtype = torch.float32,
    ):
        self.data = data
        self.fp8_dtype = fp8_dtype
        self.fp8_scale = fp8_scale
        self._dtype = dtype

    @property
    def shape(self):
        return self.data.shape

    @property
    def device(self):
        return self.data.device

    @property
    def dtype(self):
        return self._dtype

    def reshape(self, *args):
        self.data = self.data.reshape(*args)
        return self

    def view(self, *args):
        return self.__class__(
            data=self.data.view(*args),
            fp8_dtype=self.fp8_dtype,
            fp8_scale=self.fp8_scale,
            dtype=self.dtype,
        )

    def t(self):
        data = self.data.t()
        fp8_scale = self.fp8_scale
        return Float8Tensor(
            data=data,
            fp8_dtype=self.fp8_dtype,
            fp8_scale=fp8_scale,
            dtype=self.dtype,
        )

    def get_quant_data(self):
        return self.data, self.fp8_scale

    def quant_matmul(self, other: 'Float8Tensor', is_rowwise: tuple[bool, bool], key: MatmulKey):
        x1 = self.t() if is_rowwise[0] else self
        x2 = other.t() if is_rowwise[1] else other
        qdtype = get_quant_dtype()
        output = torch_npu.npu_quant_matmul(
            x1.data, x2.data, x2.fp8_scale, pertoken_scale=x1.fp8_scale, output_dtype=x1.dtype, **qdtype.mm_kwargs
        )
        # te cpu compare
        args = get_args()
        if args.te_comparison_with_cpu:
            from mindspeed.te.pytorch.fp8 import te_online_comparison_cpu

            te_online_comparison_cpu(x1, x2, output)
        if args.te_comparison_with_bf16:
            from mindspeed.te.pytorch.fp8 import te_online_comparison_bf16

            te_online_comparison_bf16(x1, x2, output)
        return output

    def all_gather_matmul(self, other: 'Float8Tensor', bias, fp8_meta: FP8Metadata, key: MatmulKey):
        x1_need_transpose, x2_need_transpose = get_matmul_wise_by_tensor_key(self, key)
        _, x1_scale = all_gather_along_dim(self.fp8_scale)
        x1 = view_as_n_dim(self.data).t() if x1_need_transpose else view_as_n_dim(self.data)
        x2 = view_as_n_dim(other.data).t() if x2_need_transpose else view_as_n_dim(other.data)
        # x1 scale 因为是单标量tensor 与之前allgather输出 tensor相同 可以省一次allgather
        hcomm_name = get_hccl_comm_name(fp8_meta.tp_group, fp8_meta.tp_rank)
        output, gather_out, _ = torch_npu.npu_all_gather_quant_mm(
            x1,
            x2,
            hcomm_name,
            fp8_meta.tp_world_size,
            bias=bias,
            x1_scale=self.fp8_scale,
            x2_scale=other.fp8_scale,
            y_dtype=self.dtype,
        )
        gather_out = Float8Tensor(gather_out, self.fp8_dtype, x1_scale, self.dtype)
        return output.view(-1, self.shape[1], output.shape[1]), gather_out

    def matmul_reduce_scatter(self, other: 'Float8Tensor', bias, fp8_meta: FP8Metadata, key: MatmulKey):
        x1_need_transpose, x2_need_transpose = get_matmul_wise_by_tensor_key(self, key)
        x1 = view_as_n_dim(self.data).t() if x1_need_transpose else view_as_n_dim(self.data)
        x2 = view_as_n_dim(other.data).t() if x2_need_transpose else view_as_n_dim(other.data)

        hcomm_name = get_hccl_comm_name(fp8_meta.tp_group, fp8_meta.tp_rank)
        output, _ = torch_npu.npu_quant_mm_reduce_scatter(
            x1,
            x2,
            hcomm_name,
            fp8_meta.tp_world_size,
            bias=bias,
            reduce_op='sum',
            x1_scale=self.fp8_scale,
            x2_scale=other.fp8_scale,
            **get_quant_dtype().mm_kwargs,
            y_dtype=self.dtype,
        )
        return output.view(-1, self.shape[1], output.shape[1])


class QuantTensorMeta(NamedTuple):
    data: torch.Tensor
    scale: torch.Tensor

    def t(self):
        return self.data.T, self.scale.transpose(0, 1)


class Float8Tensor2D:
    col_tensor: QuantTensorMeta
    row_tensor: QuantTensorMeta

    def __init__(
        self,
        fp8_dtype: torch.dtype,
        origin_shape: torch.Size,
        device: 'torch.device',
        dtype: torch.dtype = torch.float32,
        key: TensorKey = None,
    ):
        self.fp8_dtype = fp8_dtype
        self.origin_shape = origin_shape
        self.device = device
        self.dtype = dtype
        self.key = key

    def set_col_data(self, data, scale, t=False):
        if data is None:
            return
        self.col_tensor = QuantTensorMeta(data.T if t else data, scale.transpose(0, 1) if t else scale)

    def set_row_data(self, data, scale, t=False):
        if data is None:
            return
        self.row_tensor = QuantTensorMeta(data.T if t else data, scale.transpose(0, 1) if t else scale)

    def get_quant_data(self, is_rowwise=False):
        return self.row_tensor if is_rowwise else self.col_tensor

    def t(self):
        raise ValueError(f'{self.__class__.__name__} not support transpose')

    def quant_matmul(self, other: 'Float8Tensor2D', is_rowwise, key: MatmulKey):
        raise NotImplementedError()

    def restore_reshape(self, other: 'Float8Tensor2D', output: torch.Tensor):
        if len(self.origin_shape) == len(other.origin_shape):
            return output
        return output.reshape(*self.origin_shape[:-1], *output.shape[1:])

    def release(self, data: torch.Tensor, scale: torch.Tensor, matmul_key: MatmulKey = None) -> None:
        if self.key == TensorKey.weight and FP8GlobalStateManager.is_weight_quantization_reuse_configured():
            return
        from mindspeed.te.pytorch.fp8.recipes import MXFP832x32BlockScaling

        if (
            self.key == TensorKey.weight
            and matmul_key == MatmulKey.forward
            and isinstance(FP8GlobalStateManager.get_fp8_recipe(), MXFP832x32BlockScaling)
        ):
            scale.untyped_storage().resize_(0)
            return
        data.untyped_storage().resize_(0)
        scale.untyped_storage().resize_(0)


def te_cast_comparison(fp8_format, tensor, quant_tensor):
    from mindspeed.te.pytorch.fp8 import cast_to_fp8_cpu

    if fp8_format.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]:
        raise ValueError(
            f"TE online comparison only supports e4m3 and e5m2 formats, but fp8_dtype is {fp8_format.dtype}"
        )
    tensor_cpu = tensor.cpu()
    quant_tensor_cpu = cast_to_fp8_cpu(tensor_cpu, fp8_format)

    quant_tensor_cpu = quant_tensor_cpu.npu()
    abs_error = torch.abs(quant_tensor_cpu.to(torch.float32) - quant_tensor.to(torch.float32))
    rel_error = abs_error / torch.abs(quant_tensor_cpu.to(torch.float32))
    max_abs_error = torch.max(abs_error)
    max_rel_error = torch.max(rel_error)

    logger.info("The error of cast to fp8: ")
    logger.info("[%s] Max Absolute Error: %s", quant_tensor.device, max_abs_error.item())
    logger.info("[%s] Max Relative Error: %s", quant_tensor.device, max_rel_error.item())
    if max_rel_error > 0.0:
        raise ValueError(f"The error of cast exceeds tolerance: {max_rel_error.item()}")