Skip to content

Commit 81eaf15

Browse files
add coherence function, and minor corrections
add shift_files to GridCollection change heigth to height correct FIXME in function documentation correct tests
1 parent e934346 commit 81eaf15

File tree

3 files changed

+308
-15
lines changed

3 files changed

+308
-15
lines changed

src/py_eddy_tracker/dataset/grid.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2264,6 +2264,16 @@ def from_netcdf_list(cls, filenames, t, x_name, y_name, indexs=None, heigth=None
22642264
new.datasets.append((t, d))
22652265
return new
22662266

2267+
def shift_files(self, t, filename, x_name, y_name, indexs, heigth):
2268+
"""Add next file to the list and remove the oldest"""
2269+
2270+
self.datasets = self.datasets[1:]
2271+
2272+
d = RegularGridDataset(filename, x_name, y_name, indexs=indexs)
2273+
if heigth is not None:
2274+
d.add_uv(heigth)
2275+
self.datasets.append((t, d))
2276+
22672277
def interp(self, grid_name, t, lons, lats, method="bilinear"):
22682278
"""
22692279
Compute z over lons, lats

src/py_eddy_tracker/observations/network.py

Lines changed: 297 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from glob import glob
77

88
from numba import njit
9+
from numba import types as nb_types
910
from numpy import (
1011
arange,
1112
array,
@@ -20,13 +21,16 @@
2021
unique,
2122
where,
2223
zeros,
24+
meshgrid,
2325
)
26+
import zarr
2427

2528
from ..generic import build_index, wrap_longitude
26-
from ..poly import bbox_intersection, vertice_overlap
29+
from ..poly import bbox_intersection, vertice_overlap, group_obs
2730
from .groups import GroupEddiesObservations, get_missing_indices
2831
from .observation import EddiesObservations
2932
from .tracking import TrackEddiesObservations, track_loess_filter, track_median_filter
33+
from ..dataset.grid import GridCollection
3034

3135
logger = logging.getLogger("pet")
3236

@@ -97,6 +101,109 @@ def fix_next_previous_obs(next_obs, previous_obs, flag_virtual):
97101
previous_obs[i_o + 1] = i_o
98102

99103

104+
def advect(x, y, c, t0, delta_t):
105+
"""
106+
Advect particle from t0 to t0 + delta_t, with data cube.
107+
108+
:param np.array(float) x: longitude of particles
109+
:param np.array(float) y: latitude of particles
110+
:param `~py_eddy_tracker.dataset.grid.GridCollection` c: GridCollection with speed for particles
111+
:param int t0: julian day of advection start
112+
:param int delta_t: number of days to advect
113+
"""
114+
115+
kw = dict(nb_step=6, time_step=86400 / 6)
116+
if delta_t < 0:
117+
kw["backward"] = True
118+
delta_t = -delta_t
119+
p = c.advect(x, y, "u", "v", t_init=t0, **kw)
120+
for _ in range(delta_t):
121+
t, x, y = p.__next__()
122+
return t, x, y
123+
124+
125+
def particle_candidate(x, y, c, eddies, t_start, i_target, pct, **kwargs):
126+
"""Select particles within eddies, advect them, return target observation and associated percentages
127+
128+
:param np.array(float) x: longitude of particles
129+
:param np.array(float) y: latitude of particles
130+
:param `~py_eddy_tracker.dataset.grid.GridCollection` c: GridCollection with speed for particles
131+
:param NetworkObservations eddies: NetworkObservations considered
132+
:param int t_start: julian day of the advection
133+
:param np.array(int) i_target: corresponding obs where particles are advected
134+
:param np.array(int) pct: corresponding percentage of avected particles
135+
:params dict kwargs: dict of params given to `advect`
136+
"""
137+
138+
# Obs from initial time
139+
m_start = eddies.time == t_start
140+
141+
e = eddies.extract_with_mask(m_start)
142+
# to be able to get global index
143+
translate_start = where(m_start)[0]
144+
# Identify particle in eddies (only in core)
145+
i_start = e.contains(x, y, intern=True)
146+
m = i_start != -1
147+
148+
x, y, i_start = x[m], y[m], i_start[m]
149+
# Advect
150+
t_end, x, y = advect(x, y, c, t_start, **kwargs)
151+
# eddies at last date
152+
m_end = eddies.time == t_end / 86400
153+
e_end = eddies.extract_with_mask(m_end)
154+
# to be able to get global index
155+
translate_end = where(m_end)[0]
156+
# Id eddies for each alive particle (in core and extern)
157+
i_end = e_end.contains(x, y)
158+
# compute matrix and fill target array
159+
get_matrix(i_start, i_end, translate_start, translate_end, i_target, pct)
160+
161+
162+
@njit(cache=True)
163+
def get_matrix(i_start, i_end, translate_start, translate_end, i_target, pct):
164+
"""Compute target observation and associated percentages
165+
166+
:param np.array(int) i_start: indices of associated contours at starting advection day
167+
:param np.array(int) i_end: indices of associated contours after advection
168+
:param np.array(int) translate_start: corresponding global indices at starting advection day
169+
:param np.array(int) translate_end: corresponding global indices after advection
170+
:param np.array(int) i_target: corresponding obs where particles are advected
171+
:param np.array(int) pct: corresponding percentage of avected particles
172+
"""
173+
174+
nb_start, nb_end = translate_start.size, translate_end.size
175+
# Matrix which will store count for every couple
176+
count = zeros((nb_start, nb_end), dtype=nb_types.int32)
177+
# Number of particles in each origin observation
178+
ref = zeros(nb_start, dtype=nb_types.int32)
179+
# For each particle
180+
for i in range(i_start.size):
181+
i_end_ = i_end[i]
182+
i_start_ = i_start[i]
183+
if i_end_ != -1:
184+
count[i_start_, i_end_] += 1
185+
ref[i_start_] += 1
186+
for i in range(nb_start):
187+
for j in range(nb_end):
188+
pct_ = count[i, j]
189+
# If there are particles from i to j
190+
if pct_ != 0:
191+
# Get percent
192+
pct_ = pct_ / ref[i] * 100.0
193+
# Get indices in full dataset
194+
i_, j_ = translate_start[i], translate_end[j]
195+
pct_0 = pct[i_, 0]
196+
if pct_ > pct_0:
197+
pct[i_, 1] = pct_0
198+
pct[i_, 0] = pct_
199+
i_target[i_, 1] = i_target[i_, 0]
200+
i_target[i_, 0] = j_
201+
elif pct_ > pct[i_, 1]:
202+
pct[i_, 1] = pct_
203+
i_target[i_, 1] = j_
204+
return i_target, pct
205+
206+
100207
class NetworkObservations(GroupEddiesObservations):
101208

102209
__slots__ = ("_index_network",)
@@ -109,17 +216,16 @@ def __init__(self, *args, **kwargs):
109216

110217
def find_segments_relative(self, obs, stopped=None, order=1):
111218
"""
112-
Find all relative segments linked with merging/splitting events at a specific order.
219+
Find all relative segments from obs linked with merging/splitting events at a specific order.
113220
114-
:param int obs: index of event after the event
115-
:param int stopped: index of event before the event
221+
:param int obs: index of observation after the event
222+
:param int stopped: index of observation before the event
116223
:param int order: order of relatives accepted
117224
118225
:return: all relative segments
119226
:rtype: EddiesObservations
120227
"""
121228

122-
# FIXME : double "event" in the description, please clarify (event = chosen obs?)
123229

124230
# extraction of network where the event is
125231
network_id = self.tracks[obs]
@@ -247,23 +353,17 @@ def infos(self, label=""):
247353
def correct_close_events(self, nb_days_max=20):
248354
"""
249355
Transform event where
250-
segment A split to B, then A merge into B
356+
segment A splits from segment B, then x days after segment B merges with A
251357
252358
to
253359
254-
segment A split to B, then B merge to A
360+
segment A splits from segment B then x days after segment A merges with B (B will be longer)
255361
256-
these events are filtered with `nb_days_max`, which the event have to take place in less than `nb_days_max`
362+
These events have to last less than `nb_days_max` to be changed.
257363
258364
:param float nb_days_max: maximum time to search for splitting-merging event
259365
"""
260366

261-
# FIXME : we want to change
262-
# segment A splits from segment B, then x days after segment B merges with A
263-
# to
264-
# segment A splits from segment B then x days after segement A merges with B (B will be longer)
265-
# comments are in the wrong way but the example works as wanted
266-
267367
_time = self.time
268368
# segment used to correct and track changes
269369
segment = self.segment_track_array.copy()
@@ -1340,6 +1440,189 @@ def extract_with_mask(self, mask):
13401440
new.previous_obs[:] = translate[p]
13411441
return new
13421442

1443+
def analysis_coherence(
1444+
self,
1445+
date_function,
1446+
uv_params,
1447+
advection_mode="both",
1448+
dt_advect=14,
1449+
step_mesh=1.0 / 50,
1450+
output_name=None,
1451+
dissociate_network=False,
1452+
correct_close_events=0,
1453+
remove_dead_end=0,
1454+
):
1455+
1456+
"""Global function to analyse segments coherence, with network preprocessing"""
1457+
1458+
if dissociate_network:
1459+
self.dissociate_network()
1460+
1461+
if correct_close_events > 0:
1462+
self.correct_close_events(nb_days_max=correct_close_events)
1463+
1464+
if remove_dead_end > 0:
1465+
network_clean = self.remove_dead_end(nobs=0, ndays=remove_dead_end)
1466+
else:
1467+
network_clean = self
1468+
1469+
res = network_clean.segment_coherence(
1470+
date_function=date_function,
1471+
uv_params=uv_params,
1472+
advection_mode=advection_mode,
1473+
output_name=output_name,
1474+
dt_advect=dt_advect,
1475+
step_mesh=step_mesh,
1476+
)
1477+
1478+
return network_clean, res
1479+
1480+
def segment_coherence(
1481+
self,
1482+
date_function,
1483+
uv_params,
1484+
advection_mode="both",
1485+
dt_advect=14,
1486+
step_mesh=1.0 / 50,
1487+
output_name=None,
1488+
):
1489+
1490+
"""
1491+
Percentage of particules and their targets after forward or/and backward advection from a specific eddy.
1492+
1493+
:param callable date_function: python function, takes as param `int` (julian day) and return
1494+
data filename associated to the date
1495+
ex:
1496+
def date2file(julian_day):
1497+
date = datetime.timedelta(days=julian_day) + datetime.datetime(1950, 1, 1)
1498+
1499+
return f"/tmp/dt_global_allsat_phy_l4_{date.strftime('%Y%m%d')}.nc"
1500+
1501+
:param dict uv_params: dict of parameters used by
1502+
:py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_list`
1503+
:param str advection_mode: "backward", "forward" or "both"
1504+
:param int dt_advect: days for advection
1505+
:param float step_mesh: step for particule mesh in degrees
1506+
:param str output_name: if not None, name of file saved in zarr. Else, data will not be saved
1507+
"""
1508+
1509+
if advection_mode in ["both", "forward"]:
1510+
itf_final = -ones((self.obs.size, 2), dtype="i4")
1511+
ptf_final = zeros((self.obs.size, 2), dtype="i1")
1512+
1513+
if advection_mode in ["both", "backward"]:
1514+
itb_final = -ones((self.obs.size, 2), dtype="i4")
1515+
ptb_final = zeros((self.obs.size, 2), dtype="i1")
1516+
1517+
for slice_track, b0, _ in self.iter_on(self.track):
1518+
if b0 == 0:
1519+
continue
1520+
1521+
sub_networks = self.network(b0)
1522+
1523+
# find extremum to create a mesh of particles
1524+
lon = sub_networks.contour_lon_s
1525+
lonMin = lon.min() - 0.1
1526+
lonMax = lon.max() + 0.1
1527+
1528+
lat = sub_networks.contour_lat_s
1529+
latMin = lat.min() - 0.1
1530+
latMax = lat.max() + 0.1
1531+
1532+
x0, y0 = meshgrid(
1533+
arange(lonMin, lonMax, step_mesh), arange(latMin, latMax, step_mesh)
1534+
)
1535+
x0, y0 = x0.reshape(-1), y0.reshape(-1)
1536+
_, i = group_obs(x0, y0, 1, 360)
1537+
x0, y0 = x0[i], y0[i]
1538+
1539+
t_start, t_end = sub_networks.period
1540+
shape = (sub_networks.obs.size, 2)
1541+
1542+
if advection_mode in ["both", "forward"]:
1543+
1544+
# first dates to load.
1545+
dates = arange(t_start - 1, t_start + dt_advect + 2)
1546+
# files associated with dates
1547+
first_files = [date_function(x) for x in dates]
1548+
1549+
c = GridCollection.from_netcdf_list(first_files, dates, **uv_params)
1550+
1551+
i_target_f = -ones(shape, dtype="i4")
1552+
pct_target_f = zeros(shape, dtype="i1")
1553+
1554+
for _t in range(t_start, t_end - dt_advect + 1):
1555+
t_shift = _t + dt_advect + 2
1556+
1557+
# add next date to GridCollection and delete last date
1558+
c.shift_files(t_shift, date_function(int(t_shift)), **uv_params)
1559+
particle_candidate(
1560+
x0,
1561+
y0,
1562+
c,
1563+
sub_networks,
1564+
_t,
1565+
i_target_f,
1566+
pct_target_f,
1567+
delta_t=dt_advect,
1568+
)
1569+
1570+
itf_final[slice_track] = i_target_f
1571+
ptf_final[slice_track] = pct_target_f
1572+
1573+
if advection_mode in ["both", "backward"]:
1574+
1575+
# first dates to load.
1576+
dates = arange(t_start - 1, t_start + dt_advect + 2)
1577+
# files associated with dates
1578+
first_files = [date_function(x) for x in dates]
1579+
1580+
c = GridCollection.from_netcdf_list(first_files, dates, **uv_params)
1581+
1582+
i_target_b = -ones(shape, dtype="i4")
1583+
pct_target_b = zeros(shape, dtype="i1")
1584+
1585+
for _t in range(t_start + dt_advect + 1, t_end + 1):
1586+
t_shift = _t + 1
1587+
1588+
# add next date to GridCollection and delete last date
1589+
c.shift_files(t_shift, date_function(int(t_shift)), **uv_params)
1590+
particle_candidate(
1591+
x0,
1592+
y0,
1593+
c,
1594+
self,
1595+
_t,
1596+
i_target_b,
1597+
pct_target_b,
1598+
delta_t=-dt_advect,
1599+
)
1600+
1601+
itb_final[slice_track] = i_target_b
1602+
ptb_final[slice_track] = pct_target_b
1603+
1604+
if output_name is not None:
1605+
zg = zarr.open(output_name, "w")
1606+
1607+
# zarr compression parameters
1608+
params_seg = dict()
1609+
params_pct = dict()
1610+
1611+
res = []
1612+
if advection_mode in ["forward", "both"]:
1613+
res = res + [itf_final, ptf_final]
1614+
if output_name is not None:
1615+
zg.array("i_target_forward", itf_final, **params_seg)
1616+
zg.array("pct_target_forward", ptf_final, **params_pct)
1617+
1618+
if advection_mode in ["backward", "both"]:
1619+
res = res + [itb_final, ptb_final]
1620+
if output_name is not None:
1621+
zg.array("i_target_backward", itb_final, **params_seg)
1622+
zg.array("pct_target_backward", ptb_final, **params_pct)
1623+
1624+
return res
1625+
13431626

13441627
class Network:
13451628
__slots__ = (

src/py_eddy_tracker/observations/observation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2045,7 +2045,7 @@ def is_convex(self, intern=False):
20452045

20462046
def contains(self, x, y, intern=False):
20472047
"""
2048-
Return index of contour which contain (x,y)
2048+
Return index of contour containing (x,y)
20492049
20502050
:param array x: longitude
20512051
:param array y: latitude

0 commit comments

Comments
 (0)