From f4d6b99b11ef1c1d317a7084e166cca7080fda63 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Thu, 2 Jun 2022 08:22:37 +0200 Subject: [PATCH 1/6] Handle asyncio.CancelledError gracefully --- CHANGELOG.md | 7 +- docs/source/api.rst | 35 +-- docs/source/async_api.rst | 73 +++++- neo4j/_async/io/_bolt.py | 44 ++-- neo4j/_async/io/_common.py | 11 +- neo4j/_async/io/_pool.py | 113 ++++++--- neo4j/_async/work/session.py | 100 +++++++- neo4j/_async/work/transaction.py | 55 ++++- neo4j/_async/work/workspace.py | 13 ++ neo4j/_async_compat/concurrency.py | 203 ++++++++++++++-- neo4j/_async_compat/network/_bolt_socket.py | 38 ++- neo4j/_async_compat/util.py | 16 ++ neo4j/_sync/io/_bolt.py | 44 ++-- neo4j/_sync/io/_common.py | 11 +- neo4j/_sync/io/_pool.py | 91 ++++++-- neo4j/_sync/work/session.py | 100 +++++++- neo4j/_sync/work/transaction.py | 55 ++++- neo4j/_sync/work/workspace.py | 13 ++ neo4j/exceptions.py | 17 +- tests/conftest.py | 9 + .../async_/test_custom_ssl_context.py | 1 + tests/integration/conftest.py | 2 +- tests/integration/mixed/__init__.py | 16 ++ .../mixed/test_async_cancellation.py | 217 ++++++++++++++++++ .../sync/test_custom_ssl_context.py | 1 + .../mixed/async_compat/test_concurrency.py | 30 +++ 26 files changed, 1146 insertions(+), 169 deletions(-) create mode 100644 tests/integration/mixed/__init__.py create mode 100644 tests/integration/mixed/test_async_cancellation.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 12caefe0a..90fabcf62 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -87,8 +87,9 @@ they have been removed. - Deprecated Nodes' and Relationships' `id` property (`int`) in favor of `element_id` (`str`). - This also affects `Graph` objects as `graph.nodes[...]` and - `graph.relationships[...]` now prefers strings over integers. + This also affects `Graph` objects as indexing `graph.nodes[...]` and + `graph.relationships[...]` with integers has been deprecated in favor of + indexing them with strings. - `ServerInfo.connection_id` has been deprecated and will be removed in a future release. There is no replacement as this is considered internal information. @@ -118,6 +119,8 @@ be used by client code. `Record` should be imported directly from `neo4j` instead. `neo4j.data.DataHydrator` and `neo4j.data.DataDeydrator` have been removed without replacement. +- Introduced `neo4j.exceptions.SessionError` that is raised when trying to + execute work on a closed or otherwise terminated session. ## Version 4.4 diff --git a/docs/source/api.rst b/docs/source/api.rst index 76b354e8f..247f843dc 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -511,6 +511,8 @@ To construct a :class:`neo4j.Session` use the :meth:`neo4j.Driver.session` metho Sessions will often be created and destroyed using a *with block context*. +This is the recommended approach as it takes care of closing the session +properly even when an exception is raised. .. code-block:: python @@ -536,6 +538,8 @@ Session .. automethod:: close + .. automethod:: closed + .. automethod:: run .. automethod:: last_bookmarks @@ -643,7 +647,7 @@ context of the impersonated user. For this, the user for which the .. Note:: The server or all servers of the cluster need to support impersonation when. - Otherwise, the driver will raise :py:exc:`.ConfigurationError` + Otherwise, the driver will raise :exc:`.ConfigurationError` as soon as it encounters a server that does not. @@ -708,7 +712,7 @@ Neo4j supports three kinds of transaction: + :ref:`explicit-transactions-ref` + :ref:`managed-transactions-ref` -Each has pros and cons but if in doubt, use a managed transaction with a `transaction function`. +Each has pros and cons but if in doubt, use a managed transaction with a *transaction function*. .. _auto-commit-transactions-ref: @@ -716,7 +720,7 @@ Each has pros and cons but if in doubt, use a managed transaction with a `transa Auto-commit Transactions ======================== Auto-commit transactions are the simplest form of transaction, available via -:py:meth:`neo4j.Session.run`. These are easy to use but support only one +:meth:`neo4j.Session.run`. These are easy to use but support only one statement per transaction and are not automatically retried on failure. Auto-commit transactions are also the only way to run ``PERIODIC COMMIT`` @@ -756,7 +760,7 @@ Example: Explicit Transactions ===================== -Explicit transactions support multiple statements and must be created with an explicit :py:meth:`neo4j.Session.begin_transaction` call. +Explicit transactions support multiple statements and must be created with an explicit :meth:`neo4j.Session.begin_transaction` call. This creates a new :class:`neo4j.Transaction` object that can be used to run Cypher. @@ -766,16 +770,16 @@ It also gives applications the ability to directly control ``commit`` and ``roll .. automethod:: run - .. automethod:: close - - .. automethod:: closed - .. automethod:: commit .. automethod:: rollback + .. automethod:: close + + .. automethod:: closed + Closing an explicit transaction can either happen automatically at the end of a ``with`` block, -or can be explicitly controlled through the :py:meth:`neo4j.Transaction.commit`, :py:meth:`neo4j.Transaction.rollback` or :py:meth:`neo4j.Transaction.close` methods. +or can be explicitly controlled through the :meth:`neo4j.Transaction.commit`, :meth:`neo4j.Transaction.rollback` or :meth:`neo4j.Transaction.close` methods. Explicit transactions are most useful for applications that need to distribute Cypher execution across multiple functions for the same transaction. @@ -811,8 +815,8 @@ Managed Transactions (`transaction functions`) ============================================== Transaction functions are the most powerful form of transaction, providing access mode override and retry capabilities. -+ :py:meth:`neo4j.Session.write_transaction` -+ :py:meth:`neo4j.Session.read_transaction` ++ :meth:`neo4j.Session.write_transaction` ++ :meth:`neo4j.Session.read_transaction` These allow a function object representing the transactional unit of work to be passed as a parameter. This function is called one or more times, within a configurable time limit, until it succeeds. @@ -912,8 +916,8 @@ Record .. autoclass:: neo4j.Record() A :class:`neo4j.Record` is an immutable ordered collection of key-value - pairs. It is generally closer to a :py:class:`namedtuple` than to an - :py:class:`OrderedDict` inasmuch as iteration of the collection will + pairs. It is generally closer to a :class:`namedtuple` than to an + :class:`OrderedDict` inasmuch as iteration of the collection will yield values rather than keys. .. describe:: Record(iterable) @@ -1313,6 +1317,8 @@ Client-side errors * :class:`neo4j.exceptions.DriverError` + * :class:`neo4j.exceptions.SessionError` + * :class:`neo4j.exceptions.TransactionError` * :class:`neo4j.exceptions.TransactionNestingError` @@ -1347,6 +1353,9 @@ Client-side errors .. autoclass:: neo4j.exceptions.DriverError :members: is_retryable +.. autoclass:: neo4j.exceptions.SessionError + :show-inheritance: + .. autoclass:: neo4j.exceptions.TransactionError :show-inheritance: diff --git a/docs/source/async_api.rst b/docs/source/async_api.rst index 2929f4579..fbe21dbd8 100644 --- a/docs/source/async_api.rst +++ b/docs/source/async_api.rst @@ -287,6 +287,8 @@ To construct a :class:`neo4j.AsyncSession` use the :meth:`neo4j.AsyncDriver.sess Sessions will often be created and destroyed using a *with block context*. +This is the recommended approach as it takes care of closing the session +properly even when an exception is raised. .. code-block:: python @@ -313,6 +315,10 @@ AsyncSession .. automethod:: close + .. automethod:: cancel + + .. automethod:: closed + .. automethod:: run .. automethod:: last_bookmarks @@ -346,7 +352,7 @@ Neo4j supports three kinds of async transaction: + :ref:`async-explicit-transactions-ref` + :ref:`async-managed-transactions-ref` -Each has pros and cons but if in doubt, use a managed transaction with a `transaction function`. +Each has pros and cons but if in doubt, use a managed transaction with a *transaction function*. .. _async-auto-commit-transactions-ref: @@ -354,7 +360,7 @@ Each has pros and cons but if in doubt, use a managed transaction with a `transa Auto-commit Transactions ======================== Auto-commit transactions are the simplest form of transaction, available via -:py:meth:`neo4j.Session.run`. These are easy to use but support only one +:meth:`neo4j.Session.run`. These are easy to use but support only one statement per transaction and are not automatically retried on failure. Auto-commit transactions are also the only way to run ``PERIODIC COMMIT`` @@ -398,7 +404,7 @@ Example: Explicit Async Transactions =========================== -Explicit transactions support multiple statements and must be created with an explicit :py:meth:`neo4j.AsyncSession.begin_transaction` call. +Explicit transactions support multiple statements and must be created with an explicit :meth:`neo4j.AsyncSession.begin_transaction` call. This creates a new :class:`neo4j.AsyncTransaction` object that can be used to run Cypher. @@ -408,16 +414,18 @@ It also gives applications the ability to directly control ``commit`` and ``roll .. automethod:: run - .. automethod:: close - - .. automethod:: closed - .. automethod:: commit .. automethod:: rollback + .. automethod:: close + + .. automethod:: cancel + + .. automethod:: closed + Closing an explicit transaction can either happen automatically at the end of a ``async with`` block, -or can be explicitly controlled through the :py:meth:`neo4j.AsyncTransaction.commit`, :py:meth:`neo4j.AsyncTransaction.rollback` or :py:meth:`neo4j.AsyncTransaction.close` methods. +or can be explicitly controlled through the :meth:`neo4j.AsyncTransaction.commit`, :meth:`neo4j.AsyncTransaction.rollback`, :meth:`neo4j.AsyncTransaction.close` or :meth:`neo4j.AsyncTransaction.cancel` methods. Explicit transactions are most useful for applications that need to distribute Cypher execution across multiple functions for the same transaction. @@ -456,8 +464,8 @@ Managed Async Transactions (`transaction functions`) ==================================================== Transaction functions are the most powerful form of transaction, providing access mode override and retry capabilities. -+ :py:meth:`neo4j.AsyncSession.write_transaction` -+ :py:meth:`neo4j.AsyncSession.read_transaction` ++ :meth:`neo4j.AsyncSession.write_transaction` ++ :meth:`neo4j.AsyncSession.read_transaction` These allow a function object representing the transactional unit of work to be passed as a parameter. This function is called one or more times, within a configurable time limit, until it succeeds. @@ -531,3 +539,48 @@ A :class:`neo4j.AsyncResult` is attached to an active connection, through a :cla .. automethod:: closed See https://neo4j.com/docs/python-manual/current/cypher-workflow/#python-driver-type-mapping for more about type mapping. + + + +****************** +Async Cancellation +****************** + +Async Python provides a mechanism for cancelling futures +(:meth:`asyncio.Future.cancel`). The driver and its components can handle this. +However, generally, it's not advised to rely on cancellation as it forces the +driver to close affected connections to avoid leaving them in an undefined +state. This makes the driver less efficient. + +The easiest way to make sure your application code's interaction with the driver +is playing nicely with cancellation is to always use the async context manager +provided by :class:`neo4j.AsyncSession` like so: :: + + async with driver.session() as session: + ... # do what you need to do with the session + +If, for whatever reason, you need handle the session manually, you can it like +so: :: + + session = await with driver.session() + try: + ... # do what you need to do with the session + except asyncio.CancelledError: + session.cancel() + raise + finally: + # this becomes a no-op if the session has been cancelled before + await session.close() + +As mentioned above, any cancellation of I/O work will cause the driver to close +the affected connection. This will kill any :class:`neo4j.AsyncTransaction` and +:class:`neo4j.AsyncResult` objects that are attached to that connection. Hence, +after catching a :class:`asyncio.CancelledError`, you should not try to use +transactions or results created earlier. They are likely to not be valid +anymore. + +Furthermore, there is no a guarantee as to whether a piece of ongoing work got +successfully executed on the server side or not, when a cancellation happens: +``await transaction.commit()`` and other methods can throw +:exc:`asyncio.CancelledError` but still have managed to complete from the +server's perspective. diff --git a/neo4j/_async/io/_bolt.py b/neo4j/_async/io/_bolt.py index 7f31f32dd..c09ce42b8 100644 --- a/neo4j/_async/io/_bolt.py +++ b/neo4j/_async/io/_bolt.py @@ -18,6 +18,7 @@ import abc import asyncio +import socket from collections import deque from logging import getLogger from time import perf_counter @@ -370,8 +371,9 @@ def time_remaining(): await connection.hello() finally: connection.socket.set_deadline(None) - except Exception: - await connection.close_non_blocking() + except Exception as e: + log.debug("[#%04X] C: %r", s.getsockname()[1], e) + connection.kill() raise return connection @@ -678,15 +680,20 @@ async def _set_defunct_write(self, error=None, silent=False): async def _set_defunct(self, message, error=None, silent=False): from ._pool import AsyncBoltPool direct_driver = isinstance(self.pool, AsyncBoltPool) + user_cancelled = isinstance(error, asyncio.CancelledError) if error: - log.debug("[#%04X] %r", self.socket.getsockname()[1], error) - log.error(message) + log.debug("[#%04X] %r", self.local_port, error) + if not user_cancelled: + log.error(message) # We were attempting to receive data but the connection # has unexpectedly terminated. So, we need to close the # connection from the client side, and remove the address # from the connection pool. self._defunct = True + if user_cancelled: + self.kill() + raise error # cancellation error should not be re-written if not self._closing: # If we fail while closing the connection, there is no need to # remove the connection from the pool, nor to try to close the @@ -694,6 +701,7 @@ async def _set_defunct(self, message, error=None, silent=False): await self.close() if self.pool: await self.pool.deactivate(address=self.unresolved_address) + # Iterate through the outstanding responses, and if any correspond # to COMMIT requests then raise an error to signal that we are # unable to confirm that the COMMIT completed successfully. @@ -736,8 +744,9 @@ async def close(self): self.goodbye() try: await self._send_all() - except (OSError, BoltError, DriverError): - pass + except (OSError, BoltError, DriverError) as exc: + log.debug("[#%04X] ignoring failed close %r", + self.local_port, exc) log.debug("[#%04X] C: ", self.local_port) try: await self.socket.close() @@ -746,18 +755,19 @@ async def close(self): finally: self._closed = True - async def close_non_blocking(self): - """Set the socket to non-blocking and close it. - - This will try to send the `GOODBYE` message (given the socket is not - marked as defunct). However, should the write operation require - blocking (e.g., a full network buffer), then the socket will be closed - immediately (without `GOODBYE` message). - """ - if self._closed or self._closing: + def kill(self): + """Close the socket most violently. No flush, no goodbye, no mercy.""" + if self._closed: return - self.socket.settimeout(0) - await self.close() + log.debug("[#%04X] C: ", self.local_port) + self._closing = True + try: + self.socket.kill() + except OSError as exc: + log.debug("[#%04X] ignoring failed kill %r", + self.local_port, exc) + finally: + self._closed = True def closed(self): return self._closed diff --git a/neo4j/_async/io/_common.py b/neo4j/_async/io/_common.py index 98744453a..74db2c31b 100644 --- a/neo4j/_async/io/_common.py +++ b/neo4j/_async/io/_common.py @@ -18,7 +18,6 @@ import asyncio import logging -import socket from struct import pack as struct_pack from ..._async_compat.util import AsyncUtil @@ -64,8 +63,9 @@ async def _buffer_one_chunk(self): if chunk_size == 0: # chunk_size was the end marker for the message return - - except (OSError, socket.timeout, SocketDeadlineExceeded) as error: + except ( + OSError, SocketDeadlineExceeded, asyncio.CancelledError + ) as error: self._broken = True await AsyncUtil.callback(self.on_error, error) raise @@ -139,7 +139,7 @@ async def flush(self): if data: try: await self.socket.sendall(data) - except OSError as error: + except (OSError, asyncio.CancelledError) as error: await self.on_error(error) return False self._clear() @@ -186,7 +186,8 @@ def outer_async(coroutine_func): async def inner(*args, **kwargs): try: await coroutine_func(*args, **kwargs) - except (Neo4jError, ServiceUnavailable, SessionExpired) as exc: + except (Neo4jError, ServiceUnavailable, SessionExpired, + asyncio.CancelledError) as exc: await AsyncUtil.callback(self.__on_error, exc) raise return inner diff --git a/neo4j/_async/io/_pool.py b/neo4j/_async/io/_pool.py index 0a3bfac58..8da20d663 100644 --- a/neo4j/_async/io/_pool.py +++ b/neo4j/_async/io/_pool.py @@ -17,6 +17,7 @@ import abc +import asyncio import logging from collections import ( defaultdict, @@ -28,9 +29,11 @@ from ..._async_compat.concurrency import ( AsyncCondition, + AsyncCooperativeRLock, AsyncRLock, ) from ..._async_compat.network import AsyncNetworkUtil +from ..._async_compat.util import AsyncUtil from ..._conf import ( PoolConfig, WorkspaceConfig, @@ -78,7 +81,7 @@ def __init__(self, opener, pool_config, workspace_config): self.workspace_config = workspace_config self.connections = defaultdict(deque) self.connections_reservations = defaultdict(lambda: 0) - self.lock = AsyncRLock() + self.lock = AsyncCooperativeRLock() self.cond = AsyncCondition(self.lock) async def __aenter__(self): @@ -88,7 +91,7 @@ async def __aexit__(self, exc_type, exc_value, traceback): await self.close() async def _acquire_from_pool(self, address): - async with self.lock: + with self.lock: for connection in list(self.connections.get(address, [])): if connection.in_use: continue @@ -118,7 +121,7 @@ async def _acquire_from_pool_checked( connection.stale(), connection.in_use ) await connection.close() - async with self.lock: + with self.lock: try: self.connections.get(address, []).remove(connection) except ValueError: @@ -131,7 +134,7 @@ async def _acquire_from_pool_checked( else: return connection - async def _acquire_new_later(self, address, deadline): + def _acquire_new_later(self, address, deadline): async def connection_creator(): released_reservation = False try: @@ -144,20 +147,20 @@ async def connection_creator(): raise connection.pool = self connection.in_use = True - async with self.lock: + with self.lock: self.connections_reservations[address] -= 1 released_reservation = True self.connections[address].append(connection) return connection finally: if not released_reservation: - async with self.lock: + with self.lock: self.connections_reservations[address] -= 1 max_pool_size = self.pool_config.max_connection_pool_size infinite_pool_size = (max_pool_size < 0 or max_pool_size == float("inf")) - async with self.lock: + with self.lock: connections = self.connections[address] pool_size = (len(connections) + self.connections_reservations[address]) @@ -184,6 +187,8 @@ async def health_check(connection_, deadline_): if connection_.is_idle_for(liveness_check_timeout): with connection_deadline(connection_, deadline_): try: + log.debug("[#%04X] C: ", + connection_.local_port) await connection_.reset() except (OSError, ServiceUnavailable, SessionExpired): return False @@ -197,10 +202,8 @@ async def health_check(connection_, deadline_): if connection: return connection # all connections in pool are in-use - async with self.lock: - connection_creator = await self._acquire_new_later( - address, deadline - ) + with self.lock: + connection_creator = self._acquire_new_later(address, deadline) if connection_creator: break @@ -232,21 +235,55 @@ async def acquire( :param liveness_check_timeout: """ + def kill_and_release(self, *connections): + """ Release connections back into the pool after closing them. + + This method is thread safe. + """ + for connection in connections: + if not (connection.defunct() + or connection.closed()): + log.debug( + "[#%04X] C: killing connection on release", + connection.local_port + ) + connection.kill() + with self.lock: + for connection in connections: + connection.in_use = False + self.cond.notify_all() + async def release(self, *connections): - """ Release a connection back into the pool. + """ Release connections back into the pool. + This method is thread safe. """ - async with self.lock: + cancelled = None + for connection in connections: + if not (connection.defunct() + or connection.closed() + or connection.is_reset): + if cancelled is not None: + log.debug( + "[#%04X] C: released unclean connection", + connection.local_port + ) + connection.kill() + continue + try: + log.debug( + "[#%04X] C: released unclean connection", + connection.local_port + ) + await connection.reset() + except (Neo4jError, DriverError, BoltError) as e: + log.debug("Failed to reset connection on release: %r", e) + except asyncio.CancelledError as e: + log.debug("Cancelled reset connection on release: %r", e) + cancelled = e + connection.kill() + with self.lock: for connection in connections: - if not (connection.defunct() - or connection.closed() - or connection.is_reset): - try: - await connection.reset() - except (Neo4jError, DriverError, BoltError) as e: - log.debug( - "Failed to reset connection on release: %s", e - ) connection.in_use = False self.cond.notify_all() @@ -262,16 +299,33 @@ def in_use_connection_count(self, address): return sum(1 if connection.in_use else 0 for connection in connections) async def mark_all_stale(self): - async with self.lock: + with self.lock: for address in self.connections: for connection in self.connections[address]: connection.set_stale() + @classmethod + async def _close_connections(cls, connections): + cancelled = None + for connection in connections: + if cancelled is not None: + connection.kill() + continue + try: + await connection.close() + except asyncio.CancelledError as e: + # We've got cancelled: not time to gracefully close these + # connections. Time to burn down the place. + cancelled = e + connection.kill() + if cancelled is not None: + raise cancelled + async def deactivate(self, address): """ Deactivate an address from the connection pool, if present, closing all idle connection to that address """ - async with self.lock: + with self.lock: try: connections = self.connections[address] except KeyError: # already removed from the connection pool @@ -284,11 +338,11 @@ async def deactivate(self, address): # again. for conn in closable_connections: connections.remove(conn) - for conn in closable_connections: - await conn.close() if not self.connections[address]: del self.connections[address] + await self._close_connections(closable_connections) + def on_write_failure(self, address): raise WriteServiceUnavailable( "No write service available for pool {}".format(self) @@ -298,11 +352,14 @@ async def close(self): """ Close all connections and empty the pool. This method is thread safe. """ + log.debug("[#0000] C: close") try: - async with self.lock: + connections = [] + with self.lock: for address in list(self.connections): for connection in self.connections.pop(address, ()): - await connection.close() + connections.append(connection) + await self._close_connections(connections) except TypeError: pass diff --git a/neo4j/_async/work/session.py b/neo4j/_async/work/session.py index ee4b73acf..cc3de0dd4 100644 --- a/neo4j/_async/work/session.py +++ b/neo4j/_async/work/session.py @@ -14,13 +14,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - +import asyncio +from functools import wraps from logging import getLogger from random import random from time import perf_counter from ..._async_compat import async_sleep +from ..._async_compat.util import AsyncUtil from ..._conf import SessionConfig from ..._meta import ( deprecated, @@ -36,6 +37,7 @@ DriverError, Neo4jError, ServiceUnavailable, + SessionError, SessionExpired, TransactionError, ) @@ -88,12 +90,17 @@ def __init__(self, pool, session_config): super().__init__(pool, session_config) assert isinstance(session_config, SessionConfig) self._bookmarks = self._prepare_bookmarks(session_config.bookmarks) + self._cancelled = False async def __aenter__(self): return self async def __aexit__(self, exception_type, exception_value, traceback): if exception_type: + if issubclass(exception_type, asyncio.CancelledError): + self._handle_cancellation(message="__aexit__") + self._closed = True + return self._state_failed = True await self.close() @@ -115,19 +122,45 @@ def _prepare_bookmarks(self, bookmarks): async def _connect(self, access_mode, **access_kwargs): if access_mode is None: access_mode = self._config.default_access_mode - await super()._connect(access_mode, **access_kwargs) + try: + await super()._connect(access_mode, **access_kwargs) + except asyncio.CancelledError: + self._handle_cancellation(message="_connect") + raise + + async def _disconnect(self, sync=False): + try: + return await super()._disconnect(sync=sync) + except asyncio.CancelledError: + self._handle_cancellation(message="_disconnect") + raise def _collect_bookmark(self, bookmark): if bookmark: self._bookmarks = bookmark, + def _handle_cancellation(self, message="General"): + self._cancelled = True + self._transaction = None + self._auto_result = None + connection = self._connection + self._connection = None + if connection: + log.debug("[#%04X] %s cancellation clean-up", + connection.local_port, message) + self._pool.kill_and_release(connection) + else: + log.debug("[#0000] %s cancellation clean-up", message) + async def _result_closed(self): if self._auto_result: self._collect_bookmark(self._auto_result._bookmark) self._auto_result = None await self._disconnect() - async def _result_error(self, _): + async def _result_error(self, error): + if isinstance(error, asyncio.CancelledError): + return self._handle_cancellation(message="_result_error") if self._auto_result: self._auto_result = None await self._disconnect() @@ -145,6 +178,7 @@ async def close(self): This will release any borrowed resources, such as connections, and will roll back any outstanding transactions. """ + # if self._closed or self._cancelled: if self._closed: return if self._connection: @@ -183,6 +217,29 @@ async def close(self): self._state_failed = False self._closed = True + if AsyncUtil.is_async_code: + def cancel(self): + """Cancel this session. + + If the session is already closed, this method does nothing. + Else, it will if present, forcefully close the connection the + session holds. This will violently kill all work in flight. + + The primary purpose of this function is to handle + :class:`asyncio.CancelledError`. + + :: + + session = await driver.session() + try: + ... # do some work + except asyncio.CancelledError: + session.cancel() + raise + + """ + self._handle_cancellation(message="manual cancel") + async def run(self, query, parameters=None, **kwargs): """Run a Cypher query within an auto-commit transaction. @@ -208,7 +265,10 @@ async def run(self, query, parameters=None, **kwargs): :param kwargs: additional keyword parameters :returns: a new :class:`neo4j.AsyncResult` object :rtype: AsyncResult + + :raises SessionError: if the session has been closed. """ + self._check_state() if not query: raise ValueError("Cannot run an empty query") if not isinstance(query, (str, Query)): @@ -319,11 +379,16 @@ async def _transaction_closed_handler(self): self._transaction = None await self._disconnect() - async def _transaction_error_handler(self, _): + async def _transaction_error_handler(self, error): if self._transaction: self._transaction = None await self._disconnect() + def _transaction_cancel_handler(self): + return self._handle_cancellation( + message="_transaction_cancel_handler" + ) + async def _open_transaction( self, *, tx_cls, access_mode, metadata=None, timeout=None ): @@ -331,7 +396,8 @@ async def _open_transaction( self._transaction = tx_cls( self._connection, self._config.fetch_size, self._transaction_closed_handler, - self._transaction_error_handler + self._transaction_error_handler, + self._transaction_cancel_handler ) await self._transaction._begin( self._config.database, self._config.impersonated_user, @@ -363,8 +429,10 @@ async def begin_transaction(self, metadata=None, timeout=None): :returns: A new transaction instance. :rtype: AsyncTransaction - :raises TransactionError: :class:`neo4j.exceptions.TransactionError` if a transaction is already open. + :raises TransactionError: if a transaction is already open. + :raises SessionError: if the session has been closed. """ + self._check_state() # TODO: Implement TransactionConfig consumption if self._auto_result: @@ -384,6 +452,7 @@ async def begin_transaction(self, metadata=None, timeout=None): async def _run_transaction( self, access_mode, transaction_function, *args, **kwargs ): + self._check_state() if not callable(transaction_function): raise TypeError("Unit of work is not callable") @@ -410,6 +479,13 @@ async def _run_transaction( tx = self._transaction try: result = await transaction_function(tx, *args, **kwargs) + except asyncio.CancelledError: + # if cancellation callback has not been called yet: + if self._transaction is not None: + self._handle_cancellation( + message="transaction function" + ) + raise except Exception: await tx._close() raise @@ -431,7 +507,11 @@ async def _run_transaction( delay = next(retry_delay) log.warning("Transaction failed and will be retried in {}s ({})" "".format(delay, "; ".join(errors[-1].args))) - await async_sleep(delay) + try: + await async_sleep(delay) + except asyncio.CancelledError: + log.debug("[#0000] Retry cancelled") + raise if errors: raise errors[-1] @@ -488,6 +568,8 @@ async def get_two_tx(tx): :param args: arguments for the `transaction_function` :param kwargs: key word arguments for the `transaction_function` :return: a result as returned by the given unit of work + + :raises SessionError: if the session has been closed. """ return await self._run_transaction( READ_ACCESS, transaction_function, *args, **kwargs @@ -523,6 +605,8 @@ async def create_node_tx(tx, name): :param args: key word arguments for the `transaction_function` :param kwargs: key word arguments for the `transaction_function` :return: a result as returned by the given unit of work + + :raises SessionError: if the session has been closed. """ return await self._run_transaction( WRITE_ACCESS, transaction_function, *args, **kwargs diff --git a/neo4j/_async/work/transaction.py b/neo4j/_async/work/transaction.py index 57000d125..a6aadc942 100644 --- a/neo4j/_async/work/transaction.py +++ b/neo4j/_async/work/transaction.py @@ -16,6 +16,7 @@ # limitations under the License. +import asyncio from functools import wraps from ..._async_compat.util import AsyncUtil @@ -29,7 +30,8 @@ class _AsyncTransactionBase: - def __init__(self, connection, fetch_size, on_closed, on_error): + def __init__(self, connection, fetch_size, on_closed, on_error, + on_cancel): self._connection = connection self._error_handling_connection = ConnectionErrorHandler( connection, self._error_handler @@ -41,6 +43,7 @@ def __init__(self, connection, fetch_size, on_closed, on_error): self._fetch_size = fetch_size self._on_closed = on_closed self._on_error = on_error + self._on_cancel = on_cancel async def _enter(self): return self @@ -51,6 +54,9 @@ async def _exit(self, exception_type, exception_value, traceback): success = not bool(exception_type) if success: await self._commit() + elif issubclass(exception_type, asyncio.CancelledError): + self._cancel() + return await self._close() async def _begin( @@ -68,6 +74,9 @@ async def _result_on_closed_handler(self): async def _error_handler(self, exc): self._last_error = exc + if isinstance(exc, asyncio.CancelledError): + self._cancel() + return await AsyncUtil.callback(self._on_error, exc) async def _consume_results(self): @@ -150,6 +159,9 @@ async def _commit(self): await self._connection.send_all() await self._connection.fetch_all() self._bookmark = metadata.get("bookmark") + except asyncio.CancelledError: + self._on_cancel() + raise finally: self._closed_flag = True await AsyncUtil.callback(self._on_closed) @@ -174,6 +186,9 @@ async def _rollback(self): self._connection.rollback(on_success=metadata.update) await self._connection.send_all() await self._connection.fetch_all() + except asyncio.CancelledError: + self._on_cancel() + raise finally: self._closed_flag = True await AsyncUtil.callback(self._on_closed) @@ -185,10 +200,39 @@ async def _close(self): return await self._rollback() + if AsyncUtil.is_async_code: + def _cancel(self): + """Cancel this transaction. + + If the transaction is already closed, this method does nothing. + Else, it will close the connection without ROLLBACK or COMMIT in + a non-blocking manner. + + The primary purpose of this function is to handle + :class:`asyncio.CancelledError`. + + :: + + tx = await session.begin_transaction() + try: + ... # do some work + except asyncio.CancelledError: + tx.cancel() + raise + + """ + if self._closed_flag: + return + try: + self._on_cancel() + finally: + self._closed_flag = True + def _closed(self): - """Indicator to show whether the transaction has been closed. + """Indicate whether the transaction has been closed or cancelled. - :return: :const:`True` if closed, :const:`False` otherwise. + :return: + :const:`True` if closed or cancelled, :const:`False` otherwise. :rtype: bool """ return self._closed_flag @@ -229,6 +273,11 @@ async def close(self): def closed(self): return self._closed() + if AsyncUtil.is_async_code: + @wraps(_AsyncTransactionBase._cancel) + def cancel(self): + return self._cancel() + class AsyncManagedTransaction(_AsyncTransactionBase): """Transaction object provided to transaction functions. diff --git a/neo4j/_async/work/workspace.py b/neo4j/_async/work/workspace.py index 9c589db57..b7e219736 100644 --- a/neo4j/_async/work/workspace.py +++ b/neo4j/_async/work/workspace.py @@ -26,6 +26,7 @@ ) from ...exceptions import ( ServiceUnavailable, + SessionError, SessionExpired, ) from ..io import AsyncNeo4jPool @@ -130,3 +131,15 @@ async def close(self): return await self._disconnect(sync=True) self._closed = True + + def closed(self): + """Indicate whether the session has been closed. + + :return: :const:`True` if closed, :const:`False` otherwise. + :rtype: bool + """ + return self._closed + + def _check_state(self): + if self._closed: + raise SessionError(self, "Session closed") diff --git a/neo4j/_async_compat/concurrency.py b/neo4j/_async_compat/concurrency.py index feeb2371f..4e5150dc9 100644 --- a/neo4j/_async_compat/concurrency.py +++ b/neo4j/_async_compat/concurrency.py @@ -21,12 +21,18 @@ import re import threading +from neo4j._async_compat.util import AsyncUtil + __all__ = [ "AsyncCondition", + "AsyncCooperativeLock", + "AsyncCooperativeRLock", "AsyncLock", "AsyncRLock", "Condition", + "CooperativeLock", + "CooperativeRLock", "Lock", "RLock", ] @@ -79,7 +85,14 @@ async def _acquire_non_blocking(self, me): task = asyncio.ensure_future(acquire_coro) # yielding one cycle is as close to non-blocking as it gets # (at least without implementing the lock from the ground up) - await asyncio.sleep(0) + try: + await asyncio.sleep(0) + except asyncio.CancelledError: + # This is emulating non-blocking. There is no cancelling this! + # Still, we don't want to silently swallow the cancellation. + # Hence, we flag this task as cancelled again, so that the next + # `await` will raise the CancelledError. + asyncio.current_task().cancel() if task.done() and task.exception() is None: self._owner = me self._count = 1 @@ -108,7 +121,16 @@ async def acquire(self, blocking=True, timeout=-1): await self._acquire(me) return True try: - await asyncio.wait_for(self._acquire(me), timeout) + fut = asyncio.ensure_future(self._acquire(me)) + try: + await asyncio.wait_for(fut, timeout) + except asyncio.CancelledError: + if (fut.done() + and not fut.cancelled() + and fut.exception() is None): + # too late to cancel the acquisition + self._release(me) + raise return True except asyncio.TimeoutError: return False @@ -134,6 +156,131 @@ async def __aexit__(self, t, v, tb): self.release() +class AsyncCooperativeLock: + """Lock placeholder for asyncio Python when working fully cooperatively. + + This lock doesn't do anything in async Python. It's threaded counterpart, + however, is an ordinary :class:`threading.Lock`. + The AsyncCooperativeLock only works if there is no await being used + while the lock is acquired. + """ + + def __init__(self): + self._locked = False + + def __repr__(self): + res = super().__repr__() + extra = "locked" if self._locked else "unlocked" + return f"<{res[1:-1]} [{extra}]>" + + def locked(self): + """Return True if lock is acquired.""" + return self._locked + + def acquire(self): + """Acquire a lock. + + This method will raise a RuntimeError where an ordinary + (non-placeholder) lock would need to block. I.e., when the lock is + already taken. + + Returns True if the lock was successfully acquired. + """ + if self._locked: + raise RuntimeError("Cannot acquire a locked cooperative lock.") + self._locked = True + return True + + def release(self): + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + + When invoked on an unlocked lock, a RuntimeError is raised. + + There is no return value. + """ + if self._locked: + self._locked = False + else: + raise RuntimeError("Lock is not acquired.") + + __enter__ = acquire + + def __exit__(self, t, v, tb): + self.release() + + +class AsyncCooperativeRLock: + """Reentrant lock placeholder for cooperative asyncio Python. + + This lock doesn't do anything in async Python. It's threaded counterpart, + however, is an ordinary :class:`threading.Lock`. + The AsyncCooperativeLock only works if there is no await being used + while the lock is acquired. + """ + + def __init__(self): + self._owner = None + self._count = 0 + + def __repr__(self): + res = super().__repr__() + if self._owner is not None: + extra = f"locked {self._count} times by owner:{self._owner}" + else: + extra = "unlocked" + return f'<{res[1:-1]} [{extra}]>' + + def locked(self): + """Return True if lock is acquired.""" + return self._owner is not None + + def acquire(self): + """Acquire a lock. + + This method will raise a RuntimeError where an ordinary + (non-placeholder) lock would need to block. I.e., when the lock is + already taken by another Task. + + Returns True if the lock was successfully acquired. + """ + me = asyncio.current_task() + if self._owner is None: + self._owner = me + self._count = 1 + return True + if self._owner is me: + self._count += 1 + return True + raise RuntimeError( + "Cannot acquire a foreign locked cooperative lock." + ) + + def release(self): + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + + When invoked on an unlocked or foreign lock, a RuntimeError is raised. + + There is no return value. + """ + me = asyncio.current_task() + if self._owner is None: + raise RuntimeError("Lock is not acquired.") + if self._owner is not me: + raise RuntimeError("Cannot release a foreign lock.") + self._count -= 1 + if not self._count: + self._owner = None + + __enter__ = acquire + + def __exit__(self, t, v, tb): + self.release() + + # copied and modified from asyncio.locks (3.7) # to add support for `.wait(timeout)` class AsyncCondition: @@ -168,7 +315,11 @@ def __init__(self, lock=None, *, loop=None): self._waiters = collections.deque() async def __aenter__(self): - await self.acquire() + if isinstance(self._lock, (AsyncCooperativeLock, + AsyncCooperativeRLock)): + self._lock.acquire() + else: + await self.acquire() # We have no use for the "as ..." clause in the with # statement for locks. return None @@ -183,7 +334,7 @@ def __repr__(self): extra = f'{extra}, waiters:{len(self._waiters)}' return f'<{res[1:-1]} [{extra}]>' - async def _wait(self, me=None): + async def _wait(self, timeout=None, me=None): """Wait until notified. If the calling coroutine has not acquired the lock when this @@ -197,6 +348,7 @@ async def _wait(self, me=None): if not self.locked(): raise RuntimeError('cannot wait on un-acquired lock') + cancelled = False if isinstance(self._lock, AsyncRLock): self._lock._release(me) else: @@ -205,36 +357,37 @@ async def _wait(self, me=None): fut = self._loop.create_future() self._waiters.append(fut) try: - await fut + await asyncio.wait_for(fut, timeout) return True + except asyncio.TimeoutError: + return False + except asyncio.CancelledError: + cancelled = True + raise finally: self._waiters.remove(fut) finally: # Must reacquire lock even if wait is cancelled - cancelled = False - while True: - try: - if isinstance(self._lock, AsyncRLock): - await self._lock._acquire(me) - else: - await self._lock.acquire() - break - except asyncio.CancelledError: - cancelled = True - + if isinstance(self._lock, (AsyncCooperativeLock, + AsyncCooperativeRLock)): + self._lock.acquire() + else: + while True: + try: + if isinstance(self._lock, AsyncRLock): + await self._lock._acquire(me) + else: + await self._lock.acquire() + break + except asyncio.CancelledError: + cancelled = True if cancelled: raise asyncio.CancelledError async def wait(self, timeout=None): - if not timeout: - return await self._wait() me = asyncio.current_task() - try: - await asyncio.wait_for(self._wait(me), timeout) - return True - except asyncio.TimeoutError: - return False + return await self._wait(timeout=timeout, me=me) def notify(self, n=1): """By default, wake up one coroutine waiting on this condition, if any. @@ -270,5 +423,5 @@ def notify_all(self): Condition = threading.Condition -Lock = threading.Lock -RLock = threading.RLock +CooperativeLock = Lock = threading.Lock +CooperativeRLock = RLock = threading.RLock diff --git a/neo4j/_async_compat/network/_bolt_socket.py b/neo4j/_async_compat/network/_bolt_socket.py index 475536bc6..33d6f76b1 100644 --- a/neo4j/_async_compat/network/_bolt_socket.py +++ b/neo4j/_async_compat/network/_bolt_socket.py @@ -49,6 +49,7 @@ DriverError, ServiceUnavailable, ) +from ..util import AsyncUtil from ._util import ( AsyncNetworkUtil, NetworkUtil, @@ -94,7 +95,14 @@ async def _wait_for_io(self, io_fut): if timeout is not None and timeout <= 0: # give the io-operation time for one loop cycle to do its thing io_fut = asyncio.create_task(io_fut) - await asyncio.sleep(0) + try: + await asyncio.sleep(0) + except asyncio.CancelledError: + # This is emulating non-blocking. There is no cancelling this! + # Still, we don't want to silently swallow the cancellation. + # Hence, we flag this task as cancelled again, so that the next + # `await` will raise the CancelledError. + asyncio.current_task().cancel() try: return await asyncio.wait_for(io_fut, timeout) except asyncio.TimeoutError as e: @@ -150,6 +158,9 @@ async def close(self): self._writer.close() await self._writer.wait_closed() + def kill(self): + self._writer.close() + @classmethod async def _connect_secure(cls, resolved_address, timeout, keep_alive, ssl): """ @@ -225,6 +236,12 @@ async def _connect_secure(cls, resolved_address, timeout, keep_alive, ssl): raise ServiceUnavailable( "Timed out trying to establish connection to {!r}".format( resolved_address)) + except asyncio.CancelledError: + log.debug("[#0000] C: %s", resolved_address) + log.debug("[#0000] C: %s", resolved_address) + if s: + await cls.close_socket(s) + raise except (SSLError, CertificateError) as error: local_port = s.getsockname()[1] raise BoltSecurityError( @@ -314,7 +331,7 @@ async def close_socket(cls, socket_): else: socket_.shutdown(SHUT_RDWR) socket_.close() - except OSError: + except (OSError, asyncio.CancelledError): pass @classmethod @@ -347,12 +364,22 @@ async def connect(cls, address, *, timeout, custom_resolver, ssl_context, err_str = error.__class__.__name__ if str(error): err_str += ": " + str(error) - log.debug("[#%04X] C: %s", local_port, - err_str) + log.debug("[#%04X] C: %s %s", local_port, + resolved_address, err_str) if s: await cls.close_socket(s) errors.append(error) failed_addresses.append(resolved_address) + except asyncio.CancelledError: + try: + local_port = s.getsockname()[1] + except (OSError, AttributeError, TypeError): + local_port = 0 + log.debug("[#%04X] C: %s", local_port, + resolved_address) + if s: + await cls.close_socket(s) + raise except Exception: if s: await cls.close_socket(s) @@ -430,6 +457,9 @@ def close(self): self._socket.shutdown(SHUT_RDWR) self._socket.close() + def kill(self): + self._socket.close() + @classmethod def _connect(cls, resolved_address, timeout, keep_alive): """ diff --git a/neo4j/_async_compat/util.py b/neo4j/_async_compat/util.py index fd510d02b..f8c116556 100644 --- a/neo4j/_async_compat/util.py +++ b/neo4j/_async_compat/util.py @@ -16,7 +16,9 @@ # limitations under the License. +import asyncio import inspect +from functools import wraps from .._meta import experimental @@ -51,6 +53,16 @@ async def callback(cb, *args, **kwargs): experimental_async = experimental + @staticmethod + def shielded(coro_function): + assert asyncio.iscoroutinefunction(coro_function) + + @wraps(coro_function) + async def shielded_function(*args, **kwargs): + return await asyncio.shield(coro_function(*args, **kwargs)) + + return shielded_function + is_async_code = True @@ -70,4 +82,8 @@ def f_(f): return f return f_ + @staticmethod + def shielded(coro_function): + return coro_function + is_async_code = False diff --git a/neo4j/_sync/io/_bolt.py b/neo4j/_sync/io/_bolt.py index b7f9ecd88..b04820a06 100644 --- a/neo4j/_sync/io/_bolt.py +++ b/neo4j/_sync/io/_bolt.py @@ -18,6 +18,7 @@ import abc import asyncio +import socket from collections import deque from logging import getLogger from time import perf_counter @@ -370,8 +371,9 @@ def time_remaining(): connection.hello() finally: connection.socket.set_deadline(None) - except Exception: - connection.close_non_blocking() + except Exception as e: + log.debug("[#%04X] C: %r", s.getsockname()[1], e) + connection.kill() raise return connection @@ -678,15 +680,20 @@ def _set_defunct_write(self, error=None, silent=False): def _set_defunct(self, message, error=None, silent=False): from ._pool import BoltPool direct_driver = isinstance(self.pool, BoltPool) + user_cancelled = isinstance(error, asyncio.CancelledError) if error: - log.debug("[#%04X] %r", self.socket.getsockname()[1], error) - log.error(message) + log.debug("[#%04X] %r", self.local_port, error) + if not user_cancelled: + log.error(message) # We were attempting to receive data but the connection # has unexpectedly terminated. So, we need to close the # connection from the client side, and remove the address # from the connection pool. self._defunct = True + if user_cancelled: + self.kill() + raise error # cancellation error should not be re-written if not self._closing: # If we fail while closing the connection, there is no need to # remove the connection from the pool, nor to try to close the @@ -694,6 +701,7 @@ def _set_defunct(self, message, error=None, silent=False): self.close() if self.pool: self.pool.deactivate(address=self.unresolved_address) + # Iterate through the outstanding responses, and if any correspond # to COMMIT requests then raise an error to signal that we are # unable to confirm that the COMMIT completed successfully. @@ -736,8 +744,9 @@ def close(self): self.goodbye() try: self._send_all() - except (OSError, BoltError, DriverError): - pass + except (OSError, BoltError, DriverError) as exc: + log.debug("[#%04X] ignoring failed close %r", + self.local_port, exc) log.debug("[#%04X] C: ", self.local_port) try: self.socket.close() @@ -746,18 +755,19 @@ def close(self): finally: self._closed = True - def close_non_blocking(self): - """Set the socket to non-blocking and close it. - - This will try to send the `GOODBYE` message (given the socket is not - marked as defunct). However, should the write operation require - blocking (e.g., a full network buffer), then the socket will be closed - immediately (without `GOODBYE` message). - """ - if self._closed or self._closing: + def kill(self): + """Close the socket most violently. No flush, no goodbye, no mercy.""" + if self._closed: return - self.socket.settimeout(0) - self.close() + log.debug("[#%04X] C: ", self.local_port) + self._closing = True + try: + self.socket.kill() + except OSError as exc: + log.debug("[#%04X] ignoring failed kill %r", + self.local_port, exc) + finally: + self._closed = True def closed(self): return self._closed diff --git a/neo4j/_sync/io/_common.py b/neo4j/_sync/io/_common.py index 1eea4ba34..0c437b91b 100644 --- a/neo4j/_sync/io/_common.py +++ b/neo4j/_sync/io/_common.py @@ -18,7 +18,6 @@ import asyncio import logging -import socket from struct import pack as struct_pack from ..._async_compat.util import Util @@ -64,8 +63,9 @@ def _buffer_one_chunk(self): if chunk_size == 0: # chunk_size was the end marker for the message return - - except (OSError, socket.timeout, SocketDeadlineExceeded) as error: + except ( + OSError, SocketDeadlineExceeded, asyncio.CancelledError + ) as error: self._broken = True Util.callback(self.on_error, error) raise @@ -139,7 +139,7 @@ def flush(self): if data: try: self.socket.sendall(data) - except OSError as error: + except (OSError, asyncio.CancelledError) as error: self.on_error(error) return False self._clear() @@ -186,7 +186,8 @@ def outer_async(coroutine_func): def inner(*args, **kwargs): try: coroutine_func(*args, **kwargs) - except (Neo4jError, ServiceUnavailable, SessionExpired) as exc: + except (Neo4jError, ServiceUnavailable, SessionExpired, + asyncio.CancelledError) as exc: Util.callback(self.__on_error, exc) raise return inner diff --git a/neo4j/_sync/io/_pool.py b/neo4j/_sync/io/_pool.py index 3cd66a6d3..012587c6c 100644 --- a/neo4j/_sync/io/_pool.py +++ b/neo4j/_sync/io/_pool.py @@ -17,6 +17,7 @@ import abc +import asyncio import logging from collections import ( defaultdict, @@ -28,9 +29,11 @@ from ..._async_compat.concurrency import ( Condition, + CooperativeRLock, RLock, ) from ..._async_compat.network import NetworkUtil +from ..._async_compat.util import Util from ..._conf import ( PoolConfig, WorkspaceConfig, @@ -78,7 +81,7 @@ def __init__(self, opener, pool_config, workspace_config): self.workspace_config = workspace_config self.connections = defaultdict(deque) self.connections_reservations = defaultdict(lambda: 0) - self.lock = RLock() + self.lock = CooperativeRLock() self.cond = Condition(self.lock) def __enter__(self): @@ -184,6 +187,8 @@ def health_check(connection_, deadline_): if connection_.is_idle_for(liveness_check_timeout): with connection_deadline(connection_, deadline_): try: + log.debug("[#%04X] C: ", + connection_.local_port) connection_.reset() except (OSError, ServiceUnavailable, SessionExpired): return False @@ -198,9 +203,7 @@ def health_check(connection_, deadline_): return connection # all connections in pool are in-use with self.lock: - connection_creator = self._acquire_new_later( - address, deadline - ) + connection_creator = self._acquire_new_later(address, deadline) if connection_creator: break @@ -232,21 +235,55 @@ def acquire( :param liveness_check_timeout: """ + def kill_and_release(self, *connections): + """ Release connections back into the pool after closing them. + + This method is thread safe. + """ + for connection in connections: + if not (connection.defunct() + or connection.closed()): + log.debug( + "[#%04X] C: killing connection on release", + connection.local_port + ) + connection.kill() + with self.lock: + for connection in connections: + connection.in_use = False + self.cond.notify_all() + def release(self, *connections): - """ Release a connection back into the pool. + """ Release connections back into the pool. + This method is thread safe. """ + cancelled = None + for connection in connections: + if not (connection.defunct() + or connection.closed() + or connection.is_reset): + if cancelled is not None: + log.debug( + "[#%04X] C: released unclean connection", + connection.local_port + ) + connection.kill() + continue + try: + log.debug( + "[#%04X] C: released unclean connection", + connection.local_port + ) + connection.reset() + except (Neo4jError, DriverError, BoltError) as e: + log.debug("Failed to reset connection on release: %r", e) + except asyncio.CancelledError as e: + log.debug("Cancelled reset connection on release: %r", e) + cancelled = e + connection.kill() with self.lock: for connection in connections: - if not (connection.defunct() - or connection.closed() - or connection.is_reset): - try: - connection.reset() - except (Neo4jError, DriverError, BoltError) as e: - log.debug( - "Failed to reset connection on release: %s", e - ) connection.in_use = False self.cond.notify_all() @@ -267,6 +304,23 @@ def mark_all_stale(self): for connection in self.connections[address]: connection.set_stale() + @classmethod + def _close_connections(cls, connections): + cancelled = None + for connection in connections: + if cancelled is not None: + connection.kill() + continue + try: + connection.close() + except asyncio.CancelledError as e: + # We've got cancelled: not time to gracefully close these + # connections. Time to burn down the place. + cancelled = e + connection.kill() + if cancelled is not None: + raise cancelled + def deactivate(self, address): """ Deactivate an address from the connection pool, if present, closing all idle connection to that address @@ -284,11 +338,11 @@ def deactivate(self, address): # again. for conn in closable_connections: connections.remove(conn) - for conn in closable_connections: - conn.close() if not self.connections[address]: del self.connections[address] + self._close_connections(closable_connections) + def on_write_failure(self, address): raise WriteServiceUnavailable( "No write service available for pool {}".format(self) @@ -298,11 +352,14 @@ def close(self): """ Close all connections and empty the pool. This method is thread safe. """ + log.debug("[#0000] C: close") try: + connections = [] with self.lock: for address in list(self.connections): for connection in self.connections.pop(address, ()): - connection.close() + connections.append(connection) + self._close_connections(connections) except TypeError: pass diff --git a/neo4j/_sync/work/session.py b/neo4j/_sync/work/session.py index 87e103d61..8d34e7024 100644 --- a/neo4j/_sync/work/session.py +++ b/neo4j/_sync/work/session.py @@ -14,13 +14,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - +import asyncio +from functools import wraps from logging import getLogger from random import random from time import perf_counter from ..._async_compat import sleep +from ..._async_compat.util import Util from ..._conf import SessionConfig from ..._meta import ( deprecated, @@ -36,6 +37,7 @@ DriverError, Neo4jError, ServiceUnavailable, + SessionError, SessionExpired, TransactionError, ) @@ -88,12 +90,17 @@ def __init__(self, pool, session_config): super().__init__(pool, session_config) assert isinstance(session_config, SessionConfig) self._bookmarks = self._prepare_bookmarks(session_config.bookmarks) + self._cancelled = False def __enter__(self): return self def __exit__(self, exception_type, exception_value, traceback): if exception_type: + if issubclass(exception_type, asyncio.CancelledError): + self._handle_cancellation(message="__aexit__") + self._closed = True + return self._state_failed = True self.close() @@ -115,19 +122,45 @@ def _prepare_bookmarks(self, bookmarks): def _connect(self, access_mode, **access_kwargs): if access_mode is None: access_mode = self._config.default_access_mode - super()._connect(access_mode, **access_kwargs) + try: + super()._connect(access_mode, **access_kwargs) + except asyncio.CancelledError: + self._handle_cancellation(message="_connect") + raise + + def _disconnect(self, sync=False): + try: + return super()._disconnect(sync=sync) + except asyncio.CancelledError: + self._handle_cancellation(message="_disconnect") + raise def _collect_bookmark(self, bookmark): if bookmark: self._bookmarks = bookmark, + def _handle_cancellation(self, message="General"): + self._cancelled = True + self._transaction = None + self._auto_result = None + connection = self._connection + self._connection = None + if connection: + log.debug("[#%04X] %s cancellation clean-up", + connection.local_port, message) + self._pool.kill_and_release(connection) + else: + log.debug("[#0000] %s cancellation clean-up", message) + def _result_closed(self): if self._auto_result: self._collect_bookmark(self._auto_result._bookmark) self._auto_result = None self._disconnect() - def _result_error(self, _): + def _result_error(self, error): + if isinstance(error, asyncio.CancelledError): + return self._handle_cancellation(message="_result_error") if self._auto_result: self._auto_result = None self._disconnect() @@ -145,6 +178,7 @@ def close(self): This will release any borrowed resources, such as connections, and will roll back any outstanding transactions. """ + # if self._closed or self._cancelled: if self._closed: return if self._connection: @@ -183,6 +217,29 @@ def close(self): self._state_failed = False self._closed = True + if Util.is_async_code: + def cancel(self): + """Cancel this session. + + If the session is already closed, this method does nothing. + Else, it will if present, forcefully close the connection the + session holds. This will violently kill all work in flight. + + The primary purpose of this function is to handle + :class:`asyncio.CancelledError`. + + :: + + session = driver.session() + try: + ... # do some work + except asyncio.CancelledError: + session.cancel() + raise + + """ + self._handle_cancellation(message="manual cancel") + def run(self, query, parameters=None, **kwargs): """Run a Cypher query within an auto-commit transaction. @@ -208,7 +265,10 @@ def run(self, query, parameters=None, **kwargs): :param kwargs: additional keyword parameters :returns: a new :class:`neo4j.Result` object :rtype: Result + + :raises SessionError: if the session has been closed. """ + self._check_state() if not query: raise ValueError("Cannot run an empty query") if not isinstance(query, (str, Query)): @@ -319,11 +379,16 @@ def _transaction_closed_handler(self): self._transaction = None self._disconnect() - def _transaction_error_handler(self, _): + def _transaction_error_handler(self, error): if self._transaction: self._transaction = None self._disconnect() + def _transaction_cancel_handler(self): + return self._handle_cancellation( + message="_transaction_cancel_handler" + ) + def _open_transaction( self, *, tx_cls, access_mode, metadata=None, timeout=None ): @@ -331,7 +396,8 @@ def _open_transaction( self._transaction = tx_cls( self._connection, self._config.fetch_size, self._transaction_closed_handler, - self._transaction_error_handler + self._transaction_error_handler, + self._transaction_cancel_handler ) self._transaction._begin( self._config.database, self._config.impersonated_user, @@ -363,8 +429,10 @@ def begin_transaction(self, metadata=None, timeout=None): :returns: A new transaction instance. :rtype: Transaction - :raises TransactionError: :class:`neo4j.exceptions.TransactionError` if a transaction is already open. + :raises TransactionError: if a transaction is already open. + :raises SessionError: if the session has been closed. """ + self._check_state() # TODO: Implement TransactionConfig consumption if self._auto_result: @@ -384,6 +452,7 @@ def begin_transaction(self, metadata=None, timeout=None): def _run_transaction( self, access_mode, transaction_function, *args, **kwargs ): + self._check_state() if not callable(transaction_function): raise TypeError("Unit of work is not callable") @@ -410,6 +479,13 @@ def _run_transaction( tx = self._transaction try: result = transaction_function(tx, *args, **kwargs) + except asyncio.CancelledError: + # if cancellation callback has not been called yet: + if self._transaction is not None: + self._handle_cancellation( + message="transaction function" + ) + raise except Exception: tx._close() raise @@ -431,7 +507,11 @@ def _run_transaction( delay = next(retry_delay) log.warning("Transaction failed and will be retried in {}s ({})" "".format(delay, "; ".join(errors[-1].args))) - sleep(delay) + try: + sleep(delay) + except asyncio.CancelledError: + log.debug("[#0000] Retry cancelled") + raise if errors: raise errors[-1] @@ -488,6 +568,8 @@ def get_two_tx(tx): :param args: arguments for the `transaction_function` :param kwargs: key word arguments for the `transaction_function` :return: a result as returned by the given unit of work + + :raises SessionError: if the session has been closed. """ return self._run_transaction( READ_ACCESS, transaction_function, *args, **kwargs @@ -523,6 +605,8 @@ def create_node_tx(tx, name): :param args: key word arguments for the `transaction_function` :param kwargs: key word arguments for the `transaction_function` :return: a result as returned by the given unit of work + + :raises SessionError: if the session has been closed. """ return self._run_transaction( WRITE_ACCESS, transaction_function, *args, **kwargs diff --git a/neo4j/_sync/work/transaction.py b/neo4j/_sync/work/transaction.py index a834f00cf..981a40271 100644 --- a/neo4j/_sync/work/transaction.py +++ b/neo4j/_sync/work/transaction.py @@ -16,6 +16,7 @@ # limitations under the License. +import asyncio from functools import wraps from ..._async_compat.util import Util @@ -29,7 +30,8 @@ class _TransactionBase: - def __init__(self, connection, fetch_size, on_closed, on_error): + def __init__(self, connection, fetch_size, on_closed, on_error, + on_cancel): self._connection = connection self._error_handling_connection = ConnectionErrorHandler( connection, self._error_handler @@ -41,6 +43,7 @@ def __init__(self, connection, fetch_size, on_closed, on_error): self._fetch_size = fetch_size self._on_closed = on_closed self._on_error = on_error + self._on_cancel = on_cancel def _enter(self): return self @@ -51,6 +54,9 @@ def _exit(self, exception_type, exception_value, traceback): success = not bool(exception_type) if success: self._commit() + elif issubclass(exception_type, asyncio.CancelledError): + self._cancel() + return self._close() def _begin( @@ -68,6 +74,9 @@ def _result_on_closed_handler(self): def _error_handler(self, exc): self._last_error = exc + if isinstance(exc, asyncio.CancelledError): + self._cancel() + return Util.callback(self._on_error, exc) def _consume_results(self): @@ -150,6 +159,9 @@ def _commit(self): self._connection.send_all() self._connection.fetch_all() self._bookmark = metadata.get("bookmark") + except asyncio.CancelledError: + self._on_cancel() + raise finally: self._closed_flag = True Util.callback(self._on_closed) @@ -174,6 +186,9 @@ def _rollback(self): self._connection.rollback(on_success=metadata.update) self._connection.send_all() self._connection.fetch_all() + except asyncio.CancelledError: + self._on_cancel() + raise finally: self._closed_flag = True Util.callback(self._on_closed) @@ -185,10 +200,39 @@ def _close(self): return self._rollback() + if Util.is_async_code: + def _cancel(self): + """Cancel this transaction. + + If the transaction is already closed, this method does nothing. + Else, it will close the connection without ROLLBACK or COMMIT in + a non-blocking manner. + + The primary purpose of this function is to handle + :class:`asyncio.CancelledError`. + + :: + + tx = session.begin_transaction() + try: + ... # do some work + except asyncio.CancelledError: + tx.cancel() + raise + + """ + if self._closed_flag: + return + try: + self._on_cancel() + finally: + self._closed_flag = True + def _closed(self): - """Indicator to show whether the transaction has been closed. + """Indicate whether the transaction has been closed or cancelled. - :return: :const:`True` if closed, :const:`False` otherwise. + :return: + :const:`True` if closed or cancelled, :const:`False` otherwise. :rtype: bool """ return self._closed_flag @@ -229,6 +273,11 @@ def close(self): def closed(self): return self._closed() + if Util.is_async_code: + @wraps(_TransactionBase._cancel) + def cancel(self): + return self._cancel() + class ManagedTransaction(_TransactionBase): """Transaction object provided to transaction functions. diff --git a/neo4j/_sync/work/workspace.py b/neo4j/_sync/work/workspace.py index c10fc912e..e60f7ce61 100644 --- a/neo4j/_sync/work/workspace.py +++ b/neo4j/_sync/work/workspace.py @@ -26,6 +26,7 @@ ) from ...exceptions import ( ServiceUnavailable, + SessionError, SessionExpired, ) from ..io import Neo4jPool @@ -130,3 +131,15 @@ def close(self): return self._disconnect(sync=True) self._closed = True + + def closed(self): + """Indicate whether the session has been closed. + + :return: :const:`True` if closed, :const:`False` otherwise. + :rtype: bool + """ + return self._closed + + def _check_state(self): + if self._closed: + raise SessionError(self, "Session closed") diff --git a/neo4j/exceptions.py b/neo4j/exceptions.py index 15392c5e0..101f54690 100644 --- a/neo4j/exceptions.py +++ b/neo4j/exceptions.py @@ -36,6 +36,7 @@ + ForbiddenOnReadOnlyDatabase + DriverError + + SessionError + TransactionError + TransactionNestingError + ResultError @@ -357,6 +358,16 @@ def is_retryable(self): return False +# DriverError > SessionError +class SessionError(DriverError): + """ Raised when an error occurs while using a session. + """ + + def __init__(self, session, *args, **kwargs): + super().__init__(*args, **kwargs) + self.session = session + + # DriverError > TransactionError class TransactionError(DriverError): """ Raised when an error occurs while using a transaction. @@ -367,7 +378,7 @@ def __init__(self, transaction, *args, **kwargs): self.transaction = transaction -# DriverError > TransactionNestingError +# DriverError > TransactionError > TransactionNestingError class TransactionNestingError(TransactionError): """ Raised when transactions are nested incorrectly. """ @@ -411,8 +422,8 @@ class SessionExpired(DriverError): the purpose described by its original parameters. """ - def __init__(self, session, *args, **kwargs): - super().__init__(session, *args, **kwargs) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) def is_retryable(self): return True diff --git a/tests/conftest.py b/tests/conftest.py index 4bcc74353..b8b87b34e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -188,3 +188,12 @@ def _(): benchmark(func, *args, **kwargs) return _wrapper + + +@pytest.fixture +def watcher(): + import sys + + from neo4j.debug import watch + with watch("neo4j", out=sys.stdout, colour=True): + yield diff --git a/tests/integration/async_/test_custom_ssl_context.py b/tests/integration/async_/test_custom_ssl_context.py index 4bc52f728..d3d50c500 100644 --- a/tests/integration/async_/test_custom_ssl_context.py +++ b/tests/integration/async_/test_custom_ssl_context.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + from ssl import SSLContext import pytest diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 0d9661960..82ebe9cf7 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -33,7 +33,7 @@ def run_and_rollback(tx, cypher, **parameters): raise ForcedRollback(value) def f(cypher, **parameters): - with bolt_driver.session() as session: + with driver.session() as session: try: session.write_transaction(run_and_rollback, cypher, **parameters) diff --git a/tests/integration/mixed/__init__.py b/tests/integration/mixed/__init__.py new file mode 100644 index 000000000..c42cc6fb6 --- /dev/null +++ b/tests/integration/mixed/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/integration/mixed/test_async_cancellation.py b/tests/integration/mixed/test_async_cancellation.py new file mode 100644 index 000000000..c9046e1b9 --- /dev/null +++ b/tests/integration/mixed/test_async_cancellation.py @@ -0,0 +1,217 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import random +from functools import wraps + +import pytest + +import neo4j +from neo4j import exceptions as neo4j_exceptions + +from ..._async_compat import mark_async_test +from ...conftest import get_async_driver_no_warning + + +def _get_work(): + work_cancelled = False + + async def work(tx, i=1): + nonlocal work_cancelled + assert not work_cancelled # no retries after cancellation! + try: + result = await tx.run(f"RETURN {i}") + try: + for _ in range(3): + await asyncio.sleep(0) + except asyncio.CancelledError as e: + e.during_sleep = True + raise + records = [record async for record in result] + summary = await result.consume() + assert isinstance(summary, neo4j.ResultSummary) + assert len(records) == 1 + assert list(records[0]) == [i] + except asyncio.CancelledError as e: + work_cancelled = True + raise + + return work + + +async def _do_the_read_tx_func(session_, i=1): + await session_.read_transaction(_get_work(), i=i) + + +def _with_retry(outer): + @wraps(outer) + async def inner(*args, **kwargs): + for _ in range(15): # super simple retry-mechanism + try: + return await outer(*args, **kwargs) + except (neo4j_exceptions.DriverError, + neo4j_exceptions.Neo4jError) as e: + if not e.is_retryable(): + raise + await asyncio.sleep(1.5) + return inner + + +@_with_retry +async def _do_the_read_tx_context(session_, i=1): + async with await session_.begin_transaction() as tx: + await _get_work()(tx, i=i) + + +@_with_retry +async def _do_the_read_explicit_tx(session_, i=1): + tx = await session_.begin_transaction() + try: + await _get_work()(tx, i=i) + except asyncio.CancelledError: + tx.cancel() + raise + await tx.commit() + + +@_with_retry +async def _do_the_read(session_, i=1): + try: + return await _get_work()(session_, i=i) + except asyncio.CancelledError: + session_.cancel() + raise + + +REPEATS = 1000 + + +@mark_async_test +@pytest.mark.parametrize(("i", "read_func", "waits", "cancel_count"), ( + ( + f"{i + 1:0{len(str(REPEATS))}}/{REPEATS}", + random.choice(( + _do_the_read, _do_the_read_tx_context, _do_the_read_explicit_tx, + _do_the_read_tx_func + )), + random.randint(0, 1000), + random.randint(1, 20), + ) + for i in range(REPEATS) # repeats +)) +async def test_async_cancellation( + uri, auth, mocker, read_func, waits, cancel_count, i +): + async with get_async_driver_no_warning( + uri, auth=auth, session_connection_timeout=10 + ) as driver: + async with driver.session() as session: + session._handle_cancellation = mocker.Mock( + wraps=session._handle_cancellation + ) + fut = asyncio.ensure_future(read_func(session)) + for _ in range(waits): + await asyncio.sleep(0) + # time for crazy abuse! + was_done = fut.done() and not fut.cancelled() + for _ in range(cancel_count): + fut.cancel() + await asyncio.sleep(0) + cancelled_error = None + if not was_done: + with pytest.raises(asyncio.CancelledError) as exc: + await fut + cancelled_error = exc.value + + else: + await fut + + bookmarks = await session.last_bookmarks() + if not waits: + assert not bookmarks + session._handle_cancellation.assert_not_called() + elif cancelled_error is not None: + assert not bookmarks + if ( + read_func is _do_the_read + and not getattr(cancelled_error, "during_sleep", False) + ): + # manually handling the session can lead to calling + # `session.cancel` twice, but that's ok, it's a noop if + # already cancelled. + assert len(session._handle_cancellation.call_args) == 2 + else: + session._handle_cancellation.assert_called_once() + else: + assert bookmarks + session._handle_cancellation.assert_not_called() + for read_func in ( + _do_the_read, _do_the_read_tx_context, + _do_the_read_explicit_tx, _do_the_read_tx_func + ): + await read_func(session, i=2) + + # test driver is still working + async with driver.session() as session: + await _do_the_read_tx_func(session, i=3) + new_bookmarks = await session.last_bookmarks() + assert new_bookmarks + assert bookmarks != new_bookmarks + + +SESSION_REPEATS = 100 +READS_PER_SESSION = 20 + + +@mark_async_test +async def test_async_cancellation_does_not_leak(uri, auth): + async with get_async_driver_no_warning( + uri, auth=auth, + session_connection_timeout=10, + # driver needs to cope with a single connection in the pool! + max_connection_pool_size=1, + ) as driver: + for session_number in range(SESSION_REPEATS): + async with driver.session() as session: + for read_number in range(READS_PER_SESSION): + read_func = random.choice(( + _do_the_read, _do_the_read_tx_context, + _do_the_read_explicit_tx, _do_the_read_tx_func + )) + waits = random.randint(0, 1000) + cancel_count = random.randint(1, 20) + + fut = asyncio.ensure_future(read_func(session)) + for _ in range(waits): + await asyncio.sleep(0) + # time for crazy abuse! + was_done = fut.done() and not fut.cancelled() + for _ in range(cancel_count): + fut.cancel() + await asyncio.sleep(0) + if not was_done: + with pytest.raises(asyncio.CancelledError): + await fut + else: + await fut + await _do_the_read_tx_func(session, i=2) + + pool_connections = driver._pool.connections + for connections in pool_connections.values(): + assert len(connections) <= 1 diff --git a/tests/integration/sync/test_custom_ssl_context.py b/tests/integration/sync/test_custom_ssl_context.py index 56d09f684..7d3f01e5a 100644 --- a/tests/integration/sync/test_custom_ssl_context.py +++ b/tests/integration/sync/test_custom_ssl_context.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + from ssl import SSLContext import pytest diff --git a/tests/unit/mixed/async_compat/test_concurrency.py b/tests/unit/mixed/async_compat/test_concurrency.py index 3ddcb2f34..e31db0448 100644 --- a/tests/unit/mixed/async_compat/test_concurrency.py +++ b/tests/unit/mixed/async_compat/test_concurrency.py @@ -166,3 +166,33 @@ async def waiter(): assert not lock.locked() await asyncio.gather(blocker(), waiter_non_blocking(), waiter()) assert lock.locked() # waiter_non_blocking still owns it! + + +@pytest.mark.parametrize("waits", range(1, 10)) +@pytest.mark.asyncio +async def test_async_r_lock_acquire_cancellation(waits): + lock = AsyncRLock() + + async def acquire_task(): + while True: + count = lock._count + cancellation = None + try: + print("try") + await lock.acquire(timeout=0.1) + print("acquired") + except asyncio.CancelledError as exc: + print("cancelled") + cancellation = exc + + if cancellation is not None: + assert lock._count == count + raise cancellation + assert lock._count == count + 1 + + fut = asyncio.ensure_future(acquire_task()) + for _ in range(waits): + await asyncio.sleep(0) + fut.cancel() + with pytest.raises(asyncio.CancelledError): + await fut From ad3ab5765d7269c9c30f01d66623a7bb56f5863d Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Tue, 19 Jul 2022 15:36:18 +0200 Subject: [PATCH 2/6] Patch `asyncio.wait_for` --- docs/source/async_api.rst | 31 +++++ neo4j/_async/io/_common.py | 2 +- neo4j/_async/io/_pool.py | 2 + neo4j/_async/work/session.py | 2 +- neo4j/_async_compat/concurrency.py | 14 +-- neo4j/_async_compat/network/_bolt_socket.py | 9 +- neo4j/_async_compat/shims/__init__.py | 125 ++++++++++++++++++++ neo4j/_sync/io/_common.py | 2 +- neo4j/_sync/io/_pool.py | 2 + neo4j/_sync/work/session.py | 2 +- tests/unit/mixed/async_compat/test_shims.py | 74 ++++++++++++ tests/unit/mixed/io/test_direct.py | 3 +- 12 files changed, 251 insertions(+), 17 deletions(-) create mode 100644 neo4j/_async_compat/shims/__init__.py create mode 100644 tests/unit/mixed/async_compat/test_shims.py diff --git a/docs/source/async_api.rst b/docs/source/async_api.rst index fbe21dbd8..3022ad8ec 100644 --- a/docs/source/async_api.rst +++ b/docs/source/async_api.rst @@ -313,6 +313,37 @@ AsyncSession .. autoclass:: neo4j.AsyncSession() + .. note:: + + Some asyncio utility functions (e.g., :func:`asyncio.wait_for` and + :func:`asyncio.shield`) will wrap work in a :class:`asyncio.Task`. + This introduces concurrency and can lead to undefined behavior as + :class:`AsyncSession` is not concurrency-safe. + + Consider this **wrong** example:: + + async def dont_do_this(driver): + async with driver.session() as session: + await asyncio.shield(session.run("RETURN 1")) + + If ``dont_do_this`` gets cancelled while waiting for ``session.run``, + ``session.run`` itself won't get cancelled (it's shielded) so it will + continue to use the session in another Task. Concurrently, will the + async context manager (``async with driver.session()``) on exit clean + up the session. That's two Tasks handling the session concurrently. + Therefore, this yields undefined behavior. + + In this particular example, the problem could be solved by shielding + the whole coroutine ``dont_do_this`` instead of only the + ``session.run``. Like so:: + + async def thats_better(driver): + async def inner() + async with driver.session() as session: + await session.run("RETURN 1") + + await asyncio.shield(inner()) + .. automethod:: close .. automethod:: cancel diff --git a/neo4j/_async/io/_common.py b/neo4j/_async/io/_common.py index 74db2c31b..8f510689a 100644 --- a/neo4j/_async/io/_common.py +++ b/neo4j/_async/io/_common.py @@ -140,7 +140,7 @@ async def flush(self): try: await self.socket.sendall(data) except (OSError, asyncio.CancelledError) as error: - await self.on_error(error) + await AsyncUtil.callback(self.on_error, error) return False self._clear() return True diff --git a/neo4j/_async/io/_pool.py b/neo4j/_async/io/_pool.py index 8da20d663..684351ac5 100644 --- a/neo4j/_async/io/_pool.py +++ b/neo4j/_async/io/_pool.py @@ -286,6 +286,8 @@ async def release(self, *connections): for connection in connections: connection.in_use = False self.cond.notify_all() + if cancelled is not None: + raise cancelled def in_use_connection_count(self, address): """ Count the number of connections currently in use to a given diff --git a/neo4j/_async/work/session.py b/neo4j/_async/work/session.py index cc3de0dd4..48911d306 100644 --- a/neo4j/_async/work/session.py +++ b/neo4j/_async/work/session.py @@ -61,7 +61,7 @@ class AsyncSession(AsyncWorkspace): Session creation is a lightweight operation and sessions are not safe to be used in concurrent contexts (multiple threads/coroutines). Therefore, a session should generally be short-lived, and must not - span multiple threads/coroutines. + span multiple threads/asynchronous Tasks. In general, sessions will be created and destroyed within a `with` context. For example:: diff --git a/neo4j/_async_compat/concurrency.py b/neo4j/_async_compat/concurrency.py index 4e5150dc9..acaa1698c 100644 --- a/neo4j/_async_compat/concurrency.py +++ b/neo4j/_async_compat/concurrency.py @@ -21,7 +21,7 @@ import re import threading -from neo4j._async_compat.util import AsyncUtil +from neo4j._async_compat.shims import wait_for __all__ = [ @@ -123,12 +123,12 @@ async def acquire(self, blocking=True, timeout=-1): try: fut = asyncio.ensure_future(self._acquire(me)) try: - await asyncio.wait_for(fut, timeout) + await wait_for(fut, timeout) except asyncio.CancelledError: - if (fut.done() - and not fut.cancelled() - and fut.exception() is None): - # too late to cancel the acquisition + already_finished = not fut.cancel() + if already_finished: + # Too late to cancel the acquisition. + # This can only happen in Python 3.7's asyncio self._release(me) raise return True @@ -357,7 +357,7 @@ async def _wait(self, timeout=None, me=None): fut = self._loop.create_future() self._waiters.append(fut) try: - await asyncio.wait_for(fut, timeout) + await wait_for(fut, timeout) return True except asyncio.TimeoutError: return False diff --git a/neo4j/_async_compat/network/_bolt_socket.py b/neo4j/_async_compat/network/_bolt_socket.py index 33d6f76b1..209ad0f89 100644 --- a/neo4j/_async_compat/network/_bolt_socket.py +++ b/neo4j/_async_compat/network/_bolt_socket.py @@ -35,7 +35,6 @@ HAS_SNI, SSLError, ) -from time import perf_counter from ... import addressing from ..._deadline import Deadline @@ -49,7 +48,7 @@ DriverError, ServiceUnavailable, ) -from ..util import AsyncUtil +from ..shims import wait_for from ._util import ( AsyncNetworkUtil, NetworkUtil, @@ -104,7 +103,7 @@ async def _wait_for_io(self, io_fut): # `await` will raise the CancelledError. asyncio.current_task().cancel() try: - return await asyncio.wait_for(io_fut, timeout) + return await wait_for(io_fut, timeout) except asyncio.TimeoutError as e: raise to_raise("timed out") from e @@ -187,7 +186,7 @@ async def _connect_secure(cls, resolved_address, timeout, keep_alive, ssl): "Unsupported address {!r}".format(resolved_address)) s.setblocking(False) # asyncio + blocking = no-no! log.debug("[#0000] C: %s", resolved_address) - await asyncio.wait_for( + await wait_for( loop.sock_connect(s, resolved_address), timeout ) @@ -331,7 +330,7 @@ async def close_socket(cls, socket_): else: socket_.shutdown(SHUT_RDWR) socket_.close() - except (OSError, asyncio.CancelledError): + except OSError: pass @classmethod diff --git a/neo4j/_async_compat/shims/__init__.py b/neo4j/_async_compat/shims/__init__.py new file mode 100644 index 000000000..cd716219c --- /dev/null +++ b/neo4j/_async_compat/shims/__init__.py @@ -0,0 +1,125 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import functools +import sys + + +if sys.version_info >= (3, 8): + # copied from asyncio 3.10 with applied patch + def _release_waiter(waiter, *args): + if not waiter.done(): + waiter.set_result(None) + + async def _cancel_and_wait(fut, loop): + """Cancel the *fut* future or task and wait until it completes.""" + + waiter = loop.create_future() + cb = functools.partial(_release_waiter, waiter) + fut.add_done_callback(cb) + + try: + fut.cancel() + # We cannot wait on *fut* directly to make + # sure _cancel_and_wait itself is reliably cancellable. + await waiter + finally: + fut.remove_done_callback(cb) + + async def wait_for(fut, timeout): + """Wait for the single Future or coroutine to complete, with timeout. + + Coroutine will be wrapped in Task. + + Returns result of the Future or coroutine. When a timeout occurs, + it cancels the task and raises TimeoutError. To avoid the task + cancellation, wrap it in shield(). + + If the wait is cancelled, the task is also cancelled. + + This function is a coroutine. + """ + loop = asyncio.get_running_loop() + + if timeout is None: + return await fut + + if timeout <= 0: + fut = asyncio.ensure_future(fut, loop=loop) + + if fut.done(): + return fut.result() + + await _cancel_and_wait(fut, loop=loop) + try: + return fut.result() + except asyncio.CancelledError as exc: + raise asyncio.TimeoutError() from exc + + waiter = loop.create_future() + timeout_handle = loop.call_later(timeout, _release_waiter, waiter) + cb = functools.partial(_release_waiter, waiter) + + fut = asyncio.ensure_future(fut, loop=loop) + fut.add_done_callback(cb) + + try: + # wait until the future completes or the timeout + try: + await waiter + except asyncio.CancelledError: + if fut.done(): + # [PATCH] + # Applied patch to not swallow the outer cancellation. + # See: https://github.com/python/cpython/pull/26097 + # and https://github.com/python/cpython/pull/28149 + + # We got cancelled, but we are already done. Therefore, + # we defer the cancellation until the task yields to the + # event loop the next time. + asyncio.current_task().cancel() + # [/PATCH] + return fut.result() + else: + fut.remove_done_callback(cb) + # We must ensure that the task is not running + # after wait_for() returns. + # See https://bugs.python.org/issue32751 + await _cancel_and_wait(fut, loop=loop) + raise + + if fut.done(): + return fut.result() + else: + fut.remove_done_callback(cb) + # We must ensure that the task is not running + # after wait_for() returns. + # See https://bugs.python.org/issue32751 + await _cancel_and_wait(fut, loop=loop) + # In case task cancellation failed with some + # exception, we should re-raise it + # See https://bugs.python.org/issue40607 + try: + return fut.result() + except asyncio.CancelledError as exc: + raise asyncio.TimeoutError() from exc + finally: + timeout_handle.cancel() +else: + wait_for = asyncio.wait_for diff --git a/neo4j/_sync/io/_common.py b/neo4j/_sync/io/_common.py index 0c437b91b..3aa2f24d9 100644 --- a/neo4j/_sync/io/_common.py +++ b/neo4j/_sync/io/_common.py @@ -140,7 +140,7 @@ def flush(self): try: self.socket.sendall(data) except (OSError, asyncio.CancelledError) as error: - self.on_error(error) + Util.callback(self.on_error, error) return False self._clear() return True diff --git a/neo4j/_sync/io/_pool.py b/neo4j/_sync/io/_pool.py index 012587c6c..3f1ff93cc 100644 --- a/neo4j/_sync/io/_pool.py +++ b/neo4j/_sync/io/_pool.py @@ -286,6 +286,8 @@ def release(self, *connections): for connection in connections: connection.in_use = False self.cond.notify_all() + if cancelled is not None: + raise cancelled def in_use_connection_count(self, address): """ Count the number of connections currently in use to a given diff --git a/neo4j/_sync/work/session.py b/neo4j/_sync/work/session.py index 8d34e7024..c0c9ec038 100644 --- a/neo4j/_sync/work/session.py +++ b/neo4j/_sync/work/session.py @@ -61,7 +61,7 @@ class Session(Workspace): Session creation is a lightweight operation and sessions are not safe to be used in concurrent contexts (multiple threads/coroutines). Therefore, a session should generally be short-lived, and must not - span multiple threads/coroutines. + span multiple threads/asynchronous Tasks. In general, sessions will be created and destroyed within a `with` context. For example:: diff --git a/tests/unit/mixed/async_compat/test_shims.py b/tests/unit/mixed/async_compat/test_shims.py new file mode 100644 index 000000000..93e0fc435 --- /dev/null +++ b/tests/unit/mixed/async_compat/test_shims.py @@ -0,0 +1,74 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import sys + +import pytest + +from neo4j._async_compat import shims + +from ...._async_compat import mark_async_test + + +@pytest.mark.skipif( + sys.version_info < (3, 8), + reason="wait_for is only broken in Python 3.8+" +) +@mark_async_test +async def test_wait_for_shim_is_necessary_starting_from_3x8(): + # when this tests fails, the shim became superfluous + inner = asyncio.get_event_loop().create_future() + outer = asyncio.wait_for(inner, 0.1) + outer_future = asyncio.ensure_future(outer) + await asyncio.sleep(0) + inner.set_result(None) # inner is done + outer_future.cancel() # AND outer got cancelled + + # this should propagate the cancellation, but it's broken :/ + await outer_future + + +@pytest.mark.skipif( + sys.version_info >= (3, 8), + reason="wait_for is only broken in Python 3.8+" +) +@mark_async_test +async def test_wait_for_shim_is_not_necessary_prior_to_3x8(): + inner = asyncio.get_event_loop().create_future() + outer = asyncio.wait_for(inner, 0.1) + outer_future = asyncio.ensure_future(outer) + await asyncio.sleep(0) + inner.set_result(None) # inner is done + outer_future.cancel() # AND outer got cancelled + + with pytest.raises(asyncio.CancelledError): + await outer_future + + +@mark_async_test +async def test_wait_for_shim_propagates_cancellation(): + inner = asyncio.get_event_loop().create_future() + outer = shims.wait_for(inner, 0.1) + outer_future = asyncio.ensure_future(outer) + await asyncio.sleep(0) + inner.set_result(None) # inner is done + outer_future.cancel() # AND outer got cancelled + + with pytest.raises(asyncio.CancelledError): + await outer_future diff --git a/tests/unit/mixed/io/test_direct.py b/tests/unit/mixed/io/test_direct.py index b849c9604..be28f8104 100644 --- a/tests/unit/mixed/io/test_direct.py +++ b/tests/unit/mixed/io/test_direct.py @@ -32,6 +32,7 @@ import pytest +from neo4j._async_compat.shims import wait_for from neo4j._deadline import Deadline from ...async_.io.test_direct import AsyncFakeBoltPool @@ -127,7 +128,7 @@ async def wait(self, value=0, timeout=None): if time_left <= 0: return False try: - await asyncio.wait_for(self._cond.wait(), time_left) + await wait_for(self._cond.wait(), time_left) except asyncio.TimeoutError: return False From 05fe9467942a57a8f18e0b6439ba0d8266d8f40f Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Wed, 20 Jul 2022 13:36:08 +0200 Subject: [PATCH 3/6] Fix unit tests --- neo4j/_async_compat/shims/__init__.py | 4 +- .../mixed/test_async_cancellation.py | 10 ++-- tests/unit/async_/io/test_direct.py | 1 + tests/unit/async_/work/test_transaction.py | 24 +++++--- .../mixed/async_compat/test_concurrency.py | 18 ++---- tests/unit/mixed/async_compat/test_shims.py | 56 ++++++++++--------- tests/unit/sync/io/test_direct.py | 1 + tests/unit/sync/work/test_transaction.py | 24 +++++--- 8 files changed, 76 insertions(+), 62 deletions(-) diff --git a/neo4j/_async_compat/shims/__init__.py b/neo4j/_async_compat/shims/__init__.py index cd716219c..5d210fd11 100644 --- a/neo4j/_async_compat/shims/__init__.py +++ b/neo4j/_async_compat/shims/__init__.py @@ -91,8 +91,8 @@ async def wait_for(fut, timeout): # and https://github.com/python/cpython/pull/28149 # We got cancelled, but we are already done. Therefore, - # we defer the cancellation until the task yields to the - # event loop the next time. + # we defer the cancellation until next time the task yields + # to the event loop. asyncio.current_task().cancel() # [/PATCH] return fut.result() diff --git a/tests/integration/mixed/test_async_cancellation.py b/tests/integration/mixed/test_async_cancellation.py index c9046e1b9..7529e9db3 100644 --- a/tests/integration/mixed/test_async_cancellation.py +++ b/tests/integration/mixed/test_async_cancellation.py @@ -99,13 +99,13 @@ async def _do_the_read(session_, i=1): raise -REPEATS = 1000 +REPETITIONS = 1000 @mark_async_test @pytest.mark.parametrize(("i", "read_func", "waits", "cancel_count"), ( ( - f"{i + 1:0{len(str(REPEATS))}}/{REPEATS}", + f"{i + 1:0{len(str(REPETITIONS))}}/{REPETITIONS}", random.choice(( _do_the_read, _do_the_read_tx_context, _do_the_read_explicit_tx, _do_the_read_tx_func @@ -113,7 +113,7 @@ async def _do_the_read(session_, i=1): random.randint(0, 1000), random.randint(1, 20), ) - for i in range(REPEATS) # repeats + for i in range(REPETITIONS) )) async def test_async_cancellation( uri, auth, mocker, read_func, waits, cancel_count, i @@ -175,7 +175,7 @@ async def test_async_cancellation( assert bookmarks != new_bookmarks -SESSION_REPEATS = 100 +SESSION_REPETITIONS = 100 READS_PER_SESSION = 20 @@ -187,7 +187,7 @@ async def test_async_cancellation_does_not_leak(uri, auth): # driver needs to cope with a single connection in the pool! max_connection_pool_size=1, ) as driver: - for session_number in range(SESSION_REPEATS): + for session_number in range(SESSION_REPETITIONS): async with driver.session() as session: for read_number in range(READS_PER_SESSION): read_func = random.choice(( diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index 01d37e463..ec3c40b7a 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -52,6 +52,7 @@ class AsyncQuickConnection: def __init__(self, socket): self.socket = socket self.address = socket.getpeername() + self.local_port = self.address[1] @property def is_reset(self): diff --git a/tests/unit/async_/work/test_transaction.py b/tests/unit/async_/work/test_transaction.py index 7fa36ab76..24b79a81c 100644 --- a/tests/unit/async_/work/test_transaction.py +++ b/tests/unit/async_/work/test_transaction.py @@ -40,7 +40,9 @@ async def test_transaction_context_when_committing( ): on_closed = mocker.AsyncMock() on_error = mocker.AsyncMock() - tx = AsyncTransaction(async_fake_connection, 2, on_closed, on_error) + on_cancel = mocker.Mock() + tx = AsyncTransaction(async_fake_connection, 2, on_closed, on_error, + on_cancel) mock_commit = mocker.patch.object(tx, "_commit", wraps=tx._commit) mock_rollback = mocker.patch.object(tx, "_rollback", wraps=tx._rollback) async with tx as tx_: @@ -70,7 +72,9 @@ async def test_transaction_context_with_explicit_rollback( ): on_closed = mocker.AsyncMock() on_error = mocker.AsyncMock() - tx = AsyncTransaction(async_fake_connection, 2, on_closed, on_error) + on_cancel = mocker.Mock() + tx = AsyncTransaction(async_fake_connection, 2, on_closed, on_error, + on_cancel) mock_commit = mocker.patch.object(tx, "_commit", wraps=tx._commit) mock_rollback = mocker.patch.object(tx, "_rollback", wraps=tx._rollback) async with tx as tx_: @@ -99,7 +103,9 @@ class OopsError(RuntimeError): on_closed = MagicMock() on_error = MagicMock() - tx = AsyncTransaction(async_fake_connection, 2, on_closed, on_error) + on_cancel = MagicMock() + tx = AsyncTransaction(async_fake_connection, 2, on_closed, on_error, + on_cancel) mock_commit = mocker.patch.object(tx, "_commit", wraps=tx._commit) mock_rollback = mocker.patch.object(tx, "_rollback", wraps=tx._rollback) with pytest.raises(OopsError): @@ -117,7 +123,9 @@ class OopsError(RuntimeError): async def test_transaction_run_takes_no_query_object(async_fake_connection): on_closed = MagicMock() on_error = MagicMock() - tx = AsyncTransaction(async_fake_connection, 2, on_closed, on_error) + on_cancel = MagicMock() + tx = AsyncTransaction(async_fake_connection, 2, on_closed, on_error, + on_cancel) with pytest.raises(ValueError): await tx.run(Query("RETURN 1")) @@ -128,7 +136,7 @@ async def test_transaction_rollbacks_on_open_connections( ): tx = AsyncTransaction( async_fake_connection, 2, lambda *args, **kwargs: None, - lambda *args, **kwargs: None + lambda *args, **kwargs: None, lambda *args, **kwargs: None ) async with tx as tx_: async_fake_connection.is_reset_mock.return_value = False @@ -145,7 +153,7 @@ async def test_transaction_no_rollback_on_reset_connections( ): tx = AsyncTransaction( async_fake_connection, 2, lambda *args, **kwargs: None, - lambda *args, **kwargs: None + lambda *args, **kwargs: None, lambda *args, **kwargs: None ) async with tx as tx_: async_fake_connection.is_reset_mock.return_value = True @@ -162,7 +170,7 @@ async def test_transaction_no_rollback_on_closed_connections( ): tx = AsyncTransaction( async_fake_connection, 2, lambda *args, **kwargs: None, - lambda *args, **kwargs: None + lambda *args, **kwargs: None, lambda *args, **kwargs: None ) async with tx as tx_: async_fake_connection.closed.return_value = True @@ -181,7 +189,7 @@ async def test_transaction_no_rollback_on_defunct_connections( ): tx = AsyncTransaction( async_fake_connection, 2, lambda *args, **kwargs: None, - lambda *args, **kwargs: None + lambda *args, **kwargs: None, lambda *args, **kwargs: None ) async with tx as tx_: async_fake_connection.defunct.return_value = True diff --git a/tests/unit/mixed/async_compat/test_concurrency.py b/tests/unit/mixed/async_compat/test_concurrency.py index e31db0448..bf5daa72f 100644 --- a/tests/unit/mixed/async_compat/test_concurrency.py +++ b/tests/unit/mixed/async_compat/test_concurrency.py @@ -176,19 +176,13 @@ async def test_async_r_lock_acquire_cancellation(waits): async def acquire_task(): while True: count = lock._count - cancellation = None - try: - print("try") - await lock.acquire(timeout=0.1) - print("acquired") - except asyncio.CancelledError as exc: - print("cancelled") - cancellation = exc - - if cancellation is not None: - assert lock._count == count - raise cancellation + await lock.acquire(timeout=0.1) assert lock._count == count + 1 + try: + await asyncio.sleep(0) + except asyncio.CancelledError: + raise + assert count < 50 # safety guard, we shouldn't ever get there! fut = asyncio.ensure_future(acquire_task()) for _ in range(waits): diff --git a/tests/unit/mixed/async_compat/test_shims.py b/tests/unit/mixed/async_compat/test_shims.py index 93e0fc435..f4c66f914 100644 --- a/tests/unit/mixed/async_compat/test_shims.py +++ b/tests/unit/mixed/async_compat/test_shims.py @@ -26,6 +26,21 @@ from ...._async_compat import mark_async_test +async def _check_wait_for(wait_for_, should_propagate_cancellation): + inner = asyncio.get_event_loop().create_future() + outer = wait_for_(inner, 0.1) + outer_future = asyncio.ensure_future(outer) + await asyncio.sleep(0) + inner.set_result(None) # inner is done + outer_future.cancel() # AND outer got cancelled + + if should_propagate_cancellation: + with pytest.raises(asyncio.CancelledError): + await outer_future + else: + await outer_future + + @pytest.mark.skipif( sys.version_info < (3, 8), reason="wait_for is only broken in Python 3.8+" @@ -33,15 +48,11 @@ @mark_async_test async def test_wait_for_shim_is_necessary_starting_from_3x8(): # when this tests fails, the shim became superfluous - inner = asyncio.get_event_loop().create_future() - outer = asyncio.wait_for(inner, 0.1) - outer_future = asyncio.ensure_future(outer) - await asyncio.sleep(0) - inner.set_result(None) # inner is done - outer_future.cancel() # AND outer got cancelled - - # this should propagate the cancellation, but it's broken :/ - await outer_future + await _check_wait_for( + asyncio.wait_for, + # this should propagate the cancellation, but it's broken :/ + should_propagate_cancellation=False + ) @pytest.mark.skipif( @@ -50,25 +61,16 @@ async def test_wait_for_shim_is_necessary_starting_from_3x8(): ) @mark_async_test async def test_wait_for_shim_is_not_necessary_prior_to_3x8(): - inner = asyncio.get_event_loop().create_future() - outer = asyncio.wait_for(inner, 0.1) - outer_future = asyncio.ensure_future(outer) - await asyncio.sleep(0) - inner.set_result(None) # inner is done - outer_future.cancel() # AND outer got cancelled - - with pytest.raises(asyncio.CancelledError): - await outer_future + await _check_wait_for( + asyncio.wait_for, + should_propagate_cancellation=True + ) @mark_async_test async def test_wait_for_shim_propagates_cancellation(): - inner = asyncio.get_event_loop().create_future() - outer = shims.wait_for(inner, 0.1) - outer_future = asyncio.ensure_future(outer) - await asyncio.sleep(0) - inner.set_result(None) # inner is done - outer_future.cancel() # AND outer got cancelled - - with pytest.raises(asyncio.CancelledError): - await outer_future + # shim should always work regardless of the Python version + await _check_wait_for( + shims.wait_for, + should_propagate_cancellation=True + ) diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index cddbecef8..d24f93a34 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -52,6 +52,7 @@ class QuickConnection: def __init__(self, socket): self.socket = socket self.address = socket.getpeername() + self.local_port = self.address[1] @property def is_reset(self): diff --git a/tests/unit/sync/work/test_transaction.py b/tests/unit/sync/work/test_transaction.py index 9a3440faf..34cd384d7 100644 --- a/tests/unit/sync/work/test_transaction.py +++ b/tests/unit/sync/work/test_transaction.py @@ -40,7 +40,9 @@ def test_transaction_context_when_committing( ): on_closed = mocker.Mock() on_error = mocker.Mock() - tx = Transaction(fake_connection, 2, on_closed, on_error) + on_cancel = mocker.Mock() + tx = Transaction(fake_connection, 2, on_closed, on_error, + on_cancel) mock_commit = mocker.patch.object(tx, "_commit", wraps=tx._commit) mock_rollback = mocker.patch.object(tx, "_rollback", wraps=tx._rollback) with tx as tx_: @@ -70,7 +72,9 @@ def test_transaction_context_with_explicit_rollback( ): on_closed = mocker.Mock() on_error = mocker.Mock() - tx = Transaction(fake_connection, 2, on_closed, on_error) + on_cancel = mocker.Mock() + tx = Transaction(fake_connection, 2, on_closed, on_error, + on_cancel) mock_commit = mocker.patch.object(tx, "_commit", wraps=tx._commit) mock_rollback = mocker.patch.object(tx, "_rollback", wraps=tx._rollback) with tx as tx_: @@ -99,7 +103,9 @@ class OopsError(RuntimeError): on_closed = MagicMock() on_error = MagicMock() - tx = Transaction(fake_connection, 2, on_closed, on_error) + on_cancel = MagicMock() + tx = Transaction(fake_connection, 2, on_closed, on_error, + on_cancel) mock_commit = mocker.patch.object(tx, "_commit", wraps=tx._commit) mock_rollback = mocker.patch.object(tx, "_rollback", wraps=tx._rollback) with pytest.raises(OopsError): @@ -117,7 +123,9 @@ class OopsError(RuntimeError): def test_transaction_run_takes_no_query_object(fake_connection): on_closed = MagicMock() on_error = MagicMock() - tx = Transaction(fake_connection, 2, on_closed, on_error) + on_cancel = MagicMock() + tx = Transaction(fake_connection, 2, on_closed, on_error, + on_cancel) with pytest.raises(ValueError): tx.run(Query("RETURN 1")) @@ -128,7 +136,7 @@ def test_transaction_rollbacks_on_open_connections( ): tx = Transaction( fake_connection, 2, lambda *args, **kwargs: None, - lambda *args, **kwargs: None + lambda *args, **kwargs: None, lambda *args, **kwargs: None ) with tx as tx_: fake_connection.is_reset_mock.return_value = False @@ -145,7 +153,7 @@ def test_transaction_no_rollback_on_reset_connections( ): tx = Transaction( fake_connection, 2, lambda *args, **kwargs: None, - lambda *args, **kwargs: None + lambda *args, **kwargs: None, lambda *args, **kwargs: None ) with tx as tx_: fake_connection.is_reset_mock.return_value = True @@ -162,7 +170,7 @@ def test_transaction_no_rollback_on_closed_connections( ): tx = Transaction( fake_connection, 2, lambda *args, **kwargs: None, - lambda *args, **kwargs: None + lambda *args, **kwargs: None, lambda *args, **kwargs: None ) with tx as tx_: fake_connection.closed.return_value = True @@ -181,7 +189,7 @@ def test_transaction_no_rollback_on_defunct_connections( ): tx = Transaction( fake_connection, 2, lambda *args, **kwargs: None, - lambda *args, **kwargs: None + lambda *args, **kwargs: None, lambda *args, **kwargs: None ) with tx as tx_: fake_connection.defunct.return_value = True From f8cc62cf94cf565ea6e2a896f66289243f9a7a28 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Wed, 20 Jul 2022 13:36:44 +0200 Subject: [PATCH 4/6] Async wait_for shim raise cancellation right away Deferring the cancellation makes it much harder to reason about and might make the Python driver behave in a surprising way when cancelled. --- neo4j/_async/work/session.py | 3 --- neo4j/_async_compat/shims/__init__.py | 3 +-- neo4j/_sync/work/session.py | 3 --- tests/integration/mixed/test_async_cancellation.py | 2 +- tests/unit/mixed/async_compat/test_concurrency.py | 9 +++++++-- 5 files changed, 9 insertions(+), 11 deletions(-) diff --git a/neo4j/_async/work/session.py b/neo4j/_async/work/session.py index 48911d306..bf2e508af 100644 --- a/neo4j/_async/work/session.py +++ b/neo4j/_async/work/session.py @@ -90,7 +90,6 @@ def __init__(self, pool, session_config): super().__init__(pool, session_config) assert isinstance(session_config, SessionConfig) self._bookmarks = self._prepare_bookmarks(session_config.bookmarks) - self._cancelled = False async def __aenter__(self): return self @@ -140,7 +139,6 @@ def _collect_bookmark(self, bookmark): self._bookmarks = bookmark, def _handle_cancellation(self, message="General"): - self._cancelled = True self._transaction = None self._auto_result = None connection = self._connection @@ -178,7 +176,6 @@ async def close(self): This will release any borrowed resources, such as connections, and will roll back any outstanding transactions. """ - # if self._closed or self._cancelled: if self._closed: return if self._connection: diff --git a/neo4j/_async_compat/shims/__init__.py b/neo4j/_async_compat/shims/__init__.py index 5d210fd11..95d964001 100644 --- a/neo4j/_async_compat/shims/__init__.py +++ b/neo4j/_async_compat/shims/__init__.py @@ -93,9 +93,8 @@ async def wait_for(fut, timeout): # We got cancelled, but we are already done. Therefore, # we defer the cancellation until next time the task yields # to the event loop. - asyncio.current_task().cancel() + raise # [/PATCH] - return fut.result() else: fut.remove_done_callback(cb) # We must ensure that the task is not running diff --git a/neo4j/_sync/work/session.py b/neo4j/_sync/work/session.py index c0c9ec038..aaf4d3cae 100644 --- a/neo4j/_sync/work/session.py +++ b/neo4j/_sync/work/session.py @@ -90,7 +90,6 @@ def __init__(self, pool, session_config): super().__init__(pool, session_config) assert isinstance(session_config, SessionConfig) self._bookmarks = self._prepare_bookmarks(session_config.bookmarks) - self._cancelled = False def __enter__(self): return self @@ -140,7 +139,6 @@ def _collect_bookmark(self, bookmark): self._bookmarks = bookmark, def _handle_cancellation(self, message="General"): - self._cancelled = True self._transaction = None self._auto_result = None connection = self._connection @@ -178,7 +176,6 @@ def close(self): This will release any borrowed resources, such as connections, and will roll back any outstanding transactions. """ - # if self._closed or self._cancelled: if self._closed: return if self._connection: diff --git a/tests/integration/mixed/test_async_cancellation.py b/tests/integration/mixed/test_async_cancellation.py index 7529e9db3..496d7ab1e 100644 --- a/tests/integration/mixed/test_async_cancellation.py +++ b/tests/integration/mixed/test_async_cancellation.py @@ -48,7 +48,7 @@ async def work(tx, i=1): assert isinstance(summary, neo4j.ResultSummary) assert len(records) == 1 assert list(records[0]) == [i] - except asyncio.CancelledError as e: + except asyncio.CancelledError: work_cancelled = True raise diff --git a/tests/unit/mixed/async_compat/test_concurrency.py b/tests/unit/mixed/async_compat/test_concurrency.py index bf5daa72f..cc2d340c5 100644 --- a/tests/unit/mixed/async_compat/test_concurrency.py +++ b/tests/unit/mixed/async_compat/test_concurrency.py @@ -176,9 +176,14 @@ async def test_async_r_lock_acquire_cancellation(waits): async def acquire_task(): while True: count = lock._count - await lock.acquire(timeout=0.1) - assert lock._count == count + 1 try: + await lock.acquire(timeout=0.1) + assert lock._count == count + 1 + except asyncio.CancelledError: + assert lock._count == count + raise + try: + # we're also ok with a deferred cancellation await asyncio.sleep(0) except asyncio.CancelledError: raise From c05f57409ceef84172c8edfcca62981a17b84751 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Thu, 21 Jul 2022 15:16:52 +0200 Subject: [PATCH 5/6] Typos and comment clarifications --- docs/source/async_api.rst | 2 +- neo4j/_async/io/_pool.py | 2 +- neo4j/_async_compat/concurrency.py | 1 + neo4j/_sync/io/_pool.py | 2 +- 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/source/async_api.rst b/docs/source/async_api.rst index 3022ad8ec..61f2bb2c3 100644 --- a/docs/source/async_api.rst +++ b/docs/source/async_api.rst @@ -610,7 +610,7 @@ after catching a :class:`asyncio.CancelledError`, you should not try to use transactions or results created earlier. They are likely to not be valid anymore. -Furthermore, there is no a guarantee as to whether a piece of ongoing work got +Furthermore, there is no guarantee as to whether a piece of ongoing work got successfully executed on the server side or not, when a cancellation happens: ``await transaction.commit()`` and other methods can throw :exc:`asyncio.CancelledError` but still have managed to complete from the diff --git a/neo4j/_async/io/_pool.py b/neo4j/_async/io/_pool.py index 684351ac5..9b9118ff6 100644 --- a/neo4j/_async/io/_pool.py +++ b/neo4j/_async/io/_pool.py @@ -316,7 +316,7 @@ async def _close_connections(cls, connections): try: await connection.close() except asyncio.CancelledError as e: - # We've got cancelled: not time to gracefully close these + # We've got cancelled: no more time to gracefully close these # connections. Time to burn down the place. cancelled = e connection.kill() diff --git a/neo4j/_async_compat/concurrency.py b/neo4j/_async_compat/concurrency.py index acaa1698c..3f149a668 100644 --- a/neo4j/_async_compat/concurrency.py +++ b/neo4j/_async_compat/concurrency.py @@ -129,6 +129,7 @@ async def acquire(self, blocking=True, timeout=-1): if already_finished: # Too late to cancel the acquisition. # This can only happen in Python 3.7's asyncio + # as well as in our wait_for shim. self._release(me) raise return True diff --git a/neo4j/_sync/io/_pool.py b/neo4j/_sync/io/_pool.py index 3f1ff93cc..61ad373f7 100644 --- a/neo4j/_sync/io/_pool.py +++ b/neo4j/_sync/io/_pool.py @@ -316,7 +316,7 @@ def _close_connections(cls, connections): try: connection.close() except asyncio.CancelledError as e: - # We've got cancelled: not time to gracefully close these + # We've got cancelled: no more time to gracefully close these # connections. Time to burn down the place. cancelled = e connection.kill() From 33f3c6504b9c9b5cc248823543b1b225cea3e260 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Mon, 1 Aug 2022 09:23:32 +0200 Subject: [PATCH 6/6] Update comment for `wait_for` patch --- neo4j/_async_compat/shims/__init__.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/neo4j/_async_compat/shims/__init__.py b/neo4j/_async_compat/shims/__init__.py index 95d964001..f457725ce 100644 --- a/neo4j/_async_compat/shims/__init__.py +++ b/neo4j/_async_compat/shims/__init__.py @@ -21,6 +21,15 @@ import sys +# === patch asyncio.wait_for === +# The shipped wait_for can swallow cancellation errors (starting with 3.8). +# See: https://github.com/python/cpython/pull/26097 +# and https://github.com/python/cpython/pull/28149 +# Since 3.8 and 3.9 already received their final maintenance release, there +# will be now fix for this. So this patch needs to stick around at least until +# we remove support for Python 3.9. + + if sys.version_info >= (3, 8): # copied from asyncio 3.10 with applied patch def _release_waiter(waiter, *args): @@ -90,9 +99,8 @@ async def wait_for(fut, timeout): # See: https://github.com/python/cpython/pull/26097 # and https://github.com/python/cpython/pull/28149 - # We got cancelled, but we are already done. Therefore, - # we defer the cancellation until next time the task yields - # to the event loop. + # Even though the future we're waiting for is already done, + # we should not swallow the cancellation. raise # [/PATCH] else: