From 04ca2b246a02ab122f9df6c7bb76a31a4284d1b0 Mon Sep 17 00:00:00 2001 From: Carlos Fernandez Date: Sat, 3 Sep 2022 05:43:23 -0500 Subject: [PATCH] -Add validation with Tensor --- motrackers/track.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/motrackers/track.py b/motrackers/track.py index ec53ebe..98f3694 100644 --- a/motrackers/track.py +++ b/motrackers/track.py @@ -1,4 +1,6 @@ import numpy as np +import torch + from motrackers.kalman_tracker import KFTracker2D, KFTrackerSORT, KFTracker4D @@ -238,7 +240,10 @@ def predict(self): def update(self, frame_id, bbox, detection_confidence, class_id=None, lost=0, iou_score=0., **kwargs): super().update( frame_id, bbox, detection_confidence, class_id=class_id, lost=lost, iou_score=iou_score, **kwargs) - self.kf.update(bbox.copy()) + if isinstance(bbox, torch.Tensor): + self.kf.update(bbox.clone().cpu().numpy()) + else: + self.kf.update(bbox.copy()) class KFTrackCentroid(Track):