你çšxiahaishengmodify README.py
28b19213创建于 2022年11月15日历史提交
diff --git a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
index e9eb357..f72cef7 100644
--- a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
+++ b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
@@ -168,13 +168,31 @@ def delta2bbox(rois,
                 [0.0000, 0.3161, 4.1945, 0.6839],
                 [5.0000, 5.0000, 5.0000, 5.0000]])
     """
-    means = deltas.new_tensor(means).view(1, -1).repeat(1, deltas.size(1) // 4)
-    stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(1) // 4)
+
+    # fix shape for means and stds when exporting onnx
+    if torch.onnx.is_in_onnx_export():
+        means = deltas.new_tensor(means).view(1, -1).repeat(1, deltas.size(1).numpy() // 4)
+        stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(1).numpy() // 4)
+    else:
+        means = deltas.new_tensor(means).view(1, -1).repeat(1, deltas.size(1) // 4)
+        stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(1) // 4)
     denorm_deltas = deltas * stds + means
-    dx = denorm_deltas[:, 0::4]
-    dy = denorm_deltas[:, 1::4]
-    dw = denorm_deltas[:, 2::4]
-    dh = denorm_deltas[:, 3::4]
+    # dx = denorm_deltas[:, 0::4]
+    # dy = denorm_deltas[:, 1::4]
+    # dw = denorm_deltas[:, 2::4]
+    # dh = denorm_deltas[:, 3::4]
+    if denorm_deltas.shape[1] > 4:
+        denorm_deltas = denorm_deltas.view(-1, 80, 4)
+        dx = denorm_deltas[:, :, 0:1:].view(-1, 80)
+        dy = denorm_deltas[:, :, 1:2:].view(-1, 80)
+        dw = denorm_deltas[:, :, 2:3:].view(-1, 80)
+        dh = denorm_deltas[:, :, 3:4:].view(-1, 80)
+    else:
+        dx = denorm_deltas[:, 0:1:]
+        dy = denorm_deltas[:, 1:2:]
+        dw = denorm_deltas[:, 2:3:]
+        dh = denorm_deltas[:, 3:4:]
+
     max_ratio = np.abs(np.log(wh_ratio_clip))
     dw = dw.clamp(min=-max_ratio, max=max_ratio)
     dh = dh.clamp(min=-max_ratio, max=max_ratio)
diff --git a/mmdet/core/post_processing/bbox_nms.py b/mmdet/core/post_processing/bbox_nms.py
index c43aea9..e99f5d8 100644
--- a/mmdet/core/post_processing/bbox_nms.py
+++ b/mmdet/core/post_processing/bbox_nms.py
@@ -4,6 +4,59 @@ from mmcv.ops.nms import batched_nms
 from mmdet.core.bbox.iou_calculators import bbox_overlaps
 
 
+class BatchNMSOp(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size):
+        """
+        boxes (torch.Tensor): boxes in shape (batch, N, C, 4).
+        scores (torch.Tensor): scores in shape (batch, N, C).
+        return:
+            nmsed_boxes: (1, N, 4)
+            nmsed_scores: (1, N)
+            nmsed_classes: (1, N)
+            nmsed_num: (1,)
+        """
+
+        # Phony implementation for onnx export
+        nmsed_boxes = bboxes[:, :max_total_size, 0, :]
+        nmsed_scores = scores[:, :max_total_size, 0]
+        nmsed_classes = torch.arange(max_total_size, dtype=torch.long)
+        nmsed_num = torch.Tensor([max_total_size])
+
+        return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
+
+    @staticmethod
+    def symbolic(g, bboxes, scores, score_thr, iou_thr, max_size_p_class, max_t_size):
+        nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = g.op('BatchMultiClassNMS',
+            bboxes, scores, score_threshold_f=score_thr, iou_threshold_f=iou_thr,
+            max_size_per_class_i=max_size_p_class, max_total_size_i=max_t_size, outputs=4)
+        return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
+
+def batch_nms_op(bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size):
+    """
+    boxes (torch.Tensor): boxes in shape (N, 4).
+    scores (torch.Tensor): scores in shape (N, ).
+    """
+
+    num_classes = bboxes.shape[1].numpy() // 4
+    if bboxes.dtype == torch.float32:
+        bboxes = bboxes.reshape(1, bboxes.shape[0].numpy(), -1, 4).half()
+        scores = scores.reshape(1, scores.shape[0].numpy(), -1).half()
+    else:
+        bboxes = bboxes.reshape(1, bboxes.shape[0].numpy(), -1, 4)
+        scores = scores.reshape(1, scores.shape[0].numpy(), -1)
+
+    nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = BatchNMSOp.apply(bboxes, scores,
+        score_threshold, iou_threshold, max_size_per_class, max_total_size)
+    nmsed_boxes = nmsed_boxes.float()
+    nmsed_scores = nmsed_scores.float()
+    nmsed_classes = nmsed_classes.long()
+    dets = torch.cat((nmsed_boxes.reshape((max_total_size, 4)), nmsed_scores.reshape((max_total_size, 1))), -1)
+    dets = dets.reshape((max_total_size, 5))
+    labels = nmsed_classes.reshape((max_total_size, ))
+    return dets, labels
+
+
 def multiclass_nms(multi_bboxes,
                    multi_scores,
                    score_thr,
@@ -40,7 +93,17 @@ def multiclass_nms(multi_bboxes,
             multi_scores.size(0), num_classes, 4)
 
     scores = multi_scores[:, :-1]
+    # multiply score_factor after threshold to preserve more bboxes, improve
+    # mAP by 1% for YOLOv3
+    if score_factors is not None:
+        # expand the shape to match original shape of score
+        score_factors = score_factors.view(-1, 1).expand(
+            multi_scores.size(0), num_classes)
+        score_factors = score_factors.reshape(-1)
+        scores = scores * score_factors
 
+    # cpu and gpu
+    '''
     labels = torch.arange(num_classes, dtype=torch.long)
     labels = labels.view(1, -1).expand_as(scores)
 
@@ -80,7 +143,11 @@ def multiclass_nms(multi_bboxes,
         return dets, labels[keep], keep
     else:
         return dets, labels[keep]
+    '''
 
+    # npu
+    dets, labels = batch_nms_op(bboxes, scores, score_thr, nms_cfg.get("iou_threshold"), max_num, max_num)
+    return dets, labels
 
 def fast_nms(multi_bboxes,
              multi_scores,
diff --git a/mmdet/models/dense_heads/rpn_head.py b/mmdet/models/dense_heads/rpn_head.py
index f565d1a..3c29386 100644
--- a/mmdet/models/dense_heads/rpn_head.py
+++ b/mmdet/models/dense_heads/rpn_head.py
@@ -9,6 +9,57 @@ from .anchor_head import AnchorHead
 from .rpn_test_mixin import RPNTestMixin
 
 
+class BatchNMSOp(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size):
+        """
+        boxes (torch.Tensor): boxes in shape (batch, N, C, 4).
+        scores (torch.Tensor): scores in shape (batch, N, C).
+        return:
+            nmsed_boxes: (1, N, 4)
+            nmsed_scores: (1, N)
+            nmsed_classes: (1, N)
+            nmsed_num: (1,)
+        """
+
+        # Phony implementation for onnx export
+        nmsed_boxes = bboxes[:, :max_total_size, 0, :]
+        nmsed_scores = scores[:, :max_total_size, 0]
+        nmsed_classes = torch.arange(max_total_size, dtype=torch.long)
+        nmsed_num = torch.Tensor([max_total_size])
+
+        return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
+
+    @staticmethod
+    def symbolic(g, bboxes, scores, score_thr, iou_thr, max_size_p_class, max_t_size):
+        nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = g.op('BatchMultiClassNMS',
+            bboxes, scores, score_threshold_f=score_thr, iou_threshold_f=iou_thr,
+            max_size_per_class_i=max_size_p_class, max_total_size_i=max_t_size, outputs=4)
+        return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
+
+def batch_nms_op(bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size):
+    """
+    boxes (torch.Tensor): boxes in shape (N, 4).
+    scores (torch.Tensor): scores in shape (N, ).
+    """
+
+    num_classes = bboxes.shape[1].numpy() // 4
+    if bboxes.dtype == torch.float32:
+        bboxes = bboxes.reshape(1, bboxes.shape[0].numpy(), -1, 4).half()
+        scores = scores.reshape(1, scores.shape[0].numpy(), -1).half()
+    else:
+        bboxes = bboxes.reshape(1, bboxes.shape[0].numpy(), -1, 4)
+        scores = scores.reshape(1, scores.shape[0].numpy(), -1)
+
+    nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = BatchNMSOp.apply(bboxes, scores,
+        score_threshold, iou_threshold, max_size_per_class, max_total_size)
+    nmsed_boxes = nmsed_boxes.float()
+    nmsed_scores = nmsed_scores.float()
+    nmsed_classes = nmsed_classes.long()
+    dets = torch.cat((nmsed_boxes.reshape((max_total_size, 4)), nmsed_scores.reshape((max_total_size, 1))), -1)
+    labels = nmsed_classes.reshape((max_total_size, ))
+    return dets, labels
+
 @HEADS.register_module()
 class RPNHead(RPNTestMixin, AnchorHead):
     """RPN head.
@@ -132,9 +183,14 @@ class RPNHead(RPNTestMixin, AnchorHead):
             if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre:
                 # sort is faster than topk
                 # _, topk_inds = scores.topk(cfg.nms_pre)
-                ranked_scores, rank_inds = scores.sort(descending=True)
-                topk_inds = rank_inds[:cfg.nms_pre]
-                scores = ranked_scores[:cfg.nms_pre]
+                # onnx uses topk to sort, this is simpler for onnx export
+                if torch.onnx.is_in_onnx_export():
+                    scores, topk_inds = torch.topk(scores, cfg.nms_pre)
+                else:
+                    ranked_scores, rank_inds = scores.sort(descending=True)
+                    topk_inds = rank_inds[:cfg.nms_pre]
+                    scores = ranked_scores[:cfg.nms_pre]
+
                 rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
                 anchors = anchors[topk_inds, :]
             mlvl_scores.append(scores)
@@ -164,5 +220,12 @@ class RPNHead(RPNTestMixin, AnchorHead):
 
         # TODO: remove the hard coded nms type
         nms_cfg = dict(type='nms', iou_threshold=cfg.nms_thr)
+        # cpu and gpu return
+        '''
         dets, keep = batched_nms(proposals, scores, ids, nms_cfg)
         return dets[:cfg.nms_post]
+        '''
+
+        # npu return
+        dets, labels = batch_nms_op(proposals, scores, 0.0, nms_cfg.get("iou_threshold"), cfg.nms_post, cfg.nms_post)
+        return dets
diff --git a/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py b/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py
index 0cba3cd..a965e53 100644
--- a/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py
+++ b/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py
@@ -199,11 +199,11 @@ class FCNMaskHead(nn.Module):
             # TODO: Remove after F.grid_sample is supported.
             from torchvision.models.detection.roi_heads \
                 import paste_masks_in_image
-            masks = paste_masks_in_image(mask_pred, bboxes, ori_shape[:2])
+            '''masks = paste_masks_in_image(mask_pred, bboxes, ori_shape[:2])
             thr = rcnn_test_cfg.get('mask_thr_binary', 0)
             if thr > 0:
-                masks = masks >= thr
-            return masks
+                masks = masks >= thr'''
+            return mask_pred
 
         N = len(mask_pred)
         # The actual implementation split the input into chunks,
diff --git a/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py b/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py
index c0eebc4..63605c5 100644
--- a/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py
+++ b/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py
@@ -4,6 +4,31 @@ from mmcv.runner import force_fp32
 from mmdet.models.builder import ROI_EXTRACTORS
 from .base_roi_extractor import BaseRoIExtractor
 
+import torch.onnx.symbolic_helper as sym_help
+
+class RoiExtractor(torch.autograd.Function):
+    @staticmethod
+    def forward(self, f0, f1, f2, f3, rois, aligned=1, finest_scale=56, pooled_height=7, pooled_width=7,
+                         pool_mode='avg', roi_scale_factor=0, sample_num=0, spatial_scale=[0.25, 0.125, 0.0625, 0.03125]):
+        """
+        feats (torch.Tensor): feats in shape (batch, 256, H, W).
+        rois (torch.Tensor): rois in shape (k, 5).
+        return:
+            roi_feats (torch.Tensor): (k, 256, pooled_width, pooled_width)
+        """
+
+        # phony implementation for shape inference
+        k = rois.size()[0]
+        roi_feats = torch.ones(k, 256, pooled_height, pooled_width)
+        return roi_feats
+
+    @staticmethod
+    def symbolic(g, f0, f1, f2, f3, rois, aligned=1, finest_scale=56, pooled_height=7, pooled_width=7):
+        # TODO: support tensor list type for feats
+        #f_tensors = sym_help._unpack_list(feats)
+        roi_feats = g.op('RoiExtractor', f0, f1, f2, f3, rois, aligned_i=1, finest_scale_i=56, pooled_height_i=pooled_height, pooled_width_i=pooled_width,
+                         pool_mode_s='avg', roi_scale_factor_i=0, sample_num_i=0, spatial_scale_f=[0.25, 0.125, 0.0625, 0.03125], outputs=1)
+        return roi_feats
 
 @ROI_EXTRACTORS.register_module()
 class SingleRoIExtractor(BaseRoIExtractor):
@@ -52,6 +77,14 @@ class SingleRoIExtractor(BaseRoIExtractor):
 
     @force_fp32(apply_to=('feats', ), out_fp16=True)
     def forward(self, feats, rois, roi_scale_factor=None):
+        # Work around to export onnx for npu
+        if torch.onnx.is_in_onnx_export():
+            out_size = self.roi_layers[0].output_size
+            roi_feats = RoiExtractor.apply(feats[0], feats[1], feats[2], feats[3], rois, 1, 56, out_size[0], out_size[1])
+            # roi_feats = RoiExtractor.apply(list(feats), rois)
+            return roi_feats
+
+
         """Forward function."""
         out_size = self.roi_layers[0].output_size
         num_levels = len(feats)
diff --git a/tools/deployment/pytorch2onnx.py b/tools/deployment/pytorch2onnx.py
index 1305a79..c79e9fb 100644
--- a/tools/deployment/pytorch2onnx.py
+++ b/tools/deployment/pytorch2onnx.py
@@ -48,7 +48,7 @@ def pytorch2onnx(config_path,
         input_names=['input'],
         output_names=output_names,
         export_params=True,
-        keep_initializers_as_inputs=True,
+        #keep_initializers_as_inputs=True,
         do_constant_folding=True,
         verbose=show,
         opset_version=opset_version)