@@ -164,6 +164,9 @@ def obs_relative_order(self, i_obs):
164164 return self .segment_relative_order (self .segment [i_obs ])
165165
166166 def connexions (self , multi_network = False ):
167+ """
168+ create dictionnary for each segments, gives the segments which interact with
169+ """
167170 if multi_network :
168171 segment = self .segment_track_array
169172 else :
@@ -648,10 +651,12 @@ def dissociate_network(self):
648651 """
649652 Dissociate network with no known interaction (spliting/merging)
650653 """
651- self .only_one_network ()
652- tags = self .tag_segment ()
653- # FIXME : Ok if only one network
654- self .track [:] = tags [self .segment - 1 ]
654+
655+ tags = self .tag_segment (multi_network = True )
656+ if self .track [0 ] == 0 :
657+ tags -= 1
658+
659+ self .track [:] = tags [self .segment_track_array ]
655660
656661 i_sort = self .obs .argsort (order = ("track" , "segment" , "time" ), kind = "mergesort" )
657662 # Sort directly obs, with hope to save memory
@@ -672,23 +677,40 @@ def network(self, id_network):
672677
673678 @classmethod
674679 def __tag_segment (cls , seg , tag , groups , connexions ):
680+ """
681+ Will set same temporary ID for each connected segment.
682+
683+ :param int seg: current ID of seg
684+ :param ing tag: temporary ID to set for seg and its connexion
685+ :param array[int] groups: array where tag will be stored
686+ :param dict connexions: gives for one ID of seg all seg connected
687+ """
688+ # If seg are already used we stop recursivity
675689 if groups [seg ] != 0 :
676690 return
691+ # We set tag for this seg
677692 groups [seg ] = tag
678- segs = connexions .get (seg + 1 , None )
693+ # Get all connexions of this seg
694+ segs = connexions .get (seg , None )
679695 if segs is not None :
680696 for seg in segs :
681- cls .__tag_segment (seg - 1 , tag , groups , connexions )
697+ # For each connexion we apply same function
698+ cls .__tag_segment (seg , tag , groups , connexions )
682699
683- def tag_segment (self ):
684- self .only_one_network ()
685- nb = self .segment .max ()
700+ def tag_segment (self , multi_network = False ):
701+ if multi_network :
702+ nb = self .segment_track_array [- 1 ] + 1
703+ else :
704+ nb = self .segment .max () + 1
686705 sub_group = zeros (nb , dtype = "u4" )
687- c = self .connexions ()
706+ c = self .connexions (multi_network = multi_network )
688707 j = 1
708+ # for each available id
689709 for i in range (nb ):
710+ # Skip if already set
690711 if sub_group [i ] != 0 :
691712 continue
713+ # we tag an unset segments and explore all connexions
692714 self .__tag_segment (i , j , sub_group , c )
693715 j += 1
694716 return sub_group
0 commit comments