|
| 1 | +""" |
| 2 | +main code for track |
| 3 | +""" |
| 4 | +import numpy as np |
| 5 | +import torch |
| 6 | +import cv2 |
| 7 | +from PIL import Image |
| 8 | +import tqdm |
| 9 | + |
| 10 | +import argparse |
| 11 | +import os |
| 12 | +from time import gmtime, strftime |
| 13 | +from timer import Timer |
| 14 | + |
| 15 | +from basetrack import BaseTracker # for framework |
| 16 | +from deepsort import DeepSORT |
| 17 | +from bytetrack import ByteTrack |
| 18 | +from deepmot import DeepMOT |
| 19 | +from botsort import BoTSORT |
| 20 | +from uavmot import UAVMOT |
| 21 | +from strongsort import StrongSORT |
| 22 | + |
| 23 | +try: # import package that outside the tracker folder For yolo v7 |
| 24 | + import sys |
| 25 | + sys.path.append(os.getcwd()) |
| 26 | + |
| 27 | + from models.common import DetectMultiBackend |
| 28 | + from evaluate import evaluate |
| 29 | + print('Note: running yolo v5 detector') |
| 30 | + |
| 31 | +except: |
| 32 | + pass |
| 33 | + |
| 34 | +import tracker_dataloader |
| 35 | + |
| 36 | +DATASET_ROOT = '/data/wujiapeng/datasets/VisDrone2019/VisDrone2019' # your dataset root |
| 37 | + |
| 38 | +# CATEGORY_NAMES = ['car', 'van', 'truck', 'bus'] |
| 39 | +CATEGORY_NAMES = ['pedestrain', 'people', 'bicycle', 'car', 'van', 'truck', 'tricycle', 'awning-tricycle', 'bus', 'motor'] |
| 40 | +CATEGORY_DICT = {i: CATEGORY_NAMES[i] for i in range(len(CATEGORY_NAMES))} # show class |
| 41 | + |
| 42 | +# IGNORE_SEQS = [] |
| 43 | +# IGNORE_SEQS = ['uav0000073_00600_v', 'uav0000088_00290_v'] # ignore seqs |
| 44 | + |
| 45 | +# NOTE: for yolo v5 model loader(func DetectMultiBackend) |
| 46 | +YAML_DICT = {'visdrone': './data/Visdrone_car.yaml', |
| 47 | + 'uavdt': './data/UAVDT.yaml'} |
| 48 | + |
| 49 | +timer = Timer() |
| 50 | +seq_fps = [] # list to store time used for every seq |
| 51 | +def main(opts): |
| 52 | + TRACKER_DICT = { |
| 53 | + 'sort': BaseTracker, |
| 54 | + 'deepsort': DeepSORT, |
| 55 | + 'bytetrack': ByteTrack, |
| 56 | + 'deepmot': DeepMOT, |
| 57 | + 'botsort': BoTSORT, |
| 58 | + 'uavmot': UAVMOT, |
| 59 | + 'strongsort': StrongSORT, |
| 60 | + } # dict for trackers, key: str, value: class(BaseTracker) |
| 61 | + |
| 62 | + # NOTE: ATTENTION: make kalman and tracker compatible |
| 63 | + if opts.tracker == 'botsort': |
| 64 | + opts.kalman_format = 'botsort' |
| 65 | + elif opts.tracker == 'strongsort': |
| 66 | + opts.kalman_format = 'strongsort' |
| 67 | + |
| 68 | + |
| 69 | + |
| 70 | + """ |
| 71 | + 1. load model |
| 72 | + """ |
| 73 | + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| 74 | + model = DetectMultiBackend(opts.model_path, device=device, dnn=False, data=YAML_DICT[opts.dataset], fp16=False) |
| 75 | + model.eval() |
| 76 | + # warm up |
| 77 | + model.warmup(imgsz=(1, 3, 640, 640)) |
| 78 | + """ |
| 79 | + 2. load dataset and track |
| 80 | + """ |
| 81 | + # track per seq |
| 82 | + # firstly, create seq list |
| 83 | + seqs = [] |
| 84 | + if opts.data_format == 'yolo': |
| 85 | + with open(f'./{opts.dataset}/test.txt', 'r') as f: |
| 86 | + lines = f.readlines() |
| 87 | + for line in lines: |
| 88 | + if line[-2] not in seqs: |
| 89 | + seqs.append(line[-2]) |
| 90 | + |
| 91 | + elif opts.data_format == 'origin': |
| 92 | + DATA_ROOT = os.path.join(DATASET_ROOT, 'VisDrone2019-MOT-test-dev/sequences') |
| 93 | + seqs = os.listdir(DATA_ROOT) |
| 94 | + else: |
| 95 | + raise NotImplementedError |
| 96 | + seqs = sorted(seqs) |
| 97 | + seqs = [seq for seq in seqs if seq not in IGNORE_SEQS] |
| 98 | + print(f'Seqs will be evalueated, total{len(seqs)}:') |
| 99 | + print(seqs) |
| 100 | + |
| 101 | + # secondly, for each seq, instantiate dataloader class and track |
| 102 | + # every time assign a different folder to store results |
| 103 | + folder_name = strftime("%Y-%d-%m %H:%M:%S", gmtime()) |
| 104 | + folder_name = folder_name[5:-3].replace('-', '_') |
| 105 | + folder_name = folder_name.replace(' ', '_') |
| 106 | + folder_name = folder_name.replace(':', '_') |
| 107 | + folder_name = opts.tracker + '_' + folder_name |
| 108 | + |
| 109 | + for seq in seqs: |
| 110 | + print(f'--------------tracking seq {seq}--------------') |
| 111 | + |
| 112 | + path = os.path.join(DATA_ROOT, seq) if opts.data_format == 'origin' else seq |
| 113 | + loader = tracker_dataloader.TrackerLoader(path, opts.img_size, opts.data_format) |
| 114 | + |
| 115 | + data_loader = torch.utils.data.DataLoader(loader, batch_size=1) |
| 116 | + |
| 117 | + tracker = TRACKER_DICT[opts.tracker](opts, frame_rate=30, gamma=opts.gamma) # instantiate tracker TODO: finish init params |
| 118 | + |
| 119 | + results = [] # store current seq results |
| 120 | + frame_id = 0 |
| 121 | + |
| 122 | + pbar = tqdm.tqdm(desc=f"{seq}", ncols=80) |
| 123 | + for i, (img, img0) in enumerate(data_loader): |
| 124 | + pbar.update() |
| 125 | + timer.tic() # start timing this img |
| 126 | + |
| 127 | + out = model(img.to(device)) # model forward |
| 128 | + |
| 129 | + out = out[0] # NOTE: for yolo v7 |
| 130 | + |
| 131 | + if len(out.shape) == 3: # case (bs, num_obj, ...) |
| 132 | + # out = out.squeeze() |
| 133 | + # NOTE: assert batch size == 1 |
| 134 | + out = out.squeeze(0) |
| 135 | + img0 = img0.squeeze(0) |
| 136 | + # remove some low conf detections |
| 137 | + out = out[out[:, 4] > 0.001] |
| 138 | + |
| 139 | + |
| 140 | + # NOTE: yolo v7 origin out format: [xc, yc, w, h, conf, cls0_conf, cls1_conf, ..., clsn_conf] |
| 141 | + if opts.det_output_format == 'yolo': |
| 142 | + cls_conf, cls_idx = torch.max(out[:, 5:], dim=1) |
| 143 | + # out[:, 4] *= cls_conf # fuse object and cls conf |
| 144 | + out[:, 5] = cls_idx |
| 145 | + out = out[:, :6] |
| 146 | + |
| 147 | + current_tracks = tracker.update(out, img0) # List[class(STracks)] |
| 148 | + |
| 149 | + # save results |
| 150 | + cur_tlwh, cur_id, cur_cls = [], [], [] |
| 151 | + for trk in current_tracks: |
| 152 | + bbox = trk.tlwh |
| 153 | + id = trk.track_id |
| 154 | + cls = trk.cls |
| 155 | + |
| 156 | + # filter low area bbox |
| 157 | + if bbox[2] * bbox[3] > opts.min_area: |
| 158 | + cur_tlwh.append(bbox) |
| 159 | + cur_id.append(id) |
| 160 | + cur_cls.append(cls) |
| 161 | + # results.append((frame_id + 1, id, bbox, cls)) |
| 162 | + |
| 163 | + results.append((frame_id + 1, cur_id, cur_tlwh, cur_cls)) |
| 164 | + timer.toc() # end timing this image |
| 165 | + |
| 166 | + if opts.save_images: |
| 167 | + plot_img(img0, frame_id, [cur_tlwh, cur_id, cur_cls], save_dir=os.path.join(DATASET_ROOT, 'reuslt_images', seq)) |
| 168 | + |
| 169 | + frame_id += 1 |
| 170 | + |
| 171 | + seq_fps.append(i / timer.total_time) # cal fps for current seq |
| 172 | + timer.clear() # clear for next seq |
| 173 | + pbar.close() |
| 174 | + # thirdly, save results |
| 175 | + # every time assign a different name |
| 176 | + save_results(folder_name, seq, results) |
| 177 | + |
| 178 | + """ |
| 179 | + 3. evaluate results |
| 180 | + """ |
| 181 | + print(f'average fps: {np.mean(seq_fps)}') |
| 182 | + evaluate(sorted(os.listdir(f'./tracker/results/{folder_name}')), |
| 183 | + sorted([seq + '.txt' for seq in seqs]), data_type='visdrone', result_folder=folder_name) |
| 184 | + |
| 185 | + """ |
| 186 | + 4. save videos |
| 187 | + """ |
| 188 | + if opts.save_videos: |
| 189 | + save_videos(seq_names='uav0000119_02301_v') |
| 190 | + |
| 191 | + |
| 192 | +def save_results(folder_name, seq_name, results, data_type='default'): |
| 193 | + """ |
| 194 | + write results to txt file |
| 195 | +
|
| 196 | + results: list row format: frame id, target id, box coordinate, class(optional) |
| 197 | + to_file: file path(optional) |
| 198 | + data_type: write data format |
| 199 | + """ |
| 200 | + assert len(results) |
| 201 | + if not data_type == 'default': |
| 202 | + raise NotImplementedError # TODO |
| 203 | + |
| 204 | + if not os.path.exists(f'./tracker/results/{folder_name}'): |
| 205 | + |
| 206 | + os.makedirs(f'./tracker/results/{folder_name}') |
| 207 | + |
| 208 | + with open(os.path.join('./tracker/results', folder_name, seq_name + '.txt'), 'w') as f: |
| 209 | + for frame_id, target_ids, tlwhs, clses in results: |
| 210 | + if data_type == 'default': |
| 211 | + |
| 212 | + # f.write(f'{frame_id},{target_id},{tlwh[0]},{tlwh[1]},\ |
| 213 | + # {tlwh[2]},{tlwh[3]},{cls}\n') |
| 214 | + for id, tlwh, cls in zip(target_ids, tlwhs, clses): |
| 215 | + f.write(f'{frame_id},{id},{tlwh[0]:.2f},{tlwh[1]:.2f},{tlwh[2]:.2f},{tlwh[3]:.2f},{int(cls)}\n') |
| 216 | + f.close() |
| 217 | + |
| 218 | + return folder_name |
| 219 | + |
| 220 | +def plot_img(img, frame_id, results, save_dir): |
| 221 | + """ |
| 222 | + img: np.ndarray: (H, W, C) |
| 223 | + frame_id: int |
| 224 | + results: [tlwhs, ids, clses] |
| 225 | + save_dir: sr |
| 226 | +
|
| 227 | + plot images with bboxes of a seq |
| 228 | + """ |
| 229 | + if not os.path.exists(save_dir): |
| 230 | + os.makedirs(save_dir) |
| 231 | + |
| 232 | + img_ = np.ascontiguousarray(np.copy(img)) |
| 233 | + |
| 234 | + tlwhs, ids, clses = results[0], results[1], results[2] |
| 235 | + for tlwh, id, cls in zip(tlwhs, ids, clses): |
| 236 | + |
| 237 | + # convert tlwh to tlbr |
| 238 | + tlbr = tuple([int(tlwh[0]), int(tlwh[1]), int(tlwh[0] + tlwh[2]), int(tlwh[1] + tlwh[3])]) |
| 239 | + # draw a rect |
| 240 | + cv2.rectangle(img_, tlbr[:2], tlbr[2:], get_color(id), thickness=1, ) |
| 241 | + # note the id and cls |
| 242 | + text = f'{CATEGORY_DICT[cls]}-{id}' |
| 243 | + cv2.putText(img_, text, (tlbr[0], tlbr[1]), fontFace=cv2.FONT_HERSHEY_PLAIN, fontScale=1, |
| 244 | + color=(255, 164, 0), thickness=1) |
| 245 | + |
| 246 | + cv2.imwrite(filename=os.path.join(save_dir, f'{frame_id:05d}.jpg'), img=img_) |
| 247 | + |
| 248 | + |
| 249 | +def save_videos(seq_names): |
| 250 | + """ |
| 251 | + convert imgs to a video |
| 252 | +
|
| 253 | + seq_names: List[str] or str, seqs that will be generated |
| 254 | + """ |
| 255 | + if not isinstance(seq_names, list): |
| 256 | + seq_names = [seq_names] |
| 257 | + |
| 258 | + for seq in seq_names: |
| 259 | + images_path = os.path.join(DATASET_ROOT, 'reuslt_images', seq) |
| 260 | + images_name = sorted(os.listdir(images_path)) |
| 261 | + |
| 262 | + to_video_path = os.path.join(images_path, '../', seq + '.mp4') |
| 263 | + fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
| 264 | + |
| 265 | + img0 = Image.open(os.path.join(images_path, images_name[0])) |
| 266 | + vw = cv2.VideoWriter(to_video_path, fourcc, 15, img0.size) |
| 267 | + |
| 268 | + for img in images_name: |
| 269 | + if img.endswith('.jpg'): |
| 270 | + frame = cv2.imread(os.path.join(images_path, img)) |
| 271 | + vw.write(frame) |
| 272 | + |
| 273 | + print('Save videos Done!!') |
| 274 | + |
| 275 | + |
| 276 | + |
| 277 | +def get_color(idx): |
| 278 | + """ |
| 279 | + aux func for plot_seq |
| 280 | + get a unique color for each id |
| 281 | + """ |
| 282 | + idx = idx * 3 |
| 283 | + color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255) |
| 284 | + |
| 285 | + return color |
| 286 | + |
| 287 | +if __name__ == '__main__': |
| 288 | + parser = argparse.ArgumentParser() |
| 289 | + |
| 290 | + parser.add_argument('--dataset', type=str, default='visdrone', help='visdrone, or mot') |
| 291 | + parser.add_argument('--data_format', type=str, default='origin', help='format of reading dataset') |
| 292 | + parser.add_argument('--det_output_format', type=str, default='yolo', help='data format of output of detector, yolo or other') |
| 293 | + |
| 294 | + parser.add_argument('--tracker', type=str, default='bytetrack', help='sort, deepsort, etc') |
| 295 | + |
| 296 | + parser.add_argument('--model_path', type=str, default=None, help='model path') |
| 297 | + |
| 298 | + parser.add_argument('--img_size', nargs='+', type=int, default=[1280, 1280], help='[train, test] image sizes') |
| 299 | + |
| 300 | + """For tracker""" |
| 301 | + # model path |
| 302 | + parser.add_argument('--reid_model_path', type=str, default='./weights/ckpt.t7', help='path for reid model path') |
| 303 | + parser.add_argument('--dhn_path', type=str, default='./weights/DHN.pth', help='path of DHN path for DeepMOT') |
| 304 | + |
| 305 | + # threshs |
| 306 | + parser.add_argument('--conf_thresh', type=float, default=0.5, help='filter tracks') |
| 307 | + parser.add_argument('--nms_thresh', type=float, default=0.7, help='thresh for NMS') |
| 308 | + parser.add_argument('--iou_thresh', type=float, default=0.5, help='IOU thresh to filter tracks') |
| 309 | + |
| 310 | + # other options |
| 311 | + parser.add_argument('--track_buffer', type=int, default=30, help='tracking buffer') |
| 312 | + parser.add_argument('--gamma', type=float, default=0.1, help='param to control fusing motion and apperance dist') |
| 313 | + parser.add_argument('--kalman_format', type=str, default='default', help='use what kind of Kalman, default, naive, strongsort or bot-sort like') |
| 314 | + parser.add_argument('--min_area', type=float, default=150, help='use to filter small bboxs') |
| 315 | + |
| 316 | + parser.add_argument('--save_images', action='store_true', help='save tracking results (image)') |
| 317 | + parser.add_argument('--save_videos', action='store_true', help='save tracking results (video)') |
| 318 | + |
| 319 | + |
| 320 | + opts = parser.parse_args() |
| 321 | + |
| 322 | + # for debug |
| 323 | + # evaluate(sorted(os.listdir('./tracker/results/deepmot_17_08_02_38')), |
| 324 | + # sorted(os.listdir('./tracker/results/deepmot_17_08_02_38')), data_type='visdrone', result_folder='deepmot_17_08_02_38') |
| 325 | + main(opts) |
0 commit comments