-
Notifications
You must be signed in to change notification settings - Fork 24
Errors propagate through transaction #247
Changes from all commits
346a762
b6d1014
1668a05
391641d
7797c67
f2884cc
6ecd55e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,6 +62,9 @@ def stream(self, req: transaction_proto.Transaction.Req) -> Iterator[transaction | |
self._dispatcher.dispatch(req) | ||
return ResponsePartIterator(request_id, self, self._dispatcher) | ||
|
||
def done(self, request_id: UUID): | ||
self._response_collector.remove(request_id) | ||
|
||
Comment on lines
+65
to
+67
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we now clean up completed |
||
def is_open(self) -> bool: | ||
return self._is_open.get() | ||
|
||
|
@@ -78,8 +81,9 @@ def fetch(self, request_id: UUID) -> Union[transaction_proto.Transaction.Res, tr | |
raise TypeDBClientException.of(TRANSACTION_CLOSED) | ||
server_msg = next(self._response_iterator) | ||
except RpcError as e: | ||
self.close(e) | ||
raise TypeDBClientException.of_rpc(e) | ||
error = TypeDBClientException.of_rpc(e) | ||
self.close(error) | ||
raise error | ||
Comment on lines
+84
to
+86
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed this so that we propagate our own exception instead of the gRPC error throughout the transaction queues |
||
except StopIteration: | ||
self.close() | ||
raise TypeDBClientException.of(TRANSACTION_CLOSED) | ||
|
@@ -100,10 +104,10 @@ def _collect(self, response: Union[transaction_proto.Transaction.Res, transactio | |
else: | ||
raise TypeDBClientException.of(UNKNOWN_REQUEST_ID, request_id) | ||
|
||
def drain_errors(self) -> List[RpcError]: | ||
return self._response_collector.drain_errors() | ||
def get_errors(self) -> List[TypeDBClientException]: | ||
return self._response_collector.get_errors() | ||
|
||
def close(self, error: RpcError = None): | ||
def close(self, error: TypeDBClientException = None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. new type: use our own exception everywhere |
||
if self._is_open.compare_and_set(True, False): | ||
self._response_collector.close(error) | ||
try: | ||
|
@@ -127,7 +131,9 @@ def __init__(self, request_id: UUID, stream: "BidirectionalStream"): | |
self._stream = stream | ||
|
||
def get(self) -> T: | ||
return self._stream.fetch(self._request_id) | ||
value = self._stream.fetch(self._request_id) | ||
self._stream.done(self._request_id) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a |
||
return value | ||
|
||
|
||
class RequestIterator(Iterator[Union[transaction_proto.Transaction.Req, StopIteration]]): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,81 +21,98 @@ | |
|
||
import queue | ||
from threading import Lock | ||
from typing import Generic, TypeVar, Dict, Optional, Union | ||
from typing import Generic, TypeVar, Dict, Optional | ||
from uuid import UUID | ||
|
||
from grpc import RpcError | ||
|
||
from typedb.common.exception import TypeDBClientException, TRANSACTION_CLOSED | ||
from typedb.common.exception import TypeDBClientException, TRANSACTION_CLOSED, ILLEGAL_STATE, \ | ||
TRANSACTION_CLOSED_WITH_ERRORS | ||
|
||
R = TypeVar('R') | ||
|
||
|
||
class ResponseCollector(Generic[R]): | ||
|
||
def __init__(self): | ||
self._collectors: Dict[UUID, ResponseCollector.Queue[R]] = {} | ||
self._response_queues: Dict[UUID, ResponseCollector.Queue[R]] = {} | ||
self._collectors_lock = Lock() | ||
|
||
def new_queue(self, request_id: UUID): | ||
with self._collectors_lock: | ||
collector: ResponseCollector.Queue[R] = ResponseCollector.Queue() | ||
self._collectors[request_id] = collector | ||
self._response_queues[request_id] = collector | ||
return collector | ||
|
||
def get(self, request_id: UUID) -> Optional["ResponseCollector.Queue[R]"]: | ||
return self._collectors.get(request_id) | ||
return self._response_queues.get(request_id) | ||
|
||
def remove(self, request_id: UUID): | ||
with self._collectors_lock: | ||
del self._response_queues[request_id] | ||
|
||
def close(self, error: Optional[RpcError]): | ||
def close(self, error: Optional[TypeDBClientException]): | ||
with self._collectors_lock: | ||
for collector in self._collectors.values(): | ||
for collector in self._response_queues.values(): | ||
collector.close(error) | ||
|
||
def drain_errors(self) -> [RpcError]: | ||
def get_errors(self) -> [TypeDBClientException]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. new error typing |
||
errors = [] | ||
with self._collectors_lock: | ||
for collector in self._collectors.values(): | ||
errors.extend(collector.drain_errors()) | ||
for collector in self._response_queues.values(): | ||
error = collector.get_error() | ||
if error is not None: | ||
errors.append(error) | ||
Comment on lines
+61
to
+63
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. get error from each collector queue |
||
return errors | ||
|
||
|
||
class Queue(Generic[R]): | ||
|
||
def __init__(self): | ||
self._response_queue: queue.Queue[Union[Response[R], Done]] = queue.Queue() | ||
self._response_queue: queue.Queue[Response] = queue.Queue() | ||
self._error: TypeDBClientException = None | ||
|
||
def get(self, block: bool) -> R: | ||
response = self._response_queue.get(block=block) | ||
if response.message: | ||
return response.message | ||
elif response.error: | ||
raise TypeDBClientException.of_rpc(response.error) | ||
else: | ||
if response.is_value(): | ||
return response.value | ||
elif response.is_done() and self._error is None: | ||
raise TypeDBClientException.of(TRANSACTION_CLOSED) | ||
elif response.is_done() and self._error is not None: | ||
raise TypeDBClientException.of(TRANSACTION_CLOSED_WITH_ERRORS, self._error) | ||
else: | ||
raise TypeDBClientException.of(ILLEGAL_STATE) | ||
|
||
def put(self, response: R): | ||
self._response_queue.put(Response(response)) | ||
self._response_queue.put(ValueResponse(response)) | ||
|
||
def close(self, error: Optional[RpcError]): | ||
self._response_queue.put(Done(error)) | ||
def close(self, error: Optional[TypeDBClientException]): | ||
self._error = error | ||
self._response_queue.put(DoneResponse()) | ||
|
||
def drain_errors(self) -> [RpcError]: | ||
errors = [] | ||
while not self._response_queue.empty(): | ||
response = self._response_queue.get(block = False) | ||
if response.error: | ||
errors.append(response.error) | ||
return errors | ||
def get_error(self) -> TypeDBClientException: | ||
return self._error | ||
|
||
|
||
class Response: | ||
|
||
class Response(Generic[R]): | ||
def is_value(self): | ||
return False | ||
|
||
def is_done(self): | ||
return False | ||
|
||
|
||
class ValueResponse(Response, Generic[R]): | ||
|
||
def __init__(self, value: R): | ||
self.message = value | ||
self.value = value | ||
|
||
def is_value(self): | ||
return True | ||
|
||
class Done: | ||
|
||
def __init__(self, error: Optional[RpcError]): | ||
self.error = error | ||
class DoneResponse(Response): | ||
|
||
def __init__(self): | ||
pass | ||
|
||
def is_done(self): | ||
return True |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,12 +18,11 @@ | |
# specific language governing permissions and limitations | ||
# under the License. | ||
# | ||
from enum import Enum | ||
from typing import Iterator, TYPE_CHECKING | ||
from uuid import UUID | ||
|
||
import typedb_protocol.common.transaction_pb2 as transaction_proto | ||
|
||
from enum import Enum | ||
from typedb.common.exception import TypeDBClientException, ILLEGAL_ARGUMENT, MISSING_RESPONSE, ILLEGAL_STATE | ||
from typedb.common.rpc.request_builder import transaction_stream_req | ||
from typedb.stream.request_transmitter import RequestTransmitter | ||
|
@@ -78,6 +77,7 @@ def _has_next(self) -> bool: | |
|
||
def __next__(self) -> transaction_proto.Transaction.ResPart: | ||
if not self._has_next(): | ||
self._bidirectional_stream.done(self._request_id) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we also let the transaction stream know this query stream is done so it can be removed |
||
raise StopIteration | ||
self._state = ResponsePartIterator.State.EMPTY | ||
return self._next |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new name, aligned with client-nodejs