05360171创建于 2022年3月18日历史提交
diff --git a/mmdet/models/dense_heads/gfocal_head.py b/mmdet/models/dense_heads/gfocal_head.py
index 6e62382..48d5c35 100644
--- a/mmdet/models/dense_heads/gfocal_head.py
+++ b/mmdet/models/dense_heads/gfocal_head.py
@@ -11,7 +11,57 @@ from mmdet.core import (anchor_inside_flags, bbox2distance, bbox_overlaps,
 from ..builder import HEADS, build_loss
 from .anchor_head import AnchorHead
 
+class BatchNMSOp(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size):
+        """
+        boxes (torch.Tensor): boxes in shape (batch, N, C, 4).
+        scores (torch.Tensor): scores in shape (batch, N, C).
+        return:
+            nmsed_boxes: (1, N, 4)
+            nmsed_scores: (1, N)
+            nmsed_classes: (1, N)
+            nmsed_num: (1,)
+        """
+
+        # Phony implementation for onnx export
+        nmsed_boxes = bboxes[:, :max_total_size, 0, :]
+        nmsed_scores = scores[:, :max_total_size, 0]
+        nmsed_classes = torch.arange(max_total_size, dtype=torch.long)
+        nmsed_num = torch.Tensor([max_total_size])
+
+        return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
+
+    @staticmethod
+    def symbolic(g, bboxes, scores, score_thr, iou_thr, max_size_p_class, max_t_size):
+        nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = g.op('BatchMultiClassNMS',
+            bboxes, scores, score_threshold_f=score_thr, iou_threshold_f=iou_thr,
+            max_size_per_class_i=max_size_p_class, max_total_size_i=max_t_size, outputs=4)
+        return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
+
+def batch_nms_op(bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size):
+    """
+    boxes (torch.Tensor): boxes in shape (N, 4).
+    scores (torch.Tensor): scores in shape (N, ).
+    """
 
+    num_classes = bboxes.shape[1].numpy() // 4
+    if bboxes.dtype == torch.float32:
+        bboxes = bboxes.reshape(1, bboxes.shape[0].numpy(), -1, 4).half()
+        scores = scores.reshape(1, scores.shape[0].numpy(), -1).half()
+    else:
+        bboxes = bboxes.reshape(1, bboxes.shape[0].numpy(), -1, 4)
+        scores = scores.reshape(1, scores.shape[0].numpy(), -1)
+    nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = BatchNMSOp.apply(bboxes, scores,
+        score_threshold, iou_threshold, max_size_per_class, max_total_size)
+    nmsed_boxes = nmsed_boxes.float()
+    nmsed_scores = nmsed_scores.float()
+    nmsed_classes = nmsed_classes.long()
+    dets = torch.cat((nmsed_boxes.reshape((max_total_size, 4)), nmsed_scores.reshape((max_total_size, 1))), -1)
+    dets = dets.reshape((max_total_size, 5))
+    labels = nmsed_classes.reshape((max_total_size, ))
+    return dets, labels
+    
 class Integral(nn.Module):
     """A fixed layer for calculating integral result from distribution.
 
@@ -44,7 +94,7 @@ class Integral(nn.Module):
                 offsets from the box center in four directions, shape (N, 4).
         """
         x = F.softmax(x.reshape(-1, self.reg_max + 1), dim=1)
-        x = F.linear(x, self.project.type_as(x)).reshape(-1, 4)
+        x = F.linear(x, self.project.type_as(x).view(-1, self.reg_max + 1)).reshape(-1, 4)
         return x
 
 
@@ -211,7 +261,9 @@ class GFocalHead(AnchorHead):
         bbox_pred = scale(self.gfl_reg(reg_feat)).float()
         N, C, H, W = bbox_pred.size()
         prob = F.softmax(bbox_pred.reshape(N, 4, self.reg_max+1, H, W), dim=2)
-        prob_topk, _ = prob.topk(self.reg_topk, dim=2)
+        prob_tmp = prob.permute(0, 1, 4, 3, 2)
+        prob_topk, _ = torch.topk(prob_tmp, self.reg_topk)
+        prob_topk = prob_topk.permute(0, 1, 4, 3, 2)
 
         if self.add_mean:
             stat = torch.cat([prob_topk, prob_topk.mean(dim=2, keepdim=True)],
@@ -477,10 +529,8 @@ class GFocalHead(AnchorHead):
         mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
 
         if with_nms:
-            det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
-                                                    cfg.score_thr, cfg.nms,
-                                                    cfg.max_per_img)
-            return det_bboxes, det_labels
+            dets, labels = batch_nms_op(mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms.get("iou_threshold"), cfg.max_per_img, cfg.max_per_img)
+            return dets, labels
         else:
             return mlvl_bboxes, mlvl_scores
 
diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py
index 985dd5d..b6118da 100644
--- a/tools/pytorch2onnx.py
+++ b/tools/pytorch2onnx.py
@@ -40,7 +40,8 @@ def pytorch2onnx(config_path,
         export_params=True,
         keep_initializers_as_inputs=True,
         verbose=show,
-        opset_version=opset_version)
+        opset_version=opset_version,
+        enable_onnx_checker=False)
 
     model.forward = orig_model.forward
     print(f'Successfully exported ONNX model: {output_file}')