import unittest
from dataclasses import dataclass
from typing import List
import numpy as np
import torch
import torch_npu
from data_cache import golden_data_cache
from torch_npu.testing.testcase import TestCase, run_tests
import mx_driving
import mx_driving.detection
@dataclass
class KernelParams:
score: torch.Tensor
mask: torch.Tensor
embedding: torch.Tensor
kernel_label: torch.Tensor
kernel_contour: torch.Tensor
kernel_region_num: int
distance_threshold: float
@golden_data_cache(__file__)
def pixel_group_cpu_golden(params: KernelParams):
score = params.score
mask = params.mask
embedding = params.embedding
kernel_label = params.kernel_label
kernel_region_num = params.kernel_region_num
distance_threshold = params.distance_threshold
embedding_dim = embedding.shape[2]
kernel_vector = torch.zeros((kernel_region_num, embedding_dim), dtype=torch.float32)
for label in range(1, kernel_region_num):
label_mask = kernel_label == label
label_embeddings = embedding[label_mask]
kernel_vector[label, :] = label_embeddings.sum(dim=0)
vector_sum = label_mask.sum()
kernel_vector[label, :] /= vector_sum
kernel_cv = kernel_vector[label, :]
valid_mask = (mask == 1) & (kernel_label == 0)
valid_embeddings = embedding[valid_mask]
distances = torch.sum((valid_embeddings - kernel_cv) ** 2, dim=1)
within_threshold = distances < distance_threshold**2
kernel_label[valid_mask] = torch.where(within_threshold, label, kernel_label[valid_mask])
point_vector = torch.zeros((kernel_region_num, 2), dtype=torch.float32)
label_flat = kernel_label.flatten()
score_flat = score.flatten()
mask = label_flat > 0
valid_labels = label_flat[mask]
valid_scores = score_flat[mask]
point_vector.index_add_(
0, valid_labels, torch.stack((valid_scores, torch.ones_like(valid_scores)), dim=1),
)
valid_mask = point_vector[:, 1] > 0
point_vector[valid_mask, 0] /= point_vector[valid_mask, 1]
point_vector_list = point_vector.tolist()
for index in range(1, kernel_region_num):
coords = (kernel_label == index).nonzero(as_tuple=False).float()
coords = coords[:, [1, 0]]
point_vector_list[index].extend(coords.flatten().tolist())
return point_vector_list
def pixel_group_npu_golden(params: KernelParams):
output1 = mx_driving.pixel_group(
params.score.npu(),
params.mask.npu(),
params.embedding.npu(),
params.kernel_label.npu(),
params.kernel_contour.npu(),
params.kernel_region_num,
params.distance_threshold,
)
output2 = mx_driving.detection.pixel_group(
params.score.npu(),
params.mask.npu(),
params.embedding.npu(),
params.kernel_label.npu(),
params.kernel_contour.npu(),
params.kernel_region_num,
params.distance_threshold,
)
return output1, output2
@golden_data_cache(__file__)
def generate_data(H, W, dim, num):
score = np.random.uniform(0, 1, [H, W]).astype(np.float32)
score = torch.from_numpy(score)
mask = (score) > 0.5
embedding = np.random.uniform(0, 10, [H, W, dim]).astype(np.float32)
embedding = torch.from_numpy(embedding)
kernel_label = np.random.uniform(0, num, [H, W]).astype(np.int32)
kernel_label = torch.from_numpy(kernel_label)
kernel_contour = np.random.uniform(0, 1, [H, W]).astype(np.uint8)
kernel_contour = torch.from_numpy(kernel_contour)
kernel_region_num = num
distance_threshold = float(0.8)
input_data = [
score,
mask,
embedding,
kernel_label,
kernel_contour,
kernel_region_num,
distance_threshold,
]
return input_data
class TestNpuPixelGroup(TestCase):
seed = 1024
np.random.seed(seed)
def test_pixel_group(self, device="npu"):
shapes = [
[10, 10, 8, 3],
[100, 100, 10, 5],
[200, 100, 15, 6],
[256, 256, 10, 8],
[500, 1000, 15, 10],
]
for shape in shapes:
H, W, dim, num = shape
data_input = generate_data(H, W, dim, num)
params = KernelParams(*data_input)
cpu_output = pixel_group_cpu_golden(params)
npu_output1, npu_output2 = pixel_group_npu_golden(params)
self.assertRtolEqual(cpu_output, npu_output1)
self.assertRtolEqual(cpu_output, npu_output2)
if __name__ == "__main__":
run_tests()