@@ -69,16 +69,20 @@ def __init__(self, *args, **kwargs):
69
69
super ().__init__ (* args , ** kwargs )
70
70
self ._index_network = None
71
71
72
+ @property
73
+ def index_network (self ):
74
+ if self ._index_network is None :
75
+ self ._index_network = build_index (self .track )
76
+ return self ._index_network
77
+
72
78
def network_slice (self , id_network ):
73
79
"""
74
80
Return slice for one network
75
81
76
82
:param int id_network: id to identify network
77
83
"""
78
- if self ._index_network is None :
79
- self ._index_network = build_index (self .track )
80
- i = id_network - self ._index_network [2 ]
81
- i_start , i_stop = self ._index_network [0 ][i ], self ._index_network [1 ][i ]
84
+ i = id_network - self .index_network [2 ]
85
+ i_start , i_stop = self .index_network [0 ][i ], self .index_network [1 ][i ]
82
86
return slice (i_start , i_stop )
83
87
84
88
@property
@@ -228,8 +232,9 @@ def only_one_network(self):
228
232
Raise a warning or error?
229
233
if there are more than one network
230
234
"""
231
- # TODO
232
- pass
235
+ _ , i_start , _ = self .index_network
236
+ if len (i_start ) > 1 :
237
+ raise Exception ("Several network" )
233
238
234
239
def position_filter (self , median_half_window , loess_half_window ):
235
240
self .median_filter (median_half_window , "time" , "lon" ).loess_filter (
@@ -568,17 +573,19 @@ def dissociate_network(self):
568
573
# FIXME : Ok if only one network
569
574
self .track [:] = tags [self .segment - 1 ]
570
575
571
- self .obs .sort (order = ("track" , "segment" , "time" ))
576
+ i_sort = self .obs .argsort (order = ("track" , "segment" , "time" ), kind = "mergesort" )
577
+ # Sort directly obs, with hope to save memory
578
+ self .obs .sort (order = ("track" , "segment" , "time" ), kind = "mergesort" )
572
579
self ._index_network = None
573
580
574
- # FIXME
575
581
# n & p must be re-index
576
- # n, p = self.next_obs[mask] , self.previous_obs[mask]
582
+ n , p = self .next_obs , self .previous_obs
577
583
# we add 2 for -1 index return index -1
578
- # translate = -ones(len(self) + 1, dtype="i4")
579
- # translate[:-1][mask] = arange(nb_obs)
580
- # new.next_obs[:] = translate[n]
581
- # new.previous_obs[:] = translate[p]
584
+ nb_obs = len (self )
585
+ translate = - ones (nb_obs + 1 , dtype = "i4" )
586
+ translate [:- 1 ][i_sort ] = arange (nb_obs )
587
+ self .next_obs [:] = translate [n ]
588
+ self .previous_obs [:] = translate [p ]
582
589
583
590
def network (self , id_network ):
584
591
return self .extract_with_mask (self .network_slice (id_network ))
@@ -640,40 +647,27 @@ def plot(self, ax, ref=None, color_cycle=None, **kwargs):
640
647
j += 1
641
648
return mappables
642
649
643
- def remove_dead_branch (self , nobs = 3 ):
644
- """"""
645
- # TODO: bug when spliting
650
+ def remove_dead_end (self , nobs = 3 , recursive = 0 , mask = None ):
651
+ """
652
+ .. warning::
653
+ It will remove short segment which splits than merges with same segment
654
+ """
646
655
self .only_one_network ()
647
-
648
656
segments_keep = list ()
649
- interaction_segments = dict ()
650
- segments_connexion = dict ()
657
+ connexions = self .connexions ()
651
658
for i , b0 , b1 in self .iter_on ("segment" ):
652
659
nb = i .stop - i .start
653
- i_p , i_n = self .previous_obs [i .start ], self .next_obs [i .stop - 1 ]
654
- seg = self .segment [i .start ]
655
- # segment of interaction
656
- p_seg , n_seg = self .segment [i_p ], self .segment [i_n ]
657
- if nb >= nobs :
658
- segments_keep .append (seg )
659
- else :
660
- interaction_segments [seg ] = (
661
- p_seg if i_p != - 1 else - 1 ,
662
- n_seg if i_n != - 1 else - 1 ,
663
- )
664
- # Where segment are called
665
- if i_p != - 1 :
666
- if p_seg not in segments_connexion :
667
- segments_connexion [p_seg ] = list ()
668
- segments_connexion [p_seg ].append (seg )
669
- if i_n != - 1 :
670
- if n_seg not in segments_connexion :
671
- segments_connexion [n_seg ] = list ()
672
- segments_connexion [n_seg ].append (seg )
673
- print (interaction_segments )
674
- print (segments_connexion )
675
- print (segments_keep )
676
- return self .extract_segment (tuple (segments_keep ))
660
+ if mask and mask [i ].any ():
661
+ segments_keep .append (b0 )
662
+ continue
663
+ if nb < nobs and len (connexions .get (b0 , tuple ())) < 2 :
664
+ continue
665
+ segments_keep .append (b0 )
666
+ if recursive > 0 :
667
+ return self .extract_segment (segments_keep ).remove_dead_end (
668
+ nobs , recursive - 1
669
+ )
670
+ return self .extract_segment (segments_keep )
677
671
678
672
def extract_segment (self , segments ):
679
673
mask = ones (self .shape , dtype = "bool" )
0 commit comments