1
+ """
2
+ Unofficial Implementation of paper
3
+ Hard to Track Objects with Irregular Motions and Similar Appearances? Make It Easier by Buffering the Matching Space(arxiv 2212)
4
+ """
5
+
6
+ import numpy as np
7
+ from collections import deque
8
+ from basetrack import TrackState , BaseTrack , BaseTracker
9
+ import matching
10
+ import torch
11
+ from torchvision .ops import nms
12
+
13
+
14
+ """
15
+ Because the paper drops Kalman so we rewrite the STrack class and rename it as C_BIoUSTrack
16
+ """
17
+ class C_BIoUSTrack (BaseTrack ):
18
+ def __init__ (self , cls , tlwh , score ) -> None :
19
+ """
20
+ cls: category of this obj
21
+ tlwh: positoin score: conf score
22
+ """
23
+ super ().__init__ ()
24
+ self .cls = cls
25
+ self ._tlwh = np .asarray (tlwh , dtype = np .float32 ) # init tlwh
26
+ self .score = score
27
+
28
+ self .is_activated = False
29
+ self .tracklet_len = 0
30
+
31
+ self .track_id = None
32
+ self .start_frame = None
33
+ self .frame_id = None
34
+ self .time_since_update = 0 # \delta in paper, use to calculate motion state
35
+
36
+ # params in motion state
37
+ self .b1 , self .b2 , self .n = 0.3 , 0.5 , 2
38
+ self .origin_bbox_buffer = deque () # a deque store the original bbox(tlwh) from t - self.n to t, where t is the last time detected
39
+ # buffered bbox, two buffer sizes
40
+ self .buffer_bbox1 = self .get_buffer_bbox (level = 1 )
41
+ self .buffer_bbox2 = self .get_buffer_bbox (level = 2 )
42
+ # motion state, s^{t + \delta} = o^t + (\delta / n) * \sum_{i=t-n+1}^t(o^i - o^{i-1}) = o^t + (\delta / n) * (o^t - o^{t - n})
43
+ self .motion_state1 = self .buffer_bbox1 .copy ()
44
+ self .motion_state2 = self .buffer_bbox2 .copy ()
45
+
46
+
47
+ def get_buffer_bbox (self , level = 1 , bbox = None ):
48
+ """
49
+ get buffered bbox as: (top, left, w, h) -> (top - bw, y - bh, w + 2bw, h + 2bh)
50
+ level = 1: b = self.b1 level = 2: b = self.b2
51
+ bbox: if not None, use bbox to calculate buffer_bbox, else use self._tlwh
52
+ """
53
+ assert level in [1 , 2 ], 'level must be 1 or 2'
54
+
55
+ b = self .b1 if level == 1 else self .b2
56
+
57
+ if not bbox :
58
+ buffer_bbox = self ._tlwh + np .array ([- b * self ._tlwh [2 ], - b * self ._tlwh [3 ], 2 * b * self ._tlwh [2 ], 2 * b * self ._tlwh [3 ]])
59
+ else :
60
+ buffer_bbox = bbox + np .array ([- b * bbox [2 ], - b * bbox [3 ], 2 * b * bbox [2 ], 2 * b * bbox [3 ]])
61
+ return np .maximum (0.0 , buffer_bbox )
62
+
63
+
64
+ @property
65
+ def tlbr (self ):
66
+ ret = self ._tlwh .copy ()
67
+ ret [2 :] += ret [:2 ]
68
+ return ret
69
+
70
+ def activate (self , frame_id ):
71
+ """
72
+ init a new track
73
+ """
74
+ self .track_id = BaseTrack .next_id ()
75
+ self .state = TrackState .Tracked
76
+
77
+ if frame_id == 1 :
78
+ self .is_activated = True
79
+ # self.is_activated = True
80
+ self .frame_id = frame_id
81
+ self .start_frame = frame_id
82
+
83
+ def re_activate (self , new_track , frame_id , new_id = False ):
84
+ """
85
+ reactivate a lost track
86
+ """
87
+ self .tracklet_len = 0
88
+ self .state = TrackState .Tracked
89
+ self .is_activated = True
90
+ self .frame_id = frame_id
91
+ if new_id :
92
+ self .track_id = self .next_id ()
93
+ self .score = new_track .score
94
+
95
+ self ._tlwh = new_track ._tlwh
96
+ self .buffer_bbox1 = self .get_buffer_bbox (level = 1 )
97
+ self .buffer_bbox2 = self .get_buffer_bbox (level = 2 )
98
+ self .motion_state1 = self .buffer_bbox1 .copy ()
99
+ self .motion_state2 = self .buffer_bbox2 .copy ()
100
+
101
+ def update (self , new_track , frame_id ):
102
+ """
103
+ update a track
104
+ """
105
+ self .frame_id = frame_id
106
+ self .tracklet_len += 1
107
+
108
+ # update position and score
109
+ new_tlwh = new_track .tlwh
110
+ self .score = new_track .score
111
+
112
+ # update stored bbox
113
+ if (self .frame_id > self .n ):
114
+ self .origin_bbox_buffer .popleft ()
115
+ self .origin_bbox_buffer .append (new_tlwh )
116
+ else :
117
+ self .origin_bbox_buffer .append (new_tlwh )
118
+
119
+ # update motion state
120
+ if self .time_since_update : # have some unmatched frames
121
+ if self .frame_id < self .n :
122
+ self .motion_state1 = self .get_buffer_bbox (level = 1 , bbox = new_tlwh )
123
+ self .motion_state2 = self .get_buffer_bbox (level = 2 , bbox = new_tlwh )
124
+ else : # s^{t + \delta} = o^t + (\delta / n) * (o^t - o^{t - n})
125
+ motion_state = self .origin_bbox_buffer [- 1 ] + \
126
+ (self .time_since_update / self .n ) * (self .origin_bbox_buffer [- 1 ] - self .origin_bbox_buffer [0 ])
127
+ self .motion_state1 = self .get_buffer_bbox (level = 1 , bbox = motion_state )
128
+ self .motion_state2 = self .get_buffer_bbox (level = 2 , bbox = motion_state )
129
+
130
+ else : # no unmatched frames, use current detection as motion state
131
+ self .motion_state1 = self .get_buffer_bbox (level = 1 , bbox = new_tlwh )
132
+ self .motion_state2 = self .get_buffer_bbox (level = 2 , bbox = new_tlwh )
133
+
134
+ @staticmethod
135
+ def xywh2tlbr (xywh ):
136
+ """
137
+ convert xc, yc, wh to tlbr
138
+ """
139
+ if len (xywh .shape ) > 1 : # case shape (N, 4) used for Tracker update
140
+ result = np .asarray (xywh ).copy ()
141
+ result [:, :2 ] -= result [:, 2 :] // 2
142
+ result [:, 2 :] = result [:, :2 ] + result [:, 2 :]
143
+ result = np .maximum (0.0 , result ) # in case exists minus
144
+ else :
145
+ result = np .asarray (xywh ).copy ()
146
+ result [:2 ] -= result [2 :] // 2
147
+ result [2 :] = result [:2 ] + result [2 :]
148
+ result = np .maximum (0.0 , result )
149
+ return result
150
+
151
+ @staticmethod
152
+ def xywh2tlwh (xywh ):
153
+ """
154
+ convert xc, yc, wh to tlwh
155
+ """
156
+ if len (xywh .shape ) > 1 :
157
+ result = np .asarray (xywh ).copy ()
158
+ result [:, :2 ] -= result [:, 2 :] // 2
159
+ else :
160
+ result = np .asarray (xywh ).copy ()
161
+ result [:2 ] -= result [2 :] // 2
162
+
163
+ return result
164
+
165
+ @staticmethod
166
+ def tlwh2tlbr (tlwh ):
167
+ """
168
+ convert top, left, wh to tlbr
169
+ """
170
+ if len (tlwh .shape ) > 1 :
171
+ result = np .asarray (tlwh ).copy ()
172
+ result [:, 2 :] += result [:, :2 ]
173
+ else :
174
+ result = np .asarray (tlwh ).copy ()
175
+ result [2 :] += result [:2 ]
176
+
177
+ return result
178
+
179
+
180
+ class C_BIoUTracker (BaseTracker ):
181
+ def __init__ (self , opts , frame_rate = 30 , * args , ** kwargs ) -> None :
182
+ super ().__init__ (opts , frame_rate , * args , ** kwargs )
183
+
184
+ self .kalman = None # The paper drops Kalman Filter so we donot use it
185
+
186
+ def update (self , det_results , ori_img ):
187
+ """
188
+ this func is called by every time step
189
+
190
+ det_results: numpy.ndarray or torch.Tensor, shape(N, 6), 6 includes bbox, conf_score, cls
191
+ ori_img: original image, np.ndarray, shape(H, W, C)
192
+ """
193
+
194
+ if isinstance (det_results , torch .Tensor ):
195
+ det_results = det_results .cpu ().numpy ()
196
+ if isinstance (ori_img , torch .Tensor ):
197
+ ori_img = ori_img .numpy ()
198
+
199
+ self .frame_id += 1
200
+ activated_starcks = [] # for storing active tracks, for the current frame
201
+ refind_stracks = [] # Lost Tracks whose detections are obtained in the current frame
202
+ lost_stracks = [] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing)
203
+ removed_stracks = []
204
+
205
+ """step 1. filter results and init tracks"""
206
+ det_results = det_results [det_results [:, 4 ] > self .det_thresh ]
207
+
208
+ # convert the scale to origin size
209
+ # NOTE: yolo v7 origin out format: [xc, yc, w, h, conf, cls0_conf, cls1_conf, ..., clsn_conf]
210
+ # TODO: check here, if nesscessary use two ratio
211
+ img_h , img_w = ori_img .shape [0 ], ori_img .shape [1 ]
212
+ ratio = [img_h / self .model_img_size [0 ], img_w / self .model_img_size [1 ]] # usually > 1
213
+ det_results [:, 0 ], det_results [:, 2 ] = det_results [:, 0 ]* ratio [1 ], det_results [:, 2 ]* ratio [1 ]
214
+ det_results [:, 1 ], det_results [:, 3 ] = det_results [:, 1 ]* ratio [0 ], det_results [:, 3 ]* ratio [0 ]
215
+
216
+ if det_results .shape [0 ] > 0 :
217
+ if self .NMS :
218
+ # TODO: Note nms need tlbr format
219
+ bbox_temp = C_BIoUSTrack .xywh2tlbr (det_results [:, :4 ])
220
+ nms_indices = nms (torch .from_numpy (bbox_temp ), torch .from_numpy (det_results [:, 4 ]),
221
+ self .opts .nms_thresh )
222
+ det_results = det_results [nms_indices .numpy ()]
223
+
224
+ # detections: List[Strack]
225
+ detections = [C_BIoUSTrack (cls , C_BIoUSTrack .xywh2tlwh (xywh ), score )
226
+ for (cls , xywh , score ) in zip (det_results [:, - 1 ], det_results [:, :4 ], det_results [:, 4 ])]
227
+
228
+ else :
229
+ detections = []
230
+
231
+ # Do some updates
232
+ unconfirmed = []
233
+ tracked_stracks = [] # type: list[C_BIoUSTrack]
234
+ for track in self .tracked_stracks :
235
+ if not track .is_activated :
236
+ unconfirmed .append (track )
237
+ else :
238
+ tracked_stracks .append (track )
239
+
240
+ """step 2. association with IoU in level 1"""
241
+ strack_pool = joint_stracks (tracked_stracks , self .lost_stracks )
242
+ # cal IoU dist
243
+ IoU_mat = matching .buffered_iou_distance (strack_pool , detections , level = 1 )
244
+
245
+ matched_pair0 , u_tracks0_idx , u_dets0_idx = matching .linear_assignment (IoU_mat , thresh = 0.9 )
246
+
247
+ for itracked , idet in matched_pair0 : # for those who matched successfully
248
+ track = strack_pool [itracked ]
249
+ det = detections [idet ]
250
+
251
+ if track .state == TrackState .Tracked :
252
+ track .update (det , self .frame_id )
253
+ activated_starcks .append (track )
254
+
255
+ else :
256
+ track .re_activate (det , self .frame_id , new_id = False )
257
+ refind_stracks .append (track )
258
+
259
+ # tracks and detections that not matched
260
+ u_tracks0 = [strack_pool [i ] for i in u_tracks0_idx ]
261
+ u_dets0 = [detections [i ] for i in u_dets0_idx ]
262
+
263
+ """step 3. association with IoU in level 2"""
264
+ IoU_mat = matching .buffered_iou_distance (u_tracks0 , u_dets0 , level = 2 )
265
+
266
+ matched_pair1 , u_tracks1_idx , u_dets1_idx = matching .linear_assignment (IoU_mat , thresh = 0.5 )
267
+
268
+ for itracked , idet in matched_pair1 :
269
+ track = u_tracks0 [itracked ]
270
+ det = u_dets0 [idet ]
271
+
272
+ if track .state == TrackState .Tracked :
273
+ track .update (det , self .frame_id )
274
+ activated_starcks .append (track )
275
+ else :
276
+ track .re_activate (det , self .frame_id , new_id = False )
277
+ refind_stracks .append (track )
278
+
279
+ u_tracks1 = [u_tracks0 [i ] for i in u_tracks1_idx ]
280
+ u_dets1 = [u_dets0 [i ] for i in u_dets1_idx ]
281
+
282
+ """step 3'. match unconfirmed tracks"""
283
+ IoU_mat = matching .buffered_iou_distance (unconfirmed , u_dets1 , level = 1 )
284
+ matched_pair_unconfirmed , u_tracks_unconfirmed_idx , u_dets_unconfirmed_idx = \
285
+ matching .linear_assignment (IoU_mat , thresh = 0.7 )
286
+
287
+ for itracked , idet in matched_pair_unconfirmed :
288
+ track = unconfirmed [itracked ]
289
+ det = u_dets1 [idet ]
290
+ track .update (det , self .frame_id )
291
+ activated_starcks .append (track )
292
+
293
+ for idx in u_tracks_unconfirmed_idx :
294
+ track = unconfirmed [idx ]
295
+ track .mark_removed ()
296
+ removed_stracks .append (track )
297
+
298
+ # new tracks
299
+ for idx in u_dets_unconfirmed_idx :
300
+ det = u_dets1 [idx ]
301
+ if det .score > self .det_thresh + 0.1 :
302
+ det .activate (self .frame_id )
303
+ activated_starcks .append (det )
304
+
305
+ """ Step 4. deal with rest tracks"""
306
+ for u_track in u_tracks1 :
307
+ if self .frame_id - u_track .end_frame > self .max_time_lost :
308
+ u_track .mark_removed ()
309
+ removed_stracks .append (u_track )
310
+ else :
311
+ u_track .mark_lost ()
312
+ u_track .time_since_update = self .frame_id - u_track .end_frame # u_track.time_since_update += 1
313
+ lost_stracks .append (u_track )
314
+
315
+
316
+ # update all
317
+ self .tracked_stracks = [t for t in self .tracked_stracks if t .state == TrackState .Tracked ]
318
+ self .tracked_stracks = joint_stracks (self .tracked_stracks , activated_starcks )
319
+ self .tracked_stracks = joint_stracks (self .tracked_stracks , refind_stracks )
320
+ # self.lost_stracks = [t for t in self.lost_stracks if t.state == TrackState.Lost] # type: list[STrack]
321
+ self .lost_stracks = sub_stracks (self .lost_stracks , self .tracked_stracks )
322
+ self .lost_stracks .extend (lost_stracks )
323
+ self .lost_stracks = sub_stracks (self .lost_stracks , self .removed_stracks )
324
+ self .removed_stracks .extend (removed_stracks )
325
+ self .tracked_stracks , self .lost_stracks = remove_duplicate_stracks (self .tracked_stracks , self .lost_stracks )
326
+
327
+
328
+ # print
329
+ if self .debug_mode :
330
+ print ('===========Frame {}==========' .format (self .frame_id ))
331
+ print ('Activated: {}' .format ([track .track_id for track in activated_starcks ]))
332
+ print ('Refind: {}' .format ([track .track_id for track in refind_stracks ]))
333
+ print ('Lost: {}' .format ([track .track_id for track in lost_stracks ]))
334
+ print ('Removed: {}' .format ([track .track_id for track in removed_stracks ]))
335
+ return [track for track in self .tracked_stracks if track .is_activated ]
336
+
337
+
338
+ def joint_stracks (tlista , tlistb ):
339
+ exists = {}
340
+ res = []
341
+ for t in tlista :
342
+ exists [t .track_id ] = 1
343
+ res .append (t )
344
+ for t in tlistb :
345
+ tid = t .track_id
346
+ if not exists .get (tid , 0 ):
347
+ exists [tid ] = 1
348
+ res .append (t )
349
+ return res
350
+
351
+ def sub_stracks (tlista , tlistb ):
352
+ stracks = {}
353
+ for t in tlista :
354
+ stracks [t .track_id ] = t
355
+ for t in tlistb :
356
+ tid = t .track_id
357
+ if stracks .get (tid , 0 ):
358
+ del stracks [tid ]
359
+ return list (stracks .values ())
360
+
361
+ def remove_duplicate_stracks (stracksa , stracksb ):
362
+ pdist = matching .iou_distance (stracksa , stracksb )
363
+ pairs = np .where (pdist < 0.15 )
364
+ dupa , dupb = list (), list ()
365
+ for p ,q in zip (* pairs ):
366
+ timep = stracksa [p ].frame_id - stracksa [p ].start_frame
367
+ timeq = stracksb [q ].frame_id - stracksb [q ].start_frame
368
+ if timep > timeq :
369
+ dupb .append (q )
370
+ else :
371
+ dupa .append (p )
372
+ resa = [t for i ,t in enumerate (stracksa ) if not i in dupa ]
373
+ resb = [t for i ,t in enumerate (stracksb ) if not i in dupb ]
374
+ return resa , resb
0 commit comments