Skip to content

Commit 938e279

Browse files
committed
method change to speed up close track research
1 parent 9fc061f commit 938e279

File tree

3 files changed

+79
-27
lines changed

3 files changed

+79
-27
lines changed

src/py_eddy_tracker/observations/network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def close_network(self, other, nb_obs_min=10, **kwargs):
314314
p0, p1 = self.period
315315
indexs = list()
316316
for i_self, i_other, t0, t1 in self.align_on(other, bins=range(p0, p1 + 2)):
317-
i, j, s = self.index(i_self).match(other.index(i_other), **kwargs)
317+
i, j, s = self.match(other, i_self=i_self, i_other=i_other, **kwargs)
318318
indexs.append(other.re_reference_index(j, i_other))
319319
indexs = concatenate(indexs)
320320
tr, nb = unique(other.track[indexs], return_counts=True)

src/py_eddy_tracker/observations/observation.py

Lines changed: 77 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -488,8 +488,7 @@ def iter_on(self, xname, bins=None):
488488
489489
:param str,array xname:
490490
:param array bins: bounds of each bin ,
491-
:return: Group observations
492-
:rtype: self.__class__
491+
:return: index or mask, bound low, bound up
493492
"""
494493
x = self[xname] if isinstance(xname, str) else xname
495494
d = x[1:] - x[:-1]
@@ -498,19 +497,28 @@ def iter_on(self, xname, bins=None):
498497
elif not isinstance(bins, ndarray):
499498
bins = array(bins)
500499
nb_bins = len(bins) - 1
501-
i = numba_digitize(x, bins) - 1
500+
502501
# Not monotonous
503502
if (d < 0).any():
503+
# If bins cover a small part of value
504+
test, translate, x = iter_mode_reduce(x, bins)
505+
# convert value in bins number
506+
i = numba_digitize(x, bins) - 1
507+
# Order by bins
504508
i_sort = i.argsort()
509+
# If in reduce mode we will translate i_sort in full array index
510+
i_sort_ = translate[i_sort] if test else i_sort
511+
# Bound for each bins in sorting view
505512
i0, i1, _ = build_index(i[i_sort])
506513
m = ~(i0 == i1)
507514
i0, i1 = i0[m], i1[m]
508515
for i0_, i1_ in zip(i0, i1):
509516
i_bins = i[i_sort[i0_]]
510517
if i_bins == -1 or i_bins == nb_bins:
511518
continue
512-
yield i_sort[i0_:i1_], bins[i_bins], bins[i_bins + 1]
519+
yield i_sort_[i0_:i1_], bins[i_bins], bins[i_bins + 1]
513520
else:
521+
i = numba_digitize(x, bins) - 1
514522
i0, i1, _ = build_index(i)
515523
m = ~(i0 == i1)
516524
i0, i1 = i0[m], i1[m]
@@ -522,10 +530,8 @@ def align_on(self, other, var_name="time", **kwargs):
522530
"""
523531
Align the time indexes of two datasets.
524532
"""
525-
iter_self, iter_other = (
526-
self.iter_on(var_name, **kwargs),
527-
other.iter_on(var_name, **kwargs),
528-
)
533+
iter_self = self.iter_on(var_name, **kwargs)
534+
iter_other = other.iter_on(var_name, **kwargs)
529535
indexs_other, b0_other, b1_other = iter_other.__next__()
530536
for indexs_self, b0_self, b1_self in iter_self:
531537
if b0_self > b0_other:
@@ -1038,10 +1044,23 @@ def intern(flag, public_label=False):
10381044
labels = [VAR_DESCR[label]["nc_name"] for label in labels]
10391045
return labels
10401046

1041-
def match(self, other, method="overlap", intern=False, cmin=0, **kwargs):
1047+
def match(
1048+
self,
1049+
other,
1050+
i_self=None,
1051+
i_other=None,
1052+
method="overlap",
1053+
intern=False,
1054+
cmin=0,
1055+
**kwargs,
1056+
):
10421057
"""Return index and score computed on the effective contour.
10431058
10441059
:param EddiesObservations other: Observations to compare
1060+
:param array[bool,int],None i_self:
1061+
Index or mask to subset observations, it could avoid to build a specific dataset.
1062+
:param array[bool,int],None i_other:
1063+
Index or mask to subset observations, it could avoid to build a specific dataset.
10451064
:param str method:
10461065
- "overlap": the score is computed with contours;
10471066
- "circle": circles are computed and used for score (TODO)
@@ -1054,25 +1073,20 @@ def match(self, other, method="overlap", intern=False, cmin=0, **kwargs):
10541073
10551074
.. minigallery:: py_eddy_tracker.EddiesObservations.match
10561075
"""
1057-
# if method is "overlap" method will use contour to compute score,
1058-
# if method is "circle" method will apply a formula of circle overlap
10591076
x_name, y_name = self.intern(intern)
1077+
if i_self is None:
1078+
i_self = slice(None)
1079+
if i_other is None:
1080+
i_other = slice(None)
10601081
if method == "overlap":
1061-
i, j = bbox_intersection(
1062-
self[x_name], self[y_name], other[x_name], other[y_name]
1063-
)
1064-
c = vertice_overlap(
1065-
self[x_name][i],
1066-
self[y_name][i],
1067-
other[x_name][j],
1068-
other[y_name][j],
1069-
**kwargs,
1070-
)
1082+
x0, y0 = self[x_name][i_self], self[y_name][i_self]
1083+
x1, y1 = other[x_name][i_other], other[y_name][i_other]
1084+
i, j = bbox_intersection(x0, y0, x1, y1)
1085+
c = vertice_overlap(x0[i], y0[i], x1[j], y1[j], **kwargs)
10711086
elif method == "close_center":
1072-
i, j, c = close_center(
1073-
self.latitude, self.longitude, other.latitude, other.longitude, **kwargs
1074-
)
1075-
1087+
x0, y0 = self.longitude[i_self], self.latitude[i_self]
1088+
x1, y1 = other.longitude[i_other], other.latitude[i_other]
1089+
i, j, c = close_center(x0, y0, x1, y1, **kwargs)
10761090
m = c >= cmin # ajout >= pour garder la cmin dans la sélection
10771091
return i[m], j[m], c[m]
10781092

@@ -2438,3 +2452,41 @@ def numba_digitize(values, bins):
24382452
continue
24392453
out[i] = (v_ - bins[0]) / step + 1
24402454
return out
2455+
2456+
2457+
@njit(cache=True)
2458+
def iter_mode_reduce(x, bins):
2459+
"""
2460+
Test if we could use a reduce mode
2461+
2462+
:param array x: array to divide in group
2463+
:param array bins: array which defined bounds between each group
2464+
:return: If reduce mode, translator, and reduce x
2465+
"""
2466+
nb = x.shape[0]
2467+
# If we use less than half value
2468+
limit = nb // 2
2469+
# low and up
2470+
x0, x1 = bins[0], bins[-1]
2471+
m = empty(nb, dtype=numba_types.bool_)
2472+
# To count number of value cover by bins
2473+
c = 0
2474+
for i in range(nb):
2475+
x_ = x[i]
2476+
test = (x_ >= x0) * (x_ <= x1)
2477+
m[i] = test
2478+
if test:
2479+
c += 1
2480+
# If number value exceed limit
2481+
if c > limit:
2482+
return False, empty(0, dtype=numba_types.int_), x
2483+
# Indices to be able to translate in full index array
2484+
indices = empty(c, dtype=numba_types.int_)
2485+
x_ = empty(c, dtype=x.dtype)
2486+
j = 0
2487+
for i in range(nb):
2488+
if m[i]:
2489+
indices[j] = i
2490+
x_[j] = x[i]
2491+
j += 1
2492+
return True, indices, x_

src/py_eddy_tracker/observations/tracking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ def close_tracks(self, other, nb_obs_min=10, **kwargs):
570570
p0, p1 = self.period
571571
indexs = list()
572572
for i_self, i_other, t0, t1 in self.align_on(other, bins=range(p0, p1 + 2)):
573-
i, j, s = self.index(i_self).match(other.index(i_other), **kwargs)
573+
i, j, s = self.match(other, i_self=i_self, i_other=i_other, **kwargs)
574574
indexs.append(other.re_reference_index(j, i_other))
575575
indexs = concatenate(indexs)
576576
tr, nb = unique(other.track[indexs], return_counts=True)

0 commit comments

Comments
 (0)