Skip to content

Commit 3975a22

Browse files
add find_segments_relative (#59)
* correction of repr when EddiesObservation is empty * fix play_timeline and event_timeline to sync colors, and one plot for event * add find_segments_relative
1 parent fa09c06 commit 3975a22

File tree

4 files changed

+225
-20
lines changed

4 files changed

+225
-20
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ Changed
1717
Fixed
1818
^^^^^
1919
- Use `safe_load` for yaml load
20+
- repr of EddiesObservation when the collection is empty (time attribute empty array)
21+
- display_timeline and event_timeline can now use colors according to 'y' values.
22+
- event_timeline now plot all merging event in one plot, instead of one plot per merging. Same for splitting. (avoid bad legend)
2023

2124
Added
2225
^^^^^
@@ -28,6 +31,7 @@ Added
2831
- Save EddyAnim in mp4
2932
- Add method to get eddy contour which enclosed obs defined with (x,y) coordinates
3033
- Add **EddyNetworkSubSetter** to subset network which need special tool and operation after subset
34+
- Add functions to find relatives segments
3135

3236
[3.3.0] - 2020-12-03
3337
--------------------

examples/16_network/pet_relative.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
==========================
44
"""
55

6+
import datetime
7+
68
from matplotlib import pyplot as plt
9+
from matplotlib.ticker import FuncFormatter
710

811
import py_eddy_tracker.gui
912
from py_eddy_tracker import data
@@ -186,6 +189,44 @@
186189
ax.set_title(f"Close segments ({close_to_i3.infos()})")
187190
_ = close_to_i3.display_timeline(ax)
188191

192+
# %%
193+
# Keep relatives to an event
194+
# --------------------------
195+
# When you want to investigate one particular event and select only the closest segments
196+
#
197+
# First choose an event in the network
198+
after, before, stopped = n.merging_event(triplet=True, only_index=True)
199+
i_event = 5
200+
# %%
201+
# then see some order of relatives
202+
@FuncFormatter
203+
def formatter(x, pos):
204+
return (datetime.timedelta(x) + datetime.datetime(1950, 1, 1)).strftime("%d/%m/%Y")
205+
206+
207+
max_order = 2
208+
fig, axs = plt.subplots(
209+
max_order + 2, 1, sharex=True, figsize=(15, 5 * (max_order + 2))
210+
)
211+
212+
axs[0].set_title(f"full network", weight="bold")
213+
axs[0].xaxis.set_major_formatter(formatter), axs[0].grid()
214+
mappables = n.display_timeline(axs[0], colors_mode="y")
215+
axs[0].legend()
216+
217+
for k in range(0, max_order + 1):
218+
219+
ax = axs[k + 1]
220+
sub_network = n.find_segments_relative(after[i_event], stopped[i_event], order=k)
221+
222+
ax.set_title(f"relatives order={k}", weight="bold")
223+
ax.xaxis.set_major_formatter(formatter), ax.grid()
224+
225+
mappables = sub_network.display_timeline(ax, colors_mode="y")
226+
ax.legend()
227+
_ = ax.set_ylim(axs[0].get_ylim())
228+
229+
189230
# %%
190231
# Display track on map
191232
# --------------------

src/py_eddy_tracker/observations/network.py

Lines changed: 171 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,18 @@
66
from glob import glob
77

88
from numba import njit
9-
from numpy import arange, array, bincount, empty, in1d, ones, uint32, unique, zeros
9+
from numpy import (
10+
arange,
11+
array,
12+
bincount,
13+
empty,
14+
in1d,
15+
ones,
16+
uint32,
17+
unique,
18+
where,
19+
zeros,
20+
)
1021

1122
from ..generic import build_index, wrap_longitude
1223
from ..poly import bbox_intersection, vertice_overlap
@@ -71,6 +82,32 @@ def __init__(self, *args, **kwargs):
7182
super().__init__(*args, **kwargs)
7283
self._index_network = None
7384

85+
def find_segments_relative(self, obs, stopped=None, order=1):
86+
"""
87+
find all relative segments within an event from an order.
88+
89+
:param int obs: indice of event after the event
90+
:param int stopped: indice of event before the event
91+
:param int order: order of relatives accepted
92+
93+
:return: all segments relatives
94+
:rtype: EddiesObservations
95+
"""
96+
97+
# extraction of network where the event is
98+
network_id = self.tracks[obs]
99+
nw = self.network(network_id)
100+
101+
# indice of observation in new subnetwork
102+
i_obs = where(nw.segment == self.segment[obs])[0][0]
103+
104+
if stopped is None:
105+
return nw.relatives(i_obs, order=order)
106+
107+
else:
108+
i_stopped = where(nw.segment == self.segment[stopped])[0][0]
109+
return nw.relatives([i_obs, i_stopped], order=order)
110+
74111
@property
75112
def index_network(self):
76113
if self._index_network is None:
@@ -229,12 +266,38 @@ def segment_relative_order(self, seg_origine):
229266

230267
def relative(self, i_obs, order=2, direct=True, only_past=False, only_future=False):
231268
"""
232-
Extract the segments at a certain order.
269+
Extract the segments at a certain order from one observation.
270+
271+
:param list obs: indice of observation for relative computation
272+
:param int order: order of relatives wanted. 0 means only observations in obs, 1 means direct relatives, ...
273+
274+
:return: all segments relatives
275+
:rtype: EddiesObservations
233276
"""
277+
234278
d = self.segment_relative_order(self.segment[i_obs])
235279
m = (d <= order) * (d != -1)
236280
return self.extract_with_mask(m)
237281

282+
def relatives(self, obs, order=2, direct=True, only_past=False, only_future=False):
283+
"""
284+
Extract the segments at a certain order from multiple observations.
285+
286+
:param list obs: indices of observation for relatives computation
287+
:param int order: order of relatives wanted. 0 means only observations in obs, 1 means direct relatives, ...
288+
289+
:return: all segments relatives
290+
:rtype: EddiesObservations
291+
"""
292+
293+
mask = zeros(self.segment.shape, dtype=bool)
294+
295+
for i_obs in obs:
296+
d = self.segment_relative_order(self.segment[i_obs])
297+
mask += (d <= order) * (d != -1)
298+
299+
return self.extract_with_mask(mask)
300+
238301
def numbering_segment(self):
239302
"""
240303
New numbering of segment
@@ -278,7 +341,14 @@ def median_filter(self, half_window, xfield, yfield, inplace=True):
278341
return result
279342

280343
def display_timeline(
281-
self, ax, event=True, field=None, method=None, factor=1, **kwargs
344+
self,
345+
ax,
346+
event=True,
347+
field=None,
348+
method=None,
349+
factor=1,
350+
colors_mode="roll",
351+
**kwargs,
282352
):
283353
"""
284354
Plot a timeline of a network.
@@ -289,6 +359,7 @@ def display_timeline(
289359
:param str,array field: yaxis values, if None, segments are used
290360
:param str method: if None, mean values are used
291361
:param float factor: to multiply field
362+
:param str colors_mode: color of lines. "roll" means looping through colors, "y" means color adapt the y values (for matching color plots)
292363
:return: plot mappable
293364
"""
294365
self.only_one_network()
@@ -302,9 +373,16 @@ def display_timeline(
302373
)
303374
line_kw.update(kwargs)
304375
mappables = dict(lines=list())
376+
305377
if event:
306378
mappables.update(
307-
self.event_timeline(ax, field=field, method=method, factor=factor)
379+
self.event_timeline(
380+
ax,
381+
field=field,
382+
method=method,
383+
factor=factor,
384+
colors_mode=colors_mode,
385+
)
308386
)
309387
for i, b0, b1 in self.iter_on("segment"):
310388
x = self.time[i]
@@ -317,14 +395,25 @@ def display_timeline(
317395
y = self[field][i] * factor
318396
else:
319397
y = self[field][i].mean() * ones(x.shape) * factor
320-
line = ax.plot(x, y, **line_kw, color=self.COLORS[j % self.NB_COLORS])[0]
398+
399+
if colors_mode == "roll":
400+
_color = self.get_color(j)
401+
elif colors_mode == "y":
402+
_color = self.get_color(b0 - 1)
403+
else:
404+
raise NotImplementedError(f"colors_mode '{colors_mode}' not defined")
405+
406+
line = ax.plot(x, y, **line_kw, color=_color)[0]
321407
mappables["lines"].append(line)
322408
j += 1
323409

324410
return mappables
325411

326-
def event_timeline(self, ax, field=None, method=None, factor=1):
412+
def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="roll"):
413+
"""mark events in plot"""
327414
j = 0
415+
events = dict(spliting=[], merging=[])
416+
328417
# TODO : fill mappables dict
329418
y_seg = dict()
330419
if field is not None and method != "all":
@@ -337,7 +426,16 @@ def event_timeline(self, ax, field=None, method=None, factor=1):
337426
x = self.time[i]
338427
if x.shape[0] == 0:
339428
continue
340-
event_kw = dict(color=self.COLORS[j % self.NB_COLORS], ls="-", zorder=1)
429+
430+
if colors_mode == "roll":
431+
_color = self.get_color(j)
432+
elif colors_mode == "y":
433+
_color = self.get_color(b0 - 1)
434+
else:
435+
raise NotImplementedError(f"colors_mode '{colors_mode}' not defined")
436+
437+
event_kw = dict(color=_color, ls="-", zorder=1)
438+
341439
i_n, i_p = (
342440
self.next_obs[i.stop - 1],
343441
self.previous_obs[i.start],
@@ -361,7 +459,8 @@ def event_timeline(self, ax, field=None, method=None, factor=1):
361459
)
362460
)
363461
ax.plot((x[-1], self.time[i_n]), (y0, y1), **event_kw)[0]
364-
ax.plot(x[-1], y0, color="k", marker="H", markersize=10, zorder=-1)[0]
462+
events["merging"].append((x[-1], y0))
463+
365464
if i_p != -1:
366465
seg_previous = self.segment[i_p]
367466
if field is not None and method == "all":
@@ -376,8 +475,25 @@ def event_timeline(self, ax, field=None, method=None, factor=1):
376475
)
377476
)
378477
ax.plot((x[0], self.time[i_p]), (y0, y1), **event_kw)[0]
379-
ax.plot(x[0], y0, color="k", marker="*", markersize=12, zorder=-1)[0]
478+
events["spliting"].append((x[0], y0))
479+
380480
j += 1
481+
482+
kwargs = dict(color="k", zorder=-1, linestyle=" ")
483+
if len(events["spliting"]) > 0:
484+
X, Y = list(zip(*events["spliting"]))
485+
ref = ax.plot(
486+
X, Y, marker="*", markersize=12, label="spliting events", **kwargs
487+
)[0]
488+
mappables.setdefault("events", []).append(ref)
489+
490+
if len(events["merging"]) > 0:
491+
X, Y = list(zip(*events["merging"]))
492+
ref = ax.plot(
493+
X, Y, marker="H", markersize=10, label="merging events", **kwargs
494+
)[0]
495+
mappables.setdefault("events", []).append(ref)
496+
381497
return mappables
382498

383499
def mean_by_segment(self, y, **kw):
@@ -404,23 +520,49 @@ def map_segment(self, method, y, same=True, **kw):
404520
out = array(out)
405521
return out
406522

407-
def map_network(self, method, y, same=True, **kw):
523+
def map_network(self, method, y, same=True, return_dict=False, **kw):
524+
"""
525+
transform data `y` with method `method` for each track.
526+
527+
:param Callable method: method to apply on each tracks
528+
:param np.array y: data where to apply method
529+
:param bool same: if True, return array same size from y. else, return list with track edited
530+
:param bool return_dict: if None, mean values are used
531+
:param float kw: to multiply field
532+
:return: array or dict of result from method for each network
533+
"""
534+
535+
if same and return_dict:
536+
raise NotImplementedError(
537+
"both condition 'same' and 'return_dict' should no be true"
538+
)
539+
408540
if same:
409541
out = empty(y.shape, **kw)
542+
543+
elif return_dict:
544+
out = dict()
545+
410546
else:
411547
out = list()
548+
412549
for i, b0, b1 in self.iter_on(self.track):
413550
res = method(y[i])
414551
if same:
415552
out[i] = res
553+
554+
elif return_dict:
555+
out[b0] = res
556+
416557
else:
417558
if isinstance(i, slice):
418559
if i.start == i.stop:
419560
continue
420561
elif len(i) == 0:
421562
continue
422563
out.append(res)
423-
if not same:
564+
565+
if not same and not return_dict:
424566
out = array(out)
425567
return out
426568

@@ -588,7 +730,7 @@ def death_event(self):
588730
indices.append(i.stop - 1)
589731
return self.extract_event(list(set(indices)))
590732

591-
def merging_event(self, triplet=False):
733+
def merging_event(self, triplet=False, only_index=False):
592734
"""Return observation after a merging event.
593735
594736
If `triplet=True` return the eddy after a merging event, the eddy before the merging event,
@@ -611,13 +753,24 @@ def merging_event(self, triplet=False):
611753
idx_m1.append(i_n)
612754

613755
if triplet:
614-
return (
615-
self.extract_event(list(idx_m1)),
616-
self.extract_event(list(idx_m0)),
617-
self.extract_event(list(idx_m0_stop)),
618-
)
756+
if only_index:
757+
return (
758+
idx_m1,
759+
idx_m0,
760+
idx_m0_stop,
761+
)
762+
763+
else:
764+
return (
765+
self.extract_event(idx_m1),
766+
self.extract_event(idx_m0),
767+
self.extract_event(idx_m0_stop),
768+
)
619769
else:
620-
return self.extract_event(list(set(idx_m1)))
770+
if only_index:
771+
return self.extract_event(set(idx_m1))
772+
else:
773+
return list(set(idx_m1))
621774

622775
def spliting_event(self, triplet=False):
623776
"""Return observation before a splitting event.

0 commit comments

Comments
 (0)