Skip to content

Commit 85cc6fa

Browse files
committed
speed up /generalize some network method
1 parent 9643697 commit 85cc6fa

File tree

4 files changed

+75
-35
lines changed

4 files changed

+75
-35
lines changed

examples/06_grid_manipulation/pet_lavd.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def update_axes(ax, mappable=None):
4444
orientation="horizontal",
4545
)
4646
cb.set_label("Vorticity integration along trajectory at initial position")
47+
return cb
4748

4849

4950
kw_vorticity = dict(vmin=0, vmax=2e-5, cmap="viridis")
@@ -90,7 +91,8 @@ def save(self, *args, **kwargs):
9091
# Display vorticity field
9192
fig, ax, _ = start_ax()
9293
mappable = g.display(ax, abs(g.grid("vort")), **kw_vorticity)
93-
update_axes(ax, mappable)
94+
cb = update_axes(ax, mappable)
95+
cb.set_label("Vorticity")
9496

9597
# %%
9698
# Particles

notebooks/python_module/02_eddy_identification/pet_eddy_detection.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@
253253
},
254254
"outputs": [],
255255
"source": [
256-
"ax = start_axes(\"Speed Radius (km)\")\nkwargs = dict(vmin=10, vmax=50, s=80, ref=-10, cmap=\"magma_r\", factor=0.001)\na.scatter(ax, \"radius_s\", **kwargs)\nm = c.scatter(\n ax, \"radius_s\", **kwargs\n)\nupdate_axes(ax, m)"
256+
"ax = start_axes(\"Speed Radius (km)\")\nkwargs = dict(vmin=10, vmax=50, s=80, ref=-10, cmap=\"magma_r\", factor=0.001)\na.scatter(ax, \"radius_s\", **kwargs)\nm = c.scatter(ax, \"radius_s\", **kwargs)\nupdate_axes(ax, m)"
257257
]
258258
},
259259
{
@@ -271,7 +271,7 @@
271271
},
272272
"outputs": [],
273273
"source": [
274-
"ax = start_axes(\"Effective Radius (km)\")\nkwargs = dict(vmin=10, vmax=80, cmap=\"magma_r\", factor=0.001, lut=14, ref=-10)\na.filled(ax, \"effective_radius\", **kwargs)\nm = c.filled(\n ax, \"radius_e\", **kwargs\n)\nupdate_axes(ax, m)"
274+
"ax = start_axes(\"Effective Radius (km)\")\nkwargs = dict(vmin=10, vmax=80, cmap=\"magma_r\", factor=0.001, lut=14, ref=-10)\na.filled(ax, \"effective_radius\", **kwargs)\nm = c.filled(ax, \"radius_e\", **kwargs)\nupdate_axes(ax, m)"
275275
]
276276
}
277277
],

notebooks/python_module/06_grid_manipulation/pet_lavd.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
},
3838
"outputs": [],
3939
"source": [
40-
"def start_ax(title=\"\", dpi=90):\n fig = plt.figure(figsize=(16, 9), dpi=dpi)\n ax = fig.add_axes([0, 0, 1, 1], projection=\"full_axes\")\n ax.set_xlim(0, 32), ax.set_ylim(28, 46)\n ax.set_title(title)\n return fig, ax, ax.text(3, 32, \"\", fontsize=20)\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n cb = plt.colorbar(\n mappable,\n cax=ax.figure.add_axes([0.05, 0.1, 0.9, 0.01]),\n orientation=\"horizontal\",\n )\n cb.set_label(\"Vorticity integration along trajectory at initial position\")\n\n\nkw_vorticity = dict(vmin=0, vmax=2e-5, cmap=\"viridis\")"
40+
"def start_ax(title=\"\", dpi=90):\n fig = plt.figure(figsize=(16, 9), dpi=dpi)\n ax = fig.add_axes([0, 0, 1, 1], projection=\"full_axes\")\n ax.set_xlim(0, 32), ax.set_ylim(28, 46)\n ax.set_title(title)\n return fig, ax, ax.text(3, 32, \"\", fontsize=20)\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n cb = plt.colorbar(\n mappable,\n cax=ax.figure.add_axes([0.05, 0.1, 0.9, 0.01]),\n orientation=\"horizontal\",\n )\n cb.set_label(\"Vorticity integration along trajectory at initial position\")\n return cb\n\n\nkw_vorticity = dict(vmin=0, vmax=2e-5, cmap=\"viridis\")"
4141
]
4242
},
4343
{
@@ -84,7 +84,7 @@
8484
},
8585
"outputs": [],
8686
"source": [
87-
"fig, ax, _ = start_ax()\nmappable = g.display(ax, abs(g.grid(\"vort\")), **kw_vorticity)\nupdate_axes(ax, mappable)"
87+
"fig, ax, _ = start_ax()\nmappable = g.display(ax, abs(g.grid(\"vort\")), **kw_vorticity)\ncb = update_axes(ax, mappable)\ncb.set_label(\"Vorticity\")"
8888
]
8989
},
9090
{

src/py_eddy_tracker/observations/network.py

Lines changed: 68 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from glob import glob
77

88
from numba import njit
9-
from numpy import arange, array, bincount, empty, ones, uint32, unique, zeros
9+
from numpy import arange, array, bincount, empty, in1d, ones, uint32, unique, zeros
1010

1111
from ..generic import build_index, wrap_longitude
1212
from ..poly import bbox_intersection, vertice_overlap
@@ -119,13 +119,12 @@ def longer_than(self, nb_day_min=-1, nb_day_max=-1):
119119
if nb_day_max < 0:
120120
nb_day_max = 1000000000000
121121
mask = zeros(self.shape, dtype="bool")
122-
for i, b0, b1 in self.iter_on(self.segment_track_array):
122+
t = self.time
123+
for i, b0, b1 in self.iter_on(self.track):
123124
nb = i.stop - i.start
124125
if nb == 0:
125126
continue
126-
t = self.time[i]
127-
dt = t.max() - t.min()
128-
if nb_day_min <= dt <= nb_day_max:
127+
if nb_day_min <= ptp(t[i]) <= nb_day_max:
129128
mask[i] = True
130129
return self.extract_with_mask(mask)
131130

@@ -164,21 +163,26 @@ def obs_relative_order(self, i_obs):
164163
self.only_one_network()
165164
return self.segment_relative_order(self.segment[i_obs])
166165

167-
def connexions(self):
168-
self.only_one_network()
166+
def connexions(self, multi_network=False):
167+
if multi_network:
168+
segment = self.segment_track_array
169+
else:
170+
self.only_one_network()
171+
segment = self.segment
169172
segments_connexion = dict()
170173

171174
def add_seg(father, child):
172175
if father not in segments_connexion:
173176
segments_connexion[father] = list()
174177
segments_connexion[father].append(child)
175178

176-
for i, seg, _ in self.iter_on("segment"):
179+
previous_obs, next_obs = self.previous_obs, self.next_obs
180+
for i, seg, _ in self.iter_on(segment):
177181
if i.start == i.stop:
178182
continue
179-
i_p, i_n = self.previous_obs[i.start], self.next_obs[i.stop - 1]
183+
i_p, i_n = previous_obs[i.start], next_obs[i.stop - 1]
180184
# segment of interaction
181-
p_seg, n_seg = self.segment[i_p], self.segment[i_n]
185+
p_seg, n_seg = segment[i_p], segment[i_n]
182186
# Where segment are called
183187
if i_p != -1:
184188
add_seg(p_seg, seg)
@@ -395,6 +399,26 @@ def map_segment(self, method, y, same=True, **kw):
395399
out = array(out)
396400
return out
397401

402+
def map_network(self, method, y, same=True, **kw):
403+
if same:
404+
out = empty(y.shape, **kw)
405+
else:
406+
out = list()
407+
for i, b0, b1 in self.iter_on(self.track):
408+
res = method(y[i])
409+
if same:
410+
out[i] = res
411+
else:
412+
if isinstance(i, slice):
413+
if i.start == i.stop:
414+
continue
415+
elif len(i) == 0:
416+
continue
417+
out.append(res)
418+
if not same:
419+
out = array(out)
420+
return out
421+
398422
def scatter_timeline(
399423
self,
400424
ax,
@@ -410,7 +434,7 @@ def scatter_timeline(
410434
Must be call on only one network
411435
"""
412436
self.only_one_network()
413-
y = (self.segment if yfield is None else self[yfield]) * yfactor
437+
y = (self.segment if yfield is None else self.parse_varname(yfield)) * yfactor
414438
if method == "all":
415439
pass
416440
else:
@@ -536,23 +560,25 @@ def segment_track_array(self):
536560
def birth_event(self):
537561
# FIXME how to manage group 0
538562
indices = list()
563+
previous_obs = self.previous_obs
539564
for i, _, _ in self.iter_on(self.segment_track_array):
540565
nb = i.stop - i.start
541566
if nb == 0:
542567
continue
543-
i_p = self.previous_obs[i.start]
568+
i_p = previous_obs[i.start]
544569
if i_p == -1:
545570
indices.append(i.start)
546571
return self.extract_event(list(set(indices)))
547572

548573
def death_event(self):
549574
# FIXME how to manage group 0
550575
indices = list()
576+
next_obs = self.next_obs
551577
for i, _, _ in self.iter_on(self.segment_track_array):
552578
nb = i.stop - i.start
553579
if nb == 0:
554580
continue
555-
i_n = self.next_obs[i.stop - 1]
581+
i_n = next_obs[i.stop - 1]
556582
if i_n == -1:
557583
indices.append(i.stop - 1)
558584
return self.extract_event(list(set(indices)))
@@ -567,16 +593,16 @@ def merging_event(self, triplet=False):
567593
if triplet:
568594
idx_m0_stop = list()
569595
idx_m0 = list()
570-
596+
next_obs, previous_obs = self.next_obs, self.previous_obs
571597
for i, _, _ in self.iter_on(self.segment_track_array):
572598
nb = i.stop - i.start
573599
if nb == 0:
574600
continue
575-
i_n = self.next_obs[i.stop - 1]
601+
i_n = next_obs[i.stop - 1]
576602
if i_n != -1:
577603
if triplet:
578604
idx_m0_stop.append(i.stop - 1)
579-
idx_m0.append(self.previous_obs[i_n])
605+
idx_m0.append(previous_obs[i_n])
580606
idx_m1.append(i_n)
581607

582608
if triplet:
@@ -598,15 +624,16 @@ def spliting_event(self, triplet=False):
598624
if triplet:
599625
idx_s1_start = list()
600626
idx_s1 = list()
627+
next_obs, previous_obs = self.next_obs, self.previous_obs
601628
for i, _, _ in self.iter_on(self.segment_track_array):
602629
nb = i.stop - i.start
603630
if nb == 0:
604631
continue
605-
i_p = self.previous_obs[i.start]
632+
i_p = previous_obs[i.start]
606633
if i_p != -1:
607634
if triplet:
608635
idx_s1_start.append(i.start)
609-
idx_s1.append(self.next_obs[i_p])
636+
idx_s1.append(next_obs[i_p])
610637
idx_s0.append(i_p)
611638
if triplet:
612639
return (
@@ -700,32 +727,38 @@ def plot(self, ax, ref=None, color_cycle=None, **kwargs):
700727
j += 1
701728
return mappables
702729

703-
def remove_dead_end(self, nobs=3, recursive=0, mask=None):
730+
def remove_dead_end(self, nobs=3, ndays=0, recursive=0, mask=None):
704731
"""
705732
.. warning::
706733
It will remove short segment which splits than merges with same segment
707734
"""
708-
self.only_one_network()
709735
segments_keep = list()
710-
connexions = self.connexions()
711-
for i, b0, _ in self.iter_on("segment"):
712-
nb = i.stop - i.start
736+
connexions = self.connexions(multi_network=True)
737+
t = self.time
738+
for i, b0, _ in self.iter_on(self.segment_track_array):
713739
if mask and mask[i].any():
714740
segments_keep.append(b0)
715741
continue
716-
if nb < nobs and len(connexions.get(b0, tuple())) < 2:
742+
nb = i.stop - i.start
743+
dt = t[i.stop - 1] - t[i.start]
744+
if (nb < nobs or dt < ndays) and len(connexions.get(b0, tuple())) < 2:
717745
continue
718746
segments_keep.append(b0)
719747
if recursive > 0:
720-
return self.extract_segment(segments_keep).remove_dead_end(
721-
nobs, recursive - 1
748+
return self.extract_segment(segments_keep, absolute=True).remove_dead_end(
749+
nobs, ndays, recursive - 1
722750
)
723-
return self.extract_segment(segments_keep)
751+
return self.extract_segment(segments_keep, absolute=True)
724752

725-
def extract_segment(self, segments):
753+
def extract_segment(self, segments, absolute=False):
726754
mask = ones(self.shape, dtype="bool")
727-
for i, b0, b1 in self.iter_on("segment"):
728-
if b0 not in segments:
755+
segments = array(segments)
756+
values = self.segment_track_array if absolute else "segment"
757+
keep = ones(values.max() + 1, dtype="bool")
758+
v = unique(values)
759+
keep[v] = in1d(v, segments)
760+
for i, b0, b1 in self.iter_on(values):
761+
if not keep[b0]:
729762
mask[i] = False
730763
return self.extract_with_mask(mask)
731764

@@ -929,3 +962,8 @@ def new_numbering(segs):
929962
s0 = segs[i]
930963
j += 1
931964
segs[i] = j
965+
966+
967+
@njit(cache=True)
968+
def ptp(values):
969+
return values.max() - values.min()

0 commit comments

Comments
 (0)