diff --git a/include/bh_python/transform.hpp b/include/bh_python/transform.hpp index 17f8e483e..2d05a4be0 100644 --- a/include/bh_python/transform.hpp +++ b/include/bh_python/transform.hpp @@ -58,39 +58,10 @@ struct func_transform { return std::make_tuple(ptr, src); } - // If we made it to this point, we probably have a C++ pybind object or an - // invalid object. The following is based on the std::function conversion in - // pybind11/functional.hpp - if(!py::isinstance(src)) - throw py::type_error("Only ctypes double(double) and C++ functions allowed " - "(must be function)"); - - auto func = py::reinterpret_borrow(src); - - if(auto cfunc = func.cpp_function()) { - auto c = py::reinterpret_borrow( - PyCFunction_GET_SELF(cfunc.ptr())); - - auto rec = c.get_pointer(); - - if(rec && rec->is_stateless - && py::detail::same_type( - typeid(raw_t*), - *reinterpret_cast(rec->data[1]))) { - struct capture { - raw_t* f; - }; - return std::make_tuple((reinterpret_cast(&rec->data))->f, - src); - } - - // Note that each error is slightly different just to help with debugging - throw py::type_error("Only ctypes double(double) and C++ functions allowed " - "(must be stateless)"); - } + // If we made it to this point, we probably have an invalid object. throw py::type_error("Only ctypes double(double) and C++ functions allowed " - "(must be cpp function)"); + "(must be a stateless cpp function)"); } func_transform(py::object f, py::object i, py::object c, py::str n) diff --git a/src/boost_histogram/axis/transform.py b/src/boost_histogram/axis/transform.py index 20dca3693..8afd396a3 100644 --- a/src/boost_histogram/axis/transform.py +++ b/src/boost_histogram/axis/transform.py @@ -1,10 +1,12 @@ from __future__ import annotations import copy +import ctypes from typing import Any, ClassVar, TypeVar import boost_histogram +from .. import _core from .._core import axis as ca from .._utils import register @@ -12,6 +14,8 @@ __all__ = ["AxisTransform", "Function", "Pow", "log", "sqrt"] +LIB = ctypes.CDLL(_core.__file__) + def __dir__() -> list[str]: return __all__ @@ -150,7 +154,8 @@ def _produce(self, bins: int, start: float, stop: float) -> Any: def _internal_conversion(name: str) -> Any: - return getattr(ca.transform, name) + ftype = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double) + return ctypes.cast(getattr(LIB, name), ftype) sqrt = Function("_sqrt_fn", "_sq_fn", convert=_internal_conversion, name="sqrt") diff --git a/src/register_transforms.cpp b/src/register_transforms.cpp index 7abe758b1..a7a37d841 100644 --- a/src/register_transforms.cpp +++ b/src/register_transforms.cpp @@ -32,10 +32,10 @@ py::class_ register_transform(py::module& mod, Args&&... args) { } extern "C" { -double _log_fn(double v) { return std::log(v); } -double _exp_fn(double v) { return std::exp(v); } -double _sqrt_fn(double v) { return std::sqrt(v); } -double _sq_fn(double v) { return v * v; } +PYBIND11_EXPORT double _log_fn(double v) { return std::log(v); } +PYBIND11_EXPORT double _exp_fn(double v) { return std::exp(v); } +PYBIND11_EXPORT double _sqrt_fn(double v) { return std::sqrt(v); } +PYBIND11_EXPORT double _sq_fn(double v) { return v * v; } } void register_transforms(py::module& mod) {