@@ -8,8 +8,8 @@ def mkdirs(d):
os.makedirs(d)
-seq_root = '/data/yfzhang/MOT/JDE/MOT16/images/train'
-label_root = '/data/yfzhang/MOT/JDE/MOT16/labels_with_ids/train'
+seq_root = './dataset/MOT17/images/train'
+label_root = './dataset/MOT17/labels_with_ids/train'
mkdirs(label_root)
seqs = [s for s in os.listdir(seq_root)]
@@ -11,7 +11,9 @@ import json
import numpy as np
import torch
import copy
+import sys
+sys.path.insert(0, './FairMOT/src/lib')
from torch.utils.data import Dataset
from torchvision.transforms import transforms as T
from cython_bbox import bbox_overlaps as bbox_ious
@@ -484,7 +484,7 @@ class DLASeg(nn.Module):
def get_pose_net(num_layers, heads, head_conv=256, down_ratio=4):
model = DLASeg('dla{}'.format(num_layers), heads,
- pretrained=True,
+ pretrained=False,
down_ratio=down_ratio,
final_kernel=1,
last_level=5,
@@ -10,7 +10,7 @@ class opts(object):
def __init__(self):
self.parser = argparse.ArgumentParser()
# basic experiment setting
- self.parser.add_argument('task', default='mot', help='mot')
+ self.parser.add_argument('--task', default='mot', help='mot')
self.parser.add_argument('--dataset', default='jde', help='jde')
self.parser.add_argument('--exp_id', default='default')
self.parser.add_argument('--test', action='store_true')
@@ -158,6 +158,7 @@ class opts(object):
help='category specific bounding box size.')
self.parser.add_argument('--not_reg_offset', action='store_true',
help='not regress local offset.')
+ self.parser.add_argument('--input_root', type=str, default='./result/dumpOutput_device0')
def parse(self, args=''):
if args == '':
@@ -179,11 +179,7 @@ class JDETracker(object):
opt.device = torch.device('cuda')
else:
opt.device = torch.device('cpu')
- print('Creating model...')
- self.model = create_model(opt.arch, opt.heads, opt.head_conv)
- self.model = load_model(self.model, opt.load_model)
- self.model = self.model.to(opt.device)
- self.model.eval()
+
self.tracked_stracks = [] # type: list[STrack]
self.lost_stracks = [] # type: list[STrack]
@@ -225,7 +221,7 @@ class JDETracker(object):
results[j] = results[j][keep_inds]
return results
- def update(self, im_blob, img0):
+ def update(self, hm_eval, wh_eval, id_eval, reg_eval, img0):
self.frame_id += 1
activated_starcks = []
refind_stracks = []
@@ -234,8 +230,9 @@ class JDETracker(object):
width = img0.shape[1]
height = img0.shape[0]
- inp_height = im_blob.shape[2]
- inp_width = im_blob.shape[3]
+
+ inp_height = 608
+ inp_width = 1088
c = np.array([width / 2., height / 2.], dtype=np.float32)
s = max(float(inp_width) / float(inp_height) * height, width) * 1.0
meta = {'c': c, 's': s,
@@ -244,13 +241,12 @@ class JDETracker(object):
''' Step 1: Network forward, get detections & embeddings'''
with torch.no_grad():
- output = self.model(im_blob)[-1]
- hm = output['hm'].sigmoid_()
- wh = output['wh']
- id_feature = output['id']
+ hm = hm_eval.sigmoid_()
+ wh = wh_eval
+ id_feature = id_eval
id_feature = F.normalize(id_feature, dim=1)
- reg = output['reg'] if self.opt.reg_offset else None
+ reg = reg_eval if self.opt.reg_offset else None
dets, inds = mot_decode(hm, wh, reg=reg, ltrb=self.opt.ltrb, K=self.opt.K)
id_feature = _tranpose_and_gather_feat(id_feature, inds)
id_feature = id_feature.squeeze(0)