From 9af317335d548d080624caa91b3851ee63dffd53 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 29 Nov 2022 14:37:22 +0100 Subject: [PATCH 1/2] Improve error reporting on routing discovery Routing drivers (`neo4j[+s[sc]]://` scheme) retry fetching a routing table on many different errors that are considered not retryable in the context of transactions. This is to overall improve the driver's stability when connecting to clusters. However, this poses the risk of hiding user-input errors (e.g., selecting a database name that is invalid or doesn't exist). Hence, the driver blacklists a handful of selected error codes upon which the discovery process is terminated prematurely, raising the raw error to the user. We expand the list to include more errors: * `Neo.ClientError.Statement.TypeError`, e.g., when trying to impersonate an integer. * `Neo.ClientError.Statement.ArgumentError`, e.g., when trying to impersonate without the required permissions. * `Neo.ClientError.Request.Invalid`, e.g., when trying to select an integer database. --- src/neo4j/exceptions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/neo4j/exceptions.py b/src/neo4j/exceptions.py index c49a7cdf6..c3ca5b5e6 100644 --- a/src/neo4j/exceptions.py +++ b/src/neo4j/exceptions.py @@ -240,7 +240,10 @@ def is_fatal_during_discovery(self) -> bool: return False if self.code in ("Neo.ClientError.Database.DatabaseNotFound", "Neo.ClientError.Transaction.InvalidBookmark", - "Neo.ClientError.Transaction.InvalidBookmarkMixture"): + "Neo.ClientError.Transaction.InvalidBookmarkMixture", + "Neo.ClientError.Statement.TypeError", + "Neo.ClientError.Statement.ArgumentError", + "Neo.ClientError.Request.Invalid"): return True if (self.code.startswith("Neo.ClientError.Security.") and self.code != "Neo.ClientError.Security." From 009cbde1ea0e21d4f3b75718fe432e6bb6ee944c Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Wed, 30 Nov 2022 10:47:02 +0100 Subject: [PATCH 2/2] Add tests for discovery retries and fast failure --- CHANGELOG.md | 6 + src/neo4j/__init__.py | 2 +- src/neo4j/_async/driver.py | 4 +- src/neo4j/_async/io/_bolt3.py | 2 +- src/neo4j/_async/io/_bolt4.py | 2 +- src/neo4j/_async/io/_bolt5.py | 2 +- src/neo4j/_async/io/_pool.py | 2 +- src/neo4j/_async/work/result.py | 2 +- src/neo4j/_async/work/session.py | 4 +- src/neo4j/_sync/driver.py | 4 +- src/neo4j/_sync/io/_bolt3.py | 2 +- src/neo4j/_sync/io/_bolt4.py | 2 +- src/neo4j/_sync/io/_bolt5.py | 2 +- src/neo4j/_sync/io/_pool.py | 2 +- src/neo4j/_sync/work/result.py | 2 +- src/neo4j/_sync/work/session.py | 4 +- src/neo4j/api.py | 2 +- src/neo4j/exceptions.py | 20 ++- src/neo4j/spatial/__init__.py | 6 +- tests/unit/async_/io/test_neo4j_pool.py | 184 ++++++++++++++++++------ tests/unit/sync/io/test_neo4j_pool.py | 184 ++++++++++++++++++------ 21 files changed, 328 insertions(+), 112 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ee952fcba..f6ea9be2b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,12 @@ See also https://github.com/neo4j/neo4j-python-driver/wiki for more details. +## Version 5.4 +- Undocumented helper methods `Neo4jError.is_fatal_during_discovery` and + `Neo4jError.invalidates_all_connections` have been deprecated and will be + removed without replacement in version 6.0. + + ## Version 5.3 - Python 3.11 support added - Removed undocumented, unused `neo4j.data.map_type` diff --git a/src/neo4j/__init__.py b/src/neo4j/__init__.py index f6a1a7ad9..3483e537f 100644 --- a/src/neo4j/__init__.py +++ b/src/neo4j/__init__.py @@ -150,7 +150,7 @@ def __getattr__(name): - # TODO 6.0 - remove this + # TODO: 6.0 - remove this if name in ( "log", "Config", "PoolConfig", "SessionConfig", "WorkspaceConfig" ): diff --git a/src/neo4j/_async/driver.py b/src/neo4j/_async/driver.py index 06f116a24..953ec7d20 100644 --- a/src/neo4j/_async/driver.py +++ b/src/neo4j/_async/driver.py @@ -131,7 +131,7 @@ def driver(cls, uri, *, auth=None, **config) -> AsyncDriver: driver_type, security_type, parsed = parse_neo4j_uri(uri) - # TODO: 6.0 remove "trust" config option + # TODO: 6.0 - remove "trust" config option if "trust" in config.keys(): if config["trust"] not in ( TRUST_ALL_CERTIFICATES, @@ -166,7 +166,7 @@ def driver(cls, uri, *, auth=None, **config) -> AsyncDriver: or "ssl_context" in config.keys())): from ..exceptions import ConfigurationError - # TODO: 6.0 remove "trust" from error message + # TODO: 6.0 - remove "trust" from error message raise ConfigurationError( 'The config settings "encrypted", "trust", ' '"trusted_certificates", and "ssl_context" can only be ' diff --git a/src/neo4j/_async/io/_bolt3.py b/src/neo4j/_async/io/_bolt3.py index 2662f5267..2f81f5261 100644 --- a/src/neo4j/_async/io/_bolt3.py +++ b/src/neo4j/_async/io/_bolt3.py @@ -377,7 +377,7 @@ async def _process_message(self, tag, fields): self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: - if self.pool and e.invalidates_all_connections(): + if self.pool and e._invalidates_all_connections(): await self.pool.mark_all_stale() raise else: diff --git a/src/neo4j/_async/io/_bolt4.py b/src/neo4j/_async/io/_bolt4.py index 1bd9457ba..70c3d6ad1 100644 --- a/src/neo4j/_async/io/_bolt4.py +++ b/src/neo4j/_async/io/_bolt4.py @@ -333,7 +333,7 @@ async def _process_message(self, tag, fields): self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: - if self.pool and e.invalidates_all_connections(): + if self.pool and e._invalidates_all_connections(): await self.pool.mark_all_stale() raise else: diff --git a/src/neo4j/_async/io/_bolt5.py b/src/neo4j/_async/io/_bolt5.py index 01d81ce9b..24953bb5b 100644 --- a/src/neo4j/_async/io/_bolt5.py +++ b/src/neo4j/_async/io/_bolt5.py @@ -324,7 +324,7 @@ async def _process_message(self, tag, fields): self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: - if self.pool and e.invalidates_all_connections(): + if self.pool and e._invalidates_all_connections(): await self.pool.mark_all_stale() raise else: diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index deaa27c37..d1c0f7c9d 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -538,7 +538,7 @@ async def fetch_routing_table( # checks if the code is an error that is caused by the client. In # this case there is no sense in trying to fetch a RT from another # router. Hence, the driver should fail fast during discovery. - if e.is_fatal_during_discovery(): + if e._is_fatal_during_discovery(): raise except (ServiceUnavailable, SessionExpired): pass diff --git a/src/neo4j/_async/work/result.py b/src/neo4j/_async/work/result.py index ffcf18679..1cbb9270b 100644 --- a/src/neo4j/_async/work/result.py +++ b/src/neo4j/_async/work/result.py @@ -611,7 +611,7 @@ async def to_df( :: - res = await tx.run("UNWIND range(1, 10) AS n RETURN n, n+1 as m") + res = await tx.run("UNWIND range(1, 10) AS n RETURN n, n+1 AS m") df = await res.to_df() for instance will return a DataFrame with two columns: ``n`` and ``m`` diff --git a/src/neo4j/_async/work/session.py b/src/neo4j/_async/work/session.py index c4c706acd..e7e8820b1 100644 --- a/src/neo4j/_async/work/session.py +++ b/src/neo4j/_async/work/session.py @@ -591,7 +591,7 @@ async def get_two_tx(tx): READ_ACCESS, transaction_function, *args, **kwargs ) - # TODO 6.0: Remove this method + # TODO: 6.0 - Remove this method @deprecated("read_transaction has been renamed to execute_read") async def read_transaction( self, @@ -673,7 +673,7 @@ async def create_node_tx(tx, name): WRITE_ACCESS, transaction_function, *args, **kwargs ) - # TODO 6.0: Remove this method + # TODO: 6.0 - Remove this method @deprecated("write_transaction has been renamed to execute_write") async def write_transaction( self, diff --git a/src/neo4j/_sync/driver.py b/src/neo4j/_sync/driver.py index 191922f95..22c4d11ed 100644 --- a/src/neo4j/_sync/driver.py +++ b/src/neo4j/_sync/driver.py @@ -128,7 +128,7 @@ def driver(cls, uri, *, auth=None, **config) -> Driver: driver_type, security_type, parsed = parse_neo4j_uri(uri) - # TODO: 6.0 remove "trust" config option + # TODO: 6.0 - remove "trust" config option if "trust" in config.keys(): if config["trust"] not in ( TRUST_ALL_CERTIFICATES, @@ -163,7 +163,7 @@ def driver(cls, uri, *, auth=None, **config) -> Driver: or "ssl_context" in config.keys())): from ..exceptions import ConfigurationError - # TODO: 6.0 remove "trust" from error message + # TODO: 6.0 - remove "trust" from error message raise ConfigurationError( 'The config settings "encrypted", "trust", ' '"trusted_certificates", and "ssl_context" can only be ' diff --git a/src/neo4j/_sync/io/_bolt3.py b/src/neo4j/_sync/io/_bolt3.py index 5d5c250b1..249f3ecfc 100644 --- a/src/neo4j/_sync/io/_bolt3.py +++ b/src/neo4j/_sync/io/_bolt3.py @@ -377,7 +377,7 @@ def _process_message(self, tag, fields): self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: - if self.pool and e.invalidates_all_connections(): + if self.pool and e._invalidates_all_connections(): self.pool.mark_all_stale() raise else: diff --git a/src/neo4j/_sync/io/_bolt4.py b/src/neo4j/_sync/io/_bolt4.py index 48017a6dd..3ae411388 100644 --- a/src/neo4j/_sync/io/_bolt4.py +++ b/src/neo4j/_sync/io/_bolt4.py @@ -333,7 +333,7 @@ def _process_message(self, tag, fields): self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: - if self.pool and e.invalidates_all_connections(): + if self.pool and e._invalidates_all_connections(): self.pool.mark_all_stale() raise else: diff --git a/src/neo4j/_sync/io/_bolt5.py b/src/neo4j/_sync/io/_bolt5.py index fd9e754be..1be6de95d 100644 --- a/src/neo4j/_sync/io/_bolt5.py +++ b/src/neo4j/_sync/io/_bolt5.py @@ -324,7 +324,7 @@ def _process_message(self, tag, fields): self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: - if self.pool and e.invalidates_all_connections(): + if self.pool and e._invalidates_all_connections(): self.pool.mark_all_stale() raise else: diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index bec22e370..0d0faad17 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -538,7 +538,7 @@ def fetch_routing_table( # checks if the code is an error that is caused by the client. In # this case there is no sense in trying to fetch a RT from another # router. Hence, the driver should fail fast during discovery. - if e.is_fatal_during_discovery(): + if e._is_fatal_during_discovery(): raise except (ServiceUnavailable, SessionExpired): pass diff --git a/src/neo4j/_sync/work/result.py b/src/neo4j/_sync/work/result.py index 0f467d6b2..1077116df 100644 --- a/src/neo4j/_sync/work/result.py +++ b/src/neo4j/_sync/work/result.py @@ -611,7 +611,7 @@ def to_df( :: - res = tx.run("UNWIND range(1, 10) AS n RETURN n, n+1 as m") + res = tx.run("UNWIND range(1, 10) AS n RETURN n, n+1 AS m") df = res.to_df() for instance will return a DataFrame with two columns: ``n`` and ``m`` diff --git a/src/neo4j/_sync/work/session.py b/src/neo4j/_sync/work/session.py index 3fa1c4e2a..41da7973a 100644 --- a/src/neo4j/_sync/work/session.py +++ b/src/neo4j/_sync/work/session.py @@ -591,7 +591,7 @@ def get_two_tx(tx): READ_ACCESS, transaction_function, *args, **kwargs ) - # TODO 6.0: Remove this method + # TODO: 6.0 - Remove this method @deprecated("read_transaction has been renamed to execute_read") def read_transaction( self, @@ -673,7 +673,7 @@ def create_node_tx(tx, name): WRITE_ACCESS, transaction_function, *args, **kwargs ) - # TODO 6.0: Remove this method + # TODO: 6.0 - Remove this method @deprecated("write_transaction has been renamed to execute_write") def write_transaction( self, diff --git a/src/neo4j/api.py b/src/neo4j/api.py index b1435b54a..e91aed6e7 100644 --- a/src/neo4j/api.py +++ b/src/neo4j/api.py @@ -175,7 +175,7 @@ def custom_auth( return Auth(scheme, principal, credentials, realm, **parameters) -# TODO 6.0 - remove this class +# TODO: 6.0 - remove this class class Bookmark: """A Bookmark object contains an immutable list of bookmark string values. diff --git a/src/neo4j/exceptions.py b/src/neo4j/exceptions.py index c3ca5b5e6..a0ce9ed08 100644 --- a/src/neo4j/exceptions.py +++ b/src/neo4j/exceptions.py @@ -199,7 +199,7 @@ def _extract_error_class(cls, classification, code): else: return cls - # TODO 6.0: Remove this alias + # TODO: 6.0 - Remove this alias @deprecated( "Neo4jError.is_retriable is deprecated and will be removed in a " "future version. Please use Neo4jError.is_retryable instead." @@ -230,10 +230,17 @@ def is_retryable(self) -> bool: """ return False - def invalidates_all_connections(self): + def _invalidates_all_connections(self) -> bool: return self.code == "Neo.ClientError.Security.AuthorizationExpired" - def is_fatal_during_discovery(self) -> bool: + # TODO: 6.0 - Remove this alias + invalidates_all_connections = deprecated( + "Neo4jError.invalidates_all_connections is deprecated and will be " + "removed in a future version. It is an internal method and not meant " + "for external use." + )(_invalidates_all_connections) + + def _is_fatal_during_discovery(self) -> bool: # checks if the code is an error that is caused by the client. In this # case the driver should fail fast during discovery. if not isinstance(self.code, str): @@ -251,6 +258,13 @@ def is_fatal_during_discovery(self) -> bool: return True return False + # TODO: 6.0 - Remove this alias + is_fatal_during_discovery = deprecated( + "Neo4jError.is_fatal_during_discovery is deprecated and will be " + "removed in a future version. It is an internal method and not meant " + "for external use." + )(_is_fatal_during_discovery) + def __str__(self): if self.code or self.message: return "{{code: {code}}} {{message: {message}}}".format( diff --git a/src/neo4j/spatial/__init__.py b/src/neo4j/spatial/__init__.py index f5ec1478e..2659a540e 100644 --- a/src/neo4j/spatial/__init__.py +++ b/src/neo4j/spatial/__init__.py @@ -42,7 +42,7 @@ ) -# TODO: 6.0 remove +# TODO: 6.0 - remove @deprecated( "hydrate_point is considered an internal function and will be removed in " "a future version" @@ -56,7 +56,7 @@ def hydrate_point(srid, *coordinates): return _hydration.hydrate_point(srid, *coordinates) -# TODO: 6.0 remove +# TODO: 6.0 - remove @deprecated( "hydrate_point is considered an internal function and will be removed in " "a future version" @@ -72,7 +72,7 @@ def dehydrate_point(value): return _hydration.dehydrate_point(value) -# TODO: 6.0 remove +# TODO: 6.0 - remove @deprecated( "point_type is considered an internal function and will be removed in " "a future version" diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index e4dec78ba..3bf510dfe 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -34,6 +34,7 @@ from neo4j._deadline import Deadline from neo4j.addressing import ResolvedAddress from neo4j.exceptions import ( + Neo4jError, ServiceUnavailable, SessionExpired, ) @@ -41,40 +42,62 @@ from ...._async_compat import mark_async_test -ROUTER_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") -READER_ADDRESS = ResolvedAddress(("1.2.3.1", 9002), host_name="host") -WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9003), host_name="host") +ROUTER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9000), host_name="host") +ROUTER2_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") +ROUTER3_ADDRESS = ResolvedAddress(("1.2.3.1", 9002), host_name="host") +READER_ADDRESS = ResolvedAddress(("1.2.3.1", 9010), host_name="host") +WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9020), host_name="host") @pytest.fixture -def opener(async_fake_connection_generator, mocker): - async def open_(addr, timeout): - connection = async_fake_connection_generator() - connection.addr = addr - connection.timeout = timeout - route_mock = mocker.AsyncMock() - route_mock.return_value = [{ - "ttl": 1000, - "servers": [ - {"addresses": [str(ROUTER_ADDRESS)], "role": "ROUTE"}, - {"addresses": [str(READER_ADDRESS)], "role": "READ"}, - {"addresses": [str(WRITER_ADDRESS)], "role": "WRITE"}, - ], - }] - connection.attach_mock(route_mock, "route") - opener_.connections.append(connection) - return connection - - opener_ = mocker.AsyncMock() - opener_.connections = [] - opener_.side_effect = open_ - return opener_ +def routing_failure_opener(async_fake_connection_generator, mocker): + def make_opener(failures=None): + def routing_side_effect(*args, **kwargs): + nonlocal failures + res = next(failures, None) + if res is None: + return [{ + "ttl": 1000, + "servers": [ + {"addresses": [str(ROUTER1_ADDRESS), + str(ROUTER2_ADDRESS), + str(ROUTER3_ADDRESS)], + "role": "ROUTE"}, + {"addresses": [str(READER_ADDRESS)], "role": "READ"}, + {"addresses": [str(WRITER_ADDRESS)], "role": "WRITE"}, + ], + }] + raise res + + async def open_(addr, timeout): + connection = async_fake_connection_generator() + connection.addr = addr + connection.timeout = timeout + route_mock = mocker.AsyncMock() + + route_mock.side_effect = routing_side_effect + connection.attach_mock(route_mock, "route") + opener_.connections.append(connection) + return connection + + failures = iter(failures or []) + opener_ = mocker.AsyncMock() + opener_.connections = [] + opener_.side_effect = open_ + return opener_ + + return make_opener + + +@pytest.fixture +def opener(routing_failure_opener): + return routing_failure_opener() @mark_async_test async def test_acquires_new_routing_table_if_deleted(opener): pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) await pool.release(cx) @@ -90,7 +113,7 @@ async def test_acquires_new_routing_table_if_deleted(opener): @mark_async_test async def test_acquires_new_routing_table_if_stale(opener): pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) await pool.release(cx) @@ -107,7 +130,7 @@ async def test_acquires_new_routing_table_if_stale(opener): @mark_async_test async def test_removes_old_routing_table(opener): pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx = await pool.acquire(READ_ACCESS, 30, "test_db1", None, None) await pool.release(cx) @@ -131,7 +154,7 @@ async def test_removes_old_routing_table(opener): @mark_async_test async def test_chooses_right_connection_type(opener, type_): pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx1 = await pool.acquire( READ_ACCESS if type_ == "r" else WRITE_ACCESS, @@ -147,7 +170,7 @@ async def test_chooses_right_connection_type(opener, type_): @mark_async_test async def test_reuses_connection(opener): pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) await pool.release(cx1) @@ -167,7 +190,7 @@ async def break_connection(): return await res pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) await pool.release(cx1) @@ -194,7 +217,7 @@ async def break_connection(): @mark_async_test async def test_does_not_close_stale_connections_in_use(opener): pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) assert cx1 in pool.connections[cx1.addr] @@ -225,7 +248,7 @@ async def test_does_not_close_stale_connections_in_use(opener): @mark_async_test async def test_release_resets_connections(opener): pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) cx1.is_reset_mock.return_value = False @@ -238,7 +261,7 @@ async def test_release_resets_connections(opener): @mark_async_test async def test_release_does_not_resets_closed_connections(opener): pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) cx1.closed.return_value = True @@ -253,7 +276,7 @@ async def test_release_does_not_resets_closed_connections(opener): @mark_async_test async def test_release_does_not_resets_defunct_connections(opener): pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) cx1.defunct.return_value = True @@ -271,7 +294,7 @@ async def test_acquire_performs_no_liveness_check_on_fresh_connection( opener, liveness_timeout ): pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx1 = await pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) assert cx1.addr == READER_ADDRESS @@ -284,7 +307,7 @@ async def test_acquire_performs_liveness_check_on_existing_connection( opener, liveness_timeout ): pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) # populate the pool with a connection cx1 = await pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) @@ -318,7 +341,7 @@ def liveness_side_effect(*args, **kwargs): liveness_timeout = 1 pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) # populate the pool with a connection cx1 = await pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) @@ -357,7 +380,7 @@ def liveness_side_effect(*args, **kwargs): liveness_timeout = 1 pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) # populate the pool with a connection cx1 = await pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) @@ -406,7 +429,7 @@ async def close_side_effect(): # create pool with 2 idle connections pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) @@ -430,7 +453,7 @@ async def close_side_effect(): @mark_async_test async def test_failing_opener_leaves_connections_in_use_alone(opener): pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) @@ -445,7 +468,7 @@ async def test__acquire_new_later_with_room(opener): config = PoolConfig() config.max_connection_pool_size = 1 pool = AsyncNeo4jPool( - opener, config, WorkspaceConfig(), ROUTER_ADDRESS + opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) assert pool.connections_reservations[READER_ADDRESS] == 0 creator = pool._acquire_new_later(READER_ADDRESS, Deadline(1)) @@ -460,7 +483,7 @@ async def test__acquire_new_later_without_room(opener): config = PoolConfig() config.max_connection_pool_size = 1 pool = AsyncNeo4jPool( - opener, config, WorkspaceConfig(), ROUTER_ADDRESS + opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) _ = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) # pool is full now @@ -468,3 +491,78 @@ async def test__acquire_new_later_without_room(opener): creator = pool._acquire_new_later(READER_ADDRESS, Deadline(1)) assert pool.connections_reservations[READER_ADDRESS] == 0 assert creator is None + + +@pytest.mark.parametrize("error", ( + ServiceUnavailable(), + Neo4jError.hydrate("message", "Neo.ClientError.Statement.EntityNotFound"), + Neo4jError.hydrate("message", + "Neo.ClientError.Security.AuthorizationExpired"), +)) +@mark_async_test +async def test_discovery_is_retried(routing_failure_opener, error): + opener = routing_failure_opener([ + None, # first call to router for seeding the RT with more routers + error, # will be retried + ]) + pool = AsyncNeo4jPool( + opener, PoolConfig(), WorkspaceConfig(), + ResolvedAddress(("1.2.3.1", 9999), host_name="host") + ) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + await pool.release(cx1) + pool.routing_tables.get("test_db").ttl = 0 + + cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + await pool.release(cx2) + assert pool.routing_tables.get("test_db") + + assert cx1 is cx2 + + # initial router + # reader + # failed router + # successful router + # same reader again + assert len(opener.connections) == 4 + + +@pytest.mark.parametrize("error", map( + lambda args: Neo4jError.hydrate(*args), ( + ("message", "Neo.ClientError.Database.DatabaseNotFound"), + ("message", "Neo.ClientError.Transaction.InvalidBookmark"), + ("message", "Neo.ClientError.Transaction.InvalidBookmarkMixture"), + ("message", "Neo.ClientError.Statement.TypeError"), + ("message", "Neo.ClientError.Statement.ArgumentError"), + ("message", "Neo.ClientError.Request.Invalid"), + ("message", "Neo.ClientError.Security.AuthenticationRateLimit"), + ("message", "Neo.ClientError.Security.CredentialsExpired"), + ("message", "Neo.ClientError.Security.Forbidden"), + ("message", "Neo.ClientError.Security.TokenExpired"), + ("message", "Neo.ClientError.Security.Unauthorized"), + ("message", "Neo.ClientError.Security.MadeUpError"), + ) +)) +@mark_async_test +async def test_fast_failing_discovery(routing_failure_opener, error): + opener = routing_failure_opener([ + None, # first call to router for seeding the RT with more routers + error, # will be retried + ]) + pool = AsyncNeo4jPool( + opener, PoolConfig(), WorkspaceConfig(), + ResolvedAddress(("1.2.3.1", 9999), host_name="host") + ) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + await pool.release(cx1) + pool.routing_tables.get("test_db").ttl = 0 + + with pytest.raises(error.__class__) as exc: + await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + + assert exc.value is error + + # initial router + # reader + # failed router + assert len(opener.connections) == 3 diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index 17507c716..0c13a78d5 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -34,6 +34,7 @@ from neo4j._sync.io import Neo4jPool from neo4j.addressing import ResolvedAddress from neo4j.exceptions import ( + Neo4jError, ServiceUnavailable, SessionExpired, ) @@ -41,40 +42,62 @@ from ...._async_compat import mark_sync_test -ROUTER_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") -READER_ADDRESS = ResolvedAddress(("1.2.3.1", 9002), host_name="host") -WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9003), host_name="host") +ROUTER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9000), host_name="host") +ROUTER2_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") +ROUTER3_ADDRESS = ResolvedAddress(("1.2.3.1", 9002), host_name="host") +READER_ADDRESS = ResolvedAddress(("1.2.3.1", 9010), host_name="host") +WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9020), host_name="host") @pytest.fixture -def opener(fake_connection_generator, mocker): - def open_(addr, timeout): - connection = fake_connection_generator() - connection.addr = addr - connection.timeout = timeout - route_mock = mocker.Mock() - route_mock.return_value = [{ - "ttl": 1000, - "servers": [ - {"addresses": [str(ROUTER_ADDRESS)], "role": "ROUTE"}, - {"addresses": [str(READER_ADDRESS)], "role": "READ"}, - {"addresses": [str(WRITER_ADDRESS)], "role": "WRITE"}, - ], - }] - connection.attach_mock(route_mock, "route") - opener_.connections.append(connection) - return connection - - opener_ = mocker.Mock() - opener_.connections = [] - opener_.side_effect = open_ - return opener_ +def routing_failure_opener(fake_connection_generator, mocker): + def make_opener(failures=None): + def routing_side_effect(*args, **kwargs): + nonlocal failures + res = next(failures, None) + if res is None: + return [{ + "ttl": 1000, + "servers": [ + {"addresses": [str(ROUTER1_ADDRESS), + str(ROUTER2_ADDRESS), + str(ROUTER3_ADDRESS)], + "role": "ROUTE"}, + {"addresses": [str(READER_ADDRESS)], "role": "READ"}, + {"addresses": [str(WRITER_ADDRESS)], "role": "WRITE"}, + ], + }] + raise res + + def open_(addr, timeout): + connection = fake_connection_generator() + connection.addr = addr + connection.timeout = timeout + route_mock = mocker.Mock() + + route_mock.side_effect = routing_side_effect + connection.attach_mock(route_mock, "route") + opener_.connections.append(connection) + return connection + + failures = iter(failures or []) + opener_ = mocker.Mock() + opener_.connections = [] + opener_.side_effect = open_ + return opener_ + + return make_opener + + +@pytest.fixture +def opener(routing_failure_opener): + return routing_failure_opener() @mark_sync_test def test_acquires_new_routing_table_if_deleted(opener): pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None) pool.release(cx) @@ -90,7 +113,7 @@ def test_acquires_new_routing_table_if_deleted(opener): @mark_sync_test def test_acquires_new_routing_table_if_stale(opener): pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None) pool.release(cx) @@ -107,7 +130,7 @@ def test_acquires_new_routing_table_if_stale(opener): @mark_sync_test def test_removes_old_routing_table(opener): pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx = pool.acquire(READ_ACCESS, 30, "test_db1", None, None) pool.release(cx) @@ -131,7 +154,7 @@ def test_removes_old_routing_table(opener): @mark_sync_test def test_chooses_right_connection_type(opener, type_): pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx1 = pool.acquire( READ_ACCESS if type_ == "r" else WRITE_ACCESS, @@ -147,7 +170,7 @@ def test_chooses_right_connection_type(opener, type_): @mark_sync_test def test_reuses_connection(opener): pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) pool.release(cx1) @@ -167,7 +190,7 @@ def break_connection(): return res pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) pool.release(cx1) @@ -194,7 +217,7 @@ def break_connection(): @mark_sync_test def test_does_not_close_stale_connections_in_use(opener): pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) assert cx1 in pool.connections[cx1.addr] @@ -225,7 +248,7 @@ def test_does_not_close_stale_connections_in_use(opener): @mark_sync_test def test_release_resets_connections(opener): pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) cx1.is_reset_mock.return_value = False @@ -238,7 +261,7 @@ def test_release_resets_connections(opener): @mark_sync_test def test_release_does_not_resets_closed_connections(opener): pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) cx1.closed.return_value = True @@ -253,7 +276,7 @@ def test_release_does_not_resets_closed_connections(opener): @mark_sync_test def test_release_does_not_resets_defunct_connections(opener): pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) cx1.defunct.return_value = True @@ -271,7 +294,7 @@ def test_acquire_performs_no_liveness_check_on_fresh_connection( opener, liveness_timeout ): pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx1 = pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) assert cx1.addr == READER_ADDRESS @@ -284,7 +307,7 @@ def test_acquire_performs_liveness_check_on_existing_connection( opener, liveness_timeout ): pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) # populate the pool with a connection cx1 = pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) @@ -318,7 +341,7 @@ def liveness_side_effect(*args, **kwargs): liveness_timeout = 1 pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) # populate the pool with a connection cx1 = pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) @@ -357,7 +380,7 @@ def liveness_side_effect(*args, **kwargs): liveness_timeout = 1 pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) # populate the pool with a connection cx1 = pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) @@ -406,7 +429,7 @@ def close_side_effect(): # create pool with 2 idle connections pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) @@ -430,7 +453,7 @@ def close_side_effect(): @mark_sync_test def test_failing_opener_leaves_connections_in_use_alone(opener): pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) @@ -445,7 +468,7 @@ def test__acquire_new_later_with_room(opener): config = PoolConfig() config.max_connection_pool_size = 1 pool = Neo4jPool( - opener, config, WorkspaceConfig(), ROUTER_ADDRESS + opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) assert pool.connections_reservations[READER_ADDRESS] == 0 creator = pool._acquire_new_later(READER_ADDRESS, Deadline(1)) @@ -460,7 +483,7 @@ def test__acquire_new_later_without_room(opener): config = PoolConfig() config.max_connection_pool_size = 1 pool = Neo4jPool( - opener, config, WorkspaceConfig(), ROUTER_ADDRESS + opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) _ = pool.acquire(READ_ACCESS, 30, "test_db", None, None) # pool is full now @@ -468,3 +491,78 @@ def test__acquire_new_later_without_room(opener): creator = pool._acquire_new_later(READER_ADDRESS, Deadline(1)) assert pool.connections_reservations[READER_ADDRESS] == 0 assert creator is None + + +@pytest.mark.parametrize("error", ( + ServiceUnavailable(), + Neo4jError.hydrate("message", "Neo.ClientError.Statement.EntityNotFound"), + Neo4jError.hydrate("message", + "Neo.ClientError.Security.AuthorizationExpired"), +)) +@mark_sync_test +def test_discovery_is_retried(routing_failure_opener, error): + opener = routing_failure_opener([ + None, # first call to router for seeding the RT with more routers + error, # will be retried + ]) + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), + ResolvedAddress(("1.2.3.1", 9999), host_name="host") + ) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + pool.release(cx1) + pool.routing_tables.get("test_db").ttl = 0 + + cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + pool.release(cx2) + assert pool.routing_tables.get("test_db") + + assert cx1 is cx2 + + # initial router + # reader + # failed router + # successful router + # same reader again + assert len(opener.connections) == 4 + + +@pytest.mark.parametrize("error", map( + lambda args: Neo4jError.hydrate(*args), ( + ("message", "Neo.ClientError.Database.DatabaseNotFound"), + ("message", "Neo.ClientError.Transaction.InvalidBookmark"), + ("message", "Neo.ClientError.Transaction.InvalidBookmarkMixture"), + ("message", "Neo.ClientError.Statement.TypeError"), + ("message", "Neo.ClientError.Statement.ArgumentError"), + ("message", "Neo.ClientError.Request.Invalid"), + ("message", "Neo.ClientError.Security.AuthenticationRateLimit"), + ("message", "Neo.ClientError.Security.CredentialsExpired"), + ("message", "Neo.ClientError.Security.Forbidden"), + ("message", "Neo.ClientError.Security.TokenExpired"), + ("message", "Neo.ClientError.Security.Unauthorized"), + ("message", "Neo.ClientError.Security.MadeUpError"), + ) +)) +@mark_sync_test +def test_fast_failing_discovery(routing_failure_opener, error): + opener = routing_failure_opener([ + None, # first call to router for seeding the RT with more routers + error, # will be retried + ]) + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), + ResolvedAddress(("1.2.3.1", 9999), host_name="host") + ) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + pool.release(cx1) + pool.routing_tables.get("test_db").ttl = 0 + + with pytest.raises(error.__class__) as exc: + pool.acquire(READ_ACCESS, 30, "test_db", None, None) + + assert exc.value is error + + # initial router + # reader + # failed router + assert len(opener.connections) == 3