diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 863c73dfb..4804df0b8 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -487,8 +487,7 @@ def _append(self, signature, fields=(), response=None): :param response: a response object to handle callbacks """ self.packer.pack_struct(signature, fields) - self.outbox.chunk() - self.outbox.chunk() + self.outbox.wrap_message() self.responses.append(response) def _send_all(self): diff --git a/neo4j/io/_common.py b/neo4j/io/_common.py index 0eec2e2d5..becb7db47 100644 --- a/neo4j/io/_common.py +++ b/neo4j/io/_common.py @@ -89,54 +89,51 @@ def __next__(self): class Outbox: - def __init__(self, capacity=8192, max_chunk_size=16384): + def __init__(self, max_chunk_size=16384): self._max_chunk_size = max_chunk_size - self._header = 0 - self._start = 2 - self._end = 2 - self._data = bytearray(capacity) + self._chunked_data = bytearray() + self._raw_data = bytearray() + self.write = self._raw_data.extend def max_chunk_size(self): return self._max_chunk_size def clear(self): - self._header = 0 - self._start = 2 - self._end = 2 - self._data[0:2] = b"\x00\x00" - - def write(self, b): - to_write = len(b) - max_chunk_size = self._max_chunk_size - pos = 0 - while to_write > 0: - chunk_size = self._end - self._start - remaining = max_chunk_size - chunk_size - if remaining == 0 or remaining < to_write <= max_chunk_size: - self.chunk() - else: - wrote = min(to_write, remaining) - new_end = self._end + wrote - self._data[self._end:new_end] = b[pos:pos+wrote] - self._end = new_end - pos += wrote - new_chunk_size = self._end - self._start - self._data[self._header:(self._header + 2)] = struct_pack(">H", new_chunk_size) - to_write -= wrote - - def chunk(self): - self._header = self._end - self._start = self._header + 2 - self._end = self._start - self._data[self._header:self._start] = b"\x00\x00" + self._chunked_data = bytearray() + self._raw_data.clear() + + def _chunk_data(self): + data_len = len(self._raw_data) + num_full_chunks, chunk_rest = divmod( + data_len, self._max_chunk_size + ) + num_chunks = num_full_chunks + bool(chunk_rest) + + data_view = memoryview(self._raw_data) + header_start = len(self._chunked_data) + data_start = header_start + 2 + raw_data_start = 0 + for i in range(num_chunks): + chunk_size = min(data_len - raw_data_start, + self._max_chunk_size) + self._chunked_data[header_start:data_start] = struct_pack( + ">H", chunk_size + ) + self._chunked_data[data_start:(data_start + chunk_size)] = \ + data_view[raw_data_start:(raw_data_start + chunk_size)] + header_start += chunk_size + 2 + data_start = header_start + 2 + raw_data_start += chunk_size + del data_view + self._raw_data.clear() + + def wrap_message(self): + self._chunk_data() + self._chunked_data += b"\x00\x00" def view(self): - end = self._end - chunk_size = end - self._start - if chunk_size == 0: - return memoryview(self._data[:self._header]) - else: - return memoryview(self._data[:end]) + self._chunk_data() + return memoryview(self._chunked_data) class ConnectionErrorHandler: diff --git a/neo4j/packstream.py b/neo4j/packstream.py index 1b72451ba..406d761e4 100644 --- a/neo4j/packstream.py +++ b/neo4j/packstream.py @@ -125,12 +125,9 @@ def _pack(self, value): self.pack_raw(encoded) # Bytes - elif isinstance(value, bytes): + elif isinstance(value, (bytes, bytearray)): self.pack_bytes_header(len(value)) self.pack_raw(value) - elif isinstance(value, bytearray): - self.pack_bytes_header(len(value)) - self.pack_raw(bytes(value)) # List elif isinstance(value, list): @@ -169,38 +166,8 @@ def pack_bytes_header(self, size): def pack_string_header(self, size): write = self._write - if size == 0x00: - write(b"\x80") - elif size == 0x01: - write(b"\x81") - elif size == 0x02: - write(b"\x82") - elif size == 0x03: - write(b"\x83") - elif size == 0x04: - write(b"\x84") - elif size == 0x05: - write(b"\x85") - elif size == 0x06: - write(b"\x86") - elif size == 0x07: - write(b"\x87") - elif size == 0x08: - write(b"\x88") - elif size == 0x09: - write(b"\x89") - elif size == 0x0A: - write(b"\x8A") - elif size == 0x0B: - write(b"\x8B") - elif size == 0x0C: - write(b"\x8C") - elif size == 0x0D: - write(b"\x8D") - elif size == 0x0E: - write(b"\x8E") - elif size == 0x0F: - write(b"\x8F") + if size <= 0x0F: + write(bytes((0x80 | size,))) elif size < 0x100: write(b"\xD0") write(PACKED_UINT_8[size]) @@ -215,38 +182,8 @@ def pack_string_header(self, size): def pack_list_header(self, size): write = self._write - if size == 0x00: - write(b"\x90") - elif size == 0x01: - write(b"\x91") - elif size == 0x02: - write(b"\x92") - elif size == 0x03: - write(b"\x93") - elif size == 0x04: - write(b"\x94") - elif size == 0x05: - write(b"\x95") - elif size == 0x06: - write(b"\x96") - elif size == 0x07: - write(b"\x97") - elif size == 0x08: - write(b"\x98") - elif size == 0x09: - write(b"\x99") - elif size == 0x0A: - write(b"\x9A") - elif size == 0x0B: - write(b"\x9B") - elif size == 0x0C: - write(b"\x9C") - elif size == 0x0D: - write(b"\x9D") - elif size == 0x0E: - write(b"\x9E") - elif size == 0x0F: - write(b"\x9F") + if size <= 0x0F: + write(bytes((0x90 | size,))) elif size < 0x100: write(b"\xD4") write(PACKED_UINT_8[size]) @@ -264,38 +201,8 @@ def pack_list_stream_header(self): def pack_map_header(self, size): write = self._write - if size == 0x00: - write(b"\xA0") - elif size == 0x01: - write(b"\xA1") - elif size == 0x02: - write(b"\xA2") - elif size == 0x03: - write(b"\xA3") - elif size == 0x04: - write(b"\xA4") - elif size == 0x05: - write(b"\xA5") - elif size == 0x06: - write(b"\xA6") - elif size == 0x07: - write(b"\xA7") - elif size == 0x08: - write(b"\xA8") - elif size == 0x09: - write(b"\xA9") - elif size == 0x0A: - write(b"\xAA") - elif size == 0x0B: - write(b"\xAB") - elif size == 0x0C: - write(b"\xAC") - elif size == 0x0D: - write(b"\xAD") - elif size == 0x0E: - write(b"\xAE") - elif size == 0x0F: - write(b"\xAF") + if size <= 0x0F: + write(bytes((0xA0 | size,))) elif size < 0x100: write(b"\xD8") write(PACKED_UINT_8[size]) @@ -316,38 +223,8 @@ def pack_struct(self, signature, fields): raise ValueError("Structure signature must be a single byte value") write = self._write size = len(fields) - if size == 0x00: - write(b"\xB0") - elif size == 0x01: - write(b"\xB1") - elif size == 0x02: - write(b"\xB2") - elif size == 0x03: - write(b"\xB3") - elif size == 0x04: - write(b"\xB4") - elif size == 0x05: - write(b"\xB5") - elif size == 0x06: - write(b"\xB6") - elif size == 0x07: - write(b"\xB7") - elif size == 0x08: - write(b"\xB8") - elif size == 0x09: - write(b"\xB9") - elif size == 0x0A: - write(b"\xBA") - elif size == 0x0B: - write(b"\xBB") - elif size == 0x0C: - write(b"\xBC") - elif size == 0x0D: - write(b"\xBD") - elif size == 0x0E: - write(b"\xBE") - elif size == 0x0F: - write(b"\xBF") + if size <= 0x0F: + write(bytes((0xB0 | size,))) else: raise OverflowError("Structure size out of range") write(signature) diff --git a/tests/unit/io/test__common.py b/tests/unit/io/test__common.py new file mode 100644 index 000000000..3b61c7103 --- /dev/null +++ b/tests/unit/io/test__common.py @@ -0,0 +1,32 @@ +import pytest + +from neo4j.io._common import Outbox + + +@pytest.mark.parametrize(("chunk_size", "data", "result"), ( + ( + 2, + (bytes(range(10, 15)),), + bytes((0, 2, 10, 11, 0, 2, 12, 13, 0, 1, 14)) + ), + ( + 2, + (bytes(range(10, 14)),), + bytes((0, 2, 10, 11, 0, 2, 12, 13)) + ), + ( + 2, + (bytes((5, 6, 7)), bytes((8, 9))), + bytes((0, 2, 5, 6, 0, 2, 7, 8, 0, 1, 9)) + ), +)) +def test_outbox_chunking(chunk_size, data, result): + outbox = Outbox(max_chunk_size=chunk_size) + assert bytes(outbox.view()) == b"" + for d in data: + outbox.write(d) + assert bytes(outbox.view()) == result + # make sure this works multiple times + assert bytes(outbox.view()) == result + outbox.clear() + assert bytes(outbox.view()) == b""