1+ """
2+ Unofficial Implementation of paper
3+ Hard to Track Objects with Irregular Motions and Similar Appearances? Make It Easier by Buffering the Matching Space(arxiv 2212)
4+ """
5+
6+ import numpy as np
7+ from collections import deque
8+ from basetrack import TrackState , BaseTrack , BaseTracker
9+ import matching
10+ import torch
11+ from torchvision .ops import nms
12+
13+
14+ """
15+ Because the paper drops Kalman so we rewrite the STrack class and rename it as C_BIoUSTrack
16+ """
17+ class C_BIoUSTrack (BaseTrack ):
18+ def __init__ (self , cls , tlwh , score ) -> None :
19+ """
20+ cls: category of this obj
21+ tlwh: positoin score: conf score
22+ """
23+ super ().__init__ ()
24+ self .cls = cls
25+ self ._tlwh = np .asarray (tlwh , dtype = np .float32 ) # init tlwh
26+ self .score = score
27+
28+ self .is_activated = False
29+ self .tracklet_len = 0
30+
31+ self .track_id = None
32+ self .start_frame = None
33+ self .frame_id = None
34+ self .time_since_update = 0 # \delta in paper, use to calculate motion state
35+
36+ # params in motion state
37+ self .b1 , self .b2 , self .n = 0.3 , 0.5 , 2
38+ self .origin_bbox_buffer = deque () # a deque store the original bbox(tlwh) from t - self.n to t, where t is the last time detected
39+ # buffered bbox, two buffer sizes
40+ self .buffer_bbox1 = self .get_buffer_bbox (level = 1 )
41+ self .buffer_bbox2 = self .get_buffer_bbox (level = 2 )
42+ # motion state, s^{t + \delta} = o^t + (\delta / n) * \sum_{i=t-n+1}^t(o^i - o^{i-1}) = o^t + (\delta / n) * (o^t - o^{t - n})
43+ self .motion_state1 = self .buffer_bbox1 .copy ()
44+ self .motion_state2 = self .buffer_bbox2 .copy ()
45+
46+
47+ def get_buffer_bbox (self , level = 1 , bbox = None ):
48+ """
49+ get buffered bbox as: (top, left, w, h) -> (top - bw, y - bh, w + 2bw, h + 2bh)
50+ level = 1: b = self.b1 level = 2: b = self.b2
51+ bbox: if not None, use bbox to calculate buffer_bbox, else use self._tlwh
52+ """
53+ assert level in [1 , 2 ], 'level must be 1 or 2'
54+
55+ b = self .b1 if level == 1 else self .b2
56+
57+ if not bbox :
58+ buffer_bbox = self ._tlwh + np .array ([- b * self ._tlwh [2 ], - b * self ._tlwh [3 ], 2 * b * self ._tlwh [2 ], 2 * b * self ._tlwh [3 ]])
59+ else :
60+ buffer_bbox = bbox + np .array ([- b * bbox [2 ], - b * bbox [3 ], 2 * b * bbox [2 ], 2 * b * bbox [3 ]])
61+ return np .maximum (0.0 , buffer_bbox )
62+
63+
64+ @property
65+ def tlbr (self ):
66+ ret = self ._tlwh .copy ()
67+ ret [2 :] += ret [:2 ]
68+ return ret
69+
70+ def activate (self , frame_id ):
71+ """
72+ init a new track
73+ """
74+ self .track_id = BaseTrack .next_id ()
75+ self .state = TrackState .Tracked
76+
77+ if frame_id == 1 :
78+ self .is_activated = True
79+ # self.is_activated = True
80+ self .frame_id = frame_id
81+ self .start_frame = frame_id
82+
83+ def re_activate (self , new_track , frame_id , new_id = False ):
84+ """
85+ reactivate a lost track
86+ """
87+ self .tracklet_len = 0
88+ self .state = TrackState .Tracked
89+ self .is_activated = True
90+ self .frame_id = frame_id
91+ if new_id :
92+ self .track_id = self .next_id ()
93+ self .score = new_track .score
94+
95+ self ._tlwh = new_track ._tlwh
96+ self .buffer_bbox1 = self .get_buffer_bbox (level = 1 )
97+ self .buffer_bbox2 = self .get_buffer_bbox (level = 2 )
98+ self .motion_state1 = self .buffer_bbox1 .copy ()
99+ self .motion_state2 = self .buffer_bbox2 .copy ()
100+
101+ def update (self , new_track , frame_id ):
102+ """
103+ update a track
104+ """
105+ self .frame_id = frame_id
106+ self .tracklet_len += 1
107+
108+ # update position and score
109+ new_tlwh = new_track .tlwh
110+ self .score = new_track .score
111+
112+ # update stored bbox
113+ if (self .frame_id > self .n ):
114+ self .origin_bbox_buffer .popleft ()
115+ self .origin_bbox_buffer .append (new_tlwh )
116+ else :
117+ self .origin_bbox_buffer .append (new_tlwh )
118+
119+ # update motion state
120+ if self .time_since_update : # have some unmatched frames
121+ if self .frame_id < self .n :
122+ self .motion_state1 = self .get_buffer_bbox (level = 1 , bbox = new_tlwh )
123+ self .motion_state2 = self .get_buffer_bbox (level = 2 , bbox = new_tlwh )
124+ else : # s^{t + \delta} = o^t + (\delta / n) * (o^t - o^{t - n})
125+ motion_state = self .origin_bbox_buffer [- 1 ] + \
126+ (self .time_since_update / self .n ) * (self .origin_bbox_buffer [- 1 ] - self .origin_bbox_buffer [0 ])
127+ self .motion_state1 = self .get_buffer_bbox (level = 1 , bbox = motion_state )
128+ self .motion_state2 = self .get_buffer_bbox (level = 2 , bbox = motion_state )
129+
130+ else : # no unmatched frames, use current detection as motion state
131+ self .motion_state1 = self .get_buffer_bbox (level = 1 , bbox = new_tlwh )
132+ self .motion_state2 = self .get_buffer_bbox (level = 2 , bbox = new_tlwh )
133+
134+ @staticmethod
135+ def xywh2tlbr (xywh ):
136+ """
137+ convert xc, yc, wh to tlbr
138+ """
139+ if len (xywh .shape ) > 1 : # case shape (N, 4) used for Tracker update
140+ result = np .asarray (xywh ).copy ()
141+ result [:, :2 ] -= result [:, 2 :] // 2
142+ result [:, 2 :] = result [:, :2 ] + result [:, 2 :]
143+ result = np .maximum (0.0 , result ) # in case exists minus
144+ else :
145+ result = np .asarray (xywh ).copy ()
146+ result [:2 ] -= result [2 :] // 2
147+ result [2 :] = result [:2 ] + result [2 :]
148+ result = np .maximum (0.0 , result )
149+ return result
150+
151+ @staticmethod
152+ def xywh2tlwh (xywh ):
153+ """
154+ convert xc, yc, wh to tlwh
155+ """
156+ if len (xywh .shape ) > 1 :
157+ result = np .asarray (xywh ).copy ()
158+ result [:, :2 ] -= result [:, 2 :] // 2
159+ else :
160+ result = np .asarray (xywh ).copy ()
161+ result [:2 ] -= result [2 :] // 2
162+
163+ return result
164+
165+ @staticmethod
166+ def tlwh2tlbr (tlwh ):
167+ """
168+ convert top, left, wh to tlbr
169+ """
170+ if len (tlwh .shape ) > 1 :
171+ result = np .asarray (tlwh ).copy ()
172+ result [:, 2 :] += result [:, :2 ]
173+ else :
174+ result = np .asarray (tlwh ).copy ()
175+ result [2 :] += result [:2 ]
176+
177+ return result
178+
179+
180+ class C_BIoUTracker (BaseTracker ):
181+ def __init__ (self , opts , frame_rate = 30 , * args , ** kwargs ) -> None :
182+ super ().__init__ (opts , frame_rate , * args , ** kwargs )
183+
184+ self .kalman = None # The paper drops Kalman Filter so we donot use it
185+
186+ def update (self , det_results , ori_img ):
187+ """
188+ this func is called by every time step
189+
190+ det_results: numpy.ndarray or torch.Tensor, shape(N, 6), 6 includes bbox, conf_score, cls
191+ ori_img: original image, np.ndarray, shape(H, W, C)
192+ """
193+
194+ if isinstance (det_results , torch .Tensor ):
195+ det_results = det_results .cpu ().numpy ()
196+ if isinstance (ori_img , torch .Tensor ):
197+ ori_img = ori_img .numpy ()
198+
199+ self .frame_id += 1
200+ activated_starcks = [] # for storing active tracks, for the current frame
201+ refind_stracks = [] # Lost Tracks whose detections are obtained in the current frame
202+ lost_stracks = [] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing)
203+ removed_stracks = []
204+
205+ """step 1. filter results and init tracks"""
206+ det_results = det_results [det_results [:, 4 ] > self .det_thresh ]
207+
208+ # convert the scale to origin size
209+ # NOTE: yolo v7 origin out format: [xc, yc, w, h, conf, cls0_conf, cls1_conf, ..., clsn_conf]
210+ # TODO: check here, if nesscessary use two ratio
211+ img_h , img_w = ori_img .shape [0 ], ori_img .shape [1 ]
212+ ratio = [img_h / self .model_img_size [0 ], img_w / self .model_img_size [1 ]] # usually > 1
213+ det_results [:, 0 ], det_results [:, 2 ] = det_results [:, 0 ]* ratio [1 ], det_results [:, 2 ]* ratio [1 ]
214+ det_results [:, 1 ], det_results [:, 3 ] = det_results [:, 1 ]* ratio [0 ], det_results [:, 3 ]* ratio [0 ]
215+
216+ if det_results .shape [0 ] > 0 :
217+ if self .NMS :
218+ # TODO: Note nms need tlbr format
219+ bbox_temp = C_BIoUSTrack .xywh2tlbr (det_results [:, :4 ])
220+ nms_indices = nms (torch .from_numpy (bbox_temp ), torch .from_numpy (det_results [:, 4 ]),
221+ self .opts .nms_thresh )
222+ det_results = det_results [nms_indices .numpy ()]
223+
224+ # detections: List[Strack]
225+ detections = [C_BIoUSTrack (cls , C_BIoUSTrack .xywh2tlwh (xywh ), score )
226+ for (cls , xywh , score ) in zip (det_results [:, - 1 ], det_results [:, :4 ], det_results [:, 4 ])]
227+
228+ else :
229+ detections = []
230+
231+ # Do some updates
232+ unconfirmed = []
233+ tracked_stracks = [] # type: list[C_BIoUSTrack]
234+ for track in self .tracked_stracks :
235+ if not track .is_activated :
236+ unconfirmed .append (track )
237+ else :
238+ tracked_stracks .append (track )
239+
240+ """step 2. association with IoU in level 1"""
241+ strack_pool = joint_stracks (tracked_stracks , self .lost_stracks )
242+ # cal IoU dist
243+ IoU_mat = matching .buffered_iou_distance (strack_pool , detections , level = 1 )
244+
245+ matched_pair0 , u_tracks0_idx , u_dets0_idx = matching .linear_assignment (IoU_mat , thresh = 0.9 )
246+
247+ for itracked , idet in matched_pair0 : # for those who matched successfully
248+ track = strack_pool [itracked ]
249+ det = detections [idet ]
250+
251+ if track .state == TrackState .Tracked :
252+ track .update (det , self .frame_id )
253+ activated_starcks .append (track )
254+
255+ else :
256+ track .re_activate (det , self .frame_id , new_id = False )
257+ refind_stracks .append (track )
258+
259+ # tracks and detections that not matched
260+ u_tracks0 = [strack_pool [i ] for i in u_tracks0_idx ]
261+ u_dets0 = [detections [i ] for i in u_dets0_idx ]
262+
263+ """step 3. association with IoU in level 2"""
264+ IoU_mat = matching .buffered_iou_distance (u_tracks0 , u_dets0 , level = 2 )
265+
266+ matched_pair1 , u_tracks1_idx , u_dets1_idx = matching .linear_assignment (IoU_mat , thresh = 0.5 )
267+
268+ for itracked , idet in matched_pair1 :
269+ track = u_tracks0 [itracked ]
270+ det = u_dets0 [idet ]
271+
272+ if track .state == TrackState .Tracked :
273+ track .update (det , self .frame_id )
274+ activated_starcks .append (track )
275+ else :
276+ track .re_activate (det , self .frame_id , new_id = False )
277+ refind_stracks .append (track )
278+
279+ u_tracks1 = [u_tracks0 [i ] for i in u_tracks1_idx ]
280+ u_dets1 = [u_dets0 [i ] for i in u_dets1_idx ]
281+
282+ """step 3'. match unconfirmed tracks"""
283+ IoU_mat = matching .buffered_iou_distance (unconfirmed , u_dets1 , level = 1 )
284+ matched_pair_unconfirmed , u_tracks_unconfirmed_idx , u_dets_unconfirmed_idx = \
285+ matching .linear_assignment (IoU_mat , thresh = 0.7 )
286+
287+ for itracked , idet in matched_pair_unconfirmed :
288+ track = unconfirmed [itracked ]
289+ det = u_dets1 [idet ]
290+ track .update (det , self .frame_id )
291+ activated_starcks .append (track )
292+
293+ for idx in u_tracks_unconfirmed_idx :
294+ track = unconfirmed [idx ]
295+ track .mark_removed ()
296+ removed_stracks .append (track )
297+
298+ # new tracks
299+ for idx in u_dets_unconfirmed_idx :
300+ det = u_dets1 [idx ]
301+ if det .score > self .det_thresh + 0.1 :
302+ det .activate (self .frame_id )
303+ activated_starcks .append (det )
304+
305+ """ Step 4. deal with rest tracks"""
306+ for u_track in u_tracks1 :
307+ if self .frame_id - u_track .end_frame > self .max_time_lost :
308+ u_track .mark_removed ()
309+ removed_stracks .append (u_track )
310+ else :
311+ u_track .mark_lost ()
312+ u_track .time_since_update = self .frame_id - u_track .end_frame # u_track.time_since_update += 1
313+ lost_stracks .append (u_track )
314+
315+
316+ # update all
317+ self .tracked_stracks = [t for t in self .tracked_stracks if t .state == TrackState .Tracked ]
318+ self .tracked_stracks = joint_stracks (self .tracked_stracks , activated_starcks )
319+ self .tracked_stracks = joint_stracks (self .tracked_stracks , refind_stracks )
320+ # self.lost_stracks = [t for t in self.lost_stracks if t.state == TrackState.Lost] # type: list[STrack]
321+ self .lost_stracks = sub_stracks (self .lost_stracks , self .tracked_stracks )
322+ self .lost_stracks .extend (lost_stracks )
323+ self .lost_stracks = sub_stracks (self .lost_stracks , self .removed_stracks )
324+ self .removed_stracks .extend (removed_stracks )
325+ self .tracked_stracks , self .lost_stracks = remove_duplicate_stracks (self .tracked_stracks , self .lost_stracks )
326+
327+
328+ # print
329+ if self .debug_mode :
330+ print ('===========Frame {}==========' .format (self .frame_id ))
331+ print ('Activated: {}' .format ([track .track_id for track in activated_starcks ]))
332+ print ('Refind: {}' .format ([track .track_id for track in refind_stracks ]))
333+ print ('Lost: {}' .format ([track .track_id for track in lost_stracks ]))
334+ print ('Removed: {}' .format ([track .track_id for track in removed_stracks ]))
335+ return [track for track in self .tracked_stracks if track .is_activated ]
336+
337+
338+ def joint_stracks (tlista , tlistb ):
339+ exists = {}
340+ res = []
341+ for t in tlista :
342+ exists [t .track_id ] = 1
343+ res .append (t )
344+ for t in tlistb :
345+ tid = t .track_id
346+ if not exists .get (tid , 0 ):
347+ exists [tid ] = 1
348+ res .append (t )
349+ return res
350+
351+ def sub_stracks (tlista , tlistb ):
352+ stracks = {}
353+ for t in tlista :
354+ stracks [t .track_id ] = t
355+ for t in tlistb :
356+ tid = t .track_id
357+ if stracks .get (tid , 0 ):
358+ del stracks [tid ]
359+ return list (stracks .values ())
360+
361+ def remove_duplicate_stracks (stracksa , stracksb ):
362+ pdist = matching .iou_distance (stracksa , stracksb )
363+ pairs = np .where (pdist < 0.15 )
364+ dupa , dupb = list (), list ()
365+ for p ,q in zip (* pairs ):
366+ timep = stracksa [p ].frame_id - stracksa [p ].start_frame
367+ timeq = stracksb [q ].frame_id - stracksb [q ].start_frame
368+ if timep > timeq :
369+ dupb .append (q )
370+ else :
371+ dupa .append (p )
372+ resa = [t for i ,t in enumerate (stracksa ) if not i in dupa ]
373+ resb = [t for i ,t in enumerate (stracksb ) if not i in dupb ]
374+ return resa , resb
0 commit comments