From 8229c7f81919b441500093401fa9f5d06c746836 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 7 Jul 2023 13:08:23 +0200 Subject: [PATCH] AsyncRLock: don't swallow Lock.acquire errors --- src/neo4j/_async_compat/concurrency.py | 12 ++-- .../mixed/async_compat/test_concurrency.py | 63 +++++++++++++++++++ 2 files changed, 71 insertions(+), 4 deletions(-) diff --git a/src/neo4j/_async_compat/concurrency.py b/src/neo4j/_async_compat/concurrency.py index e0c1b78f6..3c4a5eae7 100644 --- a/src/neo4j/_async_compat/concurrency.py +++ b/src/neo4j/_async_compat/concurrency.py @@ -100,10 +100,14 @@ async def _acquire_non_blocking(self, me): # Hence, we flag this task as cancelled again, so that the next # `await` will raise the CancelledError. asyncio.current_task().cancel() - if task.done() and task.exception() is None: - self._owner = me - self._count = 1 - return True + if task.done(): + exception = task.exception() + if exception is None: + self._owner = me + self._count = 1 + return True + else: + raise exception task.cancel() return False diff --git a/tests/unit/mixed/async_compat/test_concurrency.py b/tests/unit/mixed/async_compat/test_concurrency.py index cc2d340c5..ffe08652e 100644 --- a/tests/unit/mixed/async_compat/test_concurrency.py +++ b/tests/unit/mixed/async_compat/test_concurrency.py @@ -168,6 +168,69 @@ async def waiter(): assert lock.locked() # waiter_non_blocking still owns it! +@pytest.mark.asyncio +async def test_async_r_lock_acquire_non_blocking_exception(mocker): + lock = AsyncRLock() + exc = RuntimeError("it broke!") + + assert not lock.locked() + + # Not sure asyncio.Lock.acquire is even fallible, but should it be or ever + # become, our AsyncRLock should re-raise the exception. + acquire_mock = mocker.patch("asyncio.Lock.acquire", autospec=True, + side_effect=asyncio.Lock.acquire) + awaits = 0 + + async def blocker(): + nonlocal awaits + + assert awaits == 0 + awaits += 1 + assert await lock.acquire() + + assert awaits == 1 + awaits += 1 + await asyncio.sleep(0) + + assert awaits == 3 + awaits += 1 + await asyncio.sleep(0) + + assert awaits == 5 + awaits += 1 + await asyncio.sleep(0) + + assert awaits == 7 + lock.release() + + async def waiter_non_blocking(): + nonlocal awaits + nonlocal acquire_mock + + assert awaits == 2 + acquire_mock.side_effect = exc + coro = lock.acquire(blocking=False) + fut = asyncio.ensure_future(coro) + assert not fut.done() + awaits += 1 + await asyncio.sleep(0) + + assert awaits == 4 + assert not fut.done() + awaits += 1 + await asyncio.sleep(0) + + assert awaits == 6 + assert fut.done() + assert fut.exception() is exc + awaits += 1 + + + assert not lock.locked() + await asyncio.gather(blocker(), waiter_non_blocking()) + assert not lock.locked() + + @pytest.mark.parametrize("waits", range(1, 10)) @pytest.mark.asyncio async def test_async_r_lock_acquire_cancellation(waits):