#include <ATen/ATen.h>
#include <torch/library.h>
#include "pytorch_npu_helper.hpp"
namespace vision {
namespace ops {
namespace {
const int SIZE = 8;
template <typename scalar_t>
at::Tensor nms_kernel_impl(
const at::Tensor& boxes,
const at::Tensor& scores,
double iou_threshold)
{
at::Tensor iou_threshold_y = at::empty({}, boxes.options().dtype(at::kFloat)).fill_(iou_threshold);
at::Tensor scores_threshold_y;
static const bool is_fill_lowest = IsGteCANNVersion("8.1.RC1");
if (is_fill_lowest) {
scores_threshold_y = at::empty({}, boxes.options().dtype(at::kFloat)).fill_(std::numeric_limits<float>::lowest());
} else {
scores_threshold_y = at::empty({}, boxes.options().dtype(at::kFloat)).fill_(0);
}
at::Tensor max_outputsize_y = at::empty({}, boxes.options().dtype(at::kInt)).fill_(boxes.size(0));
c10::SmallVector<int64_t, SIZE> outputsize = {boxes.size(0)};
at::Tensor output = at::empty(outputsize, boxes.options().dtype(at::kInt)).fill_(-1);
at_npu::native::OpCommand cmd;
cmd.Name("NonMaxSuppressionV3")
.Input(boxes)
.Input(scores)
.Input(max_outputsize_y)
.Input(iou_threshold_y)
.Input(scores_threshold_y)
.Output(output)
.Run();
auto outputsizeBool = at::gt(output, -1);
auto outputsizeInt = outputsizeBool.to(at::ScalarType::Int);
auto countLen = at::sum(outputsizeInt, at::ScalarType::Int);
at::Tensor actual_output = output.slice(0, 0, countLen.item().toLong());
actual_output = actual_output.to(at::kLong);
return actual_output;
}
at::Tensor nms_kernel(
const at::Tensor& dets,
const at::Tensor& scores,
double iou_threshold)
{
const c10::OptionalDeviceGuard device_guard(device_of(dets));
auto result = at::empty({0}, dets.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(dets.scalar_type(), "nms_kernel", [&] {
result = nms_kernel_impl<scalar_t>(dets, scores, iou_threshold);
});
return result;
}
}
TORCH_LIBRARY_IMPL(torchvision, PrivateUse1, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
}
}
}