#include "csrc/OpApiCommon.h"
#include "csrc/functions.h"
namespace {
constexpr uint32_t BLOCK_NUM = 8;
constexpr float DEFAULT_VALUE = -1.0f;
inline void npu_dynamic_scatter_check(const at::Tensor &feats, const at::Tensor &coors) {
TORCH_CHECK_NPU(feats);
TORCH_CHECK_NPU(coors);
TORCH_CHECK(coors.size(1) == 3, "npu_dynamic_scatter only support coors.size(1) == 3.");
TORCH_CHECK(feats.size(0) == coors.size(0), "npu_dynamic_scatter: feats.size(0) should equal coors.size(0).");
TORCH_CHECK(feats.size(1) <= 2048, "npu_dynamic_scatter: feats.size(1) should less than or equal to 2048.");
}
}
std::tuple<at::Tensor, at::Tensor> npu_dynamic_scatter(const at::Tensor &feats, const at::Tensor &coors,
const at::Tensor &prefix_sum_point_per_voxel, const at::Tensor &argsort_coor, int32_t num_voxels,
const char *reduce_type) {
npu_dynamic_scatter_check(feats, coors);
uint32_t point_num = static_cast<uint32_t>(feats.size(0));
uint32_t feats_dim = static_cast<uint32_t>(feats.size(1));
if (point_num == 0 || feats_dim == 0) {
return std::make_tuple(feats.clone().detach(), coors.new_empty({0}, at::kByte));
}
TORCH_CHECK(num_voxels > 0, "num_voxels must be positive, but got: ", num_voxels);
uint32_t mask_dim = (feats_dim + BLOCK_NUM - 1) / BLOCK_NUM;
at::Tensor voxel_feats = at::zeros({num_voxels, feats_dim}, feats.options());
at::Tensor compare_mask = at::empty({point_num, mask_dim}, feats.options().dtype(at::kByte));
EXEC_NPU_CMD_SYNC(
aclnnDynamicScatter, feats, prefix_sum_point_per_voxel, argsort_coor, reduce_type, voxel_feats, compare_mask);
return std::make_tuple(voxel_feats, compare_mask);
}
void npu_dynamic_scatter_grad(at::Tensor &grad_point_feats, const at::Tensor &grad_voxel_feats,
const at::Tensor &prefix_sum_point_per_voxel, const at::Tensor &argsort_coor, const at::Tensor &compare_mask,
const char *reduce_type) {
auto point_num = grad_point_feats.size(0);
auto feats_dim = grad_point_feats.size(1);
if (point_num > 0 && feats_dim > 0) {
TORCH_CHECK(grad_voxel_feats.size(0) > 0,
"grad_voxel_feats must have at least 1 voxel, but got: ", grad_voxel_feats.size(0));
EXEC_NPU_CMD(aclnnDynamicScatterGrad, grad_voxel_feats, prefix_sum_point_per_voxel, argsort_coor, compare_mask,
reduce_type, grad_point_feats);
}
}