Skip to content

Commit 943bbf3

Browse files
committed
Rewrite method to extract event for speed up
1 parent 3e73e63 commit 943bbf3

File tree

1 file changed

+133
-103
lines changed

1 file changed

+133
-103
lines changed

src/py_eddy_tracker/observations/network.py

Lines changed: 133 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,20 @@ def __init__(self, *args, **kwargs):
113113
super().__init__(*args, **kwargs)
114114
self.reset_index()
115115

116+
def __repr__(self):
117+
m_event, s_event = self.merging_event(only_index=True, triplet=True)[0], self.splitting_event(only_index=True, triplet=True)[0]
118+
period = (self.period[1] - self.period[0]) / 365.25
119+
nb_by_network = self.network_size()
120+
big = 50_000
121+
infos = [
122+
f"Atlas with {self.nb_network} networks ({self.nb_network / period:0.0f} networks/year),"
123+
f" {self.nb_segment} segments ({self.nb_segment / period:0.0f} segments/year), {len(self)} observations ({len(self) / period:0.0f} observations/year)",
124+
f" {m_event.size} merging ({m_event.size / period:0.0f} merging/year), {s_event.size} splitting ({s_event.size / period:0.0f} splitting/year)",
125+
f" with {(nb_by_network > big).sum()} network with more than {big} obs and the biggest have {nb_by_network.max()} observations ({nb_by_network[nb_by_network> big].sum()} observations cumulate)",
126+
f" {nb_by_network[0]} observations in trash"
127+
]
128+
return "\n".join(infos)
129+
116130
def reset_index(self):
117131
self._index_network = None
118132
self._index_segment_track = None
@@ -313,13 +327,19 @@ def correct_close_events(self, nb_days_max=20):
313327
"""
314328
Transform event where
315329
segment A splits from segment B, then x days after segment B merges with A
316-
317330
to
318-
319331
segment A splits from segment B then x days after segment A merges with B (B will be longer)
320-
321332
These events have to last less than `nb_days_max` to be changed.
322333
334+
335+
------------------- A
336+
/ /
337+
B --------------------
338+
to
339+
--A--
340+
/ \
341+
B -----------------------------------
342+
323343
:param float nb_days_max: maximum time to search for splitting-merging event
324344
"""
325345

@@ -342,7 +362,7 @@ def correct_close_events(self, nb_days_max=20):
342362
segments_connexion[seg] = [i, i_p, i_n]
343363

344364
for seg in sorted(segments_connexion.keys()):
345-
seg_slice, i_seg_p, i_seg_n = segments_connexion[seg]
365+
seg_slice, _, i_seg_n = segments_connexion[seg]
346366

347367
# the segment ID has to be corrected, because we may have changed it since
348368
seg_corrected = segment[seg_slice.stop - 1]
@@ -370,8 +390,6 @@ def correct_close_events(self, nb_days_max=20):
370390

371391
segments_connexion[seg_corrected][0] = my_slice
372392

373-
self.segment[:] = segment_copy
374-
self.previous_obs[:] = previous_obs
375393
return self.sort()
376394

377395
def sort(self, order=("track", "segment", "time")):
@@ -495,35 +513,38 @@ def func_backward(seg, indice):
495513
return self.extract_with_mask(mask)
496514

497515
def connexions(self, multi_network=False):
498-
"""
499-
Create dictionnary for each segment, gives the segments in interaction with
516+
"""Create dictionnary for each segment, gives the segments in interaction with
517+
518+
:param bool multi_network: use segment_track_array instead of segment, defaults to False
519+
:return dict: Return dict of set, for each seg id we get set of segment which have event with him
500520
"""
501521
if multi_network:
502522
segment = self.segment_track_array
503523
else:
504524
self.only_one_network()
505525
segment = self.segment
506526
segments_connexion = dict()
507-
508-
def add_seg(father, child):
509-
if father not in segments_connexion:
510-
segments_connexion[father] = set()
511-
segments_connexion[father].add(child)
512-
513-
previous_obs, next_obs = self.previous_obs, self.next_obs
514-
for i, seg, _ in self.iter_on(segment):
515-
if i.start == i.stop:
516-
continue
517-
i_p, i_n = previous_obs[i.start], next_obs[i.stop - 1]
518-
# segment in interaction
519-
p_seg, n_seg = segment[i_p], segment[i_n]
520-
# Where segment are called
521-
if i_p != -1:
522-
add_seg(p_seg, seg)
523-
add_seg(seg, p_seg)
524-
if i_n != -1:
525-
add_seg(n_seg, seg)
526-
add_seg(seg, n_seg)
527+
def add_seg(s1, s2):
528+
if s1 not in segments_connexion:
529+
segments_connexion[s1] = set()
530+
if s2 not in segments_connexion:
531+
segments_connexion[s2] = set()
532+
segments_connexion[s1].add(s2), segments_connexion[s2].add(s1)
533+
# Get index for each segment
534+
i0, i1, _ = self.index_segment_track
535+
i1 = i1 - 1
536+
# Check if segment merge
537+
i_next = self.next_obs[i1]
538+
m_n = i_next != -1
539+
# Check if segment come from splitting
540+
i_previous = self.previous_obs[i0]
541+
m_p = i_previous != -1
542+
# For each split
543+
for s1, s2 in zip(segment[i_previous[m_p]], segment[i0[m_p]]):
544+
add_seg(s1, s2)
545+
# For each merge
546+
for s1, s2 in zip(segment[i_next[m_n]], segment[i1[m_n]]):
547+
add_seg(s1, s2)
527548
return segments_connexion
528549

529550
@classmethod
@@ -1089,68 +1110,57 @@ def segment_track_array(self):
10891110
return self._segment_track_array
10901111

10911112
def birth_event(self):
1092-
"""Extract birth events.
1093-
Advice : individual eddies (self.track == 0) should be removed before -> apply remove_trash."""
1094-
# FIXME how to manage group 0
1095-
indices = list()
1096-
previous_obs = self.previous_obs
1097-
for i, _, _ in self.iter_on(self.segment_track_array):
1098-
nb = i.stop - i.start
1099-
if nb == 0:
1100-
continue
1101-
i_p = previous_obs[i.start]
1102-
if i_p == -1:
1103-
indices.append(i.start)
1104-
return self.extract_event(list(set(indices)))
1113+
"""Extract birth events."""
1114+
i_start, _, _ = self.index_segment_track
1115+
indices = i_start[self.previous_obs[i_start] == -1]
1116+
if self.first_is_trash():
1117+
indices = indices[1:]
1118+
return self.extract_event(indices)
1119+
generation_event = birth_event
11051120

11061121
def death_event(self):
1107-
"""Extract death events.
1108-
Advice : individual eddies (self.track == 0) should be removed before -> apply remove_trash."""
1109-
# FIXME how to manage group 0
1110-
indices = list()
1111-
next_obs = self.next_obs
1112-
for i, _, _ in self.iter_on(self.segment_track_array):
1113-
nb = i.stop - i.start
1114-
if nb == 0:
1115-
continue
1116-
i_n = next_obs[i.stop - 1]
1117-
if i_n == -1:
1118-
indices.append(i.stop - 1)
1119-
return self.extract_event(list(set(indices)))
1122+
"""Extract death events."""
1123+
_, i_stop, _ = self.index_segment_track
1124+
indices = i_stop[self.next_obs[i_stop - 1] == -1] - 1
1125+
if self.first_is_trash():
1126+
indices = indices[1:]
1127+
return self.extract_event(indices)
1128+
dissipation_event = death_event
11201129

11211130
def merging_event(self, triplet=False, only_index=False):
11221131
"""Return observation after a merging event.
11231132
11241133
If `triplet=True` return the eddy after a merging event, the eddy before the merging event,
11251134
and the eddy stopped due to merging.
11261135
"""
1127-
idx_m1 = list()
1136+
# Get start and stop for each segment, there is no empty segment
1137+
_, i1, _ = self.index_segment_track
1138+
# Get last index for each segment
1139+
i_stop = i1 - 1
1140+
# Get target index
1141+
idx_m1 = self.next_obs[i_stop]
1142+
# Get mask and valid target
1143+
m = idx_m1 != -1
1144+
idx_m1 = idx_m1[m]
1145+
# Sort by time event
1146+
i = self.time[idx_m1].argsort()
1147+
idx_m1 = idx_m1[i]
11281148
if triplet:
1129-
idx_m0_stop = list()
1130-
idx_m0 = list()
1131-
next_obs, previous_obs = self.next_obs, self.previous_obs
1132-
for i, _, _ in self.iter_on(self.segment_track_array):
1133-
nb = i.stop - i.start
1134-
if nb == 0:
1135-
continue
1136-
i_n = next_obs[i.stop - 1]
1137-
if i_n != -1:
1138-
if triplet:
1139-
idx_m0_stop.append(i.stop - 1)
1140-
idx_m0.append(previous_obs[i_n])
1141-
idx_m1.append(i_n)
1149+
# Get obs before target
1150+
idx_m0_stop = i_stop[m][i]
1151+
idx_m0 = self.previous_obs[idx_m1].copy()
11421152

11431153
if triplet:
11441154
if only_index:
1145-
return array(idx_m1), array(idx_m0), array(idx_m0_stop)
1155+
return idx_m1, idx_m0, idx_m0_stop
11461156
else:
11471157
return (
11481158
self.extract_event(idx_m1),
11491159
self.extract_event(idx_m0),
11501160
self.extract_event(idx_m0_stop),
11511161
)
11521162
else:
1153-
idx_m1 = list(set(idx_m1))
1163+
idx_m1 = unique(idx_m1)
11541164
if only_index:
11551165
return idx_m1
11561166
else:
@@ -1162,25 +1172,24 @@ def splitting_event(self, triplet=False, only_index=False):
11621172
If `triplet=True` return the eddy before a splitting event, the eddy after the splitting event,
11631173
and the eddy starting due to splitting.
11641174
"""
1165-
idx_s0 = list()
1175+
# Get start and stop for each segment, there is no empty segment
1176+
i_start, _, _ = self.index_segment_track
1177+
# Get target index
1178+
idx_s0 = self.previous_obs[i_start]
1179+
# Get mask and valid target
1180+
m = idx_s0 != -1
1181+
idx_s0 = idx_s0[m]
1182+
# Sort by time event
1183+
i = self.time[idx_s0].argsort()
1184+
idx_s0 = idx_s0[i]
11661185
if triplet:
1167-
idx_s1_start = list()
1168-
idx_s1 = list()
1169-
next_obs, previous_obs = self.next_obs, self.previous_obs
1170-
for i, _, _ in self.iter_on(self.segment_track_array):
1171-
nb = i.stop - i.start
1172-
if nb == 0:
1173-
continue
1174-
i_p = previous_obs[i.start]
1175-
if i_p != -1:
1176-
if triplet:
1177-
idx_s1_start.append(i.start)
1178-
idx_s1.append(next_obs[i_p])
1179-
idx_s0.append(i_p)
1186+
# Get obs after target
1187+
idx_s1_start = i_start[m][i]
1188+
idx_s1 = self.next_obs[idx_s0].copy()
11801189

11811190
if triplet:
11821191
if only_index:
1183-
return array(idx_s0), array(idx_s1), array(idx_s1_start)
1192+
return idx_s0, idx_s1, idx_s1_start
11841193
else:
11851194
return (
11861195
self.extract_event(idx_s0),
@@ -1189,7 +1198,7 @@ def splitting_event(self, triplet=False, only_index=False):
11891198
)
11901199

11911200
else:
1192-
idx_s0 = list(set(idx_s0))
1201+
idx_s0 = unique(idx_s0)
11931202
if only_index:
11941203
return idx_s0
11951204
else:
@@ -1199,7 +1208,7 @@ def dissociate_network(self):
11991208
"""
12001209
Dissociate networks with no known interaction (splitting/merging)
12011210
"""
1202-
tags = self.tag_segment(multi_network=True)
1211+
tags = self.tag_segment()
12031212
if self.track[0] == 0:
12041213
tags -= 1
12051214
self.track[:] = tags[self.segment_track_array]
@@ -1345,16 +1354,22 @@ def __tag_segment(cls, seg, tag, groups, connexions):
13451354
# For each connexion we apply same function
13461355
cls.__tag_segment(seg, tag, groups, connexions)
13471356

1348-
def tag_segment(self, multi_network=False):
1349-
if multi_network:
1350-
nb = self.segment_track_array[-1] + 1
1351-
else:
1352-
nb = self.segment.max() + 1
1357+
def tag_segment(self):
1358+
"""For each segment, method give a new network id, and all segment are connected
1359+
1360+
:return array: for each unique seg id, it return new network id
1361+
"""
1362+
nb = self.segment_track_array[-1] + 1
13531363
sub_group = zeros(nb, dtype="u4")
1354-
c = self.connexions(multi_network=multi_network)
1364+
c = self.connexions(multi_network=True)
13551365
j = 1
13561366
# for each available id
13571367
for i in range(nb):
1368+
# No connexions, no need to explore
1369+
if i not in c:
1370+
sub_group[i] = j
1371+
j+= 1
1372+
continue
13581373
# Skip if already set
13591374
if sub_group[i] != 0:
13601375
continue
@@ -1363,15 +1378,31 @@ def tag_segment(self, multi_network=False):
13631378
j += 1
13641379
return sub_group
13651380

1381+
13661382
def fully_connected(self):
1383+
"""Suspicious
1384+
"""
1385+
raise Exception("Must be check")
13671386
self.only_one_network()
13681387
return self.tag_segment().shape[0] == 1
13691388

1389+
def first_is_trash(self):
1390+
"""Check if first network is Trash
1391+
1392+
:return bool: True if first network is trash
1393+
"""
1394+
i_start, i_stop, _ = self.index_segment_track
1395+
sl = slice(i_start[0], i_stop[0])
1396+
return (self.previous_obs[sl] == -1).all() and (self.next_obs[sl] == -1).all()
1397+
13701398
def remove_trash(self):
13711399
"""
13721400
Remove the lonely eddies (only 1 obs in segment, associated network number is 0)
13731401
"""
1374-
return self.extract_with_mask(self.track != 0)
1402+
if self.first_is_trash():
1403+
return self.extract_with_mask(self.track != 0)
1404+
else:
1405+
return self
13751406

13761407
def plot(self, ax, ref=None, color_cycle=None, **kwargs):
13771408
"""
@@ -1551,12 +1582,11 @@ def extract_with_mask(self, mask):
15511582
logger.debug(
15521583
f"{nb_obs} observations will be extracted ({nb_obs / self.shape[0]:.3%})"
15531584
)
1554-
for field in self.obs.dtype.descr:
1585+
for field in self.fields:
15551586
if field in ("next_obs", "previous_obs"):
15561587
continue
15571588
logger.debug("Copy of field %s ...", field)
1558-
var = field[0]
1559-
new.obs[var] = self.obs[var][mask]
1589+
new.obs[field] = self.obs[field][mask]
15601590
# n & p must be re-index
15611591
n, p = self.next_obs[mask], self.previous_obs[mask]
15621592
# we add 2 for -1 index return index -1
@@ -1682,9 +1712,9 @@ def date2file(julian_day):
16821712
16831713
return f"/tmp/dt_global_{date.strftime('%Y%m%d')}.nc"
16841714
"""
1685-
1686-
itb_final = -ones((self.obs.size, 2), dtype="i4")
1687-
ptb_final = zeros((self.obs.size, 2), dtype="i1")
1715+
shape = len(self), 2
1716+
itb_final = -ones(shape, dtype="i4")
1717+
ptb_final = zeros(shape, dtype="i1")
16881718

16891719
t_start, t_end = int(self.period[0]), int(self.period[1])
16901720

@@ -1760,9 +1790,9 @@ def date2file(julian_day):
17601790
17611791
return f"/tmp/dt_global_{date.strftime('%Y%m%d')}.nc"
17621792
"""
1763-
1764-
itf_final = -ones((self.obs.size, 2), dtype="i4")
1765-
ptf_final = zeros((self.obs.size, 2), dtype="i1")
1793+
shape = len(self), 2
1794+
itf_final = -ones(shape, dtype="i4")
1795+
ptf_final = zeros(shape, dtype="i1")
17661796

17671797
t_start, t_end = int(self.period[0]), int(self.period[1])
17681798

0 commit comments

Comments
 (0)