44from  scipy  import  spatial 
55from  dateutil  import  parser 
66from  numpy  import  meshgrid , zeros , array , where , ma , argmin , vstack , ones , \
7-     newaxis , sqrt , diff , r_ 
7+     newaxis , sqrt , diff , r_ , arange 
8+ from  scipy .interpolate   import  interp1d 
89import  logging 
910from  netCDF4  import  Dataset 
1011
@@ -33,6 +34,8 @@ class AvisoGrid(BaseData):
3334        'fillval' ,
3435        '_angle' ,
3536        'sla_coeffs' ,
37+         'xinterp' ,
38+         'yinterp' ,
3639        'uspd_coeffs' ,
3740        '__lon' ,
3841        '__lat' ,
@@ -50,22 +53,37 @@ def __init__(self, aviso_file, the_domain,
5053        super (AvisoGrid , self ).__init__ ()
5154        logging .info ('Initialising the *AVISO_grid*' )
5255        self .grid_filename  =  aviso_file 
53-         self .domain  =  the_domain 
54-         self .lonmin  =  float (lonmin )
55-         self .lonmax  =  float (lonmax )
56-         self .latmin  =  float (latmin )
57-         self .latmax  =  float (latmax )
58-         self .grid_filename  =  aviso_file 
59-         
56+ 
6057        self .lon_name  =  lon_name 
6158        self .lat_name  =  lat_name 
6259        self .grid_name  =  grid_name 
6360
6461        self ._lon  =  self .read_nc (self .lon_name )
6562        self ._lat  =  self .read_nc (self .lat_name )
63+         if  the_domain  is  None :
64+             self .domain  =  'Automatic Domain' 
65+             dlon  =  abs (self ._lon [1 ] -  self ._lon [0 ])
66+             dlat  =  abs (self ._lat [1 ] -  self ._lat [0 ])
67+             self .lonmin  =  float (self ._lon .min ()) +  dlon  *  2 
68+             self .lonmax  =  float (self ._lon .max ()) -  dlon  *  2 
69+             self .latmin  =  float (self ._lat .min ()) +  dlat  *  2 
70+             self .latmax  =  float (self ._lat .max ()) -  dlat  *  2 
71+             if  ((self ._lon [- 1 ] +  dlon ) %  360 ) ==  self ._lon [0 ]:
72+                 self .domain  =  'Global' 
73+                 self .lonmin  =  - 100. 
74+                 self .lonmax  =  290. 
75+                 self .latmin  =  - 80. 
76+                 self .latmax  =  80. 
77+         else :
78+             self .domain  =  the_domain 
79+             self .lonmin  =  float (lonmin )
80+             self .lonmax  =  float (lonmax )
81+             self .latmin  =  float (latmin )
82+             self .latmax  =  float (latmax )
83+ 
6684        self .fillval  =  self .read_nc_att (self .grid_name , '_FillValue' )
6785
68-         if  lonmin  <  0  and  lonmax  <=  0 :
86+         if  self . lonmin  <  0  and  self . lonmax  <=  0 :
6987            self ._lon  -=  360. 
7088        self ._lon , self ._lat  =  meshgrid (self ._lon , self ._lat )
7189        self ._angle  =  zeros (self ._lon .shape )
@@ -75,7 +93,7 @@ def __init__(self, aviso_file, the_domain,
7593
7694        # zero_crossing, used for handling a longitude range that 
7795        # crosses zero degree meridian 
78-         if  lonmin  <  0  and  lonmax  >=  0  and  'MedSea'  not  in   self .domain :
96+         if  self . lonmin  <  0  and  self . lonmax  >=  0  and  'MedSea'  not  in   self .domain :
7997            if  ((self .lonmax  <  self ._lon .max ()) and  (self .lonmax  >  self ._lon .min ()) and  (self .lonmin  <  self ._lon .max ()) and  (self .lonmin  >  self ._lon .min ())):
8098                pass 
8199            else :
@@ -92,9 +110,15 @@ def __init__(self, aviso_file, the_domain,
92110        self .get_aviso_f_pm_pn ()
93111        self .set_u_v_eke ()
94112        self .shape  =  self .lon .shape 
95- #         pad2 = 2 * self.pad 
96- #         self.shape = (self.f_coriolis.shape[0] - pad2, 
97- #                       self.f_coriolis.shape[1] - pad2) 
113+ 
114+         # self.init_pos_interpolator() 
115+ 
116+     def  init_pos_interpolator (self ):
117+         self .xinterp  =  interp1d (self .lon [0 ].copy (), arange (self .lon .shape [1 ]), assume_sorted = True , copy = False , fill_value = (0 , - 1 ), bounds_error = False , kind = 'nearest' )
118+         self .yinterp  =  interp1d (self .lat [:, 0 ].copy (), arange (self .lon .shape [0 ]), assume_sorted = True , copy = False , fill_value = (0 , - 1 ), bounds_error = False , kind = 'nearest' )
119+ 
120+     def  nearest_indice (self , lon , lat ):
121+         return  self .xinterp (lon ), self .yinterp (lat )
98122
99123    def  set_filename (self , file_name ):
100124        self .grid_filename  =  file_name 
@@ -110,7 +134,7 @@ def get_aviso_data(self, aviso_file, dimensions=None):
110134        if  units  not  in   self .KNOWN_UNITS :
111135            raise  Exception ('Unknown units : %s'  %  units )
112136
113-         with  Dataset (self .grid_filename ) as  h_nc :
137+         with  Dataset (self .grid_filename . decode ( 'utf-8' ) ) as  h_nc :
114138            grid_dims  =  array (h_nc .variables [self .grid_name ].dimensions )
115139            lat_dim  =  h_nc .variables [self .lat_name ].dimensions [0 ]
116140            lon_dim  =  h_nc .variables [self .lon_name ].dimensions [0 ]
@@ -159,44 +183,6 @@ def set_mask(self, sla):
159183                sea_label  =  self .labels [plus9 , plus200 ]
160184                self .mask  +=  self .labels  !=  sea_label 
161185
162-     def  fillmask (self , data , mask ):
163-         """ 
164-         Fill missing values in an array with an average of nearest 
165-         neighbours 
166-         From http://permalink.gmane.org/gmane.comp.python.scientific.user/19610 
167-         """ 
168-         raise  Exception ('Use convolution to fill data' )
169-         assert  data .ndim  ==  2 , 'data must be a 2D array.' 
170-         fill_value  =  9999.99 
171-         data [mask  ==  0 ] =  fill_value 
172- 
173-         # Create (i, j) point arrays for good and bad data. 
174-         # Bad data are marked by the fill_value, good data elsewhere. 
175-         igood  =  vstack (where (data  !=  fill_value )).T 
176-         ibad  =  vstack (where (data  ==  fill_value )).T 
177- 
178-         # Create a tree for the bad points, the points to be filled 
179-         tree  =  spatial .cKDTree (igood )
180- 
181-         # Get the four closest points to the bad points 
182-         # here, distance is squared 
183-         dist , iquery  =  tree .query (ibad , k = 4 , p = 2 )
184- 
185-         # Create a normalised weight, the nearest points are weighted as 1. 
186-         #   Points greater than one are then set to zero 
187-         weight  =  dist  /  (dist .min (axis = 1 )[:, newaxis ])
188-         weight  *=  ones (dist .shape )
189-         weight [weight  >  1. ] =  0. 
190- 
191-         # Multiply the queried good points by the weight, selecting only the 
192-         # nearest points. Divide by the number of nearest points to get average 
193-         xfill  =  weight  *  data [igood [:, 0 ][iquery ], igood [:, 1 ][iquery ]]
194-         xfill  =  (xfill  /  weight .sum (axis = 1 )[:, newaxis ]).sum (axis = 1 )
195- 
196-         # Place average of nearest good points, xfill, into bad point locations 
197-         data [ibad [:, 0 ], ibad [:, 1 ]] =  xfill 
198-         return  data 
199- 
200186    @property  
201187    def  lon (self ):
202188        if  self .__lon  is  None :
0 commit comments