Skip to content

Commit e07f758

Browse files
aldanorjagerman
authored andcommitted
Implicit conversions to bool + np.bool_ conversion (#925)
This adds support for implicit conversions to bool from Python types with `__bool__` (Python 3) or `__nonzero__` (Python 2) attributes, and adds direct (i.e. non-converting) support for numpy bools.
1 parent a03408c commit e07f758

File tree

4 files changed

+92
-2
lines changed

4 files changed

+92
-2
lines changed

include/pybind11/cast.h

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,11 +1049,37 @@ template <> class type_caster<std::nullptr_t> : public void_caster<std::nullptr_
10491049

10501050
template <> class type_caster<bool> {
10511051
public:
1052-
bool load(handle src, bool) {
1052+
bool load(handle src, bool convert) {
10531053
if (!src) return false;
10541054
else if (src.ptr() == Py_True) { value = true; return true; }
10551055
else if (src.ptr() == Py_False) { value = false; return true; }
1056-
else return false;
1056+
else if (convert || !strcmp("numpy.bool_", Py_TYPE(src.ptr())->tp_name)) {
1057+
// (allow non-implicit conversion for numpy booleans)
1058+
1059+
Py_ssize_t res = -1;
1060+
if (src.is_none()) {
1061+
res = 0; // None is implicitly converted to False
1062+
}
1063+
#if defined(PYPY_VERSION)
1064+
// On PyPy, check that "__bool__" (or "__nonzero__" on Python 2.7) attr exists
1065+
else if (hasattr(src, PYBIND11_BOOL_ATTR)) {
1066+
res = PyObject_IsTrue(src.ptr());
1067+
}
1068+
#else
1069+
// Alternate approach for CPython: this does the same as the above, but optimized
1070+
// using the CPython API so as to avoid an unneeded attribute lookup.
1071+
else if (auto tp_as_number = src.ptr()->ob_type->tp_as_number) {
1072+
if (PYBIND11_NB_BOOL(tp_as_number)) {
1073+
res = (*PYBIND11_NB_BOOL(tp_as_number))(src.ptr());
1074+
}
1075+
}
1076+
#endif
1077+
if (res == 0 || res == 1) {
1078+
value = (bool) res;
1079+
return true;
1080+
}
1081+
}
1082+
return false;
10571083
}
10581084
static handle cast(bool src, return_value_policy /* policy */, handle /* parent */) {
10591085
return handle(src ? Py_True : Py_False).inc_ref();

include/pybind11/common.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,11 @@
152152
#define PYBIND11_SLICE_OBJECT PyObject
153153
#define PYBIND11_FROM_STRING PyUnicode_FromString
154154
#define PYBIND11_STR_TYPE ::pybind11::str
155+
#define PYBIND11_BOOL_ATTR "__bool__"
156+
#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_bool)
155157
#define PYBIND11_PLUGIN_IMPL(name) \
156158
extern "C" PYBIND11_EXPORT PyObject *PyInit_##name()
159+
157160
#else
158161
#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyMethod_New(ptr, nullptr, class_)
159162
#define PYBIND11_INSTANCE_METHOD_CHECK PyMethod_Check
@@ -171,6 +174,8 @@
171174
#define PYBIND11_SLICE_OBJECT PySliceObject
172175
#define PYBIND11_FROM_STRING PyString_FromString
173176
#define PYBIND11_STR_TYPE ::pybind11::bytes
177+
#define PYBIND11_BOOL_ATTR "__nonzero__"
178+
#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_nonzero)
174179
#define PYBIND11_PLUGIN_IMPL(name) \
175180
static PyObject *pybind11_init_wrapper(); \
176181
extern "C" PYBIND11_EXPORT void init##name() { \

tests/test_builtin_casters.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ TEST_SUBMODULE(builtin_casters, m) {
116116
m.def("load_nullptr_t", [](std::nullptr_t) {}); // not useful, but it should still compile
117117
m.def("cast_nullptr_t", []() { return std::nullptr_t{}; });
118118

119+
// test_bool_caster
120+
m.def("bool_passthrough", [](bool arg) { return arg; });
121+
m.def("bool_passthrough_noconvert", [](bool arg) { return arg; }, py::arg().noconvert());
122+
119123
// test_reference_wrapper
120124
m.def("refwrap_builtin", [](std::reference_wrapper<int> p) { return 10 * p.get(); });
121125
m.def("refwrap_usertype", [](std::reference_wrapper<UserType> p) { return p.get().value(); });

tests/test_builtin_casters.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,58 @@ def test_complex_cast():
265265
"""std::complex casts"""
266266
assert m.complex_cast(1) == "1.0"
267267
assert m.complex_cast(2j) == "(0.0, 2.0)"
268+
269+
270+
def test_bool_caster():
271+
"""Test bool caster implicit conversions."""
272+
convert, noconvert = m.bool_passthrough, m.bool_passthrough_noconvert
273+
274+
def require_implicit(v):
275+
pytest.raises(TypeError, noconvert, v)
276+
277+
def cant_convert(v):
278+
pytest.raises(TypeError, convert, v)
279+
280+
# straight up bool
281+
assert convert(True) is True
282+
assert convert(False) is False
283+
assert noconvert(True) is True
284+
assert noconvert(False) is False
285+
286+
# None requires implicit conversion
287+
require_implicit(None)
288+
assert convert(None) is False
289+
290+
class A(object):
291+
def __init__(self, x):
292+
self.x = x
293+
294+
def __nonzero__(self):
295+
return self.x
296+
297+
def __bool__(self):
298+
return self.x
299+
300+
class B(object):
301+
pass
302+
303+
# Arbitrary objects are not accepted
304+
cant_convert(object())
305+
cant_convert(B())
306+
307+
# Objects with __nonzero__ / __bool__ defined can be converted
308+
require_implicit(A(True))
309+
assert convert(A(True)) is True
310+
assert convert(A(False)) is False
311+
312+
313+
@pytest.requires_numpy
314+
def test_numpy_bool():
315+
import numpy as np
316+
convert, noconvert = m.bool_passthrough, m.bool_passthrough_noconvert
317+
318+
# np.bool_ is not considered implicit
319+
assert convert(np.bool_(True)) is True
320+
assert convert(np.bool_(False)) is False
321+
assert noconvert(np.bool_(True)) is True
322+
assert noconvert(np.bool_(False)) is False

0 commit comments

Comments
 (0)