Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions examples/16_network/pet_segmentation_anim.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@ def save(self, *args, **kwargs):
# %%
# Overlaod of class to pick up
TRACKS = list()
INDICES = list()


class MyTrack(TrackEddiesObservations):
@staticmethod
def get_next_obs(i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs):
TRACKS.append(ids["track"].copy())
INDICES.append(i_current)
return TrackEddiesObservations.get_next_obs(
i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs
)
Expand Down Expand Up @@ -70,9 +72,13 @@ def get_next_obs(i_current, ids, x, y, time_s, time_e, time_ref, window, **kwarg
def update(i_frame):
tr = TRACKS[i_frame]
mappable_tracks.set_array(tr)
s = 80 * ones(tr.shape)
s = 40 * ones(tr.shape)
s[tr == 0] = 4
mappable_tracks.set_sizes(s)

indices_frames = INDICES[i_frame]
mappable_CONTOUR.set_data(e.contour_lon_e[indices_frames], e.contour_lat_e[indices_frames],)
mappable_CONTOUR.set_color(cmap.colors[tr[indices_frames] % len(cmap.colors)])
return (mappable_tracks,)


Expand All @@ -85,6 +91,9 @@ def update(i_frame):
mappable_tracks = ax.scatter(
e.lon, e.lat, c=TRACKS[0], cmap=cmap, vmin=0, vmax=vmax, s=20
)
mappable_CONTOUR = ax.plot(
e.contour_lon_e[INDICES[0]], e.contour_lat_e[INDICES[0]], color=cmap.colors[0]
)[0]
ani = VideoAnimation(
fig, update, frames=range(1, len(TRACKS), 4), interval=125, blit=True
)
)
118 changes: 34 additions & 84 deletions src/py_eddy_tracker/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def build_index(groups):
:param array groups: array which contain group to be separated
:return: (first_index of each group, last_index of each group, value to shift group)
:rtype: (array, array, int)

[0,0,1,1,1, 2] -> ([0, 2, 5], [2, 5, 6], 0)
"""
i0, i1 = groups.min(), groups.max()
amplitude = i1 - i0 + 1
Expand Down Expand Up @@ -191,106 +193,54 @@ def interp2d_geo(x_g, y_g, z_g, m_g, x, y, nearest=False):
:return: z interpolated
:rtype: array
"""
if nearest:
return interp2d_nearest(x_g, y_g, z_g, x, y)
else:
return interp2d_bilinear(x_g, y_g, z_g, m_g, x, y)


@njit(cache=True, fastmath=True)
def interp2d_nearest(x_g, y_g, z_g, x, y):
"""
Nearest interpolation with wrapping if circular

:param array x_g: coordinates of grid
:param array y_g: coordinates of grid
:param array z_g: Grid value
:param array x: coordinate where interpolate z
:param array y: coordinate where interpolate z
:return: z interpolated
:rtype: array
"""
x_ref = x_g[0]
y_ref = y_g[0]
x_step = x_g[1] - x_ref
y_step = y_g[1] - y_ref
nb_x = x_g.shape[0]
nb_y = y_g.shape[0]
is_circular = abs(x_g[-1] % 360 - (x_g[0] - x_step) % 360) < 1e-5
z = empty(x.shape, dtype=z_g.dtype)
for i in prange(x.size):
i0 = int(round((x[i] - x_ref) / x_step))
j0 = int(round((y[i] - y_ref) / y_step))
if is_circular:
i0 %= nb_x
if i0 >= nb_x or i0 < 0 or j0 < 0 or j0 >= nb_y:
z[i] = nan
continue
z[i] = z_g[i0, j0]
return z


@njit(cache=True, fastmath=True)
def interp2d_bilinear(x_g, y_g, z_g, m_g, x, y):
"""
Bilinear interpolation with wrapping if circular

:param array x_g: coordinates of grid
:param array y_g: coordinates of grid
:param array z_g: Grid value
:param array m_g: Boolean grid, True if value is masked
:param array x: coordinate where interpolate z
:param array y: coordinate where interpolate z
:return: z interpolated
:rtype: array
"""
# TODO : Maybe test if we are out of bounds
x_ref = x_g[0]
y_ref = y_g[0]
x_step = x_g[1] - x_ref
y_step = y_g[1] - y_ref
nb_x = x_g.shape[0]
nb_y = y_g.shape[0]
is_circular = abs(x_g[-1] % 360 - (x_g[0] - x_step) % 360) < 1e-5
# Indices which should be never exist
i0_old, j0_old, masked = -100000000, -10000000, False
z = empty(x.shape, dtype=z_g.dtype)
for i in prange(x.size):
x_ = (x[i] - x_ref) / x_step
y_ = (y[i] - y_ref) / y_step
i0 = int(floor(x_))
i1 = i0 + 1
xd = x_ - i0
j0 = int(floor(y_))
# corner are the same need only a new xd and yd
if i0 != i0_old or j0 != j0_old:
i1 = i0 + 1
j1 = j0 + 1
if is_circular:
i0 %= nb_x
i1 %= nb_x
j1 = j0 + 1
if is_circular:
i0 %= nb_x
i1 %= nb_x
else:
if i1 >= nb_x or i0 < 0 or j0 < 0 or j1 >= nb_y:
masked = True
else:
masked = False
if not masked:
if m_g[i0, j0] or m_g[i0, j1] or m_g[i1, j0] or m_g[i1, j1]:
masked = True
else:
z00, z01, z10, z11 = (
z_g[i0, j0],
z_g[i0, j1],
z_g[i1, j0],
z_g[i1, j1],
)
masked = False
# Need to be store only on change
i0_old, j0_old = i0, j0
if masked:
z[i] = nan
continue

yd = y_ - j0
z00 = z_g[i0, j0]
z01 = z_g[i0, j1]
z10 = z_g[i1, j0]
z11 = z_g[i1, j1]
if m_g[i0, j0] or m_g[i0, j1] or m_g[i1, j0] or m_g[i1, j1]:
z[i] = nan
else:
xd = x_ - i0
yd = y_ - j0
z[i] = (z00 * (1 - xd) + (z10 * xd)) * (1 - yd) + (
z01 * (1 - xd) + z11 * xd
) * yd
if nearest:
if xd <= 0.5:
if yd <= 0.5:
z[i] = z00
else:
z[i] = z01
else:
if yd <= 0.5:
z[i] = z10
else:
z[i] = z11
else:
z[i] = (z00 * (1 - xd) + (z10 * xd)) * (1 - yd) + (
z01 * (1 - xd) + z11 * xd
) * yd
return z


Expand Down
44 changes: 34 additions & 10 deletions src/py_eddy_tracker/observations/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ def obs_relative_order(self, i_obs):
return self.segment_relative_order(self.segment[i_obs])

def connexions(self, multi_network=False):
"""
create dictionnary for each segments, gives the segments which interact with
"""
if multi_network:
segment = self.segment_track_array
else:
Expand Down Expand Up @@ -648,10 +651,12 @@ def dissociate_network(self):
"""
Dissociate network with no known interaction (spliting/merging)
"""
self.only_one_network()
tags = self.tag_segment()
# FIXME : Ok if only one network
self.track[:] = tags[self.segment - 1]

tags = self.tag_segment(multi_network=True)
if self.track[0] == 0:
tags -= 1

self.track[:] = tags[self.segment_track_array]

i_sort = self.obs.argsort(order=("track", "segment", "time"), kind="mergesort")
# Sort directly obs, with hope to save memory
Expand All @@ -672,23 +677,42 @@ def network(self, id_network):

@classmethod
def __tag_segment(cls, seg, tag, groups, connexions):
"""
Will set same temporary ID for each connected segment.

:param int seg: current ID of seg
:param ing tag: temporary ID to set for seg and its connexion
:param array[int] groups: array where tag will be stored
:param dict connexions: gives for one ID of seg all seg connected
"""
# If seg are already used we stop recursivity
seg_corrected = seg - 1
if groups[seg] != 0:
return
# We set tag for this seg
groups[seg] = tag
segs = connexions.get(seg + 1, None)
# Get all connexions of this seg
segs = connexions.get(seg, None)
if segs is not None:
for seg in segs:
cls.__tag_segment(seg - 1, tag, groups, connexions)
# For each connexion we apply same function
cls.__tag_segment(seg, tag, groups, connexions)

def tag_segment(self):
self.only_one_network()
nb = self.segment.max()

def tag_segment(self, multi_network=False):
if multi_network:
nb = self.segment_track_array[-1]+1
else:
nb = self.segment.max()+1
sub_group = zeros(nb, dtype="u4")
c = self.connexions()
c = self.connexions(multi_network=multi_network)
j = 1
# for each available id
for i in range(nb):
# Skip if already set
if sub_group[i] != 0:
continue
# we tag an unset segments and explore all connexions
self.__tag_segment(i, j, sub_group, c)
j += 1
return sub_group
Expand Down