Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/neo4j/_async_compat/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,14 @@ async def _acquire_non_blocking(self, me):
# Hence, we flag this task as cancelled again, so that the next
# `await` will raise the CancelledError.
asyncio.current_task().cancel()
if task.done() and task.exception() is None:
self._owner = me
self._count = 1
return True
if task.done():
exception = task.exception()
if exception is None:
self._owner = me
self._count = 1
return True
else:
raise exception
task.cancel()
return False

Expand Down
63 changes: 63 additions & 0 deletions tests/unit/mixed/async_compat/test_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,69 @@ async def waiter():
assert lock.locked() # waiter_non_blocking still owns it!


@pytest.mark.asyncio
async def test_async_r_lock_acquire_non_blocking_exception(mocker):
lock = AsyncRLock()
exc = RuntimeError("it broke!")

assert not lock.locked()

# Not sure asyncio.Lock.acquire is even fallible, but should it be or ever
# become, our AsyncRLock should re-raise the exception.
acquire_mock = mocker.patch("asyncio.Lock.acquire", autospec=True,
side_effect=asyncio.Lock.acquire)
awaits = 0

async def blocker():
nonlocal awaits

assert awaits == 0
awaits += 1
assert await lock.acquire()

assert awaits == 1
awaits += 1
await asyncio.sleep(0)

assert awaits == 3
awaits += 1
await asyncio.sleep(0)

assert awaits == 5
awaits += 1
await asyncio.sleep(0)

assert awaits == 7
lock.release()

async def waiter_non_blocking():
nonlocal awaits
nonlocal acquire_mock

assert awaits == 2
acquire_mock.side_effect = exc
coro = lock.acquire(blocking=False)
fut = asyncio.ensure_future(coro)
assert not fut.done()
awaits += 1
await asyncio.sleep(0)

assert awaits == 4
assert not fut.done()
awaits += 1
await asyncio.sleep(0)

assert awaits == 6
assert fut.done()
assert fut.exception() is exc
awaits += 1


assert not lock.locked()
await asyncio.gather(blocker(), waiter_non_blocking())
assert not lock.locked()


@pytest.mark.parametrize("waits", range(1, 10))
@pytest.mark.asyncio
async def test_async_r_lock_acquire_cancellation(waits):
Expand Down