diff --git a/typedb/connection/transaction.py b/typedb/connection/transaction.py index fa1cc02d..6885d32b 100644 --- a/typedb/connection/transaction.py +++ b/typedb/connection/transaction.py @@ -113,7 +113,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False def _raise_transaction_closed(self): - errors = self._bidirectional_stream.drain_errors() + errors = self._bidirectional_stream.get_errors() if len(errors) == 0: raise TypeDBClientException.of(TRANSACTION_CLOSED) else: diff --git a/typedb/stream/bidirectional_stream.py b/typedb/stream/bidirectional_stream.py index c88850a2..b00d71ae 100644 --- a/typedb/stream/bidirectional_stream.py +++ b/typedb/stream/bidirectional_stream.py @@ -62,6 +62,9 @@ def stream(self, req: transaction_proto.Transaction.Req) -> Iterator[transaction self._dispatcher.dispatch(req) return ResponsePartIterator(request_id, self, self._dispatcher) + def done(self, request_id: UUID): + self._response_collector.remove(request_id) + def is_open(self) -> bool: return self._is_open.get() @@ -78,8 +81,9 @@ def fetch(self, request_id: UUID) -> Union[transaction_proto.Transaction.Res, tr raise TypeDBClientException.of(TRANSACTION_CLOSED) server_msg = next(self._response_iterator) except RpcError as e: - self.close(e) - raise TypeDBClientException.of_rpc(e) + error = TypeDBClientException.of_rpc(e) + self.close(error) + raise error except StopIteration: self.close() raise TypeDBClientException.of(TRANSACTION_CLOSED) @@ -100,10 +104,10 @@ def _collect(self, response: Union[transaction_proto.Transaction.Res, transactio else: raise TypeDBClientException.of(UNKNOWN_REQUEST_ID, request_id) - def drain_errors(self) -> List[RpcError]: - return self._response_collector.drain_errors() + def get_errors(self) -> List[TypeDBClientException]: + return self._response_collector.get_errors() - def close(self, error: RpcError = None): + def close(self, error: TypeDBClientException = None): if self._is_open.compare_and_set(True, False): self._response_collector.close(error) try: @@ -127,7 +131,9 @@ def __init__(self, request_id: UUID, stream: "BidirectionalStream"): self._stream = stream def get(self) -> T: - return self._stream.fetch(self._request_id) + value = self._stream.fetch(self._request_id) + self._stream.done(self._request_id) + return value class RequestIterator(Iterator[Union[transaction_proto.Transaction.Req, StopIteration]]): diff --git a/typedb/stream/response_collector.py b/typedb/stream/response_collector.py index 5ad08534..47fe283c 100644 --- a/typedb/stream/response_collector.py +++ b/typedb/stream/response_collector.py @@ -21,12 +21,11 @@ import queue from threading import Lock -from typing import Generic, TypeVar, Dict, Optional, Union +from typing import Generic, TypeVar, Dict, Optional from uuid import UUID -from grpc import RpcError - -from typedb.common.exception import TypeDBClientException, TRANSACTION_CLOSED +from typedb.common.exception import TypeDBClientException, TRANSACTION_CLOSED, ILLEGAL_STATE, \ + TRANSACTION_CLOSED_WITH_ERRORS R = TypeVar('R') @@ -34,68 +33,86 @@ class ResponseCollector(Generic[R]): def __init__(self): - self._collectors: Dict[UUID, ResponseCollector.Queue[R]] = {} + self._response_queues: Dict[UUID, ResponseCollector.Queue[R]] = {} self._collectors_lock = Lock() def new_queue(self, request_id: UUID): with self._collectors_lock: collector: ResponseCollector.Queue[R] = ResponseCollector.Queue() - self._collectors[request_id] = collector + self._response_queues[request_id] = collector return collector def get(self, request_id: UUID) -> Optional["ResponseCollector.Queue[R]"]: - return self._collectors.get(request_id) + return self._response_queues.get(request_id) + + def remove(self, request_id: UUID): + with self._collectors_lock: + del self._response_queues[request_id] - def close(self, error: Optional[RpcError]): + def close(self, error: Optional[TypeDBClientException]): with self._collectors_lock: - for collector in self._collectors.values(): + for collector in self._response_queues.values(): collector.close(error) - def drain_errors(self) -> [RpcError]: + def get_errors(self) -> [TypeDBClientException]: errors = [] with self._collectors_lock: - for collector in self._collectors.values(): - errors.extend(collector.drain_errors()) + for collector in self._response_queues.values(): + error = collector.get_error() + if error is not None: + errors.append(error) return errors - class Queue(Generic[R]): def __init__(self): - self._response_queue: queue.Queue[Union[Response[R], Done]] = queue.Queue() + self._response_queue: queue.Queue[Response] = queue.Queue() + self._error: TypeDBClientException = None def get(self, block: bool) -> R: response = self._response_queue.get(block=block) - if response.message: - return response.message - elif response.error: - raise TypeDBClientException.of_rpc(response.error) - else: + if response.is_value(): + return response.value + elif response.is_done() and self._error is None: raise TypeDBClientException.of(TRANSACTION_CLOSED) + elif response.is_done() and self._error is not None: + raise TypeDBClientException.of(TRANSACTION_CLOSED_WITH_ERRORS, self._error) + else: + raise TypeDBClientException.of(ILLEGAL_STATE) def put(self, response: R): - self._response_queue.put(Response(response)) + self._response_queue.put(ValueResponse(response)) - def close(self, error: Optional[RpcError]): - self._response_queue.put(Done(error)) + def close(self, error: Optional[TypeDBClientException]): + self._error = error + self._response_queue.put(DoneResponse()) - def drain_errors(self) -> [RpcError]: - errors = [] - while not self._response_queue.empty(): - response = self._response_queue.get(block = False) - if response.error: - errors.append(response.error) - return errors + def get_error(self) -> TypeDBClientException: + return self._error +class Response: -class Response(Generic[R]): + def is_value(self): + return False + + def is_done(self): + return False + + +class ValueResponse(Response, Generic[R]): def __init__(self, value: R): - self.message = value + self.value = value + def is_value(self): + return True -class Done: - def __init__(self, error: Optional[RpcError]): - self.error = error +class DoneResponse(Response): + + def __init__(self): + pass + + def is_done(self): + return True diff --git a/typedb/stream/response_part_iterator.py b/typedb/stream/response_part_iterator.py index 08d49b8e..f934378a 100644 --- a/typedb/stream/response_part_iterator.py +++ b/typedb/stream/response_part_iterator.py @@ -18,12 +18,11 @@ # specific language governing permissions and limitations # under the License. # -from enum import Enum from typing import Iterator, TYPE_CHECKING from uuid import UUID import typedb_protocol.common.transaction_pb2 as transaction_proto - +from enum import Enum from typedb.common.exception import TypeDBClientException, ILLEGAL_ARGUMENT, MISSING_RESPONSE, ILLEGAL_STATE from typedb.common.rpc.request_builder import transaction_stream_req from typedb.stream.request_transmitter import RequestTransmitter @@ -78,6 +77,7 @@ def _has_next(self) -> bool: def __next__(self) -> transaction_proto.Transaction.ResPart: if not self._has_next(): + self._bidirectional_stream.done(self._request_id) raise StopIteration self._state = ResponsePartIterator.State.EMPTY return self._next