From 2667282d095d8b86f4fd0c42b6b9b604bdcd27a6 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 13 Feb 2025 11:57:14 +0100 Subject: [PATCH 1/2] Fix: re-auth interfering with connectivity checks Connections might have pipelined messages when being picked up from the pool (e.g., re-auth). When unflushed those, they will be `RESET` when returned to the pool. This results in pipelineing `LOGON`, `LOGOFF`, and `RESET` where the `RESET` will skip the queue on the server side. Depending on the timing, this will leave the connection in an undesirable state (e.g., unauthenticated). This PR circumvents this situation by acquiring an unprepared connection from the pool when performing connectivity checks. Such connections will not have any work pipelined. --- src/neo4j/_async/io/_pool.py | 24 ++++++++++++++++++++---- src/neo4j/_async/work/session.py | 6 ++++-- src/neo4j/_sync/io/_pool.py | 24 ++++++++++++++++++++---- src/neo4j/_sync/work/session.py | 6 ++++-- tests/unit/async_/io/test_direct.py | 3 ++- tests/unit/sync/io/test_direct.py | 3 ++- 6 files changed, 52 insertions(+), 14 deletions(-) diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index be697ddb5..747614d97 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -274,7 +274,9 @@ async def connection_creator(): return connection_creator return None - async def _re_auth_connection(self, connection, auth, force): + async def _re_auth_connection(self, connection, auth, force, unprepared): + if unprepared and not force: + return if auth: # Assert session auth is supported by the protocol. # The Bolt implementation will try as hard as it can to make the @@ -312,7 +314,14 @@ async def _re_auth_connection(self, connection, auth, force): await connection.send_all() await connection.fetch_all() - async def _acquire(self, address, auth, deadline, liveness_check_timeout): + async def _acquire( + self, + address, + auth, + deadline, + liveness_check_timeout, + unprepared=False, + ): """ Acquire a connection to a given address from the pool. @@ -362,7 +371,7 @@ async def health_check(connection_, deadline_): ) try: await self._re_auth_connection( - connection, auth, force_auth + connection, auth, force_auth, unprepared ) except ConfigurationError: if auth: @@ -419,6 +428,7 @@ async def acquire( bookmarks, auth: AcquisitionAuth, liveness_check_timeout, + unprepared=False, database_callback=None, ): """ @@ -431,6 +441,9 @@ async def acquire( :param bookmarks: :param auth: :param liveness_check_timeout: + :param unprepared: If True, no messages will be pipelined on the + connection. Meant to be used if no work is to be executed on the + connection. :param database_callback: """ ... @@ -651,6 +664,7 @@ async def acquire( bookmarks, auth: AcquisitionAuth, liveness_check_timeout, + unprepared=False, database_callback=None, ): # The access_mode and database is not needed for a direct connection, @@ -663,7 +677,7 @@ async def acquire( ) deadline = Deadline.from_timeout_or_deadline(timeout) return await self._acquire( - self.address, auth, deadline, liveness_check_timeout + self.address, auth, deadline, liveness_check_timeout, unprepared ) @@ -1132,6 +1146,7 @@ async def acquire( bookmarks, auth: AcquisitionAuth | None, liveness_check_timeout, + unprepared=False, database_callback=None, ): if access_mode not in {WRITE_ACCESS, READ_ACCESS}: @@ -1201,6 +1216,7 @@ async def wrapped_database_callback(new_database): auth, deadline, liveness_check_timeout, + unprepared, ) except (ServiceUnavailable, SessionExpired): await self.deactivate(address=address) diff --git a/src/neo4j/_async/work/session.py b/src/neo4j/_async/work/session.py index dd7324abe..3d5b64a27 100644 --- a/src/neo4j/_async/work/session.py +++ b/src/neo4j/_async/work/session.py @@ -180,14 +180,16 @@ async def _result_error(self, error): async def _get_server_info(self): assert not self._connection - await self._connect(READ_ACCESS, liveness_check_timeout=0) + await self._connect( + READ_ACCESS, liveness_check_timeout=0, unprepared=True + ) server_info = self._connection.server_info await self._disconnect() return server_info async def _verify_authentication(self): assert not self._connection - await self._connect(READ_ACCESS, force_auth=True) + await self._connect(READ_ACCESS, force_auth=True, unprepared=True) await self._disconnect() @AsyncNonConcurrentMethodChecker._non_concurrent_method diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index c04c11659..48a148916 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -271,7 +271,9 @@ def connection_creator(): return connection_creator return None - def _re_auth_connection(self, connection, auth, force): + def _re_auth_connection(self, connection, auth, force, unprepared): + if unprepared and not force: + return if auth: # Assert session auth is supported by the protocol. # The Bolt implementation will try as hard as it can to make the @@ -309,7 +311,14 @@ def _re_auth_connection(self, connection, auth, force): connection.send_all() connection.fetch_all() - def _acquire(self, address, auth, deadline, liveness_check_timeout): + def _acquire( + self, + address, + auth, + deadline, + liveness_check_timeout, + unprepared=False, + ): """ Acquire a connection to a given address from the pool. @@ -359,7 +368,7 @@ def health_check(connection_, deadline_): ) try: self._re_auth_connection( - connection, auth, force_auth + connection, auth, force_auth, unprepared ) except ConfigurationError: if auth: @@ -416,6 +425,7 @@ def acquire( bookmarks, auth: AcquisitionAuth, liveness_check_timeout, + unprepared=False, database_callback=None, ): """ @@ -428,6 +438,9 @@ def acquire( :param bookmarks: :param auth: :param liveness_check_timeout: + :param unprepared: If True, no messages will be pipelined on the + connection. Meant to be used if no work is to be executed on the + connection. :param database_callback: """ ... @@ -648,6 +661,7 @@ def acquire( bookmarks, auth: AcquisitionAuth, liveness_check_timeout, + unprepared=False, database_callback=None, ): # The access_mode and database is not needed for a direct connection, @@ -660,7 +674,7 @@ def acquire( ) deadline = Deadline.from_timeout_or_deadline(timeout) return self._acquire( - self.address, auth, deadline, liveness_check_timeout + self.address, auth, deadline, liveness_check_timeout, unprepared ) @@ -1129,6 +1143,7 @@ def acquire( bookmarks, auth: AcquisitionAuth | None, liveness_check_timeout, + unprepared=False, database_callback=None, ): if access_mode not in {WRITE_ACCESS, READ_ACCESS}: @@ -1198,6 +1213,7 @@ def wrapped_database_callback(new_database): auth, deadline, liveness_check_timeout, + unprepared, ) except (ServiceUnavailable, SessionExpired): self.deactivate(address=address) diff --git a/src/neo4j/_sync/work/session.py b/src/neo4j/_sync/work/session.py index 910fe328d..84d43fb38 100644 --- a/src/neo4j/_sync/work/session.py +++ b/src/neo4j/_sync/work/session.py @@ -180,14 +180,16 @@ def _result_error(self, error): def _get_server_info(self): assert not self._connection - self._connect(READ_ACCESS, liveness_check_timeout=0) + self._connect( + READ_ACCESS, liveness_check_timeout=0, unprepared=True + ) server_info = self._connection.server_info self._disconnect() return server_info def _verify_authentication(self): assert not self._connection - self._connect(READ_ACCESS, force_auth=True) + self._connect(READ_ACCESS, force_auth=True, unprepared=True) self._disconnect() @NonConcurrentMethodChecker._non_concurrent_method diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index 47f6f813d..7455824e0 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -70,10 +70,11 @@ async def acquire( bookmarks, auth, liveness_check_timeout, + unprepared=False, database_callback=None, ): return await self._acquire( - self.address, auth, timeout, liveness_check_timeout + self.address, auth, timeout, liveness_check_timeout, unprepared ) diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index 96acf37e4..b661cf4c2 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -70,10 +70,11 @@ def acquire( bookmarks, auth, liveness_check_timeout, + unprepared=False, database_callback=None, ): return self._acquire( - self.address, auth, timeout, liveness_check_timeout + self.address, auth, timeout, liveness_check_timeout, unprepared ) From 2e0c940e553b3b43d00a4fdbec4046ff93f5c961 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Wed, 26 Feb 2025 15:37:42 +0100 Subject: [PATCH 2/2] Add unit tests --- tests/unit/async_/io/test_direct.py | 25 +++++++++++ tests/unit/async_/work/test_session.py | 62 ++++++++++++++++++++++++++ tests/unit/sync/io/test_direct.py | 25 +++++++++++ tests/unit/sync/work/test_session.py | 62 ++++++++++++++++++++++++++ 4 files changed, 174 insertions(+) diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index 7455824e0..344bdb44b 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -278,3 +278,28 @@ async def test_liveness_check( cx1.reset.reset_mock() await pool.release(cx1) cx1.reset.assert_not_called() + + +@pytest.mark.parametrize("unprepared", (True, False, None)) +@mark_async_test +async def test_reauth(async_fake_connection_generator, unprepared): + async with AsyncFakeBoltPool( + async_fake_connection_generator, + ("127.0.0.1", 7687), + ) as pool: + address = neo4j.Address(("127.0.0.1", 7687)) + # pre-populate pool + cx = await pool._acquire(address, None, Deadline(3), None) + await pool.release(cx) + cx.reset_mock() + + kwargs = {} + if unprepared is not None: + kwargs["unprepared"] = unprepared + cx = await pool._acquire(address, None, Deadline(3), None, **kwargs) + if unprepared: + cx.re_auth.assert_not_called() + else: + cx.re_auth.assert_called_once() + + await pool.release(cx) diff --git a/tests/unit/async_/work/test_session.py b/tests/unit/async_/work/test_session.py index 06ba70461..743c7573a 100644 --- a/tests/unit/async_/work/test_session.py +++ b/tests/unit/async_/work/test_session.py @@ -877,3 +877,65 @@ async def resolve_db(): cache_spy.set.assert_called_once_with(key, resolved_db) assert session._pinned_database assert config.database == resolved_db + + +@pytest.mark.parametrize( + "method", ("_get_server_info", "_verify_authentication") +) +@mark_async_test +async def test_check_connections_are_unprepared_connection( + async_fake_pool, + method, +): + config = SessionConfig() + async with AsyncSession(async_fake_pool, config) as session: + await getattr(session, method)() + assert len(async_fake_pool.acquired_connection_mocks) == 1 + async_fake_pool.acquire.assert_awaited_once() + unprepared = async_fake_pool.acquire.call_args.kwargs.get("unprepared") + assert unprepared is True + + +async def _explicit_transaction(session: AsyncSession): + async with await session.begin_transaction(): + pass + + +async def _autocommit_transaction(session: AsyncSession): + await session.run("RETURN 1") + + +async def _tx_func_read(session: AsyncSession): + async def work(tx: AsyncManagedTransaction): + pass + + await session.execute_read(work) + + +async def _tx_func_write(session: AsyncSession): + async def work(tx: AsyncManagedTransaction): + pass + + await session.execute_write(work) + + +@pytest.mark.parametrize( + "method", + ( + _explicit_transaction, + _autocommit_transaction, + _tx_func_read, + _tx_func_write, + ), +) +@mark_async_test +async def test_work_connections_are_prepared_connection( + async_fake_pool, method +): + config = SessionConfig() + async with AsyncSession(async_fake_pool, config) as session: + await method(session) + assert len(async_fake_pool.acquired_connection_mocks) == 1 + async_fake_pool.acquire.assert_awaited_once() + unprepared = async_fake_pool.acquire.call_args.kwargs.get("unprepared") + assert unprepared is False or unprepared is None diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index b661cf4c2..42e4d6c9e 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -278,3 +278,28 @@ def test_liveness_check( cx1.reset.reset_mock() pool.release(cx1) cx1.reset.assert_not_called() + + +@pytest.mark.parametrize("unprepared", (True, False, None)) +@mark_sync_test +def test_reauth(fake_connection_generator, unprepared): + with FakeBoltPool( + fake_connection_generator, + ("127.0.0.1", 7687), + ) as pool: + address = neo4j.Address(("127.0.0.1", 7687)) + # pre-populate pool + cx = pool._acquire(address, None, Deadline(3), None) + pool.release(cx) + cx.reset_mock() + + kwargs = {} + if unprepared is not None: + kwargs["unprepared"] = unprepared + cx = pool._acquire(address, None, Deadline(3), None, **kwargs) + if unprepared: + cx.re_auth.assert_not_called() + else: + cx.re_auth.assert_called_once() + + pool.release(cx) diff --git a/tests/unit/sync/work/test_session.py b/tests/unit/sync/work/test_session.py index 94543cde2..4897c086c 100644 --- a/tests/unit/sync/work/test_session.py +++ b/tests/unit/sync/work/test_session.py @@ -877,3 +877,65 @@ def resolve_db(): cache_spy.set.assert_called_once_with(key, resolved_db) assert session._pinned_database assert config.database == resolved_db + + +@pytest.mark.parametrize( + "method", ("_get_server_info", "_verify_authentication") +) +@mark_sync_test +def test_check_connections_are_unprepared_connection( + fake_pool, + method, +): + config = SessionConfig() + with Session(fake_pool, config) as session: + getattr(session, method)() + assert len(fake_pool.acquired_connection_mocks) == 1 + fake_pool.acquire.assert_called_once() + unprepared = fake_pool.acquire.call_args.kwargs.get("unprepared") + assert unprepared is True + + +def _explicit_transaction(session: Session): + with session.begin_transaction(): + pass + + +def _autocommit_transaction(session: Session): + session.run("RETURN 1") + + +def _tx_func_read(session: Session): + def work(tx: ManagedTransaction): + pass + + session.execute_read(work) + + +def _tx_func_write(session: Session): + def work(tx: ManagedTransaction): + pass + + session.execute_write(work) + + +@pytest.mark.parametrize( + "method", + ( + _explicit_transaction, + _autocommit_transaction, + _tx_func_read, + _tx_func_write, + ), +) +@mark_sync_test +def test_work_connections_are_prepared_connection( + fake_pool, method +): + config = SessionConfig() + with Session(fake_pool, config) as session: + method(session) + assert len(fake_pool.acquired_connection_mocks) == 1 + fake_pool.acquire.assert_called_once() + unprepared = fake_pool.acquire.call_args.kwargs.get("unprepared") + assert unprepared is False or unprepared is None