66from glob import glob
77
88from numba import njit
9- from numpy import arange , array , bincount , empty , ones , uint32 , unique , zeros
9+ from numpy import arange , array , bincount , empty , in1d , ones , uint32 , unique , zeros
1010
1111from ..generic import build_index , wrap_longitude
1212from ..poly import bbox_intersection , vertice_overlap
@@ -119,13 +119,12 @@ def longer_than(self, nb_day_min=-1, nb_day_max=-1):
119119 if nb_day_max < 0 :
120120 nb_day_max = 1000000000000
121121 mask = zeros (self .shape , dtype = "bool" )
122- for i , b0 , b1 in self .iter_on (self .segment_track_array ):
122+ t = self .time
123+ for i , b0 , b1 in self .iter_on (self .track ):
123124 nb = i .stop - i .start
124125 if nb == 0 :
125126 continue
126- t = self .time [i ]
127- dt = t .max () - t .min ()
128- if nb_day_min <= dt <= nb_day_max :
127+ if nb_day_min <= ptp (t [i ]) <= nb_day_max :
129128 mask [i ] = True
130129 return self .extract_with_mask (mask )
131130
@@ -164,21 +163,26 @@ def obs_relative_order(self, i_obs):
164163 self .only_one_network ()
165164 return self .segment_relative_order (self .segment [i_obs ])
166165
167- def connexions (self ):
168- self .only_one_network ()
166+ def connexions (self , multi_network = False ):
167+ if multi_network :
168+ segment = self .segment_track_array
169+ else :
170+ self .only_one_network ()
171+ segment = self .segment
169172 segments_connexion = dict ()
170173
171174 def add_seg (father , child ):
172175 if father not in segments_connexion :
173176 segments_connexion [father ] = list ()
174177 segments_connexion [father ].append (child )
175178
176- for i , seg , _ in self .iter_on ("segment" ):
179+ previous_obs , next_obs = self .previous_obs , self .next_obs
180+ for i , seg , _ in self .iter_on (segment ):
177181 if i .start == i .stop :
178182 continue
179- i_p , i_n = self . previous_obs [i .start ], self . next_obs [i .stop - 1 ]
183+ i_p , i_n = previous_obs [i .start ], next_obs [i .stop - 1 ]
180184 # segment of interaction
181- p_seg , n_seg = self . segment [i_p ], self . segment [i_n ]
185+ p_seg , n_seg = segment [i_p ], segment [i_n ]
182186 # Where segment are called
183187 if i_p != - 1 :
184188 add_seg (p_seg , seg )
@@ -395,6 +399,26 @@ def map_segment(self, method, y, same=True, **kw):
395399 out = array (out )
396400 return out
397401
402+ def map_network (self , method , y , same = True , ** kw ):
403+ if same :
404+ out = empty (y .shape , ** kw )
405+ else :
406+ out = list ()
407+ for i , b0 , b1 in self .iter_on (self .track ):
408+ res = method (y [i ])
409+ if same :
410+ out [i ] = res
411+ else :
412+ if isinstance (i , slice ):
413+ if i .start == i .stop :
414+ continue
415+ elif len (i ) == 0 :
416+ continue
417+ out .append (res )
418+ if not same :
419+ out = array (out )
420+ return out
421+
398422 def scatter_timeline (
399423 self ,
400424 ax ,
@@ -410,7 +434,7 @@ def scatter_timeline(
410434 Must be call on only one network
411435 """
412436 self .only_one_network ()
413- y = (self .segment if yfield is None else self [ yfield ] ) * yfactor
437+ y = (self .segment if yfield is None else self . parse_varname ( yfield ) ) * yfactor
414438 if method == "all" :
415439 pass
416440 else :
@@ -536,23 +560,25 @@ def segment_track_array(self):
536560 def birth_event (self ):
537561 # FIXME how to manage group 0
538562 indices = list ()
563+ previous_obs = self .previous_obs
539564 for i , _ , _ in self .iter_on (self .segment_track_array ):
540565 nb = i .stop - i .start
541566 if nb == 0 :
542567 continue
543- i_p = self . previous_obs [i .start ]
568+ i_p = previous_obs [i .start ]
544569 if i_p == - 1 :
545570 indices .append (i .start )
546571 return self .extract_event (list (set (indices )))
547572
548573 def death_event (self ):
549574 # FIXME how to manage group 0
550575 indices = list ()
576+ next_obs = self .next_obs
551577 for i , _ , _ in self .iter_on (self .segment_track_array ):
552578 nb = i .stop - i .start
553579 if nb == 0 :
554580 continue
555- i_n = self . next_obs [i .stop - 1 ]
581+ i_n = next_obs [i .stop - 1 ]
556582 if i_n == - 1 :
557583 indices .append (i .stop - 1 )
558584 return self .extract_event (list (set (indices )))
@@ -567,16 +593,16 @@ def merging_event(self, triplet=False):
567593 if triplet :
568594 idx_m0_stop = list ()
569595 idx_m0 = list ()
570-
596+ next_obs , previous_obs = self . next_obs , self . previous_obs
571597 for i , _ , _ in self .iter_on (self .segment_track_array ):
572598 nb = i .stop - i .start
573599 if nb == 0 :
574600 continue
575- i_n = self . next_obs [i .stop - 1 ]
601+ i_n = next_obs [i .stop - 1 ]
576602 if i_n != - 1 :
577603 if triplet :
578604 idx_m0_stop .append (i .stop - 1 )
579- idx_m0 .append (self . previous_obs [i_n ])
605+ idx_m0 .append (previous_obs [i_n ])
580606 idx_m1 .append (i_n )
581607
582608 if triplet :
@@ -598,15 +624,16 @@ def spliting_event(self, triplet=False):
598624 if triplet :
599625 idx_s1_start = list ()
600626 idx_s1 = list ()
627+ next_obs , previous_obs = self .next_obs , self .previous_obs
601628 for i , _ , _ in self .iter_on (self .segment_track_array ):
602629 nb = i .stop - i .start
603630 if nb == 0 :
604631 continue
605- i_p = self . previous_obs [i .start ]
632+ i_p = previous_obs [i .start ]
606633 if i_p != - 1 :
607634 if triplet :
608635 idx_s1_start .append (i .start )
609- idx_s1 .append (self . next_obs [i_p ])
636+ idx_s1 .append (next_obs [i_p ])
610637 idx_s0 .append (i_p )
611638 if triplet :
612639 return (
@@ -700,32 +727,38 @@ def plot(self, ax, ref=None, color_cycle=None, **kwargs):
700727 j += 1
701728 return mappables
702729
703- def remove_dead_end (self , nobs = 3 , recursive = 0 , mask = None ):
730+ def remove_dead_end (self , nobs = 3 , ndays = 0 , recursive = 0 , mask = None ):
704731 """
705732 .. warning::
706733 It will remove short segment which splits than merges with same segment
707734 """
708- self .only_one_network ()
709735 segments_keep = list ()
710- connexions = self .connexions ()
711- for i , b0 , _ in self .iter_on ( "segment" ):
712- nb = i . stop - i . start
736+ connexions = self .connexions (multi_network = True )
737+ t = self .time
738+ for i , b0 , _ in self . iter_on ( self . segment_track_array ):
713739 if mask and mask [i ].any ():
714740 segments_keep .append (b0 )
715741 continue
716- if nb < nobs and len (connexions .get (b0 , tuple ())) < 2 :
742+ nb = i .stop - i .start
743+ dt = t [i .stop - 1 ] - t [i .start ]
744+ if (nb < nobs or dt < ndays ) and len (connexions .get (b0 , tuple ())) < 2 :
717745 continue
718746 segments_keep .append (b0 )
719747 if recursive > 0 :
720- return self .extract_segment (segments_keep ).remove_dead_end (
721- nobs , recursive - 1
748+ return self .extract_segment (segments_keep , absolute = True ).remove_dead_end (
749+ nobs , ndays , recursive - 1
722750 )
723- return self .extract_segment (segments_keep )
751+ return self .extract_segment (segments_keep , absolute = True )
724752
725- def extract_segment (self , segments ):
753+ def extract_segment (self , segments , absolute = False ):
726754 mask = ones (self .shape , dtype = "bool" )
727- for i , b0 , b1 in self .iter_on ("segment" ):
728- if b0 not in segments :
755+ segments = array (segments )
756+ values = self .segment_track_array if absolute else "segment"
757+ keep = ones (values .max () + 1 , dtype = "bool" )
758+ v = unique (values )
759+ keep [v ] = in1d (v , segments )
760+ for i , b0 , b1 in self .iter_on (values ):
761+ if not keep [b0 ]:
729762 mask [i ] = False
730763 return self .extract_with_mask (mask )
731764
@@ -929,3 +962,8 @@ def new_numbering(segs):
929962 s0 = segs [i ]
930963 j += 1
931964 segs [i ] = j
965+
966+
967+ @njit (cache = True )
968+ def ptp (values ):
969+ return values .max () - values .min ()
0 commit comments