Skip to content

Commit 0073e08

Browse files
authored
fix bug of yolo fromat data and generate videos
1 parent cc044e5 commit 0073e08

File tree

3 files changed

+396
-42
lines changed

3 files changed

+396
-42
lines changed

tracker/track.py

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,13 @@
3636

3737
DATASET_ROOT = '/data/wujiapeng/datasets/VisDrone2019/VisDrone2019' # your dataset root
3838

39-
# CATEGORY_NAMES = ['car', 'van', 'truck', 'bus']
40-
CATEGORY_NAMES = ['pedestrain', 'people', 'bicycle', 'car', 'van', 'truck', 'tricycle', 'awning-tricycle', 'bus', 'motor']
39+
CATEGORY_NAMES = ['car', 'van', 'truck', 'bus']
40+
# CATEGORY_NAMES = ['pedestrain', 'people', 'bicycle', 'car', 'van', 'truck', 'tricycle', 'awning-tricycle', 'bus', 'motor']
4141
CATEGORY_DICT = {i: CATEGORY_NAMES[i] for i in range(len(CATEGORY_NAMES))} # show class
4242

43+
# IGNORE_SEQS = []
44+
IGNORE_SEQS = ['uav0000073_00600_v', 'uav0000088_00290_v'] # ignore seqs
45+
4346
timer = Timer()
4447
seq_fps = [] # list to store time used for every seq
4548
def main(opts):
@@ -59,7 +62,10 @@ def main(opts):
5962
elif opts.tracker == 'strongsort':
6063
opts.kalman_format = 'strongsort'
6164

62-
65+
# NOTE: if save video, you must save image
66+
if opts.save_videos:
67+
opts.save_images = True
68+
6369
"""
6470
1. load model
6571
"""
@@ -82,15 +88,17 @@ def main(opts):
8288
with open(f'./{opts.dataset}/test.txt', 'r') as f:
8389
lines = f.readlines()
8490
for line in lines:
85-
if line[-2] not in seqs:
86-
seqs.append(line[-2])
91+
elems = line.split('/') # devide path by / in order to get sequence name(elems[-2])
92+
if elems[-2] not in seqs:
93+
seqs.append(elems[-2])
8794

8895
elif opts.data_format == 'origin':
8996
DATA_ROOT = os.path.join(DATASET_ROOT, 'VisDrone2019-MOT-test-dev/sequences')
9097
seqs = os.listdir(DATA_ROOT)
9198
else:
9299
raise NotImplementedError
93100
seqs = sorted(seqs)
101+
seqs = [seq for seq in seqs if seq not in IGNORE_SEQS]
94102
print(f'Seqs will be evalueated, total{len(seqs)}:')
95103
print(seqs)
96104

@@ -105,8 +113,9 @@ def main(opts):
105113
for seq in seqs:
106114
print(f'--------------tracking seq {seq}--------------')
107115

108-
path = os.path.join(DATA_ROOT, seq) if opts.data_format == 'origin' else seq
109-
loader = tracker_dataloader.TrackerLoader(path, opts.img_size, opts.data_format)
116+
path = os.path.join(DATA_ROOT, seq) if opts.data_format == 'origin' else os.path.join('./', f'{opts.dataset}', 'test.txt')
117+
118+
loader = tracker_dataloader.TrackerLoader(path, opts.img_size, opts.data_format, seq)
110119

111120
data_loader = torch.utils.data.DataLoader(loader, batch_size=1)
112121

@@ -120,27 +129,33 @@ def main(opts):
120129
pbar.update()
121130
timer.tic() # start timing this img
122131

123-
out = model(img.to(device)) # model forward
132+
if not i % opts.detect_per_frame: # if it's time to detect
133+
134+
out = model(img.to(device)) # model forward
135+
out = out[0] # NOTE: for yolo v7
136+
137+
if len(out.shape) == 3: # case (bs, num_obj, ...)
138+
# out = out.squeeze()
139+
# NOTE: assert batch size == 1
140+
out = out.squeeze(0)
141+
img0 = img0.squeeze(0)
142+
# remove some low conf detections
143+
out = out[out[:, 4] > 0.001]
144+
124145

125-
out = out[0] # NOTE: for yolo v7
126-
127-
if len(out.shape) == 3: # case (bs, num_obj, ...)
128-
# out = out.squeeze()
129-
# NOTE: assert batch size == 1
130-
out = out.squeeze(0)
131-
img0 = img0.squeeze(0)
132-
# remove some low conf detections
133-
out = out[out[:, 4] > 0.001]
146+
# NOTE: yolo v7 origin out format: [xc, yc, w, h, conf, cls0_conf, cls1_conf, ..., clsn_conf]
147+
if opts.det_output_format == 'yolo':
148+
cls_conf, cls_idx = torch.max(out[:, 5:], dim=1)
149+
# out[:, 4] *= cls_conf # fuse object and cls conf
150+
out[:, 5] = cls_idx
151+
out = out[:, :6]
134152

135-
136-
# NOTE: yolo v7 origin out format: [xc, yc, w, h, conf, cls0_conf, cls1_conf, ..., clsn_conf]
137-
if opts.det_output_format == 'yolo':
138-
cls_conf, cls_idx = torch.max(out[:, 5:], dim=1)
139-
# out[:, 4] *= cls_conf # fuse object and cls conf
140-
out[:, 5] = cls_idx
141-
out = out[:, :6]
142-
143-
current_tracks = tracker.update(out, img0) # List[class(STracks)]
153+
current_tracks = tracker.update(out, img0) # List[class(STracks)]
154+
else: # otherwize
155+
# make the img shape (bs, C, H, W) as (C, H, W)
156+
if len(img0.shape) == 4:
157+
img0 = img0.squeeze(0)
158+
current_tracks = tracker.update_without_detection(None, img0)
144159

145160
# save results
146161
cur_tlwh, cur_id, cur_cls = [], [], []
@@ -171,19 +186,17 @@ def main(opts):
171186
# every time assign a different name
172187
save_results(folder_name, seq, results)
173188

189+
## finally, save videos
190+
if opts.save_images and opts.save_videos:
191+
save_videos(seq_names=seq)
192+
174193
"""
175194
3. evaluate results
176195
"""
177196
print(f'average fps: {np.mean(seq_fps)}')
178197
evaluate(sorted(os.listdir(f'./tracker/results/{folder_name}')),
179198
sorted([seq + '.txt' for seq in seqs]), data_type='visdrone', result_folder=folder_name)
180199

181-
"""
182-
4. save videos
183-
"""
184-
if opts.save_videos:
185-
save_videos(seq_names='uav0000119_02301_v')
186-
187200

188201
def save_results(folder_name, seq_name, results, data_type='default'):
189202
"""
@@ -283,15 +296,14 @@ def get_color(idx):
283296
if __name__ == '__main__':
284297
parser = argparse.ArgumentParser()
285298

286-
parser.add_argument('--dataset', type=str, default='visdrone', help='visdrone or mot')
299+
parser.add_argument('--dataset', type=str, default='visdrone', help='visdrone, or mot')
287300
parser.add_argument('--data_format', type=str, default='origin', help='format of reading dataset')
288301
parser.add_argument('--det_output_format', type=str, default='yolo', help='data format of output of detector, yolo or other')
289302

290303
parser.add_argument('--tracker', type=str, default='bytetrack', help='sort, deepsort, etc')
291304

292305
parser.add_argument('--model_path', type=str, default=None, help='model path')
293306

294-
parser.add_argument('--trace', action='store_true', help='trace model')
295307
parser.add_argument('--img_size', nargs='+', type=int, default=[1280, 1280], help='[train, test] image sizes')
296308

297309
"""For tracker"""
@@ -312,6 +324,9 @@ def get_color(idx):
312324

313325
parser.add_argument('--save_images', action='store_true', help='save tracking results (image)')
314326
parser.add_argument('--save_videos', action='store_true', help='save tracking results (video)')
327+
328+
# detect per several frames
329+
parser.add_argument('--detect_per_frame', type=int, default=1, help='choose how many frames per detect')
315330

316331

317332
opts = parser.parse_args()

0 commit comments

Comments
 (0)