Skip to content

Commit 2253ec9

Browse files
committed
Tensorflow example with SSD-mobile-net
1 parent 743d368 commit 2253ec9

File tree

4 files changed

+375
-4
lines changed

4 files changed

+375
-4
lines changed

README.md

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1-
[output_video]: ./assets/sample-output.gif "Sample Output"
1+
[output_video_1]: ./assets/sample-output.gif "Sample Output with YOLO"
2+
[output_video_2]: ./assets/sample-output-2.gif "Sample Output with SSD"
23

3-
# multi-object-tracker
4-
object detection using deep learning and multi-object tracking
4+
# Multi-Object-Tracker
5+
Object detection using deep learning and multi-object tracking
6+
7+
#### YOLO
8+
![Output Sample with YOLO][output_video_1]
9+
10+
#### SSD
11+
![Output Sample with SSD][output_video_2]
512

6-
![Output Sample][output_video]
713

814
### Install OpenCV
915
Pip install for OpenCV (version 3.4.3 or later) is available [here](https://pypi.org/project/opencv-python/) and can be done with the following command:
@@ -24,6 +30,23 @@ The model and the config files will be downloaded in `./yolo_dir`. These will be
2430

2531
Example video used in above demo: https://flic.kr/p/L6qyxj
2632

33+
### Run with TensorFlow SSD model
34+
35+
1. Open the terminal
36+
2. Go to the tensorflow_model_dir: `cd ./tensorflow_model_dir`
37+
3. Run: `sudo chmod +x ./get_ssd_model.sh`
38+
4. Run: `./get_ssd_model.sh`
39+
40+
This will download model and config files in `./tensorflow_model_dir`. These will be used `tracking-tensorflow-ssd_mobilenet_v2_coco_2018_03_29.ipynb`.
41+
42+
**SSD-Mobilenet_v2_coco_2018_03_29** was used for this example.
43+
Other networks can be downloaded and ran: Go through `tracking-tensorflow-ssd_mobilenet_v2_coco_2018_03_29.ipynb` for more details.
44+
45+
- The video input can be specified in the cell named `Initiate opencv video capture object` in the notebook.
46+
- To make the source as the webcam, use `video_src=0` else provide the path of the video file (example: `video_src="/path/of/videofile.mp4"`).
47+
48+
Video used in SSD-Mobilenet multi-object detection and tracking: https://flic.kr/p/26WeEWy
49+
2750
### Run with Caffemodel
2851
- You have to use `tracking-caffe-model.ipynb`.
2952
- The model for use is provided in the folder named `caffemodel_dir`.

assets/sample-output-2.gif

36.9 MB
Loading
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/bin/sh
2+
3+
# Get models from https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API#use-existing-config-file-for-your-model
4+
5+
wget -c http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v2_coco_2018_03_29.tar.gz -O - | tar -xz
6+
wget https://raw.githubusercontent.com/opencv/opencv_extra/master/testdata/dnn/ssd_mobilenet_v2_coco_2018_03_29.pbtxt
7+
8+
# Tensorflow object detection API: https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API#use-existing-config-file-for-your-model
Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import cv2 as cv\n",
10+
"from scipy.spatial import distance\n",
11+
"import numpy as np\n",
12+
"from collections import OrderedDict"
13+
]
14+
},
15+
{
16+
"cell_type": "markdown",
17+
"metadata": {},
18+
"source": [
19+
"##### Object Tracking Class"
20+
]
21+
},
22+
{
23+
"cell_type": "code",
24+
"execution_count": 2,
25+
"metadata": {},
26+
"outputs": [],
27+
"source": [
28+
"class Tracker:\n",
29+
" def __init__(self, maxLost = 30): # maxLost: maximum object lost counted when the object is being tracked\n",
30+
" self.nextObjectID = 0 # ID of next object\n",
31+
" self.objects = OrderedDict() # stores ID:Locations\n",
32+
" self.lost = OrderedDict() # stores ID:Lost_count\n",
33+
" \n",
34+
" self.maxLost = maxLost # maximum number of frames object was not detected.\n",
35+
" \n",
36+
" def addObject(self, new_object_location):\n",
37+
" self.objects[self.nextObjectID] = new_object_location # store new object location\n",
38+
" self.lost[self.nextObjectID] = 0 # initialize frame_counts for when new object is undetected\n",
39+
" \n",
40+
" self.nextObjectID += 1\n",
41+
" \n",
42+
" def removeObject(self, objectID): # remove tracker data after object is lost\n",
43+
" del self.objects[objectID]\n",
44+
" del self.lost[objectID]\n",
45+
" \n",
46+
" @staticmethod\n",
47+
" def getLocation(bounding_box):\n",
48+
" xlt, ylt, xrb, yrb = bounding_box\n",
49+
" return (int((xlt + xrb) / 2.0), int((ylt + yrb) / 2.0))\n",
50+
" \n",
51+
" def update(self, detections):\n",
52+
" \n",
53+
" if len(detections) == 0: # if no object detected in the frame\n",
54+
" for objectID in self.lost.keys():\n",
55+
" self.lost[objectID] +=1\n",
56+
" if self.lost[objectID] > self.maxLost: self.removeObject(objectID)\n",
57+
" \n",
58+
" return self.objects\n",
59+
" \n",
60+
" new_object_locations = np.zeros((len(detections), 2), dtype=\"int\") # current object locations\n",
61+
" \n",
62+
" for (i, detection) in enumerate(detections): new_object_locations[i] = self.getLocation(detection)\n",
63+
" \n",
64+
" if len(self.objects)==0:\n",
65+
" for i in range(0, len(detections)): self.addObject(new_object_locations[i])\n",
66+
" else:\n",
67+
" objectIDs = list(self.objects.keys())\n",
68+
" previous_object_locations = np.array(list(self.objects.values()))\n",
69+
" \n",
70+
" D = distance.cdist(previous_object_locations, new_object_locations) # pairwise distance between previous and current\n",
71+
" \n",
72+
" row_idx = D.min(axis=1).argsort() # (minimum distance of previous from current).sort_as_per_index\n",
73+
" \n",
74+
" cols_idx = D.argmin(axis=1)[row_idx] # index of minimum distance of previous from current\n",
75+
" \n",
76+
" assignedRows, assignedCols = set(), set()\n",
77+
" \n",
78+
" for (row, col) in zip(row_idx, cols_idx):\n",
79+
" \n",
80+
" if row in assignedRows or col in assignedCols:\n",
81+
" continue\n",
82+
" \n",
83+
" objectID = objectIDs[row]\n",
84+
" self.objects[objectID] = new_object_locations[col]\n",
85+
" self.lost[objectID] = 0\n",
86+
" \n",
87+
" assignedRows.add(row)\n",
88+
" assignedCols.add(col)\n",
89+
" \n",
90+
" unassignedRows = set(range(0, D.shape[0])).difference(assignedRows)\n",
91+
" unassignedCols = set(range(0, D.shape[1])).difference(assignedCols)\n",
92+
" \n",
93+
" \n",
94+
" if D.shape[0]>=D.shape[1]:\n",
95+
" for row in unassignedRows:\n",
96+
" objectID = objectIDs[row]\n",
97+
" self.lost[objectID] += 1\n",
98+
" \n",
99+
" if self.lost[objectID] > self.maxLost:\n",
100+
" self.removeObject(objectID)\n",
101+
" \n",
102+
" else:\n",
103+
" for col in unassignedCols:\n",
104+
" self.addObject(new_object_locations[col])\n",
105+
" \n",
106+
" return self.objects\n"
107+
]
108+
},
109+
{
110+
"cell_type": "markdown",
111+
"metadata": {},
112+
"source": [
113+
"#### Loading Object Detector Model"
114+
]
115+
},
116+
{
117+
"cell_type": "markdown",
118+
"metadata": {},
119+
"source": [
120+
"##### Tensorflow model for Object Detection and Tracking\n",
121+
"\n",
122+
"Here, the SSD Object Detection Model is used.\n",
123+
"\n",
124+
"For more details about single shot detection (SSD), refer the following:\n",
125+
" - **Liu, W., Anguelov, D., Erhan, D., Szegedy, C., Reed, S., Fu, C. Y., & Berg, A. C. (2016, October). Ssd: Single shot multibox detector. In European conference on computer vision (pp. 21-37). Springer, Cham.**\n",
126+
" - Research paper link: https://arxiv.org/abs/1512.02325\n",
127+
" - The pretrained model: https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API#use-existing-config-file-for-your-model"
128+
]
129+
},
130+
{
131+
"cell_type": "code",
132+
"execution_count": 3,
133+
"metadata": {},
134+
"outputs": [],
135+
"source": [
136+
"model_info = {\"config_path\":\"./tensorflow_model_dir/ssd_mobilenet_v2_coco_2018_03_29.pbtxt\",\n",
137+
" \"model_weights_path\":\"./tensorflow_model_dir/ssd_mobilenet_v2_coco_2018_03_29/frozen_inference_graph.pb\",\n",
138+
" \"object_names\": {0: 'background',\n",
139+
" 1: 'person', 2: 'bicycle', 3: 'car', 4: 'motorcycle', 5: 'airplane', 6: 'bus',\n",
140+
" 7: 'train', 8: 'truck', 9: 'boat', 10: 'traffic light', 11: 'fire hydrant',\n",
141+
" 13: 'stop sign', 14: 'parking meter', 15: 'bench', 16: 'bird', 17: 'cat',\n",
142+
" 18: 'dog', 19: 'horse', 20: 'sheep', 21: 'cow', 22: 'elephant', 23: 'bear',\n",
143+
" 24: 'zebra', 25: 'giraffe', 27: 'backpack', 28: 'umbrella', 31: 'handbag',\n",
144+
" 32: 'tie', 33: 'suitcase', 34: 'frisbee', 35: 'skis', 36: 'snowboard',\n",
145+
" 37: 'sports ball', 38: 'kite', 39: 'baseball bat', 40: 'baseball glove',\n",
146+
" 41: 'skateboard', 42: 'surfboard', 43: 'tennis racket', 44: 'bottle',\n",
147+
" 46: 'wine glass', 47: 'cup', 48: 'fork', 49: 'knife', 50: 'spoon',\n",
148+
" 51: 'bowl', 52: 'banana', 53: 'apple', 54: 'sandwich', 55: 'orange',\n",
149+
" 56: 'broccoli', 57: 'carrot', 58: 'hot dog', 59: 'pizza', 60: 'donut',\n",
150+
" 61: 'cake', 62: 'chair', 63: 'couch', 64: 'potted plant', 65: 'bed',\n",
151+
" 67: 'dining table', 70: 'toilet', 72: 'tv', 73: 'laptop', 74: 'mouse',\n",
152+
" 75: 'remote', 76: 'keyboard', 77: 'cell phone', 78: 'microwave', 79: 'oven',\n",
153+
" 80: 'toaster', 81: 'sink', 82: 'refrigerator', 84: 'book', 85: 'clock',\n",
154+
" 86: 'vase', 87: 'scissors', 88: 'teddy bear', 89: 'hair drier', 90: 'toothbrush'},\n",
155+
" \"confidence_threshold\": 0.5,\n",
156+
" \"threshold\": 0.4\n",
157+
" }\n",
158+
"\n",
159+
"net = cv.dnn.readNetFromTensorflow(model_info[\"model_weights_path\"], model_info[\"config_path\"])"
160+
]
161+
},
162+
{
163+
"cell_type": "code",
164+
"execution_count": 4,
165+
"metadata": {
166+
"scrolled": true
167+
},
168+
"outputs": [],
169+
"source": [
170+
"np.random.seed(12345)\n",
171+
"\n",
172+
"bbox_colors = {key: np.random.randint(0, 255, size=(3,)).tolist() for key in model_info['object_names'].keys()}"
173+
]
174+
},
175+
{
176+
"cell_type": "markdown",
177+
"metadata": {},
178+
"source": [
179+
"##### Instantiate the Tracker Class"
180+
]
181+
},
182+
{
183+
"cell_type": "code",
184+
"execution_count": 5,
185+
"metadata": {},
186+
"outputs": [],
187+
"source": [
188+
"maxLost = 5 # maximum number of object losts counted when the object is being tracked\n",
189+
"tracker = Tracker(maxLost = maxLost)"
190+
]
191+
},
192+
{
193+
"cell_type": "markdown",
194+
"metadata": {},
195+
"source": [
196+
"##### Initiate opencv video capture object\n",
197+
"\n",
198+
"The `video_src` can take two values:\n",
199+
"1. If `video_src=0`: OpenCV accesses the camera connected through USB\n",
200+
"2. If `video_src='video_file_path'`: OpenCV will access the video file at the given path (can be MP4, AVI, etc format)"
201+
]
202+
},
203+
{
204+
"cell_type": "code",
205+
"execution_count": 6,
206+
"metadata": {},
207+
"outputs": [],
208+
"source": [
209+
"video_src = \"./data/video_test5.mp4\"#0\n",
210+
"cap = cv.VideoCapture(video_src)"
211+
]
212+
},
213+
{
214+
"cell_type": "markdown",
215+
"metadata": {},
216+
"source": [
217+
"##### Start object detection and tracking"
218+
]
219+
},
220+
{
221+
"cell_type": "code",
222+
"execution_count": 7,
223+
"metadata": {
224+
"scrolled": false
225+
},
226+
"outputs": [
227+
{
228+
"name": "stdout",
229+
"output_type": "stream",
230+
"text": [
231+
"Cannot read the video feed.\n"
232+
]
233+
}
234+
],
235+
"source": [
236+
"(H, W) = (None, None) # input image height and width for the network\n",
237+
"writer = None\n",
238+
"while(True):\n",
239+
" \n",
240+
" ok, image = cap.read()\n",
241+
" \n",
242+
" if not ok:\n",
243+
" print(\"Cannot read the video feed.\")\n",
244+
" break\n",
245+
" \n",
246+
" if W is None or H is None: (H, W) = image.shape[:2]\n",
247+
" \n",
248+
" blob = cv.dnn.blobFromImage(image, size=(300, 300), swapRB=True, crop=False)\n",
249+
" net.setInput(blob)\n",
250+
" detections = net.forward()\n",
251+
" \n",
252+
" detections_bbox = [] # bounding box for detections\n",
253+
" \n",
254+
" boxes, confidences, classIDs = [], [], []\n",
255+
" \n",
256+
" for detection in detections[0, 0, :, :]:\n",
257+
" classID = detection[1]\n",
258+
" confidence = detection[2]\n",
259+
"\n",
260+
" if confidence > model_info['confidence_threshold']:\n",
261+
" box = detection[3:7] * np.array([W, H, W, H])\n",
262+
" \n",
263+
" (left, top, right, bottom) = box.astype(\"int\")\n",
264+
" width = right - left + 1\n",
265+
" height = bottom - top + 1\n",
266+
"\n",
267+
" boxes.append([int(left), int(top), int(width), int(height)])\n",
268+
" confidences.append(float(confidence))\n",
269+
" classIDs.append(int(classID))\n",
270+
" \n",
271+
" indices = cv.dnn.NMSBoxes(boxes, confidences, model_info[\"confidence_threshold\"], model_info[\"threshold\"])\n",
272+
" \n",
273+
" if len(indices)>0:\n",
274+
" for i in indices.flatten():\n",
275+
" x, y, w, h = boxes[i][0], boxes[i][1], boxes[i][2], boxes[i][3]\n",
276+
" \n",
277+
" detections_bbox.append((x, y, x+w, y+h))\n",
278+
" \n",
279+
" clr = [int(c) for c in bbox_colors[classIDs[i]]]\n",
280+
" cv.rectangle(image, (x, y), (x+w, y+h), clr, 2)\n",
281+
" \n",
282+
" label = \"{}:{:.4f}\".format(model_info[\"object_names\"][classIDs[i]], confidences[i])\n",
283+
" (label_width, label_height), baseLine = cv.getTextSize(label, cv.FONT_HERSHEY_SIMPLEX, 0.5, 2)\n",
284+
" y_label = max(y, label_height)\n",
285+
" cv.rectangle(image, (x, y_label-label_height),\n",
286+
" (x+label_width, y_label+baseLine), (255, 255, 255), cv.FILLED)\n",
287+
" cv.putText(image, label, (x, y_label), cv.FONT_HERSHEY_SIMPLEX, 0.5, clr, 2)\n",
288+
" \n",
289+
" objects = tracker.update(detections_bbox) # update tracker based on the newly detected objects\n",
290+
" \n",
291+
" for (objectID, centroid) in objects.items():\n",
292+
" text = \"ID {}\".format(objectID)\n",
293+
" cv.putText(image, text, (centroid[0] - 10, centroid[1] - 10), cv.FONT_HERSHEY_SIMPLEX,\n",
294+
" 0.5, (0, 255, 0), 2)\n",
295+
" cv.circle(image, (centroid[0], centroid[1]), 4, (0, 255, 0), -1)\n",
296+
" \n",
297+
" cv.imshow(\"image\", image)\n",
298+
" \n",
299+
" if cv.waitKey(1) & 0xFF == ord('q'):\n",
300+
" break\n",
301+
" \n",
302+
" if writer is None:\n",
303+
" fourcc = cv.VideoWriter_fourcc(*\"MJPG\")\n",
304+
" writer = cv.VideoWriter(\"output.avi\", fourcc, 30, (W, H), True)\n",
305+
" writer.write(image)\n",
306+
"writer.release()\n",
307+
"cap.release()\n",
308+
"cv.destroyWindow(\"image\")"
309+
]
310+
},
311+
{
312+
"cell_type": "code",
313+
"execution_count": 8,
314+
"metadata": {},
315+
"outputs": [],
316+
"source": []
317+
}
318+
],
319+
"metadata": {
320+
"kernelspec": {
321+
"display_name": "drlnd",
322+
"language": "python",
323+
"name": "drlnd"
324+
},
325+
"language_info": {
326+
"codemirror_mode": {
327+
"name": "ipython",
328+
"version": 3
329+
},
330+
"file_extension": ".py",
331+
"mimetype": "text/x-python",
332+
"name": "python",
333+
"nbconvert_exporter": "python",
334+
"pygments_lexer": "ipython3",
335+
"version": "3.6.8"
336+
}
337+
},
338+
"nbformat": 4,
339+
"nbformat_minor": 2
340+
}

0 commit comments

Comments
 (0)