diff --git a/neo4j/work/result.py b/neo4j/work/result.py index 6465a65dc..e0c37159d 100644 --- a/neo4j/work/result.py +++ b/neo4j/work/result.py @@ -24,6 +24,7 @@ from neo4j.data import DataDehydrator from neo4j.exceptions import ( + Neo4jError, ServiceUnavailable, SessionExpired, ) @@ -39,17 +40,16 @@ class _ConnectionErrorHandler: The error will be re-raised after the callback. """ - def __init__(self, connection, on_network_error): + def __init__(self, connection, on_error): """ :param connection the connection object to warp :type connection Bolt - :param on_network_error the function to be called when a method of - connection raises of of the caught errors. The callback takes the - error as argument. - :type on_network_error callable + :param on_error the function to be called when a method of + connection raises of of the caught errors. + :type on_error callable """ self._connection = connection - self._on_network_error = on_network_error + self._on_error = on_error def __getattr__(self, item): connection_attr = getattr(self._connection, item) @@ -60,9 +60,9 @@ def outer(func): def inner(*args, **kwargs): try: func(*args, **kwargs) - finally: - if self._connection.defunct(): - self._on_network_error() + except (Neo4jError, ServiceUnavailable, SessionExpired) as exc: + self._on_error(exc) + raise return inner return outer(connection_attr) @@ -75,8 +75,8 @@ class Result: """ def __init__(self, connection, hydrant, fetch_size, on_closed, - on_network_error): - self._connection = _ConnectionErrorHandler(connection, on_network_error) + on_error): + self._connection = _ConnectionErrorHandler(connection, on_error) self._hydrant = hydrant self._on_closed = on_closed self._metadata = None diff --git a/neo4j/work/simple.py b/neo4j/work/simple.py index 18715fdbf..15f08b2ed 100644 --- a/neo4j/work/simple.py +++ b/neo4j/work/simple.py @@ -137,7 +137,7 @@ def _result_closed(self): self._autoResult = None self._disconnect() - def _result_network_error(self): + def _result_error(self, _): if self._autoResult: self._autoResult = None self._disconnect() @@ -227,7 +227,7 @@ def run(self, query, parameters=None, **kwparameters): self._autoResult = Result( cx, hydrant, self._config.fetch_size, self._result_closed, - self._result_network_error + self._result_error ) self._autoResult._run( query, parameters, self._config.database, @@ -261,7 +261,7 @@ def _transaction_closed_handler(self): self._transaction = None self._disconnect() - def _transaction_network_error_handler(self): + def _transaction_error_handler(self, _): if self._transaction: self._transaction = None self._disconnect() @@ -272,7 +272,7 @@ def _open_transaction(self, *, access_mode, database, metadata=None, self._transaction = Transaction( self._connection, self._config.fetch_size, self._transaction_closed_handler, - self._transaction_network_error_handler + self._transaction_error_handler ) self._transaction._begin(database, self._bookmarks, access_mode, metadata, timeout) diff --git a/neo4j/work/transaction.py b/neo4j/work/transaction.py index 3f3978d8f..ab01c86de 100644 --- a/neo4j/work/transaction.py +++ b/neo4j/work/transaction.py @@ -22,7 +22,8 @@ from neo4j.work.result import Result from neo4j.data import DataHydrator from neo4j.exceptions import ( - IncompleteCommit, + ServiceUnavailable, + SessionExpired, TransactionError, ) @@ -38,14 +39,15 @@ class Transaction: """ - def __init__(self, connection, fetch_size, on_closed, on_network_error): + def __init__(self, connection, fetch_size, on_closed, on_error): self._connection = connection self._bookmark = None self._results = [] self._closed = False + self._last_error = None self._fetch_size = fetch_size self._on_closed = on_closed - self._on_network_error = on_network_error + self._on_error = on_error def __enter__(self): return self @@ -65,9 +67,11 @@ def _begin(self, database, bookmarks, access_mode, metadata, timeout): def _result_on_closed_handler(self): pass - def _result_on_network_error_handler(self): - self._closed = True - self._on_network_error() + def _result_on_error_handler(self, exc): + self._last_error = exc + if isinstance(exc, (ServiceUnavailable, SessionExpired)): + self._closed = True + self._on_error(exc) def _consume_results(self): for result in self._results: @@ -111,7 +115,10 @@ def run(self, query, parameters=None, **kwparameters): raise ValueError("Query object is only supported for session.run") if self._closed: - raise TransactionError("Transaction closed") + raise TransactionError(self, "Transaction closed") + if self._last_error: + raise TransactionError(self, + "Transaction failed") from self._last_error if (self._results and self._connection.supports_multiple_results is False): @@ -123,7 +130,7 @@ def run(self, query, parameters=None, **kwparameters): result = Result( self._connection, DataHydrator(), self._fetch_size, self._result_on_closed_handler, - self._result_on_network_error_handler + self._result_on_error_handler ) self._results.append(result) @@ -137,7 +144,11 @@ def commit(self): :raise TransactionError: if the transaction is already closed """ if self._closed: - raise TransactionError("Transaction closed") + raise TransactionError(self, "Transaction closed") + if self._last_error: + raise TransactionError(self, + "Transaction failed") from self._last_error + metadata = {} try: self._consume_results() # DISCARD pending records then do a commit. @@ -157,7 +168,8 @@ def rollback(self): :raise TransactionError: if the transaction is already closed """ if self._closed: - raise TransactionError("Transaction closed") + raise TransactionError(self, "Transaction closed") + metadata = {} try: if not self._connection.is_reset: diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index d4195da2f..6d2a6cec8 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -1,7 +1,5 @@ { "skips": { - "neo4j.txrun.TestTxRun.test_should_not_run_valid_query_in_invalid_tx": - "Driver allows to run queries in broken transaction", "stub.routing.test_routing_v4x1.RoutingV4x1.test_should_retry_write_until_success_with_leader_change_using_tx_function": "Driver closes connection to router if DNS resolved name not in routing table", "stub.routing.test_routing_v3.RoutingV3.test_should_retry_write_until_success_with_leader_change_using_tx_function":