Skip to content

Commit d46901e

Browse files
committed
Update and fix SORT
1 parent 4b56173 commit d46901e

File tree

7 files changed

+159
-16
lines changed

7 files changed

+159
-16
lines changed

examples/example_scripts/mot_TF_SSDMobileNet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,16 @@ def main(video_path, weights_path, config_path, use_gpu, tracker):
2323
print("Cannot read the video feed.")
2424
break
2525

26+
image = cv.resize(image, (700, 500))
27+
2628
bboxes, confidences, class_ids = model.detect(image)
2729
tracks = tracker.update(bboxes, confidences, class_ids)
2830
updated_image = model.draw_bboxes(image.copy(), bboxes, confidences, class_ids)
2931

3032
updated_image = draw_tracks(updated_image, tracks)
3133

3234
cv.imshow("image", updated_image)
33-
if cv.waitKey(0) & 0xFF == ord('q'):
35+
if cv.waitKey(1) & 0xFF == ord('q'):
3436
break
3537

3638
cap.release()
@@ -75,7 +77,7 @@ def main(video_path, weights_path, config_path, use_gpu, tracker):
7577
elif args.tracker == 'CentroidKF_Tracker':
7678
tracker = CentroidKF_Tracker(max_lost=0, tracker_output_format='mot_challenge')
7779
elif args.tracker == 'SORT':
78-
tracker = SORT(max_lost=2, tracker_output_format='mot_challenge', iou_threshold=0.5, time_step=1)
80+
tracker = SORT(max_lost=3, tracker_output_format='mot_challenge', iou_threshold=0.3)
7981
else:
8082
raise NotImplementedError
8183

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import cv2 as cv
2+
from motrackers.detectors import YOLOv3
3+
from motrackers import CentroidTracker, CentroidKF_Tracker, SORT
4+
from motrackers.utils import draw_tracks
5+
6+
7+
def main(video_path, weights_path, config_path, labels_path, use_gpu, tracker):
8+
model = YOLOv3(
9+
weights_path=weights_path,
10+
configfile_path=config_path,
11+
labels_path=labels_path,
12+
confidence_threshold=0.5,
13+
nms_threshold=0.2,
14+
draw_bboxes=True,
15+
use_gpu=use_gpu
16+
)
17+
18+
cap = cv.VideoCapture(video_path)
19+
while True:
20+
ok, image = cap.read()
21+
22+
if not ok:
23+
print("Cannot read the video feed.")
24+
break
25+
26+
image = cv.resize(image, (700, 500))
27+
28+
bboxes, confidences, class_ids = model.detect(image)
29+
tracks = tracker.update(bboxes, confidences, class_ids)
30+
updated_image = model.draw_bboxes(image.copy(), bboxes, confidences, class_ids)
31+
32+
updated_image = draw_tracks(updated_image, tracks)
33+
34+
cv.imshow("image", updated_image)
35+
if cv.waitKey(1) & 0xFF == ord('q'):
36+
break
37+
38+
cap.release()
39+
cv.destroyAllWindows()
40+
41+
42+
if __name__ == '__main__':
43+
import argparse
44+
45+
parser = argparse.ArgumentParser(
46+
description='Object detections in input video using YOLOv3 trained on COCO dataset.'
47+
)
48+
49+
parser.add_argument(
50+
'--video', '-v', type=str, default="./../video_data/cars.mp4", help='Input video path.')
51+
52+
parser.add_argument(
53+
'--weights', '-w', type=str,
54+
default="./../pretrained_models/yolo_weights/yolov3.weights",
55+
help='path to weights file of YOLOv3 (`.weights` file.)'
56+
)
57+
58+
parser.add_argument(
59+
'--config', '-c', type=str,
60+
default="./../pretrained_models/yolo_weights/yolov3.cfg",
61+
help='path to config file of YOLOv3 (`.cfg` file.)'
62+
)
63+
64+
parser.add_argument(
65+
'--labels', '-l', type=str,
66+
default="./../pretrained_models/yolo_weights/coco.names",
67+
help='path to labels file of coco dataset (`.names` file.)'
68+
)
69+
70+
parser.add_argument(
71+
'--gpu', type=bool,
72+
default=False, help='Flag to use gpu to run the deep learning model. Default is `False`'
73+
)
74+
75+
parser.add_argument(
76+
'--tracker', type=str, default='CentroidTracker',
77+
help="Tracker used to track objects. Options include ['CentroidTracker', 'CentroidKF_Tracker', 'SORT']")
78+
79+
args = parser.parse_args()
80+
81+
if args.tracker == 'CentroidTracker':
82+
tracker = CentroidTracker(max_lost=3, tracker_output_format='mot_challenge')
83+
elif args.tracker == 'CentroidKF_Tracker':
84+
tracker = CentroidKF_Tracker(max_lost=3, tracker_output_format='mot_challenge')
85+
elif args.tracker == 'SORT':
86+
tracker = SORT(max_lost=3, tracker_output_format='mot_challenge', iou_threshold=0.3, time_step=1)
87+
else:
88+
raise NotImplementedError
89+
90+
main(args.video, args.weights, args.config, args.labels, args.gpu, tracker)

motrackers/kalman_tracker.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,14 @@ def __init__(self, bbox, process_noise_scale=1.0, measurement_noise_scale=1.0, t
224224
[0, 1, 0, 0, 0, 1, 0],
225225
[0, 0, 1, 0, 0, 0, 1]]) * process_noise_scale
226226

227+
process_noise_covariance[-1, -1] *= 0.01
228+
process_noise_covariance[4:, 4:] *= 0.01
229+
227230
measurement_noise_covariance = np.eye(4) * measurement_noise_scale
231+
measurement_noise_covariance[2:, 2:] *= 0.01
228232

229233
prediction_covariance = np.ones_like(transition_matrix) * 10.
234+
prediction_covariance[4:, 4:] *= 100.
230235

231236
initial_state = np.array([bbox[0], bbox[1], bbox[2], bbox[3], 0., 0., 0.])
232237

motrackers/sort_tracker.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
from scipy.optimize import linear_sum_assignment
33
from motrackers.utils.misc import iou_xywh as iou
4-
from motrackers.track import KFTrackSORT
4+
from motrackers.track import KFTrackSORT, KFTrack4DSORT
55
from motrackers.centroid_kf_tracker import CentroidKF_Tracker
66

77

@@ -36,7 +36,6 @@ def assign_tracks2detection_iou(bbox_tracks, bbox_detections, iou_threshold=0.3)
3636
bbox_detections = bbox_detections[None, :]
3737

3838
iou_matrix = np.zeros((bbox_tracks.shape[0], bbox_detections.shape[0]), dtype=np.float32)
39-
4039
for t in range(bbox_tracks.shape[0]):
4140
for d in range(bbox_detections.shape[0]):
4241
iou_matrix[t, d] = iou(bbox_tracks[t, :], bbox_detections[d, :])
@@ -107,27 +106,42 @@ def __init__(
107106
)
108107

109108
def _add_track(self, frame_id, bbox, detection_confidence, class_id, **kwargs):
110-
self.tracks[self.next_track_id] = KFTrackSORT(
109+
# self.tracks[self.next_track_id] = KFTrackSORT(
110+
# self.next_track_id, frame_id, bbox, detection_confidence, class_id=class_id,
111+
# data_output_format=self.tracker_output_format, process_noise_scale=self.process_noise_scale,
112+
# measurement_noise_scale=self.measurement_noise_scale, **kwargs
113+
# )
114+
self.tracks[self.next_track_id] = KFTrack4DSORT(
111115
self.next_track_id, frame_id, bbox, detection_confidence, class_id=class_id,
112116
data_output_format=self.tracker_output_format, process_noise_scale=self.process_noise_scale,
113-
measurement_noise_scale=self.measurement_noise_scale, **kwargs
114-
)
117+
measurement_noise_scale=self.measurement_noise_scale, kf_time_step=1, **kwargs)
115118
self.next_track_id += 1
116119

117120
def update(self, bboxes, detection_scores, class_ids):
118121
self.frame_count += 1
119122

120123
bbox_detections = np.array(bboxes, dtype='int')
121124

125+
# track_ids_all = list(self.tracks.keys())
126+
# bbox_tracks = []
127+
# track_ids = []
128+
# for track_id in track_ids_all:
129+
# bb = self.tracks[track_id].predict()
130+
# if np.any(np.isnan(bb)):
131+
# self._remove_track(track_id)
132+
# else:
133+
# track_ids.append(track_id)
134+
# bbox_tracks.append(bb)
135+
122136
track_ids = list(self.tracks.keys())
123137
bbox_tracks = []
124138
for track_id in track_ids:
125-
bbox_tracks.append(self.tracks[track_id].predict())
126-
bbox_tracks = np.array(bbox_tracks)
139+
bb = self.tracks[track_id].predict()
140+
bbox_tracks.append(bb)
127141

142+
bbox_tracks = np.array(bbox_tracks)
128143
matches, unmatched_detections, unmatched_tracks = assign_tracks2detection_iou(
129-
bbox_tracks, bbox_detections, iou_threshold=0.3
130-
)
144+
bbox_tracks, bbox_detections, iou_threshold=0.3)
131145

132146
for i in range(matches.shape[0]):
133147
t, d = matches[i, :]

motrackers/track.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from motrackers.kalman_tracker import KFTrackerSORT, KFTracker2D
2+
from motrackers.kalman_tracker import KFTracker2D, KFTrackerSORT, KFTracker4D
33

44

55
class Track:
@@ -183,9 +183,16 @@ def __init__(self, track_id, frame_id, bbox, detection_confidence, class_id=None
183183
iou_score=iou_score, data_output_format=data_output_format, **kwargs)
184184

185185
def predict(self):
186+
if (self.kf.x[6] + self.kf.x[2]) <= 0:
187+
self.kf.x[6] *= 0.0
188+
186189
x = self.kf.predict()
190+
191+
if x[2] * x[3] < 0:
192+
return np.array([np.nan, np.nan, np.nan, np.nan])
193+
187194
w = np.sqrt(x[2] * x[3])
188-
h = x[2] / w
195+
h = x[2] / float(w)
189196
bb = np.array([x[0]-0.5*w, x[1]-0.5*h, w, h])
190197
return bb
191198

@@ -196,6 +203,30 @@ def update(self, frame_id, bbox, detection_confidence, class_id=None, lost=0, io
196203
self.kf.update(z)
197204

198205

206+
class KFTrack4DSORT(Track):
207+
"""
208+
Track based on Kalman filter tracker used for SORT MOT-Algorithm.
209+
"""
210+
def __init__(self, track_id, frame_id, bbox, detection_confidence, class_id=None, lost=0, iou_score=0.,
211+
data_output_format='mot_challenge', process_noise_scale=1.0, measurement_noise_scale=1.0,
212+
kf_time_step=1, **kwargs):
213+
self.kf = KFTracker4D(
214+
bbox.copy(), process_noise_scale=process_noise_scale, measurement_noise_scale=measurement_noise_scale,
215+
time_step=kf_time_step)
216+
super().__init__(track_id, frame_id, bbox, detection_confidence, class_id=class_id, lost=lost,
217+
iou_score=iou_score, data_output_format=data_output_format, **kwargs)
218+
219+
def predict(self):
220+
x = self.kf.predict()
221+
bb = np.array([x[0], x[3], x[6], x[9]])
222+
return bb
223+
224+
def update(self, frame_id, bbox, detection_confidence, class_id=None, lost=0, iou_score=0., **kwargs):
225+
super().update(
226+
frame_id, bbox, detection_confidence, class_id=class_id, lost=lost, iou_score=iou_score, **kwargs)
227+
self.kf.update(bbox.copy())
228+
229+
199230
class KFTrackCentroid(Track):
200231
"""
201232
Track based on Kalman filter used for Centroid Tracking of bounding box in MOT.

motrackers/tracker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def _get_tracks(tracks):
115115

116116
outputs = []
117117
for trackid, track in tracks.items():
118-
outputs.append(track.output())
118+
if not track.lost:
119+
outputs.append(track.output())
119120
return outputs
120121

121122
@staticmethod

motrackers/utils/misc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ def iou_xywh(bbox1, bbox2):
9898
iou: float
9999
intersection-over-onion of bbox1, bbox2.
100100
"""
101-
bbox1 = bbox1[0], bbox1[1], bbox1[2]+bbox1[0], bbox1[3]+bbox1[0]
102-
bbox2 = bbox2[0], bbox2[1], bbox2[2] + bbox2[0], bbox2[3] + bbox2[0]
101+
bbox1 = bbox1[0], bbox1[1], bbox1[0]+bbox1[2], bbox1[1]+bbox1[3]
102+
bbox2 = bbox2[0], bbox2[1], bbox2[0]+bbox2[2], bbox2[1]+bbox2[3]
103103

104104
iou_ = iou(bbox1, bbox2)
105105

0 commit comments

Comments
 (0)