Skip to content

Commit a3afc5d

Browse files
committed
Add py::local() for module-local type bindings
This commit adds a `py::local` attribute that lets you confine a registered type to the module (more technically, the shared object) in which it is defined, by registering it with: py::class_<C>(m, "C", py::local()) This will allow the C++ class `C` to be registered in different modules with independent sets of class definitions. On the Python side, two such types will be completely distinct; on the C++ side, the C++ resolves to a different Python types in each module.
1 parent a5ccb5f commit a3afc5d

File tree

9 files changed

+241
-42
lines changed

9 files changed

+241
-42
lines changed

include/pybind11/attr.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ struct metaclass {
6464
explicit metaclass(handle value) : value(value) { }
6565
};
6666

67+
/// Annotation that marks a class as local to the module:
68+
struct local { bool value; local(bool v = true) : value(v) { } };
69+
6770
/// Annotation to mark enums as an arithmetic type
6871
struct arithmetic { };
6972

@@ -196,7 +199,7 @@ struct function_record {
196199
/// Special data structure which (temporarily) holds metadata about a bound class
197200
struct type_record {
198201
PYBIND11_NOINLINE type_record()
199-
: multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false) { }
202+
: multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false), local(false) { }
200203

201204
/// Handle to the parent scope
202205
handle scope;
@@ -243,6 +246,9 @@ struct type_record {
243246
/// Is the default (unique_ptr) holder type used?
244247
bool default_holder : 1;
245248

249+
/// Is the class definition local to the module shared object?
250+
bool local : 1;
251+
246252
PYBIND11_NOINLINE void add_base(const std::type_info &base, void *(*caster)(void *)) {
247253
auto base_info = detail::get_type_info(base, false);
248254
if (!base_info) {
@@ -408,6 +414,10 @@ struct process_attribute<metaclass> : process_attribute_default<metaclass> {
408414
static void init(const metaclass &m, type_record *r) { r->metaclass = m.value; }
409415
};
410416

417+
template <>
418+
struct process_attribute<local> : process_attribute_default<local> {
419+
static void init(const local &l, type_record *r) { r->local = l.value; }
420+
};
411421

412422
/// Process an 'arithmetic' attribute for enums (does nothing here)
413423
template <>

include/pybind11/cast.h

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,14 @@ PYBIND11_NOINLINE inline internals &get_internals() {
128128
return *internals_ptr;
129129
}
130130

131+
namespace {
132+
// Works like internals.registered_types_cpp, but for module-local registered types:
133+
inline type_map<void *> &registered_local_types_cpp() {
134+
static type_map<void *> locals{};
135+
return locals;
136+
}
137+
}
138+
131139
/// A life support system for temporary objects created by `type_caster::load()`.
132140
/// Adding a patient will keep it alive up until the enclosing function returns.
133141
class loader_life_support {
@@ -215,7 +223,7 @@ PYBIND11_NOINLINE inline void all_type_info_populate(PyTypeObject *t, std::vecto
215223
// registered types
216224
if (i + 1 == check.size()) {
217225
// When we're at the end, we can pop off the current element to avoid growing
218-
// `check` when adding just one base (which is typical--.e. when there is no
226+
// `check` when adding just one base (which is typical--i.e. when there is no
219227
// multiple inheritance)
220228
check.pop_back();
221229
i--;
@@ -259,13 +267,18 @@ PYBIND11_NOINLINE inline detail::type_info* get_type_info(PyTypeObject *type) {
259267
return bases.front();
260268
}
261269

262-
PYBIND11_NOINLINE inline detail::type_info *get_type_info(const std::type_info &tp,
270+
/// Return the type info for a given C++ type; on lookup failure can either throw or return nullptr.
271+
PYBIND11_NOINLINE inline detail::type_info *get_type_info(const std::type_index &tp,
263272
bool throw_if_missing = false) {
273+
std::type_index type_idx(tp);
264274
auto &types = get_internals().registered_types_cpp;
265-
266-
auto it = types.find(std::type_index(tp));
275+
auto it = types.find(type_idx);
267276
if (it != types.end())
268277
return (detail::type_info *) it->second;
278+
auto &locals = registered_local_types_cpp();
279+
it = locals.find(type_idx);
280+
if (it != locals.end())
281+
return (detail::type_info *) it->second;
269282
if (throw_if_missing) {
270283
std::string tname = tp.name();
271284
detail::clean_type_id(tname);
@@ -723,10 +736,8 @@ class type_caster_generic {
723736
// with .second = nullptr. (p.first = nullptr is not an error: it becomes None).
724737
PYBIND11_NOINLINE static std::pair<const void *, const type_info *> src_and_type(
725738
const void *src, const std::type_info &cast_type, const std::type_info *rtti_type = nullptr) {
726-
auto &internals = get_internals();
727-
auto it = internals.registered_types_cpp.find(std::type_index(cast_type));
728-
if (it != internals.registered_types_cpp.end())
729-
return {src, (const type_info *) it->second};
739+
if (auto *tpi = get_type_info(cast_type))
740+
return {src, const_cast<const type_info *>(tpi)};
730741

731742
// Not found, set error:
732743
std::string tname = rtti_type ? rtti_type->name() : cast_type.name();
@@ -804,7 +815,6 @@ template <typename type> class type_caster_base : public type_caster_generic {
804815
template <typename T = itype, enable_if_t<std::is_polymorphic<T>::value, int> = 0>
805816
static std::pair<const void *, const type_info *> src_and_type(const itype *src) {
806817
const void *vsrc = src;
807-
auto &internals = get_internals();
808818
auto &cast_type = typeid(itype);
809819
const std::type_info *instance_type = nullptr;
810820
if (vsrc) {
@@ -813,9 +823,8 @@ template <typename type> class type_caster_base : public type_caster_generic {
813823
// This is a base pointer to a derived type; if it is a pybind11-registered type, we
814824
// can get the correct derived pointer (which may be != base pointer) by a
815825
// dynamic_cast to most derived type:
816-
auto it = internals.registered_types_cpp.find(std::type_index(*instance_type));
817-
if (it != internals.registered_types_cpp.end())
818-
return {dynamic_cast<const void *>(src), (const type_info *) it->second};
826+
if (auto *tpi = get_type_info(*instance_type))
827+
return {dynamic_cast<const void *>(src), const_cast<const type_info *>(tpi)};
819828
}
820829
}
821830
// Otherwise we have either a nullptr, an `itype` pointer, or an unknown derived pointer, so

include/pybind11/pybind11.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,10 @@ class generic_type : public object {
844844
auto tindex = std::type_index(*rec.type);
845845
tinfo->direct_conversions = &internals.direct_conversions[tindex];
846846
tinfo->default_holder = rec.default_holder;
847-
internals.registered_types_cpp[tindex] = tinfo;
847+
if (rec.local)
848+
registered_local_types_cpp()[tindex] = tinfo;
849+
else
850+
internals.registered_types_cpp[tindex] = tinfo;
848851
internals.registered_types_py[(PyTypeObject *) m_ptr] = { tinfo };
849852

850853
if (rec.bases.size() > 1 || rec.multiple_inheritance) {
@@ -978,7 +981,7 @@ class class_ : public detail::generic_type {
978981
generic_type::initialize(record);
979982

980983
if (has_alias) {
981-
auto &instances = get_internals().registered_types_cpp;
984+
auto &instances = record.local ? registered_local_types_cpp() : get_internals().registered_types_cpp;
982985
instances[std::type_index(typeid(type_alias))] = instances[std::type_index(typeid(type))];
983986
}
984987
}
@@ -1427,7 +1430,7 @@ iterator make_iterator(Iterator first, Sentinel last, Extra &&... extra) {
14271430
typedef detail::iterator_state<Iterator, Sentinel, false, Policy> state;
14281431

14291432
if (!detail::get_type_info(typeid(state), false)) {
1430-
class_<state>(handle(), "iterator")
1433+
class_<state>(handle(), "iterator", pybind11::local())
14311434
.def("__iter__", [](state &s) -> state& { return s; })
14321435
.def("__next__", [](state &s) -> ValueType {
14331436
if (!s.first_or_done)
@@ -1456,7 +1459,7 @@ iterator make_key_iterator(Iterator first, Sentinel last, Extra &&... extra) {
14561459
typedef detail::iterator_state<Iterator, Sentinel, true, Policy> state;
14571460

14581461
if (!detail::get_type_info(typeid(state), false)) {
1459-
class_<state>(handle(), "iterator")
1462+
class_<state>(handle(), "iterator", pybind11::local())
14601463
.def("__iter__", [](state &s) -> state& { return s; })
14611464
.def("__next__", [](state &s) -> KeyType {
14621465
if (!s.first_or_done)

tests/CMakeLists.txt

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ set(PYBIND11_TEST_FILES
4343
test_methods_and_attributes.cpp
4444
test_modules.cpp
4545
test_multiple_inheritance.cpp
46+
test_local_bindings.cpp
4647
test_numpy_array.cpp
4748
test_numpy_dtypes.cpp
4849
test_numpy_vectorize.cpp
@@ -120,36 +121,49 @@ function(pybind11_enable_warnings target_name)
120121
endif()
121122
endfunction()
122123

124+
set(test_targets pybind11_tests)
125+
# If test_local_bindings is being built in pybind11_tests.so we'll also build
126+
# a separate module pybind11_local_bindings.so for the test:
127+
list(FIND PYBIND11_TEST_FILES test_local_bindings.cpp test_local_i)
128+
if (test_local_i GREATER -1)
129+
list(APPEND test_targets pybind11_local_bindings)
130+
endif()
123131

124-
# Create the binding library
125-
pybind11_add_module(pybind11_tests THIN_LTO pybind11_tests.cpp
126-
${PYBIND11_TEST_FILES} ${PYBIND11_HEADERS})
132+
set(testdir ${CMAKE_CURRENT_SOURCE_DIR})
133+
foreach(tgt ${test_targets})
134+
set(test_files ${PYBIND11_TEST_FILES})
135+
if(NOT tgt STREQUAL "pybind11_tests")
136+
set(test_files "")
137+
endif()
127138

128-
pybind11_enable_warnings(pybind11_tests)
139+
# Create the binding library
140+
pybind11_add_module(${tgt} THIN_LTO ${tgt}.cpp
141+
${test_files} ${PYBIND11_HEADERS})
129142

130-
if(MSVC)
131-
target_compile_options(pybind11_tests PRIVATE /utf-8)
132-
endif()
143+
pybind11_enable_warnings(${tgt})
133144

134-
if(EIGEN3_FOUND)
135-
if (PYBIND11_EIGEN_VIA_TARGET)
136-
target_link_libraries(pybind11_tests PRIVATE Eigen3::Eigen)
137-
else()
138-
target_include_directories(pybind11_tests PRIVATE ${EIGEN3_INCLUDE_DIR})
145+
if(MSVC)
146+
target_compile_options(${tgt} PRIVATE /utf-8)
139147
endif()
140-
target_compile_definitions(pybind11_tests PRIVATE -DPYBIND11_TEST_EIGEN)
141-
endif()
142148

143-
set(testdir ${CMAKE_CURRENT_SOURCE_DIR})
149+
if(EIGEN3_FOUND)
150+
if (PYBIND11_EIGEN_VIA_TARGET)
151+
target_link_libraries(${tgt} PRIVATE Eigen3::Eigen)
152+
else()
153+
target_include_directories(${tgt} PRIVATE ${EIGEN3_INCLUDE_DIR})
154+
endif()
155+
target_compile_definitions(${tgt} PRIVATE -DPYBIND11_TEST_EIGEN)
156+
endif()
144157

145-
# Always write the output file directly into the 'tests' directory (even on MSVC)
146-
if(NOT CMAKE_LIBRARY_OUTPUT_DIRECTORY)
147-
set_target_properties(pybind11_tests PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${testdir})
148-
foreach(config ${CMAKE_CONFIGURATION_TYPES})
149-
string(TOUPPER ${config} config)
150-
set_target_properties(pybind11_tests PROPERTIES LIBRARY_OUTPUT_DIRECTORY_${config} ${testdir})
151-
endforeach()
152-
endif()
158+
# Always write the output file directly into the 'tests' directory (even on MSVC)
159+
if(NOT CMAKE_LIBRARY_OUTPUT_DIRECTORY)
160+
set_target_properties(${tgt} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${testdir})
161+
foreach(config ${CMAKE_CONFIGURATION_TYPES})
162+
string(TOUPPER ${config} config)
163+
set_target_properties(${tgt} PROPERTIES LIBRARY_OUTPUT_DIRECTORY_${config} ${testdir})
164+
endforeach()
165+
endif()
166+
endforeach()
153167

154168
# Make sure pytest is found or produce a fatal error
155169
if(NOT PYBIND11_PYTEST_FOUND)
@@ -173,7 +187,7 @@ endif()
173187

174188
# A single command to compile and run the tests
175189
add_custom_target(pytest COMMAND ${PYTHON_EXECUTABLE} -m pytest ${PYBIND11_PYTEST_FILES}
176-
DEPENDS pybind11_tests WORKING_DIRECTORY ${testdir} ${PYBIND11_USES_TERMINAL})
190+
DEPENDS ${test_targets} WORKING_DIRECTORY ${testdir} ${PYBIND11_USES_TERMINAL})
177191

178192
if(PYBIND11_TEST_OVERRIDE)
179193
add_custom_command(TARGET pytest POST_BUILD
@@ -189,7 +203,7 @@ if (NOT PROJECT_NAME STREQUAL "pybind11")
189203
return()
190204
endif()
191205

192-
# Add a post-build comment to show the .so size and, if a previous size, compare it:
206+
# Add a post-build comment to show the primary test suite .so size and, if a previous size, compare it:
193207
add_custom_command(TARGET pybind11_tests POST_BUILD
194208
COMMAND ${PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/tools/libsize.py
195209
$<TARGET_FILE:pybind11_tests> ${CMAKE_CURRENT_BINARY_DIR}/sosize-$<TARGET_FILE_NAME:pybind11_tests>.txt)

tests/local_bindings.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#pragma once
2+
#include "pybind11_tests.h"
3+
4+
/// Simple class used to test py::local:
5+
template <int> class LocalBase {
6+
public:
7+
LocalBase(int i) : i(i) { }
8+
int i = -1;
9+
};
10+
11+
/// Registered with py::local in both main and secondary modules:
12+
using LocalType = LocalBase<0>;
13+
/// Registered without py::local in both modules:
14+
using NonLocalType = LocalBase<1>;
15+
/// A second non-local type (for stl_bind tests):
16+
using NonLocal2 = LocalBase<2>;
17+
/// Tests within-module, different-compilation-unit local definition conflict:
18+
using LocalExternal = LocalBase<3>;
19+
20+
// Simple bindings (used with the above):
21+
template <typename T, int Adjust, typename... Args>
22+
py::class_<T> bind_local(Args && ...args) {
23+
return py::class_<T>(std::forward<Args>(args)...)
24+
.def(py::init<int>())
25+
.def("get", [](T &i) { return i.i + Adjust; });
26+
};

tests/pybind11_local_bindings.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
tests/pybind11_local_bindings.cpp -- counterpart to test_local_bindings.cpp
3+
4+
Copyright (c) 2017 Jason Rhinelander <[email protected]>
5+
6+
All rights reserved. Use of this source code is governed by a
7+
BSD-style license that can be found in the LICENSE file.
8+
*/
9+
10+
#include "pybind11_tests.h"
11+
#include "local_bindings.h"
12+
#include <pybind11/stl_bind.h>
13+
14+
PYBIND11_MODULE(pybind11_local_bindings, m) {
15+
m.doc() = "pybind11 local bindings test module";
16+
17+
// Definitions here are tested by importing both this module and the
18+
// pybind11_tests.local_bindings submodule from test_local_bindings.py
19+
20+
// test_local_bindings
21+
// Local to both:
22+
bind_local<LocalType, 1>(m, "LocalType", py::local())
23+
.def("get2", [](LocalType &t) { return t.i + 2; })
24+
;
25+
26+
// Can only be called with our python type:
27+
m.def("local_value", [](LocalType &l) { return l.i; });
28+
29+
// test_nonlocal_failure
30+
// This registration will fail (global registration when LocalFail is already registered
31+
// globally in the main test module):
32+
m.def("register_nonlocal", [m]() {
33+
bind_local<NonLocalType, 0>(m, "NonLocalType");
34+
});
35+
}

tests/test_class.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include "pybind11_tests.h"
1111
#include "constructor_stats.h"
12+
#include "local_bindings.h"
1213

1314
TEST_SUBMODULE(class_, m) {
1415
// test_instance
@@ -184,6 +185,10 @@ TEST_SUBMODULE(class_, m) {
184185
auto def = new PyMethodDef{"f", f, METH_VARARGS, nullptr};
185186
return py::reinterpret_steal<py::object>(PyCFunction_NewEx(def, nullptr, m.ptr()));
186187
}());
188+
189+
// This test is actually part of test_local_bindings (test_duplicate_local), but we need a
190+
// definition in a different compilation unit within the same module:
191+
bind_local<LocalExternal, 17>(m, "LocalExternal", py::local());
187192
}
188193

189194
template <int N> class BreaksBase {};

tests/test_local_bindings.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
tests/test_local_bindings.cpp -- tests the py::local class feature which makes a class binding
3+
local to the module in which it is defined.
4+
5+
Copyright (c) 2017 Jason Rhinelander <[email protected]>
6+
7+
All rights reserved. Use of this source code is governed by a
8+
BSD-style license that can be found in the LICENSE file.
9+
*/
10+
11+
#include "pybind11_tests.h"
12+
#include "local_bindings.h"
13+
#include <pybind11/stl_bind.h>
14+
15+
TEST_SUBMODULE(local_bindings, m) {
16+
17+
// test_local_bindings
18+
// Register a class with py::local:
19+
bind_local<LocalType, -1>(m, "LocalType", py::local())
20+
.def("get3", [](LocalType &t) { return t.i + 3; })
21+
;
22+
23+
m.def("local_value", [](LocalType &l) { return l.i; });
24+
25+
// test_nonlocal_failure
26+
// The main pybind11 test module is loaded first, so this registration will succeed (the second
27+
// one, in pybind11_local_bindings.cpp, is designed to fail):
28+
bind_local<NonLocalType, 0>(m, "NonLocalType")
29+
.def(py::init<int>())
30+
.def("get", [](LocalType &i) { return i.i; })
31+
;
32+
33+
// test_duplicate_local
34+
// py::local declarations should be visible across compilation units that get linked together;
35+
// this tries to register a duplicate local. It depends on a definition in test_class.cpp and
36+
// should raise a runtime error from the duplicate definition attempt. If test_class isn't
37+
// available it *also* throws a runtime error (with "test_class not enabled" as value).
38+
m.def("register_local_external", [m]() {
39+
auto main = py::module::import("pybind11_tests");
40+
if (py::hasattr(main, "class_")) {
41+
bind_local<LocalExternal, 7>(m, "LocalExternal", py::local());
42+
}
43+
else throw std::runtime_error("test_class not enabled");
44+
});
45+
}

0 commit comments

Comments
 (0)