Skip to content

Commit 8ddf463

Browse files
committed
fix: slicing supports under/overflow
1 parent 8fead07 commit 8ddf463

File tree

3 files changed

+56
-14
lines changed

3 files changed

+56
-14
lines changed

docs/CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
## UPCOMING
44

5-
No changes yet.
5+
* Fix "picking" on a flow bin [#576][]
6+
7+
[#576]: https://github.com/scikit-hep/boost-histogram/pull/576
8+
69

710
## Version 1.0
811

src/boost_histogram/_internal/hist.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -726,18 +726,22 @@ def __getitem__( # noqa: C901
726726
except RuntimeError:
727727
pass
728728

729-
integrations = set()
730-
slices = []
729+
integrations: Set[int] = set()
730+
slices: List[_core.algorithm.reduce_command] = []
731+
pick_each: Dict[int, int] = dict()
731732

732733
# Compute needed slices and projections
733734
for i, ind in enumerate(indexes):
734735
if hasattr(ind, "__index__"):
735-
ind = slice(ind.__index__(), ind.__index__() + 1, sum) # type: ignore
736-
736+
pick_each[i] = ind.__index__() + ( # type: ignore
737+
1 if self.axes[i].traits.underflow else 0
738+
)
739+
continue
737740
elif not isinstance(ind, slice):
738741
raise IndexError(
739742
"Must be a slice, an integer, or follow the locator protocol."
740743
)
744+
741745
# If the dictionary brackets are forgotten, it's easy to put a slice
742746
# into a slice - adding a nicer error message in that case
743747
if any(isinstance(v, slice) for v in (ind.start, ind.stop, ind.step)):
@@ -778,15 +782,25 @@ def __getitem__( # noqa: C901
778782
logger.debug("Reduce with %s", slices)
779783
reduced = self._hist.reduce(*slices)
780784

781-
if not integrations:
782-
return self._new_hist(reduced)
783-
projections = [i for i in range(self.ndim) if i not in integrations]
784-
785-
return (
786-
self._new_hist(reduced.project(*projections))
787-
if projections
788-
else reduced.sum(flow=True)
789-
)
785+
if pick_each:
786+
my_slice = tuple(
787+
pick_each.get(i, slice(None)) for i in range(reduced.rank())
788+
)
789+
logger.debug("Slices: %s", my_slice)
790+
axes = [
791+
reduced.axis(i) for i in range(reduced.rank()) if i not in pick_each
792+
]
793+
logger.debug("Axes: %s", axes)
794+
new_reduced = reduced.__class__(axes)
795+
new_reduced.view(flow=True)[...] = reduced.view(flow=True)[my_slice]
796+
reduced = new_reduced
797+
integrations = {i - sum(j <= i for j in pick_each) for i in integrations}
798+
799+
if integrations:
800+
projections = [i for i in range(reduced.rank()) if i not in integrations]
801+
reduced = reduced.project(*projections)
802+
803+
return self._new_hist(reduced) if reduced.rank() > 0 else reduced.sum(flow=True)
790804

791805
def __setitem__(
792806
self, index: IndexingExpr, value: Union[ArrayLike, Accumulator]

tests/test_histogram_indexing.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pytest
33
from numpy.testing import assert_array_equal
4+
from pytest import approx
45

56
import boost_histogram as bh
67

@@ -351,6 +352,30 @@ def test_pick_int_category():
351352
assert_array_equal(h[:, :, bh.loc(7)].view(), 0)
352353

353354

355+
@pytest.mark.parametrize(
356+
"ax",
357+
[bh.axis.Regular(3, 0, 1), bh.axis.Variable([0, 0.3, 0.6, 1])],
358+
ids=["regular", "variable"],
359+
)
360+
def test_pick_flowbin(ax):
361+
w = 1e-2 # e.g. a cross section for a process
362+
x = [-0.1, -0.1, 0.1, 0.1, 0.1]
363+
y = [-0.1, 0.1, -0.1, -0.1, 0.1]
364+
365+
h = bh.Histogram(
366+
ax,
367+
ax,
368+
storage=bh.storage.Weight(),
369+
)
370+
h.fill(x, y, weight=w)
371+
372+
uf_slice = h[bh.tag.underflow, ...]
373+
assert uf_slice.values(flow=True) == approx(np.array([1, 1, 0, 0, 0]) * w)
374+
375+
uf_slice = h[..., bh.tag.underflow]
376+
assert uf_slice.values(flow=True) == approx(np.array([1, 2, 0, 0, 0]) * w)
377+
378+
354379
def test_axes_tuple():
355380
h = bh.Histogram(bh.axis.Regular(10, 0, 1))
356381
assert isinstance(h.axes[:1], bh._internal.axestuple.AxesTuple)

0 commit comments

Comments
 (0)