Skip to content

Commit 652d5ed

Browse files
authored
Merge pull request AntSimi#73 from ludwigVonKoopa/master
add new features
2 parents 3d237a2 + d01bb3f commit 652d5ed

File tree

2 files changed

+229
-4
lines changed

2 files changed

+229
-4
lines changed

examples/16_network/pet_relative.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,51 @@
136136
ax.set_title(f"Clean network ({n_clean.infos()})")
137137
_ = n_clean.display_timeline(ax)
138138

139+
140+
# %%
141+
# change splittint-merging events
142+
# ------------------
143+
# change event where seg A split to B, then A merge into B, to A split to B then B merge into A
144+
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 12), dpi=120)
145+
146+
ax1.set_title(f"Clean network ({n_clean.infos()})")
147+
n_clean.display_timeline(ax1)
148+
149+
clean_modified = n_clean.copy()
150+
clean_modified.correct_close_events(100)
151+
ax2.set_title(f"resplitted network ({clean_modified.infos()})")
152+
_ = clean_modified.display_timeline(ax2)
153+
154+
155+
# %%
156+
# keep only observations where water could propagate from an observation
157+
# ----------------------------------------------------------------------
158+
159+
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 12), dpi=120)
160+
i_observation = 600
161+
only_linked = n_clean.find_link(i_observation)
162+
163+
for ax, dataset in zip([ax1, ax2], [n_clean, only_linked]):
164+
dataset.display_timeline(
165+
ax, field="segment", marker="+", lw=2, markersize=5, colors_mode="y"
166+
)
167+
ax.scatter(
168+
n_clean.time[i_observation],
169+
n_clean.segment[i_observation],
170+
marker="s",
171+
s=50,
172+
color="black",
173+
zorder=200,
174+
label="observation start",
175+
alpha=0.6,
176+
)
177+
ax.legend()
178+
179+
ax1.set_title(f"full example ({n_clean.infos()})")
180+
ax2.set_title(f"only linked observations ({only_linked.infos()})")
181+
ax2.set_xlim(ax1.get_xlim())
182+
ax2.set_ylim(ax1.get_ylim())
183+
139184
# %%
140185
# For further figure we will use clean path
141186
n = n_clean

src/py_eddy_tracker/observations/network.py

Lines changed: 184 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,188 @@ def from_split_network(cls, group_dataset, indexs, **kwargs):
240240
def infos(self, label=""):
241241
return f"{len(self)} obs {unique(self.segment).shape[0]} segments"
242242

243+
def correct_close_events(self, nb_days_max=20):
244+
"""
245+
transform event where
246+
segment A split to B, then A merge into B
247+
248+
to
249+
250+
segment A split to B, then B merge to A
251+
252+
these events are filtered with `nb_days_max`, which the event have to take place in less than `nb_days_max`
253+
254+
:param float nb_days_max: maximum time to search for splitting-merging event
255+
"""
256+
257+
_time = self.time
258+
# segment used to correct and track changes
259+
segment = self.segment_track_array.copy()
260+
# final segment used to copy into self.segment
261+
segment_copy = self.segment
262+
263+
segments_connexion = dict()
264+
265+
previous_obs, next_obs = self.previous_obs, self.next_obs
266+
267+
# record for every segments, the slice, indice of next obs & indice of previous obs
268+
for i, seg, _ in self.iter_on(segment):
269+
if i.start == i.stop:
270+
continue
271+
272+
i_p, i_n = previous_obs[i.start], next_obs[i.stop - 1]
273+
segments_connexion[seg] = [i, i_p, i_n]
274+
275+
for seg in sorted(segments_connexion.keys()):
276+
seg_slice, i_seg_p, i_seg_n = segments_connexion[seg]
277+
278+
# the segment ID has to be corrected, because we may have changed it since
279+
seg_corrected = segment[seg_slice.stop - 1]
280+
281+
# we keep the real segment number
282+
seg_corrected_copy = segment_copy[seg_slice.stop - 1]
283+
284+
n_seg = segment[i_seg_n]
285+
286+
# if segment has splitting
287+
if i_seg_n != -1:
288+
seg2_slice, i2_seg_p, i2_seg_n = segments_connexion[n_seg]
289+
p2_seg = segment[i2_seg_p]
290+
291+
# if it merge on the first in a certain time
292+
if (p2_seg == seg_corrected) and (
293+
_time[i_seg_n] - _time[i2_seg_p] < nb_days_max
294+
):
295+
my_slice = slice(i_seg_n, seg2_slice.stop)
296+
# correct the factice segment
297+
segment[my_slice] = seg_corrected
298+
# correct the good segment
299+
segment_copy[my_slice] = seg_corrected_copy
300+
previous_obs[i_seg_n] = seg_slice.stop - 1
301+
302+
segments_connexion[seg_corrected][0] = my_slice
303+
304+
self.segment[:] = segment_copy
305+
self.previous_obs[:] = previous_obs
306+
307+
self.sort()
308+
309+
def sort(self, order=("track", "segment", "time")):
310+
"""
311+
sort observations
312+
313+
:param tuple order: order or sorting. Passed to `np.argsort`
314+
"""
315+
316+
index_order = self.obs.argsort(order=order)
317+
for field in self.elements:
318+
self[field][:] = self[field][index_order]
319+
320+
translate = -ones(index_order.max() + 2, dtype="i4")
321+
translate[index_order] = arange(index_order.shape[0])
322+
self.next_obs[:] = translate[self.next_obs]
323+
self.previous_obs[:] = translate[self.previous_obs]
324+
243325
def obs_relative_order(self, i_obs):
244326
self.only_one_network()
245327
return self.segment_relative_order(self.segment[i_obs])
246328

329+
def find_link(self, i_observations, forward=True, backward=False):
330+
"""
331+
find all observations where obs `i_observation` could be
332+
in future or past.
333+
334+
if forward=True, search all observation where water
335+
from obs "i_observation" could go
336+
337+
if backward=True, search all observation
338+
where water from obs `i_observation` could come from
339+
340+
:param int,iterable(int) i_observation:
341+
indices of observation. Can be
342+
int, or iterable of int.
343+
:param bool forward, backward:
344+
if forward, search observations after obs.
345+
else mode==backward search before obs
346+
347+
"""
348+
349+
i_obs = (
350+
[i_observations]
351+
if not hasattr(i_observations, "__iter__")
352+
else i_observations
353+
)
354+
355+
segment = self.segment_track_array
356+
previous_obs, next_obs = self.previous_obs, self.next_obs
357+
358+
segments_connexion = dict()
359+
360+
for i_slice, seg, _ in self.iter_on(segment):
361+
if i_slice.start == i_slice.stop:
362+
continue
363+
364+
i_p, i_n = previous_obs[i_slice.start], next_obs[i_slice.stop - 1]
365+
p_seg, n_seg = segment[i_p], segment[i_n]
366+
367+
# dumping slice into dict
368+
if seg not in segments_connexion:
369+
segments_connexion[seg] = [i_slice, [], []]
370+
else:
371+
segments_connexion[seg][0] = i_slice
372+
373+
if i_p != -1:
374+
375+
if p_seg not in segments_connexion:
376+
segments_connexion[p_seg] = [None, [], []]
377+
378+
# backward
379+
segments_connexion[seg][2].append((i_slice.start, i_p, p_seg))
380+
# forward
381+
segments_connexion[p_seg][1].append((i_p, i_slice.start, seg))
382+
383+
if i_n != -1:
384+
if n_seg not in segments_connexion:
385+
segments_connexion[n_seg] = [None, [], []]
386+
387+
# forward
388+
segments_connexion[seg][1].append((i_slice.stop - 1, i_n, n_seg))
389+
# backward
390+
segments_connexion[n_seg][2].append((i_n, i_slice.stop - 1, seg))
391+
392+
mask = zeros(segment.size, dtype=bool)
393+
394+
def func_forward(seg, indice):
395+
seg_slice, _forward, _ = segments_connexion[seg]
396+
397+
mask[indice : seg_slice.stop] = True
398+
for i_begin, i_end, seg2 in _forward:
399+
if i_begin < indice:
400+
continue
401+
402+
if not mask[i_end]:
403+
func_forward(seg2, i_end)
404+
405+
def func_backward(seg, indice):
406+
seg_slice, _, _backward = segments_connexion[seg]
407+
408+
mask[seg_slice.start : indice + 1] = True
409+
for i_begin, i_end, seg2 in _backward:
410+
if i_begin > indice:
411+
continue
412+
413+
if not mask[i_end]:
414+
func_backward(seg2, i_end)
415+
416+
for indice in i_obs:
417+
if forward:
418+
func_forward(segment[indice], indice)
419+
420+
if backward:
421+
func_backward(segment[indice], indice)
422+
423+
return self.extract_with_mask(mask)
424+
247425
def connexions(self, multi_network=False):
248426
"""
249427
create dictionnary for each segments, gives the segments which interact with
@@ -490,14 +668,16 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol
490668

491669
# TODO : fill mappables dict
492670
y_seg = dict()
671+
_time = self.time
672+
493673
if field is not None and method != "all":
494674
for i, b0, _ in self.iter_on("segment"):
495675
y = self[field][i]
496676
if y.shape[0] != 0:
497677
y_seg[b0] = y.mean() * factor
498678
mappables = dict()
499679
for i, b0, b1 in self.iter_on("segment"):
500-
x = self.time[i]
680+
x = _time[i]
501681
if x.shape[0] == 0:
502682
continue
503683

@@ -532,7 +712,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol
532712
else y_seg[seg_next]
533713
)
534714
)
535-
ax.plot((x[-1], self.time[i_n]), (y0, y1), **event_kw)[0]
715+
ax.plot((x[-1], _time[i_n]), (y0, y1), **event_kw)[0]
536716
events["merging"].append((x[-1], y0))
537717

538718
if i_p != -1:
@@ -548,7 +728,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol
548728
else y_seg[seg_previous]
549729
)
550730
)
551-
ax.plot((x[0], self.time[i_p]), (y0, y1), **event_kw)[0]
731+
ax.plot((x[0], _time[i_p]), (y0, y1), **event_kw)[0]
552732
events["spliting"].append((x[0], y0))
553733

554734
j += 1
@@ -1045,7 +1225,7 @@ def extract_with_mask(self, mask):
10451225
logger.warning("Empty dataset will be created")
10461226
else:
10471227
logger.info(
1048-
f"{nb_obs} observations will be extract ({nb_obs * 100. / self.shape[0]}%)"
1228+
f"{nb_obs} observations will be extract ({nb_obs / self.shape[0]:.3%})"
10491229
)
10501230
for field in self.obs.dtype.descr:
10511231
if field in ("next_obs", "previous_obs"):

0 commit comments

Comments
 (0)