10
10
import xarray as xr
11
11
from xarray .coding .cftimeindex import _parse_array_of_cftime_strings
12
12
from xarray .core .types import InterpOptions
13
- from xarray .tests import (
14
- assert_allclose ,
15
- assert_equal ,
16
- assert_identical ,
17
- has_dask ,
18
- has_scipy ,
19
- requires_cftime ,
20
- requires_dask ,
21
- requires_scipy ,
22
- )
13
+ from xarray .tests import (assert_allclose , assert_equal , assert_identical ,
14
+ has_dask , has_scipy , requires_cftime , requires_dask ,
15
+ requires_scipy )
23
16
from xarray .tests .test_dataset import create_test_data
24
17
25
18
try :
@@ -132,29 +125,57 @@ def func(obj, new_x):
132
125
assert_allclose (actual , expected )
133
126
134
127
135
- @pytest .mark .parametrize ("use_dask" , [False , True ])
136
- def test_interpolate_vectorize (use_dask : bool ) -> None :
128
+ @pytest .mark .parametrize (
129
+ "use_dask, method" ,
130
+ (
131
+ (False , "linear" ),
132
+ (False , "akima" ),
133
+ (False , "makima" ),
134
+ (True , "linear" ),
135
+ (True , "akima" ),
136
+ ),
137
+ )
138
+ def test_interpolate_vectorize (use_dask : bool , method : str ) -> None :
137
139
if not has_scipy :
138
140
pytest .skip ("scipy is not installed." )
139
141
140
142
if not has_dask and use_dask :
141
143
pytest .skip ("dask is not installed in the environment." )
142
144
143
145
# scipy interpolation for the reference
144
- def func (obj , dim , new_x ):
146
+ def func (obj , dim , new_x , method ):
147
+ scipy_kwargs = {}
148
+ interpolant_options = {
149
+ "barycentric" : "BarycentricInterpolator" ,
150
+ "krogh" : "KroghInterpolator" ,
151
+ "pchip" : "PchipInterpolator" ,
152
+ "akima" : "Akima1DInterpolator" ,
153
+ "makima" : "Akima1DInterpolator" ,
154
+ }
155
+
145
156
shape = [s for i , s in enumerate (obj .shape ) if i != obj .get_axis_num (dim )]
146
157
for s in new_x .shape [::- 1 ]:
147
158
shape .insert (obj .get_axis_num (dim ), s )
148
159
149
- return scipy .interpolate .interp1d (
150
- da [dim ],
151
- obj .data ,
152
- axis = obj .get_axis_num (dim ),
153
- bounds_error = False ,
154
- fill_value = np .nan ,
155
- )(new_x ).reshape (shape )
160
+ if method in interpolant_options :
161
+ from scipy import interpolate
162
+
163
+ interpolant = getattr (interpolate , interpolant_options [method ])
164
+ if method == "makima" :
165
+ scipy_kwargs ["method" ] = method
166
+ return interpolant (
167
+ da [dim ], obj .data , axis = obj .get_axis_num (dim ), ** scipy_kwargs
168
+ )(new_x ).reshape (shape )
169
+ else :
170
+ scipy_kwargs ["kind" ] = method
171
+ scipy_kwargs ["bounds_error" ] = False
172
+ scipy_kwargs ["fill_value" ] = np .nan
173
+ return scipy .interpolate .interp1d (
174
+ da [dim ], obj .data , axis = obj .get_axis_num (dim ), ** scipy_kwargs
175
+ )(new_x ).reshape (shape )
156
176
157
177
da = get_example_data (0 )
178
+
158
179
if use_dask :
159
180
da = da .chunk ({"y" : 5 })
160
181
@@ -165,17 +186,17 @@ def func(obj, dim, new_x):
165
186
coords = {"z" : np .random .randn (30 ), "z2" : ("z" , np .random .randn (30 ))},
166
187
)
167
188
168
- actual = da .interp (x = xdest , method = "linear" )
189
+ actual = da .interp (x = xdest , method = method )
169
190
170
191
expected = xr .DataArray (
171
- func (da , "x" , xdest ),
192
+ func (da , "x" , xdest , method ),
172
193
dims = ["z" , "y" ],
173
194
coords = {
174
195
"z" : xdest ["z" ],
175
196
"z2" : xdest ["z2" ],
176
197
"y" : da ["y" ],
177
198
"x" : ("z" , xdest .values ),
178
- "x2" : ("z" , func (da ["x2" ], "x" , xdest )),
199
+ "x2" : ("z" , func (da ["x2" ], "x" , xdest , method )),
179
200
},
180
201
)
181
202
assert_allclose (actual , expected .transpose ("z" , "y" , transpose_coords = True ))
@@ -191,18 +212,18 @@ def func(obj, dim, new_x):
191
212
},
192
213
)
193
214
194
- actual = da .interp (x = xdest , method = "linear" )
215
+ actual = da .interp (x = xdest , method = method )
195
216
196
217
expected = xr .DataArray (
197
- func (da , "x" , xdest ),
218
+ func (da , "x" , xdest , method ),
198
219
dims = ["z" , "w" , "y" ],
199
220
coords = {
200
221
"z" : xdest ["z" ],
201
222
"w" : xdest ["w" ],
202
223
"z2" : xdest ["z2" ],
203
224
"y" : da ["y" ],
204
225
"x" : (("z" , "w" ), xdest .data ),
205
- "x2" : (("z" , "w" ), func (da ["x2" ], "x" , xdest )),
226
+ "x2" : (("z" , "w" ), func (da ["x2" ], "x" , xdest , method )),
206
227
},
207
228
)
208
229
assert_allclose (actual , expected .transpose ("z" , "w" , "y" , transpose_coords = True ))
@@ -404,7 +425,7 @@ def test_errors(use_dask: bool) -> None:
404
425
pytest .skip ("dask is not installed in the environment." )
405
426
da = da .chunk ()
406
427
407
- for method in ["akima" , " spline" ]:
428
+ for method in ["spline" ]:
408
429
with pytest .raises (ValueError ):
409
430
da .interp (x = [0.5 , 1.5 ], method = method ) # type: ignore[arg-type]
410
431
@@ -922,7 +943,10 @@ def test_interp1d_bounds_error() -> None:
922
943
(("x" , np .array ([0 , 0.5 , 1 , 2 ]), dict (unit = "s" )), False ),
923
944
],
924
945
)
925
- def test_coord_attrs (x , expect_same_attrs : bool ) -> None :
946
+ def test_coord_attrs (
947
+ x ,
948
+ expect_same_attrs : bool ,
949
+ ) -> None :
926
950
base_attrs = dict (foo = "bar" )
927
951
ds = xr .Dataset (
928
952
data_vars = dict (a = 2 * np .arange (5 )),
0 commit comments