11from collections import OrderedDict
22import numpy as np
33from scipy .spatial import distance
4- from motrackers .utils .misc import get_centroid
4+ from motrackers .utils .misc import get_centroids
55from motrackers .track import Track
66
77
@@ -24,7 +24,7 @@ def __init__(self, max_lost=5):
2424 self .max_lost = max_lost
2525 self .frame_count = 0
2626
27- def _add_object (self , centroid , bbox , class_id ):
27+ def _add_track (self , centroid , bbox , class_id , ** kwargs ):
2828 """
2929 Add a newly detected object to the queue
3030
@@ -46,6 +46,9 @@ class label
4646 centroid = centroid ,
4747 bbox = bbox ,
4848 class_id = class_id )
49+ for key , value in kwargs .items ():
50+ self .tracks [self .next_track_id ].info [key ] = value
51+
4952 self .next_track_id += 1
5053
5154 def _remove_track (self , track_id ):
@@ -63,7 +66,14 @@ def _remove_track(self, track_id):
6366 """
6467 del self .tracks [track_id ]
6568
66- def get_tracks (self , tracks ):
69+ def _update_track (self , track_id , centroid , bbox , ** kwargs ):
70+ self .tracks [track_id ].centroid = centroid
71+ self .tracks [track_id ].bbox = bbox
72+ self .tracks [track_id ].lost = 0
73+ for key , value in kwargs .items ():
74+ self .tracks [track_id ].info [key ] = value
75+
76+ def _get_tracks (self , tracks ):
6777 """
6878 Output the information of tracks
6979
@@ -95,7 +105,7 @@ def get_tracks(self, tracks):
95105
96106 return outputs
97107
98- def update (self , bboxes : list , class_ids : list ):
108+ def update (self , bboxes : list , class_ids : list , detection_scores : list ):
99109 """
100110 Update the tracker based on the new bboxes as input.
101111
@@ -106,6 +116,8 @@ def update(self, bboxes: list, class_ids: list):
106116 coordinates of bounding box as tuple (top-left-x, top-left-y, bottom-right-x, bottom-right-y).
107117 class_ids : list
108118 List of class_ids (int) corresponding to labels of the detected object. Default is `None`.
119+ detection_scores: list
120+ List of detection scores / probability of each detected object or objectness.
109121
110122 Returns
111123 -------
@@ -122,39 +134,45 @@ def update(self, bboxes: list, class_ids: list):
122134 """
123135 self .frame_count += 1
124136
137+ new_bboxes = np .array (bboxes , dtype = 'int' )
138+ new_class_ids = np .array (class_ids , dtype = 'int' )
139+ new_detection_scores = np .array (detection_scores )
140+
141+ new_centroids = get_centroids (new_bboxes )
142+
143+ new_detections = list (zip (
144+ range (len (bboxes )), new_bboxes , new_class_ids , new_centroids , new_detection_scores
145+ ))
146+
125147 if len (bboxes ) == 0 : # if no object detected
126148 lost_ids = list (self .tracks .keys ())
127149 for track_id in lost_ids :
128150 self .tracks [track_id ].lost += 1
129151 if self .tracks [track_id ].lost > self .max_lost :
130152 self ._remove_track (track_id )
131153
132- outputs = self .get_tracks (self .tracks )
154+ outputs = self ._get_tracks (self .tracks )
133155 return outputs
134156
135- new_class_ids = np .array (class_ids , dtype = 'int' )
136- new_centroids = np .zeros ((len (bboxes ), 2 ), dtype = "int" )
137- for (i , bbox ) in enumerate (bboxes ):
138- new_centroids [i ] = get_centroid (bbox )
139-
140- if len (self .tracks ):
141- track_ids = list (self .tracks .keys ())
157+ track_ids = list (self .tracks .keys ())
158+ if len (track_ids ):
142159 old_centroids = np .array ([self .tracks [tid ].centroid for tid in track_ids ])
143- D = distance .cdist (old_centroids , new_centroids ) # (row, col) = distance between old (row) and new (col)
144- row_idxs = D .min (axis = 1 ).argsort () # old tracks sorted as per min distance from new
145- col_idxs = D .argmin (axis = 1 )[row_idxs ] # new tracks sorted as per min distance from old
160+ D = distance .cdist (old_centroids , new_centroids ) # (row, col) = distance between old (row) and new (col)
161+
162+ row_idxs = D .min (axis = 1 ).argsort () # old tracks sorted as per min distance from new
163+ col_idxs = D .argmin (axis = 1 )[row_idxs ] # new tracks sorted as per min distance from old
146164
147165 assigned_rows , assigned_cols = set (), set ()
148166 for (row_idx , col_idx ) in zip (row_idxs , col_idxs ):
149167 if row_idx in assigned_rows or col_idx in assigned_cols :
150168 continue
151169
152170 track_id = track_ids [row_idx ]
171+
172+ col_idx , bbox , class_id , centroid , detection_score = new_detections [col_idx ]
153173
154- if self .tracks [track_id ].class_id == new_class_ids [col_idx ]:
155- self .tracks [track_id ].centroid = new_centroids [col_idx ]
156- self .tracks [track_id ].bbox = bboxes [col_idx ]
157- self .tracks [track_id ].lost = 0
174+ if self .tracks [track_id ].class_id == class_id :
175+ self ._update_track (track_id , centroid , bbox , score = detection_score )
158176 assigned_rows .add (row_idx )
159177 assigned_cols .add (col_idx )
160178
@@ -170,10 +188,10 @@ def update(self, bboxes: list, class_ids: list):
170188 self ._remove_track (track_id )
171189 else :
172190 for col_idx in unassigned_cols :
173- self ._add_object (new_centroids [col_idx ], bboxes [col_idx ], class_ids [col_idx ])
191+ self ._add_track (new_centroids [col_idx ], bboxes [col_idx ], class_ids [col_idx ])
174192 else :
175193 for i in range (0 , len (bboxes )):
176- self ._add_object (new_centroids [i ], bboxes [i ], class_ids [i ])
194+ self ._add_track (new_centroids [i ], bboxes [i ], class_ids [i ])
177195
178- outputs = self .get_tracks (self .tracks )
196+ outputs = self ._get_tracks (self .tracks )
179197 return outputs
0 commit comments