Skip to content

Commit e08d16a

Browse files
dissociate_network for multiple networks
1 parent 63e49bf commit e08d16a

File tree

1 file changed

+32
-10
lines changed

1 file changed

+32
-10
lines changed

src/py_eddy_tracker/observations/network.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@ def obs_relative_order(self, i_obs):
164164
return self.segment_relative_order(self.segment[i_obs])
165165

166166
def connexions(self, multi_network=False):
167+
"""
168+
create dictionnary for each segments, gives the segments which interact with
169+
"""
167170
if multi_network:
168171
segment = self.segment_track_array
169172
else:
@@ -648,10 +651,12 @@ def dissociate_network(self):
648651
"""
649652
Dissociate network with no known interaction (spliting/merging)
650653
"""
651-
self.only_one_network()
652-
tags = self.tag_segment()
653-
# FIXME : Ok if only one network
654-
self.track[:] = tags[self.segment - 1]
654+
655+
tags = self.tag_segment(multi_network=True)
656+
if self.track[0] == 0:
657+
tags -= 1
658+
659+
self.track[:] = tags[self.segment_track_array]
655660

656661
i_sort = self.obs.argsort(order=("track", "segment", "time"), kind="mergesort")
657662
# Sort directly obs, with hope to save memory
@@ -672,23 +677,40 @@ def network(self, id_network):
672677

673678
@classmethod
674679
def __tag_segment(cls, seg, tag, groups, connexions):
680+
"""
681+
Will set same temporary ID for each connected segment.
682+
683+
:param int seg: current ID of seg
684+
:param ing tag: temporary ID to set for seg and its connexion
685+
:param array[int] groups: array where tag will be stored
686+
:param dict connexions: gives for one ID of seg all seg connected
687+
"""
688+
# If seg are already used we stop recursivity
675689
if groups[seg] != 0:
676690
return
691+
# We set tag for this seg
677692
groups[seg] = tag
678-
segs = connexions.get(seg + 1, None)
693+
# Get all connexions of this seg
694+
segs = connexions.get(seg, None)
679695
if segs is not None:
680696
for seg in segs:
681-
cls.__tag_segment(seg - 1, tag, groups, connexions)
697+
# For each connexion we apply same function
698+
cls.__tag_segment(seg, tag, groups, connexions)
682699

683-
def tag_segment(self):
684-
self.only_one_network()
685-
nb = self.segment.max()
700+
def tag_segment(self, multi_network=False):
701+
if multi_network:
702+
nb = self.segment_track_array[-1] + 1
703+
else:
704+
nb = self.segment.max() + 1
686705
sub_group = zeros(nb, dtype="u4")
687-
c = self.connexions()
706+
c = self.connexions(multi_network=multi_network)
688707
j = 1
708+
# for each available id
689709
for i in range(nb):
710+
# Skip if already set
690711
if sub_group[i] != 0:
691712
continue
713+
# we tag an unset segments and explore all connexions
692714
self.__tag_segment(i, j, sub_group, c)
693715
j += 1
694716
return sub_group

0 commit comments

Comments
 (0)