|
| 1 | +import numpy as np |
| 2 | +from datetime import datetime |
| 3 | +import matplotlib.pyplot as plt |
| 4 | +from matplotlib.projections import register_projection |
| 5 | +from matplotlib.widgets import Slider |
| 6 | +import py_eddy_tracker_sample as sample |
| 7 | +from .generic import flatten_line_matrix, split_line |
| 8 | +from .observations.tracking import TrackEddiesObservations |
| 9 | + |
| 10 | + |
| 11 | +try: |
| 12 | + from pylook.axes import PlatCarreAxes |
| 13 | +except: |
| 14 | + from matplotlib.axes import Axes |
| 15 | + |
| 16 | + class PlatCarreAxes(Axes): |
| 17 | + def __init__(self, *args, **kwargs): |
| 18 | + super().__init__(*args, **kwargs) |
| 19 | + self.set_aspect("equal") |
| 20 | + |
| 21 | + |
| 22 | +class GUIAxes(PlatCarreAxes): |
| 23 | + name = "full_axes" |
| 24 | + |
| 25 | + def end_pan(self, *args, **kwargs): |
| 26 | + (x0, x1), (y0, y1) = self.get_xlim(), self.get_ylim() |
| 27 | + x, y = (x1 + x0) / 2, (y1 + y0) / 2 |
| 28 | + dx, dy = (x1 - x0) / 2.0, (y1 - y0) / 2.0 |
| 29 | + r_coord = dx / dy |
| 30 | + # r_axe |
| 31 | + _, _, w_ax, h_ax = self.get_position(original=True).bounds |
| 32 | + w_fig, h_fig = self.figure.get_size_inches() |
| 33 | + r_ax = w_ax / h_ax * w_fig / h_fig |
| 34 | + if r_ax < r_coord: |
| 35 | + y0, y1 = y - dx / r_ax, y + dx / r_ax |
| 36 | + self.set_ylim(y0, y1) |
| 37 | + else: |
| 38 | + x0, x1 = x - dy * r_ax, x + dy * r_ax |
| 39 | + self.set_xlim(x0, x1) |
| 40 | + super().end_pan(*args, **kwargs) |
| 41 | + |
| 42 | + |
| 43 | +register_projection(GUIAxes) |
| 44 | + |
| 45 | + |
| 46 | +class A(TrackEddiesObservations): |
| 47 | + pass |
| 48 | + |
| 49 | + |
| 50 | +def no(*args, **kwargs): |
| 51 | + return False |
| 52 | + |
| 53 | + |
| 54 | +class GUI: |
| 55 | + __slots__ = ( |
| 56 | + "datasets", |
| 57 | + "figure", |
| 58 | + "map", |
| 59 | + "time_ax", |
| 60 | + "param_ax", |
| 61 | + "settings", |
| 62 | + "m", |
| 63 | + "last_event", |
| 64 | + ) |
| 65 | + COLORS = ("r", "g", "b", "y", "k") |
| 66 | + KEYTIME = dict(down=-1, up=1, pagedown=-5, pageup=5) |
| 67 | + |
| 68 | + def __init__(self, **datasets): |
| 69 | + self.datasets = datasets |
| 70 | + self.m = dict() |
| 71 | + self.set_initial_values() |
| 72 | + self.setup() |
| 73 | + self.last_event = datetime.now() |
| 74 | + self.draw() |
| 75 | + self.event() |
| 76 | + |
| 77 | + def set_initial_values(self): |
| 78 | + t0, t1 = 1e6, 0 |
| 79 | + for dataset in self.datasets.values(): |
| 80 | + t0_, t1_ = dataset.period |
| 81 | + t0, t1 = min(t0, t0_), max(t1, t1_) |
| 82 | + |
| 83 | + self.settings = dict(period=(t0, t1), now=20000,) |
| 84 | + # self.settings = dict(period=(t0, t1), now=t0,) |
| 85 | + |
| 86 | + @property |
| 87 | + def now(self): |
| 88 | + return self.settings["now"] |
| 89 | + |
| 90 | + @property |
| 91 | + def period(self): |
| 92 | + return self.settings["period"] |
| 93 | + |
| 94 | + @property |
| 95 | + def bbox(self): |
| 96 | + return self.map.get_xlim(), self.map.get_ylim() |
| 97 | + |
| 98 | + def indexs(self, dataset): |
| 99 | + (x0, x1), (y0, y1) = self.bbox |
| 100 | + x, y = dataset.longitude, dataset.latitude |
| 101 | + m = (x0 < x) & (x < x1) & (y0 < y) & (y < y1) & (self.now == dataset.time) |
| 102 | + return np.where(m)[0] |
| 103 | + |
| 104 | + def med(self): |
| 105 | + self.map.set_xlim(-6, 37) |
| 106 | + self.map.set_ylim(30, 46) |
| 107 | + |
| 108 | + def setup(self): |
| 109 | + self.figure = plt.figure() |
| 110 | + # map |
| 111 | + self.map = self.figure.add_axes((0, 0.25, 1, 0.75), projection="full_axes") |
| 112 | + self.map.grid() |
| 113 | + self.map.tick_params("x", pad=-12) |
| 114 | + self.map.tick_params("y", pad=-22) |
| 115 | + # time ax |
| 116 | + self.time_ax = self.figure.add_axes((0, 0.15, 1, 0.1), facecolor=".95") |
| 117 | + self.time_ax.can_pan |
| 118 | + self.time_ax.set_xlim(*self.period) |
| 119 | + self.time_ax.press = False |
| 120 | + self.time_ax.can_pan = self.time_ax.can_zoom = no |
| 121 | + for i, dataset in enumerate(self.datasets.values()): |
| 122 | + self.time_ax.hist( |
| 123 | + dataset.time, |
| 124 | + bins=np.arange(self.period[0] - 0.5, self.period[1] + 0.51), |
| 125 | + color=self.COLORS[i], |
| 126 | + histtype="step", |
| 127 | + ) |
| 128 | + # param |
| 129 | + self.param_ax = self.figure.add_axes((0, 0, 1, 0.15), facecolor="0.2") |
| 130 | + |
| 131 | + def draw(self): |
| 132 | + # map |
| 133 | + for i, (name, dataset) in enumerate(self.datasets.items()): |
| 134 | + self.m[name] = dict( |
| 135 | + contour_s=self.map.plot( |
| 136 | + [], [], color=self.COLORS[i], lw=0.5, label=name |
| 137 | + )[0], |
| 138 | + contour_e=self.map.plot([], [], color=self.COLORS[i], lw=0.5)[0], |
| 139 | + path_previous=self.map.plot([], [], color=self.COLORS[i], lw=0.5)[0], |
| 140 | + path_future=self.map.plot([], [], color=self.COLORS[i], lw=0.2, ls=":")[ |
| 141 | + 0 |
| 142 | + ], |
| 143 | + ) |
| 144 | + self.m["title"] = self.map.set_title("") |
| 145 | + # time_ax |
| 146 | + self.m["time_vline"] = self.time_ax.axvline(0, color="k", lw=1) |
| 147 | + self.m["time_text"] = self.time_ax.text( |
| 148 | + 0, 0, "", fontsize=8, bbox=dict(facecolor="w", alpha=0.75) |
| 149 | + ) |
| 150 | + |
| 151 | + def update(self): |
| 152 | + # text = [] |
| 153 | + # map |
| 154 | + xs, ys, ns = list(), list(), list() |
| 155 | + for j, (name, dataset) in enumerate(self.datasets.items()): |
| 156 | + i = self.indexs(dataset) |
| 157 | + self.m[name]["contour_s"].set_label(f"{name} {len(i)} eddies") |
| 158 | + if len(i) == 0: |
| 159 | + self.m[name]["contour_s"].set_data([], []) |
| 160 | + else: |
| 161 | + self.m[name]["contour_s"].set_data( |
| 162 | + flatten_line_matrix(dataset["contour_lon_s"][i]), |
| 163 | + flatten_line_matrix(dataset["contour_lat_s"][i]), |
| 164 | + ) |
| 165 | + # text.append(f"{i.shape[0]}") |
| 166 | + local_path = dataset.extract_ids(dataset["track"][i]) |
| 167 | + x, y, t, n, tr = ( |
| 168 | + local_path.longitude, |
| 169 | + local_path.latitude, |
| 170 | + local_path.time, |
| 171 | + local_path["n"], |
| 172 | + local_path["track"], |
| 173 | + ) |
| 174 | + m = t <= self.now |
| 175 | + if m.sum(): |
| 176 | + x_, y_ = split_line(x[m], y[m], tr[m]) |
| 177 | + self.m[name]["path_previous"].set_data(x_, y_) |
| 178 | + else: |
| 179 | + self.m[name]["path_previous"].set_data([], []) |
| 180 | + m = t >= self.now |
| 181 | + if m.sum(): |
| 182 | + x_, y_ = split_line(x[m], y[m], tr[m]) |
| 183 | + self.m[name]["path_future"].set_data(x_, y_) |
| 184 | + else: |
| 185 | + self.m[name]["path_future"].set_data([], []) |
| 186 | + m = t == self.now |
| 187 | + xs.append(x[m]), ys.append(y[m]), ns.append(n[m]) |
| 188 | + |
| 189 | + x, y, n = np.concatenate(xs), np.concatenate(ys), np.concatenate(ns) |
| 190 | + n_min = 0 |
| 191 | + if len(n) > 50: |
| 192 | + n_ = n.copy() |
| 193 | + n_.sort() |
| 194 | + n_min = n_[-50] |
| 195 | + for text in self.m.pop("texts", list()): |
| 196 | + text.remove() |
| 197 | + self.m["texts"] = [ |
| 198 | + self.map.text(x_, y_, n_) for x_, y_, n_ in zip(x, y, n) if n_ >= n_min |
| 199 | + ] |
| 200 | + |
| 201 | + self.m["title"].set_text(self.now) |
| 202 | + self.map.legend() |
| 203 | + # time ax |
| 204 | + x, y = self.m["time_vline"].get_data() |
| 205 | + self.m["time_vline"].set_data(self.now, y) |
| 206 | + # self.m["time_text"].set_text("\n".join(text)) |
| 207 | + self.m["time_text"].set_position((self.now, 0)) |
| 208 | + # force update |
| 209 | + self.map.figure.canvas.draw() |
| 210 | + |
| 211 | + def event(self): |
| 212 | + self.figure.canvas.mpl_connect("resize_event", self.adjust) |
| 213 | + self.figure.canvas.mpl_connect("scroll_event", self.scroll) |
| 214 | + self.figure.canvas.mpl_connect("button_press_event", self.press) |
| 215 | + self.figure.canvas.mpl_connect("motion_notify_event", self.move) |
| 216 | + self.figure.canvas.mpl_connect("button_release_event", self.release) |
| 217 | + self.figure.canvas.mpl_connect("key_press_event", self.keyboard) |
| 218 | + |
| 219 | + def keyboard(self, event): |
| 220 | + if event.key in self.KEYTIME: |
| 221 | + self.settings["now"] += self.KEYTIME[event.key] |
| 222 | + self.update() |
| 223 | + elif event.key == "home": |
| 224 | + self.settings["now"] = self.period[0] |
| 225 | + self.update() |
| 226 | + elif event.key == "end": |
| 227 | + self.settings["now"] = self.period[1] |
| 228 | + self.update() |
| 229 | + |
| 230 | + def press(self, event): |
| 231 | + if event.inaxes == self.time_ax and self.m["time_vline"].contains(event)[0]: |
| 232 | + self.time_ax.press = True |
| 233 | + self.time_ax.bg_cache = self.figure.canvas.copy_from_bbox(self.time_ax.bbox) |
| 234 | + |
| 235 | + def move(self, event): |
| 236 | + if event.inaxes == self.time_ax and self.time_ax.press: |
| 237 | + x, y = self.m["time_vline"].get_data() |
| 238 | + self.m["time_vline"].set_data(event.xdata, y) |
| 239 | + self.figure.canvas.restore_region(self.time_ax.bg_cache) |
| 240 | + self.time_ax.draw_artist(self.m["time_vline"]) |
| 241 | + self.figure.canvas.blit(self.time_ax.bbox) |
| 242 | + |
| 243 | + def release(self, event): |
| 244 | + if self.time_ax.press: |
| 245 | + self.time_ax.press = False |
| 246 | + self.settings["now"] = int(self.m["time_vline"].get_data()[0]) |
| 247 | + self.update() |
| 248 | + |
| 249 | + def scroll(self, event): |
| 250 | + if event.inaxes != self.map: |
| 251 | + return |
| 252 | + if event.button == "up": |
| 253 | + self.settings["now"] += 1 |
| 254 | + if event.button == "down": |
| 255 | + self.settings["now"] -= 1 |
| 256 | + self.update() |
| 257 | + |
| 258 | + def adjust(self, event=None): |
| 259 | + self.map._pan_start = None |
| 260 | + self.map.end_pan() |
| 261 | + |
| 262 | + def show(self): |
| 263 | + self.update() |
| 264 | + plt.show() |
| 265 | + |
| 266 | + |
| 267 | +if __name__ == "__main__": |
| 268 | + |
| 269 | + # a_ = A.load_file( |
| 270 | + # "/home/toto/dev/work/pet/20200611_example_dataset/tracking/Anticyclonic_track_too_short.nc" |
| 271 | + # ) |
| 272 | + # c_ = A.load_file( |
| 273 | + # "/home/toto/dev/work/pet/20200611_example_dataset/tracking/Cyclonic_track_too_short.nc" |
| 274 | + # ) |
| 275 | + a = A.load_file(sample.get_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr")) |
| 276 | + # c = A.load_file(sample.get_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr")) |
| 277 | + # g = GUI(Acyc=a, Cyc=c, Acyc_short=a_, Cyc_short=c_) |
| 278 | + g = GUI(Acyc=a) |
| 279 | + # g = GUI(Acyc_short=a_) |
| 280 | + # g = GUI(Acyc_short=a_, Cyc_short=c_) |
| 281 | + g.med() |
| 282 | + g.show() |
0 commit comments