From d5cb68c7183107785f1b0f59adafa5d2e420cd69 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 1 May 2024 22:28:21 -0600 Subject: [PATCH 1/2] Manually fuse reindexing intermediates with blockwise reduction for cohorts. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ``` | Change | Before [627bf2b6]
| After [9d710529] | Ratio | Benchmark (Parameter) | |----------|----------------------------|---------------------------------------------|---------|-------------------------------------------------| | - | 3.39±0.02ms | 2.98±0.01ms | 0.88 | cohorts.PerfectMonthly.time_graph_construct | | - | 20 | 17 | 0.85 | cohorts.PerfectMonthly.track_num_layers | | - | 23.0±0.07ms | 19.0±0.1ms | 0.83 | cohorts.ERA5Google.time_graph_construct | | - | 4878 | 3978 | 0.82 | cohorts.ERA5Google.track_num_tasks | | - | 179±0.8ms | 147±0.5ms | 0.82 | cohorts.OISST.time_graph_construct | | - | 159 | 128 | 0.81 | cohorts.ERA5Google.track_num_layers | | - | 936 | 762 | 0.81 | cohorts.PerfectMonthly.track_num_tasks | | - | 1221 | 978 | 0.8 | cohorts.OISST.track_num_layers | | - | 4929 | 3834 | 0.78 | cohorts.ERA5DayOfYear.track_num_tasks | | - | 351 | 274 | 0.78 | cohorts.NWMMidwest.track_num_layers | | - | 4562 | 3468 | 0.76 | cohorts.ERA5DayOfYear.track_num_tasks_optimized | | - | 164±1ms | 118±0.4ms | 0.72 | cohorts.ERA5DayOfYear.time_graph_construct | | - | 1100 | 735 | 0.67 | cohorts.ERA5DayOfYear.track_num_layers | | - | 3930 | 2605 | 0.66 | cohorts.NWMMidwest.track_num_tasks | | - | 3715 | 2409 | 0.65 | cohorts.NWMMidwest.track_num_tasks_optimized | | - | 28952 | 18798 | 0.65 | cohorts.OISST.track_num_tasks | | - | 27010 | 16858 | 0.62 | cohorts.OISST.track_num_tasks_optimized | ``` --- asv_bench/benchmarks/cohorts.py | 12 ++++++------ flox/core.py | 34 +++++++++++++++++++++------------ tests/test_core.py | 8 ++++++-- 3 files changed, 34 insertions(+), 20 deletions(-) diff --git a/asv_bench/benchmarks/cohorts.py b/asv_bench/benchmarks/cohorts.py index 8cb80119..da12dfdf 100644 --- a/asv_bench/benchmarks/cohorts.py +++ b/asv_bench/benchmarks/cohorts.py @@ -14,8 +14,8 @@ def setup(self, *args, **kwargs): raise NotImplementedError @cached_property - def dask(self): - return flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis)[0].dask + def result(self): + return flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis)[0] def containment(self): asfloat = self.bitmask().astype(float) @@ -52,14 +52,14 @@ def time_graph_construct(self): flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis) def track_num_tasks(self): - return len(self.dask.to_dict()) + return len(self.result.dask.to_dict()) def track_num_tasks_optimized(self): - (opt,) = dask.optimize(self.dask) - return len(opt.to_dict()) + (opt,) = dask.optimize(self.result) + return len(opt.dask.to_dict()) def track_num_layers(self): - return len(self.dask.layers) + return len(self.result.dask.layers) track_num_tasks.unit = "tasks" # type: ignore[attr-defined] # Lazy track_num_tasks_optimized.unit = "tasks" # type: ignore[attr-defined] # Lazy diff --git a/flox/core.py b/flox/core.py index 25a3495f..9431f1c7 100644 --- a/flox/core.py +++ b/flox/core.py @@ -17,6 +17,7 @@ Callable, Literal, TypedDict, + TypeVar, Union, overload, ) @@ -96,6 +97,7 @@ T_MethodOpt = None | Literal["map-reduce", "blockwise", "cohorts"] T_IsBins = Union[bool | Sequence[bool]] +T = TypeVar("T") IntermediateDict = dict[Union[str, Callable], Any] FinalResultsDict = dict[str, Union["DaskArray", "CubedArray", np.ndarray]] @@ -140,6 +142,10 @@ def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups): return result +def identity(x: T) -> T: + return x + + def _issorted(arr: np.ndarray) -> bool: return bool((arr[:-1] <= arr[1:]).all()) @@ -1438,8 +1444,11 @@ def _normalize_indexes(array: DaskArray, flatblocks, blkshape) -> tuple: def subset_to_blocks( - array: DaskArray, flatblocks: Sequence[int], blkshape: tuple[int] | None = None -) -> DaskArray: + array: DaskArray, + flatblocks: Sequence[int], + blkshape: tuple[int] | None = None, + reindexer=identity, +) -> Graph: """ Advanced indexing of .blocks such that we always get a regular array back. @@ -1464,20 +1473,21 @@ def subset_to_blocks( index = _normalize_indexes(array, flatblocks, blkshape) if all(not isinstance(i, np.ndarray) and i == slice(None) for i in index): - return array + return dask.array.map_blocks(reindexer, array, meta=array._meta) # These rest is copied from dask.array.core.py with slight modifications index = normalize_index(index, array.numblocks) index = tuple(slice(k, k + 1) if isinstance(k, Integral) else k for k in index) - name = "blocks-" + tokenize(array, index) + name = "groupby-cohort-" + tokenize(array, index) new_keys = array._key_array[index] squeezed = tuple(np.squeeze(i) if isinstance(i, np.ndarray) else i for i in index) chunks = tuple(tuple(np.array(c)[i].tolist()) for c, i in zip(array.chunks, squeezed)) keys = itertools.product(*(range(len(c)) for c in chunks)) - layer: Graph = {(name,) + key: tuple(new_keys[key].tolist()) for key in keys} + layer: Graph = {(name,) + key: (reindexer, tuple(new_keys[key].tolist())) for key in keys} + graph = HighLevelGraph.from_collections(name, layer, dependencies=[array]) return dask.array.Array(graph, name, chunks, meta=array) @@ -1651,26 +1661,26 @@ def dask_groupby_agg( elif method == "cohorts": assert chunks_cohorts + block_shape = array.blocks.shape[-len(axis) :] + reduced_ = [] groups_ = [] for blks, cohort in chunks_cohorts.items(): - index = pd.Index(cohort) - subset = subset_to_blocks(intermediate, blks, array.blocks.shape[-len(axis) :]) - reindexed = dask.array.map_blocks( - reindex_intermediates, subset, agg, index, meta=subset._meta - ) + cohort_index = pd.Index(cohort) + reindexer = partial(reindex_intermediates, agg=agg, unique_groups=cohort_index) + reindexed = subset_to_blocks(intermediate, blks, block_shape, reindexer) # now that we have reindexed, we can set reindex=True explicitlly reduced_.append( tree_reduce( reindexed, combine=partial(combine, agg=agg, reindex=True), - aggregate=partial(aggregate, expected_groups=index, reindex=True), + aggregate=partial(aggregate, expected_groups=cohort_index, reindex=True), ) ) # This is done because pandas promotes to 64-bit types when an Index is created # So we use the index to generate the return value for consistency with "map-reduce" # This is important on windows - groups_.append(index.values) + groups_.append(cohort_index.values) reduced = dask.array.concatenate(reduced_, axis=-1) groups = (np.concatenate(groups_),) diff --git a/tests/test_core.py b/tests/test_core.py index 1ec2db78..75c4d9ff 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1465,14 +1465,18 @@ def test_normalize_block_indexing_2d(flatblocks, expected): @requires_dask def test_subset_block_passthrough(): + from flox.core import identity + # full slice pass through array = dask.array.ones((5,), chunks=(1,)) + expected = dask.array.map_blocks(identity, array) subset = subset_to_blocks(array, np.arange(5)) - assert subset.name == array.name + assert subset.name == expected.name array = dask.array.ones((5, 5), chunks=1) + expected = dask.array.map_blocks(identity, array) subset = subset_to_blocks(array, np.arange(25)) - assert subset.name == array.name + assert subset.name == expected.name @requires_dask From e2b1a789c9865fcf5c25ce35e85edc8de353e5e8 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 1 May 2024 23:04:27 -0600 Subject: [PATCH 2/2] fix typing --- flox/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flox/core.py b/flox/core.py index 9431f1c7..308e05d9 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1448,7 +1448,7 @@ def subset_to_blocks( flatblocks: Sequence[int], blkshape: tuple[int] | None = None, reindexer=identity, -) -> Graph: +) -> DaskArray: """ Advanced indexing of .blocks such that we always get a regular array back.