Skip to content

Commit 191d31d

Browse files
author
adelepoulle
committed
simplify numpy call
1 parent 411148b commit 191d31d

File tree

5 files changed

+97
-92
lines changed

5 files changed

+97
-92
lines changed

src/py_eddy_tracker/grid/__init__.py

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from scipy import interpolate
44
from scipy import spatial
55
from pyproj import Proj
6-
import numpy as np
6+
from numpy import unique, array, unravel_index, r_, floor, interp, arange, \
7+
sin, cos, deg2rad, arctan2, sqrt, pi, zeros, reciprocal, ma, empty, \
8+
concatenate
79
import logging
810
from ..tracking_objects import nearest
911

@@ -106,8 +108,8 @@ def read_nc_att(self, varname, att):
106108

107109
@property
108110
def is_regular(self):
109-
steps_lon = np.unique(self._lon[0, 1:] - self._lon[0,:-1])
110-
steps_lat = np.unique(self._lat[1:, 0] - self._lat[:-1, 0])
111+
steps_lon = unique(self._lon[0, 1:] - self._lon[0,:-1])
112+
steps_lat = unique(self._lat[1:, 0] - self._lat[:-1, 0])
111113
return len(steps_lon) == 1 and len(steps_lat) == 1 and \
112114
steps_lon[0] != 0. and steps_lat[0] != 0.
113115

@@ -135,14 +137,14 @@ def kdt(lon, lat, limits, k=4):
135137
136138
Don't use cKDTree for regular grid
137139
"""
138-
ppoints = np.array([lon.ravel(), lat.ravel()]).T
140+
ppoints = array([lon.ravel(), lat.ravel()]).T
139141
ptree = spatial.cKDTree(ppoints)
140142
pindices = ptree.query(limits, k=k)[1]
141-
iind, jind = np.array([], dtype=int), np.array([], dtype=int)
143+
iind, jind = array([], dtype=int), array([], dtype=int)
142144
for pind in pindices.ravel():
143-
j, i = np.unravel_index(pind, lon.shape)
144-
iind = np.r_[iind, i]
145-
jind = np.r_[jind, j]
145+
j, i = unravel_index(pind, lon.shape)
146+
iind = r_[iind, i]
147+
jind = r_[jind, j]
146148
return iind, jind
147149

148150
if 'AvisoGrid' in self.__class__.__name__:
@@ -152,16 +154,16 @@ def kdt(lon, lat, limits, k=4):
152154
Used for a zero crossing, e.g., across Agulhas region
153155
"""
154156
if self.is_regular:
155-
i_1 = int(np.floor(np.interp((lonmin - 0.5) % 360,
157+
i_1 = int(floor(interp((lonmin - 0.5) % 360,
156158
self._lon[0],
157-
np.arange(len(self._lon[0])))))
158-
i_0 = int(np.floor(np.interp((lonmax + 0.5) % 360,
159+
arange(len(self._lon[0])))))
160+
i_0 = int(floor(interp((lonmax + 0.5) % 360,
159161
self._lon[0],
160-
np.arange(len(self._lon[0])))
162+
arange(len(self._lon[0])))
161163
) + 1)
162164
else:
163165
def half_limits(lon, lat):
164-
return np.array([[lon.min(), lon.max(),
166+
return array([[lon.min(), lon.max(),
165167
lon.max(), lon.min()],
166168
[lat.min(), lat.min(),
167169
lat.max(), lat.max()]]).T
@@ -230,12 +232,12 @@ def haversine_dist(self, lon1, lat1, lon2, lat2):
230232
Return:
231233
distance (m)
232234
"""
233-
sin_dlat = np.sin(np.deg2rad(lat2 - lat1) * 0.5)
234-
sin_dlon = np.sin(np.deg2rad(lon2 - lon1) * 0.5)
235-
cos_lat1 = np.cos(np.deg2rad(lat1))
236-
cos_lat2 = np.cos(np.deg2rad(lat2))
235+
sin_dlat = sin(deg2rad(lat2 - lat1) * 0.5)
236+
sin_dlon = sin(deg2rad(lon2 - lon1) * 0.5)
237+
cos_lat1 = cos(deg2rad(lat1))
238+
cos_lat2 = cos(deg2rad(lat2))
237239
a_val = sin_dlon ** 2 * cos_lat1 * cos_lat2 + sin_dlat ** 2
238-
c_val = 2 * np.arctan2(np.sqrt(a_val), np.sqrt(1 - a_val))
240+
c_val = 2 * arctan2(sqrt(a_val), sqrt(1 - a_val))
239241
return 6371315.0 * c_val # Return the distance
240242

241243
def nearest_point(self, lon, lat):
@@ -260,7 +262,7 @@ def get_aviso_f_pm_pn(self):
260262
logging.info('--- Computing Coriolis (f), d_x(p_m),'
261263
'd_y (p_n) for padded grid')
262264
# Get GRAVITY / Coriolis
263-
self._gof = np.sin(np.deg2rad(self.latpad)) * 4. * np.pi / 86400.
265+
self._gof = sin(deg2rad(self.latpad)) * 4. * pi / 86400.
264266
self._f_val = self._gof.copy()
265267
self._gof = self.GRAVITY / self._gof
266268

@@ -270,29 +272,29 @@ def get_aviso_f_pm_pn(self):
270272
latv = self.half_interp(self.latpad[:-1], self.latpad[1:])
271273

272274
# Get p_m and p_n
273-
p_m = np.zeros_like(self.lonpad)
275+
p_m = zeros(self.lonpad.shape)
274276
p_m[:, 1:-1] = self.haversine_dist(lonu[:, :-1], latu[:, :-1],
275277
lonu[:, 1:], latu[:, 1:])
276278
p_m[:, 0] = p_m[:, 1]
277279
p_m[:, -1] = p_m[:, -2]
278280
self._dx = p_m
279-
self._pm = np.reciprocal(p_m)
281+
self._pm = reciprocal(p_m)
280282

281-
p_n = np.zeros_like(self.lonpad)
283+
p_n = zeros(self.lonpad.shape)
282284
p_n[1:-1] = self.haversine_dist(lonv[:-1], latv[:-1],
283285
lonv[1:], latv[1:])
284286
p_n[0] = p_n[1]
285287
p_n[-1] = p_n[-2]
286288
self._dy = p_n
287-
self._pn = np.reciprocal(p_n)
289+
self._pn = reciprocal(p_n)
288290
return self
289291

290292
def u2rho_2d(self, uu_in):
291293
"""
292294
Convert a 2D field at u_val points to a field at rho points
293295
"""
294296
def uu2ur(uu_in, m_p, l_p):
295-
u_out = np.zeros((m_p, l_p))
297+
u_out = zeros((m_p, l_p))
296298
u_out[:, 1:-1] = self.half_interp(uu_in[:, :-1], uu_in[:, 1:])
297299
u_out[:, 0] = u_out[:, 1]
298300
u_out[:, -1] = u_out[:, -2]
@@ -303,7 +305,7 @@ def uu2ur(uu_in, m_p, l_p):
303305
def v2rho_2d(self, vv_in):
304306
# Convert a 2D field at v_val points to a field at rho points
305307
def vv2vr(vv_in, m_p, l_p):
306-
v_out = np.zeros((m_p, l_p))
308+
v_out = zeros((m_p, l_p))
307309
v_out[1:-1] = self.half_interp(vv_in[:-1], vv_in[1:])
308310
v_out[0] = v_out[1]
309311
v_out[-1] = v_out[-2]
@@ -372,13 +374,13 @@ def set_geostrophic_velocity(self, zeta):
372374
zeta1, zeta2 = zeta.data[1:].view(), zeta.data[:-1].view()
373375
pn1, pn2 = self.p_n[1:].view(), self.p_n[:-1].view()
374376
self.upad[:] = self.v2rho_2d(
375-
np.ma.array((zeta1 - zeta2) * 0.5 * (pn1 + pn2), mask= self.vmask))
377+
ma.array((zeta1 - zeta2) * 0.5 * (pn1 + pn2), mask= self.vmask))
376378
self.upad *= -self.gof
377379

378380
zeta1, zeta2 = zeta.data[:, 1:].view(), zeta.data[:, :-1].view()
379381
pm1, pm2 = self.p_m[:, 1:].view(), self.p_m[:, :-1].view()
380382
self.vpad[:] = self.u2rho_2d(
381-
np.ma.array((zeta1 - zeta2) *0.5 * (pm1 + pm2), mask=self.umask))
383+
ma.array((zeta1 - zeta2) *0.5 * (pm1 + pm2), mask=self.umask))
382384
self.vpad *= self.gof
383385
return self
384386

@@ -388,13 +390,13 @@ def set_u_v_eke(self, pad=2):
388390
"""
389391
j_size = self.slice_j_pad.stop - self.slice_j_pad.start
390392
if self.zero_crossing:
391-
u_1 = np.empty((j_size, self.slice_i_pad.start))
392-
u_0 = np.empty((j_size, self._lon.shape[1] - self.slice_i_pad.stop))
393-
self.upad = np.ma.concatenate((u_0, u_1), axis=1)
393+
u_1 = empty((j_size, self.slice_i_pad.start))
394+
u_0 = empty((j_size, self._lon.shape[1] - self.slice_i_pad.stop))
395+
self.upad = ma.concatenate((u_0, u_1), axis=1)
394396
else:
395-
self.upad = np.empty((j_size,
397+
self.upad = empty((j_size,
396398
self.slice_i_pad.stop - self.slice_i_pad.start))
397-
self.vpad = np.empty_like(self.upad)
399+
self.vpad = empty(self.upad.shape)
398400

399401
def get_eke(self):
400402
"""
@@ -412,7 +414,7 @@ def uspd(self):
412414
if hasattr(uspd, 'mask'):
413415
uspd.mask += self.mask[self.view_unpad]
414416
else:
415-
uspd = np.ma.array(uspd, mask=self.mask[self.view_unpad])
417+
uspd = ma.array(uspd, mask=self.mask[self.view_unpad])
416418
return uspd
417419

418420
def set_interp_coeffs(self, sla, uspd):
@@ -429,8 +431,8 @@ def set_interp_coeffs(self, sla, uspd):
429431
def create_index_inverse(slice_to_inverse, size):
430432
"""Return an array of index
431433
"""
432-
index = np.concatenate((np.arange(slice_to_inverse.stop, size),
433-
np.arange(slice_to_inverse.start)))
434+
index = concatenate((arange(slice_to_inverse.stop, size),
435+
arange(slice_to_inverse.start)))
434436
return index
435437

436438
def gaussian_resolution(self, zwl, mwl):

src/py_eddy_tracker/grid/aviso.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from scipy import ndimage
44
from scipy import spatial
55
from dateutil import parser
6-
import numpy as np
6+
from numpy import meshgrid, zeros, array, where, ma, argmin, vstack, ones \
7+
newaxis, sqrt, diff, r_
78
import logging
89
from netCDF4 import Dataset
910

@@ -15,6 +16,11 @@ class AvisoGrid(PyEddyTracker):
1516
Class to satisfy the need of the eddy tracker
1617
to have a grid class
1718
"""
19+
KNOWN_UNITS = dict(
20+
m=100.,
21+
cm=1.,
22+
)
23+
1824
def __init__(self, aviso_file, the_domain,
1925
lonmin, lonmax, latmin, latmax, grid_name, lon_name,
2026
lat_name,with_pad=True):
@@ -41,8 +47,8 @@ def __init__(self, aviso_file, the_domain,
4147

4248
if lonmin < 0 and lonmax <= 0:
4349
self._lon -= 360.
44-
self._lon, self._lat = np.meshgrid(self._lon, self._lat)
45-
self._angle = np.zeros_like(self._lon)
50+
self._lon, self._lat = meshgrid(self._lon, self._lat)
51+
self._angle = zeros(self._lon.shape)
4652

4753
if 'MedSea' in self.the_domain:
4854
self._lon -= 360.
@@ -71,23 +77,19 @@ def get_aviso_data(self, aviso_file):
7177
"""
7278
Read nc data from AVISO file
7379
"""
74-
KNOWN_UNITS = dict(
75-
m=100.,
76-
cm=1.,
77-
)
7880
self.grid_filename = aviso_file
7981
units = self.read_nc_att(self.grid_name, 'units')
80-
if units not in KNOWN_UNITS:
82+
if units not in self.KNOWN_UNITS:
8183
raise Exception('Unknown units : %s' % units)
8284

8385
with Dataset(self.grid_filename) as h_nc:
84-
grid_dims = np.array(h_nc.variables[self.grid_name].dimensions)
86+
grid_dims = array(h_nc.variables[self.grid_name].dimensions)
8587
lat_dim = h_nc.variables[self.lat_name].dimensions[0]
8688
lon_dim = h_nc.variables[self.lon_name].dimensions[0]
8789

8890
i_list = []
8991
transpose = False
90-
if np.where(grid_dims == lat_dim)[0][0] > np.where(grid_dims == lon_dim)[0][0]:
92+
if where(grid_dims == lat_dim)[0][0] > where(grid_dims == lon_dim)[0][0]:
9193
transpose = True
9294
for grid_dim in grid_dims:
9395
if grid_dim == lat_dim:
@@ -101,11 +103,11 @@ def get_aviso_data(self, aviso_file):
101103
if transpose:
102104
zeta = zeta.T
103105

104-
zeta *= KNOWN_UNITS[units] # units to cm
106+
zeta *= self.KNOWN_UNITS[units] # units to cm
105107
if hasattr(zeta, 'mask'):
106108
return zeta
107109
else:
108-
return np.ma.array(zeta)
110+
return ma.array(zeta)
109111

110112
def set_mask(self, sla):
111113
"""
@@ -117,7 +119,7 @@ def set_mask(self, sla):
117119
if 'Global' in self.the_domain:
118120

119121
# Close Drake Passage
120-
minus70 = np.argmin(np.abs(self.lonpad[0] + 70))
122+
minus70 = argmin(abs(self.lonpad[0] + 70))
121123
self.mask[:125, minus70] = True
122124

123125
# DT10 mask is open around Panama, so close it...
@@ -150,8 +152,8 @@ def set_mask(self, sla):
150152
self.labels = ndimage.label(-self.mask)[0]
151153

152154
# Set to known sea point
153-
plus200 = np.argmin(np.abs(self.lonpad[0] - 200))
154-
plus9 = np.argmin(np.abs(self.latpad[:, 0] - 9))
155+
plus200 = argmin(abs(self.lonpad[0] - 200))
156+
plus9 = argmin(abs(self.latpad[:, 0] - 9))
155157
sea_label = self.labels[plus9, plus200]
156158
self.mask += self.labels != sea_label
157159
return self
@@ -168,8 +170,8 @@ def fillmask(self, data, mask):
168170

169171
# Create (i, j) point arrays for good and bad data.
170172
# Bad data are marked by the fill_value, good data elsewhere.
171-
igood = np.vstack(np.where(data != fill_value)).T
172-
ibad = np.vstack(np.where(data == fill_value)).T
173+
igood = vstack(where(data != fill_value)).T
174+
ibad = vstack(where(data == fill_value)).T
173175

174176
# Create a tree for the bad points, the points to be filled
175177
tree = spatial.cKDTree(igood)
@@ -180,14 +182,14 @@ def fillmask(self, data, mask):
180182

181183
# Create a normalised weight, the nearest points are weighted as 1.
182184
# Points greater than one are then set to zero
183-
weight = dist / (dist.min(axis=1)[:, np.newaxis])
184-
weight *= np.ones_like(dist)
185-
np.place(weight, weight > 1., 0.)
185+
weight = dist / (dist.min(axis=1)[:, newaxis])
186+
weight *= ones(dist.shape)
187+
weight[weight > 1.] = 0.
186188

187189
# Multiply the queried good points by the weight, selecting only the
188190
# nearest points. Divide by the number of nearest points to get average
189191
xfill = weight * data[igood[:, 0][iquery], igood[:, 1][iquery]]
190-
xfill = (xfill / weight.sum(axis=1)[:, np.newaxis]).sum(axis=1)
192+
xfill = (xfill / weight.sum(axis=1)[:, newaxis]).sum(axis=1)
191193

192194
# Place average of nearest good points, xfill, into bad point locations
193195
data[ibad[:, 0], ibad[:, 1]] = xfill
@@ -262,8 +264,8 @@ def p_n(self): # Reciprocal of d_y
262264

263265
@property
264266
def resolution(self):
265-
return np.sqrt(np.diff(self.lon[1:], axis=1) *
266-
np.diff(self.lat[:, 1:], axis=0)).mean()
267+
return sqrt(diff(self.lon[1:], axis=1) *
268+
diff(self.lat[:, 1:], axis=0)).mean()
267269

268270
@property
269271
def boundary(self):
@@ -274,8 +276,8 @@ def boundary(self):
274276
Returns:
275277
lon/lat boundary points
276278
"""
277-
lon = np.r_[(self.lon[:, 0], self.lon[-1],
279+
lon = r_[(self.lon[:, 0], self.lon[-1],
278280
self.lon[::-1, -1], self.lon[0, ::-1])]
279-
lat = np.r_[(self.lat[:, 0], self.lat[-1],
281+
lat = r_[(self.lat[:, 0], self.lat[-1],
280282
self.lat[::-1, -1], self.lat[0, ::-1])]
281283
return lon, lat

src/py_eddy_tracker/grid/roms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# -*- coding: utf-8 -*-
2-
'''
2+
"""
33
===========================================================================
44
This file is part of py-eddy-tracker.
55
@@ -30,9 +30,9 @@
3030
Useful when the output files don't contain the grid
3131
information
3232
33-
'''
33+
"""
34+
3435
import netCDF4 as netcdf
35-
# from matplotlib.mlab import load
3636
import numpy as np
3737
import matplotlib.path as Path
3838
from . import PyEddyTracker

src/py_eddy_tracker/observations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def merge(self, other):
141141
return eddies
142142

143143
def reset(self):
144-
self.observations = np.zeros(0, dtype=self.dtype)
144+
self.observations = zeros(0, dtype=self.dtype)
145145

146146
@property
147147
def obs(self):

0 commit comments

Comments
 (0)