#include "csrc/OpApiCommon.h"
#include "csrc/functions.h"
using namespace std;
constexpr uint32_t ALIGN_NUM = 8;
constexpr uint32_t NUM_COORDINATES = 9;
constexpr uint32_t NUM_DIM = 2;
std::tuple<at::Tensor, at::Tensor> radius(
at::Tensor &x, at::Tensor &y, at::Tensor &ptr_x, at::Tensor &ptr_y, double r, int max_num_neighbors) {
TORCH_CHECK_NPU(x);
TORCH_CHECK_NPU(y);
TORCH_CHECK_NPU(ptr_x);
TORCH_CHECK_NPU(ptr_y);
TORCH_CHECK(x.size(1) < NUM_COORDINATES, "x must be a coordinate which is smaller than 9D, but got: ", x.size(1));
TORCH_CHECK(y.size(1) < NUM_COORDINATES, "y must be a coordinate which is smaller than 9D, but got: ", y.size(1));
TORCH_CHECK(x.dim() == NUM_DIM, "x must be a 2D Tensor, but got: ", x.dim());
TORCH_CHECK(y.dim() == NUM_DIM, "y must be a 2D Tensor, but got: ", y.dim());
TORCH_CHECK(ptr_x.size(0) >= 1, "ptr_x must have at least 1 element (batch_size + 1), but got: ", ptr_x.size(0));
TORCH_CHECK(ptr_y.size(0) >= 1, "ptr_y must have at least 1 element (batch_size + 1), but got: ", ptr_y.size(0));
auto x_shape = x.sizes();
auto y_shape = y.sizes();
auto ptr_x_shape = ptr_x.sizes();
auto x_trans = x.transpose(0, 1).contiguous();
auto y_trans = y.transpose(0, 1).contiguous();
int out_dim = y_shape[0] * max_num_neighbors;
int batch_size = ptr_x_shape[0] - 1;
auto out_temp = at::zeros({NUM_DIM, out_dim},
ptr_x.options().dtype(
at::kInt));
auto out_final = at::zeros({NUM_DIM, out_dim},
ptr_x.options().dtype(
at::kInt));
auto actual_num_neighbors = at::zeros({ALIGN_NUM}, ptr_x.options().dtype(at::kInt));
EXEC_NPU_CMD(
aclnnRadius, x_trans, y_trans, ptr_x, ptr_y, r, max_num_neighbors, out_temp, out_final, actual_num_neighbors);
return std::tie(out_final, actual_num_neighbors);
}