3131from scipy import spatial
3232from dateutil import parser
3333from numpy import meshgrid , zeros , array , where , ma , argmin , vstack , ones , \
34- newaxis , sqrt , diff , r_
34+ newaxis , sqrt , diff , r_ , arange
35+ from scipy .interpolate import interp1d
3536import logging
3637from netCDF4 import Dataset
3738
@@ -60,6 +61,8 @@ class AvisoGrid(BaseData):
6061 'fillval' ,
6162 '_angle' ,
6263 'sla_coeffs' ,
64+ 'xinterp' ,
65+ 'yinterp' ,
6366 'uspd_coeffs' ,
6467 '__lon' ,
6568 '__lat' ,
@@ -77,22 +80,37 @@ def __init__(self, aviso_file, the_domain,
7780 super (AvisoGrid , self ).__init__ ()
7881 logging .info ('Initialising the *AVISO_grid*' )
7982 self .grid_filename = aviso_file
80- self .domain = the_domain
81- self .lonmin = float (lonmin )
82- self .lonmax = float (lonmax )
83- self .latmin = float (latmin )
84- self .latmax = float (latmax )
85- self .grid_filename = aviso_file
86-
83+
8784 self .lon_name = lon_name
8885 self .lat_name = lat_name
8986 self .grid_name = grid_name
9087
9188 self ._lon = self .read_nc (self .lon_name )
9289 self ._lat = self .read_nc (self .lat_name )
90+ if the_domain is None :
91+ self .domain = 'Automatic Domain'
92+ dlon = abs (self ._lon [1 ] - self ._lon [0 ])
93+ dlat = abs (self ._lat [1 ] - self ._lat [0 ])
94+ self .lonmin = float (self ._lon .min ()) + dlon * 2
95+ self .lonmax = float (self ._lon .max ()) - dlon * 2
96+ self .latmin = float (self ._lat .min ()) + dlat * 2
97+ self .latmax = float (self ._lat .max ()) - dlat * 2
98+ if ((self ._lon [- 1 ] + dlon ) % 360 ) == self ._lon [0 ]:
99+ self .domain = 'Global'
100+ self .lonmin = - 100.
101+ self .lonmax = 290.
102+ self .latmin = - 80.
103+ self .latmax = 80.
104+ else :
105+ self .domain = the_domain
106+ self .lonmin = float (lonmin )
107+ self .lonmax = float (lonmax )
108+ self .latmin = float (latmin )
109+ self .latmax = float (latmax )
110+
93111 self .fillval = self .read_nc_att (self .grid_name , '_FillValue' )
94112
95- if lonmin < 0 and lonmax <= 0 :
113+ if self . lonmin < 0 and self . lonmax <= 0 :
96114 self ._lon -= 360.
97115 self ._lon , self ._lat = meshgrid (self ._lon , self ._lat )
98116 self ._angle = zeros (self ._lon .shape )
@@ -102,7 +120,7 @@ def __init__(self, aviso_file, the_domain,
102120
103121 # zero_crossing, used for handling a longitude range that
104122 # crosses zero degree meridian
105- if lonmin < 0 and lonmax >= 0 and 'MedSea' not in self .domain :
123+ if self . lonmin < 0 and self . lonmax >= 0 and 'MedSea' not in self .domain :
106124 if ((self .lonmax < self ._lon .max ()) and (self .lonmax > self ._lon .min ()) and (self .lonmin < self ._lon .max ()) and (self .lonmin > self ._lon .min ())):
107125 pass
108126 else :
@@ -119,9 +137,15 @@ def __init__(self, aviso_file, the_domain,
119137 self .get_aviso_f_pm_pn ()
120138 self .set_u_v_eke ()
121139 self .shape = self .lon .shape
122- # pad2 = 2 * self.pad
123- # self.shape = (self.f_coriolis.shape[0] - pad2,
124- # self.f_coriolis.shape[1] - pad2)
140+
141+ # self.init_pos_interpolator()
142+
143+ def init_pos_interpolator (self ):
144+ self .xinterp = interp1d (self .lon [0 ].copy (), arange (self .lon .shape [1 ]), assume_sorted = True , copy = False , fill_value = (0 , - 1 ), bounds_error = False , kind = 'nearest' )
145+ self .yinterp = interp1d (self .lat [:, 0 ].copy (), arange (self .lon .shape [0 ]), assume_sorted = True , copy = False , fill_value = (0 , - 1 ), bounds_error = False , kind = 'nearest' )
146+
147+ def nearest_indice (self , lon , lat ):
148+ return self .xinterp (lon ), self .yinterp (lat )
125149
126150 def set_filename (self , file_name ):
127151 self .grid_filename = file_name
@@ -137,7 +161,7 @@ def get_aviso_data(self, aviso_file, dimensions=None):
137161 if units not in self .KNOWN_UNITS :
138162 raise Exception ('Unknown units : %s' % units )
139163
140- with Dataset (self .grid_filename ) as h_nc :
164+ with Dataset (self .grid_filename . decode ( 'utf-8' ) ) as h_nc :
141165 grid_dims = array (h_nc .variables [self .grid_name ].dimensions )
142166 lat_dim = h_nc .variables [self .lat_name ].dimensions [0 ]
143167 lon_dim = h_nc .variables [self .lon_name ].dimensions [0 ]
@@ -186,44 +210,6 @@ def set_mask(self, sla):
186210 sea_label = self .labels [plus9 , plus200 ]
187211 self .mask += self .labels != sea_label
188212
189- def fillmask (self , data , mask ):
190- """
191- Fill missing values in an array with an average of nearest
192- neighbours
193- From http://permalink.gmane.org/gmane.comp.python.scientific.user/19610
194- """
195- raise Exception ('Use convolution to fill data' )
196- assert data .ndim == 2 , 'data must be a 2D array.'
197- fill_value = 9999.99
198- data [mask == 0 ] = fill_value
199-
200- # Create (i, j) point arrays for good and bad data.
201- # Bad data are marked by the fill_value, good data elsewhere.
202- igood = vstack (where (data != fill_value )).T
203- ibad = vstack (where (data == fill_value )).T
204-
205- # Create a tree for the bad points, the points to be filled
206- tree = spatial .cKDTree (igood )
207-
208- # Get the four closest points to the bad points
209- # here, distance is squared
210- dist , iquery = tree .query (ibad , k = 4 , p = 2 )
211-
212- # Create a normalised weight, the nearest points are weighted as 1.
213- # Points greater than one are then set to zero
214- weight = dist / (dist .min (axis = 1 )[:, newaxis ])
215- weight *= ones (dist .shape )
216- weight [weight > 1. ] = 0.
217-
218- # Multiply the queried good points by the weight, selecting only the
219- # nearest points. Divide by the number of nearest points to get average
220- xfill = weight * data [igood [:, 0 ][iquery ], igood [:, 1 ][iquery ]]
221- xfill = (xfill / weight .sum (axis = 1 )[:, newaxis ]).sum (axis = 1 )
222-
223- # Place average of nearest good points, xfill, into bad point locations
224- data [ibad [:, 0 ], ibad [:, 1 ]] = xfill
225- return data
226-
227213 @property
228214 def lon (self ):
229215 if self .__lon is None :
0 commit comments