diff --git a/d2/converter.py b/d2/converter.py
index 6fa5ff4..42882ce 100644
--- a/d2/converter.py
+++ b/d2/converter.py
@@ -12,8 +12,8 @@ import torch
 def parse_args():
     parser = argparse.ArgumentParser("D2 model converter")
 
-    parser.add_argument("--source_model", default="", type=str, help="Path or url to the DETR model to convert")
-    parser.add_argument("--output_model", default="", type=str, help="Path where to save the converted model")
+    parser.add_argument("--source_model", default="../detr.pth", type=str, help="Path or url to the DETR model to convert")
+    parser.add_argument("--output_model", default="../detr_.pth", type=str, help="Path where to save the converted model")
     return parser.parse_args()
 
 
diff --git a/datasets/coco.py b/datasets/coco.py
index 93a436b..3359e99 100644
--- a/datasets/coco.py
+++ b/datasets/coco.py
@@ -62,6 +62,7 @@ class ConvertCocoPolysToMask(object):
         anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]
 
         boxes = [obj["bbox"] for obj in anno]
+        # print(boxes)
         # guard against no boxes via resizing
         boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
         boxes[:, 2:] += boxes[:, :2]
@@ -137,7 +138,11 @@ def make_coco_transforms(image_set):
 
     if image_set == 'val':
         return T.Compose([
-            T.RandomResize([800], max_size=1333),
+
+            # T.pad_resize(),
+            # T.RandomResize(sizes=(640,640),),
+            T.RandomResize([768], max_size=1400),
+
             normalize,
         ])
 
diff --git a/datasets/transforms.py b/datasets/transforms.py
index 0635857..a41038f 100644
--- a/datasets/transforms.py
+++ b/datasets/transforms.py
@@ -8,9 +8,10 @@ import PIL
 import torch
 import torchvision.transforms as T
 import torchvision.transforms.functional as F
-
+from PIL import Image,ImageDraw
 from util.box_ops import box_xyxy_to_cxcywh
 from util.misc import interpolate
+import numpy as np
 
 
 def crop(image, target, region):
@@ -75,7 +76,6 @@ def hflip(image, target):
 
 def resize(image, target, size, max_size=None):
     # size can be min_size (scalar) or (w, h) tuple
-
     def get_size_with_aspect_ratio(image_size, size, max_size=None):
         w, h = image_size
         if max_size is not None:
@@ -103,14 +103,39 @@ def resize(image, target, size, max_size=None):
             return get_size_with_aspect_ratio(image_size, size, max_size)
 
     size = get_size(image.size, size, max_size)
+
+    h, w = size[0], size[1]
+    scale = np.array([i for i in range(768, 1400, 256)])
+    min_scale = np.array([i for i in range(512, 900, 256)])
+
+    def resize_step(scale, x, step):
+        value = min(abs(scale - x))
+        # print(value)
+        if (x + value) % step == 0:
+            x = x + value
+            return x
+        elif (x - value) % step == 0:
+            x = x - value
+            return x
+
+    if h == 768:
+        size = (h, resize_step(scale, w, step=256))
+    elif w == 768:
+        size = (resize_step(scale, h, step=256), w)
+    elif h < 768:
+        size = (resize_step(min_scale, h, step=256), 1024)
+    elif w < 768:
+        size = (1024, resize_step(min_scale, w, step=256))
+
     rescaled_image = F.resize(image, size)
 
+
+
     if target is None:
         return rescaled_image, None
 
     ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
     ratio_width, ratio_height = ratios
-
     target = target.copy()
     if "boxes" in target:
         boxes = target["boxes"]
@@ -199,6 +224,62 @@ class RandomResize(object):
         return resize(img, target, size, self.max_size)
 
 
+
+class pad_resize(object):
+    def __init__(self, sizes=None):
+        # assert isinstance(sizes, (list, tuple))
+        self.sizes = sizes
+
+
+    def __call__(self, img, target=None):
+        # print(img.size)
+        img,target=Pad_img(img,target)
+        return resize(img, target, size=(1280,1280))
+
+
+def Pad_img(image, target):
+    # assumes that we only pad on the bottom right corners
+    # image.show()
+    # print('ori_img size',image.size)
+    h, w = image.size
+    pad_value=int(abs(h-w)/2)
+    # print('boxes',len(target['boxes']))
+    # if target['image_id'].item()==2592:
+    #     print(target['boxes'])
+    #     for i in target['boxes']:
+    #         draw=ImageDraw.Draw(image)
+    #         draw.line([(i[0].item(), i[1].item()),(i[2].item(),i[1].item()),
+    #                    (i[2].item(), i[3].item()),(i[0].item(),i[3].item()),
+    #                    (i[0].item(), i[1].item())], width=2, fill='red')
+    #     image.show()
+    if h>w:
+        padded_image = F.pad(image, (0, pad_value, 0, pad_value))
+    else:
+        padded_image = F.pad(image, (pad_value, 0, pad_value, 0))
+    h_,w_=padded_image.size
+    target = target.copy()
+    if "boxes" in target:
+        boxes = target["boxes"]
+        scaled_boxes = boxes + torch.as_tensor([abs(h-h_)/2, abs(w-w_)/2, abs(h-h_)/2, abs(w-w_)/2])
+        target["boxes"] = scaled_boxes
+    # if target['image_id'].item() == 2592:
+    #     print(target['boxes'])
+    #     for i in target['boxes']:
+    #         print(i[0],i[1])
+    #         draw=ImageDraw.Draw(padded_image)
+    #         draw.line([(i[0].item(), i[1].item()),(i[2].item(),i[1].item()),
+    #                    (i[2].item(), i[3].item()),(i[0].item(),i[3].item()),
+    #                    (i[0].item(), i[1].item())], width=2, fill='red')
+    #     padded_image.show()
+
+    if target is None:
+        return padded_image, None
+
+    target["size"] = torch.tensor([h_, w_])
+
+
+    return padded_image, target
+
 class RandomPad(object):
     def __init__(self, max_pad):
         self.max_pad = max_pad
diff --git a/engine.py b/engine.py
index ac5ea6f..6cd96ca 100644
--- a/engine.py
+++ b/engine.py
@@ -8,11 +8,11 @@ import sys
 from typing import Iterable
 
 import torch
-
+from apex import amp
 import util.misc as utils
 from datasets.coco_eval import CocoEvaluator
 from datasets.panoptic_eval import PanopticEvaluator
-
+import time
 
 def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
                     data_loader: Iterable, optimizer: torch.optim.Optimizer,
@@ -26,6 +26,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
     print_freq = 10
 
     for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
+        optimizer.zero_grad()
         samples = samples.to(device)
         targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
 
@@ -49,8 +50,9 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
             print(loss_dict_reduced)
             sys.exit(1)
 
-        optimizer.zero_grad()
-        losses.backward()
+        with amp.scale_loss(losses, optimizer) as scaled_loss:
+            scaled_loss.backward()
+        # losses.backward()
         if max_norm > 0:
             torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
         optimizer.step()
@@ -65,7 +67,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
 
 
 @torch.no_grad()
-def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir):
+def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir,batch_size):
     model.eval()
     criterion.eval()
 
@@ -85,10 +87,11 @@ def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, out
             output_dir=os.path.join(output_dir, "panoptic_eval"),
         )
 
+    val_times=[]
     for samples, targets in metric_logger.log_every(data_loader, 10, header):
         samples = samples.to(device)
         targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
-
+        start=time.time()
         outputs = model(samples)
         loss_dict = criterion(outputs, targets)
         weight_dict = criterion.weight_dict
@@ -103,7 +106,12 @@ def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, out
                              **loss_dict_reduced_scaled,
                              **loss_dict_reduced_unscaled)
         metric_logger.update(class_error=loss_dict_reduced['class_error'])
-
+        val_time=time.time()-start
+        FPS=(1/(val_time/batch_size))
+        val_times.append(FPS)
+        if len(val_times)==10:
+            print('FPS:',sum(val_times)/10)
+            val_times=[]
         orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
         results = postprocessors['bbox'](outputs, orig_target_sizes)
         if 'segm' in postprocessors.keys():
diff --git a/hubconf.py b/hubconf.py
index 328c330..87d335e 100644
--- a/hubconf.py
+++ b/hubconf.py
@@ -3,13 +3,47 @@ import torch
 
 from models.backbone import Backbone, Joiner
 from models.detr import DETR, PostProcess
-from models.position_encoding import PositionEmbeddingSine
+from models.position_encoding import PositionEmbeddingSine, PositionEmbeddingSine_onnx 
 from models.segmentation import DETRsegm, PostProcessPanoptic
 from models.transformer import Transformer
 
+from models.backbone import Joiner_onnx
+from models.detr import DETR_onnx
+from models.backbone import Backbone_onnx
+
 dependencies = ["torch", "torchvision"]
 
 
+def _make_detr_onnx(backbone_name: str, dilation=False, num_classes=91, mask=False):
+    hidden_dim = 256
+    backbone = Backbone_onnx(backbone_name, train_backbone=True, return_interm_layers=mask, dilation=dilation)
+    pos_enc = PositionEmbeddingSine_onnx(hidden_dim // 2, normalize=True)
+    backbone_with_pos_enc = Joiner_onnx(backbone, pos_enc)
+    backbone_with_pos_enc.num_channels = backbone.num_channels
+    transformer = Transformer(d_model=hidden_dim, return_intermediate_dec=True)
+    detr = DETR_onnx(backbone_with_pos_enc, transformer, num_classes=num_classes, num_queries=100)
+    if mask:
+        return DETRsegm(detr)
+    return detr
+
+
+def detr_resnet50_onnx(pretrained=False, num_classes=91, return_postprocessor=False):
+    """
+    DETR R50 with 6 encoder and 6 decoder layers.
+
+    Achieves 42/62.4 AP/AP50 on COCO val5k.
+    """
+    model = _make_detr_onnx("resnet50", dilation=False, num_classes=num_classes)
+    if pretrained:
+        checkpoint = torch.hub.load_state_dict_from_url(
+            url="https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth", map_location="cpu", check_hash=True
+        )
+        model.load_state_dict(checkpoint["model"])
+    if return_postprocessor:
+        return model, PostProcess()
+    return model
+
+
 def _make_detr(backbone_name: str, dilation=False, num_classes=91, mask=False):
     hidden_dim = 256
     backbone = Backbone(backbone_name, train_backbone=True, return_interm_layers=mask, dilation=dilation)
diff --git a/main.py b/main.py
index e5f9eff..39cbef9 100644
--- a/main.py
+++ b/main.py
@@ -80,7 +80,7 @@ def get_args_parser():
 
     # dataset parameters
     parser.add_argument('--dataset_file', default='coco')
-    parser.add_argument('--coco_path', type=str)
+    parser.add_argument('--coco_path', type=str,default='/home/xu/SJH/datasets/coco')
     parser.add_argument('--coco_panoptic_path', type=str)
     parser.add_argument('--remove_difficult', action='store_true')
 
@@ -122,6 +122,7 @@ def main(args):
     model.to(device)
 
     model_without_ddp = model
+    print(args.distributed)
     if args.distributed:
         model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
         model_without_ddp = model.module
@@ -212,7 +213,7 @@ def main(args):
                 }, checkpoint_path)
 
         test_stats, coco_evaluator = evaluate(
-            model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir
+            model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir,args.batch_size,
         )
 
         log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
diff --git a/models/backbone.py b/models/backbone.py
index 9668093..65e65c4 100644
--- a/models/backbone.py
+++ b/models/backbone.py
@@ -93,6 +93,38 @@ class Backbone(BackboneBase):
         super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
 
 
+class BackboneBase_onnx(nn.Module):
+    def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
+        super().__init__()
+        for name, parameter in backbone.named_parameters():
+            if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
+                parameter.requires_grad_(False)
+        if return_interm_layers:
+            self.return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
+        else:
+            self.return_layers = {'layer4': "0"}
+        self.body = IntermediateLayerGetter_(backbone, return_layers=self.return_layers,export_onnx=True)
+        self.num_channels = num_channels
+
+    def forward(self, input_tensor):
+        out = self.body(input_tensor)
+        return out[0]
+
+
+class Backbone_onnx(BackboneBase_onnx):
+    """ResNet backbone with frozen BatchNorm."""
+    def __init__(self, name: str,
+                 train_backbone: bool,
+                 return_interm_layers: bool,
+                 dilation: bool):
+        backbone = getattr(torchvision.models, name)(
+            replace_stride_with_dilation=[False, False, dilation],
+            pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
+        num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
+        super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
+
+
+
 class Joiner(nn.Sequential):
     def __init__(self, backbone, position_embedding):
         super().__init__(backbone, position_embedding)
@@ -105,10 +137,19 @@ class Joiner(nn.Sequential):
             out.append(x)
             # position encoding
             pos.append(self[1](x).to(x.tensors.dtype))
-
         return out, pos
 
 
+class Joiner_onnx(nn.Sequential):
+    def __init__(self, backbone, position_embedding):
+        super().__init__(backbone, position_embedding)
+
+    def forward(self, input_tensor, mask):
+        x = self[0](input_tensor)
+        pos = self[1](x, mask)
+        return x, pos.to(x.dtype)
+
+
 def build_backbone(args):
     position_embedding = build_position_encoding(args)
     train_backbone = args.lr_backbone > 0
@@ -117,3 +158,38 @@ def build_backbone(args):
     model = Joiner(backbone, position_embedding)
     model.num_channels = backbone.num_channels
     return model
+
+class IntermediateLayerGetter_(nn.ModuleDict):
+    def __init__(self, model, return_layers, export_onnx=False):
+        if not set(return_layers).issubset([name for name, _ in model.named_children()]):
+            raise ValueError("return_layers are not present in model")
+        orig_return_layers = return_layers
+        return_layers = {str(k): str(v) for k, v in return_layers.items()}
+        layers = OrderedDict()
+        self.export_onnx = export_onnx
+        for name, module in model.named_children():
+            layers[name] = module
+            if name in return_layers:
+                del return_layers[name]
+            if not return_layers:
+                break
+
+        super(IntermediateLayerGetter_, self).__init__(layers)
+        self.return_layers = orig_return_layers
+
+    def forward(self, x):
+        if self.export_onnx:
+            out = []
+            for name, module in self.items():
+                x = module(x)
+                if name in self.return_layers:
+                    out_name = self.return_layers[name]
+                    out.append(x)
+            return out
+        out = OrderedDict()
+        for name, module in self.items():
+            x = module(x)
+            if name in self.return_layers:
+                out_name = self.return_layers[name]
+                out[out_name] = x
+        return out
diff --git a/models/detr.py b/models/detr.py
index 23c2376..f7ac457 100644
--- a/models/detr.py
+++ b/models/detr.py
@@ -7,7 +7,7 @@ import torch.nn.functional as F
 from torch import nn
 
 from util import box_ops
-from util.misc import (NestedTensor, nested_tensor_from_tensor_list,
+from util.misc import (NestedTensor, nested_tensor_from_tensor_list,_onnx_nested_tensor_from_tensor_list,
                        accuracy, get_world_size, interpolate,
                        is_dist_avail_and_initialized)
 
@@ -59,11 +59,68 @@ class DETR(nn.Module):
         if isinstance(samples, (list, torch.Tensor)):
             samples = nested_tensor_from_tensor_list(samples)
         features, pos = self.backbone(samples)
-
+        # import pdb;pdb.set_trace()
         src, mask = features[-1].decompose()
         assert mask is not None
         hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]
+        outputs_class = self.class_embed(hs)
+        outputs_coord = self.bbox_embed(hs).sigmoid()
+        out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
+        if self.aux_loss:
+            out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
+        return out
+
+    @torch.jit.unused
+    def _set_aux_loss(self, outputs_class, outputs_coord):
+        # this is a workaround to make torchscript happy, as torchscript
+        # doesn't support dictionary with non-homogeneous values, such
+        # as a dict having both a Tensor and a list.
+        return [{'pred_logits': a, 'pred_boxes': b}
+                for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
+
+
+class DETR_onnx(nn.Module):
+    """ This is the DETR module that performs object detection """
+    def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False):
+        """ Initializes the model.
+        Parameters:
+            backbone: torch module of the backbone to be used. See backbone.py
+            transformer: torch module of the transformer architecture. See transformer.py
+            num_classes: number of object classes
+            num_queries: number of object queries, ie detection slot. This is the maximal number of objects
+                         DETR can detect in a single image. For COCO, we recommend 100 queries.
+            aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
+        """
+        super().__init__()
+        self.num_queries = num_queries
+        self.transformer = transformer
+        hidden_dim = transformer.d_model
+        self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
+        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
+        self.query_embed = nn.Embedding(num_queries, hidden_dim)
+        self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)
+        self.backbone = backbone
+        self.aux_loss = aux_loss
 
+    def forward(self, samples, mask):
+        """ The forward expects a NestedTensor, which consists of:
+               - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
+               - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
+
+            It returns a dict with the following elements:
+               - "pred_logits": the classification logits (including no-object) for all queries.
+                                Shape= [batch_size x num_queries x (num_classes + 1)]
+               - "pred_boxes": The normalized boxes coordinates for all queries, represented as
+                               (center_x, center_y, height, width). These values are normalized in [0, 1],
+                               relative to the size of each individual image (disregarding possible padding).
+                               See PostProcess for information on how to retrieve the unnormalized bounding box.
+               - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
+                                dictionnaries containing the two above keys for each decoder layer.
+        """
+        src, pos= self.backbone(samples, mask)
+        # src, mask = features[-1].decompose()
+        # assert mask is not None
+        hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos)[0]
         outputs_class = self.class_embed(hs)
         outputs_coord = self.bbox_embed(hs).sigmoid()
         out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
@@ -280,7 +337,6 @@ class PostProcess(nn.Module):
         img_h, img_w = target_sizes.unbind(1)
         scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
         boxes = boxes * scale_fct[:, None, :]
-
         results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]
 
         return results
@@ -346,6 +402,8 @@ def build(args):
     losses = ['labels', 'boxes', 'cardinality']
     if args.masks:
         losses += ["masks"]
+    print('losses',losses)
+
     criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict,
                              eos_coef=args.eos_coef, losses=losses)
     criterion.to(device)
diff --git a/models/matcher.py b/models/matcher.py
index 0c29147..ed45dd9 100644
--- a/models/matcher.py
+++ b/models/matcher.py
@@ -62,6 +62,7 @@ class HungarianMatcher(nn.Module):
         tgt_ids = torch.cat([v["labels"] for v in targets])
         tgt_bbox = torch.cat([v["boxes"] for v in targets])
 
+
         # Compute the classification cost. Contrary to the loss, we don't use the NLL,
         # but approximate it in 1 - proba[target class].
         # The 1 is a constant that doesn't change the matching, it can be ommitted.
@@ -71,6 +72,7 @@ class HungarianMatcher(nn.Module):
         cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
 
         # Compute the giou cost betwen boxes
+
         cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
 
         # Final cost matrix
diff --git a/models/position_encoding.py b/models/position_encoding.py
index 73ae39e..363abfb 100644
--- a/models/position_encoding.py
+++ b/models/position_encoding.py
@@ -29,7 +29,8 @@ class PositionEmbeddingSine(nn.Module):
         x = tensor_list.tensors
         mask = tensor_list.mask
         assert mask is not None
-        not_mask = ~mask
+        # not_mask = ~mask
+        not_mask = (~mask).float()
         y_embed = not_mask.cumsum(1, dtype=torch.float32)
         x_embed = not_mask.cumsum(2, dtype=torch.float32)
         if self.normalize:
@@ -47,6 +48,49 @@ class PositionEmbeddingSine(nn.Module):
         pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
         return pos
 
+    
+class PositionEmbeddingSine_onnx(nn.Module):
+    """
+    This is a more standard version of the position embedding, very similar to the one
+    used by the Attention is all you need paper, generalized to work on images.
+    """
+    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+        super().__init__()
+        self.num_pos_feats = num_pos_feats
+        self.temperature = temperature
+        self.normalize = normalize
+        if scale is not None and normalize is False:
+            raise ValueError("normalize should be True if scale is passed")
+        if scale is None:
+            scale = 2 * math.pi
+        self.scale = scale
+
+    def forward(self, x, mask):
+        # mask = torch.zeros([x.shape[0], x.shape[2], x.shape[3]], dtype=bool)
+        # mask = torch.zeros([1, 40, 40], dtype=bool)
+        # assert mask is not None
+        # not_mask = ~mask
+        batch, w, h = mask.size()
+        not_mask = (~mask).float()
+        y_embed = not_mask.cumsum(1, dtype=torch.float32)
+        x_embed = not_mask.cumsum(2, dtype=torch.float32)
+        if self.normalize:
+            eps = 1e-6
+            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+        pos_x = x_embed[:, :, :, None] / dim_t
+        pos_y = y_embed[:, :, :, None] / dim_t
+        # pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+        # pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).reshape(batch, w, h, -1)
+        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).reshape(batch, w, h, -1)
+        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+        return pos
+
 
 class PositionEmbeddingLearned(nn.Module):
     """
@@ -69,6 +113,7 @@ class PositionEmbeddingLearned(nn.Module):
         j = torch.arange(h, device=x.device)
         x_emb = self.col_embed(i)
         y_emb = self.row_embed(j)
+
         pos = torch.cat([
             x_emb.unsqueeze(0).repeat(h, 1, 1),
             y_emb.unsqueeze(1).repeat(1, w, 1),
diff --git a/models/transformer.py b/models/transformer.py
index dcd5367..75748ed 100644
--- a/models/transformer.py
+++ b/models/transformer.py
@@ -45,12 +45,16 @@ class Transformer(nn.Module):
                 nn.init.xavier_uniform_(p)
 
     def forward(self, src, mask, query_embed, pos_embed):
+        mask = mask.to(torch.bool)
         # flatten NxCxHxW to HWxNxC
         bs, c, h, w = src.shape
-        src = src.flatten(2).permute(2, 0, 1)
-        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
+        src = src.reshape(bs, c, -1).permute(2, 0, 1)
+        pos_embed = pos_embed.reshape(bs, c, -1).permute(2, 0, 1)
+        # src = src.flatten(2).permute(2, 0, 1)
+        # pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
         query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
-        mask = mask.flatten(1)
+        mask = mask.reshape(bs, -1)
+        # mask = mask.flatten(1)
 
         tgt = torch.zeros_like(query_embed)
         memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
diff --git a/test_all.py b/test_all.py
index 7153892..9a41517 100644
--- a/test_all.py
+++ b/test_all.py
@@ -135,8 +135,13 @@ class ONNXExporterTester(unittest.TestCase):
 
         onnx_io = io.BytesIO()
         # export to onnx with the first input
-        torch.onnx.export(model, inputs_list[0], onnx_io,
-                          do_constant_folding=do_constant_folding, opset_version=12,
+        # print(inputs_list[0])
+
+        torch.onnx.export(model, inputs_list[0],onnx_io,
+                          do_constant_folding=do_constant_folding, opset_version=11,
+                          dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names)
+        torch.onnx.export(model, inputs_list[0], 'detr_640.onnx',
+                          do_constant_folding=do_constant_folding, opset_version=11,
                           dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names)
         # validate the exported model with onnx runtime
         for test_inputs in inputs_list:
@@ -177,13 +182,22 @@ class ONNXExporterTester(unittest.TestCase):
 
     def test_model_onnx_detection(self):
         model = detr_resnet50(pretrained=False).eval()
+        model.load_state_dict(torch.load('model_file/detr.pth',map_location="cpu")['model'])
         dummy_image = torch.ones(1, 3, 800, 800) * 0.3
         model(dummy_image)
-
+        import cv2
+        import numpy as np
+        img = cv2.imread('785.jpg').astype(np.float32)
+        img = cv2.resize(img, (640, 640))
+        img = np.transpose(img)
+        img = np.expand_dims(img, 0)
+        torch_input = torch.from_numpy(img)
         # Test exported model on images of different size, or dummy input
+
         self.run_model(
             model,
-            [(torch.rand(1, 3, 750, 800),)],
+            [(torch_input,)],
+            dynamic_axes={'inputs': {0: 'batch', 2: 'height', 3: 'width'}},
             input_names=["inputs"],
             output_names=["pred_logits", "pred_boxes"],
             tolerate_small_mismatch=True,
diff --git a/util/misc.py b/util/misc.py
index dfa9fb5..20a3b8c 100644
--- a/util/misc.py
+++ b/util/misc.py
@@ -1,7 +1,6 @@
 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
 """
 Misc functions, including distributed helpers.
-
 Mostly copy-paste from torchvision references.
 """
 import os
@@ -313,8 +312,9 @@ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
             return _onnx_nested_tensor_from_tensor_list(tensor_list)
 
         # TODO make it support different-sized images
-        max_size = _max_by_axis([list(img.shape) for img in tensor_list])
+
         # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
+        max_size = _max_by_axis([list(img.shape) for img in tensor_list])
         batch_shape = [len(tensor_list)] + max_size
         b, c, h, w = batch_shape
         dtype = tensor_list[0].dtype
@@ -333,11 +333,11 @@ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
 # nested_tensor_from_tensor_list() that is supported by ONNX tracing.
 @torch.jit.unused
 def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
-    max_size = []
-    for i in range(tensor_list[0].dim()):
-        max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
-        max_size.append(max_size_i)
-    max_size = tuple(max_size)
+    # max_size = []
+    # for i in range(tensor_list[0].dim()):
+    #     max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
+    #     max_size.append(max_size_i)
+    # max_size = tuple(max_size)
 
     # work around for
     # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
@@ -346,7 +346,8 @@ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTen
     padded_imgs = []
     padded_masks = []
     for img in tensor_list:
-        padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
+        # padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
+        padding = [torch.tensor(0), torch.tensor(0), torch.tensor(0)]
         padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
         padded_imgs.append(padded_img)