Skip to content

Commit f02517e

Browse files
committed
Replace apply_custom_transform() implementation using make_caster<std::function<raw_t>>
1 parent 4c57618 commit f02517e

File tree

2 files changed

+8
-21
lines changed

2 files changed

+8
-21
lines changed

tests/test_callbacks.cpp

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,30 +20,16 @@ namespace boost_histogram { // See PR #5580
2020
double custom_transform_double(double value) { return value * 3; }
2121
int custom_transform_int(int value) { return value; }
2222

23-
// Derived from
23+
// Originally derived from
2424
// https://github.com/scikit-hep/boost-histogram/blob/460ef90905d6a8a9e6dd3beddfe7b4b49b364579/include/bh_python/transform.hpp#L68-L85
2525
double apply_custom_transform(const py::object &src, double value) {
2626
using raw_t = double(double);
2727

28-
auto func = py::reinterpret_borrow<py::function>(src);
29-
30-
if (auto cfunc = func.cpp_function()) {
31-
auto c = py::reinterpret_borrow<py::capsule>(PyCFunction_GET_SELF(cfunc.ptr()));
32-
33-
auto *rec = c.get_pointer<py::detail::function_record>();
34-
35-
if (rec && rec->is_stateless
36-
&& py::detail::same_type(typeid(raw_t *),
37-
*reinterpret_cast<const std::type_info *>(rec->data[1]))) {
38-
struct capture {
39-
raw_t *f;
40-
};
41-
auto *cap = reinterpret_cast<capture *>(&rec->data);
42-
return (*cap->f)(value);
43-
}
44-
return -200;
28+
py::detail::make_caster<std::function<raw_t>> func_caster;
29+
if (!func_caster.load(src, /*convert*/ false)) {
30+
return -100;
4531
}
46-
return -100;
32+
return static_cast<std::function<raw_t> &>(func_caster)(value);
4733
}
4834

4935
} // namespace boost_histogram

tests/test_callbacks.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,9 @@ def test_boost_histogram_apply_custom_transform():
241241
cti = m.boost_histogram_custom_transform_int
242242
apply = m.boost_histogram_apply_custom_transform
243243
assert apply(ctd, 5) == 15
244-
assert apply(cti, 0) == -200
244+
with pytest.raises(TypeError):
245+
assert apply(cti, 0)
245246
assert apply(None, 0) == -100
246-
assert apply(lambda value: value, 0) == -100
247+
assert apply(lambda value: value * 10, 4) == 40
247248
assert apply({}, 0) == -100
248249
assert apply("", 0) == -100

0 commit comments

Comments
 (0)