Skip to content

Commit 5f5a622

Browse files
committed
- add method to map method on segment
- allow to appy color_cycle on map - add next_cost and previous_cost in variable
1 parent 482b7fd commit 5f5a622

File tree

3 files changed

+109
-17
lines changed

3 files changed

+109
-17
lines changed

src/py_eddy_tracker/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,26 @@ def parse_args(self, *args, **kwargs):
403403
comment="Index of next obs, if there are a merging",
404404
),
405405
),
406+
previous_cost=dict(
407+
attr_name=None,
408+
nc_name="previous_cost",
409+
nc_type="float32",
410+
nc_dims=("obs",),
411+
nc_attr=dict(
412+
long_name="Previous cost for previous obs",
413+
comment="",
414+
),
415+
),
416+
next_cost=dict(
417+
attr_name=None,
418+
nc_name="next_cost",
419+
nc_type="float32",
420+
nc_dims=("obs",),
421+
nc_attr=dict(
422+
long_name="Next cost for next obs",
423+
comment="",
424+
),
425+
),
406426
n=dict(
407427
attr_name=None,
408428
nc_name="observation_number",

src/py_eddy_tracker/observations/network.py

Lines changed: 83 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,27 @@ class NetworkObservations(EddiesObservations):
6868
@property
6969
def elements(self):
7070
elements = super().elements
71-
elements.extend(["track", "segment", "next_obs", "previous_obs"])
71+
elements.extend(
72+
[
73+
"track",
74+
"segment",
75+
"next_obs",
76+
"previous_obs",
77+
"next_cost",
78+
"previous_cost",
79+
]
80+
)
7281
return list(set(elements))
7382

83+
def astype(self, cls):
84+
new = cls.new_like(self, self.shape)
85+
print()
86+
for k in new.obs.dtype.names:
87+
if k in self.obs.dtype.names:
88+
new[k][:] = self[k][:]
89+
new.sign_type = self.sign_type
90+
return new
91+
7492
def longer_than(self, nb_day_min=-1, nb_day_max=-1):
7593
"""
7694
Select network on time duration
@@ -81,7 +99,7 @@ def longer_than(self, nb_day_min=-1, nb_day_max=-1):
8199
if nb_day_max < 0:
82100
nb_day_max = 1000000000000
83101
mask = zeros(self.shape, dtype="bool")
84-
for i, b0, b1 in self.iter_on(self.segment_track_array()):
102+
for i, b0, b1 in self.iter_on(self.segment_track_array):
85103
nb = i.stop - i.start
86104
if nb == 0:
87105
continue
@@ -115,6 +133,8 @@ def from_split_network(cls, group_dataset, indexs, **kwargs):
115133
translate[index_order] = arange(index_order.shape[0])
116134
network.next_obs[:] = translate[n]
117135
network.previous_obs[:] = translate[p]
136+
network.next_cost[:] = indexs["next_cost"][index_order]
137+
network.previous_cost[:] = indexs["previous_cost"][index_order]
118138
return network
119139

120140
def infos(self, label=""):
@@ -205,7 +225,7 @@ def position_filter(self, median_half_window, loess_half_window):
205225

206226
def loess_filter(self, half_window, xfield, yfield, inplace=True):
207227
result = track_loess_filter(
208-
half_window, self.obs[xfield], self.obs[yfield], self.segment_track_array()
228+
half_window, self.obs[xfield], self.obs[yfield], self.segment_track_array
209229
)
210230
if inplace:
211231
self.obs[yfield] = result
@@ -214,7 +234,7 @@ def loess_filter(self, half_window, xfield, yfield, inplace=True):
214234

215235
def median_filter(self, half_window, xfield, yfield, inplace=True):
216236
result = track_median_filter(
217-
half_window, self[xfield], self[yfield], self.segment_track_array()
237+
half_window, self[xfield], self[yfield], self.segment_track_array
218238
)
219239
if inplace:
220240
self[yfield][:] = result
@@ -316,18 +336,59 @@ def event_timeline(self, ax, field=None, method=None, factor=1):
316336
j += 1
317337
return mappables
318338

319-
def scatter_timeline(self, ax, name, factor=1, event=True, **kwargs):
339+
def mean_by_segment(self, y, **kw):
340+
kw["dtype"] = y.dtype
341+
return self.map_segment(lambda x: x.mean(), y, **kw)
342+
343+
def map_segment(self, method, y, same=True, **kw):
344+
if same:
345+
out = empty(y.shape, **kw)
346+
else:
347+
out = list()
348+
for i, b0, b1 in self.iter_on(self.segment_track_array):
349+
res = method(y[i])
350+
if same:
351+
out[i] = res
352+
else:
353+
if isinstance(i, slice):
354+
if i.start == i.stop:
355+
continue
356+
elif len(i) == 0:
357+
continue
358+
out.append(res)
359+
if not same:
360+
out = array(out)
361+
return out
362+
363+
def scatter_timeline(
364+
self,
365+
ax,
366+
name,
367+
factor=1,
368+
event=True,
369+
yfield=None,
370+
yfactor=1,
371+
method=None,
372+
**kwargs,
373+
):
320374
"""
321375
Must be call on only one network
322376
"""
323377
self.only_one_network()
378+
y = (self.segment if yfield is None else self[yfield]) * yfactor
379+
if method == "all":
380+
pass
381+
else:
382+
y = self.mean_by_segment(y)
324383
mappables = dict()
325384
if event:
326-
mappables.update(self.event_timeline(ax))
385+
mappables.update(
386+
self.event_timeline(ax, field=yfield, method=method, factor=yfactor)
387+
)
327388
if "c" not in kwargs:
328389
v = self.parse_varname(name)
329390
kwargs["c"] = v * factor
330-
mappables["scatter"] = ax.scatter(self.time, self.segment, **kwargs)
391+
mappables["scatter"] = ax.scatter(self.time, y, **kwargs)
331392
return mappables
332393

333394
def insert_virtual(self):
@@ -350,13 +411,14 @@ def extract_event(self, indices):
350411
new.sign_type = self.sign_type
351412
return new
352413

414+
@property
353415
def segment_track_array(self):
354416
return build_unique_array(self.segment, self.track)
355417

356418
def birth_event(self):
357419
# FIXME how to manage group 0
358420
indices = list()
359-
for i, _, _ in self.iter_on(self.segment_track_array()):
421+
for i, _, _ in self.iter_on(self.segment_track_array):
360422
nb = i.stop - i.start
361423
if nb == 0:
362424
continue
@@ -368,7 +430,7 @@ def birth_event(self):
368430
def death_event(self):
369431
# FIXME how to manage group 0
370432
indices = list()
371-
for i, _, _ in self.iter_on(self.segment_track_array()):
433+
for i, _, _ in self.iter_on(self.segment_track_array):
372434
nb = i.stop - i.start
373435
if nb == 0:
374436
continue
@@ -379,7 +441,7 @@ def death_event(self):
379441

380442
def merging_event(self):
381443
indices = list()
382-
for i, _, _ in self.iter_on(self.segment_track_array()):
444+
for i, _, _ in self.iter_on(self.segment_track_array):
383445
nb = i.stop - i.start
384446
if nb == 0:
385447
continue
@@ -390,7 +452,7 @@ def merging_event(self):
390452

391453
def spliting_event(self):
392454
indices = list()
393-
for i, _, _ in self.iter_on(self.segment_track_array()):
455+
for i, _, _ in self.iter_on(self.segment_track_array):
394456
nb = i.stop - i.start
395457
if nb == 0:
396458
continue
@@ -403,7 +465,7 @@ def fully_connected(self):
403465
self.only_one_network()
404466
# TODO
405467

406-
def plot(self, ax, ref=None, **kwargs):
468+
def plot(self, ax, ref=None, color_cycle=None, **kwargs):
407469
"""
408470
This function will draw path of each trajectory
409471
@@ -412,17 +474,25 @@ def plot(self, ax, ref=None, **kwargs):
412474
:param dict kwargs: keyword arguments for Axes.plot
413475
:return: a list of matplotlib mappables
414476
"""
477+
nb_colors = 0
478+
if color_cycle is not None:
479+
kwargs = kwargs.copy()
480+
nb_colors = len(color_cycle)
415481
mappables = list()
416482
if "label" in kwargs:
417483
kwargs["label"] = self.format_label(kwargs["label"])
418-
for i, b0, b1 in self.iter_on("segment"):
484+
j = 0
485+
for i, _, _ in self.iter_on("segment"):
419486
nb = i.stop - i.start
420487
if nb == 0:
421488
continue
489+
if nb_colors:
490+
kwargs["color"] = color_cycle[j % nb_colors]
422491
x, y = self.lon[i], self.lat[i]
423492
if ref is not None:
424493
x, y = wrap_longitude(x, y, ref, cut=True)
425494
mappables.append(ax.plot(x, y, **kwargs)[0])
495+
j += 1
426496
return mappables
427497

428498
def remove_dead_branch(self, nobs=3):

src/py_eddy_tracker/observations/tracking.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ def set_tracks(self, x, y, ids, window, **kwargs):
678678
self.follow_obs(i, track_id, used, ids, x, y, *time_index, window, **kwargs)
679679
track_id += 1
680680
# Search a possible ancestor (backward)
681-
self.previous_obs(i, ids, x, y, *time_index, window, **kwargs)
681+
self.get_previous_obs(i, ids, x, y, *time_index, window, **kwargs)
682682

683683
@classmethod
684684
def follow_obs(cls, i_next, track_id, used, ids, *args, **kwargs):
@@ -690,7 +690,7 @@ def follow_obs(cls, i_next, track_id, used, ids, *args, **kwargs):
690690
# Assign id
691691
ids["track"][i_next] = track_id
692692
# Search next
693-
i_next_ = cls.next_obs(i_next, ids, *args, **kwargs)
693+
i_next_ = cls.get_next_obs(i_next, ids, *args, **kwargs)
694694
if i_next_ == -1:
695695
break
696696
ids["next_obs"][i_next] = i_next_
@@ -706,7 +706,9 @@ def follow_obs(cls, i_next, track_id, used, ids, *args, **kwargs):
706706
i_next = i_next_
707707

708708
@staticmethod
709-
def previous_obs(i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs):
709+
def get_previous_obs(
710+
i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs
711+
):
710712
"""Backward association of observations to the segments"""
711713

712714
time_cur = ids["time"][i_current]
@@ -736,7 +738,7 @@ def previous_obs(i_current, ids, x, y, time_s, time_e, time_ref, window, **kwarg
736738
break
737739

738740
@staticmethod
739-
def next_obs(i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs):
741+
def get_next_obs(i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs):
740742
"""Forward association of observations to the segments"""
741743
time_max = time_e.shape[0] - 1
742744
time_cur = ids["time"][i_current]

0 commit comments

Comments
 (0)