6
6
from glob import glob
7
7
8
8
from numba import njit
9
- from numpy import arange , array , bincount , empty , in1d , ones , uint32 , unique , zeros
9
+ from numpy import (
10
+ arange ,
11
+ array ,
12
+ bincount ,
13
+ empty ,
14
+ in1d ,
15
+ ones ,
16
+ uint32 ,
17
+ unique ,
18
+ where ,
19
+ zeros ,
20
+ )
10
21
11
22
from ..generic import build_index , wrap_longitude
12
23
from ..poly import bbox_intersection , vertice_overlap
@@ -71,6 +82,32 @@ def __init__(self, *args, **kwargs):
71
82
super ().__init__ (* args , ** kwargs )
72
83
self ._index_network = None
73
84
85
+ def find_segments_relative (self , obs , stopped = None , order = 1 ):
86
+ """
87
+ find all relative segments within an event from an order.
88
+
89
+ :param int obs: indice of event after the event
90
+ :param int stopped: indice of event before the event
91
+ :param int order: order of relatives accepted
92
+
93
+ :return: all segments relatives
94
+ :rtype: EddiesObservations
95
+ """
96
+
97
+ # extraction of network where the event is
98
+ network_id = self .tracks [obs ]
99
+ nw = self .network (network_id )
100
+
101
+ # indice of observation in new subnetwork
102
+ i_obs = where (nw .segment == self .segment [obs ])[0 ][0 ]
103
+
104
+ if stopped is None :
105
+ return nw .relatives (i_obs , order = order )
106
+
107
+ else :
108
+ i_stopped = where (nw .segment == self .segment [stopped ])[0 ][0 ]
109
+ return nw .relatives ([i_obs , i_stopped ], order = order )
110
+
74
111
@property
75
112
def index_network (self ):
76
113
if self ._index_network is None :
@@ -229,12 +266,38 @@ def segment_relative_order(self, seg_origine):
229
266
230
267
def relative (self , i_obs , order = 2 , direct = True , only_past = False , only_future = False ):
231
268
"""
232
- Extract the segments at a certain order.
269
+ Extract the segments at a certain order from one observation.
270
+
271
+ :param list obs: indice of observation for relative computation
272
+ :param int order: order of relatives wanted. 0 means only observations in obs, 1 means direct relatives, ...
273
+
274
+ :return: all segments relatives
275
+ :rtype: EddiesObservations
233
276
"""
277
+
234
278
d = self .segment_relative_order (self .segment [i_obs ])
235
279
m = (d <= order ) * (d != - 1 )
236
280
return self .extract_with_mask (m )
237
281
282
+ def relatives (self , obs , order = 2 , direct = True , only_past = False , only_future = False ):
283
+ """
284
+ Extract the segments at a certain order from multiple observations.
285
+
286
+ :param list obs: indices of observation for relatives computation
287
+ :param int order: order of relatives wanted. 0 means only observations in obs, 1 means direct relatives, ...
288
+
289
+ :return: all segments relatives
290
+ :rtype: EddiesObservations
291
+ """
292
+
293
+ mask = zeros (self .segment .shape , dtype = bool )
294
+
295
+ for i_obs in obs :
296
+ d = self .segment_relative_order (self .segment [i_obs ])
297
+ mask += (d <= order ) * (d != - 1 )
298
+
299
+ return self .extract_with_mask (mask )
300
+
238
301
def numbering_segment (self ):
239
302
"""
240
303
New numbering of segment
@@ -278,7 +341,14 @@ def median_filter(self, half_window, xfield, yfield, inplace=True):
278
341
return result
279
342
280
343
def display_timeline (
281
- self , ax , event = True , field = None , method = None , factor = 1 , ** kwargs
344
+ self ,
345
+ ax ,
346
+ event = True ,
347
+ field = None ,
348
+ method = None ,
349
+ factor = 1 ,
350
+ colors_mode = "roll" ,
351
+ ** kwargs ,
282
352
):
283
353
"""
284
354
Plot a timeline of a network.
@@ -289,6 +359,7 @@ def display_timeline(
289
359
:param str,array field: yaxis values, if None, segments are used
290
360
:param str method: if None, mean values are used
291
361
:param float factor: to multiply field
362
+ :param str colors_mode: color of lines. "roll" means looping through colors, "y" means color adapt the y values (for matching color plots)
292
363
:return: plot mappable
293
364
"""
294
365
self .only_one_network ()
@@ -302,9 +373,16 @@ def display_timeline(
302
373
)
303
374
line_kw .update (kwargs )
304
375
mappables = dict (lines = list ())
376
+
305
377
if event :
306
378
mappables .update (
307
- self .event_timeline (ax , field = field , method = method , factor = factor )
379
+ self .event_timeline (
380
+ ax ,
381
+ field = field ,
382
+ method = method ,
383
+ factor = factor ,
384
+ colors_mode = colors_mode ,
385
+ )
308
386
)
309
387
for i , b0 , b1 in self .iter_on ("segment" ):
310
388
x = self .time [i ]
@@ -317,14 +395,25 @@ def display_timeline(
317
395
y = self [field ][i ] * factor
318
396
else :
319
397
y = self [field ][i ].mean () * ones (x .shape ) * factor
320
- line = ax .plot (x , y , ** line_kw , color = self .COLORS [j % self .NB_COLORS ])[0 ]
398
+
399
+ if colors_mode == "roll" :
400
+ _color = self .get_color (j )
401
+ elif colors_mode == "y" :
402
+ _color = self .get_color (b0 - 1 )
403
+ else :
404
+ raise NotImplementedError (f"colors_mode '{ colors_mode } ' not defined" )
405
+
406
+ line = ax .plot (x , y , ** line_kw , color = _color )[0 ]
321
407
mappables ["lines" ].append (line )
322
408
j += 1
323
409
324
410
return mappables
325
411
326
- def event_timeline (self , ax , field = None , method = None , factor = 1 ):
412
+ def event_timeline (self , ax , field = None , method = None , factor = 1 , colors_mode = "roll" ):
413
+ """mark events in plot"""
327
414
j = 0
415
+ events = dict (spliting = [], merging = [])
416
+
328
417
# TODO : fill mappables dict
329
418
y_seg = dict ()
330
419
if field is not None and method != "all" :
@@ -337,7 +426,16 @@ def event_timeline(self, ax, field=None, method=None, factor=1):
337
426
x = self .time [i ]
338
427
if x .shape [0 ] == 0 :
339
428
continue
340
- event_kw = dict (color = self .COLORS [j % self .NB_COLORS ], ls = "-" , zorder = 1 )
429
+
430
+ if colors_mode == "roll" :
431
+ _color = self .get_color (j )
432
+ elif colors_mode == "y" :
433
+ _color = self .get_color (b0 - 1 )
434
+ else :
435
+ raise NotImplementedError (f"colors_mode '{ colors_mode } ' not defined" )
436
+
437
+ event_kw = dict (color = _color , ls = "-" , zorder = 1 )
438
+
341
439
i_n , i_p = (
342
440
self .next_obs [i .stop - 1 ],
343
441
self .previous_obs [i .start ],
@@ -361,7 +459,8 @@ def event_timeline(self, ax, field=None, method=None, factor=1):
361
459
)
362
460
)
363
461
ax .plot ((x [- 1 ], self .time [i_n ]), (y0 , y1 ), ** event_kw )[0 ]
364
- ax .plot (x [- 1 ], y0 , color = "k" , marker = "H" , markersize = 10 , zorder = - 1 )[0 ]
462
+ events ["merging" ].append ((x [- 1 ], y0 ))
463
+
365
464
if i_p != - 1 :
366
465
seg_previous = self .segment [i_p ]
367
466
if field is not None and method == "all" :
@@ -376,8 +475,25 @@ def event_timeline(self, ax, field=None, method=None, factor=1):
376
475
)
377
476
)
378
477
ax .plot ((x [0 ], self .time [i_p ]), (y0 , y1 ), ** event_kw )[0 ]
379
- ax .plot (x [0 ], y0 , color = "k" , marker = "*" , markersize = 12 , zorder = - 1 )[0 ]
478
+ events ["spliting" ].append ((x [0 ], y0 ))
479
+
380
480
j += 1
481
+
482
+ kwargs = dict (color = "k" , zorder = - 1 , linestyle = " " )
483
+ if len (events ["spliting" ]) > 0 :
484
+ X , Y = list (zip (* events ["spliting" ]))
485
+ ref = ax .plot (
486
+ X , Y , marker = "*" , markersize = 12 , label = "spliting events" , ** kwargs
487
+ )[0 ]
488
+ mappables .setdefault ("events" , []).append (ref )
489
+
490
+ if len (events ["merging" ]) > 0 :
491
+ X , Y = list (zip (* events ["merging" ]))
492
+ ref = ax .plot (
493
+ X , Y , marker = "H" , markersize = 10 , label = "merging events" , ** kwargs
494
+ )[0 ]
495
+ mappables .setdefault ("events" , []).append (ref )
496
+
381
497
return mappables
382
498
383
499
def mean_by_segment (self , y , ** kw ):
@@ -404,23 +520,49 @@ def map_segment(self, method, y, same=True, **kw):
404
520
out = array (out )
405
521
return out
406
522
407
- def map_network (self , method , y , same = True , ** kw ):
523
+ def map_network (self , method , y , same = True , return_dict = False , ** kw ):
524
+ """
525
+ transform data `y` with method `method` for each track.
526
+
527
+ :param Callable method: method to apply on each tracks
528
+ :param np.array y: data where to apply method
529
+ :param bool same: if True, return array same size from y. else, return list with track edited
530
+ :param bool return_dict: if None, mean values are used
531
+ :param float kw: to multiply field
532
+ :return: array or dict of result from method for each network
533
+ """
534
+
535
+ if same and return_dict :
536
+ raise NotImplementedError (
537
+ "both condition 'same' and 'return_dict' should no be true"
538
+ )
539
+
408
540
if same :
409
541
out = empty (y .shape , ** kw )
542
+
543
+ elif return_dict :
544
+ out = dict ()
545
+
410
546
else :
411
547
out = list ()
548
+
412
549
for i , b0 , b1 in self .iter_on (self .track ):
413
550
res = method (y [i ])
414
551
if same :
415
552
out [i ] = res
553
+
554
+ elif return_dict :
555
+ out [b0 ] = res
556
+
416
557
else :
417
558
if isinstance (i , slice ):
418
559
if i .start == i .stop :
419
560
continue
420
561
elif len (i ) == 0 :
421
562
continue
422
563
out .append (res )
423
- if not same :
564
+
565
+ if not same and not return_dict :
424
566
out = array (out )
425
567
return out
426
568
@@ -588,7 +730,7 @@ def death_event(self):
588
730
indices .append (i .stop - 1 )
589
731
return self .extract_event (list (set (indices )))
590
732
591
- def merging_event (self , triplet = False ):
733
+ def merging_event (self , triplet = False , only_index = False ):
592
734
"""Return observation after a merging event.
593
735
594
736
If `triplet=True` return the eddy after a merging event, the eddy before the merging event,
@@ -611,13 +753,24 @@ def merging_event(self, triplet=False):
611
753
idx_m1 .append (i_n )
612
754
613
755
if triplet :
614
- return (
615
- self .extract_event (list (idx_m1 )),
616
- self .extract_event (list (idx_m0 )),
617
- self .extract_event (list (idx_m0_stop )),
618
- )
756
+ if only_index :
757
+ return (
758
+ idx_m1 ,
759
+ idx_m0 ,
760
+ idx_m0_stop ,
761
+ )
762
+
763
+ else :
764
+ return (
765
+ self .extract_event (idx_m1 ),
766
+ self .extract_event (idx_m0 ),
767
+ self .extract_event (idx_m0_stop ),
768
+ )
619
769
else :
620
- return self .extract_event (list (set (idx_m1 )))
770
+ if only_index :
771
+ return self .extract_event (set (idx_m1 ))
772
+ else :
773
+ return list (set (idx_m1 ))
621
774
622
775
def spliting_event (self , triplet = False ):
623
776
"""Return observation before a splitting event.
0 commit comments