aac322fb创建于 2025年8月13日历史提交
"""
Copyright (c) OpenMMLab. All rights reserved.
Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
Modification by: Huawei Developers
Modification date: 2024-06-04 
Modification Description: 
Modification 1. Add support for Ascend NPU
"""
import warnings

import numpy as np
import torch
import torch_npu
from torch.autograd import Function

import mx_driving._C


class AdsFurthestPointSampling(Function):
    @staticmethod
    def forward(ctx, point_xyz, num_points):
        if (torch.numel(point_xyz) == 0):
            raise Exception("Error! Input Tensor can not be a empty Tensor.\n")
        
        if (num_points == 0):
            raise Exception("Error! num_points can not zero.\n")
        
        B, N = point_xyz.size()[:2]
        point_xyz = point_xyz.permute(0, 2, 1).contiguous()

        nearest_dtype = point_xyz.dtype
        if point_xyz.dtype == torch.bfloat16:
            nearest_dtype = torch.float32

        nearest_dist = torch.tensor(np.ones((B, N)) * 1e10, dtype=nearest_dtype, device="npu").contiguous()
        output = mx_driving._C.npu_furthest_point_sampling(point_xyz, nearest_dist, num_points)

        return output


def furthest_point_sampling(point_xyz, num_points):
    return AdsFurthestPointSampling.apply(point_xyz, num_points)


def npu_furthest_point_sampling(point_xyz, num_points):
    warnings.warn(
        "`npu_furthest_point_sampling` will be deprecated in future. Please use `furthest_point_sampling` instead.",
        DeprecationWarning,
    )
    return AdsFurthestPointSampling.apply(point_xyz, num_points)