a1b0dbdf创建于 2022年7月7日历史提交
diff --git a/detectron2/layers/__init__.py b/detectron2/layers/__init__.py
index c8bd1fb..f5fa9ea 100644
--- a/detectron2/layers/__init__.py
+++ b/detectron2/layers/__init__.py
@@ -2,7 +2,7 @@
 from .batch_norm import FrozenBatchNorm2d, get_norm, NaiveSyncBatchNorm
 from .deform_conv import DeformConv, ModulatedDeformConv
 from .mask_ops import paste_masks_in_image
-from .nms import batched_nms, batched_nms_rotated, nms, nms_rotated
+from .nms import batched_nms, batch_nms_op, batched_nms_rotated, nms, nms_rotated
 from .roi_align import ROIAlign, roi_align
 from .roi_align_rotated import ROIAlignRotated, roi_align_rotated
 from .shape_spec import ShapeSpec
diff --git a/detectron2/layers/nms.py b/detectron2/layers/nms.py
index ac14d45..22efb24 100644
--- a/detectron2/layers/nms.py
+++ b/detectron2/layers/nms.py
@@ -15,6 +15,56 @@ if TORCH_VERSION < (1, 7):
 else:
     nms_rotated_func = torch.ops.detectron2.nms_rotated
 
+class BatchNMSOp(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size):
+        """
+        boxes (torch.Tensor): boxes in shape (batch, N, C, 4).
+        scores (torch.Tensor): scores in shape (batch, N, C).
+        return:
+            nmsed_boxes: (1, N, 4)
+            nmsed_scores: (1, N)
+            nmsed_classes: (1, N)
+            nmsed_num: (1,)
+        """
+
+        # Phony implementation for onnx export
+        nmsed_boxes = bboxes[:, :max_total_size, 0, :]
+        nmsed_scores = scores[:, :max_total_size, 0]
+        nmsed_classes = torch.arange(max_total_size, dtype=torch.long)
+        nmsed_num = torch.Tensor([max_total_size])
+
+        return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
+
+    @staticmethod
+    def symbolic(g, bboxes, scores, score_thr, iou_thr, max_size_p_class, max_t_size):
+        nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = g.op('BatchMultiClassNMS',
+            bboxes, scores, score_threshold_f=score_thr, iou_threshold_f=iou_thr,
+            max_size_per_class_i=max_size_p_class, max_total_size_i=max_t_size, outputs=4)
+        return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
+
+def batch_nms_op(bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size):
+    """
+    boxes (torch.Tensor): boxes in shape (N, 4).
+    scores (torch.Tensor): scores in shape (N, ).
+    """
+
+    num_classes = bboxes.shape[1].numpy() // 4
+    if bboxes.dtype == torch.float32:
+        bboxes = bboxes.reshape(1, bboxes.shape[0].numpy(), -1, 4).half()
+        scores = scores.reshape(1, scores.shape[0].numpy(), -1).half()
+    else:
+        bboxes = bboxes.reshape(1, bboxes.shape[0].numpy(), -1, 4)
+        scores = scores.reshape(1, scores.shape[0].numpy(), -1)
+
+    nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = BatchNMSOp.apply(bboxes, scores,
+        score_threshold, iou_threshold, max_size_per_class, max_total_size)
+    nmsed_boxes = nmsed_boxes.float()
+    nmsed_scores = nmsed_scores.float()
+    nmsed_classes = nmsed_classes.long()
+    dets = torch.cat((nmsed_boxes.reshape((max_total_size, 4)), nmsed_scores.reshape((max_total_size, 1))), -1)
+    labels = nmsed_classes.reshape((max_total_size, ))
+    return dets, labels
 
 def batched_nms(
     boxes: torch.Tensor, scores: torch.Tensor, idxs: torch.Tensor, iou_threshold: float
diff --git a/detectron2/modeling/box_regression.py b/detectron2/modeling/box_regression.py
index 12be000..074f3e3 100644
--- a/detectron2/modeling/box_regression.py
+++ b/detectron2/modeling/box_regression.py
@@ -87,20 +87,33 @@ class Box2BoxTransform(object):
         deltas = deltas.float()  # ensure fp32 for decoding precision
         boxes = boxes.to(deltas.dtype)
 
-        widths = boxes[:, 2] - boxes[:, 0]
-        heights = boxes[:, 3] - boxes[:, 1]
-        ctr_x = boxes[:, 0] + 0.5 * widths
-        ctr_y = boxes[:, 1] + 0.5 * heights
+        boxes_prof = boxes.permute(1, 0)
+        widths = boxes_prof[2, :] - boxes_prof[0, :]
+        heights = boxes_prof[3, :] - boxes_prof[1, :]
+        ctr_x = boxes_prof[0, :] + 0.5 * widths
+        ctr_y = boxes_prof[1, :] + 0.5 * heights
 
         wx, wy, ww, wh = self.weights
-        dx = deltas[:, 0::4] / wx
+        '''dx = deltas[:, 0::4] / wx
         dy = deltas[:, 1::4] / wy
         dw = deltas[:, 2::4] / ww
-        dh = deltas[:, 3::4] / wh
+        dh = deltas[:, 3::4] / wh'''
+        denorm_deltas = deltas
+        if denorm_deltas.shape[1] > 4:
+            denorm_deltas = denorm_deltas.view(-1, 80, 4)
+            dx = denorm_deltas[:, :, 0:1:].view(-1, 80) / wx
+            dy = denorm_deltas[:, :, 1:2:].view(-1, 80) / wy
+            dw = denorm_deltas[:, :, 2:3:].view(-1, 80) / ww
+            dh = denorm_deltas[:, :, 3:4:].view(-1, 80) / wh
+        else:
+            dx = denorm_deltas[:, 0:1:] / wx
+            dy = denorm_deltas[:, 1:2:] / wy
+            dw = denorm_deltas[:, 2:3:] / ww
+            dh = denorm_deltas[:, 3:4:] / wh
 
         # Prevent sending too large values into torch.exp()
-        dw = torch.clamp(dw, max=self.scale_clamp)
-        dh = torch.clamp(dh, max=self.scale_clamp)
+        dw = torch.clamp(dw, min=-float('inf'), max=self.scale_clamp)
+        dh = torch.clamp(dh, min=-float('inf'), max=self.scale_clamp)
 
         pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
         pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
diff --git a/detectron2/modeling/meta_arch/rcnn.py b/detectron2/modeling/meta_arch/rcnn.py
index e5f66d1..1bbba71 100644
--- a/detectron2/modeling/meta_arch/rcnn.py
+++ b/detectron2/modeling/meta_arch/rcnn.py
@@ -199,8 +199,9 @@ class GeneralizedRCNN(nn.Module):
         """
         assert not self.training
 
-        images = self.preprocess_image(batched_inputs)
-        features = self.backbone(images.tensor)
+        # images = self.preprocess_image(batched_inputs)
+        images = batched_inputs
+        features = self.backbone(images)
 
         if detected_instances is None:
             if self.proposal_generator is not None:
diff --git a/detectron2/modeling/poolers.py b/detectron2/modeling/poolers.py
index e5d72ab..7c0dd2f 100644
--- a/detectron2/modeling/poolers.py
+++ b/detectron2/modeling/poolers.py
@@ -94,6 +94,31 @@ def convert_boxes_to_pooler_format(box_lists: List[Boxes]):
 
     return pooler_fmt_boxes
 
+import torch.onnx.symbolic_helper as sym_help
+
+class RoiExtractor(torch.autograd.Function):
+    @staticmethod
+    def forward(self, f0, f1, f2, f3, rois, aligned=0, finest_scale=56, pooled_height=7, pooled_width=7,
+                         pool_mode='avg', roi_scale_factor=0, sample_num=0, spatial_scale=[0.25, 0.125, 0.0625, 0.03125]):
+        """
+        feats (torch.Tensor): feats in shape (batch, 256, H, W).
+        rois (torch.Tensor): rois in shape (k, 5).
+        return:
+            roi_feats (torch.Tensor): (k, 256, pooled_width, pooled_width)
+        """
+
+        # phony implementation for shape inference
+        k = rois.size()[0]
+        roi_feats = torch.ones(k, 256, pooled_height, pooled_width)
+        return roi_feats
+
+    @staticmethod
+    def symbolic(g, f0, f1, f2, f3, rois, aligned=0, finest_scale=56, pooled_height=7, pooled_width=7):
+        # TODO: support tensor list type for feats
+        #f_tensors = sym_help._unpack_list(feats)
+        roi_feats = g.op('RoiExtractor', f0, f1, f2, f3, rois, aligned_i=0, finest_scale_i=56, pooled_height_i=pooled_height, pooled_width_i=pooled_width,
+                         pool_mode_s='avg', roi_scale_factor_i=0, sample_num_i=0, spatial_scale_f=[0.25, 0.125, 0.0625, 0.03125], outputs=1)
+        return roi_feats
 
 class ROIPooler(nn.Module):
     """
@@ -202,6 +227,12 @@ class ROIPooler(nn.Module):
                 A tensor of shape (M, C, output_size, output_size) where M is the total number of
                 boxes aggregated over all N batch images and C is the number of channels in `x`.
         """
+        if torch.onnx.is_in_onnx_export():
+            output_size = self.output_size[0]
+            pooler_fmt_boxes = convert_boxes_to_pooler_format(box_lists)
+            roi_feats = RoiExtractor.apply(x[0], x[1], x[2], x[3], pooler_fmt_boxes, 0, 56, output_size, output_size)
+            return roi_feats
+
         num_level_assignments = len(self.level_poolers)
 
         assert isinstance(x, list) and isinstance(
diff --git a/detectron2/modeling/proposal_generator/proposal_utils.py b/detectron2/modeling/proposal_generator/proposal_utils.py
index 9c10436..b3437a7 100644
--- a/detectron2/modeling/proposal_generator/proposal_utils.py
+++ b/detectron2/modeling/proposal_generator/proposal_utils.py
@@ -4,7 +4,7 @@ import math
 from typing import List, Tuple
 import torch
 
-from detectron2.layers import batched_nms, cat
+from detectron2.layers import batch_nms_op, cat
 from detectron2.structures import Boxes, Instances
 from detectron2.utils.env import TORCH_VERSION
 
@@ -68,15 +68,19 @@ def find_top_rpn_proposals(
     for level_id, (proposals_i, logits_i) in enumerate(zip(proposals, pred_objectness_logits)):
         Hi_Wi_A = logits_i.shape[1]
         if isinstance(Hi_Wi_A, torch.Tensor):  # it's a tensor in tracing
-            num_proposals_i = torch.clamp(Hi_Wi_A, max=pre_nms_topk)
+            num_proposals_i = torch.clamp(Hi_Wi_A, min=0, max=pre_nms_topk)
         else:
             num_proposals_i = min(Hi_Wi_A, pre_nms_topk)
 
         # sort is faster than topk: https://github.com/pytorch/pytorch/issues/22812
-        # topk_scores_i, topk_idx = logits_i.topk(num_proposals_i, dim=1)
-        logits_i, idx = logits_i.sort(descending=True, dim=1)
+        num_proposals_i = num_proposals_i.item()
+        logits_i = logits_i.reshape(logits_i.size(1))
+        topk_scores_i, topk_idx = torch.topk(logits_i, num_proposals_i)
+        topk_scores_i = topk_scores_i.reshape(1, topk_scores_i.size(0))
+        topk_idx = topk_idx.reshape(1, topk_idx.size(0))
+        '''logits_i, idx = logits_i.sort(descending=True, dim=1)
         topk_scores_i = logits_i.narrow(1, 0, num_proposals_i)
-        topk_idx = idx.narrow(1, 0, num_proposals_i)
+        topk_idx = idx.narrow(1, 0, num_proposals_i)'''
 
         # each is N x topk
         topk_proposals_i = proposals_i[batch_idx[:, None], topk_idx]  # N x topk x 4
@@ -108,7 +112,7 @@ def find_top_rpn_proposals(
             lvl = lvl[valid_mask]
         boxes.clip(image_size)
 
-        # filter empty boxes
+        '''# filter empty boxes
         keep = boxes.nonempty(threshold=min_box_size)
         if _is_tracing() or keep.sum().item() != len(boxes):
             boxes, scores_per_img, lvl = boxes[keep], scores_per_img[keep], lvl[keep]
@@ -126,7 +130,14 @@ def find_top_rpn_proposals(
         res = Instances(image_size)
         res.proposal_boxes = boxes[keep]
         res.objectness_logits = scores_per_img[keep]
+        results.append(res)'''
+
+        dets, labels = batch_nms_op(boxes.tensor, scores_per_img, 0, nms_thresh, post_nms_topk, post_nms_topk)
+        res = Instances(image_size)
+        res.proposal_boxes = Boxes(dets[:, :4])
+        res.objectness_logits = dets[:, 4]
         results.append(res)
+
     return results
 
 
diff --git a/detectron2/modeling/proposal_generator/rpn.py b/detectron2/modeling/proposal_generator/rpn.py
index 1675377..77d9f26 100644
--- a/detectron2/modeling/proposal_generator/rpn.py
+++ b/detectron2/modeling/proposal_generator/rpn.py
@@ -434,7 +434,7 @@ class RPN(nn.Module):
         else:
             losses = {}
         proposals = self.predict_proposals(
-            anchors, pred_objectness_logits, pred_anchor_deltas, images.image_sizes
+            anchors, pred_objectness_logits, pred_anchor_deltas, [(1344, 1344)]
         )
         return proposals, losses
 
@@ -485,7 +485,8 @@ class RPN(nn.Module):
             B = anchors_i.tensor.size(1)
             pred_anchor_deltas_i = pred_anchor_deltas_i.reshape(-1, B)
             # Expand anchors to shape (N*Hi*Wi*A, B)
-            anchors_i = anchors_i.tensor.unsqueeze(0).expand(N, -1, -1).reshape(-1, B)
+            s = torch.zeros(N, anchors_i.tensor.unsqueeze(0).size(1), anchors_i.tensor.unsqueeze(0).size(2))
+            anchors_i = anchors_i.tensor.unsqueeze(0).expand_as(s).reshape(-1, B)
             proposals_i = self.box2box_transform.apply_deltas(pred_anchor_deltas_i, anchors_i)
             # Append feature map proposals with shape (N, Hi*Wi*A, B)
             proposals.append(proposals_i.view(N, -1, B))
diff --git a/detectron2/modeling/roi_heads/fast_rcnn.py b/detectron2/modeling/roi_heads/fast_rcnn.py
index 348f6a0..87c7cd3 100644
--- a/detectron2/modeling/roi_heads/fast_rcnn.py
+++ b/detectron2/modeling/roi_heads/fast_rcnn.py
@@ -7,7 +7,7 @@ from torch import nn
 from torch.nn import functional as F
 
 from detectron2.config import configurable
-from detectron2.layers import ShapeSpec, batched_nms, cat, cross_entropy, nonzero_tuple
+from detectron2.layers import ShapeSpec, batch_nms_op, cat, cross_entropy, nonzero_tuple
 from detectron2.modeling.box_regression import Box2BoxTransform
 from detectron2.structures import Boxes, Instances
 from detectron2.utils.events import get_event_storage
@@ -144,7 +144,7 @@ def fast_rcnn_inference_single_image(
     # Convert to Boxes to use the `clip` function ...
     boxes = Boxes(boxes.reshape(-1, 4))
     boxes.clip(image_shape)
-    boxes = boxes.tensor.view(-1, num_bbox_reg_classes, 4)  # R x C x 4
+    boxes = boxes.tensor.view(-1, num_bbox_reg_classes.item(), 4)  # R x C x 4
 
     # 1. Filter results based on detection scores. It can make NMS more efficient
     #    by filtering out low-confidence detections.
@@ -152,7 +152,7 @@ def fast_rcnn_inference_single_image(
     # R' x 2. First column contains indices of the R predictions;
     # Second column contains indices of classes.
     filter_inds = filter_mask.nonzero()
-    if num_bbox_reg_classes == 1:
+    '''if num_bbox_reg_classes == 1:
         boxes = boxes[filter_inds[:, 0], 0]
     else:
         boxes = boxes[filter_mask]
@@ -167,7 +167,14 @@ def fast_rcnn_inference_single_image(
     result = Instances(image_shape)
     result.pred_boxes = Boxes(boxes)
     result.scores = scores
-    result.pred_classes = filter_inds[:, 1]
+    result.pred_classes = filter_inds[:, 1]'''
+
+    dets, labels = batch_nms_op(boxes, scores, score_thresh, nms_thresh, topk_per_image, topk_per_image)
+    result = Instances(image_shape)
+    result.pred_boxes = Boxes(dets[:, :4])
+    result.scores = dets.permute(1, 0)[4, :]
+    result.pred_classes = labels
+
     return result, filter_inds[:, 0]
 
 
diff --git a/detectron2/modeling/roi_heads/mask_head.py b/detectron2/modeling/roi_heads/mask_head.py
index 5ac5c4b..f81b96b 100644
--- a/detectron2/modeling/roi_heads/mask_head.py
+++ b/detectron2/modeling/roi_heads/mask_head.py
@@ -142,7 +142,9 @@ def mask_rcnn_inference(pred_mask_logits: torch.Tensor, pred_instances: List[Ins
         num_masks = pred_mask_logits.shape[0]
         class_pred = cat([i.pred_classes for i in pred_instances])
         indices = torch.arange(num_masks, device=class_pred.device)
-        mask_probs_pred = pred_mask_logits[indices, class_pred][:, None].sigmoid()
+        print(indices,class_pred)
+        # mask_probs_pred = pred_mask_logits[indices, class_pred][:, None].sigmoid()
+        mask_probs_pred = pred_mask_logits.sigmoid()
     # mask_probs_pred.shape: (B, 1, Hmask, Wmask)
 
     num_boxes_per_image = [len(i) for i in pred_instances]
diff --git a/detectron2/structures/boxes.py b/detectron2/structures/boxes.py
index 57f862a..bad473b 100644
--- a/detectron2/structures/boxes.py
+++ b/detectron2/structures/boxes.py
@@ -202,10 +202,11 @@ class Boxes:
         """
         assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!"
         h, w = box_size
-        x1 = self.tensor[:, 0].clamp(min=0, max=w)
-        y1 = self.tensor[:, 1].clamp(min=0, max=h)
-        x2 = self.tensor[:, 2].clamp(min=0, max=w)
-        y2 = self.tensor[:, 3].clamp(min=0, max=h)
+        boxes_prof = self.tensor.permute(1, 0)
+        x1 = boxes_prof[0, :].clamp(min=0, max=w)
+        y1 = boxes_prof[1, :].clamp(min=0, max=h)
+        x2 = boxes_prof[2, :].clamp(min=0, max=w)
+        y2 = boxes_prof[3, :].clamp(min=0, max=h)
         self.tensor = torch.stack((x1, y1, x2, y2), dim=-1)
 
     def nonempty(self, threshold: float = 0.0) -> torch.Tensor:
diff --git a/tools/deploy/export_model.py b/tools/deploy/export_model.py
index fe2fe30..22145b7 100755
--- a/tools/deploy/export_model.py
+++ b/tools/deploy/export_model.py
@@ -77,6 +77,28 @@ def export_scripting(torch_model):
     # TODO inference in Python now missing postprocessing glue code
     return None
 
+from typing import Dict, Tuple
+import numpy
+from detectron2.structures import ImageList
+def preprocess_image(batched_inputs: Tuple[Dict[str, torch.Tensor]]):
+        """
+        Normalize, pad and batch the input images.
+        """
+        images = [x["image"].to('cpu') for x in batched_inputs]
+        images = [(x - numpy.array([[[103.530]], [[116.280]], [[123.675]]])) / numpy.array([[[1.]], [[1.]], [[1.]]]) for x in images]
+        import torch.nn.functional as F
+        image = torch.zeros(0, 1344, 1344)
+        for i in range(images[0].size(0)):
+            img = images[0][i]
+            img = img.expand((1, 1, img.size(0), img.size(1)))
+            img = img.to(dtype=torch.float32)
+            img = F.interpolate(img, size=(int(1344), int(1344)), mode='bilinear', align_corners=False)
+            img = img[0][0]
+            img = img.unsqueeze(0)
+            image = torch.cat((image, img))
+        images = [image]
+        images = ImageList.from_tensors(images, 32)
+        return images
 
 # experimental. API not yet final
 def export_tracing(torch_model, inputs):
@@ -84,6 +106,8 @@ def export_tracing(torch_model, inputs):
     image = inputs[0]["image"]
     inputs = [{"image": image}]  # remove other unused keys
 
+    inputs = preprocess_image(inputs).tensor.to(torch.float32)
+    image = inputs
     if isinstance(torch_model, GeneralizedRCNN):
 
         def inference(model, inputs):
@@ -104,7 +128,7 @@ def export_tracing(torch_model, inputs):
     elif args.format == "onnx":
         # NOTE onnx export currently failing in pytorch
         with PathManager.open(os.path.join(args.output, "model.onnx"), "wb") as f:
-            torch.onnx.export(traceable_model, (image,), f)
+            torch.onnx.export(traceable_model, (image,), f, opset_version=11, verbose=True)
     logger.info("Inputs schema: " + str(traceable_model.inputs_schema))
     logger.info("Outputs schema: " + str(traceable_model.outputs_schema))