Skip to content

Commit de7a243

Browse files
committed
modify track_demo.py
1 parent bebcbac commit de7a243

File tree

2 files changed

+67
-43
lines changed

2 files changed

+67
-43
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## 0. 更新记录
44

5+
**2023.2.28**优化了`track_demo.py`, 减少了内存占用.
6+
57
**2023.2.24**加入了**推理单个视频或图片文件夹**以及**YOLO v8**的推理功能, 对应的代码为`tracker/track_demo.py``tracker/track_yolov8.py`. 推理单个视频或图片文件夹不需要指定数据集与真值, 也没有评测指标的功能, 只需要在命令行中指定`obj`即可, 例如:
68

79
```shell

tracker/track_demo.py

Lines changed: 65 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@
3535
except:
3636
pass
3737

38-
SAVE_FOLDER = 'demo_result'
39-
CATEGORY_DICT = {0: 'car'}
38+
SAVE_FOLDER = 'demo_result' # NOTE: set your save path here
39+
CATEGORY_DICT = {0: 'car'} # NOTE: set the categories in your videos here,
40+
# format: class_id(start from 0): class_name
4041

4142
timer = Timer()
4243
seq_fps = [] # list to store time used for every seq
@@ -81,42 +82,16 @@ def main(opts):
8182
resized_images_queue = [] # List[torch.Tensor] store resized images
8283
images_queue = [] # List[torch.Tensor] store origin images
8384

84-
"""
85-
func: resize a frame to target size
86-
"""
87-
def resize_a_frame(frame, target_size):
88-
# resize to input to the YOLO net
89-
frame_resized = cv2.resize(frame, (target_size[0], target_size[1])) # (H', W', C)
90-
# convert BGR to RGB and to (C, H, W)
91-
frame_resized = frame_resized[:, :, ::-1].transpose(2, 0, 1)
92-
93-
frame_resized = np.ascontiguousarray(frame_resized, dtype=np.float32)
94-
frame_resized /= 255.0
95-
96-
frame_resized = torch.from_numpy(frame_resized)
97-
98-
return frame_resized
99-
85+
# check path
86+
assert os.path.exists(obj_name), 'the path does not exist! '
87+
obj, get_next_frame = None, None # init obj
10088
if 'mp4' in opts.obj: # if it is a video
101-
assert os.path.exists(obj_name), 'the path does not exist! '
102-
103-
video = cv2.VideoCapture(obj_name)
104-
while True:
105-
result, frame = video.read() # frame: np.ndarray, shape (H, W, C)
106-
if not result: break # end to the video
107-
frame_resized = resize_a_frame(frame, [opts.img_size, opts.img_size])
108-
109-
resized_images_queue.append(frame_resized)
110-
images_queue.append(frame)
89+
obj = cv2.VideoCapture(obj_name)
90+
get_next_frame = lambda _ : obj.read()
91+
11192
else:
112-
assert os.path.exists(obj_name), 'the path does not exist! '
113-
frames = os.listdir(obj_name)
114-
for item in frames:
115-
frame = cv2.imread(item)
116-
frame_resized = resize_a_frame(frame, [opts.img_size, opts.img_size])
117-
118-
resized_images_queue.append(frame_resized)
119-
images_queue.append(frame)
93+
obj = my_queue(os.listdir(obj_name))
94+
get_next_frame = lambda _ : obj.pop_front()
12095

12196

12297
"""
@@ -125,9 +100,18 @@ def resize_a_frame(frame, target_size):
125100
tracker = TRACKER_DICT[opts.tracker](opts, frame_rate=30, gamma=opts.gamma) # instantiate tracker TODO: finish init params
126101
results = [] # store current seq results
127102
frame_id = 0
128-
pbar = tqdm.tqdm(desc="demo--", ncols=80)
129-
for i, (img, img0) in enumerate(zip(resized_images_queue, images_queue)):
130-
pbar.update()
103+
104+
while True:
105+
print(f'----------processing frame {frame_id}----------')
106+
107+
# end condition
108+
is_valid, img0 = get_next_frame(None) # img0: (H, W, C)
109+
110+
if not is_valid:
111+
break # end of reading
112+
113+
img = resize_a_frame(img0, [opts.img_size, opts.img_size])
114+
131115
timer.tic() # start timing this img
132116
img = img.unsqueeze(0) # (C, H, W) -> (bs == 1, C, H, W)
133117
out = model(img.to(device)) # model forward
@@ -171,17 +155,55 @@ def resize_a_frame(frame, target_size):
171155

172156
frame_id += 1
173157

174-
seq_fps.append(i / timer.total_time) # cal fps for current seq
158+
seq_fps.append(frame_id / timer.total_time) # cal fps for current seq
175159
timer.clear() # clear for next seq
176-
pbar.close()
177160
# thirdly, save results
178161
# every time assign a different name
179162
if opts.save_txt: save_results(obj_name, '', results)
180163

181164
## finally, save videos
182165
save_videos(obj_name)
183166

167+
168+
class my_queue:
169+
"""
170+
implement a queue for image seq reading
171+
"""
172+
def __init__(self, arr: list) -> None:
173+
self.arr = arr
174+
self.start_idx = 0
175+
176+
def push_back(self, item):
177+
self.arr.append(item)
184178

179+
def pop_front(self):
180+
ret = cv2.imread(self.arr[self.start_idx])
181+
self.start_idx += 1
182+
return not self.is_empty(), ret
183+
184+
def is_empty(self):
185+
return self.start_idx == len(self.arr)
186+
187+
188+
def resize_a_frame(frame, target_size):
189+
"""
190+
resize a frame to target size
191+
192+
frame: np.ndarray, shape (H, W, C)
193+
target_size: List[int, int] | Tuple[int, int]
194+
"""
195+
# resize to input to the YOLO net
196+
frame_resized = cv2.resize(frame, (target_size[0], target_size[1])) # (H', W', C)
197+
# convert BGR to RGB and to (C, H, W)
198+
frame_resized = frame_resized[:, :, ::-1].transpose(2, 0, 1)
199+
200+
frame_resized = np.ascontiguousarray(frame_resized, dtype=np.float32)
201+
frame_resized /= 255.0
202+
203+
frame_resized = torch.from_numpy(frame_resized)
204+
205+
return frame_resized
206+
185207

186208
def save_results(obj_name, results, data_type='default'):
187209
"""
@@ -227,7 +249,7 @@ def plot_img(img, frame_id, results, save_dir):
227249
# draw a rect
228250
cv2.rectangle(img_, tlbr[:2], tlbr[2:], get_color(id), thickness=3, )
229251
# note the id and cls
230-
text = f'{CATEGORY_DICT[cls]}-{id}'
252+
text = f'car-{id}'
231253
cv2.putText(img_, text, (tlbr[0], tlbr[1]), fontFace=cv2.FONT_HERSHEY_PLAIN, fontScale=1,
232254
color=(255, 164, 0), thickness=2)
233255

@@ -282,7 +304,7 @@ def get_color(idx):
282304
parser.add_argument('--model_path', type=str, default='./weights/best.pt', help='model path')
283305
parser.add_argument('--trace', type=bool, default=False, help='traced model of YOLO v7')
284306

285-
parser.add_argument('--img_size', nargs='+', type=int, default=1280, help='[train, test] image sizes')
307+
parser.add_argument('--img_size', type=int, default=1280, help='[train, test] image sizes')
286308

287309
"""For tracker"""
288310
# model path

0 commit comments

Comments
 (0)