From 55e6b2486830278d778b73afe8e676bc8da9e01c Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Wed, 14 Jul 2021 14:53:44 +0200 Subject: [PATCH] 2nd round of migrating integration tests to testkit --- tests/integration/test_bolt_driver.py | 206 +------------------------- tests/unit/work/_fake_connection.py | 83 +++++++++++ tests/unit/work/test_result.py | 4 + tests/unit/work/test_session.py | 88 +++++------ tests/unit/work/test_transaction.py | 59 ++++++++ 5 files changed, 183 insertions(+), 257 deletions(-) create mode 100644 tests/unit/work/_fake_connection.py create mode 100644 tests/unit/work/test_transaction.py diff --git a/tests/integration/test_bolt_driver.py b/tests/integration/test_bolt_driver.py index 2326191b0..5b2e76a0c 100644 --- a/tests/integration/test_bolt_driver.py +++ b/tests/integration/test_bolt_driver.py @@ -60,198 +60,6 @@ def test_encrypted_set_to_false_by_default(bolt_driver): assert bolt_driver.encrypted is False -def test_bolt_driver_fetch_size_config_case_on_close_result_consume(bolt_uri, auth): - # python -m pytest tests/integration/test_bolt_driver.py -s -v -k test_bolt_driver_fetch_size_config_case_on_close_result_consume - try: - with GraphDatabase.driver(bolt_uri, auth=auth, user_agent="test") as driver: - assert isinstance(driver, BoltDriver) - with driver.session(fetch_size=2, default_access_mode=READ_ACCESS) as session: - result = session.run("UNWIND [1,2,3,4] AS x RETURN x") - # Check the expected result with logging manually - except ServiceUnavailable as error: - if isinstance(error.__cause__, BoltHandshakeError): - pytest.skip(error.args[0]) - - -def test_bolt_driver_fetch_size_config_case_normal(bolt_uri, auth): - # python -m pytest tests/integration/test_bolt_driver.py -s -v -k test_bolt_driver_fetch_size_config_case_normal - try: - with GraphDatabase.driver(bolt_uri, auth=auth, user_agent="test") as driver: - assert isinstance(driver, BoltDriver) - with driver.session(fetch_size=2, default_access_mode=READ_ACCESS) as session: - expected = [] - result = session.run("UNWIND [1,2,3,4] AS x RETURN x") - for record in result: - expected.append(record["x"]) - - assert expected == [1, 2, 3, 4] - except ServiceUnavailable as error: - if isinstance(error.__cause__, BoltHandshakeError): - pytest.skip(error.args[0]) - - -def test_bolt_driver_fetch_size_config_run_consume_run(bolt_uri, auth): - # python -m pytest tests/integration/test_bolt_driver.py -s -v -k test_bolt_driver_fetch_size_config_run_consume_run - try: - with GraphDatabase.driver(bolt_uri, auth=auth, user_agent="test") as driver: - assert isinstance(driver, BoltDriver) - with driver.session(fetch_size=2, default_access_mode=READ_ACCESS) as session: - expected = [] - result1 = session.run("UNWIND [1,2,3,4] AS x RETURN x") - result1.consume() - result2 = session.run("UNWIND [5,6,7,8] AS x RETURN x") - - for record in result2: - expected.append(record["x"]) - - result_summary = result2.consume() - assert isinstance(result_summary, ResultSummary) - - assert expected == [5, 6, 7, 8] - except ServiceUnavailable as error: - if isinstance(error.__cause__, BoltHandshakeError): - pytest.skip(error.args[0]) - - -def test_bolt_driver_fetch_size_config_run_run(bolt_uri, auth): - # python -m pytest tests/integration/test_bolt_driver.py -s -v -k test_bolt_driver_fetch_size_config_run_run - try: - with GraphDatabase.driver(bolt_uri, auth=auth, user_agent="test") as driver: - assert isinstance(driver, BoltDriver) - with driver.session(fetch_size=2, default_access_mode=READ_ACCESS) as session: - expected = [] - result1 = session.run("UNWIND [1,2,3,4] AS x RETURN x") - result2 = session.run("UNWIND [5,6,7,8] AS x RETURN x") - - for record in result2: - expected.append(record["x"]) - - result_summary = result2.consume() - assert isinstance(result_summary, ResultSummary) - - assert expected == [5, 6, 7, 8] - except ServiceUnavailable as error: - if isinstance(error.__cause__, BoltHandshakeError): - pytest.skip(error.args[0]) - - -def test_bolt_driver_read_transaction_fetch_size_config_normal_case(bolt_uri, auth): - # python -m pytest tests/integration/test_bolt_driver.py -s -v -k test_bolt_driver_read_transaction_fetch_size_config_normal_case - @unit_of_work(timeout=3, metadata={"foo": "bar"}) - def unwind(transaction): - assert isinstance(transaction, Transaction) - values = [] - result = transaction.run("UNWIND [1,2,3,4] AS x RETURN x") - assert isinstance(result, Result) - for record in result: - values.append(record["x"]) - return values - - try: - with GraphDatabase.driver(bolt_uri, auth=auth, user_agent="test") as driver: - assert isinstance(driver, BoltDriver) - with driver.session(fetch_size=2, default_access_mode=READ_ACCESS) as session: - expected = session.read_transaction(unwind) - - assert expected == [1, 2, 3, 4] - except ServiceUnavailable as error: - if isinstance(error.__cause__, BoltHandshakeError): - pytest.skip(error.args[0]) - - -def test_bolt_driver_multiple_results_case_1(bolt_uri, auth): - # python -m pytest tests/integration/test_bolt_driver.py -s -v -k test_bolt_driver_multiple_results_case_1 - try: - with GraphDatabase.driver(bolt_uri, auth=auth, user_agent="test") as driver: - assert isinstance(driver, BoltDriver) - with driver.session(fetch_size=2, default_access_mode=READ_ACCESS) as session: - transaction = session.begin_transaction(timeout=3, metadata={"foo": "bar"}) - result1 = transaction.run("UNWIND [1,2,3,4] AS x RETURN x") - values1 = [] - for ix in result1: - values1.append(ix["x"]) - transaction.commit() - assert values1 == [1, 2, 3, 4] - - except ServiceUnavailable as error: - if isinstance(error.__cause__, BoltHandshakeError): - pytest.skip(error.args[0]) - - -def test_bolt_driver_multiple_results_case_2(bolt_uri, auth): - # python -m pytest tests/integration/test_bolt_driver.py -s -v -k test_bolt_driver_multiple_results_case_2 - try: - with GraphDatabase.driver(bolt_uri, auth=auth, user_agent="test") as driver: - assert isinstance(driver, BoltDriver) - with driver.session(fetch_size=2, default_access_mode=READ_ACCESS) as session: - transaction = session.begin_transaction(timeout=3, metadata={"foo": "bar"}) - result1 = transaction.run("UNWIND [1,2,3,4] AS x RETURN x") - result2 = transaction.run("UNWIND [5,6,7,8] AS x RETURN x") - values1 = [] - values2 = [] - for ix in result2: - values2.append(ix["x"]) - for ix in result1: - values1.append(ix["x"]) - transaction.commit() - assert values2 == [5, 6, 7, 8] - assert values1 == [1, 2, 3, 4] - - except ServiceUnavailable as error: - if isinstance(error.__cause__, BoltHandshakeError): - pytest.skip(error.args[0]) - - -def test_bolt_driver_multiple_results_case_3(bolt_uri, auth): - # python -m pytest tests/integration/test_bolt_driver.py -s -v -k test_bolt_driver_multiple_results_case_3 - try: - with GraphDatabase.driver(bolt_uri, auth=auth, user_agent="test") as driver: - assert isinstance(driver, BoltDriver) - with driver.session(fetch_size=2, default_access_mode=READ_ACCESS) as session: - transaction = session.begin_transaction(timeout=3, metadata={"foo": "bar"}) - values1 = [] - values2 = [] - result1 = transaction.run("UNWIND [1,2,3,4] AS x RETURN x") - result1_iter = iter(result1) - values1.append(next(result1_iter)["x"]) - result2 = transaction.run("UNWIND [5,6,7,8] AS x RETURN x") - for ix in result2: - values2.append(ix["x"]) - transaction.commit() # Should discard the rest of records in result1 and result2 -> then set mode discard. - assert values2 == [5, 6, 7, 8] - assert result2._closed is True - assert values1 == [1, ] - - try: - values1.append(next(result1_iter)["x"]) - except StopIteration as e: - # Bolt 4.0 - assert values1 == [1, ] - assert result1._closed is True - else: - # Bolt 3 only have PULL ALL and no qid so it will behave like autocommit r1=session.run rs2=session.run - values1.append(next(result1_iter)["x"]) - assert values1 == [1, 2, 3] - assert result1._closed is False - - assert result1._closed is True - - except ServiceUnavailable as error: - if isinstance(error.__cause__, BoltHandshakeError): - pytest.skip(error.args[0]) - - -def test_bolt_driver_case_pull_no_records(driver): - # python -m pytest tests/integration/test_bolt_driver.py -s -v -k test_bolt_driver_case_pull_no_records - try: - with driver.session(default_access_mode=WRITE_ACCESS) as session: - with session.begin_transaction() as tx: - result = tx.run("CREATE (a:Thing {uuid:$uuid})", uuid=123) - except ServiceUnavailable as error: - if isinstance(error.__cause__, BoltHandshakeError): - pytest.skip(error.args[0]) - - @fixture def server_info(driver): """ Simple fixture to provide quick and easy access to a @@ -262,18 +70,8 @@ def server_info(driver): yield summary.server -def test_server_address(server_info): - assert server_info.address == ("127.0.0.1", 7687) - - -def test_server_protocol_version(server_info): - assert server_info.protocol_version >= (3, 0) - - -def test_server_agent(server_info): - assert server_info.agent.startswith("Neo4j/") - - +# TODO: this test will stay asy python is currently the only driver exposing the +# connection id. So this might change in the future. def test_server_connection_id(server_info): cid = server_info.connection_id assert cid.startswith("bolt-") and cid[5:].isdigit() diff --git a/tests/unit/work/_fake_connection.py b/tests/unit/work/_fake_connection.py new file mode 100644 index 000000000..de0c49925 --- /dev/null +++ b/tests/unit/work/_fake_connection.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from unittest.mock import NonCallableMagicMock + +import pytest + +from neo4j import ServerInfo + + +class FakeConnection(NonCallableMagicMock): + callbacks = [] + server_info = ServerInfo("127.0.0.1", (4, 3)) + + def fetch_message(self, *args, **kwargs): + if self.callbacks: + cb = self.callbacks.pop(0) + cb() + return super().__getattr__("fetch_message")(*args, **kwargs) + + def fetch_all(self, *args, **kwargs): + while self.callbacks: + cb = self.callbacks.pop(0) + cb() + return super().__getattr__("fetch_all")(*args, **kwargs) + + def __getattr__(self, name): + parent = super() + + def build_message_handler(name): + def func(*args, **kwargs): + def callback(): + for cb_name, param_count in ( + ("on_success", 1), + ("on_summary", 0) + ): + cb = kwargs.get(cb_name, None) + if callable(cb): + try: + param_count = \ + len(inspect.signature(cb).parameters) + except ValueError: + # e.g. built-in method as cb + pass + if param_count == 1: + cb({}) + else: + cb() + self.callbacks.append(callback) + return parent.__getattr__(name)(*args, **kwargs) + + return func + + if name in ("run", "commit", "pull", "rollback", "discard"): + return build_message_handler(name) + return parent.__getattr__(name) + + def defunct(self): + return False + + +@pytest.fixture +def fake_connection(): + return FakeConnection() diff --git a/tests/unit/work/test_result.py b/tests/unit/work/test_result.py index d7f49ece4..56f92f524 100644 --- a/tests/unit/work/test_result.py +++ b/tests/unit/work/test_result.py @@ -189,6 +189,8 @@ def _fetch_and_compare_all_records(result, key, expected_records, method, received_records.append([record.data().get(key, None)]) if limit is not None and len(received_records) == limit: break + if limit is None: + assert result._closed elif method == "next": iter_ = iter(result) n = len(expected_records) if limit is None else limit @@ -197,6 +199,7 @@ def _fetch_and_compare_all_records(result, key, expected_records, method, if limit is None: with pytest.raises(StopIteration): received_records.append([next(iter_).get(key, None)]) + assert result._closed elif method == "new iter": n = len(expected_records) if limit is None else limit for _ in range(n): @@ -204,6 +207,7 @@ def _fetch_and_compare_all_records(result, key, expected_records, method, if limit is None: with pytest.raises(StopIteration): received_records.append([next(iter(result)).get(key, None)]) + assert result._closed else: raise ValueError() assert received_records == expected_records diff --git a/tests/unit/work/test_session.py b/tests/unit/work/test_session.py index d69819b5e..3cffeebb2 100644 --- a/tests/unit/work/test_session.py +++ b/tests/unit/work/test_session.py @@ -18,66 +18,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import pytest -from unittest.mock import NonCallableMagicMock from neo4j import ( - ServerInfo, Session, SessionConfig, + Transaction, + unit_of_work, ) - -class FakeConnection(NonCallableMagicMock): - callbacks = [] - server_info = ServerInfo("127.0.0.1", (4, 3)) - - def fetch_message(self, *args, **kwargs): - if self.callbacks: - cb = self.callbacks.pop(0) - cb() - return super().__getattr__("fetch_message")(*args, **kwargs) - - def fetch_all(self, *args, **kwargs): - while self.callbacks: - cb = self.callbacks.pop(0) - cb() - return super().__getattr__("fetch_all")(*args, **kwargs) - - def __getattr__(self, name): - parent = super() - - def build_message_handler(name): - def func(*args, **kwargs): - def callback(): - for cb_name, param_count in ( - ("on_success", 1), - ("on_summary", 0) - ): - cb = kwargs.get(cb_name, None) - if callable(cb): - try: - param_count = \ - len(inspect.signature(cb).parameters) - except ValueError: - # e.g. built-in method as cb - pass - if param_count == 1: - cb({}) - else: - cb() - self.callbacks.append(callback) - return parent.__getattr__(name)(*args, **kwargs) - - return func - - if name in ("run", "commit", "pull", "rollback", "discard"): - return build_message_handler(name) - return parent.__getattr__(name) - - def defunct(self): - return False +from ._fake_connection import FakeConnection @pytest.fixture() @@ -215,3 +165,35 @@ def test_session_run_wrong_types(pool, query, error_type): with Session(pool, SessionConfig()) as session: with pytest.raises(error_type): session.run(query) + + +@pytest.mark.parametrize("tx_type", ("write_transaction", "read_transaction")) +def test_tx_function_argument_type(pool, tx_type): + def work(tx): + assert isinstance(tx, Transaction) + + with Session(pool, SessionConfig()) as session: + getattr(session, tx_type)(work) + + +@pytest.mark.parametrize("tx_type", ("write_transaction", "read_transaction")) +@pytest.mark.parametrize("decorator_kwargs", ( + {}, + {"timeout": 5}, + {"metadata": {"foo": "bar"}}, + {"timeout": 5, "metadata": {"foo": "bar"}}, + +)) +def test_decorated_tx_function_argument_type(pool, tx_type, decorator_kwargs): + @unit_of_work(**decorator_kwargs) + def work(tx): + assert isinstance(tx, Transaction) + + with Session(pool, SessionConfig()) as session: + getattr(session, tx_type)(work) + + +def test_session_tx_type(pool): + with Session(pool, SessionConfig()) as session: + tx = session.begin_transaction() + assert isinstance(tx, Transaction) diff --git a/tests/unit/work/test_transaction.py b/tests/unit/work/test_transaction.py new file mode 100644 index 000000000..116c72b27 --- /dev/null +++ b/tests/unit/work/test_transaction.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock + +import pytest + +from neo4j import ( + Transaction, +) + +from ._fake_connection import fake_connection + + +def test_transaction_context_calls_commit(mocker, fake_connection): + on_closed = MagicMock() + on_network_error = MagicMock() + tx = Transaction(fake_connection, 2, on_closed, on_network_error) + mock_commit = mocker.patch.object(tx, "commit", wraps=tx.commit) + mock_rollback = mocker.patch.object(tx, "rollback", wraps=tx.rollback) + with tx as tx_: + assert tx is tx_ + pass + mock_commit.assert_called_once_with() + assert mock_rollback.call_count == 0 + + +def test_transaction_context_calls_rollback_on_error(mocker, fake_connection): + class OopsError(RuntimeError): + pass + + on_closed = MagicMock() + on_network_error = MagicMock() + tx = Transaction(fake_connection, 2, on_closed, on_network_error) + mock_commit = mocker.patch.object(tx, "commit", wraps=tx.commit) + mock_rollback = mocker.patch.object(tx, "rollback", wraps=tx.rollback) + with pytest.raises(OopsError): + with tx as tx_: + assert tx is tx_ + raise OopsError + assert mock_commit.call_count == 0 + mock_rollback.assert_called_once_with()