Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ Deprecations

Bug fixes
~~~~~~~~~

- Fix :py:meth:`DataArray.ffill`, :py:meth:`DataArray.bfill`, :py:meth:`Dataset.ffill` and :py:meth:`Dataset.bfill` when the limit is bigger than the chunksize (:issue:`9939`).
By `Joseph Nowak <https://github.com/josephnowak>`_.
- Fix issues related to Pandas v3 ("us" vs. "ns" for python datetime, copy on write) and handling of 0d-numpy arrays in datetime/timedelta decoding (:pull:`9953`).
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
- Remove dask-expr from CI runs, add "pyarrow" dask dependency to windows CI runs, fix related tests (:issue:`9962`, :pull:`9971`).
Expand Down
50 changes: 13 additions & 37 deletions xarray/core/dask_array_ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import math
from functools import partial

from xarray.core import dtypes, nputils

Expand Down Expand Up @@ -92,31 +91,6 @@ def _dtype_push(a, axis, dtype=None):
return _push(a, axis=axis)


def _reset_cumsum(a, axis, dtype=None):
import numpy as np

cumsum = np.cumsum(a, axis=axis)
reset_points = np.maximum.accumulate(np.where(a == 0, cumsum, 0), axis=axis)
return cumsum - reset_points


def _last_reset_cumsum(a, axis, keepdims=None):
import numpy as np

# Take the last cumulative sum taking into account the reset
# This is useful for blelloch method
return np.take(_reset_cumsum(a, axis=axis), axis=axis, indices=[-1])


def _combine_reset_cumsum(a, b, axis):
import numpy as np

# It is going to sum the previous result until the first
# non nan value
bitmask = np.cumprod(b != 0, axis=axis)
return np.where(bitmask, b + a, b)


def push(array, n, axis, method="blelloch"):
"""
Dask-aware bottleneck.push
Expand Down Expand Up @@ -145,16 +119,18 @@ def push(array, n, axis, method="blelloch"):
)

if n is not None and 0 < n < array.shape[axis] - 1:
valid_positions = da.reductions.cumreduction(
func=_reset_cumsum,
binop=partial(_combine_reset_cumsum, axis=axis),
ident=0,
x=da.isnan(array, dtype=int),
axis=axis,
dtype=int,
method=method,
preop=_last_reset_cumsum,
)
pushed_array = da.where(valid_positions <= n, pushed_array, np.nan)
# The idea is to calculate a cumulative sum of a bitmask
# created from the isnan method, but every time a False is found the sum
# must be restarted, and the final result indicates the amount of contiguous
# nan values found in the original array on every position
nan_bitmask = da.isnan(array, dtype=int)
cumsum_nan = nan_bitmask.cumsum(axis=axis, method=method)
valid_positions = da.where(nan_bitmask == 0, cumsum_nan, np.nan)
valid_positions = push(valid_positions, None, axis, method=method)
# All the NaNs at the beginning are converted to 0
valid_positions = da.nan_to_num(valid_positions)
valid_positions = cumsum_nan - valid_positions
valid_positions = valid_positions <= n
pushed_array = da.where(valid_positions, pushed_array, np.nan)

return pushed_array
45 changes: 27 additions & 18 deletions xarray/tests/test_duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,31 +1025,40 @@ def test_least_squares(use_dask, skipna):
@requires_dask
@requires_bottleneck
@pytest.mark.parametrize("method", ["sequential", "blelloch"])
def test_push_dask(method):
@pytest.mark.parametrize(
"arr",
[
[np.nan, 1, 2, 3, np.nan, np.nan, np.nan, np.nan, 4, 5, np.nan, 6],
[
np.nan,
np.nan,
np.nan,
2,
np.nan,
np.nan,
np.nan,
9,
np.nan,
np.nan,
np.nan,
np.nan,
],
],
)
def test_push_dask(method, arr):
import bottleneck
import dask.array
import dask.array as da

array = np.array([np.nan, 1, 2, 3, np.nan, np.nan, np.nan, np.nan, 4, 5, np.nan, 6])
arr = np.array(arr)
chunks = list(range(1, 11)) + [(1, 2, 3, 2, 2, 1, 1)]

for n in [None, 1, 2, 3, 4, 5, 11]:
expected = bottleneck.push(array, axis=0, n=n)
for c in range(1, 11):
expected = bottleneck.push(arr, axis=0, n=n)
for c in chunks:
with raise_if_dask_computes():
actual = push(
dask.array.from_array(array, chunks=c), axis=0, n=n, method=method
)
actual = push(da.from_array(arr, chunks=c), axis=0, n=n, method=method)
np.testing.assert_equal(actual, expected)

# some chunks of size-1 with NaN
with raise_if_dask_computes():
actual = push(
dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)),
axis=0,
n=n,
method=method,
)
np.testing.assert_equal(actual, expected)


def test_extension_array_equality(categorical1, int1):
int_duck_array = PandasExtensionArray(int1)
Expand Down
Loading