Skip to content

Commit 25df7b6

Browse files
committed
Use last obs of network, manage wrapping of network around 0
1 parent 2c78c22 commit 25df7b6

File tree

4 files changed

+25
-262
lines changed

4 files changed

+25
-262
lines changed

examples/12_external_data/pet_SST_collocation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
# %%
2929
# Loading data
30-
# -----------------------------
30+
# ------------
3131
sst = RegularGridDataset(filename=filename_sst, x_name="lon", y_name="lat")
3232
alti = RegularGridDataset(
3333
data.get_path(filename_alt), x_name="longitude", y_name="latitude"
@@ -58,14 +58,14 @@ def update_axes(ax, mappable=None, unit=""):
5858

5959
# %%
6060
# ADT first display
61-
# -----------------------------
61+
# -----------------
6262
ax = start_axes("SLA", extent=extent)
6363
m = sst.display(ax, "sla", vmin=0.05, vmax=0.35)
6464
update_axes(ax, m, unit="[m]")
6565

6666
# %%
6767
# SST first display
68-
# -----------------------------
68+
# -----------------
6969

7070
# %%
7171
# We can now plot SST from `sst`

src/py_eddy_tracker/appli/network.py

Lines changed: 3 additions & 249 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,9 @@
55

66
import logging
77

8-
from netCDF4 import Dataset
9-
from numpy import arange, empty, zeros
10-
from Polygon import Polygon
11-
128
from .. import EddyParser
13-
from ..generic import build_index
149
from ..observations.network import Network
1510
from ..observations.tracking import TrackEddiesObservations
16-
from ..poly import create_vertice_from_2darray, polygon_overlap
1711

1812
logger = logging.getLogger("pet")
1913

@@ -57,246 +51,6 @@ def divide_network():
5751
include_vars=("time", "track", "latitude", "longitude", *contour_name),
5852
)
5953
ids = e.split_network(intern=args.intern, window=args.window)
60-
61-
62-
def split_network(input, output):
63-
"""Divide each group in track"""
64-
sl = slice(None)
65-
with Dataset(input) as h:
66-
group = h.variables["track"][sl]
67-
track_s, track_e, track_ref = build_index(group)
68-
# nb = track_e - track_s
69-
# m = nb > 1500
70-
# print(group[track_s[m]])
71-
72-
track_id = 12003
73-
sls = [slice(track_s[track_id - track_ref], track_e[track_id - track_ref], None)]
74-
for sl in sls:
75-
76-
print(sl)
77-
with Dataset(input) as h:
78-
time = h.variables["time"][sl]
79-
group = h.variables["track"][sl]
80-
x = h.variables["effective_contour_longitude"][sl]
81-
y = h.variables["effective_contour_latitude"][sl]
82-
print(group[0])
83-
ids = empty(
84-
time.shape,
85-
dtype=[
86-
("group", group.dtype),
87-
("time", time.dtype),
88-
("track", "u2"),
89-
("previous_cost", "f4"),
90-
("next_cost", "f4"),
91-
("previous_observation", "i4"),
92-
("next_observation", "i4"),
93-
],
94-
)
95-
ids["group"] = group
96-
ids["time"] = time
97-
# To store id track
98-
ids["track"] = 0
99-
ids["previous_cost"] = 0
100-
ids["next_cost"] = 0
101-
ids["previous_observation"] = -1
102-
ids["next_observation"] = -1
103-
# Cost with previous
104-
track_start, track_end, track_ref = build_index(group)
105-
for i0, i1 in zip(track_start, track_end):
106-
if (i1 - i0) == 0 or group[i0] == Network.NOGROUP:
107-
continue
108-
sl_group = slice(i0, i1)
109-
set_tracks(
110-
x[sl_group],
111-
y[sl_group],
112-
time[sl_group],
113-
i0,
114-
ids["track"][sl_group],
115-
ids["previous_cost"][sl_group],
116-
ids["next_cost"][sl_group],
117-
ids["previous_observation"][sl_group],
118-
ids["next_observation"][sl_group],
119-
window=5,
120-
)
121-
122-
new_i = ids.argsort(order=("group", "track", "time"))
123-
ids_sort = ids[new_i]
124-
# To be able to follow indices sorting
125-
reverse_sort = empty(new_i.shape[0], dtype="u4")
126-
reverse_sort[new_i] = arange(new_i.shape[0])
127-
# Redirect indices
128-
m = ids_sort["next_observation"] != -1
129-
ids_sort["next_observation"][m] = reverse_sort[ids_sort["next_observation"][m]]
130-
m = ids_sort["previous_observation"] != -1
131-
ids_sort["previous_observation"][m] = reverse_sort[
132-
ids_sort["previous_observation"][m]
133-
]
134-
# print(ids_sort)
135-
display_network(
136-
x[new_i],
137-
y[new_i],
138-
ids_sort["track"],
139-
ids_sort["time"],
140-
ids_sort["next_cost"],
141-
)
142-
143-
144-
def next_obs(
145-
i_current, next_cost, previous_cost, polygons, t, t_start, t_end, t_ref, window
146-
):
147-
t_max = t_end.shape[0] - 1
148-
t_cur = t[i_current]
149-
t0, t1 = t_cur + 1 - t_ref, t_cur + window - t_ref
150-
if t0 > t_max:
151-
return -1
152-
t1 = min(t1, t_max)
153-
for t_step in range(t0, t1 + 1):
154-
i0, i1 = t_start[t_step], t_end[t_step]
155-
# No observation at the time step !
156-
if i0 == i1:
157-
continue
158-
sl = slice(i0, i1)
159-
# Intersection / union, to be able to separte in case of multiple inside
160-
c = polygon_overlap(polygons[i_current], polygons[sl])
161-
# We remove low overlap
162-
if (c > 0.1).sum() > 1:
163-
print(c)
164-
c[c < 0.1] = 0
165-
# We get index of maximal overlap
166-
i = c.argmax()
167-
c_i = c[i]
168-
# No overlap found
169-
if c_i == 0:
170-
continue
171-
target = i0 + i
172-
# Check if candidate is already used
173-
c_target = previous_cost[target]
174-
if (c_target != 0 and c_target < c_i) or c_target == 0:
175-
previous_cost[target] = c_i
176-
next_cost[i_current] = c_i
177-
return target
178-
return -1
179-
180-
181-
def set_tracks(
182-
x,
183-
y,
184-
t,
185-
ref_index,
186-
track,
187-
previous_cost,
188-
next_cost,
189-
previous_observation,
190-
next_observation,
191-
window,
192-
):
193-
# Will split one group in tracks
194-
t_start, t_end, t_ref = build_index(t)
195-
nb = x.shape[0]
196-
used = zeros(nb, dtype="bool")
197-
current_track = 1
198-
# build all polygon (need to check if wrap is needed)
199-
polygons = list()
200-
for i in range(nb):
201-
polygons.append(Polygon(create_vertice_from_2darray(x, y, i)))
202-
203-
for i in range(nb):
204-
# If observation already in one track, we go to the next one
205-
if used[i]:
206-
continue
207-
build_track(
208-
i,
209-
current_track,
210-
used,
211-
track,
212-
previous_observation,
213-
next_observation,
214-
ref_index,
215-
next_cost,
216-
previous_cost,
217-
polygons,
218-
t,
219-
t_start,
220-
t_end,
221-
t_ref,
222-
window,
223-
)
224-
current_track += 1
225-
226-
227-
def build_track(
228-
first_index,
229-
track_id,
230-
used,
231-
track,
232-
previous_observation,
233-
next_observation,
234-
ref_index,
235-
next_cost,
236-
previous_cost,
237-
*args,
238-
):
239-
i_next = first_index
240-
while i_next != -1:
241-
# Flag
242-
used[i_next] = True
243-
# Assign id
244-
track[i_next] = track_id
245-
# Search next
246-
i_next_ = next_obs(i_next, next_cost, previous_cost, *args)
247-
if i_next_ == -1:
248-
break
249-
next_observation[i_next] = i_next_ + ref_index
250-
if not used[i_next_]:
251-
previous_observation[i_next_] = i_next + ref_index
252-
# Target was previously used
253-
if used[i_next_]:
254-
if next_cost[i_next] == previous_cost[i_next_]:
255-
m = track[i_next_:] == track[i_next_]
256-
track[i_next_:][m] = track_id
257-
previous_observation[i_next_] = i_next + ref_index
258-
i_next_ = -1
259-
i_next = i_next_
260-
261-
262-
def display_network(x, y, tr, t, c):
263-
tr0, tr1, t_ref = build_index(tr)
264-
import matplotlib.pyplot as plt
265-
266-
cmap = plt.get_cmap("jet")
267-
from ..generic import flatten_line_matrix
268-
269-
fig = plt.figure(figsize=(20, 10))
270-
ax = fig.add_subplot(121, aspect="equal")
271-
ax.grid()
272-
ax_time = fig.add_subplot(122)
273-
ax_time.grid()
274-
i = 0
275-
for s, e in zip(tr0, tr1):
276-
if s == e:
277-
continue
278-
sl = slice(s, e)
279-
color = cmap((tr[s] - tr[tr0[0]]) / (tr[tr0[-1]] - tr[tr0[0]]))
280-
ax.plot(
281-
flatten_line_matrix(x[sl]),
282-
flatten_line_matrix(y[sl]),
283-
color=color,
284-
label=f"{tr[s]} - {e-s} obs from {t[s]} to {t[e-1]}",
285-
)
286-
i += 1
287-
ax_time.plot(
288-
t[sl],
289-
tr[s].repeat(e - s) + c[sl],
290-
color=color,
291-
label=f"{tr[s]} - {e-s} obs",
292-
lw=0.5,
293-
)
294-
ax_time.plot(t[sl], tr[s].repeat(e - s), color=color, lw=1, marker="+")
295-
ax_time.text(t[s], tr[s] + 0.15, f"{x[s].mean():.2f}, {y[s].mean():.2f}")
296-
ax_time.axvline(t[s], color=".75", lw=0.5, ls="--", zorder=-10)
297-
ax_time.text(
298-
t[e - 1], tr[e - 1] - 0.25, f"{x[e-1].mean():.2f}, {y[e-1].mean():.2f}"
299-
)
300-
ax.legend()
301-
ax_time.legend()
302-
plt.show()
54+
e = e.add_fields(("sub_track",))
55+
e.sub_track[:] = ids["track"]
56+
e.write_file(filename=args.out)

src/py_eddy_tracker/generic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ def build_index(groups):
8383
first_index[group - i0 + 1 : next_group - i0 + 1] = i + 1
8484
last_index = zeros(amplitude, dtype=numba_types.int_)
8585
last_index[:-1] = first_index[1:]
86-
last_index[-1] = i + 1
86+
# + 2 because we iterate only until -2 and we want upper bound ( 1 + 1)
87+
last_index[-1] = i + 2
8788
return first_index, last_index, i0
8889

8990

src/py_eddy_tracker/observations/tracking.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,10 @@
2626
where,
2727
zeros,
2828
)
29-
from Polygon import Polygon
3029

3130
from .. import VAR_DESCR_inv, __version__
3231
from ..generic import build_index, cumsum_by_track, distance, split_line, wrap_longitude
33-
from ..poly import create_vertice_from_2darray, merge, polygon_overlap
32+
from ..poly import bbox_intersection, merge, vertice_overlap
3433
from .observation import EddiesObservations
3534

3635
logger = logging.getLogger("pet")
@@ -642,15 +641,14 @@ def set_tracks(self, x, y, ids, window):
642641
used = zeros(nb, dtype="bool")
643642
track_id = 1
644643
# build all polygon (need to check if wrap is needed)
645-
polygons = [Polygon(create_vertice_from_2darray(x, y, i)) for i in range(nb)]
646644
for i in range(nb):
647645
# If observation already in one track, we go to the next one
648646
if used[i]:
649647
continue
650-
self.follow_obs(i, track_id, used, ids, polygons, *time_index, window)
648+
self.follow_obs(i, track_id, used, ids, x, y, *time_index, window)
651649
track_id += 1
652650
# Search a possible ancestor
653-
self.previous_obs(i, ids, polygons, *time_index, window)
651+
self.previous_obs(i, ids, x, y, *time_index, window)
654652

655653
@classmethod
656654
def follow_obs(cls, i_next, track_id, used, ids, *args):
@@ -676,7 +674,7 @@ def follow_obs(cls, i_next, track_id, used, ids, *args):
676674
i_next = i_next_
677675

678676
@staticmethod
679-
def previous_obs(i_current, ids, polygons, time_s, time_e, time_ref, window):
677+
def previous_obs(i_current, ids, x, y, time_s, time_e, time_ref, window):
680678
time_cur = ids["time"][i_current]
681679
t0, t1 = time_cur - 1 - time_ref, max(time_cur - window - time_ref, 0)
682680
for t_step in range(t0, t1 - 1, -1):
@@ -685,7 +683,12 @@ def previous_obs(i_current, ids, polygons, time_s, time_e, time_ref, window):
685683
if i0 == i1:
686684
continue
687685
# Intersection / union, to be able to separte in case of multiple inside
688-
c = polygon_overlap(polygons[i_current], polygons[i0:i1], minimal_area=True)
686+
xi, yi, xj, yj = x[[i_current]], y[[i_current]], x[i0:i1], y[i0:i1]
687+
ii, ij = bbox_intersection(xi, yi, xj, yj)
688+
if len(ii) == 0:
689+
continue
690+
c = zeros(len(xj))
691+
c[ij] = vertice_overlap(xi[ii], yi[ii], xj[ij], yj[ij], minimal_area=True)
689692
# We remove low overlap
690693
c[c < 0.1] = 0
691694
# We get index of maximal overlap
@@ -699,7 +702,7 @@ def previous_obs(i_current, ids, polygons, time_s, time_e, time_ref, window):
699702
break
700703

701704
@staticmethod
702-
def next_obs(i_current, ids, polygons, time_s, time_e, time_ref, window):
705+
def next_obs(i_current, ids, x, y, time_s, time_e, time_ref, window):
703706
time_max = time_e.shape[0] - 1
704707
time_cur = ids["time"][i_current]
705708
t0, t1 = time_cur + 1 - time_ref, min(time_cur + window - time_ref, time_max)
@@ -711,7 +714,12 @@ def next_obs(i_current, ids, polygons, time_s, time_e, time_ref, window):
711714
if i0 == i1:
712715
continue
713716
# Intersection / union, to be able to separte in case of multiple inside
714-
c = polygon_overlap(polygons[i_current], polygons[i0:i1], minimal_area=True)
717+
xi, yi, xj, yj = x[[i_current]], y[[i_current]], x[i0:i1], y[i0:i1]
718+
ii, ij = bbox_intersection(xi, yi, xj, yj)
719+
if len(ii) == 0:
720+
continue
721+
c = zeros(len(xj))
722+
c[ij] = vertice_overlap(xi[ii], yi[ii], xj[ij], yj[ij], minimal_area=True)
715723
# We remove low overlap
716724
c[c < 0.1] = 0
717725
# We get index of maximal overlap

0 commit comments

Comments
 (0)