Skip to content

Commit 9721ac7

Browse files
committed
add tracker
1 parent bf91332 commit 9721ac7

File tree

2 files changed

+81
-14
lines changed

2 files changed

+81
-14
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from ..observations.observation import EddiesObservations as Model
2+
from ..dataset.grid import RegularGridDataset
3+
from numpy import ma
4+
import logging
5+
6+
logger = logging.getLogger("pet")
7+
8+
9+
class AreaTracker(Model):
10+
def tracking(self, other):
11+
shape = (self.shape[0], other.shape[0])
12+
i, j, c = self.match(other, intern=False)
13+
cost_mat = ma.empty(shape, dtype="f4")
14+
cost_mat.mask = ma.ones(shape, dtype="bool")
15+
m = c > 0
16+
i, j, c = i[m], j[m], c[m]
17+
cost_mat[i, j] = 1 - c
18+
19+
i_self, i_other = self.solve_function(cost_mat)
20+
i_self, i_other = self.post_process_link(other, i_self, i_other)
21+
logger.debug("%d matched with previous", i_self.shape[0])
22+
return i_self, i_other, cost_mat[i_self, i_other]
23+
24+
def propagate(self, previous_obs, current_obs, obs_to_extend, dead_track, nb_next, model):
25+
virtual = super().propagate(previous_obs, current_obs, obs_to_extend, dead_track, nb_next, model)
26+
nb_dead = len(previous_obs)
27+
nb_virtual_extend = nb_next - nb_dead
28+
for key in model.elements:
29+
if "contour_" not in key:
30+
continue
31+
virtual[key][:nb_dead] = current_obs[key]
32+
if nb_virtual_extend > 0:
33+
virtual[key][nb_dead:] = obs_to_extend[key]
34+
return virtual

src/py_eddy_tracker/observations/tracking.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,24 @@
2727
===========================================================================
2828
2929
"""
30-
from numpy import empty, arange, where, unique, interp, ones, bool_, zeros, array, median
30+
from numpy import (
31+
empty,
32+
arange,
33+
where,
34+
unique,
35+
interp,
36+
ones,
37+
bool_,
38+
zeros,
39+
array,
40+
median,
41+
)
3142
from .. import VAR_DESCR_inv
3243
import logging
3344
from datetime import datetime, timedelta
3445
from .observation import EddiesObservations
3546
from numba import njit
47+
from ..generic import split_line
3648

3749
logger = logging.getLogger("pet")
3850

@@ -224,28 +236,30 @@ def extract_ids(self, tracks):
224236
return self.__extract_with_mask(mask)
225237

226238
def extract_first_obs_in_box(self, res):
227-
data = empty(self.obs.shape, dtype=[('lon', 'f4'), ('lat', 'f4'), ('track', 'i4')])
228-
data['lon'] = self.longitude - self.longitude % res
229-
data['lat'] = self.latitude - self.latitude % res
230-
data['track'] = self.obs["track"]
239+
data = empty(
240+
self.obs.shape, dtype=[("lon", "f4"), ("lat", "f4"), ("track", "i4")]
241+
)
242+
data["lon"] = self.longitude - self.longitude % res
243+
data["lat"] = self.latitude - self.latitude % res
244+
data["track"] = self.obs["track"]
231245
_, indexs = unique(data, return_index=True)
232-
mask = zeros(self.obs.shape, dtype='bool')
246+
mask = zeros(self.obs.shape, dtype="bool")
233247
mask[indexs] = True
234248
return self.__extract_with_mask(mask)
235249

236250
def extract_in_direction(self, direction, value=0):
237251
nb_obs = self.nb_obs_by_track
238252
i_start = self.index_from_track
239253
i_stop = i_start + nb_obs - 1
240-
if direction in ('S', 'N'):
254+
if direction in ("S", "N"):
241255
d_lat = self.latitude[i_stop] - self.latitude[i_start]
242-
mask = d_lat < 0 if 'S' == direction else d_lat > 0
256+
mask = d_lat < 0 if "S" == direction else d_lat > 0
243257
mask &= abs(d_lat) > value
244258
else:
245-
lon_start , lon_end = self.longitude[i_start], self.longitude[i_stop]
259+
lon_start, lon_end = self.longitude[i_start], self.longitude[i_stop]
246260
lon_end = (lon_end - (lon_start - 180)) % 360 + lon_start - 180
247261
d_lon = lon_end - lon_start
248-
mask = d_lon < 0 if 'W' == direction else d_lon > 0
262+
mask = d_lon < 0 if "W" == direction else d_lon > 0
249263
mask &= abs(d_lon) > value
250264
mask = mask.repeat(nb_obs)
251265
return self.__extract_with_mask(mask)
@@ -280,7 +294,12 @@ def median_filter(self, half_window, xfield, yfield, inplace=True):
280294
self.obs[yfield] = result
281295

282296
def __extract_with_mask(
283-
self, mask, full_path=False, remove_incomplete=False, compress_id=False, reject_virtual=False,
297+
self,
298+
mask,
299+
full_path=False,
300+
remove_incomplete=False,
301+
compress_id=False,
302+
reject_virtual=False,
284303
):
285304
"""
286305
Extract a subset of observations
@@ -302,7 +321,7 @@ def __extract_with_mask(
302321

303322
if full_path:
304323
if reject_virtual:
305-
mask *= ~self.obs['virtual'].astype('bool')
324+
mask *= ~self.obs["virtual"].astype("bool")
306325
tracks = unique(self.tracks[mask])
307326
mask = self.get_mask_from_id(tracks)
308327
elif remove_incomplete:
@@ -333,6 +352,14 @@ def __extract_with_mask(
333352
new.obs["track"] = id_translate[new.obs["track"]]
334353
return new
335354

355+
def plot(self, ax, ref=None, ** kwargs):
356+
if "label" in kwargs:
357+
kwargs["label"] += " (%s eddies)" % (self.nb_obs_by_track != 0).sum()
358+
x, y = split_line(self.longitude, self.latitude, self.tracks)
359+
if ref is not None:
360+
x = (x - ref) % 360 + ref
361+
return ax.plot(x, y, **kwargs)
362+
336363

337364
@njit(cache=True)
338365
def compute_index(tracks, index, number):
@@ -373,7 +400,9 @@ def track_loess_filter(half_window, x, y, track):
373400
if i != 0:
374401
i_previous = i - 1
375402
dx = x[i] - x[i_previous]
376-
while dx < half_window and i_previous != 0 and cur_track == track[i_previous]:
403+
while (
404+
dx < half_window and i_previous != 0 and cur_track == track[i_previous]
405+
):
377406
w = (1 - (dx / half_window) ** 3) ** 3
378407
y_sum += y[i_previous] * w
379408
w_sum += w
@@ -412,7 +441,11 @@ def track_median_filter(half_window, x, y, track):
412441
cur_track = track[i]
413442
while x[i] - x[i_previous] > half_window or cur_track != track[i_previous]:
414443
i_previous += 1
415-
while i_next < nb and x[i_next] - x[i] <= half_window and cur_track == track[i_next]:
444+
while (
445+
i_next < nb
446+
and x[i_next] - x[i] <= half_window
447+
and cur_track == track[i_next]
448+
):
416449
i_next += 1
417450
y_new[i] = median(y[i_previous:i_next])
418451
return y_new

0 commit comments

Comments
 (0)