From 63e49bfe5c1ff42ed9b152d5b0e06462e9598f59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment?= <49512274+ludwigVonKoopa@users.noreply.github.com> Date: Thu, 25 Feb 2021 13:38:33 +0100 Subject: [PATCH 1/5] update animation example --- examples/16_network/pet_segmentation_anim.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/examples/16_network/pet_segmentation_anim.py b/examples/16_network/pet_segmentation_anim.py index cc0dc23c..81ee99af 100644 --- a/examples/16_network/pet_segmentation_anim.py +++ b/examples/16_network/pet_segmentation_anim.py @@ -37,12 +37,14 @@ def save(self, *args, **kwargs): # %% # Overlaod of class to pick up TRACKS = list() +INDICES = list() class MyTrack(TrackEddiesObservations): @staticmethod def get_next_obs(i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs): TRACKS.append(ids["track"].copy()) + INDICES.append(i_current) return TrackEddiesObservations.get_next_obs( i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs ) @@ -70,9 +72,16 @@ def get_next_obs(i_current, ids, x, y, time_s, time_e, time_ref, window, **kwarg def update(i_frame): tr = TRACKS[i_frame] mappable_tracks.set_array(tr) - s = 80 * ones(tr.shape) + s = 40 * ones(tr.shape) s[tr == 0] = 4 mappable_tracks.set_sizes(s) + + indices_frames = INDICES[i_frame] + mappable_CONTOUR.set_data( + e.contour_lon_e[indices_frames], + e.contour_lat_e[indices_frames], + ) + mappable_CONTOUR.set_color(cmap.colors[tr[indices_frames] % len(cmap.colors)]) return (mappable_tracks,) @@ -85,6 +94,9 @@ def update(i_frame): mappable_tracks = ax.scatter( e.lon, e.lat, c=TRACKS[0], cmap=cmap, vmin=0, vmax=vmax, s=20 ) +mappable_CONTOUR = ax.plot( + e.contour_lon_e[INDICES[0]], e.contour_lat_e[INDICES[0]], color=cmap.colors[0] +)[0] ani = VideoAnimation( fig, update, frames=range(1, len(TRACKS), 4), interval=125, blit=True ) From e08d16a5bb9491a002fd497e1554bddebc79ff67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment?= <49512274+ludwigVonKoopa@users.noreply.github.com> Date: Thu, 25 Feb 2021 13:38:33 +0100 Subject: [PATCH 2/5] dissociate_network for multiple networks --- src/py_eddy_tracker/observations/network.py | 42 ++++++++++++++++----- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/src/py_eddy_tracker/observations/network.py b/src/py_eddy_tracker/observations/network.py index d8c339cf..7d9ef523 100644 --- a/src/py_eddy_tracker/observations/network.py +++ b/src/py_eddy_tracker/observations/network.py @@ -164,6 +164,9 @@ def obs_relative_order(self, i_obs): return self.segment_relative_order(self.segment[i_obs]) def connexions(self, multi_network=False): + """ + create dictionnary for each segments, gives the segments which interact with + """ if multi_network: segment = self.segment_track_array else: @@ -648,10 +651,12 @@ def dissociate_network(self): """ Dissociate network with no known interaction (spliting/merging) """ - self.only_one_network() - tags = self.tag_segment() - # FIXME : Ok if only one network - self.track[:] = tags[self.segment - 1] + + tags = self.tag_segment(multi_network=True) + if self.track[0] == 0: + tags -= 1 + + self.track[:] = tags[self.segment_track_array] i_sort = self.obs.argsort(order=("track", "segment", "time"), kind="mergesort") # Sort directly obs, with hope to save memory @@ -672,23 +677,40 @@ def network(self, id_network): @classmethod def __tag_segment(cls, seg, tag, groups, connexions): + """ + Will set same temporary ID for each connected segment. + + :param int seg: current ID of seg + :param ing tag: temporary ID to set for seg and its connexion + :param array[int] groups: array where tag will be stored + :param dict connexions: gives for one ID of seg all seg connected + """ + # If seg are already used we stop recursivity if groups[seg] != 0: return + # We set tag for this seg groups[seg] = tag - segs = connexions.get(seg + 1, None) + # Get all connexions of this seg + segs = connexions.get(seg, None) if segs is not None: for seg in segs: - cls.__tag_segment(seg - 1, tag, groups, connexions) + # For each connexion we apply same function + cls.__tag_segment(seg, tag, groups, connexions) - def tag_segment(self): - self.only_one_network() - nb = self.segment.max() + def tag_segment(self, multi_network=False): + if multi_network: + nb = self.segment_track_array[-1] + 1 + else: + nb = self.segment.max() + 1 sub_group = zeros(nb, dtype="u4") - c = self.connexions() + c = self.connexions(multi_network=multi_network) j = 1 + # for each available id for i in range(nb): + # Skip if already set if sub_group[i] != 0: continue + # we tag an unset segments and explore all connexions self.__tag_segment(i, j, sub_group, c) j += 1 return sub_group From 2d7517e5fe442c4c1a927ed9a2b1b68d150f1058 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment?= <49512274+ludwigVonKoopa@users.noreply.github.com> Date: Tue, 2 Mar 2021 12:10:25 +0100 Subject: [PATCH 3/5] correction of repr when EddiesObservation is empty --- CHANGELOG.rst | 1 + src/py_eddy_tracker/observations/observation.py | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0b1a100c..1af11b87 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -17,6 +17,7 @@ Changed Fixed ^^^^^ - Use `safe_load` for yaml load +- repr of EddiesObservation when the collection is empty (time attribute empty array) Added ^^^^^ diff --git a/src/py_eddy_tracker/observations/observation.py b/src/py_eddy_tracker/observations/observation.py index 4e671147..b62227bc 100644 --- a/src/py_eddy_tracker/observations/observation.py +++ b/src/py_eddy_tracker/observations/observation.py @@ -2117,13 +2117,16 @@ def interp_grid( @property def period(self): """ - Give the time coverage + Give the time coverage. If collection is empty, return nan,nan :return: first and last date :rtype: (int,int) """ if self.period_ is None: - self.period_ = self.time.min(), self.time.max() + if self.time.size < 1: + self.period_ = nan, nan + else: + self.period_ = self.time.min(), self.time.max() return self.period_ @property From 425c18e73d2a757766d7ef40960563e4e4a2a681 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment?= <49512274+ludwigVonKoopa@users.noreply.github.com> Date: Tue, 2 Mar 2021 12:24:20 +0100 Subject: [PATCH 4/5] fix play_timeline and event_timeline to sync colors, and one plot for event --- CHANGELOG.rst | 2 + src/py_eddy_tracker/observations/network.py | 50 ++++++++++++++++--- .../observations/observation.py | 4 ++ 3 files changed, 49 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 1af11b87..d355d385 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -18,6 +18,8 @@ Fixed ^^^^^ - Use `safe_load` for yaml load - repr of EddiesObservation when the collection is empty (time attribute empty array) +- display_timeline and event_timeline can now use colors according to 'y' values. +- event_timeline now plot all merging event in one plot, instead of one plot per merging. Same for splitting. (avoid bad legend) Added ^^^^^ diff --git a/src/py_eddy_tracker/observations/network.py b/src/py_eddy_tracker/observations/network.py index 33b857ca..1aa05e3a 100644 --- a/src/py_eddy_tracker/observations/network.py +++ b/src/py_eddy_tracker/observations/network.py @@ -278,7 +278,7 @@ def median_filter(self, half_window, xfield, yfield, inplace=True): return result def display_timeline( - self, ax, event=True, field=None, method=None, factor=1, **kwargs + self, ax, event=True, field=None, method=None, factor=1, colors_mode="roll", **kwargs ): """ Plot a timeline of a network. @@ -289,6 +289,7 @@ def display_timeline( :param str,array field: yaxis values, if None, segments are used :param str method: if None, mean values are used :param float factor: to multiply field + :param str colors_mode: color of lines. "roll" means looping through colors, "y" means color adapt the y values (for matching color plots) :return: plot mappable """ self.only_one_network() @@ -302,9 +303,10 @@ def display_timeline( ) line_kw.update(kwargs) mappables = dict(lines=list()) + if event: mappables.update( - self.event_timeline(ax, field=field, method=method, factor=factor) + self.event_timeline(ax, field=field, method=method, factor=factor, colors_mode=colors_mode) ) for i, b0, b1 in self.iter_on("segment"): x = self.time[i] @@ -317,14 +319,25 @@ def display_timeline( y = self[field][i] * factor else: y = self[field][i].mean() * ones(x.shape) * factor - line = ax.plot(x, y, **line_kw, color=self.COLORS[j % self.NB_COLORS])[0] + + if colors_mode == "roll": + _color = self.get_color(j) + elif colors_mode == "y": + _color = self.get_color(b0-1) + else: + raise NotImplementedError(f"colors_mode '{colors_mode}' not defined") + + line = ax.plot(x, y, **line_kw, color=_color)[0] mappables["lines"].append(line) j += 1 return mappables - def event_timeline(self, ax, field=None, method=None, factor=1): + def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="roll"): + """mark events in plot""" j = 0 + events = dict(spliting=[], merging=[]) + # TODO : fill mappables dict y_seg = dict() if field is not None and method != "all": @@ -337,7 +350,16 @@ def event_timeline(self, ax, field=None, method=None, factor=1): x = self.time[i] if x.shape[0] == 0: continue - event_kw = dict(color=self.COLORS[j % self.NB_COLORS], ls="-", zorder=1) + + if colors_mode == "roll": + _color = self.get_color(j) + elif colors_mode == "y": + _color = self.get_color(b0-1) + else: + raise NotImplementedError(f"colors_mode '{colors_mode}' not defined") + + event_kw = dict(color=_color, ls="-", zorder=1) + i_n, i_p = ( self.next_obs[i.stop - 1], self.previous_obs[i.start], @@ -361,7 +383,8 @@ def event_timeline(self, ax, field=None, method=None, factor=1): ) ) ax.plot((x[-1], self.time[i_n]), (y0, y1), **event_kw)[0] - ax.plot(x[-1], y0, color="k", marker="H", markersize=10, zorder=-1)[0] + events["merging"].append((x[-1], y0)) + if i_p != -1: seg_previous = self.segment[i_p] if field is not None and method == "all": @@ -376,8 +399,21 @@ def event_timeline(self, ax, field=None, method=None, factor=1): ) ) ax.plot((x[0], self.time[i_p]), (y0, y1), **event_kw)[0] - ax.plot(x[0], y0, color="k", marker="*", markersize=12, zorder=-1)[0] + events["spliting"].append((x[0], y0)) + j += 1 + + kwargs = dict(color="k", zorder=-1, linestyle=" ") + if len(events["spliting"]) > 0: + X, Y = list(zip(*events["spliting"])) + ref = ax.plot(X, Y, marker="*", markersize=12, label="spliting events", **kwargs)[0] + mappables.setdefault("events",[]).append(ref) + + if len(events["merging"]) > 0: + X, Y = list(zip(*events["merging"])) + ref = ax.plot(X, Y, marker="H", markersize=10, label="merging events", **kwargs)[0] + mappables.setdefault("events",[]).append(ref) + return mappables def mean_by_segment(self, y, **kw): diff --git a/src/py_eddy_tracker/observations/observation.py b/src/py_eddy_tracker/observations/observation.py index b62227bc..8509910a 100644 --- a/src/py_eddy_tracker/observations/observation.py +++ b/src/py_eddy_tracker/observations/observation.py @@ -199,6 +199,10 @@ def __eq__(self, other): return False return array_equal(self.obs, other.obs) + ### colors methods + def get_color(self, i): + return self.COLORS[i % self.NB_COLORS] + @property def sign_legend(self): return "Cyclonic" if self.sign_type != 1 else "Anticyclonic" From 2445f531364a9ab23fe47dad102d99d1fbaf58c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment?= <49512274+ludwigVonKoopa@users.noreply.github.com> Date: Tue, 2 Mar 2021 12:28:41 +0100 Subject: [PATCH 5/5] add find_segments_relative --- CHANGELOG.rst | 1 + examples/16_network/pet_relative.py | 41 ++++++ src/py_eddy_tracker/observations/network.py | 155 +++++++++++++++++--- 3 files changed, 178 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d355d385..dc599b80 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -31,6 +31,7 @@ Added - Save EddyAnim in mp4 - Add method to get eddy contour which enclosed obs defined with (x,y) coordinates - Add **EddyNetworkSubSetter** to subset network which need special tool and operation after subset +- Add functions to find relatives segments [3.3.0] - 2020-12-03 -------------------- diff --git a/examples/16_network/pet_relative.py b/examples/16_network/pet_relative.py index 4a57062a..affdf44b 100644 --- a/examples/16_network/pet_relative.py +++ b/examples/16_network/pet_relative.py @@ -3,7 +3,10 @@ ========================== """ +import datetime + from matplotlib import pyplot as plt +from matplotlib.ticker import FuncFormatter import py_eddy_tracker.gui from py_eddy_tracker import data @@ -186,6 +189,44 @@ ax.set_title(f"Close segments ({close_to_i3.infos()})") _ = close_to_i3.display_timeline(ax) +# %% +# Keep relatives to an event +# -------------------------- +# When you want to investigate one particular event and select only the closest segments +# +# First choose an event in the network +after, before, stopped = n.merging_event(triplet=True, only_index=True) +i_event = 5 +# %% +# then see some order of relatives +@FuncFormatter +def formatter(x, pos): + return (datetime.timedelta(x) + datetime.datetime(1950, 1, 1)).strftime("%d/%m/%Y") + + +max_order = 2 +fig, axs = plt.subplots( + max_order + 2, 1, sharex=True, figsize=(15, 5 * (max_order + 2)) +) + +axs[0].set_title(f"full network", weight="bold") +axs[0].xaxis.set_major_formatter(formatter), axs[0].grid() +mappables = n.display_timeline(axs[0], colors_mode="y") +axs[0].legend() + +for k in range(0, max_order + 1): + + ax = axs[k + 1] + sub_network = n.find_segments_relative(after[i_event], stopped[i_event], order=k) + + ax.set_title(f"relatives order={k}", weight="bold") + ax.xaxis.set_major_formatter(formatter), ax.grid() + + mappables = sub_network.display_timeline(ax, colors_mode="y") + ax.legend() + _ = ax.set_ylim(axs[0].get_ylim()) + + # %% # Display track on map # -------------------- diff --git a/src/py_eddy_tracker/observations/network.py b/src/py_eddy_tracker/observations/network.py index 1aa05e3a..bbe1b956 100644 --- a/src/py_eddy_tracker/observations/network.py +++ b/src/py_eddy_tracker/observations/network.py @@ -6,7 +6,18 @@ from glob import glob from numba import njit -from numpy import arange, array, bincount, empty, in1d, ones, uint32, unique, zeros +from numpy import ( + arange, + array, + bincount, + empty, + in1d, + ones, + uint32, + unique, + where, + zeros, +) from ..generic import build_index, wrap_longitude from ..poly import bbox_intersection, vertice_overlap @@ -71,6 +82,32 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._index_network = None + def find_segments_relative(self, obs, stopped=None, order=1): + """ + find all relative segments within an event from an order. + + :param int obs: indice of event after the event + :param int stopped: indice of event before the event + :param int order: order of relatives accepted + + :return: all segments relatives + :rtype: EddiesObservations + """ + + # extraction of network where the event is + network_id = self.tracks[obs] + nw = self.network(network_id) + + # indice of observation in new subnetwork + i_obs = where(nw.segment == self.segment[obs])[0][0] + + if stopped is None: + return nw.relatives(i_obs, order=order) + + else: + i_stopped = where(nw.segment == self.segment[stopped])[0][0] + return nw.relatives([i_obs, i_stopped], order=order) + @property def index_network(self): if self._index_network is None: @@ -229,12 +266,38 @@ def segment_relative_order(self, seg_origine): def relative(self, i_obs, order=2, direct=True, only_past=False, only_future=False): """ - Extract the segments at a certain order. + Extract the segments at a certain order from one observation. + + :param list obs: indice of observation for relative computation + :param int order: order of relatives wanted. 0 means only observations in obs, 1 means direct relatives, ... + + :return: all segments relatives + :rtype: EddiesObservations """ + d = self.segment_relative_order(self.segment[i_obs]) m = (d <= order) * (d != -1) return self.extract_with_mask(m) + def relatives(self, obs, order=2, direct=True, only_past=False, only_future=False): + """ + Extract the segments at a certain order from multiple observations. + + :param list obs: indices of observation for relatives computation + :param int order: order of relatives wanted. 0 means only observations in obs, 1 means direct relatives, ... + + :return: all segments relatives + :rtype: EddiesObservations + """ + + mask = zeros(self.segment.shape, dtype=bool) + + for i_obs in obs: + d = self.segment_relative_order(self.segment[i_obs]) + mask += (d <= order) * (d != -1) + + return self.extract_with_mask(mask) + def numbering_segment(self): """ New numbering of segment @@ -278,7 +341,14 @@ def median_filter(self, half_window, xfield, yfield, inplace=True): return result def display_timeline( - self, ax, event=True, field=None, method=None, factor=1, colors_mode="roll", **kwargs + self, + ax, + event=True, + field=None, + method=None, + factor=1, + colors_mode="roll", + **kwargs, ): """ Plot a timeline of a network. @@ -306,7 +376,13 @@ def display_timeline( if event: mappables.update( - self.event_timeline(ax, field=field, method=method, factor=factor, colors_mode=colors_mode) + self.event_timeline( + ax, + field=field, + method=method, + factor=factor, + colors_mode=colors_mode, + ) ) for i, b0, b1 in self.iter_on("segment"): x = self.time[i] @@ -323,7 +399,7 @@ def display_timeline( if colors_mode == "roll": _color = self.get_color(j) elif colors_mode == "y": - _color = self.get_color(b0-1) + _color = self.get_color(b0 - 1) else: raise NotImplementedError(f"colors_mode '{colors_mode}' not defined") @@ -354,7 +430,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol if colors_mode == "roll": _color = self.get_color(j) elif colors_mode == "y": - _color = self.get_color(b0-1) + _color = self.get_color(b0 - 1) else: raise NotImplementedError(f"colors_mode '{colors_mode}' not defined") @@ -406,13 +482,17 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol kwargs = dict(color="k", zorder=-1, linestyle=" ") if len(events["spliting"]) > 0: X, Y = list(zip(*events["spliting"])) - ref = ax.plot(X, Y, marker="*", markersize=12, label="spliting events", **kwargs)[0] - mappables.setdefault("events",[]).append(ref) + ref = ax.plot( + X, Y, marker="*", markersize=12, label="spliting events", **kwargs + )[0] + mappables.setdefault("events", []).append(ref) if len(events["merging"]) > 0: X, Y = list(zip(*events["merging"])) - ref = ax.plot(X, Y, marker="H", markersize=10, label="merging events", **kwargs)[0] - mappables.setdefault("events",[]).append(ref) + ref = ax.plot( + X, Y, marker="H", markersize=10, label="merging events", **kwargs + )[0] + mappables.setdefault("events", []).append(ref) return mappables @@ -440,15 +520,40 @@ def map_segment(self, method, y, same=True, **kw): out = array(out) return out - def map_network(self, method, y, same=True, **kw): + def map_network(self, method, y, same=True, return_dict=False, **kw): + """ + transform data `y` with method `method` for each track. + + :param Callable method: method to apply on each tracks + :param np.array y: data where to apply method + :param bool same: if True, return array same size from y. else, return list with track edited + :param bool return_dict: if None, mean values are used + :param float kw: to multiply field + :return: array or dict of result from method for each network + """ + + if same and return_dict: + raise NotImplementedError( + "both condition 'same' and 'return_dict' should no be true" + ) + if same: out = empty(y.shape, **kw) + + elif return_dict: + out = dict() + else: out = list() + for i, b0, b1 in self.iter_on(self.track): res = method(y[i]) if same: out[i] = res + + elif return_dict: + out[b0] = res + else: if isinstance(i, slice): if i.start == i.stop: @@ -456,7 +561,8 @@ def map_network(self, method, y, same=True, **kw): elif len(i) == 0: continue out.append(res) - if not same: + + if not same and not return_dict: out = array(out) return out @@ -624,7 +730,7 @@ def death_event(self): indices.append(i.stop - 1) return self.extract_event(list(set(indices))) - def merging_event(self, triplet=False): + def merging_event(self, triplet=False, only_index=False): """Return observation after a merging event. If `triplet=True` return the eddy after a merging event, the eddy before the merging event, @@ -647,13 +753,24 @@ def merging_event(self, triplet=False): idx_m1.append(i_n) if triplet: - return ( - self.extract_event(list(idx_m1)), - self.extract_event(list(idx_m0)), - self.extract_event(list(idx_m0_stop)), - ) + if only_index: + return ( + idx_m1, + idx_m0, + idx_m0_stop, + ) + + else: + return ( + self.extract_event(idx_m1), + self.extract_event(idx_m0), + self.extract_event(idx_m0_stop), + ) else: - return self.extract_event(list(set(idx_m1))) + if only_index: + return self.extract_event(set(idx_m1)) + else: + return list(set(idx_m1)) def spliting_event(self, triplet=False): """Return observation before a splitting event.