Skip to content

Commit 493aed8

Browse files
committed
add new method for network
1 parent 3df15b9 commit 493aed8

File tree

3 files changed

+182
-49
lines changed

3 files changed

+182
-49
lines changed

examples/16_network/pet_relative.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44
"""
55

66
from matplotlib import pyplot as plt
7+
from matplotlib.animation import FuncAnimation
8+
from numpy import arange
79

810
import py_eddy_tracker.gui
911
from py_eddy_tracker import data
12+
from py_eddy_tracker.appli.gui import Anim
1013
from py_eddy_tracker.observations.network import NetworkObservations
1114
from py_eddy_tracker.observations.tracking import TrackEddiesObservations
1215

@@ -30,13 +33,41 @@
3033
# Display timeline
3134
fig = plt.figure(figsize=(15, 5))
3235
ax = fig.add_axes([0.04, 0.04, 0.92, 0.92])
33-
n.display_timeline(ax)
36+
_ = n.display_timeline(ax)
3437

3538
# %%
3639
# Display timeline without event
3740
fig = plt.figure(figsize=(15, 5))
3841
ax = fig.add_axes([0.04, 0.04, 0.92, 0.92])
39-
n.display_timeline(ax, event=False)
42+
_ = n.display_timeline(ax, event=False)
43+
44+
# %%
45+
# Timeline by latitude mean
46+
# -------------------------
47+
fig = plt.figure(figsize=(15, 5))
48+
ax = fig.add_axes([0.04, 0.04, 0.92, 0.92])
49+
_ = n.display_timeline(ax, field="latitude")
50+
51+
# %%
52+
# Timeline by radius mean
53+
# -----------------------
54+
fig = plt.figure(figsize=(15, 5))
55+
ax = fig.add_axes([0.04, 0.04, 0.92, 0.92])
56+
_ = n.display_timeline(ax, field="radius_e")
57+
58+
# %%
59+
# Timeline by latitude
60+
# --------------------
61+
fig = plt.figure(figsize=(15, 5))
62+
ax = fig.add_axes([0.04, 0.05, 0.92, 0.92])
63+
_ = n.display_timeline(ax, field="lat", method="all")
64+
65+
# %%
66+
fig = plt.figure(figsize=(15, 5))
67+
ax = fig.add_axes([0.04, 0.05, 0.92, 0.92])
68+
n_copy = n.copy()
69+
n_copy.median_filter(15, "time", "latitude")
70+
_ = n_copy.display_timeline(ax, field="lat", method="all")
4071

4172
# %%
4273
# Parameters timeline
@@ -138,9 +169,12 @@
138169
# Display track on map
139170
# --------------------
140171
fig = plt.figure(figsize=(15, 8))
141-
ax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection="full_axes")
172+
ax = fig.add_axes([0.04, 0.06, 0.94, 0.88], projection="full_axes")
142173
close_to_i2.plot(ax)
143-
_ = ax.set_xlim(-10, 20), ax.set_ylim(-37, -21), ax.grid()
174+
ax.set_xlim(-13, 20), ax.set_ylim(-36.5, -20), ax.grid()
175+
ax = fig.add_axes([0.08, 0.67, 0.55, 0.3])
176+
_ = close_to_i2.display_timeline(ax, field="latitude")
177+
144178

145179
# %%
146180
# Get merging event
@@ -161,3 +195,23 @@
161195
spliting.display(ax)
162196
ax.set_xlim(-10, 20), ax.set_ylim(-37, -21), ax.grid()
163197
spliting
198+
199+
# %%
200+
# Get birth event
201+
# ------------------
202+
fig = plt.figure(figsize=(15, 8))
203+
ax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection="full_axes")
204+
birth = close_to_i2.birth_event()
205+
birth.display(ax)
206+
ax.set_xlim(-10, 20), ax.set_ylim(-37, -21), ax.grid()
207+
birth
208+
209+
# %%
210+
# Get death event
211+
# ------------------
212+
fig = plt.figure(figsize=(15, 8))
213+
ax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection="full_axes")
214+
death = close_to_i2.death_event()
215+
death.display(ax)
216+
ax.set_xlim(-10, 20), ax.set_ylim(-37, -21), ax.grid()
217+
death

src/py_eddy_tracker/observations/network.py

Lines changed: 114 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ..generic import build_index, wrap_longitude
1212
from ..poly import bbox_intersection, vertice_overlap
1313
from .observation import EddiesObservations
14-
from .tracking import TrackEddiesObservations
14+
from .tracking import TrackEddiesObservations, track_median_filter
1515

1616
logger = logging.getLogger("pet")
1717

@@ -168,7 +168,17 @@ def only_one_network(self):
168168
# TODO
169169
pass
170170

171-
def display_timeline(self, ax, event=True):
171+
def median_filter(self, half_window, xfield, yfield, inplace=True):
172+
# FIXME: segments is not enough with several network
173+
result = track_median_filter(
174+
half_window, self[xfield], self[yfield], self.segment
175+
)
176+
if inplace:
177+
self[yfield][:] = result
178+
return self
179+
return result
180+
181+
def display_timeline(self, ax, event=True, field=None, method=None):
172182
"""
173183
Must be call on only one network
174184
"""
@@ -183,21 +193,33 @@ def display_timeline(self, ax, event=True):
183193
)
184194
mappables = dict(lines=list())
185195
if event:
186-
mappables.update(self.event_timeline(ax))
196+
mappables.update(self.event_timeline(ax, field=field, method=method))
187197
for i, b0, b1 in self.iter_on("segment"):
188198
x = self.time[i]
189199
if x.shape[0] == 0:
190200
continue
191-
y = b0 * ones(x.shape)
201+
if field is None:
202+
y = b0 * ones(x.shape)
203+
else:
204+
if method == "all":
205+
y = self[field][i]
206+
else:
207+
y = self[field][i].mean() * ones(x.shape)
192208
line = ax.plot(x, y, **line_kw, color=self.COLORS[j % self.NB_COLORS])[0]
193209
mappables["lines"].append(line)
194210
j += 1
195211

196212
return mappables
197213

198-
def event_timeline(self, ax):
214+
def event_timeline(self, ax, field=None, method=None):
199215
j = 0
200216
# TODO : fill mappables dict
217+
y_seg = dict()
218+
if field is not None and method != "all":
219+
for i, b0, _ in self.iter_on("segment"):
220+
y = self[field][i]
221+
if y.shape[0] != 0:
222+
y_seg[b0] = y.mean()
201223
mappables = dict()
202224
for i, b0, b1 in self.iter_on("segment"):
203225
x = self.time[i]
@@ -208,12 +230,33 @@ def event_timeline(self, ax):
208230
self.next_obs[i.stop - 1],
209231
self.previous_obs[i.start],
210232
)
233+
if field is None:
234+
y0 = b0
235+
else:
236+
if method == "all":
237+
y0 = self[field][i.stop - 1]
238+
else:
239+
y0 = y_seg[b0]
211240
if i_n != -1:
212-
ax.plot((x[-1], self.time[i_n]), (b0, self.segment[i_n]), **event_kw)
213-
ax.plot(x[-1], b0, color="k", marker=">", markersize=10, zorder=-1)
241+
seg_next = self.segment[i_n]
242+
y1 = (
243+
seg_next
244+
if field is None
245+
else (self[field][i_n] if method == "all" else y_seg[seg_next])
246+
)
247+
ax.plot((x[-1], self.time[i_n]), (y0, y1), **event_kw)[0]
248+
ax.plot(x[-1], y0, color="k", marker=">", markersize=10, zorder=-1)[0]
214249
if i_p != -1:
215-
ax.plot((x[0], self.time[i_p]), (b0, self.segment[i_p]), **event_kw)
216-
ax.plot(x[0], b0, color="k", marker="*", markersize=12, zorder=-1)
250+
seg_previous = self.segment[i_p]
251+
if field is not None and method == "all":
252+
y0 = self[field][i.start]
253+
y1 = (
254+
seg_previous
255+
if field is None
256+
else (self[field][i_p] if method == "all" else y_seg[seg_previous])
257+
)
258+
ax.plot((x[0], self.time[i_p]), (y0, y1), **event_kw)[0]
259+
ax.plot(x[0], y0, color="k", marker="*", markersize=12, zorder=-1)[0]
217260
j += 1
218261
return mappables
219262

@@ -235,16 +278,7 @@ def insert_virtual(self):
235278
# TODO
236279
pass
237280

238-
def merging_event(self):
239-
indices = list()
240-
for i, b0, b1 in self.iter_on("segment"):
241-
nb = i.stop - i.start
242-
if nb == 0:
243-
continue
244-
i_n = self.next_obs[i.stop - 1]
245-
if i_n != -1:
246-
indices.append(i.stop - 1)
247-
indices = list(set(indices))
281+
def extract_event(self, indices):
248282
nb = len(indices)
249283
new = EddiesObservations(
250284
nb,
@@ -260,30 +294,54 @@ def merging_event(self):
260294
new.sign_type = self.sign_type
261295
return new
262296

297+
def segment_track_array(self):
298+
return build_unique_array(self.segment, self.track)
299+
300+
def birth_event(self):
301+
# FIXME how to manage group 0
302+
indices = list()
303+
for i, b0, b1 in self.iter_on(self.segment_track_array()):
304+
nb = i.stop - i.start
305+
if nb == 0:
306+
continue
307+
i_p = self.previous_obs[i.start]
308+
if i_p == -1:
309+
indices.append(i.start)
310+
return self.extract_event(list(set(indices)))
311+
312+
def death_event(self):
313+
# FIXME how to manage group 0
314+
indices = list()
315+
for i, b0, b1 in self.iter_on(self.segment_track_array()):
316+
nb = i.stop - i.start
317+
if nb == 0:
318+
continue
319+
i_n = self.next_obs[i.stop - 1]
320+
if i_n == -1:
321+
indices.append(i.stop - 1)
322+
return self.extract_event(list(set(indices)))
323+
324+
def merging_event(self):
325+
indices = list()
326+
for i, b0, b1 in self.iter_on(self.segment_track_array()):
327+
nb = i.stop - i.start
328+
if nb == 0:
329+
continue
330+
i_n = self.next_obs[i.stop - 1]
331+
if i_n != -1:
332+
indices.append(i.stop - 1)
333+
return self.extract_event(list(set(indices)))
334+
263335
def spliting_event(self):
264336
indices = list()
265-
for i, b0, b1 in self.iter_on("segment"):
337+
for i, b0, b1 in self.iter_on(self.segment_track_array()):
266338
nb = i.stop - i.start
267339
if nb == 0:
268340
continue
269341
i_p = self.previous_obs[i.start]
270342
if i_p != -1:
271343
indices.append(i.start)
272-
indices = list(set(indices))
273-
nb = len(indices)
274-
new = EddiesObservations(
275-
nb,
276-
track_extra_variables=self.track_extra_variables,
277-
track_array_variables=self.track_array_variables,
278-
array_variables=self.array_variables,
279-
only_variables=self.only_variables,
280-
raw_data=self.raw_data,
281-
)
282-
283-
for k in new.obs.dtype.names:
284-
new[k][:] = self[k][indices]
285-
new.sign_type = self.sign_type
286-
return new
344+
return self.extract_event(list(set(indices)))
287345

288346
def fully_connected(self):
289347
self.only_one_network()
@@ -463,14 +521,16 @@ def group_observations(self, **kwargs):
463521
print()
464522

465523
gr = self.get_group_array(results, nb_obs)
524+
nb_alone, nb_obs, nb_gr = (gr == self.NOGROUP).sum(), len(gr), len(unique(gr))
466525
logger.info(
467-
f"{(gr == self.NOGROUP).sum()} alone / {len(gr)} obs, {len(unique(gr))} groups"
526+
f"{nb_alone} alone / {nb_obs} obs, {nb_gr} groups, "
527+
f"{nb_alone *100./nb_obs:.2f} % alone, {(nb_obs - nb_alone) / (nb_gr - 1):.1f} obs/group"
468528
)
469529
return gr
470530

471531
def build_dataset(self, group):
472532
nb_obs = group.shape[0]
473-
model = EddiesObservations.load_file(self.filenames[-1], raw_data=True)
533+
model = TrackEddiesObservations.load_file(self.filenames[-1], raw_data=True)
474534
eddies = TrackEddiesObservations.new_like(model, nb_obs)
475535
eddies.sign_type = model.sign_type
476536
# Get new index to re-order observation by group
@@ -485,17 +545,16 @@ def build_dataset(self, group):
485545
if self.memory:
486546
# Only if netcdf
487547
with open(filename, "rb") as h:
488-
e = EddiesObservations.load_file(h, raw_data=True)
548+
e = TrackEddiesObservations.load_file(h, raw_data=True)
489549
else:
490-
e = EddiesObservations.load_file(filename, raw_data=True)
550+
e = TrackEddiesObservations.load_file(filename, raw_data=True)
491551
stop = i + len(e)
492552
sl = slice(i, stop)
493553
for element in elements:
494554
eddies[element][new_i[sl]] = e[element]
495555
i = stop
496556
if display_iteration:
497557
print()
498-
eddies = eddies.add_fields(("track",))
499558
eddies.track[new_i] = group
500559
return eddies
501560

@@ -518,3 +577,18 @@ def apply_replace(x, x0, x1):
518577
for i in range(nb):
519578
if x[i] == x0:
520579
x[i] = x1
580+
581+
582+
@njit(cache=True)
583+
def build_unique_array(id1, id2):
584+
k = 0
585+
new_id = empty(id1.shape, dtype=id1.dtype)
586+
id1_previous = id1[0]
587+
id2_previous = id2[0]
588+
for i in range(id1.shape[0]):
589+
id1_, id2_ = id1[i], id2[i]
590+
if id1_ != id1_previous or id2_ != id2_previous:
591+
k += 1
592+
new_id[i] = k
593+
id1_previous, id2_previous = id1_, id2_
594+
return new_id

src/py_eddy_tracker/observations/tracking.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -432,13 +432,13 @@ def loess_filter(self, half_window, xfield, yfield, inplace=True):
432432
return self
433433

434434
def median_filter(self, half_window, xfield, yfield, inplace=True):
435-
track = self.track
436-
x = self.obs[xfield]
437-
y = self.obs[yfield]
438-
result = track_median_filter(half_window, x, y, track)
435+
result = track_median_filter(
436+
half_window, self[xfield], self[yfield], self.track
437+
)
439438
if inplace:
440-
self.obs[yfield] = result
439+
self[yfield][:] = result
441440
return self
441+
return result
442442

443443
def position_filter(self, median_half_window, loess_half_window):
444444
self.median_filter(median_half_window, "time", "lon").loess_filter(
@@ -620,9 +620,12 @@ def split_network(self, intern=True, **kwargs):
620620
# and ids["next_obs"] == -1 means the end of a non-merged segment
621621

622622
xname, yname = self.intern(intern)
623+
display_iteration = logger.getEffectiveLevel() == logging.INFO
623624
for i_s, i_e in zip(track_s, track_e):
624625
if i_s == i_e or self.tracks[i_s] == self.NOGROUP:
625626
continue
627+
if display_iteration:
628+
print(f"Network obs from {i_s} to {i_e} on {track_e[-1]}", end="\r")
626629
sl = slice(i_s, i_e)
627630
local_ids = ids[sl]
628631
# built segments with local indices
@@ -632,6 +635,8 @@ def split_network(self, intern=True, **kwargs):
632635
local_ids["previous_obs"][m] += i_s
633636
m = local_ids["next_obs"] != -1
634637
local_ids["next_obs"][m] += i_s
638+
if display_iteration:
639+
print()
635640
return ids
636641

637642
def set_tracks(self, x, y, ids, window, **kwargs):

0 commit comments

Comments
 (0)