Skip to content

Commit 04ca2b2

Browse files
committed
-Add validation with Tensor
1 parent bfbc2b1 commit 04ca2b2

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

motrackers/track.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import numpy as np
2+
import torch
3+
24
from motrackers.kalman_tracker import KFTracker2D, KFTrackerSORT, KFTracker4D
35

46

@@ -238,7 +240,10 @@ def predict(self):
238240
def update(self, frame_id, bbox, detection_confidence, class_id=None, lost=0, iou_score=0., **kwargs):
239241
super().update(
240242
frame_id, bbox, detection_confidence, class_id=class_id, lost=lost, iou_score=iou_score, **kwargs)
241-
self.kf.update(bbox.copy())
243+
if isinstance(bbox, torch.Tensor):
244+
self.kf.update(bbox.clone().cpu().numpy())
245+
else:
246+
self.kf.update(bbox.copy())
242247

243248

244249
class KFTrackCentroid(Track):

0 commit comments

Comments
 (0)