import unittest
import numpy as np
import torch
import torch_npu
from data_cache import golden_data_cache
from torch_npu.testing.testcase import TestCase, run_tests
import mx_driving._C
import mx_driving.detection
class TestRoIAwarePool3dGrad(TestCase):
@golden_data_cache(__file__)
def roiaware_pool3d_grad_cpu(self, pts_idx_of_voxels, argmax, grad_out,
npoints, pool_method):
channels = grad_out.shape[-1]
grad_in = torch.zeros((npoints, channels)).type_as(grad_out)
dtype = grad_out.dtype
if (dtype == torch.float16):
grad_out_cast = grad_out.type(torch.float32)
grad_in_cast = grad_in.type(torch.float32)
else:
grad_out_cast = grad_out
grad_in_cast = grad_in
if pool_method == 0:
self.roiaware_maxpool3d_grad_cpu(argmax, grad_out_cast, grad_in_cast)
elif pool_method == 1:
self.roiaware_avgpool3d_grad_cpu(pts_idx_of_voxels, grad_out_cast, grad_in_cast)
if (dtype == torch.float16):
grad_out_cast = grad_out_cast.type(torch.float16)
grad_in_cast = grad_in_cast.type(torch.float16)
else:
grad_out_cast = grad_out
grad_in_cast = grad_in
return grad_in_cast
def roiaware_pool3d_grad_npu(self, pts_idx_of_voxels, argmax, grad_out, npoints, pool_method):
class MockCtx:
def __init__(self, pts_idx_of_voxels, argmax, mode, num_pts, num_channels):
self.roiaware_pool3d_for_backward = (pts_idx_of_voxels, argmax, mode, num_pts, num_channels)
ctx = MockCtx(pts_idx_of_voxels, argmax, pool_method, npoints, 1)
from mx_driving.ops.roiaware_pool3d import RoIAwarePool3dFunction
_, _, grad_in, _, _, _ = RoIAwarePool3dFunction.backward(ctx, grad_out)
return grad_in
def roiaware_maxpool3d_grad_cpu(self, argmax, grad_out, grad_in):
boxes_num, out_x, out_y, out_z, channels = grad_out.shape
npoints, _ = grad_in.shape
for b in range(boxes_num):
for ox in range(out_x):
for oy in range(out_y):
for oz in range(out_z):
N_idx = argmax[b, ox, oy, oz, :]
C_idx = np.arange(channels)
grad_in[N_idx, C_idx] += grad_out[b, ox, oy, oz, C_idx]
def roiaware_avgpool3d_grad_cpu(self, pts_idx_of_voxels, grad_out, grad_in):
boxes_num, out_x, out_y, out_z, channels = grad_out.shape
max_pts_per_voxel = pts_idx_of_voxels.shape[-1]
for b in range(boxes_num):
for ox in range(out_x):
for oy in range(out_y):
for oz in range(out_z):
total_pts = pts_idx_of_voxels[b, ox, oy, oz, 0]
for i in range(1, total_pts + 1):
pts_idx = pts_idx_of_voxels[b, ox, oy, oz, i]
grad_in[pts_idx, :] += grad_out[b, ox, oy, oz, :] / max(total_pts, 1.0)
@golden_data_cache(__file__)
def gen_input_data(self, pts_idx_of_voxels_shape, channels, npoints, dtype):
boxes_num, out_x, out_y, out_z, max_pts_per_voxel = pts_idx_of_voxels_shape
grad_out = np.random.uniform(-5, 5, (boxes_num, out_x, out_y, out_z, channels)).astype(dtype)
argmax = np.random.randint(0, npoints, (boxes_num, out_x, out_y, out_z, channels)).astype("int32")
pts_idx_of_voxels = self.gen_pts_idx_of_voxels(pts_idx_of_voxels_shape, npoints).astype("int32")
grad_out = torch.from_numpy(grad_out)
argmax = torch.from_numpy(argmax)
pts_idx_of_voxels = torch.from_numpy(pts_idx_of_voxels)
return argmax, grad_out, pts_idx_of_voxels
@golden_data_cache(__file__)
def gen_pts_idx_of_voxels(self, pts_idx_of_voxels_shape, npoints):
boxes_num, out_x, out_y, out_z, max_pts_per_voxel = pts_idx_of_voxels_shape
pts_idx_of_voxels = np.zeros((boxes_num, out_x, out_y, out_z, max_pts_per_voxel - 1)).astype("int32")
total_pts_array = np.random.randint(0, max_pts_per_voxel, (boxes_num, out_x, out_y, out_z))
for b in range(boxes_num):
for ox in range(out_x):
for oy in range(out_y):
for oz in range(out_z):
total_pts = total_pts_array[b, ox, oy, oz]
choiced_idx = np.array(np.random.choice(npoints, total_pts, replace=False)).astype("int32")
choiced_idx = np.sort(choiced_idx)
pts_idx_of_voxels[b, ox, oy, oz, 0:total_pts] = choiced_idx
pts_idx_of_voxels = np.concatenate([total_pts_array.reshape(boxes_num, out_x, out_y, out_z, 1), pts_idx_of_voxels], axis=-1)
return pts_idx_of_voxels
def one_case(self, boxes_num, out_size, channels, npoints, max_pts_per_voxel, pool_method, dtype):
out_x, out_y, out_z = out_size
pts_idx_of_voxels_shape = (boxes_num, out_x, out_y, out_z, max_pts_per_voxel)
argmax, grad_out, pts_idx_of_voxels = self.gen_input_data(pts_idx_of_voxels_shape, channels, npoints, dtype)
golden_grad_in = np.zeros((npoints, channels)).astype(dtype)
golden_grad_in = torch.from_numpy(golden_grad_in)
golden_grad_in = self.roiaware_pool3d_grad_cpu(pts_idx_of_voxels, argmax, grad_out, npoints, pool_method)
grad_in = np.zeros((npoints, channels)).astype(dtype)
grad_in = torch.from_numpy(grad_in)
grad_out = grad_out.npu()
argmax = argmax.npu()
grad_in = grad_in.npu()
pts_idx_of_voxels = pts_idx_of_voxels.npu()
grad_in = self.roiaware_pool3d_grad_npu(pts_idx_of_voxels, argmax, grad_out, npoints, pool_method)
self.assertRtolEqual(grad_in, golden_grad_in)
def test_roiaware_pool3d_grad(self):
out_size = (14, 14, 14)
self.one_case(1, out_size, 256, 128, 128, 0, "float32")
self.one_case(1, out_size, 256, 128, 128, 1, "float32")
self.one_case(1, out_size, 256, 128, 128, 0, "float16")
self.one_case(1, out_size, 256, 128, 128, 1, "float16")
def test_roiaware_gradout_empty(self):
pts_idx_of_voxels_shape = (1, 14, 14, 14, 128)
argmax, grad_out, pts_idx_of_voxels = self.gen_input_data(pts_idx_of_voxels_shape, 256, 128, "float32")
grad_out = torch.tensor([]).npu()
argmax = argmax.npu()
pts_idx_of_voxels = pts_idx_of_voxels.npu()
with self.assertRaises(Exception) as ctx:
self.roiaware_pool3d_grad_npu(pts_idx_of_voxels, argmax, grad_out, 128, 0)
self.assertEqual(str(ctx.exception), "Error! Input Tensor can not be a empty Tensor.\n")
if __name__ == "__main__":
torch.npu.conv.allow_hf32 = False
run_tests()