Skip to content

Commit 98f0366

Browse files
committed
fix some bugs of track_demo
1 parent dfdcf94 commit 98f0366

File tree

3 files changed

+15
-9
lines changed

3 files changed

+15
-9
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ wandb/*
1111
track_result.txt
1212
.idea/
1313
tracker/results/*
14+
*.mp4
15+
*.mkv
16+
temp.py
17+
demo_result/
1418

1519
# Byte-compiled / optimized / DLL files
1620
__pycache__/

tracker/botsort.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def update(self, det_results, ori_img):
318318
ori_img: original image, np.ndarray, shape(H, W, C)
319319
"""
320320
if isinstance(det_results, torch.Tensor):
321-
det_results = det_results.cpu().numpy()
321+
det_results = det_results.detach().cpu().numpy()
322322
if isinstance(ori_img, torch.Tensor):
323323
ori_img = ori_img.numpy()
324324

tracker/track_demo.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,15 @@ def main(opts):
9090
# check path
9191
assert os.path.exists(obj_name), 'the path does not exist! '
9292
obj, get_next_frame = None, None # init obj
93-
if 'mp4' in opts.obj or 'MP4' in opts.obj: # if it is a video
93+
if 'mp4' in opts.obj or 'MP4' in opts.obj or 'mkv' in opts.obj: # if it is a video
9494
obj = cv2.VideoCapture(obj_name)
9595
get_next_frame = lambda _ : obj.read()
9696

9797
if os.path.isabs(obj_name): obj_name = obj_name.split('/')[-1][:-4]
9898
else: obj_name = obj_name[:-4]
9999

100100
else:
101-
obj = my_queue(os.listdir(obj_name))
101+
obj = my_queue(os.listdir(obj_name), obj_name)
102102
get_next_frame = lambda _ : obj.pop_front()
103103

104104
if os.path.isabs(obj_name): obj_name = obj_name.split('/')[-1]
@@ -124,8 +124,9 @@ def main(opts):
124124

125125
timer.tic() # start timing this img
126126
img = img.unsqueeze(0) # (C, H, W) -> (bs == 1, C, H, W)
127-
out = model(img.to(device)) # model forward
128-
out = out[0] # NOTE: for yolo v7
127+
with torch.no_grad():
128+
out = model(img.to(device)) # model forward
129+
out = out[0] # NOTE: for yolo v7
129130

130131
out = post_process_v7(out, img_size=img.shape[2:], ori_img_size=img0.shape)
131132

@@ -181,15 +182,16 @@ class my_queue:
181182
"""
182183
implement a queue for image seq reading
183184
"""
184-
def __init__(self, arr: list) -> None:
185+
def __init__(self, arr: list, root_path: str) -> None:
185186
self.arr = arr
186187
self.start_idx = 0
188+
self.root_path = root_path
187189

188190
def push_back(self, item):
189191
self.arr.append(item)
190192

191193
def pop_front(self):
192-
ret = cv2.imread(self.arr[self.start_idx])
194+
ret = cv2.imread(os.path.join(self.root_path, self.arr[self.start_idx]))
193195
self.start_idx += 1
194196
return not self.is_empty(), ret
195197

@@ -299,7 +301,7 @@ def plot_img(img, frame_id, results, save_dir):
299301
# draw a rect
300302
cv2.rectangle(img_, tlbr[:2], tlbr[2:], get_color(id), thickness=3, )
301303
# note the id and cls
302-
text = f'{CATEGORY_DICT[cls]}-{id}'
304+
text = f'id: {id}'
303305
cv2.putText(img_, text, (tlbr[0], tlbr[1]), fontFace=cv2.FONT_HERSHEY_PLAIN, fontScale=1,
304306
color=(255, 164, 0), thickness=2)
305307

@@ -364,7 +366,7 @@ def get_color(idx):
364366
parser.add_argument('--dhn_path', type=str, default='./weights/DHN.pth', help='path of DHN path for DeepMOT')
365367

366368
# threshs
367-
parser.add_argument('--conf_thresh', type=float, default=0.1, help='filter tracks')
369+
parser.add_argument('--conf_thresh', type=float, default=0.05, help='filter tracks')
368370
parser.add_argument('--nms_thresh', type=float, default=0.7, help='thresh for NMS')
369371
parser.add_argument('--iou_thresh', type=float, default=0.5, help='IOU thresh to filter tracks')
370372

0 commit comments

Comments
 (0)