From a778b554d1cf78b525d86fd5b1dc07b671ea26b5 Mon Sep 17 00:00:00 2001 From: z-zawhtet-a Date: Sat, 27 Jan 2024 12:10:59 +0700 Subject: [PATCH] modified return values --- .gitignore | 30 +- Cargo.lock | 77 +- Cargo.toml | 78 +- LICENSE | 42 +- README.md | 76 +- examples/demo.rs | 320 ++-- examples/votchallenge/.python-version | 2 +- examples/votchallenge/README.md | 186 +-- examples/votchallenge/config.yaml | 6 +- examples/votchallenge/main.rs | 328 ++--- examples/votchallenge/trackers.template.ini | 14 +- examples/votchallenge/trax_protocol.rs | 426 +++--- index.html | 231 +-- multi-object-debugger.html | 85 ++ src/lib.rs | 1462 ++++++++++--------- src/wasm.rs | 224 +-- 16 files changed, 1890 insertions(+), 1697 deletions(-) create mode 100644 multi-object-debugger.html diff --git a/.gitignore b/.gitignore index c01f9ad..5f38f3c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,15 +1,15 @@ -target/ -.DS_Store - -# python venv for the `vot` tool -examples/votchallenge/vot_venv - -# folders generated by the `vot` tool -examples/votchallenge/results -examples/votchallenge/sequences -examples/votchallenge/logs -examples/votchallenge/analysis -examples/votchallenge/cache - -# config file for `vot` that needs to be hand-edited per machine -examples/votchallenge/trackers.ini +target/ +.DS_Store + +# python venv for the `vot` tool +examples/votchallenge/vot_venv + +# folders generated by the `vot` tool +examples/votchallenge/results +examples/votchallenge/sequences +examples/votchallenge/logs +examples/votchallenge/analysis +examples/votchallenge/cache + +# config file for `vot` that needs to be hand-edited per machine +examples/votchallenge/trackers.ini diff --git a/Cargo.lock b/Cargo.lock index 4cbef43..338acb5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -262,6 +262,12 @@ dependencies = [ "either", ] +[[package]] +name = "itoa" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" + [[package]] name = "jpeg-decoder" version = "0.2.6" @@ -330,6 +336,8 @@ dependencies = [ "log", "rustfft", "rusttype", + "serde", + "serde_json", "time", "wasm-bindgen", ] @@ -494,18 +502,18 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.43" +version = "1.0.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a2ca2c61bc9f3d74d2886294ab7b9853abd9c1ad903a3ac7815c58989bb7bab" +checksum = "39278fbbf5fb4f646ce651690877f89d1c5811a3d4acb27700c1cb3cdb78fd3b" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.21" +version = "1.0.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbe448f377a7d6961e30f5955f9b8d106c3f5e449d493ee1b125c1d43c2b5179" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" dependencies = [ "proc-macro2", ] @@ -631,6 +639,12 @@ dependencies = [ "owned_ttf_parser", ] +[[package]] +name = "ryu" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" + [[package]] name = "safe_arch" version = "0.6.0" @@ -646,6 +660,37 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "serde" +version = "1.0.185" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be9b6f69f1dfd54c3b568ffa45c310d6973a5e5148fd40cf515acaf38cf5bc31" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.185" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc59dfdcbad1437773485e0367fea4b090a2e0a16d9ffc46af47764536a298ec" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d1c7e3eac408d115102c4c24ad393e0821bb3a5df4d506a80f85f7a742a526b" +dependencies = [ + "itoa", + "ryu", + "serde", +] + [[package]] name = "simba" version = "0.7.2" @@ -667,9 +712,9 @@ checksum = "a3ff2f71c82567c565ba4b3009a9350a96a7269eaa4001ebedae926230bc2254" [[package]] name = "syn" -version = "1.0.99" +version = "2.0.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58dbef6ec655055e20b86b15a8cc6d439cca19b667537ac6a1369572d151ab13" +checksum = "23e78b90f2fcf45d3e842032ce32e3f2d1545ba6636271dcbf24fa306d87be7a" dependencies = [ "proc-macro2", "quote", @@ -731,9 +776,9 @@ checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" [[package]] name = "wasm-bindgen" -version = "0.2.82" +version = "0.2.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc7652e3f6c4706c8d9cd54832c4a4ccb9b5336e2c3bd154d5cccfbf1c1f5f7d" +checksum = "7daec296f25a1bae309c0cd5c29c4b260e510e6d813c286b19eaadf409d40fce" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -741,9 +786,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.82" +version = "0.2.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "662cd44805586bd52971b9586b1df85cdbbd9112e4ef4d8f41559c334dc6ac3f" +checksum = "e397f4664c0e4e428e8313a469aaa58310d302159845980fd23b0f22a847f217" dependencies = [ "bumpalo", "log", @@ -756,9 +801,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.82" +version = "0.2.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b260f13d3012071dfb1512849c033b1925038373aea48ced3012c09df952c602" +checksum = "5961017b3b08ad5f3fe39f1e79877f8ee7c23c5e5fd5eb80de95abc41f1f16b2" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -766,9 +811,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.82" +version = "0.2.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5be8e654bdd9b79216c2929ab90721aa82faf65c48cdf08bdc4e7f51357b80da" +checksum = "c5353b8dab669f5e10f5bd76df26a9360c748f054f862ff5f3f8aae0c7fb3907" dependencies = [ "proc-macro2", "quote", @@ -779,9 +824,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.82" +version = "0.2.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6598dd0bd3c7d51095ff6531a5b23e02acdc81804e30d8f07afb77b7215a140a" +checksum = "0d046c5d029ba91a1ed14da14dca44b68bf2f124cfbaf741c54151fdb3e0750b" [[package]] name = "wide" diff --git a/Cargo.toml b/Cargo.toml index a878996..017c01c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,38 +1,40 @@ -[package] -name = "mosse" -version = "0.1.0" -edition = "2021" -authors = ["Jurriaan Barkey Wolf "] -description = "A proof-of-concept implementation of the MOSSE video object tracking algorithm by Bolme et al." -repository = "https://github.com/jjhbw/mosse-tracker" -license = "MIT" - -[lib] -crate-type = ["cdylib", "rlib"] - -[features] -default = ["rayon"] -rayon = ["imageproc/rayon", "image/jpeg_rayon"] - -[dependencies] -image = { version = "0.24.2", default-features = false, features = [ - "png", - "jpeg", -] } -rustfft = "6.0.1" -imageproc = { version = "0.23.0", default-features = false } - -# for font rendering on output/debug frames (same version as imageproc uses) -rusttype = "0.9.2" - -[target.wasm32-unknown-unknown.dependencies] -wasm-bindgen = { version = "0.2" } - -[dev-dependencies] -anyhow = "1.0.65" -env_logger = "0.9.1" -log = "0.4.17" -time = "0.3.11" - -[profile.release] -lto = true +[package] +name = "mosse" +version = "0.1.0" +edition = "2021" +authors = ["Jurriaan Barkey Wolf "] +description = "A proof-of-concept implementation of the MOSSE video object tracking algorithm by Bolme et al." +repository = "https://github.com/jjhbw/mosse-tracker" +license = "MIT" + +[lib] +crate-type = ["cdylib", "rlib"] + +[features] +default = ["rayon"] +rayon = ["imageproc/rayon", "image/jpeg_rayon"] + +[dependencies] +image = { version = "0.24.2", default-features = false, features = [ + "png", + "jpeg", +] } +rustfft = "6.0.1" +imageproc = { version = "0.23.0", default-features = false } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" + +# for font rendering on output/debug frames (same version as imageproc uses) +rusttype = "0.9.2" + +[target.wasm32-unknown-unknown.dependencies] +wasm-bindgen = { version = "0.2.88" } + +[dev-dependencies] +anyhow = "1.0.65" +env_logger = "0.9.1" +log = "0.4.17" +time = "0.3.11" + +[profile.release] +lto = true diff --git a/LICENSE b/LICENSE index 9a6af7f..d2c41cb 100644 --- a/LICENSE +++ b/LICENSE @@ -1,21 +1,21 @@ -MIT License - -Copyright (c) 2022 Jurriaan Barkey Wolf - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. +MIT License + +Copyright (c) 2022 Jurriaan Barkey Wolf + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index b06c5a9..bc55450 100644 --- a/README.md +++ b/README.md @@ -1,38 +1,38 @@ -# MOSSE tracker in Rust - -A Rust implementation of the Minimum Output Sum of Squared Error (MOSSE) tracking algorithm, as presented in the 2010 paper [Visual Object Tracking using Adaptive Correlation Filters](https://www.cs.colostate.edu/~vision/publications/bolme_cvpr10.pdf) by David S. Bolme et al. - -![example](example.gif) - -For a bit of extra context, check out the accompanying blog post at https://barkeywolf.consulting/posts/mosse-tracker/. - -## Running it - -### Cut up a video into frames - -```bash -ffmpeg -i ./testdata/traffic.mp4 -vf fps=30 ./testdata/traffic/img%04d.png -``` - -### Run the example binary - -Running a debug build (not using the `--release` flag) will dump the state of the filter at each frame to a file and will output additional debug information. Note that the image filenames need to be provided in order. Below commands should result in `test_tracking.mp4`. - -```bash -cargo run --release --example demo $(ls ./testdata/traffic/img0*.png) &&\ -ffmpeg -y -framerate 30 -i ./predicted_image_%4d.png -pix_fmt yuv420p test_tracking.mp4 &&\ -rm *.png -``` - -### Run web example - -```bash -wasm-pack build --no-default-features --target web -python3 -m http.server -``` - -Open [http://localhost:8000](http://localhost:8000) and allow webcam access. - -# Evaluate on the votchallenge dataset - -See [/examples/votchallenge](/examples/votchallenge). Thanks @alsuren for contributing the necessary code! +# MOSSE tracker in Rust + +A Rust implementation of the Minimum Output Sum of Squared Error (MOSSE) tracking algorithm, as presented in the 2010 paper [Visual Object Tracking using Adaptive Correlation Filters](https://www.cs.colostate.edu/~vision/publications/bolme_cvpr10.pdf) by David S. Bolme et al. + +![example](example.gif) + +For a bit of extra context, check out the accompanying blog post at https://barkeywolf.consulting/posts/mosse-tracker/. + +## Running it + +### Cut up a video into frames + +```bash +ffmpeg -i ./testdata/traffic.mp4 -vf fps=30 ./testdata/traffic/img%04d.png +``` + +### Run the example binary + +Running a debug build (not using the `--release` flag) will dump the state of the filter at each frame to a file and will output additional debug information. Note that the image filenames need to be provided in order. Below commands should result in `test_tracking.mp4`. + +```bash +cargo run --release --example demo $(ls ./testdata/traffic/img0*.png) &&\ +ffmpeg -y -framerate 30 -i ./predicted_image_%4d.png -pix_fmt yuv420p test_tracking.mp4 &&\ +rm *.png +``` + +### Run web example + +```bash +wasm-pack build --no-default-features --target web +python3 -m http.server +``` + +Open [http://localhost:8000](http://localhost:8000) and allow webcam access. + +# Evaluate on the votchallenge dataset + +See [/examples/votchallenge](/examples/votchallenge). Thanks @alsuren for contributing the necessary code! diff --git a/examples/demo.rs b/examples/demo.rs index d66aa2a..ec4e28f 100644 --- a/examples/demo.rs +++ b/examples/demo.rs @@ -1,160 +1,160 @@ -extern crate image; -extern crate imageproc; -extern crate mosse; -extern crate rusttype; -extern crate time; - -use image::Rgba; -use imageproc::drawing::{draw_cross_mut, draw_hollow_rect_mut, draw_text_mut}; -use imageproc::rect::Rect; -use mosse::{MosseTrackerSettings, MultiMosseTracker}; -use rusttype::{Font, Scale}; -use std::env; -use std::time::Instant; - -fn main() { - // Collect all elements in the iterator that contains the command line arguments - let args: Vec = env::args().collect(); - - // remove the first element from the list of arguments, which is the call to the binary - let inputfiles = &args[1..]; - - if inputfiles.len() == 0 { - panic!("no input files specified"); - } - let mut images = inputfiles.iter().map(|path| image::open(path).unwrap()); - let first = images.next().unwrap(); - - // initialize a new model - let (width, height) = first.to_rgb8().dimensions(); - let window_size = 64; //size of the tracking window - let psr_thresh = 7.0; // how high the psr must be before prediction is considered succesful. - let settings = MosseTrackerSettings { - window_size: window_size, - width, - height, - regularization: 0.001, - learning_rate: 0.05, - psr_threshold: psr_thresh, - }; - let desperation_threshold = 3; // how many frames the tracker should try to re-acquire the target until we consider it failed - let mut multi_tracker = MultiMosseTracker::new(settings, desperation_threshold); - - // coordinates of the target objects to track in the intial frame - let target_coords = vec![ - (143, 766), - (232, 653), - (291, 731), - (1298, 664), - (479, 642), - (574, 629), - (666, 627), - (762, 609), - ]; - - // Add all the targets on the first image to the multitracker - let first_img = first.to_luma8(); - for (i, coords) in target_coords.into_iter().enumerate() { - let start = Instant::now(); - multi_tracker.add_or_replace_target(i as u32, coords, &first_img); - println!( - "Added object on initial frame to multi-tracker in {} ms", - start.elapsed().as_millis() - ); - } - - for (i, dyn_img) in images.enumerate() { - // add leading zeroes for easier downstream proc with ffmpeg - let img_id = format!("{:<04}", i + 1); - - // track the objects on the new frame - let start = Instant::now(); - let predictions = multi_tracker.track(&dyn_img.to_luma8()); - - println!( - "Processed sample image no. {} in {} ms. Active trackers: {}.", - img_id, - start.elapsed().as_millis(), - multi_tracker.size(), - ); - - let mut img_copy = dyn_img; - for (obj_id, pred) in predictions.iter() { - // color changes when psr is low - let mut color = Rgba([125u8, 255u8, 0u8, 0u8]); - if pred.psr < psr_thresh { - color = Rgba([255u8, 0u8, 0u8, 0u8]) - } - - // Indicate the locations of the predictions by drawing on the image. - draw_cross_mut( - &mut img_copy, - Rgba([255u8, 0u8, 0u8, 0u8]), - pred.location.0 as i32, - pred.location.1 as i32, - ); - draw_hollow_rect_mut( - &mut img_copy, - Rect::at( - pred.location.0.saturating_sub(window_size / 2) as i32, - pred.location.1.saturating_sub(window_size / 2) as i32, - ) - .of_size(window_size, window_size), - color, - ); - - let font_data = include_bytes!("./Arial.ttf"); - let font = Font::try_from_bytes(font_data as &[u8]).unwrap(); - - const FONT_SCALE: f32 = 10.0; - - // render the object ID - draw_text_mut( - &mut img_copy, - Rgba([125u8, 255u8, 0u8, 0u8]), - (pred.location.0 - (window_size / 2)).try_into().unwrap(), - (pred.location.1 - (window_size / 2)).try_into().unwrap(), - Scale::uniform(FONT_SCALE), - &font, - &format!("#{}", obj_id), - ); - - // render the PSR on top of the rectangle - draw_text_mut( - &mut img_copy, - color, - (pred.location.0 - (window_size / 2)).try_into().unwrap(), - (pred.location.1 - (window_size / 2) + FONT_SCALE as u32).try_into().unwrap(), - Scale::uniform(FONT_SCALE), - &font, - &format!("PSR: {:.2}", pred.psr), - ); - - println!("Object {} PSR: {}", obj_id, pred.psr) - } - - // additional debug info - #[cfg(debug_assertions)] - { - // save the filters - multi_tracker - .dump_filter_reals() - .iter() - .enumerate() - .for_each(|(i, f)| { - f.save(format!("filter_real_obj{}_fig{}.png", i, img_id)) - .unwrap() - }) - } - - img_copy - .save(format!("predicted_image_{}.png", img_id)) - .unwrap(); - - // Break off multi tracker if all targets lost - if multi_tracker.size() == 0 { - println!("No more active trackers. Stopping demo."); - break; - } - } -} +extern crate image; +extern crate imageproc; +extern crate mosse; +extern crate rusttype; +extern crate time; + +use image::Rgba; +use imageproc::drawing::{draw_cross_mut, draw_hollow_rect_mut, draw_text_mut}; +use imageproc::rect::Rect; +use mosse::{MosseTrackerSettings, MultiMosseTracker}; +use rusttype::{Font, Scale}; +use std::env; +use std::time::Instant; + +fn main() { + // Collect all elements in the iterator that contains the command line arguments + let args: Vec = env::args().collect(); + + // remove the first element from the list of arguments, which is the call to the binary + let inputfiles = &args[1..]; + + if inputfiles.len() == 0 { + panic!("no input files specified"); + } + let mut images = inputfiles.iter().map(|path| image::open(path).unwrap()); + let first = images.next().unwrap(); + + // initialize a new model + let (width, height) = first.to_rgb8().dimensions(); + let window_size = 64; //size of the tracking window + let psr_thresh = 7.0; // how high the psr must be before prediction is considered succesful. + let settings = MosseTrackerSettings { + window_size: window_size, + width, + height, + regularization: 0.001, + learning_rate: 0.05, + psr_threshold: psr_thresh, + }; + let desperation_threshold = 3; // how many frames the tracker should try to re-acquire the target until we consider it failed + let mut multi_tracker = MultiMosseTracker::new(settings, desperation_threshold); + + // coordinates of the target objects to track in the intial frame + let target_coords = vec![ + (143, 766), + (232, 653), + (291, 731), + (1298, 664), + (479, 642), + (574, 629), + (666, 627), + (762, 609), + ]; + + // Add all the targets on the first image to the multitracker + let first_img = first.to_luma8(); + for (i, coords) in target_coords.into_iter().enumerate() { + let start = Instant::now(); + multi_tracker.add_or_replace_target(i as u32, coords, &first_img); + println!( + "Added object on initial frame to multi-tracker in {} ms", + start.elapsed().as_millis() + ); + } + + for (i, dyn_img) in images.enumerate() { + // add leading zeroes for easier downstream proc with ffmpeg + let img_id = format!("{:<04}", i + 1); + + // track the objects on the new frame + let start = Instant::now(); + let predictions = multi_tracker.track(&dyn_img.to_luma8()); + + println!( + "Processed sample image no. {} in {} ms. Active trackers: {}.", + img_id, + start.elapsed().as_millis(), + multi_tracker.size(), + ); + + let mut img_copy = dyn_img; + for (obj_id, pred) in predictions.iter() { + // color changes when psr is low + let mut color = Rgba([125u8, 255u8, 0u8, 0u8]); + if pred.psr < psr_thresh { + color = Rgba([255u8, 0u8, 0u8, 0u8]) + } + + // Indicate the locations of the predictions by drawing on the image. + draw_cross_mut( + &mut img_copy, + Rgba([255u8, 0u8, 0u8, 0u8]), + pred.location.0 as i32, + pred.location.1 as i32, + ); + draw_hollow_rect_mut( + &mut img_copy, + Rect::at( + pred.location.0.saturating_sub(window_size / 2) as i32, + pred.location.1.saturating_sub(window_size / 2) as i32, + ) + .of_size(window_size, window_size), + color, + ); + + let font_data = include_bytes!("./Arial.ttf"); + let font = Font::try_from_bytes(font_data as &[u8]).unwrap(); + + const FONT_SCALE: f32 = 10.0; + + // render the object ID + draw_text_mut( + &mut img_copy, + Rgba([125u8, 255u8, 0u8, 0u8]), + (pred.location.0 - (window_size / 2)).try_into().unwrap(), + (pred.location.1 - (window_size / 2)).try_into().unwrap(), + Scale::uniform(FONT_SCALE), + &font, + &format!("#{}", obj_id), + ); + + // render the PSR on top of the rectangle + draw_text_mut( + &mut img_copy, + color, + (pred.location.0 - (window_size / 2)).try_into().unwrap(), + (pred.location.1 - (window_size / 2) + FONT_SCALE as u32).try_into().unwrap(), + Scale::uniform(FONT_SCALE), + &font, + &format!("PSR: {:.2}", pred.psr), + ); + + println!("Object {} PSR: {}", obj_id, pred.psr) + } + + // additional debug info + #[cfg(debug_assertions)] + { + // save the filters + multi_tracker + .dump_filter_reals() + .iter() + .enumerate() + .for_each(|(i, f)| { + f.save(format!("filter_real_obj{}_fig{}.png", i, img_id)) + .unwrap() + }) + } + + img_copy + .save(format!("predicted_image_{}.png", img_id)) + .unwrap(); + + // Break off multi tracker if all targets lost + if multi_tracker.size() == 0 { + println!("No more active trackers. Stopping demo."); + break; + } + } +} diff --git a/examples/votchallenge/.python-version b/examples/votchallenge/.python-version index 2515b63..85fa9f2 100644 --- a/examples/votchallenge/.python-version +++ b/examples/votchallenge/.python-version @@ -1 +1 @@ -3.9.14 +3.9.14 diff --git a/examples/votchallenge/README.md b/examples/votchallenge/README.md index 20fe59b..37c8716 100644 --- a/examples/votchallenge/README.md +++ b/examples/votchallenge/README.md @@ -1,93 +1,93 @@ -# Running the votchallenge.net benchmarks against mosse-tracker - -These instructions are adapted from https://www.votchallenge.net/howto/tutorial_python.html - -All instructions assume that you are in this directory. - -## Installing vot tool in a python virtualenv - -First, you need to install the vot python tool, to run the benchmarks. - -At the time of writing, vot does not work with python 3.10 on macos (import error when starting up), so you may need to use an older version. - -```bash -cd examples/votchallenge # if you are not already here - -python3.9 -m venv vot_venv -source vot_venv/bin/activate -pip install git+https://github.com/votchallenge/vot-toolkit-python -``` - -## Set things up for this project - -```bash -cd examples/votchallenge # if you are not already here - -cargo build --release --example=votchallenge -cp trackers.template.ini trackers.ini -``` - -Then change the last line of your new `trackers.ini`, to point at your -`target/release/examples/votchallenge` executable. This must be an absolute path. - -## Check with a dummy sequence - -```bash -cd examples/votchallenge && source vot_venv/bin/activate # if you are not already here - -vot test MosseRust -``` - -## Run the full benchmark suite - -```bash -cd examples/votchallenge && source vot_venv/bin/activate # if you are not already here - -vot test MosseRust -``` - -This only uses a couple of cores, and take around 30 minutes, so go make yourself a cup of tea. You should see output like this: - -``` -Downloading sequence dataset "VOT2020" with 60 sequences. - Downloading |███████████████████████████████████████████████████████████████████████████| 100% [02:30<00:00] -Download completed - Loading dataset |███████████████████████████████████████████████████████████████████████████| 100% [00:00<00:00] -Loaded workspace in '/Users/alsuren/src/mosse-tracker/examples/votchallenge' -Found data for 1 trackers -Evaluating tracker MosseRust - MosseRust/baseline |███████████████████████████████████████████████████████████████████████████| 100% [13:24<00:00] - MosseRust/realtime |███████████████████████████████████████████████████████████████████████████| 100% [13:11<00:00] - MosseRust/unsupervis |███████████████████████████████████████████████████████████████████████████| 100% [01:34<00:00] -Evaluation concluded successfuly -``` - -## Checking your scores - -```bash -cd examples/votchallenge && source vot_venv/bin/activate # if you are not already here - -vot analysis MosseRust -``` - -This is a bit quicker, and will give you something like: - -``` - Loading dataset |██████████████████████████████████████████████████████████████████████████████████| 100% [00:00<00:00] -Loaded workspace in '/Users/alsuren/src/mosse-tracker/examples/votchallenge' -Found data for 1 trackers - Running analysis |██████████████████████████████████████████████████████████████████████████████████| 100% [00:21<00:00] -Analysis successful, report available as 2022-10-06T22-54-20.997015 -``` - -You can then open ./analysis/2022-10-06T22-54-20.997015/report.html in your web browser, to view the results. - -## Rerunning the analysis after making a change - -```bash -cd examples/votchallenge && source vot_venv/bin/activate # if you are not already here - -cargo build --release --example=votchallenge -rm -rf cache/ results/ -vot evaluate MosseRust && vot analysis MosseRust -``` +# Running the votchallenge.net benchmarks against mosse-tracker + +These instructions are adapted from https://www.votchallenge.net/howto/tutorial_python.html + +All instructions assume that you are in this directory. + +## Installing vot tool in a python virtualenv + +First, you need to install the vot python tool, to run the benchmarks. + +At the time of writing, vot does not work with python 3.10 on macos (import error when starting up), so you may need to use an older version. + +```bash +cd examples/votchallenge # if you are not already here + +python3.9 -m venv vot_venv +source vot_venv/bin/activate +pip install git+https://github.com/votchallenge/vot-toolkit-python +``` + +## Set things up for this project + +```bash +cd examples/votchallenge # if you are not already here + +cargo build --release --example=votchallenge +cp trackers.template.ini trackers.ini +``` + +Then change the last line of your new `trackers.ini`, to point at your +`target/release/examples/votchallenge` executable. This must be an absolute path. + +## Check with a dummy sequence + +```bash +cd examples/votchallenge && source vot_venv/bin/activate # if you are not already here + +vot test MosseRust +``` + +## Run the full benchmark suite + +```bash +cd examples/votchallenge && source vot_venv/bin/activate # if you are not already here + +vot test MosseRust +``` + +This only uses a couple of cores, and take around 30 minutes, so go make yourself a cup of tea. You should see output like this: + +``` +Downloading sequence dataset "VOT2020" with 60 sequences. + Downloading |███████████████████████████████████████████████████████████████████████████| 100% [02:30<00:00] +Download completed + Loading dataset |███████████████████████████████████████████████████████████████████████████| 100% [00:00<00:00] +Loaded workspace in '/Users/alsuren/src/mosse-tracker/examples/votchallenge' +Found data for 1 trackers +Evaluating tracker MosseRust + MosseRust/baseline |███████████████████████████████████████████████████████████████████████████| 100% [13:24<00:00] + MosseRust/realtime |███████████████████████████████████████████████████████████████████████████| 100% [13:11<00:00] + MosseRust/unsupervis |███████████████████████████████████████████████████████████████████████████| 100% [01:34<00:00] +Evaluation concluded successfuly +``` + +## Checking your scores + +```bash +cd examples/votchallenge && source vot_venv/bin/activate # if you are not already here + +vot analysis MosseRust +``` + +This is a bit quicker, and will give you something like: + +``` + Loading dataset |██████████████████████████████████████████████████████████████████████████████████| 100% [00:00<00:00] +Loaded workspace in '/Users/alsuren/src/mosse-tracker/examples/votchallenge' +Found data for 1 trackers + Running analysis |██████████████████████████████████████████████████████████████████████████████████| 100% [00:21<00:00] +Analysis successful, report available as 2022-10-06T22-54-20.997015 +``` + +You can then open ./analysis/2022-10-06T22-54-20.997015/report.html in your web browser, to view the results. + +## Rerunning the analysis after making a change + +```bash +cd examples/votchallenge && source vot_venv/bin/activate # if you are not already here + +cargo build --release --example=votchallenge +rm -rf cache/ results/ +vot evaluate MosseRust && vot analysis MosseRust +``` diff --git a/examples/votchallenge/config.yaml b/examples/votchallenge/config.yaml index 2ce38fe..149f6ef 100644 --- a/examples/votchallenge/config.yaml +++ b/examples/votchallenge/config.yaml @@ -1,3 +1,3 @@ -registry: -- ./trackers.ini -stack: vot2020 +registry: +- ./trackers.ini +stack: vot2020 diff --git a/examples/votchallenge/main.rs b/examples/votchallenge/main.rs index b932637..fed91a2 100644 --- a/examples/votchallenge/main.rs +++ b/examples/votchallenge/main.rs @@ -1,164 +1,164 @@ -mod trax_protocol; - -use std::io::stdin; - -use mosse::{MosseTrackerSettings, MultiMosseTracker}; - -use crate::trax_protocol::{ - ChannelType, Image, ImageType, Region, RegionType, TraxMessageFromClient, TraxMessageFromServer, -}; - -#[derive(Debug)] -pub enum ServerState { - Introduction, - Initialization, - Reporting { - multi_tracker: MultiMosseTracker, - first_region: Region, - }, - Termination, -} - -struct MosseTraxServer { - state: ServerState, -} -impl Default for MosseTraxServer { - fn default() -> Self { - Self { - state: ServerState::Introduction, - } - } -} -impl MosseTraxServer { - fn run(mut self) { - log::info!("starting run"); - - println!("{}", self.make_hello_message()); - - for line in stdin().lines() { - let line = line.unwrap(); - log::trace!("handling line: {line:?}"); - let message: TraxMessageFromClient = line.parse().unwrap(); - let response = self.process_message(message); - println!("{}", response); - } - } - - fn make_hello_message(&mut self) -> TraxMessageFromServer { - TraxMessageFromServer::Hello { - version: 1, - name: "MosseRust".to_string(), - identifier: "mosse-tracker-rust".to_string(), - image: ImageType::Path, - region: RegionType::Rectangle, - channels: vec![ChannelType::Color], - } - } - - fn process_message(&mut self, message: TraxMessageFromClient) -> TraxMessageFromServer { - match message { - TraxMessageFromClient::Initialize { image, region } => self.process_init(image, region), - TraxMessageFromClient::Frame { images } => self.process_frame(images), - // FIXME: return Result from this function, and make the outer loop print "quit" and exit on error? - TraxMessageFromClient::Quit => panic!("client sent quit message"), - } - } - - fn process_init(&mut self, image: Image, region: Region) -> TraxMessageFromServer { - let first = image.open().unwrap(); - - // initialize a new model - let (width, height) = first.to_rgb8().dimensions(); - // FIXME: This tracks a square that entirely encloses the target region, so it may fixate - // on the background for tall or wide targets. - let window_size = f64::max(region.width, region.height) as u32; // size of the tracking window - let psr_thresh = 7.0; // how high the psr must be before prediction is considered succesful. - let settings = MosseTrackerSettings { - window_size: window_size, - width, - height, - regularization: 0.001, - learning_rate: 0.05, - psr_threshold: psr_thresh, - }; - - // FIXME: Could I get away with a single MosseTracker here? This would make things simpler, - // but wouldn't change the results of the benchmark. - let desperation_threshold = 300000; // how many frames the tracker should try to re-acquire the target until we consider it failed - let mut multi_tracker = MultiMosseTracker::new(settings, desperation_threshold); - - let coords = ( - (region.x + region.width / 2.) as u32, - (region.y + region.height / 2.) as u32, - ); - multi_tracker.add_or_replace_target(0, coords, &first.to_luma8()); - - self.state = ServerState::Reporting { - multi_tracker, - first_region: region.clone(), - }; - - // if we were being honest, we would return the square region that we've - // actually fed into the model, but it probably doesn't matter that much. - TraxMessageFromServer::State { region } - } - - fn process_frame(&mut self, images: Vec) -> TraxMessageFromServer { - assert_eq!( - images.len(), - 1, - "TODO: handle multiple images in the same frame message?" - ); - - // FIXME: use let...else for this when it becomes stable - let (multi_tracker, first_region) = if let ServerState::Reporting { - ref mut multi_tracker, - ref first_region, - } = self.state - { - (multi_tracker, first_region) - } else { - panic!("received `frame` message when not in the Reporting state") - }; - - let frame = &images[0].open().unwrap(); - let predictions = multi_tracker.track(&frame.to_luma8()); - assert_eq!(predictions.len(), 1); - let (_obj_id, pred) = &predictions[0]; - - let region = Region { - x: pred - .location - .0 - .saturating_sub((first_region.width / 2.) as u32) as f64, - y: pred - .location - .1 - .saturating_sub((first_region.height / 2.) as u32) as f64, - height: first_region.height, - width: first_region.width, - }; - - #[cfg(debug_assertions)] - { - let mut img_copy = frame.clone(); - imageproc::drawing::draw_hollow_rect_mut( - &mut img_copy, - imageproc::rect::Rect::at(region.x as i32, region.y as i32) - .of_size(region.width as u32, region.height as u32), - image::Rgba([125u8, 255u8, 0u8, 0u8]), - ); - img_copy - .save(images[0].path.with_extension(".predicted.png")) - .unwrap(); - } - TraxMessageFromServer::State { region } - } -} - -fn main() { - env_logger::init(); - - let server = MosseTraxServer::default(); - server.run(); -} +mod trax_protocol; + +use std::io::stdin; + +use mosse::{MosseTrackerSettings, MultiMosseTracker}; + +use crate::trax_protocol::{ + ChannelType, Image, ImageType, Region, RegionType, TraxMessageFromClient, TraxMessageFromServer, +}; + +#[derive(Debug)] +pub enum ServerState { + Introduction, + Initialization, + Reporting { + multi_tracker: MultiMosseTracker, + first_region: Region, + }, + Termination, +} + +struct MosseTraxServer { + state: ServerState, +} +impl Default for MosseTraxServer { + fn default() -> Self { + Self { + state: ServerState::Introduction, + } + } +} +impl MosseTraxServer { + fn run(mut self) { + log::info!("starting run"); + + println!("{}", self.make_hello_message()); + + for line in stdin().lines() { + let line = line.unwrap(); + log::trace!("handling line: {line:?}"); + let message: TraxMessageFromClient = line.parse().unwrap(); + let response = self.process_message(message); + println!("{}", response); + } + } + + fn make_hello_message(&mut self) -> TraxMessageFromServer { + TraxMessageFromServer::Hello { + version: 1, + name: "MosseRust".to_string(), + identifier: "mosse-tracker-rust".to_string(), + image: ImageType::Path, + region: RegionType::Rectangle, + channels: vec![ChannelType::Color], + } + } + + fn process_message(&mut self, message: TraxMessageFromClient) -> TraxMessageFromServer { + match message { + TraxMessageFromClient::Initialize { image, region } => self.process_init(image, region), + TraxMessageFromClient::Frame { images } => self.process_frame(images), + // FIXME: return Result from this function, and make the outer loop print "quit" and exit on error? + TraxMessageFromClient::Quit => panic!("client sent quit message"), + } + } + + fn process_init(&mut self, image: Image, region: Region) -> TraxMessageFromServer { + let first = image.open().unwrap(); + + // initialize a new model + let (width, height) = first.to_rgb8().dimensions(); + // FIXME: This tracks a square that entirely encloses the target region, so it may fixate + // on the background for tall or wide targets. + let window_size = f64::max(region.width, region.height) as u32; // size of the tracking window + let psr_thresh = 7.0; // how high the psr must be before prediction is considered succesful. + let settings = MosseTrackerSettings { + window_size: window_size, + width, + height, + regularization: 0.001, + learning_rate: 0.05, + psr_threshold: psr_thresh, + }; + + // FIXME: Could I get away with a single MosseTracker here? This would make things simpler, + // but wouldn't change the results of the benchmark. + let desperation_threshold = 300000; // how many frames the tracker should try to re-acquire the target until we consider it failed + let mut multi_tracker = MultiMosseTracker::new(settings, desperation_threshold); + + let coords = ( + (region.x + region.width / 2.) as u32, + (region.y + region.height / 2.) as u32, + ); + multi_tracker.add_or_replace_target(0, coords, &first.to_luma8()); + + self.state = ServerState::Reporting { + multi_tracker, + first_region: region.clone(), + }; + + // if we were being honest, we would return the square region that we've + // actually fed into the model, but it probably doesn't matter that much. + TraxMessageFromServer::State { region } + } + + fn process_frame(&mut self, images: Vec) -> TraxMessageFromServer { + assert_eq!( + images.len(), + 1, + "TODO: handle multiple images in the same frame message?" + ); + + // FIXME: use let...else for this when it becomes stable + let (multi_tracker, first_region) = if let ServerState::Reporting { + ref mut multi_tracker, + ref first_region, + } = self.state + { + (multi_tracker, first_region) + } else { + panic!("received `frame` message when not in the Reporting state") + }; + + let frame = &images[0].open().unwrap(); + let predictions = multi_tracker.track(&frame.to_luma8()); + assert_eq!(predictions.len(), 1); + let (_obj_id, pred) = &predictions[0]; + + let region = Region { + x: pred + .location + .0 + .saturating_sub((first_region.width / 2.) as u32) as f64, + y: pred + .location + .1 + .saturating_sub((first_region.height / 2.) as u32) as f64, + height: first_region.height, + width: first_region.width, + }; + + #[cfg(debug_assertions)] + { + let mut img_copy = frame.clone(); + imageproc::drawing::draw_hollow_rect_mut( + &mut img_copy, + imageproc::rect::Rect::at(region.x as i32, region.y as i32) + .of_size(region.width as u32, region.height as u32), + image::Rgba([125u8, 255u8, 0u8, 0u8]), + ); + img_copy + .save(images[0].path.with_extension(".predicted.png")) + .unwrap(); + } + TraxMessageFromServer::State { region } + } +} + +fn main() { + env_logger::init(); + + let server = MosseTraxServer::default(); + server.run(); +} diff --git a/examples/votchallenge/trackers.template.ini b/examples/votchallenge/trackers.template.ini index 4f121fc..bd5b05e 100644 --- a/examples/votchallenge/trackers.template.ini +++ b/examples/votchallenge/trackers.template.ini @@ -1,7 +1,7 @@ -[MosseRust] -label = MosseRust -protocol = trax - -## Change this to point at your `target/release/examples/votchallenge` executable. -## This must be an absolute path. -command = __FIX_ME_IN_TRACKERS.INI_FILE__/mosse-tracker/target/release/examples/votchallenge +[MosseRust] +label = MosseRust +protocol = trax + +## Change this to point at your `target/release/examples/votchallenge` executable. +## This must be an absolute path. +command = __FIX_ME_IN_TRACKERS.INI_FILE__/mosse-tracker/target/release/examples/votchallenge diff --git a/examples/votchallenge/trax_protocol.rs b/examples/votchallenge/trax_protocol.rs index 23f2f2e..0df5dcb 100644 --- a/examples/votchallenge/trax_protocol.rs +++ b/examples/votchallenge/trax_protocol.rs @@ -1,213 +1,213 @@ -#![allow(dead_code)] -///! This module implements the trax protocol as described in https://trax.readthedocs.io/en/latest/protocol.html -// FIXME: split this out into its own crate? -use std::{fmt::Display, path::PathBuf, str::FromStr}; - -use image::{DynamicImage, ImageError}; - -/// messages defined by https://trax.readthedocs.io/en/latest/protocol.html#protocol-messages-and-states -pub enum TraxMessageFromServer { - Hello { - /// Specifies the supported version of the protocol. If not present, version 1 is assumed. - version: i32, - /// Specifies the name of the tracker. The name can be used by the client to verify that the correct algorithm is executed. - name: String, - /// Specifies the identifier of the current implementation. The identifier can be used to determine the version of the tracker. - identifier: String, - /// Specifies the supported image format. See Section Image formats for the list of supported formats. By default it is assumed that the tracker can accept file paths as image source. - image: ImageType, - /// Specifies the supported region format. See Section Region formats for the list of supported formats. By default it is assumed that the tracker can accept rectangles as region specification. - region: RegionType, - /// Specifies support for multi-modal images. See Section Image channels for more information - channels: Vec, - }, - State { - region: Region, - }, - Quit, -} -impl Display for TraxMessageFromServer { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "@@TRAX:")?; - match self { - TraxMessageFromServer::Hello { - version, - name, - identifier, - image, - region, - channels, - } => { - let channels = channels - .iter() - .map(ToString::to_string) - .collect::>() - .join(","); - write!(f, "hello trax.version={version} trax.name={name} trax.identifier={identifier} trax.image={image} trax.region={region} trax.channels={channels}") - } - TraxMessageFromServer::State { region } => { - write!(f, "state {region}") - } - TraxMessageFromServer::Quit => todo!(), - } - } -} -#[derive(Debug)] -pub enum TraxMessageFromClient { - Initialize { image: Image, region: Region }, - Frame { images: Vec }, - Quit, -} - -impl FromStr for TraxMessageFromClient { - // we could probably use anyhow::Error or here - type Err = anyhow::Error; - - fn from_str(s: &str) -> Result { - let s = s.trim_end().strip_prefix("@@TRAX:").unwrap(); - let (type_, rest) = s.split_once(' ').unwrap(); - let res = match type_ { - "initialize" => { - // FIXME: - // * strip out quotes and whitespace properly (tempdir might have spaces in on windows?) - let (image, region) = rest.split_once(' ').unwrap(); - Self::Initialize { - image: Image::from_str(strip_quotes_from_ends(image)?)?, - region: Region::from_str(strip_quotes_from_ends(region)?)?, - } - } - "frame" => Self::Frame { - // FIXME: https://trax.readthedocs.io/en/latest/protocol.html#protocol-messages-and-states - // says "or multiple images", which is why I made it a Vec, but I'm not sure how this - // should be handled by the server (split it and treat it as if it were mutiple "frame" messages?) - // so it might be better to flatten out the Vec into a single Image. - images: vec![Image::from_str(strip_quotes_from_ends(rest)?)?], - }, - _ => anyhow::bail!("don't understand message: {s:?}"), - }; - Ok(res) - } -} - -// I feel like there should be something like this in the standard library somewhere, but this will do for now. -fn strip_quotes_from_ends(s: &str) -> anyhow::Result<&str> { - s.strip_prefix('"') - .ok_or(anyhow::anyhow!("no leading quote on {s:?}"))? - .strip_suffix('"') - .ok_or(anyhow::anyhow!("no trailing quote on {s:?}")) -} - -#[derive(Debug, PartialEq)] -pub enum ImageType { - /// In practice, we only plan to implement the `Path` image type in our server. - Path, - Memory, - Data, - Url, -} -impl Display for ImageType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - assert_eq!( - self, - &ImageType::Path, - "only `path` image type is supported for now", - ); - match self { - ImageType::Path => write!(f, "path"), - ImageType::Memory => write!(f, "memory"), - ImageType::Data => write!(f, "data"), - ImageType::Url => write!(f, "url"), - } - } -} - -// In practice, we only plan to implement the `Path` image type in our server, otherwise I would have made this an enum as well. -#[derive(Debug)] -pub struct Image { - pub path: PathBuf, -} -impl FromStr for Image { - type Err = anyhow::Error; - - fn from_str(s: &str) -> Result { - if let Some(rest) = s.strip_prefix("file://") { - let path = PathBuf::from(rest); - Ok(Self { path }) - } else { - anyhow::bail!("could not decode path from {s}") - } - } -} -impl Image { - pub fn open(&self) -> Result { - image::open(&self.path) - } -} - -/// In practice, we only plan to implement the `Rectangle` region type in our server. -pub enum RegionType { - Rectangle, - Polygon, -} -impl Display for RegionType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - RegionType::Rectangle => write!(f, "rectangle"), - RegionType::Polygon => write!(f, "polygon"), - } - } -} - -// In practice, we only plan to implement the `Rectangle` region type in our server, otherwise I would have made this an enum as well. -#[derive(Debug, Clone)] -pub struct Region { - pub x: f64, - pub y: f64, - pub width: f64, - pub height: f64, -} -impl FromStr for Region { - type Err = anyhow::Error; - - fn from_str(s: &str) -> Result { - let [x, y, width, height]: [f64; 4] = s - .split(|c| c == ',' || c == '\t') - .map(|n| f64::from_str(n)) - .collect::, _>>()? - .try_into() - .map_err(|v| anyhow::anyhow!("{v:?} could not be coerced into a [f64; 4]"))?; - Ok(Self { - x, - y, - width, - height, - }) - } -} -impl Display for Region { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let Self { - x, - y, - width, - height, - } = self; - write!(f, "\"{x:.3},{y:.3},{width:.3},{height:.3}\"") - } -} - -/// In practice, we only plan to implement a single `Color` channel type in our server. -pub enum ChannelType { - Color, - Depth, - InfraRed, -} -impl Display for ChannelType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ChannelType::Color => write!(f, "color"), - ChannelType::Depth => write!(f, "depth"), - ChannelType::InfraRed => write!(f, "ir"), - } - } -} +#![allow(dead_code)] +///! This module implements the trax protocol as described in https://trax.readthedocs.io/en/latest/protocol.html +// FIXME: split this out into its own crate? +use std::{fmt::Display, path::PathBuf, str::FromStr}; + +use image::{DynamicImage, ImageError}; + +/// messages defined by https://trax.readthedocs.io/en/latest/protocol.html#protocol-messages-and-states +pub enum TraxMessageFromServer { + Hello { + /// Specifies the supported version of the protocol. If not present, version 1 is assumed. + version: i32, + /// Specifies the name of the tracker. The name can be used by the client to verify that the correct algorithm is executed. + name: String, + /// Specifies the identifier of the current implementation. The identifier can be used to determine the version of the tracker. + identifier: String, + /// Specifies the supported image format. See Section Image formats for the list of supported formats. By default it is assumed that the tracker can accept file paths as image source. + image: ImageType, + /// Specifies the supported region format. See Section Region formats for the list of supported formats. By default it is assumed that the tracker can accept rectangles as region specification. + region: RegionType, + /// Specifies support for multi-modal images. See Section Image channels for more information + channels: Vec, + }, + State { + region: Region, + }, + Quit, +} +impl Display for TraxMessageFromServer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "@@TRAX:")?; + match self { + TraxMessageFromServer::Hello { + version, + name, + identifier, + image, + region, + channels, + } => { + let channels = channels + .iter() + .map(ToString::to_string) + .collect::>() + .join(","); + write!(f, "hello trax.version={version} trax.name={name} trax.identifier={identifier} trax.image={image} trax.region={region} trax.channels={channels}") + } + TraxMessageFromServer::State { region } => { + write!(f, "state {region}") + } + TraxMessageFromServer::Quit => todo!(), + } + } +} +#[derive(Debug)] +pub enum TraxMessageFromClient { + Initialize { image: Image, region: Region }, + Frame { images: Vec }, + Quit, +} + +impl FromStr for TraxMessageFromClient { + // we could probably use anyhow::Error or here + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + let s = s.trim_end().strip_prefix("@@TRAX:").unwrap(); + let (type_, rest) = s.split_once(' ').unwrap(); + let res = match type_ { + "initialize" => { + // FIXME: + // * strip out quotes and whitespace properly (tempdir might have spaces in on windows?) + let (image, region) = rest.split_once(' ').unwrap(); + Self::Initialize { + image: Image::from_str(strip_quotes_from_ends(image)?)?, + region: Region::from_str(strip_quotes_from_ends(region)?)?, + } + } + "frame" => Self::Frame { + // FIXME: https://trax.readthedocs.io/en/latest/protocol.html#protocol-messages-and-states + // says "or multiple images", which is why I made it a Vec, but I'm not sure how this + // should be handled by the server (split it and treat it as if it were mutiple "frame" messages?) + // so it might be better to flatten out the Vec into a single Image. + images: vec![Image::from_str(strip_quotes_from_ends(rest)?)?], + }, + _ => anyhow::bail!("don't understand message: {s:?}"), + }; + Ok(res) + } +} + +// I feel like there should be something like this in the standard library somewhere, but this will do for now. +fn strip_quotes_from_ends(s: &str) -> anyhow::Result<&str> { + s.strip_prefix('"') + .ok_or(anyhow::anyhow!("no leading quote on {s:?}"))? + .strip_suffix('"') + .ok_or(anyhow::anyhow!("no trailing quote on {s:?}")) +} + +#[derive(Debug, PartialEq)] +pub enum ImageType { + /// In practice, we only plan to implement the `Path` image type in our server. + Path, + Memory, + Data, + Url, +} +impl Display for ImageType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + assert_eq!( + self, + &ImageType::Path, + "only `path` image type is supported for now", + ); + match self { + ImageType::Path => write!(f, "path"), + ImageType::Memory => write!(f, "memory"), + ImageType::Data => write!(f, "data"), + ImageType::Url => write!(f, "url"), + } + } +} + +// In practice, we only plan to implement the `Path` image type in our server, otherwise I would have made this an enum as well. +#[derive(Debug)] +pub struct Image { + pub path: PathBuf, +} +impl FromStr for Image { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + if let Some(rest) = s.strip_prefix("file://") { + let path = PathBuf::from(rest); + Ok(Self { path }) + } else { + anyhow::bail!("could not decode path from {s}") + } + } +} +impl Image { + pub fn open(&self) -> Result { + image::open(&self.path) + } +} + +/// In practice, we only plan to implement the `Rectangle` region type in our server. +pub enum RegionType { + Rectangle, + Polygon, +} +impl Display for RegionType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + RegionType::Rectangle => write!(f, "rectangle"), + RegionType::Polygon => write!(f, "polygon"), + } + } +} + +// In practice, we only plan to implement the `Rectangle` region type in our server, otherwise I would have made this an enum as well. +#[derive(Debug, Clone)] +pub struct Region { + pub x: f64, + pub y: f64, + pub width: f64, + pub height: f64, +} +impl FromStr for Region { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + let [x, y, width, height]: [f64; 4] = s + .split(|c| c == ',' || c == '\t') + .map(|n| f64::from_str(n)) + .collect::, _>>()? + .try_into() + .map_err(|v| anyhow::anyhow!("{v:?} could not be coerced into a [f64; 4]"))?; + Ok(Self { + x, + y, + width, + height, + }) + } +} +impl Display for Region { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let Self { + x, + y, + width, + height, + } = self; + write!(f, "\"{x:.3},{y:.3},{width:.3},{height:.3}\"") + } +} + +/// In practice, we only plan to implement a single `Color` channel type in our server. +pub enum ChannelType { + Color, + Depth, + InfraRed, +} +impl Display for ChannelType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ChannelType::Color => write!(f, "color"), + ChannelType::Depth => write!(f, "depth"), + ChannelType::InfraRed => write!(f, "ir"), + } + } +} diff --git a/index.html b/index.html index f8360b9..eee8350 100644 --- a/index.html +++ b/index.html @@ -1,93 +1,138 @@ - - - - - Mosse Tracker - - - - -

Mosse Multitracker Example

-

Click on the image to track something.

- -
- -
- - - - + + + + + Multi-Object Tracking Debugger + + + + +

Multi-Object Tracking Debugger

+

Click on the video to track something.

+ +
+ + +
+ + + + diff --git a/multi-object-debugger.html b/multi-object-debugger.html new file mode 100644 index 0000000..77dd91f --- /dev/null +++ b/multi-object-debugger.html @@ -0,0 +1,85 @@ + + + + + Multi-Object Tracking Debugger + + + + + +
+

Multi-Object Tracking Debugger

+
+ +
+

Webcam Or Upload Video

+
+ + +
+ +
+ +
+
+ + +
+
+ + +
+ + +
+ +
+

Tracks Visualization

+ +
+
+
+ + diff --git a/src/lib.rs b/src/lib.rs index 7a1f078..f4cba32 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,730 +1,732 @@ -extern crate image; -extern crate imageproc; -extern crate rustfft; - -use image::{imageops, GrayImage, ImageBuffer, Luma}; -use imageproc::geometric_transformations::Projection; -use imageproc::geometric_transformations::{rotate_about_center, warp, Interpolation}; -use rustfft::num_complex::Complex; -use rustfft::num_traits::Zero; -use rustfft::{Fft, FftPlanner}; -use std::cmp::Ordering; -use std::f32; -use std::fmt::Debug; -use std::sync::Arc; - -#[cfg(target_arch = "wasm32")] -pub mod wasm; - -// TODO: use constant declarations wherever possible -// TODO: refactor the unwrap statement into match statements wherever we can't be certain a result exists. -// TODO: behaviour at edge of frame: target may not leave frame, but filter will screw up anyway due to cropping. Move target coord freely within template? -// TODO: improve initial filter quality: additional affine perturbations, like scaling (zooming)? -// TODO: 11x11 window around peak for PSR calculation is arbitrary and seems biased towards larger video feeds? -// TODO: make k (number of perturbarions) a hyperparameter. k = 0 should not be allowed as it is senseless. -// TODO: FFT objects may be thread safe (Arc), but are they blocking during concurrent calls? See https://docs.rs/crate/rustfft/2.1.0/source/examples/concurrency.rs -// TODO: Double check: prevent division by zero (everywhere)? Or use div_checked? Inf is not acceptable!! - -// // OPTIMIZATIONS -// TODO: call add_target asynchronously to avoid blocking on the relatively long call to .train()? -// TODO: use stack-based variable length data types https://gist.github.com/jFransham/369a86eff00e5f280ed25121454acec1#use-stack-based-variable-length-datatypes -// TODO: something stack-allocated like arrayvec = "0.4.7"? -// TODO: training: preprocess is called for each perturbation. May be best to have preprocess return an image, or have it modify in place. -// TODO: carefully track data dependencies in predict function (but get a working version first!) -// TODO: in general: avoid .collect()'ing iterators where possible -// TODO: update routine can use more in-place modifications to reduce space complexity and allocs -// TODO: update routine: benchmark initialization of Gaussian peak on target coordinates. -// TODO: in general: remove allocating functions by reusing buffers where possible (such as self.prev's) - -fn preprocess(image: &GrayImage) -> Vec { - let mut prepped: Vec = image - .pixels() - // convert the pixel to u8 and then to f32 - .map(|p| p[0] as f32) - // add 1, and take the natural logarithm - .map(|p| (p + 1.0).ln()) - .collect(); - - // normalize to mean = 0 (subtract image-wide mean from each pixel) - let sum: f32 = prepped.iter().sum(); - let mean: f32 = sum / prepped.len() as f32; - prepped.iter_mut().for_each(|p| *p = *p - mean); - - // normalize to norm = 1, if possible - let u: f32 = prepped.iter().map(|a| a * a).sum(); - let norm = u.sqrt(); - if norm != 0.0 { - prepped.iter_mut().for_each(|e| *e = *e / norm) - } - - // multiply each pixel by a cosine window - let (width, height) = image.dimensions(); - let mut position = 0; - for i in 0..width { - for j in 0..height { - let cww = ((f32::consts::PI * i as f32) / (width - 1) as f32).sin(); - let cwh = ((f32::consts::PI * j as f32) / (height - 1) as f32).sin(); - prepped[position] = cww.min(cwh) * prepped[position]; - position += 1; - } - } - - return prepped; -} - -type Identifier = u32; - -#[derive(Debug)] -pub struct MultiMosseTracker { - // we also store the tracker's numeric ID, and the amount of times it did not make the PSR threshold. - trackers: Vec<(Identifier, u32, MosseTracker)>, - - // the global tracker settings - settings: MosseTrackerSettings, - - // how many times a tracker is allowed to fail the PSR threshold - desperation_level: u32, -} - -impl MultiMosseTracker { - pub fn new(settings: MosseTrackerSettings, desperation_level: u32) -> MultiMosseTracker { - return MultiMosseTracker { - trackers: Vec::new(), - settings: settings, - desperation_level: desperation_level, - }; - } - - pub fn add_or_replace_target(&mut self, id: Identifier, coords: (u32, u32), frame: &GrayImage) { - // Add a target by specifying its coords and a new ID. - // Specify an existing ID to replace an existing tracked target. - - // create a new tracker for this target and train it - let mut new_tracker = MosseTracker::new(&self.settings); - new_tracker.train(frame, coords); - - match self.trackers.iter_mut().find(|tracker| tracker.0 == id) { - Some(tuple) => { - tuple.1 = 0; - tuple.2 = new_tracker; - } - // add the tracker to the map - _ => self.trackers.push((id, 0, new_tracker)), - }; - } - - pub fn track(&mut self, frame: &GrayImage) -> Vec<(Identifier, Prediction)> { - let mut predictions: Vec<(Identifier, Prediction)> = Vec::new(); - for (id, death_watch, tracker) in &mut self.trackers { - // compute the location of the object in the new frame and save it - let pred = tracker.track_new_frame(frame); - predictions.push((*id, pred)); - - // if the tracker made the PSR threshold, update it. - // if not, we increment its death ticker. - if tracker.last_psr > self.settings.psr_threshold { - tracker.update(frame); - *death_watch = 0u32; - } else { - *death_watch += 1; - } - } - - // prune all filters with an expired death ticker - let level = &self.desperation_level; - self.trackers - .retain(|(_id, death_count, _tracker)| death_count < level); - - return predictions; - } - - pub fn dump_filter_reals(&self) -> Vec { - return self.trackers.iter().map(|t| t.2.dump_filter().0).collect(); - } - - pub fn size(&self) -> usize { - self.trackers.len() - } -} - -pub struct Prediction { - pub location: (u32, u32), - pub psr: f32, -} - -pub struct MosseTracker { - filter: Vec>, - - // constants frame height - frame_width: u32, - frame_height: u32, - - // stores dimensions of tracking window and its center - // window is square for now, this variable contains the size of the square edge - window_size: u32, - current_target_center: (u32, u32), // represents center in frame - - // the 'target' (G). A single Gaussian peak centered at the tracking window. - target: Vec>, - - // constants: learning rate and PSR threshold - eta: f32, - regularization: f32, // not super important for MOSSE: see paper fig 4. - - // the previous Ai and Bi - last_top: Vec>, - last_bottom: Vec>, - - // the previous psr - pub last_psr: f32, - - // thread-safe FFT objects containing precomputed parameters for this input data size. - fft: Arc>, - inv_fft: Arc>, -} - -impl Debug for MosseTracker { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("MosseTracker") - .field("filter", &self.filter) - .field("frame_width", &self.frame_width) - .field("frame_height", &self.frame_height) - .field("window_size", &self.window_size) - .field("current_target_center", &self.current_target_center) - .field("target", &self.target) - .field("eta", &self.eta) - .field("regularization", &self.regularization) - .field("last_top", &self.last_top) - .field("last_bottom", &self.last_bottom) - .field("last_psr", &self.last_psr) - // These fields don't implement Debug, so I can't use the #[derive(Debug)] impl. - // .field("fft", &self.fft) - // .field("inv_fft", &self.inv_fft) - .finish() - } -} - -#[derive(Debug)] -pub struct MosseTrackerSettings { - pub width: u32, - pub height: u32, - pub window_size: u32, - pub learning_rate: f32, - pub psr_threshold: f32, - pub regularization: f32, -} - -#[allow(non_snake_case)] -impl MosseTracker { - pub fn new(settings: &MosseTrackerSettings) -> MosseTracker { - // parameterize the FFT objects - let mut planner = FftPlanner::new(); - let mut inv_planner = FftPlanner::new(); - - // NOTE: we initialize the FFTs based on the size of the window - let length = (settings.window_size * settings.window_size) as usize; - let fft = planner.plan_fft_forward(length); - let inv_fft = inv_planner.plan_fft_inverse(length); - - // initialize the filter and its top and bottom parts with zeroes. - let filter = vec![Complex::zero(); length]; - let top = vec![Complex::zero(); length]; - let bottom = vec![Complex::zero(); length]; - - // initialize the target output map (G), with a compact Gaussian peak centered on the target object. - // In the Bolme paper, this map is called gi. - let mut target: Vec> = - build_target(settings.window_size, settings.window_size) - .into_iter() - .map(|p| Complex::new(p as f32, 0.0)) - .collect(); - fft.process(&mut target); - - return MosseTracker { - filter, - last_top: top, - last_bottom: bottom, - last_psr: 0.0, - eta: settings.learning_rate, - regularization: settings.regularization, - target, - fft, - inv_fft, - frame_width: settings.width, - frame_height: settings.height, - window_size: settings.window_size, - current_target_center: (0, 0), - }; - } - - fn compute_2dfft(&self, imagedata: Vec) -> Vec> { - let mut buffer: Vec> = imagedata - .into_iter() - .map(|p| Complex::new(p as f32, 0.0)) - .collect(); - - // fft.process() CONSUMES the input buffer as scratch space, make sure it is not reused - self.fft.process(&mut buffer); - - return buffer; - } - - // Train a new filter on the first frame in which the object occurs - pub fn train(&mut self, input_frame: &GrayImage, target_center: (u32, u32)) { - // store the target center as the current - self.current_target_center = target_center; - - // cut out the training template by cropping - let window = &window_crop( - input_frame, - self.window_size, - self.window_size, - target_center, - ); - - #[cfg(debug_assertions)] - { - window.save("WINDOW.png").unwrap(); - } - - // build an iterator that produces training frames that have been slightly rotated according to a theta value. - let rotated_frames = [ - 0.02, -0.02, 0.05, -0.05, 0.07, -0.07, 0.09, -0.09, 1.1, -1.1, 1.3, -1.3, 1.5, -1.5, - 2.0, -2.0, - ] - .iter() - .map(|rad| { - // Rotate an image clockwise about its center by theta radians. - let training_frame = - rotate_about_center(window, *rad, Interpolation::Nearest, Luma([0])); - - #[cfg(debug_assertions)] - { - training_frame - .save(format!("training_frame_rotated_theta_{}.png", rad)) - .unwrap(); - } - - return training_frame; - }); - - // build an iterator that produces training frames that have been slightly scaled to various degrees ('zoomed') - let scaled_frames = [0.8, 0.9, 1.1, 1.2].into_iter().map(|scalefactor| { - let scale = Projection::scale(scalefactor, scalefactor); - - let scaled_training_frame = warp(&window, &scale, Interpolation::Nearest, Luma([0])); - - #[cfg(debug_assertions)] - { - scaled_training_frame - .save(format!("training_frame_scaled_{}.png", scalefactor)) - .unwrap(); - } - - return scaled_training_frame; - }); - - // Chain these iterators together. - // Note that we add the initial, unperturbed training frame as first in line. - let training_frames = std::iter::once(window) - .cloned() - .chain(rotated_frames) - .chain(scaled_frames); - // TODO: scaling is not ready yet - // .chain(scaled_frames); - - let mut training_frame_count = 0; - for training_frame in training_frames { - // preprocess the training frame using preprocess() - let vectorized = preprocess(&training_frame); - - // calculate the 2D FFT of the preprocessed frame: FFT(fi) = Fi - let Fi = self.compute_2dfft(vectorized); - - // compute the complex conjugate of Fi, Fi*. - let Fi_star: Vec> = Fi.iter().map(|e| e.conj()).collect(); - - // compute the initial filter - let top = self.target.iter().zip(Fi_star.iter()).map(|(g, f)| g * f); - let bottom = Fi.iter().zip(Fi_star.iter()).map(|(f, f_star)| f * f_star); - - // // add the values to the running sum - self.last_top - .iter_mut() - .zip(top) - .for_each(|(running, new)| *running += new); - - self.last_bottom - .iter_mut() - .zip(bottom) - .for_each(|(running, new)| *running += new); - - training_frame_count += 1 - } - - // divide the values of the top and bottom filters by the number of training perturbations used - self.last_top - .iter_mut() - .for_each(|e| *e /= training_frame_count as f32); - - self.last_bottom - .iter_mut() - .for_each(|e| *e /= training_frame_count as f32); - - // compute the filter by dividing Ai and Bi elementwise - // note that we add a small quantity to avoid dividing by zero, which would yield NaN's. - self.filter = self - .last_top - .iter() - .zip(&self.last_bottom) - .map(|(a, b)| a / b + self.regularization) - .collect(); - - #[cfg(debug_assertions)] - { - println!( - "current center of target in frame: x={}, y={}", - self.current_target_center.0, self.current_target_center.1 - ); - } - } - - pub fn track_new_frame(&mut self, frame: &GrayImage) -> Prediction { - // cut out the training template by cropping - let window = window_crop( - frame, - self.window_size, - self.window_size, - self.current_target_center, - ); - - // preprocess the image using preprocess() - let vectorized = preprocess(&window); - - // calculate the 2D FFT of the preprocessed image: FFT(fi) = Fi - let Fi = self.compute_2dfft(vectorized); - - // elementwise multiplication of F with filter H gives Gi - let mut corr_map_gi: Vec> = - Fi.iter().zip(&self.filter).map(|(a, b)| a * b).collect(); - - // NOTE: Gi is garbage after this call - self.inv_fft.process(&mut corr_map_gi); - - // find the max value of the filtered image 'gi', along with the position of the maximum - let (maxind, max_complex) = corr_map_gi - .iter() - .enumerate() - .max_by(|a, b| { - // filtered (gi) is still complex at this point, we only care about the real part - a.1.re.partial_cmp(&b.1.re).unwrap_or(Ordering::Equal) - }) - .unwrap(); // we can unwrap the result of max_by(), as we are sure filtered.len() > 0 - - // convert the array index of the max to the coordinates in the window - let max_coord_in_window = index_to_coords(self.window_size, maxind as u32); - - let window_half = (self.window_size / 2) as i32; - let x_delta = max_coord_in_window.0 as i32 - window_half; - let y_delta = max_coord_in_window.1 as i32 - window_half; - let x_max = self.frame_width as i32 - window_half; - let y_max = self.frame_height as i32 - window_half; - - #[cfg(debug_assertions)] - { - println!( - "distance of new in-window max from window center: x = {}, y = {}", - x_delta, y_delta, - ); - } - - // compute the max coord in the frame by looking at the shift of the window center - let new_x = (self.current_target_center.0 as i32 + x_delta) - .min(x_max) - .max(window_half); - - let new_y = (self.current_target_center.1 as i32 + y_delta) - .min(y_max) - .max(window_half); - - self.current_target_center = (new_x as u32, new_y as u32); - - // compute PSR - // Note that we re-use the computed max and its coordinate for downstream simplicity - self.last_psr = compute_psr( - &corr_map_gi, - self.window_size, - self.window_size, - max_complex.re, - max_coord_in_window, - ); - - return Prediction { - location: self.current_target_center, - psr: self.last_psr, - }; - } - - // update the filter - fn update(&mut self, frame: &GrayImage) { - // cut out the training template by cropping - let window = window_crop( - frame, - self.window_size, - self.window_size, - self.current_target_center, - ); - - // preprocess the image using preprocess() - let vectorized = preprocess(&window); - - // calculate the 2D FFT of the preprocessed image: FFT(fi) = Fi - let new_Fi = self.compute_2dfft(vectorized); - - //// Update the filter using the prediction - // compute the complex conjugate of Fi, Fi*. - let Fi_star: Vec> = new_Fi.iter().map(|e| e.conj()).collect(); - - // compute Ai (top) and Bi (bottom) using F*, G, and the learning rate (see paper) - let one_minus_eta = 1.0 - self.eta; - - // update the 'top' of the filter update equation - self.last_top = self - .target - .iter() - .zip(&Fi_star) - .zip(&self.last_top) - .map(|((g, f), prev)| self.eta * (g * f) + (one_minus_eta * prev)) - .collect(); - - // update the 'bottom' of the filter update equation - self.last_bottom = new_Fi - .iter() - .zip(&Fi_star) - .zip(&self.last_bottom) - .map(|((f, f_star), prev)| self.eta * (f * f_star) + (one_minus_eta * prev)) - .collect(); - - // compute the new filter H* by dividing Ai and Bi elementwise - self.filter = self - .last_top - .iter() - .zip(&self.last_bottom) - .map(|(a, b)| a / b) - .collect(); - } - - // debug method to dump the latest filter to an inspectable image - pub fn dump_filter( - &self, - ) -> ( - ImageBuffer, Vec>, - ImageBuffer, Vec>, - ) { - // get the filter out of fourier space - // NOTE: input is garbage after this call to inv_fft.process(), so we clone the filter first. - let mut h = self.filter.clone(); - self.inv_fft.process(&mut h); - - // turn the real and imaginary values of the filter into separate grayscale images - let realfilter = h.iter().map(|c| c.re).collect(); - let imfilter = h.iter().map(|c| c.im).collect(); - - return ( - to_imgbuf(&realfilter, self.window_size, self.window_size), - to_imgbuf(&imfilter, self.window_size, self.window_size), - ); - } -} - -fn window_crop( - input_frame: &GrayImage, - window_width: u32, - window_height: u32, - center: (u32, u32), -) -> GrayImage { - let window = imageops::crop( - &mut input_frame.clone(), - center - .0 - .saturating_sub(window_width / 2) - .min(input_frame.width() - window_width), - center - .1 - .saturating_sub(window_height / 2) - .min(input_frame.height() - window_height), - window_width, - window_height, - ) - .to_image(); - - return window; -} - -fn build_target(window_width: u32, window_height: u32) -> Vec { - let mut target_gi = vec![0f32; (window_width * window_height) as usize]; - - // Optional: let the sigma depend on the window size (Galoogahi et al. (2015). Correlation Filters with Limited Boundaries) - // let sigma = ((window_width * window_height) as f32).sqrt() / 16.0; - // let variance = sigma * sigma; - let variance = 2.0; - - // create gaussian peak at the center coordinates - let center_x = window_width / 2; - let center_y = window_height / 2; - for x in 0..window_width { - for y in 0..window_height { - let distx: f32 = x as f32 - center_x as f32; - let disty: f32 = y as f32 - center_y as f32; - - // apply a crude univariate Gaussian density function - target_gi[((y * window_width) + x) as usize] = - (-((distx * distx) + (disty * disty) / variance)).exp() - } - } - - return target_gi; -} - -// function for debugging the shape of the target -// output only depends on the provided target_coords -pub fn dump_target(window_width: u32, window_height: u32) -> ImageBuffer, Vec> { - let trgt = build_target(window_width, window_height); - - let normalized = trgt.iter().map(|a| a * 255.0).collect(); - - return to_imgbuf(&normalized, window_width, window_height); -} - -fn compute_psr( - predicted: &Vec>, - width: u32, - height: u32, - max: f32, - maxpos: (u32, u32), -) -> f32 { - // uses running updates of standard deviation and mean - let mut running_sum = 0.0; - let mut running_sd = 0.0; - for e in predicted { - running_sum += e.re; - running_sd += e.re * e.re; - } - - // subtract the values of a 11*11 window around the max from the running sd and sum - // TODO: look up: why 11*11, and not something simpler like 12*12? - let max_x = maxpos.0 as i32; - let max_y = maxpos.1 as i32; - let window_left = (max_x - 5).max(0); - let window_right = (max_x + 6).min(width as i32); - let window_top = (max_y - 5).min(0); // note: named according to CG conventions - let window_bottom = (max_y + 6).min(height as i32); - for x in window_left..window_right { - for y in window_bottom..window_top { - let ind = (y * width as i32 + x) as usize; - let val = predicted[ind].re; - running_sd -= val * val; - running_sum -= val; - } - } - - // we need to subtract 11*11 window from predicted.len() to get the sidelobe_size - let sidelobe_size = (predicted.len() - (11 * 11)) as f32; - let mean_sl = running_sum / sidelobe_size; - let sd_sl = ((running_sd / sidelobe_size) - (mean_sl * mean_sl)).sqrt(); - let psr = (max - mean_sl) / sd_sl; - - return psr; -} - -fn index_to_coords(width: u32, index: u32) -> (u32, u32) { - // modulo/remainder ops are theoretically O(1) - // checked_rem returns None if rhs == 0, which would indicate an upstream error (width == 0). - let x = index.checked_rem(width).unwrap(); - - // checked sub returns None if overflow occurred, which is also a panicable offense. - // checked_div returns None if rhs == 0, which would indicate an upstream error (width == 0). - let y = (index.checked_sub(x).unwrap()).checked_div(width).unwrap(); - return (x, y); -} - -pub fn to_imgbuf(buf: &Vec, width: u32, height: u32) -> ImageBuffer, Vec> { - ImageBuffer::from_vec(width, height, buf.iter().map(|c| *c as u8).collect()).unwrap() -} - -// TODO: below tests are used as a scratch pad and for syntax experiments, not serious unit testing. -#[cfg(test)] -mod tests { - - use super::*; - - #[test] - fn sanity_test_max_by() { - let filtered: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0]; - let maxel = filtered - .iter() - .enumerate() - .max_by(|a, b| { - // filtered (gi) is still complex at this point, we only care about the real part - a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal) - }) - .unwrap(); - assert_eq!(maxel, (4usize, &5.0f32)); - } - - #[test] - fn am_i_still_sane() { - assert_eq!( - Complex::new(1.0, -3.0) * Complex::new(2.0, 5.0), - Complex::new(17.0, -1.0) - ); - } - - #[test] - fn unique_identifier() { - let width = 64; - let height = 64; - let frame = GrayImage::new(width, height); - let settings = MosseTrackerSettings { - window_size: 16, - width, - height, - regularization: 0.001, - learning_rate: 0.05, - psr_threshold: 7.0, - }; - let mut multi_tracker = MultiMosseTracker::new(settings, 3); - assert_eq!(multi_tracker.size(), 0); - multi_tracker.add_or_replace_target(0, (0, 0), &frame); - - assert_eq!(multi_tracker.size(), 1); - assert_eq!( - multi_tracker - .trackers - .iter() - .find(|t| t.0 == 0) - .unwrap() - .2 - .current_target_center, - (0, 0) - ); - - multi_tracker.add_or_replace_target(1, (10, 0), &frame); - - assert_eq!(multi_tracker.size(), 2); - - multi_tracker.add_or_replace_target(0, (10, 0), &frame); - - assert_eq!(multi_tracker.size(), 2); - assert_eq!( - multi_tracker - .trackers - .iter() - .find(|t| t.0 == 0) - .unwrap() - .2 - .current_target_center, - (10, 0) - ); - } -} +extern crate image; +extern crate imageproc; +extern crate rustfft; + +use image::{imageops, GrayImage, ImageBuffer, Luma}; +use imageproc::geometric_transformations::Projection; +use imageproc::geometric_transformations::{rotate_about_center, warp, Interpolation}; +use rustfft::num_complex::Complex; +use rustfft::num_traits::Zero; +use rustfft::{Fft, FftPlanner}; +use std::cmp::Ordering; +use std::f32; +use std::fmt::Debug; +use std::sync::Arc; +use serde::{Serialize, Deserialize}; + +#[cfg(target_arch = "wasm32")] +pub mod wasm; + +// TODO: use constant declarations wherever possible +// TODO: refactor the unwrap statement into match statements wherever we can't be certain a result exists. +// TODO: behaviour at edge of frame: target may not leave frame, but filter will screw up anyway due to cropping. Move target coord freely within template? +// TODO: improve initial filter quality: additional affine perturbations, like scaling (zooming)? +// TODO: 11x11 window around peak for PSR calculation is arbitrary and seems biased towards larger video feeds? +// TODO: make k (number of perturbarions) a hyperparameter. k = 0 should not be allowed as it is senseless. +// TODO: FFT objects may be thread safe (Arc), but are they blocking during concurrent calls? See https://docs.rs/crate/rustfft/2.1.0/source/examples/concurrency.rs +// TODO: Double check: prevent division by zero (everywhere)? Or use div_checked? Inf is not acceptable!! + +// // OPTIMIZATIONS +// TODO: call add_target asynchronously to avoid blocking on the relatively long call to .train()? +// TODO: use stack-based variable length data types https://gist.github.com/jFransham/369a86eff00e5f280ed25121454acec1#use-stack-based-variable-length-datatypes +// TODO: something stack-allocated like arrayvec = "0.4.7"? +// TODO: training: preprocess is called for each perturbation. May be best to have preprocess return an image, or have it modify in place. +// TODO: carefully track data dependencies in predict function (but get a working version first!) +// TODO: in general: avoid .collect()'ing iterators where possible +// TODO: update routine can use more in-place modifications to reduce space complexity and allocs +// TODO: update routine: benchmark initialization of Gaussian peak on target coordinates. +// TODO: in general: remove allocating functions by reusing buffers where possible (such as self.prev's) + +fn preprocess(image: &GrayImage) -> Vec { + let mut prepped: Vec = image + .pixels() + // convert the pixel to u8 and then to f32 + .map(|p| p[0] as f32) + // add 1, and take the natural logarithm + .map(|p| (p + 1.0).ln()) + .collect(); + + // normalize to mean = 0 (subtract image-wide mean from each pixel) + let sum: f32 = prepped.iter().sum(); + let mean: f32 = sum / prepped.len() as f32; + prepped.iter_mut().for_each(|p| *p = *p - mean); + + // normalize to norm = 1, if possible + let u: f32 = prepped.iter().map(|a| a * a).sum(); + let norm = u.sqrt(); + if norm != 0.0 { + prepped.iter_mut().for_each(|e| *e = *e / norm) + } + + // multiply each pixel by a cosine window + let (width, height) = image.dimensions(); + let mut position = 0; + for i in 0..width { + for j in 0..height { + let cww = ((f32::consts::PI * i as f32) / (width - 1) as f32).sin(); + let cwh = ((f32::consts::PI * j as f32) / (height - 1) as f32).sin(); + prepped[position] = cww.min(cwh) * prepped[position]; + position += 1; + } + } + + return prepped; +} + +pub type Identifier = u32; + +#[derive(Debug)] +pub struct MultiMosseTracker { + // we also store the tracker's numeric ID, and the amount of times it did not make the PSR threshold. + trackers: Vec<(Identifier, u32, MosseTracker)>, + + // the global tracker settings + settings: MosseTrackerSettings, + + // how many times a tracker is allowed to fail the PSR threshold + desperation_level: u32, +} + +impl MultiMosseTracker { + pub fn new(settings: MosseTrackerSettings, desperation_level: u32) -> MultiMosseTracker { + return MultiMosseTracker { + trackers: Vec::new(), + settings: settings, + desperation_level: desperation_level, + }; + } + + pub fn add_or_replace_target(&mut self, id: Identifier, coords: (u32, u32), frame: &GrayImage) { + // Add a target by specifying its coords and a new ID. + // Specify an existing ID to replace an existing tracked target. + + // create a new tracker for this target and train it + let mut new_tracker = MosseTracker::new(&self.settings); + new_tracker.train(frame, coords); + + match self.trackers.iter_mut().find(|tracker| tracker.0 == id) { + Some(tuple) => { + tuple.1 = 0; + tuple.2 = new_tracker; + } + // add the tracker to the map + _ => self.trackers.push((id, 0, new_tracker)), + }; + } + + pub fn track(&mut self, frame: &GrayImage) -> Vec<(Identifier, Prediction)> { + let mut predictions: Vec<(Identifier, Prediction)> = Vec::new(); + for (id, death_watch, tracker) in &mut self.trackers { + // compute the location of the object in the new frame and save it + let pred = tracker.track_new_frame(frame); + predictions.push((*id, pred)); + + // if the tracker made the PSR threshold, update it. + // if not, we increment its death ticker. + if tracker.last_psr > self.settings.psr_threshold { + tracker.update(frame); + *death_watch = 0u32; + } else { + *death_watch += 1; + } + } + + // prune all filters with an expired death ticker + let level = &self.desperation_level; + self.trackers + .retain(|(_id, death_count, _tracker)| death_count < level); + + return predictions; + } + + pub fn dump_filter_reals(&self) -> Vec { + return self.trackers.iter().map(|t| t.2.dump_filter().0).collect(); + } + + pub fn size(&self) -> usize { + self.trackers.len() + } +} + +#[derive(Serialize, Deserialize)] +pub struct Prediction { + pub location: (u32, u32), + pub psr: f32, +} + +pub struct MosseTracker { + filter: Vec>, + + // constants frame height + frame_width: u32, + frame_height: u32, + + // stores dimensions of tracking window and its center + // window is square for now, this variable contains the size of the square edge + window_size: u32, + current_target_center: (u32, u32), // represents center in frame + + // the 'target' (G). A single Gaussian peak centered at the tracking window. + target: Vec>, + + // constants: learning rate and PSR threshold + eta: f32, + regularization: f32, // not super important for MOSSE: see paper fig 4. + + // the previous Ai and Bi + last_top: Vec>, + last_bottom: Vec>, + + // the previous psr + pub last_psr: f32, + + // thread-safe FFT objects containing precomputed parameters for this input data size. + fft: Arc>, + inv_fft: Arc>, +} + +impl Debug for MosseTracker { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MosseTracker") + .field("filter", &self.filter) + .field("frame_width", &self.frame_width) + .field("frame_height", &self.frame_height) + .field("window_size", &self.window_size) + .field("current_target_center", &self.current_target_center) + .field("target", &self.target) + .field("eta", &self.eta) + .field("regularization", &self.regularization) + .field("last_top", &self.last_top) + .field("last_bottom", &self.last_bottom) + .field("last_psr", &self.last_psr) + // These fields don't implement Debug, so I can't use the #[derive(Debug)] impl. + // .field("fft", &self.fft) + // .field("inv_fft", &self.inv_fft) + .finish() + } +} + +#[derive(Debug)] +pub struct MosseTrackerSettings { + pub width: u32, + pub height: u32, + pub window_size: u32, + pub learning_rate: f32, + pub psr_threshold: f32, + pub regularization: f32, +} + +#[allow(non_snake_case)] +impl MosseTracker { + pub fn new(settings: &MosseTrackerSettings) -> MosseTracker { + // parameterize the FFT objects + let mut planner = FftPlanner::new(); + let mut inv_planner = FftPlanner::new(); + + // NOTE: we initialize the FFTs based on the size of the window + let length = (settings.window_size * settings.window_size) as usize; + let fft = planner.plan_fft_forward(length); + let inv_fft = inv_planner.plan_fft_inverse(length); + + // initialize the filter and its top and bottom parts with zeroes. + let filter = vec![Complex::zero(); length]; + let top = vec![Complex::zero(); length]; + let bottom = vec![Complex::zero(); length]; + + // initialize the target output map (G), with a compact Gaussian peak centered on the target object. + // In the Bolme paper, this map is called gi. + let mut target: Vec> = + build_target(settings.window_size, settings.window_size) + .into_iter() + .map(|p| Complex::new(p as f32, 0.0)) + .collect(); + fft.process(&mut target); + + return MosseTracker { + filter, + last_top: top, + last_bottom: bottom, + last_psr: 0.0, + eta: settings.learning_rate, + regularization: settings.regularization, + target, + fft, + inv_fft, + frame_width: settings.width, + frame_height: settings.height, + window_size: settings.window_size, + current_target_center: (0, 0), + }; + } + + fn compute_2dfft(&self, imagedata: Vec) -> Vec> { + let mut buffer: Vec> = imagedata + .into_iter() + .map(|p| Complex::new(p as f32, 0.0)) + .collect(); + + // fft.process() CONSUMES the input buffer as scratch space, make sure it is not reused + self.fft.process(&mut buffer); + + return buffer; + } + + // Train a new filter on the first frame in which the object occurs + pub fn train(&mut self, input_frame: &GrayImage, target_center: (u32, u32)) { + // store the target center as the current + self.current_target_center = target_center; + + // cut out the training template by cropping + let window = &window_crop( + input_frame, + self.window_size, + self.window_size, + target_center, + ); + + #[cfg(debug_assertions)] + { + window.save("WINDOW.png").unwrap(); + } + + // build an iterator that produces training frames that have been slightly rotated according to a theta value. + let rotated_frames = [ + 0.02, -0.02, 0.05, -0.05, 0.07, -0.07, 0.09, -0.09, 1.1, -1.1, 1.3, -1.3, 1.5, -1.5, + 2.0, -2.0, + ] + .iter() + .map(|rad| { + // Rotate an image clockwise about its center by theta radians. + let training_frame = + rotate_about_center(window, *rad, Interpolation::Nearest, Luma([0])); + + #[cfg(debug_assertions)] + { + training_frame + .save(format!("training_frame_rotated_theta_{}.png", rad)) + .unwrap(); + } + + return training_frame; + }); + + // build an iterator that produces training frames that have been slightly scaled to various degrees ('zoomed') + let scaled_frames = [0.8, 0.9, 1.1, 1.2].into_iter().map(|scalefactor| { + let scale = Projection::scale(scalefactor, scalefactor); + + let scaled_training_frame = warp(&window, &scale, Interpolation::Nearest, Luma([0])); + + #[cfg(debug_assertions)] + { + scaled_training_frame + .save(format!("training_frame_scaled_{}.png", scalefactor)) + .unwrap(); + } + + return scaled_training_frame; + }); + + // Chain these iterators together. + // Note that we add the initial, unperturbed training frame as first in line. + let training_frames = std::iter::once(window) + .cloned() + .chain(rotated_frames) + .chain(scaled_frames); + // TODO: scaling is not ready yet + // .chain(scaled_frames); + + let mut training_frame_count = 0; + for training_frame in training_frames { + // preprocess the training frame using preprocess() + let vectorized = preprocess(&training_frame); + + // calculate the 2D FFT of the preprocessed frame: FFT(fi) = Fi + let Fi = self.compute_2dfft(vectorized); + + // compute the complex conjugate of Fi, Fi*. + let Fi_star: Vec> = Fi.iter().map(|e| e.conj()).collect(); + + // compute the initial filter + let top = self.target.iter().zip(Fi_star.iter()).map(|(g, f)| g * f); + let bottom = Fi.iter().zip(Fi_star.iter()).map(|(f, f_star)| f * f_star); + + // // add the values to the running sum + self.last_top + .iter_mut() + .zip(top) + .for_each(|(running, new)| *running += new); + + self.last_bottom + .iter_mut() + .zip(bottom) + .for_each(|(running, new)| *running += new); + + training_frame_count += 1 + } + + // divide the values of the top and bottom filters by the number of training perturbations used + self.last_top + .iter_mut() + .for_each(|e| *e /= training_frame_count as f32); + + self.last_bottom + .iter_mut() + .for_each(|e| *e /= training_frame_count as f32); + + // compute the filter by dividing Ai and Bi elementwise + // note that we add a small quantity to avoid dividing by zero, which would yield NaN's. + self.filter = self + .last_top + .iter() + .zip(&self.last_bottom) + .map(|(a, b)| a / b + self.regularization) + .collect(); + + #[cfg(debug_assertions)] + { + println!( + "current center of target in frame: x={}, y={}", + self.current_target_center.0, self.current_target_center.1 + ); + } + } + + pub fn track_new_frame(&mut self, frame: &GrayImage) -> Prediction { + // cut out the training template by cropping + let window = window_crop( + frame, + self.window_size, + self.window_size, + self.current_target_center, + ); + + // preprocess the image using preprocess() + let vectorized = preprocess(&window); + + // calculate the 2D FFT of the preprocessed image: FFT(fi) = Fi + let Fi = self.compute_2dfft(vectorized); + + // elementwise multiplication of F with filter H gives Gi + let mut corr_map_gi: Vec> = + Fi.iter().zip(&self.filter).map(|(a, b)| a * b).collect(); + + // NOTE: Gi is garbage after this call + self.inv_fft.process(&mut corr_map_gi); + + // find the max value of the filtered image 'gi', along with the position of the maximum + let (maxind, max_complex) = corr_map_gi + .iter() + .enumerate() + .max_by(|a, b| { + // filtered (gi) is still complex at this point, we only care about the real part + a.1.re.partial_cmp(&b.1.re).unwrap_or(Ordering::Equal) + }) + .unwrap(); // we can unwrap the result of max_by(), as we are sure filtered.len() > 0 + + // convert the array index of the max to the coordinates in the window + let max_coord_in_window = index_to_coords(self.window_size, maxind as u32); + + let window_half = (self.window_size / 2) as i32; + let x_delta = max_coord_in_window.0 as i32 - window_half; + let y_delta = max_coord_in_window.1 as i32 - window_half; + let x_max = self.frame_width as i32 - window_half; + let y_max = self.frame_height as i32 - window_half; + + #[cfg(debug_assertions)] + { + println!( + "distance of new in-window max from window center: x = {}, y = {}", + x_delta, y_delta, + ); + } + + // compute the max coord in the frame by looking at the shift of the window center + let new_x = (self.current_target_center.0 as i32 + x_delta) + .min(x_max) + .max(window_half); + + let new_y = (self.current_target_center.1 as i32 + y_delta) + .min(y_max) + .max(window_half); + + self.current_target_center = (new_x as u32, new_y as u32); + + // compute PSR + // Note that we re-use the computed max and its coordinate for downstream simplicity + self.last_psr = compute_psr( + &corr_map_gi, + self.window_size, + self.window_size, + max_complex.re, + max_coord_in_window, + ); + + return Prediction { + location: self.current_target_center, + psr: self.last_psr, + }; + } + + // update the filter + fn update(&mut self, frame: &GrayImage) { + // cut out the training template by cropping + let window = window_crop( + frame, + self.window_size, + self.window_size, + self.current_target_center, + ); + + // preprocess the image using preprocess() + let vectorized = preprocess(&window); + + // calculate the 2D FFT of the preprocessed image: FFT(fi) = Fi + let new_Fi = self.compute_2dfft(vectorized); + + //// Update the filter using the prediction + // compute the complex conjugate of Fi, Fi*. + let Fi_star: Vec> = new_Fi.iter().map(|e| e.conj()).collect(); + + // compute Ai (top) and Bi (bottom) using F*, G, and the learning rate (see paper) + let one_minus_eta = 1.0 - self.eta; + + // update the 'top' of the filter update equation + self.last_top = self + .target + .iter() + .zip(&Fi_star) + .zip(&self.last_top) + .map(|((g, f), prev)| self.eta * (g * f) + (one_minus_eta * prev)) + .collect(); + + // update the 'bottom' of the filter update equation + self.last_bottom = new_Fi + .iter() + .zip(&Fi_star) + .zip(&self.last_bottom) + .map(|((f, f_star), prev)| self.eta * (f * f_star) + (one_minus_eta * prev)) + .collect(); + + // compute the new filter H* by dividing Ai and Bi elementwise + self.filter = self + .last_top + .iter() + .zip(&self.last_bottom) + .map(|(a, b)| a / b) + .collect(); + } + + // debug method to dump the latest filter to an inspectable image + pub fn dump_filter( + &self, + ) -> ( + ImageBuffer, Vec>, + ImageBuffer, Vec>, + ) { + // get the filter out of fourier space + // NOTE: input is garbage after this call to inv_fft.process(), so we clone the filter first. + let mut h = self.filter.clone(); + self.inv_fft.process(&mut h); + + // turn the real and imaginary values of the filter into separate grayscale images + let realfilter = h.iter().map(|c| c.re).collect(); + let imfilter = h.iter().map(|c| c.im).collect(); + + return ( + to_imgbuf(&realfilter, self.window_size, self.window_size), + to_imgbuf(&imfilter, self.window_size, self.window_size), + ); + } +} + +fn window_crop( + input_frame: &GrayImage, + window_width: u32, + window_height: u32, + center: (u32, u32), +) -> GrayImage { + let window = imageops::crop( + &mut input_frame.clone(), + center + .0 + .saturating_sub(window_width / 2) + .min(input_frame.width() - window_width), + center + .1 + .saturating_sub(window_height / 2) + .min(input_frame.height() - window_height), + window_width, + window_height, + ) + .to_image(); + + return window; +} + +fn build_target(window_width: u32, window_height: u32) -> Vec { + let mut target_gi = vec![0f32; (window_width * window_height) as usize]; + + // Optional: let the sigma depend on the window size (Galoogahi et al. (2015). Correlation Filters with Limited Boundaries) + // let sigma = ((window_width * window_height) as f32).sqrt() / 16.0; + // let variance = sigma * sigma; + let variance = 2.0; + + // create gaussian peak at the center coordinates + let center_x = window_width / 2; + let center_y = window_height / 2; + for x in 0..window_width { + for y in 0..window_height { + let distx: f32 = x as f32 - center_x as f32; + let disty: f32 = y as f32 - center_y as f32; + + // apply a crude univariate Gaussian density function + target_gi[((y * window_width) + x) as usize] = + (-((distx * distx) + (disty * disty) / variance)).exp() + } + } + + return target_gi; +} + +// function for debugging the shape of the target +// output only depends on the provided target_coords +pub fn dump_target(window_width: u32, window_height: u32) -> ImageBuffer, Vec> { + let trgt = build_target(window_width, window_height); + + let normalized = trgt.iter().map(|a| a * 255.0).collect(); + + return to_imgbuf(&normalized, window_width, window_height); +} + +fn compute_psr( + predicted: &Vec>, + width: u32, + height: u32, + max: f32, + maxpos: (u32, u32), +) -> f32 { + // uses running updates of standard deviation and mean + let mut running_sum = 0.0; + let mut running_sd = 0.0; + for e in predicted { + running_sum += e.re; + running_sd += e.re * e.re; + } + + // subtract the values of a 11*11 window around the max from the running sd and sum + // TODO: look up: why 11*11, and not something simpler like 12*12? + let max_x = maxpos.0 as i32; + let max_y = maxpos.1 as i32; + let window_left = (max_x - 5).max(0); + let window_right = (max_x + 6).min(width as i32); + let window_top = (max_y - 5).min(0); // note: named according to CG conventions + let window_bottom = (max_y + 6).min(height as i32); + for x in window_left..window_right { + for y in window_bottom..window_top { + let ind = (y * width as i32 + x) as usize; + let val = predicted[ind].re; + running_sd -= val * val; + running_sum -= val; + } + } + + // we need to subtract 11*11 window from predicted.len() to get the sidelobe_size + let sidelobe_size = (predicted.len() - (11 * 11)) as f32; + let mean_sl = running_sum / sidelobe_size; + let sd_sl = ((running_sd / sidelobe_size) - (mean_sl * mean_sl)).sqrt(); + let psr = (max - mean_sl) / sd_sl; + + return psr; +} + +fn index_to_coords(width: u32, index: u32) -> (u32, u32) { + // modulo/remainder ops are theoretically O(1) + // checked_rem returns None if rhs == 0, which would indicate an upstream error (width == 0). + let x = index.checked_rem(width).unwrap(); + + // checked sub returns None if overflow occurred, which is also a panicable offense. + // checked_div returns None if rhs == 0, which would indicate an upstream error (width == 0). + let y = (index.checked_sub(x).unwrap()).checked_div(width).unwrap(); + return (x, y); +} + +pub fn to_imgbuf(buf: &Vec, width: u32, height: u32) -> ImageBuffer, Vec> { + ImageBuffer::from_vec(width, height, buf.iter().map(|c| *c as u8).collect()).unwrap() +} + +// TODO: below tests are used as a scratch pad and for syntax experiments, not serious unit testing. +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn sanity_test_max_by() { + let filtered: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0]; + let maxel = filtered + .iter() + .enumerate() + .max_by(|a, b| { + // filtered (gi) is still complex at this point, we only care about the real part + a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal) + }) + .unwrap(); + assert_eq!(maxel, (4usize, &5.0f32)); + } + + #[test] + fn am_i_still_sane() { + assert_eq!( + Complex::new(1.0, -3.0) * Complex::new(2.0, 5.0), + Complex::new(17.0, -1.0) + ); + } + + #[test] + fn unique_identifier() { + let width = 64; + let height = 64; + let frame = GrayImage::new(width, height); + let settings = MosseTrackerSettings { + window_size: 16, + width, + height, + regularization: 0.001, + learning_rate: 0.05, + psr_threshold: 7.0, + }; + let mut multi_tracker = MultiMosseTracker::new(settings, 3); + assert_eq!(multi_tracker.size(), 0); + multi_tracker.add_or_replace_target(0, (0, 0), &frame); + + assert_eq!(multi_tracker.size(), 1); + assert_eq!( + multi_tracker + .trackers + .iter() + .find(|t| t.0 == 0) + .unwrap() + .2 + .current_target_center, + (0, 0) + ); + + multi_tracker.add_or_replace_target(1, (10, 0), &frame); + + assert_eq!(multi_tracker.size(), 2); + + multi_tracker.add_or_replace_target(0, (10, 0), &frame); + + assert_eq!(multi_tracker.size(), 2); + assert_eq!( + multi_tracker + .trackers + .iter() + .find(|t| t.0 == 0) + .unwrap() + .2 + .current_target_center, + (10, 0) + ); + } +} diff --git a/src/wasm.rs b/src/wasm.rs index 8f72ef0..c740ec7 100644 --- a/src/wasm.rs +++ b/src/wasm.rs @@ -1,105 +1,119 @@ -use crate::{MosseTrackerSettings, MultiMosseTracker}; -use image::Rgba; -use imageproc::drawing::{draw_cross_mut, draw_hollow_rect_mut, draw_text_mut}; -use imageproc::rect::Rect; -use rusttype::{Font, Scale}; -use wasm_bindgen::prelude::*; - -#[wasm_bindgen] -pub struct MultiMosseTrackerJS { - tracker: MultiMosseTracker, -} - -#[wasm_bindgen] -impl MultiMosseTrackerJS { - #[wasm_bindgen(constructor)] - pub fn new(width: u32, height: u32) -> Self { - let window_size = 48; - let psr_threshold = 7.0; - let settings = MosseTrackerSettings { - window_size, - width, - height, - regularization: 0.002, - learning_rate: 0.05, - psr_threshold, - }; - let desperation_threshold = 4; - let multi_tracker = MultiMosseTracker::new(settings, desperation_threshold); - Self { - tracker: multi_tracker, - } - } - - #[wasm_bindgen] - pub fn set_target(&mut self, x: u32, y: u32, img_data: &[u8]) { - let img = image::load_from_memory_with_format(img_data, image::ImageFormat::Png).unwrap(); - self.tracker - .add_or_replace_target(1, (x, y), &img.to_luma8()); - } - - #[wasm_bindgen] - pub fn track(&mut self, img_data: &[u8]) -> Vec { - let image = image::load_from_memory_with_format(img_data, image::ImageFormat::Png).unwrap(); - let predictions = self.tracker.track(&image.to_luma8()); - let mut img_copy = image.to_rgba8(); - for (obj_id, pred) in predictions.iter() { - let mut color = Rgba([125u8, 255u8, 0u8, 0u8]); - if pred.psr < self.tracker.settings.psr_threshold { - color = Rgba([255u8, 0u8, 0u8, 0u8]) - } - draw_cross_mut( - &mut img_copy, - Rgba([255u8, 0u8, 0u8, 0u8]), - pred.location.0 as i32, - pred.location.1 as i32, - ); - let window_size = self.tracker.settings.window_size; - draw_hollow_rect_mut( - &mut img_copy, - Rect::at( - pred.location.0.saturating_sub(window_size / 2) as i32, - pred.location.1.saturating_sub(window_size / 2) as i32, - ) - .of_size(window_size, window_size), - color, - ); - - let font_data = include_bytes!("../examples/Arial.ttf"); - let font = Font::try_from_bytes(font_data as &[u8]).unwrap(); - - const FONT_SCALE: f32 = 10.0; - - draw_text_mut( - &mut img_copy, - Rgba([125u8, 255u8, 0u8, 0u8]), - (pred.location.0 - (window_size / 2)).try_into().unwrap(), - (pred.location.1 - (window_size / 2)).try_into().unwrap(), - Scale::uniform(FONT_SCALE), - &font, - &format!("#{}", obj_id), - ); - - draw_text_mut( - &mut img_copy, - color, - (pred.location.0 - (window_size / 2)).try_into().unwrap(), - (pred.location.1 - (window_size / 2) + FONT_SCALE as u32) - .try_into() - .unwrap(), - Scale::uniform(FONT_SCALE), - &font, - &format!("PSR: {:.2}", pred.psr), - ); - } - - let mut image_data: Vec = Vec::new(); - img_copy - .write_to( - &mut std::io::Cursor::new(&mut image_data), - image::ImageFormat::Png, - ) - .unwrap(); - image_data - } -} +use crate::{MosseTrackerSettings, MultiMosseTracker, Identifier}; +// use image::Rgba; +// use imageproc::drawing::{draw_cross_mut, draw_hollow_rect_mut, draw_text_mut}; +// use imageproc::rect::Rect; +use serde_json; +// use rusttype::{Font, Scale}; +use wasm_bindgen::prelude::*; + +#[wasm_bindgen] +pub struct MultiMosseTrackerJS { + tracker: MultiMosseTracker, +} + +#[wasm_bindgen] +impl MultiMosseTrackerJS { + #[wasm_bindgen(constructor)] + pub fn new(width: u32, height: u32) -> Self { + let window_size = 48; + let psr_threshold = 7.0; + let settings = MosseTrackerSettings { + window_size, + width, + height, + regularization: 0.002, + learning_rate: 0.05, + psr_threshold, + }; + let desperation_threshold = 4; + let multi_tracker = MultiMosseTracker::new(settings, desperation_threshold); + Self { + tracker: multi_tracker, + } + } + + // Updated set_target to accept an ID parameter and use it to add or replace a target. + #[wasm_bindgen] + pub fn set_target(&mut self, id: Identifier, x: u32, y: u32, img_data: &[u8]) { + let img = image::load_from_memory_with_format(img_data, image::ImageFormat::Png).unwrap(); + self.tracker.add_or_replace_target(id, (x, y), &img.to_luma8()); + } + + // Updated track to return a string representation of the predictions instead of an image. + #[wasm_bindgen] + pub fn track(&mut self, img_data: &[u8]) -> JsValue { + let image = image::load_from_memory_with_format(img_data, image::ImageFormat::Png).unwrap(); + let predictions = self.tracker.track(&image.to_luma8()); + + // Serialize the predictions into a JSON string. + let predictions_json = serde_json::to_string(&predictions).unwrap(); + + // Convert the JSON string into a JavaScript value. + JsValue::from_str(&predictions_json) + } + + // #[wasm_bindgen] + // pub fn track(&mut self, img_data: &[u8]) -> Vec { + // let image = image::load_from_memory_with_format(img_data, image::ImageFormat::Png).unwrap(); + // let predictions = self.tracker.track(&image.to_luma8()); + // let mut img_copy = image.to_rgba8(); + // for (obj_id, pred) in predictions.iter() { + // let mut color = Rgba([125u8, 255u8, 0u8, 0u8]); + // if pred.psr < self.tracker.settings.psr_threshold { + // color = Rgba([255u8, 0u8, 0u8, 0u8]) + // } + // draw_cross_mut( + // &mut img_copy, + // Rgba([255u8, 0u8, 0u8, 0u8]), + // pred.location.0 as i32, + // pred.location.1 as i32, + // ); + // let window_size = self.tracker.settings.window_size; + // draw_hollow_rect_mut( + // &mut img_copy, + // Rect::at( + // pred.location.0.saturating_sub(window_size / 2) as i32, + // pred.location.1.saturating_sub(window_size / 2) as i32, + // ) + // .of_size(window_size, window_size), + // color, + // ); + + // let font_data = include_bytes!("../examples/Arial.ttf"); + // let font = Font::try_from_bytes(font_data as &[u8]).unwrap(); + + // const FONT_SCALE: f32 = 10.0; + + // draw_text_mut( + // &mut img_copy, + // Rgba([125u8, 255u8, 0u8, 0u8]), + // (pred.location.0 - (window_size / 2)).try_into().unwrap(), + // (pred.location.1 - (window_size / 2)).try_into().unwrap(), + // Scale::uniform(FONT_SCALE), + // &font, + // &format!("#{}", obj_id), + // ); + + // draw_text_mut( + // &mut img_copy, + // color, + // (pred.location.0 - (window_size / 2)).try_into().unwrap(), + // (pred.location.1 - (window_size / 2) + FONT_SCALE as u32) + // .try_into() + // .unwrap(), + // Scale::uniform(FONT_SCALE), + // &font, + // &format!("PSR: {:.2}", pred.psr), + // ); + // } + + // let mut image_data: Vec = Vec::new(); + // img_copy + // .write_to( + // &mut std::io::Cursor::new(&mut image_data), + // image::ImageFormat::Png, + // ) + // .unwrap(); + // image_data + // } +}