diff --git a/maskrcnn_benchmark/config/defaults.py b/maskrcnn_benchmark/config/defaults.py
index ff5db38..b4c49f7 100644
--- a/maskrcnn_benchmark/config/defaults.py
+++ b/maskrcnn_benchmark/config/defaults.py
@@ -36,7 +36,7 @@ _C.MODEL.WEIGHT = ""
 # -----------------------------------------------------------------------------
 _C.INPUT = CN()
 # Size of the fixed shape
-_C.INPUT.FIX_SHAPE = (1344, 1344)
+_C.INPUT.FIX_SHAPE = 1344
 # Size of the smallest side of the image during training
 _C.INPUT.MIN_SIZE_TRAIN = (800,)  # 800
 # Maximum size of the side of the image during training
diff --git a/maskrcnn_benchmark/layers/npu_roi_align.py b/maskrcnn_benchmark/layers/npu_roi_align.py
index 3b80a9e..f905232 100644
--- a/maskrcnn_benchmark/layers/npu_roi_align.py
+++ b/maskrcnn_benchmark/layers/npu_roi_align.py
@@ -54,12 +54,41 @@ class _ROIAlign(Function):
         return grad_input, None, None, None, None, None
 
 
-roi_align = _ROIAlign.apply
+class RoiExtractor(torch.autograd.Function):
+    @staticmethod
+    def forward(self, feats, rois, aligned=0, finest_scale=56, pooled_height=14, pooled_width=14,
+                pool_mode='avg', roi_scale_factor=0, sample_num=0, spatial_scale=[0.125, ]):
+        """
+        feats (torch.Tensor): feats in shape (batch, 256, H, W).
+        rois (torch.Tensor): rois in shape (k, 5).
+        return:
+            roi_feats (torch.Tensor): (k, 256, pooled_width, pooled_width)
+        """
+
+        # phony implementation for shape inference
+        k = rois.shape[0]
+        roi_feats = torch.ones(k, 256, pooled_height, pooled_width)
+        return roi_feats
+
+    @staticmethod
+    def symbolic(g, feats, rois, aligned=0, finest_scale=56, pooled_height=14, pooled_width=14):
+        # TODO: support tensor list type for feats
+        # f_tensors = sym_help._unpack_list(feats)
+        roi_feats = g.op('RoiExtractor', feats, rois, aligned_i=0, finest_scale_i=56,
+                         pooled_height_i=pooled_height, pooled_width_i=pooled_width,
+                         pool_mode_s='avg', roi_scale_factor_i=0, sample_num_i=0,
+                         spatial_scale_f=[0.125, ], outputs=1)
+        return roi_feats
+
+
+from torchvision.ops import roi_align
+
+roi_align_ = roi_align
 
 
 # NOTE: torchvision's RoIAlign has a different default aligned=False
 class ROIAlign(nn.Module):
-    def __init__(self, output_size, spatial_scale, sampling_ratio, aligned=True):
+    def __init__(self, output_size, spatial_scale, sampling_ratio, aligned=False):
         """ROIAlign using npu api.
 
         Origin implement from detectron2 is
@@ -108,10 +137,19 @@ class ROIAlign(nn.Module):
             rois: Bx5 boxes. First column is the index into N. The other 4 columns are xyxy.
         """
         assert rois.dim() == 2 and rois.size(1) == 5
-        return roi_align(
-            input_tensor.float(), rois, self.output_size,
-            self.spatial_scale, self.sampling_ratio, self.aligned
-        )
+
+        # if torch.onnx.is_in_onnx_export():
+        #     return RoiExtractor.apply(input_tensor.float(), rois, 0, 56, self.output_size[0], self.output_size[1])
+        # else:
+        #     return roi_align_(
+        #         input_tensor.float(), rois, self.output_size,
+        #         self.spatial_scale, self.sampling_ratio, self.aligned
+        #     )
+
+        res = roi_align_(input_tensor.float(), rois, self.output_size, self.spatial_scale, self.sampling_ratio,
+                         self.aligned)
+
+        return res
 
     def __repr__(self):
         tmpstr = self.__class__.__name__ + "("
diff --git a/maskrcnn_benchmark/modeling/box_coder.py b/maskrcnn_benchmark/modeling/box_coder.py
index 7ecd760..4a501fa 100644
--- a/maskrcnn_benchmark/modeling/box_coder.py
+++ b/maskrcnn_benchmark/modeling/box_coder.py
@@ -68,10 +68,10 @@ class BoxCoder(object):
         ctr_y = boxes[:, 1] + 0.5 * heights
 
         wx, wy, ww, wh = self.weights
-        dx = rel_codes[:, 0::4] / wx
-        dy = rel_codes[:, 1::4] / wy
-        dw = rel_codes[:, 2::4] / ww
-        dh = rel_codes[:, 3::4] / wh
+        dx = torch.true_divide(rel_codes[:, 0::4], wx)
+        dy = torch.true_divide(rel_codes[:, 1::4], wy)
+        dw = torch.true_divide(rel_codes[:, 2::4], ww)
+        dh = torch.true_divide(rel_codes[:, 3::4], wh)
 
         # Prevent sending too large values into torch.exp()
         dw = torch.clamp(dw, max=self.bbox_xform_clip)
@@ -82,14 +82,10 @@ class BoxCoder(object):
         pred_w = torch.exp(dw) * widths[:, None]
         pred_h = torch.exp(dh) * heights[:, None]
 
-        pred_boxes = torch.zeros_like(rel_codes)
-        # x1
-        pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w
-        # y1
-        pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h
-        # x2 (note: "- 1" is correct; don't be fooled by the asymmetry)
-        pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w - 1
-        # y2 (note: "- 1" is correct; don't be fooled by the asymmetry)
-        pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h - 1
+        tmp1 = pred_ctr_x - 0.5 * pred_w
+        tmp2 = pred_ctr_y - 0.5 * pred_h
+        tmp3 = pred_ctr_x + 0.5 * pred_w - 1
+        tmp4 = pred_ctr_y + 0.5 * pred_h - 1
+        pred_boxes = torch.cat((tmp1, tmp2, tmp3, tmp4), dim=1)
 
         return pred_boxes
diff --git a/maskrcnn_benchmark/modeling/detector/retinanet.py b/maskrcnn_benchmark/modeling/detector/retinanet.py
index df43a82..e24cdcb 100644
--- a/maskrcnn_benchmark/modeling/detector/retinanet.py
+++ b/maskrcnn_benchmark/modeling/detector/retinanet.py
@@ -48,15 +48,15 @@ class RetinaNet(nn.Module):
         """
         if self.training and targets is None:
             raise ValueError("In training mode, targets should be passed")
-        images = to_image_list(images)
-        features = self.backbone(images.tensors)
+        features = self.backbone(images)
 
         # Retina RPN Output
         rpn_features = features
         if self.cfg.RETINANET.BACKBONE == "p2p7":
             rpn_features = features[1:]
 
-        (anchors, detections), detector_losses = self.rpn(images, rpn_features, targets)
+        image_sizes = [images.shape[-2:] for _ in range(images.shape[0])]
+        (anchors, detections), detector_losses = self.rpn(image_sizes, rpn_features)
 
         if self.training:
             losses = {}
@@ -89,25 +89,6 @@ class RetinaNet(nn.Module):
             return losses
         else:
             if self.mask:
-                proposals = []
-                for image_detections in detections:
-                    num_of_detections = image_detections.bbox.shape[0]
-                    if num_of_detections > self.cfg.RETINANET.NUM_MASKS_TEST > 0:
-                        cls_scores = image_detections.get_field("scores")
-                        cls_scores = cls_scores.type(torch.float32)
-                        _, keep = torch.topk(
-                            cls_scores,
-                            self.cfg.RETINANET.NUM_MASKS_TEST,
-                            largest=True
-                        )
-                        image_detections = image_detections[keep]
-
-                    proposals.append(image_detections)
-
-                if self.cfg.MODEL.SPARSE_MASK_ON:
-                    x, detections, mask_losses = self.mask(
-                        features, proposals, targets
-                    )
-                else:
-                    x, detections, mask_losses = self.mask(features, proposals, targets)
+                detections, masks = self.mask(features, detections)
+                return detections, masks
             return detections
diff --git a/maskrcnn_benchmark/modeling/poolers.py b/maskrcnn_benchmark/modeling/poolers.py
index c2601eb..90af26a 100644
--- a/maskrcnn_benchmark/modeling/poolers.py
+++ b/maskrcnn_benchmark/modeling/poolers.py
@@ -1,6 +1,8 @@
 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
 import torch
 
+import numpy as np
+
 from torch import nn
 from maskrcnn_benchmark.layers import ROIAlign
 from .utils import cat
@@ -26,18 +28,19 @@ class LevelMapper(object):
         self.lvl0 = canonical_level
         self.eps = eps
 
-    def __call__(self, boxlists):
+    def __call__(self, bboxes):
         """
         Arguments:
             boxlists (list[BoxList])
         """
         # Compute level ids
-        s = torch.sqrt(cat([boxlist.area() for boxlist in boxlists]))
+        s = torch.sqrt(cat([(bbox[0][:, 2] - bbox[0][:, 0] + 1) * (bbox[0][:, 3] - bbox[0][:, 1] + 1)
+                            for bbox in bboxes]))
 
         # Eqn.(1) in FPN paper
         target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0 + self.eps))
         target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max)
-        return target_lvls.to(torch.int64) - self.k_min
+        return target_lvls - self.k_min
 
 
 class Pooler(nn.Module):
@@ -76,16 +79,18 @@ class Pooler(nn.Module):
         )
 
     def convert_to_roi_format(self, boxes):
-        concat_boxes = cat([b.bbox for b in boxes], dim=0)
+        num_box = [b[0].shape[0] for b in boxes]
+        concat_boxes = cat([b[0] for b in boxes], dim=0)
         device, dtype = concat_boxes.device, concat_boxes.dtype
         ids = cat(
             [
-                torch.full((len(b), 1), i, dtype=dtype, device=device)
+                torch.full((num_box[i], 1), i, dtype=dtype, device=device)
                 for i, b in enumerate(boxes)
             ],
             dim=0,
         )
         rois = torch.cat([ids, concat_boxes], dim=1)
+
         return rois
 
     def forward(self, x, boxes):
@@ -104,7 +109,7 @@ class Pooler(nn.Module):
 
         levels = self.map_levels(boxes)
 
-        num_rois = len(rois)
+        num_rois = rois.shape[0]
         num_channels = x[0].shape[1]
         output_size = self.output_size[0]
 
@@ -117,14 +122,9 @@ class Pooler(nn.Module):
 
         for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)):
             idx_in_level = levels == level
-
-            rois_per_level = rois[idx_in_level]
-
-            num_rois_per_level = len(rois_per_level)
-            max_len = len(rois)
-            fix_shape_rois = rois_per_level.new_zeros([max_len, 5])
-            fix_shape_rois[:num_rois_per_level] = rois_per_level
-            fix_shape_res = pooler(per_level_feature, fix_shape_rois)
-            result[idx_in_level] = fix_shape_res[:num_rois_per_level]
+            pooler_res = pooler(per_level_feature, rois)
+            idx_in_level = idx_in_level[:, None, None, None]
+            res = torch.mul(idx_in_level, pooler_res)
+            result += res
 
         return result
diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/inference.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/inference.py
index c66d7d0..f67f376 100644
--- a/maskrcnn_benchmark/modeling/roi_heads/mask_head/inference.py
+++ b/maskrcnn_benchmark/modeling/roi_heads/mask_head/inference.py
@@ -4,6 +4,7 @@ import torch
 from PIL import Image
 from torch import nn
 
+from maskrcnn_benchmark.modeling.utils import cat
 from maskrcnn_benchmark.structures.bounding_box import BoxList
 
 
@@ -38,26 +39,18 @@ class MaskPostProcessor(nn.Module):
 
         # select masks coresponding to the predicted classes
         num_masks = x.shape[0]
-        labels = [bbox.get_field("labels") for bbox in boxes]
-        labels = torch.cat(labels)
+        labels = [bbox[1] for bbox in boxes]
+        labels = cat(labels)
         index = torch.arange(num_masks, device=labels.device)
         mask_prob = mask_prob[index.long(), labels.long()][:, None]
 
         if self.masker:
             mask_prob = self.masker(mask_prob, boxes)
 
-        boxes_per_image = [len(box) for box in boxes]
-        mask_prob = mask_prob.split(boxes_per_image, dim=0)
+        # boxes_per_image = [len(box[0]) for box in boxes]
+        # mask_prob = mask_prob.split(boxes_per_image, dim=0)
 
-        results = []
-        for prob, box in zip(mask_prob, boxes):
-            bbox = BoxList(box.bbox, box.size, mode="xyxy")
-            for field in box.fields():
-                bbox.add_field(field, box.get_field(field))
-            bbox.add_field("mask", prob)
-            results.append(bbox)
-
-        return results
+        return boxes, mask_prob
 
 
 class MaskPostProcessorCOCOFormat(MaskPostProcessor):
diff --git a/maskrcnn_benchmark/modeling/roi_heads/mask_head/mask_head.py b/maskrcnn_benchmark/modeling/roi_heads/mask_head/mask_head.py
index fd88cc1..6721001 100644
--- a/maskrcnn_benchmark/modeling/roi_heads/mask_head/mask_head.py
+++ b/maskrcnn_benchmark/modeling/roi_heads/mask_head/mask_head.py
@@ -1,6 +1,7 @@
 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
 import math
 import torch
+from torch.nn import functional as F
 from maskrcnn_benchmark.structures.bounding_box import BoxList
 
 from .roi_mask_feature_extractors import make_roi_mask_feature_extractor
@@ -34,22 +35,23 @@ def keep_only_positive_boxes(boxes):
 
 
 def extra_proposals(proposals):
+    outputs = []
     for proposal in proposals:
-        cur_count = len(proposal)
-        boxes = proposal.bbox
-        labels = proposal.get_field('labels')
+        cur_count = proposal[1].shape[0]
+        boxes = proposal[0]
+        labels = proposal[1]
+        scores = proposal[2]
 
         box_count = 180
         if cur_count > box_count:
             box_count = int(math.ceil(cur_count / 45)) * 45
-        new_boxes = boxes.new_zeros((box_count, 4), dtype=torch.float)
-        new_labels = boxes.new_full((box_count,), fill_value=-1, dtype=torch.int)
-        new_boxes[:cur_count] = boxes
-        new_labels[:cur_count] = labels
+        pad_diff = box_count - cur_count
+        new_boxes = F.pad(boxes, (0, 0, 0, pad_diff), value=0)
+        new_labels = F.pad(labels, (0, pad_diff), value=-1)
+        new_scores = F.pad(scores, (0, pad_diff), value=0)
 
-        proposal.bbox = new_boxes
-        proposal.add_field('labels', new_labels)
-    return proposals
+        outputs.append([new_boxes, new_labels, new_scores])
+    return outputs
 
 
 class ROIMaskHead(torch.nn.Module):
@@ -80,7 +82,7 @@ class ROIMaskHead(torch.nn.Module):
         if self.training:
             # during training, only focus on positive boxes
             all_proposals = proposals
-        proposals = extra_proposals(proposals)
+        # proposals = extra_proposals(proposals)
 
         if self.training and self.cfg.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR:
             x = features
@@ -92,7 +94,7 @@ class ROIMaskHead(torch.nn.Module):
 
         if not self.training:
             result = self.post_processor(mask_logits, proposals)
-            return x, result, {}
+            return result
 
         loss_mask = self.loss_evaluator(proposals, mask_logits, targets)
 
diff --git a/maskrcnn_benchmark/modeling/rpn/anchor_generator.py b/maskrcnn_benchmark/modeling/rpn/anchor_generator.py
index 3554353..e60ca51 100644
--- a/maskrcnn_benchmark/modeling/rpn/anchor_generator.py
+++ b/maskrcnn_benchmark/modeling/rpn/anchor_generator.py
@@ -108,19 +108,14 @@ class AnchorGenerator(nn.Module):
             inds_inside = torch.ones(anchors.shape[0], dtype=torch.uint8, device=device)
         boxlist.add_field("visibility", inds_inside)
 
-    def forward(self, image_list, feature_maps):
-        grid_height, grid_width = feature_maps[0].shape[-2:]
+    def forward(self, image_sizes, feature_maps):
         grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
         anchors_over_all_feature_maps = self.grid_anchors(grid_sizes)
         anchors = []
-        for i, (image_height, image_width) in enumerate(image_list.image_sizes):
+        for i, (image_height, image_width) in enumerate(image_sizes):
             anchors_in_image = []
             for anchors_per_feature_map in anchors_over_all_feature_maps:
-                boxlist = BoxList(
-                    anchors_per_feature_map, (image_width, image_height), mode="xyxy"
-                )
-                self.add_visibility_to(boxlist)
-                anchors_in_image.append(boxlist)
+                anchors_in_image.append(anchors_per_feature_map)
             anchors.append(anchors_in_image)
         return anchors
 
diff --git a/maskrcnn_benchmark/modeling/rpn/retinanet.py b/maskrcnn_benchmark/modeling/rpn/retinanet.py
index a588ad1..8d39725 100644
--- a/maskrcnn_benchmark/modeling/rpn/retinanet.py
+++ b/maskrcnn_benchmark/modeling/rpn/retinanet.py
@@ -186,10 +186,10 @@ class RetinaNetModule(torch.nn.Module):
 
     def _forward_test(self, anchors, box_cls, box_regression):
         N = int(box_cls[0].size(0))
-        A = int(box_regression[0].size(1) / 4)
-        C = int(box_cls[0].size(1) / A)
+        A = torch.floor_divide(box_regression[0].size(1), 4)
+        C = torch.floor_divide(box_cls[0].size(1), A)
         anchors_size = [anchor_list[0].size for anchor_list in anchors]
-        anchors_bbox = [[anchor.bbox for anchor in anchor_list] for anchor_list in anchors]
+        anchors_bbox = [[anchor for anchor in anchor_list] for anchor_list in anchors]
         anchors_per_img = [torch.cat(anchor_list, 0) for anchor_list in anchors_bbox]
 
         box_cls = self.permute_and_concat(box_cls, C)
diff --git a/maskrcnn_benchmark/modeling/rpn/retinanet_infer.py b/maskrcnn_benchmark/modeling/rpn/retinanet_infer.py
index 27431ab..6290547 100644
--- a/maskrcnn_benchmark/modeling/rpn/retinanet_infer.py
+++ b/maskrcnn_benchmark/modeling/rpn/retinanet_infer.py
@@ -1,51 +1,6 @@
 import torch
 
 from maskrcnn_benchmark.modeling.box_coder import BoxCoder
-from maskrcnn_benchmark.structures.bounding_box import BoxList
-from maskrcnn_benchmark.structures.boxlist_ops import remove_small_boxes
-
-
-def batched_nms(boxes, scores, max_output_size, iou_threshold, scores_threshold):
-    """
-    Performs non-maximum suppression in a batched fashion.
-
-    Each index value correspond to a category, and NMS
-    will not be applied between elements of different categories.
-
-    Parameters
-    ----------
-    boxes : Tensor[N, 4]
-        boxes where NMS will be performed. They
-        are expected to be in (x1, y1, x2, y2) format
-    scores : Tensor[N]
-        scores for each one of the boxes
-    idxs : Tensor[N]
-        indices of the categories for each one of the boxes.
-    iou_threshold : float
-        discards all overlapping boxes
-        with IoU > iou_threshold
-
-    Returns
-    -------
-    keep : Tensor
-        int64 tensor with the indices of
-        the elements that have been kept by NMS, sorted
-        in decreasing order of scores
-    """
-    num_classes = scores.size(1)
-    num_boxes = scores.size(0)
-    multi_bboxes = boxes.reshape(1, num_boxes, -1, 4)
-    multi_scores = scores.reshape(1, num_boxes, num_classes)
-    nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = torch.npu_batch_nms(multi_bboxes.half(), multi_scores.half(),
-                                                                              scores_threshold,
-                                                                              iou_threshold, max_output_size,
-                                                                              max_output_size)
-    nmsed_boxes = nmsed_boxes.reshape(nmsed_boxes.shape[1:])
-    nmsed_scores = nmsed_scores.reshape(nmsed_scores.shape[1])
-    nmsed_classes = nmsed_classes.reshape(nmsed_classes.shape[1])
-    nmsed_num = nmsed_num.item()
-
-    return nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num
 
 
 class RetinaNetPostProcessor(torch.nn.Module):
@@ -83,74 +38,6 @@ class RetinaNetPostProcessor(torch.nn.Module):
             box_coder = BoxCoder(weights=(10., 10., 5., 5.))
         self.box_coder = box_coder
 
-    def forward_for_single_feature_map(self, anchors, box_cls, box_regression,
-                                       pre_nms_thresh):
-        """
-        Arguments:
-            anchors: list[BoxList]
-            box_cls: tensor of size N, A * C, H, W
-            box_regression: tensor of size N, A * 4, H, W
-        """
-        device = box_cls.device
-        N, _, H, W = box_cls.shape
-        A = int(box_regression.size(1) / 4)
-        C = int(box_cls.size(1) / A)
-
-        # put in the same format as anchors
-        box_cls = box_cls.permute(0, 2, 3, 1)
-        box_cls = box_cls.reshape(N, -1, C)
-        box_cls = box_cls.sigmoid().cpu().float()
-
-        box_regression = box_regression.permute(0, 2, 3, 1)
-        box_regression = box_regression.reshape(N, -1, 4).cpu().float()
-
-        num_anchors = A * H * W
-
-        results = [[] for _ in range(N)]
-        candidate_inds = box_cls > pre_nms_thresh
-        if candidate_inds.sum().item() == 0:
-            empty_boxlists = []
-            for a in anchors:
-                empty_boxlist = BoxList(torch.zeros(1, 4).cpu().float(), a.size)
-                empty_boxlist.add_field(
-                    "labels", torch.LongTensor([-1]).cpu())
-                empty_boxlist.add_field(
-                    "scores", torch.Tensor([0]).cpu().float())
-                empty_boxlists.append(empty_boxlist)
-            return empty_boxlists
-
-        pre_nms_top_n = candidate_inds.reshape(N, -1).sum(1)
-        pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n)
-
-        for batch_idx, (per_box_cls, per_box_regression, per_pre_nms_top_n,
-                        per_candidate_inds, per_anchors) in enumerate(
-            zip(box_cls, box_regression, pre_nms_top_n, candidate_inds, anchors)):
-            # Sort and select TopN
-            per_box_cls = per_box_cls[per_candidate_inds]
-            per_box_cls, top_k_indices = \
-                per_box_cls.topk(per_pre_nms_top_n, sorted=False)
-
-            per_candidate_nonzeros = \
-                per_candidate_inds.nonzero()[top_k_indices, :]
-
-            per_box_loc = per_candidate_nonzeros[:, 0]
-            per_class = per_candidate_nonzeros[:, 1]
-            per_class += 1
-
-            detections = self.box_coder.decode_cpu(
-                per_box_regression[per_box_loc, :].view(-1, 4),
-                per_anchors.bbox[per_box_loc, :].view(-1, 4)
-            )
-
-            boxlist = BoxList(detections, per_anchors.size, mode="xyxy")
-            boxlist.add_field("labels", per_class)
-            boxlist.add_field("scores", per_box_cls)
-            boxlist = boxlist.clip_to_image(remove_empty=False)
-            boxlist = remove_small_boxes(boxlist, self.min_size)
-            results[batch_idx] = boxlist
-
-        return results
-
     def forward(self, anchors_per_img, box_cls, box_regression, anchors_size, N, C, targets=None):
         """
         Arguments:
@@ -164,7 +51,7 @@ class RetinaNetPostProcessor(torch.nn.Module):
         """
         device = box_cls.device
         box_cls = box_cls.sigmoid()
-        k = self.pre_nms_top_n * 2
+        k = self.pre_nms_top_n  # * 4
         results = []
         for i in range(N):
             cls_scores = box_cls[i]
@@ -176,10 +63,11 @@ class RetinaNetPostProcessor(torch.nn.Module):
                 achrs.view(-1, 4)
             )
             if not self.training:
-                k = k * 2
                 scores, topk_inds = torch.topk(cls_scores.flatten(), k=k, largest=True)
-                labels = topk_inds % C
-                topk_inds = topk_inds // C
+                C = torch.tensor(C, dtype=torch.int32)
+                labels = topk_inds.int() % C
+                topk_inds = torch.floor_divide(topk_inds.int(), C).long()
+                labels = labels.int()
                 bboxes = bboxes[topk_inds]
             else:
                 max_scores, labels = torch.max(cls_scores, 1)
@@ -188,31 +76,17 @@ class RetinaNetPostProcessor(torch.nn.Module):
                 scores = topk_scores
                 labels = labels[topk_inds]
             if labels.numel() == 0:
-                result = BoxList(bboxes.new_ones([1, 4]), anchor_size, mode="xyxy")
-                result.add_field("scores", bboxes.new_zeros([1, ]))
-                result.add_field("labels", bboxes.new_full((1,), -1, dtype=torch.long))
+                result_boxes = bboxes.new_ones([1, 4]).to(device)
+                result_scores = bboxes.new_zeros([1, ]).to(device)
+                result_labels = (bboxes.new_ones((1,), dtype=torch.int32) * -1).to(device)
             else:
-                multi_scores = scores.new_zeros([k, C])
-                multi_bboxes = bboxes.new_zeros([k, 4])
                 k = min(k, labels.numel())
-                multi_bboxes[:k] = bboxes[:k]
-                indices = torch.arange(0, k).to(device)
-                multi_scores[indices, labels[:k]] = scores[:k]
-
-                nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = batched_nms(multi_bboxes, multi_scores,
-                                                                                  self.fpn_post_nms_top_n,
-                                                                                  iou_threshold=self.nms_thresh,
-                                                                                  scores_threshold=self.pre_nms_thresh)
-                nmsed_classes = nmsed_classes + 1
-                result = BoxList(nmsed_boxes, anchor_size, mode="xyxy")
-                result.add_field("scores", nmsed_scores)
-                result.add_field("labels", nmsed_classes)
-                result = result.clip_to_image(remove_empty=False)
+                result_boxes = bboxes[:k]
+                result_boxes = torch.clamp(result_boxes, 0, 1344)
+                result_scores = scores[:k]
+                result_labels = labels + 1
 
-            result.bbox = result.bbox.to(device)
-            result.add_field('labels', result.get_field('labels').to(device))
-            result.add_field('scores', result.get_field('scores').to(device))
-            results.append(result)
+            results.append([result_boxes, result_labels, result_scores])
 
         return results