@@ -205,6 +205,8 @@ class AnchorGenerator:
tuple[torch.Tensor]: The mesh grids of x and y.
"""
# use shape instead of len to keep tracing while exporting to onnx
+ x = x.to(dtype=torch.int32)
+ y = y.to(dtype=torch.int32)
xx = x.repeat(y.shape[0])
yy = y.view(-1, 1).repeat(1, x.shape[0]).view(-1)
if row_major:
@@ -10,6 +10,56 @@ from .base_dense_head import BaseDenseHead
from .dense_test_mixins import BBoxTestMixin
+class BatchNMSOp(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size):
+ """
+ boxes (torch.Tensor): boxes in shape (batch, N, C, 4).
+ scores (torch.Tensor): scores in shape (batch, N, C).
+ return:
+ nmsed_boxes: (1, N, 4)
+ nmsed_scores: (1, N)
+ nmsed_classes: (1, N)
+ nmsed_num: (1,)
+ """
+
+ # Phony implementation for onnx export
+ nmsed_boxes = bboxes[:, :max_total_size, 0, :]
+ nmsed_scores = scores[:, :max_total_size, 0]
+ nmsed_classes = torch.arange(max_total_size, dtype=torch.long)
+ nmsed_num = torch.Tensor([max_total_size])
+
+ return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
+
+ @staticmethod
+ def symbolic(g, bboxes, scores, score_thr, iou_thr, max_size_p_class, max_t_size):
+ nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = g.op('BatchMultiClassNMS',
+ bboxes, scores, score_threshold_f=score_thr, iou_threshold_f=iou_thr,
+ max_size_per_class_i=max_size_p_class, max_total_size_i=max_t_size, outputs=4)
+ return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
+
+def batch_nms_op(bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size):
+ """
+ boxes (torch.Tensor): boxes in shape (N, 4).
+ scores (torch.Tensor): scores in shape (N, ).
+ """
+
+ if bboxes.dtype == torch.float32:
+ bboxes = bboxes.reshape(bboxes.size(0), bboxes.shape[1].numpy(), -1, 4).half()
+ scores = scores.reshape(scores.size(0), scores.shape[1].numpy(), -1).half()
+ else:
+ bboxes = bboxes.reshape(bboxes.size(0), bboxes.shape[1].numpy(), -1, 4)
+ scores = scores.reshape(scores.size(0), scores.shape[1].numpy(), -1)
+
+ nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = BatchNMSOp.apply(bboxes, scores,
+ score_threshold, iou_threshold, max_size_per_class, max_total_size)
+ nmsed_boxes = nmsed_boxes.float()
+ nmsed_scores = nmsed_scores.float()
+ nmsed_classes = nmsed_classes.long()
+ dets = torch.cat((nmsed_boxes.reshape((bboxes.size(0), max_total_size, 4)), nmsed_scores.reshape((bboxes.size(0), max_total_size, 1))), -1)
+ labels = nmsed_classes.reshape((bboxes.size(0), max_total_size))
+ return dets, labels
+
@HEADS.register_module()
class AnchorHead(BaseDenseHead, BBoxTestMixin):
"""Anchor-based head (RPN, RetinaNet, SSD, etc.).
@@ -654,7 +704,10 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin):
anchors = anchors.expand_as(bbox_pred)
# Always keep topk op for dynamic input in onnx
from mmdet.core.export import get_k_for_topk
- nms_pre = get_k_for_topk(nms_pre_tensor, bbox_pred.shape[1])
+ nms_pre = bbox_pred.shape[1]
+ if nms_pre_tensor > 0 and bbox_pred.shape[1] > nms_pre_tensor:
+ nms_pre = nms_pre_tensor#get_k_for_topk(nms_pre_tensor, bbox_pred.shape[1])
+
if nms_pre > 0:
# Get maximum scores for foreground classes.
if self.use_sigmoid_cls:
@@ -667,7 +720,8 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin):
_, topk_inds = max_scores.topk(nms_pre)
batch_inds = torch.arange(batch_size).view(
- -1, 1).expand_as(topk_inds)
+ -1, 1).to(dtype=torch.int32).expand_as(topk_inds)
+ batch_inds = batch_inds.to(dtype=torch.int64)
anchors = anchors[batch_inds, topk_inds, :]
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
scores = scores[batch_inds, topk_inds, :]
@@ -695,6 +749,8 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin):
iou_threshold = cfg.nms.get('iou_threshold', 0.5)
score_threshold = cfg.score_thr
nms_pre = cfg.get('deploy_nms_pre', -1)
+ dets, labels = batch_nms_op(batch_mlvl_bboxes, batch_mlvl_scores, score_threshold, iou_threshold, cfg.max_per_img, cfg.max_per_img)
+ return dets, labels
return add_dummy_nms_for_onnx(batch_mlvl_bboxes, batch_mlvl_scores,
max_output_boxes_per_class,
iou_threshold, score_threshold,
@@ -75,7 +75,8 @@ def pytorch2onnx(model,
do_constant_folding=True,
verbose=show,
opset_version=opset_version,
- dynamic_axes=dynamic_axes)
+ dynamic_axes=dynamic_axes,
+ enable_onnx_checker=False)
model.forward = origin_forward