05360171创建于 2022年3月18日历史提交
diff --git a/effdet/__init__.py b/effdet/__init__.py
index c2aa4bc..ae65662 100644
--- a/effdet/__init__.py
+++ b/effdet/__init__.py
@@ -1,5 +1,5 @@
 from .efficientdet import EfficientDet
-from .bench import DetBenchPredict, DetBenchTrain, unwrap_bench
+from .bench import DetBenchPredict, unwrap_bench
 from .data import create_dataset, create_loader, create_parser, DetectionDatset, SkipSubset
 from .evaluator import CocoEvaluator, PascalEvaluator, OpenImagesEvaluator, create_evaluator
 from .config import get_efficientdet_config, default_detection_model_configs
diff --git a/effdet/bench.py b/effdet/bench.py
index b528c8b..7cc6864 100644
--- a/effdet/bench.py
+++ b/effdet/bench.py
@@ -32,6 +32,7 @@ def _post_process(
 
         num_classes (int): number of output classes
     """
+    
     batch_size = cls_outputs[0].shape[0]
     cls_outputs_all = torch.cat([
         cls_outputs[level].permute(0, 2, 3, 1).reshape([batch_size, -1, num_classes])
@@ -56,7 +57,7 @@ def _post_process(
     return cls_outputs_all_after_topk, box_outputs_all_after_topk, indices_all, classes_all
 
 
-@torch.jit.script
+
 def _batch_detection(
         batch_size: int, class_out, box_out, anchor_boxes, indices, classes,
         img_scale: Optional[torch.Tensor] = None,
@@ -77,22 +78,21 @@ def _batch_detection(
 
 
 class DetBenchPredict(nn.Module):
-    def __init__(self, model):
+    def __init__(self, config):
         super(DetBenchPredict, self).__init__()
-        self.model = model
-        self.config = model.config  # FIXME remove this when we can use @property (torchscript limitation)
-        self.num_levels = model.config.num_levels
-        self.num_classes = model.config.num_classes
-        self.anchors = Anchors.from_config(model.config)
-        self.max_detection_points = model.config.max_detection_points
-        self.max_det_per_image = model.config.max_det_per_image
-        self.soft_nms = model.config.soft_nms
-
-    def forward(self, x, img_info: Optional[Dict[str, torch.Tensor]] = None):
-        class_out, box_out = self.model(x)
+        self.config=config
+        self.num_levels = config.num_levels
+        self.num_classes = config.num_classes
+        self.anchors = Anchors.from_config(config)
+        self.max_detection_points = config.max_detection_points
+        self.max_det_per_image = config.max_det_per_image
+        self.soft_nms = config.soft_nms
+
+    def forward(self, x, class_out,box_out,img_info: Optional[Dict[str, torch.Tensor]] = None):
         class_out, box_out, indices, classes = _post_process(
             class_out, box_out, num_levels=self.num_levels, num_classes=self.num_classes,
             max_detection_points=self.max_detection_points)
+
         if img_info is None:
             img_scale, img_size = None, None
         else:
@@ -103,46 +103,7 @@ class DetBenchPredict(nn.Module):
         )
 
 
-class DetBenchTrain(nn.Module):
-    def __init__(self, model, create_labeler=True):
-        super(DetBenchTrain, self).__init__()
-        self.model = model
-        self.config = model.config  # FIXME remove this when we can use @property (torchscript limitation)
-        self.num_levels = model.config.num_levels
-        self.num_classes = model.config.num_classes
-        self.anchors = Anchors.from_config(model.config)
-        self.max_detection_points = model.config.max_detection_points
-        self.max_det_per_image = model.config.max_det_per_image
-        self.soft_nms = model.config.soft_nms
-        self.anchor_labeler = None
-        if create_labeler:
-            self.anchor_labeler = AnchorLabeler(self.anchors, self.num_classes, match_threshold=0.5)
-        self.loss_fn = DetectionLoss(model.config)
-
-    def forward(self, x, target: Dict[str, torch.Tensor]):
-        class_out, box_out = self.model(x)
-        if self.anchor_labeler is None:
-            # target should contain pre-computed anchor labels if labeler not present in bench
-            assert 'label_num_positives' in target
-            cls_targets = [target[f'label_cls_{l}'] for l in range(self.num_levels)]
-            box_targets = [target[f'label_bbox_{l}'] for l in range(self.num_levels)]
-            num_positives = target['label_num_positives']
-        else:
-            cls_targets, box_targets, num_positives = self.anchor_labeler.batch_label_anchors(
-                target['bbox'], target['cls'])
-
-        loss, class_loss, box_loss = self.loss_fn(class_out, box_out, cls_targets, box_targets, num_positives)
-        output = {'loss': loss, 'class_loss': class_loss, 'box_loss': box_loss}
-        if not self.training:
-            # if eval mode, output detections for evaluation
-            class_out_pp, box_out_pp, indices, classes = _post_process(
-                class_out, box_out, num_levels=self.num_levels, num_classes=self.num_classes,
-                max_detection_points=self.max_detection_points)
-            output['detections'] = _batch_detection(
-                x.shape[0], class_out_pp, box_out_pp, self.anchors.boxes, indices, classes,
-                target['img_scale'], target['img_size'],
-                max_det_per_image=self.max_det_per_image, soft_nms=self.soft_nms)
-        return output
+
 
 
 def unwrap_bench(model):
diff --git a/effdet/config/model_config.py b/effdet/config/model_config.py
index df7f5f5..598d38d 100644
--- a/effdet/config/model_config.py
+++ b/effdet/config/model_config.py
@@ -12,20 +12,15 @@ from copy import deepcopy
 def default_detection_model_configs():
     """Returns a default detection configs."""
     h = OmegaConf.create()
-
     # model name.
     h.name = 'tf_efficientdet_d1'
-
     h.backbone_name = 'tf_efficientnet_b1'
     h.backbone_args = None  # FIXME sort out kwargs vs config for backbone creation
     h.backbone_indices = None
-
     # model specific, input preprocessing parameters
     h.image_size = (640, 640)
-
     # dataset specific head parameters
     h.num_classes = 90
-
     # feature + anchor config
     h.min_level = 3
     h.max_level = 7
@@ -36,7 +31,6 @@ def default_detection_model_configs():
     # aspect ratios can be specified as below too, pairs will be calc as sqrt(val), 1/sqrt(val)
     #h.aspect_ratios = [1.0, 2.0, 0.5]
     h.anchor_scale = 4.0
-
     # FPN and head config
     h.pad_type = 'same'  # original TF models require an equivalent of Tensorflow 'SAME' padding
     h.act_type = 'swish'
@@ -75,365 +69,11 @@ def default_detection_model_configs():
     h.soft_nms = False  # use soft-nms, this is incredibly slow
     h.max_detection_points = 5000  # max detections for post process, input to NMS
     h.max_det_per_image = 100  # max detections per image limit, output of NMS
-
     return h
 
 
 efficientdet_model_param_dict = dict(
     # Models with PyTorch friendly padding and my PyTorch pretrained backbones, training TBD
-    efficientdet_d0=dict(
-        name='efficientdet_d0',
-        backbone_name='efficientnet_b0',
-        image_size=(512, 512),
-        fpn_channels=64,
-        fpn_cell_repeats=3,
-        box_class_repeats=3,
-        pad_type='',
-        redundant_bias=False,
-        backbone_args=dict(drop_path_rate=0.1),
-        url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/efficientdet_d0-f3276ba8.pth',
-    ),
-    efficientdet_d1=dict(
-        name='efficientdet_d1',
-        backbone_name='efficientnet_b1',
-        image_size=(640, 640),
-        fpn_channels=88,
-        fpn_cell_repeats=4,
-        box_class_repeats=3,
-        pad_type='',
-        redundant_bias=False,
-        backbone_args=dict(drop_path_rate=0.2),
-        url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/efficientdet_d1-bb7e98fe.pth',
-    ),
-    efficientdet_d2=dict(
-        name='efficientdet_d2',
-        backbone_name='efficientnet_b2',
-        image_size=(768, 768),
-        fpn_channels=112,
-        fpn_cell_repeats=5,
-        box_class_repeats=3,
-        pad_type='',
-        redundant_bias=False,
-        backbone_args=dict(drop_path_rate=0.2),
-        url='',  # no pretrained weights yet
-    ),
-    efficientdet_d3=dict(
-        name='efficientdet_d3',
-        backbone_name='efficientnet_b3',
-        image_size=(896, 896),
-        fpn_channels=160,
-        fpn_cell_repeats=6,
-        box_class_repeats=4,
-        pad_type='',
-        redundant_bias=False,
-        backbone_args=dict(drop_path_rate=0.2),
-        url='',  # no pretrained weights yet
-    ),
-    efficientdet_d4=dict(
-        name='efficientdet_d4',
-        backbone_name='efficientnet_b4',
-        image_size=(1024, 1024),
-        fpn_channels=224,
-        fpn_cell_repeats=7,
-        box_class_repeats=4,
-        backbone_args=dict(drop_path_rate=0.2),
-    ),
-    efficientdet_d5=dict(
-        name='efficientdet_d5',
-        backbone_name='efficientnet_b5',
-        image_size=(1280, 1280),
-        fpn_channels=288,
-        fpn_cell_repeats=7,
-        box_class_repeats=4,
-        backbone_args=dict(drop_path_rate=0.2),
-        url='',
-    ),
-
-    # My own experimental configs with alternate models, training TBD
-    # Note: any 'timm' model in the EfficientDet family can be used as a backbone here.
-    resdet50=dict(
-        name='resdet50',
-        backbone_name='resnet50',
-        image_size=(640, 640),
-        fpn_channels=88,
-        fpn_cell_repeats=4,
-        box_class_repeats=3,
-        pad_type='',
-        act_type='relu',
-        redundant_bias=False,
-        separable_conv=False,
-        backbone_args=dict(drop_path_rate=0.2),
-        url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/resdet50_416-08676892.pth',
-    ),
-    cspresdet50=dict(
-        name='cspresdet50',
-        backbone_name='cspresnet50',
-        image_size=(768, 768),
-        aspect_ratios=[1.0, 2.0, 0.5],
-        fpn_channels=88,
-        fpn_cell_repeats=4,
-        box_class_repeats=3,
-        pad_type='',
-        act_type='leaky_relu',
-        head_act_type='silu',
-        downsample_type='bilinear',
-        upsample_type='bilinear',
-        redundant_bias=False,
-        separable_conv=False,
-        head_bn_level_first=True,
-        backbone_args=dict(drop_path_rate=0.2),
-        url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/cspresdet50b-386da277.pth',
-    ),
-    cspresdext50=dict(
-        name='cspresdext50',
-        backbone_name='cspresnext50',
-        image_size=(640, 640),
-        aspect_ratios=[1.0, 2.0, 0.5],
-        fpn_channels=88,
-        fpn_cell_repeats=4,
-        box_class_repeats=3,
-        pad_type='',
-        act_type='leaky_relu',
-        redundant_bias=False,
-        separable_conv=False,
-        head_bn_level_first=True,
-        backbone_args=dict(drop_path_rate=0.2),
-        url='',
-    ),
-    cspresdext50pan=dict(
-        name='cspresdext50pan',
-        backbone_name='cspresnext50',
-        image_size=(640, 640),
-        aspect_ratios=[1.0, 2.0, 0.5],
-        fpn_channels=88,
-        fpn_cell_repeats=3,
-        box_class_repeats=3,
-        pad_type='',
-        act_type='leaky_relu',
-        fpn_name='pan_fa',  # PAN FPN experiment
-        redundant_bias=False,
-        separable_conv=False,
-        head_bn_level_first=True,
-        backbone_args=dict(drop_path_rate=0.2),
-        url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/cspresdext50pan-92fdd094.pth',
-    ),
-    cspdarkdet53=dict(
-        name='cspdarkdet53',
-        backbone_name='cspdarknet53',
-        image_size=(640, 640),
-        aspect_ratios=[1.0, 2.0, 0.5],
-        fpn_channels=88,
-        fpn_cell_repeats=4,
-        box_class_repeats=3,
-        pad_type='',
-        act_type='leaky_relu',
-        redundant_bias=False,
-        separable_conv=False,
-        head_bn_level_first=True,
-        backbone_args=dict(drop_path_rate=0.2),
-        backbone_indices=(3, 4, 5),
-        url='',
-    ),
-    cspdarkdet53m=dict(
-        name='cspdarkdet53m',
-        backbone_name='cspdarknet53',
-        image_size=(768, 768),
-        aspect_ratios=[1.0, 2.0, 0.5],
-        fpn_channels=96,
-        fpn_cell_repeats=4,
-        box_class_repeats=3,
-        pad_type='',
-        fpn_name='qufpn_fa',
-        act_type='leaky_relu',
-        head_act_type='mish',
-        downsample_type='bilinear',
-        upsample_type='bilinear',
-        redundant_bias=False,
-        separable_conv=False,
-        head_bn_level_first=True,
-        backbone_args=dict(drop_path_rate=0.2),
-        backbone_indices=(3, 4, 5),
-        url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/cspdarkdet53m-79062b2d.pth',
-    ),
-    mixdet_m=dict(
-        name='mixdet_m',
-        backbone_name='mixnet_m',
-        image_size=(512, 512),
-        aspect_ratios=[1.0, 2.0, 0.5],
-        fpn_channels=64,
-        fpn_cell_repeats=3,
-        box_class_repeats=3,
-        pad_type='',
-        redundant_bias=False,
-        head_bn_level_first=True,
-        backbone_args=dict(drop_path_rate=0.1),
-        url='',  # no pretrained weights yet
-    ),
-    mixdet_l=dict(
-        name='mixdet_l',
-        backbone_name='mixnet_l',
-        image_size=(640, 640),
-        aspect_ratios=[1.0, 2.0, 0.5],
-        fpn_channels=88,
-        fpn_cell_repeats=4,
-        box_class_repeats=3,
-        pad_type='',
-        redundant_bias=False,
-        head_bn_level_first=True,
-        backbone_args=dict(drop_path_rate=0.2),
-        url='',  # no pretrained weights yet
-    ),
-    mobiledetv2_110d=dict(
-        name='mobiledetv2_110d',
-        backbone_name='mobilenetv2_110d',
-        image_size=(384, 384),
-        aspect_ratios=[1.0, 2.0, 0.5],
-        fpn_channels=48,
-        fpn_cell_repeats=3,
-        box_class_repeats=3,
-        pad_type='',
-        act_type='relu6',
-        redundant_bias=False,
-        head_bn_level_first=True,
-        backbone_args=dict(drop_path_rate=0.05),
-        url='',  # no pretrained weights yet
-    ),
-    mobiledetv2_120d=dict(
-        name='mobiledetv2_120d',
-        backbone_name='mobilenetv2_120d',
-        image_size=(512, 512),
-        aspect_ratios=[1.0, 2.0, 0.5],
-        fpn_channels=56,
-        fpn_cell_repeats=3,
-        box_class_repeats=3,
-        pad_type='',
-        act_type='relu6',
-        redundant_bias=False,
-        head_bn_level_first=True,
-        backbone_args=dict(drop_path_rate=0.1),
-        url='',  # no pretrained weights yet
-    ),
-    mobiledetv3_large=dict(
-        name='mobiledetv3_large',
-        backbone_name='mobilenetv3_large_100',
-        image_size=(512, 512),
-        aspect_ratios=[1.0, 2.0, 0.5],
-        fpn_channels=64,
-        fpn_cell_repeats=3,
-        box_class_repeats=3,
-        pad_type='',
-        act_type='hard_swish',
-        redundant_bias=False,
-        head_bn_level_first=True,
-        backbone_args=dict(drop_path_rate=0.1),
-        url='',  # no pretrained weights yet
-    ),
-    efficientdet_q0=dict(
-        name='efficientdet_q0',
-        backbone_name='efficientnet_b0',
-        image_size=(512, 512),
-        fpn_channels=64,
-        fpn_cell_repeats=3,
-        box_class_repeats=3,
-        pad_type='',
-        fpn_name='qufpn_fa',  # quad-fpn + fast attn experiment
-        redundant_bias=False,
-        head_bn_level_first=True,
-        backbone_args=dict(drop_path_rate=0.1),
-        url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/efficientdet_q0-bdf1bdb5.pth',
-    ),
-    efficientdet_q1=dict(
-        name='efficientdet_q1',
-        backbone_name='efficientnet_b1',
-        image_size=(640, 640),
-        fpn_channels=88,
-        fpn_cell_repeats=3,
-        box_class_repeats=3,
-        pad_type='',
-        fpn_name='qufpn_fa',  # quad-fpn + fast attn experiment
-        downsample_type='bilinear',
-        upsample_type='bilinear',
-        redundant_bias=False,
-        head_bn_level_first=True,
-        backbone_args=dict(drop_path_rate=0.2),
-        url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/efficientdet_q1b-d0612140.pth',
-    ),
-    efficientdet_q2=dict(
-        name='efficientdet_q2',
-        backbone_name='efficientnet_b2',
-        image_size=(768, 768),
-        fpn_channels=112,
-        fpn_cell_repeats=4,
-        box_class_repeats=3,
-        pad_type='',
-        fpn_name='qufpn_fa',  # quad-fpn + fast attn experiment
-        redundant_bias=False,
-        head_bn_level_first=True,
-        backbone_args=dict(drop_path_rate=0.2),
-        url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/efficientdet_q2-0f7564e5.pth',
-    ),
-    efficientdet_w0=dict(
-        name='efficientdet_w0',  # 'wide'
-        backbone_name='efficientnet_b0',
-        image_size=(512, 512),
-        aspect_ratios=[1.0, 2.0, 0.5],
-        fpn_channels=80,
-        fpn_cell_repeats=3,
-        box_class_repeats=3,
-        pad_type='',
-        redundant_bias=False,
-        head_bn_level_first=True,
-        backbone_args=dict(
-            drop_path_rate=0.1,
-            feature_location='depthwise'),  # features from after DW/SE in IR block
-        url='',  # no pretrained weights yet
-    ),
-    efficientdet_es=dict(
-        name='efficientdet_es',   #EdgeTPU-Small
-        backbone_name='efficientnet_es',
-        image_size=(512, 512),
-        aspect_ratios=[1.0, 2.0, 0.5],
-        fpn_channels=72,
-        fpn_cell_repeats=3,
-        box_class_repeats=3,
-        pad_type='',
-        act_type='relu',
-        redundant_bias=False,
-        head_bn_level_first=True,
-        separable_conv=False,
-        backbone_args=dict(drop_path_rate=0.1),
-        url='',
-    ),
-    efficientdet_em=dict(
-        name='efficientdet_em',  # Edge-TPU Medium
-        backbone_name='efficientnet_em',
-        image_size=(640, 640),
-        aspect_ratios=[1.0, 2.0, 0.5],
-        fpn_channels=96,
-        fpn_cell_repeats=4,
-        box_class_repeats=3,
-        pad_type='',
-        act_type='relu',
-        redundant_bias=False,
-        head_bn_level_first=True,
-        separable_conv=False,
-        backbone_args=dict(drop_path_rate=0.2),
-        url='',  # no pretrained weights yet
-    ),
-    efficientdet_lite0=dict(
-        name='efficientdet_lite0',
-        backbone_name='efficientnet_lite0',
-        image_size=(512, 512),
-        fpn_channels=64,
-        fpn_cell_repeats=3,
-        box_class_repeats=3,
-        act_type='relu',
-        redundant_bias=False,
-        head_bn_level_first=True,
-        backbone_args=dict(drop_path_rate=0.1),
-        url='',
-    ),
-
     # Models ported from Tensorflow with pretrained backbones ported from Tensorflow
     tf_efficientdet_d0=dict(
         name='tf_efficientdet_d0',
@@ -445,67 +85,6 @@ efficientdet_model_param_dict = dict(
         backbone_args=dict(drop_path_rate=0.2),
         url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d0_34-f153e0cf.pth',
     ),
-    tf_efficientdet_d1=dict(
-        name='tf_efficientdet_d1',
-        backbone_name='tf_efficientnet_b1',
-        image_size=(640, 640),
-        fpn_channels=88,
-        fpn_cell_repeats=4,
-        box_class_repeats=3,
-        backbone_args=dict(drop_path_rate=0.2),
-        url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d1_40-a30f94af.pth'
-    ),
-    tf_efficientdet_d2=dict(
-        name='tf_efficientdet_d2',
-        backbone_name='tf_efficientnet_b2',
-        image_size=(768, 768),
-        fpn_channels=112,
-        fpn_cell_repeats=5,
-        box_class_repeats=3,
-        backbone_args=dict(drop_path_rate=0.2),
-        url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d2_43-8107aa99.pth',
-    ),
-    tf_efficientdet_d3=dict(
-        name='tf_efficientdet_d3',
-        backbone_name='tf_efficientnet_b3',
-        image_size=(896, 896),
-        fpn_channels=160,
-        fpn_cell_repeats=6,
-        box_class_repeats=4,
-        backbone_args=dict(drop_path_rate=0.2),
-        url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d3_47-0b525f35.pth',
-    ),
-    tf_efficientdet_d4=dict(
-        name='tf_efficientdet_d4',
-        backbone_name='tf_efficientnet_b4',
-        image_size=(1024, 1024),
-        fpn_channels=224,
-        fpn_cell_repeats=7,
-        box_class_repeats=4,
-        backbone_args=dict(drop_path_rate=0.2),
-        url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d4_49-f56376d9.pth',
-    ),
-    tf_efficientdet_d5=dict(
-        name='tf_efficientdet_d5',
-        backbone_name='tf_efficientnet_b5',
-        image_size=(1280, 1280),
-        fpn_channels=288,
-        fpn_cell_repeats=7,
-        box_class_repeats=4,
-        backbone_args=dict(drop_path_rate=0.2),
-        url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d5_51-c79f9be6.pth',
-    ),
-    tf_efficientdet_d6=dict(
-        name='tf_efficientdet_d6',
-        backbone_name='tf_efficientnet_b6',
-        image_size=(1280, 1280),
-        fpn_channels=384,
-        fpn_cell_repeats=8,
-        box_class_repeats=5,
-        fpn_name='bifpn_sum',  # Use unweighted sum for training stability.
-        backbone_args=dict(drop_path_rate=0.2),
-        url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d6_52-4eda3773.pth'
-    ),
     tf_efficientdet_d7=dict(
         name='tf_efficientdet_d7',
         backbone_name='tf_efficientnet_b6',
@@ -518,160 +97,12 @@ efficientdet_model_param_dict = dict(
         backbone_args=dict(drop_path_rate=0.2),
         url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d7_53-6d1d7a95.pth'
     ),
-    tf_efficientdet_d7x=dict(
-        name='tf_efficientdet_d7x',
-        backbone_name='tf_efficientnet_b7',
-        image_size=(1536, 1536),
-        fpn_channels=384,
-        fpn_cell_repeats=8,
-        box_class_repeats=5,
-        anchor_scale=4.0,
-        max_level=8,
-        fpn_name='bifpn_sum',  # Use unweighted sum for training stability.
-        backbone_args=dict(drop_path_rate=0.2),
-        url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d7x-f390b87c.pth'
-    ),
 
-    #  Models ported from Tensorflow AdvProp+AA weights
-    #  https://github.com/google/automl/blob/master/efficientdet/Det-AdvProp.md
-    tf_efficientdet_d0_ap=dict(
-        name='tf_efficientdet_d0_ap',
-        backbone_name='tf_efficientnet_b0',
-        image_size=(512, 512),
-        fpn_channels=64,
-        fpn_cell_repeats=3,
-        box_class_repeats=3,
-        mean=(0.5, 0.5, 0.5),
-        std=(0.5, 0.5, 0.5),
-        fill_color=0,
-        backbone_args=dict(drop_path_rate=0.2),
-        url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d0_ap-d0cdbd0a.pth',
-    ),
-    tf_efficientdet_d1_ap=dict(
-        name='tf_efficientdet_d1_ap',
-        backbone_name='tf_efficientnet_b1',
-        image_size=(640, 640),
-        fpn_channels=88,
-        fpn_cell_repeats=4,
-        box_class_repeats=3,
-        mean=(0.5, 0.5, 0.5),
-        std=(0.5, 0.5, 0.5),
-        fill_color=0,
-        backbone_args=dict(drop_path_rate=0.2),
-        url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d1_ap-7721d075.pth'
-    ),
-    tf_efficientdet_d2_ap=dict(
-        name='tf_efficientdet_d2_ap',
-        backbone_name='tf_efficientnet_b2',
-        image_size=(768, 768),
-        fpn_channels=112,
-        fpn_cell_repeats=5,
-        box_class_repeats=3,
-        mean=(0.5, 0.5, 0.5),
-        std=(0.5, 0.5, 0.5),
-        fill_color=0,
-        backbone_args=dict(drop_path_rate=0.2),
-        url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d2_ap-a2995c19.pth',
-    ),
-    tf_efficientdet_d3_ap=dict(
-        name='tf_efficientdet_d3_ap',
-        backbone_name='tf_efficientnet_b3',
-        image_size=(896, 896),
-        fpn_channels=160,
-        fpn_cell_repeats=6,
-        box_class_repeats=4,
-        mean=(0.5, 0.5, 0.5),
-        std=(0.5, 0.5, 0.5),
-        fill_color=0,
-        backbone_args=dict(drop_path_rate=0.2),
-        url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d3_ap-e4a2feab.pth',
-    ),
-    tf_efficientdet_d4_ap=dict(
-        name='tf_efficientdet_d4_ap',
-        backbone_name='tf_efficientnet_b4',
-        image_size=(1024, 1024),
-        fpn_channels=224,
-        fpn_cell_repeats=7,
-        box_class_repeats=4,
-        mean=(0.5, 0.5, 0.5),
-        std=(0.5, 0.5, 0.5),
-        fill_color=0,
-        backbone_args=dict(drop_path_rate=0.2),
-        url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d4_ap-f601a5fc.pth',
-    ),
-    tf_efficientdet_d5_ap=dict(
-        name='tf_efficientdet_d5_ap',
-        backbone_name='tf_efficientnet_b5',
-        image_size=(1280, 1280),
-        fpn_channels=288,
-        fpn_cell_repeats=7,
-        box_class_repeats=4,
-        mean=(0.5, 0.5, 0.5),
-        std=(0.5, 0.5, 0.5),
-        fill_color=0,
-        backbone_args=dict(drop_path_rate=0.2),
-        url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_d5_ap-3673ae5d.pth',
-    ),
 
-    # The lite configs are in TF automl repository but no weights yet and listed as 'not final'
-    tf_efficientdet_lite0=dict(
-        name='tf_efficientdet_lite0',
-        backbone_name='tf_efficientnet_lite0',
-        image_size=(512, 512),
-        fpn_channels=64,
-        fpn_cell_repeats=3,
-        box_class_repeats=3,
-        act_type='relu',
-        redundant_bias=False,
-        backbone_args=dict(drop_path_rate=0.1),
-        # unlike other tf_ models, this was not ported from tf automl impl, but trained from tf pretrained efficient lite
-        # weights using this code, will likely replace if/when official det-lite weights are released
-        url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/tf_efficientdet_lite0-f5f303a9.pth',
-    ),
-    tf_efficientdet_lite1=dict(
-        name='tf_efficientdet_lite1',
-        backbone_name='tf_efficientnet_lite1',
-        image_size=(640, 640),
-        fpn_channels=88,
-        fpn_cell_repeats=4,
-        box_class_repeats=3,
-        act_type='relu',
-        backbone_args=dict(drop_path_rate=0.2),
-        url='',  # no pretrained weights yet
-    ),
-    tf_efficientdet_lite2=dict(
-        name='tf_efficientdet_lite2',
-        backbone_name='tf_efficientnet_lite2',
-        image_size=(768, 768),
-        fpn_channels=112,
-        fpn_cell_repeats=5,
-        box_class_repeats=3,
-        act_type='relu',
-        backbone_args=dict(drop_path_rate=0.2),
-        url='',
-    ),
-    tf_efficientdet_lite3=dict(
-        name='tf_efficientdet_lite3',
-        backbone_name='tf_efficientnet_lite3',
-        image_size=(896, 896),
-        fpn_channels=160,
-        fpn_cell_repeats=6,
-        box_class_repeats=4,
-        act_type='relu',
-        backbone_args=dict(drop_path_rate=0.2),
-        url='',
-    ),
-    tf_efficientdet_lite4=dict(
-        name='tf_efficientdet_lite4',
-        backbone_name='tf_efficientnet_lite4',
-        image_size=(1024, 1024),
-        fpn_channels=224,
-        fpn_cell_repeats=7,
-        box_class_repeats=4,
-        act_type='relu',
-        backbone_args=dict(drop_path_rate=0.2),
-        url='',
-    ),
+
+
+
+
 )
 
 
diff --git a/effdet/data/dataset.py b/effdet/data/dataset.py
index d751562..0087393 100644
--- a/effdet/data/dataset.py
+++ b/effdet/data/dataset.py
@@ -44,9 +44,9 @@ class DetectionDatset(data.Dataset):
 
         img_path = self.data_dir / img_info['file_name']
         img = Image.open(img_path).convert('RGB')
+
         if self.transform is not None:
             img, target = self.transform(img, target)
-
         return img, target
 
     def __len__(self):
diff --git a/effdet/data/dataset_factory.py b/effdet/data/dataset_factory.py
index d47e183..1cffa9e 100644
--- a/effdet/data/dataset_factory.py
+++ b/effdet/data/dataset_factory.py
@@ -18,6 +18,7 @@ def create_dataset(name, root, splits=('train', 'val')):
     name = name.lower()
     root = Path(root)
     dataset_cls = DetectionDatset
+
     datasets = OrderedDict()
     if name.startswith('coco'):
         if 'coco2014' in name:
@@ -33,10 +34,13 @@ def create_dataset(name, root, splits=('train', 'val')):
                 ann_filename=ann_file,
                 has_labels=split_cfg['has_labels']
             )
+            print(root / Path(split_cfg['img_dir']))
             datasets[s] = dataset_cls(
                 data_dir=root / Path(split_cfg['img_dir']),
                 parser=create_parser(dataset_cfg.parser, cfg=parser_cfg),
             )
+
+
     elif name.startswith('voc'):
         if 'voc0712' in name:
             dataset_cfg = Voc0712Cfg()
diff --git a/effdet/data/loader.py b/effdet/data/loader.py
index adf96cf..77b0dc9 100644
--- a/effdet/data/loader.py
+++ b/effdet/data/loader.py
@@ -111,31 +111,30 @@ class PrefetchLoader:
             re_count=1,
             ):
         self.loader = loader
-        self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
-        self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
+        self.mean = torch.tensor([x * 255 for x in mean]).view(1, 3, 1, 1)
+        self.std = torch.tensor([x * 255 for x in std]).view(1, 3, 1, 1)
         if re_prob > 0.:
             self.random_erasing = RandomErasing(probability=re_prob, mode=re_mode, max_count=re_count)
         else:
             self.random_erasing = None
 
     def __iter__(self):
-        stream = torch.cuda.Stream()
+
         first = True
 
         for next_input, next_target in self.loader:
-            with torch.cuda.stream(stream):
-                next_input = next_input.cuda(non_blocking=True)
-                next_input = next_input.float().sub_(self.mean).div_(self.std)
-                next_target = {k: v.cuda(non_blocking=True) for k, v in next_target.items()}
-                if self.random_erasing is not None:
-                    next_input = self.random_erasing(next_input, next_target)
+
+            next_input = next_input.float().sub_(self.mean).div_(self.std)
+            next_target = {k: v for k, v in next_target.items()}
+            if self.random_erasing is not None:
+                next_input = self.random_erasing(next_input, next_target)
 
             if not first:
                 yield input, target
             else:
                 first = False
 
-            torch.cuda.current_stream().wait_stream(stream)
+
             input = next_input
             target = next_target
 
@@ -186,34 +185,19 @@ def create_loader(
         # The fast collate fn accepts ONLY numpy uint8 images and annotations dicts of ndarrays and python scalars
         transform = transform_fn
     else:
-        if is_training:
-            transform = transforms_coco_train(
-                img_size,
-                interpolation=interpolation,
-                use_prefetcher=use_prefetcher,
-                fill_color=fill_color,
-                mean=mean,
-                std=std)
-        else:
-            transform = transforms_coco_eval(
-                img_size,
-                interpolation=interpolation,
-                use_prefetcher=use_prefetcher,
-                fill_color=fill_color,
-                mean=mean,
-                std=std)
+        transform = transforms_coco_eval(
+            img_size,
+            interpolation=interpolation,
+            use_prefetcher=use_prefetcher,
+            fill_color=fill_color,
+            mean=mean,
+            std=std)
     dataset.transform = transform
 
     sampler = None
-    if distributed:
-        if is_training:
-            sampler = torch.utils.data.distributed.DistributedSampler(dataset)
-        else:
-            # This will add extra duplicate entries to result in equal num
-            # of samples per-process, will slightly alter validation results
-            sampler = OrderedDistributedSampler(dataset)
 
     collate_fn = collate_fn or DetectionFastCollate(anchor_labeler=anchor_labeler)
+    print(anchor_labeler)
     loader = torch.utils.data.DataLoader(
         dataset,
         batch_size=batch_size,
@@ -224,9 +208,6 @@ def create_loader(
         collate_fn=collate_fn,
     )
     if use_prefetcher:
-        if is_training:
-            loader = PrefetchLoader(loader, mean=mean, std=std, re_prob=re_prob, re_mode=re_mode, re_count=re_count)
-        else:
-            loader = PrefetchLoader(loader, mean=mean, std=std)
+        loader = PrefetchLoader(loader, mean=mean, std=std)
 
     return loader
diff --git a/effdet/data/transforms.py b/effdet/data/transforms.py
index 262cbaf..13f4d61 100644
--- a/effdet/data/transforms.py
+++ b/effdet/data/transforms.py
@@ -2,10 +2,6 @@
 
 Hacked together by Ross Wightman
 """
-import random
-import math
-from copy import deepcopy
-
 from PIL import Image
 import numpy as np
 import torch
@@ -109,113 +105,6 @@ class ResizePad:
         return new_img, anno
 
 
-class RandomResizePad:
-
-    def __init__(self, target_size: int, scale: tuple = (0.1, 2.0), interpolation: str = 'random',
-                 fill_color: tuple = (0, 0, 0)):
-        self.target_size = _size_tuple(target_size)
-        self.scale = scale
-        if interpolation == 'random':
-            self.interpolation = _RANDOM_INTERPOLATION
-        else:
-            self.interpolation = _pil_interp(interpolation)
-        self.fill_color = fill_color
-
-    def _get_params(self, img):
-        # Select a random scale factor.
-        scale_factor = random.uniform(*self.scale)
-        scaled_target_height = scale_factor * self.target_size[0]
-        scaled_target_width = scale_factor * self.target_size[1]
-
-        # Recompute the accurate scale_factor using rounded scaled image size.
-        width, height = img.size
-        img_scale_y = scaled_target_height / height
-        img_scale_x = scaled_target_width / width
-        img_scale = min(img_scale_y, img_scale_x)
-
-        # Select non-zero random offset (x, y) if scaled image is larger than target size
-        scaled_h = int(height * img_scale)
-        scaled_w = int(width * img_scale)
-        offset_y = scaled_h - self.target_size[0]
-        offset_x = scaled_w - self.target_size[1]
-        offset_y = int(max(0.0, float(offset_y)) * random.uniform(0, 1))
-        offset_x = int(max(0.0, float(offset_x)) * random.uniform(0, 1))
-        return scaled_h, scaled_w, offset_y, offset_x, img_scale
-
-    def __call__(self, img, anno: dict):
-        scaled_h, scaled_w, offset_y, offset_x, img_scale = self._get_params(img)
-
-        if isinstance(self.interpolation, (tuple, list)):
-            interpolation = random.choice(self.interpolation)
-        else:
-            interpolation = self.interpolation
-        img = img.resize((scaled_w, scaled_h), interpolation)
-        right, lower = min(scaled_w, offset_x + self.target_size[1]), min(scaled_h, offset_y + self.target_size[0])
-        img = img.crop((offset_x, offset_y, right, lower))
-        new_img = Image.new("RGB", (self.target_size[1], self.target_size[0]), color=self.fill_color)
-        new_img.paste(img)  # pastes at 0,0 (upper-left corner)
-
-        if 'bbox' in anno:
-            bbox = anno['bbox']  # for convenience, modifies in-place
-            bbox[:, :4] *= img_scale
-            box_offset = np.stack([offset_y, offset_x] * 2)
-            bbox -= box_offset
-            bbox_bound = (min(scaled_h, self.target_size[0]), min(scaled_w, self.target_size[1]))
-            clip_boxes_(bbox, bbox_bound)  # crop to bounds of target image or letter-box, whichever is smaller
-            valid_indices = (bbox[:, :2] < bbox[:, 2:4]).all(axis=1)
-            anno['bbox'] = bbox[valid_indices, :]
-            anno['cls'] = anno['cls'][valid_indices]
-
-        anno['img_scale'] = 1. / img_scale  # back to original
-
-        return new_img, anno
-
-
-class RandomFlip:
-
-    def __init__(self, horizontal=True, vertical=False, prob=0.5):
-        self.horizontal = horizontal
-        self.vertical = vertical
-        self.prob = prob
-
-    def _get_params(self):
-        do_horizontal = random.random() < self.prob if self.horizontal else False
-        do_vertical = random.random() < self.prob if self.vertical else False
-        return do_horizontal, do_vertical
-
-    def __call__(self, img, annotations: dict):
-        do_horizontal, do_vertical = self._get_params()
-        width, height = img.size
-
-        def _fliph(bbox):
-            x_max = width - bbox[:, 1]
-            x_min = width - bbox[:, 3]
-            bbox[:, 1] = x_min
-            bbox[:, 3] = x_max
-
-        def _flipv(bbox):
-            y_max = height - bbox[:, 0]
-            y_min = height - bbox[:, 2]
-            bbox[:, 0] = y_min
-            bbox[:, 2] = y_max
-
-        if do_horizontal and do_vertical:
-            img = img.transpose(Image.ROTATE_180)
-            if 'bbox' in annotations:
-                _fliph(annotations['bbox'])
-                _flipv(annotations['bbox'])
-        elif do_horizontal:
-            img = img.transpose(Image.FLIP_LEFT_RIGHT)
-            if 'bbox' in annotations:
-                _fliph(annotations['bbox'])
-        elif do_vertical:
-            img = img.transpose(Image.FLIP_TOP_BOTTOM)
-            if 'bbox' in annotations:
-                _flipv(annotations['bbox'])
-
-        return img, annotations
-
-
 def resolve_fill_color(fill_color, img_mean=IMAGENET_DEFAULT_MEAN):
     if isinstance(fill_color, tuple):
         assert len(fill_color) == 3
@@ -262,25 +151,3 @@ def transforms_coco_eval(
     image_tf = Compose(image_tfl)
     return image_tf
 
-
-def transforms_coco_train(
-        img_size=224,
-        interpolation='random',
-        use_prefetcher=False,
-        fill_color='mean',
-        mean=IMAGENET_DEFAULT_MEAN,
-        std=IMAGENET_DEFAULT_STD):
-
-    fill_color = resolve_fill_color(fill_color, mean)
-
-    image_tfl = [
-        RandomFlip(horizontal=True, prob=0.5),
-        RandomResizePad(
-            target_size=img_size, interpolation=interpolation, fill_color=fill_color),
-        ImageToNumpy(),
-    ]
-
-    assert use_prefetcher, "Only supporting prefetcher usage right now"
-
-    image_tf = Compose(image_tfl)
-    return image_tf
diff --git a/effdet/factory.py b/effdet/factory.py
index 1bee96d..1e1db07 100644
--- a/effdet/factory.py
+++ b/effdet/factory.py
@@ -1,5 +1,5 @@
 from .efficientdet import EfficientDet, HeadNet
-from .bench import DetBenchTrain, DetBenchPredict
+from .bench import DetBenchPredict
 from .config import get_efficientdet_config
 from .helpers import load_pretrained, load_checkpoint