@@ -2,7 +2,7 @@
from .batch_norm import FrozenBatchNorm2d, get_norm, NaiveSyncBatchNorm
from .deform_conv import DeformConv, ModulatedDeformConv
from .mask_ops import paste_masks_in_image
-from .nms import batched_nms, batched_nms_rotated, nms, nms_rotated
+from .nms import batched_nms, batch_nms_op, batched_nms_rotated, nms, nms_rotated
from .roi_align import ROIAlign, roi_align
from .roi_align_rotated import ROIAlignRotated, roi_align_rotated
from .shape_spec import ShapeSpec
@@ -15,6 +15,56 @@ if TORCH_VERSION < (1, 7):
else:
nms_rotated_func = torch.ops.detectron2.nms_rotated
+class BatchNMSOp(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size):
+ """
+ boxes (torch.Tensor): boxes in shape (batch, N, C, 4).
+ scores (torch.Tensor): scores in shape (batch, N, C).
+ return:
+ nmsed_boxes: (1, N, 4)
+ nmsed_scores: (1, N)
+ nmsed_classes: (1, N)
+ nmsed_num: (1,)
+ """
+
+ # Phony implementation for onnx export
+ nmsed_boxes = bboxes[:, :max_total_size, 0, :]
+ nmsed_scores = scores[:, :max_total_size, 0]
+ nmsed_classes = torch.arange(max_total_size, dtype=torch.long)
+ nmsed_num = torch.Tensor([max_total_size])
+
+ return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
+
+ @staticmethod
+ def symbolic(g, bboxes, scores, score_thr, iou_thr, max_size_p_class, max_t_size):
+ nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = g.op('BatchMultiClassNMS',
+ bboxes, scores, score_threshold_f=score_thr, iou_threshold_f=iou_thr,
+ max_size_per_class_i=max_size_p_class, max_total_size_i=max_t_size, outputs=4)
+ return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
+
+def batch_nms_op(bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size):
+ """
+ boxes (torch.Tensor): boxes in shape (N, 4).
+ scores (torch.Tensor): scores in shape (N, ).
+ """
+
+ num_classes = bboxes.shape[1].numpy() // 4
+ if bboxes.dtype == torch.float32:
+ bboxes = bboxes.reshape(1, bboxes.shape[0].numpy(), -1, 4).half()
+ scores = scores.reshape(1, scores.shape[0].numpy(), -1).half()
+ else:
+ bboxes = bboxes.reshape(1, bboxes.shape[0].numpy(), -1, 4)
+ scores = scores.reshape(1, scores.shape[0].numpy(), -1)
+
+ nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = BatchNMSOp.apply(bboxes, scores,
+ score_threshold, iou_threshold, max_size_per_class, max_total_size)
+ nmsed_boxes = nmsed_boxes.float()
+ nmsed_scores = nmsed_scores.float()
+ nmsed_classes = nmsed_classes.long()
+ dets = torch.cat((nmsed_boxes.reshape((max_total_size, 4)), nmsed_scores.reshape((max_total_size, 1))), -1)
+ labels = nmsed_classes.reshape((max_total_size, ))
+ return dets, labels
def batched_nms(
boxes: torch.Tensor, scores: torch.Tensor, idxs: torch.Tensor, iou_threshold: float
@@ -87,20 +87,33 @@ class Box2BoxTransform(object):
deltas = deltas.float() # ensure fp32 for decoding precision
boxes = boxes.to(deltas.dtype)
- widths = boxes[:, 2] - boxes[:, 0]
- heights = boxes[:, 3] - boxes[:, 1]
- ctr_x = boxes[:, 0] + 0.5 * widths
- ctr_y = boxes[:, 1] + 0.5 * heights
+ boxes_prof = boxes.permute(1, 0)
+ widths = boxes_prof[2, :] - boxes_prof[0, :]
+ heights = boxes_prof[3, :] - boxes_prof[1, :]
+ ctr_x = boxes_prof[0, :] + 0.5 * widths
+ ctr_y = boxes_prof[1, :] + 0.5 * heights
wx, wy, ww, wh = self.weights
- dx = deltas[:, 0::4] / wx
+ '''dx = deltas[:, 0::4] / wx
dy = deltas[:, 1::4] / wy
dw = deltas[:, 2::4] / ww
- dh = deltas[:, 3::4] / wh
+ dh = deltas[:, 3::4] / wh'''
+ denorm_deltas = deltas
+ if denorm_deltas.shape[1] > 4:
+ denorm_deltas = denorm_deltas.view(-1, 80, 4)
+ dx = denorm_deltas[:, :, 0:1:].view(-1, 80) / wx
+ dy = denorm_deltas[:, :, 1:2:].view(-1, 80) / wy
+ dw = denorm_deltas[:, :, 2:3:].view(-1, 80) / ww
+ dh = denorm_deltas[:, :, 3:4:].view(-1, 80) / wh
+ else:
+ dx = denorm_deltas[:, 0:1:] / wx
+ dy = denorm_deltas[:, 1:2:] / wy
+ dw = denorm_deltas[:, 2:3:] / ww
+ dh = denorm_deltas[:, 3:4:] / wh
# Prevent sending too large values into torch.exp()
- dw = torch.clamp(dw, max=self.scale_clamp)
- dh = torch.clamp(dh, max=self.scale_clamp)
+ dw = torch.clamp(dw, min=-float('inf'), max=self.scale_clamp)
+ dh = torch.clamp(dh, min=-float('inf'), max=self.scale_clamp)
pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
@@ -199,8 +199,9 @@ class GeneralizedRCNN(nn.Module):
"""
assert not self.training
- images = self.preprocess_image(batched_inputs)
- features = self.backbone(images.tensor)
+ # images = self.preprocess_image(batched_inputs)
+ images = batched_inputs
+ features = self.backbone(images)
if detected_instances is None:
if self.proposal_generator is not None:
@@ -94,6 +94,31 @@ def convert_boxes_to_pooler_format(box_lists: List[Boxes]):
return pooler_fmt_boxes
+import torch.onnx.symbolic_helper as sym_help
+
+class RoiExtractor(torch.autograd.Function):
+ @staticmethod
+ def forward(self, f0, f1, f2, f3, rois, aligned=0, finest_scale=56, pooled_height=7, pooled_width=7,
+ pool_mode='avg', roi_scale_factor=0, sample_num=0, spatial_scale=[0.25, 0.125, 0.0625, 0.03125]):
+ """
+ feats (torch.Tensor): feats in shape (batch, 256, H, W).
+ rois (torch.Tensor): rois in shape (k, 5).
+ return:
+ roi_feats (torch.Tensor): (k, 256, pooled_width, pooled_width)
+ """
+
+ # phony implementation for shape inference
+ k = rois.size()[0]
+ roi_feats = torch.ones(k, 256, pooled_height, pooled_width)
+ return roi_feats
+
+ @staticmethod
+ def symbolic(g, f0, f1, f2, f3, rois, aligned=0, finest_scale=56, pooled_height=7, pooled_width=7):
+ # TODO: support tensor list type for feats
+ #f_tensors = sym_help._unpack_list(feats)
+ roi_feats = g.op('RoiExtractor', f0, f1, f2, f3, rois, aligned_i=0, finest_scale_i=56, pooled_height_i=pooled_height, pooled_width_i=pooled_width,
+ pool_mode_s='avg', roi_scale_factor_i=0, sample_num_i=0, spatial_scale_f=[0.25, 0.125, 0.0625, 0.03125], outputs=1)
+ return roi_feats
class ROIPooler(nn.Module):
"""
@@ -202,6 +227,12 @@ class ROIPooler(nn.Module):
A tensor of shape (M, C, output_size, output_size) where M is the total number of
boxes aggregated over all N batch images and C is the number of channels in `x`.
"""
+ if torch.onnx.is_in_onnx_export():
+ output_size = self.output_size[0]
+ pooler_fmt_boxes = convert_boxes_to_pooler_format(box_lists)
+ roi_feats = RoiExtractor.apply(x[0], x[1], x[2], x[3], pooler_fmt_boxes, 0, 56, output_size, output_size)
+ return roi_feats
+
num_level_assignments = len(self.level_poolers)
assert isinstance(x, list) and isinstance(
@@ -4,7 +4,7 @@ import math
from typing import List, Tuple
import torch
-from detectron2.layers import batched_nms, cat
+from detectron2.layers import batch_nms_op, cat
from detectron2.structures import Boxes, Instances
from detectron2.utils.env import TORCH_VERSION
@@ -68,15 +68,19 @@ def find_top_rpn_proposals(
for level_id, (proposals_i, logits_i) in enumerate(zip(proposals, pred_objectness_logits)):
Hi_Wi_A = logits_i.shape[1]
if isinstance(Hi_Wi_A, torch.Tensor): # it's a tensor in tracing
- num_proposals_i = torch.clamp(Hi_Wi_A, max=pre_nms_topk)
+ num_proposals_i = torch.clamp(Hi_Wi_A, min=0, max=pre_nms_topk)
else:
num_proposals_i = min(Hi_Wi_A, pre_nms_topk)
# sort is faster than topk: https://github.com/pytorch/pytorch/issues/22812
- # topk_scores_i, topk_idx = logits_i.topk(num_proposals_i, dim=1)
- logits_i, idx = logits_i.sort(descending=True, dim=1)
+ num_proposals_i = num_proposals_i.item()
+ logits_i = logits_i.reshape(logits_i.size(1))
+ topk_scores_i, topk_idx = torch.topk(logits_i, num_proposals_i)
+ topk_scores_i = topk_scores_i.reshape(1, topk_scores_i.size(0))
+ topk_idx = topk_idx.reshape(1, topk_idx.size(0))
+ '''logits_i, idx = logits_i.sort(descending=True, dim=1)
topk_scores_i = logits_i.narrow(1, 0, num_proposals_i)
- topk_idx = idx.narrow(1, 0, num_proposals_i)
+ topk_idx = idx.narrow(1, 0, num_proposals_i)'''
# each is N x topk
topk_proposals_i = proposals_i[batch_idx[:, None], topk_idx] # N x topk x 4
@@ -108,7 +112,7 @@ def find_top_rpn_proposals(
lvl = lvl[valid_mask]
boxes.clip(image_size)
- # filter empty boxes
+ '''# filter empty boxes
keep = boxes.nonempty(threshold=min_box_size)
if _is_tracing() or keep.sum().item() != len(boxes):
boxes, scores_per_img, lvl = boxes[keep], scores_per_img[keep], lvl[keep]
@@ -126,7 +130,14 @@ def find_top_rpn_proposals(
res = Instances(image_size)
res.proposal_boxes = boxes[keep]
res.objectness_logits = scores_per_img[keep]
+ results.append(res)'''
+
+ dets, labels = batch_nms_op(boxes.tensor, scores_per_img, 0, nms_thresh, post_nms_topk, post_nms_topk)
+ res = Instances(image_size)
+ res.proposal_boxes = Boxes(dets[:, :4])
+ res.objectness_logits = dets[:, 4]
results.append(res)
+
return results
@@ -434,7 +434,7 @@ class RPN(nn.Module):
else:
losses = {}
proposals = self.predict_proposals(
- anchors, pred_objectness_logits, pred_anchor_deltas, images.image_sizes
+ anchors, pred_objectness_logits, pred_anchor_deltas, [(1344, 1344)]
)
return proposals, losses
@@ -485,7 +485,8 @@ class RPN(nn.Module):
B = anchors_i.tensor.size(1)
pred_anchor_deltas_i = pred_anchor_deltas_i.reshape(-1, B)
# Expand anchors to shape (N*Hi*Wi*A, B)
- anchors_i = anchors_i.tensor.unsqueeze(0).expand(N, -1, -1).reshape(-1, B)
+ s = torch.zeros(N, anchors_i.tensor.unsqueeze(0).size(1), anchors_i.tensor.unsqueeze(0).size(2))
+ anchors_i = anchors_i.tensor.unsqueeze(0).expand_as(s).reshape(-1, B)
proposals_i = self.box2box_transform.apply_deltas(pred_anchor_deltas_i, anchors_i)
# Append feature map proposals with shape (N, Hi*Wi*A, B)
proposals.append(proposals_i.view(N, -1, B))
@@ -7,7 +7,7 @@ from torch import nn
from torch.nn import functional as F
from detectron2.config import configurable
-from detectron2.layers import ShapeSpec, batched_nms, cat, cross_entropy, nonzero_tuple
+from detectron2.layers import ShapeSpec, batch_nms_op, cat, cross_entropy, nonzero_tuple
from detectron2.modeling.box_regression import Box2BoxTransform
from detectron2.structures import Boxes, Instances
from detectron2.utils.events import get_event_storage
@@ -144,7 +144,7 @@ def fast_rcnn_inference_single_image(
# Convert to Boxes to use the `clip` function ...
boxes = Boxes(boxes.reshape(-1, 4))
boxes.clip(image_shape)
- boxes = boxes.tensor.view(-1, num_bbox_reg_classes, 4) # R x C x 4
+ boxes = boxes.tensor.view(-1, num_bbox_reg_classes.item(), 4) # R x C x 4
# 1. Filter results based on detection scores. It can make NMS more efficient
# by filtering out low-confidence detections.
@@ -152,7 +152,7 @@ def fast_rcnn_inference_single_image(
# R' x 2. First column contains indices of the R predictions;
# Second column contains indices of classes.
filter_inds = filter_mask.nonzero()
- if num_bbox_reg_classes == 1:
+ '''if num_bbox_reg_classes == 1:
boxes = boxes[filter_inds[:, 0], 0]
else:
boxes = boxes[filter_mask]
@@ -167,7 +167,14 @@ def fast_rcnn_inference_single_image(
result = Instances(image_shape)
result.pred_boxes = Boxes(boxes)
result.scores = scores
- result.pred_classes = filter_inds[:, 1]
+ result.pred_classes = filter_inds[:, 1]'''
+
+ dets, labels = batch_nms_op(boxes, scores, score_thresh, nms_thresh, topk_per_image, topk_per_image)
+ result = Instances(image_shape)
+ result.pred_boxes = Boxes(dets[:, :4])
+ result.scores = dets.permute(1, 0)[4, :]
+ result.pred_classes = labels
+
return result, filter_inds[:, 0]
@@ -142,7 +142,9 @@ def mask_rcnn_inference(pred_mask_logits: torch.Tensor, pred_instances: List[Ins
num_masks = pred_mask_logits.shape[0]
class_pred = cat([i.pred_classes for i in pred_instances])
indices = torch.arange(num_masks, device=class_pred.device)
- mask_probs_pred = pred_mask_logits[indices, class_pred][:, None].sigmoid()
+ print(indices,class_pred)
+ # mask_probs_pred = pred_mask_logits[indices, class_pred][:, None].sigmoid()
+ mask_probs_pred = pred_mask_logits.sigmoid()
# mask_probs_pred.shape: (B, 1, Hmask, Wmask)
num_boxes_per_image = [len(i) for i in pred_instances]
@@ -202,10 +202,11 @@ class Boxes:
"""
assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!"
h, w = box_size
- x1 = self.tensor[:, 0].clamp(min=0, max=w)
- y1 = self.tensor[:, 1].clamp(min=0, max=h)
- x2 = self.tensor[:, 2].clamp(min=0, max=w)
- y2 = self.tensor[:, 3].clamp(min=0, max=h)
+ boxes_prof = self.tensor.permute(1, 0)
+ x1 = boxes_prof[0, :].clamp(min=0, max=w)
+ y1 = boxes_prof[1, :].clamp(min=0, max=h)
+ x2 = boxes_prof[2, :].clamp(min=0, max=w)
+ y2 = boxes_prof[3, :].clamp(min=0, max=h)
self.tensor = torch.stack((x1, y1, x2, y2), dim=-1)
def nonempty(self, threshold: float = 0.0) -> torch.Tensor:
@@ -77,6 +77,28 @@ def export_scripting(torch_model):
# TODO inference in Python now missing postprocessing glue code
return None
+from typing import Dict, Tuple
+import numpy
+from detectron2.structures import ImageList
+def preprocess_image(batched_inputs: Tuple[Dict[str, torch.Tensor]]):
+ """
+ Normalize, pad and batch the input images.
+ """
+ images = [x["image"].to('cpu') for x in batched_inputs]
+ images = [(x - numpy.array([[[103.530]], [[116.280]], [[123.675]]])) / numpy.array([[[1.]], [[1.]], [[1.]]]) for x in images]
+ import torch.nn.functional as F
+ image = torch.zeros(0, 1344, 1344)
+ for i in range(images[0].size(0)):
+ img = images[0][i]
+ img = img.expand((1, 1, img.size(0), img.size(1)))
+ img = img.to(dtype=torch.float32)
+ img = F.interpolate(img, size=(int(1344), int(1344)), mode='bilinear', align_corners=False)
+ img = img[0][0]
+ img = img.unsqueeze(0)
+ image = torch.cat((image, img))
+ images = [image]
+ images = ImageList.from_tensors(images, 32)
+ return images
# experimental. API not yet final
def export_tracing(torch_model, inputs):
@@ -84,6 +106,8 @@ def export_tracing(torch_model, inputs):
image = inputs[0]["image"]
inputs = [{"image": image}] # remove other unused keys
+ inputs = preprocess_image(inputs).tensor.to(torch.float32)
+ image = inputs
if isinstance(torch_model, GeneralizedRCNN):
def inference(model, inputs):
@@ -104,7 +128,7 @@ def export_tracing(torch_model, inputs):
elif args.format == "onnx":
# NOTE onnx export currently failing in pytorch
with PathManager.open(os.path.join(args.output, "model.onnx"), "wb") as f:
- torch.onnx.export(traceable_model, (image,), f)
+ torch.onnx.export(traceable_model, (image,), f, opset_version=11, verbose=True)
logger.info("Inputs schema: " + str(traceable_model.inputs_schema))
logger.info("Outputs schema: " + str(traceable_model.outputs_schema))