#include "op_plugin/AclOpsInterface.h"
#include "op_plugin/utils/OpAdapter.h"
namespace acl_op {
using npu_preparation = at_npu::native::OpPreparation;
namespace {
std::tuple<at::Tensor&, at::Tensor&, at::Tensor&> nms_with_mask_npu_nocheck(
at::Tensor& boxes,
at::Tensor& idx,
at::Tensor& mask,
const at::Tensor& input,
at::Scalar iou_threshold)
{
float iou_threshold_value = op_plugin::utils::get_scalar_float_value(iou_threshold);
at_npu::native::OpCommand cmd;
cmd.Name("NMSWithMask")
.Input(input)
.Output(boxes)
.Output(idx)
.Output(mask)
.Attr("iou_threshold", iou_threshold_value)
.Run();
return std::tie(boxes, idx, mask);
}
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> npu_nms_with_mask(
const at::Tensor& input,
const at::Scalar& iou_threshold)
{
auto output_sizes = op_infer::nms_with_mask_npu_output_size(input);
at::Tensor boxes = npu_preparation::apply_tensor(input, std::get<0>(output_sizes));
at::Tensor idx = npu_preparation::apply_tensor(std::get<1>(output_sizes), input.options().dtype(at::kInt), input);
at::Tensor mask = npu_preparation::apply_tensor(std::get<2>(output_sizes), input.options().dtype(at::kByte), input);
nms_with_mask_npu_nocheck(boxes, idx, mask, input, iou_threshold);
return std::tie(boxes, idx, mask);
}
}