66from glob import glob
77
88from numba import njit
9- from numpy import arange , array , bincount , empty , in1d , ones , uint32 , unique , zeros
9+ from numpy import (
10+ arange ,
11+ array ,
12+ bincount ,
13+ empty ,
14+ in1d ,
15+ ones ,
16+ uint32 ,
17+ unique ,
18+ where ,
19+ zeros ,
20+ )
1021
1122from ..generic import build_index , wrap_longitude
1223from ..poly import bbox_intersection , vertice_overlap
@@ -71,6 +82,32 @@ def __init__(self, *args, **kwargs):
7182 super ().__init__ (* args , ** kwargs )
7283 self ._index_network = None
7384
85+ def find_segments_relative (self , obs , stopped = None , order = 1 ):
86+ """
87+ find all relative segments within an event from an order.
88+
89+ :param int obs: indice of event after the event
90+ :param int stopped: indice of event before the event
91+ :param int order: order of relatives accepted
92+
93+ :return: all segments relatives
94+ :rtype: EddiesObservations
95+ """
96+
97+ # extraction of network where the event is
98+ network_id = self .tracks [obs ]
99+ nw = self .network (network_id )
100+
101+ # indice of observation in new subnetwork
102+ i_obs = where (nw .segment == self .segment [obs ])[0 ][0 ]
103+
104+ if stopped is None :
105+ return nw .relatives (i_obs , order = order )
106+
107+ else :
108+ i_stopped = where (nw .segment == self .segment [stopped ])[0 ][0 ]
109+ return nw .relatives ([i_obs , i_stopped ], order = order )
110+
74111 @property
75112 def index_network (self ):
76113 if self ._index_network is None :
@@ -229,12 +266,38 @@ def segment_relative_order(self, seg_origine):
229266
230267 def relative (self , i_obs , order = 2 , direct = True , only_past = False , only_future = False ):
231268 """
232- Extract the segments at a certain order.
269+ Extract the segments at a certain order from one observation.
270+
271+ :param list obs: indice of observation for relative computation
272+ :param int order: order of relatives wanted. 0 means only observations in obs, 1 means direct relatives, ...
273+
274+ :return: all segments relatives
275+ :rtype: EddiesObservations
233276 """
277+
234278 d = self .segment_relative_order (self .segment [i_obs ])
235279 m = (d <= order ) * (d != - 1 )
236280 return self .extract_with_mask (m )
237281
282+ def relatives (self , obs , order = 2 , direct = True , only_past = False , only_future = False ):
283+ """
284+ Extract the segments at a certain order from multiple observations.
285+
286+ :param list obs: indices of observation for relatives computation
287+ :param int order: order of relatives wanted. 0 means only observations in obs, 1 means direct relatives, ...
288+
289+ :return: all segments relatives
290+ :rtype: EddiesObservations
291+ """
292+
293+ mask = zeros (self .segment .shape , dtype = bool )
294+
295+ for i_obs in obs :
296+ d = self .segment_relative_order (self .segment [i_obs ])
297+ mask += (d <= order ) * (d != - 1 )
298+
299+ return self .extract_with_mask (mask )
300+
238301 def numbering_segment (self ):
239302 """
240303 New numbering of segment
@@ -278,7 +341,14 @@ def median_filter(self, half_window, xfield, yfield, inplace=True):
278341 return result
279342
280343 def display_timeline (
281- self , ax , event = True , field = None , method = None , factor = 1 , ** kwargs
344+ self ,
345+ ax ,
346+ event = True ,
347+ field = None ,
348+ method = None ,
349+ factor = 1 ,
350+ colors_mode = "roll" ,
351+ ** kwargs ,
282352 ):
283353 """
284354 Plot a timeline of a network.
@@ -289,6 +359,7 @@ def display_timeline(
289359 :param str,array field: yaxis values, if None, segments are used
290360 :param str method: if None, mean values are used
291361 :param float factor: to multiply field
362+ :param str colors_mode: color of lines. "roll" means looping through colors, "y" means color adapt the y values (for matching color plots)
292363 :return: plot mappable
293364 """
294365 self .only_one_network ()
@@ -302,9 +373,16 @@ def display_timeline(
302373 )
303374 line_kw .update (kwargs )
304375 mappables = dict (lines = list ())
376+
305377 if event :
306378 mappables .update (
307- self .event_timeline (ax , field = field , method = method , factor = factor )
379+ self .event_timeline (
380+ ax ,
381+ field = field ,
382+ method = method ,
383+ factor = factor ,
384+ colors_mode = colors_mode ,
385+ )
308386 )
309387 for i , b0 , b1 in self .iter_on ("segment" ):
310388 x = self .time [i ]
@@ -317,14 +395,25 @@ def display_timeline(
317395 y = self [field ][i ] * factor
318396 else :
319397 y = self [field ][i ].mean () * ones (x .shape ) * factor
320- line = ax .plot (x , y , ** line_kw , color = self .COLORS [j % self .NB_COLORS ])[0 ]
398+
399+ if colors_mode == "roll" :
400+ _color = self .get_color (j )
401+ elif colors_mode == "y" :
402+ _color = self .get_color (b0 - 1 )
403+ else :
404+ raise NotImplementedError (f"colors_mode '{ colors_mode } ' not defined" )
405+
406+ line = ax .plot (x , y , ** line_kw , color = _color )[0 ]
321407 mappables ["lines" ].append (line )
322408 j += 1
323409
324410 return mappables
325411
326- def event_timeline (self , ax , field = None , method = None , factor = 1 ):
412+ def event_timeline (self , ax , field = None , method = None , factor = 1 , colors_mode = "roll" ):
413+ """mark events in plot"""
327414 j = 0
415+ events = dict (spliting = [], merging = [])
416+
328417 # TODO : fill mappables dict
329418 y_seg = dict ()
330419 if field is not None and method != "all" :
@@ -337,7 +426,16 @@ def event_timeline(self, ax, field=None, method=None, factor=1):
337426 x = self .time [i ]
338427 if x .shape [0 ] == 0 :
339428 continue
340- event_kw = dict (color = self .COLORS [j % self .NB_COLORS ], ls = "-" , zorder = 1 )
429+
430+ if colors_mode == "roll" :
431+ _color = self .get_color (j )
432+ elif colors_mode == "y" :
433+ _color = self .get_color (b0 - 1 )
434+ else :
435+ raise NotImplementedError (f"colors_mode '{ colors_mode } ' not defined" )
436+
437+ event_kw = dict (color = _color , ls = "-" , zorder = 1 )
438+
341439 i_n , i_p = (
342440 self .next_obs [i .stop - 1 ],
343441 self .previous_obs [i .start ],
@@ -361,7 +459,8 @@ def event_timeline(self, ax, field=None, method=None, factor=1):
361459 )
362460 )
363461 ax .plot ((x [- 1 ], self .time [i_n ]), (y0 , y1 ), ** event_kw )[0 ]
364- ax .plot (x [- 1 ], y0 , color = "k" , marker = "H" , markersize = 10 , zorder = - 1 )[0 ]
462+ events ["merging" ].append ((x [- 1 ], y0 ))
463+
365464 if i_p != - 1 :
366465 seg_previous = self .segment [i_p ]
367466 if field is not None and method == "all" :
@@ -376,8 +475,25 @@ def event_timeline(self, ax, field=None, method=None, factor=1):
376475 )
377476 )
378477 ax .plot ((x [0 ], self .time [i_p ]), (y0 , y1 ), ** event_kw )[0 ]
379- ax .plot (x [0 ], y0 , color = "k" , marker = "*" , markersize = 12 , zorder = - 1 )[0 ]
478+ events ["spliting" ].append ((x [0 ], y0 ))
479+
380480 j += 1
481+
482+ kwargs = dict (color = "k" , zorder = - 1 , linestyle = " " )
483+ if len (events ["spliting" ]) > 0 :
484+ X , Y = list (zip (* events ["spliting" ]))
485+ ref = ax .plot (
486+ X , Y , marker = "*" , markersize = 12 , label = "spliting events" , ** kwargs
487+ )[0 ]
488+ mappables .setdefault ("events" , []).append (ref )
489+
490+ if len (events ["merging" ]) > 0 :
491+ X , Y = list (zip (* events ["merging" ]))
492+ ref = ax .plot (
493+ X , Y , marker = "H" , markersize = 10 , label = "merging events" , ** kwargs
494+ )[0 ]
495+ mappables .setdefault ("events" , []).append (ref )
496+
381497 return mappables
382498
383499 def mean_by_segment (self , y , ** kw ):
@@ -404,23 +520,49 @@ def map_segment(self, method, y, same=True, **kw):
404520 out = array (out )
405521 return out
406522
407- def map_network (self , method , y , same = True , ** kw ):
523+ def map_network (self , method , y , same = True , return_dict = False , ** kw ):
524+ """
525+ transform data `y` with method `method` for each track.
526+
527+ :param Callable method: method to apply on each tracks
528+ :param np.array y: data where to apply method
529+ :param bool same: if True, return array same size from y. else, return list with track edited
530+ :param bool return_dict: if None, mean values are used
531+ :param float kw: to multiply field
532+ :return: array or dict of result from method for each network
533+ """
534+
535+ if same and return_dict :
536+ raise NotImplementedError (
537+ "both condition 'same' and 'return_dict' should no be true"
538+ )
539+
408540 if same :
409541 out = empty (y .shape , ** kw )
542+
543+ elif return_dict :
544+ out = dict ()
545+
410546 else :
411547 out = list ()
548+
412549 for i , b0 , b1 in self .iter_on (self .track ):
413550 res = method (y [i ])
414551 if same :
415552 out [i ] = res
553+
554+ elif return_dict :
555+ out [b0 ] = res
556+
416557 else :
417558 if isinstance (i , slice ):
418559 if i .start == i .stop :
419560 continue
420561 elif len (i ) == 0 :
421562 continue
422563 out .append (res )
423- if not same :
564+
565+ if not same and not return_dict :
424566 out = array (out )
425567 return out
426568
@@ -588,7 +730,7 @@ def death_event(self):
588730 indices .append (i .stop - 1 )
589731 return self .extract_event (list (set (indices )))
590732
591- def merging_event (self , triplet = False ):
733+ def merging_event (self , triplet = False , only_index = False ):
592734 """Return observation after a merging event.
593735
594736 If `triplet=True` return the eddy after a merging event, the eddy before the merging event,
@@ -611,13 +753,24 @@ def merging_event(self, triplet=False):
611753 idx_m1 .append (i_n )
612754
613755 if triplet :
614- return (
615- self .extract_event (list (idx_m1 )),
616- self .extract_event (list (idx_m0 )),
617- self .extract_event (list (idx_m0_stop )),
618- )
756+ if only_index :
757+ return (
758+ idx_m1 ,
759+ idx_m0 ,
760+ idx_m0_stop ,
761+ )
762+
763+ else :
764+ return (
765+ self .extract_event (idx_m1 ),
766+ self .extract_event (idx_m0 ),
767+ self .extract_event (idx_m0_stop ),
768+ )
619769 else :
620- return self .extract_event (list (set (idx_m1 )))
770+ if only_index :
771+ return self .extract_event (set (idx_m1 ))
772+ else :
773+ return list (set (idx_m1 ))
621774
622775 def spliting_event (self , triplet = False ):
623776 """Return observation before a splitting event.
0 commit comments