From cb16e2a4636b8b3bb10edf0329d854397de95900 Mon Sep 17 00:00:00 2001 From: Tom White Date: Wed, 10 Apr 2024 10:05:06 +0100 Subject: [PATCH 01/15] Initial minimal working Cubed example for "map-reduce" --- cubed-example.ipynb | 1630 +++++++++++++++++++++++++++++++++++++++++++ flox/core.py | 134 +++- flox/xrutils.py | 20 +- tests/__init__.py | 1 + tests/test_core.py | 40 ++ 5 files changed, 1822 insertions(+), 3 deletions(-) create mode 100644 cubed-example.ipynb diff --git a/cubed-example.ipynb b/cubed-example.ipynb new file mode 100644 index 00000000..57fc6b3f --- /dev/null +++ b/cubed-example.ipynb @@ -0,0 +1,1630 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "8c33b843-dbaf-4320-a4ae-868b732a1171", + "metadata": {}, + "outputs": [], + "source": [ + "# based on https://flox.readthedocs.io/en/latest/user-stories/climatology-hourly.html\n", + "# but with smaller data sizes so it can be run locally\n", + "\n", + "import cubed\n", + "import cubed.array_api as xp\n", + "import numpy as np\n", + "import pandas as pd\n", + "import xarray as xr\n", + "\n", + "import flox.xarray" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "674c8844-f411-4a1a-b055-894b954639a4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset> Size: 363MB\n",
+       "Dimensions:  (time: 8760, latitude: 72, longitude: 144)\n",
+       "Coordinates:\n",
+       "  * time     (time) datetime64[ns] 70kB 2021-01-01 ... 2021-12-31T23:00:00\n",
+       "Dimensions without coordinates: latitude, longitude\n",
+       "Data variables:\n",
+       "    tp       (time, latitude, longitude) float32 363MB cubed.Array<chunksize=(744, 5, 144)>
" + ], + "text/plain": [ + " Size: 363MB\n", + "Dimensions: (time: 8760, latitude: 72, longitude: 144)\n", + "Coordinates:\n", + " * time (time) datetime64[ns] 70kB 2021-01-01 ... 2021-12-31T23:00:00\n", + "Dimensions without coordinates: latitude, longitude\n", + "Data variables:\n", + " tp (time, latitude, longitude) float32 363MB cubed.Array" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "spec = cubed.Spec(allowed_mem=\"2GB\")\n", + "ds = xr.Dataset(\n", + " {\n", + " \"tp\": (\n", + " (\"time\", \"latitude\", \"longitude\"),\n", + " xp.ones((8760, 72, 144), chunks=(744, 5, 144), dtype=np.float32, spec=spec),\n", + " )\n", + " },\n", + " coords={\"time\": pd.date_range(\"2021-01-01\", \"2021-12-31 23:59\", freq=\"h\")},\n", + ")\n", + "ds" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "3137d5f5-0706-46a5-8c63-be9c7e420229", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'tp' (hour: 24, latitude: 72, longitude: 144)> Size: 995kB\n",
+       "cubed.Array<array-018, shape=(24, 72, 144), dtype=float32, chunks=((24,), (5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 2), (144,))>\n",
+       "Coordinates:\n",
+       "  * hour     (hour) int64 192B 0 1 2 3 4 5 6 7 8 ... 15 16 17 18 19 20 21 22 23\n",
+       "Dimensions without coordinates: latitude, longitude
" + ], + "text/plain": [ + " Size: 995kB\n", + "cubed.Array\n", + "Coordinates:\n", + " * hour (hour) int64 192B 0 1 2 3 4 5 6 7 8 ... 15 16 17 18 19 20 21 22 23\n", + "Dimensions without coordinates: latitude, longitude" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hourly = flox.xarray.xarray_reduce(ds.tp, ds.time.dt.hour, func=\"mean\")\n", + "hourly" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "676d3f12-13f6-4b8f-ae74-5a9017dd69fd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'tp' (hour: 24, latitude: 72, longitude: 144)> Size: 995kB\n",
+       "array([[[1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        ...,\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.]],\n",
+       "\n",
+       "       [[1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        ...,\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.]],\n",
+       "\n",
+       "       [[1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        ...,\n",
+       "...\n",
+       "        ...,\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.]],\n",
+       "\n",
+       "       [[1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        ...,\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.]],\n",
+       "\n",
+       "       [[1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        ...,\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.],\n",
+       "        [1., 1., 1., ..., 1., 1., 1.]]], dtype=float32)\n",
+       "Coordinates:\n",
+       "  * hour     (hour) int64 192B 0 1 2 3 4 5 6 7 8 ... 15 16 17 18 19 20 21 22 23\n",
+       "Dimensions without coordinates: latitude, longitude
" + ], + "text/plain": [ + " Size: 995kB\n", + "array([[[1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " ...,\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.]],\n", + "\n", + " [[1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " ...,\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.]],\n", + "\n", + " [[1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " ...,\n", + "...\n", + " ...,\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.]],\n", + "\n", + " [[1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " ...,\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.]],\n", + "\n", + " [[1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " ...,\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.],\n", + " [1., 1., 1., ..., 1., 1., 1.]]], dtype=float32)\n", + "Coordinates:\n", + " * hour (hour) int64 192B 0 1 2 3 4 5 6 7 8 ... 15 16 17 18 19 20 21 22 23\n", + "Dimensions without coordinates: latitude, longitude" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hourly.compute()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8b59dcfc-5723-400c-977a-9cfc38bfa303", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/flox/core.py b/flox/core.py index 76563416..b0fbd855 100644 --- a/flox/core.py +++ b/flox/core.py @@ -38,7 +38,9 @@ ) from .cache import memoize from .xrutils import ( + is_chunked_array, is_duck_array, + is_duck_cubed_array, is_duck_dask_array, isnull, module_available, @@ -1718,6 +1720,112 @@ def dask_groupby_agg( return (result, groups) +def cubed_groupby_agg( + array: DaskArray, + by: T_By, + agg: Aggregation, + expected_groups: pd.Index | None, + axis: T_Axes = (), + fill_value: Any = None, + method: T_Method = "map-reduce", + reindex: bool = False, + engine: T_Engine = "numpy", + sort: bool = True, + chunks_cohorts=None, +) -> tuple[DaskArray, tuple[np.ndarray | DaskArray]]: + import cubed + import cubed.core.groupby + + # I think _tree_reduce expects this + assert isinstance(axis, Sequence) + assert all(ax >= 0 for ax in axis) + + inds = tuple(range(array.ndim)) + + by_input = by + + # Unifying chunks is necessary for argreductions. + # We need to rechunk before zipping up with the index + # let's always do it anyway + if not is_chunked_array(by): + # chunk numpy arrays like the input array + chunks = tuple(array.chunks[ax] if by.shape[ax] != 1 else (1,) for ax in range(-by.ndim, 0)) + + by = cubed.from_array(by, chunks=chunks, spec=array.spec) + _, (array, by) = cubed.core.unify_chunks(array, inds, by, inds[-by.ndim :]) + + # Cubed's groupby_reduction handles the generation of "intermediates", and the + # "map-reduce" combination step, so we don't have to do that here. + # Only the equivalent of "_simple_combine" is supported, there is no + # support for "_grouped_combine". + labels_are_unknown = is_chunked_array(by_input) and expected_groups is None + do_simple_combine = not _is_arg_reduction(agg) and not labels_are_unknown + + assert do_simple_combine + assert method == "map-reduce" + assert len(axis) == 1 # one axis/grouping + + def _groupby_func(a, by, axis, intermediate_dtype, num_groups): + blockwise_method = partial( + _get_chunk_reduction(agg.reduction_type), + func=agg.chunk, + fill_value=agg.fill_value["intermediate"], + dtype=agg.dtype["intermediate"], + reindex=reindex, + user_dtype=agg.dtype["user"], + axis=axis, + expected_groups=expected_groups if reindex else None, + engine=engine, + sort=sort, + ) + out = blockwise_method(a, by) + # Convert dict to one that cubed understands, dropping groups since they are + # known, and the same for every block. + return {f"f{idx}": intermediate for idx, intermediate in enumerate(out["intermediates"])} + + def _groupby_combine(a, axis, dummy_axis, dtype, keepdims): + # this is similar to _simple_combine, except the dummy axis and concatenation is handled by cubed + # only combine over the dummy axis, to preserve grouping along 'axis' + dtype = dict(dtype) + out = {} + for idx, combine in enumerate(agg.simple_combine): + field = f"f{idx}" + out[field] = combine(a[field], dtype=dtype[field], axis=dummy_axis, keepdims=keepdims) + return out + + def _groupby_aggregate(a): + # this is similar to _finalize_results, but not as comprehensive + arrs = tuple(v for v in a.values()) + if agg.finalize is None: + assert len(arrs) == 1 + out = arrs[0] + else: + out = agg.finalize(*arrs, **agg.finalize_kwargs) + out = out.astype(agg.dtype["final"], copy=False) + return out + + # convert list of dtypes to a structured dtype for cubed + intermediate_dtype = [(f"f{i}", dtype) for i, dtype in enumerate(agg.dtype["intermediate"])] + dtype = agg.dtype["final"] + num_groups = len(expected_groups) + + result = cubed.core.groupby.groupby_reduction( + array, + by, + func=_groupby_func, + combine_func=_groupby_combine, + aggegrate_func=_groupby_aggregate, + axis=axis, + intermediate_dtype=intermediate_dtype, + dtype=dtype, + num_groups=num_groups, + ) + + groups = (expected_groups.to_numpy(),) + + return (result, groups) + + def _collapse_blocks_along_axes(reduced: DaskArray, axis: T_Axes, group_chunks) -> DaskArray: import dask.array from dask.highlevelgraph import HighLevelGraph @@ -2240,6 +2348,7 @@ def groupby_reduce( nax = len(axis_) has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_) + has_cubed = is_duck_cubed_array(array) or is_duck_cubed_array(by_) if _is_first_last_reduction(func): if has_dask and nax != 1: @@ -2302,7 +2411,30 @@ def groupby_reduce( kwargs["engine"] = _choose_engine(by_, agg) if engine is None else engine groups: tuple[np.ndarray | DaskArray, ...] - if not has_dask: + if has_cubed: + if method is None: + method = "map-reduce" + + if method != "map-reduce": + raise NotImplementedError( + "Reduction for Cubed arrays is only implemented for method 'map-reduce'." + ) + + partial_agg = partial(cubed_groupby_agg, **kwargs) + + result, groups = partial_agg( + array, + by_, + expected_groups=expected_, + agg=agg, + reindex=reindex, + method=method, + sort=sort, + ) + + return (result, groups) + + elif not has_dask: results = _reduce_blockwise( array, by_, agg, expected_groups=expected_, reindex=reindex, sort=sort, **kwargs ) diff --git a/flox/xrutils.py b/flox/xrutils.py index e474fd40..fe9a5c85 100644 --- a/flox/xrutils.py +++ b/flox/xrutils.py @@ -37,11 +37,18 @@ def is_duck_array(value: Any) -> bool: hasattr(value, "ndim") and hasattr(value, "shape") and hasattr(value, "dtype") - and hasattr(value, "__array_function__") - and hasattr(value, "__array_ufunc__") + and ( + (hasattr(value, "__array_function__") and hasattr(value, "__array_ufunc__")) + or hasattr(value, "__array_namespace__") + ) ) +def is_chunked_array(x) -> bool: + """True if dask or cubed""" + return is_duck_dask_array(x) or (is_duck_array(x) and hasattr(x, "chunks")) + + def is_dask_collection(x): try: import dask @@ -56,6 +63,15 @@ def is_duck_dask_array(x): return is_duck_array(x) and is_dask_collection(x) +def is_duck_cubed_array(x): + try: + import cubed + + return is_duck_array(x) and isinstance(x, cubed.Array) + except ImportError: + return False + + class ReprObject: """Object that prints as the given value, for use with sentinel values.""" diff --git a/tests/__init__.py b/tests/__init__.py index 2ebeb5ad..615b4a91 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -46,6 +46,7 @@ def LooseVersion(vstring): has_cftime, requires_cftime = _importorskip("cftime") +has_cubed, requires_cubed = _importorskip("cubed") has_dask, requires_dask = _importorskip("dask") has_numba, requires_numba = _importorskip("numba") has_numbagg, requires_numbagg = _importorskip("numbagg") diff --git a/tests/test_core.py b/tests/test_core.py index 19c96758..aca5cdd4 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -36,8 +36,10 @@ SCIPY_STATS_FUNCS, assert_equal, assert_equal_tuple, + has_cubed, has_dask, raise_if_dask_computes, + requires_cubed, requires_dask, ) @@ -59,6 +61,9 @@ def dask_array_ones(*args): return None + +if has_cubed: + import cubed DEFAULT_QUANTILE = 0.9 @@ -477,6 +482,41 @@ def test_groupby_agg_dask(func, shape, array_chunks, group_chunks, add_nan, dtyp assert_equal(expected, actual) +@requires_cubed +@pytest.mark.parametrize("reindex", [True]) +@pytest.mark.parametrize("func", ["mean", "sum"]) +@pytest.mark.parametrize("add_nan", [False]) +@pytest.mark.parametrize("dtype", (float,)) +@pytest.mark.parametrize( + "shape, array_chunks, group_chunks", + [ + ((12,), (3,), 3), # form 1 + ], +) +def test_groupby_agg_cubed(func, shape, array_chunks, group_chunks, add_nan, dtype, engine, reindex): + """Tests groupby_reduce with cubed arrays against groupby_reduce with numpy arrays""" + + array = cubed.array_api.ones(shape, chunks=array_chunks) + + labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0]) + if add_nan: + labels = labels.astype(float) + labels[:3] = np.nan # entire block is NaN when group_chunks=3 + labels[-2:] = np.nan + + kwargs = dict( + func=func, expected_groups=[0, 1, 2], reindex=reindex + ) + + expected, _ = groupby_reduce(array.compute(), labels, engine="numpy", **kwargs) + actual, _ = groupby_reduce(array.compute(), labels, engine=engine, **kwargs) + assert_equal(actual, expected) + + # TODO: raise_if_cubed_computes + actual, _ = groupby_reduce(array, labels, engine=engine, **kwargs) + assert_equal(expected, actual) + + def test_numpy_reduce_axis_subset(engine): # TODO: add NaNs by = labels2d From ccae0d67b58665b5fe5a99847d530fd6653782ff Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Apr 2024 14:11:44 +0000 Subject: [PATCH 02/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- cubed-example.ipynb | 1559 +------------------------------------------ tests/test_core.py | 11 +- 2 files changed, 14 insertions(+), 1556 deletions(-) diff --git a/cubed-example.ipynb b/cubed-example.ipynb index 57fc6b3f..b2641735 100644 --- a/cubed-example.ipynb +++ b/cubed-example.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "8c33b843-dbaf-4320-a4ae-868b732a1171", "metadata": {}, "outputs": [], @@ -21,543 +21,10 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "674c8844-f411-4a1a-b055-894b954639a4", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.Dataset> Size: 363MB\n",
-       "Dimensions:  (time: 8760, latitude: 72, longitude: 144)\n",
-       "Coordinates:\n",
-       "  * time     (time) datetime64[ns] 70kB 2021-01-01 ... 2021-12-31T23:00:00\n",
-       "Dimensions without coordinates: latitude, longitude\n",
-       "Data variables:\n",
-       "    tp       (time, latitude, longitude) float32 363MB cubed.Array<chunksize=(744, 5, 144)>
" - ], - "text/plain": [ - " Size: 363MB\n", - "Dimensions: (time: 8760, latitude: 72, longitude: 144)\n", - "Coordinates:\n", - " * time (time) datetime64[ns] 70kB 2021-01-01 ... 2021-12-31T23:00:00\n", - "Dimensions without coordinates: latitude, longitude\n", - "Data variables:\n", - " tp (time, latitude, longitude) float32 363MB cubed.Array" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "spec = cubed.Spec(allowed_mem=\"2GB\")\n", "ds = xr.Dataset(\n", @@ -574,506 +41,10 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "3137d5f5-0706-46a5-8c63-be9c7e420229", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.DataArray 'tp' (hour: 24, latitude: 72, longitude: 144)> Size: 995kB\n",
-       "cubed.Array<array-018, shape=(24, 72, 144), dtype=float32, chunks=((24,), (5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 2), (144,))>\n",
-       "Coordinates:\n",
-       "  * hour     (hour) int64 192B 0 1 2 3 4 5 6 7 8 ... 15 16 17 18 19 20 21 22 23\n",
-       "Dimensions without coordinates: latitude, longitude
" - ], - "text/plain": [ - " Size: 995kB\n", - "cubed.Array\n", - "Coordinates:\n", - " * hour (hour) int64 192B 0 1 2 3 4 5 6 7 8 ... 15 16 17 18 19 20 21 22 23\n", - "Dimensions without coordinates: latitude, longitude" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "hourly = flox.xarray.xarray_reduce(ds.tp, ds.time.dt.hour, func=\"mean\")\n", "hourly" @@ -1081,518 +52,10 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "676d3f12-13f6-4b8f-ae74-5a9017dd69fd", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.DataArray 'tp' (hour: 24, latitude: 72, longitude: 144)> Size: 995kB\n",
-       "array([[[1., 1., 1., ..., 1., 1., 1.],\n",
-       "        [1., 1., 1., ..., 1., 1., 1.],\n",
-       "        [1., 1., 1., ..., 1., 1., 1.],\n",
-       "        ...,\n",
-       "        [1., 1., 1., ..., 1., 1., 1.],\n",
-       "        [1., 1., 1., ..., 1., 1., 1.],\n",
-       "        [1., 1., 1., ..., 1., 1., 1.]],\n",
-       "\n",
-       "       [[1., 1., 1., ..., 1., 1., 1.],\n",
-       "        [1., 1., 1., ..., 1., 1., 1.],\n",
-       "        [1., 1., 1., ..., 1., 1., 1.],\n",
-       "        ...,\n",
-       "        [1., 1., 1., ..., 1., 1., 1.],\n",
-       "        [1., 1., 1., ..., 1., 1., 1.],\n",
-       "        [1., 1., 1., ..., 1., 1., 1.]],\n",
-       "\n",
-       "       [[1., 1., 1., ..., 1., 1., 1.],\n",
-       "        [1., 1., 1., ..., 1., 1., 1.],\n",
-       "        [1., 1., 1., ..., 1., 1., 1.],\n",
-       "        ...,\n",
-       "...\n",
-       "        ...,\n",
-       "        [1., 1., 1., ..., 1., 1., 1.],\n",
-       "        [1., 1., 1., ..., 1., 1., 1.],\n",
-       "        [1., 1., 1., ..., 1., 1., 1.]],\n",
-       "\n",
-       "       [[1., 1., 1., ..., 1., 1., 1.],\n",
-       "        [1., 1., 1., ..., 1., 1., 1.],\n",
-       "        [1., 1., 1., ..., 1., 1., 1.],\n",
-       "        ...,\n",
-       "        [1., 1., 1., ..., 1., 1., 1.],\n",
-       "        [1., 1., 1., ..., 1., 1., 1.],\n",
-       "        [1., 1., 1., ..., 1., 1., 1.]],\n",
-       "\n",
-       "       [[1., 1., 1., ..., 1., 1., 1.],\n",
-       "        [1., 1., 1., ..., 1., 1., 1.],\n",
-       "        [1., 1., 1., ..., 1., 1., 1.],\n",
-       "        ...,\n",
-       "        [1., 1., 1., ..., 1., 1., 1.],\n",
-       "        [1., 1., 1., ..., 1., 1., 1.],\n",
-       "        [1., 1., 1., ..., 1., 1., 1.]]], dtype=float32)\n",
-       "Coordinates:\n",
-       "  * hour     (hour) int64 192B 0 1 2 3 4 5 6 7 8 ... 15 16 17 18 19 20 21 22 23\n",
-       "Dimensions without coordinates: latitude, longitude
" - ], - "text/plain": [ - " Size: 995kB\n", - "array([[[1., 1., 1., ..., 1., 1., 1.],\n", - " [1., 1., 1., ..., 1., 1., 1.],\n", - " [1., 1., 1., ..., 1., 1., 1.],\n", - " ...,\n", - " [1., 1., 1., ..., 1., 1., 1.],\n", - " [1., 1., 1., ..., 1., 1., 1.],\n", - " [1., 1., 1., ..., 1., 1., 1.]],\n", - "\n", - " [[1., 1., 1., ..., 1., 1., 1.],\n", - " [1., 1., 1., ..., 1., 1., 1.],\n", - " [1., 1., 1., ..., 1., 1., 1.],\n", - " ...,\n", - " [1., 1., 1., ..., 1., 1., 1.],\n", - " [1., 1., 1., ..., 1., 1., 1.],\n", - " [1., 1., 1., ..., 1., 1., 1.]],\n", - "\n", - " [[1., 1., 1., ..., 1., 1., 1.],\n", - " [1., 1., 1., ..., 1., 1., 1.],\n", - " [1., 1., 1., ..., 1., 1., 1.],\n", - " ...,\n", - "...\n", - " ...,\n", - " [1., 1., 1., ..., 1., 1., 1.],\n", - " [1., 1., 1., ..., 1., 1., 1.],\n", - " [1., 1., 1., ..., 1., 1., 1.]],\n", - "\n", - " [[1., 1., 1., ..., 1., 1., 1.],\n", - " [1., 1., 1., ..., 1., 1., 1.],\n", - " [1., 1., 1., ..., 1., 1., 1.],\n", - " ...,\n", - " [1., 1., 1., ..., 1., 1., 1.],\n", - " [1., 1., 1., ..., 1., 1., 1.],\n", - " [1., 1., 1., ..., 1., 1., 1.]],\n", - "\n", - " [[1., 1., 1., ..., 1., 1., 1.],\n", - " [1., 1., 1., ..., 1., 1., 1.],\n", - " [1., 1., 1., ..., 1., 1., 1.],\n", - " ...,\n", - " [1., 1., 1., ..., 1., 1., 1.],\n", - " [1., 1., 1., ..., 1., 1., 1.],\n", - " [1., 1., 1., ..., 1., 1., 1.]]], dtype=float32)\n", - "Coordinates:\n", - " * hour (hour) int64 192B 0 1 2 3 4 5 6 7 8 ... 15 16 17 18 19 20 21 22 23\n", - "Dimensions without coordinates: latitude, longitude" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "hourly.compute()" ] @@ -1607,11 +70,6 @@ } ], "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -1621,8 +79,7 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.8" + "pygments_lexer": "ipython3" } }, "nbformat": 4, diff --git a/tests/test_core.py b/tests/test_core.py index aca5cdd4..48175ad6 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -61,7 +61,8 @@ def dask_array_ones(*args): return None - + + if has_cubed: import cubed @@ -493,7 +494,9 @@ def test_groupby_agg_dask(func, shape, array_chunks, group_chunks, add_nan, dtyp ((12,), (3,), 3), # form 1 ], ) -def test_groupby_agg_cubed(func, shape, array_chunks, group_chunks, add_nan, dtype, engine, reindex): +def test_groupby_agg_cubed( + func, shape, array_chunks, group_chunks, add_nan, dtype, engine, reindex +): """Tests groupby_reduce with cubed arrays against groupby_reduce with numpy arrays""" array = cubed.array_api.ones(shape, chunks=array_chunks) @@ -504,9 +507,7 @@ def test_groupby_agg_cubed(func, shape, array_chunks, group_chunks, add_nan, dty labels[:3] = np.nan # entire block is NaN when group_chunks=3 labels[-2:] = np.nan - kwargs = dict( - func=func, expected_groups=[0, 1, 2], reindex=reindex - ) + kwargs = dict(func=func, expected_groups=[0, 1, 2], reindex=reindex) expected, _ = groupby_reduce(array.compute(), labels, engine="numpy", **kwargs) actual, _ = groupby_reduce(array.compute(), labels, engine=engine, **kwargs) From 3375f2864ceda7267c3ba6610dc4952704921f57 Mon Sep 17 00:00:00 2001 From: Tom White Date: Fri, 12 Apr 2024 12:32:23 +0100 Subject: [PATCH 03/15] Fix misspelled `aggegrate_func` --- flox/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flox/core.py b/flox/core.py index b0fbd855..0fb9ddb9 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1814,7 +1814,7 @@ def _groupby_aggregate(a): by, func=_groupby_func, combine_func=_groupby_combine, - aggegrate_func=_groupby_aggregate, + aggregate_func=_groupby_aggregate, axis=axis, intermediate_dtype=intermediate_dtype, dtype=dtype, From abdc032d004b554e06e5e9743e547ac0307d5f6b Mon Sep 17 00:00:00 2001 From: Tom White Date: Mon, 15 Apr 2024 12:02:25 +0100 Subject: [PATCH 04/15] Update flox/core.py Co-authored-by: Deepak Cherian --- flox/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flox/core.py b/flox/core.py index 0fb9ddb9..367dcefa 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1763,6 +1763,7 @@ def cubed_groupby_agg( assert do_simple_combine assert method == "map-reduce" + assert reindex is True assert len(axis) == 1 # one axis/grouping def _groupby_func(a, by, axis, intermediate_dtype, num_groups): From c740817574049bbe4ae7dec94b595fc559841cde Mon Sep 17 00:00:00 2001 From: Tom White Date: Fri, 12 Apr 2024 15:11:20 +0100 Subject: [PATCH 05/15] Expand to ALL_FUNCS --- flox/core.py | 2 +- tests/test_core.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/flox/core.py b/flox/core.py index 367dcefa..cbdea562 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1791,7 +1791,7 @@ def _groupby_combine(a, axis, dummy_axis, dtype, keepdims): out = {} for idx, combine in enumerate(agg.simple_combine): field = f"f{idx}" - out[field] = combine(a[field], dtype=dtype[field], axis=dummy_axis, keepdims=keepdims) + out[field] = combine(a[field], axis=dummy_axis, keepdims=keepdims) return out def _groupby_aggregate(a): diff --git a/tests/test_core.py b/tests/test_core.py index 48175ad6..17c94742 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -485,7 +485,7 @@ def test_groupby_agg_dask(func, shape, array_chunks, group_chunks, add_nan, dtyp @requires_cubed @pytest.mark.parametrize("reindex", [True]) -@pytest.mark.parametrize("func", ["mean", "sum"]) +@pytest.mark.parametrize("func", ALL_FUNCS) @pytest.mark.parametrize("add_nan", [False]) @pytest.mark.parametrize("dtype", (float,)) @pytest.mark.parametrize( @@ -499,6 +499,12 @@ def test_groupby_agg_cubed( ): """Tests groupby_reduce with cubed arrays against groupby_reduce with numpy arrays""" + if func in ["first", "last"] or func in BLOCKWISE_FUNCS: + pytest.skip() + + if "arg" in func and (engine in ["flox", "numbagg"] or reindex): + pytest.skip() + array = cubed.array_api.ones(shape, chunks=array_chunks) labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0]) From bed1bcf699245360ba47fa2363c321160a73ea46 Mon Sep 17 00:00:00 2001 From: Tom White Date: Mon, 15 Apr 2024 11:36:47 +0100 Subject: [PATCH 06/15] Use `_finalize_results` directly --- flox/core.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/flox/core.py b/flox/core.py index cbdea562..5ed3447e 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1795,15 +1795,10 @@ def _groupby_combine(a, axis, dummy_axis, dtype, keepdims): return out def _groupby_aggregate(a): - # this is similar to _finalize_results, but not as comprehensive - arrs = tuple(v for v in a.values()) - if agg.finalize is None: - assert len(arrs) == 1 - out = arrs[0] - else: - out = agg.finalize(*arrs, **agg.finalize_kwargs) - out = out.astype(agg.dtype["final"], copy=False) - return out + # Convert cubed dict to one that _finalize_results works with + results = {"groups": expected_groups, "intermediates": a.values()} + out = _finalize_results(results, agg, axis, expected_groups, fill_value, reindex) + return out[agg.name] # convert list of dtypes to a structured dtype for cubed intermediate_dtype = [(f"f{i}", dtype) for i, dtype in enumerate(agg.dtype["intermediate"])] From ca04f8a82e4035fa3c84b5742eea1d1d5b59548a Mon Sep 17 00:00:00 2001 From: Tom White Date: Mon, 15 Apr 2024 11:48:26 +0100 Subject: [PATCH 07/15] Add test for nan values --- tests/test_core.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 17c94742..c84b81e2 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -486,7 +486,7 @@ def test_groupby_agg_dask(func, shape, array_chunks, group_chunks, add_nan, dtyp @requires_cubed @pytest.mark.parametrize("reindex", [True]) @pytest.mark.parametrize("func", ALL_FUNCS) -@pytest.mark.parametrize("add_nan", [False]) +@pytest.mark.parametrize("add_nan", [False, True]) @pytest.mark.parametrize("dtype", (float,)) @pytest.mark.parametrize( "shape, array_chunks, group_chunks", @@ -513,7 +513,12 @@ def test_groupby_agg_cubed( labels[:3] = np.nan # entire block is NaN when group_chunks=3 labels[-2:] = np.nan - kwargs = dict(func=func, expected_groups=[0, 1, 2], reindex=reindex) + kwargs = dict( + func=func, + expected_groups=[0, 1, 2], + fill_value=False if func in ["all", "any"] else 123, + reindex=reindex, + ) expected, _ = groupby_reduce(array.compute(), labels, engine="numpy", **kwargs) actual, _ = groupby_reduce(array.compute(), labels, engine=engine, **kwargs) From 4dd215837a899eadcf6dc6c7ae90ac2dd4a51da7 Mon Sep 17 00:00:00 2001 From: Tom White Date: Mon, 15 Apr 2024 11:51:42 +0100 Subject: [PATCH 08/15] Removed unused dtype from test --- tests/test_core.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index c84b81e2..806662f9 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -487,7 +487,6 @@ def test_groupby_agg_dask(func, shape, array_chunks, group_chunks, add_nan, dtyp @pytest.mark.parametrize("reindex", [True]) @pytest.mark.parametrize("func", ALL_FUNCS) @pytest.mark.parametrize("add_nan", [False, True]) -@pytest.mark.parametrize("dtype", (float,)) @pytest.mark.parametrize( "shape, array_chunks, group_chunks", [ @@ -495,7 +494,7 @@ def test_groupby_agg_dask(func, shape, array_chunks, group_chunks, add_nan, dtyp ], ) def test_groupby_agg_cubed( - func, shape, array_chunks, group_chunks, add_nan, dtype, engine, reindex + func, shape, array_chunks, group_chunks, add_nan, engine, reindex ): """Tests groupby_reduce with cubed arrays against groupby_reduce with numpy arrays""" From dafaddbf532686087891af2d39244295590148dd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Apr 2024 11:07:27 +0000 Subject: [PATCH 09/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_core.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 806662f9..890f8656 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -493,9 +493,7 @@ def test_groupby_agg_dask(func, shape, array_chunks, group_chunks, add_nan, dtyp ((12,), (3,), 3), # form 1 ], ) -def test_groupby_agg_cubed( - func, shape, array_chunks, group_chunks, add_nan, engine, reindex -): +def test_groupby_agg_cubed(func, shape, array_chunks, group_chunks, add_nan, engine, reindex): """Tests groupby_reduce with cubed arrays against groupby_reduce with numpy arrays""" if func in ["first", "last"] or func in BLOCKWISE_FUNCS: From ee0e597819b8c6ddf4272f6a10155bec68d4e07e Mon Sep 17 00:00:00 2001 From: Tom White Date: Thu, 18 Apr 2024 12:47:42 +0100 Subject: [PATCH 10/15] Move example notebook to a gist https://gist.github.com/tomwhite/2d637d2581b44468da5b7e29c30c0c49 --- cubed-example.ipynb | 87 --------------------------------------------- 1 file changed, 87 deletions(-) delete mode 100644 cubed-example.ipynb diff --git a/cubed-example.ipynb b/cubed-example.ipynb deleted file mode 100644 index b2641735..00000000 --- a/cubed-example.ipynb +++ /dev/null @@ -1,87 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "8c33b843-dbaf-4320-a4ae-868b732a1171", - "metadata": {}, - "outputs": [], - "source": [ - "# based on https://flox.readthedocs.io/en/latest/user-stories/climatology-hourly.html\n", - "# but with smaller data sizes so it can be run locally\n", - "\n", - "import cubed\n", - "import cubed.array_api as xp\n", - "import numpy as np\n", - "import pandas as pd\n", - "import xarray as xr\n", - "\n", - "import flox.xarray" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "674c8844-f411-4a1a-b055-894b954639a4", - "metadata": {}, - "outputs": [], - "source": [ - "spec = cubed.Spec(allowed_mem=\"2GB\")\n", - "ds = xr.Dataset(\n", - " {\n", - " \"tp\": (\n", - " (\"time\", \"latitude\", \"longitude\"),\n", - " xp.ones((8760, 72, 144), chunks=(744, 5, 144), dtype=np.float32, spec=spec),\n", - " )\n", - " },\n", - " coords={\"time\": pd.date_range(\"2021-01-01\", \"2021-12-31 23:59\", freq=\"h\")},\n", - ")\n", - "ds" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3137d5f5-0706-46a5-8c63-be9c7e420229", - "metadata": {}, - "outputs": [], - "source": [ - "hourly = flox.xarray.xarray_reduce(ds.tp, ds.time.dt.hour, func=\"mean\")\n", - "hourly" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "676d3f12-13f6-4b8f-ae74-5a9017dd69fd", - "metadata": {}, - "outputs": [], - "source": [ - "hourly.compute()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8b59dcfc-5723-400c-977a-9cfc38bfa303", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From ef6b1fa46c5ad2e48e4d59540a9013be25b9996f Mon Sep 17 00:00:00 2001 From: Tom White Date: Thu, 18 Apr 2024 12:51:18 +0100 Subject: [PATCH 11/15] Add CubedArray type --- flox/core.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/flox/core.py b/flox/core.py index 5ed3447e..90795549 100644 --- a/flox/core.py +++ b/flox/core.py @@ -68,7 +68,9 @@ import dask.array.Array as DaskArray from dask.typing import Graph - T_DuckArray = Union[np.ndarray, DaskArray] # Any ? + import cubed.Array as CubedArray + + T_DuckArray = Union[np.ndarray, DaskArray, CubedArray] # Any ? T_By = T_DuckArray T_Bys = tuple[T_By, ...] T_ExpectIndex = pd.Index @@ -97,7 +99,7 @@ IntermediateDict = dict[Union[str, Callable], Any] -FinalResultsDict = dict[str, Union["DaskArray", np.ndarray]] +FinalResultsDict = dict[str, Union["DaskArray", "CubedArray", np.ndarray]] FactorProps = namedtuple("FactorProps", "offset_group nan_sentinel nanmask") # This dummy axis is inserted using np.expand_dims @@ -1721,7 +1723,7 @@ def dask_groupby_agg( def cubed_groupby_agg( - array: DaskArray, + array: CubedArray, by: T_By, agg: Aggregation, expected_groups: pd.Index | None, @@ -1732,7 +1734,7 @@ def cubed_groupby_agg( engine: T_Engine = "numpy", sort: bool = True, chunks_cohorts=None, -) -> tuple[DaskArray, tuple[np.ndarray | DaskArray]]: +) -> tuple[CubedArray, tuple[np.ndarray | CubedArray]]: import cubed import cubed.core.groupby From b0f7c680e5d690f2d8daf44f8b504a72fb9e20e7 Mon Sep 17 00:00:00 2001 From: Tom White Date: Thu, 18 Apr 2024 12:58:16 +0100 Subject: [PATCH 12/15] Add Cubed to CI --- ci/environment.yml | 1 + pyproject.toml | 1 + 2 files changed, 2 insertions(+) diff --git a/ci/environment.yml b/ci/environment.yml index b07bfa6d..9685a90b 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -6,6 +6,7 @@ dependencies: - cachey - cftime - codecov + - cubed>=0.14.2 - dask-core - pandas - numpy>=1.22 diff --git a/pyproject.toml b/pyproject.toml index 27fdbf26..d9f99cd0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,6 +121,7 @@ module=[ "asv_runner.*", "cachey", "cftime", + "cubed.*", "dask.*", "importlib_metadata", "numba", From 1018620f25d7a740ee3a310e808970443c329ba7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Apr 2024 14:03:08 +0000 Subject: [PATCH 13/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flox/core.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flox/core.py b/flox/core.py index 90795549..441b5aad 100644 --- a/flox/core.py +++ b/flox/core.py @@ -65,11 +65,10 @@ except (ModuleNotFoundError, ImportError): Unpack: Any # type: ignore[no-redef] + import cubed.Array as CubedArray import dask.array.Array as DaskArray from dask.typing import Graph - import cubed.Array as CubedArray - T_DuckArray = Union[np.ndarray, DaskArray, CubedArray] # Any ? T_By = T_DuckArray T_Bys = tuple[T_By, ...] From 2e80dc2e58ab42ec80049bb5b0d1ef156d031c50 Mon Sep 17 00:00:00 2001 From: Tom White Date: Fri, 19 Apr 2024 10:07:02 +0100 Subject: [PATCH 14/15] Make mypy happy --- flox/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flox/core.py b/flox/core.py index 441b5aad..6780c295 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1764,6 +1764,7 @@ def cubed_groupby_agg( assert do_simple_combine assert method == "map-reduce" + assert expected_groups is not None assert reindex is True assert len(axis) == 1 # one axis/grouping From 5b27ebcb018cdb251b31ceb7c345f34e6a14e3ca Mon Sep 17 00:00:00 2001 From: Tom White Date: Fri, 19 Apr 2024 10:38:59 +0100 Subject: [PATCH 15/15] Make mypy happy (again) --- flox/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flox/core.py b/flox/core.py index 6780c295..30589c25 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1777,7 +1777,7 @@ def _groupby_func(a, by, axis, intermediate_dtype, num_groups): reindex=reindex, user_dtype=agg.dtype["user"], axis=axis, - expected_groups=expected_groups if reindex else None, + expected_groups=expected_groups, engine=engine, sort=sort, )