#include "csrc/OpApiCommon.h"
#include "csrc/functions.h"
at::Tensor cal_anchors_heading(const at::Tensor& anchors, const at::Tensor& origin_pos)
{
TORCH_CHECK_NPU(anchors);
TORCH_CHECK_NPU(origin_pos);
TORCH_CHECK(anchors.dim() == 4, "anchors must be a 4D Tensor, but got: ", anchors.dim());
TORCH_CHECK(origin_pos.dim() == 2, "origin_pos must be a 2D Tensor, but got: ", origin_pos.dim());
TORCH_CHECK(anchors.size(3) == 2, "the last dim of anchors must be 2, but got: ", anchors.size(3));
uint32_t batch_size = static_cast<uint32_t>(anchors.size(0));
uint32_t anchors_num = static_cast<uint32_t>(anchors.size(1));
uint32_t seq_length = static_cast<uint32_t>(anchors.size(2));
at::Tensor heading = at::empty({batch_size, anchors_num, seq_length}, anchors.options());
EXEC_NPU_CMD(aclnnCalAnchorsHeading, anchors, origin_pos, heading);
return heading;
}