@@ -117,13 +117,14 @@ def __repr__(self):
117117 m_event , s_event = self .merging_event (only_index = True , triplet = True )[0 ], self .splitting_event (only_index = True , triplet = True )[0 ]
118118 period = (self .period [1 ] - self .period [0 ]) / 365.25
119119 nb_by_network = self .network_size ()
120+ nb_trash = 0 if self .ref_index != 0 else nb_by_network [0 ]
120121 big = 50_000
121122 infos = [
122123 f"Atlas with { self .nb_network } networks ({ self .nb_network / period :0.0f} networks/year),"
123124 f" { self .nb_segment } segments ({ self .nb_segment / period :0.0f} segments/year), { len (self )} observations ({ len (self ) / period :0.0f} observations/year)" ,
124125 f" { m_event .size } merging ({ m_event .size / period :0.0f} merging/year), { s_event .size } splitting ({ s_event .size / period :0.0f} splitting/year)" ,
125126 f" with { (nb_by_network > big ).sum ()} network with more than { big } obs and the biggest have { nb_by_network .max ()} observations ({ nb_by_network [nb_by_network > big ].sum ()} observations cumulate)" ,
126- f" { nb_by_network [ 0 ] } observations in trash"
127+ f" { nb_trash } observations in trash"
127128 ]
128129 return "\n " .join (infos )
129130
@@ -369,26 +370,29 @@ def correct_close_events(self, nb_days_max=20):
369370
370371 # we keep the real segment number
371372 seg_corrected_copy = segment_copy [seg_slice .stop - 1 ]
373+ if i_seg_n == - 1 :
374+ continue
372375
376+ # if segment is split
373377 n_seg = segment [i_seg_n ]
374378
375- # if segment is split
376- if i_seg_n ! = - 1 :
377- seg2_slice , i2_seg_p , i2_seg_n = segments_connexion [ n_seg ]
378- p2_seg = segment [i2_seg_p ]
379-
380- # if it merges on the first in a certain time
381- if (p2_seg == seg_corrected ) and (
382- _time [i_seg_n ] - _time [i2_seg_p ] < nb_days_max
383- ):
384- my_slice = slice (i_seg_n , seg2_slice .stop )
385- # correct the factice segment
386- segment [my_slice ] = seg_corrected
387- # correct the good segment
388- segment_copy [my_slice ] = seg_corrected_copy
389- previous_obs [i_seg_n ] = seg_slice .stop - 1
390-
391- segments_connexion [seg_corrected ][0 ] = my_slice
379+ seg2_slice , i2_seg_p , _ = segments_connexion [ n_seg ]
380+ if i2_seg_p = = - 1 :
381+ continue
382+ p2_seg = segment [i2_seg_p ]
383+
384+ # if it merges on the first in a certain time
385+ if (p2_seg == seg_corrected ) and (
386+ _time [i_seg_n ] - _time [i2_seg_p ] < nb_days_max
387+ ):
388+ my_slice = slice (i_seg_n , seg2_slice .stop )
389+ # correct the factice segment
390+ segment [my_slice ] = seg_corrected
391+ # correct the good segment
392+ segment_copy [my_slice ] = seg_corrected_copy
393+ previous_obs [i_seg_n ] = seg_slice .stop - 1
394+
395+ segments_connexion [seg_corrected ][0 ] = my_slice
392396
393397 return self .sort ()
394398
@@ -789,6 +793,8 @@ def display_timeline(
789793 colors_mode = colors_mode ,
790794 )
791795 )
796+ if field is not None :
797+ field = self .parse_varname (field )
792798 for i , b0 , b1 in self .iter_on ("segment" ):
793799 x = self .time [i ]
794800 if x .shape [0 ] == 0 :
@@ -797,9 +803,9 @@ def display_timeline(
797803 y = b0 * ones (x .shape )
798804 else :
799805 if method == "all" :
800- y = self [ field ] [i ] * factor
806+ y = field [i ] * factor
801807 else :
802- y = self [ field ] [i ].mean () * ones (x .shape ) * factor
808+ y = field [i ].mean () * ones (x .shape ) * factor
803809
804810 if colors_mode == "roll" :
805811 _color = self .get_color (j )
@@ -825,7 +831,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol
825831
826832 if field is not None and method != "all" :
827833 for i , b0 , _ in self .iter_on ("segment" ):
828- y = self [ field ] [i ]
834+ y = self . parse_varname ( field ) [i ]
829835 if y .shape [0 ] != 0 :
830836 y_seg [b0 ] = y .mean () * factor
831837 mappables = dict ()
@@ -851,7 +857,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol
851857 y0 = b0
852858 else :
853859 if method == "all" :
854- y0 = self [ field ] [i .stop - 1 ] * factor
860+ y0 = self . parse_varname ( field ) [i .stop - 1 ] * factor
855861 else :
856862 y0 = y_seg [b0 ]
857863 if i_n != - 1 :
@@ -860,7 +866,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol
860866 seg_next
861867 if field is None
862868 else (
863- self [ field ] [i_n ] * factor
869+ self . parse_varname ( field ) [i_n ] * factor
864870 if method == "all"
865871 else y_seg [seg_next ]
866872 )
@@ -876,7 +882,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol
876882 seg_previous
877883 if field is None
878884 else (
879- self [ field ] [i_p ] * factor
885+ self . parse_varname ( field ) [i_p ] * factor
880886 if method == "all"
881887 else y_seg [seg_previous ]
882888 )
@@ -1446,35 +1452,54 @@ def remove_dead_end(self, nobs=3, ndays=0, recursive=0, mask=None):
14461452 .. warning::
14471453 It will remove short segment that splits from then merges with the same segment
14481454 """
1449- segments_keep = list ()
14501455 connexions = self .connexions (multi_network = True )
1451- t = self .time
1452- for i , b0 , _ in self .iter_on (self .segment_track_array ):
1453- if mask and mask [i ].any ():
1454- segments_keep .append (b0 )
1455- continue
1456- nb = i .stop - i .start
1457- dt = t [i .stop - 1 ] - t [i .start ]
1458- if (nb < nobs or dt < ndays ) and len (connexions .get (b0 , tuple ())) < 2 :
1459- continue
1460- segments_keep .append (b0 )
1456+ i0 , i1 , _ = self .index_segment_track
1457+ dt = self .time [i1 - 1 ] - self .time [i0 ] + 1
1458+ nb = i1 - i0
1459+ m = (dt >= ndays ) * (nb >= nobs )
1460+ nb_connexions = array ([len (connexions .get (i , tuple ())) for i in where (~ m )[0 ]])
1461+ m [~ m ] = nb_connexions >= 2
1462+ segments_keep = where (m )[0 ]
1463+ if mask is not None :
1464+ segments_keep = unique (concatenate ((segments_keep , self .segment_track_array [mask ])))
1465+ # get mask for selected obs
1466+ m = ~ self .segment_mask (segments_keep )
1467+ self .track [m ] = 0
1468+ self .segment [m ] = 0
1469+ self .previous_obs [m ] = - 1
1470+ self .previous_cost [m ] = 0
1471+ self .next_obs [m ] = - 1
1472+ self .next_cost [m ] = 0
1473+
1474+ m_previous = m [self .previous_obs ]
1475+ self .previous_obs [m_previous ] = - 1
1476+ self .previous_cost [m_previous ] = 0
1477+ m_next = m [self .next_obs ]
1478+ self .next_obs [m_next ] = - 1
1479+ self .next_cost [m_next ] = 0
1480+
1481+ self .sort ()
14611482 if recursive > 0 :
1462- return self .extract_segment (segments_keep , absolute = True ).remove_dead_end (
1463- nobs , ndays , recursive - 1
1464- )
1465- return self .extract_segment (segments_keep , absolute = True )
1483+ self .remove_dead_end (nobs , ndays , recursive - 1 )
14661484
14671485 def extract_segment (self , segments , absolute = False ):
1468- mask = ones (self .shape , dtype = "bool" )
1469- segments = array (segments )
1470- values = self .segment_track_array if absolute else "segment"
1471- keep = ones (values .max () + 1 , dtype = "bool" )
1472- v = unique (values )
1473- keep [v ] = in1d (v , segments )
1474- for i , b0 , b1 in self .iter_on (values ):
1475- if not keep [b0 ]:
1476- mask [i ] = False
1477- return self .extract_with_mask (mask )
1486+ """Extract given segments
1487+
1488+ :param array,tuple,list segments: list of segment to extract
1489+ :param bool absolute: keep for compatibility, defaults to False
1490+ :return NetworkObservations: Return observations from selected segment
1491+ """
1492+ if not absolute :
1493+ raise Exception ("Not implemented" )
1494+ return self .extract_with_mask (self .segment_mask (segments ))
1495+
1496+ def segment_mask (self , segments ):
1497+ """Get mask from list of segment
1498+
1499+ :param list,array segments: absolute id of segment
1500+ """
1501+ return generate_mask_from_ids (array (segments ), len (self ), * self .index_segment_track )
1502+
14781503
14791504 def get_mask_with_period (self , period ):
14801505 """
0 commit comments