Skip to content

Commit d0b8c3f

Browse files
committed
roundtrip test via reference passed to aliased class method
Probably the test is failing, because it passes the arguments by value instead of by reference.
1 parent 612a597 commit d0b8c3f

File tree

3 files changed

+175
-1
lines changed

3 files changed

+175
-1
lines changed

tests/test_class_sh_basic.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <pybind11/smart_holder.h>
44

5+
#include <cstdint>
56
#include <memory>
67
#include <string>
78
#include <vector>
@@ -74,7 +75,7 @@ std::string pass_udcp(std::unique_ptr<atyp const, sddc> obj) { return "pass_udcp
7475

7576
// Helpers for testing.
7677
std::string get_mtxt(atyp const &obj) { return obj.mtxt; }
77-
std::ptrdiff_t get_ptr(atyp const &obj) { return reinterpret_cast<std::ptrdiff_t>(&obj); }
78+
std::ptrdiff_t get_ptr(atyp const &obj) { return reinterpret_cast<std::uintptr_t>(&obj); }
7879

7980
std::unique_ptr<atyp> unique_ptr_roundtrip(std::unique_ptr<atyp> obj) { return obj; }
8081
const std::unique_ptr<atyp> &unique_ptr_cref_roundtrip(const std::unique_ptr<atyp> &obj) {

tests/test_class_sh_with_alias.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <pybind11/smart_holder.h>
44

5+
#include <cstdint>
56
#include <memory>
67

78
namespace pybind11_tests {
@@ -73,14 +74,92 @@ void wrap(py::module_ m, const char *py_class_name) {
7374
m.def("AddInCppUniquePtr", AddInCppUniquePtr<SerNo>, py::arg("obj"), py::arg("other_val"));
7475
}
7576

77+
struct Passenger {
78+
std::string mtxt;
79+
// on construction: store pointer as an id
80+
Passenger() : mtxt(id() + "_") {}
81+
Passenger(const Passenger &other) { mtxt = other.mtxt + "Copy->" + id(); }
82+
Passenger(Passenger &&other) { mtxt = other.mtxt + "Move->" + id(); }
83+
std::string id() const { return std::to_string(reinterpret_cast<uintptr_t>(this)); }
84+
};
85+
struct ConsumerBase {
86+
ConsumerBase() = default;
87+
ConsumerBase(const ConsumerBase &) = default;
88+
ConsumerBase(ConsumerBase &&) = default;
89+
virtual ~ConsumerBase() = default;
90+
virtual void pass_uq_cref(const std::unique_ptr<Passenger> &obj) { modify(*obj); };
91+
virtual void pass_valu(Passenger obj) { modify(obj); };
92+
virtual void pass_lref(Passenger &obj) { modify(obj); };
93+
virtual void pass_cref(const Passenger &obj) { modify(const_cast<Passenger &>(obj)); };
94+
void modify(Passenger &obj) {
95+
// when base virtual method is called: append obj pointer again (should be same as before)
96+
obj.mtxt.append("_");
97+
obj.mtxt.append(std::to_string(reinterpret_cast<uintptr_t>(&obj)));
98+
}
99+
};
100+
struct ConsumerBaseAlias : ConsumerBase {
101+
using ConsumerBase::ConsumerBase;
102+
void pass_uq_cref(const std::unique_ptr<Passenger> &obj) override {
103+
PYBIND11_OVERRIDE(void, ConsumerBase, pass_uq_cref, obj);
104+
}
105+
void pass_valu(Passenger obj) override {
106+
PYBIND11_OVERRIDE(void, ConsumerBase, pass_valu, obj);
107+
}
108+
void pass_lref(Passenger &obj) override {
109+
PYBIND11_OVERRIDE(void, ConsumerBase, pass_lref, obj);
110+
}
111+
void pass_cref(const Passenger &obj) override {
112+
PYBIND11_OVERRIDE(void, ConsumerBase, pass_cref, obj);
113+
}
114+
};
115+
116+
// check roundtrip of Passenger send to ConsumerBaseAlias
117+
// TODO: Find template magic to avoid code duplication
118+
std::string check_roundtrip_uq_cref(ConsumerBase &consumer) {
119+
std::unique_ptr<Passenger> obj(new Passenger());
120+
consumer.pass_uq_cref(obj);
121+
return obj->mtxt;
122+
}
123+
std::string check_roundtrip_valu(ConsumerBase &consumer) {
124+
Passenger obj;
125+
consumer.pass_valu(obj);
126+
return obj.mtxt;
127+
}
128+
std::string check_roundtrip_lref(ConsumerBase &consumer) {
129+
Passenger obj;
130+
consumer.pass_lref(obj);
131+
return obj.mtxt;
132+
}
133+
std::string check_roundtrip_cref(ConsumerBase &consumer) {
134+
Passenger obj;
135+
consumer.pass_cref(obj);
136+
return obj.mtxt;
137+
}
138+
76139
} // namespace test_class_sh_with_alias
77140
} // namespace pybind11_tests
78141

79142
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::test_class_sh_with_alias::Abase<0>)
80143
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::test_class_sh_with_alias::Abase<1>)
144+
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::test_class_sh_with_alias::Passenger)
145+
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::test_class_sh_with_alias::ConsumerBase)
81146

82147
TEST_SUBMODULE(class_sh_with_alias, m) {
83148
using namespace pybind11_tests::test_class_sh_with_alias;
84149
wrap<0>(m, "Abase0");
85150
wrap<1>(m, "Abase1");
151+
152+
py::classh<Passenger>(m, "Passenger").def_readwrite("mtxt", &Passenger::mtxt);
153+
154+
py::classh<ConsumerBase, ConsumerBaseAlias>(m, "ConsumerBase")
155+
.def(py::init<>())
156+
.def("pass_uq_cref", &ConsumerBase::pass_uq_cref)
157+
.def("pass_valu", &ConsumerBase::pass_valu)
158+
.def("pass_lref", &ConsumerBase::pass_lref)
159+
.def("pass_cref", &ConsumerBase::pass_cref);
160+
161+
m.def("check_roundtrip_uq_cref", check_roundtrip_uq_cref);
162+
m.def("check_roundtrip_valu", check_roundtrip_valu);
163+
m.def("check_roundtrip_lref", check_roundtrip_lref);
164+
m.def("check_roundtrip_cref", check_roundtrip_cref);
86165
}

tests/test_class_sh_with_alias.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
# -*- coding: utf-8 -*-
2+
import re
23
import pytest
4+
import env # noqa: F401
35

46
from pybind11_tests import class_sh_with_alias as m
57

68

9+
def check_regex(expected, actual):
10+
result = re.match(expected + "$", actual)
11+
if result is None:
12+
pytest.fail("expected: '{}' != actual: '{}'".format(expected, actual))
13+
14+
715
class PyDrvd0(m.Abase0):
816
def __init__(self, val):
917
super(PyDrvd0, self).__init__(val)
@@ -56,3 +64,89 @@ def test_drvd1_add_in_cpp_unique_ptr():
5664
drvd = PyDrvd1(25)
5765
assert m.AddInCppUniquePtr(drvd, 83) == ((25 * 10 + 3) * 200 + 83) * 100 + 13
5866
return # Comment out for manual leak checking (use `top` command).
67+
68+
69+
class PyConsumer1(m.ConsumerBase):
70+
def __init__(self):
71+
m.ConsumerBase.__init__(self)
72+
73+
def pass_uq_cref(self, obj):
74+
obj.mtxt = obj.mtxt + "pass_uq_cref"
75+
76+
def pass_valu(self, obj):
77+
obj.mtxt = obj.mtxt + "pass_valu"
78+
79+
def pass_lref(self, obj):
80+
obj.mtxt = obj.mtxt + "pass_lref"
81+
82+
def pass_cref(self, obj):
83+
obj.mtxt = obj.mtxt + "pass_cref"
84+
85+
86+
class PyConsumer2(m.ConsumerBase):
87+
"""This one, additionally to PyConsumer1 calls the base methods.
88+
This results in a second call to the trampoline override dispatcher.
89+
Hence arguments have travelled a long way back and forth between C++
90+
and Python: C++ -> Python (call #1) -> C++ (call #2)."""
91+
92+
def __init__(self):
93+
m.ConsumerBase.__init__(self)
94+
95+
def pass_uq_cref(self, obj):
96+
obj.mtxt = obj.mtxt + "pass_uq_cref"
97+
m.ConsumerBase.pass_uq_cref(self, obj)
98+
99+
def pass_valu(self, obj):
100+
obj.mtxt = obj.mtxt + "pass_valu"
101+
m.ConsumerBase.pass_valu(self, obj)
102+
103+
def pass_lref(self, obj):
104+
obj.mtxt = obj.mtxt + "pass_lref"
105+
m.ConsumerBase.pass_lref(self, obj)
106+
107+
def pass_cref(self, obj):
108+
obj.mtxt = obj.mtxt + "pass_cref"
109+
m.ConsumerBase.pass_cref(self, obj)
110+
111+
112+
# roundtrip tests, creating an object in C++ that is passed by reference
113+
# to a virtual method of a class derived in Python. Thus:
114+
# C++ -> Python -> C++
115+
@pytest.mark.parametrize(
116+
"f, expected",
117+
[
118+
(m.check_roundtrip_uq_cref, "([0-9]+)_pass_uq_cref"),
119+
(m.check_roundtrip_valu, "([0-9]+)_"), # modification not passed back to C++
120+
(m.check_roundtrip_lref, "([0-9]+)_pass_lref"),
121+
pytest.param(
122+
m.check_roundtrip_cref,
123+
"([0-9]+)_pass_cref",
124+
marks=pytest.mark.skipif("env.PYPY"),
125+
),
126+
],
127+
)
128+
def test_unique_ptr_consumer1_roundtrip(f, expected):
129+
c = PyConsumer1()
130+
check_regex(expected, f(c))
131+
132+
133+
@pytest.mark.parametrize(
134+
"f, expected",
135+
[
136+
pytest.param( # cannot (yet) load unowned const unique_ptr& (for 2nd call)
137+
m.check_roundtrip_uq_cref,
138+
"([0-9]+)_pass_uq_cref_\\1",
139+
marks=pytest.mark.xfail,
140+
),
141+
(m.check_roundtrip_valu, "([0-9]+)_"), # modification not passed back to C++
142+
(m.check_roundtrip_lref, "([0-9]+)_pass_lref_\\1"),
143+
pytest.param( # PYPY always copies the argument instead of passing the reference
144+
m.check_roundtrip_cref,
145+
"([0-9]+)_pass_cref_\\1",
146+
marks=pytest.mark.skipif("env.PYPY"),
147+
),
148+
],
149+
)
150+
def test_unique_ptr_consumer2_roundtrip(f, expected):
151+
c = PyConsumer2()
152+
check_regex(expected, f(c))

0 commit comments

Comments
 (0)