@@ -10,6 +10,54 @@ from .anchor_free_head import AnchorFreeHead
INF = 1e8
+class BatchNMSOp(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size):
+ """
+ boxes (torch.Tensor): boxes in shape (batch, N, C, 4).
+ scores (torch.Tensor): scores in shape (batch, N, C).
+ return:
+ nmsed_boxes: (1, N, 4)
+ nmsed_scores: (1, N)
+ nmsed_classes: (1, N)
+ nmsed_num: (1,)
+ """
+
+ # Phony implementation for onnx export
+ nmsed_boxes = bboxes[:, :max_total_size, 0, :]
+ nmsed_scores = scores[:, :max_total_size, 0]
+ nmsed_classes = torch.arange(max_total_size, dtype=torch.long)
+ nmsed_num = torch.Tensor([max_total_size])
+
+ return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
+
+ @staticmethod
+ def symbolic(g, bboxes, scores, score_thr, iou_thr, max_size_p_class, max_t_size):
+ nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = g.op('BatchMultiClassNMS',
+ bboxes, scores, score_threshold_f=score_thr, iou_threshold_f=iou_thr,
+ max_size_per_class_i=max_size_p_class, max_total_size_i=max_t_size, outputs=4)
+ return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
+
+def batch_nms_op(bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size):
+ """
+ boxes (torch.Tensor): boxes in shape (N, 4).
+ scores (torch.Tensor): scores in shape (N, ).
+ """
+
+ if bboxes.dtype == torch.float32:
+ bboxes = bboxes.reshape(bboxes.size(0), bboxes.shape[1].numpy(), -1, 4).half()
+ scores = scores.reshape(scores.size(0), scores.shape[1].numpy(), -1).half()
+ else:
+ bboxes = bboxes.reshape(bboxes.size(0), bboxes.shape[1].numpy(), -1, 4)
+ scores = scores.reshape(scores.size(0), scores.shape[1].numpy(), -1)
+ nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = BatchNMSOp.apply(bboxes, scores,
+ score_threshold, iou_threshold, max_size_per_class, max_total_size)
+ nmsed_boxes = nmsed_boxes.float()
+ nmsed_scores = nmsed_scores.float()
+ nmsed_classes = nmsed_classes.long()
+ dets = torch.cat((nmsed_boxes.reshape((bboxes.size(0), max_total_size, 4)), nmsed_scores.reshape((scores.size(0), max_total_size, 1))), -1)
+ labels = nmsed_classes.reshape((bboxes.size(0), max_total_size))
+ return dets, labels
@HEADS.register_module()
class FCOSHead(AnchorFreeHead):
@@ -387,19 +435,23 @@ class FCOSHead(AnchorFreeHead):
bbox_pred = bbox_pred.permute(0, 2, 3,
1).reshape(batch_size, -1, 4)
- points = points.expand(batch_size, -1, 2)
+ s = torch.zeros(batch_size, points.size(0), 2)
+ points = points.expand_as(s)
# Get top-k prediction
from mmdet.core.export import get_k_for_topk
- nms_pre = get_k_for_topk(nms_pre_tensor, bbox_pred.shape[1])
+ nms_pre = bbox_pred.shape[1]
+ if nms_pre_tensor > 0 and bbox_pred.shape[1] > nms_pre_tensor:
+ nms_pre = nms_pre_tensor#get_k_for_topk(nms_pre_tensor, bbox_pred.shape[1])
if nms_pre > 0:
max_scores, _ = (scores * centerness[..., None]).max(-1)
_, topk_inds = max_scores.topk(nms_pre)
batch_inds = torch.arange(batch_size).view(
- -1, 1).expand_as(topk_inds).long()
+ -1, 1).to(dtype=torch.int32).expand_as(topk_inds).long()
# Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
if torch.onnx.is_in_onnx_export():
transformed_inds = bbox_pred.shape[
- 1] * batch_inds + topk_inds
+ 1] * batch_inds.to(dtype=torch.int32) + topk_inds.to(dtype=torch.int32)
+ transformed_inds = transformed_inds.to(dtype=torch.int64)
points = points.reshape(-1,
2)[transformed_inds, :].reshape(
batch_size, -1, 2)
@@ -438,6 +490,8 @@ class FCOSHead(AnchorFreeHead):
iou_threshold = cfg.nms.get('iou_threshold', 0.5)
score_threshold = cfg.score_thr
nms_pre = cfg.get('deploy_nms_pre', -1)
+ dets, labels = batch_nms_op(batch_mlvl_bboxes, batch_mlvl_scores, score_threshold, iou_threshold, cfg.max_per_img, cfg.max_per_img)
+ return dets, labels
return add_dummy_nms_for_onnx(batch_mlvl_bboxes, batch_mlvl_scores,
max_output_boxes_per_class,
iou_threshold, score_threshold,