Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions examples/16_network/pet_relative.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,51 @@
ax.set_title(f"Clean network ({n_clean.infos()})")
_ = n_clean.display_timeline(ax)


# %%
# change splittint-merging events
# ------------------
# change event where seg A split to B, then A merge into B, to A split to B then B merge into A
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 12), dpi=120)

ax1.set_title(f"Clean network ({n_clean.infos()})")
n_clean.display_timeline(ax1)

clean_modified = n_clean.copy()
clean_modified.correct_close_events(100)
ax2.set_title(f"resplitted network ({clean_modified.infos()})")
_ = clean_modified.display_timeline(ax2)


# %%
# keep only observations where water could propagate from an observation
# ----------------------------------------------------------------------

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 12), dpi=120)
i_observation = 600
only_linked = n_clean.find_link(i_observation)

for ax, dataset in zip([ax1, ax2], [n_clean, only_linked]):
dataset.display_timeline(
ax, field="segment", marker="+", lw=2, markersize=5, colors_mode="y"
)
ax.scatter(
n_clean.time[i_observation],
n_clean.segment[i_observation],
marker="s",
s=50,
color="black",
zorder=200,
label="observation start",
alpha=0.6,
)
ax.legend()

ax1.set_title(f"full example ({n_clean.infos()})")
ax2.set_title(f"only linked observations ({only_linked.infos()})")
ax2.set_xlim(ax1.get_xlim())
ax2.set_ylim(ax1.get_ylim())

# %%
# For further figure we will use clean path
n = n_clean
Expand Down
188 changes: 184 additions & 4 deletions src/py_eddy_tracker/observations/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,188 @@ def from_split_network(cls, group_dataset, indexs, **kwargs):
def infos(self, label=""):
return f"{len(self)} obs {unique(self.segment).shape[0]} segments"

def correct_close_events(self, nb_days_max=20):
"""
transform event where
segment A split to B, then A merge into B

to

segment A split to B, then B merge to A

these events are filtered with `nb_days_max`, which the event have to take place in less than `nb_days_max`

:param float nb_days_max: maximum time to search for splitting-merging event
"""

_time = self.time
# segment used to correct and track changes
segment = self.segment_track_array.copy()
# final segment used to copy into self.segment
segment_copy = self.segment

segments_connexion = dict()

previous_obs, next_obs = self.previous_obs, self.next_obs

# record for every segments, the slice, indice of next obs & indice of previous obs
for i, seg, _ in self.iter_on(segment):
if i.start == i.stop:
continue

i_p, i_n = previous_obs[i.start], next_obs[i.stop - 1]
segments_connexion[seg] = [i, i_p, i_n]

for seg in sorted(segments_connexion.keys()):
seg_slice, i_seg_p, i_seg_n = segments_connexion[seg]

# the segment ID has to be corrected, because we may have changed it since
seg_corrected = segment[seg_slice.stop - 1]

# we keep the real segment number
seg_corrected_copy = segment_copy[seg_slice.stop - 1]

n_seg = segment[i_seg_n]

# if segment has splitting
if i_seg_n != -1:
seg2_slice, i2_seg_p, i2_seg_n = segments_connexion[n_seg]
p2_seg = segment[i2_seg_p]

# if it merge on the first in a certain time
if (p2_seg == seg_corrected) and (
_time[i_seg_n] - _time[i2_seg_p] < nb_days_max
):
my_slice = slice(i_seg_n, seg2_slice.stop)
# correct the factice segment
segment[my_slice] = seg_corrected
# correct the good segment
segment_copy[my_slice] = seg_corrected_copy
previous_obs[i_seg_n] = seg_slice.stop - 1

segments_connexion[seg_corrected][0] = my_slice

self.segment[:] = segment_copy
self.previous_obs[:] = previous_obs

self.sort()

def sort(self, order=("track", "segment", "time")):
"""
sort observations

:param tuple order: order or sorting. Passed to `np.argsort`
"""

index_order = self.obs.argsort(order=order)
for field in self.elements:
self[field][:] = self[field][index_order]

translate = -ones(index_order.max() + 2, dtype="i4")
translate[index_order] = arange(index_order.shape[0])
self.next_obs[:] = translate[self.next_obs]
self.previous_obs[:] = translate[self.previous_obs]

def obs_relative_order(self, i_obs):
self.only_one_network()
return self.segment_relative_order(self.segment[i_obs])

def find_link(self, i_observations, forward=True, backward=False):
"""
find all observations where obs `i_observation` could be
in future or past.

if forward=True, search all observation where water
from obs "i_observation" could go

if backward=True, search all observation
where water from obs `i_observation` could come from

:param int,iterable(int) i_observation:
indices of observation. Can be
int, or iterable of int.
:param bool forward, backward:
if forward, search observations after obs.
else mode==backward search before obs

"""

i_obs = (
[i_observations]
if not hasattr(i_observations, "__iter__")
else i_observations
)

segment = self.segment_track_array
previous_obs, next_obs = self.previous_obs, self.next_obs

segments_connexion = dict()

for i_slice, seg, _ in self.iter_on(segment):
if i_slice.start == i_slice.stop:
continue

i_p, i_n = previous_obs[i_slice.start], next_obs[i_slice.stop - 1]
p_seg, n_seg = segment[i_p], segment[i_n]

# dumping slice into dict
if seg not in segments_connexion:
segments_connexion[seg] = [i_slice, [], []]
else:
segments_connexion[seg][0] = i_slice

if i_p != -1:

if p_seg not in segments_connexion:
segments_connexion[p_seg] = [None, [], []]

# backward
segments_connexion[seg][2].append((i_slice.start, i_p, p_seg))
# forward
segments_connexion[p_seg][1].append((i_p, i_slice.start, seg))

if i_n != -1:
if n_seg not in segments_connexion:
segments_connexion[n_seg] = [None, [], []]

# forward
segments_connexion[seg][1].append((i_slice.stop - 1, i_n, n_seg))
# backward
segments_connexion[n_seg][2].append((i_n, i_slice.stop - 1, seg))

mask = zeros(segment.size, dtype=bool)

def func_forward(seg, indice):
seg_slice, _forward, _ = segments_connexion[seg]

mask[indice : seg_slice.stop] = True
for i_begin, i_end, seg2 in _forward:
if i_begin < indice:
continue

if not mask[i_end]:
func_forward(seg2, i_end)

def func_backward(seg, indice):
seg_slice, _, _backward = segments_connexion[seg]

mask[seg_slice.start : indice + 1] = True
for i_begin, i_end, seg2 in _backward:
if i_begin > indice:
continue

if not mask[i_end]:
func_backward(seg2, i_end)

for indice in i_obs:
if forward:
func_forward(segment[indice], indice)

if backward:
func_backward(segment[indice], indice)

return self.extract_with_mask(mask)

def connexions(self, multi_network=False):
"""
create dictionnary for each segments, gives the segments which interact with
Expand Down Expand Up @@ -490,14 +668,16 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol

# TODO : fill mappables dict
y_seg = dict()
_time = self.time

if field is not None and method != "all":
for i, b0, _ in self.iter_on("segment"):
y = self[field][i]
if y.shape[0] != 0:
y_seg[b0] = y.mean() * factor
mappables = dict()
for i, b0, b1 in self.iter_on("segment"):
x = self.time[i]
x = _time[i]
if x.shape[0] == 0:
continue

Expand Down Expand Up @@ -532,7 +712,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol
else y_seg[seg_next]
)
)
ax.plot((x[-1], self.time[i_n]), (y0, y1), **event_kw)[0]
ax.plot((x[-1], _time[i_n]), (y0, y1), **event_kw)[0]
events["merging"].append((x[-1], y0))

if i_p != -1:
Expand All @@ -548,7 +728,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol
else y_seg[seg_previous]
)
)
ax.plot((x[0], self.time[i_p]), (y0, y1), **event_kw)[0]
ax.plot((x[0], _time[i_p]), (y0, y1), **event_kw)[0]
events["spliting"].append((x[0], y0))

j += 1
Expand Down Expand Up @@ -1045,7 +1225,7 @@ def extract_with_mask(self, mask):
logger.warning("Empty dataset will be created")
else:
logger.info(
f"{nb_obs} observations will be extract ({nb_obs * 100. / self.shape[0]}%)"
f"{nb_obs} observations will be extract ({nb_obs / self.shape[0]:.3%})"
)
for field in self.obs.dtype.descr:
if field in ("next_obs", "previous_obs"):
Expand Down