Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 2 additions & 31 deletions include/bh_python/transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,39 +58,10 @@ struct func_transform {
return std::make_tuple(ptr, src);
}

// If we made it to this point, we probably have a C++ pybind object or an
// invalid object. The following is based on the std::function conversion in
// pybind11/functional.hpp
if(!py::isinstance<py::function>(src))
throw py::type_error("Only ctypes double(double) and C++ functions allowed "
"(must be function)");

auto func = py::reinterpret_borrow<py::function>(src);

if(auto cfunc = func.cpp_function()) {
auto c = py::reinterpret_borrow<py::capsule>(
PyCFunction_GET_SELF(cfunc.ptr()));

auto rec = c.get_pointer<py::detail::function_record>();

if(rec && rec->is_stateless
&& py::detail::same_type(
typeid(raw_t*),
*reinterpret_cast<const std::type_info*>(rec->data[1]))) {
struct capture {
raw_t* f;
};
return std::make_tuple((reinterpret_cast<capture*>(&rec->data))->f,
src);
}

// Note that each error is slightly different just to help with debugging
throw py::type_error("Only ctypes double(double) and C++ functions allowed "
"(must be stateless)");
}
// If we made it to this point, we probably have an invalid object.

throw py::type_error("Only ctypes double(double) and C++ functions allowed "
"(must be cpp function)");
"(must be a stateless cpp function)");
}

func_transform(py::object f, py::object i, py::object c, py::str n)
Expand Down
7 changes: 6 additions & 1 deletion src/boost_histogram/axis/transform.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from __future__ import annotations

import copy
import ctypes
from typing import Any, ClassVar, TypeVar

import boost_histogram

from .. import _core
from .._core import axis as ca
from .._utils import register

T = TypeVar("T", bound="AxisTransform")

__all__ = ["AxisTransform", "Function", "Pow", "log", "sqrt"]

LIB = ctypes.CDLL(_core.__file__)


def __dir__() -> list[str]:
return __all__
Expand Down Expand Up @@ -150,7 +154,8 @@ def _produce(self, bins: int, start: float, stop: float) -> Any:


def _internal_conversion(name: str) -> Any:
return getattr(ca.transform, name)
ftype = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double)
return ctypes.cast(getattr(LIB, name), ftype)


sqrt = Function("_sqrt_fn", "_sq_fn", convert=_internal_conversion, name="sqrt")
Expand Down
8 changes: 4 additions & 4 deletions src/register_transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ py::class_<T> register_transform(py::module& mod, Args&&... args) {
}

extern "C" {
double _log_fn(double v) { return std::log(v); }
double _exp_fn(double v) { return std::exp(v); }
double _sqrt_fn(double v) { return std::sqrt(v); }
double _sq_fn(double v) { return v * v; }
PYBIND11_EXPORT double _log_fn(double v) { return std::log(v); }
PYBIND11_EXPORT double _exp_fn(double v) { return std::exp(v); }
PYBIND11_EXPORT double _sqrt_fn(double v) { return std::sqrt(v); }
PYBIND11_EXPORT double _sq_fn(double v) { return v * v; }
}

void register_transforms(py::module& mod) {
Expand Down
Loading