Skip to content

Commit f1fc005

Browse files
committed
Speed up iter_on with numba function
1 parent 1633dfa commit f1fc005

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

src/py_eddy_tracker/observations/observation.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
reverse_index,
5454
bbox_indice_regular,
5555
hist_numba,
56+
build_index,
5657
)
5758
from ..poly import (
5859
bbox_intersection,
@@ -159,7 +160,7 @@ def __init__(
159160
raise Exception("Unknown element : %s" % elt)
160161
self.observations = zeros(size, dtype=self.dtype)
161162
self.sign_type = None
162-
163+
163164
@property
164165
def tracks(self):
165166
return self.observations["track"]
@@ -452,18 +453,22 @@ def iter_on(self, xname, bins=None):
452453
i = digitize(x, bins) - 1
453454
# Not monotonous
454455
if (d < 0).any():
455-
for i_ in unique(i):
456-
if i_ == -1 or i_ == nb_bins:
456+
i_sort = i.argsort()
457+
i0, i1, _ = build_index(i[i_sort])
458+
m = ~(i0 == i1)
459+
i0, i1 = i0[m], i1[m]
460+
for i0_, i1_ in zip(i0, i1):
461+
i_bins = i[i_sort[i0_]]
462+
if i_bins == -1 or i_bins == nb_bins:
457463
continue
458-
index = where(i_ == i)[0]
459-
yield index, bins[i_], bins[i_ + 1]
464+
yield i_sort[i0_:i1_], bins[i_bins], bins[i_bins + 1]
460465
else:
461-
# TODO : need improvement
462-
for i_ in unique(i):
463-
if i_ == -1 or i_ == nb_bins:
464-
continue
465-
index = where(i_ == i)[0]
466-
yield slice(index[0], index[-1] + 1), bins[i_], bins[i_ + 1]
466+
i0, i1, _ = build_index(i)
467+
m = ~(i0 == i1)
468+
i0, i1 = i0[m], i1[m]
469+
for i0_, i1_ in zip(i0, i1):
470+
i_bins = i[i0_]
471+
yield slice(i0_, i1_), bins[i_bins], bins[i_bins + 1]
467472

468473
def align_on(self, other, var_name="time", **kwargs):
469474
"""
@@ -703,6 +708,7 @@ def load_from_netcdf(
703708
nb_obs = len(h_nc.dimensions[obs_dim])
704709
if indexs is not None and obs_dim in indexs:
705710
sl = indexs[obs_dim]
711+
sl = slice(sl.start, min(sl.stop, nb_obs))
706712
if sl.stop is not None:
707713
nb_obs = sl.stop
708714
if sl.start is not None:

0 commit comments

Comments
 (0)