Skip to content

Commit 09a9e73

Browse files
committed
vectorize 1d interpolators
1 parent d26144d commit 09a9e73

File tree

3 files changed

+73
-40
lines changed

3 files changed

+73
-40
lines changed

xarray/core/missing.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def __init__(
203203

204204
self.method = method
205205
self.cons_kwargs = kwargs
206+
del self.cons_kwargs["axis"]
206207
self.call_kwargs = {"nu": nu, "ext": ext}
207208

208209
if fill_value is not None:
@@ -479,7 +480,8 @@ def _get_interpolator(
479480
interp1d_methods = get_args(Interp1dOptions)
480481
valid_methods = tuple(vv for v in get_args(InterpOptions) for vv in get_args(v))
481482

482-
# prioritize scipy.interpolate
483+
# prefer numpy.interp for 1d linear interpolation. This function cannot
484+
# take higher dimensional data but scipy.interp1d can.
483485
if (
484486
method == "linear"
485487
and not kwargs.get("fill_value", None) == "extrapolate"
@@ -489,25 +491,31 @@ def _get_interpolator(
489491
interp_class = NumpyInterpolator
490492

491493
elif method in valid_methods:
494+
kwargs.update(axis=-1)
492495
if method in interp1d_methods:
493496
kwargs.update(method=method)
494497
interp_class = ScipyInterpolator
495-
elif vectorizeable_only:
496-
raise ValueError(
497-
f"{method} is not a vectorizeable interpolator. "
498-
f"Available methods are {interp1d_methods}"
499-
)
500498
elif method == "barycentric":
501499
interp_class = _import_interpolant("BarycentricInterpolator", method)
502500
elif method in ["krogh", "krog"]:
503501
interp_class = _import_interpolant("KroghInterpolator", method)
504502
elif method == "pchip":
505503
interp_class = _import_interpolant("PchipInterpolator", method)
506504
elif method == "spline":
505+
# utils.emit_user_level_warning(
506+
# "The 1d SplineInterpolator class is performing an incorrect calculation and "
507+
# "is being deprecated. Please use `method=polynomial` for 1D Spline Interpolation.",
508+
# PendingDeprecationWarning,
509+
# )
510+
if vectorizeable_only:
511+
raise ValueError(f"{method} is not a vectorizeable interpolator. ")
507512
kwargs.update(method=method)
508513
interp_class = SplineInterpolator
509514
elif method == "akima":
510515
interp_class = _import_interpolant("Akima1DInterpolator", method)
516+
elif method == "makima":
517+
kwargs.update(method="makima")
518+
interp_class = _import_interpolant("Akima1DInterpolator", method)
511519
else:
512520
raise ValueError(f"{method} is not a valid scipy interpolator")
513521
else:
@@ -525,6 +533,7 @@ def _get_interpolator_nd(method, **kwargs):
525533

526534
if method in valid_methods:
527535
kwargs.update(method=method)
536+
kwargs.update(bounds_error=False)
528537
interp_class = _import_interpolant("interpn", method)
529538
else:
530539
raise ValueError(
@@ -614,9 +623,6 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs):
614623
if not indexes_coords:
615624
return var.copy()
616625

617-
# default behavior
618-
kwargs["bounds_error"] = kwargs.get("bounds_error", False)
619-
620626
result = var
621627
# decompose the interpolation into a succession of independent interpolation
622628
for indep_indexes_coords in decompose_interp(indexes_coords):
@@ -755,8 +761,9 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs):
755761

756762
def _interp1d(var, x, new_x, func, kwargs):
757763
# x, new_x are tuples of size 1.
758-
x, new_x = x[0], new_x[0]
759-
rslt = func(x, var, assume_sorted=True, **kwargs)(np.ravel(new_x))
764+
x, new_x = x[0].data, new_x[0].data
765+
766+
rslt = func(x, var, **kwargs)(np.ravel(new_x))
760767
if new_x.ndim > 1:
761768
return reshape(rslt, (var.shape[:-1] + new_x.shape))
762769
if new_x.ndim == 0:

xarray/core/types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,9 @@ def copy(
228228
Interp1dOptions = Literal[
229229
"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial"
230230
]
231-
InterpolantOptions = Literal["barycentric", "krogh", "pchip", "spline", "akima"]
231+
InterpolantOptions = Literal[
232+
"barycentric", "krogh", "pchip", "spline", "akima", "makima"
233+
]
232234
InterpOptions = Union[Interp1dOptions, InterpolantOptions]
233235

234236
DatetimeUnitOptions = Literal[

xarray/tests/test_interp.py

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,9 @@
1010
import xarray as xr
1111
from xarray.coding.cftimeindex import _parse_array_of_cftime_strings
1212
from xarray.core.types import InterpOptions
13-
from xarray.tests import (
14-
assert_allclose,
15-
assert_equal,
16-
assert_identical,
17-
has_dask,
18-
has_scipy,
19-
requires_cftime,
20-
requires_dask,
21-
requires_scipy,
22-
)
13+
from xarray.tests import (assert_allclose, assert_equal, assert_identical,
14+
has_dask, has_scipy, requires_cftime, requires_dask,
15+
requires_scipy)
2316
from xarray.tests.test_dataset import create_test_data
2417

2518
try:
@@ -132,29 +125,57 @@ def func(obj, new_x):
132125
assert_allclose(actual, expected)
133126

134127

135-
@pytest.mark.parametrize("use_dask", [False, True])
136-
def test_interpolate_vectorize(use_dask: bool) -> None:
128+
@pytest.mark.parametrize(
129+
"use_dask, method",
130+
(
131+
(False, "linear"),
132+
(False, "akima"),
133+
(False, "makima"),
134+
(True, "linear"),
135+
(True, "akima"),
136+
),
137+
)
138+
def test_interpolate_vectorize(use_dask: bool, method: str) -> None:
137139
if not has_scipy:
138140
pytest.skip("scipy is not installed.")
139141

140142
if not has_dask and use_dask:
141143
pytest.skip("dask is not installed in the environment.")
142144

143145
# scipy interpolation for the reference
144-
def func(obj, dim, new_x):
146+
def func(obj, dim, new_x, method):
147+
scipy_kwargs = {}
148+
interpolant_options = {
149+
"barycentric": "BarycentricInterpolator",
150+
"krogh": "KroghInterpolator",
151+
"pchip": "PchipInterpolator",
152+
"akima": "Akima1DInterpolator",
153+
"makima": "Akima1DInterpolator",
154+
}
155+
145156
shape = [s for i, s in enumerate(obj.shape) if i != obj.get_axis_num(dim)]
146157
for s in new_x.shape[::-1]:
147158
shape.insert(obj.get_axis_num(dim), s)
148159

149-
return scipy.interpolate.interp1d(
150-
da[dim],
151-
obj.data,
152-
axis=obj.get_axis_num(dim),
153-
bounds_error=False,
154-
fill_value=np.nan,
155-
)(new_x).reshape(shape)
160+
if method in interpolant_options:
161+
from scipy import interpolate
162+
163+
interpolant = getattr(interpolate, interpolant_options[method])
164+
if method == "makima":
165+
scipy_kwargs["method"] = method
166+
return interpolant(
167+
da[dim], obj.data, axis=obj.get_axis_num(dim), **scipy_kwargs
168+
)(new_x).reshape(shape)
169+
else:
170+
scipy_kwargs["kind"] = method
171+
scipy_kwargs["bounds_error"] = False
172+
scipy_kwargs["fill_value"] = np.nan
173+
return scipy.interpolate.interp1d(
174+
da[dim], obj.data, axis=obj.get_axis_num(dim), **scipy_kwargs
175+
)(new_x).reshape(shape)
156176

157177
da = get_example_data(0)
178+
158179
if use_dask:
159180
da = da.chunk({"y": 5})
160181

@@ -165,17 +186,17 @@ def func(obj, dim, new_x):
165186
coords={"z": np.random.randn(30), "z2": ("z", np.random.randn(30))},
166187
)
167188

168-
actual = da.interp(x=xdest, method="linear")
189+
actual = da.interp(x=xdest, method=method)
169190

170191
expected = xr.DataArray(
171-
func(da, "x", xdest),
192+
func(da, "x", xdest, method),
172193
dims=["z", "y"],
173194
coords={
174195
"z": xdest["z"],
175196
"z2": xdest["z2"],
176197
"y": da["y"],
177198
"x": ("z", xdest.values),
178-
"x2": ("z", func(da["x2"], "x", xdest)),
199+
"x2": ("z", func(da["x2"], "x", xdest, method)),
179200
},
180201
)
181202
assert_allclose(actual, expected.transpose("z", "y", transpose_coords=True))
@@ -191,18 +212,18 @@ def func(obj, dim, new_x):
191212
},
192213
)
193214

194-
actual = da.interp(x=xdest, method="linear")
215+
actual = da.interp(x=xdest, method=method)
195216

196217
expected = xr.DataArray(
197-
func(da, "x", xdest),
218+
func(da, "x", xdest, method),
198219
dims=["z", "w", "y"],
199220
coords={
200221
"z": xdest["z"],
201222
"w": xdest["w"],
202223
"z2": xdest["z2"],
203224
"y": da["y"],
204225
"x": (("z", "w"), xdest.data),
205-
"x2": (("z", "w"), func(da["x2"], "x", xdest)),
226+
"x2": (("z", "w"), func(da["x2"], "x", xdest, method)),
206227
},
207228
)
208229
assert_allclose(actual, expected.transpose("z", "w", "y", transpose_coords=True))
@@ -404,7 +425,7 @@ def test_errors(use_dask: bool) -> None:
404425
pytest.skip("dask is not installed in the environment.")
405426
da = da.chunk()
406427

407-
for method in ["akima", "spline"]:
428+
for method in ["spline"]:
408429
with pytest.raises(ValueError):
409430
da.interp(x=[0.5, 1.5], method=method) # type: ignore[arg-type]
410431

@@ -922,7 +943,10 @@ def test_interp1d_bounds_error() -> None:
922943
(("x", np.array([0, 0.5, 1, 2]), dict(unit="s")), False),
923944
],
924945
)
925-
def test_coord_attrs(x, expect_same_attrs: bool) -> None:
946+
def test_coord_attrs(
947+
x,
948+
expect_same_attrs: bool,
949+
) -> None:
926950
base_attrs = dict(foo="bar")
927951
ds = xr.Dataset(
928952
data_vars=dict(a=2 * np.arange(5)),

0 commit comments

Comments
 (0)