diff --git a/testkit/Dockerfile b/testkit/Dockerfile index 655953ffe..74068ce46 100644 --- a/testkit/Dockerfile +++ b/testkit/Dockerfile @@ -42,13 +42,16 @@ ENV PYENV_ROOT /.pyenv ENV PATH $PYENV_ROOT/shims:$PYENV_ROOT/bin:$PATH # Set minimum supported Python version -RUN pyenv install 3.7.12 +RUN pyenv install 3.7:latest +RUN pyenv install 3.8:latest +RUN pyenv install 3.9:latest +RUN pyenv install 3.10:latest RUN pyenv rehash -RUN pyenv global 3.7.12 +RUN pyenv global $(pyenv versions --bare --skip-aliases) # Install Latest pip for each environment # https://pip.pypa.io/en/stable/news/ RUN python -m pip install --upgrade pip # Install Python Testing Tools -RUN python -m pip install coverage tox +RUN python -m pip install coverage tox tox-factor diff --git a/testkit/integration.py b/testkit/integration.py index 590a9f449..1c8f435d7 100644 --- a/testkit/integration.py +++ b/testkit/integration.py @@ -18,5 +18,13 @@ # limitations under the License. +import subprocess + + +def run(args): + subprocess.run( + args, universal_newlines=True, stderr=subprocess.STDOUT, check=True) + + if __name__ == "__main__": - pass + run(["python", "-m", "tox", "-f", "integration"]) diff --git a/testkit/unittests.py b/testkit/unittests.py index 22e655a3f..262f1bbc0 100644 --- a/testkit/unittests.py +++ b/testkit/unittests.py @@ -27,4 +27,4 @@ def run(args): if __name__ == "__main__": - run(["python", "-m", "tox", "-c", "tox-unit.ini"]) + run(["python", "-m", "tox", "-f", "unit"]) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..6a62e1129 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,190 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from functools import wraps +from os import environ +import warnings + +import pytest +import pytest_asyncio + +from neo4j import ( + AsyncGraphDatabase, + ExperimentalWarning, + GraphDatabase, +) +from neo4j._exceptions import BoltHandshakeError +from neo4j._sync.io import Bolt +from neo4j.exceptions import ServiceUnavailable + +from . import env + + +# from neo4j.debug import watch +# +# watch("neo4j") + + +@pytest.fixture(scope="session") +def uri(): + return env.NEO4J_SERVER_URI + + +@pytest.fixture(scope="session") +def bolt_uri(uri): + if env.NEO4J_SCHEME != "bolt": + pytest.skip("Test requires bolt scheme") + return uri + + +@pytest.fixture(scope="session") +def _forced_bolt_uri(): + return f"bolt://{env.NEO4J_HOST}:{env.NEO4J_PORT}" + + +@pytest.fixture(scope="session") +def neo4j_uri(): + if env.NEO4J_SCHEME != "neo4j": + pytest.skip("Test requires neo4j scheme") + return uri + + +@pytest.fixture(scope="session") +def _forced_neo4j_uri(): + return f"neo4j://{env.NEO4J_HOST}:{env.NEO4J_PORT}" + + +@pytest.fixture(scope="session") +def auth(): + return env.NEO4J_USER, env.NEO4J_PASS + + +@pytest.fixture +def driver(uri, auth): + with GraphDatabase.driver(uri, auth=auth) as driver: + yield driver + + +@pytest.fixture +def bolt_driver(bolt_uri, auth): + with GraphDatabase.driver(bolt_uri, auth=auth) as driver: + yield driver + + +@pytest.fixture +def neo4j_driver(neo4j_uri, auth): + with GraphDatabase.driver(neo4j_uri, auth=auth) as driver: + yield driver + + +@wraps(AsyncGraphDatabase.driver) +def get_async_driver_no_warning(*args, **kwargs): + # with warnings.catch_warnings(): + # warnings.filterwarnings("ignore", "neo4j async", ExperimentalWarning) + with pytest.warns(ExperimentalWarning, match="neo4j async"): + return AsyncGraphDatabase.driver(*args, **kwargs) + + +@pytest_asyncio.fixture +async def async_driver(uri, auth): + async with get_async_driver_no_warning(uri, auth=auth) as driver: + yield driver + + +@pytest_asyncio.fixture +async def async_bolt_driver(bolt_uri, auth): + async with get_async_driver_no_warning(bolt_uri, auth=auth) as driver: + yield driver + + +@pytest_asyncio.fixture +async def async_neo4j_driver(neo4j_uri, auth): + async with get_async_driver_no_warning(neo4j_uri, auth=auth) as driver: + yield driver + + +@pytest.fixture +def _forced_bolt_driver(_forced_bolt_uri): + with GraphDatabase.driver(_forced_bolt_uri, auth=auth) as driver: + yield driver + + +@pytest.fixture +def _forced_neo4j_driver(_forced_neo4j_uri): + with GraphDatabase.driver(_forced_neo4j_uri, auth=auth) as driver: + yield driver + + +@pytest.fixture(scope="session") +def server_info(_forced_bolt_driver): + return _forced_bolt_driver.get_server_info() + + +@pytest.fixture(scope="session") +def bolt_protocol_version(server_info): + return server_info.protocol_version + + +def mark_requires_min_bolt_version(version="3.5"): + return pytest.mark.skipif( + env.NEO4J_VERSION < version, + reason=f"requires server version '{version}' or higher, " + f"found '{env.NEO4J_VERSION}'" + ) + + +def mark_requires_edition(edition): + return pytest.mark.skipif( + env.NEO4J_EDITION != edition, + reason=f"requires server edition '{edition}', " + f"found '{env.NEO4J_EDITION}'" + ) + + +@pytest.fixture +def session(driver): + with driver.session() as session: + yield session + + +@pytest.fixture +def bolt_session(bolt_driver): + with bolt_driver.session() as session: + yield session + + +@pytest.fixture +def neo4j_session(neo4j_driver): + with neo4j_driver.session() as session: + yield session + + +# async support for pytest-benchmark +# https://github.com/ionelmc/pytest-benchmark/issues/66 +@pytest_asyncio.fixture +async def aio_benchmark(benchmark, event_loop): + def _wrapper(func, *args, **kwargs): + if asyncio.iscoroutinefunction(func): + @benchmark + def _(): + return event_loop.run_until_complete(func(*args, **kwargs)) + else: + benchmark(func, *args, **kwargs) + + return _wrapper diff --git a/tests/env.py b/tests/env.py index 2bcfd0e30..ea9fb46ee 100644 --- a/tests/env.py +++ b/tests/env.py @@ -16,19 +16,86 @@ # limitations under the License. -from os import getenv +import abc +from os import environ +import sys -# Full path of a server package to be used for integration testing -NEO4J_SERVER_PACKAGE = getenv("NEO4J_SERVER_PACKAGE") +class _LazyEval(abc.ABC): + @abc.abstractmethod + def eval(self): + pass -# An existing remote server at this URI -NEO4J_SERVER_URI = getenv("NEO4J_URI") -# Name of a user for the currently running server -NEO4J_USER = getenv("NEO4J_USER") +class _LazyEvalEnv(_LazyEval): + def __init__(self, env_key, type_=str, default=...): + self.env_key = env_key + self.type_ = type_ + self.default = default -# Password for the currently running server -NEO4J_PASSWORD = getenv("NEO4J_PASSWORD") + def eval(self): + if self.default is not ...: + value = environ.get(self.env_key, default=self.default) + else: + try: + value = environ[self.env_key] + except KeyError as e: + raise Exception( + f"Missing environemnt variable {self.env_key}" + ) from e + if self.type_ is bool: + return value.lower() in ("yes", "y", "1", "on", "true") + if self.type_ is not None: + return self.type_(value) -NEOCTRL_ARGS = getenv("NEOCTRL_ARGS", "3.4.1") + +class _LazyEvalFunc(_LazyEval): + def __init__(self, func): + self.func = func + + def eval(self): + return self.func() + + +class _Module: + def __init__(self, module): + self._moudle = module + + def __getattr__(self, item): + val = getattr(self._moudle, item) + if isinstance(val, _LazyEval): + val = val.eval() + setattr(self._moudle, item, val) + return val + + +_module = _Module(sys.modules[__name__]) + +sys.modules[__name__] = _module + + +NEO4J_HOST = _LazyEvalEnv("TEST_NEO4J_HOST") +NEO4J_PORT = _LazyEvalEnv("TEST_NEO4J_PORT", int) +NEO4J_USER = _LazyEvalEnv("TEST_NEO4J_USER") +NEO4J_PASS = _LazyEvalEnv("TEST_NEO4J_PASS") +NEO4J_SCHEME = _LazyEvalEnv("TEST_NEO4J_SCHEME") +NEO4J_EDITION = _LazyEvalEnv("TEST_NEO4J_EDITION") +NEO4J_VERSION = _LazyEvalEnv("TEST_NEO4J_VERSION") +NEO4J_IS_CLUSTER = _LazyEvalEnv("TEST_NEO4J_IS_CLUSTER", bool) +NEO4J_SERVER_URI = _LazyEvalFunc( + lambda: f"{_module.NEO4J_SCHEME}://{_module.NEO4J_HOST}:" + f"{_module.NEO4J_PORT}" +) + + +__all__ = ( + "NEO4J_HOST", + "NEO4J_PORT", + "NEO4J_USER", + "NEO4J_PASS", + "NEO4J_SCHEME", + "NEO4J_EDITION", + "NEO4J_VERSION", + "NEO4J_IS_CLUSTER", + "NEO4J_SERVER_URI", +) diff --git a/tests/integration/async_/test_custom_ssl_context.py b/tests/integration/async_/test_custom_ssl_context.py index 82222a680..4bc52f728 100644 --- a/tests/integration/async_/test_custom_ssl_context.py +++ b/tests/integration/async_/test_custom_ssl_context.py @@ -25,7 +25,7 @@ @mark_async_test -async def test_custom_ssl_context_wraps_connection(target, auth, mocker): +async def test_custom_ssl_context_wraps_connection(uri, auth, mocker): # Test that the driver calls either `.wrap_socket` or `.wrap_bio` on the # provided custom SSL context. @@ -39,8 +39,8 @@ def wrap_fail(*_, **__): fake_ssl_context.wrap_socket.side_effect = wrap_fail fake_ssl_context.wrap_bio.side_effect = wrap_fail - driver = AsyncGraphDatabase.neo4j_driver( - target, auth=auth, ssl_context=fake_ssl_context + driver = AsyncGraphDatabase.driver( + uri, auth=auth, ssl_context=fake_ssl_context ) async with driver: async with driver.session() as session: diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index dbcba8128..0d9661960 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -16,366 +16,29 @@ # limitations under the License. -from math import ceil -from os import getenv -from os.path import ( - dirname, - join, -) -from threading import RLock - import pytest -from neo4j import GraphDatabase -from neo4j._exceptions import BoltHandshakeError -from neo4j._sync.io import Bolt -from neo4j.exceptions import ServiceUnavailable - - -# import logging -# log = logging.getLogger("neo4j") -# -# from neo4j.debug import watch -# watch("neo4j") - -NEO4J_RELEASES = getenv("NEO4J_RELEASES", "snapshot-enterprise 3.5-enterprise").split() -NEO4J_HOST = "localhost" -NEO4J_PORTS = { - "bolt": 17601, - "http": 17401, - "https": 17301, -} -NEO4J_CORES = 3 -NEO4J_REPLICAS = 2 -NEO4J_USER = "neo4j" -NEO4J_PASSWORD = "pass" -NEO4J_AUTH = (NEO4J_USER, NEO4J_PASSWORD) -NEO4J_LOCK = RLock() -NEO4J_SERVICE = None -NEO4J_DEBUG = getenv("NEO4J_DEBUG", "") - -# TODO: re-enable when Docker is feasible -# from boltkit.server import Neo4jService, Neo4jClusterService - - -class Machine(object): - - def __init__(self, address): - self.address = address - - -class Neo4jService(object): - - run = d = join(dirname(__file__), ".run") - - edition = "enterprise" - - def __init__(self, name=None, image=None, auth=None, - n_cores=None, n_replicas=None, - bolt_port=None, http_port=None, debug_port=None, - debug_suspend=None, dir_spec=None, config=None): - from boltkit.legacy.controller import ( - _install, - create_controller, - ) - assert image.endswith("-enterprise") - release = image[:-11] - if release == "snapshot": - release = "4.0" - self.home = _install("enterprise", release, self.run, verbose=1) - self.auth = NEO4J_AUTH - self.controller = create_controller(self.home) - self.controller.set_initial_password(NEO4J_PASSWORD) - self.info = None - - def start(self, timeout=None): - self.info = self.controller.start(timeout=timeout) - - def stop(self, timeout=None): - from shutil import rmtree - self.controller.stop() - rmtree(self.home) - - def machines(self): - parsed = self.info.bolt_uri - return [Machine((parsed.hostname, parsed.port or 7687))] - - @property - def addresses(self): - return [machine.address for machine in self.machines()] - - def cores(self): - return self.machines() - - -class Neo4jClusterService(Neo4jService): - - @classmethod - def _port_range(cls, base_port, count): - if base_port is None: - return [None] * count - else: - return range(base_port, base_port + count) - - -def _ping_range(port_range): - count = 0 - for port in port_range: - count += 1 if Bolt.ping((NEO4J_HOST, port)) else 0 - return count - - -def get_existing_service(): - core_port_range = Neo4jClusterService._port_range(NEO4J_PORTS["bolt"], NEO4J_CORES) - replica_port_range = Neo4jClusterService._port_range(ceil(core_port_range.stop / 10) * 10 + 1, - NEO4J_REPLICAS) - core_count = _ping_range(core_port_range) - replica_count = _ping_range(replica_port_range) - if core_count == 0 and replica_count == 0: - return None - elif core_count == NEO4J_CORES and replica_count == NEO4J_REPLICAS: - return ExistingService(NEO4J_HOST, core_port_range, replica_port_range) - else: - raise OSError("Some ports required by the test service are already in use") - - -class ExistingMachine: - - def __init__(self, address): - self.address = address - - -class ExistingService: - - def __init__(self, host, core_port_range, replica_port_range): - self._cores = [ExistingMachine((host, port)) for port in core_port_range] - self._replicas = [ExistingMachine((host, port)) for port in replica_port_range] - - @property - def addresses(self): - return [machine.address for machine in self.cores()] - - @property - def auth(self): - return NEO4J_AUTH - - def cores(self): - return self._cores - - def replicas(self): - return self._replicas - - def start(self, timeout=None): - pass - - def stop(self, timeout=None): - pass - - -existing_service = get_existing_service() - -if existing_service: - NEO4J_RELEASES = ["existing"] - - -@pytest.fixture(scope="session", params=NEO4J_RELEASES) -def service(request): - global NEO4J_SERVICE - if NEO4J_DEBUG: - from neo4j.debug import watch - watch("neo4j", "boltkit") - with NEO4J_LOCK: - assert NEO4J_SERVICE is None - NEO4J_SERVICE = existing_service - if existing_service: - NEO4J_SERVICE = existing_service - else: - NEO4J_SERVICE = Neo4jService(auth=NEO4J_AUTH, image=request.param, n_cores=NEO4J_CORES, n_replicas=NEO4J_REPLICAS) - NEO4J_SERVICE.start(timeout=300) - yield NEO4J_SERVICE - if NEO4J_SERVICE is not None: - NEO4J_SERVICE.stop(timeout=300) - NEO4J_SERVICE = None - -@pytest.fixture(scope="session") -def addresses(service): - try: - machines = service.cores() - except AttributeError: - machines = list(service.machines.values()) - return [machine.address for machine in machines] +class ForcedRollback(Exception): + def __init__(self, return_value): + super().__init__() + self.return_value = return_value -# @fixture(scope="session") -# def readonly_addresses(service): -# try: -# machines = service.replicas() -# except AttributeError: -# machines = [] -# return [machine.address for machine in machines] - - -@pytest.fixture(scope="session") -def address(addresses): - try: - return addresses[0] - except IndexError: - return None - - -# @fixture(scope="session") -# def readonly_address(readonly_addresses): -# try: -# return readonly_addresses[0] -# except IndexError: -# return None - - -@pytest.fixture(scope="session") -def targets(addresses): - return " ".join("{}:{}".format(address[0], address[1]) for address in addresses) - - -# @fixture(scope="session") -# def readonly_targets(addresses): -# return " ".join("{}:{}".format(address[0], address[1]) for address in readonly_addresses) - - -@pytest.fixture(scope="session") -def target(address): - return "{}:{}".format(address[0], address[1]) - - -# @fixture(scope="session") -# def readonly_target(readonly_address): -# if readonly_address: -# return "{}:{}".format(readonly_address[0], readonly_address[1]) -# else: -# return None - - -@pytest.fixture(scope="session") -def bolt_uri(service, target): - return "bolt://" + target - - -@pytest.fixture(scope="session") -def neo4j_uri(service, target): - return "neo4j://" + target - - -@pytest.fixture(scope="session") -def uri(bolt_uri): - return bolt_uri - - -# @fixture(scope="session") -# def readonly_bolt_uri(service, readonly_target): -# if readonly_target: -# return "bolt://" + readonly_target -# else: -# return None - - -@pytest.fixture(scope="session") -def auth(): - return NEO4J_AUTH - - -@pytest.fixture(scope="session") -def bolt_driver(target, auth): - try: - driver = GraphDatabase.bolt_driver(target, auth=auth) - try: - yield driver - finally: - driver.close() - except ServiceUnavailable as error: - if isinstance(error.__cause__, BoltHandshakeError): - pytest.skip(error.args[0]) - - -@pytest.fixture(scope="session") -def neo4j_driver(target, auth): - try: - driver = GraphDatabase.neo4j_driver(target, auth=auth) - driver._pool.update_routing_table(database=None, imp_user=None, - bookmarks=None) - except ServiceUnavailable as error: - if isinstance(error.__cause__, BoltHandshakeError): - pytest.skip(error.args[0]) - elif error.args[0] == "Unable to retrieve routing information": - pytest.skip(error.args[0]) - else: - raise - else: - try: - yield driver - finally: - driver.close() - - -@pytest.fixture(scope="session") -def server_info(bolt_driver): - with bolt_driver.session() as session: - summary = session.run("RETURN 1").consume() - return summary.server - - -@pytest.fixture(scope="session") -def bolt_protocol_version(server_info): - return server_info.protocol_version - - -@pytest.fixture(scope="session") -def requires_bolt_4x(bolt_protocol_version): - if bolt_protocol_version < (4, 0): - pytest.skip("Requires Bolt 4.0 or above") - - -@pytest.fixture(scope="session") -def driver(neo4j_driver): - return neo4j_driver - - -@pytest.fixture() -def session(bolt_driver): - session = bolt_driver.session() - try: - yield session - finally: - session.close() - - -@pytest.fixture() -def protocol_version(session): - result = session.run("RETURN 1 AS x") - yield session._connection.server_info.protocol_version - result.consume() - - -@pytest.fixture() -def cypher_eval(bolt_driver): - +@pytest.fixture +def cypher_eval(driver): def run_and_rollback(tx, cypher, **parameters): result = tx.run(cypher, **parameters) value = result.single().value() - tx._success = False # This is not a recommended pattern - return value + raise ForcedRollback(value) def f(cypher, **parameters): with bolt_driver.session() as session: - return session.write_transaction(run_and_rollback, cypher, **parameters) + try: + session.write_transaction(run_and_rollback, cypher, + **parameters) + raise RuntimeError("Expected rollback") + except ForcedRollback as e: + return e.return_value return f - - -def pytest_sessionfinish(session, exitstatus): - """ Called after the entire session to ensure Neo4j is shut down. - """ - global NEO4J_SERVICE - with NEO4J_LOCK: - if NEO4J_SERVICE is not None: - NEO4J_SERVICE.stop(timeout=300) - NEO4J_SERVICE = None diff --git a/tests/integration/examples/test_basic_auth_example.py b/tests/integration/examples/test_basic_auth_example.py index 93fed59b1..69c7626bd 100644 --- a/tests/integration/examples/test_basic_auth_example.py +++ b/tests/integration/examples/test_basic_auth_example.py @@ -20,7 +20,8 @@ from neo4j._exceptions import BoltHandshakeError from neo4j.exceptions import ServiceUnavailable -from tests.integration.examples import DriverSetupExample + +from . import DriverSetupExample # isort: off diff --git a/tests/integration/examples/test_bearer_auth_example.py b/tests/integration/examples/test_bearer_auth_example.py index 3e8a8b9ef..bb6988df6 100644 --- a/tests/integration/examples/test_bearer_auth_example.py +++ b/tests/integration/examples/test_bearer_auth_example.py @@ -17,7 +17,8 @@ import neo4j -from tests.integration.examples import DriverSetupExample + +from . import DriverSetupExample # isort: off diff --git a/tests/integration/examples/test_config_connection_pool_example.py b/tests/integration/examples/test_config_connection_pool_example.py index 4a68e1be9..7bdfc39a1 100644 --- a/tests/integration/examples/test_config_connection_pool_example.py +++ b/tests/integration/examples/test_config_connection_pool_example.py @@ -20,7 +20,8 @@ from neo4j._exceptions import BoltHandshakeError from neo4j.exceptions import ServiceUnavailable -from tests.integration.examples import DriverSetupExample + +from . import DriverSetupExample # isort: off diff --git a/tests/integration/examples/test_config_connection_timeout_example.py b/tests/integration/examples/test_config_connection_timeout_example.py index 2ca7945e8..7e67ec5d3 100644 --- a/tests/integration/examples/test_config_connection_timeout_example.py +++ b/tests/integration/examples/test_config_connection_timeout_example.py @@ -20,7 +20,8 @@ from neo4j._exceptions import BoltHandshakeError from neo4j.exceptions import ServiceUnavailable -from tests.integration.examples import DriverSetupExample + +from . import DriverSetupExample # isort: off diff --git a/tests/integration/examples/test_config_max_retry_time_example.py b/tests/integration/examples/test_config_max_retry_time_example.py index c0bfb6fa5..1efcbcb10 100644 --- a/tests/integration/examples/test_config_max_retry_time_example.py +++ b/tests/integration/examples/test_config_max_retry_time_example.py @@ -20,7 +20,8 @@ from neo4j._exceptions import BoltHandshakeError from neo4j.exceptions import ServiceUnavailable -from tests.integration.examples import DriverSetupExample + +from . import DriverSetupExample # isort: off diff --git a/tests/integration/examples/test_config_secure_example.py b/tests/integration/examples/test_config_secure_example.py index b10f03619..3860ab5b3 100644 --- a/tests/integration/examples/test_config_secure_example.py +++ b/tests/integration/examples/test_config_secure_example.py @@ -20,7 +20,8 @@ from neo4j._exceptions import BoltHandshakeError from neo4j.exceptions import ServiceUnavailable -from tests.integration.examples import DriverSetupExample + +from . import DriverSetupExample # isort: off @@ -39,7 +40,7 @@ class ConfigSecureExample(DriverSetupExample): def __init__(self, uri, auth): # trusted_certificates: # neo4j.TrustSystemCAs() - # (default) trust certificates from system store) + # (default) trust certificates from system store # neo4j.TrustAll() # trust all certificates # neo4j.TrustCustomCAs("", ...) diff --git a/tests/integration/examples/test_config_trust_example.py b/tests/integration/examples/test_config_trust_example.py index 189b4e18e..40ef9b9f9 100644 --- a/tests/integration/examples/test_config_trust_example.py +++ b/tests/integration/examples/test_config_trust_example.py @@ -18,7 +18,7 @@ import pytest -from tests.integration.examples import DriverSetupExample +from . import DriverSetupExample # isort: off diff --git a/tests/integration/examples/test_config_unencrypted_example.py b/tests/integration/examples/test_config_unencrypted_example.py index 34688b1cf..9e1f07ab5 100644 --- a/tests/integration/examples/test_config_unencrypted_example.py +++ b/tests/integration/examples/test_config_unencrypted_example.py @@ -20,7 +20,8 @@ from neo4j._exceptions import BoltHandshakeError from neo4j.exceptions import ServiceUnavailable -from tests.integration.examples import DriverSetupExample + +from . import DriverSetupExample # isort: off diff --git a/tests/integration/examples/test_custom_auth_example.py b/tests/integration/examples/test_custom_auth_example.py index 3d792734d..399ff94f1 100644 --- a/tests/integration/examples/test_custom_auth_example.py +++ b/tests/integration/examples/test_custom_auth_example.py @@ -20,7 +20,8 @@ from neo4j._exceptions import BoltHandshakeError from neo4j.exceptions import ServiceUnavailable -from tests.integration.examples import DriverSetupExample + +from . import DriverSetupExample # isort: off diff --git a/tests/integration/examples/test_cypher_error_example.py b/tests/integration/examples/test_cypher_error_example.py index c46903a01..7551ca0e9 100644 --- a/tests/integration/examples/test_cypher_error_example.py +++ b/tests/integration/examples/test_cypher_error_example.py @@ -48,9 +48,9 @@ def select_employee(tx, name): # end::cypher-error[] -def test_example(bolt_driver): +def test_example(driver): s = StringIO() with redirect_stdout(s): - example = Neo4jErrorExample(bolt_driver) + example = Neo4jErrorExample(driver) example.get_employee_number('Alice') assert s.getvalue().startswith("Invalid input") diff --git a/tests/integration/examples/test_database_selection_example.py b/tests/integration/examples/test_database_selection_example.py index 0a678400b..77cfe2b27 100644 --- a/tests/integration/examples/test_database_selection_example.py +++ b/tests/integration/examples/test_database_selection_example.py @@ -19,6 +19,11 @@ from contextlib import redirect_stdout from io import StringIO +from ...conftest import ( + mark_requires_edition, + mark_requires_min_bolt_version, +) + # isort: off # tag::database-selection-import[] @@ -58,10 +63,12 @@ def run_example_code(self): # end::database-selection[] -def test_database_selection_example(neo4j_uri, auth, requires_bolt_4x): +@mark_requires_min_bolt_version("4") +@mark_requires_edition("enterprise") +def test_database_selection_example(uri, auth): s = StringIO() with redirect_stdout(s): - example = DatabaseSelectionExample(neo4j_uri, auth[0], auth[1]) + example = DatabaseSelectionExample(uri, auth[0], auth[1]) example.run_example_code() example.close() assert s.getvalue().startswith("Hello, Example-Database") diff --git a/tests/integration/examples/test_kerberos_auth_example.py b/tests/integration/examples/test_kerberos_auth_example.py index 799bd034a..1b02e9f62 100644 --- a/tests/integration/examples/test_kerberos_auth_example.py +++ b/tests/integration/examples/test_kerberos_auth_example.py @@ -17,7 +17,7 @@ import pytest -from tests.integration.examples import DriverSetupExample +from . import DriverSetupExample # isort: off diff --git a/tests/integration/examples/test_transaction_function_example.py b/tests/integration/examples/test_transaction_function_example.py index 17bdd4234..1450f63f0 100644 --- a/tests/integration/examples/test_transaction_function_example.py +++ b/tests/integration/examples/test_transaction_function_example.py @@ -46,10 +46,17 @@ def add_person(self, name): return add_person(self.driver, name) -def test_example(bolt_driver): - eg = TransactionFunctionExample(bolt_driver) +def work(tx, query, **parameters): + res = tx.run(query, **parameters) + return [rec.values() for rec in res], res.consume() + + +def test_example(driver): + eg = TransactionFunctionExample(driver) with eg.driver.session() as session: - session.run("MATCH (_) DETACH DELETE _") + session.write_transaction(work, "MATCH (_) DETACH DELETE _") eg.add_person("Alice") - n = session.run("MATCH (a:Person) RETURN count(a)").single().value() - assert n == 1 + records, _ = session.read_transaction( + work, "MATCH (a:Person) RETURN count(a)" + ) + assert records == [[1]] diff --git a/tests/integration/examples/test_transaction_metadata_config_example.py b/tests/integration/examples/test_transaction_metadata_config_example.py index aadb2ad46..002fd8997 100644 --- a/tests/integration/examples/test_transaction_metadata_config_example.py +++ b/tests/integration/examples/test_transaction_metadata_config_example.py @@ -44,10 +44,17 @@ def add_person(self, name): return add_person(self.driver, name) -def test_example(bolt_driver): - eg = TransactionMetadataConfigExample(bolt_driver) +def work(tx, query, **parameters): + res = tx.run(query, **parameters) + return [rec.values() for rec in res], res.consume() + + +def test_example(driver): + eg = TransactionMetadataConfigExample(driver) with eg.driver.session() as session: - session.run("MATCH (_) DETACH DELETE _") + session.write_transaction(work, "MATCH (_) DETACH DELETE _") eg.add_person("Alice") - n = session.run("MATCH (a:Person) RETURN count(a)").single().value() - assert n == 1 + records, _ = session.read_transaction( + work, "MATCH (a:Person) RETURN count(a)" + ) + assert records == [[1]] diff --git a/tests/integration/examples/test_transaction_timeout_config_example.py b/tests/integration/examples/test_transaction_timeout_config_example.py index 0f7161015..8a0e8f6af 100644 --- a/tests/integration/examples/test_transaction_timeout_config_example.py +++ b/tests/integration/examples/test_transaction_timeout_config_example.py @@ -44,10 +44,17 @@ def add_person(self, name): return add_person(self.driver, name) -def test_example(bolt_driver): - eg = TransactionTimeoutConfigExample(bolt_driver) +def work(tx, query, **parameters): + res = tx.run(query, **parameters) + return [rec.values() for rec in res], res.consume() + + +def test_example(driver): + eg = TransactionTimeoutConfigExample(driver) with eg.driver.session() as session: - session.run("MATCH (_) DETACH DELETE _") + session.write_transaction(work, "MATCH (_) DETACH DELETE _") eg.add_person("Alice") - n = session.run("MATCH (a:Person) RETURN count(a)").single().value() - assert n == 1 + records, _ = session.read_transaction( + work, "MATCH (a:Person) RETURN count(a)" + ) + assert records == [[1]] diff --git a/tests/integration/sync/test_custom_ssl_context.py b/tests/integration/sync/test_custom_ssl_context.py index 0d2cb6669..56d09f684 100644 --- a/tests/integration/sync/test_custom_ssl_context.py +++ b/tests/integration/sync/test_custom_ssl_context.py @@ -25,7 +25,7 @@ @mark_sync_test -def test_custom_ssl_context_wraps_connection(target, auth, mocker): +def test_custom_ssl_context_wraps_connection(uri, auth, mocker): # Test that the driver calls either `.wrap_socket` or `.wrap_bio` on the # provided custom SSL context. @@ -39,8 +39,8 @@ def wrap_fail(*_, **__): fake_ssl_context.wrap_socket.side_effect = wrap_fail fake_ssl_context.wrap_bio.side_effect = wrap_fail - driver = GraphDatabase.neo4j_driver( - target, auth=auth, ssl_context=fake_ssl_context + driver = GraphDatabase.driver( + uri, auth=auth, ssl_context=fake_ssl_context ) with driver: with driver.session() as session: diff --git a/tests/integration/test_bolt_driver.py b/tests/integration/test_bolt_driver.py index df8b14fe1..15c00b34b 100644 --- a/tests/integration/test_bolt_driver.py +++ b/tests/integration/test_bolt_driver.py @@ -14,11 +14,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + import pytest -from pytest import fixture -@fixture +@pytest.fixture def server_info(driver): """ Simple fixture to provide quick and easy access to a :class:`.ServerInfo` object. @@ -31,7 +32,8 @@ def server_info(driver): # TODO: 6.0 - # This test will stay as python is currently the only driver exposing # the connection id. This will be removed in 6.0 -def test_server_connection_id(server_info): +def test_server_connection_id(driver): + server_info = driver.get_server_info() with pytest.warns(DeprecationWarning): cid = server_info.connection_id assert cid.startswith("bolt-") and cid[5:].isdigit() diff --git a/tests/integration/test_readme.py b/tests/integration/test_readme.py index d9e1101ee..bb6e56c95 100644 --- a/tests/integration/test_readme.py +++ b/tests/integration/test_readme.py @@ -18,7 +18,6 @@ import pytest -from neo4j._exceptions import BoltHandshakeError from neo4j.exceptions import ServiceUnavailable diff --git a/tests/performance/test_async_results.py b/tests/performance/test_async_results.py index f0bc2548f..74ce212e3 100644 --- a/tests/performance/test_async_results.py +++ b/tests/performance/test_async_results.py @@ -16,71 +16,35 @@ # limitations under the License. -import asyncio -from itertools import product +import pytest -from pytest import mark +from neo4j import GraphDatabase -from neo4j import AsyncGraphDatabase -from .tools import RemoteGraphDatabaseServer +def work(async_driver, *units_of_work): + async def runner(): + async with async_driver.session() as session: + for unit_of_work in units_of_work: + await session.read_transaction(unit_of_work) + return runner -class AsyncReadWorkload(object): +def unit_of_work_generator(record_count, record_width, value): + async def transaction_function(tx): + s = "UNWIND range(1, $record_count) AS _ RETURN {}".format( + ", ".join("$x AS x{}".format(i) for i in range(record_width))) + p = {"record_count": record_count, "x": value} + async for record in await tx.run(s, p): + assert all(x == value for x in record.values()) - server = None - driver = None - loop = None + return transaction_function - @classmethod - def setup_class(cls): - cls.server = server = RemoteGraphDatabaseServer() - server.start() - cls.loop = asyncio.new_event_loop() - asyncio.set_event_loop(cls.loop) - cls.driver = AsyncGraphDatabase.driver(server.server_uri, - auth=server.auth_token, - encrypted=server.encrypted) - @classmethod - def teardown_class(cls): - try: - cls.loop.run_until_complete(cls.driver.close()) - cls.server.stop() - finally: - cls.loop.stop() - asyncio.set_event_loop(None) - - def work(self, *units_of_work): - async def runner(): - async with self.driver.session() as session: - for unit_of_work in units_of_work: - await session.read_transaction(unit_of_work) - - def sync_runner(): - self.loop.run_until_complete(runner()) - - return sync_runner - - -class TestAsyncReadWorkload(AsyncReadWorkload): - - @staticmethod - def uow(record_count, record_width, value): - - async def _(tx): - s = "UNWIND range(1, $record_count) AS _ RETURN {}".format( - ", ".join("$x AS x{}".format(i) for i in range(record_width))) - p = {"record_count": record_count, "x": value} - async for record in await tx.run(s, p): - assert all(x == value for x in record.values()) - - return _ - - @mark.parametrize("record_count,record_width,value", product( - [1, 1000], # record count - [1, 10], # record width - [1, u'hello, world'], # value +@pytest.mark.parametrize("record_count", [1, 1000]) +@pytest.mark.parametrize("record_width", [1, 10]) +@pytest.mark.parametrize("value", [1, u'hello, world']) +def test_async_1x1(async_driver, aio_benchmark, record_count, record_width, value): + aio_benchmark(work( + async_driver, + unit_of_work_generator(record_count, record_width, value) )) - def test_1x1(self, benchmark, record_count, record_width, value): - benchmark(self.work(self.uow(record_count, record_width, value))) diff --git a/tests/performance/test_results.py b/tests/performance/test_results.py index 6a2205b0d..f7cbbd624 100644 --- a/tests/performance/test_results.py +++ b/tests/performance/test_results.py @@ -16,59 +16,32 @@ # limitations under the License. -from itertools import product - -from pytest import mark +import pytest from neo4j import GraphDatabase -from .tools import RemoteGraphDatabaseServer - - -class ReadWorkload(object): - - server = None - driver = None - - @classmethod - def setup_class(cls): - cls.server = server = RemoteGraphDatabaseServer() - server.start() - cls.driver = GraphDatabase.driver(server.server_uri, - auth=server.auth_token, - encrypted=server.encrypted) - - @classmethod - def teardown_class(cls): - cls.driver.close() - cls.server.stop() - - def work(self, *units_of_work): - def runner(): - with self.driver.session() as session: - for unit_of_work in units_of_work: - session.read_transaction(unit_of_work) - return runner +def work(driver, *units_of_work): + def runner(): + with driver.session() as session: + for unit_of_work in units_of_work: + session.read_transaction(unit_of_work) + return runner -class TestReadWorkload(ReadWorkload): - @staticmethod - def uow(record_count, record_width, value): +def unit_of_work(record_count, record_width, value): + def transaction_function(tx): + s = "UNWIND range(1, $record_count) AS _ RETURN {}".format( + ", ".join("$x AS x{}".format(i) for i in range(record_width))) + p = {"record_count": record_count, "x": value} + for record in tx.run(s, p): + assert all(x == value for x in record.values()) - def _(tx): - s = "UNWIND range(1, $record_count) AS _ RETURN {}".format( - ", ".join("$x AS x{}".format(i) for i in range(record_width))) - p = {"record_count": record_count, "x": value} - for record in tx.run(s, p): - assert all(x == value for x in record.values()) + return transaction_function - return _ - @mark.parametrize("record_count,record_width,value", product( - [1, 1000], # record count - [1, 10], # record width - [1, u'hello, world'], # value - )) - def test_1x1(self, benchmark, record_count, record_width, value): - benchmark(self.work(self.uow(record_count, record_width, value))) +@pytest.mark.parametrize("record_count", [1, 1000]) +@pytest.mark.parametrize("record_width", [1, 10]) +@pytest.mark.parametrize("value", [1, u'hello, world']) +def test_1x1(driver, benchmark, record_count, record_width, value): + benchmark(work(driver, unit_of_work(record_count, record_width, value))) diff --git a/tests/performance/tools.py b/tests/performance/tools.py deleted file mode 100644 index 27dfb886c..000000000 --- a/tests/performance/tools.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [https://neo4j.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from unittest import SkipTest - -from neo4j import GraphDatabase -from neo4j.exceptions import AuthError -from tests.env import ( - NEO4J_PASSWORD, - NEO4J_SERVER_URI, - NEO4J_USER, -) - - -def is_listening(address): - from socket import create_connection - try: - s = create_connection(address) - except IOError: - return False - else: - s.close() - return True - - -class RemoteGraphDatabaseServer(object): - server_uri = NEO4J_SERVER_URI or "bolt://localhost:7687" - auth_token = (NEO4J_USER or "neo4j", NEO4J_PASSWORD) - encrypted = NEO4J_SERVER_URI is not None - - def __enter__(self): - self.start() - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.stop() - - @classmethod - def start(cls): - with GraphDatabase.driver(cls.server_uri, auth=cls.auth_token, encrypted=cls.encrypted) as driver: - try: - with driver.session(): - print("Using existing remote server {}\n".format(cls.server_uri)) - return - except AuthError as error: - raise RuntimeError("Failed to authenticate (%s)" % error) - raise SkipTest("No remote Neo4j server available for %s" % cls.__name__) - - @classmethod - def stop(cls): - pass diff --git a/tests/requirements.txt b/tests/requirements.txt index 31d0646a1..7dbb10a0e 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,4 +1,3 @@ -git+https://github.com/neo4j-drivers/boltkit@4.2#egg=boltkit coverage>=5.5 pytest>=6.2.5 pytest-asyncio>=0.16.0 diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index f106eb6f4..10020aeee 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -121,7 +121,7 @@ async def test_bolt_connection_ping_timeout(): assert protocol_version is None -@pytest.fixture +@pytest.yield_fixture async def pool(): async with AsyncFakeBoltPool(("127.0.0.1", 7687)) as pool: yield pool diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index 72acfc5b7..d8fa21d81 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -121,7 +121,7 @@ def test_bolt_connection_ping_timeout(): assert protocol_version is None -@pytest.fixture +@pytest.yield_fixture def pool(): with FakeBoltPool(("127.0.0.1", 7687)) as pool: yield pool diff --git a/tox-performance.ini b/tox-performance.ini deleted file mode 100644 index 3c7779a16..000000000 --- a/tox-performance.ini +++ /dev/null @@ -1,18 +0,0 @@ -[tox] -envlist = - py37 - py38 - py39 - py310 - -[testenv] -passenv = - NEO4J_USER - NEO4J_PASSWORD - NEO4J_URI -commands = - python setup.py develop - pip install --upgrade -r {toxinidir}/tests/requirements.txt - coverage erase - coverage run -m pytest -v {posargs} tests/performance - coverage report diff --git a/tox-unit.ini b/tox-unit.ini deleted file mode 100644 index 11307a04a..000000000 --- a/tox-unit.ini +++ /dev/null @@ -1,12 +0,0 @@ -[tox] -envlist = - py37 - -[testenv] -deps = - -r tests/requirements.txt -commands = - coverage erase - coverage run -m pytest -v {posargs} \ - tests/unit - coverage report diff --git a/tox.ini b/tox.ini index aa4ea24fd..203bfb01b 100644 --- a/tox.ini +++ b/tox.ini @@ -1,37 +1,14 @@ [tox] -envlist = - py37 - py38 - py39 - py310 +envlist = py{37,38,39,310}-{unit,integration,performance} [testenv] passenv = - NEO4J_SERVER_PACKAGE - NEO4J_RELEASES - NEO4J_USER - NEO4J_PASSWORD - TEAMCITY_VERSION - TEAMCITY_HOST - TEAMCITY_USER - TEAMCITY_PASSWORD - JAVA_HOME - NEO4J_URI - AWS_ACCESS_KEY_ID - AWS_SECRET_ACCESS_KEY - APPDATA - LOCALAPPDATA - ProgramData - windir - ProgramFiles - ProgramFiles(x86) - ProgramW6432 -# APPDATA and below are needed for Neo4j >=4.3 service to run on Windows + TEST_NEO4J_* deps = -r tests/requirements.txt commands = coverage erase - coverage run -m pytest -v {posargs} \ - tests/unit \ - tests/integration - coverage report + unit: coverage run -m pytest -v {posargs} tests/unit + integration: coverage run -m pytest -v {posargs} tests/integration + performance: python -m pytest --benchmark-autosave -v {posargs} tests/performance + unit,integration: coverage report