import unittest
from unittest.mock import Mock
import numpy as np
import torch
import torch_npu
from data_cache import golden_data_cache
from torch_npu.testing.testcase import TestCase, run_tests
import mx_driving
import mx_driving.fused
from mx_driving.ops.npu_deformable_aggregation import AdsDeformableAggregation
DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10]
@golden_data_cache(__file__)
def gen_inputs(B, C, input_h, input_w, anchor, pts, numGroups):
feature_maps = np.random.rand(B, input_h * input_w, C).astype(np.float32)
spatial_shape = torch.tensor([[[input_h, input_w]]], dtype=torch.int32).numpy()
scale_start_index = torch.tensor([[0]], dtype=torch.int32).numpy()
sample_location = np.random.rand(B, anchor, pts, 1, 2).astype(np.float32)
weights = np.random.rand(B, anchor, pts, 1, 1, numGroups).astype(np.float32)
return feature_maps, spatial_shape, scale_start_index, sample_location, weights
class TestDeformableAggregation(TestCase):
@golden_data_cache(__file__)
def golden_deformable_aggregation_grad(
self,
batch_size,
num_anchors,
num_pts,
num_cams,
num_scale,
num_embeds,
num_groups,
num_feat,
feature_maps,
spatial_shape,
scale_start_index,
sample_location,
weights
):
out_cpu = np.zeros((batch_size, num_anchors, num_embeds)).astype(np.float32)
grad_mc_ms_feat = np.zeros_like(feature_maps)
grad_sampling_location = np.zeros_like(sample_location)
grad_weights = np.zeros_like(weights)
grad_output = np.ones_like(out_cpu)
feature_maps = feature_maps.flatten()
spatial_shape = spatial_shape.flatten()
scale_start_index = scale_start_index.flatten()
sample_location = sample_location.flatten()
weights = weights.flatten()
grad_mc_ms_feat = grad_mc_ms_feat.flatten()
grad_sampling_location = grad_sampling_location.flatten()
grad_weights = grad_weights.flatten()
grad_output = grad_output.flatten()
num_kernels = batch_size * num_pts * num_embeds * num_anchors * num_cams * num_scale
for idx in range(num_kernels):
weights_ptr = idx // (num_embeds // num_groups)
channel_index = idx % num_embeds
idx //= num_embeds
scale_index = idx % num_scale
idx //= num_scale
cam_index = idx % num_cams
idx //= num_cams
pts_index = idx % num_pts
idx //= num_pts
anchor_index = idx % num_anchors
idx //= num_anchors
batch_index = idx % batch_size
anchor_index = batch_index * num_anchors + anchor_index
loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1
loc_w = sample_location[loc_offset]
if loc_w <= 0 or loc_w >= 1:
continue
loc_h = sample_location[loc_offset + 1]
if loc_h <= 0 or loc_h >= 1:
continue
grad = grad_output[anchor_index * num_embeds + channel_index]
cam_scale_index = cam_index * num_scale + scale_index
value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index
cam_scale_index = cam_scale_index << 1
h = spatial_shape[cam_scale_index]
w = spatial_shape[cam_scale_index + 1]
h_im = loc_h * h - 0.5
w_im = loc_w * w - 0.5
weight = weights[weights_ptr]
h_low = np.floor(h_im).astype(int)
w_low = np.floor(w_im).astype(int)
h_high = h_low + 1
w_high = w_low + 1
lh = h_im - h_low
lw = w_im - w_low
hh = 1 - lh
hw = 1 - lw
w_stride = num_embeds
h_stride = w * w_stride
h_low_ptr_offset = h_low * h_stride
h_high_ptr_offset = h_low_ptr_offset + h_stride
w_low_ptr_offset = w_low * w_stride
w_high_ptr_offset = w_low_ptr_offset + w_stride
w1 = hh * hw
w2 = hh * lw
w3 = lh * hw
w4 = lh * lw
top_grad_mc_ms_feat = grad * weight
grad_h_weight = 0
grad_w_weight = 0
v1 = 0
if h_low >= 0 and w_low >= 0:
ptr1 = value_offset + h_low_ptr_offset + w_low_ptr_offset
v1 = feature_maps[ptr1]
grad_h_weight -= hw * v1
grad_w_weight -= hh * v1
grad_mc_ms_feat[ptr1] += w1 * top_grad_mc_ms_feat
v2 = 0
if h_low >= 0 and w_high <= w - 1:
ptr2 = value_offset + h_low_ptr_offset + w_high_ptr_offset
v2 = feature_maps[ptr2]
grad_h_weight -= lw * v2
grad_w_weight += hh * v2
grad_mc_ms_feat[ptr2] += w2 * top_grad_mc_ms_feat
v3 = 0
if h_high <= h - 1 and w_low >= 0:
ptr3 = value_offset + h_high_ptr_offset + w_low_ptr_offset
v3 = feature_maps[ptr3]
grad_h_weight += hw * v3
grad_w_weight -= lh * v3
grad_mc_ms_feat[ptr3] += w3 * top_grad_mc_ms_feat
v4 = 0
if h_high <= h - 1 and w_high <= w - 1:
ptr4 = value_offset + h_high_ptr_offset + w_high_ptr_offset
v4 = feature_maps[ptr4]
grad_h_weight += lw * v4
grad_w_weight += lh * v4
grad_mc_ms_feat[ptr4] += w4 * top_grad_mc_ms_feat
val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4)
grad_weights[weights_ptr] += grad * val
grad_sampling_location[loc_offset] += w * grad_w_weight * top_grad_mc_ms_feat
grad_sampling_location[loc_offset + 1] += h * grad_h_weight * top_grad_mc_ms_feat
return grad_mc_ms_feat, grad_sampling_location, grad_weights
@unittest.skipIf(
DEVICE_NAME != 'Ascend910B',
"OP `DeformableAggregationGrad` is only supported on 910B, skip this ut!",
)
def test_deformable_aggregation(self):
np.random.seed(50051)
bList = [1, 5]
cList = [8 * 8, 8 * 8 * 2]
ptsList = [10, 21]
anchorList = [10, 13]
numGroupsList = [8]
for B in bList:
for C in cList:
for pts in ptsList:
for anchor in anchorList:
for numGroups in numGroupsList:
input_h = 16
input_w = 22
feature_maps, spatial_shape, scale_start_index, sample_location, weights = gen_inputs(B, C, input_h, input_w, anchor, pts, numGroups)
feature_maps_shape = feature_maps.shape
torch_feature_maps = torch.from_numpy(feature_maps).npu()
torch_feature_maps.requires_grad = True
torch_spatial_shape = torch.from_numpy(spatial_shape).npu()
torch_scale_start_index = torch.from_numpy(scale_start_index).npu()
torch_sample_location = torch.from_numpy(sample_location).npu()
torch_sample_location.requires_grad = True
torch_weights = torch.from_numpy(weights).npu()
torch_weights.requires_grad = True
torch_feature_maps_new = torch.from_numpy(feature_maps).npu()
torch_feature_maps_new.requires_grad = True
torch_spatial_shape_new = torch.from_numpy(spatial_shape).npu()
torch_scale_start_index_new = torch.from_numpy(scale_start_index).npu()
torch_sample_location_new = torch.from_numpy(sample_location).npu()
torch_sample_location_new.requires_grad = True
torch_weights_new = torch.from_numpy(weights).npu()
torch_weights_new.requires_grad = True
batch_size = feature_maps.shape[0]
num_feat = feature_maps.shape[1]
num_embeds = feature_maps.shape[2]
num_cams = spatial_shape.shape[0]
num_scale = spatial_shape.shape[1]
num_anchors = sample_location.shape[1]
num_pts = sample_location.shape[2]
num_groups = weights.shape[5]
torch_grad_output = torch.ones(
(B, anchor, C), dtype=torch.float32
).npu()
out_npu = AdsDeformableAggregation.forward(
Mock(),
torch_feature_maps,
torch_spatial_shape,
torch_scale_start_index,
torch_sample_location,
torch_weights
)
ctx = Mock()
ctx.saved_tensors = (
torch_feature_maps,
torch_spatial_shape,
torch_scale_start_index,
torch_sample_location,
torch_weights
)
torch_grad_mc_ms_feat, _, _, torch_grad_sampling_location, torch_grad_weights = AdsDeformableAggregation.backward(
ctx,
torch_grad_output
)
grad_mc_ms_feat, grad_sampling_location, grad_weights = self.golden_deformable_aggregation_grad(
batch_size,
num_anchors,
num_pts,
num_cams,
num_scale,
num_embeds,
num_groups,
num_feat,
feature_maps,
spatial_shape,
scale_start_index,
sample_location,
weights
)
grad_mc_ms_feat = grad_mc_ms_feat.reshape(feature_maps_shape)
grad_sampling_location = grad_sampling_location.reshape(B, anchor, pts, 1, 2)
grad_weights = grad_weights.reshape(B, anchor, pts, 1, 1, numGroups)
torch_grad_mc_ms_feat = torch_grad_mc_ms_feat.cpu().numpy()
torch_grad_sampling_location = torch_grad_sampling_location.cpu().numpy()
torch_grad_weights = torch_grad_weights.cpu().numpy()
self.assertRtolEqual(grad_mc_ms_feat, torch_grad_mc_ms_feat, prec=0.00048828125)
self.assertRtolEqual(
grad_sampling_location,
torch_grad_sampling_location,
prec=0.00048828125,
)
self.assertRtolEqual(grad_weights, torch_grad_weights, prec=0.00048828125)
torch_grad_output_new = torch.ones(
(B, anchor, C), dtype=torch.float32
).npu()
out_npu_new = AdsDeformableAggregation.forward(
Mock(),
torch_feature_maps_new,
torch_spatial_shape_new,
torch_scale_start_index_new,
torch_sample_location_new,
torch_weights_new
)
ctx_new = Mock()
ctx_new.saved_tensors = (
torch_feature_maps_new,
torch_spatial_shape_new,
torch_scale_start_index_new,
torch_sample_location_new,
torch_weights_new
)
torch_grad_mc_ms_feat_new, _, _, torch_grad_sampling_location_new, torch_grad_weights_new = AdsDeformableAggregation.backward(
ctx_new,
torch_grad_output_new
)
torch_grad_mc_ms_feat_new = torch_grad_mc_ms_feat_new.cpu().numpy()
torch_grad_sampling_location_new = torch_grad_sampling_location_new.cpu().numpy()
torch_grad_weights_new = torch_grad_weights_new.cpu().numpy()
self.assertRtolEqual(grad_mc_ms_feat, torch_grad_mc_ms_feat_new, prec=0.00048828125)
self.assertRtolEqual(
grad_sampling_location,
torch_grad_sampling_location_new,
prec=0.00048828125,
)
self.assertRtolEqual(grad_weights, torch_grad_weights_new, prec=0.00048828125)
@unittest.skipIf(
DEVICE_NAME != 'Ascend910B',
"OP `DeformableAggregationGrad` is only supported on 910B, skip this ut!",
)
def test_deformable_aggregation_backward_empty_input(self):
"""反向异常分支测试:输入为空张量"""
torch_feature_maps = torch.empty((0,), dtype=torch.float32).npu()
torch_spatial_shape = torch.empty((0,), dtype=torch.int32).npu()
torch_scale_start_index = torch.tensor([[0]], dtype=torch.int32).npu()
torch_sample_location = torch.empty((0,), dtype=torch.float32).npu()
torch_weights = torch.randn(1, 1, 1, 1, 1, 1).npu()
torch_grad_output = torch.randn(1, 1, 1).npu()
ctx = Mock()
ctx.saved_tensors = (
torch_feature_maps,
torch_spatial_shape,
torch_scale_start_index,
torch_sample_location,
torch_weights
)
with self.assertRaises(Exception) as cm:
AdsDeformableAggregation.backward(ctx, torch_grad_output)
self.assertEqual(str(cm.exception), "Erorr! Input Tensor can not be a empty Tensor.\n")
@unittest.skipIf(
DEVICE_NAME != 'Ascend910B',
"OP `DeformableAggregationGrad` is only supported on 910B, skip this ut!",
)
def test_deformable_aggregation_forward_empty_input(self):
"""正向异常分支测试:输入为空张量"""
torch_feature_maps = torch.empty((0,), dtype=torch.float32).npu()
torch_spatial_shape = torch.tensor([[[16, 22]]], dtype=torch.int32).npu()
torch_scale_start_index = torch.tensor([[0]], dtype=torch.int32).npu()
torch_sample_location = torch.randn(1, 10, 10, 1, 2).npu()
torch_weights = torch.empty((0,), dtype=torch.float32).npu()
with self.assertRaises(Exception) as cm:
AdsDeformableAggregation.forward(
Mock(),
torch_feature_maps,
torch_spatial_shape,
torch_scale_start_index,
torch_sample_location,
torch_weights
)
self.assertEqual(str(cm.exception), "Erorr! Input Tensor can not be a empty Tensor.\n")
if __name__ == "__main__":
run_tests()