6
6
from glob import glob
7
7
8
8
from numba import njit
9
- from numpy import arange , array , bincount , empty , ones , uint32 , unique , zeros
9
+ from numpy import arange , array , bincount , empty , in1d , ones , uint32 , unique , zeros
10
10
11
11
from ..generic import build_index , wrap_longitude
12
12
from ..poly import bbox_intersection , vertice_overlap
@@ -119,13 +119,12 @@ def longer_than(self, nb_day_min=-1, nb_day_max=-1):
119
119
if nb_day_max < 0 :
120
120
nb_day_max = 1000000000000
121
121
mask = zeros (self .shape , dtype = "bool" )
122
- for i , b0 , b1 in self .iter_on (self .segment_track_array ):
122
+ t = self .time
123
+ for i , b0 , b1 in self .iter_on (self .track ):
123
124
nb = i .stop - i .start
124
125
if nb == 0 :
125
126
continue
126
- t = self .time [i ]
127
- dt = t .max () - t .min ()
128
- if nb_day_min <= dt <= nb_day_max :
127
+ if nb_day_min <= ptp (t [i ]) <= nb_day_max :
129
128
mask [i ] = True
130
129
return self .extract_with_mask (mask )
131
130
@@ -164,21 +163,26 @@ def obs_relative_order(self, i_obs):
164
163
self .only_one_network ()
165
164
return self .segment_relative_order (self .segment [i_obs ])
166
165
167
- def connexions (self ):
168
- self .only_one_network ()
166
+ def connexions (self , multi_network = False ):
167
+ if multi_network :
168
+ segment = self .segment_track_array
169
+ else :
170
+ self .only_one_network ()
171
+ segment = self .segment
169
172
segments_connexion = dict ()
170
173
171
174
def add_seg (father , child ):
172
175
if father not in segments_connexion :
173
176
segments_connexion [father ] = list ()
174
177
segments_connexion [father ].append (child )
175
178
176
- for i , seg , _ in self .iter_on ("segment" ):
179
+ previous_obs , next_obs = self .previous_obs , self .next_obs
180
+ for i , seg , _ in self .iter_on (segment ):
177
181
if i .start == i .stop :
178
182
continue
179
- i_p , i_n = self . previous_obs [i .start ], self . next_obs [i .stop - 1 ]
183
+ i_p , i_n = previous_obs [i .start ], next_obs [i .stop - 1 ]
180
184
# segment of interaction
181
- p_seg , n_seg = self . segment [i_p ], self . segment [i_n ]
185
+ p_seg , n_seg = segment [i_p ], segment [i_n ]
182
186
# Where segment are called
183
187
if i_p != - 1 :
184
188
add_seg (p_seg , seg )
@@ -395,6 +399,26 @@ def map_segment(self, method, y, same=True, **kw):
395
399
out = array (out )
396
400
return out
397
401
402
+ def map_network (self , method , y , same = True , ** kw ):
403
+ if same :
404
+ out = empty (y .shape , ** kw )
405
+ else :
406
+ out = list ()
407
+ for i , b0 , b1 in self .iter_on (self .track ):
408
+ res = method (y [i ])
409
+ if same :
410
+ out [i ] = res
411
+ else :
412
+ if isinstance (i , slice ):
413
+ if i .start == i .stop :
414
+ continue
415
+ elif len (i ) == 0 :
416
+ continue
417
+ out .append (res )
418
+ if not same :
419
+ out = array (out )
420
+ return out
421
+
398
422
def scatter_timeline (
399
423
self ,
400
424
ax ,
@@ -410,7 +434,7 @@ def scatter_timeline(
410
434
Must be call on only one network
411
435
"""
412
436
self .only_one_network ()
413
- y = (self .segment if yfield is None else self [ yfield ] ) * yfactor
437
+ y = (self .segment if yfield is None else self . parse_varname ( yfield ) ) * yfactor
414
438
if method == "all" :
415
439
pass
416
440
else :
@@ -536,23 +560,25 @@ def segment_track_array(self):
536
560
def birth_event (self ):
537
561
# FIXME how to manage group 0
538
562
indices = list ()
563
+ previous_obs = self .previous_obs
539
564
for i , _ , _ in self .iter_on (self .segment_track_array ):
540
565
nb = i .stop - i .start
541
566
if nb == 0 :
542
567
continue
543
- i_p = self . previous_obs [i .start ]
568
+ i_p = previous_obs [i .start ]
544
569
if i_p == - 1 :
545
570
indices .append (i .start )
546
571
return self .extract_event (list (set (indices )))
547
572
548
573
def death_event (self ):
549
574
# FIXME how to manage group 0
550
575
indices = list ()
576
+ next_obs = self .next_obs
551
577
for i , _ , _ in self .iter_on (self .segment_track_array ):
552
578
nb = i .stop - i .start
553
579
if nb == 0 :
554
580
continue
555
- i_n = self . next_obs [i .stop - 1 ]
581
+ i_n = next_obs [i .stop - 1 ]
556
582
if i_n == - 1 :
557
583
indices .append (i .stop - 1 )
558
584
return self .extract_event (list (set (indices )))
@@ -567,16 +593,16 @@ def merging_event(self, triplet=False):
567
593
if triplet :
568
594
idx_m0_stop = list ()
569
595
idx_m0 = list ()
570
-
596
+ next_obs , previous_obs = self . next_obs , self . previous_obs
571
597
for i , _ , _ in self .iter_on (self .segment_track_array ):
572
598
nb = i .stop - i .start
573
599
if nb == 0 :
574
600
continue
575
- i_n = self . next_obs [i .stop - 1 ]
601
+ i_n = next_obs [i .stop - 1 ]
576
602
if i_n != - 1 :
577
603
if triplet :
578
604
idx_m0_stop .append (i .stop - 1 )
579
- idx_m0 .append (self . previous_obs [i_n ])
605
+ idx_m0 .append (previous_obs [i_n ])
580
606
idx_m1 .append (i_n )
581
607
582
608
if triplet :
@@ -598,15 +624,16 @@ def spliting_event(self, triplet=False):
598
624
if triplet :
599
625
idx_s1_start = list ()
600
626
idx_s1 = list ()
627
+ next_obs , previous_obs = self .next_obs , self .previous_obs
601
628
for i , _ , _ in self .iter_on (self .segment_track_array ):
602
629
nb = i .stop - i .start
603
630
if nb == 0 :
604
631
continue
605
- i_p = self . previous_obs [i .start ]
632
+ i_p = previous_obs [i .start ]
606
633
if i_p != - 1 :
607
634
if triplet :
608
635
idx_s1_start .append (i .start )
609
- idx_s1 .append (self . next_obs [i_p ])
636
+ idx_s1 .append (next_obs [i_p ])
610
637
idx_s0 .append (i_p )
611
638
if triplet :
612
639
return (
@@ -700,32 +727,38 @@ def plot(self, ax, ref=None, color_cycle=None, **kwargs):
700
727
j += 1
701
728
return mappables
702
729
703
- def remove_dead_end (self , nobs = 3 , recursive = 0 , mask = None ):
730
+ def remove_dead_end (self , nobs = 3 , ndays = 0 , recursive = 0 , mask = None ):
704
731
"""
705
732
.. warning::
706
733
It will remove short segment which splits than merges with same segment
707
734
"""
708
- self .only_one_network ()
709
735
segments_keep = list ()
710
- connexions = self .connexions ()
711
- for i , b0 , _ in self .iter_on ( "segment" ):
712
- nb = i . stop - i . start
736
+ connexions = self .connexions (multi_network = True )
737
+ t = self .time
738
+ for i , b0 , _ in self . iter_on ( self . segment_track_array ):
713
739
if mask and mask [i ].any ():
714
740
segments_keep .append (b0 )
715
741
continue
716
- if nb < nobs and len (connexions .get (b0 , tuple ())) < 2 :
742
+ nb = i .stop - i .start
743
+ dt = t [i .stop - 1 ] - t [i .start ]
744
+ if (nb < nobs or dt < ndays ) and len (connexions .get (b0 , tuple ())) < 2 :
717
745
continue
718
746
segments_keep .append (b0 )
719
747
if recursive > 0 :
720
- return self .extract_segment (segments_keep ).remove_dead_end (
721
- nobs , recursive - 1
748
+ return self .extract_segment (segments_keep , absolute = True ).remove_dead_end (
749
+ nobs , ndays , recursive - 1
722
750
)
723
- return self .extract_segment (segments_keep )
751
+ return self .extract_segment (segments_keep , absolute = True )
724
752
725
- def extract_segment (self , segments ):
753
+ def extract_segment (self , segments , absolute = False ):
726
754
mask = ones (self .shape , dtype = "bool" )
727
- for i , b0 , b1 in self .iter_on ("segment" ):
728
- if b0 not in segments :
755
+ segments = array (segments )
756
+ values = self .segment_track_array if absolute else "segment"
757
+ keep = ones (values .max () + 1 , dtype = "bool" )
758
+ v = unique (values )
759
+ keep [v ] = in1d (v , segments )
760
+ for i , b0 , b1 in self .iter_on (values ):
761
+ if not keep [b0 ]:
729
762
mask [i ] = False
730
763
return self .extract_with_mask (mask )
731
764
@@ -929,3 +962,8 @@ def new_numbering(segs):
929
962
s0 = segs [i ]
930
963
j += 1
931
964
segs [i ] = j
965
+
966
+
967
+ @njit (cache = True )
968
+ def ptp (values ):
969
+ return values .max () - values .min ()
0 commit comments