Skip to content

Commit 7fe70ec

Browse files
committed
Revamp decorator system, now could be useful in other packages
1 parent 917c288 commit 7fe70ec

File tree

8 files changed

+380
-197
lines changed

8 files changed

+380
-197
lines changed

boost_histogram/_internal/axis.py

Lines changed: 222 additions & 116 deletions
Large diffs are not rendered by default.

boost_histogram/_internal/axis_transform.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44

55
from .._core import axis as ca
66

7-
from .utils import register
7+
from .utils import register, set_family, MAIN_FAMILY, set_module
88

99

10+
@set_module("boost_histogram.axis.transform")
1011
class AxisTransform(object):
1112
__slots__ = ()
1213

@@ -17,19 +18,25 @@ def _produce(self, bins, start, stop, metadata):
1718
return self.__class__._type(bins, start, stop, metadata)
1819

1920

20-
@register(ca.transform.log)
21+
@set_family(MAIN_FAMILY)
22+
@register({ca.transform.log})
23+
@set_module("boost_histogram.axis.transform")
2124
class Log(ca.transform.log, AxisTransform):
2225
__slots__ = ()
2326
_type = ca.regular_log
2427

2528

26-
@register(ca.transform.sqrt)
29+
@register({ca.transform.sqrt})
30+
@set_family(MAIN_FAMILY)
31+
@set_module("boost_histogram.axis.transform")
2732
class Sqrt(ca.transform.sqrt, AxisTransform):
2833
__slots__ = ()
2934
_type = ca.regular_sqrt
3035

3136

32-
@register(ca.transform.pow)
37+
@register({ca.transform.pow})
38+
@set_family(MAIN_FAMILY)
39+
@set_module("boost_histogram.axis.transform")
3340
class Pow(ca.transform.pow, AxisTransform):
3441
__slots__ = ()
3542
_type = ca.regular_pow

boost_histogram/_internal/hist.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .axistuple import AxesTuple
99
from .sig_tools import inject_signature
1010
from .storage import Double, Storage
11-
from .utils import cast
11+
from .utils import cast, register, set_family, MAIN_FAMILY, CPP_FAMILY, set_module
1212

1313
import warnings
1414
import numpy as np
@@ -53,7 +53,7 @@ def _expand_ellipsis(indexes, rank):
5353
raise IndexError("an index can only have a single ellipsis ('...')")
5454

5555

56-
def _compute_commonindex(hist, index, expand_ellipsis):
56+
def _compute_commonindex(self, hist, index, expand_ellipsis):
5757
# Normalize -> h[i] == h[i,]
5858
if not isinstance(index, tuple):
5959
index = (index,)
@@ -70,7 +70,7 @@ def _compute_commonindex(hist, index, expand_ellipsis):
7070
# Allow [bh.loc(...)] to work
7171
for i in range(len(indexes)):
7272
if callable(indexes[i]):
73-
indexes[i] = indexes[i](cast(hist.axis(i), Axis, cpp=False))
73+
indexes[i] = indexes[i](cast(self, hist.axis(i), Axis))
7474
elif hasattr(indexes[i], "flow"):
7575
if indexes[i].flow == 1:
7676
indexes[i] = hist.axis(i).size
@@ -84,6 +84,9 @@ def _compute_commonindex(hist, index, expand_ellipsis):
8484
return indexes
8585

8686

87+
# We currently do not cast *to* a histogram, but this is consistent
88+
# and needed to be able to cast *from* a histogram method.
89+
@register(_histograms)
8790
class BaseHistogram(object):
8891
@inject_signature("self, *axes, storage=Double()", locals={"Double": Double})
8992
def __init__(self, *axes, **kwargs):
@@ -200,17 +203,17 @@ def _axis(self, i):
200203
"""
201204
Get N-th axis.
202205
"""
203-
return cast(self._hist.axis(i), Axis, self._cpp_module)
206+
return cast(self, self._hist.axis(i), Axis)
204207

205208
@property
206209
def _storage_type(self):
207-
return cast(
208-
self._hist._storage_type, Storage, cpp=self._cpp_module, is_class=True
209-
)
210+
return cast(self, self._hist._storage_type, Storage)
210211

211212

213+
# C++ version of histogram
214+
@set_family(CPP_FAMILY)
215+
@set_module("boost_histogram.cpp")
212216
class histogram(BaseHistogram):
213-
_cpp_module = True
214217
axis = BaseHistogram._axis
215218

216219
def rank(self):
@@ -259,9 +262,9 @@ def _project(self, *args):
259262
return self.__class__(self._hist.project(*args))
260263

261264

265+
@set_family(MAIN_FAMILY)
266+
@set_module("boost_histogram")
262267
class Histogram(BaseHistogram):
263-
_cpp_module = False
264-
265268
@inject_signature("self, *axes, storage=Double()", locals={"Double": Double})
266269
def __init__(self, *args, **kwargs):
267270
super(Histogram, self).__init__(*args, **kwargs)
@@ -339,7 +342,7 @@ def size(self):
339342

340343
def __getitem__(self, index):
341344

342-
indexes = _compute_commonindex(self._hist, index, expand_ellipsis=True)
345+
indexes = _compute_commonindex(self, self._hist, index, expand_ellipsis=True)
343346

344347
# If this is (now) all integers, return the bin contents
345348
try:
@@ -411,19 +414,22 @@ def __getitem__(self, index):
411414
)
412415

413416
def __setitem__(self, index, value):
414-
indexes = _compute_commonindex(self._hist, index, expand_ellipsis=False)
417+
indexes = _compute_commonindex(self, self._hist, index, expand_ellipsis=False)
415418
self._hist._at_set(value, *indexes)
416419

417420
def reduce(self, *args):
418421
"""
419-
Reduce based on one or more reduce_option's. If you are operating on most or all of your axis, consider slicing with [] notation.
422+
Reduce based on one or more reduce_option's. If you are operating on most
423+
or all of your axis, consider slicing with [] notation.
420424
"""
421425

422426
return self.__class__(self._hist.reduce(*args))
423427

424428
def project(self, *args):
425429
"""
426-
Project to a single axis or several axes on a multidiminsional histogram. Provided a list of axis numbers, this will produce the histogram over those axes only. Flow bins are used if available.
430+
Project to a single axis or several axes on a multidiminsional histogram.
431+
Provided a list of axis numbers, this will produce the histogram over
432+
those axes only. Flow bins are used if available.
427433
"""
428434

429435
return self.__class__(self._hist.project(*args))

boost_histogram/_internal/storage.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
del absolute_import, division, print_function
44

55
from .._core import storage as store
6-
from .utils import register
6+
from .utils import register, set_family, MAIN_FAMILY, set_module
77

88
# Simple mixin to provide a common base class for types
99
# and nice reprs
@@ -12,36 +12,50 @@ def __repr__(self):
1212
return "{self.__class__.__name__}()".format(self=self)
1313

1414

15-
@register(store.int)
15+
@register({store.int})
16+
@set_family(MAIN_FAMILY)
17+
@set_module("boost_histogram.storage")
1618
class Int(store.int, Storage):
1719
pass
1820

1921

20-
@register(store.double)
22+
@register({store.double})
23+
@set_family(MAIN_FAMILY)
24+
@set_module("boost_histogram.storage")
2125
class Double(store.double, Storage):
2226
pass
2327

2428

25-
@register(store.atomic_int)
29+
@register({store.atomic_int})
30+
@set_family(MAIN_FAMILY)
31+
@set_module("boost_histogram.storage")
2632
class AtomicInt(store.atomic_int, Storage):
2733
pass
2834

2935

30-
@register(store.unlimited)
36+
@register({store.unlimited})
37+
@set_family(MAIN_FAMILY)
38+
@set_module("boost_histogram.storage")
3139
class Unlimited(store.unlimited, Storage):
3240
pass
3341

3442

35-
@register(store.weight)
43+
@register({store.weight})
44+
@set_family(MAIN_FAMILY)
45+
@set_module("boost_histogram.storage")
3646
class Weight(store.weight, Storage):
3747
pass
3848

3949

40-
@register(store.mean)
50+
@register({store.mean})
51+
@set_family(MAIN_FAMILY)
52+
@set_module("boost_histogram.storage")
4153
class Mean(store.mean, Storage):
4254
pass
4355

4456

45-
@register(store.weighted_mean)
57+
@register({store.weighted_mean})
58+
@set_family(MAIN_FAMILY)
59+
@set_module("boost_histogram.storage")
4660
class WeightedMean(store.weighted_mean, Storage):
4761
pass

boost_histogram/_internal/utils.py

Lines changed: 102 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,76 +2,151 @@
22

33
del absolute_import, division, print_function
44

5+
# Custom families (other packages can define custom families)
6+
MAIN_FAMILY = object() # This family will be used as a fallback
7+
CPP_FAMILY = object()
58

6-
def cpp_module(cls):
9+
# These are not exported because custom user classes do not need to
10+
# add to the original families, they should make their own.
11+
12+
13+
def set_module(name):
14+
"""
15+
Set the __module__ attribute on a class. Very
16+
similar to numpy.core.overrides.set_module.
17+
"""
18+
19+
def add_module(cls):
20+
cls.__module__ = name
21+
return cls
22+
23+
return add_module
24+
25+
26+
def set_family(family):
727
"""
8-
Simple decorator to declare a class to be in the cpp
9-
module.
28+
Decorator to set the family of a class.
1029
"""
11-
cls._cpp_module = True
12-
return cls
1330

31+
def add_family(cls):
32+
cls._family = family
33+
return cls
34+
35+
return add_family
1436

15-
def register(*args):
37+
38+
def register(cpp_types=None):
1639
"""
1740
Decorator to register a C++ type to a Python class.
1841
Each class given will be added to a lookup list "_types"
19-
that cast knows about.
42+
that cast knows about. It should also part of a "family",
43+
and any class in a family will cast to the same family.
44+
See set_family.
45+
46+
For example, internally this call:
47+
48+
ax = hist._axis(0)
49+
50+
which will get a raw C++ object and need to cast it to a Python
51+
wrapped object. There will be at least two candidates (users
52+
could add more): MAIN_FAMILY and CPP_FAMILY. Cast will use the
53+
parent class's family to return the correct family. If the
54+
requested family is not found, then the regular family is the
55+
fallback.
2056
2157
This decorator, like other decorators in boost-histogram,
2258
is safe for pickling since it does not replace the
2359
original class.
60+
61+
If nothing or an empty set is passed, this will ensure that this
62+
class is not selected during the cast process. This can be
63+
used for simple renamed classes that inject warnings, etc.
2464
"""
2565

2666
def add_registration(cls):
27-
if len(args) == 0:
67+
if cpp_types is None or len(cpp_types) == 0:
2868
cls._types = set()
2969
return cls
3070

3171
if not hasattr(cls, "_types"):
3272
cls._types = set()
3373

34-
for cpp_type in args:
74+
for cpp_type in cpp_types:
3575
if cpp_type in cls._types:
3676
raise TypeError("You are trying to register {} again".format(cpp_type))
3777

3878
cls._types.add(cpp_type)
39-
if not hasattr(cls, "_cpp_module"):
40-
cls._cpp_module = False
41-
return cls
79+
80+
return cls
4281

4382
return add_registration
4483

4584

46-
def cast(cpp_object, parent_class, cpp=False, is_class=False):
85+
def _cast_make_object(canidate_class, cpp_object, is_class):
86+
"Make an object for cast"
87+
if is_class:
88+
return canidate_class
89+
elif hasattr(canidate_class, "_convert_cpp"):
90+
return canidate_class._convert_cpp(cpp_object)
91+
else:
92+
return canidate_class(cpp_object)
93+
94+
95+
def cast(self, cpp_object, parent_class):
4796
"""
4897
This converts a C++ object into a Python object.
49-
This takes the parent Python class, and an optional
50-
base parameter, which will only return classes that
51-
are in the base module.
52-
53-
Can also return the class directly with find_class=True.
98+
This takes the parent object, the C++ object,
99+
the Python class. If a class is passed in instead of
100+
an object, this will return a class instead. The parent
101+
object (self) can be either a registered class or an
102+
instance of a registered class.
54103
55104
If a class does not support direction conversion in
56105
the constructor, it should have _convert_cpp class
57106
method instead.
58107
59-
cpp setting must match the register setting.
108+
Example:
109+
110+
cast(self, hist.cpp_axis(), Axis)
111+
# -> returns Regular(...) if regular axis, etc.
112+
113+
If self is None, just use the MAIN_FAMILY.
60114
"""
61-
cpp_class = cpp_object if is_class else cpp_object.__class__
115+
if self is None:
116+
family = MAIN_FAMILY
117+
else:
118+
family = self._family
119+
120+
# Convert objects to classes, and remember if we did so
121+
if isinstance(cpp_object, type):
122+
is_class = True
123+
cpp_class = cpp_object
124+
else:
125+
is_class = False
126+
cpp_class = cpp_object.__class__
127+
128+
# Remember the fallback class if a class in the same family does not exist
129+
fallback_class = None
62130

63131
for canidate_class in _walk_subclasses(parent_class):
132+
# If a class was registered with this c++ type
64133
if (
65134
hasattr(canidate_class, "_types")
66135
and cpp_class in canidate_class._types
67-
and canidate_class._cpp_module == cpp
136+
and hasattr(canidate_class, "_family")
68137
):
69-
if is_class:
70-
return canidate_class
71-
elif hasattr(canidate_class, "_convert_cpp"):
72-
return canidate_class._convert_cpp(cpp_object)
73-
else:
74-
return canidate_class(cpp_object)
138+
# Return immediately if the family is right
139+
if canidate_class._family is family:
140+
return _cast_make_object(canidate_class, cpp_object, is_class)
141+
142+
# Or remember the class if it was from the main family
143+
if canidate_class._family is MAIN_FAMILY:
144+
fallback_class = canidate_class
145+
146+
# If no perfect match was registered, return the main family
147+
if fallback_class is not None:
148+
return _cast_make_object(fallback_class, cpp_object, is_class)
149+
75150
raise TypeError(
76151
"No conversion to {0} from {1} found.".format(parent_class.__name__, cpp_object)
77152
)

0 commit comments

Comments
 (0)