05360171创建于 2022年3月18日历史提交
diff --git a/src/gen_labels_16.py b/src/gen_labels_16.py
index 71750ae..e511ec4 100644
--- a/src/gen_labels_16.py
+++ b/src/gen_labels_16.py
@@ -8,8 +8,8 @@ def mkdirs(d):
         os.makedirs(d)
 
 
-seq_root = '/data/yfzhang/MOT/JDE/MOT16/images/train'
-label_root = '/data/yfzhang/MOT/JDE/MOT16/labels_with_ids/train'
+seq_root = './dataset/MOT17/images/train'
+label_root = './dataset/MOT17/labels_with_ids/train'
 mkdirs(label_root)
 seqs = [s for s in os.listdir(seq_root)]
 
diff --git a/src/lib/datasets/dataset/jde.py b/src/lib/datasets/dataset/jde.py
index a13ff1f..ed77b2e 100644
--- a/src/lib/datasets/dataset/jde.py
+++ b/src/lib/datasets/dataset/jde.py
@@ -11,7 +11,9 @@ import json
 import numpy as np
 import torch
 import copy
+import sys
 
+sys.path.insert(0, './FairMOT/src/lib')
 from torch.utils.data import Dataset
 from torchvision.transforms import transforms as T
 from cython_bbox import bbox_overlaps as bbox_ious
diff --git a/src/lib/models/networks/pose_dla_dcn.py b/src/lib/models/networks/pose_dla_dcn.py
index b083640..ff925a4 100644
--- a/src/lib/models/networks/pose_dla_dcn.py
+++ b/src/lib/models/networks/pose_dla_dcn.py
@@ -484,7 +484,7 @@ class DLASeg(nn.Module):
 
 def get_pose_net(num_layers, heads, head_conv=256, down_ratio=4):
   model = DLASeg('dla{}'.format(num_layers), heads,
-                 pretrained=True,
+                 pretrained=False,
                  down_ratio=down_ratio,
                  final_kernel=1,
                  last_level=5,
diff --git a/src/lib/opts.py b/src/lib/opts.py
index 00e6271..a6992e2 100644
--- a/src/lib/opts.py
+++ b/src/lib/opts.py
@@ -10,7 +10,7 @@ class opts(object):
   def __init__(self):
     self.parser = argparse.ArgumentParser()
     # basic experiment setting
-    self.parser.add_argument('task', default='mot', help='mot')
+    self.parser.add_argument('--task', default='mot', help='mot')
     self.parser.add_argument('--dataset', default='jde', help='jde')
     self.parser.add_argument('--exp_id', default='default')
     self.parser.add_argument('--test', action='store_true')
@@ -158,6 +158,7 @@ class opts(object):
                              help='category specific bounding box size.')
     self.parser.add_argument('--not_reg_offset', action='store_true',
                              help='not regress local offset.')
+    self.parser.add_argument('--input_root', type=str, default='./result/dumpOutput_device0')
 
   def parse(self, args=''):
     if args == '':
diff --git a/src/lib/tracker/multitracker.py b/src/lib/tracker/multitracker.py
index 64474bc..9fb427c 100644
--- a/src/lib/tracker/multitracker.py
+++ b/src/lib/tracker/multitracker.py
@@ -179,11 +179,7 @@ class JDETracker(object):
             opt.device = torch.device('cuda')
         else:
             opt.device = torch.device('cpu')
-        print('Creating model...')
-        self.model = create_model(opt.arch, opt.heads, opt.head_conv)
-        self.model = load_model(self.model, opt.load_model)
-        self.model = self.model.to(opt.device)
-        self.model.eval()
+        
 
         self.tracked_stracks = []  # type: list[STrack]
         self.lost_stracks = []  # type: list[STrack]
@@ -225,7 +221,7 @@ class JDETracker(object):
                 results[j] = results[j][keep_inds]
         return results
 
-    def update(self, im_blob, img0):
+    def update(self, hm_eval, wh_eval, id_eval, reg_eval, img0):
         self.frame_id += 1
         activated_starcks = []
         refind_stracks = []
@@ -234,8 +230,9 @@ class JDETracker(object):
 
         width = img0.shape[1]
         height = img0.shape[0]
-        inp_height = im_blob.shape[2]
-        inp_width = im_blob.shape[3]
+
+        inp_height = 608
+        inp_width = 1088
         c = np.array([width / 2., height / 2.], dtype=np.float32)
         s = max(float(inp_width) / float(inp_height) * height, width) * 1.0
         meta = {'c': c, 's': s,
@@ -244,13 +241,12 @@ class JDETracker(object):
 
         ''' Step 1: Network forward, get detections & embeddings'''
         with torch.no_grad():
-            output = self.model(im_blob)[-1]
-            hm = output['hm'].sigmoid_()
-            wh = output['wh']
-            id_feature = output['id']
+            hm = hm_eval.sigmoid_()
+            wh = wh_eval
+            id_feature = id_eval
             id_feature = F.normalize(id_feature, dim=1)
 
-            reg = output['reg'] if self.opt.reg_offset else None
+            reg = reg_eval if self.opt.reg_offset else None
             dets, inds = mot_decode(hm, wh, reg=reg, ltrb=self.opt.ltrb, K=self.opt.K)
             id_feature = _tranpose_and_gather_feat(id_feature, inds)
             id_feature = id_feature.squeeze(0)