3535except :
3636 pass
3737
38- SAVE_FOLDER = 'demo_result'
39- CATEGORY_DICT = {0 : 'car' }
38+ SAVE_FOLDER = 'demo_result' # NOTE: set your save path here
39+ CATEGORY_DICT = {0 : 'car' } # NOTE: set the categories in your videos here,
40+ # format: class_id(start from 0): class_name
4041
4142timer = Timer ()
4243seq_fps = [] # list to store time used for every seq
@@ -81,42 +82,16 @@ def main(opts):
8182 resized_images_queue = [] # List[torch.Tensor] store resized images
8283 images_queue = [] # List[torch.Tensor] store origin images
8384
84- """
85- func: resize a frame to target size
86- """
87- def resize_a_frame (frame , target_size ):
88- # resize to input to the YOLO net
89- frame_resized = cv2 .resize (frame , (target_size [0 ], target_size [1 ])) # (H', W', C)
90- # convert BGR to RGB and to (C, H, W)
91- frame_resized = frame_resized [:, :, ::- 1 ].transpose (2 , 0 , 1 )
92-
93- frame_resized = np .ascontiguousarray (frame_resized , dtype = np .float32 )
94- frame_resized /= 255.0
95-
96- frame_resized = torch .from_numpy (frame_resized )
97-
98- return frame_resized
99-
85+ # check path
86+ assert os .path .exists (obj_name ), 'the path does not exist! '
87+ obj , get_next_frame = None , None # init obj
10088 if 'mp4' in opts .obj : # if it is a video
101- assert os .path .exists (obj_name ), 'the path does not exist! '
102-
103- video = cv2 .VideoCapture (obj_name )
104- while True :
105- result , frame = video .read () # frame: np.ndarray, shape (H, W, C)
106- if not result : break # end to the video
107- frame_resized = resize_a_frame (frame , [opts .img_size , opts .img_size ])
108-
109- resized_images_queue .append (frame_resized )
110- images_queue .append (frame )
89+ obj = cv2 .VideoCapture (obj_name )
90+ get_next_frame = lambda _ : obj .read ()
91+
11192 else :
112- assert os .path .exists (obj_name ), 'the path does not exist! '
113- frames = os .listdir (obj_name )
114- for item in frames :
115- frame = cv2 .imread (item )
116- frame_resized = resize_a_frame (frame , [opts .img_size , opts .img_size ])
117-
118- resized_images_queue .append (frame_resized )
119- images_queue .append (frame )
93+ obj = my_queue (os .listdir (obj_name ))
94+ get_next_frame = lambda _ : obj .pop_front ()
12095
12196
12297 """
@@ -125,9 +100,18 @@ def resize_a_frame(frame, target_size):
125100 tracker = TRACKER_DICT [opts .tracker ](opts , frame_rate = 30 , gamma = opts .gamma ) # instantiate tracker TODO: finish init params
126101 results = [] # store current seq results
127102 frame_id = 0
128- pbar = tqdm .tqdm (desc = "demo--" , ncols = 80 )
129- for i , (img , img0 ) in enumerate (zip (resized_images_queue , images_queue )):
130- pbar .update ()
103+
104+ while True :
105+ print (f'----------processing frame { frame_id } ----------' )
106+
107+ # end condition
108+ is_valid , img0 = get_next_frame (None ) # img0: (H, W, C)
109+
110+ if not is_valid :
111+ break # end of reading
112+
113+ img = resize_a_frame (img0 , [opts .img_size , opts .img_size ])
114+
131115 timer .tic () # start timing this img
132116 img = img .unsqueeze (0 ) # (C, H, W) -> (bs == 1, C, H, W)
133117 out = model (img .to (device )) # model forward
@@ -171,17 +155,55 @@ def resize_a_frame(frame, target_size):
171155
172156 frame_id += 1
173157
174- seq_fps .append (i / timer .total_time ) # cal fps for current seq
158+ seq_fps .append (frame_id / timer .total_time ) # cal fps for current seq
175159 timer .clear () # clear for next seq
176- pbar .close ()
177160 # thirdly, save results
178161 # every time assign a different name
179162 if opts .save_txt : save_results (obj_name , '' , results )
180163
181164 ## finally, save videos
182165 save_videos (obj_name )
183166
167+
168+ class my_queue :
169+ """
170+ implement a queue for image seq reading
171+ """
172+ def __init__ (self , arr : list ) -> None :
173+ self .arr = arr
174+ self .start_idx = 0
175+
176+ def push_back (self , item ):
177+ self .arr .append (item )
184178
179+ def pop_front (self ):
180+ ret = cv2 .imread (self .arr [self .start_idx ])
181+ self .start_idx += 1
182+ return not self .is_empty (), ret
183+
184+ def is_empty (self ):
185+ return self .start_idx == len (self .arr )
186+
187+
188+ def resize_a_frame (frame , target_size ):
189+ """
190+ resize a frame to target size
191+
192+ frame: np.ndarray, shape (H, W, C)
193+ target_size: List[int, int] | Tuple[int, int]
194+ """
195+ # resize to input to the YOLO net
196+ frame_resized = cv2 .resize (frame , (target_size [0 ], target_size [1 ])) # (H', W', C)
197+ # convert BGR to RGB and to (C, H, W)
198+ frame_resized = frame_resized [:, :, ::- 1 ].transpose (2 , 0 , 1 )
199+
200+ frame_resized = np .ascontiguousarray (frame_resized , dtype = np .float32 )
201+ frame_resized /= 255.0
202+
203+ frame_resized = torch .from_numpy (frame_resized )
204+
205+ return frame_resized
206+
185207
186208def save_results (obj_name , results , data_type = 'default' ):
187209 """
@@ -227,7 +249,7 @@ def plot_img(img, frame_id, results, save_dir):
227249 # draw a rect
228250 cv2 .rectangle (img_ , tlbr [:2 ], tlbr [2 :], get_color (id ), thickness = 3 , )
229251 # note the id and cls
230- text = f'{ CATEGORY_DICT [ cls ] } -{ id } '
252+ text = f'car -{ id } '
231253 cv2 .putText (img_ , text , (tlbr [0 ], tlbr [1 ]), fontFace = cv2 .FONT_HERSHEY_PLAIN , fontScale = 1 ,
232254 color = (255 , 164 , 0 ), thickness = 2 )
233255
@@ -282,7 +304,7 @@ def get_color(idx):
282304 parser .add_argument ('--model_path' , type = str , default = './weights/best.pt' , help = 'model path' )
283305 parser .add_argument ('--trace' , type = bool , default = False , help = 'traced model of YOLO v7' )
284306
285- parser .add_argument ('--img_size' , nargs = '+' , type = int , default = 1280 , help = '[train, test] image sizes' )
307+ parser .add_argument ('--img_size' , type = int , default = 1280 , help = '[train, test] image sizes' )
286308
287309 """For tracker"""
288310 # model path
0 commit comments