Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
add find_segments_relative
  • Loading branch information
ludwigVonKoopa committed Mar 2, 2021
commit 2445f531364a9ab23fe47dad102d99d1fbaf58c5
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Added
- Save EddyAnim in mp4
- Add method to get eddy contour which enclosed obs defined with (x,y) coordinates
- Add **EddyNetworkSubSetter** to subset network which need special tool and operation after subset
- Add functions to find relatives segments

[3.3.0] - 2020-12-03
--------------------
Expand Down
41 changes: 41 additions & 0 deletions examples/16_network/pet_relative.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
==========================
"""

import datetime

from matplotlib import pyplot as plt
from matplotlib.ticker import FuncFormatter

import py_eddy_tracker.gui
from py_eddy_tracker import data
Expand Down Expand Up @@ -186,6 +189,44 @@
ax.set_title(f"Close segments ({close_to_i3.infos()})")
_ = close_to_i3.display_timeline(ax)

# %%
# Keep relatives to an event
# --------------------------
# When you want to investigate one particular event and select only the closest segments
#
# First choose an event in the network
after, before, stopped = n.merging_event(triplet=True, only_index=True)
i_event = 5
# %%
# then see some order of relatives
@FuncFormatter
def formatter(x, pos):
return (datetime.timedelta(x) + datetime.datetime(1950, 1, 1)).strftime("%d/%m/%Y")


max_order = 2
fig, axs = plt.subplots(
max_order + 2, 1, sharex=True, figsize=(15, 5 * (max_order + 2))
)

axs[0].set_title(f"full network", weight="bold")
axs[0].xaxis.set_major_formatter(formatter), axs[0].grid()
mappables = n.display_timeline(axs[0], colors_mode="y")
axs[0].legend()

for k in range(0, max_order + 1):

ax = axs[k + 1]
sub_network = n.find_segments_relative(after[i_event], stopped[i_event], order=k)

ax.set_title(f"relatives order={k}", weight="bold")
ax.xaxis.set_major_formatter(formatter), ax.grid()

mappables = sub_network.display_timeline(ax, colors_mode="y")
ax.legend()
_ = ax.set_ylim(axs[0].get_ylim())


# %%
# Display track on map
# --------------------
Expand Down
155 changes: 136 additions & 19 deletions src/py_eddy_tracker/observations/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,18 @@
from glob import glob

from numba import njit
from numpy import arange, array, bincount, empty, in1d, ones, uint32, unique, zeros
from numpy import (
arange,
array,
bincount,
empty,
in1d,
ones,
uint32,
unique,
where,
zeros,
)

from ..generic import build_index, wrap_longitude
from ..poly import bbox_intersection, vertice_overlap
Expand Down Expand Up @@ -71,6 +82,32 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._index_network = None

def find_segments_relative(self, obs, stopped=None, order=1):
"""
find all relative segments within an event from an order.

:param int obs: indice of event after the event
:param int stopped: indice of event before the event
:param int order: order of relatives accepted

:return: all segments relatives
:rtype: EddiesObservations
"""

# extraction of network where the event is
network_id = self.tracks[obs]
nw = self.network(network_id)

# indice of observation in new subnetwork
i_obs = where(nw.segment == self.segment[obs])[0][0]

if stopped is None:
return nw.relatives(i_obs, order=order)

else:
i_stopped = where(nw.segment == self.segment[stopped])[0][0]
return nw.relatives([i_obs, i_stopped], order=order)

@property
def index_network(self):
if self._index_network is None:
Expand Down Expand Up @@ -229,12 +266,38 @@ def segment_relative_order(self, seg_origine):

def relative(self, i_obs, order=2, direct=True, only_past=False, only_future=False):
"""
Extract the segments at a certain order.
Extract the segments at a certain order from one observation.

:param list obs: indice of observation for relative computation
:param int order: order of relatives wanted. 0 means only observations in obs, 1 means direct relatives, ...

:return: all segments relatives
:rtype: EddiesObservations
"""

d = self.segment_relative_order(self.segment[i_obs])
m = (d <= order) * (d != -1)
return self.extract_with_mask(m)

def relatives(self, obs, order=2, direct=True, only_past=False, only_future=False):
"""
Extract the segments at a certain order from multiple observations.

:param list obs: indices of observation for relatives computation
:param int order: order of relatives wanted. 0 means only observations in obs, 1 means direct relatives, ...

:return: all segments relatives
:rtype: EddiesObservations
"""

mask = zeros(self.segment.shape, dtype=bool)

for i_obs in obs:
d = self.segment_relative_order(self.segment[i_obs])
mask += (d <= order) * (d != -1)

return self.extract_with_mask(mask)

def numbering_segment(self):
"""
New numbering of segment
Expand Down Expand Up @@ -278,7 +341,14 @@ def median_filter(self, half_window, xfield, yfield, inplace=True):
return result

def display_timeline(
self, ax, event=True, field=None, method=None, factor=1, colors_mode="roll", **kwargs
self,
ax,
event=True,
field=None,
method=None,
factor=1,
colors_mode="roll",
**kwargs,
):
"""
Plot a timeline of a network.
Expand Down Expand Up @@ -306,7 +376,13 @@ def display_timeline(

if event:
mappables.update(
self.event_timeline(ax, field=field, method=method, factor=factor, colors_mode=colors_mode)
self.event_timeline(
ax,
field=field,
method=method,
factor=factor,
colors_mode=colors_mode,
)
)
for i, b0, b1 in self.iter_on("segment"):
x = self.time[i]
Expand All @@ -323,7 +399,7 @@ def display_timeline(
if colors_mode == "roll":
_color = self.get_color(j)
elif colors_mode == "y":
_color = self.get_color(b0-1)
_color = self.get_color(b0 - 1)
else:
raise NotImplementedError(f"colors_mode '{colors_mode}' not defined")

Expand Down Expand Up @@ -354,7 +430,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol
if colors_mode == "roll":
_color = self.get_color(j)
elif colors_mode == "y":
_color = self.get_color(b0-1)
_color = self.get_color(b0 - 1)
else:
raise NotImplementedError(f"colors_mode '{colors_mode}' not defined")

Expand Down Expand Up @@ -406,13 +482,17 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol
kwargs = dict(color="k", zorder=-1, linestyle=" ")
if len(events["spliting"]) > 0:
X, Y = list(zip(*events["spliting"]))
ref = ax.plot(X, Y, marker="*", markersize=12, label="spliting events", **kwargs)[0]
mappables.setdefault("events",[]).append(ref)
ref = ax.plot(
X, Y, marker="*", markersize=12, label="spliting events", **kwargs
)[0]
mappables.setdefault("events", []).append(ref)

if len(events["merging"]) > 0:
X, Y = list(zip(*events["merging"]))
ref = ax.plot(X, Y, marker="H", markersize=10, label="merging events", **kwargs)[0]
mappables.setdefault("events",[]).append(ref)
ref = ax.plot(
X, Y, marker="H", markersize=10, label="merging events", **kwargs
)[0]
mappables.setdefault("events", []).append(ref)

return mappables

Expand Down Expand Up @@ -440,23 +520,49 @@ def map_segment(self, method, y, same=True, **kw):
out = array(out)
return out

def map_network(self, method, y, same=True, **kw):
def map_network(self, method, y, same=True, return_dict=False, **kw):
"""
transform data `y` with method `method` for each track.

:param Callable method: method to apply on each tracks
:param np.array y: data where to apply method
:param bool same: if True, return array same size from y. else, return list with track edited
:param bool return_dict: if None, mean values are used
:param float kw: to multiply field
:return: array or dict of result from method for each network
"""

if same and return_dict:
raise NotImplementedError(
"both condition 'same' and 'return_dict' should no be true"
)

if same:
out = empty(y.shape, **kw)

elif return_dict:
out = dict()

else:
out = list()

for i, b0, b1 in self.iter_on(self.track):
res = method(y[i])
if same:
out[i] = res

elif return_dict:
out[b0] = res

else:
if isinstance(i, slice):
if i.start == i.stop:
continue
elif len(i) == 0:
continue
out.append(res)
if not same:

if not same and not return_dict:
out = array(out)
return out

Expand Down Expand Up @@ -624,7 +730,7 @@ def death_event(self):
indices.append(i.stop - 1)
return self.extract_event(list(set(indices)))

def merging_event(self, triplet=False):
def merging_event(self, triplet=False, only_index=False):
"""Return observation after a merging event.

If `triplet=True` return the eddy after a merging event, the eddy before the merging event,
Expand All @@ -647,13 +753,24 @@ def merging_event(self, triplet=False):
idx_m1.append(i_n)

if triplet:
return (
self.extract_event(list(idx_m1)),
self.extract_event(list(idx_m0)),
self.extract_event(list(idx_m0_stop)),
)
if only_index:
return (
idx_m1,
idx_m0,
idx_m0_stop,
)

else:
return (
self.extract_event(idx_m1),
self.extract_event(idx_m0),
self.extract_event(idx_m0_stop),
)
else:
return self.extract_event(list(set(idx_m1)))
if only_index:
return self.extract_event(set(idx_m1))
else:
return list(set(idx_m1))

def spliting_event(self, triplet=False):
"""Return observation before a splitting event.
Expand Down