diff --git a/CHANGELOG.md b/CHANGELOG.md index 3447e0a8c..29f98993f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -87,6 +87,13 @@ does not offer the `commit`, `rollback`, `close`, and `closed` methods. Those methods would have caused a hard to interpreted error previously. Hence, they have been removed. +- Deprecated Nodes' and Relationships' `id` property (`int`) in favor of + `element_id` (`str`). + This also affects `Graph` objects as `graph.nodes[...]` and + `graph.relationships[...]` now prefers strings over integers. +- `ServerInfo.connection_id` has been deprecated and will be removed in a + future release. There is no replacement as this is considered internal + information. ## Version 4.4 diff --git a/neo4j/api.py b/neo4j/api.py index b9d4de07c..567d02a02 100644 --- a/neo4j/api.py +++ b/neo4j/api.py @@ -308,6 +308,8 @@ def agent(self): return self._metadata.get("server") @property + @deprecated("The connection id is considered internal information " + "and will no longer be exposed in future versions.") def connection_id(self): """ Unique identifier for the remote server connection. """ diff --git a/neo4j/graph/__init__.py b/neo4j/graph/__init__.py index 62b1d5bcd..d8e8582d8 100644 --- a/neo4j/graph/__init__.py +++ b/neo4j/graph/__init__.py @@ -31,7 +31,10 @@ from collections.abc import Mapping -from ..meta import deprecated +from ..meta import ( + deprecated, + deprecation_warn, +) class Graph: @@ -261,6 +264,20 @@ def __init__(self, entity_dict): self._entity_dict = entity_dict def __getitem__(self, e_id): + # TODO: 6.0 - remove this compatibility shim + if isinstance(e_id, (int, float, complex)): + deprecation_warn( + "Accessing entities by an integer id is deprecated, " + "use the new style element_id (str) instead" + ) + if isinstance(e_id, float) and int(e_id) == e_id: + # Non-int floats would always fail for legacy IDs + e_id = int(e_id) + elif isinstance(e_id, complex) and int(e_id.real) == e_id: + # complex numbers with imaginary parts or non-integer real + # parts would always fail for legacy IDs + e_id = int(e_id.real) + e_id = str(e_id) return self._entity_dict[e_id] def __len__(self): diff --git a/neo4j/time/hydration.py b/neo4j/time/hydration.py index 072752e23..0212a1af9 100644 --- a/neo4j/time/hydration.py +++ b/neo4j/time/hydration.py @@ -136,7 +136,7 @@ def dehydrate_datetime(value): """ Dehydrator for `datetime` values. :param value: - :type value: datetime + :type value: datetime or DateTime :return: """ diff --git a/testkit/build.py b/testkit/build.py index 5de76f147..bb7f07080 100644 --- a/testkit/build.py +++ b/testkit/build.py @@ -19,18 +19,22 @@ """ -Executed in Go driver container. +Executed in driver container. Responsible for building driver and test backend. """ import subprocess +import sys def run(args, env=None): - subprocess.run(args, universal_newlines=True, stderr=subprocess.STDOUT, - check=True, env=env) + subprocess.run(args, universal_newlines=True, stdout=sys.stdout, + stderr=sys.stderr, check=True, env=env) if __name__ == "__main__": run(["python", "setup.py", "build"]) + run(["python", "-m", "pip", "install", "-U", "pip"]) + run(["python", "-m", "pip", "install", "-Ur", + "testkitbackend/requirements.txt"]) diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index 9d2fdb1d9..743379956 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -18,12 +18,15 @@ import json from os import path +import re +import warnings import neo4j from neo4j._async_compat.util import AsyncUtil from .. import ( fromtestkit, + test_subtest_skips, totestkit, ) from ..exceptions import MarkdAsDriverException @@ -48,11 +51,38 @@ def load_config(): SKIPPED_TESTS, FEATURES = load_config() +def _get_skip_reason(test_name): + for skip_pattern, reason in SKIPPED_TESTS.items(): + if skip_pattern[0] == skip_pattern[-1] == "'": + match = skip_pattern[1:-1] == test_name + else: + match = re.match(skip_pattern, test_name) + if match: + return reason + + async def StartTest(backend, data): - if data["testName"] in SKIPPED_TESTS: - await backend.send_response("SkipTest", { - "reason": SKIPPED_TESTS[data["testName"]] - }) + test_name = data["testName"] + reason = _get_skip_reason(test_name) + if reason is not None: + if reason.startswith("test_subtest_skips."): + await backend.send_response("RunSubTests", {}) + else: + await backend.send_response("SkipTest", {"reason": reason}) + else: + await backend.send_response("RunTest", {}) + + +async def StartSubTest(backend, data): + test_name = data["testName"] + subtest_args = data["subtestArguments"] + subtest_args.mark_all_as_read(recursive=True) + reason = _get_skip_reason(test_name) + assert reason and reason.startswith("test_subtest_skips.") or print(reason) + func = getattr(test_subtest_skips, reason[19:]) + reason = func(**subtest_args) + if reason is not None: + await backend.send_response("SkipTest", {"reason": reason}) else: await backend.send_response("RunTest", {}) @@ -412,6 +442,17 @@ async def ResultSingle(backend, data): )) +async def ResultSingleOptional(backend, data): + result = backend.results[data["resultId"]] + with warnings.catch_warnings(record=True) as warning_list: + record = await result.single(strict=False) + if record: + record = totestkit.record(record) + await backend.send_response("RecordOptional", { + "record": record, "warnings": list(map(str, warning_list)) + }) + + async def ResultPeek(backend, data): result = backend.results[data["resultId"]] record = await result.peek() diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index 3d28211af..a5ad3c752 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -18,12 +18,15 @@ import json from os import path +import re +import warnings import neo4j from neo4j._async_compat.util import Util from .. import ( fromtestkit, + test_subtest_skips, totestkit, ) from ..exceptions import MarkdAsDriverException @@ -48,11 +51,38 @@ def load_config(): SKIPPED_TESTS, FEATURES = load_config() +def _get_skip_reason(test_name): + for skip_pattern, reason in SKIPPED_TESTS.items(): + if skip_pattern[0] == skip_pattern[-1] == "'": + match = skip_pattern[1:-1] == test_name + else: + match = re.match(skip_pattern, test_name) + if match: + return reason + + def StartTest(backend, data): - if data["testName"] in SKIPPED_TESTS: - backend.send_response("SkipTest", { - "reason": SKIPPED_TESTS[data["testName"]] - }) + test_name = data["testName"] + reason = _get_skip_reason(test_name) + if reason is not None: + if reason.startswith("test_subtest_skips."): + backend.send_response("RunSubTests", {}) + else: + backend.send_response("SkipTest", {"reason": reason}) + else: + backend.send_response("RunTest", {}) + + +def StartSubTest(backend, data): + test_name = data["testName"] + subtest_args = data["subtestArguments"] + subtest_args.mark_all_as_read(recursive=True) + reason = _get_skip_reason(test_name) + assert reason and reason.startswith("test_subtest_skips.") or print(reason) + func = getattr(test_subtest_skips, reason[19:]) + reason = func(**subtest_args) + if reason is not None: + backend.send_response("SkipTest", {"reason": reason}) else: backend.send_response("RunTest", {}) @@ -412,6 +442,17 @@ def ResultSingle(backend, data): )) +def ResultSingleOptional(backend, data): + result = backend.results[data["resultId"]] + with warnings.catch_warnings(record=True) as warning_list: + record = result.single(strict=False) + if record: + record = totestkit.record(record) + backend.send_response("RecordOptional", { + "record": record, "warnings": list(map(str, warning_list)) + }) + + def ResultPeek(backend, data): result = backend.results[data["resultId"]] record = result.peek() diff --git a/testkitbackend/fromtestkit.py b/testkitbackend/fromtestkit.py index 39c74aed6..6fe6472c4 100644 --- a/testkitbackend/fromtestkit.py +++ b/testkitbackend/fromtestkit.py @@ -16,7 +16,21 @@ # limitations under the License. +from datetime import timedelta + +import pytz + from neo4j import Query +from neo4j.spatial import ( + CartesianPoint, + WGS84Point, +) +from neo4j.time import ( + Date, + DateTime, + Duration, + Time, +) def to_cypher_and_params(data): @@ -54,24 +68,81 @@ def to_query_and_params(data): def to_param(m): """ Converts testkit parameter format to driver (python) parameter """ - value = m["data"]["value"] + data = m["data"] name = m["name"] if name == "CypherNull": + if data["value"] is not None: + raise ValueError("CypherNull should be None") return None if name == "CypherString": - return str(value) + return str(data["value"]) if name == "CypherBool": - return bool(value) + return bool(data["value"]) if name == "CypherInt": - return int(value) + return int(data["value"]) if name == "CypherFloat": - return float(value) + return float(data["value"]) if name == "CypherString": - return str(value) + return str(data["value"]) if name == "CypherBytes": - return bytearray([int(byte, 16) for byte in value.split()]) + return bytearray([int(byte, 16) for byte in data["value"].split()]) if name == "CypherList": - return [to_param(v) for v in value] + return [to_param(v) for v in data["value"]] if name == "CypherMap": - return {k: to_param(value[k]) for k in value} - raise Exception("Unknown param type " + name) + return {k: to_param(data["value"][k]) for k in data["value"]} + if name == "CypherPoint": + coords = [data["x"], data["y"]] + if data.get("z") is not None: + coords.append(data["z"]) + if data["system"] == "cartesian": + return CartesianPoint(coords) + if data["system"] == "wgs84": + return WGS84Point(coords) + raise ValueError("Unknown point system: {}".format(data["system"])) + if name == "CypherDate": + return Date(data["year"], data["month"], data["day"]) + if name == "CypherTime": + tz = None + utc_offset_s = data.get("utc_offset_s") + if utc_offset_s is not None: + utc_offset_m = utc_offset_s // 60 + if utc_offset_m * 60 != utc_offset_s: + raise ValueError("the used timezone library only supports " + "UTC offsets by minutes") + tz = pytz.FixedOffset(utc_offset_m) + return Time(data["hour"], data["minute"], data["second"], + data["nanosecond"], tzinfo=tz) + if name == "CypherDateTime": + datetime = DateTime( + data["year"], data["month"], data["day"], + data["hour"], data["minute"], data["second"], data["nanosecond"] + ) + utc_offset_s = data["utc_offset_s"] + timezone_id = data["timezone_id"] + if timezone_id is not None: + utc_offset = timedelta(seconds=utc_offset_s) + tz = pytz.timezone(timezone_id) + localized_datetime = tz.localize(datetime, is_dst=False) + if localized_datetime.utcoffset() == utc_offset: + return localized_datetime + localized_datetime = tz.localize(datetime, is_dst=True) + if localized_datetime.utcoffset() == utc_offset: + return localized_datetime + raise ValueError( + "cannot localize datetime %s to timezone %s with UTC " + "offset %s" % (datetime, timezone_id, utc_offset) + ) + elif utc_offset_s is not None: + utc_offset_m = utc_offset_s // 60 + if utc_offset_m * 60 != utc_offset_s: + raise ValueError("the used timezone library only supports " + "UTC offsets by minutes") + tz = pytz.FixedOffset(utc_offset_m) + return tz.localize(datetime) + return datetime + if name == "CypherDuration": + return Duration( + months=data["months"], days=data["days"], + seconds=data["seconds"], nanoseconds=data["nanoseconds"] + ) + raise ValueError("Unknown param type " + name) diff --git a/testkitbackend/requirements.txt b/testkitbackend/requirements.txt new file mode 100644 index 000000000..3c8d7e782 --- /dev/null +++ b/testkitbackend/requirements.txt @@ -0,0 +1 @@ +-r ../requirements.txt diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 6bdd34756..9b38714ff 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -1,27 +1,19 @@ { "skips": { - "stub.retry.test_retry_clustering.TestRetryClustering.test_retry_ForbiddenOnReadOnlyDatabase_ChangingWriter": + "stub\\.retry\\.test_retry_clustering\\.TestRetryClustering\\.test_retry_ForbiddenOnReadOnlyDatabase_ChangingWriter": "Test makes assumptions about how verify_connectivity is implemented", - "stub.authorization.test_authorization.TestAuthorizationV5x0.test_should_retry_on_auth_expired_on_begin_using_tx_function": + "stub\\.authorization\\.test_authorization\\.TestAuthorizationV[0-9x]+\\.test_should_retry_on_auth_expired_on_begin_using_tx_function": "Flaky: test requires the driver to contact servers in a specific order", - "stub.authorization.test_authorization.TestAuthorizationV4x3.test_should_retry_on_auth_expired_on_begin_using_tx_function": + "stub\\.authorization\\.test_authorization\\.TestAuthorizationV[0-9x]+\\.test_should_fail_on_token_expired_on_begin_using_tx_function": "Flaky: test requires the driver to contact servers in a specific order", - "stub.authorization.test_authorization.TestAuthorizationV3.test_should_retry_on_auth_expired_on_begin_using_tx_function": - "Flaky: test requires the driver to contact servers in a specific order", - "stub.authorization.test_authorization.TestAuthorizationV4x1.test_should_retry_on_auth_expired_on_begin_using_tx_function": - "Flaky: test requires the driver to contact servers in a specific order", - "stub.authorization.test_authorization.TestAuthorizationV5x0.test_should_fail_on_token_expired_on_begin_using_tx_function": - "Flaky: test requires the driver to contact servers in a specific order", - "stub.authorization.test_authorization.TestAuthorizationV4x3.test_should_fail_on_token_expired_on_begin_using_tx_function": - "Flaky: test requires the driver to contact servers in a specific order", - "stub.authorization.test_authorization.TestAuthorizationV3.test_should_fail_on_token_expired_on_begin_using_tx_function": - "Flaky: test requires the driver to contact servers in a specific order", - "stub.authorization.test_authorization.TestAuthorizationV4x1.test_should_fail_on_token_expired_on_begin_using_tx_function": - "Flaky: test requires the driver to contact servers in a specific order", - "stub.session_run_parameters.test_session_run_parameters.TestSessionRunParameters.test_empty_query": + "'stub.session_run_parameters.test_session_run_parameters.TestSessionRunParameters.test_empty_query'": "Driver rejects empty queries before sending it to the server", - "stub.server_side_routing.test_server_side_routing.TestServerSideRouting.test_direct_connection_with_url_params": - "Driver emits deprecation warning. Behavior will be unified in 6.0." + "'stub.server_side_routing.test_server_side_routing.TestServerSideRouting.test_direct_connection_with_url_params'": + "Driver emits deprecation warning. Behavior will be unified in 6.0.", + "neo4j.datatypes.test_temporal_types.TestDataTypes.test_should_echo_all_timezone_ids": + "test_subtest_skips.tz_id", + "neo4j.datatypes.test_temporal_types.TestDataTypes.test_date_time_cypher_created_tz_id": + "test_subtest_skips.tz_id" }, "features": { "Feature:API:ConnectionAcquisitionTimeout": true, @@ -32,8 +24,11 @@ "Feature:API:Result.List": true, "Feature:API:Result.Peek": true, "Feature:API:Result.Single": true, + "Feature:API:Result.SingleOptional": true, "Feature:API:SSLConfig": true, "Feature:API:SSLSchemes": true, + "Feature:API:Type.Spatial": true, + "Feature:API:Type.Temporal": true, "Feature:Auth:Bearer": true, "Feature:Auth:Custom": true, "Feature:Auth:Kerberos": true, diff --git a/testkitbackend/test_subtest_skips.py b/testkitbackend/test_subtest_skips.py new file mode 100644 index 000000000..a92ef70f1 --- /dev/null +++ b/testkitbackend/test_subtest_skips.py @@ -0,0 +1,53 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Functions to decide whether to run a subtest or not. + +They take the subtest parameters as arguments and return + - a string with describing the reason why the subtest should be skipped + - None if the subtest should be run +""" + + +def tz_id(**params): + # We could do this automatically, but with an explicit black list we + # make sure we know what we test and what we don't. + # if params["tz_id"] not in pytz.common_timezones_set: + # return ( + # "timezone id %s is not supported by the system" % params["tz_id"] + # ) + + if params["tz_id"] in { + "SystemV/AST4", + "SystemV/AST4ADT", + "SystemV/CST6", + "SystemV/CST6CDT", + "SystemV/EST5", + "SystemV/EST5EDT", + "SystemV/HST10", + "SystemV/MST7", + "SystemV/MST7MDT", + "SystemV/PST8", + "SystemV/PST8PDT", + "SystemV/YST9", + "SystemV/YST9YDT", + }: + return ( + "timezone id %s is not supported by the system" % params["tz_id"] + ) diff --git a/testkitbackend/totestkit.py b/testkitbackend/totestkit.py index d4f5a084a..6d8591a62 100644 --- a/testkitbackend/totestkit.py +++ b/testkitbackend/totestkit.py @@ -23,6 +23,16 @@ Path, Relationship, ) +from neo4j.spatial import ( + CartesianPoint, + WGS84Point, +) +from neo4j.time import ( + Date, + DateTime, + Duration, + Time, +) def record(rec): @@ -88,5 +98,77 @@ def to(name, val): "relationships": field(list(v.relationships)), } return {"name": "Path", "data": path} + if isinstance(v, CartesianPoint): + return { + "name": "CypherPoint", + "data": { + "system": "cartesian", + "x": v.x, + "y": v.y, + "z": getattr(v, "z", None) + }, + } + if isinstance(v, WGS84Point): + return { + "name": "CypherPoint", + "data": { + "system": "wgs84", + "x": v.x, + "y": v.y, + "z": getattr(v, "z", None) + }, + } + if isinstance(v, Date): + return { + "name": "CypherDate", + "data": { + "year": v.year, + "month": v.month, + "day": v.day + } + } + if isinstance(v, Time): + data = { + "hour": v.hour, + "minute": v.minute, + "second": v.second, + "nanosecond": v.nanosecond + } + if v.tzinfo is not None: + data["utc_offset_s"] = v.tzinfo.utcoffset(v).total_seconds() + return { + "name": "CypherTime", + "data": data + } + if isinstance(v, DateTime): + data = { + "year": v.year, + "month": v.month, + "day": v.day, + "hour": v.hour, + "minute": v.minute, + "second": v.second, + "nanosecond": v.nanosecond + } + if v.tzinfo is not None: + data["utc_offset_s"] = v.tzinfo.utcoffset(v).total_seconds() + for attr in ("zone", "key"): + timezone_id = getattr(v.tzinfo, attr, None) + if isinstance(timezone_id, str): + data["timezone_id"] = timezone_id + return { + "name": "CypherDateTime", + "data": data, + } + if isinstance(v, Duration): + return { + "name": "CypherDuration", + "data": { + "months": v.months, + "days": v.days, + "seconds": v.seconds, + "nanoseconds": v.nanoseconds + }, + } - raise Exception("Unhandled type:" + str(type(v))) + raise ValueError("Unhandled type:" + str(type(v))) diff --git a/tests/integration/async_/test_custom_ssl_context.py b/tests/integration/async_/test_custom_ssl_context.py index ed8a9fad3..a13a823c3 100644 --- a/tests/integration/async_/test_custom_ssl_context.py +++ b/tests/integration/async_/test_custom_ssl_context.py @@ -25,7 +25,10 @@ @mark_async_test -async def test_custom_ssl_context_is_wraps_connection(target, auth, mocker): +async def test_custom_ssl_context_wraps_connection(target, auth, mocker): + # Test that the driver calls either `.wrap_socket` or `.wrap_bio` on the + # provided custom SSL context. + class NoNeedToGoFurtherException(Exception): pass @@ -35,6 +38,7 @@ def wrap_fail(*_, **__): fake_ssl_context = mocker.create_autospec(SSLContext) fake_ssl_context.wrap_socket.side_effect = wrap_fail fake_ssl_context.wrap_bio.side_effect = wrap_fail + driver = AsyncGraphDatabase.neo4j_driver( target, auth=auth, ssl_context=fake_ssl_context ) @@ -42,5 +46,6 @@ def wrap_fail(*_, **__): async with driver.session() as session: with pytest.raises(NoNeedToGoFurtherException): await session.run("RETURN 1") + assert (fake_ssl_context.wrap_socket.call_count + fake_ssl_context.wrap_bio.call_count) == 1 diff --git a/tests/integration/sync/test_custom_ssl_context.py b/tests/integration/sync/test_custom_ssl_context.py index 0135d034a..91f491441 100644 --- a/tests/integration/sync/test_custom_ssl_context.py +++ b/tests/integration/sync/test_custom_ssl_context.py @@ -25,7 +25,10 @@ @mark_sync_test -def test_custom_ssl_context_is_wraps_connection(target, auth, mocker): +def test_custom_ssl_context_wraps_connection(target, auth, mocker): + # Test that the driver calls either `.wrap_socket` or `.wrap_bio` on the + # provided custom SSL context. + class NoNeedToGoFurtherException(Exception): pass @@ -35,6 +38,7 @@ def wrap_fail(*_, **__): fake_ssl_context = mocker.create_autospec(SSLContext) fake_ssl_context.wrap_socket.side_effect = wrap_fail fake_ssl_context.wrap_bio.side_effect = wrap_fail + driver = GraphDatabase.neo4j_driver( target, auth=auth, ssl_context=fake_ssl_context ) @@ -42,5 +46,6 @@ def wrap_fail(*_, **__): with driver.session() as session: with pytest.raises(NoNeedToGoFurtherException): session.run("RETURN 1") + assert (fake_ssl_context.wrap_socket.call_count + fake_ssl_context.wrap_bio.call_count) == 1 diff --git a/tests/integration/test_autocommit.py b/tests/integration/test_autocommit.py deleted file mode 100644 index 960cbe3f5..000000000 --- a/tests/integration/test_autocommit.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [http://neo4j.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from neo4j import Query - - -# TODO: this test will stay until a uniform behavior for `.single()` across the -# drivers has been specified and tests are created in testkit -def test_result_single_record_value(session): - record = session.run(Query("RETURN $x"), x=1).single() - assert record.value() == 1 diff --git a/tests/integration/test_bolt_driver.py b/tests/integration/test_bolt_driver.py index 346b82cd9..6697dfa2f 100644 --- a/tests/integration/test_bolt_driver.py +++ b/tests/integration/test_bolt_driver.py @@ -14,27 +14,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - +import pytest from pytest import fixture -# TODO: this test will stay until a uniform behavior for `.single()` across the -# drivers has been specified and tests are created in testkit -def test_normal_use_case(bolt_driver): - # python -m pytest tests/integration/test_bolt_driver.py -s -v -k test_normal_use_case - session = bolt_driver.session() - value = session.run("RETURN 1").single().value() - assert value == 1 - - -# TODO: this test will stay until a uniform behavior for `.encrypted` across the -# drivers has been specified and tests are created in testkit -def test_encrypted_set_to_false_by_default(bolt_driver): - # python -m pytest tests/integration/test_bolt_driver.py -s -v -k test_encrypted_set_to_false_by_default - assert bolt_driver.encrypted is False - - @fixture def server_info(driver): """ Simple fixture to provide quick and easy access to a @@ -45,8 +28,10 @@ def server_info(driver): yield summary.server -# TODO: this test will stay asy python is currently the only driver exposing the -# connection id. So this might change in the future. +# TODO: 6.0 - +# This test will stay as python is currently the only driver exposing +# the connection id. This will be removed in 6.0 def test_server_connection_id(server_info): - cid = server_info.connection_id + with pytest.warns(DeprecationWarning): + cid = server_info.connection_id assert cid.startswith("bolt-") and cid[5:].isdigit() diff --git a/tests/integration/test_readme.py b/tests/integration/test_readme.py index fde466390..1a13a4cb4 100644 --- a/tests/integration/test_readme.py +++ b/tests/integration/test_readme.py @@ -31,23 +31,28 @@ def test_should_run_readme(uri, auth): from neo4j import GraphDatabase - try: - driver = GraphDatabase.driver(uri, auth=auth) - except ServiceUnavailable as error: - if isinstance(error.__cause__, BoltHandshakeError): - pytest.skip(error.args[0]) + driver = GraphDatabase.driver(uri, auth=auth) + + def add_friend(tx, name, friend_name): + tx.run("MERGE (a:Person {name: $name}) " + "MERGE (a)-[:KNOWS]->(friend:Person {name: $friend_name})", + name=name, friend_name=friend_name) def print_friends(tx, name): - for record in tx.run("MATCH (a:Person)-[:KNOWS]->(friend) " - "WHERE a.name = $name " - "RETURN friend.name", name=name): + for record in tx.run( + "MATCH (a:Person)-[:KNOWS]->(friend) WHERE a.name = $name " + "RETURN friend.name ORDER BY friend.name", name=name): print(record["friend.name"]) with driver.session() as session: session.run("MATCH (a) DETACH DELETE a") - session.run("CREATE (a:Person {name:'Alice'})-[:KNOWS]->({name:'Bob'})") - session.read_transaction(print_friends, "Alice") + + session.write_transaction(add_friend, "Arthur", "Guinevere") + session.write_transaction(add_friend, "Arthur", "Lancelot") + session.write_transaction(add_friend, "Arthur", "Merlin") + session.read_transaction(print_friends, "Arthur") + + session.run("MATCH (a) DETACH DELETE a") driver.close() - assert len(names) == 1 - assert "Bob" in names + assert names == {"Guinevere", "Lancelot", "Merlin"} diff --git a/tests/integration/test_result_graph.py b/tests/integration/test_result_graph.py deleted file mode 100644 index 15ea7a37d..000000000 --- a/tests/integration/test_result_graph.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [http://neo4j.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import pytest - -from neo4j.graph import Graph - - -def test_result_graph_instance(session): - # python -m pytest tests/integration/test_result_graph.py -s -v -k test_result_graph_instance - result = session.run("RETURN 1") - graph = result.graph() - - assert isinstance(graph, Graph) - - -def test_result_graph_case_1(session): - # python -m pytest tests/integration/test_result_graph.py -s -v -k test_result_graph_case_1 - result = session.run("CREATE (n1:Person:LabelTest1 {name:'Alice'})-[r1:KNOWS {since:1999}]->(n2:Person:LabelTest2 {name:'Bob'}) RETURN n1, r1, n2") - graph = result.graph() - assert isinstance(graph, Graph) - - node_view = graph.nodes - relationships_view = graph.relationships - - for node in node_view: - name = node["name"] - if name == "Alice": - assert node.labels == frozenset(["Person", "LabelTest1"]) - elif name == "Bob": - assert node.labels == frozenset(["Person", "LabelTest2"]) - else: - pytest.fail("should only contain 2 nodes, Alice and Bob. {}".format(name)) - - for relationship in relationships_view: - since = relationship["since"] - assert since == 1999 - assert relationship.type == "KNOWS" diff --git a/tests/integration/test_spatial_types.py b/tests/integration/test_spatial_types.py deleted file mode 100644 index 71ed4cd5a..000000000 --- a/tests/integration/test_spatial_types.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [http://neo4j.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import pytest - -from neo4j.spatial import ( - CartesianPoint, - WGS84Point, -) - - -def test_cartesian_point_input(cypher_eval): - x, y = cypher_eval("CYPHER runtime=interpreted " - "WITH $point AS point " - "RETURN [point.x, point.y]", - point=CartesianPoint((1.23, 4.56))) - assert x == 1.23 - assert y == 4.56 - - -def test_cartesian_3d_point_input(cypher_eval): - x, y, z = cypher_eval("CYPHER runtime=interpreted " - "WITH $point AS point " - "RETURN [point.x, point.y, point.z]", - point=CartesianPoint((1.23, 4.56, 7.89))) - assert x == 1.23 - assert y == 4.56 - assert z == 7.89 - - -def test_wgs84_point_input(cypher_eval): - lat, long = cypher_eval("CYPHER runtime=interpreted " - "WITH $point AS point " - "RETURN [point.latitude, point.longitude]", - point=WGS84Point((1.23, 4.56))) - assert long == 1.23 - assert lat == 4.56 - - -def test_wgs84_3d_point_input(cypher_eval): - lat, long, height = cypher_eval("CYPHER runtime=interpreted " - "WITH $point AS point " - "RETURN [point.latitude, point.longitude, " - "point.height]", - point=WGS84Point((1.23, 4.56, 7.89))) - assert long == 1.23 - assert lat == 4.56 - assert height == 7.89 - - -def test_point_array_input(cypher_eval): - data = [WGS84Point((1.23, 4.56)), WGS84Point((9.87, 6.54))] - value = cypher_eval("CREATE (a {x:$x}) RETURN a.x", x=data) - assert value == data - - -def test_cartesian_point_output(cypher_eval): - value = cypher_eval("RETURN point({x:3, y:4})") - assert isinstance(value, CartesianPoint) - assert value.x == 3.0 - assert value.y == 4.0 - with pytest.raises(AttributeError): - _ = value.z - - -def test_cartesian_3d_point_output(cypher_eval): - value = cypher_eval("RETURN point({x:3, y:4, z:5})") - assert isinstance(value, CartesianPoint) - assert value.x == 3.0 - assert value.y == 4.0 - assert value.z == 5.0 - - -def test_wgs84_point_output(cypher_eval): - value = cypher_eval("RETURN point({latitude:3, longitude:4})") - assert isinstance(value, WGS84Point) - assert value.latitude == 3.0 - assert value.y == 3.0 - assert value.longitude == 4.0 - assert value.x == 4.0 - with pytest.raises(AttributeError): - _ = value.height - with pytest.raises(AttributeError): - _ = value.z - - -def test_wgs84_3d_point_output(cypher_eval): - value = cypher_eval("RETURN point({latitude:3, longitude:4, height:5})") - assert isinstance(value, WGS84Point) - assert value.latitude == 3.0 - assert value.y == 3.0 - assert value.longitude == 4.0 - assert value.x == 4.0 - assert value.height == 5.0 - assert value.z == 5.0 diff --git a/tests/integration/test_temporal_types.py b/tests/integration/test_temporal_types.py deleted file mode 100644 index b3f3be995..000000000 --- a/tests/integration/test_temporal_types.py +++ /dev/null @@ -1,402 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [http://neo4j.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import datetime - -import pytest -from pytz import ( - FixedOffset, - timezone, - utc, -) - -from neo4j.exceptions import CypherTypeError -from neo4j.time import ( - Date, - DateTime, - Duration, - Time, -) - - -def test_native_date_input(cypher_eval): - from datetime import date - result = cypher_eval("CYPHER runtime=interpreted WITH $x AS x " - "RETURN [x.year, x.month, x.day]", - x=date(1976, 6, 13)) - year, month, day = result - assert year == 1976 - assert month == 6 - assert day == 13 - - -def test_date_input(cypher_eval): - result = cypher_eval("CYPHER runtime=interpreted WITH $x AS x " - "RETURN [x.year, x.month, x.day]", - x=Date(1976, 6, 13)) - year, month, day = result - assert year == 1976 - assert month == 6 - assert day == 13 - - -def test_date_array_input(cypher_eval): - data = [DateTime.now().date(), Date(1976, 6, 13)] - value = cypher_eval("CREATE (a {x:$x}) RETURN a.x", x=data) - assert value == data - - -def test_native_time_input(cypher_eval): - from datetime import time - result = cypher_eval("CYPHER runtime=interpreted WITH $x AS x " - "RETURN [x.hour, x.minute, x.second, x.nanosecond]", - x=time(12, 34, 56, 789012)) - hour, minute, second, nanosecond = result - assert hour == 12 - assert minute == 34 - assert second == 56 - assert nanosecond == 789012000 - - -def test_whole_second_time_input(cypher_eval): - result = cypher_eval("CYPHER runtime=interpreted WITH $x AS x " - "RETURN [x.hour, x.minute, x.second]", - x=Time(12, 34, 56)) - hour, minute, second = result - assert hour == 12 - assert minute == 34 - assert second == 56 - - -def test_nanosecond_resolution_time_input(cypher_eval): - result = cypher_eval("CYPHER runtime=interpreted WITH $x AS x " - "RETURN [x.hour, x.minute, x.second, x.nanosecond]", - x=Time(12, 34, 56, 789012345)) - hour, minute, second, nanosecond = result - assert hour == 12 - assert minute == 34 - assert second == 56 - assert nanosecond == 789012345 - - -def test_time_with_numeric_time_offset_input(cypher_eval): - result = cypher_eval("CYPHER runtime=interpreted WITH $x AS x " - "RETURN [x.hour, x.minute, x.second, " - " x.nanosecond, x.offset]", - x=Time(12, 34, 56, 789012345, tzinfo=FixedOffset(90))) - hour, minute, second, nanosecond, offset = result - assert hour == 12 - assert minute == 34 - assert second == 56 - assert nanosecond == 789012345 - assert offset == "+01:30" - - -def test_time_array_input(cypher_eval): - data = [Time(12, 34, 56), Time(10, 0, 0)] - value = cypher_eval("CREATE (a {x:$x}) RETURN a.x", x=data) - assert value == data - - -def test_native_datetime_input(cypher_eval): - from datetime import datetime - result = cypher_eval("CYPHER runtime=interpreted WITH $x AS x " - "RETURN [x.year, x.month, x.day, " - " x.hour, x.minute, x.second, x.nanosecond]", - x=datetime(1976, 6, 13, 12, 34, 56, 789012)) - year, month, day, hour, minute, second, nanosecond = result - assert year == 1976 - assert month == 6 - assert day == 13 - assert hour == 12 - assert minute == 34 - assert second == 56 - assert nanosecond == 789012000 - - -def test_whole_second_datetime_input(cypher_eval): - result = cypher_eval("CYPHER runtime=interpreted WITH $x AS x " - "RETURN [x.year, x.month, x.day, " - " x.hour, x.minute, x.second]", - x=DateTime(1976, 6, 13, 12, 34, 56)) - year, month, day, hour, minute, second = result - assert year == 1976 - assert month == 6 - assert day == 13 - assert hour == 12 - assert minute == 34 - assert second == 56 - - -def test_nanosecond_resolution_datetime_input(cypher_eval): - result = cypher_eval("CYPHER runtime=interpreted WITH $x AS x " - "RETURN [x.year, x.month, x.day, " - " x.hour, x.minute, x.second, x.nanosecond]", - x=DateTime(1976, 6, 13, 12, 34, 56, 789012345)) - year, month, day, hour, minute, second, nanosecond = result - assert year == 1976 - assert month == 6 - assert day == 13 - assert hour == 12 - assert minute == 34 - assert second == 56 - assert nanosecond == 789012345 - - -def test_datetime_with_numeric_time_offset_input(cypher_eval): - result = cypher_eval("CYPHER runtime=interpreted WITH $x AS x " - "RETURN [x.year, x.month, x.day, " - " x.hour, x.minute, x.second, " - " x.nanosecond, x.offset]", - x=DateTime(1976, 6, 13, 12, 34, 56, 789012345, - tzinfo=FixedOffset(90))) - year, month, day, hour, minute, second, nanosecond, offset = result - assert year == 1976 - assert month == 6 - assert day == 13 - assert hour == 12 - assert minute == 34 - assert second == 56 - assert nanosecond == 789012345 - assert offset == "+01:30" - - -def test_datetime_with_named_time_zone_input(cypher_eval): - dt = DateTime(1976, 6, 13, 12, 34, 56.789012345) - input_value = timezone("US/Pacific").localize(dt) - result = cypher_eval("CYPHER runtime=interpreted WITH $x AS x " - "RETURN [x.year, x.month, x.day, " - " x.hour, x.minute, x.second, " - " x.nanosecond, x.timezone]", - x=input_value) - year, month, day, hour, minute, second, nanosecond, tz = result - assert year == input_value.year - assert month == input_value.month - assert day == input_value.day - assert hour == input_value.hour - assert minute == input_value.minute - assert second == int(input_value.second) - assert nanosecond == int(1000000000 * input_value.second % 1000000000) - assert tz == input_value.tzinfo.zone - - -def test_datetime_array_input(cypher_eval): - data = [DateTime(2018, 4, 6, 13, 4, 42, 516120), DateTime(1976, 6, 13)] - value = cypher_eval("CREATE (a {x:$x}) RETURN a.x", x=data) - assert value == data - - -def test_duration_input(cypher_eval): - result = cypher_eval("CYPHER runtime=interpreted WITH $x AS x " - "RETURN [x.months, x.days, x.seconds, " - " x.microsecondsOfSecond]", - x=Duration(years=1, months=2, days=3, hours=4, - minutes=5, seconds=6.789012)) - months, days, seconds, microseconds = result - assert months == 14 - assert days == 3 - assert seconds == 14706 - assert microseconds == 789012 - - -def test_duration_array_input(cypher_eval): - data = [Duration(1, 2, 3, 4, 5, 6), Duration(9, 8, 7, 6, 5, 4)] - value = cypher_eval("CREATE (a {x:$x}) RETURN a.x", x=data) - assert value == data - - -def test_timedelta_input(cypher_eval): - from datetime import timedelta - result = cypher_eval("CYPHER runtime=interpreted WITH $x AS x " - "RETURN [x.months, x.days, x.seconds, " - " x.microsecondsOfSecond]", - x=timedelta(days=3, hours=4, minutes=5, - seconds=6.789012)) - months, days, seconds, microseconds = result - assert months == 0 - assert days == 3 - assert seconds == 14706 - assert microseconds == 789012 - - -def test_mixed_array_input(cypher_eval): - data = [Date(1976, 6, 13), Duration(9, 8, 7, 6, 5, 4)] - with pytest.raises(CypherTypeError): - _ = cypher_eval("CREATE (a {x:$x}) RETURN a.x", x=data) - - -def test_date_output(cypher_eval): - value = cypher_eval("RETURN date('1976-06-13')") - assert isinstance(value, Date) - assert value == Date(1976, 6, 13) - - -def test_whole_second_time_output(cypher_eval): - value = cypher_eval("RETURN time('12:34:56')") - assert isinstance(value, Time) - assert value == Time(12, 34, 56, tzinfo=FixedOffset(0)) - - -def test_nanosecond_resolution_time_output(cypher_eval): - value = cypher_eval("RETURN time('12:34:56.789012345')") - assert isinstance(value, Time) - assert value == Time(12, 34, 56, 789012345, tzinfo=FixedOffset(0)) - - -def test_time_with_numeric_time_offset_output(cypher_eval): - value = cypher_eval("RETURN time('12:34:56.789012345+0130')") - assert isinstance(value, Time) - assert value == Time(12, 34, 56, 789012345, tzinfo=FixedOffset(90)) - - -def test_whole_second_localtime_output(cypher_eval): - value = cypher_eval("RETURN localtime('12:34:56')") - assert isinstance(value, Time) - assert value == Time(12, 34, 56) - - -def test_nanosecond_resolution_localtime_output(cypher_eval): - value = cypher_eval("RETURN localtime('12:34:56.789012345')") - assert isinstance(value, Time) - assert value == Time(12, 34, 56, 789012345) - - -def test_whole_second_datetime_output(cypher_eval): - value = cypher_eval("RETURN datetime('1976-06-13T12:34:56')") - assert isinstance(value, DateTime) - assert value == DateTime(1976, 6, 13, 12, 34, 56, tzinfo=utc) - - -def test_nanosecond_resolution_datetime_output(cypher_eval): - value = cypher_eval("RETURN datetime('1976-06-13T12:34:56.789012345')") - assert isinstance(value, DateTime) - assert value == DateTime(1976, 6, 13, 12, 34, 56, 789012345, tzinfo=utc) - - -def test_datetime_with_numeric_time_offset_output(cypher_eval): - value = cypher_eval("RETURN " - "datetime('1976-06-13T12:34:56.789012345+01:30')") - assert isinstance(value, DateTime) - assert value == DateTime(1976, 6, 13, 12, 34, 56, 789012345, - tzinfo=FixedOffset(90)) - - -def test_datetime_with_named_time_zone_output(cypher_eval): - value = cypher_eval("RETURN datetime('1976-06-13T12:34:56.789012345" - "[Europe/London]')") - assert isinstance(value, DateTime) - dt = DateTime(1976, 6, 13, 12, 34, 56, 789012345) - assert value == timezone("Europe/London").localize(dt) - - -def test_whole_second_localdatetime_output(cypher_eval): - value = cypher_eval("RETURN localdatetime('1976-06-13T12:34:56')") - assert isinstance(value, DateTime) - assert value == DateTime(1976, 6, 13, 12, 34, 56) - - -def test_nanosecond_resolution_localdatetime_output(cypher_eval): - value = cypher_eval("RETURN " - "localdatetime('1976-06-13T12:34:56.789012345')") - assert isinstance(value, DateTime) - assert value == DateTime(1976, 6, 13, 12, 34, 56, 789012345) - - -def test_duration_output(cypher_eval): - value = cypher_eval("RETURN duration('P1Y2M3DT4H5M6.789S')") - assert isinstance(value, Duration) - assert value == Duration(years=1, months=2, days=3, hours=4, - minutes=5, seconds=6.789) - - -def test_nanosecond_resolution_duration_output(cypher_eval): - value = cypher_eval("RETURN duration('P1Y2M3DT4H5M6.789123456S')") - assert isinstance(value, Duration) - assert value == Duration(years=1, months=2, days=3, hours=4, - minutes=5, seconds=6, nanoseconds=789123456) - - -def test_datetime_parameter_case1(session): - # python -m pytest tests/integration/test_temporal_types.py -s -v -k test_datetime_parameter_case1 - dt1 = session.run("RETURN datetime('2019-10-30T07:54:02.129790001+00:00')").single().value() - assert isinstance(dt1, DateTime) - - dt2 = session.run("RETURN $date_time", date_time=dt1).single().value() - assert isinstance(dt2, DateTime) - - assert dt1 == dt2 - - -def test_datetime_parameter_case2(session): - # python -m pytest tests/integration/test_temporal_types.py -s -v -k test_datetime_parameter_case2 - dt1 = session.run("RETURN datetime('2019-10-30T07:54:02.129790999[UTC]')").single().value() - assert isinstance(dt1, DateTime) - assert dt1.iso_format() == "2019-10-30T07:54:02.129790999+00:00" - - dt2 = session.run("RETURN $date_time", date_time=dt1).single().value() - assert isinstance(dt2, DateTime) - - assert dt1 == dt2 - - -def test_datetime_parameter_case3(session): - # python -m pytest tests/integration/test_temporal_types.py -s -v -k test_datetime_parameter_case1 - dt1 = session.run("RETURN datetime('2019-10-30T07:54:02.129790+00:00')").single().value() - assert isinstance(dt1, DateTime) - - dt2 = session.run("RETURN $date_time", date_time=dt1).single().value() - assert isinstance(dt2, DateTime) - - assert dt1 == dt2 - - -def test_time_parameter_case1(session): - # python -m pytest tests/integration/test_temporal_types.py -s -v -k test_time_parameter_case1 - t1 = session.run("RETURN time('07:54:02.129790001+00:00')").single().value() - assert isinstance(t1, Time) - - t2 = session.run("RETURN $time", time=t1).single().value() - assert isinstance(t2, Time) - - assert t1 == t2 - - -def test_time_parameter_case2(session): - # python -m pytest tests/integration/test_temporal_types.py -s -v -k test_time_parameter_case2 - t1 = session.run("RETURN time('07:54:02.129790999+00:00')").single().value() - assert isinstance(t1, Time) - assert t1.iso_format() == "07:54:02.129790999+00:00" - time_zone_delta = t1.utc_offset() - assert isinstance(time_zone_delta, datetime.timedelta) - assert time_zone_delta == datetime.timedelta(0) - - t2 = session.run("RETURN $time", time=t1).single().value() - assert isinstance(t2, Time) - - assert t1 == t2 - - -def test_time_parameter_case3(session): - # python -m pytest tests/integration/test_temporal_types.py -s -v -k test_time_parameter_case3 - t1 = session.run("RETURN time('07:54:02.129790+00:00')").single().value() - assert isinstance(t1, Time) - - t2 = session.run("RETURN $time", time=t1).single().value() - assert isinstance(t2, Time) - - assert t1 == t2 diff --git a/tests/unit/async_/work/_fake_connection.py b/tests/unit/async_/work/_fake_connection.py index 2ba962ad3..72419b542 100644 --- a/tests/unit/async_/work/_fake_connection.py +++ b/tests/unit/async_/work/_fake_connection.py @@ -110,3 +110,81 @@ async def callback(): @pytest.fixture def async_fake_connection(async_fake_connection_generator): return async_fake_connection_generator() + + +@pytest.fixture +def async_scripted_connection_generator(async_fake_connection_generator): + class AsyncScriptedConnection(async_fake_connection_generator): + _script = [] + _script_pos = 0 + + def set_script(self, callbacks): + """Set a scripted sequence of callbacks. + + :param callbacks: The callbacks. They should be a list of 2-tuples. + `("name_of_message", {"callback_name": arguments})`. E.g., + ``` + [ + ("run", {"on_success": ({},), "on_summary": None}), + ("pull", { + "on_success": None, + "on_summary": None, + "on_records": + }) + ] + ``` + Note that arguments can be `None`. In this case, ScriptedConnection + will make a guess on best-suited default arguments. + """ + self._script = callbacks + self._script_pos = 0 + + def __getattr__(self, name): + parent = super() + + def build_message_handler(name): + try: + expected_message, scripted_callbacks = \ + self._script[self._script_pos] + except IndexError: + pytest.fail("End of scripted connection reached.") + assert name == expected_message + self._script_pos += 1 + + def func(*args, **kwargs): + async def callback(): + for cb_name, default_cb_args in ( + ("on_ignored", ({},)), + ("on_failure", ({},)), + ("on_records", ([],)), + ("on_success", ({},)), + ("on_summary", ()), + ): + cb = kwargs.get(cb_name, None) + if (not callable(cb) + or cb_name not in scripted_callbacks): + continue + cb_args = scripted_callbacks[cb_name] + if cb_args is None: + cb_args = default_cb_args + res = cb(*cb_args) + try: + await res # maybe the callback is async + except TypeError: + pass # or maybe it wasn't ;) + + self.callbacks.append(callback) + + return func + + method_mock = parent.__getattr__(name) + if name in ("run", "commit", "pull", "rollback", "discard"): + method_mock.side_effect = build_message_handler(name) + return method_mock + + return AsyncScriptedConnection + + +@pytest.fixture +def async_scripted_connection(async_scripted_connection_generator): + return async_scripted_connection_generator() diff --git a/tests/unit/async_/work/conftest.py b/tests/unit/async_/work/conftest.py index 6224f9c67..3b60f3efd 100644 --- a/tests/unit/async_/work/conftest.py +++ b/tests/unit/async_/work/conftest.py @@ -1,4 +1,6 @@ from ._fake_connection import ( async_fake_connection, async_fake_connection_generator, + async_scripted_connection, + async_scripted_connection_generator, ) diff --git a/tests/unit/async_/work/test_result.py b/tests/unit/async_/work/test_result.py index f4043b5e6..efdc2395a 100644 --- a/tests/unit/async_/work/test_result.py +++ b/tests/unit/async_/work/test_result.py @@ -37,9 +37,10 @@ Node, Relationship, ) -from neo4j.exceptions import ( - ResultConsumedError, - ResultNotSingleError, +from neo4j.exceptions import ResultNotSingleError +from neo4j.graph import ( + EntitySetView, + Graph, ) from neo4j.packstream import Structure @@ -575,6 +576,94 @@ async def test_data(num_records): assert record.data.called_once_with("hello", "world") +@pytest.mark.parametrize("records", ( + Records(["n"], []), + Records(["n"], [[42], [69], [420], [1337]]), + Records(["n1", "r", "n2"], [ + [ + # Node + Structure(b"N", 0, ["Person", "LabelTest1"], {"name": "Alice"}), + # Relationship + Structure(b"R", 0, 0, 1, "KNOWS", {"since": 1999}), + # Node + Structure(b"N", 1, ["Person", "LabelTest2"], {"name": "Bob"}), + ] + ]), +)) +@mark_async_test +async def test_result_graph(records, async_scripted_connection): + async_scripted_connection.set_script(( + ("run", {"on_success": ({"fields": records.fields},), + "on_summary": None}), + ("pull", { + "on_records": (records.records,), + "on_success": None, + "on_summary": None + }), + )) + result = AsyncResult(async_scripted_connection, DataHydrator(), 1, noop, + noop) + await result._run("CYPHER", {}, None, None, "r", None) + graph = await result.graph() + assert isinstance(graph, Graph) + if records.fields == ("n",): + assert len(graph.relationships) == 0 + assert len(graph.nodes) == 0 + else: + # EntitySetView is a little broken. It's a weird mixture of set, dict, + # and iterable. Let's just test the underlying raw dict + assert isinstance(graph.nodes, EntitySetView) + nodes = graph.nodes + + assert set(nodes._entity_dict) == {"0", "1"} + for key in ( + "0", 0, 0.0, + # I pray to god that no-one actually accessed nodes with complex + # numbers, but theoretically it would have worked with the legacy + # number IDs + 0+0j, + ): + if not isinstance(key, str): + with pytest.warns(DeprecationWarning, match="element_id"): + alice = nodes[key] + else: + alice = nodes[key] + assert isinstance(alice, Node) + isinstance(alice.labels, frozenset) + assert alice.labels == {"Person", "LabelTest1"} + assert set(alice.keys()) == {"name"} + assert alice["name"] == "Alice" + + for key in ("1", 1, 1.0, 1+0j): + if not isinstance(key, str): + with pytest.warns(DeprecationWarning, match="element_id"): + bob = nodes[key] + else: + bob = nodes[key] + assert isinstance(bob, Node) + isinstance(bob.labels, frozenset) + assert bob.labels == {"Person", "LabelTest2"} + assert set(bob.keys()) == {"name"} + assert bob["name"] == "Bob" + + assert isinstance(graph.relationships, EntitySetView) + rels = graph.relationships + + assert set(rels._entity_dict) == {"0"} + + for key in ("0", 0, 0.0, 0+0j): + if not isinstance(key, str): + with pytest.warns(DeprecationWarning, match="element_id"): + rel = rels[key] + else: + rel = rels[key] + assert isinstance(rel, Relationship) + assert rel.nodes == (alice, bob) + assert rel.type == "KNOWS" + assert set(rel.keys()) == {"since"} + assert rel["since"] == 1999 + + @pytest.mark.parametrize( ("keys", "values", "types", "instances"), ( diff --git a/tests/unit/common/spatial/test_cartesian_point.py b/tests/unit/common/spatial/test_cartesian_point.py index c33a90a14..259b177b9 100644 --- a/tests/unit/common/spatial/test_cartesian_point.py +++ b/tests/unit/common/spatial/test_cartesian_point.py @@ -20,6 +20,8 @@ import struct from unittest import TestCase +import pytest + from neo4j.data import DataDehydrator from neo4j.packstream import Packer from neo4j.spatial import CartesianPoint @@ -27,16 +29,26 @@ class CartesianPointTestCase(TestCase): - def test_alias(self): + def test_alias_3d(self): x, y, z = 3.2, 4.0, -1.2 p = CartesianPoint((x, y, z)) - self.assert_(hasattr(p, "x")) + self.assertTrue(hasattr(p, "x")) self.assertEqual(p.x, x) - self.assert_(hasattr(p, "y")) + self.assertTrue(hasattr(p, "y")) self.assertEqual(p.y, y) - self.assert_(hasattr(p, "z")) + self.assertTrue(hasattr(p, "z")) self.assertEqual(p.z, z) + def test_alias_2d(self): + x, y = 3.2, 4.0 + p = CartesianPoint((x, y)) + self.assertTrue(hasattr(p, "x")) + self.assertEqual(p.x, x) + self.assertTrue(hasattr(p, "y")) + self.assertEqual(p.y, y) + with self.assertRaises(AttributeError): + p.z + def test_dehydration_3d(self): coordinates = (1, -2, 3.1) p = CartesianPoint(coordinates) diff --git a/tests/unit/common/spatial/test_wgs84_point.py b/tests/unit/common/spatial/test_wgs84_point.py index 0dee1913f..6c378d0ba 100644 --- a/tests/unit/common/spatial/test_wgs84_point.py +++ b/tests/unit/common/spatial/test_wgs84_point.py @@ -27,15 +27,43 @@ class WGS84PointTestCase(TestCase): - def test_alias(self): + def test_alias_3d(self): x, y, z = 3.2, 4.0, -1.2 p = WGS84Point((x, y, z)) - self.assert_(hasattr(p, "longitude")) + + self.assertTrue(hasattr(p, "longitude")) self.assertEqual(p.longitude, x) - self.assert_(hasattr(p, "latitude")) + self.assertTrue(hasattr(p, "x")) + self.assertEqual(p.x, x) + + self.assertTrue(hasattr(p, "latitude")) self.assertEqual(p.latitude, y) - self.assert_(hasattr(p, "height")) + self.assertTrue(hasattr(p, "y")) + self.assertEqual(p.y, y) + + self.assertTrue(hasattr(p, "height")) self.assertEqual(p.height, z) + self.assertTrue(hasattr(p, "z")) + self.assertEqual(p.z, z) + + def test_alias_2d(self): + x, y = 3.2, 4.0 + p = WGS84Point((x, y)) + + self.assertTrue(hasattr(p, "longitude")) + self.assertEqual(p.longitude, x) + self.assertTrue(hasattr(p, "x")) + self.assertEqual(p.x, x) + + self.assertTrue(hasattr(p, "latitude")) + self.assertEqual(p.latitude, y) + self.assertTrue(hasattr(p, "y")) + self.assertEqual(p.y, y) + + with self.assertRaises(AttributeError): + p.height + with self.assertRaises(AttributeError): + p.z def test_dehydration_3d(self): coordinates = (1, -2, 3.1) diff --git a/tests/unit/sync/work/_fake_connection.py b/tests/unit/sync/work/_fake_connection.py index 557c333b4..18049af3e 100644 --- a/tests/unit/sync/work/_fake_connection.py +++ b/tests/unit/sync/work/_fake_connection.py @@ -110,3 +110,81 @@ def callback(): @pytest.fixture def fake_connection(fake_connection_generator): return fake_connection_generator() + + +@pytest.fixture +def scripted_connection_generator(fake_connection_generator): + class ScriptedConnection(fake_connection_generator): + _script = [] + _script_pos = 0 + + def set_script(self, callbacks): + """Set a scripted sequence of callbacks. + + :param callbacks: The callbacks. They should be a list of 2-tuples. + `("name_of_message", {"callback_name": arguments})`. E.g., + ``` + [ + ("run", {"on_success": ({},), "on_summary": None}), + ("pull", { + "on_success": None, + "on_summary": None, + "on_records": + }) + ] + ``` + Note that arguments can be `None`. In this case, ScriptedConnection + will make a guess on best-suited default arguments. + """ + self._script = callbacks + self._script_pos = 0 + + def __getattr__(self, name): + parent = super() + + def build_message_handler(name): + try: + expected_message, scripted_callbacks = \ + self._script[self._script_pos] + except IndexError: + pytest.fail("End of scripted connection reached.") + assert name == expected_message + self._script_pos += 1 + + def func(*args, **kwargs): + def callback(): + for cb_name, default_cb_args in ( + ("on_ignored", ({},)), + ("on_failure", ({},)), + ("on_records", ([],)), + ("on_success", ({},)), + ("on_summary", ()), + ): + cb = kwargs.get(cb_name, None) + if (not callable(cb) + or cb_name not in scripted_callbacks): + continue + cb_args = scripted_callbacks[cb_name] + if cb_args is None: + cb_args = default_cb_args + res = cb(*cb_args) + try: + res # maybe the callback is async + except TypeError: + pass # or maybe it wasn't ;) + + self.callbacks.append(callback) + + return func + + method_mock = parent.__getattr__(name) + if name in ("run", "commit", "pull", "rollback", "discard"): + method_mock.side_effect = build_message_handler(name) + return method_mock + + return ScriptedConnection + + +@pytest.fixture +def scripted_connection(scripted_connection_generator): + return scripted_connection_generator() diff --git a/tests/unit/sync/work/conftest.py b/tests/unit/sync/work/conftest.py index 6302829c2..066a23d36 100644 --- a/tests/unit/sync/work/conftest.py +++ b/tests/unit/sync/work/conftest.py @@ -1,4 +1,6 @@ from ._fake_connection import ( fake_connection, fake_connection_generator, + scripted_connection, + scripted_connection_generator, ) diff --git a/tests/unit/sync/work/test_result.py b/tests/unit/sync/work/test_result.py index d21d121a9..b8ccb695d 100644 --- a/tests/unit/sync/work/test_result.py +++ b/tests/unit/sync/work/test_result.py @@ -37,9 +37,10 @@ Node, Relationship, ) -from neo4j.exceptions import ( - ResultConsumedError, - ResultNotSingleError, +from neo4j.exceptions import ResultNotSingleError +from neo4j.graph import ( + EntitySetView, + Graph, ) from neo4j.packstream import Structure @@ -575,6 +576,94 @@ def test_data(num_records): assert record.data.called_once_with("hello", "world") +@pytest.mark.parametrize("records", ( + Records(["n"], []), + Records(["n"], [[42], [69], [420], [1337]]), + Records(["n1", "r", "n2"], [ + [ + # Node + Structure(b"N", 0, ["Person", "LabelTest1"], {"name": "Alice"}), + # Relationship + Structure(b"R", 0, 0, 1, "KNOWS", {"since": 1999}), + # Node + Structure(b"N", 1, ["Person", "LabelTest2"], {"name": "Bob"}), + ] + ]), +)) +@mark_sync_test +def test_result_graph(records, scripted_connection): + scripted_connection.set_script(( + ("run", {"on_success": ({"fields": records.fields},), + "on_summary": None}), + ("pull", { + "on_records": (records.records,), + "on_success": None, + "on_summary": None + }), + )) + result = Result(scripted_connection, DataHydrator(), 1, noop, + noop) + result._run("CYPHER", {}, None, None, "r", None) + graph = result.graph() + assert isinstance(graph, Graph) + if records.fields == ("n",): + assert len(graph.relationships) == 0 + assert len(graph.nodes) == 0 + else: + # EntitySetView is a little broken. It's a weird mixture of set, dict, + # and iterable. Let's just test the underlying raw dict + assert isinstance(graph.nodes, EntitySetView) + nodes = graph.nodes + + assert set(nodes._entity_dict) == {"0", "1"} + for key in ( + "0", 0, 0.0, + # I pray to god that no-one actually accessed nodes with complex + # numbers, but theoretically it would have worked with the legacy + # number IDs + 0+0j, + ): + if not isinstance(key, str): + with pytest.warns(DeprecationWarning, match="element_id"): + alice = nodes[key] + else: + alice = nodes[key] + assert isinstance(alice, Node) + isinstance(alice.labels, frozenset) + assert alice.labels == {"Person", "LabelTest1"} + assert set(alice.keys()) == {"name"} + assert alice["name"] == "Alice" + + for key in ("1", 1, 1.0, 1+0j): + if not isinstance(key, str): + with pytest.warns(DeprecationWarning, match="element_id"): + bob = nodes[key] + else: + bob = nodes[key] + assert isinstance(bob, Node) + isinstance(bob.labels, frozenset) + assert bob.labels == {"Person", "LabelTest2"} + assert set(bob.keys()) == {"name"} + assert bob["name"] == "Bob" + + assert isinstance(graph.relationships, EntitySetView) + rels = graph.relationships + + assert set(rels._entity_dict) == {"0"} + + for key in ("0", 0, 0.0, 0+0j): + if not isinstance(key, str): + with pytest.warns(DeprecationWarning, match="element_id"): + rel = rels[key] + else: + rel = rels[key] + assert isinstance(rel, Relationship) + assert rel.nodes == (alice, bob) + assert rel.type == "KNOWS" + assert set(rel.keys()) == {"since"} + assert rel["since"] == 1999 + + @pytest.mark.parametrize( ("keys", "values", "types", "instances"), (