Skip to content

Commit d41f06c

Browse files
add find_link
add function to search which observations is directly linked with observation choosen.
1 parent 14934f2 commit d41f06c

File tree

2 files changed

+126
-0
lines changed

2 files changed

+126
-0
lines changed

examples/16_network/pet_relative.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,36 @@
151151
ax2.set_title(f"resplitted network ({clean_modified.infos()})")
152152
_ = clean_modified.display_timeline(ax2)
153153

154+
155+
# %%
156+
# keep only observations where water could propagate from an observation
157+
# ----------------------------------------------------------------------
158+
159+
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 12), dpi=120)
160+
i_observation = 600
161+
only_linked = n_clean.find_link(i_observation)
162+
163+
for ax, dataset in zip([ax1, ax2], [n_clean, only_linked]):
164+
dataset.display_timeline(
165+
ax, field="segment", marker="+", lw=2, markersize=5, colors_mode="y"
166+
)
167+
ax.scatter(
168+
n_clean.time[i_observation],
169+
n_clean.segment[i_observation],
170+
marker="s",
171+
s=50,
172+
color="black",
173+
zorder=200,
174+
label="observation start",
175+
alpha=0.6,
176+
)
177+
ax.legend()
178+
179+
ax1.set_title(f"full example ({n_clean.infos()})")
180+
ax2.set_title(f"only linked observations ({only_linked.infos()})")
181+
ax2.set_xlim(ax1.get_xlim())
182+
ax2.set_ylim(ax1.get_ylim())
183+
154184
# %%
155185
# For further figure we will use clean path
156186
n = n_clean

src/py_eddy_tracker/observations/network.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,102 @@ def obs_relative_order(self, i_obs):
323323
self.only_one_network()
324324
return self.segment_relative_order(self.segment[i_obs])
325325

326+
def find_link(self, i_observations, forward=True, backward=False):
327+
"""
328+
find all observations where obs `i_observation` could be
329+
in future or past.
330+
331+
if forward=True, search all observation where water
332+
from obs "i_observation" could go
333+
334+
if backward=True, search all observation
335+
where water from obs `i_observation` could come from
336+
337+
:param int, iterable(int) i_observation:
338+
indices of observation. Can be
339+
int, or iterable of int.
340+
:param bool forward, backward:
341+
if forward, search observations after obs.
342+
else mode==backward search before obs
343+
344+
"""
345+
346+
i_obs = (
347+
[i_observations]
348+
if not hasattr(i_observations, "__iter__")
349+
else i_observations
350+
)
351+
352+
segment = self.segment_track_array
353+
previous_obs, next_obs = self.previous_obs, self.next_obs
354+
355+
segments_connexion = dict()
356+
357+
for i_slice, seg, _ in self.iter_on(segment):
358+
if i_slice.start == i_slice.stop:
359+
continue
360+
361+
i_p, i_n = previous_obs[i_slice.start], next_obs[i_slice.stop - 1]
362+
p_seg, n_seg = segment[i_p], segment[i_n]
363+
364+
# dumping slice into dict
365+
if seg not in segments_connexion:
366+
segments_connexion[seg] = [i_slice, [], []]
367+
else:
368+
segments_connexion[seg][0] = i_slice
369+
370+
if i_p != -1:
371+
372+
if p_seg not in segments_connexion:
373+
segments_connexion[p_seg] = [None, [], []]
374+
375+
# backward
376+
segments_connexion[seg][2].append((i_slice.start, i_p, p_seg))
377+
# forward
378+
segments_connexion[p_seg][1].append((i_p, i_slice.start, seg))
379+
380+
if i_n != -1:
381+
if n_seg not in segments_connexion:
382+
segments_connexion[n_seg] = [None, [], []]
383+
384+
# forward
385+
segments_connexion[seg][1].append((i_slice.stop - 1, i_n, n_seg))
386+
# backward
387+
segments_connexion[n_seg][2].append((i_n, i_slice.stop - 1, seg))
388+
389+
mask = zeros(segment.size, dtype=bool)
390+
391+
def func_forward(seg, indice):
392+
seg_slice, _forward, _ = segments_connexion[seg]
393+
394+
mask[indice : seg_slice.stop] = True
395+
for i_begin, i_end, seg2 in _forward:
396+
if i_begin < indice:
397+
continue
398+
399+
if not mask[i_end]:
400+
func_forward(seg2, i_end)
401+
402+
def func_backward(seg, indice):
403+
seg_slice, _, _backward = segments_connexion[seg]
404+
405+
mask[seg_slice.start : indice + 1] = True
406+
for i_begin, i_end, seg2 in _backward:
407+
if i_begin > indice:
408+
continue
409+
410+
if not mask[i_end]:
411+
func_backward(seg2, i_end)
412+
413+
for indice in i_obs:
414+
if forward:
415+
func_forward(segment[indice], indice)
416+
417+
if backward:
418+
func_backward(segment[indice], indice)
419+
420+
return self.extract_with_mask(mask)
421+
326422
def connexions(self, multi_network=False):
327423
"""
328424
create dictionnary for each segments, gives the segments which interact with

0 commit comments

Comments
 (0)