From 6602d64633f38518c3cadae29f02e18360918a77 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 1 Aug 2023 10:42:33 +0200 Subject: [PATCH] Fix missing lock acquisition in pool Fix some functions in the pool manipulating the routing information without holding the right lock. --- src/neo4j/_async/io/_bolt3.py | 4 +++- src/neo4j/_async/io/_bolt4.py | 4 +++- src/neo4j/_async/io/_bolt5.py | 4 +++- src/neo4j/_async/io/_pool.py | 18 ++++++++++-------- src/neo4j/_sync/io/_bolt3.py | 4 +++- src/neo4j/_sync/io/_bolt4.py | 4 +++- src/neo4j/_sync/io/_bolt5.py | 4 +++- src/neo4j/_sync/io/_pool.py | 14 ++++++++------ 8 files changed, 36 insertions(+), 20 deletions(-) diff --git a/src/neo4j/_async/io/_bolt3.py b/src/neo4j/_async/io/_bolt3.py index f12a887f0..3e948b02f 100644 --- a/src/neo4j/_async/io/_bolt3.py +++ b/src/neo4j/_async/io/_bolt3.py @@ -407,7 +407,9 @@ async def _process_message(self, tag, fields): raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: - self.pool.on_write_failure(address=self.unresolved_address) + await self.pool.on_write_failure( + address=self.unresolved_address + ) raise except Neo4jError as e: await self.pool.on_neo4j_error(e, self) diff --git a/src/neo4j/_async/io/_bolt4.py b/src/neo4j/_async/io/_bolt4.py index 12f7e1ae0..f4f433787 100644 --- a/src/neo4j/_async/io/_bolt4.py +++ b/src/neo4j/_async/io/_bolt4.py @@ -356,7 +356,9 @@ async def _process_message(self, tag, fields): raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: - self.pool.on_write_failure(address=self.unresolved_address) + await self.pool.on_write_failure( + address=self.unresolved_address + ) raise except Neo4jError as e: if self.pool: diff --git a/src/neo4j/_async/io/_bolt5.py b/src/neo4j/_async/io/_bolt5.py index 5d80bc7d0..89a06f99a 100644 --- a/src/neo4j/_async/io/_bolt5.py +++ b/src/neo4j/_async/io/_bolt5.py @@ -356,7 +356,9 @@ async def _process_message(self, tag, fields): raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: - self.pool.on_write_failure(address=self.unresolved_address) + await self.pool.on_write_failure( + address=self.unresolved_address + ) raise except Neo4jError as e: if self.pool: diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 5f8468c45..e854fe987 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -451,7 +451,7 @@ async def deactivate(self, address): await self._close_connections(closable_connections) - def on_write_failure(self, address): + async def on_write_failure(self, address): raise WriteServiceUnavailable( "No write service available for pool {}".format(self) ) @@ -949,17 +949,19 @@ async def deactivate(self, address): log.debug("[#0000] _: deactivating address %r", address) # We use `discard` instead of `remove` here since the former # will not fail if the address has already been removed. - for database in self.routing_tables.keys(): - self.routing_tables[database].routers.discard(address) - self.routing_tables[database].readers.discard(address) - self.routing_tables[database].writers.discard(address) + async with self.refresh_lock: + for database in self.routing_tables.keys(): + self.routing_tables[database].routers.discard(address) + self.routing_tables[database].readers.discard(address) + self.routing_tables[database].writers.discard(address) log.debug("[#0000] _: table=%r", self.routing_tables) await super(AsyncNeo4jPool, self).deactivate(address) - def on_write_failure(self, address): + async def on_write_failure(self, address): """ Remove a writer address from the routing table, if present. """ log.debug("[#0000] _: removing writer %r", address) - for database in self.routing_tables.keys(): - self.routing_tables[database].writers.discard(address) + async with self.refresh_lock: + for database in self.routing_tables.keys(): + self.routing_tables[database].writers.discard(address) log.debug("[#0000] _: table=%r", self.routing_tables) diff --git a/src/neo4j/_sync/io/_bolt3.py b/src/neo4j/_sync/io/_bolt3.py index 00d8dcb58..06fccadbd 100644 --- a/src/neo4j/_sync/io/_bolt3.py +++ b/src/neo4j/_sync/io/_bolt3.py @@ -407,7 +407,9 @@ def _process_message(self, tag, fields): raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: - self.pool.on_write_failure(address=self.unresolved_address) + self.pool.on_write_failure( + address=self.unresolved_address + ) raise except Neo4jError as e: self.pool.on_neo4j_error(e, self) diff --git a/src/neo4j/_sync/io/_bolt4.py b/src/neo4j/_sync/io/_bolt4.py index ce763b239..4eea73002 100644 --- a/src/neo4j/_sync/io/_bolt4.py +++ b/src/neo4j/_sync/io/_bolt4.py @@ -356,7 +356,9 @@ def _process_message(self, tag, fields): raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: - self.pool.on_write_failure(address=self.unresolved_address) + self.pool.on_write_failure( + address=self.unresolved_address + ) raise except Neo4jError as e: if self.pool: diff --git a/src/neo4j/_sync/io/_bolt5.py b/src/neo4j/_sync/io/_bolt5.py index b89da8ba1..bea6eabc7 100644 --- a/src/neo4j/_sync/io/_bolt5.py +++ b/src/neo4j/_sync/io/_bolt5.py @@ -356,7 +356,9 @@ def _process_message(self, tag, fields): raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: - self.pool.on_write_failure(address=self.unresolved_address) + self.pool.on_write_failure( + address=self.unresolved_address + ) raise except Neo4jError as e: if self.pool: diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index 946d2348c..b8adff017 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -946,10 +946,11 @@ def deactivate(self, address): log.debug("[#0000] _: deactivating address %r", address) # We use `discard` instead of `remove` here since the former # will not fail if the address has already been removed. - for database in self.routing_tables.keys(): - self.routing_tables[database].routers.discard(address) - self.routing_tables[database].readers.discard(address) - self.routing_tables[database].writers.discard(address) + with self.refresh_lock: + for database in self.routing_tables.keys(): + self.routing_tables[database].routers.discard(address) + self.routing_tables[database].readers.discard(address) + self.routing_tables[database].writers.discard(address) log.debug("[#0000] _: table=%r", self.routing_tables) super(Neo4jPool, self).deactivate(address) @@ -957,6 +958,7 @@ def on_write_failure(self, address): """ Remove a writer address from the routing table, if present. """ log.debug("[#0000] _: removing writer %r", address) - for database in self.routing_tables.keys(): - self.routing_tables[database].writers.discard(address) + with self.refresh_lock: + for database in self.routing_tables.keys(): + self.routing_tables[database].writers.discard(address) log.debug("[#0000] _: table=%r", self.routing_tables)