#include "pytorch_npu_helper.hpp"
#include <cfloat>
#include <ATen/ATen.h>
#include <torch/library.h>
namespace vision {
namespace ops {
namespace {
template <typename T>
void roi_pool_forward_kernel_impl(
const at::Tensor& input,
const float spatial_scale,
int channels,
int height,
int width,
int pooled_height,
int pooled_width,
const at::Tensor& rois,
int num_rois,
at::Tensor& output,
at::Tensor& argmax)
{
int64_t pooled_height_64 = pooled_height;
int64_t pooled_width_64 = pooled_width;
int64_t pooled_channel = 1;
at::Tensor roi_actual_num = at::empty({}, rois.options().dtype(at::kInt));
at_npu::native::OpCommand cmd;
if (input.size(1) % 16 == 0) {
cmd.Name("RoiPoolingWithArgMax")
.Input(input)
.Input(rois)
.Input(roi_actual_num)
.Output(output)
.Output(argmax)
.Attr("pooled_h", pooled_height_64)
.Attr("pooled_w", pooled_width_64)
.Attr("spatial_scale_h", spatial_scale)
.Attr("spatial_scale_w", spatial_scale)
.Attr("pool_channel", pooled_channel)
.Run();
} else {
cmd.Name("RoiPoolingWithArgMax")
.Input(input)
.Input(rois)
.Input(roi_actual_num)
.Output(output)
.Output(argmax)
.Attr("pooled_h", pooled_height_64)
.Attr("pooled_w", pooled_width_64)
.Attr("spatial_scale_h", spatial_scale)
.Attr("spatial_scale_w", spatial_scale)
.Attr("pool_channel", pooled_channel)
.Attr("_exclude_engines", (string) "AiCore")
.Run();
}
}
template <typename T>
void roi_pool_backward_kernel_impl(
const at::Tensor& grad,
const at::Tensor& input,
at::Tensor& output,
const at::Tensor& argmax,
const float spatial_scale,
int num_rois,
int channels,
int height,
int width,
int pooled_height,
int pooled_width,
const at::Tensor& rois)
{
int64_t pooled_height_64 = pooled_height;
int64_t pooled_width_64 = pooled_width;
int64_t pooled_channel = 1;
at::Tensor roi_actual_num = at::empty({}, rois.options().dtype(at::kInt));
at_npu::native::OpCommand cmd;
cmd.Name("RoiPoolingGradWithArgMax")
.Input(grad)
.Input(input)
.Input(rois)
.Input(roi_actual_num)
.Input(argmax)
.Output(output)
.Attr("pooled_h", pooled_height_64)
.Attr("pooled_w", pooled_width_64)
.Attr("spatial_scale_h", spatial_scale)
.Attr("spatial_scale_w", spatial_scale)
.Attr("pool_channel", pooled_channel)
.Run();
}
std::tuple<at::Tensor, at::Tensor> roi_pool_forward_kernel(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width)
{
at::TensorArg input_t{input, "input", 1};
at::TensorArg rois_t{rois, "rois", 2};
at::CheckedFrom c = "roi_pool_forward_kernel";
at::checkAllSameType(c, {input_t, rois_t});
TORCH_CHECK(input.numel() != 0 && rois.numel() != 0,
"Expected input and rois to be non-empty tensors, but got empty input and/or rois.")
TORCH_CHECK(input.dim() > 3, "tensor input's dimension must be greater than 3, "
"but got Tensor of dimension ", input.dim())
int num_rois = rois.size(0);
int channels = input.size(1);
int height = input.size(2);
int width = input.size(3);
at::Tensor output = at::zeros(
{num_rois, channels, pooled_height, pooled_width}, input.options());
at::Tensor argmax = at::zeros(
{num_rois, channels, pooled_height, pooled_width},
input.options().dtype(at::kInt));
if (output.numel() == 0) {
return std::make_tuple(output, argmax);
}
auto input_ = input.contiguous();
auto rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "roi_pool_forward_kernel", [&] {
roi_pool_forward_kernel_impl<scalar_t>(
input_,
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
rois_,
num_rois,
output,
argmax);
});
return std::make_tuple(output, argmax);
}
at::Tensor roi_pool_backward_kernel(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width)
{
TORCH_CHECK(
rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]");
at::TensorArg grad_t{grad, "grad", 1};
at::TensorArg rois_t{rois, "rois", 2};
at::CheckedFrom c = "roi_pool_backward_kernel";
at::checkAllSameType(c, {grad_t, rois_t});
auto num_rois = rois.size(0);
at::Tensor grad_input =
at::zeros({batch_size, channels, height, width}, grad.options());
at::Tensor output =
at::zeros({batch_size, channels, height, width}, grad.options());
if (grad.numel() == 0) {
return grad_input;
}
auto rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "roi_pool_backward_kernel", [&] {
roi_pool_backward_kernel_impl<scalar_t>(
grad,
grad_input,
output,
argmax,
spatial_scale,
num_rois,
channels,
height,
width,
pooled_height,
pooled_width,
rois_);
});
return output;
}
}
TORCH_LIBRARY_IMPL(torchvision, PrivateUse1, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_pool"),
TORCH_FN(roi_pool_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"),
TORCH_FN(roi_pool_backward_kernel));
}
}
}