diff --git a/examples/16_network/pet_relative.py b/examples/16_network/pet_relative.py index f18d79a0..9993104c 100644 --- a/examples/16_network/pet_relative.py +++ b/examples/16_network/pet_relative.py @@ -199,6 +199,8 @@ i_event = 5 # %% # then see some order of relatives + + @FuncFormatter def formatter(x, pos): return (datetime.timedelta(x) + datetime.datetime(1950, 1, 1)).strftime("%d/%m/%Y") @@ -209,7 +211,7 @@ def formatter(x, pos): max_order + 2, 1, sharex=True, figsize=(15, 5 * (max_order + 2)) ) -axs[0].set_title(f"full network", weight="bold") +axs[0].set_title("full network", weight="bold") axs[0].xaxis.set_major_formatter(formatter), axs[0].grid() mappables = n.display_timeline(axs[0], colors_mode="y") axs[0].legend() diff --git a/src/py_eddy_tracker/observations/groups.py b/src/py_eddy_tracker/observations/groups.py new file mode 100644 index 00000000..b5534964 --- /dev/null +++ b/src/py_eddy_tracker/observations/groups.py @@ -0,0 +1,131 @@ +import logging +from abc import ABC, abstractmethod + +from numba import njit +from numpy import arange, int32, interp, median, zeros + +from .observation import EddiesObservations + +logger = logging.getLogger("pet") + + +@njit(cache=True) +def get_missing_indices( + array_time, array_track, dt=1, flag_untrack=True, indice_untrack=0 +): + """return indices where it misses values + + :param np.array(int) array_time : array of strictly increasing int representing time + :param np.array(int) array_track: N° track where observation belong + :param int,float dt: theorical timedelta between 2 observation + :param bool flag_untrack: if True, ignore observations where n°track equal `indice_untrack` + :param int indice_untrack: n° representing where observations are untrack + + + ex : array_time = np.array([67, 68, 70, 71, 74, 75]) + array_track= np.array([ 1, 1, 1, 1, 1, 1]) + return : np.array([2, 4, 4]) + """ + + t0 = array_time[0] + t1 = t0 + + tr0 = array_track[0] + tr1 = tr0 + + nbr_step = zeros(array_time.shape, dtype=int32) + + for i in range(array_time.size - 1): + t0 = t1 + tr0 = tr1 + + t1 = array_time[i + 1] + tr1 = array_track[i + 1] + + if flag_untrack & (tr1 == indice_untrack): + continue + + if tr1 != tr0: + continue + + diff = t1 - t0 + if diff > dt: + nbr_step[i] = int(diff / dt) - 1 + + indices = zeros(nbr_step.sum(), dtype=int32) + + j = 0 + for i in range(array_time.size - 1): + nbr_missing = nbr_step[i] + + if nbr_missing != 0: + for k in range(nbr_missing): + indices[j] = i + 1 + j += 1 + return indices + + +class GroupEddiesObservations(EddiesObservations, ABC): + @abstractmethod + def fix_next_previous_obs(self): + pass + + @abstractmethod + def get_missing_indices(self, dt): + "find indices where observations is missing" + pass + + def filled_by_interpolation(self, mask): + """Filled selected values by interpolation + + :param array(bool) mask: True if must be filled by interpolation + + .. minigallery:: py_eddy_tracker.TrackEddiesObservations.filled_by_interpolation + """ + + nb_filled = mask.sum() + logger.info("%d obs will be filled (unobserved)", nb_filled) + + nb_obs = len(self) + index = arange(nb_obs) + + for field in self.obs.dtype.descr: + # print(f"field : {field}") + var = field[0] + if ( + var in ["n", "virtual", "track", "cost_association"] + or var in self.array_variables + ): + continue + self.obs[var][mask] = interp( + index[mask], index[~mask], self.obs[var][~mask] + ) + + def insert_virtual(self): + + dt_theorical = median(self.time[1:] - self.time[:-1]) + indices = self.get_missing_indices(dt_theorical) + + logger.info("%d virtual observation will be added", indices.size) + + # new observation size + size_obs_corrected = self.time.size + indices.size + + # correction of indices for new size + indices_corrected = indices + arange(indices.size) + + # creating mask with indices + mask = zeros(size_obs_corrected, dtype=bool) + mask[indices_corrected] = 1 + + # time2 = np.empty(n.time.size+indices.size, dtype=n.time.dtype) + # time2[mask] = -1 + # time2[~mask] = n.time + + # new_TEO = TrackEddiesObservations.new_like(n, size_obs_corrected) + new_TEO = self.new_like(self, size_obs_corrected) + new_TEO.obs[~mask] = self.obs + new_TEO.filled_by_interpolation(mask) + new_TEO.virtual[:] = mask + new_TEO.fix_next_previous_obs() + return new_TEO diff --git a/src/py_eddy_tracker/observations/network.py b/src/py_eddy_tracker/observations/network.py index 72f86d4b..edbab2e4 100644 --- a/src/py_eddy_tracker/observations/network.py +++ b/src/py_eddy_tracker/observations/network.py @@ -22,6 +22,7 @@ from ..generic import build_index, wrap_longitude from ..poly import bbox_intersection, vertice_overlap +from .groups import GroupEddiesObservations, get_missing_indices from .observation import EddiesObservations from .tracking import TrackEddiesObservations, track_loess_filter, track_median_filter @@ -73,7 +74,28 @@ def load_contour(self, filename): return self.DATA[filename] -class NetworkObservations(EddiesObservations): +@njit(cache=True) +def fix_next_previous_obs(next_obs, previous_obs, flag_virtual): + """when an observation is virtual, we have to fix the previous and next obs + + :param np.array(int) next_obs : indice of next observation from network + :param np.array(int previous_obs: indice of previous observation from network + :param np.array(bool) flag_virtual: if observation is virtual or not + """ + + for i_o in range(next_obs.size): + if not flag_virtual[i_o]: + continue + + # if there is many virtual side by side, there is some values writted multiple times. + # but it should not be slow + next_obs[i_o - 1] = i_o + next_obs[i_o] = i_o + 1 + previous_obs[i_o] = i_o - 1 + previous_obs[i_o + 1] = i_o + + +class NetworkObservations(GroupEddiesObservations): __slots__ = ("_index_network",) @@ -109,6 +131,25 @@ def find_segments_relative(self, obs, stopped=None, order=1): i_stopped = where(nw.segment == self.segment[stopped])[0][0] return nw.relatives([i_obs, i_stopped], order=order) + def get_missing_indices(self, dt): + """find indices where observations is missing. + + As network have all untrack observation in tracknumber `self.NOGROUP`, + we don't compute them + + :param int,float dt: theorical delta time between 2 observations + """ + return get_missing_indices( + self.time, self.track, dt=dt, flag_untrack=True, indice_untrack=self.NOGROUP + ) + + def fix_next_previous_obs(self): + """function used after 'insert_virtual', to correct next_obs and + previous obs. + """ + + fix_next_previous_obs(self.next_obs, self.previous_obs, self.virtual) + @property def index_network(self): if self._index_network is None: @@ -712,10 +753,6 @@ def scatter( mappables["scatter"] = ax.scatter(x, self.latitude, **kwargs) return mappables - def insert_virtual(self): - # TODO - pass - def extract_event(self, indices): nb = len(indices) new = EddiesObservations( diff --git a/src/py_eddy_tracker/observations/tracking.py b/src/py_eddy_tracker/observations/tracking.py index 6a716622..5902462d 100644 --- a/src/py_eddy_tracker/observations/tracking.py +++ b/src/py_eddy_tracker/observations/tracking.py @@ -16,7 +16,6 @@ degrees, empty, histogram, - interp, median, nan, ones, @@ -29,12 +28,12 @@ from .. import VAR_DESCR_inv, __version__ from ..generic import build_index, cumsum_by_track, distance, split_line, wrap_longitude from ..poly import bbox_intersection, merge, vertice_overlap -from .observation import EddiesObservations +from .groups import GroupEddiesObservations, get_missing_indices logger = logging.getLogger("pet") -class TrackEddiesObservations(EddiesObservations): +class TrackEddiesObservations(GroupEddiesObservations): """Class to practice Tracking on observations""" __slots__ = ("__obs_by_track", "__first_index_of_track", "__nb_track") @@ -77,6 +76,26 @@ def iter_track(self): continue yield self.index(slice(i0, i0 + nb)) + def get_missing_indices(self, dt): + """find indices where observations is missing. + + :param int,float dt: theorical delta time between 2 observations + """ + return get_missing_indices( + self.time, + self.track, + dt=dt, + flag_untrack=False, + indice_untrack=self.NOGROUP, + ) + + def fix_next_previous_obs(self): + """function used after 'insert_virtual', to correct next_obs and + previous obs. + """ + + pass + @property def nb_tracks(self): """ @@ -146,30 +165,6 @@ def distance_to_next(self): d_[-1] = 0 return d_ - def filled_by_interpolation(self, mask): - """Filled selected values by interpolation - - :param array(bool) mask: True if must be filled by interpolation - - .. minigallery:: py_eddy_tracker.TrackEddiesObservations.filled_by_interpolation - """ - nb_filled = mask.sum() - logger.info("%d obs will be filled (unobserved)", nb_filled) - - nb_obs = len(self) - index = arange(nb_obs) - - for field in self.obs.dtype.descr: - var = field[0] - if ( - var in ["n", "virtual", "track", "cost_association"] - or var in self.array_variables - ): - continue - self.obs[var][mask] = interp( - index[mask], index[~mask], self.obs[var][~mask] - ) - def normalize_longitude(self): """Normalize all longitude