Skip to content

Commit 7565ca1

Browse files
ludwigVonKoopaAntSimi
authored andcommitted
add function insert_virtual and class GroupEddiesCollection
1 parent d746166 commit 7565ca1

File tree

4 files changed

+198
-33
lines changed

4 files changed

+198
-33
lines changed

examples/16_network/pet_relative.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@
199199
i_event = 5
200200
# %%
201201
# then see some order of relatives
202+
203+
202204
@FuncFormatter
203205
def formatter(x, pos):
204206
return (datetime.timedelta(x) + datetime.datetime(1950, 1, 1)).strftime("%d/%m/%Y")
@@ -209,7 +211,7 @@ def formatter(x, pos):
209211
max_order + 2, 1, sharex=True, figsize=(15, 5 * (max_order + 2))
210212
)
211213

212-
axs[0].set_title(f"full network", weight="bold")
214+
axs[0].set_title("full network", weight="bold")
213215
axs[0].xaxis.set_major_formatter(formatter), axs[0].grid()
214216
mappables = n.display_timeline(axs[0], colors_mode="y")
215217
axs[0].legend()
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import logging
2+
from abc import ABC, abstractmethod
3+
4+
from numba import njit
5+
from numpy import arange, int32, interp, median, zeros
6+
7+
from .observation import EddiesObservations
8+
9+
logger = logging.getLogger("pet")
10+
11+
12+
@njit(cache=True)
13+
def get_missing_indices(
14+
array_time, array_track, dt=1, flag_untrack=True, indice_untrack=0
15+
):
16+
"""return indices where it misses values
17+
18+
:param np.array(int) array_time : array of strictly increasing int representing time
19+
:param np.array(int) array_track: N° track where observation belong
20+
:param int,float dt: theorical timedelta between 2 observation
21+
:param bool flag_untrack: if True, ignore observations where n°track equal `indice_untrack`
22+
:param int indice_untrack: n° representing where observations are untrack
23+
24+
25+
ex : array_time = np.array([67, 68, 70, 71, 74, 75])
26+
array_track= np.array([ 1, 1, 1, 1, 1, 1])
27+
return : np.array([2, 4, 4])
28+
"""
29+
30+
t0 = array_time[0]
31+
t1 = t0
32+
33+
tr0 = array_track[0]
34+
tr1 = tr0
35+
36+
nbr_step = zeros(array_time.shape, dtype=int32)
37+
38+
for i in range(array_time.size - 1):
39+
t0 = t1
40+
tr0 = tr1
41+
42+
t1 = array_time[i + 1]
43+
tr1 = array_track[i + 1]
44+
45+
if flag_untrack & (tr1 == indice_untrack):
46+
continue
47+
48+
if tr1 != tr0:
49+
continue
50+
51+
diff = t1 - t0
52+
if diff > dt:
53+
nbr_step[i] = int(diff / dt) - 1
54+
55+
indices = zeros(nbr_step.sum(), dtype=int32)
56+
57+
j = 0
58+
for i in range(array_time.size - 1):
59+
nbr_missing = nbr_step[i]
60+
61+
if nbr_missing != 0:
62+
for k in range(nbr_missing):
63+
indices[j] = i + 1
64+
j += 1
65+
return indices
66+
67+
68+
class GroupEddiesObservations(EddiesObservations, ABC):
69+
@abstractmethod
70+
def fix_next_previous_obs(self):
71+
pass
72+
73+
@abstractmethod
74+
def get_missing_indices(self, dt):
75+
"find indices where observations is missing"
76+
pass
77+
78+
def filled_by_interpolation(self, mask):
79+
"""Filled selected values by interpolation
80+
81+
:param array(bool) mask: True if must be filled by interpolation
82+
83+
.. minigallery:: py_eddy_tracker.TrackEddiesObservations.filled_by_interpolation
84+
"""
85+
86+
nb_filled = mask.sum()
87+
logger.info("%d obs will be filled (unobserved)", nb_filled)
88+
89+
nb_obs = len(self)
90+
index = arange(nb_obs)
91+
92+
for field in self.obs.dtype.descr:
93+
# print(f"field : {field}")
94+
var = field[0]
95+
if (
96+
var in ["n", "virtual", "track", "cost_association"]
97+
or var in self.array_variables
98+
):
99+
continue
100+
self.obs[var][mask] = interp(
101+
index[mask], index[~mask], self.obs[var][~mask]
102+
)
103+
104+
def insert_virtual(self):
105+
106+
dt_theorical = median(self.time[1:] - self.time[:-1])
107+
indices = self.get_missing_indices(dt_theorical)
108+
109+
logger.info("%d virtual observation will be added", indices.size)
110+
111+
# new observation size
112+
size_obs_corrected = self.time.size + indices.size
113+
114+
# correction of indices for new size
115+
indices_corrected = indices + arange(indices.size)
116+
117+
# creating mask with indices
118+
mask = zeros(size_obs_corrected, dtype=bool)
119+
mask[indices_corrected] = 1
120+
121+
# time2 = np.empty(n.time.size+indices.size, dtype=n.time.dtype)
122+
# time2[mask] = -1
123+
# time2[~mask] = n.time
124+
125+
# new_TEO = TrackEddiesObservations.new_like(n, size_obs_corrected)
126+
new_TEO = self.new_like(self, size_obs_corrected)
127+
new_TEO.obs[~mask] = self.obs
128+
new_TEO.filled_by_interpolation(mask)
129+
new_TEO.virtual[:] = mask
130+
new_TEO.fix_next_previous_obs()
131+
return new_TEO

src/py_eddy_tracker/observations/network.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from ..generic import build_index, wrap_longitude
2424
from ..poly import bbox_intersection, vertice_overlap
25+
from .groups import GroupEddiesObservations, get_missing_indices
2526
from .observation import EddiesObservations
2627
from .tracking import TrackEddiesObservations, track_loess_filter, track_median_filter
2728

@@ -73,7 +74,28 @@ def load_contour(self, filename):
7374
return self.DATA[filename]
7475

7576

76-
class NetworkObservations(EddiesObservations):
77+
@njit(cache=True)
78+
def fix_next_previous_obs(next_obs, previous_obs, flag_virtual):
79+
"""when an observation is virtual, we have to fix the previous and next obs
80+
81+
:param np.array(int) next_obs : indice of next observation from network
82+
:param np.array(int previous_obs: indice of previous observation from network
83+
:param np.array(bool) flag_virtual: if observation is virtual or not
84+
"""
85+
86+
for i_o in range(next_obs.size):
87+
if not flag_virtual[i_o]:
88+
continue
89+
90+
# if there is many virtual side by side, there is some values writted multiple times.
91+
# but it should not be slow
92+
next_obs[i_o - 1] = i_o
93+
next_obs[i_o] = i_o + 1
94+
previous_obs[i_o] = i_o - 1
95+
previous_obs[i_o + 1] = i_o
96+
97+
98+
class NetworkObservations(GroupEddiesObservations):
7799

78100
__slots__ = ("_index_network",)
79101

@@ -109,6 +131,25 @@ def find_segments_relative(self, obs, stopped=None, order=1):
109131
i_stopped = where(nw.segment == self.segment[stopped])[0][0]
110132
return nw.relatives([i_obs, i_stopped], order=order)
111133

134+
def get_missing_indices(self, dt):
135+
"""find indices where observations is missing.
136+
137+
As network have all untrack observation in tracknumber `self.NOGROUP`,
138+
we don't compute them
139+
140+
:param int,float dt: theorical delta time between 2 observations
141+
"""
142+
return get_missing_indices(
143+
self.time, self.track, dt=dt, flag_untrack=True, indice_untrack=self.NOGROUP
144+
)
145+
146+
def fix_next_previous_obs(self):
147+
"""function used after 'insert_virtual', to correct next_obs and
148+
previous obs.
149+
"""
150+
151+
fix_next_previous_obs(self.next_obs, self.previous_obs, self.virtual)
152+
112153
@property
113154
def index_network(self):
114155
if self._index_network is None:
@@ -712,10 +753,6 @@ def scatter(
712753
mappables["scatter"] = ax.scatter(x, self.latitude, **kwargs)
713754
return mappables
714755

715-
def insert_virtual(self):
716-
# TODO
717-
pass
718-
719756
def extract_event(self, indices):
720757
nb = len(indices)
721758
new = EddiesObservations(

src/py_eddy_tracker/observations/tracking.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
degrees,
1717
empty,
1818
histogram,
19-
interp,
2019
median,
2120
nan,
2221
ones,
@@ -29,12 +28,12 @@
2928
from .. import VAR_DESCR_inv, __version__
3029
from ..generic import build_index, cumsum_by_track, distance, split_line, wrap_longitude
3130
from ..poly import bbox_intersection, merge, vertice_overlap
32-
from .observation import EddiesObservations
31+
from .groups import GroupEddiesObservations, get_missing_indices
3332

3433
logger = logging.getLogger("pet")
3534

3635

37-
class TrackEddiesObservations(EddiesObservations):
36+
class TrackEddiesObservations(GroupEddiesObservations):
3837
"""Class to practice Tracking on observations"""
3938

4039
__slots__ = ("__obs_by_track", "__first_index_of_track", "__nb_track")
@@ -77,6 +76,26 @@ def iter_track(self):
7776
continue
7877
yield self.index(slice(i0, i0 + nb))
7978

79+
def get_missing_indices(self, dt):
80+
"""find indices where observations is missing.
81+
82+
:param int,float dt: theorical delta time between 2 observations
83+
"""
84+
return get_missing_indices(
85+
self.time,
86+
self.track,
87+
dt=dt,
88+
flag_untrack=False,
89+
indice_untrack=self.NOGROUP,
90+
)
91+
92+
def fix_next_previous_obs(self):
93+
"""function used after 'insert_virtual', to correct next_obs and
94+
previous obs.
95+
"""
96+
97+
pass
98+
8099
@property
81100
def nb_tracks(self):
82101
"""
@@ -146,30 +165,6 @@ def distance_to_next(self):
146165
d_[-1] = 0
147166
return d_
148167

149-
def filled_by_interpolation(self, mask):
150-
"""Filled selected values by interpolation
151-
152-
:param array(bool) mask: True if must be filled by interpolation
153-
154-
.. minigallery:: py_eddy_tracker.TrackEddiesObservations.filled_by_interpolation
155-
"""
156-
nb_filled = mask.sum()
157-
logger.info("%d obs will be filled (unobserved)", nb_filled)
158-
159-
nb_obs = len(self)
160-
index = arange(nb_obs)
161-
162-
for field in self.obs.dtype.descr:
163-
var = field[0]
164-
if (
165-
var in ["n", "virtual", "track", "cost_association"]
166-
or var in self.array_variables
167-
):
168-
continue
169-
self.obs[var][mask] = interp(
170-
index[mask], index[~mask], self.obs[var][~mask]
171-
)
172-
173168
def normalize_longitude(self):
174169
"""Normalize all longitude
175170

0 commit comments

Comments
 (0)