"""
Copyright (c) OpenMMLab. All rights reserved.
Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
Modification by: Huawei Developers
Modification date: 2024-06-04
Modification Description:
Modification 1. Add support for Ascend NPU
"""
import torch
import torch_npu
from torch.autograd import Function
import mx_driving._C
class RoipointPool3dFunction(Function):
@staticmethod
def forward(ctx, num_sampled_points, points, point_features, boxes3d):
if num_sampled_points <= 0:
raise Exception("Input num_sampled_points be more than 0")
if (points.size(0) != point_features.size(0)) or (points.size(0) != boxes3d.size(0)):
raise Exception("Input points/point_features/boxes3d shape should be (B, x, x)")
if (len(points.shape) != 3) or (points.size(2) != 3):
raise Exception("Input points shape should be (B, N, 3)")
if (len(point_features.shape) != 3) or (points.size(1) != point_features.size(1)):
raise Exception("Input point_features shape should be (B, N, C)")
if (len(boxes3d.shape) != 3) or (boxes3d.size(2) != 7):
raise Exception("Input boxes3d shape should be (B, M, 7)")
if (points.dtype != point_features.dtype) or (points.dtype != boxes3d.dtype):
raise Exception("Input points/point_features/boxes3d dtype should be the same.")
if (points.device.type != "npu") or (point_features.device.type != "npu") or (boxes3d.device.type != "npu"):
raise ValueError("The device is not npu!")
batch_size = points.size(0)
boxes_num = boxes3d.size(1)
feature_len = point_features.size(2)
pooled_features, pooled_empty_flag = mx_driving._C.npu_roipoint_pool3d_forward(
num_sampled_points, points, point_features, boxes3d
)
return pooled_features, pooled_empty_flag
roipoint_pool3d = RoipointPool3dFunction.apply