#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/utils/OpAdapter.h"
namespace acl_op {
using npu_preparation = at_npu::native::OpPreparation;
namespace {
at::Tensor& roi_align_backward_npu_nocheck(
at::Tensor& result,
const at::Tensor& self,
const at::Tensor& rois,
at::IntArrayRef xdiff_shape,
int64_t pooled_width,
int64_t pooled_height,
double spatial_scale,
int64_t sample_num,
c10::optional<int64_t> roi_end_mode)
{
at_npu::native::OpCommand cmd;
cmd.Name("ROIAlignGrad")
.Input(self, "ydiff")
.Input(rois)
.Output(result, "xdiff")
.Attr("xdiff_shape", xdiff_shape)
.Attr("spatial_scale", static_cast<float>(spatial_scale))
.Attr("pooled_height", pooled_height)
.Attr("pooled_width", pooled_width)
.Attr("sample_num", sample_num);
if (roi_end_mode.has_value()) {
cmd.Attr("roi_end_mode", roi_end_mode.value());
}
cmd.Run();
return result;
}
}
at::Tensor npu_roi_alignbk(
const at::Tensor& self,
const at::Tensor& rois,
at::IntArrayRef xdiff_shape,
int64_t pooled_width,
int64_t pooled_height,
double spatial_scale,
int64_t sample_num,
c10::optional<int64_t> roi_end_mode)
{
at::Tensor result =
npu_preparation::apply_tensor_with_format(self, xdiff_shape, ACL_FORMAT_NC1HWC0);
for (int i = 0; i < self.dim(); i++) {
if (self.size(i) == 0) {
acl_op::fill_(result, 0);
return result;
}
}
roi_align_backward_npu_nocheck(
result,
self,
rois,
xdiff_shape,
pooled_width,
pooled_height,
spatial_scale,
sample_num,
roi_end_mode);
return result;
}
}