from typing import Any
import torch
import torch_npu
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from mx_driving._C import (
npu_sparse_conv3d,
unique_voxel,
npu_sparse_matmul,
npu_sparse_conv3d_grad,
npu_subm_sparse_conv3d_v3,
npu_subm_sparse_conv3d_arch35,
npu_subm_sparse_conv3d_grad_v2,
npu_subm_sparse_conv3d_grad_arch35,
npu_sparse_inverse_conv3d,
)
class SparseConvFunction(Function):
@staticmethod
def forward(
ctx: Any,
features,
indices,
weight,
out_spatial_shape,
out_channels,
batch_size,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
) -> torch.Tensor:
weight = weight.data
outidx_pair, ouidx_offset = npu_sparse_conv3d(
indices, kernel_size, stride, padding, out_channels, out_spatial_shape, batch_size
)
num_voxels_, uni_voxels, unique_indices_offset, sorted_idx_to_former_indices, uni_argsort_indices = (
unique_voxel(ouidx_offset)
)
indices_last = torch.tensor(ouidx_offset.shape).to(unique_indices_offset.device)
unique_indices_offset = torch.cat((unique_indices_offset, indices_last), dim=0)
out_features, _ = npu_sparse_matmul(
features, weight, unique_indices_offset.int(), sorted_idx_to_former_indices.int(), outidx_pair.int()
)
ctx.save_for_backward(features, weight, sorted_idx_to_former_indices.int(), unique_indices_offset.int())
return (
out_features,
outidx_pair.int()[uni_argsort_indices],
unique_indices_offset,
sorted_idx_to_former_indices,
outidx_pair,
)
@staticmethod
@once_differentiable
def backward(
ctx: Any,
grad_out_features: torch.Tensor,
grad_outidx=None,
grad_unique_indices_offset=None,
grad_sorted_idx_to_former_indices=None,
grad_outidx_pair=None,
) -> tuple:
features, weight, sorted_idx_to_former_indices, unique_indices_offset = ctx.saved_tensors
feature_grad, weight_grad = npu_sparse_conv3d_grad(
sorted_idx_to_former_indices, unique_indices_offset, features, weight, grad_out_features
)
return feature_grad, None, weight_grad, None, None, None, None, None, None, None, None, None
def generate_map(coors, spaned_spatial_shape, bs, kernel_size):
padding = kernel_size[0] // 2
spatial_shape_size = spaned_spatial_shape[0] * spaned_spatial_shape[1] * spaned_spatial_shape[2]
if spatial_shape_size > 400000000:
spatial_shape1 = spaned_spatial_shape[1] * spaned_spatial_shape[0]
new_coors1 = (
spatial_shape1 * coors[:, 0]
+ spaned_spatial_shape[1] * coors[:, 1]
+ coors[:, 2]
+ padding
+ spaned_spatial_shape[1] * padding
)
map1 = torch.full((spatial_shape1 * bs,), -1, dtype=torch.int32, device=coors.device)
map1_length, unique_idx, _, _, _ = unique_voxel(new_coors1)
map1[unique_idx] = torch.arange(map1_length, dtype=torch.int32, device=coors.device)
map2 = torch.full((map1_length, spaned_spatial_shape[2]), -1, dtype=torch.int32, device=coors.device)
map2[map1[new_coors1], (coors[:, 3] + padding)] = torch.arange(
new_coors1.numel(), dtype=torch.int32, device=coors.device
)
else:
map1 = torch.full(
(bs, spaned_spatial_shape[0], spaned_spatial_shape[1], spaned_spatial_shape[2]),
-1,
dtype=torch.int32,
device=coors.device,
)
bs, sp0, sp1, sp2 = coors.split(1, dim=1)
bs, sp0, sp1, sp2 = bs.flatten(), sp0.flatten(), sp1.flatten(), sp2.flatten()
map1[bs, sp0 + padding, sp1 + padding, sp2 + padding] = torch.arange(
coors.shape[0], dtype=torch.int32, device=coors.device
)
map1 = map1.flatten().contiguous()
map2 = torch.Tensor([]).int()
return map1, map2
class SubMConvFunction(Function):
@staticmethod
def forward(
ctx: Any,
features,
indices,
weight,
indices_offset,
out_spatial_shape,
out_channels,
batch_size,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
) -> torch.Tensor:
weight = weight.data
indices = indices.contiguous()
spaned_spatial_shape = (
out_spatial_shape[0] + 2 * (kernel_size[0] // 2),
out_spatial_shape[1] + 2 * (kernel_size[1] // 2),
out_spatial_shape[2] + 2 * (kernel_size[2] // 2),
)
if indices_offset is None:
map1, map2 = generate_map(indices, spaned_spatial_shape, batch_size, kernel_size)
with_key = 0
indices_offset = torch.Tensor([]).int()
else:
map1, map2 = torch.Tensor([]).int(), torch.Tensor([]).int()
with_key = 1
DEVICE_NAME = torch_npu.npu.get_device_name(features.device.index)
if 'Ascend910' in DEVICE_NAME:
subm_forward_func = npu_subm_sparse_conv3d_v3
elif 'Ascend950' in DEVICE_NAME:
subm_forward_func = npu_subm_sparse_conv3d_arch35
else:
raise NotImplementedError(
'The npu_subm_sparse_conv3d operator currently only supports Ascend910B, Ascend910C and Ascend950.'
)
out_features, out_indices_offset = subm_forward_func(
features,
weight,
indices,
indices_offset,
map1,
map2,
kernel_size,
features.shape[1],
out_channels,
spaned_spatial_shape,
batch_size,
with_key,
)
indices_offset = indices_offset if with_key == 1 else out_indices_offset
ctx.save_for_backward(features, weight, indices_offset)
return out_features, indices, indices_offset
@staticmethod
@once_differentiable
def backward(ctx: Any, grad_out_features: torch.Tensor, grad_outidx=None, grad_offset=None) -> tuple:
features, weight, ouidx_offset = ctx.saved_tensors
DEVICE_NAME = torch_npu.npu.get_device_name(features.device.index)
if 'Ascend910' in DEVICE_NAME:
subm_grad_func = npu_subm_sparse_conv3d_grad_v2
elif 'Ascend950' in DEVICE_NAME:
subm_grad_func = npu_subm_sparse_conv3d_grad_arch35
else:
raise NotImplementedError(
'The npu_subm_sparse_conv3d_grad operator currently only supports Ascend910B, Ascend910C and Ascend950.'
)
feature_grad, weight_grad = subm_grad_func(features, weight, grad_out_features, ouidx_offset)
return feature_grad, None, weight_grad, None, None, None, None, None, None, None, None, None, None
class SparseInverseConvFunction(Function):
@staticmethod
def forward(
ctx: Any,
features,
weight,
in_channels,
out_channels,
kernel_size,
indice_data,
) -> torch.Tensor:
weight = weight.data
output_img2col = npu_sparse_inverse_conv3d(
features,
indice_data.origin_indices,
indice_data.unique_indices_offset.int(),
indice_data.sorted_idx_to_former_indices,
kernel_size,
in_channels,
)
out_features = output_img2col @ weight.reshape(-1, out_channels)
ctx.save_for_backward(
weight,
indice_data.unique_indices_offset.int(),
indice_data.sorted_idx_to_former_indices,
indice_data.outidx_pair,
output_img2col,
)
return out_features
@staticmethod
@once_differentiable
def backward(ctx: Any, grad_out_features: torch.Tensor) -> tuple:
weight, unique_indices_offset, sorted_idx_to_former_indices, outidx_pair, output_img2col = ctx.saved_tensors
weight_shape = weight.shape
weight.data = weight.data.permute(0, 1, 2, 4, 3).contiguous()
inverse_feature_grad, _ = npu_sparse_matmul(
grad_out_features, weight, unique_indices_offset, sorted_idx_to_former_indices, outidx_pair
)
inverse_weight_grad = grad_out_features.transpose(0, 1).contiguous() @ output_img2col
inverse_weight_grad = (
inverse_weight_grad.transpose(0, 1)
.contiguous()
.view(weight_shape[0], weight_shape[1], weight_shape[2], weight_shape[3], weight_shape[4])
)
return inverse_feature_grad, inverse_weight_grad, None, None, None, None
indice_conv = SparseConvFunction.apply
indice_subm_conv = SubMConvFunction.apply
indice_inverse_conv = SparseInverseConvFunction.apply