From 0c472257dc7196c9aadce29143a7116da9205fd8 Mon Sep 17 00:00:00 2001 From: Boris Dalstein Date: Tue, 8 Oct 2024 17:34:07 +0200 Subject: [PATCH 1/2] Fix #5399: iterator increment operator does not skip first item --- include/pybind11/pytypes.h | 16 +++++++++++----- tests/test_pytypes.cpp | 5 +++++ tests/test_pytypes.py | 6 ++++++ 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h index 7aafab6dcc..e763ce3efc 100644 --- a/include/pybind11/pytypes.h +++ b/include/pybind11/pytypes.h @@ -1470,27 +1470,26 @@ class iterator : public object { PYBIND11_OBJECT_DEFAULT(iterator, object, PyIter_Check) iterator &operator++() { + init(); advance(); return *this; } iterator operator++(int) { auto rv = *this; + init(); advance(); return rv; } // NOLINTNEXTLINE(readability-const-return-type) // PR #3263 reference operator*() const { - if (m_ptr && !value.ptr()) { - auto &self = const_cast(*this); - self.advance(); - } + init(); return value; } pointer operator->() const { - operator*(); + init(); return &value; } @@ -1513,6 +1512,13 @@ class iterator : public object { friend bool operator!=(const iterator &a, const iterator &b) { return a->ptr() != b->ptr(); } private: + void init() const { + if (m_ptr && !value.ptr()) { + auto &self = const_cast(*this); + self.advance(); + } + } + void advance() { value = reinterpret_steal(PyIter_Next(m_ptr)); if (value.ptr() == nullptr && PyErr_Occurred()) { diff --git a/tests/test_pytypes.cpp b/tests/test_pytypes.cpp index 19f65ce7eb..759f7ff500 100644 --- a/tests/test_pytypes.cpp +++ b/tests/test_pytypes.cpp @@ -150,6 +150,11 @@ TEST_SUBMODULE(pytypes, m) { m.def("get_iterator", [] { return py::iterator(); }); // test_iterable m.def("get_iterable", [] { return py::iterable(); }); + m.def("get_second_item_from_iterable", [](const py::iterable &iter) { + py::iterator it = iter.begin(); + ++it; + return *it; + }); m.def("get_frozenset_from_iterable", [](const py::iterable &iter) { return py::frozenset(iter); }); m.def("get_list_from_iterable", [](const py::iterable &iter) { return py::list(iter); }); diff --git a/tests/test_pytypes.py b/tests/test_pytypes.py index 39d0b619b8..f1273c7657 100644 --- a/tests/test_pytypes.py +++ b/tests/test_pytypes.py @@ -54,6 +54,12 @@ def test_iterable(doc): assert doc(m.get_iterable) == "get_iterable() -> Iterable" +def test_get_second_item_from_iterable(): + lins = [1, 2] + i = m.get_second_item_from_iterable(lins) + assert i == 2 + + def test_float(doc): assert doc(m.get_float) == "get_float() -> float" From 5a7af3f1e7f2c3a6c2a77664eb3a9da28b550198 Mon Sep 17 00:00:00 2001 From: Boris Dalstein Date: Tue, 8 Oct 2024 18:59:46 +0200 Subject: [PATCH 2/2] Fix postfix increment operator: init() must be called before copying *this --- include/pybind11/pytypes.h | 6 +++++- tests/test_pytypes.cpp | 7 +++++++ tests/test_pytypes.py | 7 +++---- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h index e763ce3efc..027e36098b 100644 --- a/include/pybind11/pytypes.h +++ b/include/pybind11/pytypes.h @@ -1476,8 +1476,12 @@ class iterator : public object { } iterator operator++(int) { - auto rv = *this; + // Note: We must call init() first so that rv.value is + // the same as this->value just before calling advance(). + // Otherwise, dereferencing the returned iterator may call + // advance() again and return the 3rd item instead of the 1st. init(); + auto rv = *this; advance(); return rv; } diff --git a/tests/test_pytypes.cpp b/tests/test_pytypes.cpp index 759f7ff500..8df4cdd3f6 100644 --- a/tests/test_pytypes.cpp +++ b/tests/test_pytypes.cpp @@ -150,7 +150,14 @@ TEST_SUBMODULE(pytypes, m) { m.def("get_iterator", [] { return py::iterator(); }); // test_iterable m.def("get_iterable", [] { return py::iterable(); }); + m.def("get_first_item_from_iterable", [](const py::iterable &iter) { + // This tests the postfix increment operator + py::iterator it = iter.begin(); + py::iterator it2 = it++; + return *it2; + }); m.def("get_second_item_from_iterable", [](const py::iterable &iter) { + // This tests the prefix increment operator py::iterator it = iter.begin(); ++it; return *it; diff --git a/tests/test_pytypes.py b/tests/test_pytypes.py index f1273c7657..9fd24b34f1 100644 --- a/tests/test_pytypes.py +++ b/tests/test_pytypes.py @@ -52,10 +52,9 @@ def test_from_iterable(pytype, from_iter_func): def test_iterable(doc): assert doc(m.get_iterable) == "get_iterable() -> Iterable" - - -def test_get_second_item_from_iterable(): - lins = [1, 2] + lins = [1, 2, 3] + i = m.get_first_item_from_iterable(lins) + assert i == 1 i = m.get_second_item_from_iterable(lins) assert i == 2