@@ -55,6 +55,14 @@ def lat(self):
55
55
BasePath .lat = lat
56
56
57
57
58
+ @njit (cache = True )
59
+ def prepare_for_kdtree (x_val , y_val ):
60
+ data = empty ((x_val .shape [0 ], 2 ))
61
+ data [:, 0 ] = x_val
62
+ data [:, 1 ] = y_val
63
+ return data
64
+
65
+
58
66
@njit (cache = True )
59
67
def uniform_resample_stack (vertices , num_fac = 2 , fixed_size = None ):
60
68
x_val , y_val = vertices [:, 0 ], vertices [:, 1 ]
@@ -196,6 +204,7 @@ class GridDataset(object):
196
204
'coordinates' ,
197
205
'filename' ,
198
206
'dimensions' ,
207
+ 'indexs' ,
199
208
'variables_description' ,
200
209
'global_attrs' ,
201
210
'vars' ,
@@ -209,7 +218,7 @@ class GridDataset(object):
209
218
# EARTH_RADIUS = 6378136.3
210
219
N = 1
211
220
212
- def __init__ (self , filename , x_name , y_name , centered = None ):
221
+ def __init__ (self , filename , x_name , y_name , centered = None , indexs = None ):
213
222
self .dimensions = None
214
223
self .variables_description = None
215
224
self .global_attrs = None
@@ -226,6 +235,7 @@ def __init__(self, filename, x_name, y_name, centered=None):
226
235
self .filename = filename
227
236
self .coordinates = x_name , y_name
228
237
self .vars = dict ()
238
+ self .indexs = None if indexs is None else indexs
229
239
self .interpolators = dict ()
230
240
if centered is None :
231
241
logger .warning ('We assume the position of grid is the center'
@@ -312,8 +322,10 @@ def load(self):
312
322
self .x_dim = h .variables [x_name ].dimensions
313
323
self .y_dim = h .variables [y_name ].dimensions
314
324
315
- self .vars [x_name ] = h .variables [x_name ][:]
316
- self .vars [y_name ] = h .variables [y_name ][:]
325
+ sl_x = [self .indexs .get (dim , slice (None )) for dim in self .x_dim ]
326
+ sl_y = [self .indexs .get (dim , slice (None )) for dim in self .y_dim ]
327
+ self .vars [x_name ] = h .variables [x_name ][sl_x ]
328
+ self .vars [y_name ] = h .variables [y_name ][sl_y ]
317
329
318
330
if self .is_centered :
319
331
logger .info ('Grid center' )
@@ -382,16 +394,18 @@ def copy(self, grid_in, grid_out):
382
394
)
383
395
self .vars [grid_out ] = self .grid (grid_in ).copy ()
384
396
385
- def grid (self , varname ):
397
+ def grid (self , varname , indexs = None ):
386
398
"""give grid required
387
399
"""
400
+ if indexs is None :
401
+ indexs = dict ()
388
402
if varname not in self .vars :
389
403
coordinates_dims = list (self .x_dim )
390
404
coordinates_dims .extend (list (self .y_dim ))
391
405
logger .debug ('Load %(varname)s from %(filename)s' , dict (varname = varname , filename = self .filename ))
392
406
with Dataset (self .filename ) as h :
393
407
dims = h .variables [varname ].dimensions
394
- sl = [slice (None ) if dim in coordinates_dims else 0 for dim in dims ]
408
+ sl = [indexs . get ( dim , self . indexs . get ( dim , slice (None ) if dim in coordinates_dims else 0 )) for dim in dims ]
395
409
self .vars [varname ] = h .variables [varname ][sl ]
396
410
if len (self .x_dim ) == 1 :
397
411
i_x = where (array (dims ) == self .x_dim )[0 ][0 ]
@@ -483,6 +497,10 @@ def eddy_identification(self, grid_height, uname, vname, date, step=0.005, shape
483
497
484
498
# Get h grid
485
499
data = self .grid (grid_height ).astype ('f8' )
500
+ # In case of a reduce mask
501
+ if len (data .mask .shape ) == 0 and not data .mask :
502
+ data .mask = zeros (data .shape , dtype = 'bool' )
503
+ # we remove noisy information
486
504
if precision is not None :
487
505
data = (data / precision ).round () * precision
488
506
# Compute levels for ssh
@@ -753,6 +771,24 @@ class UnRegularGridDataset(GridDataset):
753
771
'_speed_norm' ,
754
772
)
755
773
774
+ def load (self ):
775
+ """Load variable (data)
776
+ """
777
+ x_name , y_name = self .coordinates
778
+ with Dataset (self .filename ) as h :
779
+ self .x_dim = h .variables [x_name ].dimensions
780
+ self .y_dim = h .variables [y_name ].dimensions
781
+
782
+ sl_x = [self .indexs .get (dim , slice (None )) for dim in self .x_dim ]
783
+ sl_y = [self .indexs .get (dim , slice (None )) for dim in self .y_dim ]
784
+ self .vars [x_name ] = h .variables [x_name ][sl_x ]
785
+ self .vars [y_name ] = h .variables [y_name ][sl_y ]
786
+
787
+ self .x_c = self .vars [x_name ]
788
+ self .y_c = self .vars [y_name ]
789
+
790
+ self .init_pos_interpolator ()
791
+
756
792
def bbox_indice (self , vertices ):
757
793
dist , idx = self .index_interp .query (vertices , k = 1 )
758
794
i_y = idx % self .x_c .shape [1 ]
@@ -761,8 +797,9 @@ def bbox_indice(self, vertices):
761
797
762
798
def get_pixels_in (self , contour ):
763
799
(x_start , x_stop ), (y_start , y_stop ) = contour .bbox_slice
764
- pts = array ((self .x_c [x_start :x_stop , y_start :x_stop ].reshape (- 1 ),
765
- self .y_c [x_start :y_stop , y_start :y_stop ].reshape (- 1 ))).T
800
+ pts = array ((self .x_c [x_start :x_stop , y_start :y_stop ].reshape (- 1 ),
801
+ self .y_c [x_start :x_stop , y_start :y_stop ].reshape (- 1 ))).T
802
+ x_stop = min (x_stop , self .x_c .shape [0 ])
766
803
mask = contour .contains_points (pts ).reshape ((x_stop - x_start , - 1 ))
767
804
i_x , i_y = where (mask )
768
805
i_x += x_start
@@ -785,10 +822,8 @@ def compute_pixel_path(self, x0, y0, x1, y1):
785
822
def init_pos_interpolator (self ):
786
823
logger .debug ('Create a KdTree could be long ...' )
787
824
self .index_interp = cKDTree (
788
- uniform_resample_stack ((
789
- self .x_c .reshape (- 1 ),
790
- self .y_c .reshape (- 1 )
791
- )))
825
+ prepare_for_kdtree (self .x_c .reshape (- 1 ), self .y_c .reshape (- 1 )))
826
+
792
827
logger .debug ('... OK' )
793
828
794
829
def _low_filter (self , grid_name , x_cut , y_cut , factor = 40. ):
0 commit comments