import unittest
import numpy as np
import torch
import torch_npu
from data_cache import golden_data_cache
from torch_npu.testing.testcase import TestCase, run_tests
import mx_driving
import mx_driving.fused
DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10]
@golden_data_cache(__file__)
def cpu_gen_inputs(B, C, anchor, pts, numGroups):
feature_maps = np.random.rand(B, 2816, C).astype(np.float32)
spatial_shape = torch.tensor([[[32, 88]]], dtype=torch.int32).numpy()
scale_start_index = torch.tensor([[0]], dtype=torch.int32).numpy()
sample_location = np.random.rand(B, anchor, pts, 1, 2).astype(np.float32)
weights = np.random.rand(B, anchor, pts, 1, 1, numGroups).astype(np.float32)
return feature_maps, spatial_shape, scale_start_index, sample_location, weights
class TestDeformableAggregation(TestCase):
@golden_data_cache(__file__)
def golden_deformable_aggregation(self, batch_size, num_anchors, num_pts, num_cams, num_scale, num_embeds,
num_groups, num_feat, feature_maps, spatial_shape, scale_start_index,
sample_location, weights):
out = np.zeros((batch_size, num_anchors, num_embeds)).astype(np.float32)
num_kernels = batch_size * num_anchors * num_pts * num_cams * num_scale
for idx in range(num_kernels):
chanenl_offset = 0
weights_offset = idx
scale_index = idx % num_scale
idx //= num_scale
cam_index = idx % num_cams
idx //= num_cams
pts_index = idx % num_pts
idx //= num_pts
anchor_index = idx % num_anchors
idx //= num_anchors
batch_index = idx % batch_size
idx //= batch_size
loc_w = sample_location[batch_index, anchor_index, pts_index, cam_index, 0]
loc_h = sample_location[batch_index, anchor_index, pts_index, cam_index, 1]
if loc_w <= 0 or loc_w >= 1:
continue
if loc_h <= 0 or loc_h >= 1:
continue
scale_start_index_idx = scale_start_index[cam_index, scale_index]
value_offset = (batch_index * num_feat + scale_start_index_idx) * num_embeds
h = spatial_shape[cam_index, scale_index, 0]
w = spatial_shape[cam_index, scale_index, 1]
h_im = loc_h * h - 0.5
w_im = loc_w * w - 0.5
h_low = np.floor(h_im).astype(int)
w_low = np.floor(w_im).astype(int)
h_high = h_low + 1
w_high = w_low + 1
lh = h_im - h_low
lw = w_im - w_low
hh = 1 - lh
hw = 1 - lw
w_stride = num_embeds
h_stride = w * w_stride
h_low_ptr_offset = h_low * h_stride
h_high_ptr_offset = h_low_ptr_offset + h_stride
w_low_ptr_offset = w_low * w_stride
w_high_ptr_offset = w_low_ptr_offset + w_stride
for groups_idx in range(num_groups):
weights_idx = weights_offset * num_groups + groups_idx % num_groups
weight = weights[weights_idx]
v1 = 0
if h_low >= 0 and w_low >= 0:
ptr1 = value_offset + h_low_ptr_offset + w_low_ptr_offset + chanenl_offset
v1 = feature_maps[ptr1 : ptr1 + num_embeds // num_groups]
v2 = 0
if h_low >= 0 and w_high <= w - 1:
ptr2 = value_offset + h_low_ptr_offset + w_high_ptr_offset + chanenl_offset
v2 = feature_maps[ptr2 : ptr2 + num_embeds // num_groups]
v3 = 0
if h_high <= h - 1 and w_low >= 0:
ptr3 = value_offset + h_high_ptr_offset + w_low_ptr_offset + chanenl_offset
v3 = feature_maps[ptr3 : ptr3 + num_embeds // num_groups]
v4 = 0
if h_high <= h - 1 and w_high <= w - 1:
ptr4 = value_offset + h_high_ptr_offset + w_high_ptr_offset + chanenl_offset
v4 = feature_maps[ptr4 : ptr4 + num_embeds // num_groups]
w1 = hh * hw
w2 = hh * lw
w3 = lh * hw
w4 = lh * lw
val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4) * weight
out[batch_index, anchor_index, chanenl_offset : chanenl_offset + num_embeds // num_groups] += val
chanenl_offset += num_embeds // num_groups
return out
@unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `DeformableAggregation` is only supported on 910B, skip this ut!")
def test_deformable_aggregation(self):
np.random.seed(50051)
bList = [1, 5, 10]
cList = [32, 64]
numGroupsList = [8, 16]
anchorList = [10, 13, 18]
ptsList = [10, 50, 31]
for B in bList:
for C in cList:
for pts in ptsList:
for anchor in anchorList:
for numGroups in numGroupsList:
feature_maps, spatial_shape, scale_start_index, sample_location, weights = cpu_gen_inputs(B, C, anchor, pts, numGroups)
torch_feature_maps = torch.from_numpy(feature_maps).npu()
torch_spatial_shape = torch.from_numpy(spatial_shape).npu()
torch_scale_start_index = torch.from_numpy(scale_start_index).npu()
torch_sample_location = torch.from_numpy(sample_location).npu()
torch_weights = torch.from_numpy(weights).npu()
batch_size = feature_maps.shape[0]
num_feat = feature_maps.shape[1]
num_embeds = feature_maps.shape[2]
num_cams = spatial_shape.shape[0]
num_scale = spatial_shape.shape[1]
num_anchors = sample_location.shape[1]
num_pts = sample_location.shape[2]
num_groups = weights.shape[5]
weights = weights.flatten()
feature_maps = feature_maps.flatten()
out_cpu = self.golden_deformable_aggregation(batch_size, num_anchors, num_pts, num_cams,
num_scale, num_embeds, num_groups, num_feat,
feature_maps, spatial_shape, scale_start_index,
sample_location, weights)
out_npu = mx_driving.fused.npu_deformable_aggregation(torch_feature_maps,
torch_spatial_shape,
torch_scale_start_index,
torch_sample_location,
torch_weights)
self.assertRtolEqual(out_cpu, out_npu.cpu().numpy())
out_npu_new = mx_driving.deformable_aggregation(torch_feature_maps,
torch_spatial_shape,
torch_scale_start_index,
torch_sample_location,
torch_weights)
self.assertRtolEqual(out_cpu, out_npu_new.cpu().numpy())
if __name__ == "__main__":
run_tests()