"""Compare results between different algos:
CPU: simple gather-mm-scatter
Native: Fused gather-mm-scatter
ImplicitGemm: implicit gemm
"""
import time
from pathlib import Path
import numpy as np
import torch
import torch_npu
from torch import nn
from torch_npu.testing.testcase import TestCase, run_tests
from data_cache import golden_data_cache
from mx_driving.spconv import SparseSequential, SparseConvTensor, SubMConv3d
@golden_data_cache(__file__)
def generate_sparse_data(num_points, spatial_shape, in_channels):
bs = len(num_points)
total_points = sum(num_points)
features = np.random.uniform(0, 5, (total_points, in_channels))
indices = []
batch_idx = 0
for num_point in num_points:
batch_indices = []
batch_indices.append(np.ones((2 * num_point, 1)) * batch_idx)
for spatial_size in spatial_shape:
idx = np.random.uniform(0, spatial_size, (2 * num_point, 1)).astype(np.int32)
batch_indices.append(idx)
batch_indices = np.concatenate(batch_indices, axis=1)
idx_unique = np.unique(batch_indices, axis=0)
indices.append(idx_unique[:num_point])
batch_idx += 1
indices = np.concatenate(indices, axis=0)
return torch.from_numpy(features).float(), torch.from_numpy(indices).int()
def generate_map(coors, spatial_shape, bs):
spatial_shape1 = (spatial_shape[1] * spatial_shape[0])
new_coors1 = spatial_shape1 * coors[:, 0] + spatial_shape[1] * coors[:, 1] + coors[:, 2]
map1 = torch.full((spatial_shape1 * bs, ), -1, dtype=torch.int32, device=coors.device)
map1[new_coors1] = torch.arange(new_coors1.numel(), dtype=torch.int32, device=coors.device)
mask = map1 != -1
map1_unqiue_size = mask.sum()
map1[mask] = torch.arange(map1_unqiue_size, dtype=torch.int32, device=coors.device)
map2 = torch.full((map1_unqiue_size, spatial_shape[2]), -1, dtype=torch.int32, device=coors.device)
map2[map1[new_coors1], coors[:, 3]] = torch.arange(new_coors1.numel(), dtype=torch.int32, device=coors.device)
return map1, map2
@golden_data_cache(__file__)
def get_golden_output(features, indices, weights, bias, batch_size, in_channels,
out_channels, kernel_size, out_spatial_shape):
map1, map2 = generate_map(indices, out_spatial_shape, batch_size)
M = torch.zeros((features.shape[0], kernel_size, kernel_size, kernel_size, in_channels), device=features.device, dtype=features.dtype)
indices_offset = (-1) * torch.ones((features.shape[0], kernel_size, kernel_size, kernel_size), device=features.device).int()
weight_flatten = weights.reshape((kernel_size * kernel_size * kernel_size * in_channels, out_channels))
min_x_idx = indices[:, 1] - kernel_size // 2
min_y_idx = indices[:, 2] - kernel_size // 2
min_z_idx = indices[:, 3] - kernel_size // 2
kernel_offset = torch.arange(kernel_size, device=features.device)
k0 = torch.broadcast_to(kernel_offset.reshape((kernel_size, 1, 1)), (kernel_size, kernel_size, kernel_size))
k1 = torch.broadcast_to(kernel_offset.reshape((1, kernel_size, 1)), (kernel_size, kernel_size, kernel_size))
k2 = torch.broadcast_to(kernel_offset.reshape((1, 1, kernel_size)), (kernel_size, kernel_size, kernel_size))
x_idx = min_x_idx[:, None, None, None] + k0[None, :]
y_idx = min_y_idx[:, None, None, None] + k1[None, :]
z_idx = min_z_idx[:, None, None, None] + k2[None, :]
mask = (x_idx >= 0) * (y_idx >= 0) * (z_idx >= 0) * (x_idx < out_spatial_shape[0]) * (y_idx < out_spatial_shape[1]) * (z_idx < out_spatial_shape[2])
map1_idx = (indices[:, 0, None, None, None] * out_spatial_shape[1] * out_spatial_shape[0] + x_idx * out_spatial_shape[1] + y_idx)[mask]
map2_idx = z_idx[mask]
map1_val = map1[map1_idx]
mask1 = map1_val != -1
map1_val = map1_val[mask1]
map2_idx = map2_idx[mask1]
mask[mask.clone()] = mask1
points_offset = map2[map1_val, map2_idx]
mask2 = points_offset != -1
mask[mask.clone()] = mask2
M[mask] = features[points_offset[mask2], :]
out = M.reshape(features.shape[0], -1) @ weight_flatten + bias.reshape(1, -1)
indices_offset[mask] = points_offset[mask2]
return out, indices_offset.flatten(), M
def get_output(num_points, batch_size, in_channels, out_channels,
kernel_size, spatial_shape, dtype=torch.float32):
features, indices = generate_sparse_data(num_points, spatial_shape, in_channels)
features, indices = features.to(dtype).npu(), indices.npu()
net = SubMConv3d(in_channels, out_channels, kernel_size).npu()
features = features.to(dtype)
net.weight.data = net.weight.data.to(dtype)
net.bias.data = net.bias.data.to(dtype)
x = SparseConvTensor(features, indices, spatial_shape, batch_size)
golden_output, _, _ = get_golden_output(features, indices, net.weight.data, net.bias.data, batch_size,
in_channels, out_channels, kernel_size, spatial_shape)
res = net(x).features
return res.detach().cpu().numpy(), golden_output.detach().cpu().numpy()
class TestSubmSparseConv3d(TestCase):
def do_custom_test(self, num_points, out_spatial_shape, in_channels, out_channels, kernel_size, batch_size):
res, golden = get_output(num_points, batch_size, in_channels, out_channels, kernel_size, out_spatial_shape, torch.float32)
self.assertRtolEqual(golden, res, 0.00048828, 0.00048828)
res, golden = get_output(num_points, batch_size, in_channels, out_channels, kernel_size, out_spatial_shape, torch.float16)
self.assertRtolEqual(golden, res, 1e-3, 1e-3)
def test(self):
self.do_custom_test([61557], [1440, 1440, 41], 16, 32, 3, 1)
self.do_custom_test([38153], [1180, 180, 5], 128, 256, 3, 1)
self.do_custom_test([38153], [1180, 180, 5], 128, 256, 5, 1)
self.do_custom_test([23787], [3571, 4251, 1062], 4, 32, 5, 1)
self.do_custom_test([50000], [128, 128, 128], 1024, 1024, 3, 1)
self.do_custom_test([50000], [128, 128, 128], 1024, 1024, 5, 1)
self.do_custom_test([200000], [128, 128, 128], 128, 256, 7, 1)
self.do_custom_test([370000], [1440, 1440, 41], 16, 32, 3, 1)
if __name__ == "__main__":
np.random.seed(100)
run_tests()