Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions src/neo4j/_async/io/_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,9 @@ async def connection_creator():
return connection_creator
return None

async def _re_auth_connection(self, connection, auth, force):
async def _re_auth_connection(self, connection, auth, force, unprepared):
if unprepared and not force:
return
if auth:
# Assert session auth is supported by the protocol.
# The Bolt implementation will try as hard as it can to make the
Expand Down Expand Up @@ -312,7 +314,14 @@ async def _re_auth_connection(self, connection, auth, force):
await connection.send_all()
await connection.fetch_all()

async def _acquire(self, address, auth, deadline, liveness_check_timeout):
async def _acquire(
self,
address,
auth,
deadline,
liveness_check_timeout,
unprepared=False,
):
"""
Acquire a connection to a given address from the pool.

Expand Down Expand Up @@ -362,7 +371,7 @@ async def health_check(connection_, deadline_):
)
try:
await self._re_auth_connection(
connection, auth, force_auth
connection, auth, force_auth, unprepared
)
except ConfigurationError:
if auth:
Expand Down Expand Up @@ -419,6 +428,7 @@ async def acquire(
bookmarks,
auth: AcquisitionAuth,
liveness_check_timeout,
unprepared=False,
database_callback=None,
):
"""
Expand All @@ -431,6 +441,9 @@ async def acquire(
:param bookmarks:
:param auth:
:param liveness_check_timeout:
:param unprepared: If True, no messages will be pipelined on the
connection. Meant to be used if no work is to be executed on the
connection.
:param database_callback:
"""
...
Expand Down Expand Up @@ -651,6 +664,7 @@ async def acquire(
bookmarks,
auth: AcquisitionAuth,
liveness_check_timeout,
unprepared=False,
database_callback=None,
):
# The access_mode and database is not needed for a direct connection,
Expand All @@ -663,7 +677,7 @@ async def acquire(
)
deadline = Deadline.from_timeout_or_deadline(timeout)
return await self._acquire(
self.address, auth, deadline, liveness_check_timeout
self.address, auth, deadline, liveness_check_timeout, unprepared
)


Expand Down Expand Up @@ -1132,6 +1146,7 @@ async def acquire(
bookmarks,
auth: AcquisitionAuth | None,
liveness_check_timeout,
unprepared=False,
database_callback=None,
):
if access_mode not in {WRITE_ACCESS, READ_ACCESS}:
Expand Down Expand Up @@ -1201,6 +1216,7 @@ async def wrapped_database_callback(new_database):
auth,
deadline,
liveness_check_timeout,
unprepared,
)
except (ServiceUnavailable, SessionExpired):
await self.deactivate(address=address)
Expand Down
6 changes: 4 additions & 2 deletions src/neo4j/_async/work/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,16 @@ async def _result_error(self, error):

async def _get_server_info(self):
assert not self._connection
await self._connect(READ_ACCESS, liveness_check_timeout=0)
await self._connect(
READ_ACCESS, liveness_check_timeout=0, unprepared=True
)
server_info = self._connection.server_info
await self._disconnect()
return server_info

async def _verify_authentication(self):
assert not self._connection
await self._connect(READ_ACCESS, force_auth=True)
await self._connect(READ_ACCESS, force_auth=True, unprepared=True)
await self._disconnect()

@AsyncNonConcurrentMethodChecker._non_concurrent_method
Expand Down
24 changes: 20 additions & 4 deletions src/neo4j/_sync/io/_pool.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions src/neo4j/_sync/work/session.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 27 additions & 1 deletion tests/unit/async_/io/test_direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,11 @@ async def acquire(
bookmarks,
auth,
liveness_check_timeout,
unprepared=False,
database_callback=None,
):
return await self._acquire(
self.address, auth, timeout, liveness_check_timeout
self.address, auth, timeout, liveness_check_timeout, unprepared
)


Expand Down Expand Up @@ -277,3 +278,28 @@ async def test_liveness_check(
cx1.reset.reset_mock()
await pool.release(cx1)
cx1.reset.assert_not_called()


@pytest.mark.parametrize("unprepared", (True, False, None))
@mark_async_test
async def test_reauth(async_fake_connection_generator, unprepared):
async with AsyncFakeBoltPool(
async_fake_connection_generator,
("127.0.0.1", 7687),
) as pool:
address = neo4j.Address(("127.0.0.1", 7687))
# pre-populate pool
cx = await pool._acquire(address, None, Deadline(3), None)
await pool.release(cx)
cx.reset_mock()

kwargs = {}
if unprepared is not None:
kwargs["unprepared"] = unprepared
cx = await pool._acquire(address, None, Deadline(3), None, **kwargs)
if unprepared:
cx.re_auth.assert_not_called()
else:
cx.re_auth.assert_called_once()

await pool.release(cx)
62 changes: 62 additions & 0 deletions tests/unit/async_/work/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,3 +877,65 @@ async def resolve_db():
cache_spy.set.assert_called_once_with(key, resolved_db)
assert session._pinned_database
assert config.database == resolved_db


@pytest.mark.parametrize(
"method", ("_get_server_info", "_verify_authentication")
)
@mark_async_test
async def test_check_connections_are_unprepared_connection(
async_fake_pool,
method,
):
config = SessionConfig()
async with AsyncSession(async_fake_pool, config) as session:
await getattr(session, method)()
assert len(async_fake_pool.acquired_connection_mocks) == 1
async_fake_pool.acquire.assert_awaited_once()
unprepared = async_fake_pool.acquire.call_args.kwargs.get("unprepared")
assert unprepared is True


async def _explicit_transaction(session: AsyncSession):
async with await session.begin_transaction():
pass


async def _autocommit_transaction(session: AsyncSession):
await session.run("RETURN 1")


async def _tx_func_read(session: AsyncSession):
async def work(tx: AsyncManagedTransaction):
pass

await session.execute_read(work)


async def _tx_func_write(session: AsyncSession):
async def work(tx: AsyncManagedTransaction):
pass

await session.execute_write(work)


@pytest.mark.parametrize(
"method",
(
_explicit_transaction,
_autocommit_transaction,
_tx_func_read,
_tx_func_write,
),
)
@mark_async_test
async def test_work_connections_are_prepared_connection(
async_fake_pool, method
):
config = SessionConfig()
async with AsyncSession(async_fake_pool, config) as session:
await method(session)
assert len(async_fake_pool.acquired_connection_mocks) == 1
async_fake_pool.acquire.assert_awaited_once()
unprepared = async_fake_pool.acquire.call_args.kwargs.get("unprepared")
assert unprepared is False or unprepared is None
28 changes: 27 additions & 1 deletion tests/unit/sync/io/test_direct.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading