import math
import unittest

import numpy as np
import torch

import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices


class TestNPUGatherSparseIndex(TestCase):

    def generate_input_shape(self, dtype):
        item_size = dtype.itemsize
        min_size_limit = 150 * 1024 / item_size

        H = np.random.randint(low=1, high=min_size_limit + 1)
        W = math.ceil(min_size_limit / H)
        return [H, W]

    def generate_index_shape(self, dim):
        min_size_limit = 960

        if dim == 1:
            return [961, ]

        index_shape = []
        for _ in range(dim - 1):
            shape_value = np.random.randint(low=1, high=min_size_limit + 1)
            index_shape.append(shape_value)
            min_size_limit = math.ceil(min_size_limit / shape_value)

        index_shape.append(math.ceil(min_size_limit / index_shape[-1]))
        return index_shape

    def golden_function(self, weight, index):
        num_embeddings, embedding_dim = weight.shape
        self.assertTrue(index.max() < num_embeddings, f"index should be less than inputs.shape[0], get {index.max()} and {num_embeddings}")

        embedding_layer = torch.nn.Embedding(num_embeddings, embedding_dim)
        embedding_layer.weight.requires_grad = False
        embedding_layer.weight.data = weight

        return embedding_layer(index)

    def golden_function_backward(self, weight, index):
        num_embeddings, embedding_dim = weight.shape
        self.assertTrue(index.max() < num_embeddings, f"index should be less than inputs.shape[0], get {index.max()} and {num_embeddings}")

        embedding_layer = torch.nn.Embedding(num_embeddings, embedding_dim)
        embedding_layer.weight.requires_grad = True
        embedding_layer.weight.data = weight
        output = embedding_layer(index)
        loss = output.sum()
        loss.backward()
        return output, embedding_layer.weight.grad

    @SupportedDevices(['Ascend910B'])
    def test_npu_gather_sparse_index(self):
        dim_list = [1, 2, 3, 4, 5, 6]
        dtype_list = [torch.float, torch.half, torch.bfloat16, torch.int32,
                      torch.int64, torch.int8, torch.uint8, torch.bool, torch.double]
        dim_dtype_list = [[dim, dtype]
                          for dim in dim_list
                          for dtype in dtype_list]
        for item in dim_dtype_list:
            dim = item[0]
            dtype = item[1]
            input_shape = self.generate_input_shape(dtype)
            index_shape = self.generate_index_shape(dim)
            H = input_shape[0]

            inputs_golden = torch.randn(input_shape, dtype=dtype, device="npu")
            inputs_npu = inputs_golden.clone()
            index_golden = torch.randint(0, H, index_shape).npu()
            index_npu = index_golden.clone()

            npu_out = torch_npu.npu_gather_sparse_index(inputs_npu, index_npu)
            golden_out = self.golden_function(inputs_golden, index_golden)

            self.assertEqual(npu_out, golden_out)

    @SupportedDevices(['Ascend910B'])
    def test_npu_gather_sparse_index_backward(self):
        dim_list = [1]
        dtype_list = [torch.float, torch.half, torch.bfloat16]
        dim_dtype_list = [[dim, dtype]
                          for dim in dim_list
                          for dtype in dtype_list]
        for item in dim_dtype_list:
            dim = item[0]
            dtype = item[1]
            input_shape = self.generate_input_shape(dtype)
            index_shape = self.generate_index_shape(dim)
            H = input_shape[0]

            inputs_golden = torch.randn(input_shape, dtype=dtype, device="npu")
            inputs_npu = inputs_golden.clone()
            index_golden = torch.randint(0, H, index_shape).npu()
            index_npu = index_golden.clone()

            inputs_npu.requires_grad = True
            npu_out = torch_npu.npu_gather_sparse_index(inputs_npu, index_npu)

            npu_out.backward(torch.ones_like(npu_out))
            inputs_npu_grad = inputs_npu.grad

            golden_out, inputs_golden_grad = self.golden_function_backward(inputs_golden, index_golden)

            self.assertEqual(npu_out, golden_out)
            self.assertEqual(inputs_npu_grad, inputs_golden_grad)


if __name__ == "__main__":
    run_tests()