diff --git a/CHANGELOG.md b/CHANGELOG.md index 44d301b1..a950cccc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -187,6 +187,9 @@ See also https://github.com/neo4j/neo4j-python-driver/wiki for a full changelog. - `ResultSummary.gql_status_objects` - `neo4j.GqlStatusObject` - (`neo4j.exceptions.GqlError`, `neo4j.exceptions.GqlErrorClassification`) +- On failed liveness check (s. `liveness_check_timeout` configuration option), the driver will no longer remove the + remote from the cached routing tables, but only close the connection under test. + This aligns the driver with the other official Neo4j drivers. ## Version 5.28 diff --git a/src/neo4j/_addressing.py b/src/neo4j/_addressing.py index 3593048d..98ea0161 100644 --- a/src/neo4j/_addressing.py +++ b/src/neo4j/_addressing.py @@ -306,6 +306,9 @@ def port_number(self) -> int: pass raise error_cls(f"Unknown port value {self[1]!r}") + def __reduce__(self): + return Address, (tuple(self),) + class IPv4Address(Address): """ @@ -352,12 +355,15 @@ def _host_name(self) -> str: def _unresolved(self) -> Address: return super().__new__(Address, (self._host_name, *self[1:])) - def __new__(cls, iterable, *, host_name: str) -> ResolvedAddress: + def __new__(cls, iterable, host_name: str) -> ResolvedAddress: new = super().__new__(cls, iterable) new = t.cast(ResolvedAddress, new) new._unresolved_host_name = host_name return new + def __reduce__(self): + return ResolvedAddress, (tuple(self), self._unresolved_host_name) + class ResolvedIPv4Address(IPv4Address, ResolvedAddress): pass diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index f2e6aa55..f6f0ba45 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -126,6 +126,9 @@ class AsyncBolt: _closed = False _defunct = False + # Flag if the connection is currently performing a liveness check. + _liveness_check = False + #: The pool of which this connection is a member pool = None @@ -758,6 +761,13 @@ async def reset(self, dehydration_hooks=None, hydration_hooks=None): type understood by packstream and are free to return anything. """ + async def liveness_check(self): + self._liveness_check = True + try: + await self.reset() + finally: + self._liveness_check = False + @abc.abstractmethod def goodbye(self, dehydration_hooks=None, hydration_hooks=None): """ @@ -934,7 +944,11 @@ async def _set_defunct(self, message, error=None, silent=False): # remove the connection from the pool, nor to try to close the # connection again. await self.close() - if self.pool and not self._get_server_state_manager().failed(): + if ( + not self._liveness_check + and self.pool + and not self._get_server_state_manager().failed() + ): await self.pool.deactivate(address=self.unresolved_address) # Iterate through the outstanding responses, and if any correspond diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index d61f6f04..c0f7a7fd 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -355,7 +355,7 @@ async def health_check(connection_, deadline_): "[#%04X] _: liveness check", connection_.local_port, ) - await connection_.reset() + await connection_.liveness_check() except (OSError, ServiceUnavailable, SessionExpired): return False return True diff --git a/src/neo4j/_routing.py b/src/neo4j/_routing.py index 86f141d8..135c4c12 100644 --- a/src/neo4j/_routing.py +++ b/src/neo4j/_routing.py @@ -193,3 +193,14 @@ def update(self, new_routing_table): def servers(self): return set(self.routers) | set(self.writers) | set(self.readers) + + def __eq__(self, other): + if not isinstance(other, RoutingTable): + return NotImplemented + return ( + self.database == other.database + and self.routers == other.routers + and self.readers == other.readers + and self.writers == other.writers + and self.ttl == other.ttl + ) diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 7b5b5a86..55f9fbd0 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -126,6 +126,9 @@ class Bolt: _closed = False _defunct = False + # Flag if the connection is currently performing a liveness check. + _liveness_check = False + #: The pool of which this connection is a member pool = None @@ -758,6 +761,13 @@ def reset(self, dehydration_hooks=None, hydration_hooks=None): type understood by packstream and are free to return anything. """ + def liveness_check(self): + self._liveness_check = True + try: + self.reset() + finally: + self._liveness_check = False + @abc.abstractmethod def goodbye(self, dehydration_hooks=None, hydration_hooks=None): """ @@ -934,7 +944,11 @@ def _set_defunct(self, message, error=None, silent=False): # remove the connection from the pool, nor to try to close the # connection again. self.close() - if self.pool and not self._get_server_state_manager().failed(): + if ( + not self._liveness_check + and self.pool + and not self._get_server_state_manager().failed() + ): self.pool.deactivate(address=self.unresolved_address) # Iterate through the outstanding responses, and if any correspond diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index a94af11f..a296db11 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -352,7 +352,7 @@ def health_check(connection_, deadline_): "[#%04X] _: liveness check", connection_.local_port, ) - connection_.reset() + connection_.liveness_check() except (OSError, ServiceUnavailable, SessionExpired): return False return True diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index ea6046b7..275de122 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -11,9 +11,7 @@ "'neo4j.datatypes.test_temporal_types.TestDataTypes.test_should_echo_all_timezone_ids'": "test_subtest_skips.dt_conversion", "'neo4j.datatypes.test_temporal_types.TestDataTypes.test_date_time_cypher_created_tz_id'": - "test_subtest_skips.tz_id", - "stub\\.routing\\.test_routing_v[0-9x]+\\.RoutingV[0-9x]+\\.test_should_drop_connections_failing_liveness_check": - "Liveness check error handling is not (yet) unified: https://github.com/neo-technology/drivers-adr/pull/83" + "test_subtest_skips.tz_id" }, "features": { "Feature:API:BookmarkManager": true, diff --git a/tests/unit/async_/fixtures/fake_connection.py b/tests/unit/async_/fixtures/fake_connection.py index 22e22dfb..c5d3b29c 100644 --- a/tests/unit/async_/fixtures/fake_connection.py +++ b/tests/unit/async_/fixtures/fake_connection.py @@ -55,6 +55,7 @@ def __init__(self, *args, **kwargs): self.attach_mock( mock.AsyncMock(spec=AsyncAuthManager), "auth_manager" ) + self.attach_mock(mock.AsyncMock(), "liveness_check") self.unresolved_address = next(iter(args), "localhost") self.callbacks = [] diff --git a/tests/unit/async_/io/conftest.py b/tests/unit/async_/io/conftest.py index 0596c050..90a84ae3 100644 --- a/tests/unit/async_/io/conftest.py +++ b/tests/unit/async_/io/conftest.py @@ -110,6 +110,10 @@ async def send_message(self, tag, *fields): self._outbox.append_message(tag, fields, None) await self._outbox.flush() + def assert_no_more_messages(self): + assert self._messages + assert not self.recv_buffer + class AsyncFakeSocketPair: def __init__(self, address, packer_cls=None, unpacker_cls=None): diff --git a/tests/unit/async_/io/test_class_bolt_any.py b/tests/unit/async_/io/test_class_bolt_any.py new file mode 100644 index 00000000..0a9e298a --- /dev/null +++ b/tests/unit/async_/io/test_class_bolt_any.py @@ -0,0 +1,104 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +import neo4j +from neo4j._async.io._bolt3 import AsyncBolt3 +from neo4j._async.io._bolt4 import ( + AsyncBolt4x0, + AsyncBolt4x1, + AsyncBolt4x2, + AsyncBolt4x3, +) +from neo4j._async.io._bolt5 import ( + AsyncBolt5x0, + AsyncBolt5x1, + AsyncBolt5x2, + AsyncBolt5x3, + AsyncBolt5x4, + AsyncBolt5x5, + AsyncBolt5x6, + AsyncBolt5x7, + AsyncBolt5x8, +) +from neo4j.exceptions import ServiceUnavailable + +from ...._async_compat import mark_async_test + + +@pytest.fixture( + params=[ + AsyncBolt3, + AsyncBolt4x0, + AsyncBolt4x1, + AsyncBolt4x2, + AsyncBolt4x3, + AsyncBolt5x0, + AsyncBolt5x1, + AsyncBolt5x2, + AsyncBolt5x3, + AsyncBolt5x4, + AsyncBolt5x5, + AsyncBolt5x6, + AsyncBolt5x7, + AsyncBolt5x8, + ] +) +def bolt_cls(request): + return request.param + + +@mark_async_test +async def test_liveness_check_calls_reset(bolt_cls, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + connection = bolt_cls(address, sockets.client, 0) + + await sockets.server.send_message(b"\x70", {}) + await connection.liveness_check() + tag, fields = await sockets.server.pop_message() + assert tag == b"\x0f" + assert len(fields) == 0 + sockets.server.assert_no_more_messages() + + +@mark_async_test +async def test_failed_liveness_check_does_not_call_pool( + bolt_cls, fake_socket_pair, mocker +): + def broken_recv_into(*args, **kwargs): + raise OSError("nope") + + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=bolt_cls.PACKER_CLS, + unpacker_cls=bolt_cls.UNPACKER_CLS, + ) + connection = bolt_cls(address, sockets.client, 0) + pool_mock = mocker.AsyncMock() + connection.pool = pool_mock + sockets.client.recv_into = broken_recv_into + + with pytest.raises(ServiceUnavailable): + await connection.liveness_check() + + assert not pool_mock.method_calls diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index 83a5abf0..c54258eb 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -261,7 +261,7 @@ async def test_liveness_check( else: cx1.is_idle_for.assert_not_called() await pool.release(cx1) - cx1.reset.assert_not_called() + cx1.liveness_check.assert_not_called() # simulate after timeout cx1.is_idle_for.return_value = True @@ -271,13 +271,13 @@ async def test_liveness_check( assert cx2 is cx1 if effective_timeout is not None: cx1.is_idle_for.assert_called_once_with(effective_timeout) - cx1.reset.assert_awaited_once() + cx1.liveness_check.assert_awaited_once() else: cx1.is_idle_for.assert_not_called() - cx1.reset.assert_not_called() - cx1.reset.reset_mock() + cx1.liveness_check.assert_not_called() + cx1.liveness_check.reset_mock() await pool.release(cx1) - cx1.reset.assert_not_called() + cx1.liveness_check.assert_not_called() @pytest.mark.parametrize("unprepared", (True, False, None)) diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index d10abcde..d74ab768 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -17,6 +17,7 @@ import contextlib import inspect import sys +from copy import deepcopy import pytest @@ -389,7 +390,7 @@ async def test_acquire_performs_no_liveness_check_on_fresh_connection( READER1_ADDRESS, None, Deadline(30), liveness_timeout ) assert cx1.unresolved_address == READER1_ADDRESS - cx1.reset.assert_not_called() + cx1.liveness_check.assert_not_called() @pytest.mark.parametrize("liveness_timeout", (0, 1, 2)) @@ -406,13 +407,13 @@ async def test_acquire_performs_liveness_check_on_existing_connection( # make sure we assume the right state assert cx1.unresolved_address == READER1_ADDRESS cx1.is_idle_for.assert_not_called() - cx1.reset.assert_not_called() + cx1.liveness_check.assert_not_called() cx1.is_idle_for.return_value = True # release the connection await pool.release(cx1) - cx1.reset.assert_not_called() + cx1.liveness_check.assert_not_called() # then acquire it again and assert the liveness check was performed cx2 = await pool._acquire( @@ -420,7 +421,7 @@ async def test_acquire_performs_liveness_check_on_existing_connection( ) assert cx1 is cx2 cx1.is_idle_for.assert_called_once_with(liveness_timeout) - cx2.reset.assert_awaited_once() + cx2.liveness_check.assert_awaited_once() @pytest.mark.parametrize( @@ -443,15 +444,15 @@ def liveness_side_effect(*args, **kwargs): # make sure we assume the right state assert cx1.unresolved_address == READER1_ADDRESS cx1.is_idle_for.assert_not_called() - cx1.reset.assert_not_called() + cx1.liveness_check.assert_not_called() cx1.is_idle_for.return_value = True # simulate cx1 failing liveness check - cx1.reset.side_effect = liveness_side_effect + cx1.liveness_check.side_effect = liveness_side_effect # release the connection await pool.release(cx1) - cx1.reset.assert_not_called() + cx1.liveness_check.assert_not_called() # then acquire it again and assert the liveness check was performed cx2 = await pool._acquire( @@ -460,7 +461,7 @@ def liveness_side_effect(*args, **kwargs): assert cx1 is not cx2 assert cx1.unresolved_address == cx2.unresolved_address cx1.is_idle_for.assert_called_once_with(liveness_timeout) - cx2.reset.assert_not_called() + cx2.liveness_check.assert_not_called() assert cx1 not in pool.connections[cx1.unresolved_address] assert cx2 in pool.connections[cx1.unresolved_address] @@ -491,18 +492,18 @@ def liveness_side_effect(*args, **kwargs): assert cx1 is not cx2 cx1.is_idle_for.assert_not_called() cx2.is_idle_for.assert_not_called() - cx1.reset.assert_not_called() + cx1.liveness_check.assert_not_called() cx1.is_idle_for.return_value = True cx2.is_idle_for.return_value = True # simulate cx1 failing liveness check - cx1.reset.side_effect = liveness_side_effect + cx1.liveness_check.side_effect = liveness_side_effect # release the connection await pool.release(cx1) await pool.release(cx2) - cx1.reset.assert_not_called() - cx2.reset.assert_not_called() + cx1.liveness_check.assert_not_called() + cx2.liveness_check.assert_not_called() # then acquire it again and assert the liveness check was performed cx3 = await pool._acquire( @@ -510,13 +511,38 @@ def liveness_side_effect(*args, **kwargs): ) assert cx3 is cx2 cx1.is_idle_for.assert_called_once_with(liveness_timeout) - cx1.reset.assert_awaited_once() + cx1.liveness_check.assert_awaited_once() cx3.is_idle_for.assert_called_once_with(liveness_timeout) - cx3.reset.assert_awaited_once() + cx3.liveness_check.assert_awaited_once() assert cx1 not in pool.connections[cx1.unresolved_address] assert cx3 in pool.connections[cx1.unresolved_address] +@pytest.mark.parametrize( + "liveness_error", (OSError, ServiceUnavailable, SessionExpired) +) +@mark_async_test +async def test_failed_liveness_check_does_not_alter_routing_table( + opener, liveness_error +): + def liveness_side_effect(*args, **kwargs): + raise liveness_error("liveness check failed") + + liveness_timeout = 1 + pool = _simple_pool(opener) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + # simulate cx1 failing next liveness check + cx1.liveness_check.side_effect = liveness_side_effect + await pool.release(cx1) + rts = deepcopy(pool.routing_tables) + + cx2 = await pool.acquire( + READ_ACCESS, 30, TEST_DB1, None, None, liveness_timeout + ) + assert cx2 is not cx1 + assert rts == pool.routing_tables + + @mark_async_test async def test_multiple_broken_connections_on_close(opener, mocker): def mock_connection_breaks_on_close(cx): diff --git a/tests/unit/sync/fixtures/fake_connection.py b/tests/unit/sync/fixtures/fake_connection.py index df261886..fc60a4a2 100644 --- a/tests/unit/sync/fixtures/fake_connection.py +++ b/tests/unit/sync/fixtures/fake_connection.py @@ -55,6 +55,7 @@ def __init__(self, *args, **kwargs): self.attach_mock( mock.MagicMock(spec=AuthManager), "auth_manager" ) + self.attach_mock(mock.MagicMock(), "liveness_check") self.unresolved_address = next(iter(args), "localhost") self.callbacks = [] diff --git a/tests/unit/sync/io/conftest.py b/tests/unit/sync/io/conftest.py index 17f99637..34c65dbe 100644 --- a/tests/unit/sync/io/conftest.py +++ b/tests/unit/sync/io/conftest.py @@ -110,6 +110,10 @@ def send_message(self, tag, *fields): self._outbox.append_message(tag, fields, None) self._outbox.flush() + def assert_no_more_messages(self): + assert self._messages + assert not self.recv_buffer + class FakeSocketPair: def __init__(self, address, packer_cls=None, unpacker_cls=None): diff --git a/tests/unit/sync/io/test_class_bolt_any.py b/tests/unit/sync/io/test_class_bolt_any.py new file mode 100644 index 00000000..5b80e3fd --- /dev/null +++ b/tests/unit/sync/io/test_class_bolt_any.py @@ -0,0 +1,104 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +import neo4j +from neo4j._sync.io._bolt3 import Bolt3 +from neo4j._sync.io._bolt4 import ( + Bolt4x0, + Bolt4x1, + Bolt4x2, + Bolt4x3, +) +from neo4j._sync.io._bolt5 import ( + Bolt5x0, + Bolt5x1, + Bolt5x2, + Bolt5x3, + Bolt5x4, + Bolt5x5, + Bolt5x6, + Bolt5x7, + Bolt5x8, +) +from neo4j.exceptions import ServiceUnavailable + +from ...._async_compat import mark_sync_test + + +@pytest.fixture( + params=[ + Bolt3, + Bolt4x0, + Bolt4x1, + Bolt4x2, + Bolt4x3, + Bolt5x0, + Bolt5x1, + Bolt5x2, + Bolt5x3, + Bolt5x4, + Bolt5x5, + Bolt5x6, + Bolt5x7, + Bolt5x8, + ] +) +def bolt_cls(request): + return request.param + + +@mark_sync_test +def test_liveness_check_calls_reset(bolt_cls, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + connection = bolt_cls(address, sockets.client, 0) + + sockets.server.send_message(b"\x70", {}) + connection.liveness_check() + tag, fields = sockets.server.pop_message() + assert tag == b"\x0f" + assert len(fields) == 0 + sockets.server.assert_no_more_messages() + + +@mark_sync_test +def test_failed_liveness_check_does_not_call_pool( + bolt_cls, fake_socket_pair, mocker +): + def broken_recv_into(*args, **kwargs): + raise OSError("nope") + + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=bolt_cls.PACKER_CLS, + unpacker_cls=bolt_cls.UNPACKER_CLS, + ) + connection = bolt_cls(address, sockets.client, 0) + pool_mock = mocker.MagicMock() + connection.pool = pool_mock + sockets.client.recv_into = broken_recv_into + + with pytest.raises(ServiceUnavailable): + connection.liveness_check() + + assert not pool_mock.method_calls diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index 6c8934f4..f83505a2 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -261,7 +261,7 @@ def test_liveness_check( else: cx1.is_idle_for.assert_not_called() pool.release(cx1) - cx1.reset.assert_not_called() + cx1.liveness_check.assert_not_called() # simulate after timeout cx1.is_idle_for.return_value = True @@ -271,13 +271,13 @@ def test_liveness_check( assert cx2 is cx1 if effective_timeout is not None: cx1.is_idle_for.assert_called_once_with(effective_timeout) - cx1.reset.assert_called_once() + cx1.liveness_check.assert_called_once() else: cx1.is_idle_for.assert_not_called() - cx1.reset.assert_not_called() - cx1.reset.reset_mock() + cx1.liveness_check.assert_not_called() + cx1.liveness_check.reset_mock() pool.release(cx1) - cx1.reset.assert_not_called() + cx1.liveness_check.assert_not_called() @pytest.mark.parametrize("unprepared", (True, False, None)) diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index 70c8401b..d2e45567 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -17,6 +17,7 @@ import contextlib import inspect import sys +from copy import deepcopy import pytest @@ -389,7 +390,7 @@ def test_acquire_performs_no_liveness_check_on_fresh_connection( READER1_ADDRESS, None, Deadline(30), liveness_timeout ) assert cx1.unresolved_address == READER1_ADDRESS - cx1.reset.assert_not_called() + cx1.liveness_check.assert_not_called() @pytest.mark.parametrize("liveness_timeout", (0, 1, 2)) @@ -406,13 +407,13 @@ def test_acquire_performs_liveness_check_on_existing_connection( # make sure we assume the right state assert cx1.unresolved_address == READER1_ADDRESS cx1.is_idle_for.assert_not_called() - cx1.reset.assert_not_called() + cx1.liveness_check.assert_not_called() cx1.is_idle_for.return_value = True # release the connection pool.release(cx1) - cx1.reset.assert_not_called() + cx1.liveness_check.assert_not_called() # then acquire it again and assert the liveness check was performed cx2 = pool._acquire( @@ -420,7 +421,7 @@ def test_acquire_performs_liveness_check_on_existing_connection( ) assert cx1 is cx2 cx1.is_idle_for.assert_called_once_with(liveness_timeout) - cx2.reset.assert_called_once() + cx2.liveness_check.assert_called_once() @pytest.mark.parametrize( @@ -443,15 +444,15 @@ def liveness_side_effect(*args, **kwargs): # make sure we assume the right state assert cx1.unresolved_address == READER1_ADDRESS cx1.is_idle_for.assert_not_called() - cx1.reset.assert_not_called() + cx1.liveness_check.assert_not_called() cx1.is_idle_for.return_value = True # simulate cx1 failing liveness check - cx1.reset.side_effect = liveness_side_effect + cx1.liveness_check.side_effect = liveness_side_effect # release the connection pool.release(cx1) - cx1.reset.assert_not_called() + cx1.liveness_check.assert_not_called() # then acquire it again and assert the liveness check was performed cx2 = pool._acquire( @@ -460,7 +461,7 @@ def liveness_side_effect(*args, **kwargs): assert cx1 is not cx2 assert cx1.unresolved_address == cx2.unresolved_address cx1.is_idle_for.assert_called_once_with(liveness_timeout) - cx2.reset.assert_not_called() + cx2.liveness_check.assert_not_called() assert cx1 not in pool.connections[cx1.unresolved_address] assert cx2 in pool.connections[cx1.unresolved_address] @@ -491,18 +492,18 @@ def liveness_side_effect(*args, **kwargs): assert cx1 is not cx2 cx1.is_idle_for.assert_not_called() cx2.is_idle_for.assert_not_called() - cx1.reset.assert_not_called() + cx1.liveness_check.assert_not_called() cx1.is_idle_for.return_value = True cx2.is_idle_for.return_value = True # simulate cx1 failing liveness check - cx1.reset.side_effect = liveness_side_effect + cx1.liveness_check.side_effect = liveness_side_effect # release the connection pool.release(cx1) pool.release(cx2) - cx1.reset.assert_not_called() - cx2.reset.assert_not_called() + cx1.liveness_check.assert_not_called() + cx2.liveness_check.assert_not_called() # then acquire it again and assert the liveness check was performed cx3 = pool._acquire( @@ -510,13 +511,38 @@ def liveness_side_effect(*args, **kwargs): ) assert cx3 is cx2 cx1.is_idle_for.assert_called_once_with(liveness_timeout) - cx1.reset.assert_called_once() + cx1.liveness_check.assert_called_once() cx3.is_idle_for.assert_called_once_with(liveness_timeout) - cx3.reset.assert_called_once() + cx3.liveness_check.assert_called_once() assert cx1 not in pool.connections[cx1.unresolved_address] assert cx3 in pool.connections[cx1.unresolved_address] +@pytest.mark.parametrize( + "liveness_error", (OSError, ServiceUnavailable, SessionExpired) +) +@mark_sync_test +def test_failed_liveness_check_does_not_alter_routing_table( + opener, liveness_error +): + def liveness_side_effect(*args, **kwargs): + raise liveness_error("liveness check failed") + + liveness_timeout = 1 + pool = _simple_pool(opener) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + # simulate cx1 failing next liveness check + cx1.liveness_check.side_effect = liveness_side_effect + pool.release(cx1) + rts = deepcopy(pool.routing_tables) + + cx2 = pool.acquire( + READ_ACCESS, 30, TEST_DB1, None, None, liveness_timeout + ) + assert cx2 is not cx1 + assert rts == pool.routing_tables + + @mark_sync_test def test_multiple_broken_connections_on_close(opener, mocker): def mock_connection_breaks_on_close(cx):