#include "csrc/OpApiCommon.h"
#include "csrc/functions.h"
void roi_align_rotated_v2_forward_npu(const at::Tensor& input, const at::Tensor& rois_map, at::Tensor& output,
double spatial_scale, int32_t sampling_ratio, int32_t pooled_height, int32_t pooled_width, bool aligned,
bool clockwise)
{
at::Tensor feature_map = input.permute({0, 2, 3, 1}).contiguous();
at::Tensor rois = rois_map.permute({1, 0}).contiguous();
EXEC_NPU_CMD(aclnnRoiAlignRotatedV2, feature_map, rois, spatial_scale, sampling_ratio, pooled_height, pooled_width,
aligned, clockwise, output);
}
at::Tensor npu_roi_align_rotated_grad_v2(const at::Tensor& input, const at::Tensor& rois, const at::Tensor& grad_output,
int32_t pooled_height, int32_t pooled_width, double spatial_scale, int32_t sampling_ratio, bool aligned,
bool clockwise)
{
auto ori_dtype = input.scalar_type();
c10::SmallVector<int64_t, SIZE> grad_input_size = {input.size(0), input.size(2), input.size(3), input.size(1)};
at::Tensor grad_input = at::zeros(grad_input_size, input.options());
EXEC_NPU_CMD(aclnnRoiAlignRotatedGradV2, input, rois, grad_output, pooled_height, pooled_width, spatial_scale,
sampling_ratio, aligned, clockwise, grad_input);
return grad_input.to(ori_dtype);
}