1+ """
2+ Only track a video or image seqs, without evaluate
3+ """
4+
5+ import numpy as np
6+ import torch
7+ import cv2
8+ from PIL import Image
9+ import tqdm
10+
11+ import argparse
12+ import os
13+ from time import gmtime , strftime
14+ from timer import Timer
15+ import yaml
16+
17+ from basetrack import BaseTracker # for framework
18+ from deepsort import DeepSORT
19+ from bytetrack import ByteTrack
20+ from deepmot import DeepMOT
21+ from botsort import BoTSORT
22+ from uavmot import UAVMOT
23+ from strongsort import StrongSORT
24+ from c_biou_tracker import C_BIoUTracker
25+
26+ try : # import package that outside the tracker folder For yolo v7
27+ import sys
28+ sys .path .append (os .getcwd ())
29+
30+ from models .experimental import attempt_load
31+ from evaluate import evaluate
32+ from utils .torch_utils import select_device , time_synchronized , TracedModel
33+ print ('Note: running yolo v7 detector' )
34+
35+ except :
36+ pass
37+
38+ SAVE_FOLDER = 'demo_result' # NOTE: set your save path here
39+ CATEGORY_DICT = {0 : 'car' } # NOTE: set the categories in your videos here,
40+ # format: class_id(start from 0): class_name
41+
42+ timer = Timer ()
43+ seq_fps = [] # list to store time used for every seq
44+
45+ def main (opts ):
46+ TRACKER_DICT = {
47+ 'sort' : BaseTracker ,
48+ 'deepsort' : DeepSORT ,
49+ 'bytetrack' : ByteTrack ,
50+ 'deepmot' : DeepMOT ,
51+ 'botsort' : BoTSORT ,
52+ 'uavmot' : UAVMOT ,
53+ 'strongsort' : StrongSORT ,
54+ 'c_biou' : C_BIoUTracker ,
55+ } # dict for trackers, key: str, value: class(BaseTracker)
56+
57+ # NOTE: ATTENTION: make kalman and tracker compatible
58+ if opts .tracker == 'botsort' :
59+ opts .kalman_format = 'botsort'
60+ elif opts .tracker == 'strongsort' :
61+ opts .kalman_format = 'strongsort'
62+
63+ """
64+ 1. load model
65+ """
66+ device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
67+ ckpt = torch .load (opts .model_path , map_location = device )
68+ model = ckpt ['ema' if ckpt .get ('ema' ) else 'model' ].float ().fuse ().eval () # for yolo v7
69+
70+ if opts .trace :
71+ print (opts .img_size )
72+ model = TracedModel (model , device , opts .img_size )
73+ else :
74+ model .to (device )
75+
76+ """
77+ 2. load videos or images
78+ """
79+ obj_name = opts .obj
80+ # if read video, then put every frame into a queue
81+ # if read image seqs, the same as video
82+ resized_images_queue = [] # List[torch.Tensor] store resized images
83+ images_queue = [] # List[torch.Tensor] store origin images
84+
85+ # check path
86+ assert os .path .exists (obj_name ), 'the path does not exist! '
87+ obj , get_next_frame = None , None # init obj
88+ if 'mp4' in opts .obj : # if it is a video
89+ obj = cv2 .VideoCapture (obj_name )
90+ get_next_frame = lambda _ : obj .read ()
91+
92+ else :
93+ obj = my_queue (os .listdir (obj_name ))
94+ get_next_frame = lambda _ : obj .pop_front ()
95+
96+
97+ """
98+ 3. start tracking
99+ """
100+ tracker = TRACKER_DICT [opts .tracker ](opts , frame_rate = 30 , gamma = opts .gamma ) # instantiate tracker TODO: finish init params
101+ results = [] # store current seq results
102+ frame_id = 0
103+
104+ while True :
105+ print (f'----------processing frame { frame_id } ----------' )
106+
107+ # end condition
108+ is_valid , img0 = get_next_frame (None ) # img0: (H, W, C)
109+
110+ if not is_valid :
111+ break # end of reading
112+
113+ img = resize_a_frame (img0 , [opts .img_size , opts .img_size ])
114+
115+ timer .tic () # start timing this img
116+ img = img .unsqueeze (0 ) # (C, H, W) -> (bs == 1, C, H, W)
117+ out = model (img .to (device )) # model forward
118+ out = out [0 ] # NOTE: for yolo v7
119+
120+ if len (out .shape ) == 3 : # case (bs, num_obj, ...)
121+ # out = out.squeeze()
122+ # NOTE: assert batch size == 1
123+ out = out .squeeze (0 )
124+ # remove some low conf detections
125+ out = out [out [:, 4 ] > 0.001 ]
126+
127+
128+ # NOTE: yolo v7 origin out format: [xc, yc, w, h, conf, cls0_conf, cls1_conf, ..., clsn_conf]
129+ cls_conf , cls_idx = torch .max (out [:, 5 :], dim = 1 )
130+ # out[:, 4] *= cls_conf # fuse object and cls conf
131+ out [:, 5 ] = cls_idx
132+ out = out [:, :6 ]
133+
134+ current_tracks = tracker .update (out , img0 ) # List[class(STracks)]
135+
136+
137+ # save results
138+ cur_tlwh , cur_id , cur_cls = [], [], []
139+ for trk in current_tracks :
140+ bbox = trk .tlwh
141+ id = trk .track_id
142+ cls = trk .cls
143+
144+ # filter low area bbox
145+ if bbox [2 ] * bbox [3 ] > opts .min_area :
146+ cur_tlwh .append (bbox )
147+ cur_id .append (id )
148+ cur_cls .append (cls )
149+ # results.append((frame_id + 1, id, bbox, cls))
150+
151+ results .append ((frame_id + 1 , cur_id , cur_tlwh , cur_cls ))
152+ timer .toc () # end timing this image
153+
154+ plot_img (img0 , frame_id , [cur_tlwh , cur_id , cur_cls ], save_dir = os .path .join (SAVE_FOLDER , 'reuslt_images' , obj_name ))
155+
156+ frame_id += 1
157+
158+ seq_fps .append (frame_id / timer .total_time ) # cal fps for current seq
159+ timer .clear () # clear for next seq
160+ # thirdly, save results
161+ # every time assign a different name
162+ if opts .save_txt : save_results (obj_name , '' , results )
163+
164+ ## finally, save videos
165+ save_videos (obj_name )
166+
167+
168+ class my_queue :
169+ """
170+ implement a queue for image seq reading
171+ """
172+ def __init__ (self , arr : list ) -> None :
173+ self .arr = arr
174+ self .start_idx = 0
175+
176+ def push_back (self , item ):
177+ self .arr .append (item )
178+
179+ def pop_front (self ):
180+ ret = cv2 .imread (self .arr [self .start_idx ])
181+ self .start_idx += 1
182+ return not self .is_empty (), ret
183+
184+ def is_empty (self ):
185+ return self .start_idx == len (self .arr )
186+
187+
188+ def resize_a_frame (frame , target_size ):
189+ """
190+ resize a frame to target size
191+
192+ frame: np.ndarray, shape (H, W, C)
193+ target_size: List[int, int] | Tuple[int, int]
194+ """
195+ # resize to input to the YOLO net
196+ frame_resized = cv2 .resize (frame , (target_size [0 ], target_size [1 ])) # (H', W', C)
197+ # convert BGR to RGB and to (C, H, W)
198+ frame_resized = frame_resized [:, :, ::- 1 ].transpose (2 , 0 , 1 )
199+
200+ frame_resized = np .ascontiguousarray (frame_resized , dtype = np .float32 )
201+ frame_resized /= 255.0
202+
203+ frame_resized = torch .from_numpy (frame_resized )
204+
205+ return frame_resized
206+
207+
208+ def save_results (obj_name , results , data_type = 'default' ):
209+ """
210+ write results to txt file
211+
212+ results: list row format: frame id, target id, box coordinate, class(optional)
213+ to_file: file path(optional)
214+ data_type: write data format
215+ """
216+ assert len (results )
217+ if not data_type == 'default' :
218+ raise NotImplementedError # TODO
219+
220+ with open (os .path .join (SAVE_FOLDER , obj_name + '.txt' ), 'w' ) as f :
221+ for frame_id , target_ids , tlwhs , clses in results :
222+ if data_type == 'default' :
223+
224+ # f.write(f'{frame_id},{target_id},{tlwh[0]},{tlwh[1]},\
225+ # {tlwh[2]},{tlwh[3]},{cls}\n')
226+ for id , tlwh , cls in zip (target_ids , tlwhs , clses ):
227+ f .write (f'{ frame_id } ,{ id } ,{ tlwh [0 ]:.2f} ,{ tlwh [1 ]:.2f} ,{ tlwh [2 ]:.2f} ,{ tlwh [3 ]:.2f} ,{ int (cls )} \n ' )
228+ f .close ()
229+
230+ def plot_img (img , frame_id , results , save_dir ):
231+ """
232+ img: np.ndarray: (H, W, C)
233+ frame_id: int
234+ results: [tlwhs, ids, clses]
235+ save_dir: sr
236+
237+ plot images with bboxes of a seq
238+ """
239+ if not os .path .exists (save_dir ):
240+ os .makedirs (save_dir )
241+
242+ img_ = np .ascontiguousarray (np .copy (img ))
243+
244+ tlwhs , ids , clses = results [0 ], results [1 ], results [2 ]
245+ for tlwh , id , cls in zip (tlwhs , ids , clses ):
246+
247+ # convert tlwh to tlbr
248+ tlbr = tuple ([int (tlwh [0 ]), int (tlwh [1 ]), int (tlwh [0 ] + tlwh [2 ]), int (tlwh [1 ] + tlwh [3 ])])
249+ # draw a rect
250+ cv2 .rectangle (img_ , tlbr [:2 ], tlbr [2 :], get_color (id ), thickness = 3 , )
251+ # note the id and cls
252+ text = f'{ CATEGORY_DICT [cls ]} -{ id } '
253+ cv2 .putText (img_ , text , (tlbr [0 ], tlbr [1 ]), fontFace = cv2 .FONT_HERSHEY_PLAIN , fontScale = 1 ,
254+ color = (255 , 164 , 0 ), thickness = 2 )
255+
256+ cv2 .imwrite (filename = os .path .join (save_dir , f'{ frame_id :05d} .jpg' ), img = img_ )
257+
258+
259+ def save_videos (obj_name ):
260+ """
261+ convert imgs to a video
262+
263+ seq_names: List[str] or str, seqs that will be generated
264+ """
265+ if not isinstance (obj_name , list ):
266+ obj_name = [obj_name ]
267+
268+ for seq in obj_name :
269+ images_path = os .path .join (SAVE_FOLDER , 'reuslt_images' , seq )
270+ images_name = sorted (os .listdir (images_path ))
271+
272+ to_video_path = os .path .join (images_path , '../' , seq + '.mp4' )
273+ fourcc = cv2 .VideoWriter_fourcc (* "mp4v" )
274+
275+ img0 = Image .open (os .path .join (images_path , images_name [0 ]))
276+ vw = cv2 .VideoWriter (to_video_path , fourcc , 15 , img0 .size )
277+
278+ for img in images_name :
279+ if img .endswith ('.jpg' ):
280+ frame = cv2 .imread (os .path .join (images_path , img ))
281+ vw .write (frame )
282+
283+ print ('Save videos Done!!' )
284+
285+
286+ def get_color (idx ):
287+ """
288+ aux func for plot_seq
289+ get a unique color for each id
290+ """
291+ idx = idx * 3
292+ color = ((37 * idx ) % 255 , (17 * idx ) % 255 , (29 * idx ) % 255 )
293+
294+ return color
295+
296+ if __name__ == '__main__' :
297+ parser = argparse .ArgumentParser ()
298+
299+ parser .add_argument ('--obj' , type = str , default = 'M1305.mp4' , help = 'video NAME or images FOLDER NAME' )
300+
301+ parser .add_argument ('--save_txt' , type = bool , default = False , help = 'whether save txt' )
302+
303+ parser .add_argument ('--tracker' , type = str , default = 'sort' , help = 'sort, deepsort, etc' )
304+ parser .add_argument ('--model_path' , type = str , default = './weights/best.pt' , help = 'model path' )
305+ parser .add_argument ('--trace' , type = bool , default = False , help = 'traced model of YOLO v7' )
306+
307+ parser .add_argument ('--img_size' , type = int , default = 1280 , help = '[train, test] image sizes' )
308+
309+ """For tracker"""
310+ # model path
311+ parser .add_argument ('--reid_model_path' , type = str , default = './weights/ckpt.t7' , help = 'path for reid model path' )
312+ parser .add_argument ('--dhn_path' , type = str , default = './weights/DHN.pth' , help = 'path of DHN path for DeepMOT' )
313+
314+ # threshs
315+ parser .add_argument ('--conf_thresh' , type = float , default = 0.5 , help = 'filter tracks' )
316+ parser .add_argument ('--nms_thresh' , type = float , default = 0.7 , help = 'thresh for NMS' )
317+ parser .add_argument ('--iou_thresh' , type = float , default = 0.5 , help = 'IOU thresh to filter tracks' )
318+
319+ # other options
320+ parser .add_argument ('--track_buffer' , type = int , default = 30 , help = 'tracking buffer' )
321+ parser .add_argument ('--gamma' , type = float , default = 0.1 , help = 'param to control fusing motion and apperance dist' )
322+ parser .add_argument ('--kalman_format' , type = str , default = 'default' , help = 'use what kind of Kalman, default, naive, strongsort or bot-sort like' )
323+ parser .add_argument ('--min_area' , type = float , default = 150 , help = 'use to filter small bboxs' )
324+
325+ opts = parser .parse_args ()
326+
327+ if not os .path .exists (SAVE_FOLDER ): # demo save to a particular folder
328+ os .makedirs (SAVE_FOLDER )
329+ os .makedirs (os .path .join (SAVE_FOLDER , 'result_images' ))
330+ main (opts )
0 commit comments