diff --git a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
index e9eb3579..1311b8e0 100644
--- a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
+++ b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
@@ -1,3 +1,7 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import mmcv
 import numpy as np
 import torch
 
@@ -20,16 +24,25 @@ class DeltaXYWHBBoxCoder(BaseBBoxCoder):
             target for delta coordinates
         clip_border (bool, optional): Whether clip the objects outside the
             border of the image. Defaults to True.
+        add_ctr_clamp (bool): Whether to add center clamp, when added, the
+            predicted box is clamped is its center is too far away from
+            the original anchor's center. Only used by YOLOF. Default False.
+        ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
+            Default 32.
     """
 
     def __init__(self,
                  target_means=(0., 0., 0., 0.),
                  target_stds=(1., 1., 1., 1.),
-                 clip_border=True):
+                 clip_border=True,
+                 add_ctr_clamp=False,
+                 ctr_clamp=32):
         super(BaseBBoxCoder, self).__init__()
         self.means = target_means
         self.stds = target_stds
         self.clip_border = clip_border
+        self.add_ctr_clamp = add_ctr_clamp
+        self.ctr_clamp = ctr_clamp
 
     def encode(self, bboxes, gt_bboxes):
         """Get box regression transformation deltas that can be used to
@@ -57,10 +70,16 @@ class DeltaXYWHBBoxCoder(BaseBBoxCoder):
         """Apply transformation `pred_bboxes` to `boxes`.
 
         Args:
-            boxes (torch.Tensor): Basic boxes.
-            pred_bboxes (torch.Tensor): Encoded boxes with shape
-            max_shape (tuple[int], optional): Maximum shape of boxes.
-                Defaults to None.
+            bboxes (torch.Tensor): Basic boxes. Shape (B, N, 4) or (N, 4)
+            pred_bboxes (Tensor): Encoded offsets with respect to each roi.
+               Has shape (B, N, num_classes * 4) or (B, N, 4) or
+               (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
+               when rois is a grid of anchors.Offset encoding follows [1]_.
+            max_shape (Sequence[int] or torch.Tensor or Sequence[
+               Sequence[int]],optional): Maximum bounds for boxes, specifies
+               (H, W, C) or (H, W). If bboxes shape is (B, N, 4), then
+               the max_shape should be a Sequence[Sequence[int]]
+               and the length of max_shape should also be B.
             wh_ratio_clip (float, optional): The allowed ratio between
                 width and height.
 
@@ -69,8 +88,28 @@ class DeltaXYWHBBoxCoder(BaseBBoxCoder):
         """
 
         assert pred_bboxes.size(0) == bboxes.size(0)
-        decoded_bboxes = delta2bbox(bboxes, pred_bboxes, self.means, self.stds,
-                                    max_shape, wh_ratio_clip, self.clip_border)
+        if pred_bboxes.ndim == 3:
+            assert pred_bboxes.size(1) == bboxes.size(1)
+
+        if pred_bboxes.ndim == 2 and not torch.onnx.is_in_onnx_export():
+            # single image decode
+            decoded_bboxes = delta2bbox(bboxes, pred_bboxes, self.means,
+                                        self.stds, max_shape, wh_ratio_clip,
+                                        self.clip_border, self.add_ctr_clamp,
+                                        self.ctr_clamp)
+        else:
+            if pred_bboxes.ndim == 3 and not torch.onnx.is_in_onnx_export():
+                warnings.warn(
+                    'DeprecationWarning: onnx_delta2bbox is deprecated '
+                    'in the case of batch decoding and non-ONNX, '
+                    'please use “delta2bbox” instead. In order to improve '
+                    'the decoding speed, the batch function will no '
+                    'longer be supported. ')
+            decoded_bboxes = onnx_delta2bbox(bboxes, pred_bboxes, self.means,
+                                             self.stds, max_shape,
+                                             wh_ratio_clip, self.clip_border,
+                                             self.add_ctr_clamp,
+                                             self.ctr_clamp)
 
         return decoded_bboxes
 
@@ -126,7 +165,108 @@ def delta2bbox(rois,
                stds=(1., 1., 1., 1.),
                max_shape=None,
                wh_ratio_clip=16 / 1000,
-               clip_border=True):
+               clip_border=True,
+               add_ctr_clamp=False,
+               ctr_clamp=32):
+    """Apply deltas to shift/scale base boxes.
+
+    Typically the rois are anchor or proposed bounding boxes and the deltas are
+    network outputs used to shift/scale those boxes.
+    This is the inverse function of :func:`bbox2delta`.
+
+    Args:
+        rois (Tensor): Boxes to be transformed. Has shape (N, 4).
+        deltas (Tensor): Encoded offsets relative to each roi.
+            Has shape (N, num_classes * 4) or (N, 4). Note
+            N = num_base_anchors * W * H, when rois is a grid of
+            anchors. Offset encoding follows [1]_.
+        means (Sequence[float]): Denormalizing means for delta coordinates.
+            Default (0., 0., 0., 0.).
+        stds (Sequence[float]): Denormalizing standard deviation for delta
+            coordinates. Default (1., 1., 1., 1.).
+        max_shape (tuple[int, int]): Maximum bounds for boxes, specifies
+           (H, W). Default None.
+        wh_ratio_clip (float): Maximum aspect ratio for boxes. Default
+            16 / 1000.
+        clip_border (bool, optional): Whether clip the objects outside the
+            border of the image. Default True.
+        add_ctr_clamp (bool): Whether to add center clamp. When set to True,
+            the center of the prediction bounding box will be clamped to
+            avoid being too far away from the center of the anchor.
+            Only used by YOLOF. Default False.
+        ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
+            Default 32.
+
+    Returns:
+        Tensor: Boxes with shape (N, num_classes * 4) or (N, 4), where 4
+           represent tl_x, tl_y, br_x, br_y.
+
+    References:
+        .. [1] https://arxiv.org/abs/1311.2524
+
+    Example:
+        >>> rois = torch.Tensor([[ 0.,  0.,  1.,  1.],
+        >>>                      [ 0.,  0.,  1.,  1.],
+        >>>                      [ 0.,  0.,  1.,  1.],
+        >>>                      [ 5.,  5.,  5.,  5.]])
+        >>> deltas = torch.Tensor([[  0.,   0.,   0.,   0.],
+        >>>                        [  1.,   1.,   1.,   1.],
+        >>>                        [  0.,   0.,   2.,  -1.],
+        >>>                        [ 0.7, -1.9, -0.5,  0.3]])
+        >>> delta2bbox(rois, deltas, max_shape=(32, 32, 3))
+        tensor([[0.0000, 0.0000, 1.0000, 1.0000],
+                [0.1409, 0.1409, 2.8591, 2.8591],
+                [0.0000, 0.3161, 4.1945, 0.6839],
+                [5.0000, 5.0000, 5.0000, 5.0000]])
+    """
+    num_bboxes, num_classes = deltas.size(0), deltas.size(1) // 4
+    if num_bboxes == 0:
+        return deltas
+
+    deltas = deltas.reshape(-1, 4)
+
+    means = deltas.new_tensor(means).view(1, -1)
+    stds = deltas.new_tensor(stds).view(1, -1)
+    denorm_deltas = deltas * stds + means
+
+    dxy = denorm_deltas[:, :2]
+    dwh = denorm_deltas[:, 2:]
+
+    # Compute width/height of each roi
+    rois_ = rois.repeat(1, num_classes).reshape(-1, 4)
+    pxy = ((rois_[:, :2] + rois_[:, 2:]) * 0.5)
+    pwh = (rois_[:, 2:] - rois_[:, :2])
+
+    dxy_wh = pwh * dxy
+
+    max_ratio = np.abs(np.log(wh_ratio_clip))
+    if add_ctr_clamp:
+        dxy_wh = torch.clamp(dxy_wh, max=ctr_clamp, min=-ctr_clamp)
+        dwh = torch.clamp(dwh, max=max_ratio)
+    else:
+        dwh = dwh.clamp(min=-max_ratio, max=max_ratio)
+
+    gxy = pxy + dxy_wh
+    gwh = pwh * dwh.exp()
+    x1y1 = gxy - (gwh * 0.5)
+    x2y2 = gxy + (gwh * 0.5)
+    bboxes = torch.cat([x1y1, x2y2], dim=-1)
+    if clip_border and max_shape is not None:
+        bboxes[..., 0::2].clamp_(min=0, max=max_shape[1])
+        bboxes[..., 1::2].clamp_(min=0, max=max_shape[0])
+    bboxes = bboxes.reshape(num_bboxes, -1)
+    return bboxes
+
+
+def onnx_delta2bbox(rois,
+                    deltas,
+                    means=(0., 0., 0., 0.),
+                    stds=(1., 1., 1., 1.),
+                    max_shape=None,
+                    wh_ratio_clip=16 / 1000,
+                    clip_border=True,
+                    add_ctr_clamp=False,
+                    ctr_clamp=32):
     """Apply deltas to shift/scale base boxes.
 
     Typically the rois are anchor or proposed bounding boxes and the deltas are
@@ -134,21 +274,34 @@ def delta2bbox(rois,
     This is the inverse function of :func:`bbox2delta`.
 
     Args:
-        rois (Tensor): Boxes to be transformed. Has shape (N, 4)
+        rois (Tensor): Boxes to be transformed. Has shape (N, 4) or (B, N, 4)
         deltas (Tensor): Encoded offsets with respect to each roi.
-            Has shape (N, 4 * num_classes). Note N = num_anchors * W * H when
-            rois is a grid of anchors. Offset encoding follows [1]_.
-        means (Sequence[float]): Denormalizing means for delta coordinates
+            Has shape (B, N, num_classes * 4) or (B, N, 4) or
+            (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
+            when rois is a grid of anchors.Offset encoding follows [1]_.
+        means (Sequence[float]): Denormalizing means for delta coordinates.
+            Default (0., 0., 0., 0.).
         stds (Sequence[float]): Denormalizing standard deviation for delta
-            coordinates
-        max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W)
+            coordinates. Default (1., 1., 1., 1.).
+        max_shape (Sequence[int] or torch.Tensor or Sequence[
+            Sequence[int]],optional): Maximum bounds for boxes, specifies
+            (H, W, C) or (H, W). If rois shape is (B, N, 4), then
+            the max_shape should be a Sequence[Sequence[int]]
+            and the length of max_shape should also be B. Default None.
         wh_ratio_clip (float): Maximum aspect ratio for boxes.
+            Default 16 / 1000.
         clip_border (bool, optional): Whether clip the objects outside the
-            border of the image. Defaults to True.
+            border of the image. Default True.
+        add_ctr_clamp (bool): Whether to add center clamp, when added, the
+            predicted box is clamped is its center is too far away from
+            the original anchor's center. Only used by YOLOF. Default False.
+        ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
+            Default 32.
 
     Returns:
-        Tensor: Boxes with shape (N, 4), where columns represent
-            tl_x, tl_y, br_x, br_y.
+        Tensor: Boxes with shape (B, N, num_classes * 4) or (B, N, 4) or
+           (N, num_classes * 4) or (N, 4), where 4 represent
+           tl_x, tl_y, br_x, br_y.
 
     References:
         .. [1] https://arxiv.org/abs/1311.2524
@@ -162,43 +315,76 @@ def delta2bbox(rois,
         >>>                        [  1.,   1.,   1.,   1.],
         >>>                        [  0.,   0.,   2.,  -1.],
         >>>                        [ 0.7, -1.9, -0.5,  0.3]])
-        >>> delta2bbox(rois, deltas, max_shape=(32, 32))
+        >>> delta2bbox(rois, deltas, max_shape=(32, 32, 3))
         tensor([[0.0000, 0.0000, 1.0000, 1.0000],
                 [0.1409, 0.1409, 2.8591, 2.8591],
                 [0.0000, 0.3161, 4.1945, 0.6839],
                 [5.0000, 5.0000, 5.0000, 5.0000]])
     """
-    means = deltas.new_tensor(means).view(1, -1).repeat(1, deltas.size(1) // 4)
-    stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(1) // 4)
+    means = deltas.new_tensor(means).view(1,
+                                          -1).repeat(1,
+                                                     deltas.size(-1) // 4)
+    stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(-1) // 4)
     denorm_deltas = deltas * stds + means
-    dx = denorm_deltas[:, 0::4]
-    dy = denorm_deltas[:, 1::4]
-    dw = denorm_deltas[:, 2::4]
-    dh = denorm_deltas[:, 3::4]
-    max_ratio = np.abs(np.log(wh_ratio_clip))
-    dw = dw.clamp(min=-max_ratio, max=max_ratio)
-    dh = dh.clamp(min=-max_ratio, max=max_ratio)
+    dx = denorm_deltas[..., 0::4]
+    dy = denorm_deltas[..., 1::4]
+    dw = denorm_deltas[..., 2::4]
+    dh = denorm_deltas[..., 3::4]
+
+    x1, y1 = rois[..., 0], rois[..., 1]
+    x2, y2 = rois[..., 2], rois[..., 3]
     # Compute center of each roi
-    px = ((rois[:, 0] + rois[:, 2]) * 0.5).unsqueeze(1).expand_as(dx)
-    py = ((rois[:, 1] + rois[:, 3]) * 0.5).unsqueeze(1).expand_as(dy)
+    px = ((x1 + x2) * 0.5).unsqueeze(-1).expand_as(dx)
+    py = ((y1 + y2) * 0.5).unsqueeze(-1).expand_as(dy)
     # Compute width/height of each roi
-    pw = (rois[:, 2] - rois[:, 0]).unsqueeze(1).expand_as(dw)
-    ph = (rois[:, 3] - rois[:, 1]).unsqueeze(1).expand_as(dh)
+    pw = (x2 - x1).unsqueeze(-1).expand_as(dw)
+    ph = (y2 - y1).unsqueeze(-1).expand_as(dh)
+
+    dx_width = pw * dx
+    dy_height = ph * dy
+
+    max_ratio = np.abs(np.log(wh_ratio_clip))
+    if add_ctr_clamp:
+        dx_width = torch.clamp(dx_width, max=ctr_clamp, min=-ctr_clamp)
+        dy_height = torch.clamp(dy_height, max=ctr_clamp, min=-ctr_clamp)
+        dw = torch.clamp(dw, max=max_ratio)
+        dh = torch.clamp(dh, max=max_ratio)
+    else:
+        dw = dw.clamp(min=-max_ratio, max=max_ratio)
+        dh = dh.clamp(min=-max_ratio, max=max_ratio)
     # Use exp(network energy) to enlarge/shrink each roi
     gw = pw * dw.exp()
     gh = ph * dh.exp()
     # Use network energy to shift the center of each roi
-    gx = px + pw * dx
-    gy = py + ph * dy
+    gx = px + dx_width
+    gy = py + dy_height
     # Convert center-xy/width/height to top-left, bottom-right
     x1 = gx - gw * 0.5
     y1 = gy - gh * 0.5
     x2 = gx + gw * 0.5
     y2 = gy + gh * 0.5
-    if clip_border and max_shape is not None:
-        x1 = x1.clamp(min=0, max=max_shape[1])
-        y1 = y1.clamp(min=0, max=max_shape[0])
-        x2 = x2.clamp(min=0, max=max_shape[1])
-        y2 = y2.clamp(min=0, max=max_shape[0])
+
     bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
+
+    if clip_border and max_shape is not None:
+        # clip bboxes with dynamic `min` and `max` for onnx
+        if torch.onnx.is_in_onnx_export():
+            from mmdet.core.export.onnx_helper import dynamic_clip_for_onnx
+            x1, y1, x2, y2 = dynamic_clip_for_onnx(x1, y1, x2, y2, max_shape)
+            bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
+            return bboxes
+        if not isinstance(max_shape, torch.Tensor):
+            max_shape = x1.new_tensor(max_shape)
+        max_shape = max_shape[..., :2].type_as(x1)
+        if max_shape.ndim == 2:
+            assert bboxes.ndim == 3
+            assert max_shape.size(0) == bboxes.size(0)
+
+        min_xy = x1.new_tensor(0)
+        max_xy = torch.cat(
+            [max_shape] * (deltas.size(-1) // 2),
+            dim=-1).flip(-1).unsqueeze(-2)
+        bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
+        bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
+
     return bboxes
diff --git a/mmdet/core/export/onnx_helper.py b/mmdet/core/export/onnx_helper.py
new file mode 100644
index 00000000..9abd220b
--- /dev/null
+++ b/mmdet/core/export/onnx_helper.py
@@ -0,0 +1,245 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+
+import torch
+
+
+def dynamic_clip_for_onnx(x1, y1, x2, y2, max_shape):
+    """Clip boxes dynamically for onnx.
+
+    Since torch.clamp cannot have dynamic `min` and `max`, we scale the
+      boxes by 1/max_shape and clamp in the range [0, 1].
+
+    Args:
+        x1 (Tensor): The x1 for bounding boxes.
+        y1 (Tensor): The y1 for bounding boxes.
+        x2 (Tensor): The x2 for bounding boxes.
+        y2 (Tensor): The y2 for bounding boxes.
+        max_shape (Tensor or torch.Size): The (H,W) of original image.
+    Returns:
+        tuple(Tensor): The clipped x1, y1, x2, y2.
+    """
+    # assert isinstance(
+    #     max_shape,
+    #     torch.Tensor), '`max_shape` should be tensor of (h,w) for onnx, got {}'.format(max_shape.__class__.__name__)
+
+    assert isinstance(max_shape, (torch.Tensor, torch.Size, list, tuple)), '`max_shape` should be ' + \
+        'torch.Tensor/torch.Size/list/tuple of (h, w) for onnx, got {}'.format(max_shape.__class__.__name__)
+    if not isinstance(max_shape, torch.Tensor):
+        max_shape = torch.tensor(max_shape, dtype=x1.dtype, device=x1.device)
+    else:
+        max_shape = max_shape.type_as(x1)
+
+    # scale by 1/max_shape
+    x1 = x1 / max_shape[1]
+    y1 = y1 / max_shape[0]
+    x2 = x2 / max_shape[1]
+    y2 = y2 / max_shape[0]
+
+    # clamp [0, 1]
+    x1 = torch.clamp(x1, 0, 1)
+    y1 = torch.clamp(y1, 0, 1)
+    x2 = torch.clamp(x2, 0, 1)
+    y2 = torch.clamp(y2, 0, 1)
+
+    # scale back
+    x1 = x1 * max_shape[1]
+    y1 = y1 * max_shape[0]
+    x2 = x2 * max_shape[1]
+    y2 = y2 * max_shape[0]
+    return x1, y1, x2, y2
+
+
+def get_k_for_topk(k, size):
+    """Get k of TopK for onnx exporting.
+
+    The K of TopK in TensorRT should not be a Tensor, while in ONNX Runtime
+      it could be a Tensor.Due to dynamic shape feature, we have to decide
+      whether to do TopK and what K it should be while exporting to ONNX.
+    If returned K is less than zero, it means we do not have to do
+      TopK operation.
+
+    Args:
+        k (int or Tensor): The set k value for nms from config file.
+        size (Tensor or torch.Size): The number of elements of \
+            TopK's input tensor
+    Returns:
+        tuple: (int or Tensor): The final K for TopK.
+    """
+    ret_k = -1
+    if k <= 0 or size <= 0:
+        return ret_k
+    if torch.onnx.is_in_onnx_export():
+        is_trt_backend = os.environ.get('ONNX_BACKEND') == 'MMCVTensorRT'
+        if is_trt_backend:
+            # TensorRT does not support dynamic K with TopK op
+            if 0 < k < size:
+                ret_k = k
+        else:
+            # Always keep topk op for dynamic input in onnx for ONNX Runtime
+            ret_k = torch.where(k < size, k, size)
+    elif k < size:
+        ret_k = k
+    else:
+        # ret_k is -1
+        pass
+    return ret_k
+
+
+def add_dummy_nms_for_onnx(boxes,
+                           scores,
+                           max_output_boxes_per_class=1000,
+                           iou_threshold=0.5,
+                           score_threshold=0.05,
+                           pre_top_k=-1,
+                           after_top_k=-1,
+                           labels=None):
+    """Create a dummy onnx::NonMaxSuppression op while exporting to ONNX.
+
+    This function helps exporting to onnx with batch and multiclass NMS op.
+    It only supports class-agnostic detection results. That is, the scores
+    is of shape (N, num_bboxes, num_classes) and the boxes is of shape
+    (N, num_boxes, 4).
+
+    Args:
+        boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4]
+        scores (Tensor): The detection scores of shape
+            [N, num_boxes, num_classes]
+        max_output_boxes_per_class (int): Maximum number of output
+            boxes per class of nms. Defaults to 1000.
+        iou_threshold (float): IOU threshold of nms. Defaults to 0.5
+        score_threshold (float): score threshold of nms.
+            Defaults to 0.05.
+        pre_top_k (bool): Number of top K boxes to keep before nms.
+            Defaults to -1.
+        after_top_k (int): Number of top K boxes to keep after nms.
+            Defaults to -1.
+        labels (Tensor, optional): It not None, explicit labels would be used.
+            Otherwise, labels would be automatically generated using
+            num_classed. Defaults to None.
+
+    Returns:
+        tuple[Tensor, Tensor]: dets of shape [N, num_det, 5]
+            and class labels of shape [N, num_det].
+    """
+    max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class])
+    iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
+    score_threshold = torch.tensor([score_threshold], dtype=torch.float32)
+    batch_size = scores.shape[0]
+    num_class = scores.shape[2]
+
+    if pre_top_k > 0:
+        nms_pre = torch.tensor(pre_top_k, device=scores.device, dtype=torch.long)
+        nms_pre = get_k_for_topk(nms_pre, boxes.shape[1])
+
+        if nms_pre > 0:
+            max_scores, _ = scores.max(-1)
+            _, topk_inds = max_scores.topk(nms_pre)
+            batch_inds = torch.arange(batch_size).view(
+                -1, 1).expand_as(topk_inds).long()
+            # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
+            # transformed_inds = boxes.shape[1] * batch_inds + topk_inds
+            transformed_inds = (boxes.shape[1] * batch_inds.int()) + topk_inds.int()
+            transformed_inds = transformed_inds.long()
+            boxes = boxes.reshape(-1, 4)[transformed_inds, :].reshape(
+                batch_size, -1, 4)
+            scores = scores.reshape(-1, num_class)[transformed_inds, :].reshape(
+                batch_size, -1, num_class)
+            if labels is not None:
+                labels = labels.reshape(-1, 1)[transformed_inds].reshape(
+                    batch_size, -1)
+
+    scores = scores.permute(0, 2, 1)
+    num_box = boxes.shape[1]
+    # turn off tracing to create a dummy output of nms
+    state = torch._C._get_tracing_state()
+    # dummy indices of nms's output
+    num_fake_det = 2
+    batch_inds = torch.randint(batch_size, (num_fake_det, 1))
+    cls_inds = torch.randint(num_class, (num_fake_det, 1))
+    box_inds = torch.randint(num_box, (num_fake_det, 1))
+    indices = torch.cat([batch_inds, cls_inds, box_inds], dim=1)
+    output = indices
+    setattr(DummyONNXNMSop, 'output', output)
+
+    # open tracing
+    torch._C._set_tracing_state(state)
+    selected_indices = DummyONNXNMSop.apply(boxes, scores,
+                                            max_output_boxes_per_class,
+                                            iou_threshold, score_threshold)
+
+    batch_inds, cls_inds = selected_indices[:, 0], selected_indices[:, 1]
+    box_inds = selected_indices[:, 2]
+    if labels is None:
+        labels = torch.arange(num_class, dtype=torch.long).to(scores.device)
+        labels = labels.view(1, num_class, 1).expand_as(scores)
+    scores = scores.reshape(-1, 1)
+    boxes = boxes.reshape(batch_size, -1).repeat(1, num_class).reshape(-1, 4)
+    # pos_inds = (num_class * batch_inds + cls_inds) * num_box + box_inds # original
+    pos_inds = (num_class * batch_inds.int()) + cls_inds.int()
+    pos_inds = (pos_inds * num_box.int()) + box_inds.int()
+    pos_inds = pos_inds.long()
+    # pos_inds = (batch_inds.new_tensor(num_class) * batch_inds + cls_inds) * batch_inds.new_tensor(num_box) + box_inds
+    mask = scores.new_zeros(scores.shape)
+    # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
+    # PyTorch style code: mask[batch_inds, box_inds] += 1
+    mask[pos_inds, :] += 1
+    scores = scores * mask
+    boxes = boxes * mask
+
+    scores = scores.reshape(batch_size, -1)
+    boxes = boxes.reshape(batch_size, -1, 4)
+    labels = labels.reshape(batch_size, -1)
+
+    if boxes.dtype != torch.float:
+        boxes = boxes.float()
+        scores = scores.float()
+
+    if after_top_k > 0:
+        nms_after = torch.tensor(
+            after_top_k, device=scores.device, dtype=torch.long)
+        nms_after = get_k_for_topk(nms_after, num_box * num_class)
+
+        if nms_after > 0:
+            _, topk_inds = scores.topk(nms_after)
+            batch_inds = torch.arange(batch_size).view(-1, 1).expand_as(topk_inds).long()
+            # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
+            batch_inds = scores.shape[1] * batch_inds
+            # transformed_inds = batch_inds + topk_inds
+            transformed_inds = batch_inds.int() + topk_inds.int()
+            transformed_inds = transformed_inds.long()
+            scores = scores.reshape(-1, 1)[transformed_inds, :].reshape(
+                batch_size, -1)
+            boxes = boxes.reshape(-1, 4)[transformed_inds, :].reshape(
+                batch_size, -1, 4)
+            labels = labels.reshape(-1, 1)[transformed_inds, :].reshape(
+                batch_size, -1)
+
+    scores = scores.unsqueeze(2)
+    dets = torch.cat([boxes, scores], dim=2)
+    return dets, labels
+
+
+class DummyONNXNMSop(torch.autograd.Function):
+    """DummyONNXNMSop.
+
+    This class is only for creating onnx::NonMaxSuppression.
+    """
+
+    @staticmethod
+    def forward(ctx, boxes, scores, max_output_boxes_per_class, iou_threshold,
+                score_threshold):
+
+        return DummyONNXNMSop.output
+
+    @staticmethod
+    def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold,
+                 score_threshold):
+        return g.op(
+            'NonMaxSuppression',
+            boxes,
+            scores,
+            max_output_boxes_per_class,
+            iou_threshold,
+            score_threshold,
+            outputs=1)
diff --git a/mmdet/core/export/pytorch2onnx.py b/mmdet/core/export/pytorch2onnx.py
index 8f9309df..b9f43d48 100644
--- a/mmdet/core/export/pytorch2onnx.py
+++ b/mmdet/core/export/pytorch2onnx.py
@@ -39,6 +39,7 @@ def generate_inputs_and_wrap_model(config_path, checkpoint_path, input_config):
 
     model = build_model_from_cfg(config_path, checkpoint_path)
     one_img, one_meta = preprocess_example_input(input_config)
+    one_meta['img_shape_for_onnx'] = one_img.shape[-2:]
     tensor_data = [one_img]
     model.forward = partial(
         model.forward, img_metas=[[one_meta]], return_loss=False)
diff --git a/mmdet/models/backbones/ssd_vgg.py b/mmdet/models/backbones/ssd_vgg.py
index cbc4fbb2..4bb7e37a 100644
--- a/mmdet/models/backbones/ssd_vgg.py
+++ b/mmdet/models/backbones/ssd_vgg.py
@@ -162,8 +162,14 @@ class L2Norm(nn.Module):
 
     def forward(self, x):
         """Forward function."""
-        # normalization layer convert to FP32 in FP16 training
+        # # normalization layer convert to FP32 in FP16 training
+        # x_float = x.float()
+        # norm = x_float.pow(2).sum(1, keepdim=True).sqrt() + self.eps
+        # return (self.weight[None, :, None, None].float().expand_as(x_float) *
+        #         x_float / norm).type_as(x)
+
         x_float = x.float()
-        norm = x_float.pow(2).sum(1, keepdim=True).sqrt() + self.eps
+        x_mul = x_float * x_float
+        norm = x_mul.sum(1, keepdim=True).sqrt() + self.eps
         return (self.weight[None, :, None, None].float().expand_as(x_float) *
                 x_float / norm).type_as(x)
diff --git a/mmdet/models/dense_heads/anchor_head.py b/mmdet/models/dense_heads/anchor_head.py
index a5bb4137..1aef3804 100644
--- a/mmdet/models/dense_heads/anchor_head.py
+++ b/mmdet/models/dense_heads/anchor_head.py
@@ -487,6 +487,162 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin):
             num_total_samples=num_total_samples)
         return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
 
+    @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+    def onnx_export(self,
+                    cls_scores,
+                    bbox_preds,
+                    score_factors=None,
+                    img_metas=None,
+                    with_nms=True):
+        """Transform network output for a batch into bbox predictions.
+
+        Args:
+            cls_scores (list[Tensor]): Box scores for each scale level
+                with shape (N, num_points * num_classes, H, W).
+            bbox_preds (list[Tensor]): Box energies / deltas for each scale
+                level with shape (N, num_points * 4, H, W).
+            score_factors (list[Tensor]): score_factors for each s
+                cale level with shape (N, num_points * 1, H, W).
+                Default: None.
+            img_metas (list[dict]): Meta information of each image, e.g.,
+                image size, scaling factor, etc. Default: None.
+            with_nms (bool): Whether apply nms to the bboxes. Default: True.
+
+        Returns:
+            tuple[Tensor, Tensor] | list[tuple]: When `with_nms` is True,
+            it is tuple[Tensor, Tensor], first tensor bboxes with shape
+            [N, num_det, 5], 5 arrange as (x1, y1, x2, y2, score)
+            and second element is class labels of shape [N, num_det].
+            When `with_nms` is False, first tensor is bboxes with
+            shape [N, num_det, 4], second tensor is raw score has
+            shape  [N, num_det, num_classes].
+        """
+        assert len(cls_scores) == len(bbox_preds)
+
+        num_levels = len(cls_scores)
+
+        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+
+        mlvl_priors = self.anchor_generator.grid_anchors(
+            featmap_sizes, device=bbox_preds[0].device)
+
+        mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
+        mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
+
+        assert len(
+            img_metas
+        ) == 1, 'Only support one input image while in exporting to ONNX'
+        img_shape = torch.tensor(
+            img_metas[0]['img_shape_for_onnx'],
+            dtype=torch.long,
+            device=bbox_preds[0].device)
+
+        cfg = self.test_cfg
+        assert len(cls_scores) == len(bbox_preds) == len(mlvl_priors)
+        device = cls_scores[0].device
+        batch_size = cls_scores[0].shape[0]
+        # convert to tensor to keep tracing
+        nms_pre_tensor = torch.tensor(
+            cfg.get('nms_pre', -1), device=device, dtype=torch.long)
+
+        # e.g. Retina, FreeAnchor, etc.
+        if score_factors is None:
+            with_score_factors = False
+            mlvl_score_factor = [None for _ in range(num_levels)]
+        else:
+            # e.g. FCOS, PAA, ATSS, etc.
+            with_score_factors = True
+            mlvl_score_factor = [
+                score_factors[i].detach() for i in range(num_levels)
+            ]
+            mlvl_score_factors = []
+
+        mlvl_batch_bboxes = []
+        mlvl_scores = []
+
+        for cls_score, bbox_pred, score_factors, priors in zip(
+                mlvl_cls_scores, mlvl_bbox_preds, mlvl_score_factor,
+                mlvl_priors):
+            assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
+
+            scores = cls_score.permute(0, 2, 3,
+                                       1).reshape(batch_size, -1,
+                                                  self.cls_out_channels)
+            if self.use_sigmoid_cls:
+                scores = scores.sigmoid()
+                nms_pre_score = scores
+            else:
+                scores = scores.softmax(-1)
+                nms_pre_score = scores
+
+            if with_score_factors:
+                score_factors = score_factors.permute(0, 2, 3, 1).reshape(
+                    batch_size, -1).sigmoid()
+            bbox_pred = bbox_pred.permute(0, 2, 3,
+                                          1).reshape(batch_size, -1, 4)
+            priors = priors.expand(batch_size, -1, priors.size(-1))
+            # Get top-k predictions
+            from mmdet.core.export.onnx_helper import get_k_for_topk
+            nms_pre = get_k_for_topk(nms_pre_tensor, bbox_pred.shape[1])
+            if nms_pre > 0:
+
+                if with_score_factors:
+                    nms_pre_score = (nms_pre_score * score_factors[..., None])
+                else:
+                    nms_pre_score = nms_pre_score
+
+                # Get maximum scores for foreground classes.
+                if self.use_sigmoid_cls:
+                    max_scores, _ = nms_pre_score.max(-1)
+                else:
+                    # remind that we set FG labels to [0, num_class-1]
+                    # since mmdet v2.0
+                    # BG cat_id: num_class
+                    max_scores, _ = nms_pre_score[..., :-1].max(-1)
+                _, topk_inds = max_scores.topk(nms_pre)
+
+                batch_inds = torch.arange(
+                    batch_size, device=bbox_pred.device).view(
+                        -1, 1).expand_as(topk_inds).long()
+                # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
+                # transformed_inds = bbox_pred.shape[1] * batch_inds + topk_inds
+                transformed_inds = (bbox_pred.shape[1] * batch_inds).int() + topk_inds.int()
+                transformed_inds = transformed_inds.long()
+                priors = priors.reshape(
+                    -1, priors.size(-1))[transformed_inds, :].reshape(
+                        batch_size, -1, priors.size(-1))
+                bbox_pred = bbox_pred.reshape(-1,
+                                              4)[transformed_inds, :].reshape(
+                                                  batch_size, -1, 4)
+                scores = scores.reshape(
+                    -1, self.cls_out_channels)[transformed_inds, :].reshape(
+                        batch_size, -1, self.cls_out_channels)
+                if with_score_factors:
+                    score_factors = score_factors.reshape(
+                        -1, 1)[transformed_inds].reshape(batch_size, -1)
+
+            bboxes = self.bbox_coder.decode(
+                priors, bbox_pred, max_shape=img_shape)
+
+            mlvl_batch_bboxes.append(bboxes)
+            mlvl_scores.append(scores)
+            if with_score_factors:
+                mlvl_score_factors.append(score_factors)
+
+        batch_bboxes = torch.cat(mlvl_batch_bboxes, dim=1)
+        batch_scores = torch.cat(mlvl_scores, dim=1)
+        if with_score_factors:
+            batch_score_factors = torch.cat(mlvl_score_factors, dim=1)
+
+        if not self.use_sigmoid_cls:
+            batch_scores = batch_scores[..., :self.num_classes]
+
+        if with_score_factors:
+            batch_scores = batch_scores * (batch_score_factors.unsqueeze(2))
+
+        # directly return bboxes without NMS
+        return batch_bboxes, batch_scores
+
     @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
     def get_bboxes(self,
                    cls_scores,
@@ -545,38 +701,45 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin):
             >>> assert det_bboxes.shape[1] == 5
             >>> assert len(det_bboxes) == len(det_labels) == cfg.max_per_img
         """
-        assert len(cls_scores) == len(bbox_preds)
-        num_levels = len(cls_scores)
-
-        device = cls_scores[0].device
-        featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
-        mlvl_anchors = self.anchor_generator.grid_anchors(
-            featmap_sizes, device=device)
-
-        result_list = []
-        for img_id in range(len(img_metas)):
-            cls_score_list = [
-                cls_scores[i][img_id].detach() for i in range(num_levels)
-            ]
-            bbox_pred_list = [
-                bbox_preds[i][img_id].detach() for i in range(num_levels)
-            ]
-            img_shape = img_metas[img_id]['img_shape']
-            scale_factor = img_metas[img_id]['scale_factor']
-            if with_nms:
-                # some heads don't support with_nms argument
-                proposals = self._get_bboxes_single(cls_score_list,
-                                                    bbox_pred_list,
-                                                    mlvl_anchors, img_shape,
-                                                    scale_factor, cfg, rescale)
-            else:
-                proposals = self._get_bboxes_single(cls_score_list,
-                                                    bbox_pred_list,
-                                                    mlvl_anchors, img_shape,
-                                                    scale_factor, cfg, rescale,
-                                                    with_nms)
-            result_list.append(proposals)
-        return result_list
+        if torch.onnx.is_in_onnx_export():
+            return self.onnx_export(cls_scores,
+                                    bbox_preds,
+                                    score_factors=None,
+                                    img_metas=img_metas,
+                                    with_nms=with_nms)
+        else:
+            assert len(cls_scores) == len(bbox_preds)
+            num_levels = len(cls_scores)
+
+            device = cls_scores[0].device
+            featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
+            mlvl_anchors = self.anchor_generator.grid_anchors(
+                featmap_sizes, device=device)
+
+            result_list = []
+            for img_id in range(len(img_metas)):
+                cls_score_list = [
+                    cls_scores[i][img_id].detach() for i in range(num_levels)
+                ]
+                bbox_pred_list = [
+                    bbox_preds[i][img_id].detach() for i in range(num_levels)
+                ]
+                img_shape = img_metas[img_id]['img_shape']
+                scale_factor = img_metas[img_id]['scale_factor']
+                if with_nms:
+                    # some heads don't support with_nms argument
+                    proposals = self._get_bboxes_single(cls_score_list,
+                                                        bbox_pred_list,
+                                                        mlvl_anchors, img_shape,
+                                                        scale_factor, cfg, rescale)
+                else:
+                    proposals = self._get_bboxes_single(cls_score_list,
+                                                        bbox_pred_list,
+                                                        mlvl_anchors, img_shape,
+                                                        scale_factor, cfg, rescale,
+                                                        with_nms)
+                result_list.append(proposals)
+            return result_list
 
     def _get_bboxes_single(self,
                            cls_score_list,
@@ -612,6 +775,7 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin):
                 are bounding box positions (tl_x, tl_y, br_x, br_y) and the
                 5-th column is a score between 0 and 1.
         """
+        print('in _get_bboxes_single')
         cfg = self.test_cfg if cfg is None else cfg
         assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors)
         mlvl_bboxes = []
diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py
index a8e7487b..97ed2d09 100644
--- a/tools/pytorch2onnx.py
+++ b/tools/pytorch2onnx.py
@@ -33,23 +33,32 @@ def pytorch2onnx(config_path,
     one_img, one_meta = preprocess_example_input(input_config)
     model, tensor_data = generate_inputs_and_wrap_model(
         config_path, checkpoint_path, input_config)
+
+    input_names = ['input']
+    dynamic_axes = {'input': {0: 'batch', 2: 'height', 3: 'width'}}
+
     output_names = ['boxes']
+    dynamic_axes['boxes'] = {0: 'batch'}
     if model.with_bbox:
         output_names.append('labels')
+        dynamic_axes['labels'] = {0: 'batch'}
     if model.with_mask:
         output_names.append('masks')
+        dynamic_axes['masks'] = {0: 'batch'}
 
     torch.onnx.export(
         model,
         tensor_data,
         output_file,
-        input_names=['input'],
+        input_names=input_names,
         output_names=output_names,
+        dynamic_axes=dynamic_axes,
         export_params=True,
         keep_initializers_as_inputs=True,
         do_constant_folding=True,
         verbose=show,
-        opset_version=opset_version)
+        opset_version=opset_version,
+        enable_onnx_checker=False)
 
     model.forward = orig_model.forward
     print(f'Successfully exported ONNX model: {output_file}')
@@ -67,6 +76,7 @@ def pytorch2onnx(config_path,
             tensor_data = [one_img]
         # check the numerical value
         # get pytorch output
+        one_meta['img_shape_for_onnx'] = one_img.shape[-2:]
         pytorch_results = model(tensor_data, [[one_meta]], return_loss=False)
         pytorch_results = pytorch_results[0]
         # get onnx output