36
36
37
37
DATASET_ROOT = '/data/wujiapeng/datasets/VisDrone2019/VisDrone2019' # your dataset root
38
38
39
- # CATEGORY_NAMES = ['car', 'van', 'truck', 'bus']
40
- CATEGORY_NAMES = ['pedestrain' , 'people' , 'bicycle' , 'car' , 'van' , 'truck' , 'tricycle' , 'awning-tricycle' , 'bus' , 'motor' ]
39
+ CATEGORY_NAMES = ['car' , 'van' , 'truck' , 'bus' ]
40
+ # CATEGORY_NAMES = ['pedestrain', 'people', 'bicycle', 'car', 'van', 'truck', 'tricycle', 'awning-tricycle', 'bus', 'motor']
41
41
CATEGORY_DICT = {i : CATEGORY_NAMES [i ] for i in range (len (CATEGORY_NAMES ))} # show class
42
42
43
+ # IGNORE_SEQS = []
44
+ IGNORE_SEQS = ['uav0000073_00600_v' , 'uav0000088_00290_v' ] # ignore seqs
45
+
43
46
timer = Timer ()
44
47
seq_fps = [] # list to store time used for every seq
45
48
def main (opts ):
@@ -59,7 +62,10 @@ def main(opts):
59
62
elif opts .tracker == 'strongsort' :
60
63
opts .kalman_format = 'strongsort'
61
64
62
-
65
+ # NOTE: if save video, you must save image
66
+ if opts .save_videos :
67
+ opts .save_images = True
68
+
63
69
"""
64
70
1. load model
65
71
"""
@@ -82,15 +88,17 @@ def main(opts):
82
88
with open (f'./{ opts .dataset } /test.txt' , 'r' ) as f :
83
89
lines = f .readlines ()
84
90
for line in lines :
85
- if line [- 2 ] not in seqs :
86
- seqs .append (line [- 2 ])
91
+ elems = line .split ('/' ) # devide path by / in order to get sequence name(elems[-2])
92
+ if elems [- 2 ] not in seqs :
93
+ seqs .append (elems [- 2 ])
87
94
88
95
elif opts .data_format == 'origin' :
89
96
DATA_ROOT = os .path .join (DATASET_ROOT , 'VisDrone2019-MOT-test-dev/sequences' )
90
97
seqs = os .listdir (DATA_ROOT )
91
98
else :
92
99
raise NotImplementedError
93
100
seqs = sorted (seqs )
101
+ seqs = [seq for seq in seqs if seq not in IGNORE_SEQS ]
94
102
print (f'Seqs will be evalueated, total{ len (seqs )} :' )
95
103
print (seqs )
96
104
@@ -105,8 +113,9 @@ def main(opts):
105
113
for seq in seqs :
106
114
print (f'--------------tracking seq { seq } --------------' )
107
115
108
- path = os .path .join (DATA_ROOT , seq ) if opts .data_format == 'origin' else seq
109
- loader = tracker_dataloader .TrackerLoader (path , opts .img_size , opts .data_format )
116
+ path = os .path .join (DATA_ROOT , seq ) if opts .data_format == 'origin' else os .path .join ('./' , f'{ opts .dataset } ' , 'test.txt' )
117
+
118
+ loader = tracker_dataloader .TrackerLoader (path , opts .img_size , opts .data_format , seq )
110
119
111
120
data_loader = torch .utils .data .DataLoader (loader , batch_size = 1 )
112
121
@@ -120,27 +129,33 @@ def main(opts):
120
129
pbar .update ()
121
130
timer .tic () # start timing this img
122
131
123
- out = model (img .to (device )) # model forward
132
+ if not i % opts .detect_per_frame : # if it's time to detect
133
+
134
+ out = model (img .to (device )) # model forward
135
+ out = out [0 ] # NOTE: for yolo v7
136
+
137
+ if len (out .shape ) == 3 : # case (bs, num_obj, ...)
138
+ # out = out.squeeze()
139
+ # NOTE: assert batch size == 1
140
+ out = out .squeeze (0 )
141
+ img0 = img0 .squeeze (0 )
142
+ # remove some low conf detections
143
+ out = out [out [:, 4 ] > 0.001 ]
144
+
124
145
125
- out = out [0 ] # NOTE: for yolo v7
126
-
127
- if len (out .shape ) == 3 : # case (bs, num_obj, ...)
128
- # out = out.squeeze()
129
- # NOTE: assert batch size == 1
130
- out = out .squeeze (0 )
131
- img0 = img0 .squeeze (0 )
132
- # remove some low conf detections
133
- out = out [out [:, 4 ] > 0.001 ]
146
+ # NOTE: yolo v7 origin out format: [xc, yc, w, h, conf, cls0_conf, cls1_conf, ..., clsn_conf]
147
+ if opts .det_output_format == 'yolo' :
148
+ cls_conf , cls_idx = torch .max (out [:, 5 :], dim = 1 )
149
+ # out[:, 4] *= cls_conf # fuse object and cls conf
150
+ out [:, 5 ] = cls_idx
151
+ out = out [:, :6 ]
134
152
135
-
136
- # NOTE: yolo v7 origin out format: [xc, yc, w, h, conf, cls0_conf, cls1_conf, ..., clsn_conf]
137
- if opts .det_output_format == 'yolo' :
138
- cls_conf , cls_idx = torch .max (out [:, 5 :], dim = 1 )
139
- # out[:, 4] *= cls_conf # fuse object and cls conf
140
- out [:, 5 ] = cls_idx
141
- out = out [:, :6 ]
142
-
143
- current_tracks = tracker .update (out , img0 ) # List[class(STracks)]
153
+ current_tracks = tracker .update (out , img0 ) # List[class(STracks)]
154
+ else : # otherwize
155
+ # make the img shape (bs, C, H, W) as (C, H, W)
156
+ if len (img0 .shape ) == 4 :
157
+ img0 = img0 .squeeze (0 )
158
+ current_tracks = tracker .update_without_detection (None , img0 )
144
159
145
160
# save results
146
161
cur_tlwh , cur_id , cur_cls = [], [], []
@@ -171,19 +186,17 @@ def main(opts):
171
186
# every time assign a different name
172
187
save_results (folder_name , seq , results )
173
188
189
+ ## finally, save videos
190
+ if opts .save_images and opts .save_videos :
191
+ save_videos (seq_names = seq )
192
+
174
193
"""
175
194
3. evaluate results
176
195
"""
177
196
print (f'average fps: { np .mean (seq_fps )} ' )
178
197
evaluate (sorted (os .listdir (f'./tracker/results/{ folder_name } ' )),
179
198
sorted ([seq + '.txt' for seq in seqs ]), data_type = 'visdrone' , result_folder = folder_name )
180
199
181
- """
182
- 4. save videos
183
- """
184
- if opts .save_videos :
185
- save_videos (seq_names = 'uav0000119_02301_v' )
186
-
187
200
188
201
def save_results (folder_name , seq_name , results , data_type = 'default' ):
189
202
"""
@@ -283,15 +296,14 @@ def get_color(idx):
283
296
if __name__ == '__main__' :
284
297
parser = argparse .ArgumentParser ()
285
298
286
- parser .add_argument ('--dataset' , type = str , default = 'visdrone' , help = 'visdrone or mot' )
299
+ parser .add_argument ('--dataset' , type = str , default = 'visdrone' , help = 'visdrone, or mot' )
287
300
parser .add_argument ('--data_format' , type = str , default = 'origin' , help = 'format of reading dataset' )
288
301
parser .add_argument ('--det_output_format' , type = str , default = 'yolo' , help = 'data format of output of detector, yolo or other' )
289
302
290
303
parser .add_argument ('--tracker' , type = str , default = 'bytetrack' , help = 'sort, deepsort, etc' )
291
304
292
305
parser .add_argument ('--model_path' , type = str , default = None , help = 'model path' )
293
306
294
- parser .add_argument ('--trace' , action = 'store_true' , help = 'trace model' )
295
307
parser .add_argument ('--img_size' , nargs = '+' , type = int , default = [1280 , 1280 ], help = '[train, test] image sizes' )
296
308
297
309
"""For tracker"""
@@ -312,6 +324,9 @@ def get_color(idx):
312
324
313
325
parser .add_argument ('--save_images' , action = 'store_true' , help = 'save tracking results (image)' )
314
326
parser .add_argument ('--save_videos' , action = 'store_true' , help = 'save tracking results (video)' )
327
+
328
+ # detect per several frames
329
+ parser .add_argument ('--detect_per_frame' , type = int , default = 1 , help = 'choose how many frames per detect' )
315
330
316
331
317
332
opts = parser .parse_args ()
0 commit comments