@@ -345,7 +345,6 @@ def __init__(self, opts, frame_rate=30, *args, **kwargs) -> None:
345345 self .model_img_size = opts .img_size
346346
347347 self .debug_mode = False
348-
349348 def update (self , det_results , ori_img ):
350349 """
351350 this func is called by every time step
@@ -472,6 +471,56 @@ def update(self, det_results, ori_img):
472471 self .tracked_stracks , self .lost_stracks = remove_duplicate_stracks (self .tracked_stracks , self .lost_stracks )
473472
474473
474+ # print
475+ if self .debug_mode :
476+ print ('===========Frame {}==========' .format (self .frame_id ))
477+ print ('Activated: {}' .format ([track .track_id for track in activated_starcks ]))
478+ print ('Refind: {}' .format ([track .track_id for track in refind_stracks ]))
479+ print ('Lost: {}' .format ([track .track_id for track in lost_stracks ]))
480+ print ('Removed: {}' .format ([track .track_id for track in removed_stracks ]))
481+ return [track for track in self .tracked_stracks if track .is_activated ]
482+
483+ def update_without_detection (self , det_results , ori_img ):
484+ """
485+ update tracks when no detection
486+ only predict current tracks
487+ """
488+ if isinstance (ori_img , torch .Tensor ):
489+ ori_img = ori_img .numpy ()
490+
491+ self .frame_id += 1
492+ activated_starcks = [] # for storing active tracks, for the current frame
493+ refind_stracks = [] # Lost Tracks whose detections are obtained in the current frame
494+ lost_stracks = [] # The tracks which are not obtained in the current frame but are not removed.(Lost for some time lesser than the threshold for removing)
495+ removed_stracks = []
496+
497+ """step 1. init tracks"""
498+
499+ # Do some updates
500+ unconfirmed = []
501+ tracked_stracks = [] # type: list[STrack]
502+ for track in self .tracked_stracks :
503+ if not track .is_activated :
504+ unconfirmed .append (track )
505+ else :
506+ tracked_stracks .append (track )
507+
508+ """step 2. predict Kalman without updating"""
509+ strack_pool = joint_stracks (tracked_stracks , self .lost_stracks )
510+ STrack .multi_predict (stracks = strack_pool , kalman = self .kalman )
511+
512+ # update all
513+ self .tracked_stracks = [t for t in self .tracked_stracks if t .state == TrackState .Tracked ]
514+ self .tracked_stracks = joint_stracks (self .tracked_stracks , activated_starcks )
515+ self .tracked_stracks = joint_stracks (self .tracked_stracks , refind_stracks )
516+ # self.lost_stracks = [t for t in self.lost_stracks if t.state == TrackState.Lost] # type: list[STrack]
517+ self .lost_stracks = sub_stracks (self .lost_stracks , self .tracked_stracks )
518+ self .lost_stracks .extend (lost_stracks )
519+ self .lost_stracks = sub_stracks (self .lost_stracks , self .removed_stracks )
520+ self .removed_stracks .extend (removed_stracks )
521+ self .tracked_stracks , self .lost_stracks = remove_duplicate_stracks (self .tracked_stracks , self .lost_stracks )
522+
523+
475524 # print
476525 if self .debug_mode :
477526 print ('===========Frame {}==========' .format (self .frame_id ))
0 commit comments