Skip to content

Commit 9eea120

Browse files
committed
- filter on each segment of network
- factor could be apply on timeline - new numbering of segment - extract network greater than X days
1 parent 196d431 commit 9eea120

File tree

2 files changed

+90
-18
lines changed

2 files changed

+90
-18
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ Added
2727
- Color and text field for contour in **EddyAnim** could be choose
2828
- Save EddyAnim in mp4
2929
- Add method to get eddy contour which enclosed obs defined with (x,y) coordinates
30+
- Add **EddyNetworkSubSetter** to subset network which need special tool and operation after subset
3031

3132
[3.3.0] - 2020-12-03
3233
--------------------

src/py_eddy_tracker/observations/network.py

Lines changed: 89 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
from glob import glob
77

88
from numba import njit
9-
from numpy import arange, array, bincount, empty, ones, uint32, unique
9+
from numpy import arange, array, bincount, empty, ones, uint32, unique, zeros
1010

1111
from ..generic import build_index, wrap_longitude
1212
from ..poly import bbox_intersection, vertice_overlap
1313
from .observation import EddiesObservations
14-
from .tracking import TrackEddiesObservations, track_median_filter
14+
from .tracking import TrackEddiesObservations, track_loess_filter, track_median_filter
1515

1616
logger = logging.getLogger("pet")
1717

@@ -71,6 +71,26 @@ def elements(self):
7171
elements.extend(["track", "segment", "next_obs", "previous_obs"])
7272
return list(set(elements))
7373

74+
def longer_than(self, nb_day_min=-1, nb_day_max=-1):
75+
"""
76+
Select network on time duration
77+
78+
:param int nb_day_min: Minimal number of day which must be covered by one network, if negative -> not used
79+
:param int nb_day_max: Maximal number of day which must be covered by one network, if negative -> not used
80+
"""
81+
if nb_day_max < 0:
82+
nb_day_max = 1000000000000
83+
mask = zeros(self.shape, dtype="bool")
84+
for i, b0, b1 in self.iter_on(self.segment_track_array()):
85+
nb = i.stop - i.start
86+
if nb == 0:
87+
continue
88+
t = self.time[i]
89+
dt = t.max() - t.min()
90+
if nb_day_min <= dt <= nb_day_max:
91+
mask[i] = True
92+
return self.extract_with_mask(mask)
93+
7494
@classmethod
7595
def from_split_network(cls, group_dataset, indexs, **kwargs):
7696
"""
@@ -160,6 +180,13 @@ def relative(self, i_obs, order=2, direct=True, only_past=False, only_future=Fal
160180
m = (d <= order) * (d != -1)
161181
return self.extract_with_mask(m)
162182

183+
def numbering_segment(self):
184+
"""
185+
New numbering of segment
186+
"""
187+
for i, _, _ in self.iter_on("track"):
188+
new_numbering(self.segment[i])
189+
163190
def only_one_network(self):
164191
"""
165192
Raise a warning or error?
@@ -168,17 +195,35 @@ def only_one_network(self):
168195
# TODO
169196
pass
170197

198+
def position_filter(self, median_half_window, loess_half_window):
199+
self.median_filter(median_half_window, "time", "lon").loess_filter(
200+
loess_half_window, "time", "lon"
201+
)
202+
self.median_filter(median_half_window, "time", "lat").loess_filter(
203+
loess_half_window, "time", "lat"
204+
)
205+
206+
def loess_filter(self, half_window, xfield, yfield, inplace=True):
207+
result = track_loess_filter(
208+
half_window, self.obs[xfield], self.obs[yfield], self.segment_track_array()
209+
)
210+
if inplace:
211+
self.obs[yfield] = result
212+
return self
213+
return result
214+
171215
def median_filter(self, half_window, xfield, yfield, inplace=True):
172-
# FIXME: segments is not enough with several network
173216
result = track_median_filter(
174-
half_window, self[xfield], self[yfield], self.segment
217+
half_window, self[xfield], self[yfield], self.segment_track_array()
175218
)
176219
if inplace:
177220
self[yfield][:] = result
178221
return self
179222
return result
180223

181-
def display_timeline(self, ax, event=True, field=None, method=None):
224+
def display_timeline(
225+
self, ax, event=True, field=None, method=None, factor=1, **kwargs
226+
):
182227
"""
183228
Must be call on only one network
184229
"""
@@ -191,9 +236,12 @@ def display_timeline(self, ax, event=True, field=None, method=None):
191236
zorder=1,
192237
lw=3,
193238
)
239+
line_kw.update(kwargs)
194240
mappables = dict(lines=list())
195241
if event:
196-
mappables.update(self.event_timeline(ax, field=field, method=method))
242+
mappables.update(
243+
self.event_timeline(ax, field=field, method=method, factor=factor)
244+
)
197245
for i, b0, b1 in self.iter_on("segment"):
198246
x = self.time[i]
199247
if x.shape[0] == 0:
@@ -202,24 +250,24 @@ def display_timeline(self, ax, event=True, field=None, method=None):
202250
y = b0 * ones(x.shape)
203251
else:
204252
if method == "all":
205-
y = self[field][i]
253+
y = self[field][i] * factor
206254
else:
207-
y = self[field][i].mean() * ones(x.shape)
255+
y = self[field][i].mean() * ones(x.shape) * factor
208256
line = ax.plot(x, y, **line_kw, color=self.COLORS[j % self.NB_COLORS])[0]
209257
mappables["lines"].append(line)
210258
j += 1
211259

212260
return mappables
213261

214-
def event_timeline(self, ax, field=None, method=None):
262+
def event_timeline(self, ax, field=None, method=None, factor=1):
215263
j = 0
216264
# TODO : fill mappables dict
217265
y_seg = dict()
218266
if field is not None and method != "all":
219267
for i, b0, _ in self.iter_on("segment"):
220268
y = self[field][i]
221269
if y.shape[0] != 0:
222-
y_seg[b0] = y.mean()
270+
y_seg[b0] = y.mean() * factor
223271
mappables = dict()
224272
for i, b0, b1 in self.iter_on("segment"):
225273
x = self.time[i]
@@ -234,26 +282,34 @@ def event_timeline(self, ax, field=None, method=None):
234282
y0 = b0
235283
else:
236284
if method == "all":
237-
y0 = self[field][i.stop - 1]
285+
y0 = self[field][i.stop - 1] * factor
238286
else:
239287
y0 = y_seg[b0]
240288
if i_n != -1:
241289
seg_next = self.segment[i_n]
242290
y1 = (
243291
seg_next
244292
if field is None
245-
else (self[field][i_n] if method == "all" else y_seg[seg_next])
293+
else (
294+
self[field][i_n] * factor
295+
if method == "all"
296+
else y_seg[seg_next]
297+
)
246298
)
247299
ax.plot((x[-1], self.time[i_n]), (y0, y1), **event_kw)[0]
248300
ax.plot(x[-1], y0, color="k", marker=">", markersize=10, zorder=-1)[0]
249301
if i_p != -1:
250302
seg_previous = self.segment[i_p]
251303
if field is not None and method == "all":
252-
y0 = self[field][i.start]
304+
y0 = self[field][i.start] * factor
253305
y1 = (
254306
seg_previous
255307
if field is None
256-
else (self[field][i_p] if method == "all" else y_seg[seg_previous])
308+
else (
309+
self[field][i_p] * factor
310+
if method == "all"
311+
else y_seg[seg_previous]
312+
)
257313
)
258314
ax.plot((x[0], self.time[i_p]), (y0, y1), **event_kw)[0]
259315
ax.plot(x[0], y0, color="k", marker="*", markersize=12, zorder=-1)[0]
@@ -300,7 +356,7 @@ def segment_track_array(self):
300356
def birth_event(self):
301357
# FIXME how to manage group 0
302358
indices = list()
303-
for i, b0, b1 in self.iter_on(self.segment_track_array()):
359+
for i, _, _ in self.iter_on(self.segment_track_array()):
304360
nb = i.stop - i.start
305361
if nb == 0:
306362
continue
@@ -312,7 +368,7 @@ def birth_event(self):
312368
def death_event(self):
313369
# FIXME how to manage group 0
314370
indices = list()
315-
for i, b0, b1 in self.iter_on(self.segment_track_array()):
371+
for i, _, _ in self.iter_on(self.segment_track_array()):
316372
nb = i.stop - i.start
317373
if nb == 0:
318374
continue
@@ -323,7 +379,7 @@ def death_event(self):
323379

324380
def merging_event(self):
325381
indices = list()
326-
for i, b0, b1 in self.iter_on(self.segment_track_array()):
382+
for i, _, _ in self.iter_on(self.segment_track_array()):
327383
nb = i.stop - i.start
328384
if nb == 0:
329385
continue
@@ -334,7 +390,7 @@ def merging_event(self):
334390

335391
def spliting_event(self):
336392
indices = list()
337-
for i, b0, b1 in self.iter_on(self.segment_track_array()):
393+
for i, _, _ in self.iter_on(self.segment_track_array()):
338394
nb = i.stop - i.start
339395
if nb == 0:
340396
continue
@@ -425,6 +481,9 @@ def extract_with_mask(self, mask):
425481
if nb_obs == 0:
426482
logger.warning("Empty dataset will be created")
427483
else:
484+
logger.info(
485+
f"{nb_obs} observations will be extract ({nb_obs * 100. / self.shape[0]}%)"
486+
)
428487
for field in self.obs.dtype.descr:
429488
if field in ("next_obs", "previous_obs"):
430489
continue
@@ -592,3 +651,15 @@ def build_unique_array(id1, id2):
592651
new_id[i] = k
593652
id1_previous, id2_previous = id1_, id2_
594653
return new_id
654+
655+
656+
@njit(cache=True)
657+
def new_numbering(segs):
658+
nb = len(segs)
659+
s0 = segs[0]
660+
j = 0
661+
for i in range(nb):
662+
if segs[i] != s0:
663+
s0 = segs[i]
664+
j += 1
665+
segs[i] = j

0 commit comments

Comments
 (0)