Skip to content

Commit d7656c3

Browse files
Modification dissociate (#56)
* update animation example * dissociate_network for multiple networks
1 parent 98b19f7 commit d7656c3

File tree

2 files changed

+45
-11
lines changed

2 files changed

+45
-11
lines changed

examples/16_network/pet_segmentation_anim.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,14 @@ def save(self, *args, **kwargs):
3737
# %%
3838
# Overlaod of class to pick up
3939
TRACKS = list()
40+
INDICES = list()
4041

4142

4243
class MyTrack(TrackEddiesObservations):
4344
@staticmethod
4445
def get_next_obs(i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs):
4546
TRACKS.append(ids["track"].copy())
47+
INDICES.append(i_current)
4648
return TrackEddiesObservations.get_next_obs(
4749
i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs
4850
)
@@ -70,9 +72,16 @@ def get_next_obs(i_current, ids, x, y, time_s, time_e, time_ref, window, **kwarg
7072
def update(i_frame):
7173
tr = TRACKS[i_frame]
7274
mappable_tracks.set_array(tr)
73-
s = 80 * ones(tr.shape)
75+
s = 40 * ones(tr.shape)
7476
s[tr == 0] = 4
7577
mappable_tracks.set_sizes(s)
78+
79+
indices_frames = INDICES[i_frame]
80+
mappable_CONTOUR.set_data(
81+
e.contour_lon_e[indices_frames],
82+
e.contour_lat_e[indices_frames],
83+
)
84+
mappable_CONTOUR.set_color(cmap.colors[tr[indices_frames] % len(cmap.colors)])
7685
return (mappable_tracks,)
7786

7887

@@ -85,4 +94,7 @@ def update(i_frame):
8594
mappable_tracks = ax.scatter(
8695
e.lon, e.lat, c=TRACKS[0], cmap=cmap, vmin=0, vmax=vmax, s=20
8796
)
97+
mappable_CONTOUR = ax.plot(
98+
e.contour_lon_e[INDICES[0]], e.contour_lat_e[INDICES[0]], color=cmap.colors[0]
99+
)[0]
88100
ani = VideoAnimation(fig, update, frames=range(1, len(TRACKS), 4), interval=125)

src/py_eddy_tracker/observations/network.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ def obs_relative_order(self, i_obs):
166166
return self.segment_relative_order(self.segment[i_obs])
167167

168168
def connexions(self, multi_network=False):
169+
"""
170+
create dictionnary for each segments, gives the segments which interact with
171+
"""
169172
if multi_network:
170173
segment = self.segment_track_array
171174
else:
@@ -650,10 +653,12 @@ def dissociate_network(self):
650653
"""
651654
Dissociate network with no known interaction (spliting/merging)
652655
"""
653-
self.only_one_network()
654-
tags = self.tag_segment()
655-
# FIXME : Ok if only one network
656-
self.track[:] = tags[self.segment - 1]
656+
657+
tags = self.tag_segment(multi_network=True)
658+
if self.track[0] == 0:
659+
tags -= 1
660+
661+
self.track[:] = tags[self.segment_track_array]
657662

658663
i_sort = self.obs.argsort(order=("track", "segment", "time"), kind="mergesort")
659664
# Sort directly obs, with hope to save memory
@@ -674,23 +679,40 @@ def network(self, id_network):
674679

675680
@classmethod
676681
def __tag_segment(cls, seg, tag, groups, connexions):
682+
"""
683+
Will set same temporary ID for each connected segment.
684+
685+
:param int seg: current ID of seg
686+
:param ing tag: temporary ID to set for seg and its connexion
687+
:param array[int] groups: array where tag will be stored
688+
:param dict connexions: gives for one ID of seg all seg connected
689+
"""
690+
# If seg are already used we stop recursivity
677691
if groups[seg] != 0:
678692
return
693+
# We set tag for this seg
679694
groups[seg] = tag
680-
segs = connexions.get(seg + 1, None)
695+
# Get all connexions of this seg
696+
segs = connexions.get(seg, None)
681697
if segs is not None:
682698
for seg in segs:
683-
cls.__tag_segment(seg - 1, tag, groups, connexions)
699+
# For each connexion we apply same function
700+
cls.__tag_segment(seg, tag, groups, connexions)
684701

685-
def tag_segment(self):
686-
self.only_one_network()
687-
nb = self.segment.max()
702+
def tag_segment(self, multi_network=False):
703+
if multi_network:
704+
nb = self.segment_track_array[-1] + 1
705+
else:
706+
nb = self.segment.max() + 1
688707
sub_group = zeros(nb, dtype="u4")
689-
c = self.connexions()
708+
c = self.connexions(multi_network=multi_network)
690709
j = 1
710+
# for each available id
691711
for i in range(nb):
712+
# Skip if already set
692713
if sub_group[i] != 0:
693714
continue
715+
# we tag an unset segments and explore all connexions
694716
self.__tag_segment(i, j, sub_group, c)
695717
j += 1
696718
return sub_group

0 commit comments

Comments
 (0)