Skip to content

Commit 3c92d82

Browse files
committed
Example how to run again a segmentation
1 parent 9f626f1 commit 3c92d82

File tree

1 file changed

+210
-0
lines changed

1 file changed

+210
-0
lines changed
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
"""
2+
Replay segmentation
3+
===================
4+
Case from figure 10 from https://doi.org/10.1002/2017JC013158
5+
6+
"""
7+
from datetime import datetime, timedelta
8+
9+
import numpy as np
10+
from matplotlib import pyplot as plt
11+
from matplotlib.animation import FuncAnimation
12+
from matplotlib.ticker import FuncFormatter
13+
14+
import py_eddy_tracker.gui
15+
from py_eddy_tracker.appli.gui import Anim
16+
from py_eddy_tracker.observations.network import NetworkObservations
17+
from py_eddy_tracker.observations.tracking import TrackEddiesObservations
18+
19+
20+
# %%
21+
# Function used to do quick display
22+
class VideoAnimation(FuncAnimation):
23+
def _repr_html_(self, *args, **kwargs):
24+
"""To get video in html and have a player"""
25+
return self.to_html5_video()
26+
27+
def save(self, *args, **kwargs):
28+
if args[0].endswith("gif"):
29+
# In this case gif is use to create thumbnail which are not use but consume same time than video
30+
# So we create an empty file, to save time
31+
with open(args[0], "w") as h:
32+
pass
33+
return
34+
return super().save(*args, **kwargs)
35+
36+
37+
@FuncFormatter
38+
def formatter(x, pos):
39+
return (timedelta(x) + datetime(1950, 1, 1)).strftime("%d/%m/%Y")
40+
41+
42+
def start_axes(title=""):
43+
fig = plt.figure(figsize=(13, 6))
44+
ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], projection="full_axes")
45+
ax.set_xlim(19, 29), ax.set_ylim(31, 35.5)
46+
ax.set_aspect("equal")
47+
ax.set_title(title, weight="bold")
48+
ax.update_env()
49+
return ax
50+
51+
52+
def timeline_axes(title=""):
53+
fig = plt.figure(figsize=(15, 5))
54+
ax = fig.add_axes([0.04, 0.06, 0.89, 0.88])
55+
ax.set_title(title, weight="bold")
56+
ax.xaxis.set_major_formatter(formatter), ax.grid()
57+
return ax
58+
59+
60+
def update_axes(ax, mappable=None):
61+
ax.grid(True)
62+
if mappable:
63+
return plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9]))
64+
65+
66+
# %%
67+
# Class for new_segmentation
68+
# --------------------------
69+
# The oldest win
70+
class MyTrackEddiesObservations(TrackEddiesObservations):
71+
__slots__ = tuple()
72+
73+
@classmethod
74+
def follow_obs(cls, i_next, track_id, used, ids, *args, **kwargs):
75+
"""
76+
Method to overwrite behaviour in merging.
77+
78+
We will give the point to the older one
79+
"""
80+
while i_next != -1:
81+
# Flag
82+
used[i_next] = True
83+
# Assign id
84+
ids["track"][i_next] = track_id
85+
# Search next
86+
i_next_ = cls.get_next_obs(i_next, ids, *args, **kwargs)
87+
if i_next_ == -1:
88+
break
89+
ids["next_obs"][i_next] = i_next_
90+
# Target was previously used
91+
if used[i_next_]:
92+
# if ids["next_cost"][i_next] == ids["previous_cost"][i_next_]:
93+
# print(ids[i_next])
94+
# print(ids[i_next_])
95+
# m = ids["track"][i_next_:] == ids["track"][i_next_]
96+
# ids["track"][i_next_:][m] = track_id
97+
# ids["previous_obs"][i_next_] = i_next
98+
i_next_ = -1
99+
else:
100+
ids["previous_obs"][i_next_] = i_next
101+
i_next = i_next_
102+
103+
104+
def get_obs(dataset):
105+
"Function to isolate a specific obs"
106+
return np.where(
107+
(dataset.lat > 33)
108+
* (dataset.lat < 34)
109+
* (dataset.lon > 22)
110+
* (dataset.lon < 23)
111+
* (dataset.time > 20630)
112+
* (dataset.time < 20650)
113+
)[0][0]
114+
115+
116+
# %%
117+
# Get original network, we will isolate only relative at order *2*
118+
n = NetworkObservations.load_file(
119+
"/tmp/Anticyclonic_seg.nc"
120+
# "/data/adelepoulle/work/Eddies/20201217_network_build/tracking/med/Anticyclonic_seg.nc"
121+
)
122+
123+
n = n.extract_with_mask(n.track == n.track[get_obs(n)])
124+
n_ = n.relative(get_obs(n), order=2)
125+
126+
# %%
127+
ax = start_axes(n_.infos())
128+
n_.plot(ax, color_cycle=n.COLORS)
129+
update_axes(ax)
130+
fig = plt.figure(figsize=(15, 5))
131+
ax = fig.add_axes([0.04, 0.05, 0.92, 0.92])
132+
ax.xaxis.set_major_formatter(formatter), ax.grid()
133+
_ = n_.display_timeline(ax)
134+
135+
# %%
136+
# Run a new segmentation
137+
# ----------------------
138+
e = n.astype(MyTrackEddiesObservations)
139+
e.obs.sort(order=("track", "time"), kind="stable")
140+
split_matrix = e.split_network(intern=False, window=7)
141+
n_ = NetworkObservations.from_split_network(e, split_matrix)
142+
n_ = n_.relative(get_obs(n_), order=2)
143+
n_.numbering_segment()
144+
145+
# %%
146+
# New version
147+
# -----------
148+
ax = start_axes(n_.infos())
149+
n_.plot(ax, color_cycle=n_.COLORS)
150+
update_axes(ax)
151+
fig = plt.figure(figsize=(15, 5))
152+
ax = fig.add_axes([0.04, 0.05, 0.92, 0.92])
153+
ax.xaxis.set_major_formatter(formatter), ax.grid()
154+
_ = n_.display_timeline(ax)
155+
156+
# %%
157+
# Parameter timeline
158+
# ------------------
159+
kw = dict(s=35, cmap=plt.get_cmap("Spectral_r", 8), zorder=10)
160+
ax = timeline_axes()
161+
n_.median_filter(15, "time", "latitude")
162+
m = n_.scatter_timeline(ax, "shape_error_e", vmin=14, vmax=70, **kw, yfield="lat")
163+
cb = update_axes(ax, m["scatter"])
164+
cb.set_label("Effective shape error")
165+
166+
ax = timeline_axes()
167+
n_.median_filter(15, "time", "latitude")
168+
m = n_.scatter_timeline(
169+
ax, "shape_error_e", vmin=14, vmax=70, **kw, yfield="lat", method="all"
170+
)
171+
cb = update_axes(ax, m["scatter"])
172+
cb.set_label("Effective shape error")
173+
ax.set_ylabel("Latitude")
174+
175+
ax = timeline_axes()
176+
n_.median_filter(15, "time", "latitude")
177+
kw["s"] = (n_.radius_e * 1e-3) ** 2 / 30 ** 2 * 20
178+
m = n_.scatter_timeline(
179+
ax,
180+
"shape_error_e",
181+
vmin=14,
182+
vmax=70,
183+
**kw,
184+
yfield="lon",
185+
method="all",
186+
)
187+
ax.set_ylabel("Longitude")
188+
cb = update_axes(ax, m["scatter"])
189+
cb.set_label("Effective shape error")
190+
191+
# %%
192+
# Cost association plot
193+
# ---------------------
194+
n_copy = n_.copy()
195+
n_copy.median_filter(2, "time", "next_cost")
196+
for b0, b1 in [
197+
(datetime(i, 1, 1), datetime(i, 12, 31)) for i in (2004, 2005, 2006, 2007, 2008)
198+
]:
199+
200+
ref, delta = datetime(1950, 1, 1), 20
201+
b0_, b1_ = (b0 - ref).days, (b1 - ref).days
202+
ax = timeline_axes()
203+
ax.set_xlim(b0_ - delta, b1_ + delta)
204+
ax.set_ylim(0, 1)
205+
ax.axvline(b0_, color="k", lw=1.5, ls="--"), ax.axvline(
206+
b1_, color="k", lw=1.5, ls="--"
207+
)
208+
n_copy.display_timeline(ax, field="next_cost", method="all", lw=4, markersize=8)
209+
210+
n_.display_timeline(ax, field="next_cost", method="all", lw=0.5, markersize=0)

0 commit comments

Comments
 (0)