diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index be697ddb5..747614d97 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -274,7 +274,9 @@ async def connection_creator(): return connection_creator return None - async def _re_auth_connection(self, connection, auth, force): + async def _re_auth_connection(self, connection, auth, force, unprepared): + if unprepared and not force: + return if auth: # Assert session auth is supported by the protocol. # The Bolt implementation will try as hard as it can to make the @@ -312,7 +314,14 @@ async def _re_auth_connection(self, connection, auth, force): await connection.send_all() await connection.fetch_all() - async def _acquire(self, address, auth, deadline, liveness_check_timeout): + async def _acquire( + self, + address, + auth, + deadline, + liveness_check_timeout, + unprepared=False, + ): """ Acquire a connection to a given address from the pool. @@ -362,7 +371,7 @@ async def health_check(connection_, deadline_): ) try: await self._re_auth_connection( - connection, auth, force_auth + connection, auth, force_auth, unprepared ) except ConfigurationError: if auth: @@ -419,6 +428,7 @@ async def acquire( bookmarks, auth: AcquisitionAuth, liveness_check_timeout, + unprepared=False, database_callback=None, ): """ @@ -431,6 +441,9 @@ async def acquire( :param bookmarks: :param auth: :param liveness_check_timeout: + :param unprepared: If True, no messages will be pipelined on the + connection. Meant to be used if no work is to be executed on the + connection. :param database_callback: """ ... @@ -651,6 +664,7 @@ async def acquire( bookmarks, auth: AcquisitionAuth, liveness_check_timeout, + unprepared=False, database_callback=None, ): # The access_mode and database is not needed for a direct connection, @@ -663,7 +677,7 @@ async def acquire( ) deadline = Deadline.from_timeout_or_deadline(timeout) return await self._acquire( - self.address, auth, deadline, liveness_check_timeout + self.address, auth, deadline, liveness_check_timeout, unprepared ) @@ -1132,6 +1146,7 @@ async def acquire( bookmarks, auth: AcquisitionAuth | None, liveness_check_timeout, + unprepared=False, database_callback=None, ): if access_mode not in {WRITE_ACCESS, READ_ACCESS}: @@ -1201,6 +1216,7 @@ async def wrapped_database_callback(new_database): auth, deadline, liveness_check_timeout, + unprepared, ) except (ServiceUnavailable, SessionExpired): await self.deactivate(address=address) diff --git a/src/neo4j/_async/work/session.py b/src/neo4j/_async/work/session.py index dd7324abe..3d5b64a27 100644 --- a/src/neo4j/_async/work/session.py +++ b/src/neo4j/_async/work/session.py @@ -180,14 +180,16 @@ async def _result_error(self, error): async def _get_server_info(self): assert not self._connection - await self._connect(READ_ACCESS, liveness_check_timeout=0) + await self._connect( + READ_ACCESS, liveness_check_timeout=0, unprepared=True + ) server_info = self._connection.server_info await self._disconnect() return server_info async def _verify_authentication(self): assert not self._connection - await self._connect(READ_ACCESS, force_auth=True) + await self._connect(READ_ACCESS, force_auth=True, unprepared=True) await self._disconnect() @AsyncNonConcurrentMethodChecker._non_concurrent_method diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index c04c11659..48a148916 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -271,7 +271,9 @@ def connection_creator(): return connection_creator return None - def _re_auth_connection(self, connection, auth, force): + def _re_auth_connection(self, connection, auth, force, unprepared): + if unprepared and not force: + return if auth: # Assert session auth is supported by the protocol. # The Bolt implementation will try as hard as it can to make the @@ -309,7 +311,14 @@ def _re_auth_connection(self, connection, auth, force): connection.send_all() connection.fetch_all() - def _acquire(self, address, auth, deadline, liveness_check_timeout): + def _acquire( + self, + address, + auth, + deadline, + liveness_check_timeout, + unprepared=False, + ): """ Acquire a connection to a given address from the pool. @@ -359,7 +368,7 @@ def health_check(connection_, deadline_): ) try: self._re_auth_connection( - connection, auth, force_auth + connection, auth, force_auth, unprepared ) except ConfigurationError: if auth: @@ -416,6 +425,7 @@ def acquire( bookmarks, auth: AcquisitionAuth, liveness_check_timeout, + unprepared=False, database_callback=None, ): """ @@ -428,6 +438,9 @@ def acquire( :param bookmarks: :param auth: :param liveness_check_timeout: + :param unprepared: If True, no messages will be pipelined on the + connection. Meant to be used if no work is to be executed on the + connection. :param database_callback: """ ... @@ -648,6 +661,7 @@ def acquire( bookmarks, auth: AcquisitionAuth, liveness_check_timeout, + unprepared=False, database_callback=None, ): # The access_mode and database is not needed for a direct connection, @@ -660,7 +674,7 @@ def acquire( ) deadline = Deadline.from_timeout_or_deadline(timeout) return self._acquire( - self.address, auth, deadline, liveness_check_timeout + self.address, auth, deadline, liveness_check_timeout, unprepared ) @@ -1129,6 +1143,7 @@ def acquire( bookmarks, auth: AcquisitionAuth | None, liveness_check_timeout, + unprepared=False, database_callback=None, ): if access_mode not in {WRITE_ACCESS, READ_ACCESS}: @@ -1198,6 +1213,7 @@ def wrapped_database_callback(new_database): auth, deadline, liveness_check_timeout, + unprepared, ) except (ServiceUnavailable, SessionExpired): self.deactivate(address=address) diff --git a/src/neo4j/_sync/work/session.py b/src/neo4j/_sync/work/session.py index 910fe328d..84d43fb38 100644 --- a/src/neo4j/_sync/work/session.py +++ b/src/neo4j/_sync/work/session.py @@ -180,14 +180,16 @@ def _result_error(self, error): def _get_server_info(self): assert not self._connection - self._connect(READ_ACCESS, liveness_check_timeout=0) + self._connect( + READ_ACCESS, liveness_check_timeout=0, unprepared=True + ) server_info = self._connection.server_info self._disconnect() return server_info def _verify_authentication(self): assert not self._connection - self._connect(READ_ACCESS, force_auth=True) + self._connect(READ_ACCESS, force_auth=True, unprepared=True) self._disconnect() @NonConcurrentMethodChecker._non_concurrent_method diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index 47f6f813d..344bdb44b 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -70,10 +70,11 @@ async def acquire( bookmarks, auth, liveness_check_timeout, + unprepared=False, database_callback=None, ): return await self._acquire( - self.address, auth, timeout, liveness_check_timeout + self.address, auth, timeout, liveness_check_timeout, unprepared ) @@ -277,3 +278,28 @@ async def test_liveness_check( cx1.reset.reset_mock() await pool.release(cx1) cx1.reset.assert_not_called() + + +@pytest.mark.parametrize("unprepared", (True, False, None)) +@mark_async_test +async def test_reauth(async_fake_connection_generator, unprepared): + async with AsyncFakeBoltPool( + async_fake_connection_generator, + ("127.0.0.1", 7687), + ) as pool: + address = neo4j.Address(("127.0.0.1", 7687)) + # pre-populate pool + cx = await pool._acquire(address, None, Deadline(3), None) + await pool.release(cx) + cx.reset_mock() + + kwargs = {} + if unprepared is not None: + kwargs["unprepared"] = unprepared + cx = await pool._acquire(address, None, Deadline(3), None, **kwargs) + if unprepared: + cx.re_auth.assert_not_called() + else: + cx.re_auth.assert_called_once() + + await pool.release(cx) diff --git a/tests/unit/async_/work/test_session.py b/tests/unit/async_/work/test_session.py index 06ba70461..743c7573a 100644 --- a/tests/unit/async_/work/test_session.py +++ b/tests/unit/async_/work/test_session.py @@ -877,3 +877,65 @@ async def resolve_db(): cache_spy.set.assert_called_once_with(key, resolved_db) assert session._pinned_database assert config.database == resolved_db + + +@pytest.mark.parametrize( + "method", ("_get_server_info", "_verify_authentication") +) +@mark_async_test +async def test_check_connections_are_unprepared_connection( + async_fake_pool, + method, +): + config = SessionConfig() + async with AsyncSession(async_fake_pool, config) as session: + await getattr(session, method)() + assert len(async_fake_pool.acquired_connection_mocks) == 1 + async_fake_pool.acquire.assert_awaited_once() + unprepared = async_fake_pool.acquire.call_args.kwargs.get("unprepared") + assert unprepared is True + + +async def _explicit_transaction(session: AsyncSession): + async with await session.begin_transaction(): + pass + + +async def _autocommit_transaction(session: AsyncSession): + await session.run("RETURN 1") + + +async def _tx_func_read(session: AsyncSession): + async def work(tx: AsyncManagedTransaction): + pass + + await session.execute_read(work) + + +async def _tx_func_write(session: AsyncSession): + async def work(tx: AsyncManagedTransaction): + pass + + await session.execute_write(work) + + +@pytest.mark.parametrize( + "method", + ( + _explicit_transaction, + _autocommit_transaction, + _tx_func_read, + _tx_func_write, + ), +) +@mark_async_test +async def test_work_connections_are_prepared_connection( + async_fake_pool, method +): + config = SessionConfig() + async with AsyncSession(async_fake_pool, config) as session: + await method(session) + assert len(async_fake_pool.acquired_connection_mocks) == 1 + async_fake_pool.acquire.assert_awaited_once() + unprepared = async_fake_pool.acquire.call_args.kwargs.get("unprepared") + assert unprepared is False or unprepared is None diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index 96acf37e4..42e4d6c9e 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -70,10 +70,11 @@ def acquire( bookmarks, auth, liveness_check_timeout, + unprepared=False, database_callback=None, ): return self._acquire( - self.address, auth, timeout, liveness_check_timeout + self.address, auth, timeout, liveness_check_timeout, unprepared ) @@ -277,3 +278,28 @@ def test_liveness_check( cx1.reset.reset_mock() pool.release(cx1) cx1.reset.assert_not_called() + + +@pytest.mark.parametrize("unprepared", (True, False, None)) +@mark_sync_test +def test_reauth(fake_connection_generator, unprepared): + with FakeBoltPool( + fake_connection_generator, + ("127.0.0.1", 7687), + ) as pool: + address = neo4j.Address(("127.0.0.1", 7687)) + # pre-populate pool + cx = pool._acquire(address, None, Deadline(3), None) + pool.release(cx) + cx.reset_mock() + + kwargs = {} + if unprepared is not None: + kwargs["unprepared"] = unprepared + cx = pool._acquire(address, None, Deadline(3), None, **kwargs) + if unprepared: + cx.re_auth.assert_not_called() + else: + cx.re_auth.assert_called_once() + + pool.release(cx) diff --git a/tests/unit/sync/work/test_session.py b/tests/unit/sync/work/test_session.py index 94543cde2..4897c086c 100644 --- a/tests/unit/sync/work/test_session.py +++ b/tests/unit/sync/work/test_session.py @@ -877,3 +877,65 @@ def resolve_db(): cache_spy.set.assert_called_once_with(key, resolved_db) assert session._pinned_database assert config.database == resolved_db + + +@pytest.mark.parametrize( + "method", ("_get_server_info", "_verify_authentication") +) +@mark_sync_test +def test_check_connections_are_unprepared_connection( + fake_pool, + method, +): + config = SessionConfig() + with Session(fake_pool, config) as session: + getattr(session, method)() + assert len(fake_pool.acquired_connection_mocks) == 1 + fake_pool.acquire.assert_called_once() + unprepared = fake_pool.acquire.call_args.kwargs.get("unprepared") + assert unprepared is True + + +def _explicit_transaction(session: Session): + with session.begin_transaction(): + pass + + +def _autocommit_transaction(session: Session): + session.run("RETURN 1") + + +def _tx_func_read(session: Session): + def work(tx: ManagedTransaction): + pass + + session.execute_read(work) + + +def _tx_func_write(session: Session): + def work(tx: ManagedTransaction): + pass + + session.execute_write(work) + + +@pytest.mark.parametrize( + "method", + ( + _explicit_transaction, + _autocommit_transaction, + _tx_func_read, + _tx_func_write, + ), +) +@mark_sync_test +def test_work_connections_are_prepared_connection( + fake_pool, method +): + config = SessionConfig() + with Session(fake_pool, config) as session: + method(session) + assert len(fake_pool.acquired_connection_mocks) == 1 + fake_pool.acquire.assert_called_once() + unprepared = fake_pool.acquire.call_args.kwargs.get("unprepared") + assert unprepared is False or unprepared is None