05360171创建于 2022年3月18日历史提交
diff --git a/mmdet/core/anchor/anchor_generator.py b/mmdet/core/anchor/anchor_generator.py
index 596cfa0b..4a015e93 100644
--- a/mmdet/core/anchor/anchor_generator.py
+++ b/mmdet/core/anchor/anchor_generator.py
@@ -205,6 +205,8 @@ class AnchorGenerator:
             tuple[torch.Tensor]: The mesh grids of x and y.
         """
         # use shape instead of len to keep tracing while exporting to onnx
+        x = x.to(dtype=torch.int32)
+        y = y.to(dtype=torch.int32)
         xx = x.repeat(y.shape[0])
         yy = y.view(-1, 1).repeat(1, x.shape[0]).view(-1)
         if row_major:
diff --git a/mmdet/models/dense_heads/anchor_head.py b/mmdet/models/dense_heads/anchor_head.py
index 96753ef8..da0e03d2 100644
--- a/mmdet/models/dense_heads/anchor_head.py
+++ b/mmdet/models/dense_heads/anchor_head.py
@@ -10,6 +10,56 @@ from .base_dense_head import BaseDenseHead
 from .dense_test_mixins import BBoxTestMixin
 
 
+class BatchNMSOp(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size):
+        """
+        boxes (torch.Tensor): boxes in shape (batch, N, C, 4).
+        scores (torch.Tensor): scores in shape (batch, N, C).
+        return:
+            nmsed_boxes: (1, N, 4)
+            nmsed_scores: (1, N)
+            nmsed_classes: (1, N)
+            nmsed_num: (1,)
+        """
+
+        # Phony implementation for onnx export
+        nmsed_boxes = bboxes[:, :max_total_size, 0, :]
+        nmsed_scores = scores[:, :max_total_size, 0]
+        nmsed_classes = torch.arange(max_total_size, dtype=torch.long)
+        nmsed_num = torch.Tensor([max_total_size])
+
+        return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
+
+    @staticmethod
+    def symbolic(g, bboxes, scores, score_thr, iou_thr, max_size_p_class, max_t_size):
+        nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = g.op('BatchMultiClassNMS',
+            bboxes, scores, score_threshold_f=score_thr, iou_threshold_f=iou_thr,
+            max_size_per_class_i=max_size_p_class, max_total_size_i=max_t_size, outputs=4)
+        return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
+
+def batch_nms_op(bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size):
+    """
+    boxes (torch.Tensor): boxes in shape (N, 4).
+    scores (torch.Tensor): scores in shape (N, ).
+    """
+
+    if bboxes.dtype == torch.float32:
+        bboxes = bboxes.reshape(bboxes.size(0), bboxes.shape[1].numpy(), -1, 4).half()
+        scores = scores.reshape(scores.size(0), scores.shape[1].numpy(), -1).half()
+    else:
+        bboxes = bboxes.reshape(bboxes.size(0), bboxes.shape[1].numpy(), -1, 4)
+        scores = scores.reshape(scores.size(0), scores.shape[1].numpy(), -1)
+
+    nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = BatchNMSOp.apply(bboxes, scores,
+        score_threshold, iou_threshold, max_size_per_class, max_total_size)
+    nmsed_boxes = nmsed_boxes.float()
+    nmsed_scores = nmsed_scores.float()
+    nmsed_classes = nmsed_classes.long()
+    dets = torch.cat((nmsed_boxes.reshape((bboxes.size(0), max_total_size, 4)), nmsed_scores.reshape((bboxes.size(0), max_total_size, 1))), -1)
+    labels = nmsed_classes.reshape((bboxes.size(0), max_total_size))
+    return dets, labels
+
 @HEADS.register_module()
 class AnchorHead(BaseDenseHead, BBoxTestMixin):
     """Anchor-based head (RPN, RetinaNet, SSD, etc.).
@@ -654,7 +704,10 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin):
             anchors = anchors.expand_as(bbox_pred)
             # Always keep topk op for dynamic input in onnx
             from mmdet.core.export import get_k_for_topk
-            nms_pre = get_k_for_topk(nms_pre_tensor, bbox_pred.shape[1])
+            nms_pre = bbox_pred.shape[1]
+            if nms_pre_tensor > 0 and bbox_pred.shape[1] > nms_pre_tensor:
+                nms_pre = nms_pre_tensor#get_k_for_topk(nms_pre_tensor, bbox_pred.shape[1])
+
             if nms_pre > 0:
                 # Get maximum scores for foreground classes.
                 if self.use_sigmoid_cls:
@@ -667,7 +720,8 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin):
 
                 _, topk_inds = max_scores.topk(nms_pre)
                 batch_inds = torch.arange(batch_size).view(
-                    -1, 1).expand_as(topk_inds)
+                    -1, 1).to(dtype=torch.int32).expand_as(topk_inds)
+                batch_inds = batch_inds.to(dtype=torch.int64)
                 anchors = anchors[batch_inds, topk_inds, :]
                 bbox_pred = bbox_pred[batch_inds, topk_inds, :]
                 scores = scores[batch_inds, topk_inds, :]
@@ -695,6 +749,8 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin):
             iou_threshold = cfg.nms.get('iou_threshold', 0.5)
             score_threshold = cfg.score_thr
             nms_pre = cfg.get('deploy_nms_pre', -1)
+            dets, labels = batch_nms_op(batch_mlvl_bboxes, batch_mlvl_scores, score_threshold, iou_threshold, cfg.max_per_img, cfg.max_per_img)
+            return dets, labels
             return add_dummy_nms_for_onnx(batch_mlvl_bboxes, batch_mlvl_scores,
                                           max_output_boxes_per_class,
                                           iou_threshold, score_threshold,
diff --git a/tools/deployment/pytorch2onnx.py b/tools/deployment/pytorch2onnx.py
index 0e10000d..4b9ba3fd 100644
--- a/tools/deployment/pytorch2onnx.py
+++ b/tools/deployment/pytorch2onnx.py
@@ -75,7 +75,8 @@ def pytorch2onnx(model,
         do_constant_folding=True,
         verbose=show,
         opset_version=opset_version,
-        dynamic_axes=dynamic_axes)
+        dynamic_axes=dynamic_axes,
+        enable_onnx_checker=False)
 
     model.forward = origin_forward