|
19 | 19 | # limitations under the License. |
20 | 20 |
|
21 | 21 |
|
| 22 | +from contextlib import contextmanager |
22 | 23 | import socket |
23 | 24 | from struct import pack as struct_pack |
24 | 25 |
|
25 | 26 | from neo4j.exceptions import ( |
26 | | - AuthError, |
27 | 27 | Neo4jError, |
28 | 28 | ServiceUnavailable, |
29 | 29 | SessionExpired, |
@@ -94,11 +94,14 @@ def __init__(self, max_chunk_size=16384): |
94 | 94 | self._chunked_data = bytearray() |
95 | 95 | self._raw_data = bytearray() |
96 | 96 | self.write = self._raw_data.extend |
| 97 | + self._tmp_buffering = 0 |
97 | 98 |
|
98 | 99 | def max_chunk_size(self): |
99 | 100 | return self._max_chunk_size |
100 | 101 |
|
101 | 102 | def clear(self): |
| 103 | + if self._tmp_buffering: |
| 104 | + raise RuntimeError("Cannot clear while buffering") |
102 | 105 | self._chunked_data = bytearray() |
103 | 106 | self._raw_data.clear() |
104 | 107 |
|
@@ -128,13 +131,29 @@ def _chunk_data(self): |
128 | 131 | self._raw_data.clear() |
129 | 132 |
|
130 | 133 | def wrap_message(self): |
| 134 | + if self._tmp_buffering: |
| 135 | + raise RuntimeError("Cannot wrap message while buffering") |
131 | 136 | self._chunk_data() |
132 | 137 | self._chunked_data += b"\x00\x00" |
133 | 138 |
|
134 | 139 | def view(self): |
| 140 | + if self._tmp_buffering: |
| 141 | + raise RuntimeError("Cannot view while buffering") |
135 | 142 | self._chunk_data() |
136 | 143 | return memoryview(self._chunked_data) |
137 | 144 |
|
| 145 | + @contextmanager |
| 146 | + def tmp_buffer(self): |
| 147 | + self._tmp_buffering += 1 |
| 148 | + old_len = len(self._raw_data) |
| 149 | + try: |
| 150 | + yield |
| 151 | + except Exception: |
| 152 | + del self._raw_data[old_len:] |
| 153 | + raise |
| 154 | + finally: |
| 155 | + self._tmp_buffering -= 1 |
| 156 | + |
138 | 157 |
|
139 | 158 | class ConnectionErrorHandler: |
140 | 159 | """ |
|
0 commit comments