|
30 | 30 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
31 | 31 | import builtins
|
32 | 32 | import operator
|
| 33 | +from distutils.version import LooseVersion |
33 | 34 | from typing import Union
|
34 | 35 |
|
35 | 36 | import numpy as np
|
@@ -96,3 +97,99 @@ def __array_function__(self, *args, **kwargs):
|
96 | 97 |
|
97 | 98 |
|
98 | 99 | IS_NEP18_ACTIVE = _is_nep18_active()
|
| 100 | + |
| 101 | + |
| 102 | +if LooseVersion(np.__version__) >= "1.20.0": |
| 103 | + sliding_window_view = np.lib.stride_tricks.sliding_window_view |
| 104 | +else: |
| 105 | + from numpy.core.numeric import normalize_axis_tuple # type: ignore |
| 106 | + from numpy.lib.stride_tricks import as_strided |
| 107 | + |
| 108 | + # copied from numpy.lib.stride_tricks |
| 109 | + def sliding_window_view( |
| 110 | + x, window_shape, axis=None, *, subok=False, writeable=False |
| 111 | + ): |
| 112 | + """ |
| 113 | + Create a sliding window view into the array with the given window shape. |
| 114 | +
|
| 115 | + Also known as rolling or moving window, the window slides across all |
| 116 | + dimensions of the array and extracts subsets of the array at all window |
| 117 | + positions. |
| 118 | +
|
| 119 | + .. versionadded:: 1.20.0 |
| 120 | +
|
| 121 | + Parameters |
| 122 | + ---------- |
| 123 | + x : array_like |
| 124 | + Array to create the sliding window view from. |
| 125 | + window_shape : int or tuple of int |
| 126 | + Size of window over each axis that takes part in the sliding window. |
| 127 | + If `axis` is not present, must have same length as the number of input |
| 128 | + array dimensions. Single integers `i` are treated as if they were the |
| 129 | + tuple `(i,)`. |
| 130 | + axis : int or tuple of int, optional |
| 131 | + Axis or axes along which the sliding window is applied. |
| 132 | + By default, the sliding window is applied to all axes and |
| 133 | + `window_shape[i]` will refer to axis `i` of `x`. |
| 134 | + If `axis` is given as a `tuple of int`, `window_shape[i]` will refer to |
| 135 | + the axis `axis[i]` of `x`. |
| 136 | + Single integers `i` are treated as if they were the tuple `(i,)`. |
| 137 | + subok : bool, optional |
| 138 | + If True, sub-classes will be passed-through, otherwise the returned |
| 139 | + array will be forced to be a base-class array (default). |
| 140 | + writeable : bool, optional |
| 141 | + When true, allow writing to the returned view. The default is false, |
| 142 | + as this should be used with caution: the returned view contains the |
| 143 | + same memory location multiple times, so writing to one location will |
| 144 | + cause others to change. |
| 145 | +
|
| 146 | + Returns |
| 147 | + ------- |
| 148 | + view : ndarray |
| 149 | + Sliding window view of the array. The sliding window dimensions are |
| 150 | + inserted at the end, and the original dimensions are trimmed as |
| 151 | + required by the size of the sliding window. |
| 152 | + That is, ``view.shape = x_shape_trimmed + window_shape``, where |
| 153 | + ``x_shape_trimmed`` is ``x.shape`` with every entry reduced by one less |
| 154 | + than the corresponding window size. |
| 155 | + """ |
| 156 | + window_shape = ( |
| 157 | + tuple(window_shape) if np.iterable(window_shape) else (window_shape,) |
| 158 | + ) |
| 159 | + # first convert input to array, possibly keeping subclass |
| 160 | + x = np.array(x, copy=False, subok=subok) |
| 161 | + |
| 162 | + window_shape_array = np.array(window_shape) |
| 163 | + if np.any(window_shape_array < 0): |
| 164 | + raise ValueError("`window_shape` cannot contain negative values") |
| 165 | + |
| 166 | + if axis is None: |
| 167 | + axis = tuple(range(x.ndim)) |
| 168 | + if len(window_shape) != len(axis): |
| 169 | + raise ValueError( |
| 170 | + f"Since axis is `None`, must provide " |
| 171 | + f"window_shape for all dimensions of `x`; " |
| 172 | + f"got {len(window_shape)} window_shape elements " |
| 173 | + f"and `x.ndim` is {x.ndim}." |
| 174 | + ) |
| 175 | + else: |
| 176 | + axis = normalize_axis_tuple(axis, x.ndim, allow_duplicate=True) |
| 177 | + if len(window_shape) != len(axis): |
| 178 | + raise ValueError( |
| 179 | + f"Must provide matching length window_shape and " |
| 180 | + f"axis; got {len(window_shape)} window_shape " |
| 181 | + f"elements and {len(axis)} axes elements." |
| 182 | + ) |
| 183 | + |
| 184 | + out_strides = x.strides + tuple(x.strides[ax] for ax in axis) |
| 185 | + |
| 186 | + # note: same axis can be windowed repeatedly |
| 187 | + x_shape_trimmed = list(x.shape) |
| 188 | + for ax, dim in zip(axis, window_shape): |
| 189 | + if x_shape_trimmed[ax] < dim: |
| 190 | + raise ValueError("window shape cannot be larger than input array shape") |
| 191 | + x_shape_trimmed[ax] -= dim - 1 |
| 192 | + out_shape = tuple(x_shape_trimmed) + window_shape |
| 193 | + return as_strided( |
| 194 | + x, strides=out_strides, shape=out_shape, subok=subok, writeable=writeable |
| 195 | + ) |
0 commit comments