import numpy as np
import torch_npu
import torch
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import get_npu_device


class TestRepeatInterleaveBackward(TestCase):

    def supported_op_exec(self, *args):
        inputs_data, repeats_data, output_grad, dim, all_back, repeat_int, repeat_tensor = args
        inputs_data = inputs_data.detach().clone().npu()
        repeats_data = repeats_data.detach().clone().npu()
        output_grad = output_grad.detach().clone().npu()
        if repeat_int:
            if all_back:
                result = torch_npu.repeat_interleave_backward_int(output_grad, inputs_data, repeats_data[0], dim)
                return result.cpu().float().detach().numpy()
            else:
                inputs_data.requires_grad_(True)
                inputs_data.retain_grad()
                if repeat_tensor:
                    y = torch.repeat_interleave(inputs_data, repeats_data[0], dim)
                else:
                    y = torch.repeat_interleave(inputs_data, int(repeats_data[0]), dim)
                y.backward(output_grad)
                return inputs_data.grad.cpu().float().detach().numpy()
        else:
            if all_back:
                result = torch_npu.repeat_interleave_backward_tensor(output_grad, inputs_data, repeats_data, dim)
                return result.cpu().float().detach().numpy()
            else:
                inputs_data.requires_grad_(True)
                inputs_data.retain_grad()
                y = torch.repeat_interleave(inputs_data, repeats_data, dim)
                y.backward(output_grad)
                return inputs_data.grad.cpu().float().detach().numpy()

    def custom_op_exec(self, *args):
        input_shape, repeat_shape, axis, data_type, repeats_type, repeat_int = args
        inputs_data = torch.rand(input_shape, dtype=data_type, requires_grad=True)
        repeat_low, repeat_high = 2, 129
        repeats_data = torch.randint(repeat_low, repeat_high, repeat_shape, dtype=repeats_type)

        if repeat_int:
            repeats = repeats_data[0]
        else:
            repeats = repeats_data
        y = torch.repeat_interleave(inputs_data.to(torch.float), repeats, axis).to(data_type)

        y_grad = torch.rand(y.shape, dtype=data_type, requires_grad=True)
        y.backward(y_grad)
        return inputs_data, repeats_data, y_grad

    def test_rms_norm(self, device="npu"):
        data_type_all = (torch.half, torch.bfloat16, torch.float)
        repeats_type_all = (
            torch.int64,
        )
        test_shape_all = (
            (40, 1, 16),
        )
        all_back = False
        if torch.__version__ == "2.0.1":
            all_back = True
        for input_shape in test_shape_all:
            axis_all = list(range(-len(input_shape), len(input_shape))) + [None]
            for axis in axis_all:
                for data_type in data_type_all:
                    for repeats_type in repeats_type_all:

                        for repeat_tensor in (False, True):
                            repeat_shape = (1,)
                            repeat_int = True
                            inputs_data, repeats_data, output_grad = self.custom_op_exec(input_shape,
                                                                                         repeat_shape,
                                                                                         axis,
                                                                                         data_type,
                                                                                         repeats_type,
                                                                                         repeat_int)
                            input_grad = inputs_data.grad.float().detach().numpy()
                            result = self.supported_op_exec(inputs_data, repeats_data, output_grad,
                                                            axis, all_back,
                                                            repeat_int, repeat_tensor)

                            self.assertRtolEqual(result, input_grad)


if __name__ == "__main__":
    run_tests()