Skip to content

Commit 7b62de8

Browse files
committed
speed up of iter_on with numba on digitize
1 parent 332e016 commit 7b62de8

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

src/py_eddy_tracker/observations/observation.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ def iter_on(self, xname, bins=None):
491491
if bins is None:
492492
bins = arange(x.min(), x.max() + 2)
493493
nb_bins = len(bins) - 1
494-
i = digitize(x, bins) - 1
494+
i = numba_digitize(x, bins) - 1
495495
# Not monotonous
496496
if (d < 0).any():
497497
i_sort = i.argsort()
@@ -2380,3 +2380,30 @@ def sum_row_column(mask):
23802380
row_sum[i] += 1
23812381
column_sum[j] += 1
23822382
return row_sum, column_sum
2383+
2384+
2385+
@njit(cache=True)
2386+
def numba_digitize(values, bins):
2387+
# Check if bins are regular
2388+
nb_bins = bins.shape[0]
2389+
step = bins[1] - bins[0]
2390+
bin_previous = bins[1]
2391+
for i in range(2, nb_bins):
2392+
bin_current = bins[i]
2393+
if step != (bin_current - bin_previous):
2394+
# If bins are not regular
2395+
return digitize(values, bins)
2396+
bin_previous = bin_current
2397+
nb_values = values.shape[0]
2398+
out = empty(nb_values, dtype=numba_types.int64)
2399+
up, down = bins[0], bins[-1]
2400+
for i in range(nb_values):
2401+
v_ = values[i]
2402+
if v_ >= down:
2403+
out[i] = nb_bins
2404+
continue
2405+
if v_ < up:
2406+
out[i] = 0
2407+
continue
2408+
out[i] = (v_ - bins[0]) / step + 1
2409+
return out

0 commit comments

Comments
 (0)