import itertools

import torch
from torch.distributed._tensor import distribute_tensor, Replicate, Shard
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    parametrize,
    run_tests,
)

import torch_npu
from torch_npu.testing._internal.common_dtensor import NPUDTensorTestBase
from torch_npu.testing.common_distributed import with_comms, skipIfUnsupportMultiNPU


def get_shape_from_layout(batch: int, num_head: int, seq_length: int, dimension: int, layout: str):
    layout_map = {
        "B": batch,
        "N": num_head,
        "S": seq_length,
        "D": dimension,
        "1": 1,
    }
    shape = []
    for dim in layout:
        if dim in layout_map:
            shape.append(layout_map[dim])
        else:
            raise ValueError(f"Invalid layout character: {dim}")

    return tuple(shape)


class TestMathOps(NPUDTensorTestBase):
    @skipIfUnsupportMultiNPU(4)
    @with_comms
    def test_npu_rms_norm_forward(self):
        device_mesh = self.build_device_mesh()

        x = torch.randn((1, 128, 64), dtype=torch.float32).npu()
        gamma = torch.randn(64, dtype=torch.float32).npu()

        y, rstd = torch_npu.npu_rms_norm(x, gamma)

        dist_x = distribute_tensor(x, device_mesh, [Shard(1)])
        dist_gamma = distribute_tensor(gamma, device_mesh, [Replicate()])

        dist_y, dist_rstd = torch_npu.npu_rms_norm(dist_x, dist_gamma)

        self.assertEqual(dist_y.full_tensor(), y)
        self.assertEqual(dist_gamma.full_tensor(), gamma)

    @skipIfUnsupportMultiNPU(4)
    @with_comms
    def test_npu_rms_norm_backward(self):
        device_mesh = self.build_device_mesh()

        x = torch.randn((1, 128, 64), dtype=torch.float32).npu()
        gamma = torch.randn(64, dtype=torch.float32).npu()
        grad_y = torch.randn((1, 128, 64), dtype=torch.float32).npu()

        x = x.npu()
        gamma = gamma.npu()
        grad_y = grad_y.npu()
        x.requires_grad = True
        gamma.requires_grad = True

        y, rstd = torch_npu.npu_rms_norm(x, gamma, epsilon=1e-06)
        y.backward(grad_y)
        dx = x.grad
        dw = gamma.grad

        dist_x = distribute_tensor(x, device_mesh, [Shard(2)])
        dist_gamma = distribute_tensor(gamma, device_mesh, [Replicate()])

        dist_y, dist_rsts = torch_npu.npu_rms_norm(dist_x, dist_gamma, epsilon=1e-06)
        dist_grad_y = distribute_tensor(grad_y, device_mesh, dist_y.placements)
        dist_y.backward(dist_grad_y)
        dist_dx = dist_x.grad
        dist_dw = dist_gamma.grad

        self.assertEqual(dist_y.full_tensor(), y)
        self.assertEqual(dist_gamma.full_tensor(), gamma)

        self.assertEqual(dist_dx.full_tensor(), dx)
        self.assertEqual(dist_dw.full_tensor(), dw)

    @skipIfUnsupportMultiNPU(4)
    @with_comms
    def test_npu_add_rms_norm_forward(self):
        device_mesh = self.build_device_mesh()

        x1 = torch.randn((1, 128, 64), dtype=torch.float32).npu()
        x2 = torch.randn((1, 128, 64), dtype=torch.float32).npu()
        gamma = torch.randn(64, dtype=torch.float32).npu()

        y, rstd, x = torch_npu.npu_add_rms_norm(x1, x2, gamma)

        def test_placement_comb(placements1, placements2):
            dist_x1 = distribute_tensor(x1, device_mesh, placements1)
            dist_x2 = distribute_tensor(x2, device_mesh, placements2)
            dist_gamma = distribute_tensor(gamma, device_mesh, [Replicate()])
            dist_y, dist_rstd, dist_x = torch_npu.npu_add_rms_norm(dist_x1, dist_x2, dist_gamma)
            self.assertEqual(dist_y.full_tensor(), y)
            self.assertEqual(dist_rstd.full_tensor(), rstd)
            self.assertEqual(dist_x.full_tensor(), x)

        placement = [Shard(0), Shard(1), Shard(2), Replicate()]
        placement_combs = itertools.product(placement, placement)
        for comb in placement_combs:
            test_placement_comb([comb[0]], [comb[1]])

    @skipIfUnsupportMultiNPU(4)
    @with_comms
    @parametrize(
        "rotary_mode,input_layout,sin_cos_layout",
        [
            ("half", "BNSD", "11SD"),
            ("half", "BNSD", "B1SD"),
            ("half", "BNSD", "BNSD"),
            ("half", "BSND", "1S1D"),
            ("half", "BSND", "BS1D"),
            ("half", "BSND", "BSND"),
            ("half", "SBND", "S11D"),
            ("half", "SBND", "SB1D"),
            ("half", "SBND", "SBND"),
            ("interleave", "BNSD", "11SD"),
            ("interleave", "BSND", "1S1D"),
            ("interleave", "SBND", "S11D"),
        ]
    )
    def test_npu_rotary_mul_forward(self, rotary_mode, input_layout, sin_cos_layout):
        device_mesh = self.build_device_mesh()

        B = 8
        N = 8
        S = 64
        D = 32
        x_shape = get_shape_from_layout(B, N, S, D, input_layout)
        x = torch.randn(x_shape, dtype=torch.float32, device="npu")
        sin_cos_shape = get_shape_from_layout(B, N, S, D, sin_cos_layout)
        sin = torch.randn(sin_cos_shape, dtype=torch.float32, device="npu") * 2 - 1
        cos = torch.randn(sin_cos_shape, dtype=torch.float32, device="npu") * 2 - 1

        y = torch_npu.npu_rotary_mul(x, cos, sin, rotary_mode=rotary_mode)

        def test_placement_comb(x_placements, sin_placements, cos_placements):
            dist_x = distribute_tensor(x, device_mesh, x_placements)
            dist_sin = distribute_tensor(sin, device_mesh, sin_placements)
            dist_cos = distribute_tensor(cos, device_mesh, cos_placements)
            dist_y = torch_npu.npu_rotary_mul(dist_x, dist_cos, dist_sin, rotary_mode=rotary_mode)
            self.assertEqual(dist_y.full_tensor(), y)

        placements = [Shard(0), Shard(1), Shard(2), Replicate()]
        for placement in placements:
            if isinstance(placement, Shard) and sin_cos_shape[placement.dim] == 1:
                test_placement_comb([placement], [Replicate()], [Replicate()])
            else:
                test_placement_comb([placement], [placement], [placement])

    @skipIfUnsupportMultiNPU(4)
    @with_comms
    @parametrize(
        "rotary_mode,input_layout,sin_cos_layout",
        [
            ("half", "BNSD", "11SD"),
            ("half", "BNSD", "B1SD"),
            ("half", "BNSD", "BNSD"),
            ("half", "BSND", "1S1D"),
            ("half", "BSND", "BS1D"),
            ("half", "BSND", "BSND"),
            ("half", "SBND", "S11D"),
            ("half", "SBND", "SB1D"),
            ("half", "SBND", "SBND"),
            ("interleave", "BNSD", "11SD"),
            ("interleave", "BSND", "1S1D"),
            ("interleave", "SBND", "S11D"),
        ]
    )
    def test_npu_rotary_mul_backward(self, rotary_mode, input_layout, sin_cos_layout):
        device_mesh = self.build_device_mesh()

        B = 8
        N = 8
        S = 64
        D = 32
        x_shape = get_shape_from_layout(B, N, S, D, input_layout)
        x = torch.randn(x_shape, dtype=torch.float32, device="npu", requires_grad=True)
        sin_cos_shape = get_shape_from_layout(B, N, S, D, sin_cos_layout)
        sin = torch.randn(sin_cos_shape, dtype=torch.float32, device="npu") * 2 - 1
        cos = torch.randn(sin_cos_shape, dtype=torch.float32, device="npu") * 2 - 1
        sin.requires_grad = True
        cos.requires_grad = True


        y = torch_npu.npu_rotary_mul(x, cos, sin, rotary_mode=rotary_mode)
        grad_y = torch.ones_like(y, dtype=torch.float32, device="npu")
        y.backward(grad_y)

        def test_placement_comb(x_placements, sin_placements, cos_placements):
            dist_x = distribute_tensor(x, device_mesh, x_placements)
            dist_sin = distribute_tensor(sin, device_mesh, sin_placements)
            dist_cos = distribute_tensor(cos, device_mesh, cos_placements)
            dist_y = torch_npu.npu_rotary_mul(dist_x, dist_cos, dist_sin, rotary_mode=rotary_mode)
            dist_grad_y = distribute_tensor(grad_y, device_mesh, dist_y.placements)
            dist_y.backward(dist_grad_y)
            self.assertEqual(dist_y.full_tensor(), y)
            self.assertEqual(dist_x.grad.full_tensor(), x.grad)
            self.assertEqual(dist_sin.grad.full_tensor(), sin.grad)
            self.assertEqual(dist_cos.grad.full_tensor(), cos.grad)

        placements = [Shard(0), Shard(1), Shard(2), Replicate()]
        for placement in placements:
            if isinstance(placement, Shard) and sin_cos_shape[placement.dim] == 1:
                test_placement_comb([placement], [Replicate()], [Replicate()])
            else:
                test_placement_comb([placement], [placement], [placement])


instantiate_parametrized_tests(TestMathOps)


if __name__ == "__main__":
    run_tests()