Skip to content

Commit 0be96cc

Browse files
committed
fix: AxesTuple constuctor more strict
1 parent 928d066 commit 0be96cc

File tree

3 files changed

+52
-5
lines changed

3 files changed

+52
-5
lines changed

docs/CHANGELOG.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,18 @@
22

33
## Version 1.2
44

5+
### Version 1.2.2
6+
7+
#### User changes
8+
9+
* PyPy 3.8 now supported. [#677][]
10+
* The GIL is released a little more often now. [#662][]
11+
* AxesTuple does not allow construction of non-axes. [#680][]
12+
13+
[#662]: https://github.com/scikit-hep/boost-histogram/pull/662
14+
[#677]: https://github.com/scikit-hep/boost-histogram/pull/677
15+
[#680]: https://github.com/scikit-hep/boost-histogram/pull/680
16+
517
### Version 1.2.1
618

719
#### User changes

src/boost_histogram/_internal/axestuple.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import partial
2-
from typing import Any, List, Tuple, TypeVar
2+
from typing import Any, Iterable, List, Tuple, Type, TypeVar
33

44
import numpy as np
55

@@ -45,6 +45,15 @@ class AxesTuple(tuple): # type: ignore[type-arg]
4545
__slots__ = ()
4646
_MGRIDOPTS = {"sparse": True, "indexing": "ij"}
4747

48+
def __new__(cls: Type[B], __iterable: Iterable[Axis]) -> B:
49+
self = super().__new__(cls, __iterable) # type: ignore[arg-type]
50+
for item in self:
51+
if not isinstance(item, Axis):
52+
raise TypeError(
53+
f"Only an iterable of Axis supported in AxesTuple, got {item}"
54+
)
55+
return self
56+
4857
@property
4958
def size(self) -> Tuple[int, ...]:
5059
return tuple(s.size for s in self)
@@ -93,14 +102,15 @@ def __getitem__(self, item: Any) -> Any:
93102
result = super().__getitem__(item)
94103
return self.__class__(result) if isinstance(result, tuple) else result
95104

96-
def __getattr__(self, attr: str) -> Any:
97-
return self.__class__(getattr(s, attr) for s in self)
105+
def __getattr__(self, attr: str) -> Tuple[Any, ...]:
106+
return tuple(getattr(s, attr) for s in self)
98107

99108
def __setattr__(self, attr: str, values: Any) -> None:
100109
try:
101110
return super().__setattr__(attr, values)
102111
except AttributeError:
103-
self.__class__(s.__setattr__(attr, v) for s, v in zip_strict(self, values))
112+
for s, v in zip_strict(self, values):
113+
s.__setattr__(attr, v)
104114

105115
value.__doc__ = Axis.value.__doc__
106116
index.__doc__ = Axis.index.__doc__

tests/test_axes_object.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
11
import numpy as np
2+
import pytest
23

34
import boost_histogram as bh
45

56

6-
def test_axes_all_at_once():
7+
@pytest.fixture
8+
def h():
9+
return bh.Histogram(
10+
bh.axis.Regular(10, 0, 10, metadata=2),
11+
bh.axis.Integer(0, 5, metadata="hi"),
12+
bh.axis.StrCategory(["HI", "HO"]),
13+
)
14+
15+
16+
def test_axes_basics(h):
717
h = bh.Histogram(
818
bh.axis.Regular(10, 0, 10, metadata=2),
919
bh.axis.Integer(0, 5, metadata="hi"),
@@ -22,6 +32,8 @@ def test_axes_all_at_once():
2232

2333
assert h.axes.metadata == (None, 3, "bye")
2434

35+
36+
def test_axes_centers(h):
2537
centers = h.axes.centers
2638
answers = np.ogrid[0.5:10, 0.5:5, 0.5:2]
2739
full_answers = np.mgrid[0.5:10, 0.5:5, 0.5:2]
@@ -33,6 +45,8 @@ def test_axes_all_at_once():
3345
np.testing.assert_allclose(centers.flatten()[i], answers[i].flatten())
3446
np.testing.assert_allclose(h.axes[i].centers, answers[i].ravel())
3547

48+
49+
def test_axes_edges(h):
3650
edges = h.axes.edges
3751
answers = np.ogrid[0:11, 0:6, 0:3]
3852
full_answers = np.mgrid[0:11, 0:6, 0:3]
@@ -44,6 +58,8 @@ def test_axes_all_at_once():
4458
np.testing.assert_allclose(edges.ravel()[i], answers[i].ravel())
4559
np.testing.assert_allclose(h.axes[i].edges, answers[i].ravel())
4660

61+
62+
def test_axes_widths(h):
4763
widths = h.axes.widths
4864
answers = np.ogrid[1:1:10j, 1:1:5j, 1:1:2j]
4965
full_answers = np.mgrid[1:1:10j, 1:1:5j, 1:1:2j]
@@ -54,3 +70,12 @@ def test_axes_all_at_once():
5470
np.testing.assert_allclose(widths.T[i], answers[i].T)
5571
np.testing.assert_allclose(widths.ravel()[i], answers[i].ravel())
5672
np.testing.assert_allclose(h.axes[i].widths, answers[i].ravel())
73+
74+
75+
def test_axis_misconstuct():
76+
inp = [bh.axis.Regular(12, 0, 1)]
77+
ok = bh.axis.AxesTuple(inp)
78+
assert ok[0] == inp[0]
79+
80+
with pytest.raises(TypeError):
81+
bh.axis.AxesTuple(inp[0])

0 commit comments

Comments
 (0)