import unittest

import torch
import torch_npu
import numpy as np

from data_cache import golden_data_cache
from torch_npu.testing.testcase import TestCase, run_tests

from mx_driving import npu_draw_gaussian_to_heatmap


# pylint: disable=too-many-arguments,huawei-too-many-arguments
@golden_data_cache(__file__)
def golden_draw_gaussian_to_heatmap(
    boxes,
    cur_class_id_tensor,
    num_classes,
    out_size_factor,
    gaussian_overlap,
    min_radius,
    voxel_size,
    pc_range,
    feature_map_size,
    max_objs=500
):
    heatmap = boxes.new_zeros(num_classes, feature_map_size[1], feature_map_size[0])
    inds = boxes.new_zeros(max_objs).long()
    mask = boxes.new_zeros(max_objs, dtype=torch.uint8)
    x, y, z = boxes[:, 0], boxes[:, 1], boxes[:, 2]
    coord_x = (x - pc_range[0]) / voxel_size[0] / out_size_factor
    coord_y = (y - pc_range[1]) / voxel_size[1] / out_size_factor
    coord_x = torch.clamp(coord_x, min=0, max=feature_map_size[0] - 0.5)
    coord_y = torch.clamp(coord_y, min=0, max=feature_map_size[1] - 0.5)
    center = torch.cat((coord_x[:, None], coord_y[:, None]), dim=-1)
    center_int = center.int()

    dx, dy, dz = boxes[:, 3], boxes[:, 4], boxes[:, 5]
    dx = dx / voxel_size[0] / out_size_factor
    dy = dy / voxel_size[1] / out_size_factor
    radius = gaussian_radius(dx, dy, min_overlap=gaussian_overlap)
    radius = torch.clamp_min(radius.int(), min=min_radius)

    for k in range(min(max_objs, boxes.shape[0])):
        if dx[k] <= 0 or dy[k] <= 0:
            continue
        if not (0 <= center_int[k][0] <= feature_map_size[0] and 0 <= center_int[k][1] <= feature_map_size[1]):
            continue
        cur_class_id = (cur_class_id_tensor[k] - 1).long()
        draw_gaussian_to_heatmap(heatmap[cur_class_id], center[k], radius[k].item())

        inds[k] = center_int[k, 1] * feature_map_size[0] + center_int[k, 0]
        mask[k] = 1

    return heatmap, mask,center_int,radius


def gaussian_radius(height, width, min_overlap=0.5):
    """
    Args:
        height: (N)
        width: (N)
        min_overlap:
    Returns:
    """
    a1 = 1
    b1 = (height + width)
    c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
    sq1 = (b1 ** 2 - 4 * a1 * c1).sqrt()
    r1 = (b1 + sq1) / 2

    a2 = 4
    b2 = 2 * (height + width)
    c2 = (1 - min_overlap) * width * height
    sq2 = (b2 ** 2 - 4 * a2 * c2).sqrt()
    r2 = (b2 + sq2) / 2

    a3 = 4 * min_overlap
    b3 = -2 * min_overlap * (height + width)
    c3 = (min_overlap - 1) * width * height
    sq3 = (b3 ** 2 - 4 * a3 * c3).sqrt()
    r3 = (b3 + sq3) / 2
    ret = torch.min(torch.min(r1, r2), r3)
    return ret


def gaussian2D(shape, sigma=1):
    m, n = [(ss - 1.) / 2. for ss in shape]
    y, x = np.ogrid[-m:m + 1, -n:n + 1]
    h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
    h[h < np.finfo(h.dtype).eps * h.max()] = 0
    return h


def draw_gaussian_to_heatmap(heatmap, center, radius, k=1, valid_mask=None):
    diameter = 2 * radius + 1
    gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6)

    x, y = int(center[0]), int(center[1])

    height, width = heatmap.shape[0:2]

    left, right = min(x, radius), min(width - x, radius + 1)
    top, bottom = min(y, radius), min(height - y, radius + 1)
    masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
    masked_gaussian = torch.from_numpy(
        gaussian[radius - top:radius + bottom, radius - left:radius + right]
    ).to(heatmap.device).float()

    if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
        if valid_mask is not None:
            cur_valid_mask = valid_mask[y - top:y + bottom, x - left:x + right]
            masked_gaussian = masked_gaussian * cur_valid_mask.float()

        torch.max(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
    return heatmap


class TestDrawGaussianToHeatmap(TestCase):
    seed = 1024
    torch.manual_seed(seed)

    def test_draw_gaussian_to_heatmap(self):
        shapes = [
            [120, 9, 120],
            [100, 9, 100],
            [80, 9, 80],
            [60, 9, 60],
            [40, 9, 40],
            [20, 9, 20],
            [5, 9, 5],
        ]
        out_size_factor = 8
        gaussian_overlap = 0.1
        min_radius = 2
        voxel_size = [0.075, 0.075, 0.2] 
        pc_range = [-54.0, -54.0, -5.0, 54.0, 54.0, 3.0]
        feature_map_size = [180, 180]
        num_classes = 10 
        for shape in shapes:
            H, W, max_obj = shape
            boxes = 50 * torch.rand((H, W), dtype=torch.float32)
            cur_class_id = torch.randint(low=1, high=num_classes+1, size=(H,), dtype=torch.int32)

            heatmap_cpu,  mask_cpu,center_int_cpu,radius_cpu = golden_draw_gaussian_to_heatmap(
                boxes, cur_class_id, num_classes, out_size_factor, gaussian_overlap, min_radius, voxel_size, pc_range, feature_map_size, max_obj)

            cur_class_id_npu = cur_class_id.clone().to("npu")
            mask = mask_cpu.clone().to("npu")
            center_int = center_int_cpu.clone().to("npu").t().contiguous()
            radius = radius_cpu.clone().to("npu")

            heatmap_npu = npu_draw_gaussian_to_heatmap(
                mask, cur_class_id_npu, center_int,radius,feature_map_size[0],feature_map_size[1],num_classes)
            self.assertRtolEqual(heatmap_cpu.numpy(), heatmap_npu.cpu().numpy())


if __name__ == "__main__":
    run_tests()