Skip to content

Commit d54a743

Browse files
committed
return sorting argument
1 parent f1260a0 commit d54a743

File tree

4 files changed

+38
-45
lines changed

4 files changed

+38
-45
lines changed

src/py_eddy_tracker/observations/network.py

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,14 @@ def fix_next_previous_obs(next_obs, previous_obs, flag_virtual):
106106

107107
class NetworkObservations(GroupEddiesObservations):
108108

109-
__slots__ = ("_index_network",)
110-
109+
__slots__ = ("_index_network", "_index_segment_track", "_segment_track_array")
111110
NOGROUP = 0
112111

113112
def __init__(self, *args, **kwargs):
114113
super().__init__(*args, **kwargs)
114+
self.reset_index()
115+
116+
def reset_index(self):
115117
self._index_network = None
116118
self._index_segment_track = None
117119
self._segment_track_array = None
@@ -251,9 +253,8 @@ def elements(self):
251253

252254
def astype(self, cls):
253255
new = cls.new_like(self, self.shape)
254-
print()
255-
for k in new.obs.dtype.names:
256-
if k in self.obs.dtype.names:
256+
for k in new.fields:
257+
if k in self.fields:
257258
new[k][:] = self[k][:]
258259
new.sign_type = self.sign_type
259260
return new
@@ -371,23 +372,27 @@ def correct_close_events(self, nb_days_max=20):
371372

372373
self.segment[:] = segment_copy
373374
self.previous_obs[:] = previous_obs
374-
375-
self.sort()
375+
return self.sort()
376376

377377
def sort(self, order=("track", "segment", "time")):
378378
"""
379379
Sort observations
380380
381381
:param tuple order: order or sorting. Given to :func:`numpy.argsort`
382382
"""
383-
index_order = self.obs.argsort(order=order)
384-
for field in self.elements:
383+
index_order = self.obs.argsort(order=order, kind="mergesort")
384+
self.reset_index()
385+
for field in self.fields:
385386
self[field][:] = self[field][index_order]
386387

387-
translate = -ones(index_order.max() + 2, dtype="i4")
388-
translate[index_order] = arange(index_order.shape[0])
388+
nb_obs = len(self)
389+
# we add 1 for -1 index return index -1
390+
translate = -ones(nb_obs + 1, dtype="i4")
391+
translate[index_order] = arange(nb_obs)
392+
# next & previous must be re-indexed
389393
self.next_obs[:] = translate[self.next_obs]
390394
self.previous_obs[:] = translate[self.previous_obs]
395+
return index_order, translate
391396

392397
def obs_relative_order(self, i_obs):
393398
self.only_one_network()
@@ -654,16 +659,16 @@ def normalize_longitude(self):
654659
lon0 = (self.lon[i_start] - 180).repeat(i_stop - i_start)
655660
logger.debug("Normalize longitude")
656661
self.lon[:] = (self.lon - lon0) % 360 + lon0
657-
if "lon_max" in self.obs.dtype.names:
662+
if "lon_max" in self.fields:
658663
logger.debug("Normalize longitude_max")
659664
self.lon_max[:] = (self.lon_max - self.lon + 180) % 360 + self.lon - 180
660665
if not self.raw_data:
661-
if "contour_lon_e" in self.obs.dtype.names:
666+
if "contour_lon_e" in self.fields:
662667
logger.debug("Normalize effective contour longitude")
663668
self.contour_lon_e[:] = (
664669
(self.contour_lon_e.T - self.lon + 180) % 360 + self.lon - 180
665670
).T
666-
if "contour_lon_s" in self.obs.dtype.names:
671+
if "contour_lon_s" in self.fields:
667672
logger.debug("Normalize speed contour longitude")
668673
self.contour_lon_s[:] = (
669674
(self.contour_lon_s.T - self.lon + 180) % 360 + self.lon - 180
@@ -1071,7 +1076,7 @@ def extract_event(self, indices):
10711076
raw_data=self.raw_data,
10721077
)
10731078

1074-
for k in new.obs.dtype.names:
1079+
for k in new.fields:
10751080
new[k][:] = self[k][indices]
10761081
new.sign_type = self.sign_type
10771082
return new
@@ -1194,27 +1199,11 @@ def dissociate_network(self):
11941199
"""
11951200
Dissociate networks with no known interaction (splitting/merging)
11961201
"""
1197-
11981202
tags = self.tag_segment(multi_network=True)
11991203
if self.track[0] == 0:
12001204
tags -= 1
1201-
12021205
self.track[:] = tags[self.segment_track_array]
1203-
1204-
i_sort = self.obs.argsort(order=("track", "segment", "time"), kind="mergesort")
1205-
# Sort directly obs, with hope to save memory
1206-
self.obs.sort(order=("track", "segment", "time"), kind="mergesort")
1207-
self._index_network = None
1208-
1209-
# n & p must be re-indexed
1210-
n, p = self.next_obs, self.previous_obs
1211-
# we add 2 for -1 index return index -1
1212-
nb_obs = len(self)
1213-
translate = -ones(nb_obs + 1, dtype="i4")
1214-
translate[:-1][i_sort] = arange(nb_obs)
1215-
self.next_obs[:] = translate[n]
1216-
self.previous_obs[:] = translate[p]
1217-
return translate
1206+
return self.sort()
12181207

12191208
def network_segment(self, id_network, id_segment):
12201209
return self.extract_with_mask(self.segment_slice(id_network, id_segment))

src/py_eddy_tracker/observations/observation.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -306,12 +306,16 @@ def box_display(value):
306306
"""Return values evenly spaced with few numbers"""
307307
return "".join([f"{v_:10.2f}" for v_ in value])
308308

309+
@property
310+
def fields(self):
311+
return list(self.obs.dtype.names)
312+
309313
def field_table(self):
310314
"""
311315
Produce description table of the fields available in this object
312316
"""
313317
rows = [("Name (Unit)", "Long name", "Scale factor", "Offset")]
314-
names = list(self.obs.dtype.names)
318+
names = self.fields
315319
names.sort()
316320
for field in names:
317321
infos = VAR_DESCR[field]
@@ -414,7 +418,7 @@ def remove_fields(self, *fields):
414418
"""
415419
nb_obs = self.obs.shape[0]
416420
fields = set(fields)
417-
only_variables = set(self.obs.dtype.names) - fields
421+
only_variables = set(self.fields) - fields
418422
track_extra_variables = set(self.track_extra_variables) - fields
419423
array_variables = set(self.array_variables) - fields
420424
new = self.__class__(
@@ -426,7 +430,7 @@ def remove_fields(self, *fields):
426430
raw_data=self.raw_data,
427431
)
428432
new.sign_type = self.sign_type
429-
for name in new.obs.dtype.names:
433+
for name in new.fields:
430434
logger.debug("Copy of field %s ...", name)
431435
new.obs[name] = self.obs[name]
432436
return new
@@ -444,12 +448,12 @@ def add_fields(self, fields=list(), array_fields=list()):
444448
track_array_variables=self.track_array_variables,
445449
array_variables=list(concatenate((self.array_variables, array_fields))),
446450
only_variables=list(
447-
concatenate((self.obs.dtype.names, fields, array_fields))
451+
concatenate((self.fields, fields, array_fields))
448452
),
449453
raw_data=self.raw_data,
450454
)
451455
new.sign_type = self.sign_type
452-
for name in self.obs.dtype.names:
456+
for name in self.fields:
453457
logger.debug("Copy of field %s ...", name)
454458
new.obs[name] = self.obs[name]
455459
return new
@@ -467,8 +471,8 @@ def circle_contour(self, only_virtual=False, factor=1):
467471
"""
468472
angle = radians(linspace(0, 360, self.track_array_variables))
469473
x_norm, y_norm = cos(angle), sin(angle)
470-
radius_s = "contour_lon_s" in self.obs.dtype.names
471-
radius_e = "contour_lon_e" in self.obs.dtype.names
474+
radius_s = "contour_lon_s" in self.fields
475+
radius_e = "contour_lon_e" in self.fields
472476
for i, obs in enumerate(self):
473477
if only_virtual and not obs["virtual"]:
474478
continue
@@ -684,7 +688,7 @@ def distance(self, other):
684688

685689
def __copy__(self):
686690
eddies = self.new_like(self, len(self))
687-
for k in self.obs.dtype.names:
691+
for k in self.fields:
688692
eddies[k][:] = self[k][:]
689693
eddies.sign_type = self.sign_type
690694
return eddies

src/py_eddy_tracker/observations/tracking.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,16 +183,16 @@ def normalize_longitude(self):
183183
lon0 = (self.lon[self.index_from_track] - 180).repeat(self.nb_obs_by_track)
184184
logger.debug("Normalize longitude")
185185
self.lon[:] = (self.lon - lon0) % 360 + lon0
186-
if "lon_max" in self.obs.dtype.names:
186+
if "lon_max" in self.fields:
187187
logger.debug("Normalize longitude_max")
188188
self.lon_max[:] = (self.lon_max - self.lon + 180) % 360 + self.lon - 180
189189
if not self.raw_data:
190-
if "contour_lon_e" in self.obs.dtype.names:
190+
if "contour_lon_e" in self.fields:
191191
logger.debug("Normalize effective contour longitude")
192192
self.contour_lon_e[:] = (
193193
(self.contour_lon_e.T - self.lon + 180) % 360 + self.lon - 180
194194
).T
195-
if "contour_lon_s" in self.obs.dtype.names:
195+
if "contour_lon_s" in self.fields:
196196
logger.debug("Normalize speed contour longitude")
197197
self.contour_lon_s[:] = (
198198
(self.contour_lon_s.T - self.lon + 180) % 360 + self.lon - 180

src/py_eddy_tracker/tracking.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ def merge(self, until=-1, raw_data=True):
658658
# Set type of eddy with first file
659659
eddies.sign_type = self.current_obs.sign_type
660660
# Fields to copy
661-
fields = self.current_obs.obs.dtype.names
661+
fields = self.current_obs.fields
662662

663663
# To know if the track start
664664
first_obs_save_in_tracks = zeros(self.i_current_by_tracks.shape, dtype=bool_)
@@ -707,7 +707,7 @@ def merge(self, until=-1, raw_data=True):
707707
# Index in the current file
708708
index_current = self[i]["out"]
709709

710-
if "cost_association" in eddies.obs.dtype.names:
710+
if "cost_association" in eddies.fields:
711711
eddies["cost_association"][index_final - 1] = self[i]["cost_value"]
712712
# Copy all variable
713713
for field in fields:

0 commit comments

Comments
 (0)