Skip to content

Commit 66f9905

Browse files
committed
Remove reference to obs or observation to be easily replace by store later
1 parent c7fbbd7 commit 66f9905

File tree

3 files changed

+27
-33
lines changed

3 files changed

+27
-33
lines changed

src/py_eddy_tracker/observations/groups.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,15 +292,14 @@ def filled_by_interpolation(self, mask):
292292
nb_obs = len(self)
293293
index = arange(nb_obs)
294294

295-
for field in self.obs.dtype.descr:
296-
var = field[0]
295+
for field in self.fields:
297296
if (
298-
var in ["n", "virtual", "track", "cost_association"]
299-
or var in self.array_variables
297+
field in ["n", "virtual", "track", "cost_association"]
298+
or field in self.array_variables
300299
):
301300
continue
302-
self.obs[var][mask] = interp(
303-
index[mask], index[~mask], self.obs[var][~mask]
301+
self.obs[field][mask] = interp(
302+
index[mask], index[~mask], self.obs[field][~mask]
304303
)
305304

306305
def insert_virtual(self):

src/py_eddy_tracker/observations/observation.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def get_infos(self):
264264
bins_lat=(-90, -60, -15, 15, 60, 90),
265265
bins_amplitude=array((0, 1, 2, 3, 4, 5, 10, 500)),
266266
bins_radius=array((0, 15, 30, 45, 60, 75, 100, 200, 2000)),
267-
nb_obs=self.observations.shape[0],
267+
nb_obs=len(self),
268268
)
269269
t0, t1 = self.period
270270
infos["t0"], infos["t1"] = t0, t1
@@ -341,7 +341,7 @@ def __repr__(self):
341341
bins_lat = (-90, -60, -15, 15, 60, 90)
342342
bins_amplitude = array((0, 1, 2, 3, 4, 5, 10, 500))
343343
bins_radius = array((0, 15, 30, 45, 60, 75, 100, 200, 2000))
344-
nb_obs = self.observations.shape[0]
344+
nb_obs = len(self)
345345

346346
return f""" | {nb_obs} observations from {t0} to {t1} ({period} days, ~{nb_obs / period:.0f} obs/day)
347347
| Speed area : {self.speed_area.sum() / period / 1e12:.2f} Mkm²/day
@@ -416,7 +416,7 @@ def remove_fields(self, *fields):
416416
"""
417417
Copy with fields listed remove
418418
"""
419-
nb_obs = self.obs.shape[0]
419+
nb_obs = len(self)
420420
fields = set(fields)
421421
only_variables = set(self.fields) - fields
422422
track_extra_variables = set(self.track_extra_variables) - fields
@@ -439,7 +439,7 @@ def add_fields(self, fields=list(), array_fields=list()):
439439
"""
440440
Add a new field.
441441
"""
442-
nb_obs = self.obs.shape[0]
442+
nb_obs = len(self)
443443
new = self.__class__(
444444
size=nb_obs,
445445
track_extra_variables=list(
@@ -547,9 +547,9 @@ def merge(self, other):
547547
nb_obs_self = len(self)
548548
nb_obs = nb_obs_self + len(other)
549549
eddies = self.new_like(self, nb_obs)
550-
other_keys = other.obs.dtype.fields.keys()
551-
self_keys = self.obs.dtype.fields.keys()
552-
for key in eddies.obs.dtype.fields.keys():
550+
other_keys = other.fields
551+
self_keys = self.fields
552+
for key in eddies.fields:
553553
eddies.obs[key][:nb_obs_self] = self.obs[key][:]
554554
if key in other_keys:
555555
eddies.obs[key][nb_obs_self:] = other.obs[key][:]
@@ -657,8 +657,8 @@ def insert_observations(self, other, index):
657657
"""Insert other obs in self at the given index."""
658658
if not self.coherence(other):
659659
raise Exception("Observations with no coherence")
660-
insert_size = len(other.obs)
661-
self_size = len(self.obs)
660+
insert_size = len(other)
661+
self_size = len(self)
662662
new_size = self_size + insert_size
663663
if self_size == 0:
664664
self.observations = other.obs
@@ -1542,8 +1542,7 @@ def to_zarr(self, handler, **kwargs):
15421542
handler.attrs["track_array_variables"] = self.track_array_variables
15431543
handler.attrs["array_variables"] = ",".join(self.array_variables)
15441544
# Iter on variables to create:
1545-
fields = [field[0] for field in self.observations.dtype.descr]
1546-
for ori_name in fields:
1545+
for ori_name in self.fields:
15471546
# Patch for a transition
15481547
name = ori_name
15491548
#
@@ -1588,12 +1587,11 @@ def to_netcdf(self, handler, **kwargs):
15881587
handler.track_array_variables = self.track_array_variables
15891588
handler.array_variables = ",".join(self.array_variables)
15901589
# Iter on variables to create:
1591-
fields = [field[0] for field in self.observations.dtype.descr]
15921590
fields_ = array(
1593-
[VAR_DESCR[field[0]]["nc_name"] for field in self.observations.dtype.descr]
1591+
[VAR_DESCR[field]["nc_name"] for field in self.fields]
15941592
)
15951593
i = fields_.argsort()
1596-
for ori_name in array(fields)[i]:
1594+
for ori_name in array(self.fields)[i]:
15971595
# Patch for a transition
15981596
name = ori_name
15991597
#
@@ -1865,10 +1863,9 @@ def extract_with_mask(self, mask):
18651863
if nb_obs == 0:
18661864
logger.warning("Empty dataset will be created")
18671865
else:
1868-
for field in self.obs.dtype.descr:
1866+
for field in self.fields:
18691867
logger.debug("Copy of field %s ...", field)
1870-
var = field[0]
1871-
new.obs[var] = self.obs[var][mask]
1868+
new.obs[field] = self.obs[field][mask]
18721869
return new
18731870

18741871
def scatter(self, ax, name=None, ref=None, factor=1, **kwargs):

src/py_eddy_tracker/observations/tracking.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def __repr__(self):
118118
t0, t1 = self.period
119119
period = t1 - t0 + 1
120120
nb = self.nb_obs_by_track
121-
nb_obs = self.observations.shape[0]
121+
nb_obs = len(self)
122122
m = self.virtual.astype("bool")
123123
nb_m = m.sum()
124124
bins_t = (1, 30, 90, 180, 270, 365, 1000, 10000)
@@ -147,7 +147,7 @@ def __repr__(self):
147147

148148
def add_distance(self):
149149
"""Add a field of distance (m) between two consecutive observations, 0 for the last observation of each track"""
150-
if "distance_next" in self.observations.dtype.descr:
150+
if "distance_next" in self.fields:
151151
return self
152152
new = self.add_fields(("distance_next",))
153153
new["distance_next"][:1] = self.distance_to_next()
@@ -205,10 +205,9 @@ def extract_longer_eddies(self, nb_min, nb_obs, compress_id=True):
205205
logger.info("Selection of %d observations", nb_obs_select)
206206
eddies = self.__class__.new_like(self, nb_obs_select)
207207
eddies.sign_type = self.sign_type
208-
for field in self.obs.dtype.descr:
208+
for field in self.fields:
209209
logger.debug("Copy of field %s ...", field)
210-
var = field[0]
211-
eddies.obs[var] = self.obs[var][mask]
210+
eddies.obs[field] = self.obs[field][mask]
212211
if compress_id:
213212
list_id = unique(eddies.obs.track)
214213
list_id.sort()
@@ -387,13 +386,13 @@ def extract_toward_direction(self, west=True, delta_lon=None):
387386

388387
def extract_first_obs_in_box(self, res):
389388
data = empty(
390-
self.obs.shape, dtype=[("lon", "f4"), ("lat", "f4"), ("track", "i4")]
389+
len(self), dtype=[("lon", "f4"), ("lat", "f4"), ("track", "i4")]
391390
)
392391
data["lon"] = self.longitude - self.longitude % res
393392
data["lat"] = self.latitude - self.latitude % res
394393
data["track"] = self.track
395394
_, indexs = unique(data, return_index=True)
396-
mask = zeros(self.obs.shape, dtype="bool")
395+
mask = zeros(len(self), dtype="bool")
397396
mask[indexs] = True
398397
return self.extract_with_mask(mask)
399398

@@ -508,10 +507,9 @@ def extract_with_mask(
508507
if nb_obs == 0:
509508
logger.info("Empty dataset will be created")
510509
else:
511-
for field in self.obs.dtype.descr:
510+
for field in self.fields:
512511
logger.debug("Copy of field %s ...", field)
513-
var = field[0]
514-
new.obs[var] = self.obs[var][mask]
512+
new.obs[field] = self.obs[field][mask]
515513
if compress_id:
516514
list_id = unique(new.track)
517515
list_id.sort()

0 commit comments

Comments
 (0)