Skip to content

Commit 12d2fb7

Browse files
authored
Create track_yolov5.py
1 parent 6fd5c88 commit 12d2fb7

File tree

1 file changed

+325
-0
lines changed

1 file changed

+325
-0
lines changed

other/track_yolov5.py

Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
"""
2+
main code for track
3+
"""
4+
import numpy as np
5+
import torch
6+
import cv2
7+
from PIL import Image
8+
import tqdm
9+
10+
import argparse
11+
import os
12+
from time import gmtime, strftime
13+
from timer import Timer
14+
15+
from basetrack import BaseTracker # for framework
16+
from deepsort import DeepSORT
17+
from bytetrack import ByteTrack
18+
from deepmot import DeepMOT
19+
from botsort import BoTSORT
20+
from uavmot import UAVMOT
21+
from strongsort import StrongSORT
22+
23+
try: # import package that outside the tracker folder For yolo v7
24+
import sys
25+
sys.path.append(os.getcwd())
26+
27+
from models.common import DetectMultiBackend
28+
from evaluate import evaluate
29+
print('Note: running yolo v5 detector')
30+
31+
except:
32+
pass
33+
34+
import tracker_dataloader
35+
36+
DATASET_ROOT = '/data/wujiapeng/datasets/VisDrone2019/VisDrone2019' # your dataset root
37+
38+
# CATEGORY_NAMES = ['car', 'van', 'truck', 'bus']
39+
CATEGORY_NAMES = ['pedestrain', 'people', 'bicycle', 'car', 'van', 'truck', 'tricycle', 'awning-tricycle', 'bus', 'motor']
40+
CATEGORY_DICT = {i: CATEGORY_NAMES[i] for i in range(len(CATEGORY_NAMES))} # show class
41+
42+
# IGNORE_SEQS = []
43+
# IGNORE_SEQS = ['uav0000073_00600_v', 'uav0000088_00290_v'] # ignore seqs
44+
45+
# NOTE: for yolo v5 model loader(func DetectMultiBackend)
46+
YAML_DICT = {'visdrone': './data/Visdrone_car.yaml',
47+
'uavdt': './data/UAVDT.yaml'}
48+
49+
timer = Timer()
50+
seq_fps = [] # list to store time used for every seq
51+
def main(opts):
52+
TRACKER_DICT = {
53+
'sort': BaseTracker,
54+
'deepsort': DeepSORT,
55+
'bytetrack': ByteTrack,
56+
'deepmot': DeepMOT,
57+
'botsort': BoTSORT,
58+
'uavmot': UAVMOT,
59+
'strongsort': StrongSORT,
60+
} # dict for trackers, key: str, value: class(BaseTracker)
61+
62+
# NOTE: ATTENTION: make kalman and tracker compatible
63+
if opts.tracker == 'botsort':
64+
opts.kalman_format = 'botsort'
65+
elif opts.tracker == 'strongsort':
66+
opts.kalman_format = 'strongsort'
67+
68+
69+
70+
"""
71+
1. load model
72+
"""
73+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
74+
model = DetectMultiBackend(opts.model_path, device=device, dnn=False, data=YAML_DICT[opts.dataset], fp16=False)
75+
model.eval()
76+
# warm up
77+
model.warmup(imgsz=(1, 3, 640, 640))
78+
"""
79+
2. load dataset and track
80+
"""
81+
# track per seq
82+
# firstly, create seq list
83+
seqs = []
84+
if opts.data_format == 'yolo':
85+
with open(f'./{opts.dataset}/test.txt', 'r') as f:
86+
lines = f.readlines()
87+
for line in lines:
88+
if line[-2] not in seqs:
89+
seqs.append(line[-2])
90+
91+
elif opts.data_format == 'origin':
92+
DATA_ROOT = os.path.join(DATASET_ROOT, 'VisDrone2019-MOT-test-dev/sequences')
93+
seqs = os.listdir(DATA_ROOT)
94+
else:
95+
raise NotImplementedError
96+
seqs = sorted(seqs)
97+
seqs = [seq for seq in seqs if seq not in IGNORE_SEQS]
98+
print(f'Seqs will be evalueated, total{len(seqs)}:')
99+
print(seqs)
100+
101+
# secondly, for each seq, instantiate dataloader class and track
102+
# every time assign a different folder to store results
103+
folder_name = strftime("%Y-%d-%m %H:%M:%S", gmtime())
104+
folder_name = folder_name[5:-3].replace('-', '_')
105+
folder_name = folder_name.replace(' ', '_')
106+
folder_name = folder_name.replace(':', '_')
107+
folder_name = opts.tracker + '_' + folder_name
108+
109+
for seq in seqs:
110+
print(f'--------------tracking seq {seq}--------------')
111+
112+
path = os.path.join(DATA_ROOT, seq) if opts.data_format == 'origin' else seq
113+
loader = tracker_dataloader.TrackerLoader(path, opts.img_size, opts.data_format)
114+
115+
data_loader = torch.utils.data.DataLoader(loader, batch_size=1)
116+
117+
tracker = TRACKER_DICT[opts.tracker](opts, frame_rate=30, gamma=opts.gamma) # instantiate tracker TODO: finish init params
118+
119+
results = [] # store current seq results
120+
frame_id = 0
121+
122+
pbar = tqdm.tqdm(desc=f"{seq}", ncols=80)
123+
for i, (img, img0) in enumerate(data_loader):
124+
pbar.update()
125+
timer.tic() # start timing this img
126+
127+
out = model(img.to(device)) # model forward
128+
129+
out = out[0] # NOTE: for yolo v7
130+
131+
if len(out.shape) == 3: # case (bs, num_obj, ...)
132+
# out = out.squeeze()
133+
# NOTE: assert batch size == 1
134+
out = out.squeeze(0)
135+
img0 = img0.squeeze(0)
136+
# remove some low conf detections
137+
out = out[out[:, 4] > 0.001]
138+
139+
140+
# NOTE: yolo v7 origin out format: [xc, yc, w, h, conf, cls0_conf, cls1_conf, ..., clsn_conf]
141+
if opts.det_output_format == 'yolo':
142+
cls_conf, cls_idx = torch.max(out[:, 5:], dim=1)
143+
# out[:, 4] *= cls_conf # fuse object and cls conf
144+
out[:, 5] = cls_idx
145+
out = out[:, :6]
146+
147+
current_tracks = tracker.update(out, img0) # List[class(STracks)]
148+
149+
# save results
150+
cur_tlwh, cur_id, cur_cls = [], [], []
151+
for trk in current_tracks:
152+
bbox = trk.tlwh
153+
id = trk.track_id
154+
cls = trk.cls
155+
156+
# filter low area bbox
157+
if bbox[2] * bbox[3] > opts.min_area:
158+
cur_tlwh.append(bbox)
159+
cur_id.append(id)
160+
cur_cls.append(cls)
161+
# results.append((frame_id + 1, id, bbox, cls))
162+
163+
results.append((frame_id + 1, cur_id, cur_tlwh, cur_cls))
164+
timer.toc() # end timing this image
165+
166+
if opts.save_images:
167+
plot_img(img0, frame_id, [cur_tlwh, cur_id, cur_cls], save_dir=os.path.join(DATASET_ROOT, 'reuslt_images', seq))
168+
169+
frame_id += 1
170+
171+
seq_fps.append(i / timer.total_time) # cal fps for current seq
172+
timer.clear() # clear for next seq
173+
pbar.close()
174+
# thirdly, save results
175+
# every time assign a different name
176+
save_results(folder_name, seq, results)
177+
178+
"""
179+
3. evaluate results
180+
"""
181+
print(f'average fps: {np.mean(seq_fps)}')
182+
evaluate(sorted(os.listdir(f'./tracker/results/{folder_name}')),
183+
sorted([seq + '.txt' for seq in seqs]), data_type='visdrone', result_folder=folder_name)
184+
185+
"""
186+
4. save videos
187+
"""
188+
if opts.save_videos:
189+
save_videos(seq_names='uav0000119_02301_v')
190+
191+
192+
def save_results(folder_name, seq_name, results, data_type='default'):
193+
"""
194+
write results to txt file
195+
196+
results: list row format: frame id, target id, box coordinate, class(optional)
197+
to_file: file path(optional)
198+
data_type: write data format
199+
"""
200+
assert len(results)
201+
if not data_type == 'default':
202+
raise NotImplementedError # TODO
203+
204+
if not os.path.exists(f'./tracker/results/{folder_name}'):
205+
206+
os.makedirs(f'./tracker/results/{folder_name}')
207+
208+
with open(os.path.join('./tracker/results', folder_name, seq_name + '.txt'), 'w') as f:
209+
for frame_id, target_ids, tlwhs, clses in results:
210+
if data_type == 'default':
211+
212+
# f.write(f'{frame_id},{target_id},{tlwh[0]},{tlwh[1]},\
213+
# {tlwh[2]},{tlwh[3]},{cls}\n')
214+
for id, tlwh, cls in zip(target_ids, tlwhs, clses):
215+
f.write(f'{frame_id},{id},{tlwh[0]:.2f},{tlwh[1]:.2f},{tlwh[2]:.2f},{tlwh[3]:.2f},{int(cls)}\n')
216+
f.close()
217+
218+
return folder_name
219+
220+
def plot_img(img, frame_id, results, save_dir):
221+
"""
222+
img: np.ndarray: (H, W, C)
223+
frame_id: int
224+
results: [tlwhs, ids, clses]
225+
save_dir: sr
226+
227+
plot images with bboxes of a seq
228+
"""
229+
if not os.path.exists(save_dir):
230+
os.makedirs(save_dir)
231+
232+
img_ = np.ascontiguousarray(np.copy(img))
233+
234+
tlwhs, ids, clses = results[0], results[1], results[2]
235+
for tlwh, id, cls in zip(tlwhs, ids, clses):
236+
237+
# convert tlwh to tlbr
238+
tlbr = tuple([int(tlwh[0]), int(tlwh[1]), int(tlwh[0] + tlwh[2]), int(tlwh[1] + tlwh[3])])
239+
# draw a rect
240+
cv2.rectangle(img_, tlbr[:2], tlbr[2:], get_color(id), thickness=1, )
241+
# note the id and cls
242+
text = f'{CATEGORY_DICT[cls]}-{id}'
243+
cv2.putText(img_, text, (tlbr[0], tlbr[1]), fontFace=cv2.FONT_HERSHEY_PLAIN, fontScale=1,
244+
color=(255, 164, 0), thickness=1)
245+
246+
cv2.imwrite(filename=os.path.join(save_dir, f'{frame_id:05d}.jpg'), img=img_)
247+
248+
249+
def save_videos(seq_names):
250+
"""
251+
convert imgs to a video
252+
253+
seq_names: List[str] or str, seqs that will be generated
254+
"""
255+
if not isinstance(seq_names, list):
256+
seq_names = [seq_names]
257+
258+
for seq in seq_names:
259+
images_path = os.path.join(DATASET_ROOT, 'reuslt_images', seq)
260+
images_name = sorted(os.listdir(images_path))
261+
262+
to_video_path = os.path.join(images_path, '../', seq + '.mp4')
263+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
264+
265+
img0 = Image.open(os.path.join(images_path, images_name[0]))
266+
vw = cv2.VideoWriter(to_video_path, fourcc, 15, img0.size)
267+
268+
for img in images_name:
269+
if img.endswith('.jpg'):
270+
frame = cv2.imread(os.path.join(images_path, img))
271+
vw.write(frame)
272+
273+
print('Save videos Done!!')
274+
275+
276+
277+
def get_color(idx):
278+
"""
279+
aux func for plot_seq
280+
get a unique color for each id
281+
"""
282+
idx = idx * 3
283+
color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
284+
285+
return color
286+
287+
if __name__ == '__main__':
288+
parser = argparse.ArgumentParser()
289+
290+
parser.add_argument('--dataset', type=str, default='visdrone', help='visdrone, or mot')
291+
parser.add_argument('--data_format', type=str, default='origin', help='format of reading dataset')
292+
parser.add_argument('--det_output_format', type=str, default='yolo', help='data format of output of detector, yolo or other')
293+
294+
parser.add_argument('--tracker', type=str, default='bytetrack', help='sort, deepsort, etc')
295+
296+
parser.add_argument('--model_path', type=str, default=None, help='model path')
297+
298+
parser.add_argument('--img_size', nargs='+', type=int, default=[1280, 1280], help='[train, test] image sizes')
299+
300+
"""For tracker"""
301+
# model path
302+
parser.add_argument('--reid_model_path', type=str, default='./weights/ckpt.t7', help='path for reid model path')
303+
parser.add_argument('--dhn_path', type=str, default='./weights/DHN.pth', help='path of DHN path for DeepMOT')
304+
305+
# threshs
306+
parser.add_argument('--conf_thresh', type=float, default=0.5, help='filter tracks')
307+
parser.add_argument('--nms_thresh', type=float, default=0.7, help='thresh for NMS')
308+
parser.add_argument('--iou_thresh', type=float, default=0.5, help='IOU thresh to filter tracks')
309+
310+
# other options
311+
parser.add_argument('--track_buffer', type=int, default=30, help='tracking buffer')
312+
parser.add_argument('--gamma', type=float, default=0.1, help='param to control fusing motion and apperance dist')
313+
parser.add_argument('--kalman_format', type=str, default='default', help='use what kind of Kalman, default, naive, strongsort or bot-sort like')
314+
parser.add_argument('--min_area', type=float, default=150, help='use to filter small bboxs')
315+
316+
parser.add_argument('--save_images', action='store_true', help='save tracking results (image)')
317+
parser.add_argument('--save_videos', action='store_true', help='save tracking results (video)')
318+
319+
320+
opts = parser.parse_args()
321+
322+
# for debug
323+
# evaluate(sorted(os.listdir('./tracker/results/deepmot_17_08_02_38')),
324+
# sorted(os.listdir('./tracker/results/deepmot_17_08_02_38')), data_type='visdrone', result_folder='deepmot_17_08_02_38')
325+
main(opts)

0 commit comments

Comments
 (0)