@@ -55,6 +55,14 @@ def lat(self):
5555BasePath .lat = lat
5656
5757
58+ @njit (cache = True )
59+ def prepare_for_kdtree (x_val , y_val ):
60+ data = empty ((x_val .shape [0 ], 2 ))
61+ data [:, 0 ] = x_val
62+ data [:, 1 ] = y_val
63+ return data
64+
65+
5866@njit (cache = True )
5967def uniform_resample_stack (vertices , num_fac = 2 , fixed_size = None ):
6068 x_val , y_val = vertices [:, 0 ], vertices [:, 1 ]
@@ -196,6 +204,7 @@ class GridDataset(object):
196204 'coordinates' ,
197205 'filename' ,
198206 'dimensions' ,
207+ 'indexs' ,
199208 'variables_description' ,
200209 'global_attrs' ,
201210 'vars' ,
@@ -209,7 +218,7 @@ class GridDataset(object):
209218 # EARTH_RADIUS = 6378136.3
210219 N = 1
211220
212- def __init__ (self , filename , x_name , y_name , centered = None ):
221+ def __init__ (self , filename , x_name , y_name , centered = None , indexs = None ):
213222 self .dimensions = None
214223 self .variables_description = None
215224 self .global_attrs = None
@@ -226,6 +235,7 @@ def __init__(self, filename, x_name, y_name, centered=None):
226235 self .filename = filename
227236 self .coordinates = x_name , y_name
228237 self .vars = dict ()
238+ self .indexs = None if indexs is None else indexs
229239 self .interpolators = dict ()
230240 if centered is None :
231241 logger .warning ('We assume the position of grid is the center'
@@ -312,8 +322,10 @@ def load(self):
312322 self .x_dim = h .variables [x_name ].dimensions
313323 self .y_dim = h .variables [y_name ].dimensions
314324
315- self .vars [x_name ] = h .variables [x_name ][:]
316- self .vars [y_name ] = h .variables [y_name ][:]
325+ sl_x = [self .indexs .get (dim , slice (None )) for dim in self .x_dim ]
326+ sl_y = [self .indexs .get (dim , slice (None )) for dim in self .y_dim ]
327+ self .vars [x_name ] = h .variables [x_name ][sl_x ]
328+ self .vars [y_name ] = h .variables [y_name ][sl_y ]
317329
318330 if self .is_centered :
319331 logger .info ('Grid center' )
@@ -382,16 +394,18 @@ def copy(self, grid_in, grid_out):
382394 )
383395 self .vars [grid_out ] = self .grid (grid_in ).copy ()
384396
385- def grid (self , varname ):
397+ def grid (self , varname , indexs = None ):
386398 """give grid required
387399 """
400+ if indexs is None :
401+ indexs = dict ()
388402 if varname not in self .vars :
389403 coordinates_dims = list (self .x_dim )
390404 coordinates_dims .extend (list (self .y_dim ))
391405 logger .debug ('Load %(varname)s from %(filename)s' , dict (varname = varname , filename = self .filename ))
392406 with Dataset (self .filename ) as h :
393407 dims = h .variables [varname ].dimensions
394- sl = [slice (None ) if dim in coordinates_dims else 0 for dim in dims ]
408+ sl = [indexs . get ( dim , self . indexs . get ( dim , slice (None ) if dim in coordinates_dims else 0 )) for dim in dims ]
395409 self .vars [varname ] = h .variables [varname ][sl ]
396410 if len (self .x_dim ) == 1 :
397411 i_x = where (array (dims ) == self .x_dim )[0 ][0 ]
@@ -483,6 +497,10 @@ def eddy_identification(self, grid_height, uname, vname, date, step=0.005, shape
483497
484498 # Get h grid
485499 data = self .grid (grid_height ).astype ('f8' )
500+ # In case of a reduce mask
501+ if len (data .mask .shape ) == 0 and not data .mask :
502+ data .mask = zeros (data .shape , dtype = 'bool' )
503+ # we remove noisy information
486504 if precision is not None :
487505 data = (data / precision ).round () * precision
488506 # Compute levels for ssh
@@ -753,6 +771,24 @@ class UnRegularGridDataset(GridDataset):
753771 '_speed_norm' ,
754772 )
755773
774+ def load (self ):
775+ """Load variable (data)
776+ """
777+ x_name , y_name = self .coordinates
778+ with Dataset (self .filename ) as h :
779+ self .x_dim = h .variables [x_name ].dimensions
780+ self .y_dim = h .variables [y_name ].dimensions
781+
782+ sl_x = [self .indexs .get (dim , slice (None )) for dim in self .x_dim ]
783+ sl_y = [self .indexs .get (dim , slice (None )) for dim in self .y_dim ]
784+ self .vars [x_name ] = h .variables [x_name ][sl_x ]
785+ self .vars [y_name ] = h .variables [y_name ][sl_y ]
786+
787+ self .x_c = self .vars [x_name ]
788+ self .y_c = self .vars [y_name ]
789+
790+ self .init_pos_interpolator ()
791+
756792 def bbox_indice (self , vertices ):
757793 dist , idx = self .index_interp .query (vertices , k = 1 )
758794 i_y = idx % self .x_c .shape [1 ]
@@ -761,8 +797,9 @@ def bbox_indice(self, vertices):
761797
762798 def get_pixels_in (self , contour ):
763799 (x_start , x_stop ), (y_start , y_stop ) = contour .bbox_slice
764- pts = array ((self .x_c [x_start :x_stop , y_start :x_stop ].reshape (- 1 ),
765- self .y_c [x_start :y_stop , y_start :y_stop ].reshape (- 1 ))).T
800+ pts = array ((self .x_c [x_start :x_stop , y_start :y_stop ].reshape (- 1 ),
801+ self .y_c [x_start :x_stop , y_start :y_stop ].reshape (- 1 ))).T
802+ x_stop = min (x_stop , self .x_c .shape [0 ])
766803 mask = contour .contains_points (pts ).reshape ((x_stop - x_start , - 1 ))
767804 i_x , i_y = where (mask )
768805 i_x += x_start
@@ -785,10 +822,8 @@ def compute_pixel_path(self, x0, y0, x1, y1):
785822 def init_pos_interpolator (self ):
786823 logger .debug ('Create a KdTree could be long ...' )
787824 self .index_interp = cKDTree (
788- uniform_resample_stack ((
789- self .x_c .reshape (- 1 ),
790- self .y_c .reshape (- 1 )
791- )))
825+ prepare_for_kdtree (self .x_c .reshape (- 1 ), self .y_c .reshape (- 1 )))
826+
792827 logger .debug ('... OK' )
793828
794829 def _low_filter (self , grid_name , x_cut , y_cut , factor = 40. ):
0 commit comments