Skip to content

Commit dfdcf94

Browse files
committed
update track_demo
1 parent 75d99ca commit dfdcf94

File tree

4 files changed

+70
-25
lines changed

4 files changed

+70
-25
lines changed

tracker/botsort.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def multi_gmc(stracks, H=np.eye(2, 3)):
270270

271271

272272
class BoTSORT(BaseTracker):
273-
def __init__(self, opts, frame_rate=30, gamma=0.02, use_GMC=False, *args, **kwargs) -> None:
273+
def __init__(self, opts, frame_rate=30, gamma=0.02, use_GMC=True, *args, **kwargs) -> None:
274274
super().__init__(opts, frame_rate, *args, **kwargs)
275275

276276
self.use_apperance_model = False

tracker/config_files/uavdt.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ CATEGORY_DICT:
88
0: 'car'
99

1010
CERTAIN_SEQS:
11-
- 'M0101'
11+
-
1212
IGNORE_SEQS: # Seqs you want to ignore
1313
-
1414

tracker/track_demo.py

Lines changed: 66 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from models.experimental import attempt_load
3131
from evaluate import evaluate
3232
from utils.torch_utils import select_device, time_synchronized, TracedModel
33+
from utils.general import non_max_suppression, scale_coords, check_img_size
3334
print('Note: running yolo v7 detector')
3435

3536
except:
@@ -64,8 +65,11 @@ def main(opts):
6465
1. load model
6566
"""
6667
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
68+
6769
ckpt = torch.load(opts.model_path, map_location=device)
6870
model = ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval() # for yolo v7
71+
stride = int(model.stride.max()) # model stride
72+
opts.img_size = check_img_size(opts.img_size, s=stride) # check img_size
6973

7074
if opts.trace:
7175
print(opts.img_size)
@@ -116,13 +120,15 @@ def main(opts):
116120
if not is_valid:
117121
break # end of reading
118122

119-
img = resize_a_frame(img0, [opts.img_size, opts.img_size])
123+
img, img0 = preprocess_v7(ori_img=img0, model_size=(opts.img_size, opts.img_size), model_stride=stride)
120124

121125
timer.tic() # start timing this img
122126
img = img.unsqueeze(0) # (C, H, W) -> (bs == 1, C, H, W)
123127
out = model(img.to(device)) # model forward
124128
out = out[0] # NOTE: for yolo v7
125-
129+
130+
out = post_process_v7(out, img_size=img.shape[2:], ori_img_size=img0.shape)
131+
126132
if len(out.shape) == 3: # case (bs, num_obj, ...)
127133
# out = out.squeeze()
128134
# NOTE: assert batch size == 1
@@ -157,7 +163,7 @@ def main(opts):
157163
results.append((frame_id + 1, cur_id, cur_tlwh, cur_cls))
158164
timer.toc() # end timing this image
159165

160-
plot_img(img0, frame_id, [cur_tlwh, cur_id, cur_cls], save_dir=os.path.join(SAVE_FOLDER, 'reuslt_images', obj_name))
166+
plot_img(img0, frame_id, [cur_tlwh, cur_id, cur_cls], save_dir=os.path.join(SAVE_FOLDER, 'result_images', obj_name))
161167

162168
frame_id += 1
163169

@@ -191,25 +197,63 @@ def is_empty(self):
191197
return self.start_idx == len(self.arr)
192198

193199

194-
def resize_a_frame(frame, target_size):
195-
"""
196-
resize a frame to target size
197-
198-
frame: np.ndarray, shape (H, W, C)
199-
target_size: List[int, int] | Tuple[int, int]
200+
def post_process_v7(out, img_size, ori_img_size):
201+
""" post process for v5 and v7
202+
200203
"""
201-
# resize to input to the YOLO net
202-
frame_resized = cv2.resize(frame, (target_size[0], target_size[1])) # (H', W', C)
203-
# convert BGR to RGB and to (C, H, W)
204-
frame_resized = frame_resized[:, :, ::-1].transpose(2, 0, 1)
205204

206-
frame_resized = np.ascontiguousarray(frame_resized, dtype=np.float32)
207-
frame_resized /= 255.0
205+
out = non_max_suppression(out, conf_thres=0.01, )[0]
206+
out[:, :4] = scale_coords(img_size, out[:, :4], ori_img_size, ratio_pad=None).round()
208207

209-
frame_resized = torch.from_numpy(frame_resized)
208+
# out: tlbr, conf, cls
210209

211-
return frame_resized
210+
return out
212211

212+
def preprocess_v7(ori_img, model_size, model_stride):
213+
""" simple preprocess for a single image
214+
215+
"""
216+
img_resized = _letterbox(ori_img, new_shape=model_size, stride=model_stride)[0]
217+
218+
img_resized = img_resized[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB
219+
img_resized = np.ascontiguousarray(img_resized)
220+
221+
img_resized = torch.from_numpy(img_resized).float()
222+
img_resized /= 255.0
223+
224+
return img_resized, ori_img
225+
226+
def _letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
227+
# Resize and pad image while meeting stride-multiple constraints
228+
shape = img.shape[:2] # current shape [height, width]
229+
if isinstance(new_shape, int):
230+
new_shape = (new_shape, new_shape)
231+
232+
# Scale ratio (new / old)
233+
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
234+
if not scaleup: # only scale down, do not scale up (for better test mAP)
235+
r = min(r, 1.0)
236+
237+
# Compute padding
238+
ratio = r, r # width, height ratios
239+
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
240+
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
241+
if auto: # minimum rectangle
242+
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
243+
elif scaleFill: # stretch
244+
dw, dh = 0.0, 0.0
245+
new_unpad = (new_shape[1], new_shape[0])
246+
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
247+
248+
dw /= 2 # divide padding into 2 sides
249+
dh /= 2
250+
251+
if shape[::-1] != new_unpad: # resize
252+
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
253+
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
254+
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
255+
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
256+
return img, ratio, (dw, dh)
213257

214258
def save_results(obj_name, results, data_type='default'):
215259
"""
@@ -273,7 +317,8 @@ def save_videos(obj_name):
273317
obj_name = [obj_name]
274318

275319
for seq in obj_name:
276-
images_path = os.path.join(SAVE_FOLDER, 'reuslt_images', seq)
320+
if 'mp4' in seq: seq = seq[:-4]
321+
images_path = os.path.join(SAVE_FOLDER, 'result_images', seq)
277322
images_name = sorted(os.listdir(images_path))
278323

279324
to_video_path = os.path.join(images_path, '../', seq + '.mp4')
@@ -303,12 +348,12 @@ def get_color(idx):
303348
if __name__ == '__main__':
304349
parser = argparse.ArgumentParser()
305350

306-
parser.add_argument('--obj', type=str, default='M1305.mp4', help='video NAME or images FOLDER NAME')
351+
parser.add_argument('--obj', type=str, default='demo.mp4', help='video NAME or images FOLDER NAME')
307352

308353
parser.add_argument('--save_txt', type=bool, default=False, help='whether save txt')
309354

310355
parser.add_argument('--tracker', type=str, default='sort', help='sort, deepsort, etc')
311-
parser.add_argument('--model_path', type=str, default='./weights/best.pt', help='model path')
356+
parser.add_argument('--model_path', type=str, default='./weights/yolov7_UAVDT_35epochs_20230507.pt', help='model path')
312357
parser.add_argument('--trace', type=bool, default=False, help='traced model of YOLO v7')
313358

314359
parser.add_argument('--img_size', type=int, default=1280, help='[train, test] image sizes')
@@ -319,7 +364,7 @@ def get_color(idx):
319364
parser.add_argument('--dhn_path', type=str, default='./weights/DHN.pth', help='path of DHN path for DeepMOT')
320365

321366
# threshs
322-
parser.add_argument('--conf_thresh', type=float, default=0.5, help='filter tracks')
367+
parser.add_argument('--conf_thresh', type=float, default=0.1, help='filter tracks')
323368
parser.add_argument('--nms_thresh', type=float, default=0.7, help='thresh for NMS')
324369
parser.add_argument('--iou_thresh', type=float, default=0.5, help='IOU thresh to filter tracks')
325370

tracker/track_yolov5.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def main(opts, cfgs):
167167
timer.toc() # end timing this image
168168

169169
if opts.save_images:
170-
plot_img(img0, frame_id, [cur_tlwh, cur_id, cur_cls], save_dir=os.path.join(DATASET_ROOT, 'reuslt_images', seq))
170+
plot_img(img0, frame_id, [cur_tlwh, cur_id, cur_cls], save_dir=os.path.join(DATASET_ROOT, 'result_images', seq))
171171

172172
frame_id += 1
173173

@@ -297,7 +297,7 @@ def save_videos(seq_names):
297297
seq_names = [seq_names]
298298

299299
for seq in seq_names:
300-
images_path = os.path.join(DATASET_ROOT, 'reuslt_images', seq)
300+
images_path = os.path.join(DATASET_ROOT, 'result_images', seq)
301301
images_name = sorted(os.listdir(images_path))
302302

303303
to_video_path = os.path.join(images_path, '../', seq + '.mp4')

0 commit comments

Comments
 (0)