diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index d161371f0..e8923e33e 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -284,6 +284,7 @@ def protocol_handlers(cls, protocol_version=None): AsyncBolt5x3, AsyncBolt5x4, AsyncBolt5x5, + AsyncBolt5x6, ) handlers = { @@ -299,6 +300,7 @@ def protocol_handlers(cls, protocol_version=None): AsyncBolt5x3.PROTOCOL_VERSION: AsyncBolt5x3, AsyncBolt5x4.PROTOCOL_VERSION: AsyncBolt5x4, AsyncBolt5x5.PROTOCOL_VERSION: AsyncBolt5x5, + AsyncBolt5x6.PROTOCOL_VERSION: AsyncBolt5x6, } if protocol_version is None: @@ -413,7 +415,10 @@ async def open( # Carry out Bolt subclass imports locally to avoid circular dependency # issues. - if protocol_version == (5, 5): + if protocol_version == (5, 6): + from ._bolt5 import AsyncBolt5x6 + bolt_cls = AsyncBolt5x6 + elif protocol_version == (5, 5): from ._bolt5 import AsyncBolt5x5 bolt_cls = AsyncBolt5x5 elif protocol_version == (5, 4): diff --git a/src/neo4j/_async/io/_bolt5.py b/src/neo4j/_async/io/_bolt5.py index cd015f6e8..20e5ef7bd 100644 --- a/src/neo4j/_async/io/_bolt5.py +++ b/src/neo4j/_async/io/_bolt5.py @@ -685,6 +685,7 @@ def telemetry(self, api: TelemetryAPI, dehydration_hooks=None, Response(self, "telemetry", hydration_hooks, **handlers), dehydration_hooks=dehydration_hooks) + class AsyncBolt5x5(AsyncBolt5x4): PROTOCOL_VERSION = Version(5, 5) @@ -783,7 +784,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, ("CURRENT_SCHEMA", "/"), ) - def _make_enrich_diagnostic_record_handler(self, wrapped_handler=None): + def _make_enrich_statuses_handler(self, wrapped_handler=None): async def handler(metadata): def enrich(metadata_): if not isinstance(metadata_, dict): @@ -794,6 +795,7 @@ def enrich(metadata_): for status in statuses: if not isinstance(status, dict): continue + status["description"] = status.get("status_description") diag_record = status.setdefault("diagnostic_record", {}) if not isinstance(diag_record, dict): log.info("[#%04X] _: Server supplied an " @@ -810,14 +812,44 @@ def enrich(metadata_): def discard(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, **handlers): - handlers["on_success"] = self._make_enrich_diagnostic_record_handler( + handlers["on_success"] = self._make_enrich_statuses_handler( wrapped_handler=handlers.get("on_success") ) super().discard(n, qid, dehydration_hooks, hydration_hooks, **handlers) def pull(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, **handlers): - handlers["on_success"] = self._make_enrich_diagnostic_record_handler( + handlers["on_success"] = self._make_enrich_statuses_handler( wrapped_handler=handlers.get("on_success") ) super().pull(n, qid, dehydration_hooks, hydration_hooks, **handlers) + + +class AsyncBolt5x6(AsyncBolt5x5): + + PROTOCOL_VERSION = Version(5, 6) + + def _make_enrich_statuses_handler(self, wrapped_handler=None): + async def handler(metadata): + def enrich(metadata_): + if not isinstance(metadata_, dict): + return + statuses = metadata_.get("statuses") + if not isinstance(statuses, list): + return + for status in statuses: + if not isinstance(status, dict): + continue + diag_record = status.setdefault("diagnostic_record", {}) + if not isinstance(diag_record, dict): + log.info("[#%04X] _: Server supplied an " + "invalid diagnostic record (%r).", + self.local_port, diag_record) + continue + for key, value in self.DEFAULT_DIAGNOSTIC_RECORD: + diag_record.setdefault(key, value) + + enrich(metadata) + await AsyncUtil.callback(wrapped_handler, metadata) + + return handler diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 92258c2c1..1c6763900 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -284,6 +284,7 @@ def protocol_handlers(cls, protocol_version=None): Bolt5x3, Bolt5x4, Bolt5x5, + Bolt5x6, ) handlers = { @@ -299,6 +300,7 @@ def protocol_handlers(cls, protocol_version=None): Bolt5x3.PROTOCOL_VERSION: Bolt5x3, Bolt5x4.PROTOCOL_VERSION: Bolt5x4, Bolt5x5.PROTOCOL_VERSION: Bolt5x5, + Bolt5x6.PROTOCOL_VERSION: Bolt5x6, } if protocol_version is None: @@ -413,7 +415,10 @@ def open( # Carry out Bolt subclass imports locally to avoid circular dependency # issues. - if protocol_version == (5, 5): + if protocol_version == (5, 6): + from ._bolt5 import Bolt5x6 + bolt_cls = Bolt5x6 + elif protocol_version == (5, 5): from ._bolt5 import Bolt5x5 bolt_cls = Bolt5x5 elif protocol_version == (5, 4): diff --git a/src/neo4j/_sync/io/_bolt5.py b/src/neo4j/_sync/io/_bolt5.py index 12740a6c5..86924b8a3 100644 --- a/src/neo4j/_sync/io/_bolt5.py +++ b/src/neo4j/_sync/io/_bolt5.py @@ -685,6 +685,7 @@ def telemetry(self, api: TelemetryAPI, dehydration_hooks=None, Response(self, "telemetry", hydration_hooks, **handlers), dehydration_hooks=dehydration_hooks) + class Bolt5x5(Bolt5x4): PROTOCOL_VERSION = Version(5, 5) @@ -783,7 +784,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, ("CURRENT_SCHEMA", "/"), ) - def _make_enrich_diagnostic_record_handler(self, wrapped_handler=None): + def _make_enrich_statuses_handler(self, wrapped_handler=None): def handler(metadata): def enrich(metadata_): if not isinstance(metadata_, dict): @@ -794,6 +795,7 @@ def enrich(metadata_): for status in statuses: if not isinstance(status, dict): continue + status["description"] = status.get("status_description") diag_record = status.setdefault("diagnostic_record", {}) if not isinstance(diag_record, dict): log.info("[#%04X] _: Server supplied an " @@ -810,14 +812,44 @@ def enrich(metadata_): def discard(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, **handlers): - handlers["on_success"] = self._make_enrich_diagnostic_record_handler( + handlers["on_success"] = self._make_enrich_statuses_handler( wrapped_handler=handlers.get("on_success") ) super().discard(n, qid, dehydration_hooks, hydration_hooks, **handlers) def pull(self, n=-1, qid=-1, dehydration_hooks=None, hydration_hooks=None, **handlers): - handlers["on_success"] = self._make_enrich_diagnostic_record_handler( + handlers["on_success"] = self._make_enrich_statuses_handler( wrapped_handler=handlers.get("on_success") ) super().pull(n, qid, dehydration_hooks, hydration_hooks, **handlers) + + +class Bolt5x6(Bolt5x5): + + PROTOCOL_VERSION = Version(5, 6) + + def _make_enrich_statuses_handler(self, wrapped_handler=None): + def handler(metadata): + def enrich(metadata_): + if not isinstance(metadata_, dict): + return + statuses = metadata_.get("statuses") + if not isinstance(statuses, list): + return + for status in statuses: + if not isinstance(status, dict): + continue + diag_record = status.setdefault("diagnostic_record", {}) + if not isinstance(diag_record, dict): + log.info("[#%04X] _: Server supplied an " + "invalid diagnostic record (%r).", + self.local_port, diag_record) + continue + for key, value in self.DEFAULT_DIAGNOSTIC_RECORD: + diag_record.setdefault(key, value) + + enrich(metadata) + Util.callback(wrapped_handler, metadata) + + return handler diff --git a/src/neo4j/_work/summary.py b/src/neo4j/_work/summary.py index 3b795d3ba..3029c53f5 100644 --- a/src/neo4j/_work/summary.py +++ b/src/neo4j/_work/summary.py @@ -163,7 +163,7 @@ def _set_notifications(self): for notification_key, status_key in ( ("title", "title"), ("code", "neo4j_code"), - ("description", "status_description"), + ("description", "description"), ): value = status.get(status_key) if not isinstance(value, str) or not value: diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 7e16d0cfc..e4d8b14b5 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -57,6 +57,7 @@ "Feature:Bolt:5.3": true, "Feature:Bolt:5.4": true, "Feature:Bolt:5.5": true, + "Feature:Bolt:5.6": true, "Feature:Bolt:Patch:UTC": true, "Feature:Impersonation": true, "Feature:TLS:1.1": "Driver blocks TLS 1.1 for security reasons.", diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py index 7d1e5217b..686b5a393 100644 --- a/tests/unit/async_/io/test_class_bolt.py +++ b/tests/unit/async_/io/test_class_bolt.py @@ -37,7 +37,7 @@ def test_class_method_protocol_handlers(): expected_handlers = { (3, 0), (4, 1), (4, 2), (4, 3), (4, 4), - (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), + (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), } protocol_handlers = AsyncBolt.protocol_handlers() @@ -65,7 +65,8 @@ def test_class_method_protocol_handlers(): ((5, 3), 1), ((5, 4), 1), ((5, 5), 1), - ((5, 6), 0), + ((5, 6), 1), + ((5, 7), 0), ((6, 0), 0), ] ) @@ -85,7 +86,7 @@ def test_class_method_protocol_handlers_with_invalid_protocol_version(): # [bolt-version-bump] search tag when changing bolt version support def test_class_method_get_handshake(): handshake = AsyncBolt.get_handshake() - assert (b"\x00\x05\x05\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" + assert (b"\x00\x06\x06\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" == handshake) @@ -134,6 +135,7 @@ async def test_cancel_hello_in_open(mocker, none_auth): ((5, 3), "neo4j._async.io._bolt5.AsyncBolt5x3"), ((5, 4), "neo4j._async.io._bolt5.AsyncBolt5x4"), ((5, 5), "neo4j._async.io._bolt5.AsyncBolt5x5"), + ((5, 6), "neo4j._async.io._bolt5.AsyncBolt5x6"), ), ) @mark_async_test @@ -166,14 +168,14 @@ async def test_version_negotiation( (2, 0), (4, 0), (3, 1), - (5, 6), + (5, 7), (6, 0), )) @mark_async_test async def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = ( "('3.0', '4.1', '4.2', '4.3', '4.4', " - "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5')" + "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6')" ) address = ("localhost", 7687) diff --git a/tests/unit/async_/io/test_class_bolt5x5.py b/tests/unit/async_/io/test_class_bolt5x5.py index 3da6dee7c..413bf9abf 100644 --- a/tests/unit/async_/io/test_class_bolt5x5.py +++ b/tests/unit/async_/io/test_class_bolt5x5.py @@ -615,7 +615,7 @@ async def test_tracks_last_database(fake_socket_pair, actions): ) @pytest.mark.parametrize("method", ("pull", "discard")) @mark_async_test -async def test_enriches_diagnostic_record( +async def test_enriches_statuses( sent_diag_records, method, fake_socket_pair, @@ -628,7 +628,9 @@ async def test_enriches_diagnostic_record( sent_metadata = { "statuses": [ - {"diagnostic_record": r} if r is not ... else {} + {"status_description": "the status description", "diagnostic_record": r} + if r is not ... + else { "status_description": "the status description" } for r in sent_diag_records ] } @@ -654,7 +656,9 @@ def extend_diag_record(r): expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] expected_metadata = { "statuses": [ - {"diagnostic_record": r} if r is not ... else {} + {"status_description": "the status description", "description": "the status description", "diagnostic_record": r} + if r is not ... + else { "status_description": "the status description", "description": "the status description" } for r in expected_diag_records ] } diff --git a/tests/unit/async_/io/test_class_bolt5x6.py b/tests/unit/async_/io/test_class_bolt5x6.py new file mode 100644 index 000000000..84521bfb9 --- /dev/null +++ b/tests/unit/async_/io/test_class_bolt5x6.py @@ -0,0 +1,666 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging + +import pytest + +import neo4j +from neo4j._api import TelemetryAPI +from neo4j._async.config import AsyncPoolConfig +from neo4j._async.io._bolt5 import AsyncBolt5x6 +from neo4j._meta import ( + BOLT_AGENT_DICT, + USER_AGENT, +) + +from ...._async_compat import mark_async_test +from ....iter_util import powerset + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 0 + connection = AsyncBolt5x6(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = -1 + connection = AsyncBolt5x6(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 999999999 + connection = AsyncBolt5x6(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},) + ), +)) +@mark_async_test +async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x6.UNPACKER_CLS) + connection = AsyncBolt5x6(address, socket, AsyncPoolConfig.max_connection_lifetime) + connection.begin(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + (("", {}), {"imp_user": "imposter"}, ("", {}, {"imp_user": "imposter"})), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}) + ), +)) +@mark_async_test +async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x6.UNPACKER_CLS) + connection = AsyncBolt5x6(address, socket, AsyncPoolConfig.max_connection_lifetime) + connection.run(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_async_test +async def test_n_extra_in_discard(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x6.UNPACKER_CLS) + connection = AsyncBolt5x6(address, socket, AsyncPoolConfig.max_connection_lifetime) + connection.discard(n=666) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x6.UNPACKER_CLS) + connection = AsyncBolt5x6(address, socket, AsyncPoolConfig.max_connection_lifetime) + connection.discard(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ] +) +@mark_async_test +async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x6.UNPACKER_CLS) + connection = AsyncBolt5x6(address, socket, AsyncPoolConfig.max_connection_lifetime) + connection.discard(n=666, qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_n_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x6.UNPACKER_CLS) + connection = AsyncBolt5x6(address, socket, AsyncPoolConfig.max_connection_lifetime) + connection.pull(n=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x6.UNPACKER_CLS) + connection = AsyncBolt5x6(address, socket, AsyncPoolConfig.max_connection_lifetime) + connection.pull(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_async_test +async def test_n_and_qid_extras_in_pull(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x6.UNPACKER_CLS) + connection = AsyncBolt5x6(address, socket, AsyncPoolConfig.max_connection_lifetime) + connection.pull(n=666, qid=777) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_async_test +async def test_hello_passes_routing_metadata(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x6.PACKER_CLS, + unpacker_cls=AsyncBolt5x6.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x6( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"} + ) + await connection.hello() + tag, fields = await sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_async_test +async def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x6.UNPACKER_CLS) + connection = AsyncBolt5x6( + address, socket, AsyncPoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + await connection.send_all() + + if serv_enabled and not driver_disabled: + tag, fields = await socket.pop_message() + assert tag == b"\x54" + assert fields == [int(api)] + else: + with pytest.raises(OSError): + await socket.pop_message() + + +@pytest.mark.parametrize(("hints", "valid"), ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), +)) +@mark_async_test +async def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x6.PACKER_CLS, + unpacker_cls=AsyncBolt5x6.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x6( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + await connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any("recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + else: + sockets.client.settimeout.assert_not_called() + assert any(repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_async_test +async def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x6.PACKER_CLS, + unpacker_cls=AsyncBolt5x6.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x6( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +def _assert_notifications_in_extra(extra, expected): + for key in expected: + assert key in extra + assert extra[key] == expected[key] + + + +@pytest.mark.parametrize(("method", "args", "extra_idx"), ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), +)) +@pytest.mark.parametrize( + ("cls_min_sev", "method_min_sev"), + itertools.product((None, "WARNING", "OFF"), repeat=2) +) +@pytest.mark.parametrize( + ("cls_dis_clss", "method_dis_clss"), + itertools.product((None, [], ["HINT"], ["HINT", "DEPRECATION"]), repeat=2) +) +@mark_async_test +async def test_supports_notification_filters( + fake_socket, method, args, extra_idx, cls_min_sev, method_min_sev, + cls_dis_clss, method_dis_clss +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x6.UNPACKER_CLS) + connection = AsyncBolt5x6( + address, socket, AsyncPoolConfig.max_connection_lifetime, + notifications_min_severity=cls_min_sev, + notifications_disabled_classifications=cls_dis_clss + ) + method = getattr(connection, method) + + method(*args, notifications_min_severity=method_min_sev, + notifications_disabled_classifications=method_dis_clss) + await connection.send_all() + + _, fields = await socket.pop_message() + extra = fields[extra_idx] + expected = {} + if method_min_sev is not None: + expected["notifications_minimum_severity"] = method_min_sev + if method_dis_clss is not None: + expected["notifications_disabled_classifications"] = method_dis_clss + _assert_notifications_in_extra(extra, expected) + + +@pytest.mark.parametrize("min_sev", (None, "WARNING", "OFF")) +@pytest.mark.parametrize("dis_clss", + (None, [], ["HINT"], ["HINT", "DEPRECATION"])) +@mark_async_test +async def test_hello_supports_notification_filters( + fake_socket_pair, min_sev, dis_clss +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x6.PACKER_CLS, + unpacker_cls=AsyncBolt5x6.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x6( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime, + notifications_min_severity=min_sev, + notifications_disabled_classifications=dis_clss + ) + + await connection.hello() + + tag, fields = await sockets.server.pop_message() + extra = fields[0] + expected = {} + if min_sev is not None: + expected["notifications_minimum_severity"] = min_sev + if dis_clss is not None: + expected["notifications_disabled_classifications"] = dis_clss + _assert_notifications_in_extra(extra, expected) + + +@mark_async_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +async def test_user_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x6.PACKER_CLS, + unpacker_cls=AsyncBolt5x6.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = AsyncBolt5x6( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + await connection.hello() + + tag, fields = await sockets.server.pop_message() + extra = fields[0] + if not user_agent: + assert extra["user_agent"] == USER_AGENT + else: + assert extra["user_agent"] == user_agent + + +@mark_async_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +async def test_sends_bolt_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x6.PACKER_CLS, + unpacker_cls=AsyncBolt5x6.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = AsyncBolt5x6( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + await connection.hello() + + tag, fields = await sockets.server.pop_message() + extra = fields[0] + assert extra["bolt_agent"] == BOLT_AGENT_DICT + + +@mark_async_test +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ) +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + -1e-15, + ValueError("Timeout must be a positive number or 0") + ), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds") + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds") + ) + ) +) +async def test_tx_timeout( + fake_socket_pair, func, args, extra_idx, timeout, res +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x6.PACKER_CLS, + unpacker_cls=AsyncBolt5x6.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x6(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + await connection.send_all() + tag, fields = await sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +@mark_async_test +async def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x6.PACKER_CLS, + unpacker_cls=AsyncBolt5x6.UNPACKER_CLS) + connection = AsyncBolt5x6(address, sockets.client, 0) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + await connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + await sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + await sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + await connection.send_all() + await connection.fetch_all() + assert connection.last_database == db + + await sockets.server.send_message(b"\x70", {}) + if finish == "reset": + await connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + await connection.send_all() + await connection.fetch_all() + + assert connection.last_database == db + + +DEFAULT_DIAG_REC_PAIRS = ( + ("OPERATION", ""), + ("OPERATION_CODE", "0"), + ("CURRENT_SCHEMA", "/"), +) + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + limit=3, + ) +) +@pytest.mark.parametrize("method", ("pull", "discard")) +@mark_async_test +async def test_enriches_statuses( + sent_diag_records, + method, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x6.PACKER_CLS, + unpacker_cls=AsyncBolt5x6.UNPACKER_CLS) + connection = AsyncBolt5x6(address, sockets.client, 0) + + sent_metadata = { + "statuses": [ + {"status_description": "the status description", "description": "description", "diagnostic_record": r} + if r is not ... + else { "status_description": "the status description", "description": "description" } + for r in sent_diag_records + ] + } + await sockets.server.send_message(b"\x70", sent_metadata) + + received_metadata = None + + def on_success(metadata): + nonlocal received_metadata + received_metadata = metadata + + getattr(connection, method)(on_success=on_success) + await connection.send_all() + await connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = { + "statuses": [ + {"status_description": "the status description", "description": "description", "diagnostic_record": r} + if r is not ... + else { "status_description": "the status description", "description": "description" } + for r in expected_diag_records + ] + } + + assert received_metadata == expected_metadata diff --git a/tests/unit/common/work/test_summary.py b/tests/unit/common/work/test_summary.py index 6e86436e5..bb3610716 100644 --- a/tests/unit/common/work/test_summary.py +++ b/tests/unit/common/work/test_summary.py @@ -190,6 +190,7 @@ def test_statuses_and_notifications_dont_mix(summary_args_kwargs) -> None: raw_status = { "gql_status": "12345", "status_description": "cool description", + "description": "cool notification description", "neo4j_code": "Neo.Foo.Bar.Baz", "title": "nice title", "diagnostic_record": raw_diag_rec, @@ -314,6 +315,7 @@ def make_raw_status( "gql_status": gql_status, "status_description": "note: successful completion - " f"custom stuff {i}", + "description": f"notification description {i}", "neo4j_code": f"Neo.Foo.Bar.{type_}-{i}", "title": f"Some cool title which defo is dope! {i}", "diagnostic_record": { @@ -606,15 +608,18 @@ def test_status( ) -> None: args, kwargs = summary_args_kwargs default_position = SummaryInputPosition(line=1337, column=42, offset=420) - default_description = "some nice description goes here" + default_status_description = "some nice description goes here" + default_description = "some nice notification description here" default_severity = "WARNING" default_classification = "HINT" default_code = "Neo.Cool.Legacy.Code" default_title = "Cool Title" default_gql_status = "12345" + raw_status: t.Dict[str, t.Any] = { "gql_status": default_gql_status, - "status_description": default_description, + "status_description": default_status_description, + "description": default_description, "neo4j_code": default_code, "title": default_title, "diagnostic_record": { @@ -661,7 +666,7 @@ def test_status( == expectation_overwrite.get("gql_status", default_gql_status)) assert (status.status_description == expectation_overwrite.get("status_description", - default_description)) + default_status_description)) assert (status.position == expectation_overwrite.get("position", default_position)) assert (status.raw_classification @@ -837,6 +842,7 @@ def test_summary_result_counters(summary_args_kwargs, counters_set) -> None: ((5, 3), "t_first"), ((5, 4), "t_first"), ((5, 5), "t_first"), + ((5, 6), "t_first"), )) def test_summary_result_available_after( summary_args_kwargs, exists, bolt_version, meta_name @@ -869,6 +875,7 @@ def test_summary_result_available_after( ((5, 3), "t_last"), ((5, 4), "t_last"), ((5, 5), "t_last"), + ((5, 6), "t_last"), )) def test_summary_result_consumed_after( summary_args_kwargs, exists, bolt_version, meta_name @@ -1438,9 +1445,9 @@ def test_no_notification_from_status(raw_status, summary_args_kwargs) -> None: ("FOOBAR", None, ..., -1, 1.6, False, [], {})) ), - # copies status_description to description + # copies description to description ( - {"status_description": "something completely different 👀"}, {}, + {"description": "something completely different 👀"}, {}, {"description": "something completely different 👀"} ), @@ -1537,15 +1544,17 @@ def test_notification_from_status( summary_args_kwargs ) -> None: default_status = "03BAZ" - default_description = "note: successful completion - custom stuff" + default_status_description = "note: successful completion - custom stuff" default_code = "Neo.Foo.Bar.Baz" default_title = "Some cool title which defo is dope!" default_severity = "INFORMATION" default_classification = "HINT" + default_description = "nice message" default_position = SummaryInputPosition(line=1337, column=42, offset=420) raw_status_obj: t.Dict[str, t.Any] = { "gql_status": default_status, - "status_description": default_description, + "status_description": default_status_description, + "description": default_description, "neo4j_code": default_code, "title": default_title, "diagnostic_record": { @@ -1767,7 +1776,7 @@ def test_broken_diagnostic_record(in_status, summary_args_kwargs) -> None: ("status_overwrite", "diagnostic_record_overwrite"), ( *( - ({"status_description": value}, {}) + ({"description": value}, {}) for value in t.cast(t.Iterable[t.Any], ("", None, ..., 1, False, [], {})) ), @@ -1818,7 +1827,8 @@ def test_no_notification_from_broken_status( status_overwrite, diagnostic_record_overwrite, summary_args_kwargs ) -> None: default_status = "03BAZ" - default_description = "note: successful completion - custom stuff" + default_status_description = "note: successful completion - custom stuff" + default_description = "some description" default_code = "Neo.Foo.Bar.Baz" default_title = "Some cool title which defo is dope!" default_severity = "INFORMATION" @@ -1826,7 +1836,8 @@ def test_no_notification_from_broken_status( default_position = SummaryInputPosition(line=1337, column=42, offset=420) raw_status_obj: t.Dict[str, t.Any] = { "gql_status": default_status, - "status_description": default_description, + "description": default_description, + "status_description": default_status_description, "neo4j_code": default_code, "title": default_title, "diagnostic_record": { diff --git a/tests/unit/sync/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py index d07f673c1..b0d000cf9 100644 --- a/tests/unit/sync/io/test_class_bolt.py +++ b/tests/unit/sync/io/test_class_bolt.py @@ -37,7 +37,7 @@ def test_class_method_protocol_handlers(): expected_handlers = { (3, 0), (4, 1), (4, 2), (4, 3), (4, 4), - (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), + (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), } protocol_handlers = Bolt.protocol_handlers() @@ -65,7 +65,8 @@ def test_class_method_protocol_handlers(): ((5, 3), 1), ((5, 4), 1), ((5, 5), 1), - ((5, 6), 0), + ((5, 6), 1), + ((5, 7), 0), ((6, 0), 0), ] ) @@ -85,7 +86,7 @@ def test_class_method_protocol_handlers_with_invalid_protocol_version(): # [bolt-version-bump] search tag when changing bolt version support def test_class_method_get_handshake(): handshake = Bolt.get_handshake() - assert (b"\x00\x05\x05\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" + assert (b"\x00\x06\x06\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" == handshake) @@ -134,6 +135,7 @@ def test_cancel_hello_in_open(mocker, none_auth): ((5, 3), "neo4j._sync.io._bolt5.Bolt5x3"), ((5, 4), "neo4j._sync.io._bolt5.Bolt5x4"), ((5, 5), "neo4j._sync.io._bolt5.Bolt5x5"), + ((5, 6), "neo4j._sync.io._bolt5.Bolt5x6"), ), ) @mark_sync_test @@ -166,14 +168,14 @@ def test_version_negotiation( (2, 0), (4, 0), (3, 1), - (5, 6), + (5, 7), (6, 0), )) @mark_sync_test def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = ( "('3.0', '4.1', '4.2', '4.3', '4.4', " - "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5')" + "'5.0', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6')" ) address = ("localhost", 7687) diff --git a/tests/unit/sync/io/test_class_bolt5x5.py b/tests/unit/sync/io/test_class_bolt5x5.py index 550cf8546..fe30b7e9a 100644 --- a/tests/unit/sync/io/test_class_bolt5x5.py +++ b/tests/unit/sync/io/test_class_bolt5x5.py @@ -615,7 +615,7 @@ def test_tracks_last_database(fake_socket_pair, actions): ) @pytest.mark.parametrize("method", ("pull", "discard")) @mark_sync_test -def test_enriches_diagnostic_record( +def test_enriches_statuses( sent_diag_records, method, fake_socket_pair, @@ -628,7 +628,9 @@ def test_enriches_diagnostic_record( sent_metadata = { "statuses": [ - {"diagnostic_record": r} if r is not ... else {} + {"status_description": "the status description", "diagnostic_record": r} + if r is not ... + else { "status_description": "the status description" } for r in sent_diag_records ] } @@ -654,7 +656,9 @@ def extend_diag_record(r): expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] expected_metadata = { "statuses": [ - {"diagnostic_record": r} if r is not ... else {} + {"status_description": "the status description", "description": "the status description", "diagnostic_record": r} + if r is not ... + else { "status_description": "the status description", "description": "the status description" } for r in expected_diag_records ] } diff --git a/tests/unit/sync/io/test_class_bolt5x6.py b/tests/unit/sync/io/test_class_bolt5x6.py new file mode 100644 index 000000000..0504f0731 --- /dev/null +++ b/tests/unit/sync/io/test_class_bolt5x6.py @@ -0,0 +1,666 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging + +import pytest + +import neo4j +from neo4j._api import TelemetryAPI +from neo4j._meta import ( + BOLT_AGENT_DICT, + USER_AGENT, +) +from neo4j._sync.config import PoolConfig +from neo4j._sync.io._bolt5 import Bolt5x6 + +from ...._async_compat import mark_sync_test +from ....iter_util import powerset + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 0 + connection = Bolt5x6(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = -1 + connection = Bolt5x6(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 999999999 + connection = Bolt5x6(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},) + ), +)) +@mark_sync_test +def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x6.UNPACKER_CLS) + connection = Bolt5x6(address, socket, PoolConfig.max_connection_lifetime) + connection.begin(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + (("", {}), {"imp_user": "imposter"}, ("", {}, {"imp_user": "imposter"})), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}) + ), +)) +@mark_sync_test +def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x6.UNPACKER_CLS) + connection = Bolt5x6(address, socket, PoolConfig.max_connection_lifetime) + connection.run(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_sync_test +def test_n_extra_in_discard(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x6.UNPACKER_CLS) + connection = Bolt5x6(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x6.UNPACKER_CLS) + connection = Bolt5x6(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ] +) +@mark_sync_test +def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x6.UNPACKER_CLS) + connection = Bolt5x6(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666, qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_n_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x6.UNPACKER_CLS) + connection = Bolt5x6(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x6.UNPACKER_CLS) + connection = Bolt5x6(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_sync_test +def test_n_and_qid_extras_in_pull(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x6.UNPACKER_CLS) + connection = Bolt5x6(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=666, qid=777) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_sync_test +def test_hello_passes_routing_metadata(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x6.PACKER_CLS, + unpacker_cls=Bolt5x6.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x6( + address, sockets.client, PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"} + ) + connection.hello() + tag, fields = sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_sync_test +def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x6.UNPACKER_CLS) + connection = Bolt5x6( + address, socket, PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled + ) + if serv_enabled: + connection.configuration_hints["telemetry.enabled"] = True + connection.telemetry(api) + connection.send_all() + + if serv_enabled and not driver_disabled: + tag, fields = socket.pop_message() + assert tag == b"\x54" + assert fields == [int(api)] + else: + with pytest.raises(OSError): + socket.pop_message() + + +@pytest.mark.parametrize(("hints", "valid"), ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), +)) +@mark_sync_test +def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x6.PACKER_CLS, + unpacker_cls=Bolt5x6.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x6( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any("recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + else: + sockets.client.settimeout.assert_not_called() + assert any(repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_sync_test +def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x6.PACKER_CLS, + unpacker_cls=Bolt5x6.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x6( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +def _assert_notifications_in_extra(extra, expected): + for key in expected: + assert key in extra + assert extra[key] == expected[key] + + + +@pytest.mark.parametrize(("method", "args", "extra_idx"), ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), +)) +@pytest.mark.parametrize( + ("cls_min_sev", "method_min_sev"), + itertools.product((None, "WARNING", "OFF"), repeat=2) +) +@pytest.mark.parametrize( + ("cls_dis_clss", "method_dis_clss"), + itertools.product((None, [], ["HINT"], ["HINT", "DEPRECATION"]), repeat=2) +) +@mark_sync_test +def test_supports_notification_filters( + fake_socket, method, args, extra_idx, cls_min_sev, method_min_sev, + cls_dis_clss, method_dis_clss +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x6.UNPACKER_CLS) + connection = Bolt5x6( + address, socket, PoolConfig.max_connection_lifetime, + notifications_min_severity=cls_min_sev, + notifications_disabled_classifications=cls_dis_clss + ) + method = getattr(connection, method) + + method(*args, notifications_min_severity=method_min_sev, + notifications_disabled_classifications=method_dis_clss) + connection.send_all() + + _, fields = socket.pop_message() + extra = fields[extra_idx] + expected = {} + if method_min_sev is not None: + expected["notifications_minimum_severity"] = method_min_sev + if method_dis_clss is not None: + expected["notifications_disabled_classifications"] = method_dis_clss + _assert_notifications_in_extra(extra, expected) + + +@pytest.mark.parametrize("min_sev", (None, "WARNING", "OFF")) +@pytest.mark.parametrize("dis_clss", + (None, [], ["HINT"], ["HINT", "DEPRECATION"])) +@mark_sync_test +def test_hello_supports_notification_filters( + fake_socket_pair, min_sev, dis_clss +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x6.PACKER_CLS, + unpacker_cls=Bolt5x6.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x6( + address, sockets.client, PoolConfig.max_connection_lifetime, + notifications_min_severity=min_sev, + notifications_disabled_classifications=dis_clss + ) + + connection.hello() + + tag, fields = sockets.server.pop_message() + extra = fields[0] + expected = {} + if min_sev is not None: + expected["notifications_minimum_severity"] = min_sev + if dis_clss is not None: + expected["notifications_disabled_classifications"] = dis_clss + _assert_notifications_in_extra(extra, expected) + + +@mark_sync_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +def test_user_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x6.PACKER_CLS, + unpacker_cls=Bolt5x6.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = Bolt5x6( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + connection.hello() + + tag, fields = sockets.server.pop_message() + extra = fields[0] + if not user_agent: + assert extra["user_agent"] == USER_AGENT + else: + assert extra["user_agent"] == user_agent + + +@mark_sync_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +def test_sends_bolt_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x6.PACKER_CLS, + unpacker_cls=Bolt5x6.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = Bolt5x6( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + connection.hello() + + tag, fields = sockets.server.pop_message() + extra = fields[0] + assert extra["bolt_agent"] == BOLT_AGENT_DICT + + +@mark_sync_test +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ) +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + -1e-15, + ValueError("Timeout must be a positive number or 0") + ), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds") + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds") + ) + ) +) +def test_tx_timeout( + fake_socket_pair, func, args, extra_idx, timeout, res +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x6.PACKER_CLS, + unpacker_cls=Bolt5x6.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x6(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + connection.send_all() + tag, fields = sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +@mark_sync_test +def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x6.PACKER_CLS, + unpacker_cls=Bolt5x6.UNPACKER_CLS) + connection = Bolt5x6(address, sockets.client, 0) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + connection.send_all() + connection.fetch_all() + assert connection.last_database == db + + sockets.server.send_message(b"\x70", {}) + if finish == "reset": + connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + connection.send_all() + connection.fetch_all() + + assert connection.last_database == db + + +DEFAULT_DIAG_REC_PAIRS = ( + ("OPERATION", ""), + ("OPERATION_CODE", "0"), + ("CURRENT_SCHEMA", "/"), +) + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + limit=3, + ) +) +@pytest.mark.parametrize("method", ("pull", "discard")) +@mark_sync_test +def test_enriches_statuses( + sent_diag_records, + method, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x6.PACKER_CLS, + unpacker_cls=Bolt5x6.UNPACKER_CLS) + connection = Bolt5x6(address, sockets.client, 0) + + sent_metadata = { + "statuses": [ + {"status_description": "the status description", "description": "description", "diagnostic_record": r} + if r is not ... + else { "status_description": "the status description", "description": "description" } + for r in sent_diag_records + ] + } + sockets.server.send_message(b"\x70", sent_metadata) + + received_metadata = None + + def on_success(metadata): + nonlocal received_metadata + received_metadata = metadata + + getattr(connection, method)(on_success=on_success) + connection.send_all() + connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = { + "statuses": [ + {"status_description": "the status description", "description": "description", "diagnostic_record": r} + if r is not ... + else { "status_description": "the status description", "description": "description" } + for r in expected_diag_records + ] + } + + assert received_metadata == expected_metadata