From 2ce3b6cd77b28fab82effa5a104a0c609ddfd7db Mon Sep 17 00:00:00 2001 From: Dean Moldovan Date: Tue, 12 Sep 2017 11:04:08 +0200 Subject: [PATCH] Fix casting of classes with mixed polymorphic inheritance A polymorphic class can inherit from a non-polymorphic base. --- docs/changelog.rst | 3 +++ include/pybind11/attr.h | 27 ++++++++++++++++++++++-- include/pybind11/detail/internals.h | 4 +++- include/pybind11/pybind11.h | 10 ++++++++- tests/test_class.cpp | 32 +++++++++++++++++++++++++++++ tests/test_class.py | 24 ++++++++++++++++++++++ 6 files changed, 96 insertions(+), 4 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 8b7047df4c..cd15645af3 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -15,6 +15,9 @@ v2.3.0 (Not yet released) for non-MSVC compilers). `#934 `_. +* Fixed casting of polymorphic classes which inherit from non-polymorphic bases. + `#1084 `_. + v2.2.1 (September 14, 2017) ----------------------------------------------------- diff --git a/include/pybind11/attr.h b/include/pybind11/attr.h index dce875a6b9..c9d63e6b37 100644 --- a/include/pybind11/attr.h +++ b/include/pybind11/attr.h @@ -200,7 +200,8 @@ struct function_record { /// Special data structure which (temporarily) holds metadata about a bound class struct type_record { PYBIND11_NOINLINE type_record() - : multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false), module_local(false) { } + : multiple_inheritance(false), polymorphic(false), dynamic_attr(false), + buffer_protocol(false), module_local(false) { } /// Handle to the parent scope handle scope; @@ -238,6 +239,9 @@ struct type_record { /// Multiple inheritance marker bool multiple_inheritance : 1; + /// Type is polymorphic in C++ + bool polymorphic : 1; + /// Does the class manage a __dict__? bool dynamic_attr : 1; @@ -250,6 +254,7 @@ struct type_record { /// Is the class definition local to the module shared object? bool module_local : 1; + /// Add a base as a template argument -- allows casting to base for non-simple types PYBIND11_NOINLINE void add_base(const std::type_info &base, void *(*caster)(void *)) { auto base_info = detail::get_type_info(base, false); if (!base_info) { @@ -276,6 +281,24 @@ struct type_record { if (caster) base_info->implicit_casts.emplace_back(type, caster); } + + /// Add a base as a runtime argument -- only for simple types + PYBIND11_NOINLINE void add_base(handle base) { + if (!base || !PyType_Check(base.ptr())) + pybind11_fail("generic_type: type \"" + std::string(name) + "\" " + "is trying to register a non-type object as a base"); + + auto base_ptr = (PyTypeObject *) base.ptr(); + auto base_info = detail::get_type_info(base_ptr); + if (polymorphic != base_info->polymorphic) { + pybind11_fail("generic_type: type \"" + std::string(name) + "\" is polymorphic, " + "but its base \"" + std::string(base_ptr->tp_name) + "\" is not. " + "In this case, the base must be specified as a template argument: " + "py::class_(...) instead of py::class_(..., base)."); + } + + bases.append(base); + } }; inline function_call::function_call(function_record &f, handle p) : @@ -392,7 +415,7 @@ template <> struct process_attribute : process_attribute_default { /// Process a parent class attribute. Single inheritance only (class_ itself already guarantees that) template struct process_attribute::value>> : process_attribute_default { - static void init(const handle &h, type_record *r) { r->bases.append(h); } + static void init(const handle &h, type_record *r) { r->add_base(h); } }; /// Process a parent class attribute (deprecated, does not support multiple inheritance) diff --git a/include/pybind11/detail/internals.h b/include/pybind11/detail/internals.h index 213cbaeb21..8f778c9917 100644 --- a/include/pybind11/detail/internals.h +++ b/include/pybind11/detail/internals.h @@ -104,6 +104,8 @@ struct type_info { bool simple_type : 1; /* True if there is no multiple inheritance in this type's inheritance tree */ bool simple_ancestors : 1; + /* Type is polymorphic in C++ */ + bool polymorphic : 1; /* for base vs derived holder_type checks */ bool default_holder : 1; /* true if this is a type registered with py::module_local */ @@ -111,7 +113,7 @@ struct type_info { }; /// Tracks the `internals` and `type_info` ABI version independent of the main library version -#define PYBIND11_INTERNALS_VERSION 1 +#define PYBIND11_INTERNALS_VERSION 2 #if defined(WITH_THREAD) # define PYBIND11_INTERNALS_KIND "" diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 59af022927..e6450d274e 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -898,6 +898,7 @@ class generic_type : public object { tinfo->dealloc = rec.dealloc; tinfo->simple_type = true; tinfo->simple_ancestors = true; + tinfo->polymorphic = rec.polymorphic; tinfo->default_holder = rec.default_holder; tinfo->module_local = rec.module_local; @@ -916,7 +917,12 @@ class generic_type : public object { } else if (rec.bases.size() == 1) { auto parent_tinfo = get_type_info((PyTypeObject *) rec.bases[0].ptr()); - tinfo->simple_ancestors = parent_tinfo->simple_ancestors; + if (tinfo->polymorphic == parent_tinfo->polymorphic) { + tinfo->simple_ancestors = parent_tinfo->simple_ancestors; + } else { + mark_parents_nonsimple(tinfo->type); + tinfo->simple_ancestors = false; + } } if (rec.module_local) { @@ -927,6 +933,7 @@ class generic_type : public object { } /// Helper function which tags all parents of a type using mult. inheritance + /// or a polymorphic type which inherits from a non-polymorphic base void mark_parents_nonsimple(PyTypeObject *value) { auto t = reinterpret_borrow(value->tp_bases); for (handle h : t) { @@ -1045,6 +1052,7 @@ class class_ : public detail::generic_type { record.holder_size = sizeof(holder_type); record.init_instance = init_instance; record.dealloc = dealloc; + record.polymorphic = std::is_polymorphic::value; record.default_holder = std::is_same>::value; set_operator_new(&record); diff --git a/tests/test_class.cpp b/tests/test_class.cpp index 2221906170..23037c3a20 100644 --- a/tests/test_class.cpp +++ b/tests/test_class.cpp @@ -11,6 +11,20 @@ #include "constructor_stats.h" #include "local_bindings.h" +// test_mixed_polymorphic_inheritance (MSVC can't link this if defined at function scope) +struct NonPolymorphicBase { + std::int64_t a, b; +}; + +struct PolymorphicDerived : NonPolymorphicBase { + PolymorphicDerived() : NonPolymorphicBase{1, 2} { } + virtual ~PolymorphicDerived() { } +}; + +struct LocalPolymorphicDerived : NonPolymorphicBase { + virtual ~LocalPolymorphicDerived() = default; +}; + TEST_SUBMODULE(class_, m) { // test_instance struct NoConstructor { @@ -81,6 +95,24 @@ TEST_SUBMODULE(class_, m) { m.def("pet_name_species", [](const Pet &pet) { return pet.name() + " is a " + pet.species(); }); m.def("dog_bark", [](const Dog &dog) { return dog.bark(); }); + // test_mixed_polymorphic_inheritance + py::class_(m, "NonPolymorphicBase") + .def_readwrite("a", &NonPolymorphicBase::a) + .def_readwrite("b", &NonPolymorphicBase::b); + + py::class_(m, "PolymorphicDerived") + .def(py::init<>()); + + m.def("call_with_nonpolymorphic_base", [](const NonPolymorphicBase &x) { return x.b; }); + m.def("call_with_polymorphic_derived", [](const PolymorphicDerived &x) { return x.b; }); + + m.def("register_mixed_polymorphic_base_at_runtime", []() { + auto module = py::module::import("pybind11_tests").attr("class_"); + auto base = module.attr("NonPolymorphicBase"); + // Expected to throw + py::class_(module, "LocalPolymorphicDerived", base); + }); + // test_automatic_upcasting struct BaseClass { virtual ~BaseClass() {} }; struct DerivedClass1 : BaseClass { }; diff --git a/tests/test_class.py b/tests/test_class.py index 412d6798e9..dfbcedc47c 100644 --- a/tests/test_class.py +++ b/tests/test_class.py @@ -76,6 +76,30 @@ def test_inheritance(msg): assert "No constructor defined!" in str(excinfo.value) +def test_mixed_polymorphic_inheritance(): + """A polymorphic class can inherit members from a non-polymorphic base""" + import re + + class PolymorphicDerived(m.PolymorphicDerived): + def __init__(self): + m.PolymorphicDerived.__init__(self) + + for x in m.PolymorphicDerived(), PolymorphicDerived(): + assert (x.a, x.b) == (1, 2) + x.a = 11 + x.b = 22 + assert (x.a, x.b) == (11, 22) + assert m.call_with_nonpolymorphic_base(x) == 22 + assert m.call_with_polymorphic_derived(x) == 22 + + with pytest.raises(RuntimeError) as excinfo: + m.register_mixed_polymorphic_base_at_runtime() + assert re.match('generic_type: type ".*LocalPolymorphicDerived" is polymorphic, ' + 'but its base ".*NonPolymorphicBase" is not', str(excinfo.value)) + assert ('In this case, the base must be specified as a template argument: ' + 'py::class_(...) instead of py::class_(..., base).') in str(excinfo.value) + + def test_automatic_upcasting(): assert type(m.return_class_1()).__name__ == "DerivedClass1" assert type(m.return_class_2()).__name__ == "DerivedClass2"