Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion examples/16_network/pet_segmentation_anim.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@ def save(self, *args, **kwargs):
# %%
# Overlaod of class to pick up
TRACKS = list()
INDICES = list()


class MyTrack(TrackEddiesObservations):
@staticmethod
def get_next_obs(i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs):
TRACKS.append(ids["track"].copy())
INDICES.append(i_current)
return TrackEddiesObservations.get_next_obs(
i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs
)
Expand Down Expand Up @@ -70,9 +72,16 @@ def get_next_obs(i_current, ids, x, y, time_s, time_e, time_ref, window, **kwarg
def update(i_frame):
tr = TRACKS[i_frame]
mappable_tracks.set_array(tr)
s = 80 * ones(tr.shape)
s = 40 * ones(tr.shape)
s[tr == 0] = 4
mappable_tracks.set_sizes(s)

indices_frames = INDICES[i_frame]
mappable_CONTOUR.set_data(
e.contour_lon_e[indices_frames],
e.contour_lat_e[indices_frames],
)
mappable_CONTOUR.set_color(cmap.colors[tr[indices_frames] % len(cmap.colors)])
return (mappable_tracks,)


Expand All @@ -85,4 +94,7 @@ def update(i_frame):
mappable_tracks = ax.scatter(
e.lon, e.lat, c=TRACKS[0], cmap=cmap, vmin=0, vmax=vmax, s=20
)
mappable_CONTOUR = ax.plot(
e.contour_lon_e[INDICES[0]], e.contour_lat_e[INDICES[0]], color=cmap.colors[0]
)[0]
ani = VideoAnimation(fig, update, frames=range(1, len(TRACKS), 4), interval=125)
42 changes: 32 additions & 10 deletions src/py_eddy_tracker/observations/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ def obs_relative_order(self, i_obs):
return self.segment_relative_order(self.segment[i_obs])

def connexions(self, multi_network=False):
"""
create dictionnary for each segments, gives the segments which interact with
"""
if multi_network:
segment = self.segment_track_array
else:
Expand Down Expand Up @@ -650,10 +653,12 @@ def dissociate_network(self):
"""
Dissociate network with no known interaction (spliting/merging)
"""
self.only_one_network()
tags = self.tag_segment()
# FIXME : Ok if only one network
self.track[:] = tags[self.segment - 1]

tags = self.tag_segment(multi_network=True)
if self.track[0] == 0:
tags -= 1

self.track[:] = tags[self.segment_track_array]

i_sort = self.obs.argsort(order=("track", "segment", "time"), kind="mergesort")
# Sort directly obs, with hope to save memory
Expand All @@ -674,23 +679,40 @@ def network(self, id_network):

@classmethod
def __tag_segment(cls, seg, tag, groups, connexions):
"""
Will set same temporary ID for each connected segment.

:param int seg: current ID of seg
:param ing tag: temporary ID to set for seg and its connexion
:param array[int] groups: array where tag will be stored
:param dict connexions: gives for one ID of seg all seg connected
"""
# If seg are already used we stop recursivity
if groups[seg] != 0:
return
# We set tag for this seg
groups[seg] = tag
segs = connexions.get(seg + 1, None)
# Get all connexions of this seg
segs = connexions.get(seg, None)
if segs is not None:
for seg in segs:
cls.__tag_segment(seg - 1, tag, groups, connexions)
# For each connexion we apply same function
cls.__tag_segment(seg, tag, groups, connexions)

def tag_segment(self):
self.only_one_network()
nb = self.segment.max()
def tag_segment(self, multi_network=False):
if multi_network:
nb = self.segment_track_array[-1] + 1
else:
nb = self.segment.max() + 1
sub_group = zeros(nb, dtype="u4")
c = self.connexions()
c = self.connexions(multi_network=multi_network)
j = 1
# for each available id
for i in range(nb):
# Skip if already set
if sub_group[i] != 0:
continue
# we tag an unset segments and explore all connexions
self.__tag_segment(i, j, sub_group, c)
j += 1
return sub_group
Expand Down