import importlib.util
import unittest
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices


def numpy_hifloat8():
    try:
        from en_dtypes import hifloat8
        return hifloat8
    except ModuleNotFoundError:
        raise RuntimeError("en_dtypes is needed to support hifloat8 dtype!!! "
                        "Please install with `pip3 install en-dtypes`")
    except ImportError:
        raise RuntimeError("Please upgrade en_dtypes to v0.0.3 at least to support hifloat8 dtype!!! "
                        "Command is `pip3 install --upgrade en-dtypes`")


def numpy_float4_e2m1fn():
    try:
        from ml_dtypes import float4_e2m1fn
        return float4_e2m1fn
    except ModuleNotFoundError:
        raise RuntimeError("ml_dtypes is needed to support float4_e2m1fn dtype!!! "
                        "Please install with `pip3 install ml-dtypes`")
    except ImportError:
        raise RuntimeError("Please upgrade ml_dtypes to support float4_e2m1fn dtype!!! "
                        "Command is `pip3 install --upgrade ml-dtypes`")


def numpy_to_torch(np_arr):
    FP8_DTYPE_MAP_NUMPY_TO_TORCH = {
        "bfloat16": torch.bfloat16,
        "float8_e5m2": torch.float8_e5m2,
        "float8_e4m3fn": torch.float8_e4m3fn,
        "float8_e8m0": None if not hasattr(torch, "float8_e8m0fnu") else getattr(torch, "float8_e8m0fnu"),
        "hifloat8": torch_npu.hifloat8,
    }

    def _bitcast_float8_to_torch(np_arr):
        np_dtype = np_arr.dtype
        torch_dtype = FP8_DTYPE_MAP_NUMPY_TO_TORCH[str(np_dtype)]
        np_uint8 = np_arr.view(np.uint8)
        t_uint8 = torch.from_numpy(np_uint8)
        return t_uint8.view(torch_dtype)

    if str(np_arr.dtype) in list(FP8_DTYPE_MAP_NUMPY_TO_TORCH.keys()):
        return _bitcast_float8_to_torch(np_arr)
    return torch.from_numpy(np_arr)


def assert_tensors_close(x: torch.Tensor, y: torch.Tensor, rtol=1e-3, atol=1e-5, label="Tensor"):
    if x.device.type != "cpu":
        x = x.cpu()
    if y.device.type != "cpu":
        y = y.cpu()

    x_f32 = x.to(torch.float32)
    y_f32 = y.to(torch.float32)

    nan_mask_x = torch.isnan(x_f32)
    nan_mask_y = torch.isnan(y_f32)
    if not torch.equal(nan_mask_x, nan_mask_y):
        raise AssertionError(
            f"[{label}] NaN mismatch: x has NaNs at {torch.where(nan_mask_x)}, y has NaNs at {torch.where(nan_mask_y)}")

    valid_mask = ~nan_mask_x
    x_valid = x_f32[valid_mask]
    y_valid = y_f32[valid_mask]

    inf_mask_x = torch.isinf(x_valid)
    inf_mask_y = torch.isinf(y_valid)
    if not torch.equal(inf_mask_x, inf_mask_y):
        raise AssertionError(f"[{label}] Inf mismatch.")

    valid_mask_no_inf = ~inf_mask_x
    x_final = x_valid[valid_mask_no_inf]
    y_final = y_valid[valid_mask_no_inf]

    if x_final.numel() == 0:
        return

    diff = torch.abs(x_final - y_final)
    tolerance = atol + (rtol * torch.abs(y_final))

    failure_mask = diff > tolerance

    if torch.any(failure_mask):
        max_diff = diff.max().item()
        max_idx = torch.argmax(diff).item()

        y_safe = y_final.clone()
        y_safe[y_safe == 0] = 1e-12
        rel_error = (diff / torch.abs(y_safe)).max().item()

        raise AssertionError(
            f"[{label}] Tensors not close!\n"
            f"  Max absolute diff: {max_diff:.6e} at index {max_idx}\n"
            f"  Max relative error: {rel_error:.6e}\n"
            f"  Tolerance: atol={atol}, rtol={rtol}\n"
            f"  Shape: {x.shape}")


class TestMoeReRouting(TestCase):

    def generate_inputs(self, bs, hidden_dim, dtype, expert_num=16, rank_num=2):
        expert_token_num_per_rank = np.zeros((rank_num, expert_num), dtype=np.int64)
        for i in range(rank_num):
            for j in range(expert_num):
                expert_token_num_per_rank[i, j] = (bs // expert_num) + np.random.randint(0, 5)

        if dtype == torch_npu.float4_e2m1fn_x2:
            fp4_dtype = numpy_float4_e2m1fn()
            tokens = np.random.randn(bs, hidden_dim // 2).astype(fp4_dtype)
        elif dtype == torch_npu.hifloat8:
            tokens = np.random.randn(bs, hidden_dim).astype(np.float16)
            tokens = tokens.astype(numpy_hifloat8())
        else:
            tokens = np.random.randn(bs, hidden_dim).astype(np.float16)

        return tokens, expert_token_num_per_rank

    def custom_op_exec(self, tokens_npu, expert_token_num_per_rank_npu, per_token_scales_npu,
                        expert_token_num_type, idx_type, tokens_dtype=None):
        return torch_npu.npu_moe_re_routing(
            tokens_npu,
            expert_token_num_per_rank_npu,
            per_token_scales=per_token_scales_npu,
            expert_token_num_type=expert_token_num_type,
            idx_type=idx_type,
            tokens_dtype=tokens_dtype
        )

    def golden_calc(self, tokens, expert_token_num_per_rank, expert_token_num_type=1):
        bs = tokens.shape[0]
        hidden_dim = tokens.shape[1]
        rank_num = expert_token_num_per_rank.shape[0]
        expert_num = expert_token_num_per_rank.shape[1]

        permute_tokens = np.zeros((bs, hidden_dim), dtype=tokens.dtype)
        permute_per_token_scales = np.ones((bs,), dtype=np.float32)
        permute_token_idx = np.arange(bs, dtype=np.int32)

        if str(tokens.dtype) == "hifloat8":
            permute_tokens = tokens.copy()
        else:
            permute_tokens = tokens.copy()

        expert_token_num = expert_token_num_per_rank[0, :expert_num].copy()

        return permute_tokens, permute_per_token_scales, permute_token_idx, expert_token_num

    @SupportedDevices(['Ascend910B'])
    def test_npu_moe_re_routing_fp16(self, device="npu"):
        bs_list = [32, 128]
        hidden_dim_list = [4096]
        dtype_list = [torch.float16]

        for bs, hidden_dim, dtype in zip(bs_list, hidden_dim_list, dtype_list):
            tokens, expert_token_num_per_rank = self.generate_inputs(bs, hidden_dim, dtype)
            tokens_npu = torch.from_numpy(tokens).npu()
            expert_token_num_per_rank_npu = torch.from_numpy(expert_token_num_per_rank).npu()

            permute_tokens, permute_scales, permute_idx, expert_token_num = \
                self.custom_op_exec(tokens_npu, expert_token_num_per_rank_npu, None, 1, 0)

            golden_tokens, golden_scales, golden_idx, golden_expert_num = \
                self.golden_calc(tokens, expert_token_num_per_rank)

            self.assertEqual(permute_tokens.shape[0], bs)
            self.assertEqual(permute_tokens.shape[1], hidden_dim)

    @SupportedDevices(['Ascend910B'])
    def test_npu_moe_re_routing_bf16(self, device="npu"):
        bs_list = [32]
        hidden_dim_list = [4096]
        dtype_list = [torch.bfloat16]

        for bs, hidden_dim, dtype in zip(bs_list, hidden_dim_list, dtype_list):
            tokens, expert_token_num_per_rank = self.generate_inputs(bs, hidden_dim, dtype)
            tokens_npu = torch.from_numpy(tokens).npu()
            expert_token_num_per_rank_npu = torch.from_numpy(expert_token_num_per_rank).npu()

            permute_tokens, permute_scales, permute_idx, expert_token_num = \
                self.custom_op_exec(tokens_npu, expert_token_num_per_rank_npu, None, 1, 0)

            self.assertEqual(permute_tokens.shape[0], bs)
            self.assertEqual(permute_tokens.shape[1], hidden_dim)

    @unittest.skipIf(
        importlib.util.find_spec("en_dtypes") is None,
        "Unittest for hif8 need package en_dtypes"
    )
    @SupportedDevices(['Ascend950'])
    def test_npu_moe_re_routing_hif8(self, device="npu"):
        bs_list = [32, 128]
        hidden_dim_list = [4096, 7168]
        expert_token_num_type_list = [1]
        idx_type_list = [0, 1]

        for bs, hidden_dim, expert_token_num_type, idx_type in zip(
                bs_list, hidden_dim_list, expert_token_num_type_list, idx_type_list):
            tokens, expert_token_num_per_rank = self.generate_inputs(bs, hidden_dim, torch_npu.hifloat8)

            tokens_uint8 = tokens.view(np.uint8)
            tokens_npu = torch.from_numpy(tokens_uint8).npu()
            expert_token_num_per_rank_npu = torch.from_numpy(expert_token_num_per_rank).npu()

            permute_tokens, permute_scales, permute_idx, expert_token_num = \
                self.custom_op_exec(tokens_npu, expert_token_num_per_rank_npu, None,
                                    expert_token_num_type, idx_type, tokens_dtype=290)

            self.assertEqual(permute_tokens.dtype, torch.uint8)
            self.assertEqual(permute_tokens.shape[0], bs)
            self.assertEqual(permute_tokens.shape[1], hidden_dim)

            golden_tokens, golden_scales, golden_idx, golden_expert_num = \
                self.golden_calc(tokens, expert_token_num_per_rank, expert_token_num_type)

            permute_tokens_cpu = permute_tokens.cpu()
            golden_tokens_torch = numpy_to_torch(golden_tokens)

            assert_tensors_close(permute_tokens_cpu, golden_tokens_torch, rtol=1e-2, atol=1e-2,
                                 label=f"hif8_tokens_bs={bs}_h={hidden_dim}")

    @unittest.skipIf(
        importlib.util.find_spec("en_dtypes") is None,
        "Unittest for hif8 need package en_dtypes"
    )
    @SupportedDevices(['Ascend950'])
    def test_npu_moe_re_routing_hif8_with_scales(self, device="npu"):
        bs_list = [32]
        hidden_dim_list = [4096]
        expert_token_num_type = 1
        idx_type = 0

        for bs, hidden_dim in zip(bs_list, hidden_dim_list):
            tokens, expert_token_num_per_rank = self.generate_inputs(bs, hidden_dim, torch_npu.hifloat8)

            per_token_scales = np.random.randn(bs).astype(np.float32)

            tokens_uint8 = tokens.view(np.uint8)
            tokens_npu = torch.from_numpy(tokens_uint8).npu()
            expert_token_num_per_rank_npu = torch.from_numpy(expert_token_num_per_rank).npu()
            per_token_scales_npu = torch.from_numpy(per_token_scales).npu()

            permute_tokens, permute_scales, permute_idx, expert_token_num = \
                self.custom_op_exec(tokens_npu, expert_token_num_per_rank_npu, per_token_scales_npu,
                                    expert_token_num_type, idx_type, tokens_dtype=290)

            self.assertEqual(permute_tokens.dtype, torch.uint8)
            self.assertEqual(permute_scales.dtype, torch.float32)
            self.assertEqual(permute_idx.dtype, torch.int32)

            self.assertEqual(permute_tokens.shape[0], bs)
            self.assertEqual(permute_tokens.shape[1], hidden_dim)

    @unittest.skipIf(
        importlib.util.find_spec("ml_dtypes") is None,
        "Unittest for fp4 need package ml_dtypes"
    )
    @SupportedDevices(['Ascend950'])
    def test_npu_moe_re_routing_fp4_e2m1(self, device="npu"):
        bs_list = [32, 128]
        hidden_dim_list = [4096, 7168]
        expert_token_num_type = 1
        idx_type = 0

        for bs, hidden_dim in zip(bs_list, hidden_dim_list):
            tokens, expert_token_num_per_rank = self.generate_inputs(bs, hidden_dim, torch_npu.float4_e2m1fn_x2)

            tokens_uint8 = tokens.view(np.uint8)
            tokens_npu = torch.from_numpy(tokens_uint8).npu()
            expert_token_num_per_rank_npu = torch.from_numpy(expert_token_num_per_rank).npu()

            permute_tokens, permute_scales, permute_idx, expert_token_num = \
                self.custom_op_exec(tokens_npu, expert_token_num_per_rank_npu, None,
                                    expert_token_num_type, idx_type, tokens_dtype=296)

            self.assertEqual(permute_tokens.dtype, torch.uint8)
            self.assertEqual(permute_tokens.shape[0], bs)
            self.assertEqual(permute_tokens.shape[1], hidden_dim // 2)

    @unittest.skipIf(
        importlib.util.find_spec("ml_dtypes") is None,
        "Unittest for fp4 need package ml_dtypes"
    )
    @SupportedDevices(['Ascend950'])
    def test_npu_moe_re_routing_fp4_e2m1_with_scales(self, device="npu"):
        bs_list = [32]
        hidden_dim_list = [4096]
        expert_token_num_type = 1
        idx_type = 0

        for bs, hidden_dim in zip(bs_list, hidden_dim_list):
            tokens, expert_token_num_per_rank = self.generate_inputs(bs, hidden_dim, torch_npu.float4_e2m1fn_x2)

            per_token_scales = np.random.randn(bs).astype(np.float32)

            tokens_uint8 = tokens.view(np.uint8)
            tokens_npu = torch.from_numpy(tokens_uint8).npu()
            expert_token_num_per_rank_npu = torch.from_numpy(expert_token_num_per_rank).npu()
            per_token_scales_npu = torch.from_numpy(per_token_scales).npu()

            permute_tokens, permute_scales, permute_idx, expert_token_num = \
                self.custom_op_exec(tokens_npu, expert_token_num_per_rank_npu, per_token_scales_npu,
                                    expert_token_num_type, idx_type, tokens_dtype=296)

            self.assertEqual(permute_tokens.dtype, torch.uint8)
            self.assertEqual(permute_scales.dtype, torch.float32)
            self.assertEqual(permute_idx.dtype, torch.int32)

            self.assertEqual(permute_tokens.shape[0], bs)
            self.assertEqual(permute_tokens.shape[1], hidden_dim // 2)


if __name__ == "__main__":
    run_tests()