@@ -68,9 +68,27 @@ class NetworkObservations(EddiesObservations):
68
68
@property
69
69
def elements (self ):
70
70
elements = super ().elements
71
- elements .extend (["track" , "segment" , "next_obs" , "previous_obs" ])
71
+ elements .extend (
72
+ [
73
+ "track" ,
74
+ "segment" ,
75
+ "next_obs" ,
76
+ "previous_obs" ,
77
+ "next_cost" ,
78
+ "previous_cost" ,
79
+ ]
80
+ )
72
81
return list (set (elements ))
73
82
83
+ def astype (self , cls ):
84
+ new = cls .new_like (self , self .shape )
85
+ print ()
86
+ for k in new .obs .dtype .names :
87
+ if k in self .obs .dtype .names :
88
+ new [k ][:] = self [k ][:]
89
+ new .sign_type = self .sign_type
90
+ return new
91
+
74
92
def longer_than (self , nb_day_min = - 1 , nb_day_max = - 1 ):
75
93
"""
76
94
Select network on time duration
@@ -81,7 +99,7 @@ def longer_than(self, nb_day_min=-1, nb_day_max=-1):
81
99
if nb_day_max < 0 :
82
100
nb_day_max = 1000000000000
83
101
mask = zeros (self .shape , dtype = "bool" )
84
- for i , b0 , b1 in self .iter_on (self .segment_track_array () ):
102
+ for i , b0 , b1 in self .iter_on (self .segment_track_array ):
85
103
nb = i .stop - i .start
86
104
if nb == 0 :
87
105
continue
@@ -115,6 +133,8 @@ def from_split_network(cls, group_dataset, indexs, **kwargs):
115
133
translate [index_order ] = arange (index_order .shape [0 ])
116
134
network .next_obs [:] = translate [n ]
117
135
network .previous_obs [:] = translate [p ]
136
+ network .next_cost [:] = indexs ["next_cost" ][index_order ]
137
+ network .previous_cost [:] = indexs ["previous_cost" ][index_order ]
118
138
return network
119
139
120
140
def infos (self , label = "" ):
@@ -205,7 +225,7 @@ def position_filter(self, median_half_window, loess_half_window):
205
225
206
226
def loess_filter (self , half_window , xfield , yfield , inplace = True ):
207
227
result = track_loess_filter (
208
- half_window , self .obs [xfield ], self .obs [yfield ], self .segment_track_array ()
228
+ half_window , self .obs [xfield ], self .obs [yfield ], self .segment_track_array
209
229
)
210
230
if inplace :
211
231
self .obs [yfield ] = result
@@ -214,7 +234,7 @@ def loess_filter(self, half_window, xfield, yfield, inplace=True):
214
234
215
235
def median_filter (self , half_window , xfield , yfield , inplace = True ):
216
236
result = track_median_filter (
217
- half_window , self [xfield ], self [yfield ], self .segment_track_array ()
237
+ half_window , self [xfield ], self [yfield ], self .segment_track_array
218
238
)
219
239
if inplace :
220
240
self [yfield ][:] = result
@@ -316,18 +336,59 @@ def event_timeline(self, ax, field=None, method=None, factor=1):
316
336
j += 1
317
337
return mappables
318
338
319
- def scatter_timeline (self , ax , name , factor = 1 , event = True , ** kwargs ):
339
+ def mean_by_segment (self , y , ** kw ):
340
+ kw ["dtype" ] = y .dtype
341
+ return self .map_segment (lambda x : x .mean (), y , ** kw )
342
+
343
+ def map_segment (self , method , y , same = True , ** kw ):
344
+ if same :
345
+ out = empty (y .shape , ** kw )
346
+ else :
347
+ out = list ()
348
+ for i , b0 , b1 in self .iter_on (self .segment_track_array ):
349
+ res = method (y [i ])
350
+ if same :
351
+ out [i ] = res
352
+ else :
353
+ if isinstance (i , slice ):
354
+ if i .start == i .stop :
355
+ continue
356
+ elif len (i ) == 0 :
357
+ continue
358
+ out .append (res )
359
+ if not same :
360
+ out = array (out )
361
+ return out
362
+
363
+ def scatter_timeline (
364
+ self ,
365
+ ax ,
366
+ name ,
367
+ factor = 1 ,
368
+ event = True ,
369
+ yfield = None ,
370
+ yfactor = 1 ,
371
+ method = None ,
372
+ ** kwargs ,
373
+ ):
320
374
"""
321
375
Must be call on only one network
322
376
"""
323
377
self .only_one_network ()
378
+ y = (self .segment if yfield is None else self [yfield ]) * yfactor
379
+ if method == "all" :
380
+ pass
381
+ else :
382
+ y = self .mean_by_segment (y )
324
383
mappables = dict ()
325
384
if event :
326
- mappables .update (self .event_timeline (ax ))
385
+ mappables .update (
386
+ self .event_timeline (ax , field = yfield , method = method , factor = yfactor )
387
+ )
327
388
if "c" not in kwargs :
328
389
v = self .parse_varname (name )
329
390
kwargs ["c" ] = v * factor
330
- mappables ["scatter" ] = ax .scatter (self .time , self . segment , ** kwargs )
391
+ mappables ["scatter" ] = ax .scatter (self .time , y , ** kwargs )
331
392
return mappables
332
393
333
394
def insert_virtual (self ):
@@ -350,13 +411,14 @@ def extract_event(self, indices):
350
411
new .sign_type = self .sign_type
351
412
return new
352
413
414
+ @property
353
415
def segment_track_array (self ):
354
416
return build_unique_array (self .segment , self .track )
355
417
356
418
def birth_event (self ):
357
419
# FIXME how to manage group 0
358
420
indices = list ()
359
- for i , _ , _ in self .iter_on (self .segment_track_array () ):
421
+ for i , _ , _ in self .iter_on (self .segment_track_array ):
360
422
nb = i .stop - i .start
361
423
if nb == 0 :
362
424
continue
@@ -368,7 +430,7 @@ def birth_event(self):
368
430
def death_event (self ):
369
431
# FIXME how to manage group 0
370
432
indices = list ()
371
- for i , _ , _ in self .iter_on (self .segment_track_array () ):
433
+ for i , _ , _ in self .iter_on (self .segment_track_array ):
372
434
nb = i .stop - i .start
373
435
if nb == 0 :
374
436
continue
@@ -379,7 +441,7 @@ def death_event(self):
379
441
380
442
def merging_event (self ):
381
443
indices = list ()
382
- for i , _ , _ in self .iter_on (self .segment_track_array () ):
444
+ for i , _ , _ in self .iter_on (self .segment_track_array ):
383
445
nb = i .stop - i .start
384
446
if nb == 0 :
385
447
continue
@@ -390,7 +452,7 @@ def merging_event(self):
390
452
391
453
def spliting_event (self ):
392
454
indices = list ()
393
- for i , _ , _ in self .iter_on (self .segment_track_array () ):
455
+ for i , _ , _ in self .iter_on (self .segment_track_array ):
394
456
nb = i .stop - i .start
395
457
if nb == 0 :
396
458
continue
@@ -403,7 +465,7 @@ def fully_connected(self):
403
465
self .only_one_network ()
404
466
# TODO
405
467
406
- def plot (self , ax , ref = None , ** kwargs ):
468
+ def plot (self , ax , ref = None , color_cycle = None , ** kwargs ):
407
469
"""
408
470
This function will draw path of each trajectory
409
471
@@ -412,17 +474,25 @@ def plot(self, ax, ref=None, **kwargs):
412
474
:param dict kwargs: keyword arguments for Axes.plot
413
475
:return: a list of matplotlib mappables
414
476
"""
477
+ nb_colors = 0
478
+ if color_cycle is not None :
479
+ kwargs = kwargs .copy ()
480
+ nb_colors = len (color_cycle )
415
481
mappables = list ()
416
482
if "label" in kwargs :
417
483
kwargs ["label" ] = self .format_label (kwargs ["label" ])
418
- for i , b0 , b1 in self .iter_on ("segment" ):
484
+ j = 0
485
+ for i , _ , _ in self .iter_on ("segment" ):
419
486
nb = i .stop - i .start
420
487
if nb == 0 :
421
488
continue
489
+ if nb_colors :
490
+ kwargs ["color" ] = color_cycle [j % nb_colors ]
422
491
x , y = self .lon [i ], self .lat [i ]
423
492
if ref is not None :
424
493
x , y = wrap_longitude (x , y , ref , cut = True )
425
494
mappables .append (ax .plot (x , y , ** kwargs )[0 ])
495
+ j += 1
426
496
return mappables
427
497
428
498
def remove_dead_branch (self , nobs = 3 ):
0 commit comments