import unittest

import numpy as np
import torch
import torch_npu
from data_cache import golden_data_cache
from torch_npu.testing.testcase import TestCase, run_tests

import mx_driving.point
from mx_driving import Voxelization


class TestHardVoxelize(TestCase):
    seed = 1024
    point_nums = [1, 7, 6134, 99999]
    np.random.seed(seed)

    @golden_data_cache(__file__)
    def gen(self, point_num):
        x = 108 * np.random.rand(point_num) - 54 
        y = 108 * np.random.rand(point_num) - 54 
        z = 10 * np.random.rand(point_num) - 5
        return np.stack([x, y, z], axis=-1)

    def npu_hard_voxelize(self, points):
        points_npu = torch.from_numpy(points.astype(np.float32)).npu()
        vlz1 = Voxelization(
            [0.075, 0.075, 0.2], [-54, -54, -5, 54, 54, 5], 10, 1000
        )
        cnt1, pts1, voxs1, num_per_vox1 = vlz1(points_npu)
        vlz = mx_driving.point.Voxelization(
            [0.075, 0.075, 0.2], [-54, -54, -5, 54, 54, 5], 10, 1000
        )
        cnt, pts, voxs, num_per_vox = vlz(points_npu)
        return cnt, voxs.cpu().numpy(), cnt1, voxs1.cpu().numpy()

    @golden_data_cache(__file__)
    def golden_hard_voxelize(self, points):
        point_num = points.shape[0]
        gridx = 1440
        gridy = 1440
        gridz = 50
        points = points.astype(np.float64)
        coorx = np.floor((points[:, 0] + 54) / 0.075).astype(np.int32)
        coory = np.floor((points[:, 1] + 54) / 0.075).astype(np.int32)
        coorz = np.floor((points[:, 2] + 5) / 0.2).astype(np.int32)
        result = []
        seen = set()
        for i in range(point_num):
            x, y, z = coorx[i], coory[i], coorz[i]
            if x >= 0 and x < gridx and y >= 0 and y < gridy and z >= 0 and z < gridz:
                code = (x << 19) | (y << 8) | z
                if code not in seen:
                    seen.add(code)
                    result.append([x, y, z])
                if len(seen) == 1000:
                    break
        
        return len(result), np.array(result)


    def test_hard_voxelize(self):
        for point_num in self.point_nums:
            voxels = self.gen(point_num)
            cnt_cpu, res_cpu = self.golden_hard_voxelize(voxels)
            cnt_npu, res_npu, cnt_npu1, res_npu1 = self.npu_hard_voxelize(voxels)
            self.assertRtolEqual(cnt_cpu, cnt_npu)
            self.assertRtolEqual(cnt_cpu, cnt_npu1)
            self.assertRtolEqual(res_cpu, res_npu)
            self.assertRtolEqual(res_cpu, res_npu1)


if __name__ == "__main__":
    run_tests()