Skip to content

Commit 3658198

Browse files
committed
add more reid models and fix bugs
1 parent c3d0c2f commit 3658198

30 files changed

+1247
-558
lines changed

tracker/track.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,14 @@ def get_args():
8585
parser.add_argument('--dataset', type=str, default='visdrone_part', help='visdrone, mot17, etc.')
8686
parser.add_argument('--detector', type=str, default='yolo_ultralytics_v8', help='yolov7, yolox, etc.')
8787
parser.add_argument('--tracker', type=str, default='sort', help='sort, deepsort, etc')
88+
parser.add_argument('--reid', action='store_true', help='enable reid model, work in bot, byte, ocsort and hybridsort')
8889
parser.add_argument('--reid_model', type=str, default='osnet_x0_25', help='osnet or deppsort')
8990

9091
parser.add_argument('--kalman_format', type=str, default='default', help='use what kind of Kalman, sort, deepsort, byte, etc.')
9192
parser.add_argument('--img_size', type=int, default=1280, help='image size, [h, w]')
9293

93-
parser.add_argument('--conf_thresh', type=float, default=0.2, help='filter tracks')
94+
parser.add_argument('--conf_thresh', type=float, default=0.2, help='filter detections')
95+
parser.add_argument('--conf_thresh_low', type=float, default=0.1, help='filter low conf detections, used in two-stage association')
9496
parser.add_argument('--nms_thresh', type=float, default=0.7, help='thresh for NMS')
9597
parser.add_argument('--iou_thresh', type=float, default=0.5, help='IOU thresh to filter tracks')
9698

@@ -108,7 +110,7 @@ def get_args():
108110

109111

110112
"""other options"""
111-
parser.add_argument('--discard_reid', action='store_true', help='discard reid model, only work in bot-sort etc. which need a reid part')
113+
parser.add_argument('--fuse_detection_score', action='store_true', help='fuse detection conf with iou score')
112114
parser.add_argument('--track_buffer', type=int, default=30, help='tracking buffer')
113115
parser.add_argument('--gamma', type=float, default=0.1, help='param to control fusing motion and apperance dist')
114116
parser.add_argument('--min_area', type=float, default=150, help='use to filter small bboxs')

tracker/track_demo.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def get_args():
8585

8686
parser.add_argument('--detector', type=str, default='yolo_ultralytics_v8', help='yolov7, yolox, etc.')
8787
parser.add_argument('--tracker', type=str, default='sort', help='sort, deepsort, etc')
88+
parser.add_argument('--reid', action='store_true', help='enable reid model, work in bot, byte, ocsort and hybridsort')
8889
parser.add_argument('--reid_model', type=str, default='osnet_x0_25', help='osnet or deppsort')
8990

9091
parser.add_argument('--kalman_format', type=str, default='default', help='use what kind of Kalman, sort, deepsort, byte, etc.')
@@ -109,7 +110,7 @@ def get_args():
109110

110111

111112
"""other options"""
112-
parser.add_argument('--discard_reid', action='store_true', help='discard reid model, only work in bot-sort etc. which need a reid part')
113+
parser.add_argument('--fuse_detection_score', action='store_true', help='fuse detection conf with iou score')
113114
parser.add_argument('--track_buffer', type=int, default=30, help='tracking buffer')
114115
parser.add_argument('--gamma', type=float, default=0.1, help='param to control fusing motion and apperance dist')
115116
parser.add_argument('--min_area', type=float, default=150, help='use to filter small bboxs')
@@ -120,6 +121,8 @@ def get_args():
120121

121122
parser.add_argument('--track_eval', type=bool, default=True, help='Use TrackEval to evaluate')
122123

124+
"""camera parameter"""
125+
parser.add_argument('--camera_parameter_folder', type=str, default='./tracker/cam_param_files', help='folder path of camera parameter files')
123126
return parser.parse_args()
124127

125128
def main(args):

tracker/trackers/basetrack.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ def end_frame(self):
3535
def next_id():
3636
BaseTrack._count += 1
3737
return BaseTrack._count
38+
39+
@staticmethod
40+
def clear_count():
41+
BaseTrack._count = 0
3842

3943
def activate(self, *args):
4044
raise NotImplementedError
@@ -106,12 +110,6 @@ def tlwh_to_xysa(tlwh):
106110
ret[3] = tlwh[2] / tlwh[3]
107111
return ret
108112

109-
def to_xyah(self):
110-
return self.tlwh_to_xyah(self.tlwh)
111-
112-
def to_xywh(self):
113-
return self.tlwh_to_xywh(self.tlwh)
114-
115113
@staticmethod
116114
def tlbr_to_tlwh(tlbr):
117115
ret = np.asarray(tlbr).copy()
@@ -124,6 +122,12 @@ def tlwh_to_tlbr(tlwh):
124122
ret = np.asarray(tlwh).copy()
125123
ret[2:] += ret[:2]
126124
return ret
125+
126+
def to_xyah(self):
127+
return self.tlwh_to_xyah(self.tlwh)
128+
129+
def to_xywh(self):
130+
return self.tlwh_to_xywh(self.tlwh)
127131

128132
def __repr__(self):
129133
return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)

tracker/trackers/botsort_tracker.py

Lines changed: 25 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -13,36 +13,10 @@
1313
from .tracklet import Tracklet, Tracklet_w_reid
1414
from .matching import *
1515

16-
from .reid_models.OSNet import *
17-
from .reid_models.load_model_tools import load_pretrained_weights
18-
from .reid_models.deepsort_reid import Extractor
19-
20-
from .camera_motion_compensation import GMC
21-
22-
REID_MODEL_DICT = {
23-
'osnet_x1_0': osnet_x1_0,
24-
'osnet_x0_75': osnet_x0_75,
25-
'osnet_x0_5': osnet_x0_5,
26-
'osnet_x0_25': osnet_x0_25,
27-
'deepsort': Extractor
28-
}
29-
30-
31-
def load_reid_model(reid_model, reid_model_path):
32-
33-
if 'osnet' in reid_model:
34-
func = REID_MODEL_DICT[reid_model]
35-
model = func(num_classes=1, pretrained=False, )
36-
load_pretrained_weights(model, reid_model_path)
37-
model.cuda().eval()
38-
39-
elif 'deepsort' in reid_model:
40-
model = REID_MODEL_DICT[reid_model](reid_model_path, use_cuda=True)
16+
# for reid
17+
from .reid_models.engine import load_reid_model, crop_and_resize, select_device
4118

42-
else:
43-
raise NotImplementedError
44-
45-
return model
19+
from .camera_motion_compensation.cmc import GMC
4620

4721
class BotTracker(object):
4822
def __init__(self, args, frame_rate=30):
@@ -59,60 +33,34 @@ def __init__(self, args, frame_rate=30):
5933

6034
self.motion = args.kalman_format
6135

62-
self.with_reid = not args.discard_reid
36+
self.with_reid = args.reid
6337

64-
self.reid_model, self.crop_transforms = None, None
38+
self.reid_model = None
6539
if self.with_reid:
66-
self.reid_model = load_reid_model(args.reid_model, args.reid_model_path)
67-
self.crop_transforms = T.Compose([
68-
# T.ToPILImage(),
69-
# T.Resize(size=(256, 128)),
70-
T.ToTensor(), # (c, 128, 256)
71-
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
72-
])
73-
40+
self.reid_model = load_reid_model(args.reid_model, args.reid_model_path, device=args.device)
41+
self.reid_model.eval()
7442

7543
# camera motion compensation module
7644
self.gmc = GMC(method='orb', downscale=2, verbose=None)
7745

78-
def reid_preprocess(self, obj_bbox):
79-
"""
80-
preprocess cropped object bboxes
81-
82-
obj_bbox: np.ndarray, shape=(h_obj, w_obj, c)
83-
84-
return:
85-
torch.Tensor of shape (c, 128, 256)
86-
"""
87-
obj_bbox = cv2.resize(obj_bbox.astype(np.float32) / 255.0, dsize=(128, 128)) # shape: (128, 256, c)
88-
89-
return self.crop_transforms(obj_bbox)
46+
# once init, clear all trackid count to avoid large id
47+
BaseTrack.clear_count()
9048

49+
@torch.no_grad()
9150
def get_feature(self, tlwhs, ori_img):
9251
"""
9352
get apperance feature of an object
9453
tlwhs: shape (num_of_objects, 4)
9554
ori_img: original image, np.ndarray, shape(H, W, C)
9655
"""
97-
obj_bbox = []
98-
99-
for tlwh in tlwhs:
100-
tlwh = list(map(int, tlwh))
101-
# if any(tlbr_ == -1 for tlbr_ in tlwh):
102-
# print(tlwh)
103-
104-
tlbr_tensor = self.reid_preprocess(ori_img[tlwh[1]: tlwh[1] + tlwh[3], tlwh[0]: tlwh[0] + tlwh[2]])
105-
obj_bbox.append(tlbr_tensor)
106-
107-
if not obj_bbox:
108-
return np.array([])
109-
110-
obj_bbox = torch.stack(obj_bbox, dim=0)
111-
obj_bbox = obj_bbox.cuda()
112-
113-
features = self.reid_model(obj_bbox) # shape: (num_of_objects, feature_dim)
114-
return features.cpu().detach().numpy()
11556

57+
if tlwhs.size == 0:
58+
return np.empty((0, 512))
59+
60+
crop_bboxes = crop_and_resize(tlwhs, ori_img, input_format='tlwh', sz=(64, 128))
61+
features = self.reid_model(crop_bboxes).cpu().numpy()
62+
63+
return features
11664

11765
def update(self, output_results, img, ori_img):
11866
"""
@@ -181,10 +129,13 @@ def update(self, output_results, img, ori_img):
181129
ious_dists = iou_distance(tracklet_pool, detections)
182130
ious_dists_mask = (ious_dists > 0.5) # high conf iou
183131

132+
# fuse detection conf into iou dist
133+
if self.args.fuse_detection_score:
134+
ious_dists = fuse_det_score(ious_dists, detections)
135+
184136
if self.with_reid:
185137
# mixed cost matrix
186138
emb_dists = embedding_distance(tracklet_pool, detections) / 2.0
187-
raw_emb_dists = emb_dists.copy()
188139
emb_dists[emb_dists > 0.25] = 1.0
189140
emb_dists[ious_dists_mask] = 1.0
190141
dists = np.minimum(ious_dists, emb_dists)
@@ -238,9 +189,12 @@ def update(self, output_results, img, ori_img):
238189
ious_dists = iou_distance(unconfirmed, detections)
239190
ious_dists_mask = (ious_dists > 0.5)
240191

192+
# fuse detection conf into iou dist
193+
if self.args.fuse_detection_score:
194+
ious_dists = fuse_det_score(ious_dists, detections)
195+
241196
if self.with_reid:
242197
emb_dists = embedding_distance(unconfirmed, detections) / 2.0
243-
raw_emb_dists = emb_dists.copy()
244198
emb_dists[emb_dists > 0.25] = 1.0
245199
emb_dists[ious_dists_mask] = 1.0
246200
dists = np.minimum(ious_dists, emb_dists)

tracker/trackers/byte_tracker.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,14 @@
55
import numpy as np
66
from collections import deque
77
from .basetrack import BaseTrack, TrackState
8-
from .tracklet import Tracklet
8+
from .tracklet import Tracklet, Tracklet_w_reid
99
from .matching import *
1010

11+
# for reid
12+
import torch
13+
import torchvision.transforms as T
14+
from .reid_models.engine import load_reid_model, crop_and_resize
15+
1116
class ByteTracker(object):
1217
def __init__(self, args, frame_rate=30):
1318
self.tracked_tracklets = [] # type: list[Tracklet]
@@ -23,6 +28,31 @@ def __init__(self, args, frame_rate=30):
2328

2429
self.motion = args.kalman_format
2530

31+
# whether to use reid
32+
self.with_reid = args.reid
33+
self.reid_model = None
34+
if self.with_reid:
35+
self.reid_model = load_reid_model(args.reid_model, args.reid_model_path, device=args.device)
36+
37+
# once init, clear all trackid count to avoid large id
38+
BaseTrack.clear_count()
39+
40+
@torch.no_grad()
41+
def get_feature(self, tlwhs, ori_img):
42+
"""
43+
get apperance feature of an object
44+
tlwhs: shape (num_of_objects, 4)
45+
ori_img: original image, np.ndarray, shape(H, W, C)
46+
"""
47+
48+
if tlwhs.size == 0:
49+
return np.empty((0, 512))
50+
51+
crop_bboxes = crop_and_resize(tlwhs, ori_img, input_format='tlwh', sz=(64, 128))
52+
features = self.reid_model(crop_bboxes).cpu().numpy()
53+
54+
return features
55+
2656
def update(self, output_results, img, ori_img):
2757
"""
2858
output_results: processed detections (scale to original size) tlbr format
@@ -39,7 +69,7 @@ def update(self, output_results, img, ori_img):
3969
categories = output_results[:, -1]
4070

4171
remain_inds = scores > self.args.conf_thresh
42-
inds_low = scores > 0.1
72+
inds_low = scores > self.args.conf_thresh_low
4373
inds_high = scores < self.args.conf_thresh
4474

4575
inds_second = np.logical_and(inds_low, inds_high)
@@ -52,10 +82,17 @@ def update(self, output_results, img, ori_img):
5282
scores_keep = scores[remain_inds]
5383
scores_second = scores[inds_second]
5484

85+
"""Step 1: Extract reid features"""
86+
if self.with_reid:
87+
features_keep = self.get_feature(tlwhs=dets[:, :4], ori_img=ori_img)
88+
5589
if len(dets) > 0:
56-
'''Detections'''
57-
detections = [Tracklet(tlwh, s, cate, motion=self.motion) for
58-
(tlwh, s, cate) in zip(dets, scores_keep, cates)]
90+
if self.with_reid:
91+
detections = [Tracklet_w_reid(tlwh, s, cate, motion=self.motion, feat=feat) for
92+
(tlwh, s, cate, feat) in zip(dets, scores_keep, cates, features_keep)]
93+
else:
94+
detections = [Tracklet(tlwh, s, cate, motion=self.motion) for
95+
(tlwh, s, cate) in zip(dets, scores_keep, cates)]
5996
else:
6097
detections = []
6198

@@ -76,6 +113,16 @@ def update(self, output_results, img, ori_img):
76113
tracklet.predict()
77114

78115
dists = iou_distance(tracklet_pool, detections)
116+
117+
# fuse detection conf into iou dist
118+
if self.args.fuse_detection_score:
119+
dists = fuse_det_score(dists, detections)
120+
121+
if self.with_reid:
122+
# eq. 11 in Bot-SORT paper, i.e., the common method of
123+
# fusing reid and motion. you can adjust the weight here
124+
emb_dists = embedding_distance(tracklet_pool, detections)
125+
dists = 0.9 * dists + 0.1 * emb_dists
79126

80127
matches, u_track, u_detection = linear_assignment(dists, thresh=0.9)
81128

@@ -119,6 +166,10 @@ def update(self, output_results, img, ori_img):
119166
'''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
120167
detections = [detections[i] for i in u_detection]
121168
dists = iou_distance(unconfirmed, detections)
169+
170+
# fuse detection conf into iou dist
171+
if self.args.fuse_detection_score:
172+
dists = fuse_det_score(dists, detections)
122173

123174
matches, u_unconfirmed, u_detection = linear_assignment(dists, thresh=0.7)
124175

tracker/trackers/c_biou_tracker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ def __init__(self, args, frame_rate=30):
2323

2424
self.motion = args.kalman_format
2525

26+
# once init, clear all trackid count to avoid large id
27+
BaseTrack.clear_count()
28+
2629
def update(self, output_results, img, ori_img):
2730
"""
2831
output_results: processed detections (scale to original size) tlbr format

0 commit comments

Comments
 (0)