From a815986e6733469a814c6784329b9d7f37cf214e Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 4 Jul 2024 17:35:24 +0200 Subject: [PATCH 01/23] First API draft * Use properties and setters for attributes on errors * Add deprecation and preview warnings * Adjust TestKit backend --- src/neo4j/_async/io/_bolt.py | 5 +- src/neo4j/_async/io/_bolt5.py | 98 ++++- src/neo4j/_async/io/_common.py | 15 +- src/neo4j/_sync/io/_bolt.py | 5 +- src/neo4j/_sync/io/_bolt5.py | 98 ++++- src/neo4j/_sync/io/_common.py | 15 +- src/neo4j/exceptions.py | 503 ++++++++++++++++++++++-- testkitbackend/_async/backend.py | 36 +- testkitbackend/_sync/backend.py | 36 +- testkitbackend/totestkit.py | 49 +++ tests/unit/async_/io/test_class_bolt.py | 12 +- tests/unit/common/work/test_summary.py | 2 + tests/unit/sync/io/test_class_bolt.py | 12 +- 13 files changed, 764 insertions(+), 122 deletions(-) diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index 87c784569..fc1edbe75 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -458,7 +458,10 @@ async def open( # avoid new lines after imports for better readability and conciseness # fmt: off - if protocol_version == (5, 6): + if protocol_version == (5, 7): + from ._bolt5 import AsyncBolt5x7 + bolt_cls = AsyncBolt5x7 + elif protocol_version == (5, 6): from ._bolt5 import AsyncBolt5x6 bolt_cls = AsyncBolt5x6 elif protocol_version == (5, 5): diff --git a/src/neo4j/_async/io/_bolt5.py b/src/neo4j/_async/io/_bolt5.py index 370319712..b50513217 100644 --- a/src/neo4j/_async/io/_bolt5.py +++ b/src/neo4j/_async/io/_bolt5.py @@ -982,7 +982,7 @@ def begin( dehydration_hooks=dehydration_hooks, ) - DEFAULT_DIAGNOSTIC_RECORD = ( + DEFAULT_STATUS_DIAGNOSTIC_RECORD = ( ("OPERATION", ""), ("OPERATION_CODE", "0"), ("CURRENT_SCHEMA", "/"), @@ -1062,15 +1062,107 @@ def enrich(metadata_): if not isinstance(diag_record, dict): log.info( "[#%04X] _: Server supplied an " - "invalid diagnostic record (%r).", + "invalid status diagnostic record (%r).", self.local_port, diag_record, ) continue - for key, value in self.DEFAULT_DIAGNOSTIC_RECORD: + for key, value in self.DEFAULT_STATUS_DIAGNOSTIC_RECORD: diag_record.setdefault(key, value) enrich(metadata) await AsyncUtil.callback(wrapped_handler, metadata) return handler + + +class AsyncBolt5x7(AsyncBolt5x6): + PROTOCOL_VERSION = Version(5, 7) + + DEFAULT_ERROR_DIAGNOSTIC_RECORD = ( + AsyncBolt5x5.DEFAULT_STATUS_DIAGNOSTIC_RECORD + ) + + def _enrich_error_diagnostic_record(self, metadata): + if not isinstance(metadata, dict): + return + diag_record = metadata.setdefault("diagnostic_record", {}) + if not isinstance(diag_record, dict): + log.info( + "[#%04X] _: Server supplied an " + "invalid error diagnostic record (%r).", + self.local_port, + diag_record, + ) + return + for key, value in self.DEFAULT_ERROR_DIAGNOSTIC_RECORD: + diag_record.setdefault(key, value) + + async def _process_message(self, tag, fields): + """Process at most one message from the server, if available. + + :returns: 2-tuple of number of detail messages and number of summary + messages fetched + """ + details = [] + summary_signature = summary_metadata = None + if tag == b"\x71": # RECORD + details = fields + elif fields: + summary_signature = tag + summary_metadata = fields[0] + else: + summary_signature = tag + + if details: + # Do not log any data + log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) + await self.responses[0].on_records(details) + + if summary_signature is None: + return len(details), 0 + + response = self.responses.popleft() + response.complete = True + if summary_signature == b"\x70": + log.debug( + "[#%04X] S: SUCCESS %r", self.local_port, summary_metadata + ) + self._server_state_manager.transition( + response.message, summary_metadata + ) + await response.on_success(summary_metadata or {}) + elif summary_signature == b"\x7e": + log.debug("[#%04X] S: IGNORED", self.local_port) + await response.on_ignored(summary_metadata or {}) + elif summary_signature == b"\x7f": + log.debug( + "[#%04X] S: FAILURE %r", self.local_port, summary_metadata + ) + self._server_state_manager.state = self.bolt_states.FAILED + self._enrich_error_diagnostic_record(summary_metadata) + try: + await response.on_failure(summary_metadata or {}) + except (ServiceUnavailable, DatabaseUnavailable): + if self.pool: + await self.pool.deactivate(address=self.unresolved_address) + raise + except (NotALeader, ForbiddenOnReadOnlyDatabase): + if self.pool: + await self.pool.on_write_failure( + address=self.unresolved_address, + database=self.last_database, + ) + raise + except Neo4jError as e: + if self.pool: + await self.pool.on_neo4j_error(e, self) + raise + else: + sig_int = ord(summary_signature) + raise BoltProtocolError( + f"Unexpected response message with signature {sig_int:02X}", + self.unresolved_address, + ) + + return len(details), 1 diff --git a/src/neo4j/_async/io/_common.py b/src/neo4j/_async/io/_common.py index 2aa55f616..53c687acf 100644 --- a/src/neo4j/_async/io/_common.py +++ b/src/neo4j/_async/io/_common.py @@ -21,6 +21,7 @@ from ..._async_compat.util import AsyncUtil from ..._exceptions import SocketDeadlineExceededError +from ...api import Version from ...exceptions import ( Neo4jError, ServiceUnavailable, @@ -29,6 +30,8 @@ ) +GQL_ERROR_AWARE_PROTOCOL = Version(5, 7) + log = logging.getLogger("neo4j.io") @@ -248,7 +251,7 @@ async def on_failure(self, metadata): await AsyncUtil.callback(handler, metadata) handler = self.handlers.get("on_summary") await AsyncUtil.callback(handler) - raise Neo4jError.hydrate(**metadata) + raise self._hydrate_error(metadata) async def on_ignored(self, metadata=None): """Handle an IGNORED message been received.""" @@ -257,6 +260,12 @@ async def on_ignored(self, metadata=None): handler = self.handlers.get("on_summary") await AsyncUtil.callback(handler) + def _hydrate_error(self, metadata): + if self.connection.PROTOCOL_VERSION >= GQL_ERROR_AWARE_PROTOCOL: + return Neo4jError._hydrate_gql(**metadata) + else: + return Neo4jError.hydrate(**metadata) + class InitResponse(Response): async def on_failure(self, metadata): @@ -271,7 +280,7 @@ async def on_failure(self, metadata): "message", "Connection initialisation failed due to an unknown error", ) - raise Neo4jError.hydrate(**metadata) + raise self._hydrate_error(metadata) class LogonResponse(InitResponse): @@ -283,7 +292,7 @@ async def on_failure(self, metadata): await AsyncUtil.callback(handler, metadata) handler = self.handlers.get("on_summary") await AsyncUtil.callback(handler) - raise Neo4jError.hydrate(**metadata) + raise self._hydrate_error(metadata) class ResetResponse(Response): diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 3aa2f020d..1c0e8bf8c 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -458,7 +458,10 @@ def open( # avoid new lines after imports for better readability and conciseness # fmt: off - if protocol_version == (5, 6): + if protocol_version == (5, 7): + from ._bolt5 import Bolt5x7 + bolt_cls = Bolt5x7 + elif protocol_version == (5, 6): from ._bolt5 import Bolt5x6 bolt_cls = Bolt5x6 elif protocol_version == (5, 5): diff --git a/src/neo4j/_sync/io/_bolt5.py b/src/neo4j/_sync/io/_bolt5.py index 7692d1772..52a19dece 100644 --- a/src/neo4j/_sync/io/_bolt5.py +++ b/src/neo4j/_sync/io/_bolt5.py @@ -982,7 +982,7 @@ def begin( dehydration_hooks=dehydration_hooks, ) - DEFAULT_DIAGNOSTIC_RECORD = ( + DEFAULT_STATUS_DIAGNOSTIC_RECORD = ( ("OPERATION", ""), ("OPERATION_CODE", "0"), ("CURRENT_SCHEMA", "/"), @@ -1062,15 +1062,107 @@ def enrich(metadata_): if not isinstance(diag_record, dict): log.info( "[#%04X] _: Server supplied an " - "invalid diagnostic record (%r).", + "invalid status diagnostic record (%r).", self.local_port, diag_record, ) continue - for key, value in self.DEFAULT_DIAGNOSTIC_RECORD: + for key, value in self.DEFAULT_STATUS_DIAGNOSTIC_RECORD: diag_record.setdefault(key, value) enrich(metadata) Util.callback(wrapped_handler, metadata) return handler + + +class Bolt5x7(Bolt5x6): + PROTOCOL_VERSION = Version(5, 7) + + DEFAULT_ERROR_DIAGNOSTIC_RECORD = ( + Bolt5x5.DEFAULT_STATUS_DIAGNOSTIC_RECORD + ) + + def _enrich_error_diagnostic_record(self, metadata): + if not isinstance(metadata, dict): + return + diag_record = metadata.setdefault("diagnostic_record", {}) + if not isinstance(diag_record, dict): + log.info( + "[#%04X] _: Server supplied an " + "invalid error diagnostic record (%r).", + self.local_port, + diag_record, + ) + return + for key, value in self.DEFAULT_ERROR_DIAGNOSTIC_RECORD: + diag_record.setdefault(key, value) + + def _process_message(self, tag, fields): + """Process at most one message from the server, if available. + + :returns: 2-tuple of number of detail messages and number of summary + messages fetched + """ + details = [] + summary_signature = summary_metadata = None + if tag == b"\x71": # RECORD + details = fields + elif fields: + summary_signature = tag + summary_metadata = fields[0] + else: + summary_signature = tag + + if details: + # Do not log any data + log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) + self.responses[0].on_records(details) + + if summary_signature is None: + return len(details), 0 + + response = self.responses.popleft() + response.complete = True + if summary_signature == b"\x70": + log.debug( + "[#%04X] S: SUCCESS %r", self.local_port, summary_metadata + ) + self._server_state_manager.transition( + response.message, summary_metadata + ) + response.on_success(summary_metadata or {}) + elif summary_signature == b"\x7e": + log.debug("[#%04X] S: IGNORED", self.local_port) + response.on_ignored(summary_metadata or {}) + elif summary_signature == b"\x7f": + log.debug( + "[#%04X] S: FAILURE %r", self.local_port, summary_metadata + ) + self._server_state_manager.state = self.bolt_states.FAILED + self._enrich_error_diagnostic_record(summary_metadata) + try: + response.on_failure(summary_metadata or {}) + except (ServiceUnavailable, DatabaseUnavailable): + if self.pool: + self.pool.deactivate(address=self.unresolved_address) + raise + except (NotALeader, ForbiddenOnReadOnlyDatabase): + if self.pool: + self.pool.on_write_failure( + address=self.unresolved_address, + database=self.last_database, + ) + raise + except Neo4jError as e: + if self.pool: + self.pool.on_neo4j_error(e, self) + raise + else: + sig_int = ord(summary_signature) + raise BoltProtocolError( + f"Unexpected response message with signature {sig_int:02X}", + self.unresolved_address, + ) + + return len(details), 1 diff --git a/src/neo4j/_sync/io/_common.py b/src/neo4j/_sync/io/_common.py index 09fba0d8d..f5a440a66 100644 --- a/src/neo4j/_sync/io/_common.py +++ b/src/neo4j/_sync/io/_common.py @@ -21,6 +21,7 @@ from ..._async_compat.util import Util from ..._exceptions import SocketDeadlineExceededError +from ...api import Version from ...exceptions import ( Neo4jError, ServiceUnavailable, @@ -29,6 +30,8 @@ ) +GQL_ERROR_AWARE_PROTOCOL = Version(5, 7) + log = logging.getLogger("neo4j.io") @@ -248,7 +251,7 @@ def on_failure(self, metadata): Util.callback(handler, metadata) handler = self.handlers.get("on_summary") Util.callback(handler) - raise Neo4jError.hydrate(**metadata) + raise self._hydrate_error(metadata) def on_ignored(self, metadata=None): """Handle an IGNORED message been received.""" @@ -257,6 +260,12 @@ def on_ignored(self, metadata=None): handler = self.handlers.get("on_summary") Util.callback(handler) + def _hydrate_error(self, metadata): + if self.connection.PROTOCOL_VERSION >= GQL_ERROR_AWARE_PROTOCOL: + return Neo4jError._hydrate_gql(**metadata) + else: + return Neo4jError.hydrate(**metadata) + class InitResponse(Response): def on_failure(self, metadata): @@ -271,7 +280,7 @@ def on_failure(self, metadata): "message", "Connection initialisation failed due to an unknown error", ) - raise Neo4jError.hydrate(**metadata) + raise self._hydrate_error(metadata) class LogonResponse(InitResponse): @@ -283,7 +292,7 @@ def on_failure(self, metadata): Util.callback(handler, metadata) handler = self.handlers.get("on_summary") Util.callback(handler) - raise Neo4jError.hydrate(**metadata) + raise self._hydrate_error(metadata) class ResetResponse(Response): diff --git a/src/neo4j/exceptions.py b/src/neo4j/exceptions.py index a317ea097..a47d5a4bf 100644 --- a/src/neo4j/exceptions.py +++ b/src/neo4j/exceptions.py @@ -59,9 +59,50 @@ from __future__ import annotations +import sys import typing as t +from copy import deepcopy -from ._meta import deprecated +from ._meta import ( + deprecated, + preview, +) + + +__all__ = [ + "AuthConfigurationError", + "AuthError", + "BrokenRecordError", + "CertificateConfigurationError", + "ClientError", + "ConfigurationError", + "ConstraintError", + "CypherSyntaxError", + "CypherTypeError", + "DatabaseError", + "DatabaseUnavailable", + "DriverError", + "Forbidden", + "ForbiddenOnReadOnlyDatabase", + "IncompleteCommit", + "Neo4jError", + "NotALeader", + "ReadServiceUnavailable", + "ResultConsumedError", + "ResultError", + "ResultFailedError", + "ResultNotSingleError", + "RoutingServiceUnavailable", + "ServiceUnavailable", + "SessionError", + "SessionExpired", + "TokenExpired", + "TransactionError", + "TransactionNestingError", + "TransientError", + "UnsupportedServerProduct", + "WriteServiceUnavailable", +] if t.TYPE_CHECKING: @@ -168,57 +209,107 @@ } +_UNKNOWN_GQL_STATUS = "50N42" +_UNKNOWN_GQL_EXPLANATION = "general processing exception - unknown error" +_UNKNOWN_GQL_CLASSIFICATION = "UNKNOWN" # TODO: define final value +_UNKNOWN_GQL_DIAGNOSTIC_RECORD = ( + ( + "OPERATION", + "", + ), + ( + "OPERATION_CODE", + "0", + ), + ("CURRENT_SCHEMA", "/"), +) + + +def _make_gql_description(explanation: str, message: str | None = None) -> str: + if message is None: + return f"error: {explanation}" + + return f"error: {explanation}. {message}" + + # Neo4jError class Neo4jError(Exception): """Raised when the Cypher engine returns an error to the client.""" - #: (str or None) The error message returned by the server. - message = None - #: (str or None) The error code returned by the server. - #: There are many Neo4j status codes, see - #: `status codes `_. - code = None - classification = None - category = None - title = None + _neo4j_code: str + _message: str + _classification: str + _category: str + _title: str #: (dict) Any additional information returned by the server. - metadata = None + _metadata: dict[str, t.Any] + + _gql_status: str + _gql_explanation: str # internal use only + _gql_description: str + _gql_status_description: str + _gql_classification: str + _status_diagnostic_record: dict[str, t.Any] # original, internal only + _diagnostic_record: dict[str, t.Any] # copy to be used externally _retryable = False @classmethod - def hydrate( + def _hydrate( cls, + *, + neo4j_code: str | None = None, message: str | None = None, - code: str | None = None, - **metadata: t.Any, - ) -> Neo4jError: + gql_status: str | None = None, + explanation: str | None = None, + diagnostic_record: dict[str, t.Any] | None = None, + cause: Neo4jError | None = None, + ) -> te.Self: + neo4j_code = neo4j_code or "Neo.DatabaseError.General.UnknownError" message = message or "An unknown error occurred" - code = code or "Neo.DatabaseError.General.UnknownError" + try: - _, classification, category, title = code.split(".") + _, classification, category, title = neo4j_code.split(".") except ValueError: classification = CLASSIFICATION_DATABASE category = "General" title = "UnknownError" else: classification_override, code_override = ERROR_REWRITE_MAP.get( - code, (None, None) + neo4j_code, (None, None) ) if classification_override is not None: classification = classification_override if code_override is not None: - code = code_override + neo4j_code = code_override - error_class = cls._extract_error_class(classification, code) + error_class = cls._extract_error_class(classification, neo4j_code) + + if explanation is not None: + gql_description = _make_gql_description(explanation, message) + inst = error_class(gql_description) + inst._gql_description = gql_description + else: + inst = error_class(message) + + inst._neo4j_code = neo4j_code + inst._message = message + inst._classification = classification + inst._category = category + inst._title = title + if gql_status: + inst._gql_status = gql_status + if explanation: + inst._gql_explanation = explanation + if diagnostic_record is not None: + inst._status_diagnostic_record = diagnostic_record + if cause: + inst.__cause__ = cause + else: + current_exc = sys.exc_info()[1] + if current_exc is not None: + inst.__context__ = current_exc - inst = error_class(message) - inst.message = message - inst.code = code - inst.classification = classification - inst.category = category - inst.title = title - inst.metadata = metadata return inst @classmethod @@ -241,6 +332,257 @@ def _extract_error_class(cls, classification, code): else: return cls + @classmethod + def hydrate( + cls, + code: str | None = None, + message: str | None = None, + **metadata: t.Any, + ) -> te.Self: + inst = cls._hydrate(neo4j_code=code, message=message) + inst._metadata = metadata + return inst + + @classmethod + def _hydrate_gql(cls, **metadata: t.Any) -> te.Self: + gql_status = metadata.pop("gql_status", None) + if not isinstance(gql_status, str): + gql_status = None + status_explanation = metadata.pop("status_explanation", None) + if not isinstance(status_explanation, str): + status_explanation = None + message = metadata.pop("status_message", None) + if not isinstance(message, str): + message = None + neo4j_code = metadata.pop("neo4j_code", None) + if not isinstance(neo4j_code, str): + neo4j_code = None + diagnostic_record = metadata.pop("diagnostic_record", None) + if not isinstance(diagnostic_record, dict): + diagnostic_record = None + cause = metadata.pop("cause", None) + if not isinstance(cause, dict): + cause = None + else: + cause = cls._hydrate_gql(**cause) + + inst = cls._hydrate( + neo4j_code=neo4j_code, + message=message, + gql_status=gql_status, + explanation=status_explanation, + diagnostic_record=diagnostic_record, + cause=cause, + ) + inst._metadata = metadata + + return inst + + @property + def message(self) -> str: + """ + TODO. + + #: (str or None) The error message returned by the server. + """ + return self._message + + @message.setter + @deprecated("Altering the message of a Neo4jError is deprecated.") + def message(self, value: str) -> None: + self._message = value + + # TODO: 6.0 - Remove this alias + @property + @deprecated( + "The code of a Neo4jError is deprecated. Use neo4j_code instead." + ) + def code(self) -> str: + """ + The neo4j error code returned by the server. + + .. deprecated:: 5.xx + Use :attr:`.neo4j_code` instead. + """ + return self._neo4j_code + + # TODO: 6.0 - Remove this and all other deprecated setters + @code.setter + @deprecated("Altering the code of a Neo4jError is deprecated.") + def code(self, value: str) -> None: + self._neo4j_code = value + + @property + def neo4j_code(self) -> str: + """ + The error code returned by the server. + + There are many Neo4j status codes, see + `status codes `_. + + .. versionadded: 5.xx + """ + return self._neo4j_code + + @property + @deprecated("classification of Neo4jError is deprecated.") + def classification(self) -> str: + # Undocumented, has been there before + # TODO 6.0: Remove this property + return self._classification + + @classification.setter + @deprecated("classification of Neo4jError is deprecated.") + def classification(self, value: str) -> None: + self._classification = value + + @property + @deprecated("category of Neo4jError is deprecated.") + def category(self) -> str: + # Undocumented, has been there before + # TODO 6.0: Remove this property + return self._category + + @category.setter + @deprecated("category of Neo4jError is deprecated.") + def category(self, value: str) -> None: + self._category = value + + @property + @deprecated("title of Neo4jError is deprecated.") + def title(self) -> str: + # Undocumented, has been there before + # TODO 6.0: Remove this property + return self._title + + @title.setter + @deprecated("title of Neo4jError is deprecated.") + def title(self, value: str) -> None: + self._title = value + + @property + def metadata(self) -> dict[str, t.Any]: + # Undocumented, might be useful for debugging + return self._metadata + + @metadata.setter + @deprecated("Altering the metadata of Neo4jError is deprecated.") + def metadata(self, value: dict[str, t.Any]) -> None: + # TODO 6.0: Remove this property + self._metadata = value + + @property + @preview("GQLSTATUS support is a preview feature.") + def gql_status(self) -> str: + """ + The GQLSTATUS returned from the server. + + The status code ``50N42`` (unknown error) is a special code that the + driver will use for polyfilling (when connected to an old, + non-GQL-aware server). + Further, it may be used by servers during the transition-phase to + GQLSTATUS-awareness. + + This means this code is not guaranteed to be stable and may change in + future versions. + + **This is a preview**. + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. versionadded: 5.xx + """ + if hasattr(self, "_gql_status"): + return self._gql_status + + self._gql_status = _UNKNOWN_GQL_STATUS + return self._gql_status + + def _get_explanation(self) -> str: + if hasattr(self, "_gql_explanation"): + return self._gql_explanation + + self._gql_explanation = _UNKNOWN_GQL_EXPLANATION + return self._gql_explanation + + @property + @preview("GQLSTATUS support is a preview feature.") + def gql_status_description(self) -> str: + """ + A description of the GQLSTATUS returned from the server. + + This description is meant for human consumption and debugging purposes. + Don't rely on it in a programmatic way. + + **This is a preview**. + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. versionadded: 5.xx + """ + if hasattr(self, "_gql_status_description"): + return self._gql_status_description + + self._gql_status_description = _make_gql_description( + self._get_explanation(), self._message + ) + return self._gql_status_description + + @property + @preview("GQLSTATUS support is a preview feature.") + def gql_classification(self) -> str: + """ + TODO. + + **This is a preview**. + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. versionadded: 5.xx + """ + # TODO + if hasattr(self, "_gql_classification"): + return self._gql_classification + + diag_record = self._get_status_diagnostic_record() + classification = diag_record.get("_classification") + if not isinstance(classification, str): + self._classification = _UNKNOWN_GQL_CLASSIFICATION + else: + self._classification = classification + return self._classification + + def _get_status_diagnostic_record(self) -> dict[str, t.Any]: + if hasattr(self, "_status_diagnostic_record"): + return self._status_diagnostic_record + + self._status_diagnostic_record = dict(_UNKNOWN_GQL_DIAGNOSTIC_RECORD) + return self._status_diagnostic_record + + @property + @preview("GQLSTATUS support is a preview feature.") + def diagnostic_record(self) -> dict[str, t.Any]: + """ + Further information about the GQLSTATUS for diagnostic purposes. + + **This is a preview**. + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. versionadded: 5.xx + """ + if hasattr(self, "_diagnostic_record"): + return self._diagnostic_record + + self._diagnostic_record = deepcopy( + self._get_status_diagnostic_record() + ) + return self._diagnostic_record + # TODO: 6.0 - Remove this alias @deprecated( "Neo4jError.is_retriable is deprecated and will be removed in a " @@ -277,7 +619,9 @@ def is_retryable(self) -> bool: return self._retryable def _unauthenticates_all_connections(self) -> bool: - return self.code == "Neo.ClientError.Security.AuthorizationExpired" + return ( + self._neo4j_code == "Neo.ClientError.Security.AuthorizationExpired" + ) # TODO: 6.0 - Remove this alias invalidates_all_connections = deprecated( @@ -289,9 +633,10 @@ def _unauthenticates_all_connections(self) -> bool: def _is_fatal_during_discovery(self) -> bool: # checks if the code is an error that is caused by the client. In this # case the driver should fail fast during discovery. - if not isinstance(self.code, str): + code = self._neo4j_code + if not isinstance(code, str): return False - if self.code in { + if code in { "Neo.ClientError.Database.DatabaseNotFound", "Neo.ClientError.Transaction.InvalidBookmark", "Neo.ClientError.Transaction.InvalidBookmarkMixture", @@ -301,14 +646,14 @@ def _is_fatal_during_discovery(self) -> bool: }: return True return ( - self.code.startswith("Neo.ClientError.Security.") - and self.code != "Neo.ClientError.Security.AuthorizationExpired" + code.startswith("Neo.ClientError.Security.") + and code != "Neo.ClientError.Security.AuthorizationExpired" ) def _has_security_code(self) -> bool: - if self.code is None: + if self._neo4j_code is None: return False - return self.code.startswith("Neo.ClientError.Security.") + return self._neo4j_code.startswith("Neo.ClientError.Security.") # TODO: 6.0 - Remove this alias is_fatal_during_discovery = deprecated( @@ -318,8 +663,18 @@ def _has_security_code(self) -> bool: )(_is_fatal_during_discovery) def __str__(self): - if self.code or self.message: - return f"{{code: {self.code}}} {{message: {self.message}}}" + code = self._neo4j_code + message = self._message + if code or message: + return f"{{neo4j_code: {code}}} {{message: {message}}}" + # TODO: 6.0 - User gql status and status_description instead + # something like: + # return ( + # f"{{gql_status: {self.gql_status}}} " + # f"{{neo4j_code: {self.neo4j_code}}} " + # f"{{gql_status_description: {self.gql_status_description}}} " + # f"{{diagnostic_record: {self.diagnostic_record}}}" + # ) return super().__str__() @@ -436,6 +791,9 @@ class ForbiddenOnReadOnlyDatabase(TransientError): class DriverError(Exception): """Raised when the Driver raises an error.""" + _diagnostic_record: dict[str, t.Any] + _gql_description: str + def is_retryable(self) -> bool: """ Whether the error is retryable. @@ -451,6 +809,79 @@ def is_retryable(self) -> bool: """ return False + @property + @preview("GQLSTATUS support is a preview feature.") + def gql_status(self) -> str: + """ + The GQLSTATUS of this error. + + .. seealso:: :attr:`.Neo4jError.gql_status` + + **This is a preview**. + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. versionadded: 5.xx + """ + return _UNKNOWN_GQL_STATUS + + @property + @preview("GQLSTATUS support is a preview feature.") + def gql_status_description(self) -> str: + """ + A description of the GQLSTATUS. + + This description is meant for human consumption and debugging purposes. + Don't rely on it in a programmatic way. + + **This is a preview**. + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. versionadded: 5.xx + """ + if hasattr(self, "_gql_description"): + return self._gql_description + + self._gql_description = _make_gql_description(_UNKNOWN_GQL_EXPLANATION) + return self._gql_description + + @property + @preview("GQLSTATUS support is a preview feature.") + def gql_classification(self) -> str: + """ + TODO. + + **This is a preview**. + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. versionadded: 5.xx + """ + return _UNKNOWN_GQL_CLASSIFICATION + + @property + @preview("GQLSTATUS support is a preview feature.") + def diagnostic_record(self) -> dict[str, t.Any]: + """ + Further information about the GQLSTATUS for diagnostic purposes. + + **This is a preview**. + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. versionadded: 5.xx + """ + if hasattr(self, "_diagnostic_record"): + return self._diagnostic_record + + self._diagnostic_record = dict(_UNKNOWN_GQL_DIAGNOSTIC_RECORD) + return self._diagnostic_record + # DriverError > SessionError class SessionError(DriverError): diff --git a/testkitbackend/_async/backend.py b/testkitbackend/_async/backend.py index a66184953..dcb73a169 100644 --- a/testkitbackend/_async/backend.py +++ b/testkitbackend/_async/backend.py @@ -34,6 +34,7 @@ UnsupportedServerProduct, ) +from .. import totestkit from .._driver_logger import ( buffer_handler, log, @@ -133,43 +134,16 @@ def _exc_stems_from_driver(exc): return True return None - @staticmethod - def _exc_msg(exc, max_depth=10): - if isinstance(exc, Neo4jError) and exc.message is not None: - return str(exc.message) - - depth = 0 - res = str(exc) - while getattr(exc, "__cause__", None) is not None: - depth += 1 - if depth >= max_depth: - break - res += f"\nCaused by: {exc.__cause__!r}" - exc = exc.__cause__ - return res - async def write_driver_exc(self, exc): - log.debug(traceback.format_exc()) + log.debug(exc.args) + log.debug("".join(traceback.format_exception(exc))) key = self.next_key() self.errors[key] = exc - payload = {"id": key, "msg": ""} - - if isinstance(exc, MarkdAsDriverError): - wrapped_exc = exc.wrapped_exc - payload["errorType"] = str(type(wrapped_exc)) - if wrapped_exc.args: - payload["msg"] = self._exc_msg(wrapped_exc.args[0]) - payload["retryable"] = False - else: - payload["errorType"] = str(type(exc)) - payload["msg"] = self._exc_msg(exc) - if isinstance(exc, Neo4jError): - payload["code"] = exc.code - payload["retryable"] = getattr(exc, "is_retryable", bool)() + data = totestkit.driver_exc(exc, id_=key) - await self.send_response("DriverError", payload) + await self.send_response(data["name"], data["data"]) async def _process(self, request): # Process a received request. diff --git a/testkitbackend/_sync/backend.py b/testkitbackend/_sync/backend.py index a3c23891d..ee8a3edc5 100644 --- a/testkitbackend/_sync/backend.py +++ b/testkitbackend/_sync/backend.py @@ -34,6 +34,7 @@ UnsupportedServerProduct, ) +from .. import totestkit from .._driver_logger import ( buffer_handler, log, @@ -133,43 +134,16 @@ def _exc_stems_from_driver(exc): return True return None - @staticmethod - def _exc_msg(exc, max_depth=10): - if isinstance(exc, Neo4jError) and exc.message is not None: - return str(exc.message) - - depth = 0 - res = str(exc) - while getattr(exc, "__cause__", None) is not None: - depth += 1 - if depth >= max_depth: - break - res += f"\nCaused by: {exc.__cause__!r}" - exc = exc.__cause__ - return res - def write_driver_exc(self, exc): - log.debug(traceback.format_exc()) + log.debug(exc.args) + log.debug("".join(traceback.format_exception(exc))) key = self.next_key() self.errors[key] = exc - payload = {"id": key, "msg": ""} - - if isinstance(exc, MarkdAsDriverError): - wrapped_exc = exc.wrapped_exc - payload["errorType"] = str(type(wrapped_exc)) - if wrapped_exc.args: - payload["msg"] = self._exc_msg(wrapped_exc.args[0]) - payload["retryable"] = False - else: - payload["errorType"] = str(type(exc)) - payload["msg"] = self._exc_msg(exc) - if isinstance(exc, Neo4jError): - payload["code"] = exc.code - payload["retryable"] = getattr(exc, "is_retryable", bool)() + data = totestkit.driver_exc(exc, id_=key) - self.send_response("DriverError", payload) + self.send_response(data["name"], data["data"]) def _process(self, request): # Process a received request. diff --git a/testkitbackend/totestkit.py b/testkitbackend/totestkit.py index f6960ce61..a3e6c158b 100644 --- a/testkitbackend/totestkit.py +++ b/testkitbackend/totestkit.py @@ -19,6 +19,7 @@ import math import neo4j +from neo4j.exceptions import Neo4jError from neo4j.graph import ( Node, Path, @@ -36,6 +37,7 @@ ) from ._warning_check import warning_check +from .exceptions import MarkdAsDriverError def record(rec): @@ -293,3 +295,50 @@ def to(name, val): def auth_token(auth): return {"name": "AuthorizationToken", "data": vars(auth)} + + +def _exc_msg(exc, max_depth=10): + if isinstance(exc, Neo4jError) and exc.message is not None: + return str(exc.message) + + depth = 0 + res = str(exc) + while getattr(exc, "__cause__", None) is not None: + depth += 1 + if depth >= max_depth: + break + res += f"\nCaused by: {exc.__cause__!r}" + exc = exc.__cause__ + return res + + +def driver_exc(exc, id_=None): + payload = {"msg": ""} + if id_ is not None: + payload["id"] = id_ + payload["retryable"] = getattr(exc, "is_retryable", bool)() + if isinstance(exc, MarkdAsDriverError): + wrapped_exc = exc.wrapped_exc + payload["errorType"] = str(type(wrapped_exc)) + if wrapped_exc.args: + payload["msg"] = _exc_msg(wrapped_exc.args[0]) + else: + payload["errorType"] = str(type(exc)) + payload["msg"] = _exc_msg(exc) + if isinstance(exc, Neo4jError): + payload["code"] = exc.neo4j_code + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + payload["gqlStatus"] = exc.gql_status + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + payload["statusDescription"] = exc.gql_status_description + if exc.__cause__ is not None: + payload["cause"] = driver_exc(exc.__cause__) + + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + payload["diagnosticRecord"] = { + k: field(v) for k, v in exc.diagnostic_record.items() + } + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + payload["classification"] = exc.gql_classification + + return {"name": "DriverError", "data": payload} diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py index 5bbc50e85..b0ddbc968 100644 --- a/tests/unit/async_/io/test_class_bolt.py +++ b/tests/unit/async_/io/test_class_bolt.py @@ -38,7 +38,7 @@ def test_class_method_protocol_handlers(): expected_handlers = { (3, 0), (4, 1), (4, 2), (4, 3), (4, 4), - (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), + (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), } # fmt: on @@ -68,7 +68,8 @@ def test_class_method_protocol_handlers(): ((5, 4), 1), ((5, 5), 1), ((5, 6), 1), - ((5, 7), 0), + ((5, 7), 1), + ((5, 8), 0), ((6, 0), 0), ], ) @@ -91,7 +92,7 @@ def test_class_method_get_handshake(): handshake = AsyncBolt.get_handshake() assert ( handshake - == b"\x00\x06\x06\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" + == b"\x00\x07\x07\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" ) @@ -141,6 +142,7 @@ async def test_cancel_hello_in_open(mocker, none_auth): ((5, 4), "neo4j._async.io._bolt5.AsyncBolt5x4"), ((5, 5), "neo4j._async.io._bolt5.AsyncBolt5x5"), ((5, 6), "neo4j._async.io._bolt5.AsyncBolt5x6"), + ((5, 7), "neo4j._async.io._bolt5.AsyncBolt5x7"), ), ) @mark_async_test @@ -179,7 +181,7 @@ async def test_version_negotiation( (2, 0), (4, 0), (3, 1), - (5, 7), + (5, 8), (6, 0), ), ) @@ -187,7 +189,7 @@ async def test_version_negotiation( async def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = ( "('3.0', '4.1', '4.2', '4.3', '4.4', " - "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6')" + "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7')" ) address = ("localhost", 7687) diff --git a/tests/unit/common/work/test_summary.py b/tests/unit/common/work/test_summary.py index 58adc9c96..74f2059a3 100644 --- a/tests/unit/common/work/test_summary.py +++ b/tests/unit/common/work/test_summary.py @@ -889,6 +889,7 @@ def test_summary_result_counters(summary_args_kwargs, counters_set) -> None: ((5, 4), "t_first"), ((5, 5), "t_first"), ((5, 6), "t_first"), + ((5, 7), "t_first"), ), ) def test_summary_result_available_after( @@ -925,6 +926,7 @@ def test_summary_result_available_after( ((5, 4), "t_last"), ((5, 5), "t_last"), ((5, 6), "t_last"), + ((5, 7), "t_last"), ), ) def test_summary_result_consumed_after( diff --git a/tests/unit/sync/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py index f82d441dd..f3b063037 100644 --- a/tests/unit/sync/io/test_class_bolt.py +++ b/tests/unit/sync/io/test_class_bolt.py @@ -38,7 +38,7 @@ def test_class_method_protocol_handlers(): expected_handlers = { (3, 0), (4, 1), (4, 2), (4, 3), (4, 4), - (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), + (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), } # fmt: on @@ -68,7 +68,8 @@ def test_class_method_protocol_handlers(): ((5, 4), 1), ((5, 5), 1), ((5, 6), 1), - ((5, 7), 0), + ((5, 7), 1), + ((5, 8), 0), ((6, 0), 0), ], ) @@ -91,7 +92,7 @@ def test_class_method_get_handshake(): handshake = Bolt.get_handshake() assert ( handshake - == b"\x00\x06\x06\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" + == b"\x00\x07\x07\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" ) @@ -141,6 +142,7 @@ def test_cancel_hello_in_open(mocker, none_auth): ((5, 4), "neo4j._sync.io._bolt5.Bolt5x4"), ((5, 5), "neo4j._sync.io._bolt5.Bolt5x5"), ((5, 6), "neo4j._sync.io._bolt5.Bolt5x6"), + ((5, 7), "neo4j._sync.io._bolt5.Bolt5x7"), ), ) @mark_sync_test @@ -179,7 +181,7 @@ def test_version_negotiation( (2, 0), (4, 0), (3, 1), - (5, 7), + (5, 8), (6, 0), ), ) @@ -187,7 +189,7 @@ def test_version_negotiation( def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = ( "('3.0', '4.1', '4.2', '4.3', '4.4', " - "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6')" + "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7')" ) address = ("localhost", 7687) From 3a4f3d000bf046127dcf668e72ed97785e27d150 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Wed, 25 Sep 2024 10:50:45 +0200 Subject: [PATCH 02/23] Update to reflect design changes --- CHANGELOG.md | 2 +- src/neo4j/_async/io/_bolt.py | 4 + src/neo4j/_async/io/_bolt5.py | 3 +- src/neo4j/_async/io/_common.py | 2 +- src/neo4j/_async/io/_pool.py | 3 + src/neo4j/_async/work/session.py | 1 + src/neo4j/_sync/io/_bolt.py | 4 + src/neo4j/_sync/io/_bolt5.py | 3 +- src/neo4j/_sync/io/_common.py | 2 +- src/neo4j/_sync/io/_pool.py | 3 + src/neo4j/_sync/work/session.py | 1 + src/neo4j/exceptions.py | 815 ++++++++++------- testkitbackend/_async/backend.py | 49 +- testkitbackend/_sync/backend.py | 49 +- testkitbackend/test_config.json | 1 + testkitbackend/totestkit.py | 81 +- tests/iter_util.py | 18 +- tests/unit/async_/fixtures/fake_connection.py | 2 +- tests/unit/async_/io/test_class_bolt3.py | 2 +- tests/unit/async_/io/test_class_bolt4x0.py | 2 +- tests/unit/async_/io/test_class_bolt4x1.py | 2 +- tests/unit/async_/io/test_class_bolt4x2.py | 2 +- tests/unit/async_/io/test_class_bolt4x3.py | 2 +- tests/unit/async_/io/test_class_bolt4x4.py | 2 +- tests/unit/async_/io/test_class_bolt5x0.py | 2 +- tests/unit/async_/io/test_class_bolt5x1.py | 2 +- tests/unit/async_/io/test_class_bolt5x2.py | 2 +- tests/unit/async_/io/test_class_bolt5x3.py | 2 +- tests/unit/async_/io/test_class_bolt5x4.py | 2 +- tests/unit/async_/io/test_class_bolt5x5.py | 2 +- tests/unit/async_/io/test_class_bolt5x6.py | 2 +- tests/unit/async_/io/test_class_bolt5x7.py | 851 ++++++++++++++++++ tests/unit/async_/io/test_neo4j_pool.py | 38 +- tests/unit/async_/test_auth_management.py | 2 +- tests/unit/async_/work/test_transaction.py | 7 +- tests/unit/common/test_exceptions.py | 20 +- tests/unit/common/test_record.py | 1 + tests/unit/sync/fixtures/fake_connection.py | 2 +- tests/unit/sync/io/test_class_bolt3.py | 2 +- tests/unit/sync/io/test_class_bolt4x0.py | 2 +- tests/unit/sync/io/test_class_bolt4x1.py | 2 +- tests/unit/sync/io/test_class_bolt4x2.py | 2 +- tests/unit/sync/io/test_class_bolt4x3.py | 2 +- tests/unit/sync/io/test_class_bolt4x4.py | 2 +- tests/unit/sync/io/test_class_bolt5x0.py | 2 +- tests/unit/sync/io/test_class_bolt5x1.py | 2 +- tests/unit/sync/io/test_class_bolt5x2.py | 2 +- tests/unit/sync/io/test_class_bolt5x3.py | 2 +- tests/unit/sync/io/test_class_bolt5x4.py | 2 +- tests/unit/sync/io/test_class_bolt5x5.py | 2 +- tests/unit/sync/io/test_class_bolt5x6.py | 2 +- tests/unit/sync/io/test_class_bolt5x7.py | 851 ++++++++++++++++++ tests/unit/sync/io/test_neo4j_pool.py | 38 +- tests/unit/sync/test_auth_management.py | 2 +- tests/unit/sync/work/test_transaction.py | 7 +- 55 files changed, 2439 insertions(+), 475 deletions(-) create mode 100644 tests/unit/async_/io/test_class_bolt5x7.py create mode 100644 tests/unit/sync/io/test_class_bolt5x7.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 569cb6755..d47af7a0b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,7 @@ See also https://github.com/neo4j/neo4j-python-driver/wiki for a full changelog. ## NEXT RELEASE -- No breaking or major changes. +- TODO: list the deprecations (making error properties read-only) ## Version 5.24 diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index fc1edbe75..002c046ee 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -217,6 +217,7 @@ def _to_auth_dict(cls, auth): try: return vars(auth) except (KeyError, TypeError) as e: + # TODO: 6.0 - change this to be a DriverError (or subclass) raise AuthError( f"Cannot determine auth details from {auth!r}" ) from e @@ -306,6 +307,7 @@ def protocol_handlers(cls, protocol_version=None): AsyncBolt5x4, AsyncBolt5x5, AsyncBolt5x6, + AsyncBolt5x7, ) handlers = { @@ -322,6 +324,7 @@ def protocol_handlers(cls, protocol_version=None): AsyncBolt5x4.PROTOCOL_VERSION: AsyncBolt5x4, AsyncBolt5x5.PROTOCOL_VERSION: AsyncBolt5x5, AsyncBolt5x6.PROTOCOL_VERSION: AsyncBolt5x6, + AsyncBolt5x7.PROTOCOL_VERSION: AsyncBolt5x7, } if protocol_version is None: @@ -509,6 +512,7 @@ async def open( await AsyncBoltSocket.close_socket(s) supported_versions = cls.protocol_handlers().keys() + # TODO: 6.0 - raise public DriverError subclass instead raise BoltHandshakeError( "The neo4j server does not support communication with this " "driver. This driver has support for Bolt protocols " diff --git a/src/neo4j/_async/io/_bolt5.py b/src/neo4j/_async/io/_bolt5.py index b50513217..3932e9b5c 100644 --- a/src/neo4j/_async/io/_bolt5.py +++ b/src/neo4j/_async/io/_bolt5.py @@ -1009,7 +1009,7 @@ def enrich(metadata_): diag_record, ) continue - for key, value in self.DEFAULT_DIAGNOSTIC_RECORD: + for key, value in self.DEFAULT_STATUS_DIAGNOSTIC_RECORD: diag_record.setdefault(key, value) enrich(metadata) @@ -1097,6 +1097,7 @@ def _enrich_error_diagnostic_record(self, metadata): return for key, value in self.DEFAULT_ERROR_DIAGNOSTIC_RECORD: diag_record.setdefault(key, value) + self._enrich_error_diagnostic_record(metadata.get("cause")) async def _process_message(self, tag, fields): """Process at most one message from the server, if available. diff --git a/src/neo4j/_async/io/_common.py b/src/neo4j/_async/io/_common.py index 53c687acf..2b733507c 100644 --- a/src/neo4j/_async/io/_common.py +++ b/src/neo4j/_async/io/_common.py @@ -264,7 +264,7 @@ def _hydrate_error(self, metadata): if self.connection.PROTOCOL_VERSION >= GQL_ERROR_AWARE_PROTOCOL: return Neo4jError._hydrate_gql(**metadata) else: - return Neo4jError.hydrate(**metadata) + return Neo4jError._hydrate_neo4j(**metadata) class InitResponse(Response): diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 0653ff71b..7a520abe7 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -341,6 +341,7 @@ async def health_check(connection_, deadline_): or not await self.cond.wait(timeout) ): log.debug("[#0000] _: acquisition timed out") + # TODO: 6.0 - change this to be a DriverError (or subclass) raise ClientError( "failed to obtain a connection from the pool within " f"{deadline.original_timeout!r}s (timeout)" @@ -1055,8 +1056,10 @@ async def acquire( liveness_check_timeout, ): if access_mode not in {WRITE_ACCESS, READ_ACCESS}: + # TODO: 6.0 - change this to be a ValueError raise ClientError(f"Non valid 'access_mode'; {access_mode}") if not timeout: + # TODO: 6.0 - change this to be a ValueError raise ClientError( f"'timeout' must be a float larger than 0; {timeout}" ) diff --git a/src/neo4j/_async/work/session.py b/src/neo4j/_async/work/session.py index 4c70c760b..3aa4966c9 100644 --- a/src/neo4j/_async/work/session.py +++ b/src/neo4j/_async/work/session.py @@ -294,6 +294,7 @@ async def run( raise TypeError("query must be a string or a Query instance") if self._transaction: + # TODO: 6.0 - change this to be a TransactionError raise ClientError( "Explicit Transaction must be handled explicitly" ) diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 1c0e8bf8c..4503c6184 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -217,6 +217,7 @@ def _to_auth_dict(cls, auth): try: return vars(auth) except (KeyError, TypeError) as e: + # TODO: 6.0 - change this to be a DriverError (or subclass) raise AuthError( f"Cannot determine auth details from {auth!r}" ) from e @@ -306,6 +307,7 @@ def protocol_handlers(cls, protocol_version=None): Bolt5x4, Bolt5x5, Bolt5x6, + Bolt5x7, ) handlers = { @@ -322,6 +324,7 @@ def protocol_handlers(cls, protocol_version=None): Bolt5x4.PROTOCOL_VERSION: Bolt5x4, Bolt5x5.PROTOCOL_VERSION: Bolt5x5, Bolt5x6.PROTOCOL_VERSION: Bolt5x6, + Bolt5x7.PROTOCOL_VERSION: Bolt5x7, } if protocol_version is None: @@ -509,6 +512,7 @@ def open( BoltSocket.close_socket(s) supported_versions = cls.protocol_handlers().keys() + # TODO: 6.0 - raise public DriverError subclass instead raise BoltHandshakeError( "The neo4j server does not support communication with this " "driver. This driver has support for Bolt protocols " diff --git a/src/neo4j/_sync/io/_bolt5.py b/src/neo4j/_sync/io/_bolt5.py index 52a19dece..e20a1f62a 100644 --- a/src/neo4j/_sync/io/_bolt5.py +++ b/src/neo4j/_sync/io/_bolt5.py @@ -1009,7 +1009,7 @@ def enrich(metadata_): diag_record, ) continue - for key, value in self.DEFAULT_DIAGNOSTIC_RECORD: + for key, value in self.DEFAULT_STATUS_DIAGNOSTIC_RECORD: diag_record.setdefault(key, value) enrich(metadata) @@ -1097,6 +1097,7 @@ def _enrich_error_diagnostic_record(self, metadata): return for key, value in self.DEFAULT_ERROR_DIAGNOSTIC_RECORD: diag_record.setdefault(key, value) + self._enrich_error_diagnostic_record(metadata.get("cause")) def _process_message(self, tag, fields): """Process at most one message from the server, if available. diff --git a/src/neo4j/_sync/io/_common.py b/src/neo4j/_sync/io/_common.py index f5a440a66..a48701751 100644 --- a/src/neo4j/_sync/io/_common.py +++ b/src/neo4j/_sync/io/_common.py @@ -264,7 +264,7 @@ def _hydrate_error(self, metadata): if self.connection.PROTOCOL_VERSION >= GQL_ERROR_AWARE_PROTOCOL: return Neo4jError._hydrate_gql(**metadata) else: - return Neo4jError.hydrate(**metadata) + return Neo4jError._hydrate_neo4j(**metadata) class InitResponse(Response): diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index 94fd1d06d..1570e745c 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -338,6 +338,7 @@ def health_check(connection_, deadline_): or not self.cond.wait(timeout) ): log.debug("[#0000] _: acquisition timed out") + # TODO: 6.0 - change this to be a DriverError (or subclass) raise ClientError( "failed to obtain a connection from the pool within " f"{deadline.original_timeout!r}s (timeout)" @@ -1052,8 +1053,10 @@ def acquire( liveness_check_timeout, ): if access_mode not in {WRITE_ACCESS, READ_ACCESS}: + # TODO: 6.0 - change this to be a ValueError raise ClientError(f"Non valid 'access_mode'; {access_mode}") if not timeout: + # TODO: 6.0 - change this to be a ValueError raise ClientError( f"'timeout' must be a float larger than 0; {timeout}" ) diff --git a/src/neo4j/_sync/work/session.py b/src/neo4j/_sync/work/session.py index cee241c05..1bc415147 100644 --- a/src/neo4j/_sync/work/session.py +++ b/src/neo4j/_sync/work/session.py @@ -294,6 +294,7 @@ def run( raise TypeError("query must be a string or a Query instance") if self._transaction: + # TODO: 6.0 - change this to be a TransactionError raise ClientError( "Explicit Transaction must be handled explicitly" ) diff --git a/src/neo4j/exceptions.py b/src/neo4j/exceptions.py index a47d5a4bf..3e0b12570 100644 --- a/src/neo4j/exceptions.py +++ b/src/neo4j/exceptions.py @@ -59,13 +59,13 @@ from __future__ import annotations -import sys import typing as t -from copy import deepcopy +from copy import deepcopy as _deepcopy +from enum import Enum as _Enum from ._meta import ( deprecated, - preview, + preview as _preview, ) @@ -84,6 +84,8 @@ "DriverError", "Forbidden", "ForbiddenOnReadOnlyDatabase", + "GQLError", + "GQLErrorClassification", "IncompleteCommit", "Neo4jError", "NotALeader", @@ -106,6 +108,8 @@ if t.TYPE_CHECKING: + from collections.abc import Mapping + import typing_extensions as te from ._async.work import ( @@ -129,6 +133,7 @@ ] _TResult = t.Union[AsyncResult, Result] _TSession = t.Union[AsyncSession, Session] + _T = t.TypeVar("_T") else: _TTransaction = t.Union[ "AsyncManagedTransaction", @@ -209,65 +214,424 @@ } -_UNKNOWN_GQL_STATUS = "50N42" -_UNKNOWN_GQL_EXPLANATION = "general processing exception - unknown error" -_UNKNOWN_GQL_CLASSIFICATION = "UNKNOWN" # TODO: define final value -_UNKNOWN_GQL_DIAGNOSTIC_RECORD = ( - ( - "OPERATION", - "", - ), - ( - "OPERATION_CODE", - "0", - ), +_UNKNOWN_NEO4J_CODE: te.Final[str] = "Neo.DatabaseError.General.UnknownError" +# TODO: 6.0 - Make _UNKNOWN_GQL_MESSAGE the default message +_UNKNOWN_MESSAGE: te.Final[str] = "An unknown error occurred" +_UNKNOWN_GQL_STATUS: te.Final[str] = "50N42" +_UNKNOWN_GQL_DESCRIPTION: te.Final[str] = ( + "general processing exception - unknown error" +) +# FIXME: _UNKNOWN_GQL_MESSAGE needs final format +_UNKNOWN_GQL_MESSAGE: te.Final[str] = ( + f"{_UNKNOWN_GQL_STATUS}: {_UNKNOWN_GQL_DESCRIPTION}. {_UNKNOWN_MESSAGE}" +) +_UNKNOWN_GQL_DIAGNOSTIC_RECORD: te.Final[tuple[tuple[str, t.Any], ...]] = ( + ("OPERATION", ""), + ("OPERATION_CODE", "0"), ("CURRENT_SCHEMA", "/"), ) -def _make_gql_description(explanation: str, message: str | None = None) -> str: - if message is None: - return f"error: {explanation}" +class GQLErrorClassification(str, _Enum): + """ + Server-side notification category. - return f"error: {explanation}. {message}" + Inherits from :class:`str` and :class:`enum.Enum`. + Hence, can also be compared to its string value:: + >>> GQLErrorClassification.CLIENT_ERROR == "CLIENT_ERROR" + True + >>> GQLErrorClassification.DATABASE_ERROR == "DATABASE_ERROR" + True + >>> GQLErrorClassification.TRANSIENT_ERROR == "TRANSIENT_ERROR" + True -# Neo4jError -class Neo4jError(Exception): - """Raised when the Cypher engine returns an error to the client.""" + .. seealso:: :attr:`.GQLError.classification` - _neo4j_code: str - _message: str - _classification: str - _category: str - _title: str - #: (dict) Any additional information returned by the server. - _metadata: dict[str, t.Any] + .. versionadded:: 5.xx + """ + + CLIENT_ERROR = "CLIENT_ERROR" + DATABASE_ERROR = "DATABASE_ERROR" + TRANSIENT_ERROR = "TRANSIENT_ERROR" + #: Used when the server provides a Classification which the driver is + #: unaware of. + #: This can happen when connecting to a server newer than the driver or + #: before GQL errors were introduced. + UNKNOWN = "UNKNOWN" + + +class GQLError(Exception): + """ + The GQL compliant data of an error. + + This error isn't raised by the driver as it. + Instead, only subclasses are raised. + + **This is a preview**. + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. versionadded: 5.xx + """ _gql_status: str - _gql_explanation: str # internal use only - _gql_description: str + # TODO: 6.0 - make message always str + _message: str | None _gql_status_description: str - _gql_classification: str + _gql_raw_classification: str | None + _gql_classification: GQLErrorClassification _status_diagnostic_record: dict[str, t.Any] # original, internal only _diagnostic_record: dict[str, t.Any] # copy to be used externally + __cause__: GQLError | None + + @staticmethod + def _hydrate_cause(**metadata: t.Any) -> GQLError: + meta_extractor = _MetaExtractor(metadata) + gql_status = meta_extractor.str_value("gql_status") + description = meta_extractor.str_value("description") + message = meta_extractor.str_value("message") + diagnostic_record = meta_extractor.map_value("diagnostic_record") + cause_map = meta_extractor.map_value("cause") + if cause_map is not None: + cause = GQLError._hydrate_cause(**cause_map) + else: + cause = None + inst = GQLError() + inst._init_gql( + gql_status=gql_status, + message=message, + description=description, + diagnostic_record=diagnostic_record, + cause=cause, + ) + return inst + + def _init_gql( + self, + *, + gql_status: str | None = None, + message: str | None = None, + description: str | None = None, + diagnostic_record: dict[str, t.Any] | None = None, + cause: GQLError | None = None, + ) -> None: + if gql_status is None or message is None or description is None: + self._gql_status = _UNKNOWN_GQL_STATUS + self._message = _UNKNOWN_GQL_MESSAGE + self._gql_status_description = f"error: {_UNKNOWN_GQL_DESCRIPTION}" + else: + self._gql_status = gql_status + self._message = message + self._gql_status_description = description + if diagnostic_record is not None: + self._status_diagnostic_record = diagnostic_record + super().__setattr__("__cause__", cause) + + def _set_unknown_gql(self): + self._gql_status = _UNKNOWN_GQL_STATUS + self._message = _UNKNOWN_GQL_MESSAGE + self._gql_status_description = f"error: {_UNKNOWN_GQL_DESCRIPTION}" + + @staticmethod + def _format_message_details( + gql_status: str | None, + description: str | None, + details: str | None, + ): + if gql_status is None: + gql_status = _UNKNOWN_GQL_STATUS + if description is None: + description = _UNKNOWN_GQL_DESCRIPTION + if details is None: + return f"{gql_status}: {description}" + return f"{gql_status}: {description}. {details}" + + def __setattr__(self, key, value): + if key == "__cause__": + raise AttributeError( + "Cannot set __cause__ on GQLError or `raise ... from ...`." + ) + super().__setattr__(key, value) + + @property + def _gql_status_no_preview(self) -> str: + if hasattr(self, "_gql_status"): + return self._gql_status + + self._set_unknown_gql() + return self._gql_status + + @property + @_preview("GQLSTATUS support is a preview feature.") + def gql_status(self) -> str: + """ + The GQLSTATUS returned from the server. + + The status code ``50N42`` (unknown error) is a special code that the + driver will use for polyfilling (when connected to an old, + non-GQL-aware server). + Further, it may be used by servers during the transition-phase to + GQLSTATUS-awareness. + + This means this code is not guaranteed to be stable and may change in + future versions. + + **This is a preview**. + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + """ + return self._gql_status_no_preview + + @property + def _message_no_preview(self) -> str | None: + if hasattr(self, "_message"): + return self._message + + self._set_unknown_gql() + return self._message + + @property + @_preview("GQLSTATUS support is a preview feature.") + def message(self) -> str | None: + """ + TODO. + + **This is a preview**. + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + """ + return self._message_no_preview + + @property + def _gql_status_description_no_preview(self) -> str: + if hasattr(self, "_gql_status_description"): + return self._gql_status_description + + self._set_unknown_gql() + return self._gql_status_description + + @property + @_preview("GQLSTATUS support is a preview feature.") + def gql_status_description(self) -> str: + """ + A description of the GQLSTATUS returned from the server. + + This description is meant for human consumption and debugging purposes. + Don't rely on it in a programmatic way. + + **This is a preview**. + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + """ + return self._gql_status_description_no_preview + + @property + def _gql_raw_classification_no_preview(self) -> str | None: + if hasattr(self, "_gql_raw_classification"): + return self._gql_raw_classification + + diag_record = self._get_status_diagnostic_record() + classification = diag_record.get("_classification") + if not isinstance(classification, str): + self._gql_raw_classification = None + else: + self._gql_raw_classification = classification + return self._gql_raw_classification + + @property + @_preview("GQLSTATUS support is a preview feature.") + def gql_raw_classification(self) -> str | None: + """ + TODO. + + **This is a preview**. + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + """ + return self._gql_raw_classification_no_preview + + @property + @_preview("GQLSTATUS support is a preview feature.") + def gql_classification(self) -> GQLErrorClassification: + """ + TODO. + + **This is a preview**. + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + """ + if hasattr(self, "_gql_classification"): + return self._gql_classification + + classification = self._gql_raw_classification_no_preview + if not ( + isinstance(classification, str) + and classification + in t.cast(t.Iterable[str], iter(GQLErrorClassification)) + ): + self._gql_classification = GQLErrorClassification.UNKNOWN + else: + self._gql_classification = GQLErrorClassification(classification) + return self._gql_classification + + def _get_status_diagnostic_record(self) -> dict[str, t.Any]: + if hasattr(self, "_status_diagnostic_record"): + return self._status_diagnostic_record + + self._status_diagnostic_record = dict(_UNKNOWN_GQL_DIAGNOSTIC_RECORD) + return self._status_diagnostic_record + + @property + @_preview("GQLSTATUS support is a preview feature.") + def diagnostic_record(self) -> Mapping[str, t.Any]: + """ + Further information about the GQLSTATUS for diagnostic purposes. + + **This is a preview**. + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + """ + if hasattr(self, "_diagnostic_record"): + return self._diagnostic_record + + self._diagnostic_record = _deepcopy( + self._get_status_diagnostic_record() + ) + return self._diagnostic_record + + @_preview("GQLSTATUS support is a preview feature.") + def __str__(self): + return ( + f"{{gql_status: {self._gql_status_no_preview}}} " + f"{{gql_status_description: " + f"{self._gql_status_description_no_preview}}} " + f"{{message: {self._message_no_preview}}} " + f"{{diagnostic_record: {self.diagnostic_record}}} " + f"{{raw_classification: " + f"{self._gql_raw_classification_no_preview}}}" + ) + + +# Neo4jError +class Neo4jError(GQLError): + """Raised when the Cypher engine returns an error to the client.""" + + _neo4j_code: str | None + _classification: str | None + _category: str | None + _title: str | None + #: (dict) Any additional information returned by the server. + _metadata: dict[str, t.Any] | None _retryable = False + def __init__(self, *args) -> None: + Exception.__init__(self, *args) + self._neo4j_code = None + self._classification = None + self._category = None + self._title = None + self._metadata = None + self._message = None + + # TODO: 6.0 - do this instead to get rid of all optional attributes + # self._neo4j_code = _UNKNOWN_NEO4J_CODE + # _, self._classification, self._category, self._title = ( + # self._neo4j_code.split(".") + # ) + # self._metadata = {} + # self._init_gql(message=_UNKNOWN_MESSAGE) + + # TODO: 6.0 - Remove this alias @classmethod - def _hydrate( + @deprecated( + "Neo4jError.hydrate is deprecated and will be " + "removed in a future version. It is an internal method and not meant " + "for external use." + ) + def hydrate( cls, - *, - neo4j_code: str | None = None, + code: str | None = None, message: str | None = None, - gql_status: str | None = None, - explanation: str | None = None, - diagnostic_record: dict[str, t.Any] | None = None, - cause: Neo4jError | None = None, - ) -> te.Self: - neo4j_code = neo4j_code or "Neo.DatabaseError.General.UnknownError" - message = message or "An unknown error occurred" + **metadata: t.Any, + ) -> Neo4jError: + # backward compatibility: make falsy values None + code = code or None + message = message or None + return cls._hydrate_neo4j(code=code, message=message, **metadata) + @classmethod + def _hydrate_neo4j(cls, **metadata: t.Any) -> Neo4jError: + meta_extractor = _MetaExtractor(metadata) + code = meta_extractor.str_value("code") + if not code: + code = _UNKNOWN_NEO4J_CODE + message = meta_extractor.str_value("message") + if not message: + message = _UNKNOWN_MESSAGE + inst = cls._basic_hydrate( + neo4j_code=code, + message=message, + ) + inst._init_gql( + gql_status=_UNKNOWN_GQL_STATUS, + message=message, + description=(f"error: {_UNKNOWN_GQL_DESCRIPTION}. {message}"), + ) + inst._metadata = meta_extractor.rest() + return inst + + @classmethod + def _hydrate_gql(cls, **metadata: t.Any) -> Neo4jError: + meta_extractor = _MetaExtractor(metadata) + gql_status = meta_extractor.str_value("gql_status") + status_description = meta_extractor.str_value("description") + message = meta_extractor.str_value("message", _UNKNOWN_GQL_MESSAGE) + neo4j_code = meta_extractor.str_value( + "neo4j_code", + _UNKNOWN_NEO4J_CODE, + ) + diagnostic_record = meta_extractor.map_value("diagnostic_record") + cause_map = meta_extractor.map_value("cause") + if cause_map is not None: + cause = cls._hydrate_cause(**cause_map) + else: + cause = None + + inst = cls._basic_hydrate( + neo4j_code=neo4j_code, + message=message, + ) + inst._init_gql( + gql_status=gql_status, + message=message, + description=status_description, + diagnostic_record=diagnostic_record, + cause=cause, + ) + inst._metadata = meta_extractor.rest() + + return inst + + @classmethod + def _basic_hydrate( + cls, + *, + neo4j_code: str, + message: str, + # gql_status: str | None = None, + # status_description: str | None = None, + # diagnostic_record: dict[str, t.Any] | None = None, + # cause: GQLError | None = None, + ) -> Neo4jError: try: _, classification, category, title = neo4j_code.split(".") except ValueError: @@ -283,37 +647,21 @@ def _hydrate( if code_override is not None: neo4j_code = code_override - error_class = cls._extract_error_class(classification, neo4j_code) - - if explanation is not None: - gql_description = _make_gql_description(explanation, message) - inst = error_class(gql_description) - inst._gql_description = gql_description - else: - inst = error_class(message) + error_class: type[Neo4jError] = cls._extract_error_class( + classification, neo4j_code + ) + inst = error_class(message) inst._neo4j_code = neo4j_code - inst._message = message inst._classification = classification inst._category = category inst._title = title - if gql_status: - inst._gql_status = gql_status - if explanation: - inst._gql_explanation = explanation - if diagnostic_record is not None: - inst._status_diagnostic_record = diagnostic_record - if cause: - inst.__cause__ = cause - else: - current_exc = sys.exc_info()[1] - if current_exc is not None: - inst.__context__ = current_exc + inst._message = message return inst @classmethod - def _extract_error_class(cls, classification, code): + def _extract_error_class(cls, classification, code) -> type[Neo4jError]: if classification == CLASSIFICATION_CLIENT: try: return client_errors[code] @@ -332,54 +680,8 @@ def _extract_error_class(cls, classification, code): else: return cls - @classmethod - def hydrate( - cls, - code: str | None = None, - message: str | None = None, - **metadata: t.Any, - ) -> te.Self: - inst = cls._hydrate(neo4j_code=code, message=message) - inst._metadata = metadata - return inst - - @classmethod - def _hydrate_gql(cls, **metadata: t.Any) -> te.Self: - gql_status = metadata.pop("gql_status", None) - if not isinstance(gql_status, str): - gql_status = None - status_explanation = metadata.pop("status_explanation", None) - if not isinstance(status_explanation, str): - status_explanation = None - message = metadata.pop("status_message", None) - if not isinstance(message, str): - message = None - neo4j_code = metadata.pop("neo4j_code", None) - if not isinstance(neo4j_code, str): - neo4j_code = None - diagnostic_record = metadata.pop("diagnostic_record", None) - if not isinstance(diagnostic_record, dict): - diagnostic_record = None - cause = metadata.pop("cause", None) - if not isinstance(cause, dict): - cause = None - else: - cause = cls._hydrate_gql(**cause) - - inst = cls._hydrate( - neo4j_code=neo4j_code, - message=message, - gql_status=gql_status, - explanation=status_explanation, - diagnostic_record=diagnostic_record, - cause=cause, - ) - inst._metadata = metadata - - return inst - @property - def message(self) -> str: + def message(self) -> str | None: """ TODO. @@ -392,12 +694,8 @@ def message(self) -> str: def message(self, value: str) -> None: self._message = value - # TODO: 6.0 - Remove this alias @property - @deprecated( - "The code of a Neo4jError is deprecated. Use neo4j_code instead." - ) - def code(self) -> str: + def code(self) -> str | None: """ The neo4j error code returned by the server. @@ -413,20 +711,7 @@ def code(self, value: str) -> None: self._neo4j_code = value @property - def neo4j_code(self) -> str: - """ - The error code returned by the server. - - There are many Neo4j status codes, see - `status codes `_. - - .. versionadded: 5.xx - """ - return self._neo4j_code - - @property - @deprecated("classification of Neo4jError is deprecated.") - def classification(self) -> str: + def classification(self) -> str | None: # Undocumented, has been there before # TODO 6.0: Remove this property return self._classification @@ -437,31 +722,30 @@ def classification(self, value: str) -> None: self._classification = value @property - @deprecated("category of Neo4jError is deprecated.") - def category(self) -> str: + def category(self) -> str | None: # Undocumented, has been there before # TODO 6.0: Remove this property return self._category @category.setter - @deprecated("category of Neo4jError is deprecated.") + @deprecated("Altering the category of Neo4jError is deprecated.") def category(self, value: str) -> None: self._category = value @property - @deprecated("title of Neo4jError is deprecated.") - def title(self) -> str: + # @deprecated("title of Neo4jError is deprecated.") + def title(self) -> str | None: # Undocumented, has been there before # TODO 6.0: Remove this property return self._title @title.setter - @deprecated("title of Neo4jError is deprecated.") + @deprecated("Altering the title of Neo4jError is deprecated.") def title(self, value: str) -> None: self._title = value @property - def metadata(self) -> dict[str, t.Any]: + def metadata(self) -> dict[str, t.Any] | None: # Undocumented, might be useful for debugging return self._metadata @@ -471,118 +755,6 @@ def metadata(self, value: dict[str, t.Any]) -> None: # TODO 6.0: Remove this property self._metadata = value - @property - @preview("GQLSTATUS support is a preview feature.") - def gql_status(self) -> str: - """ - The GQLSTATUS returned from the server. - - The status code ``50N42`` (unknown error) is a special code that the - driver will use for polyfilling (when connected to an old, - non-GQL-aware server). - Further, it may be used by servers during the transition-phase to - GQLSTATUS-awareness. - - This means this code is not guaranteed to be stable and may change in - future versions. - - **This is a preview**. - It might be changed without following the deprecation policy. - See also - https://github.com/neo4j/neo4j-python-driver/wiki/preview-features - - .. versionadded: 5.xx - """ - if hasattr(self, "_gql_status"): - return self._gql_status - - self._gql_status = _UNKNOWN_GQL_STATUS - return self._gql_status - - def _get_explanation(self) -> str: - if hasattr(self, "_gql_explanation"): - return self._gql_explanation - - self._gql_explanation = _UNKNOWN_GQL_EXPLANATION - return self._gql_explanation - - @property - @preview("GQLSTATUS support is a preview feature.") - def gql_status_description(self) -> str: - """ - A description of the GQLSTATUS returned from the server. - - This description is meant for human consumption and debugging purposes. - Don't rely on it in a programmatic way. - - **This is a preview**. - It might be changed without following the deprecation policy. - See also - https://github.com/neo4j/neo4j-python-driver/wiki/preview-features - - .. versionadded: 5.xx - """ - if hasattr(self, "_gql_status_description"): - return self._gql_status_description - - self._gql_status_description = _make_gql_description( - self._get_explanation(), self._message - ) - return self._gql_status_description - - @property - @preview("GQLSTATUS support is a preview feature.") - def gql_classification(self) -> str: - """ - TODO. - - **This is a preview**. - It might be changed without following the deprecation policy. - See also - https://github.com/neo4j/neo4j-python-driver/wiki/preview-features - - .. versionadded: 5.xx - """ - # TODO - if hasattr(self, "_gql_classification"): - return self._gql_classification - - diag_record = self._get_status_diagnostic_record() - classification = diag_record.get("_classification") - if not isinstance(classification, str): - self._classification = _UNKNOWN_GQL_CLASSIFICATION - else: - self._classification = classification - return self._classification - - def _get_status_diagnostic_record(self) -> dict[str, t.Any]: - if hasattr(self, "_status_diagnostic_record"): - return self._status_diagnostic_record - - self._status_diagnostic_record = dict(_UNKNOWN_GQL_DIAGNOSTIC_RECORD) - return self._status_diagnostic_record - - @property - @preview("GQLSTATUS support is a preview feature.") - def diagnostic_record(self) -> dict[str, t.Any]: - """ - Further information about the GQLSTATUS for diagnostic purposes. - - **This is a preview**. - It might be changed without following the deprecation policy. - See also - https://github.com/neo4j/neo4j-python-driver/wiki/preview-features - - .. versionadded: 5.xx - """ - if hasattr(self, "_diagnostic_record"): - return self._diagnostic_record - - self._diagnostic_record = deepcopy( - self._get_status_diagnostic_record() - ) - return self._diagnostic_record - # TODO: 6.0 - Remove this alias @deprecated( "Neo4jError.is_retriable is deprecated and will be removed in a " @@ -666,7 +838,7 @@ def __str__(self): code = self._neo4j_code message = self._message if code or message: - return f"{{neo4j_code: {code}}} {{message: {message}}}" + return f"{{code: {code}}} {{message: {message}}}" # TODO: 6.0 - User gql status and status_description instead # something like: # return ( @@ -675,7 +847,45 @@ def __str__(self): # f"{{gql_status_description: {self.gql_status_description}}} " # f"{{diagnostic_record: {self.diagnostic_record}}}" # ) - return super().__str__() + return Exception.__str__(self) + + +class _MetaExtractor: + def __init__(self, metadata: dict[str, t.Any]): + self._metadata = metadata + + def rest(self) -> dict[str, t.Any]: + return self._metadata + + @t.overload + def str_value(self, key: str) -> str | None: ... + + @t.overload + def str_value(self, key: str, default: _T) -> str | _T: ... + + def str_value( + self, key: str, default: _T | None = None + ) -> str | _T | None: + res = self._metadata.pop(key, default) + if not isinstance(res, str): + res = default + return res + + @t.overload + def map_value(self, key: str) -> dict[str, t.Any] | None: ... + + @t.overload + def map_value(self, key: str, default: _T) -> dict[str, t.Any] | _T: ... + + def map_value( + self, key: str, default: _T | None = None + ) -> dict[str, t.Any] | _T | None: + res = self._metadata.pop(key, default) + if not ( + isinstance(res, dict) and all(isinstance(k, str) for k in res) + ): + res = default + return res # Neo4jError > ClientError @@ -788,12 +998,9 @@ class ForbiddenOnReadOnlyDatabase(TransientError): # DriverError -class DriverError(Exception): +class DriverError(GQLError): """Raised when the Driver raises an error.""" - _diagnostic_record: dict[str, t.Any] - _gql_description: str - def is_retryable(self) -> bool: """ Whether the error is retryable. @@ -809,78 +1016,8 @@ def is_retryable(self) -> bool: """ return False - @property - @preview("GQLSTATUS support is a preview feature.") - def gql_status(self) -> str: - """ - The GQLSTATUS of this error. - - .. seealso:: :attr:`.Neo4jError.gql_status` - - **This is a preview**. - It might be changed without following the deprecation policy. - See also - https://github.com/neo4j/neo4j-python-driver/wiki/preview-features - - .. versionadded: 5.xx - """ - return _UNKNOWN_GQL_STATUS - - @property - @preview("GQLSTATUS support is a preview feature.") - def gql_status_description(self) -> str: - """ - A description of the GQLSTATUS. - - This description is meant for human consumption and debugging purposes. - Don't rely on it in a programmatic way. - - **This is a preview**. - It might be changed without following the deprecation policy. - See also - https://github.com/neo4j/neo4j-python-driver/wiki/preview-features - - .. versionadded: 5.xx - """ - if hasattr(self, "_gql_description"): - return self._gql_description - - self._gql_description = _make_gql_description(_UNKNOWN_GQL_EXPLANATION) - return self._gql_description - - @property - @preview("GQLSTATUS support is a preview feature.") - def gql_classification(self) -> str: - """ - TODO. - - **This is a preview**. - It might be changed without following the deprecation policy. - See also - https://github.com/neo4j/neo4j-python-driver/wiki/preview-features - - .. versionadded: 5.xx - """ - return _UNKNOWN_GQL_CLASSIFICATION - - @property - @preview("GQLSTATUS support is a preview feature.") - def diagnostic_record(self) -> dict[str, t.Any]: - """ - Further information about the GQLSTATUS for diagnostic purposes. - - **This is a preview**. - It might be changed without following the deprecation policy. - See also - https://github.com/neo4j/neo4j-python-driver/wiki/preview-features - - .. versionadded: 5.xx - """ - if hasattr(self, "_diagnostic_record"): - return self._diagnostic_record - - self._diagnostic_record = dict(_UNKNOWN_GQL_DIAGNOSTIC_RECORD) - return self._diagnostic_record + def __str__(self): + return Exception.__str__(self) # DriverError > SessionError @@ -962,6 +1099,13 @@ class SessionExpired(DriverError): its original parameters. """ + def __init__(self, *args): + super().__init__(*args) + self._init_gql( + gql_status="08000", + description="error: connection exception", + ) + def is_retryable(self) -> bool: return True @@ -975,6 +1119,13 @@ class ServiceUnavailable(DriverError): failure of a database service that the driver is unable to route around. """ + def __init__(self, *args): + super().__init__(*args) + self._init_gql( + gql_status="08000", + description="error: connection exception", + ) + def is_retryable(self) -> bool: return True @@ -1005,6 +1156,16 @@ class IncompleteCommit(ServiceUnavailable): successfully or not. """ + def __init__(self, *args): + super().__init__(*args) + self._init_gql( + gql_status="08007", + description=( + "error: connection exception - " + "transaction resolution unknown" + ), + ) + def is_retryable(self) -> bool: return False diff --git a/testkitbackend/_async/backend.py b/testkitbackend/_async/backend.py index dcb73a169..ff2308076 100644 --- a/testkitbackend/_async/backend.py +++ b/testkitbackend/_async/backend.py @@ -125,6 +125,17 @@ async def process_request(self): @staticmethod def _exc_stems_from_driver(exc): + if isinstance( + exc, + ( + Neo4jError, + DriverError, + UnsupportedServerProduct, + BoltError, + MarkdAsDriverError, + ), + ): + return True stack = traceback.extract_tb(exc.__traceback__) for frame in stack[-1:1:-1]: p = Path(frame.filename) @@ -134,16 +145,30 @@ def _exc_stems_from_driver(exc): return True return None - async def write_driver_exc(self, exc): + def _serialize_driver_exc(self, exc): log.debug(exc.args) log.debug("".join(traceback.format_exception(exc))) key = self.next_key() self.errors[key] = exc - data = totestkit.driver_exc(exc, id_=key) + return totestkit.driver_exc(exc, id_=key) + + @staticmethod + def _serialize_backend_error(exc): + tb = "".join(traceback.format_exception(exc)) + log.error(tb) + return {"name": "BackendError", "data": {"msg": tb}} - await self.send_response(data["name"], data["data"]) + def _serialize_exc(self, exc): + try: + if isinstance(exc, requests.FrontendError): + return {"name": "FrontendError", "data": {"msg": str(exc)}} + if self._exc_stems_from_driver(exc): + return self._serialize_driver_exc(exc) + except Exception as e: + return self._serialize_backend_error(e) + return self._serialize_backend_error(exc) async def _process(self, request): # Process a received request. @@ -164,23 +189,9 @@ async def _process(self, request): f"Backend does not support some properties of the {name} " f"request: {', '.join(unused_keys)}" ) - except ( - Neo4jError, - DriverError, - UnsupportedServerProduct, - BoltError, - MarkdAsDriverError, - ) as e: - await self.write_driver_exc(e) - except requests.FrontendError as e: - await self.send_response("FrontendError", {"msg": str(e)}) except Exception as e: - if self._exc_stems_from_driver(e): - await self.write_driver_exc(e) - else: - tb = traceback.format_exc() - log.error(tb) - await self.send_response("BackendError", {"msg": tb}) + data = self._serialize_exc(e) + await self.send_response(data["name"], data["data"]) async def send_response(self, name, data): """Send a response to backend.""" diff --git a/testkitbackend/_sync/backend.py b/testkitbackend/_sync/backend.py index ee8a3edc5..b192c83c6 100644 --- a/testkitbackend/_sync/backend.py +++ b/testkitbackend/_sync/backend.py @@ -125,6 +125,17 @@ def process_request(self): @staticmethod def _exc_stems_from_driver(exc): + if isinstance( + exc, + ( + Neo4jError, + DriverError, + UnsupportedServerProduct, + BoltError, + MarkdAsDriverError, + ), + ): + return True stack = traceback.extract_tb(exc.__traceback__) for frame in stack[-1:1:-1]: p = Path(frame.filename) @@ -134,16 +145,30 @@ def _exc_stems_from_driver(exc): return True return None - def write_driver_exc(self, exc): + def _serialize_driver_exc(self, exc): log.debug(exc.args) log.debug("".join(traceback.format_exception(exc))) key = self.next_key() self.errors[key] = exc - data = totestkit.driver_exc(exc, id_=key) + return totestkit.driver_exc(exc, id_=key) + + @staticmethod + def _serialize_backend_error(exc): + tb = "".join(traceback.format_exception(exc)) + log.error(tb) + return {"name": "BackendError", "data": {"msg": tb}} - self.send_response(data["name"], data["data"]) + def _serialize_exc(self, exc): + try: + if isinstance(exc, requests.FrontendError): + return {"name": "FrontendError", "data": {"msg": str(exc)}} + if self._exc_stems_from_driver(exc): + return self._serialize_driver_exc(exc) + except Exception as e: + return self._serialize_backend_error(e) + return self._serialize_backend_error(exc) def _process(self, request): # Process a received request. @@ -164,23 +189,9 @@ def _process(self, request): f"Backend does not support some properties of the {name} " f"request: {', '.join(unused_keys)}" ) - except ( - Neo4jError, - DriverError, - UnsupportedServerProduct, - BoltError, - MarkdAsDriverError, - ) as e: - self.write_driver_exc(e) - except requests.FrontendError as e: - self.send_response("FrontendError", {"msg": str(e)}) except Exception as e: - if self._exc_stems_from_driver(e): - self.write_driver_exc(e) - else: - tb = traceback.format_exc() - log.error(tb) - self.send_response("BackendError", {"msg": tb}) + data = self._serialize_exc(e) + self.send_response(data["name"], data["data"]) def send_response(self, name, data): """Send a response to backend.""" diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index e4d8b14b5..bca7f0caa 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -58,6 +58,7 @@ "Feature:Bolt:5.4": true, "Feature:Bolt:5.5": true, "Feature:Bolt:5.6": true, + "Feature:Bolt:5.7": true, "Feature:Bolt:Patch:UTC": true, "Feature:Impersonation": true, "Feature:TLS:1.1": "Driver blocks TLS 1.1 for security reasons.", diff --git a/testkitbackend/totestkit.py b/testkitbackend/totestkit.py index a3e6c158b..5b67b0494 100644 --- a/testkitbackend/totestkit.py +++ b/testkitbackend/totestkit.py @@ -19,7 +19,10 @@ import math import neo4j -from neo4j.exceptions import Neo4jError +from neo4j.exceptions import ( + GQLError, + Neo4jError, +) from neo4j.graph import ( Node, Path, @@ -297,23 +300,8 @@ def auth_token(auth): return {"name": "AuthorizationToken", "data": vars(auth)} -def _exc_msg(exc, max_depth=10): - if isinstance(exc, Neo4jError) and exc.message is not None: - return str(exc.message) - - depth = 0 - res = str(exc) - while getattr(exc, "__cause__", None) is not None: - depth += 1 - if depth >= max_depth: - break - res += f"\nCaused by: {exc.__cause__!r}" - exc = exc.__cause__ - return res - - def driver_exc(exc, id_=None): - payload = {"msg": ""} + payload = {} if id_ is not None: payload["id"] = id_ payload["retryable"] = getattr(exc, "is_retryable", bool)() @@ -326,19 +314,66 @@ def driver_exc(exc, id_=None): payload["errorType"] = str(type(exc)) payload["msg"] = _exc_msg(exc) if isinstance(exc, Neo4jError): - payload["code"] = exc.neo4j_code + payload["code"] = exc.code + if isinstance(exc, GQLError): with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): payload["gqlStatus"] = exc.gql_status with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): payload["statusDescription"] = exc.gql_status_description - if exc.__cause__ is not None: - payload["cause"] = driver_exc(exc.__cause__) - + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + payload["rawClassification"] = exc.gql_raw_classification + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + payload["classification"] = exc.gql_classification with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): payload["diagnosticRecord"] = { k: field(v) for k, v in exc.diagnostic_record.items() } - with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): - payload["classification"] = exc.gql_classification + if exc.__cause__ is not None: + payload["cause"] = driver_exc_cause(exc.__cause__) return {"name": "DriverError", "data": payload} + + +# def _exc_msg(exc, max_depth=10): +# if isinstance(exc, Neo4jError) and exc.message is not None: +# return str(exc.message) +# +# depth = 0 +# res = str(exc) +# while getattr(exc, "__cause__", None) is not None: +# depth += 1 +# if depth >= max_depth: +# break +# res += f"\n Caused by: {exc.__cause__!r}" +# exc = exc.__cause__ +# return res + + +def _exc_msg(exc): + if isinstance(exc, Neo4jError) and exc.message is not None: + return str(exc.message) + return str(exc) + + +def driver_exc_cause(exc): + if not isinstance(exc, GQLError): + raise TypeError("Expected GQLError as cause") + payload = {} + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + payload["msg"] = exc.message + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + payload["gqlStatus"] = exc.gql_status + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + payload["statusDescription"] = exc.gql_status_description + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + payload["diagnosticRecord"] = { + k: field(v) for k, v in exc.diagnostic_record.items() + } + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + payload["classification"] = exc.gql_classification + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + payload["rawClassification"] = exc.gql_raw_classification + if exc.__cause__ is not None: + payload["cause"] = driver_exc_cause(exc.__cause__) + + return {"name": "GqlError", "data": payload} diff --git a/tests/iter_util.py b/tests/iter_util.py index b63f47667..69f5ba007 100644 --- a/tests/iter_util.py +++ b/tests/iter_util.py @@ -29,7 +29,8 @@ def powerset( iterable: t.Iterable[_T], - limit: int | None = None, + lower_limit: int | None = None, + upper_limit: int | None = None, ) -> t.Iterable[tuple[_T, ...]]: """ Build the powerset of an iterable. @@ -39,12 +40,19 @@ def powerset( >>> tuple(powerset([1, 2, 3])) ((), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)) - >>> tuple(powerset([1, 2, 3], limit=2)) + >>> tuple(powerset([1, 2, 3], upper_limit=2)) ((), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3)) + >>> tuple(powerset([1, 2, 3], lower_limit=2)) + ((1, 2), (1, 3), (2, 3), (1, 2, 3)) + :return: The powerset of the iterable. """ s = list(iterable) - if limit is None: - limit = len(s) - return chain.from_iterable(combinations(s, r) for r in range(limit + 1)) + if upper_limit is None: + upper_limit = len(s) + if lower_limit is None: + lower_limit = 0 + return chain.from_iterable( + combinations(s, r) for r in range(lower_limit, upper_limit + 1) + ) diff --git a/tests/unit/async_/fixtures/fake_connection.py b/tests/unit/async_/fixtures/fake_connection.py index 06d0270fe..9bf967791 100644 --- a/tests/unit/async_/fixtures/fake_connection.py +++ b/tests/unit/async_/fixtures/fake_connection.py @@ -206,7 +206,7 @@ async def callback(): cb_args = default_cb_args res = cb(*cb_args) if cb_name == "on_failure": - error = Neo4jError.hydrate(**cb_args[0]) + error = Neo4jError._hydrate_gql(**cb_args[0]) # suppress in case the callback is not async with suppress(TypeError): await res diff --git a/tests/unit/async_/io/test_class_bolt3.py b/tests/unit/async_/io/test_class_bolt3.py index b6ab70036..e2f56ff90 100644 --- a/tests/unit/async_/io/test_class_bolt3.py +++ b/tests/unit/async_/io/test_class_bolt3.py @@ -532,7 +532,7 @@ def raises_if_db(db): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt4x0.py b/tests/unit/async_/io/test_class_bolt4x0.py index e981bee18..fa555fd1f 100644 --- a/tests/unit/async_/io/test_class_bolt4x0.py +++ b/tests/unit/async_/io/test_class_bolt4x0.py @@ -623,7 +623,7 @@ async def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt4x1.py b/tests/unit/async_/io/test_class_bolt4x1.py index 870cd7fb8..e7ca17e04 100644 --- a/tests/unit/async_/io/test_class_bolt4x1.py +++ b/tests/unit/async_/io/test_class_bolt4x1.py @@ -645,7 +645,7 @@ async def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt4x2.py b/tests/unit/async_/io/test_class_bolt4x2.py index e7bf1f0ad..bffb44245 100644 --- a/tests/unit/async_/io/test_class_bolt4x2.py +++ b/tests/unit/async_/io/test_class_bolt4x2.py @@ -645,7 +645,7 @@ async def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt4x3.py b/tests/unit/async_/io/test_class_bolt4x3.py index e3ca11d21..1f249feb2 100644 --- a/tests/unit/async_/io/test_class_bolt4x3.py +++ b/tests/unit/async_/io/test_class_bolt4x3.py @@ -674,7 +674,7 @@ async def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt4x4.py b/tests/unit/async_/io/test_class_bolt4x4.py index a60fc2c8a..695ac7c93 100644 --- a/tests/unit/async_/io/test_class_bolt4x4.py +++ b/tests/unit/async_/io/test_class_bolt4x4.py @@ -634,7 +634,7 @@ async def test_tx_timeout( {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt5x0.py b/tests/unit/async_/io/test_class_bolt5x0.py index 90e4a8fe5..d1f09dcce 100644 --- a/tests/unit/async_/io/test_class_bolt5x0.py +++ b/tests/unit/async_/io/test_class_bolt5x0.py @@ -698,7 +698,7 @@ async def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt5x1.py b/tests/unit/async_/io/test_class_bolt5x1.py index 5118dcc48..003263aab 100644 --- a/tests/unit/async_/io/test_class_bolt5x1.py +++ b/tests/unit/async_/io/test_class_bolt5x1.py @@ -752,7 +752,7 @@ async def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt5x2.py b/tests/unit/async_/io/test_class_bolt5x2.py index 1d7bd4752..345c9a521 100644 --- a/tests/unit/async_/io/test_class_bolt5x2.py +++ b/tests/unit/async_/io/test_class_bolt5x2.py @@ -789,7 +789,7 @@ async def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt5x3.py b/tests/unit/async_/io/test_class_bolt5x3.py index d0d8ee4d5..c70a3df4f 100644 --- a/tests/unit/async_/io/test_class_bolt5x3.py +++ b/tests/unit/async_/io/test_class_bolt5x3.py @@ -676,7 +676,7 @@ async def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt5x4.py b/tests/unit/async_/io/test_class_bolt5x4.py index 72e07e9ef..7ff21e090 100644 --- a/tests/unit/async_/io/test_class_bolt5x4.py +++ b/tests/unit/async_/io/test_class_bolt5x4.py @@ -681,7 +681,7 @@ async def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt5x5.py b/tests/unit/async_/io/test_class_bolt5x5.py index a7a6000d9..77d748de3 100644 --- a/tests/unit/async_/io/test_class_bolt5x5.py +++ b/tests/unit/async_/io/test_class_bolt5x5.py @@ -690,7 +690,7 @@ async def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt5x6.py b/tests/unit/async_/io/test_class_bolt5x6.py index 067bd47d8..a51065724 100644 --- a/tests/unit/async_/io/test_class_bolt5x6.py +++ b/tests/unit/async_/io/test_class_bolt5x6.py @@ -690,7 +690,7 @@ async def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/async_/io/test_class_bolt5x7.py b/tests/unit/async_/io/test_class_bolt5x7.py new file mode 100644 index 000000000..bb9babbac --- /dev/null +++ b/tests/unit/async_/io/test_class_bolt5x7.py @@ -0,0 +1,851 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging + +import pytest + +import neo4j +from neo4j._api import TelemetryAPI +from neo4j._async.config import AsyncPoolConfig +from neo4j._async.io._bolt5 import AsyncBolt5x7 +from neo4j._meta import ( + BOLT_AGENT_DICT, + USER_AGENT, +) +from neo4j.exceptions import Neo4jError + +from ...._async_compat import mark_async_test +from ....iter_util import powerset + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 0 + connection = AsyncBolt5x7( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = -1 + connection = AsyncBolt5x7( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 999999999 + connection = AsyncBolt5x7( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},), + ), + ), +) +@mark_async_test +async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x7.UNPACKER_CLS) + connection = AsyncBolt5x7( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.begin(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + ( + ("", {}), + {"imp_user": "imposter"}, + ("", {}, {"imp_user": "imposter"}), + ), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}), + ), + ), +) +@mark_async_test +async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x7.UNPACKER_CLS) + connection = AsyncBolt5x7( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.run(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_async_test +async def test_n_extra_in_discard(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x7.UNPACKER_CLS) + connection = AsyncBolt5x7( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.discard(n=666) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ], +) +@mark_async_test +async def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x7.UNPACKER_CLS) + connection = AsyncBolt5x7( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.discard(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ], +) +@mark_async_test +async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x7.UNPACKER_CLS) + connection = AsyncBolt5x7( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.discard(n=666, qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ], +) +@mark_async_test +async def test_n_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x7.UNPACKER_CLS) + connection = AsyncBolt5x7( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.pull(n=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ], +) +@mark_async_test +async def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x7.UNPACKER_CLS) + connection = AsyncBolt5x7( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.pull(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_async_test +async def test_n_and_qid_extras_in_pull(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x7.UNPACKER_CLS) + connection = AsyncBolt5x7( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.pull(n=666, qid=777) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_async_test +async def test_hello_passes_routing_metadata(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x7.PACKER_CLS, + unpacker_cls=AsyncBolt5x7.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x7( + address, + sockets.client, + AsyncPoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"}, + ) + await connection.hello() + tag, fields = await sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_async_test +async def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x7.UNPACKER_CLS) + connection = AsyncBolt5x7( + address, + socket, + AsyncPoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled, + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + await connection.send_all() + + if serv_enabled and not driver_disabled: + tag, fields = await socket.pop_message() + assert tag == b"\x54" + assert fields == [int(api)] + else: + with pytest.raises(OSError): + await socket.pop_message() + + +@pytest.mark.parametrize( + ("hints", "valid"), + ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), + ), +) +@mark_async_test +async def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x7.PACKER_CLS, + unpacker_cls=AsyncBolt5x7.UNPACKER_CLS, + ) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x7( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + await connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any( + "recv_timeout_seconds" in msg and "invalid" in msg + for msg in caplog.messages + ) + else: + sockets.client.settimeout.assert_not_called() + assert any( + repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages + ) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize( + "auth", + ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), + ), +) +@mark_async_test +async def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x7.PACKER_CLS, + unpacker_cls=AsyncBolt5x7.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x7( + address, + sockets.client, + AsyncPoolConfig.max_connection_lifetime, + auth=auth, + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +def _assert_notifications_in_extra(extra, expected): + for key in expected: + assert key in extra + assert extra[key] == expected[key] + + +@pytest.mark.parametrize( + ("method", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("cls_min_sev", "method_min_sev"), + itertools.product((None, "WARNING", "OFF"), repeat=2), +) +@pytest.mark.parametrize( + ("cls_dis_clss", "method_dis_clss"), + itertools.product((None, [], ["HINT"], ["HINT", "DEPRECATION"]), repeat=2), +) +@mark_async_test +async def test_supports_notification_filters( + fake_socket, + method, + args, + extra_idx, + cls_min_sev, + method_min_sev, + cls_dis_clss, + method_dis_clss, +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x7.UNPACKER_CLS) + connection = AsyncBolt5x7( + address, + socket, + AsyncPoolConfig.max_connection_lifetime, + notifications_min_severity=cls_min_sev, + notifications_disabled_classifications=cls_dis_clss, + ) + method = getattr(connection, method) + + method( + *args, + notifications_min_severity=method_min_sev, + notifications_disabled_classifications=method_dis_clss, + ) + await connection.send_all() + + _, fields = await socket.pop_message() + extra = fields[extra_idx] + expected = {} + if method_min_sev is not None: + expected["notifications_minimum_severity"] = method_min_sev + if method_dis_clss is not None: + expected["notifications_disabled_classifications"] = method_dis_clss + _assert_notifications_in_extra(extra, expected) + + +@pytest.mark.parametrize("min_sev", (None, "WARNING", "OFF")) +@pytest.mark.parametrize( + "dis_clss", (None, [], ["HINT"], ["HINT", "DEPRECATION"]) +) +@mark_async_test +async def test_hello_supports_notification_filters( + fake_socket_pair, min_sev, dis_clss +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x7.PACKER_CLS, + unpacker_cls=AsyncBolt5x7.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x7( + address, + sockets.client, + AsyncPoolConfig.max_connection_lifetime, + notifications_min_severity=min_sev, + notifications_disabled_classifications=dis_clss, + ) + + await connection.hello() + + _tag, fields = await sockets.server.pop_message() + extra = fields[0] + expected = {} + if min_sev is not None: + expected["notifications_minimum_severity"] = min_sev + if dis_clss is not None: + expected["notifications_disabled_classifications"] = dis_clss + _assert_notifications_in_extra(extra, expected) + + +@mark_async_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +async def test_user_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x7.PACKER_CLS, + unpacker_cls=AsyncBolt5x7.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = AsyncBolt5x7( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + await connection.hello() + + _tag, fields = await sockets.server.pop_message() + extra = fields[0] + if not user_agent: + assert extra["user_agent"] == USER_AGENT + else: + assert extra["user_agent"] == user_agent + + +@mark_async_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +async def test_sends_bolt_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x7.PACKER_CLS, + unpacker_cls=AsyncBolt5x7.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = AsyncBolt5x7( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + await connection.hello() + + _tag, fields = await sockets.server.pop_message() + extra = fields[0] + assert extra["bolt_agent"] == BOLT_AGENT_DICT + + +@mark_async_test +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + -1e-15, + ValueError("Timeout must be a positive number or 0"), + ), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds"), + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds"), + ), + ), +) +async def test_tx_timeout( + fake_socket_pair, func, args, extra_idx, timeout, res +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x7.PACKER_CLS, + unpacker_cls=AsyncBolt5x7.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x7(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + await connection.send_all() + _tag, fields = await sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2, + ), +) +@mark_async_test +async def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x7.PACKER_CLS, + unpacker_cls=AsyncBolt5x7.UNPACKER_CLS, + ) + connection = AsyncBolt5x7(address, sockets.client, 0) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + await connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + await sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + await sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + await connection.send_all() + await connection.fetch_all() + assert connection.last_database == db + + await sockets.server.send_message(b"\x70", {}) + if finish == "reset": + await connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + await connection.send_all() + await connection.fetch_all() + + assert connection.last_database == db + + +DEFAULT_DIAG_REC_PAIRS = ( + ("OPERATION", ""), + ("OPERATION_CODE", "0"), + ("CURRENT_SCHEMA", "/"), +) + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + upper_limit=3, + ), +) +@pytest.mark.parametrize("method", ("pull", "discard")) +@mark_async_test +async def test_enriches_statuses( + sent_diag_records, + method, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x7.PACKER_CLS, + unpacker_cls=AsyncBolt5x7.UNPACKER_CLS, + ) + connection = AsyncBolt5x7(address, sockets.client, 0) + + sent_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in sent_diag_records + ] + } + await sockets.server.send_message(b"\x70", sent_metadata) + + received_metadata = None + + def on_success(metadata): + nonlocal received_metadata + received_metadata = metadata + + getattr(connection, method)(on_success=on_success) + await connection.send_all() + await connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in expected_diag_records + ] + } + + assert received_metadata == expected_metadata + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + lower_limit=1, + upper_limit=3, + ), +) +@mark_async_test +async def test_enriches_error_statuses( + sent_diag_records, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x7.PACKER_CLS, + unpacker_cls=AsyncBolt5x7.UNPACKER_CLS, + ) + connection = AsyncBolt5x7(address, sockets.client, 0) + sent_diag_records = [ + {**r, "_classification": "CLIENT_ERROR", "_status_parameters": {}} + if isinstance(r, dict) + else r + for r in sent_diag_records + ] + + sent_metadata = _build_error_hierarchy_metadata(sent_diag_records) + + await sockets.server.send_message(b"\x7f", sent_metadata) + + received_metadata = None + + def on_failure(metadata): + nonlocal received_metadata + received_metadata = metadata + + connection.run("RETURN 1", on_failure=on_failure) + await connection.send_all() + with pytest.raises(Neo4jError): + await connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = _build_error_hierarchy_metadata(expected_diag_records) + + assert received_metadata == expected_metadata + + +def _build_error_hierarchy_metadata(diag_records_metadata): + metadata = { + "gql_status": "FOO12", + "description": "but have you tried not doing that?!", + "message": "some people just can't be helped", + "neo4j_code": "Neo.ClientError.Generic.YouSuck", + } + if diag_records_metadata[0] is not ...: + metadata["diagnostic_record"] = diag_records_metadata[0] + current_root = metadata + for i, r in enumerate(diag_records_metadata[1:]): + current_root["cause"] = { + "gql_status": f"BAR{i + 1:02}", + "description": f"error cause nr. {i + 1}", + "message": f"cause message {i + 1}", + } + if r is not ...: + metadata["diagnostic_record"] = r + current_root = current_root["cause"] + return metadata diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index 6fbd2de88..c0be16adf 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -535,11 +535,13 @@ async def test_passes_pool_config_to_connection(mocker): "error", ( ServiceUnavailable(), - Neo4jError.hydrate( - "message", "Neo.ClientError.Statement.EntityNotFound" + Neo4jError._hydrate_neo4j( + code="Neo.ClientError.Statement.EntityNotFound", + message="message", ), - Neo4jError.hydrate( - "message", "Neo.ClientError.Security.AuthorizationExpired" + Neo4jError._hydrate_neo4j( + code="Neo.ClientError.Security.AuthorizationExpired", + message="message", ), ), ) @@ -578,20 +580,20 @@ async def test_discovery_is_retried(custom_routing_opener, error): @pytest.mark.parametrize( "error", map( - lambda args: Neo4jError.hydrate(*args), + lambda args: Neo4jError._hydrate_neo4j(code=args[0], message=args[1]), ( - ("message", "Neo.ClientError.Database.DatabaseNotFound"), - ("message", "Neo.ClientError.Transaction.InvalidBookmark"), - ("message", "Neo.ClientError.Transaction.InvalidBookmarkMixture"), - ("message", "Neo.ClientError.Statement.TypeError"), - ("message", "Neo.ClientError.Statement.ArgumentError"), - ("message", "Neo.ClientError.Request.Invalid"), - ("message", "Neo.ClientError.Security.AuthenticationRateLimit"), - ("message", "Neo.ClientError.Security.CredentialsExpired"), - ("message", "Neo.ClientError.Security.Forbidden"), - ("message", "Neo.ClientError.Security.TokenExpired"), - ("message", "Neo.ClientError.Security.Unauthorized"), - ("message", "Neo.ClientError.Security.MadeUpError"), + ("Neo.ClientError.Database.DatabaseNotFound", "message"), + ("Neo.ClientError.Transaction.InvalidBookmark", "message"), + ("Neo.ClientError.Transaction.InvalidBookmarkMixture", "message"), + ("Neo.ClientError.Statement.TypeError", "message"), + ("Neo.ClientError.Statement.ArgumentError", "message"), + ("Neo.ClientError.Request.Invalid", "message"), + ("Neo.ClientError.Security.AuthenticationRateLimit", "message"), + ("Neo.ClientError.Security.CredentialsExpired", "message"), + ("Neo.ClientError.Security.Forbidden", "message"), + ("Neo.ClientError.Security.TokenExpired", "message"), + ("Neo.ClientError.Security.Unauthorized", "message"), + ("Neo.ClientError.Security.MadeUpError", "message"), ), ), ) @@ -627,7 +629,7 @@ async def test_fast_failing_discovery(custom_routing_opener, error): @pytest.mark.parametrize( ("error", "marks_unauthenticated", "fetches_new"), ( - (Neo4jError.hydrate("message", args[0]), *args[1:]) + (Neo4jError._hydrate_neo4j(code=args[0], message="message"), *args[1:]) for args in ( ("Neo.ClientError.Database.DatabaseNotFound", False, False), ("Neo.ClientError.Statement.TypeError", False, False), diff --git a/tests/unit/async_/test_auth_management.py b/tests/unit/async_/test_auth_management.py index dd731a9c5..57420d3a2 100644 --- a/tests/unit/async_/test_auth_management.py +++ b/tests/unit/async_/test_auth_management.py @@ -59,7 +59,7 @@ "Neo.ClientError.Security.Unauthorized", } SAMPLE_ERRORS = [ - Neo4jError.hydrate(code=code) + Neo4jError._hydrate_neo4j(code=code) for code in { "Neo.ClientError.Security.AuthenticationRateLimit", "Neo.ClientError.Security.AuthorizationExpired", diff --git a/tests/unit/async_/work/test_transaction.py b/tests/unit/async_/work/test_transaction.py index 4e30ff582..81fba42b6 100644 --- a/tests/unit/async_/work/test_transaction.py +++ b/tests/unit/async_/work/test_transaction.py @@ -315,7 +315,12 @@ async def test_server_error_propagates(async_scripted_connection, error): ( "pull", { - "on_failure": ({"code": "Neo.ClientError.Made.Up"},), + "on_failure": ( + { + "neo4j_code": "Neo.ClientError.Made.Up", + "gql_status": "50N42", + }, + ), "on_summary": None, }, ) diff --git a/tests/unit/common/test_exceptions.py b/tests/unit/common/test_exceptions.py index 8cc192cdb..882e75141 100644 --- a/tests/unit/common/test_exceptions.py +++ b/tests/unit/common/test_exceptions.py @@ -149,7 +149,7 @@ def test_serviceunavailable_raised_from_bolt_protocol_error_with_explicit_style( def test_neo4jerror_hydrate_with_no_args(): - error = Neo4jError.hydrate() + error = Neo4jError._hydrate_neo4j() assert isinstance(error, DatabaseError) assert error.classification == CLASSIFICATION_DATABASE @@ -160,8 +160,10 @@ def test_neo4jerror_hydrate_with_no_args(): assert error.code == "Neo.DatabaseError.General.UnknownError" -def test_neo4jerror_hydrate_with_message_and_code_rubish(): - error = Neo4jError.hydrate(message="Test error message", code="ASDF_asdf") +def test_neo4jerror_hydrate_with_message_and_code_rubbish(): + error = Neo4jError._hydrate_neo4j( + message="Test error message", code="ASDF_asdf" + ) assert isinstance(error, DatabaseError) assert error.classification == CLASSIFICATION_DATABASE @@ -173,7 +175,7 @@ def test_neo4jerror_hydrate_with_message_and_code_rubish(): def test_neo4jerror_hydrate_with_message_and_code_database(): - error = Neo4jError.hydrate( + error = Neo4jError._hydrate_neo4j( message="Test error message", code="Neo.DatabaseError.General.UnknownError", ) @@ -188,7 +190,7 @@ def test_neo4jerror_hydrate_with_message_and_code_database(): def test_neo4jerror_hydrate_with_message_and_code_transient(): - error = Neo4jError.hydrate( + error = Neo4jError._hydrate_neo4j( message="Test error message", code="Neo.TransientError.General.TestError", ) @@ -203,7 +205,7 @@ def test_neo4jerror_hydrate_with_message_and_code_transient(): def test_neo4jerror_hydrate_with_message_and_code_client(): - error = Neo4jError.hydrate( + error = Neo4jError._hydrate_neo4j( message="Test error message", code=f"Neo.{CLASSIFICATION_CLIENT}.General.TestError", ) @@ -254,7 +256,7 @@ def test_neo4jerror_hydrate_with_message_and_code_client(): ) def test_error_rewrite(code, expected_cls, expected_code): message = "Test error message" - error = Neo4jError.hydrate(message=message, code=code) + error = Neo4jError._hydrate_neo4j(message=message, code=code) expected_retryable = expected_cls is TransientError assert error.__class__ is expected_cls @@ -303,12 +305,12 @@ def test_error_rewrite(code, expected_cls, expected_code): "{code: Neo.ClientError.General.UnknownError} " "{message: An unknown error occurred}", ), - ), + )[1:2], ) def test_neo4j_error_from_server_as_str( code, message, expected_cls, expected_str ): - error = Neo4jError.hydrate(code=code, message=message) + error = Neo4jError._hydrate_neo4j(code=code, message=message) assert type(error) is expected_cls assert str(error) == expected_str diff --git a/tests/unit/common/test_record.py b/tests/unit/common/test_record.py index 0c5156542..b3c38301a 100644 --- a/tests/unit/common/test_record.py +++ b/tests/unit/common/test_record.py @@ -591,5 +591,6 @@ class TestError(Exception): with pytest.raises(BrokenRecordError) as raised: accessor(r) exc_value = raised.value + assert exc_value.__cause__ is not None assert exc_value.__cause__ is exc assert list(traceback.walk_tb(exc_value.__cause__.__traceback__)) == frames diff --git a/tests/unit/sync/fixtures/fake_connection.py b/tests/unit/sync/fixtures/fake_connection.py index 08f63df62..8785badb6 100644 --- a/tests/unit/sync/fixtures/fake_connection.py +++ b/tests/unit/sync/fixtures/fake_connection.py @@ -206,7 +206,7 @@ def callback(): cb_args = default_cb_args res = cb(*cb_args) if cb_name == "on_failure": - error = Neo4jError.hydrate(**cb_args[0]) + error = Neo4jError._hydrate_gql(**cb_args[0]) # suppress in case the callback is not async with suppress(TypeError): res diff --git a/tests/unit/sync/io/test_class_bolt3.py b/tests/unit/sync/io/test_class_bolt3.py index d6b1b3c13..ba80ce81b 100644 --- a/tests/unit/sync/io/test_class_bolt3.py +++ b/tests/unit/sync/io/test_class_bolt3.py @@ -532,7 +532,7 @@ def raises_if_db(db): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt4x0.py b/tests/unit/sync/io/test_class_bolt4x0.py index 619902f6f..a0ad36e8b 100644 --- a/tests/unit/sync/io/test_class_bolt4x0.py +++ b/tests/unit/sync/io/test_class_bolt4x0.py @@ -623,7 +623,7 @@ def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt4x1.py b/tests/unit/sync/io/test_class_bolt4x1.py index d89bcd5dd..c4b0208a8 100644 --- a/tests/unit/sync/io/test_class_bolt4x1.py +++ b/tests/unit/sync/io/test_class_bolt4x1.py @@ -645,7 +645,7 @@ def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt4x2.py b/tests/unit/sync/io/test_class_bolt4x2.py index bdcee8c6e..b6ac961a5 100644 --- a/tests/unit/sync/io/test_class_bolt4x2.py +++ b/tests/unit/sync/io/test_class_bolt4x2.py @@ -645,7 +645,7 @@ def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt4x3.py b/tests/unit/sync/io/test_class_bolt4x3.py index 333eccece..c5da8700f 100644 --- a/tests/unit/sync/io/test_class_bolt4x3.py +++ b/tests/unit/sync/io/test_class_bolt4x3.py @@ -674,7 +674,7 @@ def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt4x4.py b/tests/unit/sync/io/test_class_bolt4x4.py index 5eb4cbb73..164372b00 100644 --- a/tests/unit/sync/io/test_class_bolt4x4.py +++ b/tests/unit/sync/io/test_class_bolt4x4.py @@ -634,7 +634,7 @@ def test_tx_timeout( {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt5x0.py b/tests/unit/sync/io/test_class_bolt5x0.py index 4b3677c8a..6f26b97a9 100644 --- a/tests/unit/sync/io/test_class_bolt5x0.py +++ b/tests/unit/sync/io/test_class_bolt5x0.py @@ -698,7 +698,7 @@ def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt5x1.py b/tests/unit/sync/io/test_class_bolt5x1.py index cd2219fd7..dfe638a90 100644 --- a/tests/unit/sync/io/test_class_bolt5x1.py +++ b/tests/unit/sync/io/test_class_bolt5x1.py @@ -752,7 +752,7 @@ def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt5x2.py b/tests/unit/sync/io/test_class_bolt5x2.py index 017df4217..5dc09be8a 100644 --- a/tests/unit/sync/io/test_class_bolt5x2.py +++ b/tests/unit/sync/io/test_class_bolt5x2.py @@ -789,7 +789,7 @@ def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt5x3.py b/tests/unit/sync/io/test_class_bolt5x3.py index 533006cbd..af8527106 100644 --- a/tests/unit/sync/io/test_class_bolt5x3.py +++ b/tests/unit/sync/io/test_class_bolt5x3.py @@ -676,7 +676,7 @@ def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt5x4.py b/tests/unit/sync/io/test_class_bolt5x4.py index 345850917..5773d1f61 100644 --- a/tests/unit/sync/io/test_class_bolt5x4.py +++ b/tests/unit/sync/io/test_class_bolt5x4.py @@ -681,7 +681,7 @@ def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION_CODE": "0"}, {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt5x5.py b/tests/unit/sync/io/test_class_bolt5x5.py index 90d5fd21e..361a9c14d 100644 --- a/tests/unit/sync/io/test_class_bolt5x5.py +++ b/tests/unit/sync/io/test_class_bolt5x5.py @@ -690,7 +690,7 @@ def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt5x6.py b/tests/unit/sync/io/test_class_bolt5x6.py index fcde8a7c8..15f378720 100644 --- a/tests/unit/sync/io/test_class_bolt5x6.py +++ b/tests/unit/sync/io/test_class_bolt5x6.py @@ -690,7 +690,7 @@ def test_tracks_last_database(fake_socket_pair, actions): {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, ), - limit=3, + upper_limit=3, ), ) @pytest.mark.parametrize("method", ("pull", "discard")) diff --git a/tests/unit/sync/io/test_class_bolt5x7.py b/tests/unit/sync/io/test_class_bolt5x7.py new file mode 100644 index 000000000..6aaf8502f --- /dev/null +++ b/tests/unit/sync/io/test_class_bolt5x7.py @@ -0,0 +1,851 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging + +import pytest + +import neo4j +from neo4j._api import TelemetryAPI +from neo4j._meta import ( + BOLT_AGENT_DICT, + USER_AGENT, +) +from neo4j._sync.config import PoolConfig +from neo4j._sync.io._bolt5 import Bolt5x7 +from neo4j.exceptions import Neo4jError + +from ...._async_compat import mark_sync_test +from ....iter_util import powerset + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 0 + connection = Bolt5x7( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = -1 + connection = Bolt5x7( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 999999999 + connection = Bolt5x7( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},), + ), + ), +) +@mark_sync_test +def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x7.UNPACKER_CLS) + connection = Bolt5x7( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.begin(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + ( + ("", {}), + {"imp_user": "imposter"}, + ("", {}, {"imp_user": "imposter"}), + ), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}), + ), + ), +) +@mark_sync_test +def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x7.UNPACKER_CLS) + connection = Bolt5x7( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.run(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_sync_test +def test_n_extra_in_discard(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x7.UNPACKER_CLS) + connection = Bolt5x7( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.discard(n=666) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ], +) +@mark_sync_test +def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x7.UNPACKER_CLS) + connection = Bolt5x7( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.discard(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ], +) +@mark_sync_test +def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x7.UNPACKER_CLS) + connection = Bolt5x7( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.discard(n=666, qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ], +) +@mark_sync_test +def test_n_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x7.UNPACKER_CLS) + connection = Bolt5x7( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.pull(n=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ], +) +@mark_sync_test +def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x7.UNPACKER_CLS) + connection = Bolt5x7( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.pull(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_sync_test +def test_n_and_qid_extras_in_pull(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x7.UNPACKER_CLS) + connection = Bolt5x7( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.pull(n=666, qid=777) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_sync_test +def test_hello_passes_routing_metadata(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x7.PACKER_CLS, + unpacker_cls=Bolt5x7.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x7( + address, + sockets.client, + PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"}, + ) + connection.hello() + tag, fields = sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_sync_test +def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x7.UNPACKER_CLS) + connection = Bolt5x7( + address, + socket, + PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled, + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + connection.send_all() + + if serv_enabled and not driver_disabled: + tag, fields = socket.pop_message() + assert tag == b"\x54" + assert fields == [int(api)] + else: + with pytest.raises(OSError): + socket.pop_message() + + +@pytest.mark.parametrize( + ("hints", "valid"), + ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), + ), +) +@mark_sync_test +def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x7.PACKER_CLS, + unpacker_cls=Bolt5x7.UNPACKER_CLS, + ) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x7( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any( + "recv_timeout_seconds" in msg and "invalid" in msg + for msg in caplog.messages + ) + else: + sockets.client.settimeout.assert_not_called() + assert any( + repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages + ) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize( + "auth", + ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), + ), +) +@mark_sync_test +def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x7.PACKER_CLS, + unpacker_cls=Bolt5x7.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x7( + address, + sockets.client, + PoolConfig.max_connection_lifetime, + auth=auth, + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +def _assert_notifications_in_extra(extra, expected): + for key in expected: + assert key in extra + assert extra[key] == expected[key] + + +@pytest.mark.parametrize( + ("method", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("cls_min_sev", "method_min_sev"), + itertools.product((None, "WARNING", "OFF"), repeat=2), +) +@pytest.mark.parametrize( + ("cls_dis_clss", "method_dis_clss"), + itertools.product((None, [], ["HINT"], ["HINT", "DEPRECATION"]), repeat=2), +) +@mark_sync_test +def test_supports_notification_filters( + fake_socket, + method, + args, + extra_idx, + cls_min_sev, + method_min_sev, + cls_dis_clss, + method_dis_clss, +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x7.UNPACKER_CLS) + connection = Bolt5x7( + address, + socket, + PoolConfig.max_connection_lifetime, + notifications_min_severity=cls_min_sev, + notifications_disabled_classifications=cls_dis_clss, + ) + method = getattr(connection, method) + + method( + *args, + notifications_min_severity=method_min_sev, + notifications_disabled_classifications=method_dis_clss, + ) + connection.send_all() + + _, fields = socket.pop_message() + extra = fields[extra_idx] + expected = {} + if method_min_sev is not None: + expected["notifications_minimum_severity"] = method_min_sev + if method_dis_clss is not None: + expected["notifications_disabled_classifications"] = method_dis_clss + _assert_notifications_in_extra(extra, expected) + + +@pytest.mark.parametrize("min_sev", (None, "WARNING", "OFF")) +@pytest.mark.parametrize( + "dis_clss", (None, [], ["HINT"], ["HINT", "DEPRECATION"]) +) +@mark_sync_test +def test_hello_supports_notification_filters( + fake_socket_pair, min_sev, dis_clss +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x7.PACKER_CLS, + unpacker_cls=Bolt5x7.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x7( + address, + sockets.client, + PoolConfig.max_connection_lifetime, + notifications_min_severity=min_sev, + notifications_disabled_classifications=dis_clss, + ) + + connection.hello() + + _tag, fields = sockets.server.pop_message() + extra = fields[0] + expected = {} + if min_sev is not None: + expected["notifications_minimum_severity"] = min_sev + if dis_clss is not None: + expected["notifications_disabled_classifications"] = dis_clss + _assert_notifications_in_extra(extra, expected) + + +@mark_sync_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +def test_user_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x7.PACKER_CLS, + unpacker_cls=Bolt5x7.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = Bolt5x7( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + connection.hello() + + _tag, fields = sockets.server.pop_message() + extra = fields[0] + if not user_agent: + assert extra["user_agent"] == USER_AGENT + else: + assert extra["user_agent"] == user_agent + + +@mark_sync_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +def test_sends_bolt_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x7.PACKER_CLS, + unpacker_cls=Bolt5x7.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = Bolt5x7( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + connection.hello() + + _tag, fields = sockets.server.pop_message() + extra = fields[0] + assert extra["bolt_agent"] == BOLT_AGENT_DICT + + +@mark_sync_test +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + -1e-15, + ValueError("Timeout must be a positive number or 0"), + ), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds"), + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds"), + ), + ), +) +def test_tx_timeout( + fake_socket_pair, func, args, extra_idx, timeout, res +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x7.PACKER_CLS, + unpacker_cls=Bolt5x7.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x7(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + connection.send_all() + _tag, fields = sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2, + ), +) +@mark_sync_test +def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x7.PACKER_CLS, + unpacker_cls=Bolt5x7.UNPACKER_CLS, + ) + connection = Bolt5x7(address, sockets.client, 0) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + connection.send_all() + connection.fetch_all() + assert connection.last_database == db + + sockets.server.send_message(b"\x70", {}) + if finish == "reset": + connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + connection.send_all() + connection.fetch_all() + + assert connection.last_database == db + + +DEFAULT_DIAG_REC_PAIRS = ( + ("OPERATION", ""), + ("OPERATION_CODE", "0"), + ("CURRENT_SCHEMA", "/"), +) + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + upper_limit=3, + ), +) +@pytest.mark.parametrize("method", ("pull", "discard")) +@mark_sync_test +def test_enriches_statuses( + sent_diag_records, + method, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x7.PACKER_CLS, + unpacker_cls=Bolt5x7.UNPACKER_CLS, + ) + connection = Bolt5x7(address, sockets.client, 0) + + sent_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in sent_diag_records + ] + } + sockets.server.send_message(b"\x70", sent_metadata) + + received_metadata = None + + def on_success(metadata): + nonlocal received_metadata + received_metadata = metadata + + getattr(connection, method)(on_success=on_success) + connection.send_all() + connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in expected_diag_records + ] + } + + assert received_metadata == expected_metadata + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + lower_limit=1, + upper_limit=3, + ), +) +@mark_sync_test +def test_enriches_error_statuses( + sent_diag_records, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x7.PACKER_CLS, + unpacker_cls=Bolt5x7.UNPACKER_CLS, + ) + connection = Bolt5x7(address, sockets.client, 0) + sent_diag_records = [ + {**r, "_classification": "CLIENT_ERROR", "_status_parameters": {}} + if isinstance(r, dict) + else r + for r in sent_diag_records + ] + + sent_metadata = _build_error_hierarchy_metadata(sent_diag_records) + + sockets.server.send_message(b"\x7f", sent_metadata) + + received_metadata = None + + def on_failure(metadata): + nonlocal received_metadata + received_metadata = metadata + + connection.run("RETURN 1", on_failure=on_failure) + connection.send_all() + with pytest.raises(Neo4jError): + connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = _build_error_hierarchy_metadata(expected_diag_records) + + assert received_metadata == expected_metadata + + +def _build_error_hierarchy_metadata(diag_records_metadata): + metadata = { + "gql_status": "FOO12", + "description": "but have you tried not doing that?!", + "message": "some people just can't be helped", + "neo4j_code": "Neo.ClientError.Generic.YouSuck", + } + if diag_records_metadata[0] is not ...: + metadata["diagnostic_record"] = diag_records_metadata[0] + current_root = metadata + for i, r in enumerate(diag_records_metadata[1:]): + current_root["cause"] = { + "gql_status": f"BAR{i + 1:02}", + "description": f"error cause nr. {i + 1}", + "message": f"cause message {i + 1}", + } + if r is not ...: + metadata["diagnostic_record"] = r + current_root = current_root["cause"] + return metadata diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index edc777d0b..89b4d16b3 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -535,11 +535,13 @@ def test_passes_pool_config_to_connection(mocker): "error", ( ServiceUnavailable(), - Neo4jError.hydrate( - "message", "Neo.ClientError.Statement.EntityNotFound" + Neo4jError._hydrate_neo4j( + code="Neo.ClientError.Statement.EntityNotFound", + message="message", ), - Neo4jError.hydrate( - "message", "Neo.ClientError.Security.AuthorizationExpired" + Neo4jError._hydrate_neo4j( + code="Neo.ClientError.Security.AuthorizationExpired", + message="message", ), ), ) @@ -578,20 +580,20 @@ def test_discovery_is_retried(custom_routing_opener, error): @pytest.mark.parametrize( "error", map( - lambda args: Neo4jError.hydrate(*args), + lambda args: Neo4jError._hydrate_neo4j(code=args[0], message=args[1]), ( - ("message", "Neo.ClientError.Database.DatabaseNotFound"), - ("message", "Neo.ClientError.Transaction.InvalidBookmark"), - ("message", "Neo.ClientError.Transaction.InvalidBookmarkMixture"), - ("message", "Neo.ClientError.Statement.TypeError"), - ("message", "Neo.ClientError.Statement.ArgumentError"), - ("message", "Neo.ClientError.Request.Invalid"), - ("message", "Neo.ClientError.Security.AuthenticationRateLimit"), - ("message", "Neo.ClientError.Security.CredentialsExpired"), - ("message", "Neo.ClientError.Security.Forbidden"), - ("message", "Neo.ClientError.Security.TokenExpired"), - ("message", "Neo.ClientError.Security.Unauthorized"), - ("message", "Neo.ClientError.Security.MadeUpError"), + ("Neo.ClientError.Database.DatabaseNotFound", "message"), + ("Neo.ClientError.Transaction.InvalidBookmark", "message"), + ("Neo.ClientError.Transaction.InvalidBookmarkMixture", "message"), + ("Neo.ClientError.Statement.TypeError", "message"), + ("Neo.ClientError.Statement.ArgumentError", "message"), + ("Neo.ClientError.Request.Invalid", "message"), + ("Neo.ClientError.Security.AuthenticationRateLimit", "message"), + ("Neo.ClientError.Security.CredentialsExpired", "message"), + ("Neo.ClientError.Security.Forbidden", "message"), + ("Neo.ClientError.Security.TokenExpired", "message"), + ("Neo.ClientError.Security.Unauthorized", "message"), + ("Neo.ClientError.Security.MadeUpError", "message"), ), ), ) @@ -627,7 +629,7 @@ def test_fast_failing_discovery(custom_routing_opener, error): @pytest.mark.parametrize( ("error", "marks_unauthenticated", "fetches_new"), ( - (Neo4jError.hydrate("message", args[0]), *args[1:]) + (Neo4jError._hydrate_neo4j(code=args[0], message="message"), *args[1:]) for args in ( ("Neo.ClientError.Database.DatabaseNotFound", False, False), ("Neo.ClientError.Statement.TypeError", False, False), diff --git a/tests/unit/sync/test_auth_management.py b/tests/unit/sync/test_auth_management.py index da90a9409..34eb4d727 100644 --- a/tests/unit/sync/test_auth_management.py +++ b/tests/unit/sync/test_auth_management.py @@ -59,7 +59,7 @@ "Neo.ClientError.Security.Unauthorized", } SAMPLE_ERRORS = [ - Neo4jError.hydrate(code=code) + Neo4jError._hydrate_neo4j(code=code) for code in { "Neo.ClientError.Security.AuthenticationRateLimit", "Neo.ClientError.Security.AuthorizationExpired", diff --git a/tests/unit/sync/work/test_transaction.py b/tests/unit/sync/work/test_transaction.py index 97f31409b..53aeba1e5 100644 --- a/tests/unit/sync/work/test_transaction.py +++ b/tests/unit/sync/work/test_transaction.py @@ -315,7 +315,12 @@ def test_server_error_propagates(scripted_connection, error): ( "pull", { - "on_failure": ({"code": "Neo.ClientError.Made.Up"},), + "on_failure": ( + { + "neo4j_code": "Neo.ClientError.Made.Up", + "gql_status": "50N42", + }, + ), "on_summary": None, }, ) From c1b49737641cf06e791685768a135f70f37dbc08 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Wed, 25 Sep 2024 12:05:58 +0200 Subject: [PATCH 03/23] Add docs, clean-up to-dos and commented out code --- CHANGELOG.md | 4 +- docs/source/api.rst | 11 +++ src/neo4j/_work/summary.py | 2 +- src/neo4j/exceptions.py | 145 +++++++++++++++++------------------- testkitbackend/totestkit.py | 23 +----- 5 files changed, 86 insertions(+), 99 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d47af7a0b..bc020e69b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,9 @@ See also https://github.com/neo4j/neo4j-python-driver/wiki for a full changelog. ## NEXT RELEASE -- TODO: list the deprecations (making error properties read-only) +- Deprecated setting attributes on `Neo4jError` like `message` and `code`. +- Deprecated undocumented method `Neo4jError.hydrate`. + It's internal and should not be used by client code. ## Version 5.24 diff --git a/docs/source/api.rst b/docs/source/api.rst index a623387c1..35041a855 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -1989,6 +1989,17 @@ Errors ****** +GQL Errors +========== +.. autoexception:: neo4j.exceptions.GqlError() + :show-inheritance: + :members: gql_status, message, gql_status_description, gql_raw_classification, gql_classification, diagnostic_record, __cause__ + +.. autoclass:: neo4j.exceptions.GqlErrorClassification() + :show-inheritance: + :members: + + Neo4j Errors ============ diff --git a/src/neo4j/_work/summary.py b/src/neo4j/_work/summary.py index f605e3c7e..54849832f 100644 --- a/src/neo4j/_work/summary.py +++ b/src/neo4j/_work/summary.py @@ -767,7 +767,7 @@ def gql_status(self) -> str: .. note:: This means these codes are not guaranteed to be stable and may - change in future versions. + change in future versions of the driver or the server. """ if hasattr(self, "_gql_status"): return self._gql_status diff --git a/src/neo4j/exceptions.py b/src/neo4j/exceptions.py index 3e0b12570..0727ea86d 100644 --- a/src/neo4j/exceptions.py +++ b/src/neo4j/exceptions.py @@ -84,8 +84,8 @@ "DriverError", "Forbidden", "ForbiddenOnReadOnlyDatabase", - "GQLError", - "GQLErrorClassification", + "GqlError", + "GqlErrorClassification", "IncompleteCommit", "Neo4jError", "NotALeader", @@ -232,21 +232,26 @@ ) -class GQLErrorClassification(str, _Enum): +class GqlErrorClassification(str, _Enum): """ - Server-side notification category. + Server-side GQL error category. Inherits from :class:`str` and :class:`enum.Enum`. Hence, can also be compared to its string value:: - >>> GQLErrorClassification.CLIENT_ERROR == "CLIENT_ERROR" + >>> GqlErrorClassification.CLIENT_ERROR == "CLIENT_ERROR" True - >>> GQLErrorClassification.DATABASE_ERROR == "DATABASE_ERROR" + >>> GqlErrorClassification.DATABASE_ERROR == "DATABASE_ERROR" True - >>> GQLErrorClassification.TRANSIENT_ERROR == "TRANSIENT_ERROR" + >>> GqlErrorClassification.TRANSIENT_ERROR == "TRANSIENT_ERROR" True - .. seealso:: :attr:`.GQLError.classification` + **This is a preview**. + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. seealso:: :attr:`.GqlError.gql_classification` .. versionadded:: 5.xx """ @@ -261,12 +266,13 @@ class GQLErrorClassification(str, _Enum): UNKNOWN = "UNKNOWN" -class GQLError(Exception): +class GqlError(Exception): """ The GQL compliant data of an error. - This error isn't raised by the driver as it. + This error isn't raised by the driver as is. Instead, only subclasses are raised. + Further, it is used as the :attr:`__cause__` of GqlError subclasses. **This is a preview**. It might be changed without following the deprecation policy. @@ -281,13 +287,13 @@ class GQLError(Exception): _message: str | None _gql_status_description: str _gql_raw_classification: str | None - _gql_classification: GQLErrorClassification + _gql_classification: GqlErrorClassification _status_diagnostic_record: dict[str, t.Any] # original, internal only _diagnostic_record: dict[str, t.Any] # copy to be used externally - __cause__: GQLError | None + __cause__: GqlError | None @staticmethod - def _hydrate_cause(**metadata: t.Any) -> GQLError: + def _hydrate_cause(**metadata: t.Any) -> GqlError: meta_extractor = _MetaExtractor(metadata) gql_status = meta_extractor.str_value("gql_status") description = meta_extractor.str_value("description") @@ -295,10 +301,10 @@ def _hydrate_cause(**metadata: t.Any) -> GQLError: diagnostic_record = meta_extractor.map_value("diagnostic_record") cause_map = meta_extractor.map_value("cause") if cause_map is not None: - cause = GQLError._hydrate_cause(**cause_map) + cause = GqlError._hydrate_cause(**cause_map) else: cause = None - inst = GQLError() + inst = GqlError() inst._init_gql( gql_status=gql_status, message=message, @@ -315,7 +321,7 @@ def _init_gql( message: str | None = None, description: str | None = None, diagnostic_record: dict[str, t.Any] | None = None, - cause: GQLError | None = None, + cause: GqlError | None = None, ) -> None: if gql_status is None or message is None or description is None: self._gql_status = _UNKNOWN_GQL_STATUS @@ -351,7 +357,7 @@ def _format_message_details( def __setattr__(self, key, value): if key == "__cause__": raise AttributeError( - "Cannot set __cause__ on GQLError or `raise ... from ...`." + "Cannot set __cause__ on GqlError or `raise ... from ...`." ) super().__setattr__(key, value) @@ -375,13 +381,9 @@ def gql_status(self) -> str: Further, it may be used by servers during the transition-phase to GQLSTATUS-awareness. - This means this code is not guaranteed to be stable and may change in - future versions. - - **This is a preview**. - It might be changed without following the deprecation policy. - See also - https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + .. note:: + This means that the code ``50N42`` is not guaranteed to be stable + and may change in future versions of the driver or the server. """ return self._gql_status_no_preview @@ -397,12 +399,15 @@ def _message_no_preview(self) -> str | None: @_preview("GQLSTATUS support is a preview feature.") def message(self) -> str | None: """ - TODO. + The error message returned by the server. - **This is a preview**. - It might be changed without following the deprecation policy. - See also - https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + It is a string representation of the error that occurred. + + This message is meant for human consumption and debugging purposes. + Don't rely on it in a programmatic way. + + This value is never :data:`None` unless the subclass in question + states otherwise. """ return self._message_no_preview @@ -420,13 +425,10 @@ def gql_status_description(self) -> str: """ A description of the GQLSTATUS returned from the server. + It describes the error that occurred in detail. + This description is meant for human consumption and debugging purposes. Don't rely on it in a programmatic way. - - **This is a preview**. - It might be changed without following the deprecation policy. - See also - https://github.com/neo4j/neo4j-python-driver/wiki/preview-features """ return self._gql_status_description_no_preview @@ -447,25 +449,32 @@ def _gql_raw_classification_no_preview(self) -> str | None: @_preview("GQLSTATUS support is a preview feature.") def gql_raw_classification(self) -> str | None: """ - TODO. + Vendor specific classification of the error. - **This is a preview**. - It might be changed without following the deprecation policy. - See also - https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + This is a convenience accessor for ``_classification`` in the + diagnostic record. + :data:`None` is returned if the classification is not available + or not a string. """ return self._gql_raw_classification_no_preview @property @_preview("GQLSTATUS support is a preview feature.") - def gql_classification(self) -> GQLErrorClassification: + def gql_classification(self) -> GqlErrorClassification: """ - TODO. + Vendor specific classification of the error. + + This is the parsed version of :attr:`.gql_raw_classification`. + :attr:`GqlErrorClassification.UNKNOWN` is returned if the + classification is missing, invalid, or has an unknown value + (e.g., a newer version of the server introduced a new value). **This is a preview**. It might be changed without following the deprecation policy. See also https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. seealso:: :attr:`.gql_raw_classification` """ if hasattr(self, "_gql_classification"): return self._gql_classification @@ -474,11 +483,11 @@ def gql_classification(self) -> GQLErrorClassification: if not ( isinstance(classification, str) and classification - in t.cast(t.Iterable[str], iter(GQLErrorClassification)) + in t.cast(t.Iterable[str], iter(GqlErrorClassification)) ): - self._gql_classification = GQLErrorClassification.UNKNOWN + self._gql_classification = GqlErrorClassification.UNKNOWN else: - self._gql_classification = GQLErrorClassification(classification) + self._gql_classification = GqlErrorClassification(classification) return self._gql_classification def _get_status_diagnostic_record(self) -> dict[str, t.Any]: @@ -491,14 +500,7 @@ def _get_status_diagnostic_record(self) -> dict[str, t.Any]: @property @_preview("GQLSTATUS support is a preview feature.") def diagnostic_record(self) -> Mapping[str, t.Any]: - """ - Further information about the GQLSTATUS for diagnostic purposes. - - **This is a preview**. - It might be changed without following the deprecation policy. - See also - https://github.com/neo4j/neo4j-python-driver/wiki/preview-features - """ + """Further information about the GQLSTATUS for diagnostic purposes.""" if hasattr(self, "_diagnostic_record"): return self._diagnostic_record @@ -521,7 +523,7 @@ def __str__(self): # Neo4jError -class Neo4jError(GQLError): +class Neo4jError(GqlError): """Raised when the Cypher engine returns an error to the client.""" _neo4j_code: str | None @@ -548,7 +550,7 @@ def __init__(self, *args) -> None: # self._neo4j_code.split(".") # ) # self._metadata = {} - # self._init_gql(message=_UNKNOWN_MESSAGE) + # self._init_gql() # TODO: 6.0 - Remove this alias @classmethod @@ -584,7 +586,7 @@ def _hydrate_neo4j(cls, **metadata: t.Any) -> Neo4jError: inst._init_gql( gql_status=_UNKNOWN_GQL_STATUS, message=message, - description=(f"error: {_UNKNOWN_GQL_DESCRIPTION}. {message}"), + description=f"error: {_UNKNOWN_GQL_DESCRIPTION}. {message}", ) inst._metadata = meta_extractor.rest() return inst @@ -622,16 +624,7 @@ def _hydrate_gql(cls, **metadata: t.Any) -> Neo4jError: return inst @classmethod - def _basic_hydrate( - cls, - *, - neo4j_code: str, - message: str, - # gql_status: str | None = None, - # status_description: str | None = None, - # diagnostic_record: dict[str, t.Any] | None = None, - # cause: GQLError | None = None, - ) -> Neo4jError: + def _basic_hydrate(cls, *, neo4j_code: str, message: str) -> Neo4jError: try: _, classification, category, title = neo4j_code.split(".") except ValueError: @@ -683,9 +676,9 @@ def _extract_error_class(cls, classification, code) -> type[Neo4jError]: @property def message(self) -> str | None: """ - TODO. + The error message returned by the server. - #: (str or None) The error message returned by the server. + This value is only :data:`None` for locally created errors. """ return self._message @@ -699,8 +692,8 @@ def code(self) -> str | None: """ The neo4j error code returned by the server. - .. deprecated:: 5.xx - Use :attr:`.neo4j_code` instead. + For example, "Neo.ClientError.Security.AuthorizationExpired". + This value is only :data:`None` for locally created errors. """ return self._neo4j_code @@ -712,19 +705,17 @@ def code(self, value: str) -> None: @property def classification(self) -> str | None: - # Undocumented, has been there before - # TODO 6.0: Remove this property + # Undocumented, will likely be removed with support for neo4j codes return self._classification @classification.setter - @deprecated("classification of Neo4jError is deprecated.") + @deprecated("Altering the classification of Neo4jError is deprecated.") def classification(self, value: str) -> None: self._classification = value @property def category(self) -> str | None: - # Undocumented, has been there before - # TODO 6.0: Remove this property + # Undocumented, will likely be removed with support for neo4j codes return self._category @category.setter @@ -733,10 +724,8 @@ def category(self, value: str) -> None: self._category = value @property - # @deprecated("title of Neo4jError is deprecated.") def title(self) -> str | None: - # Undocumented, has been there before - # TODO 6.0: Remove this property + # Undocumented, will likely be removed with support for neo4j codes return self._title @title.setter @@ -998,7 +987,7 @@ class ForbiddenOnReadOnlyDatabase(TransientError): # DriverError -class DriverError(GQLError): +class DriverError(GqlError): """Raised when the Driver raises an error.""" def is_retryable(self) -> bool: diff --git a/testkitbackend/totestkit.py b/testkitbackend/totestkit.py index 5b67b0494..e62c43f81 100644 --- a/testkitbackend/totestkit.py +++ b/testkitbackend/totestkit.py @@ -20,7 +20,7 @@ import neo4j from neo4j.exceptions import ( - GQLError, + GqlError, Neo4jError, ) from neo4j.graph import ( @@ -315,7 +315,7 @@ def driver_exc(exc, id_=None): payload["msg"] = _exc_msg(exc) if isinstance(exc, Neo4jError): payload["code"] = exc.code - if isinstance(exc, GQLError): + if isinstance(exc, GqlError): with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): payload["gqlStatus"] = exc.gql_status with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): @@ -334,21 +334,6 @@ def driver_exc(exc, id_=None): return {"name": "DriverError", "data": payload} -# def _exc_msg(exc, max_depth=10): -# if isinstance(exc, Neo4jError) and exc.message is not None: -# return str(exc.message) -# -# depth = 0 -# res = str(exc) -# while getattr(exc, "__cause__", None) is not None: -# depth += 1 -# if depth >= max_depth: -# break -# res += f"\n Caused by: {exc.__cause__!r}" -# exc = exc.__cause__ -# return res - - def _exc_msg(exc): if isinstance(exc, Neo4jError) and exc.message is not None: return str(exc.message) @@ -356,8 +341,8 @@ def _exc_msg(exc): def driver_exc_cause(exc): - if not isinstance(exc, GQLError): - raise TypeError("Expected GQLError as cause") + if not isinstance(exc, GqlError): + raise TypeError("Expected GqlError as cause") payload = {} with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): payload["msg"] = exc.message From ea89c7ab6e0a654fa9f0123796dc24a6b882bef2 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Wed, 25 Sep 2024 12:25:26 +0200 Subject: [PATCH 04/23] Improve resilience in diag. record enrichment for erroneous types --- src/neo4j/_async/io/_bolt5.py | 6 +++--- src/neo4j/_sync/io/_bolt5.py | 6 +++--- tests/unit/async_/io/test_class_bolt5x7.py | 5 ++--- tests/unit/sync/io/test_class_bolt5x7.py | 5 ++--- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/neo4j/_async/io/_bolt5.py b/src/neo4j/_async/io/_bolt5.py index 3932e9b5c..ad53a7f08 100644 --- a/src/neo4j/_async/io/_bolt5.py +++ b/src/neo4j/_async/io/_bolt5.py @@ -1094,9 +1094,9 @@ def _enrich_error_diagnostic_record(self, metadata): self.local_port, diag_record, ) - return - for key, value in self.DEFAULT_ERROR_DIAGNOSTIC_RECORD: - diag_record.setdefault(key, value) + else: + for key, value in self.DEFAULT_ERROR_DIAGNOSTIC_RECORD: + diag_record.setdefault(key, value) self._enrich_error_diagnostic_record(metadata.get("cause")) async def _process_message(self, tag, fields): diff --git a/src/neo4j/_sync/io/_bolt5.py b/src/neo4j/_sync/io/_bolt5.py index e20a1f62a..d755a4b2d 100644 --- a/src/neo4j/_sync/io/_bolt5.py +++ b/src/neo4j/_sync/io/_bolt5.py @@ -1094,9 +1094,9 @@ def _enrich_error_diagnostic_record(self, metadata): self.local_port, diag_record, ) - return - for key, value in self.DEFAULT_ERROR_DIAGNOSTIC_RECORD: - diag_record.setdefault(key, value) + else: + for key, value in self.DEFAULT_ERROR_DIAGNOSTIC_RECORD: + diag_record.setdefault(key, value) self._enrich_error_diagnostic_record(metadata.get("cause")) def _process_message(self, tag, fields): diff --git a/tests/unit/async_/io/test_class_bolt5x7.py b/tests/unit/async_/io/test_class_bolt5x7.py index bb9babbac..97a8b4eaa 100644 --- a/tests/unit/async_/io/test_class_bolt5x7.py +++ b/tests/unit/async_/io/test_class_bolt5x7.py @@ -841,11 +841,10 @@ def _build_error_hierarchy_metadata(diag_records_metadata): current_root = metadata for i, r in enumerate(diag_records_metadata[1:]): current_root["cause"] = { - "gql_status": f"BAR{i + 1:02}", "description": f"error cause nr. {i + 1}", "message": f"cause message {i + 1}", } - if r is not ...: - metadata["diagnostic_record"] = r current_root = current_root["cause"] + if r is not ...: + current_root["diagnostic_record"] = r return metadata diff --git a/tests/unit/sync/io/test_class_bolt5x7.py b/tests/unit/sync/io/test_class_bolt5x7.py index 6aaf8502f..cf999cc65 100644 --- a/tests/unit/sync/io/test_class_bolt5x7.py +++ b/tests/unit/sync/io/test_class_bolt5x7.py @@ -841,11 +841,10 @@ def _build_error_hierarchy_metadata(diag_records_metadata): current_root = metadata for i, r in enumerate(diag_records_metadata[1:]): current_root["cause"] = { - "gql_status": f"BAR{i + 1:02}", "description": f"error cause nr. {i + 1}", "message": f"cause message {i + 1}", } - if r is not ...: - metadata["diagnostic_record"] = r current_root = current_root["cause"] + if r is not ...: + current_root["diagnostic_record"] = r return metadata From a0632e3b6d006918dedd8c12561371f2a21aa20c Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 27 Sep 2024 10:49:36 +0200 Subject: [PATCH 05/23] Fix struct hydration in FAILURE responses --- src/neo4j/_async/io/_bolt.py | 10 ++++++ src/neo4j/_async/io/_bolt3.py | 30 ++++++++++++++++++ src/neo4j/_async/io/_bolt4.py | 45 +++++++++++++++++++++++++++ src/neo4j/_async/io/_bolt5.py | 58 +++++++++++++++++++++++++++++++++++ src/neo4j/_sync/io/_bolt.py | 10 ++++++ src/neo4j/_sync/io/_bolt3.py | 30 ++++++++++++++++++ src/neo4j/_sync/io/_bolt4.py | 45 +++++++++++++++++++++++++++ src/neo4j/_sync/io/_bolt5.py | 58 +++++++++++++++++++++++++++++++++++ 8 files changed, 286 insertions(+) diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index 002c046ee..339c065f8 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -916,6 +916,16 @@ def goodbye(self, dehydration_hooks=None, hydration_hooks=None): def new_hydration_scope(self): return self.hydration_handler.new_hydration_scope() + def _default_hydration_hooks(self, dehydration_hooks, hydration_hooks): + if dehydration_hooks is not None and hydration_hooks is not None: + return dehydration_hooks, hydration_hooks + hydration_scope = self.new_hydration_scope() + if dehydration_hooks is None: + dehydration_hooks = hydration_scope.dehydration_hooks + if hydration_hooks is None: + hydration_hooks = hydration_scope.hydration_hooks + return dehydration_hooks, hydration_hooks + def _append( self, signature, fields=(), response=None, dehydration_hooks=None ): diff --git a/src/neo4j/_async/io/_bolt3.py b/src/neo4j/_async/io/_bolt3.py index 1617f0d79..08e75abb2 100644 --- a/src/neo4j/_async/io/_bolt3.py +++ b/src/neo4j/_async/io/_bolt3.py @@ -215,6 +215,9 @@ async def hello(self, dehydration_hooks=None, hydration_hooks=None): or self.notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) headers = self.get_base_headers() headers.update(self.auth_dict) logged_headers = dict(headers) @@ -275,6 +278,9 @@ async def route( f"{self.PROTOCOL_VERSION!r}. Trying to impersonate " f"{imp_user!r}." ) + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) metadata = {} records = [] @@ -337,6 +343,9 @@ def run( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) if not parameters: parameters = {} extra = {} @@ -379,6 +388,9 @@ def discard( **handlers, ): # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: DISCARD_ALL", self.local_port) self._append( b"\x2f", @@ -396,6 +408,9 @@ def pull( **handlers, ): # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: PULL_ALL", self.local_port) self._append( b"\x3f", @@ -435,6 +450,9 @@ def begin( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {} if mode in {READ_ACCESS, "r"}: # It will default to mode "w" if nothing is specified @@ -464,6 +482,9 @@ def begin( ) def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: COMMIT", self.local_port) self._append( b"\x12", @@ -475,6 +496,9 @@ def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): def rollback( self, dehydration_hooks=None, hydration_hooks=None, **handlers ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: ROLLBACK", self.local_port) self._append( b"\x13", @@ -490,6 +514,9 @@ async def reset(self, dehydration_hooks=None, hydration_hooks=None): Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: RESET", self.local_port) response = ResetResponse(self, "reset", hydration_hooks) self._append( @@ -499,6 +526,9 @@ async def reset(self, dehydration_hooks=None, hydration_hooks=None): await self.fetch_all() def goodbye(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: GOODBYE", self.local_port) self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) diff --git a/src/neo4j/_async/io/_bolt4.py b/src/neo4j/_async/io/_bolt4.py index 9be53e75d..202d55707 100644 --- a/src/neo4j/_async/io/_bolt4.py +++ b/src/neo4j/_async/io/_bolt4.py @@ -131,6 +131,9 @@ async def hello(self, dehydration_hooks=None, hydration_hooks=None): or self.notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) headers = self.get_base_headers() headers.update(self.auth_dict) logged_headers = dict(headers) @@ -184,6 +187,9 @@ async def route( f"{self.PROTOCOL_VERSION!r}. Trying to impersonate " f"{imp_user!r}." ) + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) metadata = {} records = [] @@ -244,6 +250,9 @@ def run( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) if not parameters: parameters = {} extra = {} @@ -292,6 +301,9 @@ def discard( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {"n": n} if qid != -1: extra["qid"] = qid @@ -311,6 +323,9 @@ def pull( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {"n": n} if qid != -1: extra["qid"] = qid @@ -347,6 +362,9 @@ def begin( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {} if mode in {READ_ACCESS, "r"}: # It will default to mode "w" if nothing is specified @@ -379,6 +397,9 @@ def begin( ) def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: COMMIT", self.local_port) self._append( b"\x12", @@ -390,6 +411,9 @@ def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): def rollback( self, dehydration_hooks=None, hydration_hooks=None, **handlers ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: ROLLBACK", self.local_port) self._append( b"\x13", @@ -405,6 +429,9 @@ async def reset(self, dehydration_hooks=None, hydration_hooks=None): Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: RESET", self.local_port) response = ResetResponse(self, "reset", hydration_hooks) self._append( @@ -414,6 +441,9 @@ async def reset(self, dehydration_hooks=None, hydration_hooks=None): await self.fetch_all() def goodbye(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: GOODBYE", self.local_port) self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) @@ -547,6 +577,9 @@ async def route( f"{self.PROTOCOL_VERSION!r}. Trying to impersonate " f"{imp_user!r}." ) + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) routing_context = self.routing_context or {} log.debug( @@ -576,6 +609,9 @@ async def hello(self, dehydration_hooks=None, hydration_hooks=None): or self.notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) @@ -635,6 +671,9 @@ async def route( dehydration_hooks=None, hydration_hooks=None, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) routing_context = self.routing_context or {} db_context = {} if database is not None: @@ -683,6 +722,9 @@ def run( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) if not parameters: parameters = {} extra = {} @@ -744,6 +786,9 @@ def begin( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {} if mode in {READ_ACCESS, "r"}: # It will default to mode "w" if nothing is specified diff --git a/src/neo4j/_async/io/_bolt5.py b/src/neo4j/_async/io/_bolt5.py index ad53a7f08..04426e12d 100644 --- a/src/neo4j/_async/io/_bolt5.py +++ b/src/neo4j/_async/io/_bolt5.py @@ -136,6 +136,9 @@ async def hello(self, dehydration_hooks=None, hydration_hooks=None): or self.notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) @@ -200,6 +203,9 @@ async def route( dehydration_hooks=None, hydration_hooks=None, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) routing_context = self.routing_context or {} db_context = {} if database is not None: @@ -248,6 +254,9 @@ def run( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) if not parameters: parameters = {} extra = {} @@ -298,6 +307,9 @@ def discard( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {"n": n} if qid != -1: extra["qid"] = qid @@ -317,6 +329,9 @@ def pull( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {"n": n} if qid != -1: extra["qid"] = qid @@ -347,6 +362,9 @@ def begin( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {} if mode in {READ_ACCESS, "r"}: # It will default to mode "w" if nothing is specified @@ -381,6 +399,9 @@ def begin( ) def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: COMMIT", self.local_port) self._append( b"\x12", @@ -392,6 +413,9 @@ def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): def rollback( self, dehydration_hooks=None, hydration_hooks=None, **handlers ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: ROLLBACK", self.local_port) self._append( b"\x13", @@ -407,6 +431,9 @@ async def reset(self, dehydration_hooks=None, hydration_hooks=None): Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: RESET", self.local_port) response = ResetResponse(self, "reset", hydration_hooks) self._append( @@ -416,6 +443,9 @@ async def reset(self, dehydration_hooks=None, hydration_hooks=None): await self.fetch_all() def goodbye(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: GOODBYE", self.local_port) self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) @@ -580,6 +610,9 @@ async def hello(self, dehydration_hooks=None, hydration_hooks=None): or self.notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) @@ -620,6 +653,9 @@ def on_success(metadata): check_supported_server_product(self.server_info.agent) def logon(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) logged_auth_dict = dict(self.auth_dict) if "credentials" in logged_auth_dict: logged_auth_dict["credentials"] = "*******" @@ -632,6 +668,9 @@ def logon(self, dehydration_hooks=None, hydration_hooks=None): ) def logoff(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: LOGOFF", self.local_port) self._append( b"\x6b", @@ -658,6 +697,10 @@ def get_base_headers(self): return headers async def hello(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) @@ -709,6 +752,9 @@ def run( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) if not parameters: parameters = {} extra = {} @@ -773,6 +819,9 @@ def begin( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {} if mode in {READ_ACCESS, "r"}: # It will default to mode "w" if nothing is specified @@ -838,6 +887,9 @@ def telemetry( "telemetry.enabled", False ): return + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) api_raw = int(api) log.debug( "[#%04X] C: TELEMETRY %i # (%r)", self.local_port, api_raw, api @@ -877,6 +929,9 @@ def run( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) if not parameters: parameters = {} extra = {} @@ -941,6 +996,9 @@ def begin( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {} if mode in {READ_ACCESS, "r"}: # It will default to mode "w" if nothing is specified diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 4503c6184..f1176ba02 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -916,6 +916,16 @@ def goodbye(self, dehydration_hooks=None, hydration_hooks=None): def new_hydration_scope(self): return self.hydration_handler.new_hydration_scope() + def _default_hydration_hooks(self, dehydration_hooks, hydration_hooks): + if dehydration_hooks is not None and hydration_hooks is not None: + return dehydration_hooks, hydration_hooks + hydration_scope = self.new_hydration_scope() + if dehydration_hooks is None: + dehydration_hooks = hydration_scope.dehydration_hooks + if hydration_hooks is None: + hydration_hooks = hydration_scope.hydration_hooks + return dehydration_hooks, hydration_hooks + def _append( self, signature, fields=(), response=None, dehydration_hooks=None ): diff --git a/src/neo4j/_sync/io/_bolt3.py b/src/neo4j/_sync/io/_bolt3.py index 2847781e7..e3cfd1429 100644 --- a/src/neo4j/_sync/io/_bolt3.py +++ b/src/neo4j/_sync/io/_bolt3.py @@ -215,6 +215,9 @@ def hello(self, dehydration_hooks=None, hydration_hooks=None): or self.notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) headers = self.get_base_headers() headers.update(self.auth_dict) logged_headers = dict(headers) @@ -275,6 +278,9 @@ def route( f"{self.PROTOCOL_VERSION!r}. Trying to impersonate " f"{imp_user!r}." ) + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) metadata = {} records = [] @@ -337,6 +343,9 @@ def run( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) if not parameters: parameters = {} extra = {} @@ -379,6 +388,9 @@ def discard( **handlers, ): # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: DISCARD_ALL", self.local_port) self._append( b"\x2f", @@ -396,6 +408,9 @@ def pull( **handlers, ): # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: PULL_ALL", self.local_port) self._append( b"\x3f", @@ -435,6 +450,9 @@ def begin( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {} if mode in {READ_ACCESS, "r"}: # It will default to mode "w" if nothing is specified @@ -464,6 +482,9 @@ def begin( ) def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: COMMIT", self.local_port) self._append( b"\x12", @@ -475,6 +496,9 @@ def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): def rollback( self, dehydration_hooks=None, hydration_hooks=None, **handlers ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: ROLLBACK", self.local_port) self._append( b"\x13", @@ -490,6 +514,9 @@ def reset(self, dehydration_hooks=None, hydration_hooks=None): Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: RESET", self.local_port) response = ResetResponse(self, "reset", hydration_hooks) self._append( @@ -499,6 +526,9 @@ def reset(self, dehydration_hooks=None, hydration_hooks=None): self.fetch_all() def goodbye(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: GOODBYE", self.local_port) self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) diff --git a/src/neo4j/_sync/io/_bolt4.py b/src/neo4j/_sync/io/_bolt4.py index 023d16d8d..69bb6dd6e 100644 --- a/src/neo4j/_sync/io/_bolt4.py +++ b/src/neo4j/_sync/io/_bolt4.py @@ -131,6 +131,9 @@ def hello(self, dehydration_hooks=None, hydration_hooks=None): or self.notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) headers = self.get_base_headers() headers.update(self.auth_dict) logged_headers = dict(headers) @@ -184,6 +187,9 @@ def route( f"{self.PROTOCOL_VERSION!r}. Trying to impersonate " f"{imp_user!r}." ) + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) metadata = {} records = [] @@ -244,6 +250,9 @@ def run( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) if not parameters: parameters = {} extra = {} @@ -292,6 +301,9 @@ def discard( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {"n": n} if qid != -1: extra["qid"] = qid @@ -311,6 +323,9 @@ def pull( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {"n": n} if qid != -1: extra["qid"] = qid @@ -347,6 +362,9 @@ def begin( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {} if mode in {READ_ACCESS, "r"}: # It will default to mode "w" if nothing is specified @@ -379,6 +397,9 @@ def begin( ) def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: COMMIT", self.local_port) self._append( b"\x12", @@ -390,6 +411,9 @@ def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): def rollback( self, dehydration_hooks=None, hydration_hooks=None, **handlers ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: ROLLBACK", self.local_port) self._append( b"\x13", @@ -405,6 +429,9 @@ def reset(self, dehydration_hooks=None, hydration_hooks=None): Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: RESET", self.local_port) response = ResetResponse(self, "reset", hydration_hooks) self._append( @@ -414,6 +441,9 @@ def reset(self, dehydration_hooks=None, hydration_hooks=None): self.fetch_all() def goodbye(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: GOODBYE", self.local_port) self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) @@ -547,6 +577,9 @@ def route( f"{self.PROTOCOL_VERSION!r}. Trying to impersonate " f"{imp_user!r}." ) + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) routing_context = self.routing_context or {} log.debug( @@ -576,6 +609,9 @@ def hello(self, dehydration_hooks=None, hydration_hooks=None): or self.notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) @@ -635,6 +671,9 @@ def route( dehydration_hooks=None, hydration_hooks=None, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) routing_context = self.routing_context or {} db_context = {} if database is not None: @@ -683,6 +722,9 @@ def run( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) if not parameters: parameters = {} extra = {} @@ -744,6 +786,9 @@ def begin( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {} if mode in {READ_ACCESS, "r"}: # It will default to mode "w" if nothing is specified diff --git a/src/neo4j/_sync/io/_bolt5.py b/src/neo4j/_sync/io/_bolt5.py index d755a4b2d..9ea52c78d 100644 --- a/src/neo4j/_sync/io/_bolt5.py +++ b/src/neo4j/_sync/io/_bolt5.py @@ -136,6 +136,9 @@ def hello(self, dehydration_hooks=None, hydration_hooks=None): or self.notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) @@ -200,6 +203,9 @@ def route( dehydration_hooks=None, hydration_hooks=None, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) routing_context = self.routing_context or {} db_context = {} if database is not None: @@ -248,6 +254,9 @@ def run( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) if not parameters: parameters = {} extra = {} @@ -298,6 +307,9 @@ def discard( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {"n": n} if qid != -1: extra["qid"] = qid @@ -317,6 +329,9 @@ def pull( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {"n": n} if qid != -1: extra["qid"] = qid @@ -347,6 +362,9 @@ def begin( or notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {} if mode in {READ_ACCESS, "r"}: # It will default to mode "w" if nothing is specified @@ -381,6 +399,9 @@ def begin( ) def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: COMMIT", self.local_port) self._append( b"\x12", @@ -392,6 +413,9 @@ def commit(self, dehydration_hooks=None, hydration_hooks=None, **handlers): def rollback( self, dehydration_hooks=None, hydration_hooks=None, **handlers ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: ROLLBACK", self.local_port) self._append( b"\x13", @@ -407,6 +431,9 @@ def reset(self, dehydration_hooks=None, hydration_hooks=None): Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: RESET", self.local_port) response = ResetResponse(self, "reset", hydration_hooks) self._append( @@ -416,6 +443,9 @@ def reset(self, dehydration_hooks=None, hydration_hooks=None): self.fetch_all() def goodbye(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: GOODBYE", self.local_port) self._append(b"\x02", (), dehydration_hooks=dehydration_hooks) @@ -580,6 +610,9 @@ def hello(self, dehydration_hooks=None, hydration_hooks=None): or self.notifications_disabled_classifications is not None ): self.assert_notification_filtering_support() + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) @@ -620,6 +653,9 @@ def on_success(metadata): check_supported_server_product(self.server_info.agent) def logon(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) logged_auth_dict = dict(self.auth_dict) if "credentials" in logged_auth_dict: logged_auth_dict["credentials"] = "*******" @@ -632,6 +668,9 @@ def logon(self, dehydration_hooks=None, hydration_hooks=None): ) def logoff(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) log.debug("[#%04X] C: LOGOFF", self.local_port) self._append( b"\x6b", @@ -658,6 +697,10 @@ def get_base_headers(self): return headers def hello(self, dehydration_hooks=None, hydration_hooks=None): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) + def on_success(metadata): self.configuration_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) @@ -709,6 +752,9 @@ def run( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) if not parameters: parameters = {} extra = {} @@ -773,6 +819,9 @@ def begin( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {} if mode in {READ_ACCESS, "r"}: # It will default to mode "w" if nothing is specified @@ -838,6 +887,9 @@ def telemetry( "telemetry.enabled", False ): return + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) api_raw = int(api) log.debug( "[#%04X] C: TELEMETRY %i # (%r)", self.local_port, api_raw, api @@ -877,6 +929,9 @@ def run( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) if not parameters: parameters = {} extra = {} @@ -941,6 +996,9 @@ def begin( hydration_hooks=None, **handlers, ): + dehydration_hooks, hydration_hooks = self._default_hydration_hooks( + dehydration_hooks, hydration_hooks + ) extra = {} if mode in {READ_ACCESS, "r"}: # It will default to mode "w" if nothing is specified From ba3c2a21ed7f456a2a559f4799dcec07f7bc742f Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 27 Sep 2024 11:35:22 +0200 Subject: [PATCH 06/23] Update GQL error fallback description and message --- src/neo4j/exceptions.py | 41 ++++++++++++++--------------------------- 1 file changed, 14 insertions(+), 27 deletions(-) diff --git a/src/neo4j/exceptions.py b/src/neo4j/exceptions.py index 0727ea86d..04dff1d8c 100644 --- a/src/neo4j/exceptions.py +++ b/src/neo4j/exceptions.py @@ -219,11 +219,11 @@ _UNKNOWN_MESSAGE: te.Final[str] = "An unknown error occurred" _UNKNOWN_GQL_STATUS: te.Final[str] = "50N42" _UNKNOWN_GQL_DESCRIPTION: te.Final[str] = ( - "general processing exception - unknown error" + "error: general processing exception - unexpected error" ) -# FIXME: _UNKNOWN_GQL_MESSAGE needs final format _UNKNOWN_GQL_MESSAGE: te.Final[str] = ( - f"{_UNKNOWN_GQL_STATUS}: {_UNKNOWN_GQL_DESCRIPTION}. {_UNKNOWN_MESSAGE}" + f"{_UNKNOWN_GQL_STATUS}: " + "Unexpected error has occurred. See debug log for details." ) _UNKNOWN_GQL_DIAGNOSTIC_RECORD: te.Final[tuple[tuple[str, t.Any], ...]] = ( ("OPERATION", ""), @@ -326,7 +326,7 @@ def _init_gql( if gql_status is None or message is None or description is None: self._gql_status = _UNKNOWN_GQL_STATUS self._message = _UNKNOWN_GQL_MESSAGE - self._gql_status_description = f"error: {_UNKNOWN_GQL_DESCRIPTION}" + self._gql_status_description = _UNKNOWN_GQL_DESCRIPTION else: self._gql_status = gql_status self._message = message @@ -338,21 +338,7 @@ def _init_gql( def _set_unknown_gql(self): self._gql_status = _UNKNOWN_GQL_STATUS self._message = _UNKNOWN_GQL_MESSAGE - self._gql_status_description = f"error: {_UNKNOWN_GQL_DESCRIPTION}" - - @staticmethod - def _format_message_details( - gql_status: str | None, - description: str | None, - details: str | None, - ): - if gql_status is None: - gql_status = _UNKNOWN_GQL_STATUS - if description is None: - description = _UNKNOWN_GQL_DESCRIPTION - if details is None: - return f"{gql_status}: {description}" - return f"{gql_status}: {description}. {details}" + self._gql_status_description = _UNKNOWN_GQL_DESCRIPTION def __setattr__(self, key, value): if key == "__cause__": @@ -573,12 +559,8 @@ def hydrate( @classmethod def _hydrate_neo4j(cls, **metadata: t.Any) -> Neo4jError: meta_extractor = _MetaExtractor(metadata) - code = meta_extractor.str_value("code") - if not code: - code = _UNKNOWN_NEO4J_CODE - message = meta_extractor.str_value("message") - if not message: - message = _UNKNOWN_MESSAGE + code = meta_extractor.str_value("code") or _UNKNOWN_NEO4J_CODE + message = meta_extractor.str_value("message") or _UNKNOWN_MESSAGE inst = cls._basic_hydrate( neo4j_code=code, message=message, @@ -586,7 +568,7 @@ def _hydrate_neo4j(cls, **metadata: t.Any) -> Neo4jError: inst._init_gql( gql_status=_UNKNOWN_GQL_STATUS, message=message, - description=f"error: {_UNKNOWN_GQL_DESCRIPTION}. {message}", + description=f"{_UNKNOWN_GQL_DESCRIPTION}. {message}", ) inst._metadata = meta_extractor.rest() return inst @@ -596,7 +578,12 @@ def _hydrate_gql(cls, **metadata: t.Any) -> Neo4jError: meta_extractor = _MetaExtractor(metadata) gql_status = meta_extractor.str_value("gql_status") status_description = meta_extractor.str_value("description") - message = meta_extractor.str_value("message", _UNKNOWN_GQL_MESSAGE) + message = meta_extractor.str_value("message") + if gql_status is None or status_description is None or message is None: + gql_status = _UNKNOWN_GQL_STATUS + # TODO: 6.0 - Make this fall back to _UNKNOWN_GQL_MESSAGE + message = _UNKNOWN_MESSAGE + status_description = _UNKNOWN_GQL_DESCRIPTION neo4j_code = meta_extractor.str_value( "neo4j_code", _UNKNOWN_NEO4J_CODE, From 21090f55c27ee78aabc78e1199823ed8b960fc55 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Wed, 2 Oct 2024 15:49:00 +0200 Subject: [PATCH 07/23] Small clean-ups --- src/neo4j/exceptions.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/src/neo4j/exceptions.py b/src/neo4j/exceptions.py index 04dff1d8c..81441eda5 100644 --- a/src/neo4j/exceptions.py +++ b/src/neo4j/exceptions.py @@ -341,11 +341,25 @@ def _set_unknown_gql(self): self._gql_status_description = _UNKNOWN_GQL_DESCRIPTION def __setattr__(self, key, value): - if key == "__cause__": - raise AttributeError( - "Cannot set __cause__ on GqlError or `raise ... from ...`." - ) - super().__setattr__(key, value) + if key != "__cause__" or getattr(self, "__cause__", None) is None: + super().__setattr__(key, value) + # If the GqlError already has a cause, which might have been set by the + # server, we don't want to overwrite it with for instance + # `raise some_gql_error from some_local_error`. Therefore, we traverse + # the cause chain and append the local cause. + root = self.__cause__ + seen_errors = {id(self), id(root)} + while True: + cause = getattr(root, "__cause__", None) + if cause is None: + root.__cause__ = value + return + root = cause + if id(root) in seen_errors: + # Circular cause chain -> we have no choice but to either + # overwrite the cause or ignore the new one. + return + seen_errors.add(id(root)) @property def _gql_status_no_preview(self) -> str: @@ -541,9 +555,8 @@ def __init__(self, *args) -> None: # TODO: 6.0 - Remove this alias @classmethod @deprecated( - "Neo4jError.hydrate is deprecated and will be " - "removed in a future version. It is an internal method and not meant " - "for external use." + "Neo4jError.hydrate is deprecated and will be removed in a future " + "version. It is an internal method and not meant for external use." ) def hydrate( cls, @@ -728,7 +741,6 @@ def metadata(self) -> dict[str, t.Any] | None: @metadata.setter @deprecated("Altering the metadata of Neo4jError is deprecated.") def metadata(self, value: dict[str, t.Any]) -> None: - # TODO 6.0: Remove this property self._metadata = value # TODO: 6.0 - Remove this alias @@ -815,7 +827,7 @@ def __str__(self): message = self._message if code or message: return f"{{code: {code}}} {{message: {message}}}" - # TODO: 6.0 - User gql status and status_description instead + # TODO: 6.0 - Use gql status and status_description instead # something like: # return ( # f"{{gql_status: {self.gql_status}}} " From ba3c97adbba82331f22e2dd085ead289b75a1081 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 3 Oct 2024 12:33:25 +0200 Subject: [PATCH 08/23] Fix GqlError cause handling + add tests for it --- src/neo4j/exceptions.py | 82 ++++++++------ tests/unit/common/test_exceptions.py | 157 +++++++++++++++++++++++++++ 2 files changed, 205 insertions(+), 34 deletions(-) diff --git a/src/neo4j/exceptions.py b/src/neo4j/exceptions.py index 81441eda5..6210e520e 100644 --- a/src/neo4j/exceptions.py +++ b/src/neo4j/exceptions.py @@ -290,7 +290,10 @@ class GqlError(Exception): _gql_classification: GqlErrorClassification _status_diagnostic_record: dict[str, t.Any] # original, internal only _diagnostic_record: dict[str, t.Any] # copy to be used externally - __cause__: GqlError | None + _gql_cause: GqlError | None + + def __init__(self, *args): + super().__init__(*args) @staticmethod def _hydrate_cause(**metadata: t.Any) -> GqlError: @@ -333,34 +336,54 @@ def _init_gql( self._gql_status_description = description if diagnostic_record is not None: self._status_diagnostic_record = diagnostic_record - super().__setattr__("__cause__", cause) + self._gql_cause = cause def _set_unknown_gql(self): self._gql_status = _UNKNOWN_GQL_STATUS self._message = _UNKNOWN_GQL_MESSAGE self._gql_status_description = _UNKNOWN_GQL_DESCRIPTION - def __setattr__(self, key, value): - if key != "__cause__" or getattr(self, "__cause__", None) is None: - super().__setattr__(key, value) - # If the GqlError already has a cause, which might have been set by the - # server, we don't want to overwrite it with for instance - # `raise some_gql_error from some_local_error`. Therefore, we traverse - # the cause chain and append the local cause. - root = self.__cause__ + def __getattribute__(self, item): + if item != "__cause__": + return super().__getattribute__(item) + gql_cause = self._get_attr_or_none("_gql_cause") + if gql_cause is None: + # No GQL cause, no magic needed + return super().__getattribute__(item) + local_cause = self._get_attr_or_none("__cause__") + if local_cause is None: + # We have a GQL cause but no local cause + # => set the GQL cause as the local cause + self.__cause__ = gql_cause + self.__suppress_context__ = True + self._gql_cause = None + return super().__getattribute__(item) + # We have both a GQL cause and a local cause + # => traverse the cause chain and append the local cause. + root = gql_cause seen_errors = {id(self), id(root)} while True: cause = getattr(root, "__cause__", None) if cause is None: - root.__cause__ = value - return + root.__cause__ = local_cause + root.__suppress_context__ = True + self.__cause__ = gql_cause + self.__suppress_context__ = True + self._gql_cause = None + return gql_cause root = cause if id(root) in seen_errors: # Circular cause chain -> we have no choice but to either # overwrite the cause or ignore the new one. - return + return local_cause seen_errors.add(id(root)) + def _get_attr_or_none(self, item): + try: + return super().__getattribute__(item) + except AttributeError: + return None + @property def _gql_status_no_preview(self) -> str: if hasattr(self, "_gql_status"): @@ -459,23 +482,7 @@ def gql_raw_classification(self) -> str | None: return self._gql_raw_classification_no_preview @property - @_preview("GQLSTATUS support is a preview feature.") - def gql_classification(self) -> GqlErrorClassification: - """ - Vendor specific classification of the error. - - This is the parsed version of :attr:`.gql_raw_classification`. - :attr:`GqlErrorClassification.UNKNOWN` is returned if the - classification is missing, invalid, or has an unknown value - (e.g., a newer version of the server introduced a new value). - - **This is a preview**. - It might be changed without following the deprecation policy. - See also - https://github.com/neo4j/neo4j-python-driver/wiki/preview-features - - .. seealso:: :attr:`.gql_raw_classification` - """ + def _gql_classification_no_preview(self) -> GqlErrorClassification: if hasattr(self, "_gql_classification"): return self._gql_classification @@ -490,6 +497,11 @@ def gql_classification(self) -> GqlErrorClassification: self._gql_classification = GqlErrorClassification(classification) return self._gql_classification + @property + @_preview("GQLSTATUS support is a preview feature.") + def gql_classification(self) -> GqlErrorClassification: + return self._gql_classification_no_preview + def _get_status_diagnostic_record(self) -> dict[str, t.Any]: if hasattr(self, "_status_diagnostic_record"): return self._status_diagnostic_record @@ -498,9 +510,7 @@ def _get_status_diagnostic_record(self) -> dict[str, t.Any]: return self._status_diagnostic_record @property - @_preview("GQLSTATUS support is a preview feature.") - def diagnostic_record(self) -> Mapping[str, t.Any]: - """Further information about the GQLSTATUS for diagnostic purposes.""" + def _diagnostic_record_no_preview(self) -> Mapping[str, t.Any]: if hasattr(self, "_diagnostic_record"): return self._diagnostic_record @@ -509,14 +519,18 @@ def diagnostic_record(self) -> Mapping[str, t.Any]: ) return self._diagnostic_record + @property @_preview("GQLSTATUS support is a preview feature.") + def diagnostic_record(self) -> Mapping[str, t.Any]: + return self._diagnostic_record_no_preview + def __str__(self): return ( f"{{gql_status: {self._gql_status_no_preview}}} " f"{{gql_status_description: " f"{self._gql_status_description_no_preview}}} " f"{{message: {self._message_no_preview}}} " - f"{{diagnostic_record: {self.diagnostic_record}}} " + f"{{diagnostic_record: {self._diagnostic_record_no_preview}}} " f"{{raw_classification: " f"{self._gql_raw_classification_no_preview}}}" ) diff --git a/tests/unit/common/test_exceptions.py b/tests/unit/common/test_exceptions.py index 882e75141..a433bc030 100644 --- a/tests/unit/common/test_exceptions.py +++ b/tests/unit/common/test_exceptions.py @@ -14,6 +14,11 @@ # limitations under the License. +from __future__ import annotations + +import contextlib +import traceback + import pytest from neo4j._exceptions import ( @@ -28,6 +33,7 @@ CLASSIFICATION_TRANSIENT, ClientError, DatabaseError, + GqlError, Neo4jError, ServiceUnavailable, TransientError, @@ -322,3 +328,154 @@ def test_neo4j_error_from_code_as_str(cls): assert type(error) is cls assert str(error) == "Generated somewhere in the driver" + + +def _make_test_gql_error( + identifier: str, + cause: GqlError | None = None, +) -> GqlError: + error = GqlError(identifier) + error._init_gql( + gql_status=f"{identifier[:5].upper():<05}", + description=f"error: $h!t went down - {identifier}", + message=identifier, + cause=cause, + ) + return error + + +def _set_error_cause(exc, cause, method="set") -> None: + if method == "set": + exc.__cause__ = cause + elif method == "raise": + with contextlib.suppress(exc.__class__): + raise exc from cause + else: + raise ValueError(f"Invalid cause set method {method!r}") + + +_CYCLIC_CAUSE_MARKER = object() + + +def _assert_error_chain( + exc: BaseException, + expected: list[object], +) -> None: + assert isinstance(exc, BaseException) + + collection_root: BaseException | None = exc + actual_chain: list[object] = [exc] + actual_chain_ids = [id(exc)] + while collection_root is not None: + cause = getattr(collection_root, "__cause__", None) + if id(cause) in actual_chain_ids: + actual_chain.append(_CYCLIC_CAUSE_MARKER) + actual_chain_ids.append(id(_CYCLIC_CAUSE_MARKER)) + break + actual_chain.append(cause) + actual_chain_ids.append(id(cause)) + collection_root = cause + + assert actual_chain_ids == list(map(id, expected)) + + expected_lines = [ + str(exc) + for exc in expected + if exc is not None and exc is not _CYCLIC_CAUSE_MARKER + ] + expected_lines.reverse() + exc_fmt = traceback.format_exception(type(exc), exc, exc.__traceback__) + for line in exc_fmt: + if not expected_lines: + break + if expected_lines[0] in line: + expected_lines.pop(0) + if expected_lines: + traceback_fmt = "".join(exc_fmt) + pytest.fail( + f"Expected lines not found: {expected_lines} in traceback:\n" + f"{traceback_fmt}" + ) + + +def test_cause_chain_extension_no_cause() -> None: + root = _make_test_gql_error("root") + + _assert_error_chain(root, [root, None]) + + +def test_cause_chain_extension_only_gql_cause() -> None: + root_cause = _make_test_gql_error("rootCause") + root = _make_test_gql_error("root", cause=root_cause) + + _assert_error_chain(root, [root, root_cause, None]) + + +@pytest.mark.parametrize("local_cause_method", ("raise", "set")) +def test_cause_chain_extension_only_local_cause(local_cause_method) -> None: + root_cause = ClientError("rootCause") + root = _make_test_gql_error("root") + _set_error_cause(root, root_cause, local_cause_method) + + _assert_error_chain(root, [root, root_cause, None]) + + +@pytest.mark.parametrize("local_cause_method", ("raise", "set")) +def test_cause_chain_extension_multiple_causes(local_cause_method) -> None: + root4_cause2 = _make_test_gql_error("r4c2") + root4_cause1 = _make_test_gql_error("r4c1", cause=root4_cause2) + root4 = _make_test_gql_error("root4", cause=root4_cause1) + root3 = ClientError("root3") + _set_error_cause(root3, root4, local_cause_method) + root2_cause3 = _make_test_gql_error("r2c3") + root2_cause2 = _make_test_gql_error("r2c2", cause=root2_cause3) + root2_cause1 = _make_test_gql_error("r2c1", cause=root2_cause2) + root2 = _make_test_gql_error("root2", cause=root2_cause1) + _set_error_cause(root2, root3, local_cause_method) + root1_cause2 = _make_test_gql_error("r1c2") + root1_cause1 = _make_test_gql_error("r1c1", cause=root1_cause2) + root1 = _make_test_gql_error("root1", cause=root1_cause1) + _set_error_cause(root1, root2, local_cause_method) + + _assert_error_chain( + root1, [ + root1, root1_cause1, root1_cause2, + root2, root2_cause1, root2_cause2, root2_cause3, + root3, + root4, root4_cause1, root4_cause2, + None, + ], + ) # fmt: skip + + +@pytest.mark.parametrize("local_cause_method", ("raise", "set")) +def test_cause_chain_extension_circular_local_causes( + local_cause_method, +) -> None: + root6 = ClientError("root6") + root5 = _make_test_gql_error("root5") + _set_error_cause(root5, root6, local_cause_method) + root4_cause = _make_test_gql_error("r4c") + root4 = _make_test_gql_error("root4", cause=root4_cause) + _set_error_cause(root4, root5, local_cause_method) + root3 = ClientError("root3") + _set_error_cause(root3, root4, local_cause_method) + root2 = _make_test_gql_error("root2") + _set_error_cause(root2, root3, local_cause_method) + root1 = ClientError("root1") + _set_error_cause(root1, root2, local_cause_method) + _set_error_cause(root6, root1, local_cause_method) + + _assert_error_chain( + root1, + [ + root1, + root2, + root3, + root4, + root4_cause, + root5, + root6, + _CYCLIC_CAUSE_MARKER, + ], + ) From 2c1ffc13c8daa8e1aa97c646cb46efb948ff603e Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 3 Oct 2024 15:43:30 +0200 Subject: [PATCH 09/23] More unit tests --- src/neo4j/exceptions.py | 3 - tests/unit/common/test_exceptions.py | 317 +++++++++++++++++++++++++-- 2 files changed, 293 insertions(+), 27 deletions(-) diff --git a/src/neo4j/exceptions.py b/src/neo4j/exceptions.py index 6210e520e..f225f2a52 100644 --- a/src/neo4j/exceptions.py +++ b/src/neo4j/exceptions.py @@ -292,9 +292,6 @@ class GqlError(Exception): _diagnostic_record: dict[str, t.Any] # copy to be used externally _gql_cause: GqlError | None - def __init__(self, *args): - super().__init__(*args) - @staticmethod def _hydrate_cause(**metadata: t.Any) -> GqlError: meta_extractor = _MetaExtractor(metadata) diff --git a/tests/unit/common/test_exceptions.py b/tests/unit/common/test_exceptions.py index a433bc030..dfb444a01 100644 --- a/tests/unit/common/test_exceptions.py +++ b/tests/unit/common/test_exceptions.py @@ -17,10 +17,13 @@ from __future__ import annotations import contextlib +import re import traceback import pytest +import neo4j.exceptions +from neo4j import PreviewWarning from neo4j._exceptions import ( BoltError, BoltHandshakeError, @@ -154,6 +157,36 @@ def test_serviceunavailable_raised_from_bolt_protocol_error_with_explicit_style( assert e.value.__cause__ is error +def _assert_default_gql_error_attrs_from_neo4j_error(error: GqlError) -> None: + with pytest.warns(PreviewWarning, match="GQLSTATUS"): + assert error.gql_status == "50N42" + if error.message: + with pytest.warns(PreviewWarning, match="GQLSTATUS"): + assert error.gql_status_description == ( + "error: general processing exception - unexpected error. " + f"{error.message}" + ) + else: + with pytest.warns(PreviewWarning, match="GQLSTATUS"): + assert error.gql_status_description == ( + "error: general processing exception - unexpected error" + ) + with pytest.warns(PreviewWarning, match="GQLSTATUS"): + assert ( + error.gql_classification + == neo4j.exceptions.GqlErrorClassification.UNKNOWN + ) + with pytest.warns(PreviewWarning, match="GQLSTATUS"): + assert error.gql_raw_classification is None + with pytest.warns(PreviewWarning, match="GQLSTATUS"): + assert error.diagnostic_record == { + "CURRENT_SCHEMA": "/", + "OPERATION": "", + "OPERATION_CODE": "0", + } + assert error.__cause__ is None + + def test_neo4jerror_hydrate_with_no_args(): error = Neo4jError._hydrate_neo4j() @@ -164,6 +197,7 @@ def test_neo4jerror_hydrate_with_no_args(): assert error.metadata == {} assert error.message == "An unknown error occurred" assert error.code == "Neo.DatabaseError.General.UnknownError" + _assert_default_gql_error_attrs_from_neo4j_error(error) def test_neo4jerror_hydrate_with_message_and_code_rubbish(): @@ -178,6 +212,7 @@ def test_neo4jerror_hydrate_with_message_and_code_rubbish(): assert error.metadata == {} assert error.message == "Test error message" assert error.code == "ASDF_asdf" + _assert_default_gql_error_attrs_from_neo4j_error(error) def test_neo4jerror_hydrate_with_message_and_code_database(): @@ -193,6 +228,7 @@ def test_neo4jerror_hydrate_with_message_and_code_database(): assert error.metadata == {} assert error.message == "Test error message" assert error.code == "Neo.DatabaseError.General.UnknownError" + _assert_default_gql_error_attrs_from_neo4j_error(error) def test_neo4jerror_hydrate_with_message_and_code_transient(): @@ -208,6 +244,7 @@ def test_neo4jerror_hydrate_with_message_and_code_transient(): assert error.metadata == {} assert error.message == "Test error message" assert error.code == f"Neo.{CLASSIFICATION_TRANSIENT}.General.TestError" + _assert_default_gql_error_attrs_from_neo4j_error(error) def test_neo4jerror_hydrate_with_message_and_code_client(): @@ -223,6 +260,7 @@ def test_neo4jerror_hydrate_with_message_and_code_client(): assert error.metadata == {} assert error.message == "Test error message" assert error.code == f"Neo.{CLASSIFICATION_CLIENT}.General.TestError" + _assert_default_gql_error_attrs_from_neo4j_error(error) @pytest.mark.parametrize( @@ -260,9 +298,20 @@ def test_neo4jerror_hydrate_with_message_and_code_client(): ), ), ) -def test_error_rewrite(code, expected_cls, expected_code): +@pytest.mark.parametrize("mode", ("neo4j", "gql")) +def test_error_rewrite(code, expected_cls, expected_code, mode): message = "Test error message" - error = Neo4jError._hydrate_neo4j(message=message, code=code) + if mode == "neo4j": + error = Neo4jError._hydrate_neo4j(message=message, code=code) + elif mode == "gql": + error = Neo4jError._hydrate_gql( + gql_status="12345", + description="error: things - they hit the fan", + message=message, + neo4j_code=code, + ) + else: + raise ValueError(f"Invalid mode {mode!r}") expected_retryable = expected_cls is TransientError assert error.__class__ is expected_cls @@ -274,49 +323,95 @@ def test_error_rewrite(code, expected_cls, expected_code): @pytest.mark.parametrize( - ("code", "message", "expected_cls", "expected_str"), + ("code", "message", "expected_cls", "expected_str", "mode"), ( - ( - "Neo.ClientError.General.UnknownError", - "Test error message", - ClientError, - "{code: Neo.ClientError.General.UnknownError} " - "{message: Test error message}", - ), - ( - None, - "Test error message", - DatabaseError, - "{code: Neo.DatabaseError.General.UnknownError} " - "{message: Test error message}", + # values that behave the same in both modes + *( + ( + *x, + mode, + ) + for mode in ("neo4j", "gql") + for x in ( + ( + "Neo.ClientError.General.UnknownError", + "Test error message", + ClientError, + ( + "{code: Neo.ClientError.General.UnknownError} " + "{message: Test error message}" + ), + ), + ( + None, + "Test error message", + DatabaseError, + ( + "{code: Neo.DatabaseError.General.UnknownError} " + "{message: Test error message}" + ), + ), + ( + "Neo.ClientError.General.UnknownError", + None, + ClientError, + ( + "{code: Neo.ClientError.General.UnknownError} " + "{message: An unknown error occurred}" + ), + ), + ) ), + # neo4j error specific behavior ( "", "Test error message", DatabaseError, - "{code: Neo.DatabaseError.General.UnknownError} " - "{message: Test error message}", + ( + "{code: Neo.DatabaseError.General.UnknownError} " + "{message: Test error message}" + ), + "neo4j", ), ( "Neo.ClientError.General.UnknownError", - None, + "", ClientError, "{code: Neo.ClientError.General.UnknownError} " "{message: An unknown error occurred}", + "neo4j", + ), + # gql error specific behavior + ( + "", + "Test error message", + DatabaseError, + "{code: } {message: Test error message}", + "gql", ), ( "Neo.ClientError.General.UnknownError", "", ClientError, - "{code: Neo.ClientError.General.UnknownError} " - "{message: An unknown error occurred}", + "{code: Neo.ClientError.General.UnknownError} {message: }", + "gql", ), - )[1:2], + ), ) def test_neo4j_error_from_server_as_str( - code, message, expected_cls, expected_str + code, message, expected_cls, expected_str, mode ): - error = Neo4jError._hydrate_neo4j(code=code, message=message) + if mode == "neo4j": + error = Neo4jError._hydrate_neo4j(code=code, message=message) + elif mode == "gql": + error = Neo4jError._hydrate_gql( + gql_status="12345", + description="error: things - they hit the fan", + neo4j_code=code, + message=message, + ) + else: + raise ValueError(f"Invalid mode {mode!r}") assert type(error) is expected_cls assert str(error) == expected_str @@ -479,3 +574,177 @@ def test_cause_chain_extension_circular_local_causes( _CYCLIC_CAUSE_MARKER, ], ) + + +_DEFAULT_GQL_ERROR_ATTRIBUTES = { + "code": "Neo.DatabaseError.General.UnknownError", + "classification": "DatabaseError", + "category": "General", + "title": "UnknownError", + "message": "An unknown error occurred", + "gql_status": "50N42", + "gql_status_description": ( + "error: general processing exception - unexpected error" + ), + "gql_classification": neo4j.exceptions.GqlErrorClassification.UNKNOWN, + "gql_raw_classification": None, + "diagnostic_record": { + "CURRENT_SCHEMA": "/", + "OPERATION": "", + "OPERATION_CODE": "0", + }, + "__cause__": None, +} + + +@pytest.mark.parametrize( + ("metadata", "attributes"), + ( + # all default values + ( + {}, + _DEFAULT_GQL_ERROR_ATTRIBUTES, + ), + # example from ADR + ( + { + "gql_status": "01N00", + "message": "01EXAMPLE you have failed something", + "description": "client error - example error. Message", + "neo4j_code": "Neo.Example.Failure.Code", + "diagnostic_record": { + "CURRENT_SCHEMA": "", + "OPERATION": "", + "OPERATION_CODE": "", + "_classification": "CLIENT_ERROR", + "_status_parameters": {}, + }, + }, + { + "code": "Neo.Example.Failure.Code", + "classification": "Example", + "category": "Failure", + "title": "Code", + "message": "01EXAMPLE you have failed something", + "gql_status": "01N00", + "gql_status_description": ( + "client error - example error. Message" + ), + "gql_classification": ( + neo4j.exceptions.GqlErrorClassification.CLIENT_ERROR + ), + "gql_raw_classification": "CLIENT_ERROR", + "diagnostic_record": { + "CURRENT_SCHEMA": "", + "OPERATION": "", + "OPERATION_CODE": "", + "_classification": "CLIENT_ERROR", + "_status_parameters": {}, + }, + "__cause__": None, + }, + ), + # garbage diagnostic record + ( + { + "diagnostic_record": { + "CURRENT_SCHEMA": 1.5, + "OPERATION": False, + "_classification": ["whelp", None], + "_🤡": "🎈", + "foo": {"bar": "baz"}, + }, + }, + { + **_DEFAULT_GQL_ERROR_ATTRIBUTES, + "gql_classification": ( + neo4j.exceptions.GqlErrorClassification.UNKNOWN + ), + "gql_raw_classification": None, + "diagnostic_record": { + "CURRENT_SCHEMA": 1.5, + "OPERATION": False, + "_classification": ["whelp", None], + "_🤡": "🎈", + "foo": {"bar": "baz"}, + }, + }, + ), + ( + { + "diagnostic_record": { + "_classification": "SOME_FUTURE_CLASSIFICATION", + }, + }, + { + **_DEFAULT_GQL_ERROR_ATTRIBUTES, + "gql_classification": ( + neo4j.exceptions.GqlErrorClassification.UNKNOWN + ), + "gql_raw_classification": "SOME_FUTURE_CLASSIFICATION", + "diagnostic_record": { + "_classification": "SOME_FUTURE_CLASSIFICATION", + }, + }, + ), + ), +) +def test_gql_hydration(metadata, attributes): + # TODO: test causes + error = Neo4jError._hydrate_gql(**metadata) + + preview_attrs = { + "gql_status", + "gql_status_description", + "gql_classification", + "gql_raw_classification", + "diagnostic_record", + } + + for attr in ( + "code", + "classification", + "category", + "title", + "message", + "gql_status", + "gql_status_description", + "gql_classification", + "gql_raw_classification", + "diagnostic_record", + "__cause__", + ): + expected_value = attributes[attr] + if attr in preview_attrs: + with pytest.warns(PreviewWarning, match="GQLSTATUS"): + actual_value = getattr(error, attr) + else: + actual_value = getattr(error, attr) + assert actual_value == expected_value + + +@pytest.mark.parametrize( + "attr", + ( + "code", + "classification", + "category", + "title", + "message", + "metadata", + ), +) +def test_deprecated_setter(attr): + obj = object() + error = Neo4jError() + + with pytest.warns( + DeprecationWarning, + match=re.compile( + rf".*\baltering\b.*\b{attr}\b.*", + flags=re.IGNORECASE, + ), + ): + setattr(error, attr, obj) + + assert getattr(error, attr) is obj From acd01d416f51829d2a8306d01a93f3d101662f3a Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 3 Oct 2024 16:05:39 +0200 Subject: [PATCH 10/23] `.. versionadded:: 5.xx` -> `.. versionadded:: 5.26` --- src/neo4j/exceptions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/neo4j/exceptions.py b/src/neo4j/exceptions.py index f225f2a52..07a23f70e 100644 --- a/src/neo4j/exceptions.py +++ b/src/neo4j/exceptions.py @@ -253,7 +253,7 @@ class GqlErrorClassification(str, _Enum): .. seealso:: :attr:`.GqlError.gql_classification` - .. versionadded:: 5.xx + .. versionadded:: 5.26 """ CLIENT_ERROR = "CLIENT_ERROR" @@ -279,7 +279,7 @@ class GqlError(Exception): See also https://github.com/neo4j/neo4j-python-driver/wiki/preview-features - .. versionadded: 5.xx + .. versionadded: 5.26 """ _gql_status: str From 2d2aa6ea4515ffcbdc841e30aebf890b346f6c17 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 4 Oct 2024 16:21:15 +0200 Subject: [PATCH 11/23] TestKit backend: only sent GQLError causes to TestKit --- testkitbackend/totestkit.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/testkitbackend/totestkit.py b/testkitbackend/totestkit.py index e62c43f81..96e277db5 100644 --- a/testkitbackend/totestkit.py +++ b/testkitbackend/totestkit.py @@ -328,8 +328,9 @@ def driver_exc(exc, id_=None): payload["diagnosticRecord"] = { k: field(v) for k, v in exc.diagnostic_record.items() } - if exc.__cause__ is not None: - payload["cause"] = driver_exc_cause(exc.__cause__) + cause = exc, "__cause__", None + if isinstance(cause, GqlError): + payload["cause"] = driver_exc_cause(cause) return {"name": "DriverError", "data": payload} From 205a167c4f675bb217c32a3beeae29beeef8ba87 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 4 Oct 2024 18:05:11 +0200 Subject: [PATCH 12/23] Fix hydration/dehydration hooks confusion --- src/neo4j/_async/io/_bolt5.py | 2 +- src/neo4j/_sync/io/_bolt5.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/neo4j/_async/io/_bolt5.py b/src/neo4j/_async/io/_bolt5.py index 04426e12d..06336193f 100644 --- a/src/neo4j/_async/io/_bolt5.py +++ b/src/neo4j/_async/io/_bolt5.py @@ -227,7 +227,7 @@ async def route( response=Response( self, "route", hydration_hooks, on_success=metadata.update ), - dehydration_hooks=hydration_hooks, + dehydration_hooks=dehydration_hooks, ) await self.send_all() await self.fetch_all() diff --git a/src/neo4j/_sync/io/_bolt5.py b/src/neo4j/_sync/io/_bolt5.py index 9ea52c78d..4138a9d5d 100644 --- a/src/neo4j/_sync/io/_bolt5.py +++ b/src/neo4j/_sync/io/_bolt5.py @@ -227,7 +227,7 @@ def route( response=Response( self, "route", hydration_hooks, on_success=metadata.update ), - dehydration_hooks=hydration_hooks, + dehydration_hooks=dehydration_hooks, ) self.send_all() self.fetch_all() From 48b395b86476f289ebe4ad2c595b62e0ffd28a1c Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 4 Oct 2024 18:43:46 +0200 Subject: [PATCH 13/23] Loosen constraint on utc_patch handling --- src/neo4j/_codec/hydration/v1/hydration_handler.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/neo4j/_codec/hydration/v1/hydration_handler.py b/src/neo4j/_codec/hydration/v1/hydration_handler.py index 8b3c99777..f5ce1bbd5 100644 --- a/src/neo4j/_codec/hydration/v1/hydration_handler.py +++ b/src/neo4j/_codec/hydration/v1/hydration_handler.py @@ -154,7 +154,6 @@ def hydrate_path(self, nodes, relationships, sequence): class HydrationHandler(HydrationHandlerABC): def __init__(self): super().__init__() - self._created_scope = False self.struct_hydration_functions = { **self.struct_hydration_functions, b"X": spatial.hydrate_point, @@ -201,8 +200,6 @@ def __init__(self): def patch_utc(self): from ..v2 import temporal as temporal_v2 - assert not self._created_scope - del self.struct_hydration_functions[b"F"] del self.struct_hydration_functions[b"f"] self.struct_hydration_functions.update( @@ -226,5 +223,4 @@ def patch_utc(self): ) def new_hydration_scope(self): - self._created_scope = True return HydrationScope(self, _GraphHydrator()) From 6ef1812d728614f9bd5235e2cd3f3e7dd9e8349a Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 4 Oct 2024 19:12:31 +0200 Subject: [PATCH 14/23] TestKit backend: fix error traceback formatting --- testkitbackend/_async/backend.py | 12 ++++++++++-- testkitbackend/_sync/backend.py | 12 ++++++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/testkitbackend/_async/backend.py b/testkitbackend/_async/backend.py index ff2308076..76f7fb420 100644 --- a/testkitbackend/_async/backend.py +++ b/testkitbackend/_async/backend.py @@ -145,9 +145,17 @@ def _exc_stems_from_driver(exc): return True return None + @staticmethod + def _get_tb(exc): + return "".join( + traceback.format_exception( + type(exc), exc, getattr(exc, "__traceback__", None) + ) + ) + def _serialize_driver_exc(self, exc): log.debug(exc.args) - log.debug("".join(traceback.format_exception(exc))) + log.debug(self._get_tb(exc)) key = self.next_key() self.errors[key] = exc @@ -156,7 +164,7 @@ def _serialize_driver_exc(self, exc): @staticmethod def _serialize_backend_error(exc): - tb = "".join(traceback.format_exception(exc)) + tb = AsyncBackend._get_tb(exc) log.error(tb) return {"name": "BackendError", "data": {"msg": tb}} diff --git a/testkitbackend/_sync/backend.py b/testkitbackend/_sync/backend.py index b192c83c6..a5464703f 100644 --- a/testkitbackend/_sync/backend.py +++ b/testkitbackend/_sync/backend.py @@ -145,9 +145,17 @@ def _exc_stems_from_driver(exc): return True return None + @staticmethod + def _get_tb(exc): + return "".join( + traceback.format_exception( + type(exc), exc, getattr(exc, "__traceback__", None) + ) + ) + def _serialize_driver_exc(self, exc): log.debug(exc.args) - log.debug("".join(traceback.format_exception(exc))) + log.debug(self._get_tb(exc)) key = self.next_key() self.errors[key] = exc @@ -156,7 +164,7 @@ def _serialize_driver_exc(self, exc): @staticmethod def _serialize_backend_error(exc): - tb = "".join(traceback.format_exception(exc)) + tb = Backend._get_tb(exc) log.error(tb) return {"name": "BackendError", "data": {"msg": tb}} From 29017689d11aeec79b964fd6a73f59cb2caea1ec Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 4 Oct 2024 19:15:58 +0200 Subject: [PATCH 15/23] TestKit backend: fix GqlError cause serialization --- testkitbackend/totestkit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testkitbackend/totestkit.py b/testkitbackend/totestkit.py index 96e277db5..16eb7b21d 100644 --- a/testkitbackend/totestkit.py +++ b/testkitbackend/totestkit.py @@ -328,7 +328,7 @@ def driver_exc(exc, id_=None): payload["diagnosticRecord"] = { k: field(v) for k, v in exc.diagnostic_record.items() } - cause = exc, "__cause__", None + cause = getattr(exc, "__cause__", None) if isinstance(cause, GqlError): payload["cause"] = driver_exc_cause(cause) From 44f356a4e1886d0b4a319b7eeac957359ce66102 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 4 Oct 2024 19:27:40 +0200 Subject: [PATCH 16/23] TestKit backend: improve general error cause serialization --- testkitbackend/totestkit.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/testkitbackend/totestkit.py b/testkitbackend/totestkit.py index 16eb7b21d..9f0181f73 100644 --- a/testkitbackend/totestkit.py +++ b/testkitbackend/totestkit.py @@ -328,9 +328,9 @@ def driver_exc(exc, id_=None): payload["diagnosticRecord"] = { k: field(v) for k, v in exc.diagnostic_record.items() } - cause = getattr(exc, "__cause__", None) - if isinstance(cause, GqlError): - payload["cause"] = driver_exc_cause(cause) + cause = driver_exc_cause(getattr(exc, "__cause__", None)) + if cause is not None: + payload["cause"] = cause return {"name": "DriverError", "data": payload} @@ -342,8 +342,10 @@ def _exc_msg(exc): def driver_exc_cause(exc): + if exc is None: + return None if not isinstance(exc, GqlError): - raise TypeError("Expected GqlError as cause") + return driver_exc_cause(getattr(exc, "__cause__", None)) payload = {} with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): payload["msg"] = exc.message @@ -359,7 +361,8 @@ def driver_exc_cause(exc): payload["classification"] = exc.gql_classification with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): payload["rawClassification"] = exc.gql_raw_classification - if exc.__cause__ is not None: - payload["cause"] = driver_exc_cause(exc.__cause__) + cause = getattr(exc, "__cause__", None) + if cause is not None: + payload["cause"] = driver_exc_cause(cause) return {"name": "GqlError", "data": payload} From d74a1e7fdc8049347933fe65e410fe5d0d14154f Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 7 Oct 2024 09:39:24 +0200 Subject: [PATCH 17/23] TestKit backend: fix warning assertion --- testkitbackend/totestkit.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/testkitbackend/totestkit.py b/testkitbackend/totestkit.py index 9f0181f73..8d6a4222f 100644 --- a/testkitbackend/totestkit.py +++ b/testkitbackend/totestkit.py @@ -347,7 +347,10 @@ def driver_exc_cause(exc): if not isinstance(exc, GqlError): return driver_exc_cause(getattr(exc, "__cause__", None)) payload = {} - with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + if not isinstance(exc, Neo4jError): + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + payload["msg"] = exc.message + else: payload["msg"] = exc.message with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): payload["gqlStatus"] = exc.gql_status From 129de58c697b69bb1f5d5417c55e50b1921bc777 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 7 Oct 2024 10:19:24 +0200 Subject: [PATCH 18/23] TestKit backend: restore error msg formatting --- testkitbackend/totestkit.py | 38 ++++++++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/testkitbackend/totestkit.py b/testkitbackend/totestkit.py index 8d6a4222f..968d8fa1e 100644 --- a/testkitbackend/totestkit.py +++ b/testkitbackend/totestkit.py @@ -335,23 +335,39 @@ def driver_exc(exc, id_=None): return {"name": "DriverError", "data": payload} -def _exc_msg(exc): +def _exc_msg(exc, max_depth=10): if isinstance(exc, Neo4jError) and exc.message is not None: return str(exc.message) - return str(exc) + depth = 0 + if isinstance(exc, GqlError): + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + res = exc.message + else: + res = str(exc) + while getattr(exc, "__cause__", None) is not None: + if isinstance(exc.__cause__, GqlError): + # Not including GqlError in the chain as they will be serialized + # separately in the `cause` field. + break + depth += 1 + if depth >= max_depth: + break + res += f"\nCaused by: {exc.__cause__!r}" + exc = exc.__cause__ + return res -def driver_exc_cause(exc): + +def driver_exc_cause(exc, max_depth=10): if exc is None: return None + if max_depth <= 0: + return None if not isinstance(exc, GqlError): - return driver_exc_cause(getattr(exc, "__cause__", None)) - payload = {} - if not isinstance(exc, Neo4jError): - with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): - payload["msg"] = exc.message - else: - payload["msg"] = exc.message + return driver_exc_cause( + getattr(exc, "__cause__", None), max_depth=max_depth - 1 + ) + payload = {"msg": _exc_msg(exc)} with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): payload["gqlStatus"] = exc.gql_status with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): @@ -366,6 +382,6 @@ def driver_exc_cause(exc): payload["rawClassification"] = exc.gql_raw_classification cause = getattr(exc, "__cause__", None) if cause is not None: - payload["cause"] = driver_exc_cause(cause) + payload["cause"] = driver_exc_cause(cause, max_depth=max_depth - 1) return {"name": "GqlError", "data": payload} From b4d47a07524af9871815bd95891e603df5728308 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 7 Oct 2024 11:01:29 +0200 Subject: [PATCH 19/23] TestKit backend: further error message serialization fixes --- testkitbackend/totestkit.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/testkitbackend/totestkit.py b/testkitbackend/totestkit.py index 968d8fa1e..4a89d4654 100644 --- a/testkitbackend/totestkit.py +++ b/testkitbackend/totestkit.py @@ -20,6 +20,7 @@ import neo4j from neo4j.exceptions import ( + DriverError, GqlError, Neo4jError, ) @@ -341,8 +342,13 @@ def _exc_msg(exc, max_depth=10): depth = 0 if isinstance(exc, GqlError): - with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): - res = exc.message + if isinstance(exc, Neo4jError) and exc.message is not None: + res = str(exc.message) + elif isinstance(exc, DriverError): + res = str(exc) + else: + with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): + res = exc.message else: res = str(exc) while getattr(exc, "__cause__", None) is not None: From 9dc131dcfc4bad3035c2e83d1befa5c1a52c58a6 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 7 Oct 2024 16:02:01 +0200 Subject: [PATCH 20/23] TestKit backend: I swear, I'll get error serialization right... eventually --- testkitbackend/totestkit.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/testkitbackend/totestkit.py b/testkitbackend/totestkit.py index 4a89d4654..5e6ec227b 100644 --- a/testkitbackend/totestkit.py +++ b/testkitbackend/totestkit.py @@ -20,7 +20,6 @@ import neo4j from neo4j.exceptions import ( - DriverError, GqlError, Neo4jError, ) @@ -342,10 +341,8 @@ def _exc_msg(exc, max_depth=10): depth = 0 if isinstance(exc, GqlError): - if isinstance(exc, Neo4jError) and exc.message is not None: - res = str(exc.message) - elif isinstance(exc, DriverError): - res = str(exc) + if isinstance(exc, Neo4jError): + res = str(exc.message) if exc.message is not None else str(exc) else: with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): res = exc.message From e0bca52d974c3adb2f90413032ce3382f4a1a0ae Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 7 Oct 2024 17:02:39 +0200 Subject: [PATCH 21/23] TestKit backend: surely this time the error formatting is right... right?? --- testkitbackend/totestkit.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/testkitbackend/totestkit.py b/testkitbackend/totestkit.py index 5e6ec227b..baab4bdea 100644 --- a/testkitbackend/totestkit.py +++ b/testkitbackend/totestkit.py @@ -345,7 +345,9 @@ def _exc_msg(exc, max_depth=10): res = str(exc.message) if exc.message is not None else str(exc) else: with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): - res = exc.message + msg = exc.message + if exc.args: + res = f"{msg} - {exc!s}" else: res = str(exc) while getattr(exc, "__cause__", None) is not None: From 65805660347d1c111cd1529b835d8a06a3683c0f Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 8 Oct 2024 09:08:25 +0200 Subject: [PATCH 22/23] TestKit backend: are you feeling it now, Mr. Krabs? --- testkitbackend/totestkit.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/testkitbackend/totestkit.py b/testkitbackend/totestkit.py index baab4bdea..87d591cf9 100644 --- a/testkitbackend/totestkit.py +++ b/testkitbackend/totestkit.py @@ -346,8 +346,7 @@ def _exc_msg(exc, max_depth=10): else: with warning_check(neo4j.PreviewWarning, r".*\bGQLSTATUS\b.*"): msg = exc.message - if exc.args: - res = f"{msg} - {exc!s}" + res = f"{msg} - {exc!s}" if exc.args else msg else: res = str(exc) while getattr(exc, "__cause__", None) is not None: From ed3908f7e50df2eedbfdd319bb296eabad693334 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 8 Oct 2024 11:02:59 +0200 Subject: [PATCH 23/23] =?UTF-8?q?=F0=9F=98=AD=F0=9F=99=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- testkitbackend/totestkit.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/testkitbackend/totestkit.py b/testkitbackend/totestkit.py index 87d591cf9..eac8aec3b 100644 --- a/testkitbackend/totestkit.py +++ b/testkitbackend/totestkit.py @@ -22,6 +22,7 @@ from neo4j.exceptions import ( GqlError, Neo4jError, + ResultFailedError, ) from neo4j.graph import ( Node, @@ -350,9 +351,16 @@ def _exc_msg(exc, max_depth=10): else: res = str(exc) while getattr(exc, "__cause__", None) is not None: - if isinstance(exc.__cause__, GqlError): + if ( # Not including GqlError in the chain as they will be serialized # separately in the `cause` field. + isinstance(exc.__cause__, GqlError) + # Special case for ResultFailedError: + # Always serialize the cause in the message to please TestKit. + # Else, the cause's class name will get lost (can't be serialized + # as a field in of an error cause). + and not isinstance(exc, ResultFailedError) + ): break depth += 1 if depth >= max_depth: