44from scipy import spatial
55from dateutil import parser
66from numpy import meshgrid , zeros , array , where , ma , argmin , vstack , ones , \
7- newaxis , sqrt , diff , r_
7+ newaxis , sqrt , diff , r_ , arange
8+ from scipy .interpolate import interp1d
89import logging
910from netCDF4 import Dataset
1011
@@ -33,6 +34,8 @@ class AvisoGrid(BaseData):
3334 'fillval' ,
3435 '_angle' ,
3536 'sla_coeffs' ,
37+ 'xinterp' ,
38+ 'yinterp' ,
3639 'uspd_coeffs' ,
3740 '__lon' ,
3841 '__lat' ,
@@ -50,22 +53,37 @@ def __init__(self, aviso_file, the_domain,
5053 super (AvisoGrid , self ).__init__ ()
5154 logging .info ('Initialising the *AVISO_grid*' )
5255 self .grid_filename = aviso_file
53- self .domain = the_domain
54- self .lonmin = float (lonmin )
55- self .lonmax = float (lonmax )
56- self .latmin = float (latmin )
57- self .latmax = float (latmax )
58- self .grid_filename = aviso_file
59-
56+
6057 self .lon_name = lon_name
6158 self .lat_name = lat_name
6259 self .grid_name = grid_name
6360
6461 self ._lon = self .read_nc (self .lon_name )
6562 self ._lat = self .read_nc (self .lat_name )
63+ if the_domain is None :
64+ self .domain = 'Automatic Domain'
65+ dlon = abs (self ._lon [1 ] - self ._lon [0 ])
66+ dlat = abs (self ._lat [1 ] - self ._lat [0 ])
67+ self .lonmin = float (self ._lon .min ()) + dlon * 2
68+ self .lonmax = float (self ._lon .max ()) - dlon * 2
69+ self .latmin = float (self ._lat .min ()) + dlat * 2
70+ self .latmax = float (self ._lat .max ()) - dlat * 2
71+ if ((self ._lon [- 1 ] + dlon ) % 360 ) == self ._lon [0 ]:
72+ self .domain = 'Global'
73+ self .lonmin = - 100.
74+ self .lonmax = 290.
75+ self .latmin = - 80.
76+ self .latmax = 80.
77+ else :
78+ self .domain = the_domain
79+ self .lonmin = float (lonmin )
80+ self .lonmax = float (lonmax )
81+ self .latmin = float (latmin )
82+ self .latmax = float (latmax )
83+
6684 self .fillval = self .read_nc_att (self .grid_name , '_FillValue' )
6785
68- if lonmin < 0 and lonmax <= 0 :
86+ if self . lonmin < 0 and self . lonmax <= 0 :
6987 self ._lon -= 360.
7088 self ._lon , self ._lat = meshgrid (self ._lon , self ._lat )
7189 self ._angle = zeros (self ._lon .shape )
@@ -75,7 +93,7 @@ def __init__(self, aviso_file, the_domain,
7593
7694 # zero_crossing, used for handling a longitude range that
7795 # crosses zero degree meridian
78- if lonmin < 0 and lonmax >= 0 and 'MedSea' not in self .domain :
96+ if self . lonmin < 0 and self . lonmax >= 0 and 'MedSea' not in self .domain :
7997 if ((self .lonmax < self ._lon .max ()) and (self .lonmax > self ._lon .min ()) and (self .lonmin < self ._lon .max ()) and (self .lonmin > self ._lon .min ())):
8098 pass
8199 else :
@@ -92,9 +110,15 @@ def __init__(self, aviso_file, the_domain,
92110 self .get_aviso_f_pm_pn ()
93111 self .set_u_v_eke ()
94112 self .shape = self .lon .shape
95- # pad2 = 2 * self.pad
96- # self.shape = (self.f_coriolis.shape[0] - pad2,
97- # self.f_coriolis.shape[1] - pad2)
113+
114+ # self.init_pos_interpolator()
115+
116+ def init_pos_interpolator (self ):
117+ self .xinterp = interp1d (self .lon [0 ].copy (), arange (self .lon .shape [1 ]), assume_sorted = True , copy = False , fill_value = (0 , - 1 ), bounds_error = False , kind = 'nearest' )
118+ self .yinterp = interp1d (self .lat [:, 0 ].copy (), arange (self .lon .shape [0 ]), assume_sorted = True , copy = False , fill_value = (0 , - 1 ), bounds_error = False , kind = 'nearest' )
119+
120+ def nearest_indice (self , lon , lat ):
121+ return self .xinterp (lon ), self .yinterp (lat )
98122
99123 def set_filename (self , file_name ):
100124 self .grid_filename = file_name
@@ -110,7 +134,7 @@ def get_aviso_data(self, aviso_file, dimensions=None):
110134 if units not in self .KNOWN_UNITS :
111135 raise Exception ('Unknown units : %s' % units )
112136
113- with Dataset (self .grid_filename ) as h_nc :
137+ with Dataset (self .grid_filename . decode ( 'utf-8' ) ) as h_nc :
114138 grid_dims = array (h_nc .variables [self .grid_name ].dimensions )
115139 lat_dim = h_nc .variables [self .lat_name ].dimensions [0 ]
116140 lon_dim = h_nc .variables [self .lon_name ].dimensions [0 ]
@@ -159,44 +183,6 @@ def set_mask(self, sla):
159183 sea_label = self .labels [plus9 , plus200 ]
160184 self .mask += self .labels != sea_label
161185
162- def fillmask (self , data , mask ):
163- """
164- Fill missing values in an array with an average of nearest
165- neighbours
166- From http://permalink.gmane.org/gmane.comp.python.scientific.user/19610
167- """
168- raise Exception ('Use convolution to fill data' )
169- assert data .ndim == 2 , 'data must be a 2D array.'
170- fill_value = 9999.99
171- data [mask == 0 ] = fill_value
172-
173- # Create (i, j) point arrays for good and bad data.
174- # Bad data are marked by the fill_value, good data elsewhere.
175- igood = vstack (where (data != fill_value )).T
176- ibad = vstack (where (data == fill_value )).T
177-
178- # Create a tree for the bad points, the points to be filled
179- tree = spatial .cKDTree (igood )
180-
181- # Get the four closest points to the bad points
182- # here, distance is squared
183- dist , iquery = tree .query (ibad , k = 4 , p = 2 )
184-
185- # Create a normalised weight, the nearest points are weighted as 1.
186- # Points greater than one are then set to zero
187- weight = dist / (dist .min (axis = 1 )[:, newaxis ])
188- weight *= ones (dist .shape )
189- weight [weight > 1. ] = 0.
190-
191- # Multiply the queried good points by the weight, selecting only the
192- # nearest points. Divide by the number of nearest points to get average
193- xfill = weight * data [igood [:, 0 ][iquery ], igood [:, 1 ][iquery ]]
194- xfill = (xfill / weight .sum (axis = 1 )[:, newaxis ]).sum (axis = 1 )
195-
196- # Place average of nearest good points, xfill, into bad point locations
197- data [ibad [:, 0 ], ibad [:, 1 ]] = xfill
198- return data
199-
200186 @property
201187 def lon (self ):
202188 if self .__lon is None :
0 commit comments