from unittest.mock import Mock

import numpy as np
import torch
from data_cache import golden_data_cache
from torch_npu.testing.testcase import TestCase, run_tests

from mx_driving.ops.group_points import AdsGroupPoints


@golden_data_cache(__file__)
def cpu_gen_inputs(B, C, N, npoints, nsample):
    np_grad_out = np.random.rand(B, C, npoints, nsample).astype(np.float32)
    np_indices = np.random.randint(0, N, (B, npoints, nsample)).astype(np.int32)
    np_grad_features = np.zeros((B, C, N)).astype(np.float32)

    return np_grad_out, np_indices, np_grad_features


class TestGroupPointsGrad(TestCase):
    @golden_data_cache(__file__)
    def golden_group_points_grad(self, np_grad_out, np_indices, np_grad_features, B, npoints, nsample):
        np_grad_out = np_grad_out.transpose(0, 2, 3, 1)
        np_grad_features = np_grad_features.transpose(0, 2, 1)

        for b in range(B):
            for npo in range(npoints):
                for nsa in range(nsample):
                    idx_offset = np_indices[b, npo, nsa]
                    np_grad_features[b, idx_offset, :] += np_grad_out[b, npo, nsa, :]

        np_grad_features = np_grad_features.transpose(0, 2, 1)
        return np_grad_features

    def test_group_points_grad(self):
        np.random.seed(50051)
        B_list = [16, 32, 64]
        C_list = [16, 31, 32, 35, 64, 512]
        N_list = [64]
        npoints_list = [16, 100]
        nsample_list = [32, 50]

        for B in B_list:
            for C in C_list:
                for N in N_list:
                    for npoints in npoints_list:
                        for nsample in nsample_list:
                            np_grad_out, np_indices, np_grad_features = cpu_gen_inputs(B, C, N, npoints, nsample)

                            torch_grad_out = torch.from_numpy(np_grad_out).npu()
                            torch_indices = torch.from_numpy(np_indices).npu()

                            golden_grad_features = self.golden_group_points_grad(
                                np_grad_out, np_indices, np_grad_features, B, npoints, nsample
                            )

                            ctx = Mock()
                            ctx.for_backwards = (torch_indices, N)
                            npu_grad_features, _ = AdsGroupPoints.backward(ctx, torch_grad_out)

                            self.assertRtolEqual(golden_grad_features, npu_grad_features.cpu().numpy())

    def test_group_points_backward_empty_grad(self):
        """反向传播异常分支测试:覆盖grad_out为空的情况"""
        grad_out = torch.empty((0, 0, 0, 0), dtype=torch.float32).npu()
        indices = torch.randint(0, 10, (2, 3, 4), dtype=torch.int32).npu()
        N = 10
        ctx = Mock()
        ctx.for_backwards = (indices, N)

        with self.assertRaises(Exception) as cm:
            AdsGroupPoints.backward(ctx, grad_out)
        self.assertEqual(str(cm.exception), "Error! Input Tensor can not be a empty Tensor.\n")


if __name__ == "__main__":
    run_tests()