Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion asv_bench/benchmarks/cohorts.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def time_find_group_cohorts(self):
except AttributeError:
pass

def track_num_cohorts(self):
return len(self.chunks_cohorts())

def time_graph_construct(self):
flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis)

Expand All @@ -60,10 +63,11 @@ def track_num_tasks_optimized(self):
def track_num_layers(self):
return len(self.result.dask.layers)

track_num_cohorts.unit = "cohorts" # type: ignore[attr-defined] # Lazy
track_num_tasks.unit = "tasks" # type: ignore[attr-defined] # Lazy
track_num_tasks_optimized.unit = "tasks" # type: ignore[attr-defined] # Lazy
track_num_layers.unit = "layers" # type: ignore[attr-defined] # Lazy
for f in [track_num_tasks, track_num_tasks_optimized, track_num_layers]:
for f in [track_num_tasks, track_num_tasks_optimized, track_num_layers, track_num_cohorts]:
f.repeat = 1 # type: ignore[attr-defined] # Lazy
f.rounds = 1 # type: ignore[attr-defined] # Lazy
f.number = 1 # type: ignore[attr-defined] # Lazy
Expand Down
32 changes: 24 additions & 8 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def find_group_cohorts(
# Invert the label_chunks mapping so we know which labels occur together.
def invert(x) -> tuple[np.ndarray, ...]:
arr = label_chunks[x]
return tuple(arr)
return tuple(arr.tolist())

chunks_cohorts = tlz.groupby(invert, label_chunks.keys())

Expand Down Expand Up @@ -477,22 +477,37 @@ def invert(x) -> tuple[np.ndarray, ...]:
containment.nnz / math.prod(containment.shape)
)
)
# Use a threshold to force some merging. We do not use the filtered
# containment matrix for estimating "sparsity" because it is a bit
# hard to reason about.

# Next we for-loop over groups and merge those that are quite similar.
# Use a threshold on containment to always force some merging.
# Note that we do not use the filtered containment matrix for estimating "sparsity"
# because it is a bit hard to reason about.
MIN_CONTAINMENT = 0.75 # arbitrary
mask = containment.data < MIN_CONTAINMENT

# Now we also know "exact cohorts" -- cohorts whose constituent groups
# occur in exactly the same chunks. We only need examine one member of each group.
# Skip the others by first looping over the exact cohorts, and zero out those rows.
repeated = np.concatenate([v[1:] for v in chunks_cohorts.values()]).astype(int)
repeated_idx = np.searchsorted(present_labels, repeated)
for i in repeated_idx:
mask[containment.indptr[i] : containment.indptr[i + 1]] = True
containment.data[mask] = 0
containment.eliminate_zeros()

# Iterate over labels, beginning with those with most chunks
# Figure out all the labels we need to loop over later
n_overlapping_labels = containment.astype(bool).sum(axis=1)
order = np.argsort(n_overlapping_labels, kind="stable")[::-1]
# Order is such that we iterate over labels, beginning with those with most overlaps
# Also filter out any "exact" cohorts
order = order[n_overlapping_labels[order] > 0]

logger.debug("find_group_cohorts: merging cohorts")
order = np.argsort(containment.sum(axis=LABEL_AXIS), kind="stable")[::-1]
merged_cohorts = {}
merged_keys = set()
# TODO: we can optimize this to loop over chunk_cohorts instead
# by zeroing out rows that are already in a cohort
for rowidx in order:
if present_labels[rowidx] in merged_keys:
continue
cohidx = containment.indices[
slice(containment.indptr[rowidx], containment.indptr[rowidx + 1])
]
Expand All @@ -507,6 +522,7 @@ def invert(x) -> tuple[np.ndarray, ...]:

actual_ngroups = np.concatenate(tuple(merged_cohorts.values())).size
expected_ngroups = present_labels.size
assert len(merged_keys) == actual_ngroups
assert expected_ngroups == actual_ngroups, (expected_ngroups, actual_ngroups)

# sort by first label in cohort
Expand Down
Loading