Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion typedb/connection/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return False

def _raise_transaction_closed(self):
errors = self._bidirectional_stream.drain_errors()
errors = self._bidirectional_stream.get_errors()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new name, aligned with client-nodejs

if len(errors) == 0:
raise TypeDBClientException.of(TRANSACTION_CLOSED)
else:
Expand Down
18 changes: 12 additions & 6 deletions typedb/stream/bidirectional_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def stream(self, req: transaction_proto.Transaction.Req) -> Iterator[transaction
self._dispatcher.dispatch(req)
return ResponsePartIterator(request_id, self, self._dispatcher)

def done(self, request_id: UUID):
self._response_collector.remove(request_id)

Comment on lines +65 to +67
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we now clean up completed single/stream from the response collectors, so we don't propagate and print errors for every old operation the transaction also handled but isn't active anymore!

def is_open(self) -> bool:
return self._is_open.get()

Expand All @@ -78,8 +81,9 @@ def fetch(self, request_id: UUID) -> Union[transaction_proto.Transaction.Res, tr
raise TypeDBClientException.of(TRANSACTION_CLOSED)
server_msg = next(self._response_iterator)
except RpcError as e:
self.close(e)
raise TypeDBClientException.of_rpc(e)
error = TypeDBClientException.of_rpc(e)
self.close(error)
raise error
Comment on lines +84 to +86
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed this so that we propagate our own exception instead of the gRPC error throughout the transaction queues

except StopIteration:
self.close()
raise TypeDBClientException.of(TRANSACTION_CLOSED)
Expand All @@ -100,10 +104,10 @@ def _collect(self, response: Union[transaction_proto.Transaction.Res, transactio
else:
raise TypeDBClientException.of(UNKNOWN_REQUEST_ID, request_id)

def drain_errors(self) -> List[RpcError]:
return self._response_collector.drain_errors()
def get_errors(self) -> List[TypeDBClientException]:
return self._response_collector.get_errors()

def close(self, error: RpcError = None):
def close(self, error: TypeDBClientException = None):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new type: use our own exception everywhere

if self._is_open.compare_and_set(True, False):
self._response_collector.close(error)
try:
Expand All @@ -127,7 +131,9 @@ def __init__(self, request_id: UUID, stream: "BidirectionalStream"):
self._stream = stream

def get(self) -> T:
return self._stream.fetch(self._request_id)
value = self._stream.fetch(self._request_id)
self._stream.done(self._request_id)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a Single can immediately remove itself from the transaction stream when the user retrieves the value via get()

return value


class RequestIterator(Iterator[Union[transaction_proto.Transaction.Req, StopIteration]]):
Expand Down
85 changes: 51 additions & 34 deletions typedb/stream/response_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,81 +21,98 @@

import queue
from threading import Lock
from typing import Generic, TypeVar, Dict, Optional, Union
from typing import Generic, TypeVar, Dict, Optional
from uuid import UUID

from grpc import RpcError

from typedb.common.exception import TypeDBClientException, TRANSACTION_CLOSED
from typedb.common.exception import TypeDBClientException, TRANSACTION_CLOSED, ILLEGAL_STATE, \
TRANSACTION_CLOSED_WITH_ERRORS

R = TypeVar('R')


class ResponseCollector(Generic[R]):

def __init__(self):
self._collectors: Dict[UUID, ResponseCollector.Queue[R]] = {}
self._response_queues: Dict[UUID, ResponseCollector.Queue[R]] = {}
self._collectors_lock = Lock()

def new_queue(self, request_id: UUID):
with self._collectors_lock:
collector: ResponseCollector.Queue[R] = ResponseCollector.Queue()
self._collectors[request_id] = collector
self._response_queues[request_id] = collector
return collector

def get(self, request_id: UUID) -> Optional["ResponseCollector.Queue[R]"]:
return self._collectors.get(request_id)
return self._response_queues.get(request_id)

def remove(self, request_id: UUID):
with self._collectors_lock:
del self._response_queues[request_id]

def close(self, error: Optional[RpcError]):
def close(self, error: Optional[TypeDBClientException]):
with self._collectors_lock:
for collector in self._collectors.values():
for collector in self._response_queues.values():
collector.close(error)

def drain_errors(self) -> [RpcError]:
def get_errors(self) -> [TypeDBClientException]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new error typing

errors = []
with self._collectors_lock:
for collector in self._collectors.values():
errors.extend(collector.drain_errors())
for collector in self._response_queues.values():
error = collector.get_error()
if error is not None:
errors.append(error)
Comment on lines +61 to +63
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get error from each collector queue

return errors


class Queue(Generic[R]):

def __init__(self):
self._response_queue: queue.Queue[Union[Response[R], Done]] = queue.Queue()
self._response_queue: queue.Queue[Response] = queue.Queue()
self._error: TypeDBClientException = None

def get(self, block: bool) -> R:
response = self._response_queue.get(block=block)
if response.message:
return response.message
elif response.error:
raise TypeDBClientException.of_rpc(response.error)
else:
if response.is_value():
return response.value
elif response.is_done() and self._error is None:
raise TypeDBClientException.of(TRANSACTION_CLOSED)
elif response.is_done() and self._error is not None:
raise TypeDBClientException.of(TRANSACTION_CLOSED_WITH_ERRORS, self._error)
else:
raise TypeDBClientException.of(ILLEGAL_STATE)

def put(self, response: R):
self._response_queue.put(Response(response))
self._response_queue.put(ValueResponse(response))

def close(self, error: Optional[RpcError]):
self._response_queue.put(Done(error))
def close(self, error: Optional[TypeDBClientException]):
self._error = error
self._response_queue.put(DoneResponse())

def drain_errors(self) -> [RpcError]:
errors = []
while not self._response_queue.empty():
response = self._response_queue.get(block = False)
if response.error:
errors.append(response.error)
return errors
def get_error(self) -> TypeDBClientException:
return self._error


class Response:

class Response(Generic[R]):
def is_value(self):
return False

def is_done(self):
return False


class ValueResponse(Response, Generic[R]):

def __init__(self, value: R):
self.message = value
self.value = value

def is_value(self):
return True

class Done:

def __init__(self, error: Optional[RpcError]):
self.error = error
class DoneResponse(Response):

def __init__(self):
pass

def is_done(self):
return True
4 changes: 2 additions & 2 deletions typedb/stream/response_part_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
# specific language governing permissions and limitations
# under the License.
#
from enum import Enum
from typing import Iterator, TYPE_CHECKING
from uuid import UUID

import typedb_protocol.common.transaction_pb2 as transaction_proto

from enum import Enum
from typedb.common.exception import TypeDBClientException, ILLEGAL_ARGUMENT, MISSING_RESPONSE, ILLEGAL_STATE
from typedb.common.rpc.request_builder import transaction_stream_req
from typedb.stream.request_transmitter import RequestTransmitter
Expand Down Expand Up @@ -78,6 +77,7 @@ def _has_next(self) -> bool:

def __next__(self) -> transaction_proto.Transaction.ResPart:
if not self._has_next():
self._bidirectional_stream.done(self._request_id)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also let the transaction stream know this query stream is done so it can be removed

raise StopIteration
self._state = ResponsePartIterator.State.EMPTY
return self._next