@@ -90,15 +90,15 @@ def main(opts):
90
90
# check path
91
91
assert os .path .exists (obj_name ), 'the path does not exist! '
92
92
obj , get_next_frame = None , None # init obj
93
- if 'mp4' in opts .obj or 'MP4' in opts .obj : # if it is a video
93
+ if 'mp4' in opts .obj or 'MP4' in opts .obj or 'mkv' in opts . obj : # if it is a video
94
94
obj = cv2 .VideoCapture (obj_name )
95
95
get_next_frame = lambda _ : obj .read ()
96
96
97
97
if os .path .isabs (obj_name ): obj_name = obj_name .split ('/' )[- 1 ][:- 4 ]
98
98
else : obj_name = obj_name [:- 4 ]
99
99
100
100
else :
101
- obj = my_queue (os .listdir (obj_name ))
101
+ obj = my_queue (os .listdir (obj_name ), obj_name )
102
102
get_next_frame = lambda _ : obj .pop_front ()
103
103
104
104
if os .path .isabs (obj_name ): obj_name = obj_name .split ('/' )[- 1 ]
@@ -124,8 +124,9 @@ def main(opts):
124
124
125
125
timer .tic () # start timing this img
126
126
img = img .unsqueeze (0 ) # (C, H, W) -> (bs == 1, C, H, W)
127
- out = model (img .to (device )) # model forward
128
- out = out [0 ] # NOTE: for yolo v7
127
+ with torch .no_grad ():
128
+ out = model (img .to (device )) # model forward
129
+ out = out [0 ] # NOTE: for yolo v7
129
130
130
131
out = post_process_v7 (out , img_size = img .shape [2 :], ori_img_size = img0 .shape )
131
132
@@ -181,15 +182,16 @@ class my_queue:
181
182
"""
182
183
implement a queue for image seq reading
183
184
"""
184
- def __init__ (self , arr : list ) -> None :
185
+ def __init__ (self , arr : list , root_path : str ) -> None :
185
186
self .arr = arr
186
187
self .start_idx = 0
188
+ self .root_path = root_path
187
189
188
190
def push_back (self , item ):
189
191
self .arr .append (item )
190
192
191
193
def pop_front (self ):
192
- ret = cv2 .imread (self .arr [self .start_idx ])
194
+ ret = cv2 .imread (os . path . join ( self .root_path , self . arr [self .start_idx ]) )
193
195
self .start_idx += 1
194
196
return not self .is_empty (), ret
195
197
@@ -299,7 +301,7 @@ def plot_img(img, frame_id, results, save_dir):
299
301
# draw a rect
300
302
cv2 .rectangle (img_ , tlbr [:2 ], tlbr [2 :], get_color (id ), thickness = 3 , )
301
303
# note the id and cls
302
- text = f'{ CATEGORY_DICT [ cls ] } - { id } '
304
+ text = f'id: { id } '
303
305
cv2 .putText (img_ , text , (tlbr [0 ], tlbr [1 ]), fontFace = cv2 .FONT_HERSHEY_PLAIN , fontScale = 1 ,
304
306
color = (255 , 164 , 0 ), thickness = 2 )
305
307
@@ -364,7 +366,7 @@ def get_color(idx):
364
366
parser .add_argument ('--dhn_path' , type = str , default = './weights/DHN.pth' , help = 'path of DHN path for DeepMOT' )
365
367
366
368
# threshs
367
- parser .add_argument ('--conf_thresh' , type = float , default = 0.1 , help = 'filter tracks' )
369
+ parser .add_argument ('--conf_thresh' , type = float , default = 0.05 , help = 'filter tracks' )
368
370
parser .add_argument ('--nms_thresh' , type = float , default = 0.7 , help = 'thresh for NMS' )
369
371
parser .add_argument ('--iou_thresh' , type = float , default = 0.5 , help = 'IOU thresh to filter tracks' )
370
372
0 commit comments