@@ -90,15 +90,15 @@ def main(opts):
9090 # check path
9191 assert os .path .exists (obj_name ), 'the path does not exist! '
9292 obj , get_next_frame = None , None # init obj
93- if 'mp4' in opts .obj or 'MP4' in opts .obj : # if it is a video
93+ if 'mp4' in opts .obj or 'MP4' in opts .obj or 'mkv' in opts . obj : # if it is a video
9494 obj = cv2 .VideoCapture (obj_name )
9595 get_next_frame = lambda _ : obj .read ()
9696
9797 if os .path .isabs (obj_name ): obj_name = obj_name .split ('/' )[- 1 ][:- 4 ]
9898 else : obj_name = obj_name [:- 4 ]
9999
100100 else :
101- obj = my_queue (os .listdir (obj_name ))
101+ obj = my_queue (os .listdir (obj_name ), obj_name )
102102 get_next_frame = lambda _ : obj .pop_front ()
103103
104104 if os .path .isabs (obj_name ): obj_name = obj_name .split ('/' )[- 1 ]
@@ -124,8 +124,9 @@ def main(opts):
124124
125125 timer .tic () # start timing this img
126126 img = img .unsqueeze (0 ) # (C, H, W) -> (bs == 1, C, H, W)
127- out = model (img .to (device )) # model forward
128- out = out [0 ] # NOTE: for yolo v7
127+ with torch .no_grad ():
128+ out = model (img .to (device )) # model forward
129+ out = out [0 ] # NOTE: for yolo v7
129130
130131 out = post_process_v7 (out , img_size = img .shape [2 :], ori_img_size = img0 .shape )
131132
@@ -181,15 +182,16 @@ class my_queue:
181182 """
182183 implement a queue for image seq reading
183184 """
184- def __init__ (self , arr : list ) -> None :
185+ def __init__ (self , arr : list , root_path : str ) -> None :
185186 self .arr = arr
186187 self .start_idx = 0
188+ self .root_path = root_path
187189
188190 def push_back (self , item ):
189191 self .arr .append (item )
190192
191193 def pop_front (self ):
192- ret = cv2 .imread (self .arr [self .start_idx ])
194+ ret = cv2 .imread (os . path . join ( self .root_path , self . arr [self .start_idx ]) )
193195 self .start_idx += 1
194196 return not self .is_empty (), ret
195197
@@ -299,7 +301,7 @@ def plot_img(img, frame_id, results, save_dir):
299301 # draw a rect
300302 cv2 .rectangle (img_ , tlbr [:2 ], tlbr [2 :], get_color (id ), thickness = 3 , )
301303 # note the id and cls
302- text = f'{ CATEGORY_DICT [ cls ] } - { id } '
304+ text = f'id: { id } '
303305 cv2 .putText (img_ , text , (tlbr [0 ], tlbr [1 ]), fontFace = cv2 .FONT_HERSHEY_PLAIN , fontScale = 1 ,
304306 color = (255 , 164 , 0 ), thickness = 2 )
305307
@@ -364,7 +366,7 @@ def get_color(idx):
364366 parser .add_argument ('--dhn_path' , type = str , default = './weights/DHN.pth' , help = 'path of DHN path for DeepMOT' )
365367
366368 # threshs
367- parser .add_argument ('--conf_thresh' , type = float , default = 0.1 , help = 'filter tracks' )
369+ parser .add_argument ('--conf_thresh' , type = float , default = 0.05 , help = 'filter tracks' )
368370 parser .add_argument ('--nms_thresh' , type = float , default = 0.7 , help = 'thresh for NMS' )
369371 parser .add_argument ('--iou_thresh' , type = float , default = 0.5 , help = 'IOU thresh to filter tracks' )
370372
0 commit comments