@@ -106,12 +106,14 @@ def fix_next_previous_obs(next_obs, previous_obs, flag_virtual):
106106
107107class NetworkObservations (GroupEddiesObservations ):
108108
109- __slots__ = ("_index_network" ,)
110-
109+ __slots__ = ("_index_network" , "_index_segment_track" , "_segment_track_array" )
111110 NOGROUP = 0
112111
113112 def __init__ (self , * args , ** kwargs ):
114113 super ().__init__ (* args , ** kwargs )
114+ self .reset_index ()
115+
116+ def reset_index (self ):
115117 self ._index_network = None
116118 self ._index_segment_track = None
117119 self ._segment_track_array = None
@@ -251,9 +253,8 @@ def elements(self):
251253
252254 def astype (self , cls ):
253255 new = cls .new_like (self , self .shape )
254- print ()
255- for k in new .obs .dtype .names :
256- if k in self .obs .dtype .names :
256+ for k in new .fields :
257+ if k in self .fields :
257258 new [k ][:] = self [k ][:]
258259 new .sign_type = self .sign_type
259260 return new
@@ -371,23 +372,27 @@ def correct_close_events(self, nb_days_max=20):
371372
372373 self .segment [:] = segment_copy
373374 self .previous_obs [:] = previous_obs
374-
375- self .sort ()
375+ return self .sort ()
376376
377377 def sort (self , order = ("track" , "segment" , "time" )):
378378 """
379379 Sort observations
380380
381381 :param tuple order: order or sorting. Given to :func:`numpy.argsort`
382382 """
383- index_order = self .obs .argsort (order = order )
384- for field in self .elements :
383+ index_order = self .obs .argsort (order = order , kind = "mergesort" )
384+ self .reset_index ()
385+ for field in self .fields :
385386 self [field ][:] = self [field ][index_order ]
386387
387- translate = - ones (index_order .max () + 2 , dtype = "i4" )
388- translate [index_order ] = arange (index_order .shape [0 ])
388+ nb_obs = len (self )
389+ # we add 1 for -1 index return index -1
390+ translate = - ones (nb_obs + 1 , dtype = "i4" )
391+ translate [index_order ] = arange (nb_obs )
392+ # next & previous must be re-indexed
389393 self .next_obs [:] = translate [self .next_obs ]
390394 self .previous_obs [:] = translate [self .previous_obs ]
395+ return index_order , translate
391396
392397 def obs_relative_order (self , i_obs ):
393398 self .only_one_network ()
@@ -654,16 +659,16 @@ def normalize_longitude(self):
654659 lon0 = (self .lon [i_start ] - 180 ).repeat (i_stop - i_start )
655660 logger .debug ("Normalize longitude" )
656661 self .lon [:] = (self .lon - lon0 ) % 360 + lon0
657- if "lon_max" in self .obs . dtype . names :
662+ if "lon_max" in self .fields :
658663 logger .debug ("Normalize longitude_max" )
659664 self .lon_max [:] = (self .lon_max - self .lon + 180 ) % 360 + self .lon - 180
660665 if not self .raw_data :
661- if "contour_lon_e" in self .obs . dtype . names :
666+ if "contour_lon_e" in self .fields :
662667 logger .debug ("Normalize effective contour longitude" )
663668 self .contour_lon_e [:] = (
664669 (self .contour_lon_e .T - self .lon + 180 ) % 360 + self .lon - 180
665670 ).T
666- if "contour_lon_s" in self .obs . dtype . names :
671+ if "contour_lon_s" in self .fields :
667672 logger .debug ("Normalize speed contour longitude" )
668673 self .contour_lon_s [:] = (
669674 (self .contour_lon_s .T - self .lon + 180 ) % 360 + self .lon - 180
@@ -1071,7 +1076,7 @@ def extract_event(self, indices):
10711076 raw_data = self .raw_data ,
10721077 )
10731078
1074- for k in new .obs . dtype . names :
1079+ for k in new .fields :
10751080 new [k ][:] = self [k ][indices ]
10761081 new .sign_type = self .sign_type
10771082 return new
@@ -1194,27 +1199,11 @@ def dissociate_network(self):
11941199 """
11951200 Dissociate networks with no known interaction (splitting/merging)
11961201 """
1197-
11981202 tags = self .tag_segment (multi_network = True )
11991203 if self .track [0 ] == 0 :
12001204 tags -= 1
1201-
12021205 self .track [:] = tags [self .segment_track_array ]
1203-
1204- i_sort = self .obs .argsort (order = ("track" , "segment" , "time" ), kind = "mergesort" )
1205- # Sort directly obs, with hope to save memory
1206- self .obs .sort (order = ("track" , "segment" , "time" ), kind = "mergesort" )
1207- self ._index_network = None
1208-
1209- # n & p must be re-indexed
1210- n , p = self .next_obs , self .previous_obs
1211- # we add 2 for -1 index return index -1
1212- nb_obs = len (self )
1213- translate = - ones (nb_obs + 1 , dtype = "i4" )
1214- translate [:- 1 ][i_sort ] = arange (nb_obs )
1215- self .next_obs [:] = translate [n ]
1216- self .previous_obs [:] = translate [p ]
1217- return translate
1206+ return self .sort ()
12181207
12191208 def network_segment (self , id_network , id_segment ):
12201209 return self .extract_with_mask (self .segment_slice (id_network , id_segment ))
0 commit comments