import numpy as np
import torch
import torch_npu
from data_cache import golden_data_cache
from torch_npu.testing.testcase import TestCase, run_tests
import mx_driving.point
from mx_driving import bev_pool_v2
from mx_driving._C import npu_bev_pool_v2_backward
@golden_data_cache(__file__)
def golden_bev_pool_v2(
depth, feat, ranks_depth, ranks_feat, ranks_bev, interval_starts, interval_lengths, b, d, h, w, c
):
output = np.zeros((b, d, h, w, c), dtype=np.float32)
depth = depth.flatten()
feat = feat.reshape((-1, c))
output = output.reshape((-1, c))
for start, length in zip(interval_starts, interval_lengths):
for i in range(length):
output[ranks_bev[start]] += depth[ranks_depth[start + i]] * feat[ranks_feat[start + i]]
output = output.reshape((b, d, h, w, c))
output = np.transpose(output, (0, 4, 1, 2, 3))
return output
@golden_data_cache(__file__)
def golden_bev_pool_v2_grad(
grad_out, depth, feat, ranks_depth, ranks_feat, ranks_bev, interval_starts, interval_lengths, b, d, h, w, c
):
grad_depth = np.zeros_like(depth).flatten()
grad_feat = np.zeros_like(feat).reshape((-1, c))
depth = depth.flatten()
feat = feat.reshape((-1, c))
grad_out = grad_out.reshape((-1, c))
for start, length in zip(interval_starts, interval_lengths):
for i in range(length):
gd = np.dot(grad_out[ranks_bev[start + i]], feat[ranks_feat[start + i]])
grad_depth[ranks_depth[start + i]] = gd
grad_feat[ranks_feat[start + i]] += depth[ranks_depth[start + i]] * grad_out[ranks_bev[start + i]]
grad_feat = grad_feat.reshape((b, 1, h, w, c))
return grad_feat
@golden_data_cache(__file__)
def generate_bev_pool_data(B, D, H, W, C, N_RANKS):
feat = np.random.rand(B, 1, H, W, C).astype(np.float32)
depth = np.random.rand(B, 1, D, H, W).astype(np.float32)
grad_out = np.random.rand(B, D, H, W, C).astype(np.float32)
ranks_depth = np.sort(np.random.randint(0, B * D * H * W, (N_RANKS,)).astype(np.int32))
ranks_feat = np.sort(np.random.randint(0, B * H * W, (N_RANKS,)).astype(np.int32))
ranks_bev = np.sort(np.random.randint(0, B * D * H * W, (N_RANKS,)).astype(np.int32))
bev_feat_shape = (B, D, H, W, C)
return feat, depth, grad_out, ranks_depth, ranks_feat, ranks_bev, bev_feat_shape
class TestBEVPoolV2(TestCase):
seed = 1024
np.random.seed(seed)
def test_bev_pool_v2(self):
shapes = [
[1, 1, 1, 1, 1, 1],
[3, 3, 3, 3, 3, 3],
[3, 3, 15, 15, 17, 33],
[1, 5, 128, 128, 31, 777],
[32, 4, 128, 128, 64, 9999],
]
for shape in shapes:
B, D, H, W, C, N_RANKS = shape
feat, depth, grad_out, ranks_depth, ranks_feat, ranks_bev, bev_feat_shape = generate_bev_pool_data(
B, D, H, W, C, N_RANKS
)
kept = np.ones(ranks_bev.shape[0], dtype=bool)
kept[1:] = ranks_feat[1:] != ranks_feat[:-1]
interval_starts = np.where(kept)[0].astype(np.int32)
interval_lengths = np.zeros_like(interval_starts, dtype=np.int32)
interval_lengths[:-1] = interval_starts[1:] - interval_starts[:-1]
interval_lengths[-1] = ranks_feat.shape[0] - interval_starts[-1]
feat_npu = torch.from_numpy(feat).npu()
grad_out_npu = torch.from_numpy(grad_out).npu()
depth_npu = torch.from_numpy(depth).npu()
ranks_depth_npu = torch.from_numpy(ranks_depth).npu()
ranks_feat_npu = torch.from_numpy(ranks_feat).npu()
ranks_bev_npu = torch.from_numpy(ranks_bev).npu()
interval_lengths_npu = torch.from_numpy(interval_lengths).npu()
interval_starts_npu = torch.from_numpy(interval_starts).npu()
bev_feat = bev_pool_v2(
depth_npu,
feat_npu,
ranks_depth_npu,
ranks_feat_npu,
ranks_bev_npu,
(B, D, H, W, C),
interval_starts_npu,
interval_lengths_npu,
)
bev_feat_point = mx_driving.point.bev_pool_v2(
depth_npu,
feat_npu,
ranks_depth_npu,
ranks_feat_npu,
ranks_bev_npu,
(B, D, H, W, C),
interval_starts_npu,
interval_lengths_npu,
)
bev_feat_cpu = golden_bev_pool_v2(
depth, feat, ranks_depth, ranks_feat, ranks_bev, interval_starts, interval_lengths, B, D, H, W, C
)
_, grad_feat_npu = npu_bev_pool_v2_backward(
grad_out_npu,
depth_npu,
feat_npu,
ranks_depth_npu,
ranks_feat_npu,
ranks_bev_npu,
interval_lengths_npu,
interval_starts_npu,
B,
D,
H,
W,
)
grad_feat = golden_bev_pool_v2_grad(
grad_out,
depth,
feat,
ranks_depth,
ranks_feat,
ranks_bev,
interval_starts,
interval_lengths,
B,
D,
H,
W,
C,
)
self.assertRtolEqual(bev_feat.detach().cpu().numpy(), bev_feat_cpu)
self.assertRtolEqual(grad_feat_npu.cpu().numpy(), grad_feat)
if __name__ == "__main__":
run_tests()