35
35
except :
36
36
pass
37
37
38
- SAVE_FOLDER = 'demo_result'
39
- CATEGORY_DICT = {0 : 'car' }
38
+ SAVE_FOLDER = 'demo_result' # NOTE: set your save path here
39
+ CATEGORY_DICT = {0 : 'car' } # NOTE: set the categories in your videos here,
40
+ # format: class_id(start from 0): class_name
40
41
41
42
timer = Timer ()
42
43
seq_fps = [] # list to store time used for every seq
@@ -81,42 +82,16 @@ def main(opts):
81
82
resized_images_queue = [] # List[torch.Tensor] store resized images
82
83
images_queue = [] # List[torch.Tensor] store origin images
83
84
84
- """
85
- func: resize a frame to target size
86
- """
87
- def resize_a_frame (frame , target_size ):
88
- # resize to input to the YOLO net
89
- frame_resized = cv2 .resize (frame , (target_size [0 ], target_size [1 ])) # (H', W', C)
90
- # convert BGR to RGB and to (C, H, W)
91
- frame_resized = frame_resized [:, :, ::- 1 ].transpose (2 , 0 , 1 )
92
-
93
- frame_resized = np .ascontiguousarray (frame_resized , dtype = np .float32 )
94
- frame_resized /= 255.0
95
-
96
- frame_resized = torch .from_numpy (frame_resized )
97
-
98
- return frame_resized
99
-
85
+ # check path
86
+ assert os .path .exists (obj_name ), 'the path does not exist! '
87
+ obj , get_next_frame = None , None # init obj
100
88
if 'mp4' in opts .obj : # if it is a video
101
- assert os .path .exists (obj_name ), 'the path does not exist! '
102
-
103
- video = cv2 .VideoCapture (obj_name )
104
- while True :
105
- result , frame = video .read () # frame: np.ndarray, shape (H, W, C)
106
- if not result : break # end to the video
107
- frame_resized = resize_a_frame (frame , [opts .img_size , opts .img_size ])
108
-
109
- resized_images_queue .append (frame_resized )
110
- images_queue .append (frame )
89
+ obj = cv2 .VideoCapture (obj_name )
90
+ get_next_frame = lambda _ : obj .read ()
91
+
111
92
else :
112
- assert os .path .exists (obj_name ), 'the path does not exist! '
113
- frames = os .listdir (obj_name )
114
- for item in frames :
115
- frame = cv2 .imread (item )
116
- frame_resized = resize_a_frame (frame , [opts .img_size , opts .img_size ])
117
-
118
- resized_images_queue .append (frame_resized )
119
- images_queue .append (frame )
93
+ obj = my_queue (os .listdir (obj_name ))
94
+ get_next_frame = lambda _ : obj .pop_front ()
120
95
121
96
122
97
"""
@@ -125,9 +100,18 @@ def resize_a_frame(frame, target_size):
125
100
tracker = TRACKER_DICT [opts .tracker ](opts , frame_rate = 30 , gamma = opts .gamma ) # instantiate tracker TODO: finish init params
126
101
results = [] # store current seq results
127
102
frame_id = 0
128
- pbar = tqdm .tqdm (desc = "demo--" , ncols = 80 )
129
- for i , (img , img0 ) in enumerate (zip (resized_images_queue , images_queue )):
130
- pbar .update ()
103
+
104
+ while True :
105
+ print (f'----------processing frame { frame_id } ----------' )
106
+
107
+ # end condition
108
+ is_valid , img0 = get_next_frame (None ) # img0: (H, W, C)
109
+
110
+ if not is_valid :
111
+ break # end of reading
112
+
113
+ img = resize_a_frame (img0 , [opts .img_size , opts .img_size ])
114
+
131
115
timer .tic () # start timing this img
132
116
img = img .unsqueeze (0 ) # (C, H, W) -> (bs == 1, C, H, W)
133
117
out = model (img .to (device )) # model forward
@@ -171,17 +155,55 @@ def resize_a_frame(frame, target_size):
171
155
172
156
frame_id += 1
173
157
174
- seq_fps .append (i / timer .total_time ) # cal fps for current seq
158
+ seq_fps .append (frame_id / timer .total_time ) # cal fps for current seq
175
159
timer .clear () # clear for next seq
176
- pbar .close ()
177
160
# thirdly, save results
178
161
# every time assign a different name
179
162
if opts .save_txt : save_results (obj_name , '' , results )
180
163
181
164
## finally, save videos
182
165
save_videos (obj_name )
183
166
167
+
168
+ class my_queue :
169
+ """
170
+ implement a queue for image seq reading
171
+ """
172
+ def __init__ (self , arr : list ) -> None :
173
+ self .arr = arr
174
+ self .start_idx = 0
175
+
176
+ def push_back (self , item ):
177
+ self .arr .append (item )
184
178
179
+ def pop_front (self ):
180
+ ret = cv2 .imread (self .arr [self .start_idx ])
181
+ self .start_idx += 1
182
+ return not self .is_empty (), ret
183
+
184
+ def is_empty (self ):
185
+ return self .start_idx == len (self .arr )
186
+
187
+
188
+ def resize_a_frame (frame , target_size ):
189
+ """
190
+ resize a frame to target size
191
+
192
+ frame: np.ndarray, shape (H, W, C)
193
+ target_size: List[int, int] | Tuple[int, int]
194
+ """
195
+ # resize to input to the YOLO net
196
+ frame_resized = cv2 .resize (frame , (target_size [0 ], target_size [1 ])) # (H', W', C)
197
+ # convert BGR to RGB and to (C, H, W)
198
+ frame_resized = frame_resized [:, :, ::- 1 ].transpose (2 , 0 , 1 )
199
+
200
+ frame_resized = np .ascontiguousarray (frame_resized , dtype = np .float32 )
201
+ frame_resized /= 255.0
202
+
203
+ frame_resized = torch .from_numpy (frame_resized )
204
+
205
+ return frame_resized
206
+
185
207
186
208
def save_results (obj_name , results , data_type = 'default' ):
187
209
"""
@@ -227,7 +249,7 @@ def plot_img(img, frame_id, results, save_dir):
227
249
# draw a rect
228
250
cv2 .rectangle (img_ , tlbr [:2 ], tlbr [2 :], get_color (id ), thickness = 3 , )
229
251
# note the id and cls
230
- text = f'{ CATEGORY_DICT [ cls ] } -{ id } '
252
+ text = f'car -{ id } '
231
253
cv2 .putText (img_ , text , (tlbr [0 ], tlbr [1 ]), fontFace = cv2 .FONT_HERSHEY_PLAIN , fontScale = 1 ,
232
254
color = (255 , 164 , 0 ), thickness = 2 )
233
255
@@ -282,7 +304,7 @@ def get_color(idx):
282
304
parser .add_argument ('--model_path' , type = str , default = './weights/best.pt' , help = 'model path' )
283
305
parser .add_argument ('--trace' , type = bool , default = False , help = 'traced model of YOLO v7' )
284
306
285
- parser .add_argument ('--img_size' , nargs = '+' , type = int , default = 1280 , help = '[train, test] image sizes' )
307
+ parser .add_argument ('--img_size' , type = int , default = 1280 , help = '[train, test] image sizes' )
286
308
287
309
"""For tracker"""
288
310
# model path
0 commit comments