Skip to content

Commit e9e4aec

Browse files
author
yixu.cui
committed
merge origin/master into here
2 parents f37b452 + 7eb4f8a commit e9e4aec

File tree

4 files changed

+724
-6
lines changed

4 files changed

+724
-6
lines changed

README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@
22

33
## 0. 更新记录
44

5+
**2023.2.28**优化了`track_demo.py`, 减少了内存占用.
6+
7+
**2023.2.24**加入了**推理单个视频或图片文件夹**以及**YOLO v8**的推理功能, 对应的代码为`tracker/track_demo.py``tracker/track_yolov8.py`. 推理单个视频或图片文件夹不需要指定数据集与真值, 也没有评测指标的功能, 只需要在命令行中指定`obj`即可, 例如:
8+
9+
```shell
10+
python tracker/track_demo.py --obj demo.mp4
11+
```
12+
YOLO v8 代码的参数与之前完全相同. 安装YOLO v8以及训练步骤请参照[YOLO v8](https://github.com/ultralytics/ultralytics)
13+
14+
515
**2023.2.11**修复了TrackEval路径报错的问题, 详见[issue35](https://github.com/JackWoo0831/Yolov7-tracker/issues/35)
616

717
**2023.2.10**修改了[DeepSORT](https://github.com/JackWoo0831/Yolov7-tracker/blob/master/tracker/deepsort.py)的代码与相关部分代码, 遵循了DeepSORT原论文**级联匹配和余弦距离计算**的原则, 并且解决了原有DeepSORT代码出现莫名漂移跟踪框的问题.
@@ -207,6 +217,11 @@ python tracker/track.py --dataset visdrone --data_format origin --tracker c_biou
207217
python tracker/track.py --dataset mot17 --data_format yolo --tracker ${TRACKER} --model_path ${MODEL_PATH}
208218
```
209219

220+
***推理单个视频或图片序列***:
221+
222+
```shell
223+
python tracker/track_demo.py --obj demo.mp4
224+
```
210225

211226

212227
> StrongSORT中OSNet的下载地址, 请参照https://github.com/mikel-brostrom/Yolov5_StrongSORT_OSNet/blob/master/strong_sort/deep/reid_model_factory.py

tracker/track_demo.py

Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
"""
2+
Only track a video or image seqs, without evaluate
3+
"""
4+
5+
import numpy as np
6+
import torch
7+
import cv2
8+
from PIL import Image
9+
import tqdm
10+
11+
import argparse
12+
import os
13+
from time import gmtime, strftime
14+
from timer import Timer
15+
import yaml
16+
17+
from basetrack import BaseTracker # for framework
18+
from deepsort import DeepSORT
19+
from bytetrack import ByteTrack
20+
from deepmot import DeepMOT
21+
from botsort import BoTSORT
22+
from uavmot import UAVMOT
23+
from strongsort import StrongSORT
24+
from c_biou_tracker import C_BIoUTracker
25+
26+
try: # import package that outside the tracker folder For yolo v7
27+
import sys
28+
sys.path.append(os.getcwd())
29+
30+
from models.experimental import attempt_load
31+
from evaluate import evaluate
32+
from utils.torch_utils import select_device, time_synchronized, TracedModel
33+
print('Note: running yolo v7 detector')
34+
35+
except:
36+
pass
37+
38+
SAVE_FOLDER = 'demo_result' # NOTE: set your save path here
39+
CATEGORY_DICT = {0: 'car'} # NOTE: set the categories in your videos here,
40+
# format: class_id(start from 0): class_name
41+
42+
timer = Timer()
43+
seq_fps = [] # list to store time used for every seq
44+
45+
def main(opts):
46+
TRACKER_DICT = {
47+
'sort': BaseTracker,
48+
'deepsort': DeepSORT,
49+
'bytetrack': ByteTrack,
50+
'deepmot': DeepMOT,
51+
'botsort': BoTSORT,
52+
'uavmot': UAVMOT,
53+
'strongsort': StrongSORT,
54+
'c_biou': C_BIoUTracker,
55+
} # dict for trackers, key: str, value: class(BaseTracker)
56+
57+
# NOTE: ATTENTION: make kalman and tracker compatible
58+
if opts.tracker == 'botsort':
59+
opts.kalman_format = 'botsort'
60+
elif opts.tracker == 'strongsort':
61+
opts.kalman_format = 'strongsort'
62+
63+
"""
64+
1. load model
65+
"""
66+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
67+
ckpt = torch.load(opts.model_path, map_location=device)
68+
model = ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval() # for yolo v7
69+
70+
if opts.trace:
71+
print(opts.img_size)
72+
model = TracedModel(model, device, opts.img_size)
73+
else:
74+
model.to(device)
75+
76+
"""
77+
2. load videos or images
78+
"""
79+
obj_name = opts.obj
80+
# if read video, then put every frame into a queue
81+
# if read image seqs, the same as video
82+
resized_images_queue = [] # List[torch.Tensor] store resized images
83+
images_queue = [] # List[torch.Tensor] store origin images
84+
85+
# check path
86+
assert os.path.exists(obj_name), 'the path does not exist! '
87+
obj, get_next_frame = None, None # init obj
88+
if 'mp4' in opts.obj: # if it is a video
89+
obj = cv2.VideoCapture(obj_name)
90+
get_next_frame = lambda _ : obj.read()
91+
92+
else:
93+
obj = my_queue(os.listdir(obj_name))
94+
get_next_frame = lambda _ : obj.pop_front()
95+
96+
97+
"""
98+
3. start tracking
99+
"""
100+
tracker = TRACKER_DICT[opts.tracker](opts, frame_rate=30, gamma=opts.gamma) # instantiate tracker TODO: finish init params
101+
results = [] # store current seq results
102+
frame_id = 0
103+
104+
while True:
105+
print(f'----------processing frame {frame_id}----------')
106+
107+
# end condition
108+
is_valid, img0 = get_next_frame(None) # img0: (H, W, C)
109+
110+
if not is_valid:
111+
break # end of reading
112+
113+
img = resize_a_frame(img0, [opts.img_size, opts.img_size])
114+
115+
timer.tic() # start timing this img
116+
img = img.unsqueeze(0) # (C, H, W) -> (bs == 1, C, H, W)
117+
out = model(img.to(device)) # model forward
118+
out = out[0] # NOTE: for yolo v7
119+
120+
if len(out.shape) == 3: # case (bs, num_obj, ...)
121+
# out = out.squeeze()
122+
# NOTE: assert batch size == 1
123+
out = out.squeeze(0)
124+
# remove some low conf detections
125+
out = out[out[:, 4] > 0.001]
126+
127+
128+
# NOTE: yolo v7 origin out format: [xc, yc, w, h, conf, cls0_conf, cls1_conf, ..., clsn_conf]
129+
cls_conf, cls_idx = torch.max(out[:, 5:], dim=1)
130+
# out[:, 4] *= cls_conf # fuse object and cls conf
131+
out[:, 5] = cls_idx
132+
out = out[:, :6]
133+
134+
current_tracks = tracker.update(out, img0) # List[class(STracks)]
135+
136+
137+
# save results
138+
cur_tlwh, cur_id, cur_cls = [], [], []
139+
for trk in current_tracks:
140+
bbox = trk.tlwh
141+
id = trk.track_id
142+
cls = trk.cls
143+
144+
# filter low area bbox
145+
if bbox[2] * bbox[3] > opts.min_area:
146+
cur_tlwh.append(bbox)
147+
cur_id.append(id)
148+
cur_cls.append(cls)
149+
# results.append((frame_id + 1, id, bbox, cls))
150+
151+
results.append((frame_id + 1, cur_id, cur_tlwh, cur_cls))
152+
timer.toc() # end timing this image
153+
154+
plot_img(img0, frame_id, [cur_tlwh, cur_id, cur_cls], save_dir=os.path.join(SAVE_FOLDER, 'reuslt_images', obj_name))
155+
156+
frame_id += 1
157+
158+
seq_fps.append(frame_id / timer.total_time) # cal fps for current seq
159+
timer.clear() # clear for next seq
160+
# thirdly, save results
161+
# every time assign a different name
162+
if opts.save_txt: save_results(obj_name, '', results)
163+
164+
## finally, save videos
165+
save_videos(obj_name)
166+
167+
168+
class my_queue:
169+
"""
170+
implement a queue for image seq reading
171+
"""
172+
def __init__(self, arr: list) -> None:
173+
self.arr = arr
174+
self.start_idx = 0
175+
176+
def push_back(self, item):
177+
self.arr.append(item)
178+
179+
def pop_front(self):
180+
ret = cv2.imread(self.arr[self.start_idx])
181+
self.start_idx += 1
182+
return not self.is_empty(), ret
183+
184+
def is_empty(self):
185+
return self.start_idx == len(self.arr)
186+
187+
188+
def resize_a_frame(frame, target_size):
189+
"""
190+
resize a frame to target size
191+
192+
frame: np.ndarray, shape (H, W, C)
193+
target_size: List[int, int] | Tuple[int, int]
194+
"""
195+
# resize to input to the YOLO net
196+
frame_resized = cv2.resize(frame, (target_size[0], target_size[1])) # (H', W', C)
197+
# convert BGR to RGB and to (C, H, W)
198+
frame_resized = frame_resized[:, :, ::-1].transpose(2, 0, 1)
199+
200+
frame_resized = np.ascontiguousarray(frame_resized, dtype=np.float32)
201+
frame_resized /= 255.0
202+
203+
frame_resized = torch.from_numpy(frame_resized)
204+
205+
return frame_resized
206+
207+
208+
def save_results(obj_name, results, data_type='default'):
209+
"""
210+
write results to txt file
211+
212+
results: list row format: frame id, target id, box coordinate, class(optional)
213+
to_file: file path(optional)
214+
data_type: write data format
215+
"""
216+
assert len(results)
217+
if not data_type == 'default':
218+
raise NotImplementedError # TODO
219+
220+
with open(os.path.join(SAVE_FOLDER, obj_name + '.txt'), 'w') as f:
221+
for frame_id, target_ids, tlwhs, clses in results:
222+
if data_type == 'default':
223+
224+
# f.write(f'{frame_id},{target_id},{tlwh[0]},{tlwh[1]},\
225+
# {tlwh[2]},{tlwh[3]},{cls}\n')
226+
for id, tlwh, cls in zip(target_ids, tlwhs, clses):
227+
f.write(f'{frame_id},{id},{tlwh[0]:.2f},{tlwh[1]:.2f},{tlwh[2]:.2f},{tlwh[3]:.2f},{int(cls)}\n')
228+
f.close()
229+
230+
def plot_img(img, frame_id, results, save_dir):
231+
"""
232+
img: np.ndarray: (H, W, C)
233+
frame_id: int
234+
results: [tlwhs, ids, clses]
235+
save_dir: sr
236+
237+
plot images with bboxes of a seq
238+
"""
239+
if not os.path.exists(save_dir):
240+
os.makedirs(save_dir)
241+
242+
img_ = np.ascontiguousarray(np.copy(img))
243+
244+
tlwhs, ids, clses = results[0], results[1], results[2]
245+
for tlwh, id, cls in zip(tlwhs, ids, clses):
246+
247+
# convert tlwh to tlbr
248+
tlbr = tuple([int(tlwh[0]), int(tlwh[1]), int(tlwh[0] + tlwh[2]), int(tlwh[1] + tlwh[3])])
249+
# draw a rect
250+
cv2.rectangle(img_, tlbr[:2], tlbr[2:], get_color(id), thickness=3, )
251+
# note the id and cls
252+
text = f'{CATEGORY_DICT[cls]}-{id}'
253+
cv2.putText(img_, text, (tlbr[0], tlbr[1]), fontFace=cv2.FONT_HERSHEY_PLAIN, fontScale=1,
254+
color=(255, 164, 0), thickness=2)
255+
256+
cv2.imwrite(filename=os.path.join(save_dir, f'{frame_id:05d}.jpg'), img=img_)
257+
258+
259+
def save_videos(obj_name):
260+
"""
261+
convert imgs to a video
262+
263+
seq_names: List[str] or str, seqs that will be generated
264+
"""
265+
if not isinstance(obj_name, list):
266+
obj_name = [obj_name]
267+
268+
for seq in obj_name:
269+
images_path = os.path.join(SAVE_FOLDER, 'reuslt_images', seq)
270+
images_name = sorted(os.listdir(images_path))
271+
272+
to_video_path = os.path.join(images_path, '../', seq + '.mp4')
273+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
274+
275+
img0 = Image.open(os.path.join(images_path, images_name[0]))
276+
vw = cv2.VideoWriter(to_video_path, fourcc, 15, img0.size)
277+
278+
for img in images_name:
279+
if img.endswith('.jpg'):
280+
frame = cv2.imread(os.path.join(images_path, img))
281+
vw.write(frame)
282+
283+
print('Save videos Done!!')
284+
285+
286+
def get_color(idx):
287+
"""
288+
aux func for plot_seq
289+
get a unique color for each id
290+
"""
291+
idx = idx * 3
292+
color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
293+
294+
return color
295+
296+
if __name__ == '__main__':
297+
parser = argparse.ArgumentParser()
298+
299+
parser.add_argument('--obj', type=str, default='M1305.mp4', help='video NAME or images FOLDER NAME')
300+
301+
parser.add_argument('--save_txt', type=bool, default=False, help='whether save txt')
302+
303+
parser.add_argument('--tracker', type=str, default='sort', help='sort, deepsort, etc')
304+
parser.add_argument('--model_path', type=str, default='./weights/best.pt', help='model path')
305+
parser.add_argument('--trace', type=bool, default=False, help='traced model of YOLO v7')
306+
307+
parser.add_argument('--img_size', type=int, default=1280, help='[train, test] image sizes')
308+
309+
"""For tracker"""
310+
# model path
311+
parser.add_argument('--reid_model_path', type=str, default='./weights/ckpt.t7', help='path for reid model path')
312+
parser.add_argument('--dhn_path', type=str, default='./weights/DHN.pth', help='path of DHN path for DeepMOT')
313+
314+
# threshs
315+
parser.add_argument('--conf_thresh', type=float, default=0.5, help='filter tracks')
316+
parser.add_argument('--nms_thresh', type=float, default=0.7, help='thresh for NMS')
317+
parser.add_argument('--iou_thresh', type=float, default=0.5, help='IOU thresh to filter tracks')
318+
319+
# other options
320+
parser.add_argument('--track_buffer', type=int, default=30, help='tracking buffer')
321+
parser.add_argument('--gamma', type=float, default=0.1, help='param to control fusing motion and apperance dist')
322+
parser.add_argument('--kalman_format', type=str, default='default', help='use what kind of Kalman, default, naive, strongsort or bot-sort like')
323+
parser.add_argument('--min_area', type=float, default=150, help='use to filter small bboxs')
324+
325+
opts = parser.parse_args()
326+
327+
if not os.path.exists(SAVE_FOLDER): # demo save to a particular folder
328+
os.makedirs(SAVE_FOLDER)
329+
os.makedirs(os.path.join(SAVE_FOLDER, 'result_images'))
330+
main(opts)

0 commit comments

Comments
 (0)