Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ v2.3.0 (Not yet released)
for non-MSVC compilers).
`#934 <https://github.com/pybind/pybind11/pull/934>`_.

* Fixed casting of polymorphic classes which inherit from non-polymorphic bases.
`#1084 <https://github.com/pybind/pybind11/pull/1084>`_.

v2.2.1 (September 14, 2017)
-----------------------------------------------------

Expand Down
27 changes: 25 additions & 2 deletions include/pybind11/attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ struct function_record {
/// Special data structure which (temporarily) holds metadata about a bound class
struct type_record {
PYBIND11_NOINLINE type_record()
: multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false), module_local(false) { }
: multiple_inheritance(false), polymorphic(false), dynamic_attr(false),
buffer_protocol(false), module_local(false) { }

/// Handle to the parent scope
handle scope;
Expand Down Expand Up @@ -238,6 +239,9 @@ struct type_record {
/// Multiple inheritance marker
bool multiple_inheritance : 1;

/// Type is polymorphic in C++
bool polymorphic : 1;

/// Does the class manage a __dict__?
bool dynamic_attr : 1;

Expand All @@ -250,6 +254,7 @@ struct type_record {
/// Is the class definition local to the module shared object?
bool module_local : 1;

/// Add a base as a template argument -- allows casting to base for non-simple types
PYBIND11_NOINLINE void add_base(const std::type_info &base, void *(*caster)(void *)) {
auto base_info = detail::get_type_info(base, false);
if (!base_info) {
Expand All @@ -276,6 +281,24 @@ struct type_record {
if (caster)
base_info->implicit_casts.emplace_back(type, caster);
}

/// Add a base as a runtime argument -- only for simple types
PYBIND11_NOINLINE void add_base(handle base) {
if (!base || !PyType_Check(base.ptr()))
pybind11_fail("generic_type: type \"" + std::string(name) + "\" "
"is trying to register a non-type object as a base");

auto base_ptr = (PyTypeObject *) base.ptr();
auto base_info = detail::get_type_info(base_ptr);
if (polymorphic != base_info->polymorphic) {
pybind11_fail("generic_type: type \"" + std::string(name) + "\" is polymorphic, "
"but its base \"" + std::string(base_ptr->tp_name) + "\" is not. "
"In this case, the base must be specified as a template argument: "
"py::class_<T, Base>(...) instead of py::class_<T>(..., base).");
}

bases.append(base);
}
};

inline function_call::function_call(function_record &f, handle p) :
Expand Down Expand Up @@ -392,7 +415,7 @@ template <> struct process_attribute<arg_v> : process_attribute_default<arg_v> {
/// Process a parent class attribute. Single inheritance only (class_ itself already guarantees that)
template <typename T>
struct process_attribute<T, enable_if_t<is_pyobject<T>::value>> : process_attribute_default<handle> {
static void init(const handle &h, type_record *r) { r->bases.append(h); }
static void init(const handle &h, type_record *r) { r->add_base(h); }
};

/// Process a parent class attribute (deprecated, does not support multiple inheritance)
Expand Down
4 changes: 3 additions & 1 deletion include/pybind11/detail/internals.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,16 @@ struct type_info {
bool simple_type : 1;
/* True if there is no multiple inheritance in this type's inheritance tree */
bool simple_ancestors : 1;
/* Type is polymorphic in C++ */
bool polymorphic : 1;
/* for base vs derived holder_type checks */
bool default_holder : 1;
/* true if this is a type registered with py::module_local */
bool module_local : 1;
};

/// Tracks the `internals` and `type_info` ABI version independent of the main library version
#define PYBIND11_INTERNALS_VERSION 1
#define PYBIND11_INTERNALS_VERSION 2

#if defined(WITH_THREAD)
# define PYBIND11_INTERNALS_KIND ""
Expand Down
10 changes: 9 additions & 1 deletion include/pybind11/pybind11.h
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,7 @@ class generic_type : public object {
tinfo->dealloc = rec.dealloc;
tinfo->simple_type = true;
tinfo->simple_ancestors = true;
tinfo->polymorphic = rec.polymorphic;
tinfo->default_holder = rec.default_holder;
tinfo->module_local = rec.module_local;

Expand All @@ -916,7 +917,12 @@ class generic_type : public object {
}
else if (rec.bases.size() == 1) {
auto parent_tinfo = get_type_info((PyTypeObject *) rec.bases[0].ptr());
tinfo->simple_ancestors = parent_tinfo->simple_ancestors;
if (tinfo->polymorphic == parent_tinfo->polymorphic) {
tinfo->simple_ancestors = parent_tinfo->simple_ancestors;
} else {
mark_parents_nonsimple(tinfo->type);
tinfo->simple_ancestors = false;
}
}

if (rec.module_local) {
Expand All @@ -927,6 +933,7 @@ class generic_type : public object {
}

/// Helper function which tags all parents of a type using mult. inheritance
/// or a polymorphic type which inherits from a non-polymorphic base
void mark_parents_nonsimple(PyTypeObject *value) {
auto t = reinterpret_borrow<tuple>(value->tp_bases);
for (handle h : t) {
Expand Down Expand Up @@ -1045,6 +1052,7 @@ class class_ : public detail::generic_type {
record.holder_size = sizeof(holder_type);
record.init_instance = init_instance;
record.dealloc = dealloc;
record.polymorphic = std::is_polymorphic<type>::value;
record.default_holder = std::is_same<holder_type, std::unique_ptr<type>>::value;

set_operator_new<type>(&record);
Expand Down
32 changes: 32 additions & 0 deletions tests/test_class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,20 @@
#include "constructor_stats.h"
#include "local_bindings.h"

// test_mixed_polymorphic_inheritance (MSVC can't link this if defined at function scope)
struct NonPolymorphicBase {
std::int64_t a, b;
};

struct PolymorphicDerived : NonPolymorphicBase {
PolymorphicDerived() : NonPolymorphicBase{1, 2} { }
virtual ~PolymorphicDerived() { }
};

struct LocalPolymorphicDerived : NonPolymorphicBase {
virtual ~LocalPolymorphicDerived() = default;
};

TEST_SUBMODULE(class_, m) {
// test_instance
struct NoConstructor {
Expand Down Expand Up @@ -81,6 +95,24 @@ TEST_SUBMODULE(class_, m) {
m.def("pet_name_species", [](const Pet &pet) { return pet.name() + " is a " + pet.species(); });
m.def("dog_bark", [](const Dog &dog) { return dog.bark(); });

// test_mixed_polymorphic_inheritance
py::class_<NonPolymorphicBase>(m, "NonPolymorphicBase")
.def_readwrite("a", &NonPolymorphicBase::a)
.def_readwrite("b", &NonPolymorphicBase::b);

py::class_<PolymorphicDerived, NonPolymorphicBase>(m, "PolymorphicDerived")
.def(py::init<>());

m.def("call_with_nonpolymorphic_base", [](const NonPolymorphicBase &x) { return x.b; });
m.def("call_with_polymorphic_derived", [](const PolymorphicDerived &x) { return x.b; });

m.def("register_mixed_polymorphic_base_at_runtime", []() {
auto module = py::module::import("pybind11_tests").attr("class_");
auto base = module.attr("NonPolymorphicBase");
// Expected to throw
py::class_<LocalPolymorphicDerived>(module, "LocalPolymorphicDerived", base);
});

// test_automatic_upcasting
struct BaseClass { virtual ~BaseClass() {} };
struct DerivedClass1 : BaseClass { };
Expand Down
24 changes: 24 additions & 0 deletions tests/test_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,30 @@ def test_inheritance(msg):
assert "No constructor defined!" in str(excinfo.value)


def test_mixed_polymorphic_inheritance():
"""A polymorphic class can inherit members from a non-polymorphic base"""
import re

class PolymorphicDerived(m.PolymorphicDerived):
def __init__(self):
m.PolymorphicDerived.__init__(self)

for x in m.PolymorphicDerived(), PolymorphicDerived():
assert (x.a, x.b) == (1, 2)
x.a = 11
x.b = 22
assert (x.a, x.b) == (11, 22)
assert m.call_with_nonpolymorphic_base(x) == 22
assert m.call_with_polymorphic_derived(x) == 22

with pytest.raises(RuntimeError) as excinfo:
m.register_mixed_polymorphic_base_at_runtime()
assert re.match('generic_type: type ".*LocalPolymorphicDerived" is polymorphic, '
'but its base ".*NonPolymorphicBase" is not', str(excinfo.value))
assert ('In this case, the base must be specified as a template argument: '
'py::class_<T, Base>(...) instead of py::class_<T>(..., base).') in str(excinfo.value)


def test_automatic_upcasting():
assert type(m.return_class_1()).__name__ == "DerivedClass1"
assert type(m.return_class_2()).__name__ == "DerivedClass2"
Expand Down