3030 from models .experimental import attempt_load
3131 from evaluate import evaluate
3232 from utils .torch_utils import select_device , time_synchronized , TracedModel
33+ from utils .general import non_max_suppression , scale_coords , check_img_size
3334 print ('Note: running yolo v7 detector' )
3435
3536except :
@@ -64,8 +65,11 @@ def main(opts):
6465 1. load model
6566 """
6667 device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
68+
6769 ckpt = torch .load (opts .model_path , map_location = device )
6870 model = ckpt ['ema' if ckpt .get ('ema' ) else 'model' ].float ().fuse ().eval () # for yolo v7
71+ stride = int (model .stride .max ()) # model stride
72+ opts .img_size = check_img_size (opts .img_size , s = stride ) # check img_size
6973
7074 if opts .trace :
7175 print (opts .img_size )
@@ -116,13 +120,15 @@ def main(opts):
116120 if not is_valid :
117121 break # end of reading
118122
119- img = resize_a_frame ( img0 , [ opts .img_size , opts .img_size ] )
123+ img , img0 = preprocess_v7 ( ori_img = img0 , model_size = ( opts .img_size , opts .img_size ), model_stride = stride )
120124
121125 timer .tic () # start timing this img
122126 img = img .unsqueeze (0 ) # (C, H, W) -> (bs == 1, C, H, W)
123127 out = model (img .to (device )) # model forward
124128 out = out [0 ] # NOTE: for yolo v7
125-
129+
130+ out = post_process_v7 (out , img_size = img .shape [2 :], ori_img_size = img0 .shape )
131+
126132 if len (out .shape ) == 3 : # case (bs, num_obj, ...)
127133 # out = out.squeeze()
128134 # NOTE: assert batch size == 1
@@ -157,7 +163,7 @@ def main(opts):
157163 results .append ((frame_id + 1 , cur_id , cur_tlwh , cur_cls ))
158164 timer .toc () # end timing this image
159165
160- plot_img (img0 , frame_id , [cur_tlwh , cur_id , cur_cls ], save_dir = os .path .join (SAVE_FOLDER , 'reuslt_images ' , obj_name ))
166+ plot_img (img0 , frame_id , [cur_tlwh , cur_id , cur_cls ], save_dir = os .path .join (SAVE_FOLDER , 'result_images ' , obj_name ))
161167
162168 frame_id += 1
163169
@@ -191,25 +197,63 @@ def is_empty(self):
191197 return self .start_idx == len (self .arr )
192198
193199
194- def resize_a_frame (frame , target_size ):
195- """
196- resize a frame to target size
197-
198- frame: np.ndarray, shape (H, W, C)
199- target_size: List[int, int] | Tuple[int, int]
200+ def post_process_v7 (out , img_size , ori_img_size ):
201+ """ post process for v5 and v7
202+
200203 """
201- # resize to input to the YOLO net
202- frame_resized = cv2 .resize (frame , (target_size [0 ], target_size [1 ])) # (H', W', C)
203- # convert BGR to RGB and to (C, H, W)
204- frame_resized = frame_resized [:, :, ::- 1 ].transpose (2 , 0 , 1 )
205204
206- frame_resized = np . ascontiguousarray ( frame_resized , dtype = np . float32 )
207- frame_resized /= 255.0
205+ out = non_max_suppression ( out , conf_thres = 0.01 , )[ 0 ]
206+ out [:, : 4 ] = scale_coords ( img_size , out [:, : 4 ], ori_img_size , ratio_pad = None ). round ()
208207
209- frame_resized = torch . from_numpy ( frame_resized )
208+ # out: tlbr, conf, cls
210209
211- return frame_resized
210+ return out
212211
212+ def preprocess_v7 (ori_img , model_size , model_stride ):
213+ """ simple preprocess for a single image
214+
215+ """
216+ img_resized = _letterbox (ori_img , new_shape = model_size , stride = model_stride )[0 ]
217+
218+ img_resized = img_resized [:, :, ::- 1 ].transpose (2 , 0 , 1 ) # BGR to RGB
219+ img_resized = np .ascontiguousarray (img_resized )
220+
221+ img_resized = torch .from_numpy (img_resized ).float ()
222+ img_resized /= 255.0
223+
224+ return img_resized , ori_img
225+
226+ def _letterbox (img , new_shape = (640 , 640 ), color = (114 , 114 , 114 ), auto = True , scaleFill = False , scaleup = True , stride = 32 ):
227+ # Resize and pad image while meeting stride-multiple constraints
228+ shape = img .shape [:2 ] # current shape [height, width]
229+ if isinstance (new_shape , int ):
230+ new_shape = (new_shape , new_shape )
231+
232+ # Scale ratio (new / old)
233+ r = min (new_shape [0 ] / shape [0 ], new_shape [1 ] / shape [1 ])
234+ if not scaleup : # only scale down, do not scale up (for better test mAP)
235+ r = min (r , 1.0 )
236+
237+ # Compute padding
238+ ratio = r , r # width, height ratios
239+ new_unpad = int (round (shape [1 ] * r )), int (round (shape [0 ] * r ))
240+ dw , dh = new_shape [1 ] - new_unpad [0 ], new_shape [0 ] - new_unpad [1 ] # wh padding
241+ if auto : # minimum rectangle
242+ dw , dh = np .mod (dw , stride ), np .mod (dh , stride ) # wh padding
243+ elif scaleFill : # stretch
244+ dw , dh = 0.0 , 0.0
245+ new_unpad = (new_shape [1 ], new_shape [0 ])
246+ ratio = new_shape [1 ] / shape [1 ], new_shape [0 ] / shape [0 ] # width, height ratios
247+
248+ dw /= 2 # divide padding into 2 sides
249+ dh /= 2
250+
251+ if shape [::- 1 ] != new_unpad : # resize
252+ img = cv2 .resize (img , new_unpad , interpolation = cv2 .INTER_LINEAR )
253+ top , bottom = int (round (dh - 0.1 )), int (round (dh + 0.1 ))
254+ left , right = int (round (dw - 0.1 )), int (round (dw + 0.1 ))
255+ img = cv2 .copyMakeBorder (img , top , bottom , left , right , cv2 .BORDER_CONSTANT , value = color ) # add border
256+ return img , ratio , (dw , dh )
213257
214258def save_results (obj_name , results , data_type = 'default' ):
215259 """
@@ -273,7 +317,8 @@ def save_videos(obj_name):
273317 obj_name = [obj_name ]
274318
275319 for seq in obj_name :
276- images_path = os .path .join (SAVE_FOLDER , 'reuslt_images' , seq )
320+ if 'mp4' in seq : seq = seq [:- 4 ]
321+ images_path = os .path .join (SAVE_FOLDER , 'result_images' , seq )
277322 images_name = sorted (os .listdir (images_path ))
278323
279324 to_video_path = os .path .join (images_path , '../' , seq + '.mp4' )
@@ -303,12 +348,12 @@ def get_color(idx):
303348if __name__ == '__main__' :
304349 parser = argparse .ArgumentParser ()
305350
306- parser .add_argument ('--obj' , type = str , default = 'M1305 .mp4' , help = 'video NAME or images FOLDER NAME' )
351+ parser .add_argument ('--obj' , type = str , default = 'demo .mp4' , help = 'video NAME or images FOLDER NAME' )
307352
308353 parser .add_argument ('--save_txt' , type = bool , default = False , help = 'whether save txt' )
309354
310355 parser .add_argument ('--tracker' , type = str , default = 'sort' , help = 'sort, deepsort, etc' )
311- parser .add_argument ('--model_path' , type = str , default = './weights/best .pt' , help = 'model path' )
356+ parser .add_argument ('--model_path' , type = str , default = './weights/yolov7_UAVDT_35epochs_20230507 .pt' , help = 'model path' )
312357 parser .add_argument ('--trace' , type = bool , default = False , help = 'traced model of YOLO v7' )
313358
314359 parser .add_argument ('--img_size' , type = int , default = 1280 , help = '[train, test] image sizes' )
@@ -319,7 +364,7 @@ def get_color(idx):
319364 parser .add_argument ('--dhn_path' , type = str , default = './weights/DHN.pth' , help = 'path of DHN path for DeepMOT' )
320365
321366 # threshs
322- parser .add_argument ('--conf_thresh' , type = float , default = 0.5 , help = 'filter tracks' )
367+ parser .add_argument ('--conf_thresh' , type = float , default = 0.1 , help = 'filter tracks' )
323368 parser .add_argument ('--nms_thresh' , type = float , default = 0.7 , help = 'thresh for NMS' )
324369 parser .add_argument ('--iou_thresh' , type = float , default = 0.5 , help = 'IOU thresh to filter tracks' )
325370
0 commit comments