From ec280497229a021c76acce85ed1ada5c7cb86ca6 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 10 Aug 2023 10:28:13 +0200 Subject: [PATCH 1/4] Optimize driver.execute_query by pipelining `BEGIN` --- src/neo4j/_async/driver.py | 7 +++--- src/neo4j/_async/work/session.py | 20 +++++++++++++--- src/neo4j/_async/work/transaction.py | 4 +++- src/neo4j/_sync/driver.py | 7 +++--- src/neo4j/_sync/work/session.py | 20 +++++++++++++--- src/neo4j/_sync/work/transaction.py | 4 +++- src/neo4j/_util/__init__.py | 22 +++++++++++++++++ src/neo4j/_util/_context_bool.py | 36 ++++++++++++++++++++++++++++ testkitbackend/test_config.json | 1 + 9 files changed, 107 insertions(+), 14 deletions(-) create mode 100644 src/neo4j/_util/__init__.py create mode 100644 src/neo4j/_util/_context_bool.py diff --git a/src/neo4j/_async/driver.py b/src/neo4j/_async/driver.py index 705573706..0d0af7ad0 100644 --- a/src/neo4j/_async/driver.py +++ b/src/neo4j/_async/driver.py @@ -860,9 +860,10 @@ async def example(driver: neo4j.AsyncDriver) -> neo4j.Record:: else: raise ValueError("Invalid routing control value: %r" % routing_) - return await executor( - _work, query_, parameters, result_transformer_ - ) + with session._pipelined_begin: + return await executor( + _work, query_, parameters, result_transformer_ + ) @property def execute_query_bookmark_manager(self) -> AsyncBookmarkManager: diff --git a/src/neo4j/_async/work/session.py b/src/neo4j/_async/work/session.py index 41227ea0c..0897bcfba 100644 --- a/src/neo4j/_async/work/session.py +++ b/src/neo4j/_async/work/session.py @@ -32,6 +32,7 @@ deprecated, PreviewWarning, ) +from ..._util import ContextBool from ..._work import Query from ...api import ( Bookmarks, @@ -100,6 +101,10 @@ class AsyncSession(AsyncWorkspace): # The state this session is in. _state_failed = False + _config: SessionConfig + _bookmark_manager: t.Optional[Bookmarks] + _pipelined_begin: ContextBool + def __init__(self, pool, session_config): assert isinstance(session_config, SessionConfig) if session_config.auth is not None: @@ -115,6 +120,7 @@ def __init__(self, pool, session_config): self._config = session_config self._initialize_bookmarks(session_config.bookmarks) self._bookmark_manager = session_config.bookmark_manager + self._pipelined_begin = ContextBool() async def __aenter__(self) -> AsyncSession: return self @@ -421,6 +427,7 @@ async def _open_transaction( bookmarks, access_mode, metadata, timeout, self._config.notifications_min_severity, self._config.notifications_disabled_categories, + pipelined=self._pipelined_begin ) async def begin_transaction( @@ -480,9 +487,15 @@ async def begin_transaction( return t.cast(AsyncTransaction, self._transaction) + async def _run_transaction( - self, access_mode, transaction_function, *args, **kwargs - ): + self, + access_mode: str, + transaction_function: t.Callable[ + te.Concatenate[AsyncManagedTransaction, _P], t.Awaitable[_R] + ], + *args: _P.args, **kwargs: _P.kwargs + ) -> _R: self._check_state() if not callable(transaction_function): raise TypeError("Unit of work is not callable") @@ -498,7 +511,7 @@ async def _run_transaction( errors = [] - t0 = -1 # Timer + t0: float = -1 # Timer while True: try: @@ -507,6 +520,7 @@ async def _run_transaction( access_mode=access_mode, metadata=metadata, timeout=timeout ) + assert isinstance(self._transaction, AsyncManagedTransaction) tx = self._transaction try: result = await transaction_function(tx, *args, **kwargs) diff --git a/src/neo4j/_async/work/transaction.py b/src/neo4j/_async/work/transaction.py index 91fa1bf6f..811493356 100644 --- a/src/neo4j/_async/work/transaction.py +++ b/src/neo4j/_async/work/transaction.py @@ -74,6 +74,7 @@ async def _exit(self, exception_type, exception_value, traceback): async def _begin( self, database, imp_user, bookmarks, access_mode, metadata, timeout, notifications_min_severity, notifications_disabled_categories, + pipelined=False, ): self._database = database self._connection.begin( @@ -83,7 +84,8 @@ async def _begin( notifications_disabled_categories=notifications_disabled_categories ) await self._error_handling_connection.send_all() - await self._error_handling_connection.fetch_all() + if not pipelined: + await self._error_handling_connection.fetch_all() async def _result_on_closed_handler(self): pass diff --git a/src/neo4j/_sync/driver.py b/src/neo4j/_sync/driver.py index 3b3d80a32..71a34554d 100644 --- a/src/neo4j/_sync/driver.py +++ b/src/neo4j/_sync/driver.py @@ -859,9 +859,10 @@ def example(driver: neo4j.Driver) -> neo4j.Record:: else: raise ValueError("Invalid routing control value: %r" % routing_) - return executor( - _work, query_, parameters, result_transformer_ - ) + with session._pipelined_begin: + return executor( + _work, query_, parameters, result_transformer_ + ) @property def execute_query_bookmark_manager(self) -> BookmarkManager: diff --git a/src/neo4j/_sync/work/session.py b/src/neo4j/_sync/work/session.py index 8f38739bd..b4b321f6a 100644 --- a/src/neo4j/_sync/work/session.py +++ b/src/neo4j/_sync/work/session.py @@ -32,6 +32,7 @@ deprecated, PreviewWarning, ) +from ..._util import ContextBool from ..._work import Query from ...api import ( Bookmarks, @@ -100,6 +101,10 @@ class Session(Workspace): # The state this session is in. _state_failed = False + _config: SessionConfig + _bookmark_manager: t.Optional[Bookmarks] + _pipelined_begin: ContextBool + def __init__(self, pool, session_config): assert isinstance(session_config, SessionConfig) if session_config.auth is not None: @@ -115,6 +120,7 @@ def __init__(self, pool, session_config): self._config = session_config self._initialize_bookmarks(session_config.bookmarks) self._bookmark_manager = session_config.bookmark_manager + self._pipelined_begin = ContextBool() def __enter__(self) -> Session: return self @@ -421,6 +427,7 @@ def _open_transaction( bookmarks, access_mode, metadata, timeout, self._config.notifications_min_severity, self._config.notifications_disabled_categories, + pipelined=self._pipelined_begin ) def begin_transaction( @@ -480,9 +487,15 @@ def begin_transaction( return t.cast(Transaction, self._transaction) + def _run_transaction( - self, access_mode, transaction_function, *args, **kwargs - ): + self, + access_mode: str, + transaction_function: t.Callable[ + te.Concatenate[ManagedTransaction, _P], t.Union[_R] + ], + *args: _P.args, **kwargs: _P.kwargs + ) -> _R: self._check_state() if not callable(transaction_function): raise TypeError("Unit of work is not callable") @@ -498,7 +511,7 @@ def _run_transaction( errors = [] - t0 = -1 # Timer + t0: float = -1 # Timer while True: try: @@ -507,6 +520,7 @@ def _run_transaction( access_mode=access_mode, metadata=metadata, timeout=timeout ) + assert isinstance(self._transaction, ManagedTransaction) tx = self._transaction try: result = transaction_function(tx, *args, **kwargs) diff --git a/src/neo4j/_sync/work/transaction.py b/src/neo4j/_sync/work/transaction.py index 0f314c50e..1222169be 100644 --- a/src/neo4j/_sync/work/transaction.py +++ b/src/neo4j/_sync/work/transaction.py @@ -74,6 +74,7 @@ def _exit(self, exception_type, exception_value, traceback): def _begin( self, database, imp_user, bookmarks, access_mode, metadata, timeout, notifications_min_severity, notifications_disabled_categories, + pipelined=False, ): self._database = database self._connection.begin( @@ -83,7 +84,8 @@ def _begin( notifications_disabled_categories=notifications_disabled_categories ) self._error_handling_connection.send_all() - self._error_handling_connection.fetch_all() + if not pipelined: + self._error_handling_connection.fetch_all() def _result_on_closed_handler(self): pass diff --git a/src/neo4j/_util/__init__.py b/src/neo4j/_util/__init__.py new file mode 100644 index 000000000..1144eb930 --- /dev/null +++ b/src/neo4j/_util/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ._context_bool import ContextBool + + +__all__ = ["ContextBool"] diff --git a/src/neo4j/_util/_context_bool.py b/src/neo4j/_util/_context_bool.py new file mode 100644 index 000000000..4f36e35ae --- /dev/null +++ b/src/neo4j/_util/_context_bool.py @@ -0,0 +1,36 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + + +__all__ = ["ContextBool"] + + +class ContextBool: + def __init__(self) -> None: + self._value = False + + def __bool__(self) -> bool: + return self._value + + def __enter__(self) -> None: + self._value = True + + def __exit__(self, exc_type, exc_value, traceback) -> None: + self._value = False diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index f4000e89a..353115de0 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -65,6 +65,7 @@ "Optimization:AuthPipelining": true, "Optimization:ConnectionReuse": true, "Optimization:EagerTransactionBegin": true, + "Optimization:ExecuteQueryPipelining": true, "Optimization:ImplicitDefaultArguments": true, "Optimization:MinimalBookmarksSet": true, "Optimization:MinimalResets": true, From 57c9d5ed3487dcfa44ef80d06e230cc7946606ed Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 10 Aug 2023 12:32:01 +0200 Subject: [PATCH 2/4] Fix and extend unit tests --- tests/unit/async_/test_driver.py | 107 +++++++++------------ tests/unit/async_/work/test_transaction.py | 46 +++++++++ tests/unit/sync/test_driver.py | 107 +++++++++------------ tests/unit/sync/work/test_transaction.py | 46 +++++++++ 4 files changed, 186 insertions(+), 120 deletions(-) diff --git a/tests/unit/async_/test_driver.py b/tests/unit/async_/test_driver.py index ea8026567..052c93550 100644 --- a/tests/unit/async_/test_driver.py +++ b/tests/unit/async_/test_driver.py @@ -62,6 +62,15 @@ ) +@pytest.fixture +def session_cls_mock(mocker): + session_cls_mock = mocker.patch("neo4j._async.driver.AsyncSession", + autospec=True) + session_cls_mock.return_value.attach_mock(mocker.NonCallableMagicMock(), + "_pipelined_begin") + yield session_cls_mock + + @pytest.mark.parametrize("protocol", ("bolt://", "bolt+s://", "bolt+ssc://")) @pytest.mark.parametrize("host", ("localhost", "127.0.0.1", "[::1]", "[0:0:0:0:0:0:0:1]")) @@ -223,7 +232,8 @@ async def test_driver_opens_write_session_by_default(uri, fake_pool, mocker): mocker.ANY, mocker.ANY, mocker.ANY, - mocker.ANY + mocker.ANY, + mocker.ANY, ) await driver.close() @@ -296,12 +306,10 @@ async def test_get_server_info_parameters_are_experimental( @mark_async_test -async def test_with_builtin_bookmark_manager(mocker) -> None: +async def test_with_builtin_bookmark_manager(session_cls_mock) -> None: bmm = AsyncGraphDatabase.bookmark_manager() # could be one line, but want to make sure the type checker assigns # bmm whatever type AsyncGraphDatabase.bookmark_manager() returns - session_cls_mock = mocker.patch("neo4j._async.driver.AsyncSession", - autospec=True) driver = AsyncGraphDatabase.driver("bolt://localhost") async with driver as driver: _ = driver.session(bookmark_manager=bmm) @@ -310,7 +318,9 @@ async def test_with_builtin_bookmark_manager(mocker) -> None: @AsyncTestDecorators.mark_async_only_test -async def test_with_custom_inherited_async_bookmark_manager(mocker) -> None: +async def test_with_custom_inherited_async_bookmark_manager( + session_cls_mock +) -> None: class BMM(AsyncBookmarkManager): async def update_bookmarks( self, previous_bookmarks: t.Iterable[str], @@ -325,10 +335,6 @@ async def forget(self, databases: t.Iterable[str]) -> None: ... bmm = BMM() - # could be one line, but want to make sure the type checker assigns - # bmm whatever type AsyncGraphDatabase.bookmark_manager() returns - session_cls_mock = mocker.patch("neo4j._async.driver.AsyncSession", - autospec=True) driver = AsyncGraphDatabase.driver("bolt://localhost") async with driver as driver: _ = driver.session(bookmark_manager=bmm) @@ -337,7 +343,9 @@ async def forget(self, databases: t.Iterable[str]) -> None: @mark_async_test -async def test_with_custom_inherited_sync_bookmark_manager(mocker) -> None: +async def test_with_custom_inherited_sync_bookmark_manager( + session_cls_mock +) -> None: class BMM(BookmarkManager): def update_bookmarks( self, previous_bookmarks: t.Iterable[str], @@ -352,10 +360,6 @@ def forget(self, databases: t.Iterable[str]) -> None: ... bmm = BMM() - # could be one line, but want to make sure the type checker assigns - # bmm whatever type AsyncGraphDatabase.bookmark_manager() returns - session_cls_mock = mocker.patch("neo4j._async.driver.AsyncSession", - autospec=True) driver = AsyncGraphDatabase.driver("bolt://localhost") async with driver as driver: _ = driver.session(bookmark_manager=bmm) @@ -364,7 +368,9 @@ def forget(self, databases: t.Iterable[str]) -> None: @AsyncTestDecorators.mark_async_only_test -async def test_with_custom_ducktype_async_bookmark_manager(mocker) -> None: +async def test_with_custom_ducktype_async_bookmark_manager( + session_cls_mock +) -> None: class BMM: async def update_bookmarks( self, previous_bookmarks: t.Iterable[str], @@ -379,10 +385,6 @@ async def forget(self, databases: t.Iterable[str]) -> None: ... bmm = BMM() - # could be one line, but want to make sure the type checker assigns - # bmm whatever type AsyncGraphDatabase.bookmark_manager() returns - session_cls_mock = mocker.patch("neo4j._async.driver.AsyncSession", - autospec=True) driver = AsyncGraphDatabase.driver("bolt://localhost") async with driver as driver: _ = driver.session(bookmark_manager=bmm) @@ -391,7 +393,9 @@ async def forget(self, databases: t.Iterable[str]) -> None: @mark_async_test -async def test_with_custom_ducktype_sync_bookmark_manager(mocker) -> None: +async def test_with_custom_ducktype_sync_bookmark_manager( + session_cls_mock +) -> None: class BMM: def update_bookmarks( self, previous_bookmarks: t.Iterable[str], @@ -406,10 +410,6 @@ def forget(self, databases: t.Iterable[str]) -> None: ... bmm = BMM() - # could be one line, but want to make sure the type checker assigns - # bmm whatever type AsyncGraphDatabase.bookmark_manager() returns - session_cls_mock = mocker.patch("neo4j._async.driver.AsyncSession", - autospec=True) driver = AsyncGraphDatabase.driver("bolt://localhost") async with driver as driver: _ = driver.session(bookmark_manager=bmm) @@ -468,7 +468,6 @@ def forget(self, databases: t.Iterable[str]) -> None: async def test_driver_factory_with_notification_filters( uri: str, mocker, - fake_pool, min_sev: t.Optional[_T_NotificationMinimumSeverity], dis_cats: t.Optional[t.Iterable[_T_NotificationDisabledCategory]], ) -> None: @@ -551,6 +550,7 @@ async def test_driver_factory_with_notification_filters( @mark_async_test async def test_session_factory_with_notification_filter( uri: str, + session_cls_mock, mocker, min_sev: t.Optional[_T_NotificationMinimumSeverity], dis_cats: t.Optional[t.Iterable[_T_NotificationDisabledCategory]], @@ -559,8 +559,6 @@ async def test_session_factory_with_notification_filter( pool_mock: t.Any = mocker.AsyncMock(spec=pool_cls) mocker.patch.object(pool_cls, "open", return_value=pool_mock) pool_mock.address = mocker.Mock() - session_cls_mock = mocker.patch("neo4j._async.driver.AsyncSession", - autospec=True) async with AsyncGraphDatabase.driver(uri, auth=None) as driver: if min_sev is ...: @@ -630,11 +628,10 @@ async def test_execute_query_work(mocker) -> None: @pytest.mark.parametrize("positional", (True, False)) @mark_async_test async def test_execute_query_query( - mocker, query: str, positional: bool + query: str, positional: bool, session_cls_mock, mocker ) -> None: driver = AsyncGraphDatabase.driver("bolt://localhost") - session_cls_mock = mocker.patch("neo4j._async.driver.AsyncSession", - autospec=True) + async with driver as driver: if positional: res = await driver.execute_query(query) @@ -658,12 +655,11 @@ async def test_execute_query_query( @pytest.mark.parametrize("positional", (True, False)) @mark_async_test async def test_execute_query_parameters( - mocker, parameters: t.Optional[t.Dict[str, t.Any]], - positional: bool + parameters: t.Optional[t.Dict[str, t.Any]], positional: bool, + session_cls_mock, mocker ) -> None: driver = AsyncGraphDatabase.driver("bolt://localhost") - session_cls_mock = mocker.patch("neo4j._async.driver.AsyncSession", - autospec=True) + async with driver as driver: if parameters is Ellipsis: parameters = None @@ -690,11 +686,10 @@ async def test_execute_query_parameters( )) @mark_async_test async def test_execute_query_keyword_parameters( - mocker, parameters: t.Optional[t.Dict[str, t.Any]], + parameters: t.Optional[t.Dict[str, t.Any]], session_cls_mock, mocker ) -> None: driver = AsyncGraphDatabase.driver("bolt://localhost") - session_cls_mock = mocker.patch("neo4j._async.driver.AsyncSession", - autospec=True) + async with driver as driver: if parameters is None: res = await driver.execute_query("") @@ -763,11 +758,11 @@ async def test_execute_query_parameter_precedence( kw_params: t.Dict[str, t.Any], expected_params: t.Dict[str, t.Any], positional: bool, + session_cls_mock, mocker ) -> None: driver = AsyncGraphDatabase.driver("bolt://localhost") - session_cls_mock = mocker.patch("neo4j._async.driver.AsyncSession", - autospec=True) + async with driver as driver: if params is None: res = await driver.execute_query("", **kw_params) @@ -802,12 +797,11 @@ async def test_execute_query_parameter_precedence( @pytest.mark.parametrize("positional", (True, False)) @mark_async_test async def test_execute_query_routing_control( - mocker, session_executor: str, positional: bool, - routing_mode: t.Union[neo4j.RoutingControl, te.Literal["r", "w"], None] + session_executor: str, positional: bool, + routing_mode: t.Union[neo4j.RoutingControl, te.Literal["r", "w"], None], + session_cls_mock, mocker ) -> None: driver = AsyncGraphDatabase.driver("bolt://localhost") - session_cls_mock = mocker.patch("neo4j._async.driver.AsyncSession", - autospec=True) async with driver as driver: if routing_mode is None: res = await driver.execute_query("") @@ -834,11 +828,9 @@ async def test_execute_query_routing_control( @pytest.mark.parametrize("positional", (True, False)) @mark_async_test async def test_execute_query_database( - mocker, database: t.Optional[str], positional: bool + database: t.Optional[str], positional: bool, session_cls_mock ) -> None: driver = AsyncGraphDatabase.driver("bolt://localhost") - session_cls_mock = mocker.patch("neo4j._async.driver.AsyncSession", - autospec=True) async with driver as driver: if database is Ellipsis: database = None @@ -858,11 +850,9 @@ async def test_execute_query_database( @pytest.mark.parametrize("positional", (True, False)) @mark_async_test async def test_execute_query_impersonated_user( - mocker, impersonated_user: t.Optional[str], positional: bool + impersonated_user: t.Optional[str], positional: bool, session_cls_mock ) -> None: driver = AsyncGraphDatabase.driver("bolt://localhost") - session_cls_mock = mocker.patch("neo4j._async.driver.AsyncSession", - autospec=True) async with driver as driver: if impersonated_user is Ellipsis: impersonated_user = None @@ -886,12 +876,11 @@ async def test_execute_query_impersonated_user( @pytest.mark.parametrize("positional", (True, False)) @mark_async_test async def test_execute_query_bookmark_manager( - mocker, positional: bool, - bookmark_manager: t.Union[AsyncBookmarkManager, BookmarkManager, None] + positional: bool, + bookmark_manager: t.Union[AsyncBookmarkManager, BookmarkManager, None], + session_cls_mock ) -> None: driver = AsyncGraphDatabase.driver("bolt://localhost") - session_cls_mock = mocker.patch("neo4j._async.driver.AsyncSession", - autospec=True) async with driver as driver: if bookmark_manager is Ellipsis: bookmark_manager = driver.execute_query_bookmark_manager @@ -915,12 +904,12 @@ async def test_execute_query_bookmark_manager( @pytest.mark.parametrize("positional", (True, False)) @mark_async_test async def test_execute_query_result_transformer( - mocker, positional: bool, - result_transformer: t.Callable[[AsyncResult], t.Awaitable[SomeClass]] + positional: bool, + result_transformer: t.Callable[[AsyncResult], t.Awaitable[SomeClass]], + session_cls_mock, + mocker ) -> None: driver = AsyncGraphDatabase.driver("bolt://localhost") - session_cls_mock = mocker.patch("neo4j._async.driver.AsyncSession", - autospec=True) res: t.Any async with driver as driver: if result_transformer is Ellipsis: @@ -952,10 +941,8 @@ async def test_execute_query_result_transformer( @mark_async_test -async def test_supports_session_auth(mocker) -> None: +async def test_supports_session_auth(session_cls_mock) -> None: driver = AsyncGraphDatabase.driver("bolt://localhost") - session_cls_mock = mocker.patch("neo4j._async.driver.AsyncSession", - autospec=True) async with driver as driver: res = await driver.supports_session_auth() diff --git a/tests/unit/async_/work/test_transaction.py b/tests/unit/async_/work/test_transaction.py index 14ff98abc..53a0c4418 100644 --- a/tests/unit/async_/work/test_transaction.py +++ b/tests/unit/async_/work/test_transaction.py @@ -23,6 +23,7 @@ from neo4j import ( AsyncTransaction, + NotificationMinimumSeverity, Query, ) @@ -229,3 +230,48 @@ async def test_transaction_no_rollback_on_defunct_connections( async_fake_connection.is_reset_mock.assert_not_called() async_fake_connection.reset.assert_not_called() async_fake_connection.rollback.assert_not_called() + + +@pytest.mark.parametrize("pipeline", (True, False)) +@mark_async_test +async def test_transaction_begin_pipelining( + async_fake_connection, pipeline +) -> None: + tx = AsyncTransaction( + async_fake_connection, 2, lambda *args, **kwargs: None, + lambda *args, **kwargs: None, lambda *args, **kwargs: None + ) + database = "db" + imp_user = None + bookmarks = ["bookmark1", "bookmark2"] + access_mode = "r" + metadata = {"key": "value"} + timeout = 42 + notifications_min_severity = NotificationMinimumSeverity.INFORMATION + notifications_disabled_categories = ["cat1", "cat2"] + + await tx._begin( + database, imp_user, bookmarks, access_mode, metadata, timeout, + notifications_min_severity, notifications_disabled_categories, + pipelined=pipeline + ) + expected_calls = [ + ( + "begin", + { + "db": database, + "imp_user": imp_user, + "bookmarks": bookmarks, + "mode": access_mode, + "metadata": metadata, + "timeout": timeout, + "notifications_min_severity": notifications_min_severity, + "notifications_disabled_categories": + notifications_disabled_categories, + } + ), + ("send_all",), + ] + if not pipeline: + expected_calls.append(("fetch_all",)) + assert async_fake_connection.method_calls == expected_calls diff --git a/tests/unit/sync/test_driver.py b/tests/unit/sync/test_driver.py index 7a60f0f98..47112637a 100644 --- a/tests/unit/sync/test_driver.py +++ b/tests/unit/sync/test_driver.py @@ -61,6 +61,15 @@ ) +@pytest.fixture +def session_cls_mock(mocker): + session_cls_mock = mocker.patch("neo4j._sync.driver.Session", + autospec=True) + session_cls_mock.return_value.attach_mock(mocker.NonCallableMagicMock(), + "_pipelined_begin") + yield session_cls_mock + + @pytest.mark.parametrize("protocol", ("bolt://", "bolt+s://", "bolt+ssc://")) @pytest.mark.parametrize("host", ("localhost", "127.0.0.1", "[::1]", "[0:0:0:0:0:0:0:1]")) @@ -222,7 +231,8 @@ def test_driver_opens_write_session_by_default(uri, fake_pool, mocker): mocker.ANY, mocker.ANY, mocker.ANY, - mocker.ANY + mocker.ANY, + mocker.ANY, ) driver.close() @@ -295,12 +305,10 @@ def test_get_server_info_parameters_are_experimental( @mark_sync_test -def test_with_builtin_bookmark_manager(mocker) -> None: +def test_with_builtin_bookmark_manager(session_cls_mock) -> None: bmm = GraphDatabase.bookmark_manager() # could be one line, but want to make sure the type checker assigns # bmm whatever type AsyncGraphDatabase.bookmark_manager() returns - session_cls_mock = mocker.patch("neo4j._sync.driver.Session", - autospec=True) driver = GraphDatabase.driver("bolt://localhost") with driver as driver: _ = driver.session(bookmark_manager=bmm) @@ -309,7 +317,9 @@ def test_with_builtin_bookmark_manager(mocker) -> None: @TestDecorators.mark_async_only_test -def test_with_custom_inherited_async_bookmark_manager(mocker) -> None: +def test_with_custom_inherited_async_bookmark_manager( + session_cls_mock +) -> None: class BMM(BookmarkManager): def update_bookmarks( self, previous_bookmarks: t.Iterable[str], @@ -324,10 +334,6 @@ def forget(self, databases: t.Iterable[str]) -> None: ... bmm = BMM() - # could be one line, but want to make sure the type checker assigns - # bmm whatever type AsyncGraphDatabase.bookmark_manager() returns - session_cls_mock = mocker.patch("neo4j._sync.driver.Session", - autospec=True) driver = GraphDatabase.driver("bolt://localhost") with driver as driver: _ = driver.session(bookmark_manager=bmm) @@ -336,7 +342,9 @@ def forget(self, databases: t.Iterable[str]) -> None: @mark_sync_test -def test_with_custom_inherited_sync_bookmark_manager(mocker) -> None: +def test_with_custom_inherited_sync_bookmark_manager( + session_cls_mock +) -> None: class BMM(BookmarkManager): def update_bookmarks( self, previous_bookmarks: t.Iterable[str], @@ -351,10 +359,6 @@ def forget(self, databases: t.Iterable[str]) -> None: ... bmm = BMM() - # could be one line, but want to make sure the type checker assigns - # bmm whatever type AsyncGraphDatabase.bookmark_manager() returns - session_cls_mock = mocker.patch("neo4j._sync.driver.Session", - autospec=True) driver = GraphDatabase.driver("bolt://localhost") with driver as driver: _ = driver.session(bookmark_manager=bmm) @@ -363,7 +367,9 @@ def forget(self, databases: t.Iterable[str]) -> None: @TestDecorators.mark_async_only_test -def test_with_custom_ducktype_async_bookmark_manager(mocker) -> None: +def test_with_custom_ducktype_async_bookmark_manager( + session_cls_mock +) -> None: class BMM: def update_bookmarks( self, previous_bookmarks: t.Iterable[str], @@ -378,10 +384,6 @@ def forget(self, databases: t.Iterable[str]) -> None: ... bmm = BMM() - # could be one line, but want to make sure the type checker assigns - # bmm whatever type AsyncGraphDatabase.bookmark_manager() returns - session_cls_mock = mocker.patch("neo4j._sync.driver.Session", - autospec=True) driver = GraphDatabase.driver("bolt://localhost") with driver as driver: _ = driver.session(bookmark_manager=bmm) @@ -390,7 +392,9 @@ def forget(self, databases: t.Iterable[str]) -> None: @mark_sync_test -def test_with_custom_ducktype_sync_bookmark_manager(mocker) -> None: +def test_with_custom_ducktype_sync_bookmark_manager( + session_cls_mock +) -> None: class BMM: def update_bookmarks( self, previous_bookmarks: t.Iterable[str], @@ -405,10 +409,6 @@ def forget(self, databases: t.Iterable[str]) -> None: ... bmm = BMM() - # could be one line, but want to make sure the type checker assigns - # bmm whatever type AsyncGraphDatabase.bookmark_manager() returns - session_cls_mock = mocker.patch("neo4j._sync.driver.Session", - autospec=True) driver = GraphDatabase.driver("bolt://localhost") with driver as driver: _ = driver.session(bookmark_manager=bmm) @@ -467,7 +467,6 @@ def forget(self, databases: t.Iterable[str]) -> None: def test_driver_factory_with_notification_filters( uri: str, mocker, - fake_pool, min_sev: t.Optional[_T_NotificationMinimumSeverity], dis_cats: t.Optional[t.Iterable[_T_NotificationDisabledCategory]], ) -> None: @@ -550,6 +549,7 @@ def test_driver_factory_with_notification_filters( @mark_sync_test def test_session_factory_with_notification_filter( uri: str, + session_cls_mock, mocker, min_sev: t.Optional[_T_NotificationMinimumSeverity], dis_cats: t.Optional[t.Iterable[_T_NotificationDisabledCategory]], @@ -558,8 +558,6 @@ def test_session_factory_with_notification_filter( pool_mock: t.Any = mocker.MagicMock(spec=pool_cls) mocker.patch.object(pool_cls, "open", return_value=pool_mock) pool_mock.address = mocker.Mock() - session_cls_mock = mocker.patch("neo4j._sync.driver.Session", - autospec=True) with GraphDatabase.driver(uri, auth=None) as driver: if min_sev is ...: @@ -629,11 +627,10 @@ def test_execute_query_work(mocker) -> None: @pytest.mark.parametrize("positional", (True, False)) @mark_sync_test def test_execute_query_query( - mocker, query: str, positional: bool + query: str, positional: bool, session_cls_mock, mocker ) -> None: driver = GraphDatabase.driver("bolt://localhost") - session_cls_mock = mocker.patch("neo4j._sync.driver.Session", - autospec=True) + with driver as driver: if positional: res = driver.execute_query(query) @@ -657,12 +654,11 @@ def test_execute_query_query( @pytest.mark.parametrize("positional", (True, False)) @mark_sync_test def test_execute_query_parameters( - mocker, parameters: t.Optional[t.Dict[str, t.Any]], - positional: bool + parameters: t.Optional[t.Dict[str, t.Any]], positional: bool, + session_cls_mock, mocker ) -> None: driver = GraphDatabase.driver("bolt://localhost") - session_cls_mock = mocker.patch("neo4j._sync.driver.Session", - autospec=True) + with driver as driver: if parameters is Ellipsis: parameters = None @@ -689,11 +685,10 @@ def test_execute_query_parameters( )) @mark_sync_test def test_execute_query_keyword_parameters( - mocker, parameters: t.Optional[t.Dict[str, t.Any]], + parameters: t.Optional[t.Dict[str, t.Any]], session_cls_mock, mocker ) -> None: driver = GraphDatabase.driver("bolt://localhost") - session_cls_mock = mocker.patch("neo4j._sync.driver.Session", - autospec=True) + with driver as driver: if parameters is None: res = driver.execute_query("") @@ -762,11 +757,11 @@ def test_execute_query_parameter_precedence( kw_params: t.Dict[str, t.Any], expected_params: t.Dict[str, t.Any], positional: bool, + session_cls_mock, mocker ) -> None: driver = GraphDatabase.driver("bolt://localhost") - session_cls_mock = mocker.patch("neo4j._sync.driver.Session", - autospec=True) + with driver as driver: if params is None: res = driver.execute_query("", **kw_params) @@ -801,12 +796,11 @@ def test_execute_query_parameter_precedence( @pytest.mark.parametrize("positional", (True, False)) @mark_sync_test def test_execute_query_routing_control( - mocker, session_executor: str, positional: bool, - routing_mode: t.Union[neo4j.RoutingControl, te.Literal["r", "w"], None] + session_executor: str, positional: bool, + routing_mode: t.Union[neo4j.RoutingControl, te.Literal["r", "w"], None], + session_cls_mock, mocker ) -> None: driver = GraphDatabase.driver("bolt://localhost") - session_cls_mock = mocker.patch("neo4j._sync.driver.Session", - autospec=True) with driver as driver: if routing_mode is None: res = driver.execute_query("") @@ -833,11 +827,9 @@ def test_execute_query_routing_control( @pytest.mark.parametrize("positional", (True, False)) @mark_sync_test def test_execute_query_database( - mocker, database: t.Optional[str], positional: bool + database: t.Optional[str], positional: bool, session_cls_mock ) -> None: driver = GraphDatabase.driver("bolt://localhost") - session_cls_mock = mocker.patch("neo4j._sync.driver.Session", - autospec=True) with driver as driver: if database is Ellipsis: database = None @@ -857,11 +849,9 @@ def test_execute_query_database( @pytest.mark.parametrize("positional", (True, False)) @mark_sync_test def test_execute_query_impersonated_user( - mocker, impersonated_user: t.Optional[str], positional: bool + impersonated_user: t.Optional[str], positional: bool, session_cls_mock ) -> None: driver = GraphDatabase.driver("bolt://localhost") - session_cls_mock = mocker.patch("neo4j._sync.driver.Session", - autospec=True) with driver as driver: if impersonated_user is Ellipsis: impersonated_user = None @@ -885,12 +875,11 @@ def test_execute_query_impersonated_user( @pytest.mark.parametrize("positional", (True, False)) @mark_sync_test def test_execute_query_bookmark_manager( - mocker, positional: bool, - bookmark_manager: t.Union[BookmarkManager, BookmarkManager, None] + positional: bool, + bookmark_manager: t.Union[BookmarkManager, BookmarkManager, None], + session_cls_mock ) -> None: driver = GraphDatabase.driver("bolt://localhost") - session_cls_mock = mocker.patch("neo4j._sync.driver.Session", - autospec=True) with driver as driver: if bookmark_manager is Ellipsis: bookmark_manager = driver.execute_query_bookmark_manager @@ -914,12 +903,12 @@ def test_execute_query_bookmark_manager( @pytest.mark.parametrize("positional", (True, False)) @mark_sync_test def test_execute_query_result_transformer( - mocker, positional: bool, - result_transformer: t.Callable[[Result], t.Union[SomeClass]] + positional: bool, + result_transformer: t.Callable[[Result], t.Union[SomeClass]], + session_cls_mock, + mocker ) -> None: driver = GraphDatabase.driver("bolt://localhost") - session_cls_mock = mocker.patch("neo4j._sync.driver.Session", - autospec=True) res: t.Any with driver as driver: if result_transformer is Ellipsis: @@ -951,10 +940,8 @@ def test_execute_query_result_transformer( @mark_sync_test -def test_supports_session_auth(mocker) -> None: +def test_supports_session_auth(session_cls_mock) -> None: driver = GraphDatabase.driver("bolt://localhost") - session_cls_mock = mocker.patch("neo4j._sync.driver.Session", - autospec=True) with driver as driver: res = driver.supports_session_auth() diff --git a/tests/unit/sync/work/test_transaction.py b/tests/unit/sync/work/test_transaction.py index d74ddccc1..5a262b9c4 100644 --- a/tests/unit/sync/work/test_transaction.py +++ b/tests/unit/sync/work/test_transaction.py @@ -22,6 +22,7 @@ import pytest from neo4j import ( + NotificationMinimumSeverity, Query, Transaction, ) @@ -229,3 +230,48 @@ def test_transaction_no_rollback_on_defunct_connections( fake_connection.is_reset_mock.assert_not_called() fake_connection.reset.assert_not_called() fake_connection.rollback.assert_not_called() + + +@pytest.mark.parametrize("pipeline", (True, False)) +@mark_sync_test +def test_transaction_begin_pipelining( + fake_connection, pipeline +) -> None: + tx = Transaction( + fake_connection, 2, lambda *args, **kwargs: None, + lambda *args, **kwargs: None, lambda *args, **kwargs: None + ) + database = "db" + imp_user = None + bookmarks = ["bookmark1", "bookmark2"] + access_mode = "r" + metadata = {"key": "value"} + timeout = 42 + notifications_min_severity = NotificationMinimumSeverity.INFORMATION + notifications_disabled_categories = ["cat1", "cat2"] + + tx._begin( + database, imp_user, bookmarks, access_mode, metadata, timeout, + notifications_min_severity, notifications_disabled_categories, + pipelined=pipeline + ) + expected_calls = [ + ( + "begin", + { + "db": database, + "imp_user": imp_user, + "bookmarks": bookmarks, + "mode": access_mode, + "metadata": metadata, + "timeout": timeout, + "notifications_min_severity": notifications_min_severity, + "notifications_disabled_categories": + notifications_disabled_categories, + } + ), + ("send_all",), + ] + if not pipeline: + expected_calls.append(("fetch_all",)) + assert fake_connection.method_calls == expected_calls From 5b02a57e8fac656c6a5e0fd38b7b0dde9db32c5c Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Thu, 17 Aug 2023 14:58:57 +0200 Subject: [PATCH 3/4] Reduce network flushing on pipelined transactions Signed-off-by: Grant Lodge <6323995+thelonelyvulpes@users.noreply.github.com> --- src/neo4j/_async/work/transaction.py | 2 +- src/neo4j/_sync/work/transaction.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/neo4j/_async/work/transaction.py b/src/neo4j/_async/work/transaction.py index 811493356..23c54fd21 100644 --- a/src/neo4j/_async/work/transaction.py +++ b/src/neo4j/_async/work/transaction.py @@ -83,8 +83,8 @@ async def _begin( notifications_min_severity=notifications_min_severity, notifications_disabled_categories=notifications_disabled_categories ) - await self._error_handling_connection.send_all() if not pipelined: + await self._error_handling_connection.send_all() await self._error_handling_connection.fetch_all() async def _result_on_closed_handler(self): diff --git a/src/neo4j/_sync/work/transaction.py b/src/neo4j/_sync/work/transaction.py index 1222169be..26884429d 100644 --- a/src/neo4j/_sync/work/transaction.py +++ b/src/neo4j/_sync/work/transaction.py @@ -83,8 +83,8 @@ def _begin( notifications_min_severity=notifications_min_severity, notifications_disabled_categories=notifications_disabled_categories ) - self._error_handling_connection.send_all() if not pipelined: + self._error_handling_connection.send_all() self._error_handling_connection.fetch_all() def _result_on_closed_handler(self): From 55bc8f7a518ac44cd046587ae2392fb3565d2541 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Thu, 17 Aug 2023 17:20:32 +0200 Subject: [PATCH 4/4] Adjust unit tests to deferred send Signed-off-by: Grant Lodge <6323995+thelonelyvulpes@users.noreply.github.com> --- tests/unit/async_/work/test_transaction.py | 4 ++-- tests/unit/sync/work/test_transaction.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/async_/work/test_transaction.py b/tests/unit/async_/work/test_transaction.py index 53a0c4418..33238c76f 100644 --- a/tests/unit/async_/work/test_transaction.py +++ b/tests/unit/async_/work/test_transaction.py @@ -255,7 +255,7 @@ async def test_transaction_begin_pipelining( notifications_min_severity, notifications_disabled_categories, pipelined=pipeline ) - expected_calls = [ + expected_calls: list = [ ( "begin", { @@ -270,8 +270,8 @@ async def test_transaction_begin_pipelining( notifications_disabled_categories, } ), - ("send_all",), ] if not pipeline: + expected_calls.append(("send_all",)) expected_calls.append(("fetch_all",)) assert async_fake_connection.method_calls == expected_calls diff --git a/tests/unit/sync/work/test_transaction.py b/tests/unit/sync/work/test_transaction.py index 5a262b9c4..6bffe7846 100644 --- a/tests/unit/sync/work/test_transaction.py +++ b/tests/unit/sync/work/test_transaction.py @@ -255,7 +255,7 @@ def test_transaction_begin_pipelining( notifications_min_severity, notifications_disabled_categories, pipelined=pipeline ) - expected_calls = [ + expected_calls: list = [ ( "begin", { @@ -270,8 +270,8 @@ def test_transaction_begin_pipelining( notifications_disabled_categories, } ), - ("send_all",), ] if not pipeline: + expected_calls.append(("send_all",)) expected_calls.append(("fetch_all",)) assert fake_connection.method_calls == expected_calls