@@ -11,7 +11,57 @@ from mmdet.core import (anchor_inside_flags, bbox2distance, bbox_overlaps,
from ..builder import HEADS, build_loss
from .anchor_head import AnchorHead
+class BatchNMSOp(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size):
+ """
+ boxes (torch.Tensor): boxes in shape (batch, N, C, 4).
+ scores (torch.Tensor): scores in shape (batch, N, C).
+ return:
+ nmsed_boxes: (1, N, 4)
+ nmsed_scores: (1, N)
+ nmsed_classes: (1, N)
+ nmsed_num: (1,)
+ """
+
+ # Phony implementation for onnx export
+ nmsed_boxes = bboxes[:, :max_total_size, 0, :]
+ nmsed_scores = scores[:, :max_total_size, 0]
+ nmsed_classes = torch.arange(max_total_size, dtype=torch.long)
+ nmsed_num = torch.Tensor([max_total_size])
+
+ return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
+
+ @staticmethod
+ def symbolic(g, bboxes, scores, score_thr, iou_thr, max_size_p_class, max_t_size):
+ nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = g.op('BatchMultiClassNMS',
+ bboxes, scores, score_threshold_f=score_thr, iou_threshold_f=iou_thr,
+ max_size_per_class_i=max_size_p_class, max_total_size_i=max_t_size, outputs=4)
+ return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
+
+def batch_nms_op(bboxes, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size):
+ """
+ boxes (torch.Tensor): boxes in shape (N, 4).
+ scores (torch.Tensor): scores in shape (N, ).
+ """
+ num_classes = bboxes.shape[1].numpy() // 4
+ if bboxes.dtype == torch.float32:
+ bboxes = bboxes.reshape(1, bboxes.shape[0].numpy(), -1, 4).half()
+ scores = scores.reshape(1, scores.shape[0].numpy(), -1).half()
+ else:
+ bboxes = bboxes.reshape(1, bboxes.shape[0].numpy(), -1, 4)
+ scores = scores.reshape(1, scores.shape[0].numpy(), -1)
+ nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = BatchNMSOp.apply(bboxes, scores,
+ score_threshold, iou_threshold, max_size_per_class, max_total_size)
+ nmsed_boxes = nmsed_boxes.float()
+ nmsed_scores = nmsed_scores.float()
+ nmsed_classes = nmsed_classes.long()
+ dets = torch.cat((nmsed_boxes.reshape((max_total_size, 4)), nmsed_scores.reshape((max_total_size, 1))), -1)
+ dets = dets.reshape((max_total_size, 5))
+ labels = nmsed_classes.reshape((max_total_size, ))
+ return dets, labels
+
class Integral(nn.Module):
"""A fixed layer for calculating integral result from distribution.
@@ -44,7 +94,7 @@ class Integral(nn.Module):
offsets from the box center in four directions, shape (N, 4).
"""
x = F.softmax(x.reshape(-1, self.reg_max + 1), dim=1)
- x = F.linear(x, self.project.type_as(x)).reshape(-1, 4)
+ x = F.linear(x, self.project.type_as(x).view(-1, self.reg_max + 1)).reshape(-1, 4)
return x
@@ -211,7 +261,9 @@ class GFocalHead(AnchorHead):
bbox_pred = scale(self.gfl_reg(reg_feat)).float()
N, C, H, W = bbox_pred.size()
prob = F.softmax(bbox_pred.reshape(N, 4, self.reg_max+1, H, W), dim=2)
- prob_topk, _ = prob.topk(self.reg_topk, dim=2)
+ prob_tmp = prob.permute(0, 1, 4, 3, 2)
+ prob_topk, _ = torch.topk(prob_tmp, self.reg_topk)
+ prob_topk = prob_topk.permute(0, 1, 4, 3, 2)
if self.add_mean:
stat = torch.cat([prob_topk, prob_topk.mean(dim=2, keepdim=True)],
@@ -477,10 +529,8 @@ class GFocalHead(AnchorHead):
mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
if with_nms:
- det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
- cfg.score_thr, cfg.nms,
- cfg.max_per_img)
- return det_bboxes, det_labels
+ dets, labels = batch_nms_op(mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms.get("iou_threshold"), cfg.max_per_img, cfg.max_per_img)
+ return dets, labels
else:
return mlvl_bboxes, mlvl_scores
@@ -40,7 +40,8 @@ def pytorch2onnx(config_path,
export_params=True,
keep_initializers_as_inputs=True,
verbose=show,
- opset_version=opset_version)
+ opset_version=opset_version,
+ enable_onnx_checker=False)
model.forward = orig_model.forward
print(f'Successfully exported ONNX model: {output_file}')