Skip to content

Commit 652a73a

Browse files
Pim SchellartPim Schellart
authored andcommitted
Add check for matching holder_type when inheriting
1 parent 5f07fac commit 652a73a

File tree

6 files changed

+85
-4
lines changed

6 files changed

+85
-4
lines changed

include/pybind11/attr.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ struct type_record {
185185
/// Does the class require its own metaclass?
186186
bool metaclass : 1;
187187

188-
PYBIND11_NOINLINE void add_base(const std::type_info *base, void *(*caster)(void *)) {
188+
PYBIND11_NOINLINE void add_base(const std::type_info *base, void *(*caster)(void *), const std::type_info* check_holder_type = nullptr) {
189189
auto base_info = detail::get_type_info(*base, false);
190190
if (!base_info) {
191191
std::string tname(base->name());
@@ -194,6 +194,13 @@ struct type_record {
194194
"\" referenced unknown base type \"" + tname + "\"");
195195
}
196196

197+
if (check_holder_type && *check_holder_type != *(base_info->void_holder_type)) {
198+
std::string tname(base->name());
199+
detail::clean_type_id(tname);
200+
pybind11_fail("generic_type: type \"" + std::string(name) +
201+
"\" and base base \"" + tname + "\" have different holder type");
202+
}
203+
197204
bases.append((PyObject *) base_info->type);
198205

199206
if (base_info->type->tp_dictoffset != 0)

include/pybind11/cast.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,31 @@
1919
NAMESPACE_BEGIN(pybind11)
2020
NAMESPACE_BEGIN(detail)
2121

22+
/** Extract type_info of holder type (generalized to hold void)
23+
* Allows for comparisons of holder types between classes
24+
* Users may want to add specializations for custom smart pointers
25+
* in order to enable checks. */
26+
template <typename... Ts>
27+
struct void_holder {
28+
static const std::type_info* type() {
29+
return nullptr;
30+
}
31+
};
32+
33+
template <typename... Ts>
34+
struct void_holder<std::unique_ptr<Ts...>> {
35+
static const std::type_info* type() {
36+
return &typeid(std::unique_ptr<void>);
37+
}
38+
};
39+
40+
template <typename... Ts>
41+
struct void_holder<std::shared_ptr<Ts...>> {
42+
static const std::type_info* type() {
43+
return &typeid(std::shared_ptr<void>);
44+
}
45+
};
46+
2247
/// Additional type information which does not fit into the PyTypeObject
2348
struct type_info {
2449
PyTypeObject *type;
@@ -32,6 +57,9 @@ struct type_info {
3257
/** A simple type never occurs as a (direct or indirect) parent
3358
* of a class that makes use of multiple inheritance */
3459
bool simple_type = true;
60+
/* holder_type generalized to hold void for base vs derived
61+
* holder_type checks */
62+
const std::type_info* void_holder_type = nullptr;
3563
};
3664

3765
PYBIND11_NOINLINE inline internals &get_internals() {

include/pybind11/pybind11.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,10 @@ class class_ : public detail::generic_type {
10161016

10171017
detail::generic_type::initialize(&record);
10181018

1019+
/* Register holder type (e.g. std::shared_ptr) */
1020+
auto tinfo = detail::get_type_info(typeid(type));
1021+
tinfo->void_holder_type = detail::void_holder<holder_type>::type();
1022+
10191023
if (has_alias) {
10201024
auto &instances = pybind11::detail::get_internals().registered_types_cpp;
10211025
instances[std::type_index(typeid(type_alias))] = instances[std::type_index(typeid(type))];
@@ -1026,7 +1030,7 @@ class class_ : public detail::generic_type {
10261030
static void add_base(detail::type_record &rec) {
10271031
rec.add_base(&typeid(Base), [](void *src) -> void * {
10281032
return static_cast<Base *>(reinterpret_cast<type *>(src));
1029-
});
1033+
}, detail::void_holder<holder_type>::type());
10301034
}
10311035

10321036
template <typename Base, detail::enable_if_t<!is_base<Base>::value, int> = 0>

tests/CMakeLists.txt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ endif()
6868
pybind11_add_module(pybind11_tests pybind11_tests.cpp
6969
${PYBIND11_TEST_FILES} ${PYBIND11_HEADERS})
7070

71+
pybind11_add_module(pybind11_test_broken_import pybind11_test_broken_import.cpp
72+
${PYBIND11_HEADERS})
73+
7174
pybind11_enable_warnings(pybind11_tests)
7275

7376
if(EIGEN3_FOUND)
@@ -80,9 +83,11 @@ set(testdir ${PROJECT_SOURCE_DIR}/tests)
8083
# Always write the output file directly into the 'tests' directory (even on MSVC)
8184
if(NOT CMAKE_LIBRARY_OUTPUT_DIRECTORY)
8285
set_target_properties(pybind11_tests PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${testdir})
86+
set_target_properties(pybind11_test_broken_import PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${testdir})
8387
foreach(config ${CMAKE_CONFIGURATION_TYPES})
8488
string(TOUPPER ${config} config)
8589
set_target_properties(pybind11_tests PROPERTIES LIBRARY_OUTPUT_DIRECTORY_${config} ${testdir})
90+
set_target_properties(pybind11_test_broken_import PROPERTIES LIBRARY_OUTPUT_DIRECTORY_${config} ${testdir})
8691
endforeach()
8792
endif()
8893

@@ -97,8 +102,8 @@ if(NOT PYBIND11_PYTEST_FOUND)
97102
endif()
98103

99104
# A single command to compile and run the tests
100-
add_custom_target(pytest COMMAND ${PYTHON_EXECUTABLE} -m pytest -rws ${PYBIND11_PYTEST_FILES}
101-
DEPENDS pybind11_tests WORKING_DIRECTORY ${testdir})
105+
add_custom_target(pytest COMMAND ${PYTHON_EXECUTABLE} -m pytest -rws ${PYBIND11_PYTEST_FILES} test_incompatible_holder.py
106+
DEPENDS pybind11_tests pybind11_test_broken_import WORKING_DIRECTORY ${testdir})
102107

103108
if(PYBIND11_TEST_OVERRIDE)
104109
add_custom_command(TARGET pytest POST_BUILD
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
tests/test_inheritance.cpp -- inheritance, automatic upcasting for polymorphic types
3+
4+
Copyright (c) 2016 Wenzel Jakob <[email protected]>
5+
6+
All rights reserved. Use of this source code is governed by a
7+
BSD-style license that can be found in the LICENSE file.
8+
*/
9+
10+
#include <pybind11/pybind11.h>
11+
12+
class Base {
13+
};
14+
15+
class Derived : public Base {
16+
};
17+
18+
namespace py = pybind11;
19+
20+
/** Note that this test cannot be part of the pybind11_tests module
21+
* because it is designed to fail on import */
22+
PYBIND11_PLUGIN(pybind11_test_broken_import) {
23+
py::module m("pybind11_test_broken_import", "test for incompatible holder type");
24+
25+
py::class_<Base>(m, "Base");
26+
py::class_<Derived, std::shared_ptr<Derived>, Base>(m, "Derived")
27+
.def(py::init<>());
28+
29+
return m.ptr();
30+
}

tests/test_incompatible_holder.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import pytest
2+
3+
4+
def test_inheritance(msg):
5+
with pytest.raises(ImportError):
6+
from pybind11_test_broken_import import Base, Derived
7+

0 commit comments

Comments
 (0)