Skip to content

Commit 3b803f1

Browse files
committed
fix bug of track_demo.py
1 parent fc70fc4 commit 3b803f1

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

tracker/basetrack.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def update(self, det_results, ori_img):
362362
ori_img: original image, np.ndarray, shape(H, W, C)
363363
"""
364364
if isinstance(det_results, torch.Tensor):
365-
det_results = det_results.cpu().numpy()
365+
det_results = det_results.detach().cpu().numpy()
366366
if isinstance(ori_img, torch.Tensor):
367367
ori_img = ori_img.numpy()
368368

tracker/track_demo.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def main(opts):
7272
model = TracedModel(model, device, opts.img_size)
7373
else:
7474
model.to(device)
75+
model.eval()
7576

7677
"""
7778
2. load videos or images
@@ -85,14 +86,19 @@ def main(opts):
8586
# check path
8687
assert os.path.exists(obj_name), 'the path does not exist! '
8788
obj, get_next_frame = None, None # init obj
88-
if 'mp4' in opts.obj: # if it is a video
89+
if 'mp4' in opts.obj or 'MP4' in opts.obj: # if it is a video
8990
obj = cv2.VideoCapture(obj_name)
9091
get_next_frame = lambda _ : obj.read()
92+
93+
if os.path.isabs(obj_name): obj_name = obj_name.split('/')[-1][:-4]
94+
else: obj_name = obj_name[:-4]
9195

9296
else:
9397
obj = my_queue(os.listdir(obj_name))
9498
get_next_frame = lambda _ : obj.pop_front()
9599

100+
if os.path.isabs(obj_name): obj_name = obj_name.split('/')[-1]
101+
96102

97103
"""
98104
3. start tracking
@@ -262,6 +268,7 @@ def save_videos(obj_name):
262268
263269
seq_names: List[str] or str, seqs that will be generated
264270
"""
271+
265272
if not isinstance(obj_name, list):
266273
obj_name = [obj_name]
267274

0 commit comments

Comments
 (0)