Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 18 additions & 24 deletions tests/integration/mixed/test_async_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_can_create_async_driver_outside_of_loop(uri, auth):
counter = 0
was_full = False

async def return_1(tx):
async def return_1(tx: neo4j.AsyncManagedTransaction) -> None:
nonlocal counter, was_full
res = await tx.run("RETURN 1")

Expand All @@ -58,30 +58,24 @@ async def return_1(tx):
await res.consume()
counter -= 1

async def run(driver: neo4j.AsyncDriver):
async with driver:
sessions = []
try:
for i in range(pool_size * 4):
sessions.append(driver.session())
work_loads = (session.execute_read(return_1)
for session in sessions)
await asyncio.gather(*work_loads)
finally:
cancelled = None
for session in sessions:
if not cancelled:
try:
await session.close()
except asyncio.CancelledError as e:
session.cancel()
cancelled = e
else:
session.cancel()
await driver.close()
if cancelled:
raise cancelled
async def session_handler(session: neo4j.AsyncSession) -> None:
nonlocal was_full
try:
async with session:
await session.execute_read(return_1)
except BaseException:
# if we failed, no need to make return_1 stall any longer
was_full = True
raise

async def run(driver_: neo4j.AsyncDriver):
async with driver_:
work_loads = (session_handler(driver_.session())
for _ in range(pool_size * 4))
res = await asyncio.gather(*work_loads, return_exceptions=True)
for r in res:
if isinstance(r, Exception):
raise r

driver = neo4j.AsyncGraphDatabase.driver(
uri, auth=auth, max_connection_pool_size=pool_size
Expand Down