diff --git a/motrackers/track.py b/motrackers/track.py index ec53ebe..98f3694 100644 --- a/motrackers/track.py +++ b/motrackers/track.py @@ -1,4 +1,6 @@ import numpy as np +import torch + from motrackers.kalman_tracker import KFTracker2D, KFTrackerSORT, KFTracker4D @@ -238,7 +240,10 @@ def predict(self): def update(self, frame_id, bbox, detection_confidence, class_id=None, lost=0, iou_score=0., **kwargs): super().update( frame_id, bbox, detection_confidence, class_id=class_id, lost=lost, iou_score=iou_score, **kwargs) - self.kf.update(bbox.copy()) + if isinstance(bbox, torch.Tensor): + self.kf.update(bbox.clone().cpu().numpy()) + else: + self.kf.update(bbox.copy()) class KFTrackCentroid(Track):