diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index ccde6da01..2ed57bc78 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -123,7 +123,7 @@ class Bolt(abc.ABC): PROTOCOL_VERSION = None # flag if connection needs RESET to go back to READY state - _is_reset = True + is_reset = False # The socket in_use = False @@ -460,10 +460,6 @@ def rollback(self, **handlers): """ Appends a ROLLBACK message to the output queue.""" pass - @property - def is_reset(self): - return self._is_reset - @abc.abstractmethod def reset(self): """ Appends a RESET message to the outgoing queue, sends it and consumes diff --git a/neo4j/io/_bolt3.py b/neo4j/io/_bolt3.py index 53d1109a1..fe7608b0b 100644 --- a/neo4j/io/_bolt3.py +++ b/neo4j/io/_bolt3.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum from logging import getLogger from ssl import SSLSocket @@ -52,6 +53,53 @@ log = getLogger("neo4j") +class ServerStates(Enum): + CONNECTED = "CONNECTED" + READY = "READY" + STREAMING = "STREAMING" + TX_READY_OR_TX_STREAMING = "TX_READY||TX_STREAMING" + FAILED = "FAILED" + + +class ServerStateManager: + _STATE_TRANSITIONS = { + ServerStates.CONNECTED: { + "hello": ServerStates.READY, + }, + ServerStates.READY: { + "run": ServerStates.STREAMING, + "begin": ServerStates.TX_READY_OR_TX_STREAMING, + }, + ServerStates.STREAMING: { + "pull": ServerStates.READY, + "discard": ServerStates.READY, + "reset": ServerStates.READY, + }, + ServerStates.TX_READY_OR_TX_STREAMING: { + "commit": ServerStates.READY, + "rollback": ServerStates.READY, + "reset": ServerStates.READY, + }, + ServerStates.FAILED: { + "reset": ServerStates.READY, + } + } + + def __init__(self, init_state, on_change=None): + self.state = init_state + self._on_change = on_change + + def transition(self, message, metadata): + if metadata.get("has_more"): + return + state_before = self.state + self.state = self._STATE_TRANSITIONS\ + .get(self.state, {})\ + .get(message, self.state) + if state_before != self.state and callable(self._on_change): + self._on_change(state_before, self.state) + + class Bolt3(Bolt): """ Protocol handler for Bolt 3. @@ -64,6 +112,25 @@ class Bolt3(Bolt): supports_multiple_databases = False + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._server_state_manager = ServerStateManager( + ServerStates.CONNECTED, on_change=self._on_server_state_change + ) + + def _on_server_state_change(self, old_state, new_state): + log.debug("[#%04X] State: %s > %s", self.local_port, + old_state.name, new_state.name) + + @property + def is_reset(self): + if self.responses: + # We can't be sure of the server's state as there are still pending + # responses. Unless the last message we sent was RESET. In that case + # the server state will always be READY when we're done. + return self.responses[-1].message == "reset" + return self._server_state_manager.state == ServerStates.READY + @property def encrypted(self): return isinstance(self.socket, SSLSocket) @@ -92,7 +159,8 @@ def hello(self): logged_headers["credentials"] = "*******" log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers,), - response=InitResponse(self, on_success=self.server_info.update)) + response=InitResponse(self, "hello", + on_success=self.server_info.update)) self.send_all() self.fetch_all() check_supported_server_product(self.server_info.agent) @@ -155,21 +223,20 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, fields = (query, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) if query.upper() == u"COMMIT": - self._append(b"\x10", fields, CommitResponse(self, **handlers)) + self._append(b"\x10", fields, CommitResponse(self, "run", + **handlers)) else: - self._append(b"\x10", fields, Response(self, **handlers)) - self._is_reset = False + self._append(b"\x10", fields, Response(self, "run", **handlers)) def discard(self, n=-1, qid=-1, **handlers): # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. log.debug("[#%04X] C: DISCARD_ALL", self.local_port) - self._append(b"\x2F", (), Response(self, **handlers)) + self._append(b"\x2F", (), Response(self, "discard", **handlers)) def pull(self, n=-1, qid=-1, **handlers): # Just ignore n and qid, it is not supported in the Bolt 3 Protocol. log.debug("[#%04X] C: PULL_ALL", self.local_port) - self._append(b"\x3F", (), Response(self, **handlers)) - self._is_reset = False + self._append(b"\x3F", (), Response(self, "pull", **handlers)) def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers): if db is not None: @@ -193,16 +260,15 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, except TypeError: raise TypeError("Timeout must be specified as a number of seconds") log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) - self._append(b"\x11", (extra,), Response(self, **handlers)) - self._is_reset = False + self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) def commit(self, **handlers): log.debug("[#%04X] C: COMMIT", self.local_port) - self._append(b"\x12", (), CommitResponse(self, **handlers)) + self._append(b"\x12", (), CommitResponse(self, "commit", **handlers)) def rollback(self, **handlers): log.debug("[#%04X] C: ROLLBACK", self.local_port) - self._append(b"\x13", (), Response(self, **handlers)) + self._append(b"\x13", (), Response(self, "rollback", **handlers)) def reset(self): """ Add a RESET message to the outgoing queue, send @@ -213,10 +279,9 @@ def fail(metadata): raise BoltProtocolError("RESET failed %r" % metadata, address=self.unresolved_address) log.debug("[#%04X] C: RESET", self.local_port) - self._append(b"\x0F", response=Response(self, on_failure=fail)) + self._append(b"\x0F", response=Response(self, "reset", on_failure=fail)) self.send_all() self.fetch_all() - self._is_reset = True def fetch_message(self): """ Receive at most one message from the server, if available. @@ -249,12 +314,15 @@ def fetch_message(self): response.complete = True if summary_signature == b"\x70": log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata) + self._server_state_manager.transition(response.message, + summary_metadata) response.on_success(summary_metadata or {}) elif summary_signature == b"\x7E": log.debug("[#%04X] S: IGNORED", self.local_port) response.on_ignored(summary_metadata or {}) elif summary_signature == b"\x7F": log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) + self._server_state_manager.state = ServerStates.FAILED try: response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): diff --git a/neo4j/io/_bolt4.py b/neo4j/io/_bolt4.py index 8f123e337..4b7c2045c 100644 --- a/neo4j/io/_bolt4.py +++ b/neo4j/io/_bolt4.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum from logging import getLogger from ssl import SSLSocket @@ -37,7 +38,6 @@ Neo4jError, NotALeader, ServiceUnavailable, - SessionExpired, ) from neo4j.io import ( Bolt, @@ -48,6 +48,10 @@ InitResponse, Response, ) +from neo4j.io._bolt3 import ( + ServerStateManager, + ServerStates, +) log = getLogger("neo4j") @@ -65,6 +69,25 @@ class Bolt4x0(Bolt): supports_multiple_databases = True + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._server_state_manager = ServerStateManager( + ServerStates.CONNECTED, on_change=self._on_server_state_change + ) + + def _on_server_state_change(self, old_state, new_state): + log.debug("[#%04X] State: %s > %s", self.local_port, + old_state.name, new_state.name) + + @property + def is_reset(self): + if self.responses: + # We can't be sure of the server's state as there are still pending + # responses. Unless the last message we sent was RESET. In that case + # the server state will always be READY when we're done. + return self.responses[-1].message == "reset" + return self._server_state_manager.state == ServerStates.READY + @property def encrypted(self): return isinstance(self.socket, SSLSocket) @@ -93,7 +116,8 @@ def hello(self): logged_headers["credentials"] = "*******" log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers,), - response=InitResponse(self, on_success=self.server_info.update)) + response=InitResponse(self, "hello", + on_success=self.server_info.update)) self.send_all() self.fetch_all() check_supported_server_product(self.server_info.agent) @@ -162,25 +186,24 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, fields = (query, parameters, extra) log.debug("[#%04X] C: RUN %s", self.local_port, " ".join(map(repr, fields))) if query.upper() == u"COMMIT": - self._append(b"\x10", fields, CommitResponse(self, **handlers)) + self._append(b"\x10", fields, CommitResponse(self, "run", + **handlers)) else: - self._append(b"\x10", fields, Response(self, **handlers)) - self._is_reset = False + self._append(b"\x10", fields, Response(self, "run", **handlers)) def discard(self, n=-1, qid=-1, **handlers): extra = {"n": n} if qid != -1: extra["qid"] = qid log.debug("[#%04X] C: DISCARD %r", self.local_port, extra) - self._append(b"\x2F", (extra,), Response(self, **handlers)) + self._append(b"\x2F", (extra,), Response(self, "discard", **handlers)) def pull(self, n=-1, qid=-1, **handlers): extra = {"n": n} if qid != -1: extra["qid"] = qid log.debug("[#%04X] C: PULL %r", self.local_port, extra) - self._append(b"\x3F", (extra,), Response(self, **handlers)) - self._is_reset = False + self._append(b"\x3F", (extra,), Response(self, "pull", **handlers)) def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers): @@ -205,16 +228,15 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, except TypeError: raise TypeError("Timeout must be specified as a number of seconds") log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) - self._append(b"\x11", (extra,), Response(self, **handlers)) - self._is_reset = False + self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) def commit(self, **handlers): log.debug("[#%04X] C: COMMIT", self.local_port) - self._append(b"\x12", (), CommitResponse(self, **handlers)) + self._append(b"\x12", (), CommitResponse(self, "commit", **handlers)) def rollback(self, **handlers): log.debug("[#%04X] C: ROLLBACK", self.local_port) - self._append(b"\x13", (), Response(self, **handlers)) + self._append(b"\x13", (), Response(self, "rollback", **handlers)) def reset(self): """ Add a RESET message to the outgoing queue, send @@ -225,10 +247,9 @@ def fail(metadata): raise BoltProtocolError("RESET failed %r" % metadata, self.unresolved_address) log.debug("[#%04X] C: RESET", self.local_port) - self._append(b"\x0F", response=Response(self, on_failure=fail)) + self._append(b"\x0F", response=Response(self, "reset", on_failure=fail)) self.send_all() self.fetch_all() - self._is_reset = True def fetch_message(self): """ Receive at most one message from the server, if available. @@ -261,12 +282,15 @@ def fetch_message(self): response.complete = True if summary_signature == b"\x70": log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata) + self._server_state_manager.transition(response.message, + summary_metadata) response.on_success(summary_metadata or {}) elif summary_signature == b"\x7E": log.debug("[#%04X] S: IGNORED", self.local_port) response.on_ignored(summary_metadata or {}) elif summary_signature == b"\x7F": log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) + self._server_state_manager.state = ServerStates.FAILED try: response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): @@ -372,7 +396,9 @@ def fail(md): else: bookmarks = list(bookmarks) self._append(b"\x66", (routing_context, bookmarks, database), - response=Response(self, on_success=metadata.update, on_failure=fail)) + response=Response(self, "route", + on_success=metadata.update, + on_failure=fail)) self.send_all() self.fetch_all() return [metadata.get("rt")] @@ -400,7 +426,8 @@ def on_success(metadata): logged_headers["credentials"] = "*******" log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) self._append(b"\x01", (headers,), - response=InitResponse(self, on_success=on_success)) + response=InitResponse(self, "hello", + on_success=on_success)) self.send_all() self.fetch_all() check_supported_server_product(self.server_info.agent) diff --git a/neo4j/io/_common.py b/neo4j/io/_common.py index fc543499e..0dc8b2a3b 100644 --- a/neo4j/io/_common.py +++ b/neo4j/io/_common.py @@ -144,9 +144,10 @@ class Response: more detail messages followed by one summary message). """ - def __init__(self, connection, **handlers): + def __init__(self, connection, message, **handlers): self.connection = connection self.handlers = handlers + self.message = message self.complete = False def on_records(self, records): diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 858a86641..8e4450e8b 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -29,14 +29,12 @@ "stub.session_run_parameters.test_session_run_parameters.TestSessionRunParameters.test_empty_query": "Driver rejects empty queries before sending it to the server", "tls.tlsversions.TestTlsVersions.test_1_1": - "TLSv1.1 and below are disabled in the driver", - "stub.disconnects.test_disconnects.TestDisconnects.test_fail_on_reset": - "Driver silently ignores all errors on releasing connections back into the pool." + "TLSv1.1 and below are disabled in the driver" }, "features": { "AuthorizationExpiredTreatment": true, "Optimization:ImplicitDefaultArguments": true, - "Optimization:MinimalResets": "Driver resets some clean connections when put back into pool", + "Optimization:MinimalResets": true, "Optimization:ConnectionReuse": true, "Optimization:PullPipelining": true, "ConfHint:connection.recv_timeout_seconds": true,