Skip to content

Commit c3af262

Browse files
committed
add C_BIoU and modified some codes
1 parent 15f57f3 commit c3af262

File tree

3 files changed

+398
-2
lines changed

3 files changed

+398
-2
lines changed

tracker/bytetrack.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def update(self, det_results, ori_img):
165165
""" Step 4. deal with rest tracks and dets"""
166166
# deal with final unmatched tracks
167167
for idx in u_tracks1_idx:
168-
track = strack_pool[idx]
168+
track = u_tracks0[idx]
169169
track.mark_lost()
170170
lost_stracks.append(track)
171171

tracker/c_biou_tracker.py

Lines changed: 374 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,374 @@
1+
"""
2+
Unofficial Implementation of paper
3+
Hard to Track Objects with Irregular Motions and Similar Appearances? Make It Easier by Buffering the Matching Space(arxiv 2212)
4+
"""
5+
6+
import numpy as np
7+
from collections import deque
8+
from basetrack import TrackState, BaseTrack, BaseTracker
9+
import matching
10+
import torch
11+
from torchvision.ops import nms
12+
13+
14+
"""
15+
Because the paper drops Kalman so we rewrite the STrack class and rename it as C_BIoUSTrack
16+
"""
17+
class C_BIoUSTrack(BaseTrack):
18+
def __init__(self, cls, tlwh, score) -> None:
19+
"""
20+
cls: category of this obj
21+
tlwh: positoin score: conf score
22+
"""
23+
super().__init__()
24+
self.cls = cls
25+
self._tlwh = np.asarray(tlwh, dtype=np.float32) # init tlwh
26+
self.score = score
27+
28+
self.is_activated = False
29+
self.tracklet_len = 0
30+
31+
self.track_id = None
32+
self.start_frame = None
33+
self.frame_id = None
34+
self.time_since_update = 0 # \delta in paper, use to calculate motion state
35+
36+
# params in motion state
37+
self.b1, self.b2, self.n = 0.3, 0.5, 2
38+
self.origin_bbox_buffer = deque() # a deque store the original bbox(tlwh) from t - self.n to t, where t is the last time detected
39+
# buffered bbox, two buffer sizes
40+
self.buffer_bbox1 = self.get_buffer_bbox(level=1)
41+
self.buffer_bbox2 = self.get_buffer_bbox(level=2)
42+
# motion state, s^{t + \delta} = o^t + (\delta / n) * \sum_{i=t-n+1}^t(o^i - o^{i-1}) = o^t + (\delta / n) * (o^t - o^{t - n})
43+
self.motion_state1 = self.buffer_bbox1.copy()
44+
self.motion_state2 = self.buffer_bbox2.copy()
45+
46+
47+
def get_buffer_bbox(self, level=1, bbox=None):
48+
"""
49+
get buffered bbox as: (top, left, w, h) -> (top - bw, y - bh, w + 2bw, h + 2bh)
50+
level = 1: b = self.b1 level = 2: b = self.b2
51+
bbox: if not None, use bbox to calculate buffer_bbox, else use self._tlwh
52+
"""
53+
assert level in [1, 2], 'level must be 1 or 2'
54+
55+
b = self.b1 if level == 1 else self.b2
56+
57+
if not bbox:
58+
buffer_bbox = self._tlwh + np.array([-b*self._tlwh[2], -b*self._tlwh[3], 2*b*self._tlwh[2], 2*b*self._tlwh[3]])
59+
else:
60+
buffer_bbox = bbox + np.array([-b*bbox[2], -b*bbox[3], 2*b*bbox[2], 2*b*bbox[3]])
61+
return np.maximum(0.0, buffer_bbox)
62+
63+
64+
@property
65+
def tlbr(self):
66+
ret = self._tlwh.copy()
67+
ret[2:] += ret[:2]
68+
return ret
69+
70+
def activate(self, frame_id):
71+
"""
72+
init a new track
73+
"""
74+
self.track_id = BaseTrack.next_id()
75+
self.state = TrackState.Tracked
76+
77+
if frame_id == 1:
78+
self.is_activated = True
79+
# self.is_activated = True
80+
self.frame_id = frame_id
81+
self.start_frame = frame_id
82+
83+
def re_activate(self, new_track, frame_id, new_id=False):
84+
"""
85+
reactivate a lost track
86+
"""
87+
self.tracklet_len = 0
88+
self.state = TrackState.Tracked
89+
self.is_activated = True
90+
self.frame_id = frame_id
91+
if new_id:
92+
self.track_id = self.next_id()
93+
self.score = new_track.score
94+
95+
self._tlwh = new_track._tlwh
96+
self.buffer_bbox1 = self.get_buffer_bbox(level=1)
97+
self.buffer_bbox2 = self.get_buffer_bbox(level=2)
98+
self.motion_state1 = self.buffer_bbox1.copy()
99+
self.motion_state2 = self.buffer_bbox2.copy()
100+
101+
def update(self, new_track, frame_id):
102+
"""
103+
update a track
104+
"""
105+
self.frame_id = frame_id
106+
self.tracklet_len += 1
107+
108+
# update position and score
109+
new_tlwh = new_track.tlwh
110+
self.score = new_track.score
111+
112+
# update stored bbox
113+
if (self.frame_id > self.n):
114+
self.origin_bbox_buffer.popleft()
115+
self.origin_bbox_buffer.append(new_tlwh)
116+
else:
117+
self.origin_bbox_buffer.append(new_tlwh)
118+
119+
# update motion state
120+
if self.time_since_update: # have some unmatched frames
121+
if self.frame_id < self.n:
122+
self.motion_state1 = self.get_buffer_bbox(level=1, bbox=new_tlwh)
123+
self.motion_state2 = self.get_buffer_bbox(level=2, bbox=new_tlwh)
124+
else: # s^{t + \delta} = o^t + (\delta / n) * (o^t - o^{t - n})
125+
motion_state = self.origin_bbox_buffer[-1] + \
126+
(self.time_since_update / self.n) * (self.origin_bbox_buffer[-1] - self.origin_bbox_buffer[0])
127+
self.motion_state1 = self.get_buffer_bbox(level=1, bbox=motion_state)
128+
self.motion_state2 = self.get_buffer_bbox(level=2, bbox=motion_state)
129+
130+
else: # no unmatched frames, use current detection as motion state
131+
self.motion_state1 = self.get_buffer_bbox(level=1, bbox=new_tlwh)
132+
self.motion_state2 = self.get_buffer_bbox(level=2, bbox=new_tlwh)
133+
134+
@staticmethod
135+
def xywh2tlbr(xywh):
136+
"""
137+
convert xc, yc, wh to tlbr
138+
"""
139+
if len(xywh.shape) > 1: # case shape (N, 4) used for Tracker update
140+
result = np.asarray(xywh).copy()
141+
result[:, :2] -= result[:, 2:] // 2
142+
result[:, 2:] = result[:, :2] + result[:, 2:]
143+
result = np.maximum(0.0, result) # in case exists minus
144+
else:
145+
result = np.asarray(xywh).copy()
146+
result[:2] -= result[2:] // 2
147+
result[2:] = result[:2] + result[2:]
148+
result = np.maximum(0.0, result)
149+
return result
150+
151+
@staticmethod
152+
def xywh2tlwh(xywh):
153+
"""
154+
convert xc, yc, wh to tlwh
155+
"""
156+
if len(xywh.shape) > 1:
157+
result = np.asarray(xywh).copy()
158+
result[:, :2] -= result[:, 2:] // 2
159+
else:
160+
result = np.asarray(xywh).copy()
161+
result[:2] -= result[2:] // 2
162+
163+
return result
164+
165+
@staticmethod
166+
def tlwh2tlbr(tlwh):
167+
"""
168+
convert top, left, wh to tlbr
169+
"""
170+
if len(tlwh.shape) > 1:
171+
result = np.asarray(tlwh).copy()
172+
result[:, 2:] += result[:, :2]
173+
else:
174+
result = np.asarray(tlwh).copy()
175+
result[2:] += result[:2]
176+
177+
return result
178+
179+
180+
class C_BIoUTracker(BaseTracker):
181+
def __init__(self, opts, frame_rate=30, *args, **kwargs) -> None:
182+
super().__init__(opts, frame_rate, *args, **kwargs)
183+
184+
self.kalman = None # The paper drops Kalman Filter so we donot use it
185+
186+
def update(self, det_results, ori_img):
187+
"""
188+
this func is called by every time step
189+
190+
det_results: numpy.ndarray or torch.Tensor, shape(N, 6), 6 includes bbox, conf_score, cls
191+
ori_img: original image, np.ndarray, shape(H, W, C)
192+
"""
193+
194+
if isinstance(det_results, torch.Tensor):
195+
det_results = det_results.cpu().numpy()
196+
if isinstance(ori_img, torch.Tensor):
197+
ori_img = ori_img.numpy()
198+
199+
self.frame_id += 1
200+
activated_starcks = [] # for storing active tracks, for the current frame
201+
refind_stracks = [] # Lost Tracks whose detections are obtained in the current frame
202+
lost_stracks = [] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing)
203+
removed_stracks = []
204+
205+
"""step 1. filter results and init tracks"""
206+
det_results = det_results[det_results[:, 4] > self.det_thresh]
207+
208+
# convert the scale to origin size
209+
# NOTE: yolo v7 origin out format: [xc, yc, w, h, conf, cls0_conf, cls1_conf, ..., clsn_conf]
210+
# TODO: check here, if nesscessary use two ratio
211+
img_h, img_w = ori_img.shape[0], ori_img.shape[1]
212+
ratio = [img_h / self.model_img_size[0], img_w / self.model_img_size[1]] # usually > 1
213+
det_results[:, 0], det_results[:, 2] = det_results[:, 0]*ratio[1], det_results[:, 2]*ratio[1]
214+
det_results[:, 1], det_results[:, 3] = det_results[:, 1]*ratio[0], det_results[:, 3]*ratio[0]
215+
216+
if det_results.shape[0] > 0:
217+
if self.NMS:
218+
# TODO: Note nms need tlbr format
219+
bbox_temp = C_BIoUSTrack.xywh2tlbr(det_results[:, :4])
220+
nms_indices = nms(torch.from_numpy(bbox_temp), torch.from_numpy(det_results[:, 4]),
221+
self.opts.nms_thresh)
222+
det_results = det_results[nms_indices.numpy()]
223+
224+
# detections: List[Strack]
225+
detections = [C_BIoUSTrack(cls, C_BIoUSTrack.xywh2tlwh(xywh), score)
226+
for (cls, xywh, score) in zip(det_results[:, -1], det_results[:, :4], det_results[:, 4])]
227+
228+
else:
229+
detections = []
230+
231+
# Do some updates
232+
unconfirmed = []
233+
tracked_stracks = [] # type: list[C_BIoUSTrack]
234+
for track in self.tracked_stracks:
235+
if not track.is_activated:
236+
unconfirmed.append(track)
237+
else:
238+
tracked_stracks.append(track)
239+
240+
"""step 2. association with IoU in level 1"""
241+
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
242+
# cal IoU dist
243+
IoU_mat = matching.buffered_iou_distance(strack_pool, detections, level=1)
244+
245+
matched_pair0, u_tracks0_idx, u_dets0_idx = matching.linear_assignment(IoU_mat, thresh=0.9)
246+
247+
for itracked, idet in matched_pair0: # for those who matched successfully
248+
track = strack_pool[itracked]
249+
det = detections[idet]
250+
251+
if track.state == TrackState.Tracked:
252+
track.update(det, self.frame_id)
253+
activated_starcks.append(track)
254+
255+
else:
256+
track.re_activate(det, self.frame_id, new_id=False)
257+
refind_stracks.append(track)
258+
259+
# tracks and detections that not matched
260+
u_tracks0 = [strack_pool[i] for i in u_tracks0_idx]
261+
u_dets0 = [detections[i] for i in u_dets0_idx]
262+
263+
"""step 3. association with IoU in level 2"""
264+
IoU_mat = matching.buffered_iou_distance(u_tracks0, u_dets0, level=2)
265+
266+
matched_pair1, u_tracks1_idx, u_dets1_idx = matching.linear_assignment(IoU_mat, thresh=0.5)
267+
268+
for itracked, idet in matched_pair1:
269+
track = u_tracks0[itracked]
270+
det = u_dets0[idet]
271+
272+
if track.state == TrackState.Tracked:
273+
track.update(det, self.frame_id)
274+
activated_starcks.append(track)
275+
else:
276+
track.re_activate(det, self.frame_id, new_id=False)
277+
refind_stracks.append(track)
278+
279+
u_tracks1 = [u_tracks0[i] for i in u_tracks1_idx]
280+
u_dets1 = [u_dets0[i] for i in u_dets1_idx]
281+
282+
"""step 3'. match unconfirmed tracks"""
283+
IoU_mat = matching.buffered_iou_distance(unconfirmed, u_dets1, level=1)
284+
matched_pair_unconfirmed, u_tracks_unconfirmed_idx, u_dets_unconfirmed_idx = \
285+
matching.linear_assignment(IoU_mat, thresh=0.7)
286+
287+
for itracked, idet in matched_pair_unconfirmed:
288+
track = unconfirmed[itracked]
289+
det = u_dets1[idet]
290+
track.update(det, self.frame_id)
291+
activated_starcks.append(track)
292+
293+
for idx in u_tracks_unconfirmed_idx:
294+
track = unconfirmed[idx]
295+
track.mark_removed()
296+
removed_stracks.append(track)
297+
298+
# new tracks
299+
for idx in u_dets_unconfirmed_idx:
300+
det = u_dets1[idx]
301+
if det.score > self.det_thresh + 0.1:
302+
det.activate(self.frame_id)
303+
activated_starcks.append(det)
304+
305+
""" Step 4. deal with rest tracks"""
306+
for u_track in u_tracks1:
307+
if self.frame_id - u_track.end_frame > self.max_time_lost:
308+
u_track.mark_removed()
309+
removed_stracks.append(u_track)
310+
else:
311+
u_track.mark_lost()
312+
u_track.time_since_update = self.frame_id - u_track.end_frame # u_track.time_since_update += 1
313+
lost_stracks.append(u_track)
314+
315+
316+
# update all
317+
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
318+
self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
319+
self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
320+
# self.lost_stracks = [t for t in self.lost_stracks if t.state == TrackState.Lost] # type: list[STrack]
321+
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
322+
self.lost_stracks.extend(lost_stracks)
323+
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
324+
self.removed_stracks.extend(removed_stracks)
325+
self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
326+
327+
328+
# print
329+
if self.debug_mode:
330+
print('===========Frame {}=========='.format(self.frame_id))
331+
print('Activated: {}'.format([track.track_id for track in activated_starcks]))
332+
print('Refind: {}'.format([track.track_id for track in refind_stracks]))
333+
print('Lost: {}'.format([track.track_id for track in lost_stracks]))
334+
print('Removed: {}'.format([track.track_id for track in removed_stracks]))
335+
return [track for track in self.tracked_stracks if track.is_activated]
336+
337+
338+
def joint_stracks(tlista, tlistb):
339+
exists = {}
340+
res = []
341+
for t in tlista:
342+
exists[t.track_id] = 1
343+
res.append(t)
344+
for t in tlistb:
345+
tid = t.track_id
346+
if not exists.get(tid, 0):
347+
exists[tid] = 1
348+
res.append(t)
349+
return res
350+
351+
def sub_stracks(tlista, tlistb):
352+
stracks = {}
353+
for t in tlista:
354+
stracks[t.track_id] = t
355+
for t in tlistb:
356+
tid = t.track_id
357+
if stracks.get(tid, 0):
358+
del stracks[tid]
359+
return list(stracks.values())
360+
361+
def remove_duplicate_stracks(stracksa, stracksb):
362+
pdist = matching.iou_distance(stracksa, stracksb)
363+
pairs = np.where(pdist<0.15)
364+
dupa, dupb = list(), list()
365+
for p,q in zip(*pairs):
366+
timep = stracksa[p].frame_id - stracksa[p].start_frame
367+
timeq = stracksb[q].frame_id - stracksb[q].start_frame
368+
if timep > timeq:
369+
dupb.append(q)
370+
else:
371+
dupa.append(p)
372+
resa = [t for i,t in enumerate(stracksa) if not i in dupa]
373+
resb = [t for i,t in enumerate(stracksb) if not i in dupb]
374+
return resa, resb

0 commit comments

Comments
 (0)