Skip to content

Commit f8bbcfc

Browse files
committed
Allow to set options in yaml file for tracking
1 parent 8688672 commit f8bbcfc

File tree

6 files changed

+231
-112
lines changed

6 files changed

+231
-112
lines changed

src/py_eddy_tracker/featured_tracking/area_tracker.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66

77

88
class AreaTracker(Model):
9+
10+
__slots__ = ("cmin",)
11+
12+
def __init__(self, *args, cmin=0.2, **kwargs):
13+
super().__init__(*args, **kwargs)
14+
self.cmin = cmin
15+
916
@classmethod
1017
def needed_variable(cls):
1118
vars = ["longitude", "latitude"]
@@ -17,7 +24,7 @@ def tracking(self, other):
1724
i, j, c = self.match(other, intern=False)
1825
cost_mat = ma.empty(shape, dtype="f4")
1926
cost_mat.mask = ma.ones(shape, dtype="bool")
20-
m = c > 0.2
27+
m = c > self.cmin
2128
i, j, c = i[m], j[m], c[m]
2229
cost_mat[i, j] = 1 - c
2330

src/py_eddy_tracker/featured_tracking/old_tracker_reference.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77

88
class CheltonTracker(Model):
9+
10+
__slots__ = tuple()
11+
912
GROUND = RegularGridDataset(
1013
path.join(path.dirname(__file__), "../data/mask_1_60.nc"), "lon", "lat"
1114
)

src/py_eddy_tracker/observations/observation.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,16 @@ def distance(self, other):
539539
other eddies."""
540540
return distance_grid(self.lon, self.lat, other.lon, other.lat)
541541

542+
def __copy__(self):
543+
eddies = self.new_like(self, len(self))
544+
for k in self.obs.dtype.names:
545+
eddies[k][:] = self[k][:]
546+
eddies.sign_type = self.sign_type
547+
return eddies
548+
549+
def copy(self):
550+
return self.__copy__()
551+
542552
@staticmethod
543553
def new_like(eddies, new_size: int):
544554
return eddies.__class__(
@@ -599,6 +609,8 @@ def load_file(cls, filename, **kwargs):
599609
filename_ = (
600610
filename.filename if isinstance(filename, ExFileObject) else filename
601611
)
612+
if isinstance(filename, zarr.storage.MutableMapping):
613+
return cls.load_from_zarr(filename, **kwargs)
602614
end = b".zarr" if isinstance(filename_, bytes) else ".zarr"
603615
if filename_.endswith(end):
604616
return cls.load_from_zarr(filename, **kwargs)
@@ -614,6 +626,7 @@ def load_from_zarr(
614626
include_vars=None,
615627
indexs=None,
616628
buffer_size=5000000,
629+
**class_kwargs,
617630
):
618631
"""Load data from zarr.
619632
@@ -623,6 +636,7 @@ def load_from_zarr(
623636
:param None,list(str) include_vars: If defined only this variable will be loaded
624637
:param None,dict indexs: Indexs to laad only a slice of data
625638
:param int buffer_size: Size of buffer used to load zarr data
639+
:param class_kwargs: argument to set up observations class
626640
:return: Obsevations selected
627641
:return type: class
628642
"""
@@ -667,6 +681,7 @@ def load_from_zarr(
667681
kwargs["only_variables"] = (
668682
None if include_vars is None else [VAR_DESCR_inv[i] for i in include_vars]
669683
)
684+
kwargs.update(class_kwargs)
670685
eddies = cls(size=nb_obs, **kwargs)
671686
for variable in var_list:
672687
var_inv = VAR_DESCR_inv[variable]
@@ -727,7 +742,7 @@ def copy_data_to_zarr(
727742
if i_stop is None:
728743
i_stop = handler_zarr.shape[0]
729744
for i in range(i_start, i_stop, buffer_size):
730-
sl_in = slice(i, i + buffer_size)
745+
sl_in = slice(i, min(i + buffer_size, i_stop))
731746
data = handler_zarr[sl_in]
732747
if factor != 1:
733748
data *= factor
@@ -741,7 +756,13 @@ def copy_data_to_zarr(
741756

742757
@classmethod
743758
def load_from_netcdf(
744-
cls, filename, raw_data=False, remove_vars=None, include_vars=None, indexs=None
759+
cls,
760+
filename,
761+
raw_data=False,
762+
remove_vars=None,
763+
include_vars=None,
764+
indexs=None,
765+
**class_kwargs,
745766
):
746767
"""Load data from netcdf.
747768
@@ -750,6 +771,7 @@ def load_from_netcdf(
750771
:param None,list(str) remove_vars: List of variable name which will be not loaded
751772
:param None,list(str) include_vars: If defined only this variable will be loaded
752773
:param None,dict indexs: Indexs to laad only a slice of data
774+
:param class_kwargs: argument to set up observations class
753775
:return: Obsevations selected
754776
:return type: class
755777
"""
@@ -799,6 +821,7 @@ def load_from_netcdf(
799821
if include_vars is None
800822
else [VAR_DESCR_inv[i] for i in include_vars]
801823
)
824+
kwargs.update(class_kwargs)
802825
eddies = cls(size=nb_obs, **kwargs)
803826
for variable in var_list:
804827
var_inv = VAR_DESCR_inv[variable]
@@ -1439,7 +1462,7 @@ def create_variable_zarr(
14391462
add_offset=None,
14401463
filters=None,
14411464
compressor=None,
1442-
chunck_size=2500000
1465+
chunck_size=2500000,
14431466
):
14441467
kwargs_variable["shape"] = data.shape
14451468
kwargs_variable["compressor"] = (

0 commit comments

Comments
 (0)