Skip to content

Commit ebf6b2b

Browse files
committed
Add py::local() for module-local type bindings
This commit adds a `py::local` attribute that lets you confine a registered type to the module (more technically, the shared object) in which it is defined, by registering it with: py::class_<C>(m, "C", py::local()) This will allow the C++ class `C` to be registered in different modules with independent sets of class definitions. On the Python side, two such types will be completely distinct; on the C++ side, the C++ resolves to a different Python types in each module.
1 parent 60526d4 commit ebf6b2b

File tree

8 files changed

+199
-42
lines changed

8 files changed

+199
-42
lines changed

include/pybind11/attr.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ struct metaclass {
6464
explicit metaclass(handle value) : value(value) { }
6565
};
6666

67+
/// Annotation that marks a class as local to the module:
68+
struct local { bool value; local(bool v = true) : value(v) { } };
69+
6770
/// Annotation to mark enums as an arithmetic type
6871
struct arithmetic { };
6972

@@ -196,7 +199,7 @@ struct function_record {
196199
/// Special data structure which (temporarily) holds metadata about a bound class
197200
struct type_record {
198201
PYBIND11_NOINLINE type_record()
199-
: multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false) { }
202+
: multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false), local(false) { }
200203

201204
/// Handle to the parent scope
202205
handle scope;
@@ -243,6 +246,9 @@ struct type_record {
243246
/// Is the default (unique_ptr) holder type used?
244247
bool default_holder : 1;
245248

249+
/// Is the class definition local to the module shared object?
250+
bool local : 1;
251+
246252
PYBIND11_NOINLINE void add_base(const std::type_info &base, void *(*caster)(void *)) {
247253
auto base_info = detail::get_type_info(base, false);
248254
if (!base_info) {
@@ -408,6 +414,10 @@ struct process_attribute<metaclass> : process_attribute_default<metaclass> {
408414
static void init(const metaclass &m, type_record *r) { r->metaclass = m.value; }
409415
};
410416

417+
template <>
418+
struct process_attribute<local> : process_attribute_default<local> {
419+
static void init(const local &l, type_record *r) { r->local = l.value; }
420+
};
411421

412422
/// Process an 'arithmetic' attribute for enums (does nothing here)
413423
template <>

include/pybind11/cast.h

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ PYBIND11_NOINLINE inline internals &get_internals() {
111111
return *internals_ptr;
112112
}
113113

114+
// Works like internals.registered_types_cpp, but for module-local registered types:
115+
inline type_map<void *> &registered_local_types_cpp() {
116+
static type_map<void *> locals{};
117+
return locals;
118+
}
119+
114120
/// A life support system for temporary objects created by `type_caster::load()`.
115121
/// Adding a patient will keep it alive up until the enclosing function returns.
116122
class loader_life_support {
@@ -198,7 +204,7 @@ PYBIND11_NOINLINE inline void all_type_info_populate(PyTypeObject *t, std::vecto
198204
// registered types
199205
if (i + 1 == check.size()) {
200206
// When we're at the end, we can pop off the current element to avoid growing
201-
// `check` when adding just one base (which is typical--.e. when there is no
207+
// `check` when adding just one base (which is typical--i.e. when there is no
202208
// multiple inheritance)
203209
check.pop_back();
204210
i--;
@@ -242,13 +248,18 @@ PYBIND11_NOINLINE inline detail::type_info* get_type_info(PyTypeObject *type) {
242248
return bases.front();
243249
}
244250

245-
PYBIND11_NOINLINE inline detail::type_info *get_type_info(const std::type_info &tp,
251+
/// Return the type info for a given C++ type; on lookup failure can either throw or return nullptr.
252+
PYBIND11_NOINLINE inline detail::type_info *get_type_info(const std::type_index &tp,
246253
bool throw_if_missing = false) {
254+
std::type_index type_idx(tp);
247255
auto &types = get_internals().registered_types_cpp;
248-
249-
auto it = types.find(std::type_index(tp));
256+
auto it = types.find(type_idx);
250257
if (it != types.end())
251258
return (detail::type_info *) it->second;
259+
auto &locals = registered_local_types_cpp();
260+
it = locals.find(type_idx);
261+
if (it != locals.end())
262+
return (detail::type_info *) it->second;
252263
if (throw_if_missing) {
253264
std::string tname = tp.name();
254265
detail::clean_type_id(tname);
@@ -706,10 +717,8 @@ class type_caster_generic {
706717
// with .second = nullptr. (p.first = nullptr is not an error: it becomes None).
707718
PYBIND11_NOINLINE static std::pair<const void *, const type_info *> src_and_type(
708719
const void *src, const std::type_info &cast_type, const std::type_info *rtti_type = nullptr) {
709-
auto &internals = get_internals();
710-
auto it = internals.registered_types_cpp.find(std::type_index(cast_type));
711-
if (it != internals.registered_types_cpp.end())
712-
return {src, (const type_info *) it->second};
720+
if (auto *tpi = get_type_info(cast_type))
721+
return {src, const_cast<const type_info *>(tpi)};
713722

714723
// Not found, set error:
715724
std::string tname = rtti_type ? rtti_type->name() : cast_type.name();
@@ -787,7 +796,6 @@ template <typename type> class type_caster_base : public type_caster_generic {
787796
template <typename T = itype, enable_if_t<std::is_polymorphic<T>::value, int> = 0>
788797
static std::pair<const void *, const type_info *> src_and_type(const itype *src) {
789798
const void *vsrc = src;
790-
auto &internals = get_internals();
791799
auto &cast_type = typeid(itype);
792800
const std::type_info *instance_type = nullptr;
793801
if (vsrc) {
@@ -796,9 +804,8 @@ template <typename type> class type_caster_base : public type_caster_generic {
796804
// This is a base pointer to a derived type; if it is a pybind11-registered type, we
797805
// can get the correct derived pointer (which may be != base pointer) by a
798806
// dynamic_cast to most derived type:
799-
auto it = internals.registered_types_cpp.find(std::type_index(*instance_type));
800-
if (it != internals.registered_types_cpp.end())
801-
return {dynamic_cast<const void *>(src), (const type_info *) it->second};
807+
if (auto *tpi = get_type_info(*instance_type))
808+
return {dynamic_cast<const void *>(src), const_cast<const type_info *>(tpi)};
802809
}
803810
}
804811
// Otherwise we have either a nullptr, an `itype` pointer, or an unknown derived pointer, so

include/pybind11/pybind11.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,10 @@ class generic_type : public object {
844844
auto tindex = std::type_index(*rec.type);
845845
tinfo->direct_conversions = &internals.direct_conversions[tindex];
846846
tinfo->default_holder = rec.default_holder;
847-
internals.registered_types_cpp[tindex] = tinfo;
847+
if (rec.local)
848+
registered_local_types_cpp()[tindex] = tinfo;
849+
else
850+
internals.registered_types_cpp[tindex] = tinfo;
848851
internals.registered_types_py[(PyTypeObject *) m_ptr] = { tinfo };
849852

850853
if (rec.bases.size() > 1 || rec.multiple_inheritance) {
@@ -978,7 +981,7 @@ class class_ : public detail::generic_type {
978981
generic_type::initialize(record);
979982

980983
if (has_alias) {
981-
auto &instances = get_internals().registered_types_cpp;
984+
auto &instances = record.local ? registered_local_types_cpp() : get_internals().registered_types_cpp;
982985
instances[std::type_index(typeid(type_alias))] = instances[std::type_index(typeid(type))];
983986
}
984987
}
@@ -1427,7 +1430,7 @@ iterator make_iterator(Iterator first, Sentinel last, Extra &&... extra) {
14271430
typedef detail::iterator_state<Iterator, Sentinel, false, Policy> state;
14281431

14291432
if (!detail::get_type_info(typeid(state), false)) {
1430-
class_<state>(handle(), "iterator")
1433+
class_<state>(handle(), "iterator", pybind11::local())
14311434
.def("__iter__", [](state &s) -> state& { return s; })
14321435
.def("__next__", [](state &s) -> ValueType {
14331436
if (!s.first_or_done)
@@ -1456,7 +1459,7 @@ iterator make_key_iterator(Iterator first, Sentinel last, Extra &&... extra) {
14561459
typedef detail::iterator_state<Iterator, Sentinel, true, Policy> state;
14571460

14581461
if (!detail::get_type_info(typeid(state), false)) {
1459-
class_<state>(handle(), "iterator")
1462+
class_<state>(handle(), "iterator", pybind11::local())
14601463
.def("__iter__", [](state &s) -> state& { return s; })
14611464
.def("__next__", [](state &s) -> KeyType {
14621465
if (!s.first_or_done)

tests/CMakeLists.txt

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ set(PYBIND11_TEST_FILES
4343
test_methods_and_attributes.cpp
4444
test_modules.cpp
4545
test_multiple_inheritance.cpp
46+
test_local_bindings.cpp
4647
test_numpy_array.cpp
4748
test_numpy_dtypes.cpp
4849
test_numpy_vectorize.cpp
@@ -120,36 +121,49 @@ function(pybind11_enable_warnings target_name)
120121
endif()
121122
endfunction()
122123

124+
set(test_targets pybind11_tests)
125+
# If test_local_bindings is being built in pybind11_tests.so we'll also build
126+
# a separate module pybind11_local_bindings.so for the test:
127+
list(FIND PYBIND11_TEST_FILES test_local_bindings.cpp test_local_i)
128+
if (test_local_i GREATER -1)
129+
list(APPEND test_targets pybind11_local_bindings)
130+
endif()
123131

124-
# Create the binding library
125-
pybind11_add_module(pybind11_tests THIN_LTO pybind11_tests.cpp
126-
${PYBIND11_TEST_FILES} ${PYBIND11_HEADERS})
132+
set(testdir ${CMAKE_CURRENT_SOURCE_DIR})
133+
foreach(tgt ${test_targets})
134+
set(test_files ${PYBIND11_TEST_FILES})
135+
if(NOT tgt STREQUAL "pybind11_tests")
136+
set(test_files "")
137+
endif()
127138

128-
pybind11_enable_warnings(pybind11_tests)
139+
# Create the binding library
140+
pybind11_add_module(${tgt} THIN_LTO ${tgt}.cpp
141+
${test_files} ${PYBIND11_HEADERS})
129142

130-
if(MSVC)
131-
target_compile_options(pybind11_tests PRIVATE /utf-8)
132-
endif()
143+
pybind11_enable_warnings(${tgt})
133144

134-
if(EIGEN3_FOUND)
135-
if (PYBIND11_EIGEN_VIA_TARGET)
136-
target_link_libraries(pybind11_tests PRIVATE Eigen3::Eigen)
137-
else()
138-
target_include_directories(pybind11_tests PRIVATE ${EIGEN3_INCLUDE_DIR})
145+
if(MSVC)
146+
target_compile_options(${tgt} PRIVATE /utf-8)
139147
endif()
140-
target_compile_definitions(pybind11_tests PRIVATE -DPYBIND11_TEST_EIGEN)
141-
endif()
142148

143-
set(testdir ${CMAKE_CURRENT_SOURCE_DIR})
149+
if(EIGEN3_FOUND)
150+
if (PYBIND11_EIGEN_VIA_TARGET)
151+
target_link_libraries(${tgt} PRIVATE Eigen3::Eigen)
152+
else()
153+
target_include_directories(${tgt} PRIVATE ${EIGEN3_INCLUDE_DIR})
154+
endif()
155+
target_compile_definitions(${tgt} PRIVATE -DPYBIND11_TEST_EIGEN)
156+
endif()
144157

145-
# Always write the output file directly into the 'tests' directory (even on MSVC)
146-
if(NOT CMAKE_LIBRARY_OUTPUT_DIRECTORY)
147-
set_target_properties(pybind11_tests PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${testdir})
148-
foreach(config ${CMAKE_CONFIGURATION_TYPES})
149-
string(TOUPPER ${config} config)
150-
set_target_properties(pybind11_tests PROPERTIES LIBRARY_OUTPUT_DIRECTORY_${config} ${testdir})
151-
endforeach()
152-
endif()
158+
# Always write the output file directly into the 'tests' directory (even on MSVC)
159+
if(NOT CMAKE_LIBRARY_OUTPUT_DIRECTORY)
160+
set_target_properties(${tgt} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${testdir})
161+
foreach(config ${CMAKE_CONFIGURATION_TYPES})
162+
string(TOUPPER ${config} config)
163+
set_target_properties(${tgt} PROPERTIES LIBRARY_OUTPUT_DIRECTORY_${config} ${testdir})
164+
endforeach()
165+
endif()
166+
endforeach()
153167

154168
# Make sure pytest is found or produce a fatal error
155169
if(NOT PYBIND11_PYTEST_FOUND)
@@ -173,7 +187,7 @@ endif()
173187

174188
# A single command to compile and run the tests
175189
add_custom_target(pytest COMMAND ${PYTHON_EXECUTABLE} -m pytest ${PYBIND11_PYTEST_FILES}
176-
DEPENDS pybind11_tests WORKING_DIRECTORY ${testdir} ${PYBIND11_USES_TERMINAL})
190+
DEPENDS ${test_targets} WORKING_DIRECTORY ${testdir} ${PYBIND11_USES_TERMINAL})
177191

178192
if(PYBIND11_TEST_OVERRIDE)
179193
add_custom_command(TARGET pytest POST_BUILD
@@ -189,7 +203,7 @@ if (NOT PROJECT_NAME STREQUAL "pybind11")
189203
return()
190204
endif()
191205

192-
# Add a post-build comment to show the .so size and, if a previous size, compare it:
206+
# Add a post-build comment to show the primary test suite .so size and, if a previous size, compare it:
193207
add_custom_command(TARGET pybind11_tests POST_BUILD
194208
COMMAND ${PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/tools/libsize.py
195209
$<TARGET_FILE:pybind11_tests> ${CMAKE_CURRENT_BINARY_DIR}/sosize-$<TARGET_FILE_NAME:pybind11_tests>.txt)

tests/local_bindings.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#pragma once
2+
#include "pybind11_tests.h"
3+
4+
/// Simple class used to test py::local:
5+
template <int> class LocalBase {
6+
public:
7+
LocalBase(int i) : i(i) { }
8+
int i = -1;
9+
};
10+
11+
/// Registered with py::local in both main and secondary modules:
12+
using LocalType = LocalBase<0>;
13+
/// Registered without py::local in both modules:
14+
using NonLocalType = LocalBase<1>;
15+
/// A second non-local type (for stl_bind tests):
16+
using NonLocal2 = LocalBase<2>;
17+
18+
// Simple bindings (used with the above):
19+
template <typename T, int Adjust, typename... Args>
20+
py::class_<T> bind_local(Args && ...args) {
21+
return py::class_<T>(std::forward<Args>(args)...)
22+
.def(py::init<int>())
23+
.def("get", [](T &i) { return i.i + Adjust; });
24+
};

tests/pybind11_local_bindings.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
tests/pybind11_local_bindings.cpp -- counterpart to test_local_bindings.cpp
3+
4+
Copyright (c) 2017 Jason Rhinelander <[email protected]>
5+
6+
All rights reserved. Use of this source code is governed by a
7+
BSD-style license that can be found in the LICENSE file.
8+
*/
9+
10+
#include "pybind11_tests.h"
11+
#include "local_bindings.h"
12+
#include <pybind11/stl_bind.h>
13+
14+
PYBIND11_MODULE(pybind11_local_bindings, m) {
15+
m.doc() = "pybind11 local bindings test module";
16+
17+
// Local to both:
18+
bind_local<LocalType, 1>(m, "LocalType", py::local())
19+
.def("get2", [](LocalType &t) { return t.i + 2; })
20+
;
21+
22+
// Can only be called with our python type:
23+
m.def("local_value", [](LocalType &l) { return l.i; });
24+
25+
// This registration will fail (global registration when LocalFail is already registered
26+
// globally in the main test module):
27+
m.def("register_nonlocal", [m]() {
28+
bind_local<NonLocalType, 0>(m, "NonLocalType");
29+
});
30+
}

tests/test_local_bindings.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
tests/test_local_bindings.cpp -- tests the py::local class feature which makes a class binding
3+
local to the module in which it is defined.
4+
5+
Copyright (c) 2017 Jason Rhinelander <[email protected]>
6+
7+
All rights reserved. Use of this source code is governed by a
8+
BSD-style license that can be found in the LICENSE file.
9+
*/
10+
11+
#include "pybind11_tests.h"
12+
#include "local_bindings.h"
13+
#include <pybind11/stl_bind.h>
14+
15+
TEST_SUBMODULE(local_bindings, m) {
16+
17+
// Register a class with py::local:
18+
bind_local<LocalType, -1>(m, "LocalType", py::local())
19+
.def("get3", [](LocalType &t) { return t.i + 3; })
20+
;
21+
22+
m.def("local_value", [](LocalType &l) { return l.i; });
23+
24+
// The main pybind11 test module is loaded first, so this registration will succeed (the second
25+
// one, in pybind11_local_bindings.cpp, is designed to fail):
26+
bind_local<NonLocalType, 0>(m, "NonLocalType")
27+
.def(py::init<int>())
28+
.def("get", [](LocalType &i) { return i.i; })
29+
;
30+
}

tests/test_local_bindings.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import pytest
2+
3+
from pybind11_tests import local_bindings as m1
4+
5+
6+
def test_local_bindings():
7+
# Make sure we can load the second module with the conflicting (but local) definition:
8+
import pybind11_local_bindings as m2
9+
10+
i1 = m1.LocalType(5)
11+
12+
assert i1.get() == 4
13+
assert i1.get3() == 8
14+
15+
i2 = m2.LocalType(10)
16+
assert i2.get() == 11
17+
assert i2.get2() == 12
18+
19+
assert not hasattr(i1, 'get2')
20+
assert not hasattr(i2, 'get3')
21+
22+
assert m1.local_value(i1) == 5
23+
assert m2.local_value(i2) == 10
24+
25+
with pytest.raises(TypeError) as excinfo:
26+
m1.local_value(i2)
27+
assert "incompatible function arguments" in str(excinfo.value)
28+
29+
with pytest.raises(TypeError) as excinfo:
30+
m2.local_value(i1)
31+
assert "incompatible function arguments" in str(excinfo.value)
32+
33+
34+
def test_nonlocal_failure():
35+
import pybind11_local_bindings as m2
36+
37+
with pytest.raises(RuntimeError) as excinfo:
38+
m2.register_nonlocal()
39+
assert str(excinfo.value) == 'generic_type: type "NonLocalType" is already registered!'

0 commit comments

Comments
 (0)