Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tiledb/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ def __getitem__(self, selection):

return result

def label_index(self, labels):
"""Apply Array.label_index with query parameters."""
from .multirange_indexing import LabelAggregation

return LabelAggregation(self.query.array, tuple(labels), query=self)

@property
def multi_index(self):
"""Apply Array.multi_index with query parameters."""
Expand Down
71 changes: 71 additions & 0 deletions tiledb/multirange_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,77 @@ def _run_query(self) -> Dict[str, np.ndarray]:
return result


class LabelAggregation(MultiRangeAggregation):
"""
Implements multi-range aggregation indexing by label.
"""

def __init__(
self,
array: Array,
labels: Sequence[str],
query: Optional[AggregationProxy] = None,
):
if array.schema.sparse:
raise NotImplementedError(
"querying sparse arrays by label is not yet implemented"
)
super().__init__(array, query)
self.label_query: Optional[Query] = None
self._labels: Dict[int, str] = {}
for label_name in labels:
dim_label = array.schema.dim_label(label_name)
dim_idx = dim_label.dim_index
if dim_idx in self._labels:
raise TileDBError(
f"cannot set labels `{self._labels[dim_idx]}` and "
f"`{label_name}` defined on the same dimension"
)
self._labels[dim_idx] = label_name

def _set_ranges(self, idx):
dim_ranges, label_ranges = getitem_ranges_with_labels(
self.array, self._labels, idx
)
if label_ranges is None:
with timing("add_ranges"):
self.subarray.add_ranges(tuple(dim_ranges))
# No label query.
self.label_query = None
# All ranges are finalized: set shape and subarray now.
self._set_shape(dim_ranges)
self.pyquery.set_subarray(self.subarray)
else:
label_subarray = Subarray(self.array)
with timing("add_ranges"):
self.subarray.add_ranges(dim_ranges=dim_ranges)
label_subarray.add_ranges(label_ranges=label_ranges)
self.label_query = Query(self.array)
self.label_query.set_subarray(label_subarray)

def _run_query(self) -> Dict[str, np.ndarray]:
# If querying by label and the label query is not yet complete, run the label
# query and update the pyquery with the actual dimensions.
if self.label_query is not None and not self.label_query.is_complete():
self.label_query._submit()

if not self.label_query.is_complete():
raise TileDBError("failed to get dimension ranges from labels")
label_subarray = self.label_query.subarray()
# Check that the label query returned results for all dimensions.
if any(
label_subarray.num_dim_ranges(dim_idx) == 0 for dim_idx in self._labels
):
self.pyquery = None
else:
# Get the ranges from the label query and set to the
self.subarray.copy_ranges(
self.label_query.subarray(), self._labels.keys()
)
self.pyquery.set_subarray(self.subarray)
return super()._run_query()


class DataFrameIndexer(_BaseIndexer):
"""
Implements `.df[]` indexing to directly return a dataframe
Expand Down
74 changes: 74 additions & 0 deletions tiledb/tests/test_dimension_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,3 +483,77 @@ def test_dimension_label_on_query(self):
},
),
)

@pytest.mark.skipif(
tiledb.libtiledb.version() < (2, 15),
reason="dimension labels requires libtiledb version 2.15 or greater",
)
def test_dimension_label_on_aggregation(self):
uri = self.path("aggregation_label_index")

dim1 = tiledb.Dim("d1", domain=(0, 3), dtype=np.int32)
dim2 = tiledb.Dim("d2", domain=(0, 2), dtype=np.int32)
dom = tiledb.Domain(dim1, dim2)
att = tiledb.Attr("a1", dtype=np.int64)
dim_labels = {
0: {"l1": dim1.create_label_schema("increasing", np.int64)},
1: {"l2": dim2.create_label_schema("increasing", np.float64)},
}
schema = tiledb.ArraySchema(domain=dom, attrs=(att,), dim_labels=dim_labels)
tiledb.Array.create(uri, schema)

# Create data: [[10, 20, 30], [40, 50, 60], [70, 80, 90], [100, 110, 120]]
a1_data = np.reshape(np.arange(10, 130, 10), (4, 3))
l1_data = np.array([100, 200, 300, 400], dtype=np.int64)
l2_data = np.array([1.0, 2.0, 3.0], dtype=np.float64)

with tiledb.open(uri, "w") as A:
A[:] = {"a1": a1_data, "l1": l1_data, "l2": l2_data}

with tiledb.open(uri, "r") as A:
# Test sum aggregation with single dimension label
q = A.query(attrs="", dims=["d1"])
result = q.agg("sum").label_index(["l1"])[200:300]
# Sum of rows 1 and 2: [40, 50, 60] + [70, 80, 90] = 390
assert result == 390

# Test count aggregation
result = q.agg("count").label_index(["l1"])[100:400]
# All 4 rows, 3 columns each = 12 elements
assert result == 12

# Test mean aggregation
result = q.agg("mean").label_index(["l1"])[200:300]
# Mean of [40, 50, 60, 70, 80, 90] = 65.0
assert result == 65.0

# Test min aggregation
result = q.agg("min").label_index(["l1"])[200:300]
# Min of [40, 50, 60, 70, 80, 90] = 40
assert result == 40

# Test max aggregation
result = q.agg("max").label_index(["l1"])[200:300]
# Max of [40, 50, 60, 70, 80, 90] = 90
assert result == 90

# Test with second dimension label (floating point)
result = q.agg("sum").label_index(["l2"])[:, 2.0:3.0]
# Sum of columns 1 and 2: [20, 50, 80, 110] + [30, 60, 90, 120] = 560
assert result == 560

# Test with multiple dimension labels
result = q.agg("sum").label_index(["l1", "l2"])[200:300, 1.0:2.0]
# Sum of rows 1-2, columns 0-1: [40, 50, 70, 80] = 240
assert result == 240

# Test single point selection
result = q.agg("sum").label_index(["l1"])[200:200]
# Sum of row 1: [40, 50, 60] = 150
assert result == 150

# Test with multiple aggregations
result = q.agg(["sum", "mean"]).label_index(["l1"])[100:200]
# Rows 0-1: [10, 20, 30, 40, 50, 60]
assert result["sum"] == 210
assert result["mean"] == 35.0
Loading