@@ -69,16 +69,20 @@ def __init__(self, *args, **kwargs):
6969 super ().__init__ (* args , ** kwargs )
7070 self ._index_network = None
7171
72+ @property
73+ def index_network (self ):
74+ if self ._index_network is None :
75+ self ._index_network = build_index (self .track )
76+ return self ._index_network
77+
7278 def network_slice (self , id_network ):
7379 """
7480 Return slice for one network
7581
7682 :param int id_network: id to identify network
7783 """
78- if self ._index_network is None :
79- self ._index_network = build_index (self .track )
80- i = id_network - self ._index_network [2 ]
81- i_start , i_stop = self ._index_network [0 ][i ], self ._index_network [1 ][i ]
84+ i = id_network - self .index_network [2 ]
85+ i_start , i_stop = self .index_network [0 ][i ], self .index_network [1 ][i ]
8286 return slice (i_start , i_stop )
8387
8488 @property
@@ -228,8 +232,9 @@ def only_one_network(self):
228232 Raise a warning or error?
229233 if there are more than one network
230234 """
231- # TODO
232- pass
235+ _ , i_start , _ = self .index_network
236+ if len (i_start ) > 1 :
237+ raise Exception ("Several network" )
233238
234239 def position_filter (self , median_half_window , loess_half_window ):
235240 self .median_filter (median_half_window , "time" , "lon" ).loess_filter (
@@ -568,17 +573,19 @@ def dissociate_network(self):
568573 # FIXME : Ok if only one network
569574 self .track [:] = tags [self .segment - 1 ]
570575
571- self .obs .sort (order = ("track" , "segment" , "time" ))
576+ i_sort = self .obs .argsort (order = ("track" , "segment" , "time" ), kind = "mergesort" )
577+ # Sort directly obs, with hope to save memory
578+ self .obs .sort (order = ("track" , "segment" , "time" ), kind = "mergesort" )
572579 self ._index_network = None
573580
574- # FIXME
575581 # n & p must be re-index
576- # n, p = self.next_obs[mask] , self.previous_obs[mask]
582+ n , p = self .next_obs , self .previous_obs
577583 # we add 2 for -1 index return index -1
578- # translate = -ones(len(self) + 1, dtype="i4")
579- # translate[:-1][mask] = arange(nb_obs)
580- # new.next_obs[:] = translate[n]
581- # new.previous_obs[:] = translate[p]
584+ nb_obs = len (self )
585+ translate = - ones (nb_obs + 1 , dtype = "i4" )
586+ translate [:- 1 ][i_sort ] = arange (nb_obs )
587+ self .next_obs [:] = translate [n ]
588+ self .previous_obs [:] = translate [p ]
582589
583590 def network (self , id_network ):
584591 return self .extract_with_mask (self .network_slice (id_network ))
@@ -640,40 +647,27 @@ def plot(self, ax, ref=None, color_cycle=None, **kwargs):
640647 j += 1
641648 return mappables
642649
643- def remove_dead_branch (self , nobs = 3 ):
644- """"""
645- # TODO: bug when spliting
650+ def remove_dead_end (self , nobs = 3 , recursive = 0 , mask = None ):
651+ """
652+ .. warning::
653+ It will remove short segment which splits than merges with same segment
654+ """
646655 self .only_one_network ()
647-
648656 segments_keep = list ()
649- interaction_segments = dict ()
650- segments_connexion = dict ()
657+ connexions = self .connexions ()
651658 for i , b0 , b1 in self .iter_on ("segment" ):
652659 nb = i .stop - i .start
653- i_p , i_n = self .previous_obs [i .start ], self .next_obs [i .stop - 1 ]
654- seg = self .segment [i .start ]
655- # segment of interaction
656- p_seg , n_seg = self .segment [i_p ], self .segment [i_n ]
657- if nb >= nobs :
658- segments_keep .append (seg )
659- else :
660- interaction_segments [seg ] = (
661- p_seg if i_p != - 1 else - 1 ,
662- n_seg if i_n != - 1 else - 1 ,
663- )
664- # Where segment are called
665- if i_p != - 1 :
666- if p_seg not in segments_connexion :
667- segments_connexion [p_seg ] = list ()
668- segments_connexion [p_seg ].append (seg )
669- if i_n != - 1 :
670- if n_seg not in segments_connexion :
671- segments_connexion [n_seg ] = list ()
672- segments_connexion [n_seg ].append (seg )
673- print (interaction_segments )
674- print (segments_connexion )
675- print (segments_keep )
676- return self .extract_segment (tuple (segments_keep ))
660+ if mask and mask [i ].any ():
661+ segments_keep .append (b0 )
662+ continue
663+ if nb < nobs and len (connexions .get (b0 , tuple ())) < 2 :
664+ continue
665+ segments_keep .append (b0 )
666+ if recursive > 0 :
667+ return self .extract_segment (segments_keep ).remove_dead_end (
668+ nobs , recursive - 1
669+ )
670+ return self .extract_segment (segments_keep )
677671
678672 def extract_segment (self , segments ):
679673 mask = ones (self .shape , dtype = "bool" )
0 commit comments