@@ -113,6 +113,20 @@ def __init__(self, *args, **kwargs):
113113 super ().__init__ (* args , ** kwargs )
114114 self .reset_index ()
115115
116+ def __repr__ (self ):
117+ m_event , s_event = self .merging_event (only_index = True , triplet = True )[0 ], self .splitting_event (only_index = True , triplet = True )[0 ]
118+ period = (self .period [1 ] - self .period [0 ]) / 365.25
119+ nb_by_network = self .network_size ()
120+ big = 50_000
121+ infos = [
122+ f"Atlas with { self .nb_network } networks ({ self .nb_network / period :0.0f} networks/year),"
123+ f" { self .nb_segment } segments ({ self .nb_segment / period :0.0f} segments/year), { len (self )} observations ({ len (self ) / period :0.0f} observations/year)" ,
124+ f" { m_event .size } merging ({ m_event .size / period :0.0f} merging/year), { s_event .size } splitting ({ s_event .size / period :0.0f} splitting/year)" ,
125+ f" with { (nb_by_network > big ).sum ()} network with more than { big } obs and the biggest have { nb_by_network .max ()} observations ({ nb_by_network [nb_by_network > big ].sum ()} observations cumulate)" ,
126+ f" { nb_by_network [0 ]} observations in trash"
127+ ]
128+ return "\n " .join (infos )
129+
116130 def reset_index (self ):
117131 self ._index_network = None
118132 self ._index_segment_track = None
@@ -313,13 +327,19 @@ def correct_close_events(self, nb_days_max=20):
313327 """
314328 Transform event where
315329 segment A splits from segment B, then x days after segment B merges with A
316-
317330 to
318-
319331 segment A splits from segment B then x days after segment A merges with B (B will be longer)
320-
321332 These events have to last less than `nb_days_max` to be changed.
322333
334+
335+ ------------------- A
336+ / /
337+ B --------------------
338+ to
339+ --A--
340+ / \
341+ B -----------------------------------
342+
323343 :param float nb_days_max: maximum time to search for splitting-merging event
324344 """
325345
@@ -342,7 +362,7 @@ def correct_close_events(self, nb_days_max=20):
342362 segments_connexion [seg ] = [i , i_p , i_n ]
343363
344364 for seg in sorted (segments_connexion .keys ()):
345- seg_slice , i_seg_p , i_seg_n = segments_connexion [seg ]
365+ seg_slice , _ , i_seg_n = segments_connexion [seg ]
346366
347367 # the segment ID has to be corrected, because we may have changed it since
348368 seg_corrected = segment [seg_slice .stop - 1 ]
@@ -370,8 +390,6 @@ def correct_close_events(self, nb_days_max=20):
370390
371391 segments_connexion [seg_corrected ][0 ] = my_slice
372392
373- self .segment [:] = segment_copy
374- self .previous_obs [:] = previous_obs
375393 return self .sort ()
376394
377395 def sort (self , order = ("track" , "segment" , "time" )):
@@ -495,35 +513,38 @@ def func_backward(seg, indice):
495513 return self .extract_with_mask (mask )
496514
497515 def connexions (self , multi_network = False ):
498- """
499- Create dictionnary for each segment, gives the segments in interaction with
516+ """Create dictionnary for each segment, gives the segments in interaction with
517+
518+ :param bool multi_network: use segment_track_array instead of segment, defaults to False
519+ :return dict: Return dict of set, for each seg id we get set of segment which have event with him
500520 """
501521 if multi_network :
502522 segment = self .segment_track_array
503523 else :
504524 self .only_one_network ()
505525 segment = self .segment
506526 segments_connexion = dict ()
507-
508- def add_seg (father , child ):
509- if father not in segments_connexion :
510- segments_connexion [father ] = set ()
511- segments_connexion [father ].add (child )
512-
513- previous_obs , next_obs = self .previous_obs , self .next_obs
514- for i , seg , _ in self .iter_on (segment ):
515- if i .start == i .stop :
516- continue
517- i_p , i_n = previous_obs [i .start ], next_obs [i .stop - 1 ]
518- # segment in interaction
519- p_seg , n_seg = segment [i_p ], segment [i_n ]
520- # Where segment are called
521- if i_p != - 1 :
522- add_seg (p_seg , seg )
523- add_seg (seg , p_seg )
524- if i_n != - 1 :
525- add_seg (n_seg , seg )
526- add_seg (seg , n_seg )
527+ def add_seg (s1 , s2 ):
528+ if s1 not in segments_connexion :
529+ segments_connexion [s1 ] = set ()
530+ if s2 not in segments_connexion :
531+ segments_connexion [s2 ] = set ()
532+ segments_connexion [s1 ].add (s2 ), segments_connexion [s2 ].add (s1 )
533+ # Get index for each segment
534+ i0 , i1 , _ = self .index_segment_track
535+ i1 = i1 - 1
536+ # Check if segment merge
537+ i_next = self .next_obs [i1 ]
538+ m_n = i_next != - 1
539+ # Check if segment come from splitting
540+ i_previous = self .previous_obs [i0 ]
541+ m_p = i_previous != - 1
542+ # For each split
543+ for s1 , s2 in zip (segment [i_previous [m_p ]], segment [i0 [m_p ]]):
544+ add_seg (s1 , s2 )
545+ # For each merge
546+ for s1 , s2 in zip (segment [i_next [m_n ]], segment [i1 [m_n ]]):
547+ add_seg (s1 , s2 )
527548 return segments_connexion
528549
529550 @classmethod
@@ -1089,68 +1110,57 @@ def segment_track_array(self):
10891110 return self ._segment_track_array
10901111
10911112 def birth_event (self ):
1092- """Extract birth events.
1093- Advice : individual eddies (self.track == 0) should be removed before -> apply remove_trash."""
1094- # FIXME how to manage group 0
1095- indices = list ()
1096- previous_obs = self .previous_obs
1097- for i , _ , _ in self .iter_on (self .segment_track_array ):
1098- nb = i .stop - i .start
1099- if nb == 0 :
1100- continue
1101- i_p = previous_obs [i .start ]
1102- if i_p == - 1 :
1103- indices .append (i .start )
1104- return self .extract_event (list (set (indices )))
1113+ """Extract birth events."""
1114+ i_start , _ , _ = self .index_segment_track
1115+ indices = i_start [self .previous_obs [i_start ] == - 1 ]
1116+ if self .first_is_trash ():
1117+ indices = indices [1 :]
1118+ return self .extract_event (indices )
1119+ generation_event = birth_event
11051120
11061121 def death_event (self ):
1107- """Extract death events.
1108- Advice : individual eddies (self.track == 0) should be removed before -> apply remove_trash."""
1109- # FIXME how to manage group 0
1110- indices = list ()
1111- next_obs = self .next_obs
1112- for i , _ , _ in self .iter_on (self .segment_track_array ):
1113- nb = i .stop - i .start
1114- if nb == 0 :
1115- continue
1116- i_n = next_obs [i .stop - 1 ]
1117- if i_n == - 1 :
1118- indices .append (i .stop - 1 )
1119- return self .extract_event (list (set (indices )))
1122+ """Extract death events."""
1123+ _ , i_stop , _ = self .index_segment_track
1124+ indices = i_stop [self .next_obs [i_stop - 1 ] == - 1 ] - 1
1125+ if self .first_is_trash ():
1126+ indices = indices [1 :]
1127+ return self .extract_event (indices )
1128+ dissipation_event = death_event
11201129
11211130 def merging_event (self , triplet = False , only_index = False ):
11221131 """Return observation after a merging event.
11231132
11241133 If `triplet=True` return the eddy after a merging event, the eddy before the merging event,
11251134 and the eddy stopped due to merging.
11261135 """
1127- idx_m1 = list ()
1136+ # Get start and stop for each segment, there is no empty segment
1137+ _ , i1 , _ = self .index_segment_track
1138+ # Get last index for each segment
1139+ i_stop = i1 - 1
1140+ # Get target index
1141+ idx_m1 = self .next_obs [i_stop ]
1142+ # Get mask and valid target
1143+ m = idx_m1 != - 1
1144+ idx_m1 = idx_m1 [m ]
1145+ # Sort by time event
1146+ i = self .time [idx_m1 ].argsort ()
1147+ idx_m1 = idx_m1 [i ]
11281148 if triplet :
1129- idx_m0_stop = list ()
1130- idx_m0 = list ()
1131- next_obs , previous_obs = self .next_obs , self .previous_obs
1132- for i , _ , _ in self .iter_on (self .segment_track_array ):
1133- nb = i .stop - i .start
1134- if nb == 0 :
1135- continue
1136- i_n = next_obs [i .stop - 1 ]
1137- if i_n != - 1 :
1138- if triplet :
1139- idx_m0_stop .append (i .stop - 1 )
1140- idx_m0 .append (previous_obs [i_n ])
1141- idx_m1 .append (i_n )
1149+ # Get obs before target
1150+ idx_m0_stop = i_stop [m ][i ]
1151+ idx_m0 = self .previous_obs [idx_m1 ].copy ()
11421152
11431153 if triplet :
11441154 if only_index :
1145- return array ( idx_m1 ), array ( idx_m0 ), array ( idx_m0_stop )
1155+ return idx_m1 , idx_m0 , idx_m0_stop
11461156 else :
11471157 return (
11481158 self .extract_event (idx_m1 ),
11491159 self .extract_event (idx_m0 ),
11501160 self .extract_event (idx_m0_stop ),
11511161 )
11521162 else :
1153- idx_m1 = list ( set ( idx_m1 ) )
1163+ idx_m1 = unique ( idx_m1 )
11541164 if only_index :
11551165 return idx_m1
11561166 else :
@@ -1162,25 +1172,24 @@ def splitting_event(self, triplet=False, only_index=False):
11621172 If `triplet=True` return the eddy before a splitting event, the eddy after the splitting event,
11631173 and the eddy starting due to splitting.
11641174 """
1165- idx_s0 = list ()
1175+ # Get start and stop for each segment, there is no empty segment
1176+ i_start , _ , _ = self .index_segment_track
1177+ # Get target index
1178+ idx_s0 = self .previous_obs [i_start ]
1179+ # Get mask and valid target
1180+ m = idx_s0 != - 1
1181+ idx_s0 = idx_s0 [m ]
1182+ # Sort by time event
1183+ i = self .time [idx_s0 ].argsort ()
1184+ idx_s0 = idx_s0 [i ]
11661185 if triplet :
1167- idx_s1_start = list ()
1168- idx_s1 = list ()
1169- next_obs , previous_obs = self .next_obs , self .previous_obs
1170- for i , _ , _ in self .iter_on (self .segment_track_array ):
1171- nb = i .stop - i .start
1172- if nb == 0 :
1173- continue
1174- i_p = previous_obs [i .start ]
1175- if i_p != - 1 :
1176- if triplet :
1177- idx_s1_start .append (i .start )
1178- idx_s1 .append (next_obs [i_p ])
1179- idx_s0 .append (i_p )
1186+ # Get obs after target
1187+ idx_s1_start = i_start [m ][i ]
1188+ idx_s1 = self .next_obs [idx_s0 ].copy ()
11801189
11811190 if triplet :
11821191 if only_index :
1183- return array ( idx_s0 ), array ( idx_s1 ), array ( idx_s1_start )
1192+ return idx_s0 , idx_s1 , idx_s1_start
11841193 else :
11851194 return (
11861195 self .extract_event (idx_s0 ),
@@ -1189,7 +1198,7 @@ def splitting_event(self, triplet=False, only_index=False):
11891198 )
11901199
11911200 else :
1192- idx_s0 = list ( set ( idx_s0 ) )
1201+ idx_s0 = unique ( idx_s0 )
11931202 if only_index :
11941203 return idx_s0
11951204 else :
@@ -1199,7 +1208,7 @@ def dissociate_network(self):
11991208 """
12001209 Dissociate networks with no known interaction (splitting/merging)
12011210 """
1202- tags = self .tag_segment (multi_network = True )
1211+ tags = self .tag_segment ()
12031212 if self .track [0 ] == 0 :
12041213 tags -= 1
12051214 self .track [:] = tags [self .segment_track_array ]
@@ -1345,16 +1354,22 @@ def __tag_segment(cls, seg, tag, groups, connexions):
13451354 # For each connexion we apply same function
13461355 cls .__tag_segment (seg , tag , groups , connexions )
13471356
1348- def tag_segment (self , multi_network = False ):
1349- if multi_network :
1350- nb = self .segment_track_array [- 1 ] + 1
1351- else :
1352- nb = self .segment .max () + 1
1357+ def tag_segment (self ):
1358+ """For each segment, method give a new network id, and all segment are connected
1359+
1360+ :return array: for each unique seg id, it return new network id
1361+ """
1362+ nb = self .segment_track_array [- 1 ] + 1
13531363 sub_group = zeros (nb , dtype = "u4" )
1354- c = self .connexions (multi_network = multi_network )
1364+ c = self .connexions (multi_network = True )
13551365 j = 1
13561366 # for each available id
13571367 for i in range (nb ):
1368+ # No connexions, no need to explore
1369+ if i not in c :
1370+ sub_group [i ] = j
1371+ j += 1
1372+ continue
13581373 # Skip if already set
13591374 if sub_group [i ] != 0 :
13601375 continue
@@ -1363,15 +1378,31 @@ def tag_segment(self, multi_network=False):
13631378 j += 1
13641379 return sub_group
13651380
1381+
13661382 def fully_connected (self ):
1383+ """Suspicious
1384+ """
1385+ raise Exception ("Must be check" )
13671386 self .only_one_network ()
13681387 return self .tag_segment ().shape [0 ] == 1
13691388
1389+ def first_is_trash (self ):
1390+ """Check if first network is Trash
1391+
1392+ :return bool: True if first network is trash
1393+ """
1394+ i_start , i_stop , _ = self .index_segment_track
1395+ sl = slice (i_start [0 ], i_stop [0 ])
1396+ return (self .previous_obs [sl ] == - 1 ).all () and (self .next_obs [sl ] == - 1 ).all ()
1397+
13701398 def remove_trash (self ):
13711399 """
13721400 Remove the lonely eddies (only 1 obs in segment, associated network number is 0)
13731401 """
1374- return self .extract_with_mask (self .track != 0 )
1402+ if self .first_is_trash ():
1403+ return self .extract_with_mask (self .track != 0 )
1404+ else :
1405+ return self
13751406
13761407 def plot (self , ax , ref = None , color_cycle = None , ** kwargs ):
13771408 """
@@ -1551,12 +1582,11 @@ def extract_with_mask(self, mask):
15511582 logger .debug (
15521583 f"{ nb_obs } observations will be extracted ({ nb_obs / self .shape [0 ]:.3%} )"
15531584 )
1554- for field in self .obs . dtype . descr :
1585+ for field in self .fields :
15551586 if field in ("next_obs" , "previous_obs" ):
15561587 continue
15571588 logger .debug ("Copy of field %s ..." , field )
1558- var = field [0 ]
1559- new .obs [var ] = self .obs [var ][mask ]
1589+ new .obs [field ] = self .obs [field ][mask ]
15601590 # n & p must be re-index
15611591 n , p = self .next_obs [mask ], self .previous_obs [mask ]
15621592 # we add 2 for -1 index return index -1
@@ -1682,9 +1712,9 @@ def date2file(julian_day):
16821712
16831713 return f"/tmp/dt_global_{date.strftime('%Y%m%d')}.nc"
16841714 """
1685-
1686- itb_final = - ones (( self . obs . size , 2 ) , dtype = "i4" )
1687- ptb_final = zeros (( self . obs . size , 2 ) , dtype = "i1" )
1715+ shape = len ( self ), 2
1716+ itb_final = - ones (shape , dtype = "i4" )
1717+ ptb_final = zeros (shape , dtype = "i1" )
16881718
16891719 t_start , t_end = int (self .period [0 ]), int (self .period [1 ])
16901720
@@ -1760,9 +1790,9 @@ def date2file(julian_day):
17601790
17611791 return f"/tmp/dt_global_{date.strftime('%Y%m%d')}.nc"
17621792 """
1763-
1764- itf_final = - ones (( self . obs . size , 2 ) , dtype = "i4" )
1765- ptf_final = zeros (( self . obs . size , 2 ) , dtype = "i1" )
1793+ shape = len ( self ), 2
1794+ itf_final = - ones (shape , dtype = "i4" )
1795+ ptf_final = zeros (shape , dtype = "i1" )
17661796
17671797 t_start , t_end = int (self .period [0 ]), int (self .period [1 ])
17681798
0 commit comments