diff --git a/boltstub/__main__.py b/boltstub/__main__.py index 2ba4a7de..a3ed004c 100644 --- a/boltstub/__main__.py +++ b/boltstub/__main__.py @@ -146,8 +146,8 @@ def signal_handler(sig, frame): elif sigint_count == 2: print("2nd SIGINT received. Closing all connections.") service.close_all_connections_async() - if sigint_count > 3: - print("3nd SIGINT received. Hard exit.") + elif sigint_count >= 3: + print("3rd SIGINT received. Hard exit.") return exit_(130) if platform.system() == "Windows": diff --git a/boltstub/bolt_protocol.py b/boltstub/bolt_protocol.py index cf4e2107..fb70c08d 100644 --- a/boltstub/bolt_protocol.py +++ b/boltstub/bolt_protocol.py @@ -561,3 +561,12 @@ class Bolt5x6Protocol(Bolt5x5Protocol): equivalent_versions = set() server_agent = "Neo4j/5.23.0" + + +class Bolt5x7Protocol(Bolt5x6Protocol): + protocol_version = (5, 7) + version_aliases = set() + # allow the server to negotiate other bolt versions + equivalent_versions = set() + + server_agent = "Neo4j/5.24.0" diff --git a/nutkit/protocol/cypher.py b/nutkit/protocol/cypher.py index 2c273c1c..481a8d2c 100644 --- a/nutkit/protocol/cypher.py +++ b/nutkit/protocol/cypher.py @@ -496,4 +496,18 @@ def as_cypher_type(value): return CypherString(value) if isinstance(value, (bytes, bytearray)): return CypherBytes(value) + if isinstance( + value, + ( + CypherNode, + CypherRelationship, + CypherPath, + CypherPoint, + CypherDate, + CypherTime, + CypherDateTime, + CypherDuration, + ) + ): + return value raise TypeError("Unsupported type: {}".format(type(value))) diff --git a/nutkit/protocol/feature.py b/nutkit/protocol/feature.py index c7467131..02f5bfaa 100644 --- a/nutkit/protocol/feature.py +++ b/nutkit/protocol/feature.py @@ -123,6 +123,8 @@ class Feature(Enum): BOLT_5_5 = "Feature:Bolt:5.5" # The driver supports Bolt protocol version 5.6 BOLT_5_6 = "Feature:Bolt:5.6" + # The driver supports Bolt protocol version 5.7 + BOLT_5_7 = "Feature:Bolt:5.7" # The driver supports patching DateTimes to use UTC for Bolt 4.3 and 4.4 BOLT_PATCH_UTC = "Feature:Bolt:Patch:UTC" # The driver supports impersonation diff --git a/nutkit/protocol/responses.py b/nutkit/protocol/responses.py index 0fd8e625..aaec1442 100644 --- a/nutkit/protocol/responses.py +++ b/nutkit/protocol/responses.py @@ -761,20 +761,81 @@ class DriverError(BaseError): """ def __init__(self, id=None, errorType=None, msg="", code="", - retryable=None): + retryable=None, gqlStatus=None, statusDescription=None, + cause=None, diagnosticRecord=None, classification=None, + rawClassification=None): self.id = id self.errorType = errorType self.msg = msg self.code = code self.retryable = retryable + assert isinstance(gqlStatus, (str, type(None))) + self.gql_status = gqlStatus + assert isinstance(statusDescription, (str, type(None))) + self.status_description = statusDescription + if cause is not None: + assert isinstance(cause, GqlError) + self.cause = cause + assert isinstance(diagnosticRecord, (dict, type(None))) + self.diagnostic_record = diagnosticRecord + assert isinstance(classification, (str, type(None))) + self.classification = classification + assert isinstance(rawClassification, (str, type(None))) + self.raw_classification = rawClassification def __str__(self): - return f"DriverError(type={self.errorType}, msg={self.msg!r})" + return ( + f"DriverError(" + f"errorType={self.errorType!r}, " + f"msg={self.msg!r}, " + f"code={self.code!r}, " + f"retryable={self.retryable!r}, " + f"gqlStatus={self.gql_status!r}, " + f"statusDescription={self.status_description!r}, " + f"diagnosticRecord={self.diagnostic_record!r}, " + f"classification={self.classification!r}, " + f"rawClassification={self.raw_classification!r}, " + f"cause={self.cause!r})" + ) def __repr__(self): return self.__str__() +class GqlError: + """TODO.""" + + def __init__(self, msg="", gqlStatus=None, statusDescription=None, + cause=None, diagnosticRecord=None, classification=None, + rawClassification=None): + self.msg = msg + assert isinstance(gqlStatus, (str, type(None))) + self.gql_status = gqlStatus + assert isinstance(statusDescription, (str, type(None))) + self.status_description = statusDescription + if cause is not None: + assert isinstance(cause, GqlError) + self.cause = cause + assert isinstance(diagnosticRecord, (dict, type(None))) + self.diagnostic_record = diagnosticRecord + assert isinstance(classification, (str, type(None))) + self.classification = classification + assert isinstance(rawClassification, (str, type(None))) + self.raw_classification = rawClassification + + def __str__(self): + return ( + f"DriverErrorCause(" + f"msg={self.msg!r}, " + f"gqlStatus={self.gql_status!r}, " + f"statusDescription={self.status_description!r}, " + f"diagnosticRecord={self.diagnostic_record!r}, " + f"classification={self.classification!r}, " + f"rawClassification={self.raw_classification!r}, " + f"cause={self.cause!r})" + ) + + class FrontendError(BaseError): """ Error originating from client code. diff --git a/tests/stub/errors/__init__.py b/tests/stub/errors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/stub/errors/scripts/error.script b/tests/stub/errors/scripts/error.script new file mode 100644 index 00000000..5c2bd25e --- /dev/null +++ b/tests/stub/errors/scripts/error.script @@ -0,0 +1,19 @@ +!: BOLT #BOLT_VERSION# + +A: HELLO {"{}": "*"} +A: LOGON {"{}": "*"} +*: RESET +C: RUN {"U": "*"} {"{}": "*"} {"{}": "*"} +S: FAILURE #ERROR# +# Allow driver to pipeline a PULL or DISCARD after RUN +{? + {{ + C: PULL {"[n]": {"Z": "*"}} + S: IGNORED + ---- + C: DISCARD {"[n]": {"Z": "*"}} + S: IGNORED + }} +?} ++: RESET +A: GOODBYE diff --git a/tests/stub/errors/test_errors.py b/tests/stub/errors/test_errors.py new file mode 100644 index 00000000..34f1d46e --- /dev/null +++ b/tests/stub/errors/test_errors.py @@ -0,0 +1,344 @@ +import json +from abc import ( + ABC, + abstractmethod, +) +from contextlib import contextmanager +from copy import deepcopy + +import nutkit.protocol as types +from nutkit.frontend import Driver +from tests.shared import ( + driver_feature, + TestkitTestCase, +) +from tests.stub.shared import StubServer + + +class _ErrorTestCase(TestkitTestCase, ABC): + @property + @abstractmethod + def bolt_version(self) -> str: + pass + + @contextmanager + def server(self, script, vars_=None): + if vars_ is None: + vars_ = {} + vars_.update({"#BOLT_VERSION#": self.bolt_version}) + server = StubServer(9001) + server.start(path=self.script_path(script), + vars_=vars_) + try: + yield server + except Exception: + server.reset() + raise + + server.done() + + @contextmanager + def driver(self, server): + auth = types.AuthorizationToken("bearer", credentials="foo") + uri = f"bolt://{server.address}" + driver = Driver(self._backend, uri, auth) + try: + yield driver + finally: + driver.close() + + @contextmanager + def session(self, driver): + session = driver.session("w") + try: + yield session + finally: + session.close() + + def get_error(self, error_data): + def run(session_): + session_.run("RETURN 1").consume() + + vars_ = {"#ERROR#": json.dumps(error_data)} + with self.server("error.script", vars_=vars_) as server: + with self.driver(server) as driver: + with self.session(driver) as session: + with self.assertRaises(types.DriverError) as exc: + run(session) + return exc.exception + + +DEFAULT_DIAG_REC = { + "CURRENT_SCHEMA": "/", + "OPERATION": "", + "OPERATION_CODE": "0", +} + + +class TestError5x6(_ErrorTestCase): + required_features = ( + types.Feature.BOLT_5_6, + ) + + bolt_version = "5.6" + + def test_error(self): + supports_retryable_check = self.driver_supports_features( + types.Feature.API_RETRYABLE_EXCEPTION + ) + supports_bolt_5_7 = self.driver_supports_features( + types.Feature.BOLT_5_7 + ) + + for (error_code, retryable) in ( + ("Neo.ClientError.User.Uncool", False), + ("Neo.TransientError.Oopsie.OhSnap", True), + ): + with self.subTest(code=error_code): + error_message = "Sever ain't cool with this!" + error_data = { + "code": error_code, + "message": error_message, + } + + exc = self.get_error(error_data) + + self.assertEqual(exc.code, error_code) + self.assertEqual(exc.msg, error_message) + if supports_retryable_check: + self.assertEqual(exc.retryable, retryable) + if supports_bolt_5_7: + self.assertEqual(exc.gql_status, "50N42") + self.assertEqual( + exc.status_description, + "error: " + "general processing exception - unexpected error. " + f"{error_message}", + ) + self.assertEqual( + exc.diagnostic_record, + types.as_cypher_type(DEFAULT_DIAG_REC).value, + ) + self.assertEqual(exc.raw_classification, None) + self.assertEqual(exc.classification, "UNKNOWN") + self.assertIsNone(exc.cause) + if supports_retryable_check: + self.assertEqual(exc.retryable, retryable) + + +class TestError5x7(_ErrorTestCase): + required_features = ( + types.Feature.BOLT_5_7, + ) + + bolt_version = "5.7" + + def _make_test_error_data( + self, + status=..., + description=..., + message=..., + code=..., + diagnostic_record=..., + extra_diag_rec=None, + del_diag_rec=None, + cause=None, + ): + data = {} + if status is not None: + data["gql_status"] = "01N00" if status is ... else status + if description is not None: + data["description"] = ( + "cool class - mediocre subclass" + if description is ... + else description + ) + if message is not None: + data["message"] = ( + "Sever ain't cool with this, John Doe!" + if message is ... + else message + ) + if code is not None: + data["neo4j_code"] = ( + "Neo.ClientError.User.Uncool" if code is ... else code + ) + if diagnostic_record is ...: + data["diagnostic_record"] = { + **DEFAULT_DIAG_REC, + "_classification": "CLIENT_ERROR", + "_status_parameters": {"userName": "John Doe"}, + } + elif diagnostic_record is not None: + data["diagnostic_record"] = diagnostic_record + if extra_diag_rec is not None: + data["diagnostic_record"].update(extra_diag_rec) + if del_diag_rec: + for key in del_diag_rec: + del data["diagnostic_record"][key] + if cause is not None: + data["cause"] = cause + return data + + def _assert_is_test_error(self, exc, data): + self.assertEqual(exc.gql_status, data["gql_status"]) + self.assertEqual(exc.status_description, data["description"]) + self.assertEqual(exc.msg, data["message"]) + if "neo4j_code" in data: + self.assertEqual(exc.code, data["neo4j_code"]) + else: + self.assertFalse(hasattr(exc, "code")) + expected_diag_rec = deepcopy(data.get("diagnostic_record", {})) + for k, v in DEFAULT_DIAG_REC.items(): + expected_diag_rec.setdefault(k, v) + self.assertEqual( + exc.diagnostic_record, + types.as_cypher_type(expected_diag_rec).value, + ) + expected_raw_classification = expected_diag_rec.get("_classification") + if isinstance(expected_raw_classification, str): + self.assertEqual( + exc.raw_classification, + expected_raw_classification, + ) + expected_classification = "UNKNOWN" + if expected_raw_classification in { + "CLIENT_ERROR", + "DATABASE_ERROR", + "TRANSIENT_ERROR", + }: + expected_classification = expected_raw_classification + self.assertEqual(exc.classification, expected_classification) + if "cause" in data: + self._assert_is_test_error(exc.cause, data["cause"]) + + def test_simple_gql_error(self): + error_data = self._make_test_error_data() + error = self.get_error(error_data) + self._assert_is_test_error(error, error_data) + + def test_nested_gql_error(self): + for depth in (1, 10): + with self.subTest(depth=depth): + cause = None + for i in range(depth, 0, -1): + cause = self._make_test_error_data( + status=f"01N{i:02d}", + description=f"description ({i})", + message=f"message ({i})", + code=None, + diagnostic_record={ + "CURRENT_SCHEMA": f"/{i}", + "OPERATION": f"OP{i}", + "OPERATION_CODE": f"{i}", + "_classification": f"CLIENT_ERROR{i}", + "_status_parameters": {"nestedCause": i}, + }, + cause=cause, + ) + error_data = self._make_test_error_data(cause=cause) + error = self.get_error(error_data) + self._assert_is_test_error(error, error_data) + + def test_error_classification(self): + for as_cause in (False, True): + for classification in ( + "CLIENT_ERROR", + "DATABASE_ERROR", + "TRANSIENT_ERROR", + "SECURITY_ERROR", # made up classification + ): + with self.subTest( + as_cause=as_cause, + classification=classification, + ): + error_data = self._make_test_error_data( + extra_diag_rec={"_classification": classification}, + code=None if as_cause else ..., + ) + if as_cause: + error_data = self._make_test_error_data( + cause=error_data, + ) + error = self.get_error(error_data) + self._assert_is_test_error(error, error_data) + + def test_filling_default_diagnostic_record(self): + for as_cause in (False, True): + with self.subTest(as_cause=as_cause): + error_data = self._make_test_error_data( + diagnostic_record=None, + code=None if as_cause else ..., + ) + if as_cause: + error_data = self._make_test_error_data(cause=error_data) + error = self.get_error(error_data) + self._assert_is_test_error(error, error_data) + + def test_filling_default_value_in_diagnostic_record(self): + for as_cause in (False, True): + for missing_key in ( + "CURRENT_SCHEMA", + "OPERATION", + "OPERATION_CODE", + ): + with self.subTest(as_cause=as_cause, missing_key=missing_key): + error_data = self._make_test_error_data( + del_diag_rec=[missing_key], + code=None if as_cause else ..., + ) + if as_cause: + error_data = self._make_test_error_data( + cause=error_data + ) + error = self.get_error(error_data) + self._assert_is_test_error(error, error_data) + + def test_keeps_rubbish_in_diagnostic_record(self): + use_spacial = self.driver_supports_features( + types.Feature.API_TYPE_SPATIAL + ) + for as_cause in (False, True): + with self.subTest(as_cause=as_cause): + diagnostic_record = { + "foo": "bar", + "_baz": 1.2, + "OPERATION": None, + "CURRENT_SCHEMA": {"uh": "oh!"}, + "OPERATION_CODE": False, + "_classification": 42, + "_status_parameters": [ + # stub script will interpret this as JOLT spatial point + {"@": "SRID=4326;POINT(56.21 13.43)"} + if use_spacial + else "whatever", + ], + } + error_data = self._make_test_error_data( + diagnostic_record=diagnostic_record, + code=None if as_cause else ..., + ) + if as_cause: + error_data = self._make_test_error_data(cause=error_data) + + error = self.get_error(error_data) + + if use_spacial: + diagnostic_record["_status_parameters"] = [ + types.CypherPoint("wgs84", 56.21, 13.43) + ] + self._assert_is_test_error(error, error_data) + + @driver_feature(types.Feature.API_RETRYABLE_EXCEPTION) + def test_error_retryable(self): + for (neo4j_code, retryable) in ( + ("Neo.ClientError.User.Uncool", False), + ("Neo.TransientError.Oopsie.OhSnap", True), + ): + with self.subTest(error_code=neo4j_code): + error_data = self._make_test_error_data(code=neo4j_code) + + exc = self.get_error(error_data) + + self._assert_is_test_error(exc, error_data) + self.assertEqual(exc.retryable, retryable) diff --git a/tests/stub/versions/scripts/v5x7_return_1.script b/tests/stub/versions/scripts/v5x7_return_1.script new file mode 100644 index 00000000..037d43f4 --- /dev/null +++ b/tests/stub/versions/scripts/v5x7_return_1.script @@ -0,0 +1,19 @@ +!: BOLT 5.7 + +C: HELLO {"{}": "*"} +S: SUCCESS {"server": "#SERVER_AGENT#", "connection_id": "bolt-123456789"} +A: LOGON {"{}": "*"} +*: RESET +{? + ?: TELEMETRY {"{}": "*"} + C: RUN {"U": "*"} {"{}": "*"} {"{}": "*"} + S: SUCCESS {"fields": ["n.name"]} + {{ + C: PULL {"n": {"Z": "*"}} + ---- + C: DISCARD {"n": {"Z": "*"}} + }} + S: SUCCESS {"type": "w"} +?} +*: RESET +?: GOODBYE diff --git a/tests/stub/versions/test_versions.py b/tests/stub/versions/test_versions.py index 44156408..39f4e921 100644 --- a/tests/stub/versions/test_versions.py +++ b/tests/stub/versions/test_versions.py @@ -157,9 +157,13 @@ def test_supports_bolt5x4(self): def test_supports_bolt5x6(self): self._run("5x6") + @driver_feature(types.Feature.BOLT_5_7) + def test_supports_bolt5x7(self): + self._run("5x7") + def test_server_version(self): for version in ( - "5x6", "5x4", "5x3", "5x2", "5x1", "5x0", + "5x7", "5x6", "5x4", "5x3", "5x2", "5x1", "5x0", "4x4", "4x3", "4x2", "4x1", "3" ): if not self.driver_supports_bolt(version): @@ -169,7 +173,7 @@ def test_server_version(self): def test_server_agent(self): for version in ( - "5x6", "5x4", "5x3", "5x2", "5x1", "5x0", + "5x7", "5x6", "5x4", "5x3", "5x2", "5x1", "5x0", "4x4", "4x3", "4x2", "4x1", "3" ): for agent, reject in ( @@ -204,7 +208,7 @@ def test_server_address_in_summary(self): if get_driver_name() in ["javascript", "dotnet"]: self.skipTest("Backend doesn't support server address in summary") for version in ( - "5x6", "5x4", "5x3", "5x2", "5x1", "5x0", + "5x7", "5x6", "5x4", "5x3", "5x2", "5x1", "5x0", "4x4", "4x3", "4x2", "4x1", "3" ): if not self.driver_supports_bolt(version):