#include "csrc/OpApiCommon.h"
#include "csrc/functions.h"
constexpr uint32_t NEW_XYZ_DIM = 3;
constexpr uint32_t XYZ_DIM = 3;
constexpr uint32_t ROT_DIM = 3;
constexpr uint32_t POINT_DIM = 3;
constexpr uint32_t ROT_SIZE = 9;
constexpr uint32_t B_INDEX = 0;
constexpr uint32_t N_INDEX = 1;
constexpr uint32_t M_INDEX = 1;
constexpr uint32_t INPUT_XYZ_DIM_INDEX = 2;
constexpr uint32_t INPUT_NEW_XYZ_DIM_INDEX = 2;
constexpr uint32_t INPUT_ROT_DIM_INDEX = 2;
at::Tensor cylinder_query(double radius, double hmin, double hmax, int nsample, const at::Tensor &new_xyz,
const at::Tensor &xyz, const at::Tensor &rot) {
TORCH_CHECK_NPU(new_xyz);
TORCH_CHECK_NPU(xyz);
TORCH_CHECK_NPU(rot);
TORCH_CHECK(new_xyz.dim() == NEW_XYZ_DIM, "new_xyz must be a 3D Tensor, but got: ", new_xyz.dim());
TORCH_CHECK(xyz.dim() == XYZ_DIM, "xyz must be a 3D Tensor, but got: ", xyz.dim());
TORCH_CHECK(rot.dim() == ROT_DIM, "rot must be a 3D Tensor, but got: ", rot.dim());
TORCH_CHECK(rot.size(B_INDEX) == new_xyz.size(B_INDEX), "The batch sizes of rot and new_xyz must be equal.");
TORCH_CHECK(rot.size(B_INDEX) == xyz.size(B_INDEX), "The batch sizes of rot and xyz must be equal.");
TORCH_CHECK(new_xyz.size(INPUT_NEW_XYZ_DIM_INDEX) == POINT_DIM,
"new_xyz Coordinates should be represented by 3 numbers, bug got: ", new_xyz.size(INPUT_NEW_XYZ_DIM_INDEX));
TORCH_CHECK(xyz.size(INPUT_XYZ_DIM_INDEX) == POINT_DIM,
"xyz Coordinates should be represented by 3 numbers, bug got: ", xyz.size(INPUT_XYZ_DIM_INDEX));
TORCH_CHECK(rot.size(INPUT_ROT_DIM_INDEX) == ROT_SIZE,
"The size of the last dimension in rot should be 9, bug got: ", xyz.size(INPUT_ROT_DIM_INDEX));
TORCH_CHECK(rot.size(M_INDEX) == new_xyz.size(M_INDEX), "The number of rot and new_xyz must be equal.");
TORCH_CHECK(hmin < hmax, "The value of hmin needs to be less than the value of hmax.");
TORCH_CHECK(nsample > 0, "The value of nsample should be greater than 0.");
TORCH_CHECK(nsample <= xyz.size(N_INDEX),
"The value of nsample should be greater than the number of points in the tensor xyz.");
uint32_t B = static_cast<uint32_t>(new_xyz.size(B_INDEX));
uint32_t N = static_cast<uint32_t>(xyz.size(N_INDEX));
uint32_t M = static_cast<uint32_t>(new_xyz.size(M_INDEX));
at::Tensor origin_index = at::arange(0, xyz.size(N_INDEX), new_xyz.options().dtype(at::kFloat));
at::Tensor out = at::empty({B, M, N}, new_xyz.options());
EXEC_NPU_CMD(aclnnCylinderQuery, new_xyz, xyz, rot, origin_index, B, N, M, radius, hmin, hmax, nsample, out);
return out;
}