05360171创建于 2022年3月18日历史提交
diff --git a/requirements.txt b/requirements.txt
index c7e7dfc..40e0dea 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,6 @@
 # TODO: Update with exact module version
 numpy
-torch>=1.7
+torch==1.7
 opencv_python
 loguru
 scikit-image
@@ -15,4 +15,4 @@ tensorboard
 # verified versions
 onnx==1.8.1
 onnxruntime==1.8.0
-onnx-simplifier==0.3.5
\ No newline at end of file
+onnx-simplifier==0.3.5
diff --git a/tools/export_onnx.py b/tools/export_onnx.py
index 5502ab3..5534fdb 100644
--- a/tools/export_onnx.py
+++ b/tools/export_onnx.py
@@ -12,6 +12,16 @@ from torch import nn
 from yolox.exp import get_exp
 from yolox.models.network_blocks import SiLU
 from yolox.utils import replace_module
+from onnx.utils import extract_model
+
+
+def truncate_model(model_path, save_path, beg_nodes, end_nodes):
+    extract_model(
+        model_path,
+        save_path,
+        beg_nodes,
+        end_nodes
+    )
 
 
 def make_parser():
@@ -110,7 +120,18 @@ def main():
         assert check, "Simplified ONNX model could not be validated"
         onnx.save(model_simp, args.output_name)
         logger.info("generated simplified onnx model named {}".format(args.output_name))
+    model_path = args.output_name
+    save_path = model_path
+    truncate_model(
+        model_path,
+        save_path,
+        beg_nodes=['images'],
+        end_nodes=['1535', '1536', '1526',
+                   '1561', '1562', '1552',
+                   '1587', '1588', '1578']
+    )
 
 
 if __name__ == "__main__":
     main()
+
diff --git a/yolox/evaluators/coco_evaluator.py b/yolox/evaluators/coco_evaluator.py
index 96eb56a..64827fc 100644
--- a/yolox/evaluators/coco_evaluator.py
+++ b/yolox/evaluators/coco_evaluator.py
@@ -164,33 +164,34 @@ class COCOEvaluator:
                 data_list.append(pred_data)
         return data_list
 
-    def evaluate_prediction(self, data_dict, statistics):
+    def evaluate_prediction(self, data_dict, statistics=None):
         if not is_main_process():
             return 0, 0, None
 
         logger.info("Evaluate in main process...")
 
         annType = ["segm", "bbox", "keypoints"]
+        
+        if statistics is not None:
+            inference_time = statistics[0].item()
+            nms_time = statistics[1].item()
+            n_samples = statistics[2].item()
+
+            a_infer_time = 1000 * inference_time / (n_samples * self.dataloader.batch_size)
+            a_nms_time = 1000 * nms_time / (n_samples * self.dataloader.batch_size)
+    
+            time_info = ", ".join(
+                [
+                    "Average {} time: {:.2f} ms".format(k, v)
+                    for k, v in zip(
+                        ["forward", "NMS", "inference"],
+                        [a_infer_time, a_nms_time, (a_infer_time + a_nms_time)],
+                    )
+                ]
+            )
 
-        inference_time = statistics[0].item()
-        nms_time = statistics[1].item()
-        n_samples = statistics[2].item()
-
-        a_infer_time = 1000 * inference_time / (n_samples * self.dataloader.batch_size)
-        a_nms_time = 1000 * nms_time / (n_samples * self.dataloader.batch_size)
-
-        time_info = ", ".join(
-            [
-                "Average {} time: {:.2f} ms".format(k, v)
-                for k, v in zip(
-                    ["forward", "NMS", "inference"],
-                    [a_infer_time, a_nms_time, (a_infer_time + a_nms_time)],
-                )
-            ]
-        )
-
-        info = time_info + "\n"
-
+            info = time_info + "\n"
+        info = "\n"
         # Evaluate the Dt (detection) json comparing with the ground truth
         if len(data_dict) > 0:
             cocoGt = self.dataloader.dataset.coco
@@ -216,6 +217,7 @@ class COCOEvaluator:
             with contextlib.redirect_stdout(redirect_string):
                 cocoEval.summarize()
             info += redirect_string.getvalue()
-            return cocoEval.stats[0], cocoEval.stats[1], info
+            #return cocoEval.stats[0], cocoEval.stats[1], info
+            return info
         else:
             return 0, 0, info