3636
3737DATASET_ROOT = '/data/wujiapeng/datasets/VisDrone2019/VisDrone2019' # your dataset root
3838
39- # CATEGORY_NAMES = ['car', 'van', 'truck', 'bus']
40- CATEGORY_NAMES = ['pedestrain' , 'people' , 'bicycle' , 'car' , 'van' , 'truck' , 'tricycle' , 'awning-tricycle' , 'bus' , 'motor' ]
39+ CATEGORY_NAMES = ['car' , 'van' , 'truck' , 'bus' ]
40+ # CATEGORY_NAMES = ['pedestrain', 'people', 'bicycle', 'car', 'van', 'truck', 'tricycle', 'awning-tricycle', 'bus', 'motor']
4141CATEGORY_DICT = {i : CATEGORY_NAMES [i ] for i in range (len (CATEGORY_NAMES ))} # show class
4242
43+ # IGNORE_SEQS = []
44+ IGNORE_SEQS = ['uav0000073_00600_v' , 'uav0000088_00290_v' ] # ignore seqs
45+
4346timer = Timer ()
4447seq_fps = [] # list to store time used for every seq
4548def main (opts ):
@@ -59,7 +62,10 @@ def main(opts):
5962 elif opts .tracker == 'strongsort' :
6063 opts .kalman_format = 'strongsort'
6164
62-
65+ # NOTE: if save video, you must save image
66+ if opts .save_videos :
67+ opts .save_images = True
68+
6369 """
6470 1. load model
6571 """
@@ -82,15 +88,17 @@ def main(opts):
8288 with open (f'./{ opts .dataset } /test.txt' , 'r' ) as f :
8389 lines = f .readlines ()
8490 for line in lines :
85- if line [- 2 ] not in seqs :
86- seqs .append (line [- 2 ])
91+ elems = line .split ('/' ) # devide path by / in order to get sequence name(elems[-2])
92+ if elems [- 2 ] not in seqs :
93+ seqs .append (elems [- 2 ])
8794
8895 elif opts .data_format == 'origin' :
8996 DATA_ROOT = os .path .join (DATASET_ROOT , 'VisDrone2019-MOT-test-dev/sequences' )
9097 seqs = os .listdir (DATA_ROOT )
9198 else :
9299 raise NotImplementedError
93100 seqs = sorted (seqs )
101+ seqs = [seq for seq in seqs if seq not in IGNORE_SEQS ]
94102 print (f'Seqs will be evalueated, total{ len (seqs )} :' )
95103 print (seqs )
96104
@@ -105,8 +113,9 @@ def main(opts):
105113 for seq in seqs :
106114 print (f'--------------tracking seq { seq } --------------' )
107115
108- path = os .path .join (DATA_ROOT , seq ) if opts .data_format == 'origin' else seq
109- loader = tracker_dataloader .TrackerLoader (path , opts .img_size , opts .data_format )
116+ path = os .path .join (DATA_ROOT , seq ) if opts .data_format == 'origin' else os .path .join ('./' , f'{ opts .dataset } ' , 'test.txt' )
117+
118+ loader = tracker_dataloader .TrackerLoader (path , opts .img_size , opts .data_format , seq )
110119
111120 data_loader = torch .utils .data .DataLoader (loader , batch_size = 1 )
112121
@@ -120,27 +129,33 @@ def main(opts):
120129 pbar .update ()
121130 timer .tic () # start timing this img
122131
123- out = model (img .to (device )) # model forward
132+ if not i % opts .detect_per_frame : # if it's time to detect
133+
134+ out = model (img .to (device )) # model forward
135+ out = out [0 ] # NOTE: for yolo v7
136+
137+ if len (out .shape ) == 3 : # case (bs, num_obj, ...)
138+ # out = out.squeeze()
139+ # NOTE: assert batch size == 1
140+ out = out .squeeze (0 )
141+ img0 = img0 .squeeze (0 )
142+ # remove some low conf detections
143+ out = out [out [:, 4 ] > 0.001 ]
144+
124145
125- out = out [0 ] # NOTE: for yolo v7
126-
127- if len (out .shape ) == 3 : # case (bs, num_obj, ...)
128- # out = out.squeeze()
129- # NOTE: assert batch size == 1
130- out = out .squeeze (0 )
131- img0 = img0 .squeeze (0 )
132- # remove some low conf detections
133- out = out [out [:, 4 ] > 0.001 ]
146+ # NOTE: yolo v7 origin out format: [xc, yc, w, h, conf, cls0_conf, cls1_conf, ..., clsn_conf]
147+ if opts .det_output_format == 'yolo' :
148+ cls_conf , cls_idx = torch .max (out [:, 5 :], dim = 1 )
149+ # out[:, 4] *= cls_conf # fuse object and cls conf
150+ out [:, 5 ] = cls_idx
151+ out = out [:, :6 ]
134152
135-
136- # NOTE: yolo v7 origin out format: [xc, yc, w, h, conf, cls0_conf, cls1_conf, ..., clsn_conf]
137- if opts .det_output_format == 'yolo' :
138- cls_conf , cls_idx = torch .max (out [:, 5 :], dim = 1 )
139- # out[:, 4] *= cls_conf # fuse object and cls conf
140- out [:, 5 ] = cls_idx
141- out = out [:, :6 ]
142-
143- current_tracks = tracker .update (out , img0 ) # List[class(STracks)]
153+ current_tracks = tracker .update (out , img0 ) # List[class(STracks)]
154+ else : # otherwize
155+ # make the img shape (bs, C, H, W) as (C, H, W)
156+ if len (img0 .shape ) == 4 :
157+ img0 = img0 .squeeze (0 )
158+ current_tracks = tracker .update_without_detection (None , img0 )
144159
145160 # save results
146161 cur_tlwh , cur_id , cur_cls = [], [], []
@@ -171,19 +186,17 @@ def main(opts):
171186 # every time assign a different name
172187 save_results (folder_name , seq , results )
173188
189+ ## finally, save videos
190+ if opts .save_images and opts .save_videos :
191+ save_videos (seq_names = seq )
192+
174193 """
175194 3. evaluate results
176195 """
177196 print (f'average fps: { np .mean (seq_fps )} ' )
178197 evaluate (sorted (os .listdir (f'./tracker/results/{ folder_name } ' )),
179198 sorted ([seq + '.txt' for seq in seqs ]), data_type = 'visdrone' , result_folder = folder_name )
180199
181- """
182- 4. save videos
183- """
184- if opts .save_videos :
185- save_videos (seq_names = 'uav0000119_02301_v' )
186-
187200
188201def save_results (folder_name , seq_name , results , data_type = 'default' ):
189202 """
@@ -283,15 +296,14 @@ def get_color(idx):
283296if __name__ == '__main__' :
284297 parser = argparse .ArgumentParser ()
285298
286- parser .add_argument ('--dataset' , type = str , default = 'visdrone' , help = 'visdrone or mot' )
299+ parser .add_argument ('--dataset' , type = str , default = 'visdrone' , help = 'visdrone, or mot' )
287300 parser .add_argument ('--data_format' , type = str , default = 'origin' , help = 'format of reading dataset' )
288301 parser .add_argument ('--det_output_format' , type = str , default = 'yolo' , help = 'data format of output of detector, yolo or other' )
289302
290303 parser .add_argument ('--tracker' , type = str , default = 'bytetrack' , help = 'sort, deepsort, etc' )
291304
292305 parser .add_argument ('--model_path' , type = str , default = None , help = 'model path' )
293306
294- parser .add_argument ('--trace' , action = 'store_true' , help = 'trace model' )
295307 parser .add_argument ('--img_size' , nargs = '+' , type = int , default = [1280 , 1280 ], help = '[train, test] image sizes' )
296308
297309 """For tracker"""
@@ -312,6 +324,9 @@ def get_color(idx):
312324
313325 parser .add_argument ('--save_images' , action = 'store_true' , help = 'save tracking results (image)' )
314326 parser .add_argument ('--save_videos' , action = 'store_true' , help = 'save tracking results (video)' )
327+
328+ # detect per several frames
329+ parser .add_argument ('--detect_per_frame' , type = int , default = 1 , help = 'choose how many frames per detect' )
315330
316331
317332 opts = parser .parse_args ()
0 commit comments