Skip to content

Commit abd4433

Browse files
committed
-lazy cube management
-event statistics
1 parent e31d0a7 commit abd4433

File tree

4 files changed

+194
-29
lines changed

4 files changed

+194
-29
lines changed

src/py_eddy_tracker/dataset/grid.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -304,14 +304,26 @@ def __init__(
304304
"We assume pixel position of grid is centered for %s", filename
305305
)
306306
if not unset:
307-
self.load_general_features()
308-
self.load()
307+
self.populate()
309308

310309
def populate(self):
311310
if self.dimensions is None:
312311
self.load_general_features()
313312
self.load()
314313

314+
def clean(self):
315+
self.dimensions = None
316+
self.variables_description = None
317+
self.global_attrs = None
318+
self.x_c = None
319+
self.y_c = None
320+
self.x_bounds = None
321+
self.y_bounds = None
322+
self.x_dim = None
323+
self.y_dim = None
324+
self.contours = None
325+
self.vars = dict()
326+
315327
@property
316328
def is_centered(self):
317329
"""Give True if pixel is described with its center's position or
@@ -429,7 +441,7 @@ def c_to_bounds(c):
429441
def setup_coordinates(self):
430442
x_name, y_name = self.coordinates
431443
if self.is_centered:
432-
logger.info("Grid center")
444+
# logger.info("Grid center")
433445
self.x_c = self.vars[x_name].astype("float64")
434446
self.y_c = self.vars[y_name].astype("float64")
435447

@@ -1968,14 +1980,21 @@ def interp(self, grid_name, lons, lats, method="bilinear"):
19681980
self.x_c, self.y_c, g, m, lons, lats, nearest=method == "nearest"
19691981
)
19701982

1971-
def uv_for_advection(self, u_name, v_name, time_step=600, backward=False, factor=1):
1983+
def uv_for_advection(self, u_name=None, v_name=None, time_step=600, h_name=None, backward=False, factor=1):
19721984
"""
19731985
Get U,V to be used in degrees with precomputed time step
19741986
1975-
:param str,array u_name: U field to advect obs
1976-
:param str,array v_name: V field to advect obs
1987+
:param None,str,array u_name: U field to advect obs, if h_name is None
1988+
:param None,str,array v_name: V field to advect obs, if h_name is None
1989+
:param None,str,array h_name: H field to compute UV to advect obs, if u_name and v_name are None
19771990
:param int time_step: Number of second for each advection
19781991
"""
1992+
if h_name is not None:
1993+
u_name, v_name = 'u', 'v'
1994+
if u_name not in self.vars:
1995+
self.add_uv(h_name)
1996+
self.vars.pop(h_name, None)
1997+
19791998
u = (self.grid(u_name) if isinstance(u_name, str) else u_name).copy()
19801999
v = (self.grid(v_name) if isinstance(v_name, str) else v_name).copy()
19812000
# N seconds / 1 degrees in m
@@ -2318,6 +2337,14 @@ def from_netcdf_list(
23182337
new.datasets.append((_t, d))
23192338
return new
23202339

2340+
@property
2341+
def are_loaded(self):
2342+
return ~array([d.dimensions is None for _, d in self.datasets])
2343+
2344+
def __repr__(self):
2345+
nb_dataset = len(self.datasets)
2346+
return f"{self.are_loaded.sum()}/{nb_dataset} datasets loaded"
2347+
23212348
def shift_files(self, t, filename, heigth=None, **rgd_kwargs):
23222349
"""Add next file to the list and remove the oldest"""
23232350

@@ -2440,36 +2467,42 @@ def filament(
24402467
t += dt
24412468
yield t, f_x, f_y
24422469

2470+
def reset_grids(self, N=None):
2471+
if N is not None:
2472+
m = self.are_loaded
2473+
if m.sum() > N:
2474+
for i in where(m)[0]:
2475+
self.datasets[i][1].clean()
2476+
24432477
def advect(
24442478
self,
24452479
x,
24462480
y,
2447-
u_name,
2448-
v_name,
24492481
t_init,
24502482
mask_particule=None,
24512483
nb_step=10,
24522484
time_step=600,
24532485
rk4=True,
2486+
reset_grid=None,
24542487
**kw,
24552488
):
24562489
"""
24572490
At each call it will update position in place with u & v field
24582491
24592492
:param array x: Longitude of obs to move
24602493
:param array y: Latitude of obs to move
2461-
:param str,array u_name: U field to advect obs
2462-
:param str,array v_name: V field to advect obs
24632494
:param float t_init: time to start advection
24642495
:param array,None mask_particule: advect only i mask is True
24652496
:param int nb_step: Number of iteration before to release data
24662497
:param int time_step: Number of second for each advection
24672498
:param bool rk4: Use rk4 algorithm instead of finite difference
2499+
:param int reset_grid: Delete all loaded data in cube if there are more than N grid loaded
24682500
24692501
:return: t,x,y position
24702502
24712503
.. minigallery:: py_eddy_tracker.GridCollection.advect
24722504
"""
2505+
self.reset_grids(reset_grid)
24732506
backward = kw.get("backward", False)
24742507
if backward:
24752508
generator = self.get_previous_time_step(t_init)
@@ -2480,9 +2513,9 @@ def advect(
24802513
dt = nb_step * time_step
24812514
t_step = time_step
24822515
t0, d0 = generator.__next__()
2483-
u0, v0, m0 = d0.uv_for_advection(u_name, v_name, time_step, **kw)
2516+
u0, v0, m0 = d0.uv_for_advection(time_step=time_step, **kw)
24842517
t1, d1 = generator.__next__()
2485-
u1, v1, m1 = d1.uv_for_advection(u_name, v_name, time_step, **kw)
2518+
u1, v1, m1 = d1.uv_for_advection(time_step=time_step, **kw)
24862519
t0 = t0 * 86400
24872520
t1 = t1 * 86400
24882521
t = t_init * 86400
@@ -2497,7 +2530,7 @@ def advect(
24972530
t0, u0, v0, m0 = t1, u1, v1, m1
24982531
t1, d1 = generator.__next__()
24992532
t1 = t1 * 86400
2500-
u1, v1, m1 = d1.uv_for_advection(u_name, v_name, time_step, **kw)
2533+
u1, v1, m1 = d1.uv_for_advection(time_step=time_step, **kw)
25012534
w = 1 - (arange(t, t + dt, t_step) - t0) / (t1 - t0)
25022535
half_w = t_step / 2.0 / (t1 - t0)
25032536
advect_(

src/py_eddy_tracker/observations/groups.py

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def particle_candidate_step(
114114
# Advect particles
115115
kw = dict(nb_step=day_fraction, time_step=86400 / day_fraction)
116116
p = c.advect(x, y, t_init=t_start, **kwargs, **kw)
117-
for _ in range(dt):
117+
for _ in range(abs(dt)):
118118
_, x, y = p.__next__()
119119
m = ~(isnan(x) + isnan(y))
120120
i_end = full(x.shape, -1, dtype="i4")
@@ -352,7 +352,7 @@ def keep_tracks_by_date(self, date, nb_days):
352352
return self.extract_with_mask(mask)
353353

354354
def particle_candidate_atlas(
355-
self, cube, space_step, dt, start_intern=False, end_intern=False, **kwargs
355+
self, cube, space_step, dt, start_intern=False, end_intern=False, callback_coherence=None, finalize_coherence=None, **kwargs
356356
):
357357
"""Select particles within eddies, advect them, return target observation and associated percentages
358358
@@ -361,7 +361,9 @@ def particle_candidate_atlas(
361361
:param int dt: duration of advection
362362
:param bool start_intern: Use intern or extern contour at injection, defaults to False
363363
:param bool end_intern: Use intern or extern contour at end of advection, defaults to False
364-
:params dict kwargs: dict of params given to advection
364+
:param dict kwargs: dict of params given to advection
365+
:param func callback_coherence: if None we will use cls.fill_coherence
366+
:param func finalize_coherence: to apply on results of callback_coherence
365367
:return (np.array,np.array): return target index and percent associate
366368
"""
367369
t_start, t_end = int(self.period[0]), int(self.period[1])
@@ -374,23 +376,62 @@ def particle_candidate_atlas(
374376
i_target, pct = full(shape, -1, dtype="i4"), zeros(shape, dtype="i1")
375377
# Backward or forward
376378
times = arange(t_start, t_end - dt) if dt > 0 else arange(t_start + dt, t_end)
379+
380+
if callback_coherence is None:
381+
callback_coherence = self.fill_coherence
382+
indexs = dict()
383+
results = list()
384+
kw_coherence = dict(space_step=space_step, dt=dt, c=cube)
385+
kw_coherence.update(kwargs)
377386
for t in times:
387+
logger.info("Coherence for time step : %s in [%s:%s]", t, times[0], times[-1])
378388
# Get index for origin
379389
i = t - t_start
380390
indexs0 = i_sort[i_start[i] : i_end[i]]
381391
# Get index for end
382392
i = t + dt - t_start
383393
indexs1 = i_sort[i_start[i] : i_end[i]]
384-
# Get contour data
385-
contours0 = [self[label][indexs0] for label in self.intern(start_intern)]
386-
contours1 = [self[label][indexs1] for label in self.intern(end_intern)]
387-
# Get local result
388-
i_target_, pct_ = particle_candidate_step(
389-
t, contours0, contours1, space_step, dt, cube, **kwargs
390-
)
391-
# Merge result
392-
m = i_target_ != -1
393-
i_target_[m] = indexs1[i_target_[m]]
394-
i_target[indexs0] = i_target_
395-
pct[indexs0] = pct_
394+
if indexs0.size == 0 or indexs1.size == 0:
395+
continue
396+
397+
results.append(callback_coherence(self, i_target, pct, indexs0, indexs1, start_intern, end_intern, t_start=t, **kw_coherence))
398+
indexs[results[-1]] = indexs0, indexs1
399+
400+
if finalize_coherence is not None:
401+
finalize_coherence(results, indexs, i_target, pct)
396402
return i_target, pct
403+
404+
@classmethod
405+
def fill_coherence(cls, network, i_targets, percents, i_origin, i_end, start_intern, end_intern, **kwargs):
406+
"""_summary_
407+
408+
:param array i_targets: global target
409+
:param array percents:
410+
:param array i_origin: indices of origins
411+
:param array i_end: indices of ends
412+
:param bool start_intern: Use intern or extern contour at injection
413+
:param bool end_intern: Use intern or extern contour at end of advection
414+
"""
415+
# Get contour data
416+
contours_start = [network[label][i_origin] for label in cls.intern(start_intern)]
417+
contours_end = [network[label][i_end] for label in cls.intern(end_intern)]
418+
# Compute local coherence
419+
i_local_targets, local_percents = particle_candidate_step(contours_start=contours_start, contours_end=contours_end,**kwargs)
420+
# Store
421+
cls.merge_particle_result(i_targets, percents, i_local_targets, local_percents, i_origin, i_end)
422+
423+
@staticmethod
424+
def merge_particle_result(i_targets, percents, i_local_targets, local_percents, i_origin, i_end):
425+
"""Copy local result in merged result with global indexation
426+
427+
:param array i_targets: global target
428+
:param array percents:
429+
:param array i_local_targets: local index target
430+
:param array local_percents:
431+
:param array i_origin: indices of origins
432+
:param array i_end: indices of ends
433+
"""
434+
m = i_local_targets != -1
435+
i_local_targets[m] = i_end[i_local_targets[m]]
436+
i_targets[i_origin] = i_local_targets
437+
percents[i_origin] = local_percents

src/py_eddy_tracker/observations/network.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import time
88

99
import netCDF4
10-
from numba import njit
10+
from numba import njit, types as nb_types
1111
from numba.typed import List
1212
from numpy import (
1313
arange,
@@ -23,6 +23,8 @@
2323
unique,
2424
where,
2525
zeros,
26+
percentile,
27+
nan
2628
)
2729
import zarr
2830

@@ -1743,6 +1745,7 @@ def segment_coherence_forward(
17431745
step_mesh=1.0 / 50,
17441746
contour_start="speed",
17451747
contour_end="speed",
1748+
**kwargs,
17461749
):
17471750

17481751
"""
@@ -1801,6 +1804,7 @@ def date2file(julian_day):
18011804
n_days=n_days,
18021805
contour_start=contour_start,
18031806
contour_end=contour_end,
1807+
**kwargs
18041808
)
18051809
logger.info(
18061810
(
@@ -1996,7 +2000,69 @@ def build_dataset(self, group, raw_data=True):
19962000
print()
19972001
eddies.track[new_i] = group
19982002
return eddies
2003+
2004+
@njit(cache=True)
2005+
def get_percentile_on_following_obs(i, indexs, percents, follow_obs, t, segment, i_target, window, q=50, nb_min=1):
2006+
"""Get stat on a part of segment close of an event
2007+
2008+
:param int i: index to follow
2009+
:param array indexs: indexs from coherence
2010+
:param array percents: percent from coherence
2011+
:param array[int] follow_obs: give index for the following observation
2012+
:param array t: time for each observation
2013+
:param array segment: segment for each observation
2014+
:param int i_target: index of target
2015+
:param int window: time window of search
2016+
:param int q: Percentile from 0 to 100, defaults to 50
2017+
:param int nb_min: Number minimal of observation to provide statistics, defaults to 1
2018+
:return float : return statistic
2019+
"""
2020+
last_t, segment_follow = t[i], segment[i]
2021+
segment_target = segment[i_target]
2022+
percent_target = empty(window, dtype=percents.dtype)
2023+
j = 0
2024+
while abs(last_t - t[i]) < window and i != -1 and segment_follow == segment[i]:
2025+
# Iter on primary & secondary
2026+
for index, percent in zip(indexs[i], percents[i]):
2027+
if index != -1 and segment[index] == segment_target:
2028+
percent_target[j] = percent
2029+
j += 1
2030+
i = follow_obs[i]
2031+
if j < nb_min:
2032+
return nan
2033+
return percentile(percent_target[:j], q)
19992034

2035+
@njit(cache=True)
2036+
def get_percentile_around_event(i, i1, i2, ind, pct, follow_obs, t, segment, window=10, follow_parent=False, q=50, nb_min=1):
2037+
"""Get stat around event
2038+
2039+
:param array[int] i: Indexs of target
2040+
:param array[int] i1: Indexs of primary origin
2041+
:param array[int] i2: Indexs of secondary origin
2042+
:param array ind: indexs from coherence
2043+
:param array pct: percent from coherence
2044+
:param array[int] follow_obs: give index for the following observation
2045+
:param array t: time for each observation
2046+
:param array segment: segment for each observation
2047+
:param int window: time window of search, defaults to 10
2048+
:param bool follow_parent: Follow parent instead of child, defaults to False
2049+
:param int q: Percentile from 0 to 100, defaults to 50
2050+
:param int nb_min: Number minimal of observation to provide statistics, defaults to 1
2051+
:return (array,array) : statistic for each event
2052+
"""
2053+
stat1 = empty(i.size, dtype=nb_types.float32)
2054+
stat2 = empty(i.size, dtype=nb_types.float32)
2055+
# iter on event
2056+
for j, (i_, i1_, i2_) in enumerate(zip(i, i1, i2)):
2057+
if follow_parent:
2058+
# We follow parent
2059+
stat1[j] = get_percentile_on_following_obs(i_, ind, pct, follow_obs, t, segment, i1_, window, q, nb_min)
2060+
stat2[j] = get_percentile_on_following_obs(i_, ind, pct, follow_obs, t, segment, i2_, window, q, nb_min)
2061+
else:
2062+
# We follow child
2063+
stat1[j] = get_percentile_on_following_obs(i1_, ind, pct, follow_obs, t, segment, i_, window, q, nb_min)
2064+
stat2[j] = get_percentile_on_following_obs(i2_, ind, pct, follow_obs, t, segment, i_, window, q, nb_min)
2065+
return stat1, stat2
20002066

20012067
@njit(cache=True)
20022068
def get_next_index(gr):

src/py_eddy_tracker/observations/observation.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1656,6 +1656,31 @@ def create_variable(
16561656
except ValueError:
16571657
logger.warning("Data is empty")
16581658

1659+
@staticmethod
1660+
def get_filters_zarr(name):
1661+
"""Get filters to store in zarr for known variable
1662+
1663+
:param str name: private variable name
1664+
:return list: filters list
1665+
"""
1666+
content = VAR_DESCR.get(name)
1667+
filters = list()
1668+
store_dtype = content["output_type"]
1669+
scale_factor, add_offset = content.get("scale_factor", None), content.get("add_offset", None)
1670+
if scale_factor is not None or add_offset is not None:
1671+
if add_offset is None:
1672+
add_offset = 0
1673+
filters.append(
1674+
zarr.FixedScaleOffset(
1675+
offset=add_offset,
1676+
scale=1 / scale_factor,
1677+
dtype=content["nc_type"],
1678+
astype=store_dtype,
1679+
)
1680+
)
1681+
filters.extend(content.get("filters", []))
1682+
return filters
1683+
16591684
def create_variable_zarr(
16601685
self,
16611686
handler_zarr,

0 commit comments

Comments
 (0)