diff --git a/include/pybind11/detail/type_caster_base.h b/include/pybind11/detail/type_caster_base.h index 476646ee8f..aaf191a60e 100644 --- a/include/pybind11/detail/type_caster_base.h +++ b/include/pybind11/detail/type_caster_base.h @@ -115,6 +115,40 @@ inline void all_type_info_add_base_most_derived_first(std::vector & bases.push_back(addl_base); } +inline void all_type_info_check_for_divergence(const std::vector &bases) { + using sz_t = std::size_t; + sz_t n = bases.size(); + if (n < 3) { + return; + } + std::vector cluster_ids; + cluster_ids.reserve(n); + for (sz_t ci = 0; ci < n; ci++) { + cluster_ids.push_back(ci); + } + for (sz_t i = 0; i < n - 1; i++) { + if (cluster_ids[i] != i) { + continue; + } + for (sz_t j = i + 1; j < n; j++) { + if (PyType_IsSubtype(bases[i]->type, bases[j]->type) != 0) { + sz_t k = cluster_ids[j]; + if (k == j) { + cluster_ids[j] = i; + } else { + PyErr_Format( + PyExc_TypeError, + "bases include diverging derived types: base=%s, derived1=%s, derived2=%s", + bases[j]->type->tp_name, + bases[k]->type->tp_name, + bases[i]->type->tp_name); + throw error_already_set(); + } + } + } + } +} + // Populates a just-created cache entry. PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector &bases) { assert(bases.empty()); @@ -168,6 +202,7 @@ PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector 1) { + std::vector matching_bases; for (auto *base : bases) { if (no_cpp_mi ? PyType_IsSubtype(base->type, typeinfo->type) : base->type == typeinfo->type) { - this_.load_value( - reinterpret_cast(src.ptr())->get_value_and_holder(base)); - return true; + matching_bases.push_back(base); } } + if (!matching_bases.empty()) { + if (matching_bases.size() > 1) { + matching_bases.push_back(const_cast(typeinfo)); + all_type_info_check_for_divergence(matching_bases); + } + this_.load_value(reinterpret_cast(src.ptr())->get_value_and_holder( + matching_bases[0])); + return true; + } } // Case 2c: C++ multiple inheritance is involved and we couldn't find an exact type diff --git a/tests/test_python_multiple_inheritance.cpp b/tests/test_python_multiple_inheritance.cpp index 6899171585..15b1f3360e 100644 --- a/tests/test_python_multiple_inheritance.cpp +++ b/tests/test_python_multiple_inheritance.cpp @@ -26,6 +26,18 @@ struct CppDrvd : CppBase { int drvd_value; }; +struct CppDrvd2 : CppBase { + explicit CppDrvd2(int value) : CppBase(value), drvd2_value(value * 5) {} + int get_drvd2_value() const { return drvd2_value; } + void reset_drvd2_value(int new_value) { drvd2_value = new_value; } + + int get_base_value_from_drvd2() const { return get_base_value(); } + void reset_base_value_from_drvd2(int new_value) { reset_base_value(new_value); } + +private: + int drvd2_value; +}; + } // namespace test_python_multiple_inheritance TEST_SUBMODULE(python_multiple_inheritance, m) { @@ -42,4 +54,13 @@ TEST_SUBMODULE(python_multiple_inheritance, m) { .def("reset_drvd_value", &CppDrvd::reset_drvd_value) .def("get_base_value_from_drvd", &CppDrvd::get_base_value_from_drvd) .def("reset_base_value_from_drvd", &CppDrvd::reset_base_value_from_drvd); + + py::class_(m, "CppDrvd2") + .def(py::init()) + .def("get_drvd2_value", &CppDrvd2::get_drvd2_value) + .def("reset_drvd2_value", &CppDrvd2::reset_drvd2_value) + .def("get_base_value_from_drvd2", &CppDrvd2::get_base_value_from_drvd2) + .def("reset_base_value_from_drvd2", &CppDrvd2::reset_base_value_from_drvd2); + + m.def("pass_CppBase", [](const CppBase *) {}); } diff --git a/tests/test_python_multiple_inheritance.py b/tests/test_python_multiple_inheritance.py index 3bddd67dfb..6b941cb8d5 100644 --- a/tests/test_python_multiple_inheritance.py +++ b/tests/test_python_multiple_inheritance.py @@ -1,6 +1,8 @@ # Adapted from: # https://github.com/google/clif/blob/5718e4d0807fd3b6a8187dde140069120b81ecef/clif/testing/python/python_multiple_inheritance_test.py +import pytest + from pybind11_tests import python_multiple_inheritance as m @@ -12,6 +14,28 @@ class PPCC(PC, m.CppDrvd): pass +class PPPCCC(PPCC, m.CppDrvd2): + pass + + +class PC1(m.CppDrvd): + pass + + +class PC2(m.CppDrvd2): + pass + + +class PCD(PC1, PC2): + pass + + +class PCDI(PC1, PC2): + def __init__(self): + PC1.__init__(self, 11) + PC2.__init__(self, 12) + + def test_PC(): d = PC(11) assert d.get_base_value() == 11 @@ -33,3 +57,27 @@ def test_PPCC(): d.reset_base_value_from_drvd(30) assert d.get_base_value() == 30 assert d.get_base_value_from_drvd() == 30 + + +def NOtest_PPPCCC(): + # terminate called after throwing an instance of 'pybind11::error_already_set' + # what(): TypeError: bases include diverging derived types: + # base=pybind11_tests.python_multiple_inheritance.CppBase, + # derived1=pybind11_tests.python_multiple_inheritance.CppDrvd, + # derived2=pybind11_tests.python_multiple_inheritance.CppDrvd2 + PPPCCC(11) + + +def test_PCD(): + # This escapes all_type_info_check_for_divergence() because CppBase does not appear in bases. + with pytest.raises( + TypeError, + match=r"CppDrvd2\.__init__\(\) must be called when overriding __init__$", + ): + PCD(11) + + +def test_PCDI(): + obj = PCDI() + with pytest.raises(TypeError, match="^bases include diverging derived types: "): + m.pass_CppBase(obj)