Skip to content

Commit 97caffa

Browse files
committed
Add method to remove dead end
1 parent 403fbbb commit 97caffa

File tree

2 files changed

+39
-47
lines changed

2 files changed

+39
-47
lines changed

examples/16_network/pet_relative.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from matplotlib import pyplot as plt
77

8+
import py_eddy_tracker.gui
89
from py_eddy_tracker import data
910
from py_eddy_tracker.observations.network import NetworkObservations
1011
from py_eddy_tracker.observations.tracking import TrackEddiesObservations
@@ -106,11 +107,8 @@
106107
# Remove dead branch
107108
# ------------------
108109
# Remove all tiny segment with less than N obs which didn't join two segments
109-
#
110-
# .. warning::
111-
# Must be explore, no solution to solve all the case
112110

113-
n_clean = n.remove_dead_branch(nobs=51)
111+
n_clean = n.remove_dead_end(nobs=10)
114112
fig = plt.figure(figsize=(15, 8))
115113
ax = fig.add_axes([0.04, 0.54, 0.90, 0.40])
116114
ax.set_title(f"Original network ({n.infos()})")

src/py_eddy_tracker/observations/network.py

Lines changed: 37 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -69,16 +69,20 @@ def __init__(self, *args, **kwargs):
6969
super().__init__(*args, **kwargs)
7070
self._index_network = None
7171

72+
@property
73+
def index_network(self):
74+
if self._index_network is None:
75+
self._index_network = build_index(self.track)
76+
return self._index_network
77+
7278
def network_slice(self, id_network):
7379
"""
7480
Return slice for one network
7581
7682
:param int id_network: id to identify network
7783
"""
78-
if self._index_network is None:
79-
self._index_network = build_index(self.track)
80-
i = id_network - self._index_network[2]
81-
i_start, i_stop = self._index_network[0][i], self._index_network[1][i]
84+
i = id_network - self.index_network[2]
85+
i_start, i_stop = self.index_network[0][i], self.index_network[1][i]
8286
return slice(i_start, i_stop)
8387

8488
@property
@@ -228,8 +232,9 @@ def only_one_network(self):
228232
Raise a warning or error?
229233
if there are more than one network
230234
"""
231-
# TODO
232-
pass
235+
_, i_start, _ = self.index_network
236+
if len(i_start) > 1:
237+
raise Exception("Several network")
233238

234239
def position_filter(self, median_half_window, loess_half_window):
235240
self.median_filter(median_half_window, "time", "lon").loess_filter(
@@ -568,17 +573,19 @@ def dissociate_network(self):
568573
# FIXME : Ok if only one network
569574
self.track[:] = tags[self.segment - 1]
570575

571-
self.obs.sort(order=("track", "segment", "time"))
576+
i_sort = self.obs.argsort(order=("track", "segment", "time"), kind="mergesort")
577+
# Sort directly obs, with hope to save memory
578+
self.obs.sort(order=("track", "segment", "time"), kind="mergesort")
572579
self._index_network = None
573580

574-
# FIXME
575581
# n & p must be re-index
576-
# n, p = self.next_obs[mask], self.previous_obs[mask]
582+
n, p = self.next_obs, self.previous_obs
577583
# we add 2 for -1 index return index -1
578-
# translate = -ones(len(self) + 1, dtype="i4")
579-
# translate[:-1][mask] = arange(nb_obs)
580-
# new.next_obs[:] = translate[n]
581-
# new.previous_obs[:] = translate[p]
584+
nb_obs = len(self)
585+
translate = -ones(nb_obs + 1, dtype="i4")
586+
translate[:-1][i_sort] = arange(nb_obs)
587+
self.next_obs[:] = translate[n]
588+
self.previous_obs[:] = translate[p]
582589

583590
def network(self, id_network):
584591
return self.extract_with_mask(self.network_slice(id_network))
@@ -640,40 +647,27 @@ def plot(self, ax, ref=None, color_cycle=None, **kwargs):
640647
j += 1
641648
return mappables
642649

643-
def remove_dead_branch(self, nobs=3):
644-
""""""
645-
# TODO: bug when spliting
650+
def remove_dead_end(self, nobs=3, recursive=0, mask=None):
651+
"""
652+
.. warning::
653+
It will remove short segment which splits than merges with same segment
654+
"""
646655
self.only_one_network()
647-
648656
segments_keep = list()
649-
interaction_segments = dict()
650-
segments_connexion = dict()
657+
connexions = self.connexions()
651658
for i, b0, b1 in self.iter_on("segment"):
652659
nb = i.stop - i.start
653-
i_p, i_n = self.previous_obs[i.start], self.next_obs[i.stop - 1]
654-
seg = self.segment[i.start]
655-
# segment of interaction
656-
p_seg, n_seg = self.segment[i_p], self.segment[i_n]
657-
if nb >= nobs:
658-
segments_keep.append(seg)
659-
else:
660-
interaction_segments[seg] = (
661-
p_seg if i_p != -1 else -1,
662-
n_seg if i_n != -1 else -1,
663-
)
664-
# Where segment are called
665-
if i_p != -1:
666-
if p_seg not in segments_connexion:
667-
segments_connexion[p_seg] = list()
668-
segments_connexion[p_seg].append(seg)
669-
if i_n != -1:
670-
if n_seg not in segments_connexion:
671-
segments_connexion[n_seg] = list()
672-
segments_connexion[n_seg].append(seg)
673-
print(interaction_segments)
674-
print(segments_connexion)
675-
print(segments_keep)
676-
return self.extract_segment(tuple(segments_keep))
660+
if mask and mask[i].any():
661+
segments_keep.append(b0)
662+
continue
663+
if nb < nobs and len(connexions.get(b0, tuple())) < 2:
664+
continue
665+
segments_keep.append(b0)
666+
if recursive > 0:
667+
return self.extract_segment(segments_keep).remove_dead_end(
668+
nobs, recursive - 1
669+
)
670+
return self.extract_segment(segments_keep)
677671

678672
def extract_segment(self, segments):
679673
mask = ones(self.shape, dtype="bool")

0 commit comments

Comments
 (0)