|
| 1 | +from collections import OrderedDict |
| 2 | +import numpy as np |
| 3 | +from scipy.spatial import distance |
| 4 | +from motrackers.utils.misc import get_centroid |
| 5 | +from motrackers.track import Track |
| 6 | + |
| 7 | + |
| 8 | +class SimpleTracker2: |
| 9 | + """ |
| 10 | + Greedy Tracker with class label check. |
| 11 | + """ |
| 12 | + def __init__(self, max_lost=5): |
| 13 | + """ |
| 14 | +
|
| 15 | + Parameters |
| 16 | + ---------- |
| 17 | + max_lost : int |
| 18 | + maximum number of consecutive frames object was not detected. |
| 19 | + """ |
| 20 | + |
| 21 | + self.next_track_id = 0 # ID of next object |
| 22 | + self.tracks = OrderedDict() |
| 23 | + |
| 24 | + self.max_lost = max_lost |
| 25 | + self.frame_count = 0 |
| 26 | + |
| 27 | + def _add_object(self, centroid, bbox, class_id): |
| 28 | + """ |
| 29 | + Add a newly detected object to the queue |
| 30 | +
|
| 31 | + Parameters |
| 32 | + ---------- |
| 33 | + centroid : tuple |
| 34 | + centroid coordinate (x, y) in pixels of the bounding box. |
| 35 | + bbox : tuple |
| 36 | + bounding box of the object being tracked as top right and bottom right coordinates. |
| 37 | + class_id : int |
| 38 | + class label |
| 39 | +
|
| 40 | + Returns |
| 41 | + ------- |
| 42 | +
|
| 43 | + """ |
| 44 | + # store new object location |
| 45 | + self.tracks[self.next_track_id] = Track(track_id=self.next_track_id, |
| 46 | + centroid=centroid, |
| 47 | + bbox=bbox, |
| 48 | + class_id=class_id) |
| 49 | + self.next_track_id += 1 |
| 50 | + |
| 51 | + def _remove_track(self, track_id): |
| 52 | + """ |
| 53 | + Remove tracker data after object is lost |
| 54 | +
|
| 55 | + Parameters |
| 56 | + ---------- |
| 57 | + track_id : int |
| 58 | + track_id of the track lost while tracking |
| 59 | +
|
| 60 | + Returns |
| 61 | + ------- |
| 62 | +
|
| 63 | + """ |
| 64 | + del self.tracks[track_id] |
| 65 | + |
| 66 | + def get_tracks(self, tracks): |
| 67 | + """ |
| 68 | + Output the information of tracks |
| 69 | +
|
| 70 | + Parameters |
| 71 | + ---------- |
| 72 | + tracks : OrderedDict |
| 73 | + Dictionary of Tracks or objects being tracked with keys as track_id |
| 74 | + and values as corresponding `Track` objects. |
| 75 | +
|
| 76 | + Returns |
| 77 | + ------- |
| 78 | + outputs : list |
| 79 | + List of tracks being currently tracked by the tracker. |
| 80 | + Each element of this list contains the following tuple: |
| 81 | + (frame#, trackid, class_id, centroid, bbox, info_dict). |
| 82 | + class_id is the id for label of the detection. |
| 83 | + centroid represents the pixel coordinates of the centroid of bounding box, i.e., (x, y). |
| 84 | + bbox is the bounding box coordinates as (x_top_left, y_top_left, x_bottom_right, y_bottom_right). |
| 85 | + info_dict is the dictionary of information which may be useful from the tracker (example: |
| 86 | + number of times tracker was lost while tracking.). |
| 87 | +
|
| 88 | + """ |
| 89 | + outputs = [] |
| 90 | + |
| 91 | + for trackid, track in tracks.items(): |
| 92 | + track.info['lost'] = track.lost |
| 93 | + op = (self.frame_count, trackid, track.class_id, track.centroid, track.bbox, track.info) |
| 94 | + outputs.append(op) |
| 95 | + |
| 96 | + return outputs |
| 97 | + |
| 98 | + def update(self, bboxes: list, class_ids: list): |
| 99 | + """ |
| 100 | + Update the tracker based on the new bboxes as input. |
| 101 | +
|
| 102 | + Parameters |
| 103 | + ---------- |
| 104 | + bboxes : list |
| 105 | + List of bounding boxes detected in the current frame/timestep. Each element of the list represent |
| 106 | + coordinates of bounding box as tuple (top-left-x, top-left-y, bottom-right-x, bottom-right-y). |
| 107 | + class_ids : list |
| 108 | + List of class_ids (int) corresponding to labels of the detected object. Default is `None`. |
| 109 | +
|
| 110 | + Returns |
| 111 | + ------- |
| 112 | + outputs : list |
| 113 | + List of tracks being currently tracked by the tracker. |
| 114 | + Each element of this list contains the following tuple: |
| 115 | + (frame#, trackid, class_id, centroid, bbox, info_dict). |
| 116 | + class_id is the id for label of the detection. |
| 117 | + centroid represents the pixel coordinates of the centroid of bounding box, i.e., (x, y). |
| 118 | + bbox is the bounding box coordinates as (x_top_left, y_top_left, x_bottom_right, y_bottom_right). |
| 119 | + info_dict is the dictionary of information which may be useful from the tracker (example: |
| 120 | + number of times tracker was lost while tracking.). |
| 121 | +
|
| 122 | + """ |
| 123 | + self.frame_count += 1 |
| 124 | + |
| 125 | + if len(bboxes) == 0: # if no object detected |
| 126 | + lost_ids = list(self.tracks.keys()) |
| 127 | + for track_id in lost_ids: |
| 128 | + self.tracks[track_id].lost += 1 |
| 129 | + if self.tracks[track_id].lost > self.max_lost: |
| 130 | + self._remove_track(track_id) |
| 131 | + |
| 132 | + outputs = self.get_tracks(self.tracks) |
| 133 | + return outputs |
| 134 | + |
| 135 | + new_class_ids = np.array(class_ids, dtype='int') |
| 136 | + new_centroids = np.zeros((len(bboxes), 2), dtype="int") |
| 137 | + for (i, bbox) in enumerate(bboxes): |
| 138 | + new_centroids[i] = get_centroid(bbox) |
| 139 | + |
| 140 | + if len(self.tracks): |
| 141 | + track_ids = list(self.tracks.keys()) |
| 142 | + old_centroids = np.array([self.tracks[tid].centroid for tid in track_ids]) |
| 143 | + D = distance.cdist(old_centroids, new_centroids) # (row, col) = distance between old (row) and new (col) |
| 144 | + row_idxs = D.min(axis=1).argsort() # old tracks sorted as per min distance from new |
| 145 | + col_idxs = D.argmin(axis=1)[row_idxs] # new tracks sorted as per min distance from old |
| 146 | + |
| 147 | + assigned_rows, assigned_cols = set(), set() |
| 148 | + for (row_idx, col_idx) in zip(row_idxs, col_idxs): |
| 149 | + if row_idx in assigned_rows or col_idx in assigned_cols: |
| 150 | + continue |
| 151 | + |
| 152 | + track_id = track_ids[row_idx] |
| 153 | + |
| 154 | + if self.tracks[track_id].class_id == new_class_ids[col_idx]: |
| 155 | + self.tracks[track_id].centroid = new_centroids[col_idx] |
| 156 | + self.tracks[track_id].bbox = bboxes[col_idx] |
| 157 | + self.tracks[track_id].lost = 0 |
| 158 | + assigned_rows.add(row_idx) |
| 159 | + assigned_cols.add(col_idx) |
| 160 | + |
| 161 | + unassigned_rows = set(range(0, D.shape[0])).difference(assigned_rows) |
| 162 | + unassigned_cols = set(range(0, D.shape[1])).difference(assigned_cols) |
| 163 | + |
| 164 | + if D.shape[0] >= D.shape[1]: |
| 165 | + for row_idx in unassigned_rows: |
| 166 | + track_id = track_ids[row_idx] |
| 167 | + self.tracks[track_id].lost += 1 |
| 168 | + |
| 169 | + if self.tracks[track_id].lost > self.max_lost: |
| 170 | + self._remove_track(track_id) |
| 171 | + else: |
| 172 | + for col_idx in unassigned_cols: |
| 173 | + self._add_object(new_centroids[col_idx], bboxes[col_idx], class_ids[col_idx]) |
| 174 | + else: |
| 175 | + for i in range(0, len(bboxes)): |
| 176 | + self._add_object(new_centroids[i], bboxes[i], class_ids[i]) |
| 177 | + |
| 178 | + outputs = self.get_tracks(self.tracks) |
| 179 | + return outputs |
0 commit comments