|
3 | 3 | # This file is part of Iris and is released under the BSD license. |
4 | 4 | # See LICENSE in the root of the repository for full licensing details. |
5 | 5 | """Miscellaneous utility functions.""" |
| 6 | +from __future__ import annotations |
6 | 7 |
|
7 | 8 | from abc import ABCMeta, abstractmethod |
8 | 9 | from collections.abc import Hashable, Iterable |
@@ -282,7 +283,12 @@ def guess_coord_axis(coord): |
282 | 283 | return axis |
283 | 284 |
|
284 | 285 |
|
285 | | -def rolling_window(a, window=1, step=1, axis=-1): |
| 286 | +def rolling_window( |
| 287 | + a: np.ndarray | da.Array, |
| 288 | + window: int = 1, |
| 289 | + step: int = 1, |
| 290 | + axis: int = -1, |
| 291 | +) -> np.ndarray | da.Array: |
286 | 292 | """Make an ndarray with a rolling window of the last dimension. |
287 | 293 |
|
288 | 294 | Parameters |
@@ -323,34 +329,33 @@ def rolling_window(a, window=1, step=1, axis=-1): |
323 | 329 | See more at :doc:`/userguide/real_and_lazy_data`. |
324 | 330 |
|
325 | 331 | """ |
326 | | - # NOTE: The implementation of this function originates from |
327 | | - # https://github.com/numpy/numpy/pull/31#issuecomment-1304851 04/08/2011 |
328 | 332 | if window < 1: |
329 | 333 | raise ValueError("`window` must be at least 1.") |
330 | 334 | if window > a.shape[axis]: |
331 | 335 | raise ValueError("`window` is too long.") |
332 | 336 | if step < 1: |
333 | 337 | raise ValueError("`step` must be at least 1.") |
334 | 338 | axis = axis % a.ndim |
335 | | - num_windows = (a.shape[axis] - window + step) // step |
336 | | - shape = a.shape[:axis] + (num_windows, window) + a.shape[axis + 1 :] |
337 | | - strides = ( |
338 | | - a.strides[:axis] |
339 | | - + (step * a.strides[axis], a.strides[axis]) |
340 | | - + a.strides[axis + 1 :] |
| 339 | + array_module = da if isinstance(a, da.Array) else np |
| 340 | + steps = tuple( |
| 341 | + slice(None, None, step) if i == axis else slice(None) for i in range(a.ndim) |
341 | 342 | ) |
342 | | - rw = np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides) |
343 | | - if ma.isMaskedArray(a): |
344 | | - mask = ma.getmaskarray(a) |
345 | | - strides = ( |
346 | | - mask.strides[:axis] |
347 | | - + (step * mask.strides[axis], mask.strides[axis]) |
348 | | - + mask.strides[axis + 1 :] |
349 | | - ) |
350 | | - rw = ma.array( |
351 | | - rw, |
352 | | - mask=np.lib.stride_tricks.as_strided(mask, shape=shape, strides=strides), |
| 343 | + |
| 344 | + def _rolling_window(array): |
| 345 | + return array_module.moveaxis( |
| 346 | + array_module.lib.stride_tricks.sliding_window_view( |
| 347 | + array, |
| 348 | + window_shape=window, |
| 349 | + axis=axis, |
| 350 | + )[steps], |
| 351 | + -1, |
| 352 | + axis + 1, |
353 | 353 | ) |
| 354 | + |
| 355 | + rw = _rolling_window(a) |
| 356 | + if isinstance(da.utils.meta_from_array(a), np.ma.MaskedArray): |
| 357 | + mask = _rolling_window(array_module.ma.getmaskarray(a)) |
| 358 | + rw = array_module.ma.masked_array(rw, mask) |
354 | 359 | return rw |
355 | 360 |
|
356 | 361 |
|
|
0 commit comments