11from __future__ import annotations
22
3- from typing import TYPE_CHECKING , Hashable , Iterable , Sequence
3+ from typing import TYPE_CHECKING , Any , Hashable , Iterable , Sequence , Union
44
55import numpy as np
66import pandas as pd
1919from .xrutils import _contains_cftime_datetimes , _to_pytimedelta , datetime_to_numeric
2020
2121if TYPE_CHECKING :
22- from xarray import DataArray , Dataset , Resample
22+ from xarray .core .resample import Resample
23+ from xarray .core .types import T_DataArray , T_Dataset
24+
25+ Dims = Union [str , Iterable [Hashable ], None ]
2326
2427
2528def _get_input_core_dims (group_names , dim , ds , grouper_dims ):
@@ -51,13 +54,13 @@ def lookup_order(dimension):
5154
5255
5356def xarray_reduce (
54- obj : Dataset | DataArray ,
55- * by : DataArray | Iterable [ str ] | Iterable [ DataArray ] ,
57+ obj : T_Dataset | T_DataArray ,
58+ * by : T_DataArray | Hashable ,
5659 func : str | Aggregation ,
5760 expected_groups = None ,
5861 isbin : bool | Sequence [bool ] = False ,
5962 sort : bool = True ,
60- dim : Hashable = None ,
63+ dim : Dims | ellipsis = None ,
6164 split_out : int = 1 ,
6265 fill_value = None ,
6366 method : str = "map-reduce" ,
@@ -203,8 +206,11 @@ def xarray_reduce(
203206 if keep_attrs is None :
204207 keep_attrs = True
205208
206- if isinstance (isbin , bool ):
207- isbin = (isbin ,) * nby
209+ if isinstance (isbin , Sequence ):
210+ isbins = isbin
211+ else :
212+ isbins = (isbin ,) * nby
213+
208214 if expected_groups is None :
209215 expected_groups = (None ,) * nby
210216 if isinstance (expected_groups , (np .ndarray , list )): # TODO: test for list
@@ -217,78 +223,86 @@ def xarray_reduce(
217223 raise NotImplementedError
218224
219225 # eventually drop the variables we are grouping by
220- maybe_drop = [b for b in by if isinstance (b , str )]
226+ maybe_drop = [b for b in by if isinstance (b , Hashable )]
221227 unindexed_dims = tuple (
222228 b
223- for b , isbin_ in zip (by , isbin )
224- if isinstance (b , str ) and not isbin_ and b in obj .dims and b not in obj .indexes
229+ for b , isbin_ in zip (by , isbins )
230+ if isinstance (b , Hashable ) and not isbin_ and b in obj .dims and b not in obj .indexes
225231 )
226232
227- by : tuple [ DataArray ] = tuple (obj [g ] if isinstance (g , str ) else g for g in by ) # type: ignore
233+ by_da = tuple (obj [g ] if isinstance (g , Hashable ) else g for g in by )
228234
229235 grouper_dims = []
230- for g in by :
236+ for g in by_da :
231237 for d in g .dims :
232238 if d not in grouper_dims :
233239 grouper_dims .append (d )
234240
235- if isinstance (obj , xr .DataArray ):
236- ds = obj ._to_temp_dataset ()
237- else :
241+ if isinstance (obj , xr .Dataset ):
238242 ds = obj
243+ else :
244+ ds = obj ._to_temp_dataset ()
239245
240246 ds = ds .drop_vars ([var for var in maybe_drop if var in ds .variables ])
241247
242248 if dim is Ellipsis :
243249 if nby > 1 :
244250 raise NotImplementedError ("Multiple by are not allowed when dim is Ellipsis." )
245- dim = tuple (obj .dims )
246- if by [0 ].name in ds .dims and not isbin [0 ]:
247- dim = tuple (d for d in dim if d != by [0 ].name )
251+ name_ = by_da [0 ].name
252+ if name_ in ds .dims and not isbins [0 ]:
253+ dim_tuple = tuple (d for d in obj .dims if d != name_ )
254+ else :
255+ dim_tuple = tuple (obj .dims )
248256 elif dim is not None :
249- dim = _atleast_1d (dim )
257+ dim_tuple = _atleast_1d (dim )
250258 else :
251- dim = tuple ()
259+ dim_tuple = tuple ()
252260
253261 # broadcast all variables against each other along all dimensions in `by` variables
254262 # don't exclude `dim` because it need not be a dimension in any of the `by` variables!
255263 # in the case where dim is Ellipsis, and by.ndim < obj.ndim
256264 # then we also broadcast `by` to all `obj.dims`
257265 # TODO: avoid this broadcasting
258- exclude_dims = tuple (d for d in ds .dims if d not in grouper_dims and d not in dim )
259- ds , * by = xr .broadcast (ds , * by , exclude = exclude_dims )
266+ exclude_dims = tuple (d for d in ds .dims if d not in grouper_dims and d not in dim_tuple )
267+ ds_broad , * by_broad = xr .broadcast (ds , * by_da , exclude = exclude_dims )
260268
261- if not dim :
262- dim = tuple (by [0 ].dims )
269+ # all members of by_broad have the same dimensions
270+ # so we just pull by_broad[0].dims if dim is None
271+ if not dim_tuple :
272+ dim_tuple = tuple (by_broad [0 ].dims )
263273
264- if any (d not in grouper_dims and d not in obj .dims for d in dim ):
274+ if any (d not in grouper_dims and d not in obj .dims for d in dim_tuple ):
265275 raise ValueError (f"Cannot reduce over absent dimensions { dim } ." )
266276
267- dims_not_in_groupers = tuple (d for d in dim if d not in grouper_dims )
268- if dims_not_in_groupers == tuple (dim ) and not any (isbin ):
277+ dims_not_in_groupers = tuple (d for d in dim_tuple if d not in grouper_dims )
278+ if dims_not_in_groupers == tuple (dim_tuple ) and not any (isbins ):
269279 # reducing along a dimension along which groups do not vary
270280 # This is really just a normal reduction.
271281 # This is not right when binning so we exclude.
272- if skipna and isinstance (func , str ):
273- dsfunc = func [3 :]
282+ if isinstance (func , str ):
283+ dsfunc = func [3 :] if skipna else func
274284 else :
275- dsfunc = func
285+ raise NotImplementedError (
286+ "func must be a string when reducing along a dimension not present in `by`"
287+ )
276288 # TODO: skipna needs test
277- result = getattr (ds , dsfunc )(dim = dim , skipna = skipna )
289+ result = getattr (ds_broad , dsfunc )(dim = dim_tuple , skipna = skipna )
278290 if isinstance (obj , xr .DataArray ):
279291 return obj ._from_temp_dataset (result )
280292 else :
281293 return result
282294
283- axis = tuple (range (- len (dim ), 0 ))
284- group_names = tuple (g .name if not binned else f"{ g .name } _bins" for g , binned in zip (by , isbin ))
285-
286- group_shape = [None ] * len (by )
287- expected_groups = list (expected_groups )
295+ axis = tuple (range (- len (dim_tuple ), 0 ))
288296
289297 # Set expected_groups and convert to index since we need coords, sizes
290298 # for output xarray objects
291- for idx , (b , expect , isbin_ ) in enumerate (zip (by , expected_groups , isbin )):
299+ expected_groups = list (expected_groups )
300+ group_names : tuple [Any , ...] = ()
301+ group_sizes : dict [Any , int ] = {}
302+ for idx , (b_ , expect , isbin_ ) in enumerate (zip (by_broad , expected_groups , isbins )):
303+ group_name = b_ .name if not isbin_ else f"{ b_ .name } _bins"
304+ group_names += (group_name ,)
305+
292306 if isbin_ and isinstance (expect , int ):
293307 raise NotImplementedError (
294308 "flox does not support binning into an integer number of bins yet."
@@ -297,13 +311,21 @@ def xarray_reduce(
297311 if isbin_ :
298312 raise ValueError (
299313 f"Please provided bin edges for group variable { idx } "
300- f"named { group_names [ idx ] } in expected_groups."
314+ f"named { group_name } in expected_groups."
301315 )
302- expected_groups [idx ] = _get_expected_groups (b .data , sort = sort , raise_if_dask = True )
303-
304- expected_groups = _convert_expected_groups_to_index (expected_groups , isbin , sort = sort )
305- group_shape = tuple (len (e ) for e in expected_groups )
306- group_sizes = dict (zip (group_names , group_shape ))
316+ expect_ = _get_expected_groups (b_ .data , sort = sort , raise_if_dask = True )
317+ else :
318+ expect_ = expect
319+ expect_index = _convert_expected_groups_to_index ((expect_ ,), (isbin_ ,), sort = sort )[0 ]
320+
321+ # The if-check is for type hinting mainly, it narrows down the return
322+ # type of _convert_expected_groups_to_index to pure pd.Index:
323+ if expect_index is not None :
324+ expected_groups [idx ] = expect_index
325+ group_sizes [group_name ] = len (expect_index )
326+ else :
327+ # This will never be reached
328+ raise ValueError ("expect_index cannot be None" )
307329
308330 def wrapper (array , * by , func , skipna , ** kwargs ):
309331 # Handle skipna here because I need to know dtype to make a good default choice.
@@ -349,20 +371,20 @@ def wrapper(array, *by, func, skipna, **kwargs):
349371 if isinstance (obj , xr .Dataset ):
350372 # broadcasting means the group dim gets added to ds, so we check the original obj
351373 for k , v in obj .data_vars .items ():
352- is_missing_dim = not (any (d in v .dims for d in dim ))
374+ is_missing_dim = not (any (d in v .dims for d in dim_tuple ))
353375 if is_missing_dim :
354376 missing_dim [k ] = v
355377
356- input_core_dims = _get_input_core_dims (group_names , dim , ds , grouper_dims )
378+ input_core_dims = _get_input_core_dims (group_names , dim_tuple , ds_broad , grouper_dims )
357379 input_core_dims += [input_core_dims [- 1 ]] * (nby - 1 )
358380
359381 actual = xr .apply_ufunc (
360382 wrapper ,
361- ds .drop_vars (tuple (missing_dim )).transpose (..., * grouper_dims ),
362- * by ,
383+ ds_broad .drop_vars (tuple (missing_dim )).transpose (..., * grouper_dims ),
384+ * by_broad ,
363385 input_core_dims = input_core_dims ,
364386 # for xarray's test_groupby_duplicate_coordinate_labels
365- exclude_dims = set (dim ),
387+ exclude_dims = set (dim_tuple ),
366388 output_core_dims = [group_names ],
367389 dask = "allowed" ,
368390 dask_gufunc_kwargs = dict (output_sizes = group_sizes ),
@@ -379,27 +401,27 @@ def wrapper(array, *by, func, skipna, **kwargs):
379401 "engine" : engine ,
380402 "reindex" : reindex ,
381403 "expected_groups" : tuple (expected_groups ),
382- "isbin" : isbin ,
404+ "isbin" : isbins ,
383405 "finalize_kwargs" : finalize_kwargs ,
384406 },
385407 )
386408
387409 # restore non-dim coord variables without the core dimension
388410 # TODO: shouldn't apply_ufunc handle this?
389- for var in set (ds .variables ) - set (ds .dims ):
390- if all (d not in ds [var ].dims for d in dim ):
391- actual [var ] = ds [var ]
411+ for var in set (ds_broad .variables ) - set (ds_broad .dims ):
412+ if all (d not in ds_broad [var ].dims for d in dim_tuple ):
413+ actual [var ] = ds_broad [var ]
392414
393- for name , expect , by_ in zip (group_names , expected_groups , by ):
415+ for name , expect , by_ in zip (group_names , expected_groups , by_broad ):
394416 # Can't remove this till xarray handles IntervalIndex
395417 if isinstance (expect , pd .IntervalIndex ):
396418 expect = expect .to_numpy ()
397419 if isinstance (actual , xr .Dataset ) and name in actual :
398420 actual = actual .drop_vars (name )
399421 # When grouping by MultiIndex, expect is an pd.Index wrapping
400422 # an object array of tuples
401- if name in ds .indexes and isinstance (ds .indexes [name ], pd .MultiIndex ):
402- levelnames = ds .indexes [name ].names
423+ if name in ds_broad .indexes and isinstance (ds_broad .indexes [name ], pd .MultiIndex ):
424+ levelnames = ds_broad .indexes [name ].names
403425 expect = pd .MultiIndex .from_tuples (expect .values , names = levelnames )
404426 actual [name ] = expect
405427 if Version (xr .__version__ ) > Version ("2022.03.0" ):
@@ -414,18 +436,17 @@ def wrapper(array, *by, func, skipna, **kwargs):
414436
415437 if nby == 1 :
416438 for var in actual :
417- if isinstance (obj , xr .DataArray ):
418- template = obj
419- else :
439+ if isinstance (obj , xr .Dataset ):
420440 template = obj [var ]
441+ else :
442+ template = obj
443+
421444 if actual [var ].ndim > 1 :
422- actual [var ] = _restore_dim_order (actual [var ], template , by [0 ])
445+ actual [var ] = _restore_dim_order (actual [var ], template , by_broad [0 ])
423446
424447 if missing_dim :
425448 for k , v in missing_dim .items ():
426- missing_group_dims = {
427- dim : size for dim , size in group_sizes .items () if dim not in v .dims
428- }
449+ missing_group_dims = {d : size for d , size in group_sizes .items () if d not in v .dims }
429450 # The expand_dims is for backward compat with xarray's questionable behaviour
430451 if missing_group_dims :
431452 actual [k ] = v .expand_dims (missing_group_dims ).variable
@@ -439,9 +460,9 @@ def wrapper(array, *by, func, skipna, **kwargs):
439460
440461
441462def rechunk_for_cohorts (
442- obj : DataArray | Dataset ,
463+ obj : T_DataArray | T_Dataset ,
443464 dim : str ,
444- labels : DataArray ,
465+ labels : T_DataArray ,
445466 force_new_chunk_at ,
446467 chunksize : int | None = None ,
447468 ignore_old_chunks : bool = False ,
@@ -486,7 +507,7 @@ def rechunk_for_cohorts(
486507 )
487508
488509
489- def rechunk_for_blockwise (obj : DataArray | Dataset , dim : str , labels : DataArray ):
510+ def rechunk_for_blockwise (obj : T_DataArray | T_Dataset , dim : str , labels : T_DataArray ):
490511 """
491512 Rechunks array so that group boundaries line up with chunk boundaries, allowing
492513 embarassingly parallel group reductions.
0 commit comments