diff --git a/neo4j/work/result.py b/neo4j/work/result.py index bd1834526..75921b0ba 100644 --- a/neo4j/work/result.py +++ b/neo4j/work/result.py @@ -25,6 +25,7 @@ from neo4j.data import DataDehydrator from neo4j.io import ConnectionErrorHandler from neo4j.work.summary import ResultSummary +from neo4j.exceptions import ResultConsumedError class Result: @@ -192,20 +193,37 @@ def __iter__(self): self._closed = True def _attach(self): - """Sets the Result object in an attached state by fetching messages from the connection to the buffer. + """Sets the Result object in an attached state by fetching messages from + the connection to the buffer. """ if self._closed is False: while self._attached is False: self._connection.fetch_message() - def _buffer_all(self): - """Sets the Result object in an detached state by fetching all records from the connection to the buffer. + def _buffer(self, n=None): + """Try to fill `self_record_buffer` with n records. + + Might end up with more records in the buffer if the fetch size makes it + overshoot. + Might ent up with fewer records in the buffer if there are not enough + records available. """ record_buffer = deque() for record in self: record_buffer.append(record) + if n is not None and len(record_buffer) >= n: + break self._closed = False - self._record_buffer = record_buffer + if n is None: + self._record_buffer = record_buffer + else: + self._record_buffer.extend(record_buffer) + + def _buffer_all(self): + """Sets the Result object in an detached state by fetching all records + from the connection to the buffer. + """ + self._buffer() def _obtain_summary(self): """Obtain the summary of this result, buffering any remaining records. @@ -278,6 +296,13 @@ def single(self): :returns: the next :class:`neo4j.Record` or :const:`None` if none remain :warns: if more than one record is available """ + # TODO in 5.0 replace with this code that raises an error if there's not + # exactly one record in the left result stream. + # self._buffer(2). + # if len(self._record_buffer) != 1: + # raise SomeError("Expected exactly 1 record, found %i" + # % len(self._record_buffer)) + # return self._record_buffer.popleft() records = list(self) # TODO: exhausts the result with self.consume if there are more records. size = len(records) if size == 0: @@ -292,16 +317,9 @@ def peek(self): :returns: the next :class:`.Record` or :const:`None` if none remain """ + self._buffer(1) if self._record_buffer: return self._record_buffer[0] - if not self._attached: - return None - while self._attached: - self._connection.fetch_message() - if self._record_buffer: - return self._record_buffer[0] - - return None def graph(self): """Return a :class:`neo4j.graph.Graph` instance containing all the graph objects diff --git a/testkitbackend/requests.py b/testkitbackend/requests.py index 53f595f98..070996564 100644 --- a/testkitbackend/requests.py +++ b/testkitbackend/requests.py @@ -301,6 +301,20 @@ def ResultNext(backend, data): backend.send_response("Record", totestkit.record(record)) +def ResultSingle(backend, data): + result = backend.results[data["resultId"]] + backend.send_response("Record", totestkit.record(result.single())) + + +def ResultPeek(backend, data): + result = backend.results[data["resultId"]] + record = result.peek() + if record is not None: + backend.send_response("Record", totestkit.record(record)) + else: + backend.send_response("NullRecord", {}) + + def ResultConsume(backend, data): result = backend.results[data["resultId"]] summary = result.consume() diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 282b34549..ebe59eae9 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -38,6 +38,8 @@ "TLSv1.1 and below are disabled in the driver" }, "features": { + "Feature:API:Result.Single": "Does not raise error when not exactly one record is available. To be fixed in 5.0", + "Feature:API:Result.Peek": true, "AuthorizationExpiredTreatment": true, "Optimization:ImplicitDefaultArguments": true, "Optimization:MinimalResets": true, diff --git a/tests/unit/work/test_result.py b/tests/unit/work/test_result.py index 7f2d15bb5..df33293a2 100644 --- a/tests/unit/work/test_result.py +++ b/tests/unit/work/test_result.py @@ -277,12 +277,14 @@ def test_result_peek(records, fetch_size): connection = ConnectionStub(records=Records(["x"], records)) result = Result(connection, HydratorStub(), fetch_size, noop, noop) result._run("CYPHER", {}, None, "r", None) - record = result.peek() - if not records: - assert record is None - else: - assert isinstance(record, Record) - assert record.get("x") == records[0][0] + for i in range(len(records) + 1): + record = result.peek() + if i == len(records): + assert record is None + else: + assert isinstance(record, Record) + assert record.get("x") == records[i][0] + next(iter(result)) # consume the record @pytest.mark.parametrize("records", ([[1], [2]], [[1]], []))