|
19 | 19 | from flox.xrutils import notnull
|
20 | 20 |
|
21 | 21 | from . import assert_equal
|
22 |
| -from .strategies import by_arrays, chunked_arrays, func_st, numeric_arrays |
| 22 | +from .strategies import array_dtypes, by_arrays, chunked_arrays, func_st, numeric_arrays |
23 | 23 | from .strategies import chunks as chunks_strategy
|
24 | 24 |
|
25 | 25 | dask.config.set(scheduler="sync")
|
@@ -233,3 +233,25 @@ def test_first_last_useless(data, func):
|
233 | 233 | actual, groups = groupby_reduce(array, by, axis=-1, func=func, engine="numpy")
|
234 | 234 | expected = np.zeros(shape[:-1] + (len(groups),), dtype=array.dtype)
|
235 | 235 | assert_equal(actual, expected)
|
| 236 | + |
| 237 | + |
| 238 | +@given( |
| 239 | + func=st.sampled_from(["sum", "prod", "nansum", "nanprod"]), |
| 240 | + engine=st.sampled_from(["numpy", "flox"]), |
| 241 | + array_dtype=st.none() | array_dtypes, |
| 242 | + dtype=st.none() | array_dtypes, |
| 243 | +) |
| 244 | +def test_agg_dtype_specified(func, array_dtype, dtype, engine): |
| 245 | + # regression test for GH388 |
| 246 | + counts = np.array([0, 2, 1, 0, 1], dtype=array_dtype) |
| 247 | + group = np.array([1, 1, 1, 2, 2]) |
| 248 | + actual, _ = groupby_reduce( |
| 249 | + counts, |
| 250 | + group, |
| 251 | + expected_groups=(np.array([1, 2]),), |
| 252 | + func=func, |
| 253 | + dtype=dtype, |
| 254 | + engine=engine, |
| 255 | + ) |
| 256 | + expected = getattr(np, func)(counts, keepdims=True, dtype=dtype) |
| 257 | + assert actual.dtype == expected.dtype |
0 commit comments