@@ -253,6 +253,7 @@ class GridDataset(object):
253253 "filename" ,
254254 "dimensions" ,
255255 "indexs" ,
256+ "nc4file" ,
256257 "variables_description" ,
257258 "global_attrs" ,
258259 "vars" ,
@@ -275,6 +276,7 @@ def __init__(
275276 indexs = None ,
276277 unset = False ,
277278 nan_masking = False ,
279+ nc4file = None ,
278280 ):
279281 """
280282 :param str filename: Filename to load
@@ -301,6 +303,7 @@ def __init__(
301303 self .coordinates = x_name , y_name
302304 self .vars = dict ()
303305 self .indexs = dict () if indexs is None else indexs
306+ self .nc4file = Dataset (filename , "r" ) if nc4file is None else nc4file
304307 if centered is None :
305308 logger .warning (
306309 "We assume pixel position of grid is centered for %s" , filename
@@ -344,25 +347,25 @@ def load_general_features(self):
344347 logger .debug (
345348 "Load general feature from %(filename)s" , dict (filename = self .filename )
346349 )
347- with Dataset ( self .filename ) as h :
348- # Load generals
349- self .dimensions = {i : len (v ) for i , v in h .dimensions .items ()}
350- self .variables_description = dict ()
351- for i , v in h .variables .items ():
352- args = (i , v .datatype )
353- kwargs = dict (dimensions = v .dimensions , zlib = True )
354- if hasattr (v , "_FillValue" ):
355- kwargs ["fill_value" ] = (v ._FillValue ,)
356- attrs = dict ()
357- for attr in v .ncattrs ():
358- if attr in kwargs .keys ():
359- continue
360- if attr == "_FillValue" :
361- continue
362- attrs [attr ] = getattr (v , attr )
363- self .variables_description [i ] = dict (
364- args = args , kwargs = kwargs , attrs = attrs , infos = dict ()
365- )
350+ h = self .nc4file
351+ # Load generals
352+ self .dimensions = {i : len (v ) for i , v in h .dimensions .items ()}
353+ self .variables_description = dict ()
354+ for i , v in h .variables .items ():
355+ args = (i , v .datatype )
356+ kwargs = dict (dimensions = v .dimensions , zlib = True )
357+ if hasattr (v , "_FillValue" ):
358+ kwargs ["fill_value" ] = (v ._FillValue ,)
359+ attrs = dict ()
360+ for attr in v .ncattrs ():
361+ if attr in kwargs .keys ():
362+ continue
363+ if attr == "_FillValue" :
364+ continue
365+ attrs [attr ] = getattr (v , attr )
366+ self .variables_description [i ] = dict (
367+ args = args , kwargs = kwargs , attrs = attrs , infos = dict ()
368+ )
366369 self .global_attrs = {attr : getattr (h , attr ) for attr in h .ncattrs ()}
367370
368371 def write (self , filename ):
@@ -407,14 +410,14 @@ def load(self):
407410 Get coordinates and setup coordinates function
408411 """
409412 x_name , y_name = self .coordinates
410- with Dataset ( self .filename ) as h :
411- self .x_dim = h .variables [x_name ].dimensions
412- self .y_dim = h .variables [y_name ].dimensions
413+ h = self .nc4file
414+ self .x_dim = h .variables [x_name ].dimensions
415+ self .y_dim = h .variables [y_name ].dimensions
413416
414- sl_x = [self .indexs .get (dim , slice (None )) for dim in self .x_dim ]
415- sl_y = [self .indexs .get (dim , slice (None )) for dim in self .y_dim ]
416- self .vars [x_name ] = h .variables [x_name ][sl_x ]
417- self .vars [y_name ] = h .variables [y_name ][sl_y ]
417+ sl_x = [self .indexs .get (dim , slice (None )) for dim in self .x_dim ]
418+ sl_y = [self .indexs .get (dim , slice (None )) for dim in self .y_dim ]
419+ self .vars [x_name ] = h .variables [x_name ][sl_x ]
420+ self .vars [y_name ] = h .variables [y_name ][sl_y ]
418421
419422 self .setup_coordinates ()
420423
@@ -481,10 +484,10 @@ def units(self, varname):
481484 stored_units = self .variables_description [varname ]["attrs" ].get ("units" , None )
482485 if stored_units is not None :
483486 return stored_units
484- with Dataset ( self .filename ) as h :
485- var = h .variables [varname ]
486- if hasattr (var , "units" ):
487- return var .units
487+ h = self .nc4file
488+ var = h .variables [varname ]
489+ if hasattr (var , "units" ):
490+ return var .units
488491
489492 @property
490493 def variables (self ):
@@ -535,24 +538,24 @@ def grid(self, varname, indexs=None):
535538 "Load %(varname)s from %(filename)s" ,
536539 dict (varname = varname , filename = self .filename ),
537540 )
538- with Dataset ( self .filename ) as h :
539- dims = h .variables [varname ].dimensions
540- sl = [
541- indexs .get (
542- dim ,
543- self .indexs .get (
544- dim , slice (None ) if dim in coordinates_dims else 0
545- ),
546- )
547- for dim in dims
548- ]
549- self .vars [varname ] = h .variables [varname ][sl ]
550- if len (self .x_dim ) == 1 :
551- i_x = where (array (dims ) == self .x_dim )[0 ][0 ]
552- i_y = where (array (dims ) == self .y_dim )[0 ][0 ]
553- if i_x > i_y :
554- self .variables_description [varname ]["infos" ]["transpose" ] = True
555- self .vars [varname ] = self .vars [varname ].T
541+ h = self .nc4file
542+ dims = h .variables [varname ].dimensions
543+ sl = [
544+ indexs .get (
545+ dim ,
546+ self .indexs .get (
547+ dim , slice (None ) if dim in coordinates_dims else 0
548+ ),
549+ )
550+ for dim in dims
551+ ]
552+ self .vars [varname ] = h .variables [varname ][sl ]
553+ if len (self .x_dim ) == 1 :
554+ i_x = where (array (dims ) == self .x_dim )[0 ][0 ]
555+ i_y = where (array (dims ) == self .y_dim )[0 ][0 ]
556+ if i_x > i_y :
557+ self .variables_description [varname ]["infos" ]["transpose" ] = True
558+ self .vars [varname ] = self .vars [varname ].T
556559 if self .nan_mask :
557560 self .vars [varname ] = ma .array (
558561 self .vars [varname ],
@@ -578,20 +581,20 @@ def grid_tiles(self, varname, slice_x, slice_y):
578581 slice_x = slice_x ,
579582 ),
580583 )
581- with Dataset ( self .filename ) as h :
582- dims = h .variables [varname ].dimensions
583- sl = [
584- (slice_x if dim in list (self .x_dim ) else slice_y )
585- if dim in coordinates_dims
586- else 0
587- for dim in dims
588- ]
589- data = h .variables [varname ][sl ]
590- if len (self .x_dim ) == 1 :
591- i_x = where (array (dims ) == self .x_dim )[0 ][0 ]
592- i_y = where (array (dims ) == self .y_dim )[0 ][0 ]
593- if i_x > i_y :
594- data = data .T
584+ h = self .nc4file
585+ dims = h .variables [varname ].dimensions
586+ sl = [
587+ (slice_x if dim in list (self .x_dim ) else slice_y )
588+ if dim in coordinates_dims
589+ else 0
590+ for dim in dims
591+ ]
592+ data = h .variables [varname ][sl ]
593+ if len (self .x_dim ) == 1 :
594+ i_x = where (array (dims ) == self .x_dim )[0 ][0 ]
595+ i_y = where (array (dims ) == self .y_dim )[0 ][0 ]
596+ if i_x > i_y :
597+ data = data .T
595598 if not hasattr (data , "mask" ):
596599 data = ma .array (data , mask = zeros (data .shape , dtype = "bool" ))
597600 return data
@@ -1086,19 +1089,19 @@ class UnRegularGridDataset(GridDataset):
10861089 def load (self ):
10871090 """Load variable (data)"""
10881091 x_name , y_name = self .coordinates
1089- with Dataset ( self .filename ) as h :
1090- self .x_dim = h .variables [x_name ].dimensions
1091- self .y_dim = h .variables [y_name ].dimensions
1092+ h = self .nc4file
1093+ self .x_dim = h .variables [x_name ].dimensions
1094+ self .y_dim = h .variables [y_name ].dimensions
10921095
1093- sl_x = [self .indexs .get (dim , slice (None )) for dim in self .x_dim ]
1094- sl_y = [self .indexs .get (dim , slice (None )) for dim in self .y_dim ]
1095- self .vars [x_name ] = h .variables [x_name ][sl_x ]
1096- self .vars [y_name ] = h .variables [y_name ][sl_y ]
1096+ sl_x = [self .indexs .get (dim , slice (None )) for dim in self .x_dim ]
1097+ sl_y = [self .indexs .get (dim , slice (None )) for dim in self .y_dim ]
1098+ self .vars [x_name ] = h .variables [x_name ][sl_x ]
1099+ self .vars [y_name ] = h .variables [y_name ][sl_y ]
10971100
1098- self .x_c = self .vars [x_name ]
1099- self .y_c = self .vars [y_name ]
1101+ self .x_c = self .vars [x_name ]
1102+ self .y_c = self .vars [y_name ]
11001103
1101- self .init_pos_interpolator ()
1104+ self .init_pos_interpolator ()
11021105
11031106 @property
11041107 def bounds (self ):
0 commit comments