Skip to content

Commit 887b40d

Browse files
authored
feat: support selecting on categorical axes (experimental) (#577)
* feat: support selecting on categorical axes * feat: support bh.loc in an list index expression * fix: add more tests and fix issues found * fix: clarify as experimental in 1.1.0
1 parent e063b13 commit 887b40d

File tree

3 files changed

+189
-17
lines changed

3 files changed

+189
-17
lines changed

docs/usage/indexing.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,23 @@ Boost-histogram specific details
1010
--------------------------------
1111

1212
Boost-histogram implements ``bh.loc``, ``builtins.sum``, ``bh.rebin``, ``bh.underflow``, and ``bh.overflow`` from the UHI spec. A ``bh.tag.at`` locator is provided as well, which simulates the Boost.Histogram C++ ``.at()`` indexing using the UHI locator protocol.
13+
14+
Boost-histogram allows "picking" using lists, similar to NumPy. If you select
15+
with multiple lists, boost-histogram instead selects per-axis, rather than
16+
group-selecting and reducing to a single axis, like NumPy does. You can use
17+
``bh.loc(...)`` inside these lists.
18+
19+
Example::
20+
21+
h = bh.histogram(
22+
bh.axis.Regular(10, 0, 1),
23+
bh.axis.StrCategory(["a", "b", "c"]),
24+
bh.axis.IntCategory([5, 6, 7]),
25+
)
26+
27+
minihist = h[:, [bh.loc("a"), bh.loc("c")], [0, 2]]
28+
29+
# Produces a 3D histgoram with Regular(10, 0, 1) x StrCategory(["a", "c"]) x IntCategory([5, 7])
30+
31+
32+
This feature is considered experimental in boost-histogram 1.1.0. Removed bins are not added to the overflow bin currently.

src/boost_histogram/_internal/hist.py

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import collections.abc
12
import copy
23
import logging
34
import threading
@@ -55,8 +56,10 @@
5556

5657
CppAxis = NewType("CppAxis", object)
5758

58-
InnerIndexing = Union[SupportsIndex, Callable[[Axis], int], slice, "ellipsis"]
59-
IndexingWithMapping = Union[InnerIndexing, Mapping[int, InnerIndexing]]
59+
SimpleIndexing = Union[SupportsIndex, slice]
60+
InnerIndexing = Union[SimpleIndexing, Callable[[Axis], int], "ellipsis"]
61+
FullInnerIndexing = Union[InnerIndexing, List[InnerIndexing]]
62+
IndexingWithMapping = Union[FullInnerIndexing, Mapping[int, FullInnerIndexing]]
6063
IndexingExpr = Union[IndexingWithMapping, Tuple[IndexingWithMapping, ...]]
6164

6265
T = TypeVar("T")
@@ -582,6 +585,26 @@ def __repr__(self) -> str:
582585
ret += f" ({outer} with flow)"
583586
return ret
584587

588+
def _compute_uhi_index(self, index: InnerIndexing, axis: int) -> SimpleIndexing:
589+
"""
590+
Converts an expression that contains UHI locators to one that does not.
591+
"""
592+
# Support sum and rebin directly
593+
if index is sum or hasattr(index, "factor"): # type: ignore
594+
index = slice(None, None, index)
595+
596+
# General locators
597+
# Note that MyPy doesn't like these very much - the fix
598+
# will be to properly set input types
599+
elif callable(index):
600+
index = index(self.axes[axis])
601+
elif isinstance(index, SupportsIndex):
602+
if abs(int(index)) >= self._hist.axis(axis).size:
603+
raise IndexError("histogram index is out of range")
604+
index %= self._hist.axis(axis).size
605+
606+
return index # type: ignore
607+
585608
def _compute_commonindex(
586609
self, index: IndexingExpr
587610
) -> List[Union[SupportsIndex, slice, Mapping[int, Union[SupportsIndex, slice]]]]:
@@ -613,18 +636,11 @@ def _compute_commonindex(
613636

614637
# Allow [bh.loc(...)] to work
615638
for i in range(len(indexes)):
616-
# Support sum and rebin directly
617-
if indexes[i] is sum or hasattr(indexes[i], "factor"):
618-
indexes[i] = slice(None, None, indexes[i])
619-
# General locators
620-
# Note that MyPy doesn't like these very much - the fix
621-
# will be to properly set input types
622-
elif callable(indexes[i]):
623-
indexes[i] = indexes[i](self.axes[i])
624-
elif hasattr(indexes[i], "__index__"):
625-
if abs(indexes[i]) >= hist.axis(i).size:
626-
raise IndexError("histogram index is out of range")
627-
indexes[i] %= hist.axis(i).size
639+
# Support list of UHI indexers
640+
if isinstance(indexes[i], list):
641+
indexes[i] = [self._compute_uhi_index(index, i) for index in indexes[i]]
642+
else:
643+
indexes[i] = self._compute_uhi_index(indexes[i], i)
628644

629645
return indexes
630646

@@ -729,6 +745,7 @@ def __getitem__( # noqa: C901
729745
integrations: Set[int] = set()
730746
slices: List[_core.algorithm.reduce_command] = []
731747
pick_each: Dict[int, int] = dict()
748+
pick_set: Dict[int, List[int]] = dict()
732749

733750
# Compute needed slices and projections
734751
for i, ind in enumerate(indexes):
@@ -737,6 +754,9 @@ def __getitem__( # noqa: C901
737754
1 if self.axes[i].traits.underflow else 0
738755
)
739756
continue
757+
elif isinstance(ind, collections.abc.Sequence):
758+
pick_set[i] = list(ind)
759+
continue
740760
elif not isinstance(ind, slice):
741761
raise IndexError(
742762
"Must be a slice, an integer, or follow the locator protocol."
@@ -782,17 +802,45 @@ def __getitem__( # noqa: C901
782802
logger.debug("Reduce with %s", slices)
783803
reduced = self._hist.reduce(*slices)
784804

805+
if pick_set:
806+
warnings.warn(
807+
"List indexing selection is experimental. Removed bins are not placed in overflow."
808+
)
809+
logger.debug("Slices for picking sets: %s", pick_set)
810+
axes = [reduced.axis(i) for i in range(reduced.rank())]
811+
reduced_view = reduced.view(flow=True)
812+
for i in pick_set:
813+
selection = copy.copy(pick_set[i])
814+
ax = reduced.axis(i)
815+
if ax.traits_ordered:
816+
raise RuntimeError(
817+
f"Axis {i} is not a categorical axis, cannot pick with list"
818+
)
819+
820+
if ax.traits_overflow and ax.size not in pick_set[i]:
821+
selection.append(ax.size)
822+
823+
new_axis = axes[i].__class__([axes[i].value(j) for j in pick_set[i]])
824+
new_axis.metadata = axes[i].metadata
825+
axes[i] = new_axis
826+
reduced_view = np.take(reduced_view, selection, axis=i)
827+
828+
logger.debug("Axes: %s", axes)
829+
new_reduced = reduced.__class__(axes)
830+
new_reduced.view(flow=True)[...] = reduced_view
831+
reduced = new_reduced
832+
785833
if pick_each:
786-
my_slice = tuple(
834+
tuple_slice = tuple(
787835
pick_each.get(i, slice(None)) for i in range(reduced.rank())
788836
)
789-
logger.debug("Slices: %s", my_slice)
837+
logger.debug("Slices for pick each: %s", tuple_slice)
790838
axes = [
791839
reduced.axis(i) for i in range(reduced.rank()) if i not in pick_each
792840
]
793841
logger.debug("Axes: %s", axes)
794842
new_reduced = reduced.__class__(axes)
795-
new_reduced.view(flow=True)[...] = reduced.view(flow=True)[my_slice]
843+
new_reduced.view(flow=True)[...] = reduced.view(flow=True)[tuple_slice]
796844
reduced = new_reduced
797845
integrations = {i - sum(j <= i for j in pick_each) for i in integrations}
798846

tests/test_internal_histogram.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pytest
33
from numpy.testing import assert_allclose, assert_array_equal
4+
from pytest import approx
45

56
import boost_histogram as bh
67

@@ -258,3 +259,106 @@ def test_int_cat_hist():
258259

259260
with pytest.raises(RuntimeError):
260261
h.fill(0.5)
262+
263+
264+
@pytest.mark.filterwarnings("ignore:List indexing selection is experimental")
265+
def test_int_cat_hist_pick_several():
266+
h = bh.Histogram(
267+
bh.axis.IntCategory([1, 2, 7], __dict__={"xval": 5}), storage=bh.storage.Int64()
268+
)
269+
270+
h.fill(1, weight=8)
271+
h.fill(2, weight=7)
272+
h.fill(7, weight=6)
273+
274+
assert h.view() == approx(np.array([8, 7, 6]))
275+
assert h.sum() == 21
276+
277+
assert h[[0, 2]].values() == approx(np.array([8, 6]))
278+
assert h[[2, 0]].values() == approx(np.array([6, 8]))
279+
assert h[[1]].values() == approx(np.array([7]))
280+
281+
assert h[[bh.loc(1), bh.loc(7)]].values() == approx(np.array([8, 6]))
282+
assert h[[bh.loc(7), bh.loc(1)]].values() == approx(np.array([6, 8]))
283+
assert h[[bh.loc(2)]].values() == approx(np.array([7]))
284+
285+
assert tuple(h[[0, 2]].axes[0]) == (1, 7)
286+
assert tuple(h[[2, 0]].axes[0]) == (7, 1)
287+
assert tuple(h[[1]].axes[0]) == (2,)
288+
289+
assert h.axes[0].xval == 5
290+
assert h[[0]].axes[0].xval == 5
291+
assert h[[0, 1, 2]].axes[0].xval == 5
292+
293+
294+
@pytest.mark.filterwarnings("ignore:List indexing selection is experimental")
295+
def test_str_cat_pick_several():
296+
h = bh.Histogram(bh.axis.StrCategory(["a", "b", "c"]))
297+
298+
h.fill(["a", "a", "a", "b", "b", "c"], weight=0.25)
299+
300+
assert h[[0, 1, 2]].values() == approx(np.array([0.75, 0.5, 0.25]))
301+
assert h[[2, 1, 0]].values() == approx(np.array([0.25, 0.5, 0.75]))
302+
assert h[[1, 0]].values() == approx(np.array([0.5, 0.75]))
303+
304+
assert h[[bh.loc("a"), bh.loc("b"), bh.loc("c")]].values() == approx(
305+
np.array([0.75, 0.5, 0.25])
306+
)
307+
assert h[[bh.loc("c"), bh.loc("b"), bh.loc("a")]].values() == approx(
308+
np.array([0.25, 0.5, 0.75])
309+
)
310+
assert h[[bh.loc("b"), bh.loc("a")]].values() == approx(np.array([0.5, 0.75]))
311+
312+
assert tuple(h[[1, 0]].axes[0]) == ("b", "a")
313+
314+
315+
@pytest.mark.filterwarnings("ignore:List indexing selection is experimental")
316+
def test_pick_invalid():
317+
h = bh.Histogram(bh.axis.Regular(10, 0, 1))
318+
with pytest.raises(RuntimeError):
319+
h[[0, 1]]
320+
321+
h = bh.Histogram(bh.axis.Integer(0, 10))
322+
with pytest.raises(RuntimeError):
323+
h[[0, 1]]
324+
325+
326+
@pytest.mark.filterwarnings("ignore:List indexing selection is experimental")
327+
def test_str_cat_pick_dual():
328+
h = bh.Histogram(
329+
bh.axis.StrCategory(["a", "b", "c"]), bh.axis.StrCategory(["d", "e", "f"])
330+
)
331+
vals = np.arange(9).reshape(3, 3)
332+
h.values()[...] = vals
333+
334+
assert h[[0], [0]].values() == approx(vals[[0]][:, [0]])
335+
assert h[[1], [2]].values() == approx(vals[[1]][:, [2]])
336+
assert h[[1], [0, 1]].values() == approx(vals[[1]][:, [0, 1]])
337+
assert h[[0, 1], [1]].values() == approx(vals[[0, 1]][:, [1]])
338+
assert h[[0, 1], [0, 1]].values() == approx(vals[[0, 1]][:, [0, 1]])
339+
assert h[[0, 1], [2, 1]].values() == approx(vals[[0, 1]][:, [2, 1]])
340+
341+
342+
@pytest.mark.filterwarnings("ignore:List indexing selection is experimental")
343+
def test_pick_multiaxis():
344+
h = bh.Histogram(
345+
bh.axis.StrCategory(["a", "b", "c"]),
346+
bh.axis.IntCategory([-5, 0, 10]),
347+
bh.axis.Regular(5, 0, 1),
348+
bh.axis.StrCategory(["d", "e", "f"]),
349+
storage=bh.storage.Int64(),
350+
)
351+
352+
h.fill("b", -5, 0.65, "f")
353+
h.fill("b", -5, 0.65, "e")
354+
355+
mini = h[[bh.loc("b"), 2], [1, bh.loc(-5)], sum, bh.loc("f")]
356+
357+
assert mini.ndim == 2
358+
assert tuple(mini.axes[0]) == ("b", "c")
359+
assert tuple(mini.axes[1]) == (0, -5)
360+
361+
assert h[[1, 2], [0, 1], sum, bh.loc("f")].sum() == 1
362+
assert h[[1, 2], [1, 0], sum, bh.loc("f")].sum() == 1
363+
364+
assert mini.values() == approx(np.array(((0, 1), (0, 0))))

0 commit comments

Comments
 (0)