1
1
import warnings
2
2
from functools import partial
3
- from typing import Any , Callable , Dict , Sequence
3
+ from numbers import Number
4
+ from typing import Any , Callable , Dict , Hashable , Sequence , Union
4
5
5
6
import numpy as np
6
7
import pandas as pd
7
8
8
9
from . import utils
9
- from .common import _contains_datetime_like_objects
10
+ from .common import _contains_datetime_like_objects , ones_like
10
11
from .computation import apply_ufunc
11
12
from .duck_array_ops import dask_array_type
12
13
from .utils import OrderedSet , is_scalar
13
14
from .variable import Variable , broadcast_variables
14
15
15
16
17
+ def _get_nan_block_lengths (obj , dim : Hashable , index : Variable ):
18
+ """
19
+ Return an object where each NaN element in 'obj' is replaced by the
20
+ length of the gap the element is in.
21
+ """
22
+
23
+ # make variable so that we get broadcasting for free
24
+ index = Variable ([dim ], index )
25
+
26
+ # algorithm from https://github.com/pydata/xarray/pull/3302#discussion_r324707072
27
+ arange = ones_like (obj ) * index
28
+ valid = obj .notnull ()
29
+ valid_arange = arange .where (valid )
30
+ cumulative_nans = valid_arange .ffill (dim = dim ).fillna (index [0 ])
31
+
32
+ nan_block_lengths = (
33
+ cumulative_nans .diff (dim = dim , label = "upper" )
34
+ .reindex ({dim : obj [dim ]})
35
+ .where (valid )
36
+ .bfill (dim = dim )
37
+ .where (~ valid , 0 )
38
+ .fillna (index [- 1 ] - valid_arange .max ())
39
+ )
40
+
41
+ return nan_block_lengths
42
+
43
+
16
44
class BaseInterpolator :
17
45
"""Generic interpolator class for normalizing interpolation methods
18
46
"""
@@ -178,7 +206,7 @@ def _apply_over_vars_with_dim(func, self, dim=None, **kwargs):
178
206
return ds
179
207
180
208
181
- def get_clean_interp_index (arr , dim , use_coordinate = True ):
209
+ def get_clean_interp_index (arr , dim : Hashable , use_coordinate : Union [ str , bool ] = True ):
182
210
"""get index to use for x values in interpolation.
183
211
184
212
If use_coordinate is True, the coordinate that shares the name of the
@@ -195,23 +223,33 @@ def get_clean_interp_index(arr, dim, use_coordinate=True):
195
223
index = arr .coords [use_coordinate ]
196
224
if index .ndim != 1 :
197
225
raise ValueError (
198
- "Coordinates used for interpolation must be 1D, "
199
- "%s is %dD." % ( use_coordinate , index .ndim )
226
+ f "Coordinates used for interpolation must be 1D, "
227
+ f" { use_coordinate } is { index .ndim } D."
200
228
)
229
+ index = index .to_index ()
230
+
231
+ # TODO: index.name is None for multiindexes
232
+ # set name for nice error messages below
233
+ if isinstance (index , pd .MultiIndex ):
234
+ index .name = dim
235
+
236
+ if not index .is_monotonic :
237
+ raise ValueError (f"Index { index .name !r} must be monotonically increasing" )
238
+
239
+ if not index .is_unique :
240
+ raise ValueError (f"Index { index .name !r} has duplicate values" )
201
241
202
242
# raise if index cannot be cast to a float (e.g. MultiIndex)
203
243
try :
204
244
index = index .values .astype (np .float64 )
205
245
except (TypeError , ValueError ):
206
246
# pandas raises a TypeError
207
- # xarray/nuppy raise a ValueError
247
+ # xarray/numpy raise a ValueError
208
248
raise TypeError (
209
- "Index must be castable to float64 to support"
210
- "interpolation, got: %s" % type (index )
249
+ f "Index { index . name !r } must be castable to float64 to support "
250
+ f "interpolation, got { type (index ). __name__ } ."
211
251
)
212
- # check index sorting now so we can skip it later
213
- if not (np .diff (index ) > 0 ).all ():
214
- raise ValueError ("Index must be monotonicly increasing" )
252
+
215
253
else :
216
254
axis = arr .get_axis_num (dim )
217
255
index = np .arange (arr .shape [axis ], dtype = np .float64 )
@@ -220,7 +258,13 @@ def get_clean_interp_index(arr, dim, use_coordinate=True):
220
258
221
259
222
260
def interp_na (
223
- self , dim = None , use_coordinate = True , method = "linear" , limit = None , ** kwargs
261
+ self ,
262
+ dim : Hashable = None ,
263
+ use_coordinate : Union [bool , str ] = True ,
264
+ method : str = "linear" ,
265
+ limit : int = None ,
266
+ max_gap : Union [int , float , str , pd .Timedelta , np .timedelta64 ] = None ,
267
+ ** kwargs ,
224
268
):
225
269
"""Interpolate values according to different methods.
226
270
"""
@@ -230,6 +274,40 @@ def interp_na(
230
274
if limit is not None :
231
275
valids = _get_valid_fill_mask (self , dim , limit )
232
276
277
+ if max_gap is not None :
278
+ max_type = type (max_gap ).__name__
279
+ if not is_scalar (max_gap ):
280
+ raise ValueError ("max_gap must be a scalar." )
281
+
282
+ if (
283
+ dim in self .indexes
284
+ and isinstance (self .indexes [dim ], pd .DatetimeIndex )
285
+ and use_coordinate
286
+ ):
287
+ if not isinstance (max_gap , (np .timedelta64 , pd .Timedelta , str )):
288
+ raise TypeError (
289
+ f"Underlying index is DatetimeIndex. Expected max_gap of type str, pandas.Timedelta or numpy.timedelta64 but received { max_type } "
290
+ )
291
+
292
+ if isinstance (max_gap , str ):
293
+ try :
294
+ max_gap = pd .to_timedelta (max_gap )
295
+ except ValueError :
296
+ raise ValueError (
297
+ f"Could not convert { max_gap !r} to timedelta64 using pandas.to_timedelta"
298
+ )
299
+
300
+ if isinstance (max_gap , pd .Timedelta ):
301
+ max_gap = np .timedelta64 (max_gap .value , "ns" )
302
+
303
+ max_gap = np .timedelta64 (max_gap , "ns" ).astype (np .float64 )
304
+
305
+ if not use_coordinate :
306
+ if not isinstance (max_gap , (Number , np .number )):
307
+ raise TypeError (
308
+ f"Expected integer or floating point max_gap since use_coordinate=False. Received { max_type } ."
309
+ )
310
+
233
311
# method
234
312
index = get_clean_interp_index (self , dim , use_coordinate = use_coordinate )
235
313
interp_class , kwargs = _get_interpolator (method , ** kwargs )
@@ -253,6 +331,14 @@ def interp_na(
253
331
if limit is not None :
254
332
arr = arr .where (valids )
255
333
334
+ if max_gap is not None :
335
+ if dim not in self .coords :
336
+ raise NotImplementedError (
337
+ "max_gap not implemented for unlabeled coordinates yet."
338
+ )
339
+ nan_block_lengths = _get_nan_block_lengths (self , dim , index )
340
+ arr = arr .where (nan_block_lengths <= max_gap )
341
+
256
342
return arr
257
343
258
344
0 commit comments