diff --git a/boltstub/bolt_protocol.py b/boltstub/bolt_protocol.py index 13c8b2df..88db291d 100644 --- a/boltstub/bolt_protocol.py +++ b/boltstub/bolt_protocol.py @@ -552,3 +552,12 @@ class Bolt5x5Protocol(Bolt5x4Protocol): equivalent_versions = set() server_agent = "Neo4j/5.21.0" + + +class Bolt5x6Protocol(Bolt5x5Protocol): + protocol_version = (5, 6) + version_aliases = set() + # allow the server to negotiate other bolt versions + equivalent_versions = set() + + server_agent = "Neo4j/5.24.0" diff --git a/nutkit/protocol/feature.py b/nutkit/protocol/feature.py index 68504b5f..39a66836 100644 --- a/nutkit/protocol/feature.py +++ b/nutkit/protocol/feature.py @@ -120,6 +120,8 @@ class Feature(Enum): BOLT_5_4 = "Feature:Bolt:5.4" # The driver supports Bolt protocol version 5.5 BOLT_5_5 = "Feature:Bolt:5.5" + # The driver supports Bolt protocol version 5.6 + BOLT_5_6 = "Feature:Bolt:5.6" # The driver supports patching DateTimes to use UTC for Bolt 4.3 and 4.4 BOLT_PATCH_UTC = "Feature:Bolt:Patch:UTC" # The driver supports impersonation diff --git a/nutkit/protocol/responses.py b/nutkit/protocol/responses.py index 226e8333..ef228bd9 100644 --- a/nutkit/protocol/responses.py +++ b/nutkit/protocol/responses.py @@ -761,12 +761,28 @@ class DriverError(BaseError): """ def __init__(self, id=None, errorType=None, msg="", code="", - retryable=None): + retryable=None, gqlStatus=None, statusDescription=None, + cause=None, diagnosticRecord=None, classification=None): self.id = id self.errorType = errorType self.msg = msg self.code = code self.retryable = retryable + assert isinstance(gqlStatus, (str, type(None))) + self.gql_status = gqlStatus + assert isinstance(statusDescription, (str, type(None))) + self.status_description = statusDescription + if cause is not None: + assert isinstance(cause, DriverError) + # Don't bother giving cause errors IDs. + # They might not even represent re-throwable exceptions in the + # driver. + assert cause.id is None + self.cause = cause + assert isinstance(diagnosticRecord, (dict, type(None))) + self.diagnostic_record = diagnosticRecord + assert isinstance(classification, (str, type(None))) + self.classification = classification def __str__(self): return f"DriverError(type={self.errorType}, msg={self.msg!r})" diff --git a/tests/stub/errors/__init__.py b/tests/stub/errors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/stub/errors/scripts/error.script b/tests/stub/errors/scripts/error.script new file mode 100644 index 00000000..5c2bd25e --- /dev/null +++ b/tests/stub/errors/scripts/error.script @@ -0,0 +1,19 @@ +!: BOLT #BOLT_VERSION# + +A: HELLO {"{}": "*"} +A: LOGON {"{}": "*"} +*: RESET +C: RUN {"U": "*"} {"{}": "*"} {"{}": "*"} +S: FAILURE #ERROR# +# Allow driver to pipeline a PULL or DISCARD after RUN +{? + {{ + C: PULL {"[n]": {"Z": "*"}} + S: IGNORED + ---- + C: DISCARD {"[n]": {"Z": "*"}} + S: IGNORED + }} +?} ++: RESET +A: GOODBYE diff --git a/tests/stub/errors/test_errors.py b/tests/stub/errors/test_errors.py new file mode 100644 index 00000000..640df60f --- /dev/null +++ b/tests/stub/errors/test_errors.py @@ -0,0 +1,238 @@ +import json +from abc import ( + ABC, + abstractmethod, +) +from contextlib import contextmanager + +import nutkit.protocol as types +from nutkit.frontend import Driver +from tests.shared import TestkitTestCase +from tests.stub.shared import StubServer + + +class _ErrorTestCase(TestkitTestCase, ABC): + @property + @abstractmethod + def bolt_version(self) -> str: + pass + + @contextmanager + def server(self, script, vars_=None): + if vars_ is None: + vars_ = {} + vars_.update({"#BOLT_VERSION#": self.bolt_version}) + server = StubServer(9001) + server.start(path=self.script_path(script), + vars_=vars_) + try: + yield server + except Exception: + server.reset() + raise + + server.done() + + @contextmanager + def driver(self, server): + auth = types.AuthorizationToken("bearer", credentials="foo") + uri = f"bolt://{server.address}" + driver = Driver(self._backend, uri, auth) + try: + yield driver + finally: + driver.close() + + @contextmanager + def session(self, driver): + session = driver.session("w") + try: + yield session + finally: + session.close() + + def get_error(self, error_data): + def run(session_): + session_.run("RETURN 1").consume() + + vars_ = {"#ERROR#": json.dumps(error_data)} + with self.server("error.script", vars_=vars_) as server: + with self.driver(server) as driver: + with self.session(driver) as session: + with self.assertRaises(types.DriverError) as exc: + run(session) + return exc.exception + + +class TestError5x5(_ErrorTestCase): + required_features = ( + types.Feature.BOLT_5_5, + ) + + bolt_version = "5.5" + + def test_error(self): + for (error_code, retryable) in ( + ("Neo.ClientError.User.Uncool", False), + ("Neo.TransientError.Oopsie.OhSnap", True), + ): + with self.subTest(code=error_code): + error_message = "Sever ain't cool with this!" + error_data = {"code": error_code, "message": error_message} + + error = self.get_error(error_data) + + self.assertEqual(error.code, error_code) + self.assertEqual(error.msg, error_message) + self.assertEqual(error.retryable, retryable) + if self.driver_supports_features(types.Feature.BOLT_5_6): + self.assertEqual(error.gql_status, "50N42") + expected_desc = ( + "error: " + "general processing exception - unknown error. " + f"{error_message}" + ) + self.assertEqual(error.status_description, expected_desc) + self.assertIsNone(error.cause) + self.assertEqual( + error.diagnostic_record, + types.as_cypher_type(DEFAULT_DIAG_REC).value + ) + # TODO: TBD + # self.assertEqual(error.classification, "UNKNOWN") + + +DEFAULT_DIAG_REC = { + "CURRENT_SCHEMA": "/", + "OPERATION": "", + "OPERATION_CODE": "0", +} + + +class TestError5x6(_ErrorTestCase): + required_features = ( + types.Feature.BOLT_5_6, + ) + + bolt_version = "5.6" + + def test_error(self): + error_status = "01N00" + error_message = "Sever ain't cool with this, John Doe!" + error_explanation = "cool class - mediocre subclass" + error_code = "Neo.ClientError.User.Uncool" + diagnostic_record = { + "CURRENT_SCHEMA": "/", + "OPERATION": "", + "OPERATION_CODE": "0", + "_classification": "TBD", # TODO + "_status_parameters": { + "userName": "John Doe", + }, + } + error_data = { + "gql_status": error_status, + "status_message": error_message, + "status_explanation": error_explanation, + "neo4j_code": error_code, + "diagnostic_record": diagnostic_record, + } + + error = self.get_error(error_data) + + self.assertEqual(error.code, error_code) + self.assertEqual(error.msg, error_message) + # TODO: what part of the error is used to determine retryability? + # self.assertEqual(error.retryable, retryable) + self.assertEqual(error.gql_status, error_status) + self.assertEqual( + error.status_description, + f"error: {error_explanation}. {error_message}" + ) + self.assertIsNone(error.cause) + self.assertEqual(error.diagnostic_record, + types.as_cypher_type(diagnostic_record).value) + # TODO: TBD + # self.assertEqual(error.classification, "UNKNOWN") + + # TODO: test driver fills in default values in diag. rec. + + def test_nested_error(self): + error_status = "01ABC" + error_code = "Neo.ClientError.Bar.Baz" + cause_status = "01N00" + cause_message = "Sever ain't cool with this, John Doe!" + cause_explanation = "cool class - mediocre subclass" + cause_code = "Neo.ClientError.User.Uncool" + diagnostic_record = { + "CURRENT_SCHEMA": "/", + "OPERATION": "", + "OPERATION_CODE": "0", + "_classification": "TBD", # TODO + "_status_parameters": { + "userName": "John Doe", + }, + } + error_data = { + "gql_status": error_status, + "status_message": "msg", + "status_explanation": "explanation", + "neo4j_code": error_code, + "diagnostic_record": DEFAULT_DIAG_REC, + "cause": { + "gql_status": cause_status, + "status_message": cause_message, + "status_explanation": cause_explanation, + "neo4j_code": cause_code, + "diagnostic_record": diagnostic_record, + }, + } + + error = self.get_error(error_data) + + self.assertIsInstance(error, types.DriverError) + self.assertEqual(error.code, error_code) + + cause = error.cause + self.assertIsNotNone(cause) + self.assertEqual(cause.code, cause_code) + self.assertEqual(cause.msg, cause_message) + # TODO: self.assertEqual(cause.retryable, ?) + self.assertEqual(cause.gql_status, cause_status) + self.assertEqual(cause.status_description, + f"error: {cause_explanation}. {cause_message}") + self.assertIsNone(cause.cause) + self.assertEqual(cause.diagnostic_record, + types.as_cypher_type(diagnostic_record).value) + # TODO: TBD + # self.assertEqual(cause.classification, "UNKNOWN") + + def test_deeply_nested_error(self): + def make_status(i_): + return f"01N{i_:02d}" + + error_data = { + "gql_status": make_status(0), + "status_message": "msg", + "status_explanation": "explanation", + "neo4j_code": "Neo.ClientError.Bar.Baz0", + "diagnostic_record": DEFAULT_DIAG_REC, + } + parent_data = error_data + for i in range(1, 10): + parent_data["cause"] = { + "gql_status": make_status(i), + "status_message": f"msg{i}", + "status_explanation": f"explanation{i}", + "neo4j_code": f"Neo.ClientError.Bar.Baz{i}", + "diagnostic_record": DEFAULT_DIAG_REC, + } + parent_data = parent_data["cause"] + + error = self.get_error(error_data) + + for i in range(10): + self.assertIsInstance(error, types.DriverError) + self.assertEqual(error.code, f"Neo.ClientError.Bar.Baz{i}") + self.assertEqual(error.gql_status, make_status(i)) + error = error.cause diff --git a/tests/stub/versions/test_versions.py b/tests/stub/versions/test_versions.py index 0101af44..5dba5bf2 100644 --- a/tests/stub/versions/test_versions.py +++ b/tests/stub/versions/test_versions.py @@ -157,9 +157,13 @@ def test_supports_bolt5x4(self): def test_supports_bolt5x5(self): self._run("5x5") + @driver_feature(types.Feature.BOLT_5_6) + def test_supports_bolt5x6(self): + self._run("5x6") + def test_server_version(self): for version in ( - "5x5", "5x4", "5x3", "5x2", "5x1", "5x0", + "5x6", "5x5", "5x4", "5x3", "5x2", "5x1", "5x0", "4x4", "4x3", "4x2", "4x1", "3" ): if not self.driver_supports_bolt(version): @@ -169,7 +173,7 @@ def test_server_version(self): def test_server_agent(self): for version in ( - "5x5", "5x4", "5x3", "5x2", "5x1", "5x0", + "5x6", "5x5", "5x4", "5x3", "5x2", "5x1", "5x0", "4x4", "4x3", "4x2", "4x1", "3" ): for agent, reject in ( @@ -204,7 +208,7 @@ def test_server_address_in_summary(self): if get_driver_name() in ["javascript", "dotnet"]: self.skipTest("Backend doesn't support server address in summary") for version in ( - "5x5", "5x4", "5x3", "5x2", "5x1", "5x0", + "5x6", "5x5", "5x4", "5x3", "5x2", "5x1", "5x0", "4x4", "4x3", "4x2", "4x1", "3" ): if not self.driver_supports_bolt(version): @@ -331,6 +335,15 @@ def test_should_reject_server_using_verify_connectivity_bolt_5x5(self): version="5.5", script="v5x1_and_up_optional_hello.script" ) + @driver_feature(types.Feature.BOLT_5_6) + def test_should_reject_server_using_verify_connectivity_bolt_5x6(self): + # TODO remove this block once fixed + if get_driver_name() in ["dotnet", "go", "javascript"]: + self.skipTest("Driver does not check server agent string") + self._test_should_reject_server_using_verify_connectivity( + version="5.6", script="v5x1_and_up_optional_hello.script" + ) + def _test_should_reject_server_using_verify_connectivity( self, version, script ):