@@ -166,6 +166,9 @@ def obs_relative_order(self, i_obs):
166166 return self .segment_relative_order (self .segment [i_obs ])
167167
168168 def connexions (self , multi_network = False ):
169+ """
170+ create dictionnary for each segments, gives the segments which interact with
171+ """
169172 if multi_network :
170173 segment = self .segment_track_array
171174 else :
@@ -650,10 +653,12 @@ def dissociate_network(self):
650653 """
651654 Dissociate network with no known interaction (spliting/merging)
652655 """
653- self .only_one_network ()
654- tags = self .tag_segment ()
655- # FIXME : Ok if only one network
656- self .track [:] = tags [self .segment - 1 ]
656+
657+ tags = self .tag_segment (multi_network = True )
658+ if self .track [0 ] == 0 :
659+ tags -= 1
660+
661+ self .track [:] = tags [self .segment_track_array ]
657662
658663 i_sort = self .obs .argsort (order = ("track" , "segment" , "time" ), kind = "mergesort" )
659664 # Sort directly obs, with hope to save memory
@@ -674,23 +679,40 @@ def network(self, id_network):
674679
675680 @classmethod
676681 def __tag_segment (cls , seg , tag , groups , connexions ):
682+ """
683+ Will set same temporary ID for each connected segment.
684+
685+ :param int seg: current ID of seg
686+ :param ing tag: temporary ID to set for seg and its connexion
687+ :param array[int] groups: array where tag will be stored
688+ :param dict connexions: gives for one ID of seg all seg connected
689+ """
690+ # If seg are already used we stop recursivity
677691 if groups [seg ] != 0 :
678692 return
693+ # We set tag for this seg
679694 groups [seg ] = tag
680- segs = connexions .get (seg + 1 , None )
695+ # Get all connexions of this seg
696+ segs = connexions .get (seg , None )
681697 if segs is not None :
682698 for seg in segs :
683- cls .__tag_segment (seg - 1 , tag , groups , connexions )
699+ # For each connexion we apply same function
700+ cls .__tag_segment (seg , tag , groups , connexions )
684701
685- def tag_segment (self ):
686- self .only_one_network ()
687- nb = self .segment .max ()
702+ def tag_segment (self , multi_network = False ):
703+ if multi_network :
704+ nb = self .segment_track_array [- 1 ] + 1
705+ else :
706+ nb = self .segment .max () + 1
688707 sub_group = zeros (nb , dtype = "u4" )
689- c = self .connexions ()
708+ c = self .connexions (multi_network = multi_network )
690709 j = 1
710+ # for each available id
691711 for i in range (nb ):
712+ # Skip if already set
692713 if sub_group [i ] != 0 :
693714 continue
715+ # we tag an unset segments and explore all connexions
694716 self .__tag_segment (i , j , sub_group , c )
695717 j += 1
696718 return sub_group
0 commit comments