Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ Documentation

Internal Changes
~~~~~~~~~~~~~~~~
- Avoid stacking when grouping by a chunked array. This can be a large performance improvement.
By `Deepak Cherian <https://github.com/dcherian>`_.

.. _whats-new.2025.03.1:

Expand Down
28 changes: 19 additions & 9 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,18 +661,26 @@ def __init__(
# specification for the groupby operation
# TODO: handle obj having variables that are not present on any of the groupers
# simple broadcasting fails for ExtensionArrays.
# FIXME: Skip this stacking when grouping by a dask array, it's useless in that case.
(self.group1d, self._obj, self._stacked_dim, self._inserted_dims) = _ensure_1d(
group=self.encoded.codes, obj=obj
)
(self._group_dim,) = self.group1d.dims
codes = self.encoded.codes
self._by_chunked = is_chunked_array(codes._variable._data)
if not self._by_chunked:
(self.group1d, self._obj, self._stacked_dim, self._inserted_dims) = (
_ensure_1d(group=codes, obj=obj)
)
(self._group_dim,) = self.group1d.dims
else:
self.group1d = None
# This transpose preserves dim order behaviour
self._obj = obj.transpose(..., *codes.dims)
self._stacked_dim = None
self._inserted_dims = []
self._group_dim = None

# cached attributes
self._groups = None
self._dims = None
self._sizes = None
self._len = len(self.encoded.full_index)
self._by_chunked = is_chunked_array(self.encoded.codes.data)

@property
def sizes(self) -> Mapping[Hashable, int]:
Expand Down Expand Up @@ -817,6 +825,7 @@ def __getitem__(self, key: GroupKey) -> T_Xarray:
"""
Get DataArray or Dataset corresponding to a particular group label.
"""
self._raise_if_by_is_chunked()
return self._obj.isel({self._group_dim: self.groups[key]})

def __len__(self) -> int:
Expand Down Expand Up @@ -1331,9 +1340,6 @@ def quantile(
"Sample quantiles in statistical packages,"
The American Statistician, 50(4), pp. 361-365, 1996
"""
if dim is None:
dim = (self._group_dim,)

# Dataset.quantile does this, do it for flox to ensure same output.
q = np.asarray(q, dtype=np.float64)

Expand All @@ -1348,6 +1354,8 @@ def quantile(
)
return result
else:
if dim is None:
dim = (self._group_dim,)
return self.map(
self._obj.__class__.quantile,
shortcut=False,
Expand Down Expand Up @@ -1491,6 +1499,7 @@ class DataArrayGroupByBase(GroupBy["DataArray"], DataArrayGroupbyArithmetic):

@property
def dims(self) -> tuple[Hashable, ...]:
self._raise_if_by_is_chunked()
if self._dims is None:
index = self.encoded.group_indices[0]
self._dims = self._obj.isel({self._group_dim: index}).dims
Expand Down Expand Up @@ -1702,6 +1711,7 @@ class DatasetGroupByBase(GroupBy["Dataset"], DatasetGroupbyArithmetic):

@property
def dims(self) -> Frozen[Hashable, int]:
self._raise_if_by_is_chunked()
if self._dims is None:
index = self.encoded.group_indices[0]
self._dims = self._obj.isel({self._group_dim: index}).dims
Expand Down
Loading