Skip to content

Commit a792e21

Browse files
committed
Add network example with particle
1 parent 930c9c6 commit a792e21

File tree

8 files changed

+435
-7
lines changed

8 files changed

+435
-7
lines changed

examples/07_cube_manipulation/pet_lavd_detection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,8 @@ def update_axes(ax, mappable=None):
178178
_ = update_axes(ax, mappable)
179179

180180
# %%
181-
# Contourdetection
182-
# ----------------
181+
# Contour detection
182+
# -----------------
183183
# To extract contour from LAVD grid, we will used method design for SSH, with some hacks and adapted options.
184184
# It will produce false amplitude and speed.
185185
kw_ident = dict(
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
"""
2+
Follow particle
3+
===============
4+
5+
"""
6+
import re
7+
8+
from matplotlib import colors
9+
from matplotlib import pyplot as plt
10+
from matplotlib.animation import FuncAnimation
11+
from numba import njit
12+
from numba import types as nb_types
13+
from numpy import arange, meshgrid, ones, unique, where, zeros
14+
15+
from py_eddy_tracker import start_logger
16+
from py_eddy_tracker.appli.gui import Anim
17+
from py_eddy_tracker.data import get_path
18+
from py_eddy_tracker.dataset.grid import GridCollection
19+
from py_eddy_tracker.observations.network import NetworkObservations
20+
21+
start_logger().setLevel("ERROR")
22+
23+
24+
# %%
25+
class VideoAnimation(FuncAnimation):
26+
def _repr_html_(self, *args, **kwargs):
27+
"""To get video in html and have a player"""
28+
content = self.to_html5_video()
29+
return re.sub(
30+
r'width="[0-9]*"\sheight="[0-9]*"', 'width="100%" height="100%"', content
31+
)
32+
33+
def save(self, *args, **kwargs):
34+
if args[0].endswith("gif"):
35+
# In this case gif is use to create thumbnail which are not use but consume same time than video
36+
# So we create an empty file, to save time
37+
with open(args[0], "w") as _:
38+
pass
39+
return
40+
return super().save(*args, **kwargs)
41+
42+
43+
# %%
44+
n = NetworkObservations.load_file(get_path("network_med.nc")).network(651)
45+
n = n.extract_with_mask((n.time >= 20180) * (n.time <= 20269))
46+
n = n.remove_dead_end(nobs=0, ndays=10)
47+
n.numbering_segment()
48+
c = GridCollection.from_netcdf_cube(
49+
get_path("dt_med_allsat_phy_l4_2005T2.nc"),
50+
"longitude",
51+
"latitude",
52+
"time",
53+
heigth="adt",
54+
)
55+
56+
# %%
57+
# Schema
58+
# ------
59+
fig = plt.figure(figsize=(12, 6))
60+
ax = fig.add_axes([0.05, 0.05, 0.9, 0.9])
61+
_ = n.display_timeline(ax, field="longitude", marker="+", lw=2, markersize=5)
62+
63+
# %%
64+
# Animation
65+
# ---------
66+
# Particle settings
67+
t_snapshot = 20200
68+
step = 1 / 50.0
69+
x, y = meshgrid(arange(20, 36, step), arange(30, 46, step))
70+
N = 6
71+
x_f, y_f = x[::N, ::N].copy(), y[::N, ::N].copy()
72+
x, y = x.reshape(-1), y.reshape(-1)
73+
x_f, y_f = x_f.reshape(-1), y_f.reshape(-1)
74+
n_ = n.extract_with_mask(n.time == t_snapshot)
75+
index = n_.contains(x, y, intern=True)
76+
m = index != -1
77+
index = n_.segment[index[m]]
78+
index_ = unique(index)
79+
x, y = x[m], y[m]
80+
m = ~n_.inside(x_f, y_f, intern=True)
81+
x_f, y_f = x_f[m], y_f[m]
82+
83+
# %%
84+
# Animation
85+
cmap = colors.ListedColormap(list(n.COLORS), name="from_list", N=n.segment.max() + 1)
86+
a = Anim(
87+
n,
88+
intern=False,
89+
figsize=(12, 6),
90+
nb_step=1,
91+
dpi=60,
92+
field_color="segment",
93+
field_txt="segment",
94+
cmap=cmap,
95+
)
96+
a.fig.suptitle(""), a.ax.set_xlim(24, 36), a.ax.set_ylim(30, 36)
97+
a.txt.set_position((25, 31))
98+
99+
step = 0.25
100+
kw_p = dict(nb_step=2, time_step=86400 * step * 0.5, t_init=t_snapshot - 2 * step)
101+
102+
mappables = dict()
103+
particules = c.advect(x, y, "u", "v", **kw_p)
104+
filament = c.filament(x_f, y_f, "u", "v", **kw_p, filament_size=3)
105+
kw = dict(ls="", marker=".", markersize=0.25)
106+
for k in index_:
107+
m = k == index
108+
mappables[k] = a.ax.plot([], [], color=cmap(k), **kw)[0]
109+
m_filament = a.ax.plot([], [], lw=0.25, color="gray")[0]
110+
111+
112+
def update(frame):
113+
tt, xt, yt = particules.__next__()
114+
for k, mappable in mappables.items():
115+
m = index == k
116+
mappable.set_data(xt[m], yt[m])
117+
tt, xt, yt = filament.__next__()
118+
m_filament.set_data(xt, yt)
119+
if frame % 1 == 0:
120+
a.func_animation(frame)
121+
122+
123+
ani = VideoAnimation(a.fig, update, frames=arange(20200, 20269, step), interval=200)
124+
125+
126+
# %%
127+
# In which observations are the particle
128+
# --------------------------------------
129+
def advect(x, y, c, t0, delta_t):
130+
"""
131+
Advect particle from t0 to t0 + delta_t, with data cube.
132+
"""
133+
kw = dict(nb_step=6, time_step=86400 / 6)
134+
if delta_t < 0:
135+
kw["backward"] = True
136+
delta_t = -delta_t
137+
p = c.advect(x, y, "u", "v", t_init=t0, **kw)
138+
for _ in range(delta_t):
139+
t, x, y = p.__next__()
140+
return t, x, y
141+
142+
143+
def particle_candidate(x, y, c, eddies, t_start, i_target, pct, **kwargs):
144+
# Obs from initial time
145+
m_start = eddies.time == t_start
146+
e = eddies.extract_with_mask(m_start)
147+
# to be able to get global index
148+
translate_start = where(m_start)[0]
149+
# Identify particle in eddies(only in core)
150+
i_start = e.contains(x, y, intern=True)
151+
m = i_start != -1
152+
x, y, i_start = x[m], y[m], i_start[m]
153+
# Advect
154+
t_end, x, y = advect(x, y, c, t_start, **kwargs)
155+
# eddies at last date
156+
m_end = eddies.time == t_end / 86400
157+
e_end = eddies.extract_with_mask(m_end)
158+
# to be able to get global index
159+
translate_end = where(m_end)[0]
160+
# Id eddies for each alive particle(in core and extern)
161+
i_end = e_end.contains(x, y)
162+
# compute matrix and filled target array
163+
get_matrix(i_start, i_end, translate_start, translate_end, i_target, pct)
164+
165+
166+
@njit(cache=True)
167+
def get_matrix(i_start, i_end, translate_start, translate_end, i_target, pct):
168+
nb_start, nb_end = translate_start.size, translate_end.size
169+
# Matrix which will store count for every couple
170+
count = zeros((nb_start, nb_end), dtype=nb_types.int32)
171+
# Number of particle in each origin observation
172+
ref = zeros(nb_start, dtype=nb_types.int32)
173+
# For each particle
174+
for i in range(i_start.size):
175+
i_end_ = i_end[i]
176+
i_start_ = i_start[i]
177+
if i_end_ != -1:
178+
count[i_start_, i_end_] += 1
179+
ref[i_start_] += 1
180+
for i in range(nb_start):
181+
for j in range(nb_end):
182+
pct_ = count[i, j]
183+
# If there are particle from i to j
184+
if pct_ != 0:
185+
# Get percent
186+
pct_ = pct_ / ref[i] * 100.0
187+
# Get indices in full dataset
188+
i_, j_ = translate_start[i], translate_end[j]
189+
pct_0 = pct[i_, 0]
190+
if pct_ > pct_0:
191+
pct[i_, 1] = pct_0
192+
pct[i_, 0] = pct_
193+
i_target[i_, 1] = i_target[i_, 0]
194+
i_target[i_, 0] = j_
195+
elif pct_ > pct[i_, 1]:
196+
pct[i_, 1] = pct_
197+
i_target[i_, 1] = j_
198+
return i_target, pct
199+
200+
201+
# %%
202+
# Particle advection
203+
# ^^^^^^^^^^^^^^^^^^
204+
step = 1 / 60.0
205+
206+
x, y = meshgrid(arange(20, 36, step), arange(30, 46, step))
207+
x0, y0 = x.reshape(-1), y.reshape(-1)
208+
t_start, t_end = n.period
209+
dt = 14
210+
211+
shape = (n.obs.size, 2)
212+
# Forward run
213+
i_target_f, pct_target_f = -ones(shape, dtype="i4"), zeros(shape, dtype="i1")
214+
for t in range(t_start, t_end - dt):
215+
particle_candidate(x0, y0, c, n, t, i_target_f, pct_target_f, delta_t=dt)
216+
217+
# Backward run
218+
i_target_b, pct_target_b = -ones(shape, dtype="i4"), zeros(shape, dtype="i1")
219+
for t in range(t_start + dt, t_end):
220+
particle_candidate(x0, y0, c, n, t, i_target_b, pct_target_b, delta_t=-dt)
221+
222+
# %%
223+
fig = plt.figure(figsize=(10, 10))
224+
ax_1st_b = fig.add_axes([0.05, 0.52, 0.45, 0.45])
225+
ax_2nd_b = fig.add_axes([0.05, 0.05, 0.45, 0.45])
226+
ax_1st_f = fig.add_axes([0.52, 0.52, 0.45, 0.45])
227+
ax_2nd_f = fig.add_axes([0.52, 0.05, 0.45, 0.45])
228+
ax_1st_b.set_title("Backward advection for each time step")
229+
ax_1st_f.set_title("Forward advection for each time step")
230+
231+
232+
def color_alpha(target, pct, vmin=5, vmax=80):
233+
color = cmap(n.segment[target])
234+
# We will hide under 5 % and from 80% to 100 % it will be 1
235+
alpha = (pct - vmin) / (vmax - vmin)
236+
alpha[alpha < 0] = 0
237+
alpha[alpha > 1] = 1
238+
color[:, 3] = alpha
239+
return color
240+
241+
242+
kw = dict(
243+
name=None, yfield="longitude", event=False, zorder=-100, s=(n.speed_area / 20e6)
244+
)
245+
n.scatter_timeline(ax_1st_b, c=color_alpha(i_target_b.T[0], pct_target_b.T[0]), **kw)
246+
n.scatter_timeline(ax_2nd_b, c=color_alpha(i_target_b.T[1], pct_target_b.T[1]), **kw)
247+
n.scatter_timeline(ax_1st_f, c=color_alpha(i_target_f.T[0], pct_target_f.T[0]), **kw)
248+
n.scatter_timeline(ax_2nd_f, c=color_alpha(i_target_f.T[1], pct_target_f.T[1]), **kw)
249+
for ax in (ax_1st_b, ax_2nd_b, ax_1st_f, ax_2nd_f):
250+
n.display_timeline(ax, field="longitude", marker="+", lw=2, markersize=5)
251+
ax.grid()

examples/16_network/pet_segmentation_anim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def get_obs(dataset):
4848

4949

5050
# %%
51-
# Overlaod of class to pick up
51+
# Hack to pick up each step of segmentation
5252
TRACKS = list()
5353
INDICES = list()
5454

notebooks/python_module/07_cube_manipulation/pet_lavd_detection.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@
163163
"cell_type": "markdown",
164164
"metadata": {},
165165
"source": [
166-
"## Contourdetection\nTo extract contour from LAVD grid, we will used method design for SSH, with some hacks and adapted options.\nIt will produce false amplitude and speed.\n\n"
166+
"## Contour detection\nTo extract contour from LAVD grid, we will used method design for SSH, with some hacks and adapted options.\nIt will produce false amplitude and speed.\n\n"
167167
]
168168
},
169169
{

0 commit comments

Comments
 (0)