import torch
from data_cache import golden_data_cache
from torch_npu.testing.testcase import TestCase, run_tests
from mx_driving import bev_pool_v3
from mx_driving.ops.bev_pool_v3 import BEVPoolV3
class TestBEVPoolV3(TestCase):
seed = 1024
def setUp(self):
torch.manual_seed(self.seed)
class MockCtx:
def __init__(self, saved_tensors):
self.saved_tensors = saved_tensors
self.MockCtx = MockCtx
self.shapes = [
[1, 1, 1, 1, 8, 1],
[3, 3, 3, 3, 16, 3],
[3, 3, 15, 15, 32, 33],
[1, 5, 17, 23, 8, 777],
[32, 7, 11, 17, 64, 9999],
]
@staticmethod
def ranks_bev_2d_to_1d(ranks_bev, D, H, W):
return (
ranks_bev[:, 3].to(torch.int32) * (D * H * W)
+ ranks_bev[:, 2].to(torch.int32) * (H * W)
+ ranks_bev[:, 0].to(torch.int32) * W
+ ranks_bev[:, 1].to(torch.int32)
)
@golden_data_cache(__file__)
def golden_bev_pool_v3(self, depth, feat, ranks_depth, ranks_feat, ranks_bev, bev_feat_shape):
B, D, H, W, C = bev_feat_shape
feat_flat = feat.view(-1, C)
out_flat = torch.zeros(B * D * H * W, C, dtype=feat.dtype, device=feat.device)
if depth is not None:
depth_flat = depth.view(-1)
d_vals = depth_flat[ranks_depth]
f_vals = feat_flat[ranks_feat]
weighted = d_vals.unsqueeze(1) * f_vals
out_flat.index_add_(0, ranks_bev, weighted)
else:
if ranks_bev.dim() != 2:
raise ValueError("ranks_bev must be 2D when running without depth")
linear_idx = self.ranks_bev_2d_to_1d(ranks_bev, D, H, W).long()
out_flat.index_add_(0, linear_idx, feat_flat)
out = out_flat.view(B, D, H, W, C).permute(0, 4, 1, 2, 3).contiguous()
return out
@golden_data_cache(__file__)
def golden_bev_pool_v3_grad_with_depth(self, bev_feat_cpu, grad_out, feat, depth):
bev_feat_cpu.backward(grad_out)
return feat.grad, depth.grad
@golden_data_cache(__file__)
def golden_bev_pool_v3_grad_without_depth(self, bev_feat_cpu, grad_out, feat):
bev_feat_cpu.backward(grad_out)
return feat.grad
def generate_bev_pool_data(self, input_shape, with_depth=True):
B, D, H, W, C, N_RANKS = input_shape
grad_out = torch.rand([B, C, D, H, W]) * 10 - 5
bev_feat_shape = [B, D, H, W, C]
if with_depth:
depth = torch.rand([B, 1, D, H, W]) * 10 - 5
feat = torch.rand([B, 1, H, W, C]) * 10 - 5
ranks_depth = torch.randint(0, B * D * H * W, [N_RANKS], dtype=torch.int32)
ranks_feat = torch.randint(0, B * H * W, [N_RANKS], dtype=torch.int32)
ranks_bev = torch.randint(0, B * D * H * W, [N_RANKS], dtype=torch.int32)
return feat, depth, ranks_depth, ranks_feat, ranks_bev, bev_feat_shape, grad_out
else:
feat = torch.rand([N_RANKS, C]) * 10 - 5
ranks_bev = torch.stack(
[
torch.randint(0, H, (N_RANKS,)),
torch.randint(0, W, (N_RANKS,)),
torch.randint(0, D, (N_RANKS,)),
torch.randint(0, B, (N_RANKS,)),
],
dim=1,
).to(torch.int32)
return feat, None, None, None, ranks_bev, bev_feat_shape, grad_out
def test_bev_pool_v3_with_depth(self):
for shape in self.shapes:
feat, depth, ranks_depth, ranks_feat, ranks_bev, bev_feat_shape, _ = self.generate_bev_pool_data(
input_shape=shape, with_depth=True
)
bev_feat_cpu = self.golden_bev_pool_v3(depth, feat, ranks_depth, ranks_feat, ranks_bev, bev_feat_shape)
bev_feat_npu = bev_pool_v3(
depth.npu(), feat.npu(), ranks_depth.npu(), ranks_feat.npu(), ranks_bev.npu(), bev_feat_shape
)
self.assertRtolEqual(bev_feat_npu.detach().cpu(), bev_feat_cpu)
def test_bev_pool_v3_grad_with_depth(self):
for shape in self.shapes:
feat, depth, ranks_depth, ranks_feat, ranks_bev, bev_feat_shape, grad_out = self.generate_bev_pool_data(
input_shape=shape, with_depth=True
)
feat_npu = feat.clone().npu()
depth_npu = depth.clone().npu()
feat.requires_grad_()
depth.requires_grad_()
bev_feat_cpu = self.golden_bev_pool_v3(depth, feat, ranks_depth, ranks_feat, ranks_bev, bev_feat_shape)
grad_feat_cpu, grad_depth_cpu = self.golden_bev_pool_v3_grad_with_depth(bev_feat_cpu, grad_out, feat, depth)
saved_tensors = (depth_npu, feat_npu, ranks_feat.npu(), ranks_depth.npu(), ranks_bev.npu())
grad_depth_npu, grad_feat_npu, _, _, _, _ = BEVPoolV3.backward(self.MockCtx(saved_tensors), grad_out.npu())
self.assertRtolEqual(grad_feat_npu.detach().cpu(), grad_feat_cpu)
self.assertRtolEqual(grad_depth_npu.detach().cpu(), grad_depth_cpu)
def test_bev_pool_v3_without_depth(self):
for shape in self.shapes:
feat, _, _, _, ranks_bev, bev_feat_shape, _ = self.generate_bev_pool_data(
input_shape=shape, with_depth=False
)
bev_feat_cpu = self.golden_bev_pool_v3(None, feat, None, None, ranks_bev, bev_feat_shape)
bev_feat_npu = bev_pool_v3(None, feat.npu(), None, None, ranks_bev.npu(), bev_feat_shape)
self.assertRtolEqual(bev_feat_npu.detach().cpu(), bev_feat_cpu)
def test_bev_pool_v3_grad_without_depth(self):
for shape in self.shapes:
feat, _, _, _, ranks_bev, bev_feat_shape, grad_out = self.generate_bev_pool_data(
input_shape=shape, with_depth=False
)
feat_npu = feat.clone().npu()
feat.requires_grad_()
bev_feat_cpu = self.golden_bev_pool_v3(None, feat, None, None, ranks_bev, bev_feat_shape)
grad_feat_cpu = self.golden_bev_pool_v3_grad_without_depth(bev_feat_cpu, grad_out, feat)
_, D, H, W, _ = bev_feat_shape
ranks_bev_1d = self.ranks_bev_2d_to_1d(ranks_bev, D, H, W)
saved_tensors = (None, feat_npu, None, None, ranks_bev_1d.npu())
_, grad_feat_npu, _, _, _, _ = BEVPoolV3.backward(self.MockCtx(saved_tensors), grad_out.npu())
self.assertRtolEqual(grad_feat_npu.detach().cpu(), grad_feat_cpu)
if __name__ == "__main__":
run_tests()