forked from adipandas/multi-object-tracker
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdetector_TF_SSDMobileNetV2.py
More file actions
69 lines (52 loc) · 2.03 KB
/
detector_TF_SSDMobileNetV2.py
File metadata and controls
69 lines (52 loc) · 2.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import cv2 as cv
from motrackers.detectors import TF_SSDMobileNetV2
def main(video_path, model):
cap = cv.VideoCapture(video_path)
while True:
ok, image = cap.read()
if not ok:
print("Cannot read the video feed.")
break
bboxes, confidences, class_ids = model.detect(image)
updated_image = model.draw_bboxes(image, bboxes, confidences, class_ids)
cv.imshow("image", updated_image)
if cv.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv.destroyAllWindows()
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(
description='Object detections in input video using TensorFlow model of MobileNetSSD.')
parser.add_argument(
'--video', '-v', type=str, default="./../video_data/cars.mp4", help='Input video path.')
parser.add_argument(
'--weights', '-w', type=str,
default="./../pretrained_models/tensorflow_weights/ssd_mobilenet_v2_coco_2018_03_29/frozen_inference_graph.pb",
help='path to weights file of tf-MobileNetSSD (`.pb` file).'
)
parser.add_argument(
'--config', '-c', type=str,
default="./../pretrained_models/tensorflow_weights/ssd_mobilenet_v2_coco_2018_03_29.pbtxt",
help='path to config file of Caffe-MobileNetSSD (`.pbtxt` file).'
)
parser.add_argument(
'--labels', '-l', type=str,
default="./../pretrained_models/tensorflow_weights/ssd_mobilenet_v2_coco_names.json",
help='path to labels file of coco dataset (`.names` file.)'
)
parser.add_argument(
'--gpu', type=bool, default=False,
help='Flag to use gpu to run the deep learning model. Default is `False`'
)
args = parser.parse_args()
model = TF_SSDMobileNetV2(
weights_path=args.weights,
configfile_path=args.config,
labels_path=args.labels,
confidence_threshold=0.5,
nms_threshold=0.2,
draw_bboxes=True,
use_gpu=args.gpu
)
main(args.video, model)