diff --git a/examples/16_network/pet_segmentation_anim.py b/examples/16_network/pet_segmentation_anim.py index cc0dc23c..1cb84452 100644 --- a/examples/16_network/pet_segmentation_anim.py +++ b/examples/16_network/pet_segmentation_anim.py @@ -37,12 +37,14 @@ def save(self, *args, **kwargs): # %% # Overlaod of class to pick up TRACKS = list() +INDICES = list() class MyTrack(TrackEddiesObservations): @staticmethod def get_next_obs(i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs): TRACKS.append(ids["track"].copy()) + INDICES.append(i_current) return TrackEddiesObservations.get_next_obs( i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs ) @@ -70,9 +72,13 @@ def get_next_obs(i_current, ids, x, y, time_s, time_e, time_ref, window, **kwarg def update(i_frame): tr = TRACKS[i_frame] mappable_tracks.set_array(tr) - s = 80 * ones(tr.shape) + s = 40 * ones(tr.shape) s[tr == 0] = 4 mappable_tracks.set_sizes(s) + + indices_frames = INDICES[i_frame] + mappable_CONTOUR.set_data(e.contour_lon_e[indices_frames], e.contour_lat_e[indices_frames],) + mappable_CONTOUR.set_color(cmap.colors[tr[indices_frames] % len(cmap.colors)]) return (mappable_tracks,) @@ -85,6 +91,9 @@ def update(i_frame): mappable_tracks = ax.scatter( e.lon, e.lat, c=TRACKS[0], cmap=cmap, vmin=0, vmax=vmax, s=20 ) +mappable_CONTOUR = ax.plot( + e.contour_lon_e[INDICES[0]], e.contour_lat_e[INDICES[0]], color=cmap.colors[0] +)[0] ani = VideoAnimation( fig, update, frames=range(1, len(TRACKS), 4), interval=125, blit=True -) +) \ No newline at end of file diff --git a/src/py_eddy_tracker/generic.py b/src/py_eddy_tracker/generic.py index d6915088..1c066b0a 100644 --- a/src/py_eddy_tracker/generic.py +++ b/src/py_eddy_tracker/generic.py @@ -70,6 +70,8 @@ def build_index(groups): :param array groups: array which contain group to be separated :return: (first_index of each group, last_index of each group, value to shift group) :rtype: (array, array, int) + + [0,0,1,1,1, 2] -> ([0, 2, 5], [2, 5, 6], 0) """ i0, i1 = groups.min(), groups.max() amplitude = i1 - i0 + 1 @@ -191,59 +193,7 @@ def interp2d_geo(x_g, y_g, z_g, m_g, x, y, nearest=False): :return: z interpolated :rtype: array """ - if nearest: - return interp2d_nearest(x_g, y_g, z_g, x, y) - else: - return interp2d_bilinear(x_g, y_g, z_g, m_g, x, y) - - -@njit(cache=True, fastmath=True) -def interp2d_nearest(x_g, y_g, z_g, x, y): - """ - Nearest interpolation with wrapping if circular - - :param array x_g: coordinates of grid - :param array y_g: coordinates of grid - :param array z_g: Grid value - :param array x: coordinate where interpolate z - :param array y: coordinate where interpolate z - :return: z interpolated - :rtype: array - """ - x_ref = x_g[0] - y_ref = y_g[0] - x_step = x_g[1] - x_ref - y_step = y_g[1] - y_ref - nb_x = x_g.shape[0] - nb_y = y_g.shape[0] - is_circular = abs(x_g[-1] % 360 - (x_g[0] - x_step) % 360) < 1e-5 - z = empty(x.shape, dtype=z_g.dtype) - for i in prange(x.size): - i0 = int(round((x[i] - x_ref) / x_step)) - j0 = int(round((y[i] - y_ref) / y_step)) - if is_circular: - i0 %= nb_x - if i0 >= nb_x or i0 < 0 or j0 < 0 or j0 >= nb_y: - z[i] = nan - continue - z[i] = z_g[i0, j0] - return z - - -@njit(cache=True, fastmath=True) -def interp2d_bilinear(x_g, y_g, z_g, m_g, x, y): - """ - Bilinear interpolation with wrapping if circular - - :param array x_g: coordinates of grid - :param array y_g: coordinates of grid - :param array z_g: Grid value - :param array m_g: Boolean grid, True if value is masked - :param array x: coordinate where interpolate z - :param array y: coordinate where interpolate z - :return: z interpolated - :rtype: array - """ + # TODO : Maybe test if we are out of bounds x_ref = x_g[0] y_ref = y_g[0] x_step = x_g[1] - x_ref @@ -251,46 +201,46 @@ def interp2d_bilinear(x_g, y_g, z_g, m_g, x, y): nb_x = x_g.shape[0] nb_y = y_g.shape[0] is_circular = abs(x_g[-1] % 360 - (x_g[0] - x_step) % 360) < 1e-5 - # Indices which should be never exist - i0_old, j0_old, masked = -100000000, -10000000, False z = empty(x.shape, dtype=z_g.dtype) for i in prange(x.size): x_ = (x[i] - x_ref) / x_step y_ = (y[i] - y_ref) / y_step i0 = int(floor(x_)) + i1 = i0 + 1 + xd = x_ - i0 j0 = int(floor(y_)) - # corner are the same need only a new xd and yd - if i0 != i0_old or j0 != j0_old: - i1 = i0 + 1 - j1 = j0 + 1 - if is_circular: - i0 %= nb_x - i1 %= nb_x + j1 = j0 + 1 + if is_circular: + i0 %= nb_x + i1 %= nb_x + else: if i1 >= nb_x or i0 < 0 or j0 < 0 or j1 >= nb_y: - masked = True - else: - masked = False - if not masked: - if m_g[i0, j0] or m_g[i0, j1] or m_g[i1, j0] or m_g[i1, j1]: - masked = True - else: - z00, z01, z10, z11 = ( - z_g[i0, j0], - z_g[i0, j1], - z_g[i1, j0], - z_g[i1, j1], - ) - masked = False - # Need to be store only on change - i0_old, j0_old = i0, j0 - if masked: + z[i] = nan + continue + + yd = y_ - j0 + z00 = z_g[i0, j0] + z01 = z_g[i0, j1] + z10 = z_g[i1, j0] + z11 = z_g[i1, j1] + if m_g[i0, j0] or m_g[i0, j1] or m_g[i1, j0] or m_g[i1, j1]: z[i] = nan else: - xd = x_ - i0 - yd = y_ - j0 - z[i] = (z00 * (1 - xd) + (z10 * xd)) * (1 - yd) + ( - z01 * (1 - xd) + z11 * xd - ) * yd + if nearest: + if xd <= 0.5: + if yd <= 0.5: + z[i] = z00 + else: + z[i] = z01 + else: + if yd <= 0.5: + z[i] = z10 + else: + z[i] = z11 + else: + z[i] = (z00 * (1 - xd) + (z10 * xd)) * (1 - yd) + ( + z01 * (1 - xd) + z11 * xd + ) * yd return z diff --git a/src/py_eddy_tracker/observations/network.py b/src/py_eddy_tracker/observations/network.py index d8c339cf..f2f30ade 100644 --- a/src/py_eddy_tracker/observations/network.py +++ b/src/py_eddy_tracker/observations/network.py @@ -164,6 +164,9 @@ def obs_relative_order(self, i_obs): return self.segment_relative_order(self.segment[i_obs]) def connexions(self, multi_network=False): + """ + create dictionnary for each segments, gives the segments which interact with + """ if multi_network: segment = self.segment_track_array else: @@ -648,10 +651,12 @@ def dissociate_network(self): """ Dissociate network with no known interaction (spliting/merging) """ - self.only_one_network() - tags = self.tag_segment() - # FIXME : Ok if only one network - self.track[:] = tags[self.segment - 1] + + tags = self.tag_segment(multi_network=True) + if self.track[0] == 0: + tags -= 1 + + self.track[:] = tags[self.segment_track_array] i_sort = self.obs.argsort(order=("track", "segment", "time"), kind="mergesort") # Sort directly obs, with hope to save memory @@ -672,23 +677,42 @@ def network(self, id_network): @classmethod def __tag_segment(cls, seg, tag, groups, connexions): + """ + Will set same temporary ID for each connected segment. + + :param int seg: current ID of seg + :param ing tag: temporary ID to set for seg and its connexion + :param array[int] groups: array where tag will be stored + :param dict connexions: gives for one ID of seg all seg connected + """ + # If seg are already used we stop recursivity + seg_corrected = seg - 1 if groups[seg] != 0: return + # We set tag for this seg groups[seg] = tag - segs = connexions.get(seg + 1, None) + # Get all connexions of this seg + segs = connexions.get(seg, None) if segs is not None: for seg in segs: - cls.__tag_segment(seg - 1, tag, groups, connexions) + # For each connexion we apply same function + cls.__tag_segment(seg, tag, groups, connexions) - def tag_segment(self): - self.only_one_network() - nb = self.segment.max() + + def tag_segment(self, multi_network=False): + if multi_network: + nb = self.segment_track_array[-1]+1 + else: + nb = self.segment.max()+1 sub_group = zeros(nb, dtype="u4") - c = self.connexions() + c = self.connexions(multi_network=multi_network) j = 1 + # for each available id for i in range(nb): + # Skip if already set if sub_group[i] != 0: continue + # we tag an unset segments and explore all connexions self.__tag_segment(i, j, sub_group, c) j += 1 return sub_group