Skip to content

Commit f52783d

Browse files
committed
Modify remove dead end
speed up extract_segment
1 parent 9d408e5 commit f52783d

File tree

2 files changed

+76
-49
lines changed

2 files changed

+76
-49
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ and this project adheres to `Semantic Versioning <https://semver.org/spec/v2.0.0
1111
Changed
1212
^^^^^^^
1313

14+
- Remove dead end method for network will move dead end to the trash and not remove observations
15+
1416
Fixed
1517
^^^^^
1618

src/py_eddy_tracker/observations/network.py

Lines changed: 74 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,14 @@ def __repr__(self):
117117
m_event, s_event = self.merging_event(only_index=True, triplet=True)[0], self.splitting_event(only_index=True, triplet=True)[0]
118118
period = (self.period[1] - self.period[0]) / 365.25
119119
nb_by_network = self.network_size()
120+
nb_trash = 0 if self.ref_index != 0 else nb_by_network[0]
120121
big = 50_000
121122
infos = [
122123
f"Atlas with {self.nb_network} networks ({self.nb_network / period:0.0f} networks/year),"
123124
f" {self.nb_segment} segments ({self.nb_segment / period:0.0f} segments/year), {len(self)} observations ({len(self) / period:0.0f} observations/year)",
124125
f" {m_event.size} merging ({m_event.size / period:0.0f} merging/year), {s_event.size} splitting ({s_event.size / period:0.0f} splitting/year)",
125126
f" with {(nb_by_network > big).sum()} network with more than {big} obs and the biggest have {nb_by_network.max()} observations ({nb_by_network[nb_by_network> big].sum()} observations cumulate)",
126-
f" {nb_by_network[0]} observations in trash"
127+
f" {nb_trash} observations in trash"
127128
]
128129
return "\n".join(infos)
129130

@@ -369,26 +370,29 @@ def correct_close_events(self, nb_days_max=20):
369370

370371
# we keep the real segment number
371372
seg_corrected_copy = segment_copy[seg_slice.stop - 1]
373+
if i_seg_n == -1:
374+
continue
372375

376+
# if segment is split
373377
n_seg = segment[i_seg_n]
374378

375-
# if segment is split
376-
if i_seg_n != -1:
377-
seg2_slice, i2_seg_p, i2_seg_n = segments_connexion[n_seg]
378-
p2_seg = segment[i2_seg_p]
379-
380-
# if it merges on the first in a certain time
381-
if (p2_seg == seg_corrected) and (
382-
_time[i_seg_n] - _time[i2_seg_p] < nb_days_max
383-
):
384-
my_slice = slice(i_seg_n, seg2_slice.stop)
385-
# correct the factice segment
386-
segment[my_slice] = seg_corrected
387-
# correct the good segment
388-
segment_copy[my_slice] = seg_corrected_copy
389-
previous_obs[i_seg_n] = seg_slice.stop - 1
390-
391-
segments_connexion[seg_corrected][0] = my_slice
379+
seg2_slice, i2_seg_p, _ = segments_connexion[n_seg]
380+
if i2_seg_p == -1:
381+
continue
382+
p2_seg = segment[i2_seg_p]
383+
384+
# if it merges on the first in a certain time
385+
if (p2_seg == seg_corrected) and (
386+
_time[i_seg_n] - _time[i2_seg_p] < nb_days_max
387+
):
388+
my_slice = slice(i_seg_n, seg2_slice.stop)
389+
# correct the factice segment
390+
segment[my_slice] = seg_corrected
391+
# correct the good segment
392+
segment_copy[my_slice] = seg_corrected_copy
393+
previous_obs[i_seg_n] = seg_slice.stop - 1
394+
395+
segments_connexion[seg_corrected][0] = my_slice
392396

393397
return self.sort()
394398

@@ -789,6 +793,8 @@ def display_timeline(
789793
colors_mode=colors_mode,
790794
)
791795
)
796+
if field is not None:
797+
field = self.parse_varname(field)
792798
for i, b0, b1 in self.iter_on("segment"):
793799
x = self.time[i]
794800
if x.shape[0] == 0:
@@ -797,9 +803,9 @@ def display_timeline(
797803
y = b0 * ones(x.shape)
798804
else:
799805
if method == "all":
800-
y = self[field][i] * factor
806+
y = field[i] * factor
801807
else:
802-
y = self[field][i].mean() * ones(x.shape) * factor
808+
y = field[i].mean() * ones(x.shape) * factor
803809

804810
if colors_mode == "roll":
805811
_color = self.get_color(j)
@@ -825,7 +831,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol
825831

826832
if field is not None and method != "all":
827833
for i, b0, _ in self.iter_on("segment"):
828-
y = self[field][i]
834+
y = self.parse_varname(field)[i]
829835
if y.shape[0] != 0:
830836
y_seg[b0] = y.mean() * factor
831837
mappables = dict()
@@ -851,7 +857,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol
851857
y0 = b0
852858
else:
853859
if method == "all":
854-
y0 = self[field][i.stop - 1] * factor
860+
y0 = self.parse_varname(field)[i.stop - 1] * factor
855861
else:
856862
y0 = y_seg[b0]
857863
if i_n != -1:
@@ -860,7 +866,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol
860866
seg_next
861867
if field is None
862868
else (
863-
self[field][i_n] * factor
869+
self.parse_varname(field)[i_n] * factor
864870
if method == "all"
865871
else y_seg[seg_next]
866872
)
@@ -876,7 +882,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol
876882
seg_previous
877883
if field is None
878884
else (
879-
self[field][i_p] * factor
885+
self.parse_varname(field)[i_p] * factor
880886
if method == "all"
881887
else y_seg[seg_previous]
882888
)
@@ -1446,35 +1452,54 @@ def remove_dead_end(self, nobs=3, ndays=0, recursive=0, mask=None):
14461452
.. warning::
14471453
It will remove short segment that splits from then merges with the same segment
14481454
"""
1449-
segments_keep = list()
14501455
connexions = self.connexions(multi_network=True)
1451-
t = self.time
1452-
for i, b0, _ in self.iter_on(self.segment_track_array):
1453-
if mask and mask[i].any():
1454-
segments_keep.append(b0)
1455-
continue
1456-
nb = i.stop - i.start
1457-
dt = t[i.stop - 1] - t[i.start]
1458-
if (nb < nobs or dt < ndays) and len(connexions.get(b0, tuple())) < 2:
1459-
continue
1460-
segments_keep.append(b0)
1456+
i0, i1, _ = self.index_segment_track
1457+
dt = self.time[i1 -1] - self.time[i0] + 1
1458+
nb = i1 - i0
1459+
m = (dt >= ndays) * (nb >= nobs)
1460+
nb_connexions = array([len(connexions.get(i, tuple())) for i in where(~m)[0]])
1461+
m[~m] = nb_connexions >= 2
1462+
segments_keep = where(m)[0]
1463+
if mask is not None:
1464+
segments_keep = unique(concatenate((segments_keep, self.segment_track_array[mask])))
1465+
# get mask for selected obs
1466+
m = ~self.segment_mask(segments_keep)
1467+
self.track[m] = 0
1468+
self.segment[m] = 0
1469+
self.previous_obs[m] = -1
1470+
self.previous_cost[m] = 0
1471+
self.next_obs[m] = -1
1472+
self.next_cost[m] = 0
1473+
1474+
m_previous = m[self.previous_obs]
1475+
self.previous_obs[m_previous] = -1
1476+
self.previous_cost[m_previous] = 0
1477+
m_next = m[self.next_obs]
1478+
self.next_obs[m_next] = -1
1479+
self.next_cost[m_next] = 0
1480+
1481+
self.sort()
14611482
if recursive > 0:
1462-
return self.extract_segment(segments_keep, absolute=True).remove_dead_end(
1463-
nobs, ndays, recursive - 1
1464-
)
1465-
return self.extract_segment(segments_keep, absolute=True)
1483+
self.remove_dead_end(nobs, ndays, recursive - 1)
14661484

14671485
def extract_segment(self, segments, absolute=False):
1468-
mask = ones(self.shape, dtype="bool")
1469-
segments = array(segments)
1470-
values = self.segment_track_array if absolute else "segment"
1471-
keep = ones(values.max() + 1, dtype="bool")
1472-
v = unique(values)
1473-
keep[v] = in1d(v, segments)
1474-
for i, b0, b1 in self.iter_on(values):
1475-
if not keep[b0]:
1476-
mask[i] = False
1477-
return self.extract_with_mask(mask)
1486+
"""Extract given segments
1487+
1488+
:param array,tuple,list segments: list of segment to extract
1489+
:param bool absolute: keep for compatibility, defaults to False
1490+
:return NetworkObservations: Return observations from selected segment
1491+
"""
1492+
if not absolute:
1493+
raise Exception("Not implemented")
1494+
return self.extract_with_mask(self.segment_mask(segments))
1495+
1496+
def segment_mask(self, segments):
1497+
"""Get mask from list of segment
1498+
1499+
:param list,array segments: absolute id of segment
1500+
"""
1501+
return generate_mask_from_ids(array(segments), len(self), *self.index_segment_track)
1502+
14781503

14791504
def get_mask_with_period(self, period):
14801505
"""

0 commit comments

Comments
 (0)