@@ -117,13 +117,14 @@ def __repr__(self):
117
117
m_event , s_event = self .merging_event (only_index = True , triplet = True )[0 ], self .splitting_event (only_index = True , triplet = True )[0 ]
118
118
period = (self .period [1 ] - self .period [0 ]) / 365.25
119
119
nb_by_network = self .network_size ()
120
+ nb_trash = 0 if self .ref_index != 0 else nb_by_network [0 ]
120
121
big = 50_000
121
122
infos = [
122
123
f"Atlas with { self .nb_network } networks ({ self .nb_network / period :0.0f} networks/year),"
123
124
f" { self .nb_segment } segments ({ self .nb_segment / period :0.0f} segments/year), { len (self )} observations ({ len (self ) / period :0.0f} observations/year)" ,
124
125
f" { m_event .size } merging ({ m_event .size / period :0.0f} merging/year), { s_event .size } splitting ({ s_event .size / period :0.0f} splitting/year)" ,
125
126
f" with { (nb_by_network > big ).sum ()} network with more than { big } obs and the biggest have { nb_by_network .max ()} observations ({ nb_by_network [nb_by_network > big ].sum ()} observations cumulate)" ,
126
- f" { nb_by_network [ 0 ] } observations in trash"
127
+ f" { nb_trash } observations in trash"
127
128
]
128
129
return "\n " .join (infos )
129
130
@@ -369,26 +370,29 @@ def correct_close_events(self, nb_days_max=20):
369
370
370
371
# we keep the real segment number
371
372
seg_corrected_copy = segment_copy [seg_slice .stop - 1 ]
373
+ if i_seg_n == - 1 :
374
+ continue
372
375
376
+ # if segment is split
373
377
n_seg = segment [i_seg_n ]
374
378
375
- # if segment is split
376
- if i_seg_n ! = - 1 :
377
- seg2_slice , i2_seg_p , i2_seg_n = segments_connexion [ n_seg ]
378
- p2_seg = segment [i2_seg_p ]
379
-
380
- # if it merges on the first in a certain time
381
- if (p2_seg == seg_corrected ) and (
382
- _time [i_seg_n ] - _time [i2_seg_p ] < nb_days_max
383
- ):
384
- my_slice = slice (i_seg_n , seg2_slice .stop )
385
- # correct the factice segment
386
- segment [my_slice ] = seg_corrected
387
- # correct the good segment
388
- segment_copy [my_slice ] = seg_corrected_copy
389
- previous_obs [i_seg_n ] = seg_slice .stop - 1
390
-
391
- segments_connexion [seg_corrected ][0 ] = my_slice
379
+ seg2_slice , i2_seg_p , _ = segments_connexion [ n_seg ]
380
+ if i2_seg_p = = - 1 :
381
+ continue
382
+ p2_seg = segment [i2_seg_p ]
383
+
384
+ # if it merges on the first in a certain time
385
+ if (p2_seg == seg_corrected ) and (
386
+ _time [i_seg_n ] - _time [i2_seg_p ] < nb_days_max
387
+ ):
388
+ my_slice = slice (i_seg_n , seg2_slice .stop )
389
+ # correct the factice segment
390
+ segment [my_slice ] = seg_corrected
391
+ # correct the good segment
392
+ segment_copy [my_slice ] = seg_corrected_copy
393
+ previous_obs [i_seg_n ] = seg_slice .stop - 1
394
+
395
+ segments_connexion [seg_corrected ][0 ] = my_slice
392
396
393
397
return self .sort ()
394
398
@@ -789,6 +793,8 @@ def display_timeline(
789
793
colors_mode = colors_mode ,
790
794
)
791
795
)
796
+ if field is not None :
797
+ field = self .parse_varname (field )
792
798
for i , b0 , b1 in self .iter_on ("segment" ):
793
799
x = self .time [i ]
794
800
if x .shape [0 ] == 0 :
@@ -797,9 +803,9 @@ def display_timeline(
797
803
y = b0 * ones (x .shape )
798
804
else :
799
805
if method == "all" :
800
- y = self [ field ] [i ] * factor
806
+ y = field [i ] * factor
801
807
else :
802
- y = self [ field ] [i ].mean () * ones (x .shape ) * factor
808
+ y = field [i ].mean () * ones (x .shape ) * factor
803
809
804
810
if colors_mode == "roll" :
805
811
_color = self .get_color (j )
@@ -825,7 +831,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol
825
831
826
832
if field is not None and method != "all" :
827
833
for i , b0 , _ in self .iter_on ("segment" ):
828
- y = self [ field ] [i ]
834
+ y = self . parse_varname ( field ) [i ]
829
835
if y .shape [0 ] != 0 :
830
836
y_seg [b0 ] = y .mean () * factor
831
837
mappables = dict ()
@@ -851,7 +857,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol
851
857
y0 = b0
852
858
else :
853
859
if method == "all" :
854
- y0 = self [ field ] [i .stop - 1 ] * factor
860
+ y0 = self . parse_varname ( field ) [i .stop - 1 ] * factor
855
861
else :
856
862
y0 = y_seg [b0 ]
857
863
if i_n != - 1 :
@@ -860,7 +866,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol
860
866
seg_next
861
867
if field is None
862
868
else (
863
- self [ field ] [i_n ] * factor
869
+ self . parse_varname ( field ) [i_n ] * factor
864
870
if method == "all"
865
871
else y_seg [seg_next ]
866
872
)
@@ -876,7 +882,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol
876
882
seg_previous
877
883
if field is None
878
884
else (
879
- self [ field ] [i_p ] * factor
885
+ self . parse_varname ( field ) [i_p ] * factor
880
886
if method == "all"
881
887
else y_seg [seg_previous ]
882
888
)
@@ -1446,35 +1452,54 @@ def remove_dead_end(self, nobs=3, ndays=0, recursive=0, mask=None):
1446
1452
.. warning::
1447
1453
It will remove short segment that splits from then merges with the same segment
1448
1454
"""
1449
- segments_keep = list ()
1450
1455
connexions = self .connexions (multi_network = True )
1451
- t = self .time
1452
- for i , b0 , _ in self .iter_on (self .segment_track_array ):
1453
- if mask and mask [i ].any ():
1454
- segments_keep .append (b0 )
1455
- continue
1456
- nb = i .stop - i .start
1457
- dt = t [i .stop - 1 ] - t [i .start ]
1458
- if (nb < nobs or dt < ndays ) and len (connexions .get (b0 , tuple ())) < 2 :
1459
- continue
1460
- segments_keep .append (b0 )
1456
+ i0 , i1 , _ = self .index_segment_track
1457
+ dt = self .time [i1 - 1 ] - self .time [i0 ] + 1
1458
+ nb = i1 - i0
1459
+ m = (dt >= ndays ) * (nb >= nobs )
1460
+ nb_connexions = array ([len (connexions .get (i , tuple ())) for i in where (~ m )[0 ]])
1461
+ m [~ m ] = nb_connexions >= 2
1462
+ segments_keep = where (m )[0 ]
1463
+ if mask is not None :
1464
+ segments_keep = unique (concatenate ((segments_keep , self .segment_track_array [mask ])))
1465
+ # get mask for selected obs
1466
+ m = ~ self .segment_mask (segments_keep )
1467
+ self .track [m ] = 0
1468
+ self .segment [m ] = 0
1469
+ self .previous_obs [m ] = - 1
1470
+ self .previous_cost [m ] = 0
1471
+ self .next_obs [m ] = - 1
1472
+ self .next_cost [m ] = 0
1473
+
1474
+ m_previous = m [self .previous_obs ]
1475
+ self .previous_obs [m_previous ] = - 1
1476
+ self .previous_cost [m_previous ] = 0
1477
+ m_next = m [self .next_obs ]
1478
+ self .next_obs [m_next ] = - 1
1479
+ self .next_cost [m_next ] = 0
1480
+
1481
+ self .sort ()
1461
1482
if recursive > 0 :
1462
- return self .extract_segment (segments_keep , absolute = True ).remove_dead_end (
1463
- nobs , ndays , recursive - 1
1464
- )
1465
- return self .extract_segment (segments_keep , absolute = True )
1483
+ self .remove_dead_end (nobs , ndays , recursive - 1 )
1466
1484
1467
1485
def extract_segment (self , segments , absolute = False ):
1468
- mask = ones (self .shape , dtype = "bool" )
1469
- segments = array (segments )
1470
- values = self .segment_track_array if absolute else "segment"
1471
- keep = ones (values .max () + 1 , dtype = "bool" )
1472
- v = unique (values )
1473
- keep [v ] = in1d (v , segments )
1474
- for i , b0 , b1 in self .iter_on (values ):
1475
- if not keep [b0 ]:
1476
- mask [i ] = False
1477
- return self .extract_with_mask (mask )
1486
+ """Extract given segments
1487
+
1488
+ :param array,tuple,list segments: list of segment to extract
1489
+ :param bool absolute: keep for compatibility, defaults to False
1490
+ :return NetworkObservations: Return observations from selected segment
1491
+ """
1492
+ if not absolute :
1493
+ raise Exception ("Not implemented" )
1494
+ return self .extract_with_mask (self .segment_mask (segments ))
1495
+
1496
+ def segment_mask (self , segments ):
1497
+ """Get mask from list of segment
1498
+
1499
+ :param list,array segments: absolute id of segment
1500
+ """
1501
+ return generate_mask_from_ids (array (segments ), len (self ), * self .index_segment_track )
1502
+
1478
1503
1479
1504
def get_mask_with_period (self , period ):
1480
1505
"""
0 commit comments