@@ -240,10 +240,188 @@ def from_split_network(cls, group_dataset, indexs, **kwargs):
240240 def infos (self , label = "" ):
241241 return f"{ len (self )} obs { unique (self .segment ).shape [0 ]} segments"
242242
243+ def correct_close_events (self , nb_days_max = 20 ):
244+ """
245+ transform event where
246+ segment A split to B, then A merge into B
247+
248+ to
249+
250+ segment A split to B, then B merge to A
251+
252+ these events are filtered with `nb_days_max`, which the event have to take place in less than `nb_days_max`
253+
254+ :param float nb_days_max: maximum time to search for splitting-merging event
255+ """
256+
257+ _time = self .time
258+ # segment used to correct and track changes
259+ segment = self .segment_track_array .copy ()
260+ # final segment used to copy into self.segment
261+ segment_copy = self .segment
262+
263+ segments_connexion = dict ()
264+
265+ previous_obs , next_obs = self .previous_obs , self .next_obs
266+
267+ # record for every segments, the slice, indice of next obs & indice of previous obs
268+ for i , seg , _ in self .iter_on (segment ):
269+ if i .start == i .stop :
270+ continue
271+
272+ i_p , i_n = previous_obs [i .start ], next_obs [i .stop - 1 ]
273+ segments_connexion [seg ] = [i , i_p , i_n ]
274+
275+ for seg in sorted (segments_connexion .keys ()):
276+ seg_slice , i_seg_p , i_seg_n = segments_connexion [seg ]
277+
278+ # the segment ID has to be corrected, because we may have changed it since
279+ seg_corrected = segment [seg_slice .stop - 1 ]
280+
281+ # we keep the real segment number
282+ seg_corrected_copy = segment_copy [seg_slice .stop - 1 ]
283+
284+ n_seg = segment [i_seg_n ]
285+
286+ # if segment has splitting
287+ if i_seg_n != - 1 :
288+ seg2_slice , i2_seg_p , i2_seg_n = segments_connexion [n_seg ]
289+ p2_seg = segment [i2_seg_p ]
290+
291+ # if it merge on the first in a certain time
292+ if (p2_seg == seg_corrected ) and (
293+ _time [i_seg_n ] - _time [i2_seg_p ] < nb_days_max
294+ ):
295+ my_slice = slice (i_seg_n , seg2_slice .stop )
296+ # correct the factice segment
297+ segment [my_slice ] = seg_corrected
298+ # correct the good segment
299+ segment_copy [my_slice ] = seg_corrected_copy
300+ previous_obs [i_seg_n ] = seg_slice .stop - 1
301+
302+ segments_connexion [seg_corrected ][0 ] = my_slice
303+
304+ self .segment [:] = segment_copy
305+ self .previous_obs [:] = previous_obs
306+
307+ self .sort ()
308+
309+ def sort (self , order = ("track" , "segment" , "time" )):
310+ """
311+ sort observations
312+
313+ :param tuple order: order or sorting. Passed to `np.argsort`
314+ """
315+
316+ index_order = self .obs .argsort (order = order )
317+ for field in self .elements :
318+ self [field ][:] = self [field ][index_order ]
319+
320+ translate = - ones (index_order .max () + 2 , dtype = "i4" )
321+ translate [index_order ] = arange (index_order .shape [0 ])
322+ self .next_obs [:] = translate [self .next_obs ]
323+ self .previous_obs [:] = translate [self .previous_obs ]
324+
243325 def obs_relative_order (self , i_obs ):
244326 self .only_one_network ()
245327 return self .segment_relative_order (self .segment [i_obs ])
246328
329+ def find_link (self , i_observations , forward = True , backward = False ):
330+ """
331+ find all observations where obs `i_observation` could be
332+ in future or past.
333+
334+ if forward=True, search all observation where water
335+ from obs "i_observation" could go
336+
337+ if backward=True, search all observation
338+ where water from obs `i_observation` could come from
339+
340+ :param int,iterable(int) i_observation:
341+ indices of observation. Can be
342+ int, or iterable of int.
343+ :param bool forward, backward:
344+ if forward, search observations after obs.
345+ else mode==backward search before obs
346+
347+ """
348+
349+ i_obs = (
350+ [i_observations ]
351+ if not hasattr (i_observations , "__iter__" )
352+ else i_observations
353+ )
354+
355+ segment = self .segment_track_array
356+ previous_obs , next_obs = self .previous_obs , self .next_obs
357+
358+ segments_connexion = dict ()
359+
360+ for i_slice , seg , _ in self .iter_on (segment ):
361+ if i_slice .start == i_slice .stop :
362+ continue
363+
364+ i_p , i_n = previous_obs [i_slice .start ], next_obs [i_slice .stop - 1 ]
365+ p_seg , n_seg = segment [i_p ], segment [i_n ]
366+
367+ # dumping slice into dict
368+ if seg not in segments_connexion :
369+ segments_connexion [seg ] = [i_slice , [], []]
370+ else :
371+ segments_connexion [seg ][0 ] = i_slice
372+
373+ if i_p != - 1 :
374+
375+ if p_seg not in segments_connexion :
376+ segments_connexion [p_seg ] = [None , [], []]
377+
378+ # backward
379+ segments_connexion [seg ][2 ].append ((i_slice .start , i_p , p_seg ))
380+ # forward
381+ segments_connexion [p_seg ][1 ].append ((i_p , i_slice .start , seg ))
382+
383+ if i_n != - 1 :
384+ if n_seg not in segments_connexion :
385+ segments_connexion [n_seg ] = [None , [], []]
386+
387+ # forward
388+ segments_connexion [seg ][1 ].append ((i_slice .stop - 1 , i_n , n_seg ))
389+ # backward
390+ segments_connexion [n_seg ][2 ].append ((i_n , i_slice .stop - 1 , seg ))
391+
392+ mask = zeros (segment .size , dtype = bool )
393+
394+ def func_forward (seg , indice ):
395+ seg_slice , _forward , _ = segments_connexion [seg ]
396+
397+ mask [indice : seg_slice .stop ] = True
398+ for i_begin , i_end , seg2 in _forward :
399+ if i_begin < indice :
400+ continue
401+
402+ if not mask [i_end ]:
403+ func_forward (seg2 , i_end )
404+
405+ def func_backward (seg , indice ):
406+ seg_slice , _ , _backward = segments_connexion [seg ]
407+
408+ mask [seg_slice .start : indice + 1 ] = True
409+ for i_begin , i_end , seg2 in _backward :
410+ if i_begin > indice :
411+ continue
412+
413+ if not mask [i_end ]:
414+ func_backward (seg2 , i_end )
415+
416+ for indice in i_obs :
417+ if forward :
418+ func_forward (segment [indice ], indice )
419+
420+ if backward :
421+ func_backward (segment [indice ], indice )
422+
423+ return self .extract_with_mask (mask )
424+
247425 def connexions (self , multi_network = False ):
248426 """
249427 create dictionnary for each segments, gives the segments which interact with
@@ -490,14 +668,16 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol
490668
491669 # TODO : fill mappables dict
492670 y_seg = dict ()
671+ _time = self .time
672+
493673 if field is not None and method != "all" :
494674 for i , b0 , _ in self .iter_on ("segment" ):
495675 y = self [field ][i ]
496676 if y .shape [0 ] != 0 :
497677 y_seg [b0 ] = y .mean () * factor
498678 mappables = dict ()
499679 for i , b0 , b1 in self .iter_on ("segment" ):
500- x = self . time [i ]
680+ x = _time [i ]
501681 if x .shape [0 ] == 0 :
502682 continue
503683
@@ -532,7 +712,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol
532712 else y_seg [seg_next ]
533713 )
534714 )
535- ax .plot ((x [- 1 ], self . time [i_n ]), (y0 , y1 ), ** event_kw )[0 ]
715+ ax .plot ((x [- 1 ], _time [i_n ]), (y0 , y1 ), ** event_kw )[0 ]
536716 events ["merging" ].append ((x [- 1 ], y0 ))
537717
538718 if i_p != - 1 :
@@ -548,7 +728,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol
548728 else y_seg [seg_previous ]
549729 )
550730 )
551- ax .plot ((x [0 ], self . time [i_p ]), (y0 , y1 ), ** event_kw )[0 ]
731+ ax .plot ((x [0 ], _time [i_p ]), (y0 , y1 ), ** event_kw )[0 ]
552732 events ["spliting" ].append ((x [0 ], y0 ))
553733
554734 j += 1
@@ -1045,7 +1225,7 @@ def extract_with_mask(self, mask):
10451225 logger .warning ("Empty dataset will be created" )
10461226 else :
10471227 logger .info (
1048- f"{ nb_obs } observations will be extract ({ nb_obs * 100. / self .shape [0 ]} % )"
1228+ f"{ nb_obs } observations will be extract ({ nb_obs / self .shape [0 ]:.3% } )"
10491229 )
10501230 for field in self .obs .dtype .descr :
10511231 if field in ("next_obs" , "previous_obs" ):
0 commit comments