05360171创建于 2022年3月18日历史提交
diff --git a/mmdet/models/dense_heads/fcos_head.py b/mmdet/models/dense_heads/fcos_head.py
index 323d154b..cfa5e5a4 100644
--- a/mmdet/models/dense_heads/fcos_head.py
+++ b/mmdet/models/dense_heads/fcos_head.py
@@ -10,6 +10,54 @@ from .anchor_free_head import AnchorFreeHead
 
 INF = 1e8
 
+class BatchNMSOp(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size):
+        """
+        boxes (torch.Tensor): boxes in shape (batch, N, C, 4).
+        scores (torch.Tensor): scores in shape (batch, N, C).
+        return:
+            nmsed_boxes: (1, N, 4)
+            nmsed_scores: (1, N)
+            nmsed_classes: (1, N)
+            nmsed_num: (1,)
+        """
+
+        # Phony implementation for onnx export
+        nmsed_boxes = bboxes[:, :max_total_size, 0, :]
+        nmsed_scores = scores[:, :max_total_size, 0]
+        nmsed_classes = torch.arange(max_total_size, dtype=torch.long)
+        nmsed_num = torch.Tensor([max_total_size])
+
+        return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
+
+    @staticmethod
+    def symbolic(g, bboxes, scores, score_thr, iou_thr, max_size_p_class, max_t_size):
+        nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = g.op('BatchMultiClassNMS',
+            bboxes, scores, score_threshold_f=score_thr, iou_threshold_f=iou_thr,
+            max_size_per_class_i=max_size_p_class, max_total_size_i=max_t_size, outputs=4)
+        return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
+
+def batch_nms_op(bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size):
+    """
+    boxes (torch.Tensor): boxes in shape (N, 4).
+    scores (torch.Tensor): scores in shape (N, ).
+    """
+
+    if bboxes.dtype == torch.float32:
+        bboxes = bboxes.reshape(bboxes.size(0), bboxes.shape[1].numpy(), -1, 4).half()
+        scores = scores.reshape(scores.size(0), scores.shape[1].numpy(), -1).half()
+    else:
+        bboxes = bboxes.reshape(bboxes.size(0), bboxes.shape[1].numpy(), -1, 4)
+        scores = scores.reshape(scores.size(0), scores.shape[1].numpy(), -1)
+    nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = BatchNMSOp.apply(bboxes, scores,
+        score_threshold, iou_threshold, max_size_per_class, max_total_size)
+    nmsed_boxes = nmsed_boxes.float()
+    nmsed_scores = nmsed_scores.float()
+    nmsed_classes = nmsed_classes.long()
+    dets = torch.cat((nmsed_boxes.reshape((bboxes.size(0), max_total_size, 4)), nmsed_scores.reshape((scores.size(0), max_total_size, 1))), -1)
+    labels = nmsed_classes.reshape((bboxes.size(0), max_total_size))
+    return dets, labels
 
 @HEADS.register_module()
 class FCOSHead(AnchorFreeHead):
@@ -387,19 +435,23 @@ class FCOSHead(AnchorFreeHead):
 
             bbox_pred = bbox_pred.permute(0, 2, 3,
                                           1).reshape(batch_size, -1, 4)
-            points = points.expand(batch_size, -1, 2)
+            s = torch.zeros(batch_size, points.size(0), 2)
+            points = points.expand_as(s)
             # Get top-k prediction
             from mmdet.core.export import get_k_for_topk
-            nms_pre = get_k_for_topk(nms_pre_tensor, bbox_pred.shape[1])
+            nms_pre = bbox_pred.shape[1]
+            if nms_pre_tensor > 0 and bbox_pred.shape[1] > nms_pre_tensor:
+                nms_pre = nms_pre_tensor#get_k_for_topk(nms_pre_tensor, bbox_pred.shape[1])
             if nms_pre > 0:
                 max_scores, _ = (scores * centerness[..., None]).max(-1)
                 _, topk_inds = max_scores.topk(nms_pre)
                 batch_inds = torch.arange(batch_size).view(
-                    -1, 1).expand_as(topk_inds).long()
+                    -1, 1).to(dtype=torch.int32).expand_as(topk_inds).long()
                 # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
                 if torch.onnx.is_in_onnx_export():
                     transformed_inds = bbox_pred.shape[
-                        1] * batch_inds + topk_inds
+                        1] * batch_inds.to(dtype=torch.int32) + topk_inds.to(dtype=torch.int32)
+                    transformed_inds = transformed_inds.to(dtype=torch.int64)
                     points = points.reshape(-1,
                                             2)[transformed_inds, :].reshape(
                                                 batch_size, -1, 2)
@@ -438,6 +490,8 @@ class FCOSHead(AnchorFreeHead):
             iou_threshold = cfg.nms.get('iou_threshold', 0.5)
             score_threshold = cfg.score_thr
             nms_pre = cfg.get('deploy_nms_pre', -1)
+            dets, labels = batch_nms_op(batch_mlvl_bboxes, batch_mlvl_scores, score_threshold, iou_threshold, cfg.max_per_img, cfg.max_per_img)
+            return dets, labels
             return add_dummy_nms_for_onnx(batch_mlvl_bboxes, batch_mlvl_scores,
                                           max_output_boxes_per_class,
                                           iou_threshold, score_threshold,