From 13914a329a1b6249d37ccd7a7d3c82eb111c7dec Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Wed, 27 Mar 2024 12:31:06 +1000 Subject: [PATCH 1/9] Support 3.9 in the extension --- pyproject.toml | 2 +- src/dispatch/experimental/durable/frame.c | 18 ++- src/dispatch/experimental/durable/frame309.h | 144 +++++++++++++++++++ 3 files changed, 156 insertions(+), 8 deletions(-) create mode 100644 src/dispatch/experimental/durable/frame309.h diff --git a/pyproject.toml b/pyproject.toml index 546014b5..125732bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "dispatch-py" description = "Develop reliable distributed systems on the Dispatch platform." readme = "README.md" dynamic = ["version"] -requires-python = ">= 3.10" +requires-python = ">= 3.9" dependencies = [ "grpcio >= 1.60.0", "protobuf >= 4.24.0", diff --git a/src/dispatch/experimental/durable/frame.c b/src/dispatch/experimental/durable/frame.c index 1cfb6c6a..b3bb4517 100644 --- a/src/dispatch/experimental/durable/frame.c +++ b/src/dispatch/experimental/durable/frame.c @@ -6,11 +6,12 @@ #define PY_SSIZE_T_CLEAN #include -#if PY_MAJOR_VERSION != 3 || (PY_MINOR_VERSION < 10 || PY_MINOR_VERSION > 13) -# error Python 3.10-3.13 is required +#if PY_MAJOR_VERSION != 3 || (PY_MINOR_VERSION < 9 || PY_MINOR_VERSION > 13) +# error Python 3.9-3.13 is required #endif -// This is a redefinition of the private PyTryBlock from 3.10. +// This is a redefinition of the private PyTryBlock from <= 3.10. +// https://github.com/python/cpython/blob/3.9/Include/cpython/frameobject.h#L11 // https://github.com/python/cpython/blob/3.10/Include/cpython/frameobject.h#L22 typedef struct { int b_type; @@ -18,7 +19,8 @@ typedef struct { int b_level; } PyTryBlock; -// This is a redefinition of the private PyCoroWrapper from 3.10-3.13. +// This is a redefinition of the private PyCoroWrapper from 3.9-3.13. +// https://github.com/python/cpython/blob/3.9/Objects/genobject.c#L830 // https://github.com/python/cpython/blob/3.10/Objects/genobject.c#L884 // https://github.com/python/cpython/blob/3.11/Objects/genobject.c#L1016 // https://github.com/python/cpython/blob/3.12/Objects/genobject.c#L1003 @@ -51,7 +53,9 @@ static int get_frame_iblock(Frame *frame); static void set_frame_iblock(Frame *frame, int iblock); static PyTryBlock *get_frame_blockstack(Frame *frame); -#if PY_MINOR_VERSION == 10 +#if PY_MINOR_VERSION == 9 +#include "frame309.h" +#elif PY_MINOR_VERSION == 10 #include "frame310.h" #elif PY_MINOR_VERSION == 11 #include "frame311.h" @@ -78,7 +82,7 @@ static const char *get_type_name(PyObject *obj) { static PyGenObject *get_generator_like_object(PyObject *obj) { if (PyGen_Check(obj) || PyCoro_CheckExact(obj) || PyAsyncGen_CheckExact(obj)) { - // Note: In Python 3.10-3.13, the PyGenObject, PyCoroObject and PyAsyncGenObject + // Note: In Python 3.9-3.13, the PyGenObject, PyCoroObject and PyAsyncGenObject // have the same layout, they just have different field prefixes (gi_, cr_, ag_). // We cast to PyGenObject here so that the remainder of the code can use the gi_ // prefix for all three cases. @@ -386,7 +390,7 @@ static PyObject *ext_set_frame_stack_at(PyObject *self, PyObject *args) { } PyObject **localsplus = get_frame_localsplus(frame); PyObject *prev = localsplus[index]; - if (Py_IsTrue(unset)) { + if (PyObject_IsTrue(unset)) { localsplus[index] = NULL; } else { Py_INCREF(stack_obj); diff --git a/src/dispatch/experimental/durable/frame309.h b/src/dispatch/experimental/durable/frame309.h new file mode 100644 index 00000000..31de5fdf --- /dev/null +++ b/src/dispatch/experimental/durable/frame309.h @@ -0,0 +1,144 @@ +// This is a redefinition of the private/opaque frame object. +// https://github.com/python/cpython/blob/3.9/Include/cpython/frameobject.h#L17 +// +// In Python <= 3.10, `struct _frame` is both the PyFrameObject and +// PyInterpreterFrame. From Python 3.11 onwards, the two were split with the +// PyFrameObject (struct _frame) pointing to struct _PyInterpreterFrame. +struct Frame { + PyObject_VAR_HEAD + struct Frame *f_back; // struct _frame + PyCodeObject *f_code; + PyObject *f_builtins; + PyObject *f_globals; + PyObject *f_locals; + PyObject **f_valuestack; + PyObject **f_stacktop; + PyObject *f_trace; + char f_trace_lines; + char f_trace_opcodes; + PyObject *f_gen; + int f_lasti; + int f_lineno; + int f_iblock; + char f_executing; + PyTryBlock f_blockstack[CO_MAXBLOCKS]; + PyObject *f_localsplus[1]; +}; + +// Python 3.9 and prior didn't have an explicit enum of frame states, +// but we can derive them based on the presence of a frame, and other +// information found on the frame, for compatibility with later versions. +typedef enum _framestate { + FRAME_CREATED = -2, + FRAME_EXECUTING = 0, + FRAME_CLEARED = 4 +} FrameState; + +/* +// This is the definition of PyGenObject for reference to developers +// working on the extension. +// +// Note that PyCoroObject and PyAsyncGenObject have the same layout as +// PyGenObject, however the struct fields have a cr_ and ag_ prefix +// (respectively) rather than a gi_ prefix. In Python <= 3.10, PyCoroObject +// and PyAsyncGenObject have extra fields compared to PyGenObject. In Python +// 3.11 onwards, the three objects are identical (except for field name +// prefixes). The extra fields in Python <= 3.10 are not applicable to the +// extension at this time. +// +// https://github.com/python/cpython/blob/3.9/Include/genobject.h#L15 +typedef struct { + PyObject_HEAD + PyFrameObject *gi_frame; + char gi_running; + PyObject *gi_code; + PyObject *gi_weakreflist; + PyObject *gi_name; + PyObject *gi_qualname; + _PyErr_StackItem gi_exc_state; +} PyGenObject; +*/ + +static Frame *get_frame(PyGenObject *gen_like) { + Frame *frame = (Frame *)(gen_like->gi_frame); + assert(frame); + return frame; +} + +static PyCodeObject *get_frame_code(Frame *frame) { + PyCodeObject *code = frame->f_code; + assert(code); + return code; +} + +static int get_frame_lasti(Frame *frame) { + return frame->f_lasti; +} + +static void set_frame_lasti(Frame *frame, int lasti) { + frame->f_lasti = lasti; +} + +static int get_frame_state(PyGenObject *gen_like) { + // Python 3.9 doesn't have frame states, but we can derive + // some for compatibility with later versions and to simplify + // the extension. + Frame *frame = (Frame *)(gen_like->gi_frame); + if (!frame) { + return FRAME_CLEARED; + } + return frame->f_executing ? FRAME_EXECUTING : FRAME_CREATED; +} + +static void set_frame_state(PyGenObject *gen_like, int fs) { + Frame *frame = get_frame(gen_like); + frame->f_executing = (fs == FRAME_EXECUTING); +} + +static int valid_frame_state(int fs) { + return fs == FRAME_CREATED || fs == FRAME_EXECUTING || fs == FRAME_CLEARED; +} + +static int get_frame_stacktop_limit(Frame *frame) { + PyCodeObject *code = get_frame_code(frame); + return code->co_stacksize + code->co_nlocals; +} + +static int get_frame_stacktop(Frame *frame) { + assert(frame->f_localsplus); + int stacktop = (int)(frame->f_stacktop - frame->f_localsplus); + assert(stacktop >= 0 && stacktop < get_frame_stacktop_limit(frame)); + return stacktop; +} + +static void set_frame_stacktop(Frame *frame, int stacktop) { + assert(stacktop >= 0 && stacktop < get_frame_stacktop_limit(frame)); + assert(frame->f_localsplus); + frame->f_stacktop = frame->f_localsplus + stacktop; +} + +static PyObject **get_frame_localsplus(Frame *frame) { + PyObject **localsplus = frame->f_localsplus; + assert(localsplus); + return localsplus; +} + +static int get_frame_iblock_limit(Frame *frame) { + return CO_MAXBLOCKS; +} + +static int get_frame_iblock(Frame *frame) { + return frame->f_iblock; +} + +static void set_frame_iblock(Frame *frame, int iblock) { + assert(iblock >= 0 && iblock < get_frame_iblock_limit(frame)); + frame->f_iblock = iblock; +} + +static PyTryBlock *get_frame_blockstack(Frame *frame) { + PyTryBlock *blockstack = frame->f_blockstack; + assert(blockstack); + return blockstack; +} + From 09d390ee3d2ab4b90dadd0ab3560684f0d755465 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Wed, 27 Mar 2024 12:43:15 +1000 Subject: [PATCH 2/9] Avoid match keyword (PEP 636) which wasn't added until Python 3.10 --- src/dispatch/function.py | 22 +++--- src/dispatch/integrations/http.py | 50 ++++++------ src/dispatch/integrations/httpx.py | 17 ++-- src/dispatch/integrations/openai.py | 9 +-- src/dispatch/integrations/requests.py | 17 ++-- src/dispatch/integrations/slack.py | 6 +- src/dispatch/scheduler.py | 108 +++++++++++--------------- src/dispatch/status.py | 46 ++++++----- src/dispatch/test/service.py | 13 ++-- 9 files changed, 135 insertions(+), 153 deletions(-) diff --git a/src/dispatch/function.py b/src/dispatch/function.py index 97f18176..0c8daae9 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -308,13 +308,12 @@ def __setstate__(self, state): def _init_stub(self): result = urlparse(self.api_url) - match result.scheme: - case "http": - creds = grpc.local_channel_credentials() - case "https": - creds = grpc.ssl_channel_credentials() - case _: - raise ValueError(f"Invalid API scheme: '{result.scheme}'") + if result.scheme == "http": + creds = grpc.local_channel_credentials() + elif result.scheme == "https": + creds = grpc.ssl_channel_credentials() + else: + raise ValueError(f"Invalid API scheme: '{result.scheme}'") call_creds = grpc.access_token_call_credentials(self.api_key) creds = grpc.composite_channel_credentials(creds, call_creds) @@ -344,11 +343,10 @@ def dispatch(self, calls: Iterable[Call]) -> list[DispatchID]: resp = self._stub.Dispatch(req) except grpc.RpcError as e: status_code = e.code() - match status_code: - case grpc.StatusCode.UNAUTHENTICATED: - raise PermissionError( - f"Dispatch received an invalid authentication token (check {self.api_key_from} is correct)" - ) from e + if status_code == grpc.StatusCode.UNAUTHENTICATED: + raise PermissionError( + f"Dispatch received an invalid authentication token (check {self.api_key_from} is correct)" + ) from e raise dispatch_ids = [DispatchID(x) for x in resp.dispatch_ids] diff --git a/src/dispatch/integrations/http.py b/src/dispatch/integrations/http.py index 6846deac..19c3c263 100644 --- a/src/dispatch/integrations/http.py +++ b/src/dispatch/integrations/http.py @@ -4,33 +4,31 @@ def http_response_code_status(code: int) -> Status: """Returns a Status that's broadly equivalent to an HTTP response status code.""" - match code: - case 400: # Bad Request - return Status.INVALID_ARGUMENT - case 401: # Unauthorized - return Status.UNAUTHENTICATED - case 403: # Forbidden - return Status.PERMISSION_DENIED - case 404: # Not Found - return Status.NOT_FOUND - case 408: # Request Timeout - return Status.TIMEOUT - case 429: # Too Many Requests - return Status.THROTTLED - case 501: # Not Implemented - return Status.PERMANENT_ERROR + if code == 400: # Bad Request + return Status.INVALID_ARGUMENT + elif code == 401: # Unauthorized + return Status.UNAUTHENTICATED + elif code == 403: # Forbidden + return Status.PERMISSION_DENIED + elif code == 404: # Not Found + return Status.NOT_FOUND + elif code == 408: # Request Timeout + return Status.TIMEOUT + elif code == 429: # Too Many Requests + return Status.THROTTLED + elif code == 501: # Not Implemented + return Status.PERMANENT_ERROR category = code // 100 - match category: - case 1: # 1xx informational - return Status.PERMANENT_ERROR - case 2: # 2xx success - return Status.OK - case 3: # 3xx redirection - return Status.PERMANENT_ERROR - case 4: # 4xx client error - return Status.PERMANENT_ERROR - case 5: # 5xx server error - return Status.TEMPORARY_ERROR + if category == 1: # 1xx informational + return Status.PERMANENT_ERROR + elif category == 2: # 2xx success + return Status.OK + elif category == 3: # 3xx redirection + return Status.PERMANENT_ERROR + elif category == 4: # 4xx client error + return Status.PERMANENT_ERROR + elif category == 5: # 5xx server error + return Status.TEMPORARY_ERROR return Status.UNSPECIFIED diff --git a/src/dispatch/integrations/httpx.py b/src/dispatch/integrations/httpx.py index 3d60a65a..64a07588 100644 --- a/src/dispatch/integrations/httpx.py +++ b/src/dispatch/integrations/httpx.py @@ -6,15 +6,14 @@ def httpx_error_status(error: Exception) -> Status: # See https://www.python-httpx.org/exceptions/ - match error: - case httpx.HTTPStatusError(): - return httpx_response_status(error.response) - case httpx.InvalidURL(): - return Status.INVALID_ARGUMENT - case httpx.UnsupportedProtocol(): - return Status.INVALID_ARGUMENT - case httpx.TimeoutException(): - return Status.TIMEOUT + if isinstance(error, httpx.HTTPStatusError): + return httpx_response_status(error.response) + elif isinstance(error, httpx.InvalidURL): + return Status.INVALID_ARGUMENT + elif isinstance(error, httpx.UnsupportedProtocol): + return Status.INVALID_ARGUMENT + elif isinstance(error, httpx.TimeoutException): + return Status.TIMEOUT return Status.TEMPORARY_ERROR diff --git a/src/dispatch/integrations/openai.py b/src/dispatch/integrations/openai.py index 0d781fe4..533133d4 100644 --- a/src/dispatch/integrations/openai.py +++ b/src/dispatch/integrations/openai.py @@ -6,11 +6,10 @@ def openai_error_status(error: Exception) -> Status: # See https://github.com/openai/openai-python/blob/main/src/openai/_exceptions.py - match error: - case openai.APITimeoutError(): - return Status.TIMEOUT - case openai.APIStatusError(): - return http_response_code_status(error.status_code) + if isinstance(error, openai.APITimeoutError): + return Status.TIMEOUT + elif isinstance(error, openai.APIStatusError): + return http_response_code_status(error.status_code) return Status.TEMPORARY_ERROR diff --git a/src/dispatch/integrations/requests.py b/src/dispatch/integrations/requests.py index b61ed21c..89de804f 100644 --- a/src/dispatch/integrations/requests.py +++ b/src/dispatch/integrations/requests.py @@ -7,14 +7,15 @@ def requests_error_status(error: Exception) -> Status: # See https://requests.readthedocs.io/en/latest/api/#exceptions # and https://requests.readthedocs.io/en/latest/_modules/requests/exceptions/ - match error: - case requests.HTTPError(): - if error.response is not None: - return requests_response_status(error.response) - case requests.Timeout(): - return Status.TIMEOUT - case ValueError(): # base class of things like requests.InvalidURL, etc. - return Status.INVALID_ARGUMENT + if isinstance(error, requests.HTTPError): + if error.response is not None: + return requests_response_status(error.response) + elif isinstance(error, requests.Timeout): + return Status.TIMEOUT + elif isinstance( + error, ValueError + ): # base class of things like requests.InvalidURL, etc. + return Status.INVALID_ARGUMENT return Status.TEMPORARY_ERROR diff --git a/src/dispatch/integrations/slack.py b/src/dispatch/integrations/slack.py index 78040945..28e73718 100644 --- a/src/dispatch/integrations/slack.py +++ b/src/dispatch/integrations/slack.py @@ -8,10 +8,8 @@ def slack_error_status(error: Exception) -> Status: # See https://github.com/slackapi/python-slack-sdk/blob/main/slack/errors.py - match error: - case slack_sdk.errors.SlackApiError(): - if error.response is not None: - return slack_response_status(error.response) + if isinstance(error, slack_sdk.errors.SlackApiError) and error.response is not None: + return slack_response_status(error.response) return Status.TEMPORARY_ERROR diff --git a/src/dispatch/scheduler.py b/src/dispatch/scheduler.py index 89c4d7af..1e46a5fe 100644 --- a/src/dispatch/scheduler.py +++ b/src/dispatch/scheduler.py @@ -162,13 +162,12 @@ def error(self) -> Exception | None: return self.generic_error if self.first_result is not None or len(self.errors) == 0: return None - match len(self.errors): - case 0: - return None - case 1: - return self.errors[self.order[0]] - case _: - return AnyException([self.errors[id] for id in self.order]) + if len(self.errors) == 0: + return None + elif len(self.errors) == 1: + return self.errors[self.order[0]] + else: + return AnyException([self.errors[id] for id in self.order]) def value(self) -> Any: assert self.ready() @@ -470,60 +469,47 @@ def _run(self, input: Input) -> Output: # Handle coroutines that yield. logger.debug("%s yielded %s", coroutine, coroutine_yield) - match coroutine_yield: - case Call(): - call = coroutine_yield - call_id = state.next_call_id - state.next_call_id += 1 - call.correlation_id = correlation_id(coroutine.id, call_id) - logger.debug( - "enqueuing call %d (%s) for %s", - call_id, - call.function, - coroutine, - ) - pending_calls.append(call) - coroutine.result = CallFuture() - state.suspended[coroutine.id] = coroutine - state.prev_callers.append(coroutine) - state.outstanding_calls += 1 - - case AllDirective(): - children = spawn_children( - state, coroutine, coroutine_yield.awaitables - ) - - child_ids = [child.id for child in children] - coroutine.result = AllFuture( - order=child_ids, waiting=set(child_ids) - ) - state.suspended[coroutine.id] = coroutine - - case AnyDirective(): - children = spawn_children( - state, coroutine, coroutine_yield.awaitables - ) - - child_ids = [child.id for child in children] - coroutine.result = AnyFuture( - order=child_ids, waiting=set(child_ids) - ) - state.suspended[coroutine.id] = coroutine - - case RaceDirective(): - children = spawn_children( - state, coroutine, coroutine_yield.awaitables - ) - - coroutine.result = RaceFuture( - waiting={child.id for child in children} - ) - state.suspended[coroutine.id] = coroutine - - case _: - raise RuntimeError( - f"coroutine unexpectedly yielded '{coroutine_yield}'" - ) + if isinstance(coroutine_yield, Call): + call = coroutine_yield + call_id = state.next_call_id + state.next_call_id += 1 + call.correlation_id = correlation_id(coroutine.id, call_id) + logger.debug( + "enqueuing call %d (%s) for %s", + call_id, + call.function, + coroutine, + ) + pending_calls.append(call) + coroutine.result = CallFuture() + state.suspended[coroutine.id] = coroutine + state.prev_callers.append(coroutine) + state.outstanding_calls += 1 + + elif isinstance(coroutine_yield, AllDirective): + children = spawn_children(state, coroutine, coroutine_yield.awaitables) + + child_ids = [child.id for child in children] + coroutine.result = AllFuture(order=child_ids, waiting=set(child_ids)) + state.suspended[coroutine.id] = coroutine + + elif isinstance(coroutine_yield, AnyDirective): + children = spawn_children(state, coroutine, coroutine_yield.awaitables) + + child_ids = [child.id for child in children] + coroutine.result = AnyFuture(order=child_ids, waiting=set(child_ids)) + state.suspended[coroutine.id] = coroutine + + elif isinstance(coroutine_yield, RaceDirective): + children = spawn_children(state, coroutine, coroutine_yield.awaitables) + + coroutine.result = RaceFuture(waiting={child.id for child in children}) + state.suspended[coroutine.id] = coroutine + + else: + raise RuntimeError( + f"coroutine unexpectedly yielded '{coroutine_yield}'" + ) # Serialize coroutines and scheduler state. logger.debug("serializing state") diff --git a/src/dispatch/status.py b/src/dispatch/status.py index 5a413802..1a8f34d2 100644 --- a/src/dispatch/status.py +++ b/src/dispatch/status.py @@ -92,27 +92,31 @@ def status_for_error(error: Exception) -> Status: # If not, resort to standard error categorization. # # See https://docs.python.org/3/library/exceptions.html - match error: - case IncompatibleStateError(): - return Status.INCOMPATIBLE_STATE - case TimeoutError(): - return Status.TIMEOUT - case TypeError() | ValueError(): - return Status.INVALID_ARGUMENT - case ConnectionError(): - return Status.TCP_ERROR - case PermissionError(): - return Status.PERMISSION_DENIED - case FileNotFoundError(): - return Status.NOT_FOUND - case EOFError() | InterruptedError() | KeyboardInterrupt() | OSError(): - # For OSError, we might want to categorize the values of errnon - # to determine whether the error is temporary or permanent. - # - # In general, permanent errors from the OS are rare because they - # tend to be caused by invalid use of syscalls, which are - # unlikely at higher abstraction levels. - return Status.TEMPORARY_ERROR + if isinstance(error, IncompatibleStateError): + return Status.INCOMPATIBLE_STATE + elif isinstance(error, TimeoutError): + return Status.TIMEOUT + elif isinstance(error, TypeError) or isinstance(error, ValueError): + return Status.INVALID_ARGUMENT + elif isinstance(error, ConnectionError): + return Status.TCP_ERROR + elif isinstance(error, PermissionError): + return Status.PERMISSION_DENIED + elif isinstance(error, FileNotFoundError): + return Status.NOT_FOUND + elif ( + isinstance(error, EOFError) + or isinstance(error, InterruptedError) + or isinstance(error, KeyboardInterrupt) + or isinstance(error, OSError) + ): + # For OSError, we might want to categorize the values of errnon + # to determine whether the error is temporary or permanent. + # + # In general, permanent errors from the OS are rare because they + # tend to be caused by invalid use of syscalls, which are + # unlikely at higher abstraction levels. + return Status.TEMPORARY_ERROR return Status.PERMANENT_ERROR diff --git a/src/dispatch/test/service.py b/src/dispatch/test/service.py index 09729396..edcf6759 100644 --- a/src/dispatch/test/service.py +++ b/src/dispatch/test/service.py @@ -143,13 +143,12 @@ def dispatch_calls(self): while self.queue: dispatch_id, request, call_type = self.queue.pop(0) - match call_type: - case CallType.CALL: - logger.info("calling function %s", request.function) - case CallType.RESUME: - logger.info("resuming function %s", request.function) - case CallType.RETRY: - logger.info("retrying function %s", request.function) + if call_type == CallType.CALL: + logger.info("calling function %s", request.function) + elif call_type == CallType.RESUME: + logger.info("resuming function %s", request.function) + elif call_type == CallType.RETRY: + logger.info("retrying function %s", request.function) try: response = self.endpoint_client.run(request) From 52811f939451811e6e91accde1cf7ba199c28afb Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Wed, 27 Mar 2024 12:43:21 +1000 Subject: [PATCH 3/9] Add 3.9 to test matrix --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1acf51d1..c5fc2f47 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python: ['3.10', '3.11', '3.12'] + python: ['3.9', '3.10', '3.11', '3.12'] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python }} From 5788dcf213c2e9102781cbbee10e8ce7e2150bb5 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Wed, 27 Mar 2024 12:45:13 +1000 Subject: [PATCH 4/9] Avoid @dataclass(slots=True) which wasn't added until Python 3.10 --- src/dispatch/coroutine.py | 6 +++--- src/dispatch/experimental/durable/registry.py | 2 +- src/dispatch/proto.py | 10 +++++----- src/dispatch/scheduler.py | 16 ++++++++-------- src/dispatch/signature/key.py | 2 +- src/dispatch/signature/request.py | 2 +- 6 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/dispatch/coroutine.py b/src/dispatch/coroutine.py index cf9d4c93..79701a73 100644 --- a/src/dispatch/coroutine.py +++ b/src/dispatch/coroutine.py @@ -48,17 +48,17 @@ def race(*awaitables: Awaitable[Any]) -> list[Any]: # type: ignore[misc] return (yield RaceDirective(awaitables)) -@dataclass(slots=True) +@dataclass class AllDirective: awaitables: tuple[Awaitable[Any], ...] -@dataclass(slots=True) +@dataclass class AnyDirective: awaitables: tuple[Awaitable[Any], ...] -@dataclass(slots=True) +@dataclass class RaceDirective: awaitables: tuple[Awaitable[Any], ...] diff --git a/src/dispatch/experimental/durable/registry.py b/src/dispatch/experimental/durable/registry.py index 3a5d9765..da8f2c28 100644 --- a/src/dispatch/experimental/durable/registry.py +++ b/src/dispatch/experimental/durable/registry.py @@ -3,7 +3,7 @@ from types import FunctionType -@dataclass(slots=True) +@dataclass class RegisteredFunction: """A function that can be referenced in durable state.""" diff --git a/src/dispatch/proto.py b/src/dispatch/proto.py index 3e0d90e0..3782108a 100644 --- a/src/dispatch/proto.py +++ b/src/dispatch/proto.py @@ -139,7 +139,7 @@ def from_poll_results( ) -@dataclass(slots=True) +@dataclass class Arguments: """A container for positional and keyword arguments.""" @@ -147,7 +147,7 @@ class Arguments: kwargs: dict[str, Any] -@dataclass(slots=True) +@dataclass class Output: """The output of a primitive function. @@ -240,7 +240,7 @@ def poll( # the current Python process. -@dataclass(slots=True) +@dataclass class Call: """Instruction to call a function. @@ -263,7 +263,7 @@ def _as_proto(self) -> call_pb.Call: ) -@dataclass(slots=True) +@dataclass class CallResult: """Result of a Call.""" @@ -305,7 +305,7 @@ def from_error(cls, error: Error, correlation_id: int | None = None) -> CallResu return CallResult(correlation_id=correlation_id, error=error) -@dataclass(slots=True) +@dataclass class Error: """Error when running a function. diff --git a/src/dispatch/scheduler.py b/src/dispatch/scheduler.py index 1e46a5fe..c97bce94 100644 --- a/src/dispatch/scheduler.py +++ b/src/dispatch/scheduler.py @@ -17,7 +17,7 @@ CorrelationID: TypeAlias = int -@dataclass(slots=True) +@dataclass class CoroutineResult: """The result from running a coroutine to completion.""" @@ -26,7 +26,7 @@ class CoroutineResult: error: Exception | None = None -@dataclass(slots=True) +@dataclass class CallResult: """The result of an asynchronous function call.""" @@ -47,7 +47,7 @@ def error(self) -> Exception | None: ... def value(self) -> Any: ... -@dataclass(slots=True) +@dataclass class CallFuture: """A future result of a dispatch.coroutine.call() operation.""" @@ -78,7 +78,7 @@ def value(self) -> Any: return self.result.value -@dataclass(slots=True) +@dataclass class AllFuture: """A future result of a dispatch.coroutine.all() operation.""" @@ -120,7 +120,7 @@ def value(self) -> list[Any]: return [self.results[id].value for id in self.order] -@dataclass(slots=True) +@dataclass class AnyFuture: """A future result of a dispatch.coroutine.any() operation.""" @@ -177,7 +177,7 @@ def value(self) -> Any: return self.first_result.value -@dataclass(slots=True) +@dataclass class RaceFuture: """A future result of a dispatch.coroutine.race() operation.""" @@ -217,7 +217,7 @@ def value(self) -> Any: return self.first_result.value if self.first_result else None -@dataclass(slots=True) +@dataclass class Coroutine: """An in-flight coroutine.""" @@ -241,7 +241,7 @@ def __repr__(self): return f"Coroutine({self.id}, {self.coroutine.__qualname__})" -@dataclass(slots=True) +@dataclass class State: """State of the scheduler and the coroutines it's managing.""" diff --git a/src/dispatch/signature/key.py b/src/dispatch/signature/key.py index 7fc1cee0..b13b4563 100644 --- a/src/dispatch/signature/key.py +++ b/src/dispatch/signature/key.py @@ -48,7 +48,7 @@ def private_key_from_bytes(key: bytes) -> Ed25519PrivateKey: return Ed25519PrivateKey.from_private_bytes(key) -@dataclass(slots=True) +@dataclass class KeyResolver(HTTPSignatureKeyResolver): """KeyResolver provides public and private keys. diff --git a/src/dispatch/signature/request.py b/src/dispatch/signature/request.py index 15ac24c9..256ed24a 100644 --- a/src/dispatch/signature/request.py +++ b/src/dispatch/signature/request.py @@ -3,7 +3,7 @@ from http_message_signatures.structures import CaseInsensitiveDict -@dataclass(slots=True) +@dataclass class Request: """A framework-agnostic representation of an HTTP request.""" From 9c5bda456b070bf66ee941b027f5b64f9947b580 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Wed, 27 Mar 2024 13:08:54 +1000 Subject: [PATCH 5/9] Avoid type|type (PEP 604) which wasn't added until Python 3.10 --- pyproject.toml | 3 +- src/dispatch/experimental/durable/frame.pyi | 26 ++++----- src/dispatch/experimental/durable/function.py | 18 +++--- src/dispatch/fastapi.py | 15 ++--- src/dispatch/function.py | 12 ++-- src/dispatch/id.py | 2 +- src/dispatch/proto.py | 42 +++++++------- src/dispatch/scheduler.py | 55 ++++++++++--------- src/dispatch/signature/digest.py | 5 +- src/dispatch/signature/key.py | 9 +-- src/dispatch/signature/request.py | 3 +- src/dispatch/test/client.py | 11 ++-- src/dispatch/test/service.py | 9 +-- tests/dispatch/test_scheduler.py | 6 +- 14 files changed, 112 insertions(+), 104 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 125732bf..95ab6ed6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,8 @@ dependencies = [ "tblib >= 3.0.0", "docopt >= 0.6.2", "types-docopt >= 0.6.11.4", - "httpx >= 0.27.0" + "httpx >= 0.27.0", + "typing_extensions >= 4.10" ] [project.optional-dependencies] diff --git a/src/dispatch/experimental/durable/frame.pyi b/src/dispatch/experimental/durable/frame.pyi index e701afd0..d20115d0 100644 --- a/src/dispatch/experimental/durable/frame.pyi +++ b/src/dispatch/experimental/durable/frame.pyi @@ -1,31 +1,31 @@ from types import FrameType -from typing import Any, AsyncGenerator, Coroutine, Generator, Tuple +from typing import Any, AsyncGenerator, Coroutine, Generator, Tuple, Union -def get_frame_ip(frame: FrameType | Coroutine | Generator | AsyncGenerator) -> int: +def get_frame_ip(frame: Union[FrameType, Coroutine, Generator, AsyncGenerator]) -> int: """Get instruction pointer of a generator or coroutine.""" -def set_frame_ip(frame: FrameType | Coroutine | Generator | AsyncGenerator, ip: int): +def set_frame_ip(frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], ip: int): """Set instruction pointer of a generator or coroutine.""" -def get_frame_sp(frame: FrameType | Coroutine | Generator | AsyncGenerator) -> int: +def get_frame_sp(frame: Union[FrameType, Coroutine, Generator, AsyncGenerator]) -> int: """Get stack pointer of a generator or coroutine.""" -def set_frame_sp(frame: FrameType | Coroutine | Generator | AsyncGenerator, sp: int): +def set_frame_sp(frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], sp: int): """Set stack pointer of a generator or coroutine.""" -def get_frame_bp(frame: FrameType | Coroutine | Generator | AsyncGenerator) -> int: +def get_frame_bp(frame: Union[FrameType, Coroutine, Generator, AsyncGenerator]) -> int: """Get block pointer of a generator or coroutine.""" -def set_frame_bp(frame: FrameType | Coroutine | Generator | AsyncGenerator, bp: int): +def set_frame_bp(frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], bp: int): """Set block pointer of a generator or coroutine.""" def get_frame_stack_at( - frame: FrameType | Coroutine | Generator | AsyncGenerator, index: int + frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], index: int ) -> Tuple[bool, Any]: """Get an object from a generator or coroutine's stack, as an (is_null, obj) tuple.""" def set_frame_stack_at( - frame: FrameType | Coroutine | Generator | AsyncGenerator, + frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], index: int, unset: bool, value: Any, @@ -33,23 +33,23 @@ def set_frame_stack_at( """Set or unset an object on the stack of a generator or coroutine.""" def get_frame_block_at( - frame: FrameType | Coroutine | Generator | AsyncGenerator, index: int + frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], index: int ) -> Tuple[int, int, int]: """Get a block from a generator or coroutine.""" def set_frame_block_at( - frame: FrameType | Coroutine | Generator | AsyncGenerator, + frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], index: int, value: Tuple[int, int, int], ): """Restore a block of a generator or coroutine.""" def get_frame_state( - frame: FrameType | Coroutine | Generator | AsyncGenerator, + frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], ) -> int: """Get frame state of a generator or coroutine.""" def set_frame_state( - frame: FrameType | Coroutine | Generator | AsyncGenerator, state: int + frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], state: int ): """Set frame state of a generator or coroutine.""" diff --git a/src/dispatch/experimental/durable/function.py b/src/dispatch/experimental/durable/function.py index a740be41..3c721540 100644 --- a/src/dispatch/experimental/durable/function.py +++ b/src/dispatch/experimental/durable/function.py @@ -9,7 +9,7 @@ MethodType, TracebackType, ) -from typing import Any, Callable, Coroutine, Generator, TypeVar, Union, cast +from typing import Any, Callable, Coroutine, Generator, TypeVar, Union, cast, Optional from . import frame as ext from .registry import RegisteredFunction, lookup_function, register_function @@ -75,7 +75,7 @@ class Serializable: "__qualname__", ) - g: GeneratorType | CoroutineType + g: Union[GeneratorType, CoroutineType] registered_fn: RegisteredFunction wrapped_coroutine: Union["DurableCoroutine", None] args: tuple[Any, ...] @@ -83,7 +83,7 @@ class Serializable: def __init__( self, - g: GeneratorType | CoroutineType, + g: Union[GeneratorType, CoroutineType], registered_fn: RegisteredFunction, wrapped_coroutine: Union["DurableCoroutine", None], *args: Any, @@ -243,7 +243,7 @@ def __await__(self) -> Generator[Any, None, _ReturnT]: def send(self, send: _SendT) -> _YieldT: return self.coroutine.send(send) - def throw(self, typ, val=None, tb: TracebackType | None = None) -> _YieldT: + def throw(self, typ, val=None, tb: Optional[TracebackType] = None) -> _YieldT: return self.coroutine.throw(typ, val, tb) def close(self) -> None: @@ -270,11 +270,11 @@ def cr_frame(self) -> FrameType: return self.coroutine.cr_frame @property - def cr_await(self) -> Any | None: + def cr_await(self) -> Any: return self.coroutine.cr_await @property - def cr_origin(self) -> tuple[tuple[str, int, str], ...] | None: + def cr_origin(self) -> Optional[tuple[tuple[str, int, str], ...]]: return self.coroutine.cr_origin def __repr__(self) -> str: @@ -291,7 +291,7 @@ def __init__( self, generator: GeneratorType, registered_fn: RegisteredFunction, - coroutine: DurableCoroutine | None, + coroutine: Optional[DurableCoroutine], *args: Any, **kwargs: Any, ): @@ -309,7 +309,7 @@ def __next__(self) -> _YieldT: def send(self, send: _SendT) -> _YieldT: return self.generator.send(send) - def throw(self, typ, val=None, tb: TracebackType | None = None) -> _YieldT: + def throw(self, typ, val=None, tb: Optional[TracebackType] = None) -> _YieldT: return self.generator.throw(typ, val, tb) def close(self) -> None: @@ -336,7 +336,7 @@ def gi_frame(self) -> FrameType: return self.generator.gi_frame @property - def gi_yieldfrom(self) -> GeneratorType | None: + def gi_yieldfrom(self) -> Optional[GeneratorType]: return self.generator.gi_yieldfrom def __repr__(self) -> str: diff --git a/src/dispatch/fastapi.py b/src/dispatch/fastapi.py index 239c3b36..9e7c450c 100644 --- a/src/dispatch/fastapi.py +++ b/src/dispatch/fastapi.py @@ -22,6 +22,7 @@ def read_root(): import os from datetime import timedelta from urllib.parse import urlparse +from typing import Optional, Union import fastapi import fastapi.responses @@ -51,10 +52,10 @@ class Dispatch(Registry): def __init__( self, app: fastapi.FastAPI, - endpoint: str | None = None, - verification_key: Ed25519PublicKey | str | bytes | None = None, - api_key: str | None = None, - api_url: str | None = None, + endpoint: Optional[str] = None, + verification_key: Optional[Union[Ed25519PublicKey, str, bytes]] = None, + api_key: Optional[str] = None, + api_url: Optional[str] = None, ): """Initialize a Dispatch endpoint, and integrate it into a FastAPI app. @@ -122,8 +123,8 @@ def __init__( def parse_verification_key( - verification_key: Ed25519PublicKey | str | bytes | None, -) -> Ed25519PublicKey | None: + verification_key: Optional[Union[Ed25519PublicKey, str, bytes]], +) -> Optional[Ed25519PublicKey]: if isinstance(verification_key, Ed25519PublicKey): return verification_key @@ -169,7 +170,7 @@ def __init__(self, status, code, message): self.message = message -def _new_app(function_registry: Dispatch, verification_key: Ed25519PublicKey | None): +def _new_app(function_registry: Dispatch, verification_key: Optional[Ed25519PublicKey]): app = fastapi.FastAPI() @app.exception_handler(_ConnectError) diff --git a/src/dispatch/function.py b/src/dispatch/function.py index 0c8daae9..5de267e2 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -12,11 +12,11 @@ Dict, Generic, Iterable, - ParamSpec, - TypeAlias, TypeVar, + Optional, overload, ) +from typing_extensions import ParamSpec, TypeAlias from urllib.parse import urlparse import grpc @@ -73,7 +73,7 @@ def _primitive_dispatch(self, input: Any = None) -> DispatchID: return dispatch_id def _build_primitive_call( - self, input: Any, correlation_id: int | None = None + self, input: Any, correlation_id: Optional[int] = None ) -> Call: return Call( correlation_id=correlation_id, @@ -137,7 +137,7 @@ def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID: return self._primitive_dispatch(Arguments(args, kwargs)) def build_call( - self, *args: P.args, correlation_id: int | None = None, **kwargs: P.kwargs + self, *args: P.args, correlation_id: Optional[int] = None, **kwargs: P.kwargs ) -> Call: """Create a Call for this function with the provided input. Useful to generate calls when using the Client. @@ -162,7 +162,7 @@ class Registry: __slots__ = ("functions", "endpoint", "client") def __init__( - self, endpoint: str, api_key: str | None = None, api_url: str | None = None + self, endpoint: str, api_key: Optional[str] = None, api_url: Optional[str] = None ): """Initialize a function registry. @@ -261,7 +261,7 @@ class Client: __slots__ = ("api_url", "api_key", "_stub", "api_key_from") - def __init__(self, api_key: None | str = None, api_url: None | str = None): + def __init__(self, api_key: Optional[str] = None, api_url: Optional[str] = None): """Create a new Dispatch client. Args: diff --git a/src/dispatch/id.py b/src/dispatch/id.py index ee3cce2a..d5f669be 100644 --- a/src/dispatch/id.py +++ b/src/dispatch/id.py @@ -1,4 +1,4 @@ -from typing import TypeAlias +from typing_extensions import TypeAlias DispatchID: TypeAlias = str """Unique identifier in Dispatch. diff --git a/src/dispatch/proto.py b/src/dispatch/proto.py index 3782108a..14a71724 100644 --- a/src/dispatch/proto.py +++ b/src/dispatch/proto.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from traceback import format_exception from types import TracebackType -from typing import Any +from typing import Any, Optional import google.protobuf.any_pb2 import google.protobuf.message @@ -97,7 +97,7 @@ def call_results(self) -> list[CallResult]: return self._call_results @property - def poll_error(self) -> Error | None: + def poll_error(self) -> Optional[Error]: self._assert_resume() return self._poll_error @@ -125,7 +125,7 @@ def from_poll_results( function: str, coroutine_state: Any, call_results: list[CallResult], - error: Error | None = None, + error: Optional[Error] = None, ): return Input( req=function_pb.RunRequest( @@ -163,7 +163,7 @@ def __init__(self, proto: function_pb.RunResponse): self._message = proto @classmethod - def value(cls, value: Any, status: Status | None = None) -> Output: + def value(cls, value: Any, status: Optional[Status] = None) -> Output: """Terminally exit the function with the provided return value.""" if status is None: status = status_for_output(value) @@ -183,8 +183,8 @@ def tail_call(cls, tail_call: Call) -> Output: @classmethod def exit( cls, - result: CallResult | None = None, - tail_call: Call | None = None, + result: Optional[CallResult] = None, + tail_call: Optional[Call] = None, status: Status = Status.OK, ) -> Output: """Terminally exit the function.""" @@ -201,10 +201,10 @@ def exit( def poll( cls, state: Any, - calls: None | list[Call] = None, + calls: Optional[list[Call]] = None, min_results: int = 1, max_results: int = 10, - max_wait_seconds: int | None = None, + max_wait_seconds: Optional[int] = None, ) -> Output: """Suspend the function with a set of Calls, instructing the orchestrator to resume the function with the provided state when @@ -249,9 +249,9 @@ class Call: """ function: str - input: Any | None = None - endpoint: str | None = None - correlation_id: int | None = None + input: Optional[Any] = None + endpoint: Optional[str] = None + correlation_id: Optional[int] = None def _as_proto(self) -> call_pb.Call: input_bytes = _pb_any_pickle(self.input) @@ -267,9 +267,9 @@ def _as_proto(self) -> call_pb.Call: class CallResult: """Result of a Call.""" - correlation_id: int | None = None - output: Any | None = None - error: Error | None = None + correlation_id: Optional[int] = None + output: Optional[Any] = None + error: Optional[Error] = None def _as_proto(self) -> call_pb.CallResult: output_any = None @@ -297,11 +297,11 @@ def _from_proto(cls, proto: call_pb.CallResult) -> CallResult: ) @classmethod - def from_value(cls, output: Any, correlation_id: int | None = None) -> CallResult: + def from_value(cls, output: Any, correlation_id: Optional[int] = None) -> CallResult: return CallResult(correlation_id=correlation_id, output=output) @classmethod - def from_error(cls, error: Error, correlation_id: int | None = None) -> CallResult: + def from_error(cls, error: Error, correlation_id: Optional[int] = None) -> CallResult: return CallResult(correlation_id=correlation_id, error=error) @@ -316,16 +316,16 @@ class Error: status: Status type: str message: str - value: Exception | None = None - traceback: bytes | None = None + value: Optional[Exception] = None + traceback: Optional[bytes] = None def __init__( self, status: Status, type: str, message: str, - value: Exception | None = None, - traceback: bytes | None = None, + value: Optional[Exception] = None, + traceback: Optional[bytes] = None, ): """Create a new Error. @@ -355,7 +355,7 @@ def __init__( self.traceback = "".join(format_exception(value)).encode("utf-8") @classmethod - def from_exception(cls, ex: Exception, status: Status | None = None) -> Error: + def from_exception(cls, ex: Exception, status: Optional[Status] = None) -> Error: """Create an Error from a Python exception, using its class qualified named as type. diff --git a/src/dispatch/scheduler.py b/src/dispatch/scheduler.py index c97bce94..bd86a441 100644 --- a/src/dispatch/scheduler.py +++ b/src/dispatch/scheduler.py @@ -2,7 +2,8 @@ import pickle import sys from dataclasses import dataclass, field -from typing import Any, Awaitable, Callable, Protocol, TypeAlias +from typing import Any, Awaitable, Callable, Protocol, Optional, Union +from typing_extensions import TypeAlias from dispatch.coroutine import AllDirective, AnyDirective, AnyException, RaceDirective from dispatch.error import IncompatibleStateError @@ -22,8 +23,8 @@ class CoroutineResult: """The result from running a coroutine to completion.""" coroutine_id: CoroutineID - value: Any | None = None - error: Exception | None = None + value: Optional[Any] = None + error: Optional[Exception] = None @dataclass @@ -31,18 +32,18 @@ class CallResult: """The result of an asynchronous function call.""" call_id: CallID - value: Any | None = None - error: Exception | None = None + value: Optional[Any] = None + error: Optional[Exception] = None class Future(Protocol): - def add_result(self, result: CallResult | CoroutineResult): ... + def add_result(self, result: Union[CallResult, CoroutineResult]): ... def add_error(self, error: Exception): ... def ready(self) -> bool: ... - def error(self) -> Exception | None: ... + def error(self) -> Optional[Exception]: ... def value(self) -> Any: ... @@ -51,10 +52,10 @@ def value(self) -> Any: ... class CallFuture: """A future result of a dispatch.coroutine.call() operation.""" - result: CallResult | None = None - first_error: Exception | None = None + result: Optional[CallResult] = None + first_error: Optional[Exception] = None - def add_result(self, result: CallResult | CoroutineResult): + def add_result(self, result: Union[CallResult, CoroutineResult]): assert isinstance(result, CallResult) if self.result is None: self.result = result @@ -68,7 +69,7 @@ def add_error(self, error: Exception): def ready(self) -> bool: return self.first_error is not None or self.result is not None - def error(self) -> Exception | None: + def error(self) -> Optional[Exception]: assert self.ready() return self.first_error @@ -85,9 +86,9 @@ class AllFuture: order: list[CoroutineID] = field(default_factory=list) waiting: set[CoroutineID] = field(default_factory=set) results: dict[CoroutineID, CoroutineResult] = field(default_factory=dict) - first_error: Exception | None = None + first_error: Optional[Exception] = None - def add_result(self, result: CallResult | CoroutineResult): + def add_result(self, result: Union[CallResult, CoroutineResult]): assert isinstance(result, CoroutineResult) try: @@ -109,7 +110,7 @@ def add_error(self, error: Exception): def ready(self) -> bool: return self.first_error is not None or len(self.waiting) == 0 - def error(self) -> Exception | None: + def error(self) -> Optional[Exception]: assert self.ready() return self.first_error @@ -126,11 +127,11 @@ class AnyFuture: order: list[CoroutineID] = field(default_factory=list) waiting: set[CoroutineID] = field(default_factory=set) - first_result: CoroutineResult | None = None + first_result: Optional[CoroutineResult] = None errors: dict[CoroutineID, Exception] = field(default_factory=dict) - generic_error: Exception | None = None + generic_error: Optional[Exception] = None - def add_result(self, result: CallResult | CoroutineResult): + def add_result(self, result: Union[CallResult, CoroutineResult]): assert isinstance(result, CoroutineResult) try: @@ -156,7 +157,7 @@ def ready(self) -> bool: or len(self.waiting) == 0 ) - def error(self) -> Exception | None: + def error(self) -> Optional[Exception]: assert self.ready() if self.generic_error is not None: return self.generic_error @@ -182,10 +183,10 @@ class RaceFuture: """A future result of a dispatch.coroutine.race() operation.""" waiting: set[CoroutineID] = field(default_factory=set) - first_result: CoroutineResult | None = None - first_error: Exception | None = None + first_result: Optional[CoroutineResult] = None + first_error: Optional[Exception] = None - def add_result(self, result: CallResult | CoroutineResult): + def add_result(self, result: Union[CallResult, CoroutineResult]): assert isinstance(result, CoroutineResult) if result.error is not None: @@ -208,7 +209,7 @@ def ready(self) -> bool: or len(self.waiting) == 0 ) - def error(self) -> Exception | None: + def error(self) -> Optional[Exception]: assert self.ready() return self.first_error @@ -222,9 +223,9 @@ class Coroutine: """An in-flight coroutine.""" id: CoroutineID - parent_id: CoroutineID | None - coroutine: DurableCoroutine | DurableGenerator - result: Future | None = None + parent_id: Optional[CoroutineID] + coroutine: Union[DurableCoroutine, DurableGenerator] + result: Optional[Future] = None def run(self) -> Any: if self.result is None: @@ -278,7 +279,7 @@ def __init__( version: str = sys.version, poll_min_results: int = 1, poll_max_results: int = 10, - poll_max_wait_seconds: int | None = None, + poll_max_wait_seconds: Optional[int] = None, ): """Initialize the scheduler. @@ -422,7 +423,7 @@ def _run(self, input: Input) -> Output: assert coroutine.id not in state.suspended coroutine_yield = None - coroutine_result: CoroutineResult | None = None + coroutine_result: Optional[CoroutineResult] = None try: coroutine_yield = coroutine.run() except StopIteration as e: diff --git a/src/dispatch/signature/digest.py b/src/dispatch/signature/digest.py index 4bc264a3..c5602126 100644 --- a/src/dispatch/signature/digest.py +++ b/src/dispatch/signature/digest.py @@ -1,11 +1,12 @@ import hashlib import hmac +from typing import Union import http_sfv from http_message_signatures import InvalidSignature -def generate_content_digest(body: str | bytes) -> str: +def generate_content_digest(body: Union[str, bytes]) -> str: """Returns a SHA-512 Content-Digest header, according to https://datatracker.ietf.org/doc/html/draft-ietf-httpbis-digest-headers-13 """ @@ -16,7 +17,7 @@ def generate_content_digest(body: str | bytes) -> str: return str(http_sfv.Dictionary({"sha-512": digest})) -def verify_content_digest(digest_header: str | bytes, body: str | bytes): +def verify_content_digest(digest_header: Union[str, bytes], body: Union[str, bytes]): """Verify a SHA-256 or SHA-512 Content-Digest header matches a request body.""" if isinstance(body, str): diff --git a/src/dispatch/signature/key.py b/src/dispatch/signature/key.py index b13b4563..84b44e9e 100644 --- a/src/dispatch/signature/key.py +++ b/src/dispatch/signature/key.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Union, Optional from cryptography.hazmat.primitives.asymmetric.ed25519 import ( Ed25519PrivateKey, @@ -11,7 +12,7 @@ from http_message_signatures import HTTPSignatureKeyResolver -def public_key_from_pem(pem: str | bytes) -> Ed25519PublicKey: +def public_key_from_pem(pem: Union[str, bytes]) -> Ed25519PublicKey: """Returns an Ed25519 public key given a PEM representation.""" if isinstance(pem, str): pem = pem.encode() @@ -28,7 +29,7 @@ def public_key_from_bytes(key: bytes) -> Ed25519PublicKey: def private_key_from_pem( - pem: str | bytes, password: bytes | None = None + pem: Union[str, bytes], password: Optional[bytes] = None ) -> Ed25519PrivateKey: """Returns an Ed25519 private key given a PEM representation and optional password.""" @@ -57,8 +58,8 @@ class KeyResolver(HTTPSignatureKeyResolver): """ key_id: str - public_key: Ed25519PublicKey | None = None - private_key: Ed25519PrivateKey | None = None + public_key: Optional[Ed25519PublicKey] = None + private_key: Optional[Ed25519PrivateKey] = None def resolve_public_key(self, key_id: str): if key_id != self.key_id or self.public_key is None: diff --git a/src/dispatch/signature/request.py b/src/dispatch/signature/request.py index 256ed24a..ee6f13fc 100644 --- a/src/dispatch/signature/request.py +++ b/src/dispatch/signature/request.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Union from http_message_signatures.structures import CaseInsensitiveDict @@ -10,4 +11,4 @@ class Request: method: str url: str headers: CaseInsensitiveDict - body: str | bytes + body: Union[str, bytes] diff --git a/src/dispatch/test/client.py b/src/dispatch/test/client.py index 0b3f4539..01078aec 100644 --- a/src/dispatch/test/client.py +++ b/src/dispatch/test/client.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Optional import fastapi import grpc @@ -25,7 +26,7 @@ class EndpointClient: """ def __init__( - self, http_client: httpx.Client, signing_key: Ed25519PrivateKey | None = None + self, http_client: httpx.Client, signing_key: Optional[Ed25519PrivateKey] = None ): """Initialize the client. @@ -48,14 +49,14 @@ def run(self, request: function_pb.RunRequest) -> function_pb.RunResponse: return self._stub.Run(request) @classmethod - def from_url(cls, url: str, signing_key: Ed25519PrivateKey | None = None): + def from_url(cls, url: str, signing_key: Optional[Ed25519PrivateKey] = None): """Returns an EndpointClient for a Dispatch endpoint URL.""" http_client = httpx.Client(base_url=url) return EndpointClient(http_client, signing_key) @classmethod def from_app( - cls, app: fastapi.FastAPI, signing_key: Ed25519PrivateKey | None = None + cls, app: fastapi.FastAPI, signing_key: Optional[Ed25519PrivateKey] = None ): """Returns an EndpointClient for a Dispatch endpoint bound to a FastAPI app instance.""" @@ -65,7 +66,7 @@ def from_app( class _HttpxGrpcChannel(grpc.Channel): def __init__( - self, http_client: httpx.Client, signing_key: Ed25519PrivateKey | None = None + self, http_client: httpx.Client, signing_key: Optional[Ed25519PrivateKey] = None ): self.http_client = http_client self.signing_key = signing_key @@ -113,7 +114,7 @@ def __init__( method, request_serializer, response_deserializer, - signing_key: Ed25519PrivateKey | None = None, + signing_key: Optional[Ed25519PrivateKey] = None, ): self.client = client self.method = method diff --git a/src/dispatch/test/service.py b/src/dispatch/test/service.py index edcf6759..39367658 100644 --- a/src/dispatch/test/service.py +++ b/src/dispatch/test/service.py @@ -5,7 +5,8 @@ import time from collections import OrderedDict from dataclasses import dataclass -from typing import TypeAlias +from typing import Optional +from typing_extensions import TypeAlias import grpc import httpx @@ -52,8 +53,8 @@ class DispatchService(dispatch_grpc.DispatchServiceServicer): def __init__( self, endpoint_client: EndpointClient, - api_key: str | None = None, - retry_on_status: set[Status] | None = None, + api_key: Optional[str] = None, + retry_on_status: Optional[set[Status]] = None, collect_roundtrips: bool = False, ): """Initialize the Dispatch service. @@ -90,7 +91,7 @@ def __init__( if collect_roundtrips: self.roundtrips = OrderedDict() - self._thread: threading.Thread | None = None + self._thread: threading.Optional[Thread] = None self._stop_event = threading.Event() self._work_signal = threading.Condition() diff --git a/tests/dispatch/test_scheduler.py b/tests/dispatch/test_scheduler.py index 2bfc079a..c5189de2 100644 --- a/tests/dispatch/test_scheduler.py +++ b/tests/dispatch/test_scheduler.py @@ -1,5 +1,5 @@ import unittest -from typing import Any, Callable +from typing import Any, Callable, Optional from dispatch.coroutine import AnyException, any, call, gather, race from dispatch.experimental.durable import durable @@ -414,7 +414,7 @@ def resume( main: Callable, prev_output: Output, call_results: list[CallResult], - poll_error: Exception | None = None, + poll_error: Optional[Exception] = None, ): poll = self.assert_poll(prev_output) input = Input.from_poll_results( @@ -444,7 +444,7 @@ def assert_exit_result_value(self, output: Output, expect: Any): self.assertEqual(expect, any_unpickle(result.output)) def assert_exit_result_error( - self, output: Output, expect: type[Exception], message: str | None = None + self, output: Output, expect: type[Exception], message: Optional[str] = None ): result = self.assert_exit_result(output) self.assertFalse(result.HasField("output")) From 171ff38ab376efde557a4c923f2477f863c1ee29 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Wed, 27 Mar 2024 13:17:07 +1000 Subject: [PATCH 6/9] Remove async generator which isn't applicable the SDK for now --- .../experimental/durable/test_frame.py | 41 ------------------- 1 file changed, 41 deletions(-) diff --git a/tests/dispatch/experimental/durable/test_frame.py b/tests/dispatch/experimental/durable/test_frame.py index b4e9c2c1..ea2f6287 100644 --- a/tests/dispatch/experimental/durable/test_frame.py +++ b/tests/dispatch/experimental/durable/test_frame.py @@ -28,14 +28,6 @@ async def coroutine(a): await Yields(a) -async def async_generator(a): - await Yields(a) - a += 1 - yield a - a += 1 - await Yields(a) - - class TestFrame(unittest.TestCase): def test_generator_copy(self): # Create an instance and run it to the first yield point. @@ -74,39 +66,6 @@ def test_coroutine_copy(self): assert next(g) == 2 assert next(g) == 3 - def test_async_generator_copy(self): - # Create an instance and run it to the first yield point. - ag = async_generator(1) - next_awaitable = anext(ag) - g = next_awaitable.__await__() - assert next(g) == 1 - - # Copy the async generator. - ag2 = async_generator(1) - self.copy_to(ag, ag2) - next_awaitable2 = anext(ag2) - g2 = next_awaitable2.__await__() - - # The copy should start from where the previous generator was suspended. - try: - next(g2) - raise RuntimeError - except StopIteration as e: - assert e.value == 2 - next_awaitable2 = anext(ag2) - g2 = next_awaitable2.__await__() - assert next(g2) == 3 - - # Original generator is not affected. - try: - next(g) - raise RuntimeError - except StopIteration as e: - assert e.value == 2 - next_awaitable = anext(ag) - g = next_awaitable.__await__() - assert next(g) == 3 - def copy_to(self, from_obj, to_obj): ext.set_frame_state(to_obj, ext.get_frame_state(from_obj)) ext.set_frame_ip(to_obj, ext.get_frame_ip(from_obj)) From e46a395f8eeba371029c095602e5cfaa70889013 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Wed, 27 Mar 2024 13:17:24 +1000 Subject: [PATCH 7/9] traceback.format_exception changed in 3.10; use compatible call --- src/dispatch/proto.py | 12 +++++++++--- tests/dispatch/test_error.py | 24 +++++++++++++++++++++--- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/src/dispatch/proto.py b/src/dispatch/proto.py index 14a71724..ef70e84c 100644 --- a/src/dispatch/proto.py +++ b/src/dispatch/proto.py @@ -297,11 +297,15 @@ def _from_proto(cls, proto: call_pb.CallResult) -> CallResult: ) @classmethod - def from_value(cls, output: Any, correlation_id: Optional[int] = None) -> CallResult: + def from_value( + cls, output: Any, correlation_id: Optional[int] = None + ) -> CallResult: return CallResult(correlation_id=correlation_id, output=output) @classmethod - def from_error(cls, error: Error, correlation_id: Optional[int] = None) -> CallResult: + def from_error( + cls, error: Error, correlation_id: Optional[int] = None + ) -> CallResult: return CallResult(correlation_id=correlation_id, error=error) @@ -352,7 +356,9 @@ def __init__( self.value = value self.traceback = traceback if not traceback and value: - self.traceback = "".join(format_exception(value)).encode("utf-8") + self.traceback = "".join( + format_exception(value.__class__, value, value.__traceback__) + ).encode("utf-8") @classmethod def from_exception(cls, ex: Exception, status: Optional[Status] = None) -> Error: diff --git a/tests/dispatch/test_error.py b/tests/dispatch/test_error.py index f036df06..df78436b 100644 --- a/tests/dispatch/test_error.py +++ b/tests/dispatch/test_error.py @@ -12,7 +12,13 @@ def test_conversion_between_exception_and_error(self): except Exception as e: original_exception = e error = Error.from_exception(e) - original_traceback = "".join(traceback.format_exception(original_exception)) + original_traceback = "".join( + traceback.format_exception( + original_exception.__class__, + original_exception, + original_exception.__traceback__, + ) + ) # For some reasons traceback.format_exception does not include the caret # (^) in the original traceback, but it does in the reconstructed one, @@ -24,7 +30,13 @@ def strip_caret(s): reconstructed_exception = error.to_exception() reconstructed_traceback = strip_caret( - "".join(traceback.format_exception(reconstructed_exception)) + "".join( + traceback.format_exception( + reconstructed_exception.__class__, + reconstructed_exception, + reconstructed_exception.__traceback__, + ) + ) ) assert type(reconstructed_exception) is type(original_exception) @@ -34,7 +46,13 @@ def strip_caret(s): error2 = Error.from_exception(reconstructed_exception) reconstructed_exception2 = error2.to_exception() reconstructed_traceback2 = strip_caret( - "".join(traceback.format_exception(reconstructed_exception2)) + "".join( + traceback.format_exception( + reconstructed_exception2.__class__, + reconstructed_exception2, + reconstructed_exception2.__traceback__, + ) + ) ) assert type(reconstructed_exception2) is type(original_exception) From 0dd9a239ffa364d3d4f8d0f10f6c554fdf210031 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Wed, 27 Mar 2024 13:20:30 +1000 Subject: [PATCH 8/9] Fix some minor typing issues --- src/dispatch/test/service.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dispatch/test/service.py b/src/dispatch/test/service.py index 39367658..d711a4ed 100644 --- a/src/dispatch/test/service.py +++ b/src/dispatch/test/service.py @@ -6,10 +6,10 @@ from collections import OrderedDict from dataclasses import dataclass from typing import Optional -from typing_extensions import TypeAlias import grpc import httpx +from typing_extensions import TypeAlias import dispatch.sdk.v1.call_pb2 as call_pb import dispatch.sdk.v1.dispatch_pb2 as dispatch_pb @@ -87,11 +87,11 @@ def __init__( self.pollers: dict[DispatchID, Poller] = {} self.parents: dict[DispatchID, Poller] = {} - self.roundtrips: OrderedDict[DispatchID, list[RoundTrip]] | None = None + self.roundtrips: Optional[OrderedDict[DispatchID, list[RoundTrip]]] = None if collect_roundtrips: self.roundtrips = OrderedDict() - self._thread: threading.Optional[Thread] = None + self._thread: Optional[threading.Thread] = None self._stop_event = threading.Event() self._work_signal = threading.Condition() From d9cef9698fb38d781b2092186d800fcdaa8fec54 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Wed, 27 Mar 2024 13:20:38 +1000 Subject: [PATCH 9/9] make fmt --- src/dispatch/experimental/durable/frame.pyi | 12 +++++++++--- src/dispatch/experimental/durable/function.py | 2 +- src/dispatch/fastapi.py | 2 +- src/dispatch/function.py | 9 ++++++--- src/dispatch/scheduler.py | 3 ++- src/dispatch/signature/key.py | 2 +- 6 files changed, 20 insertions(+), 10 deletions(-) diff --git a/src/dispatch/experimental/durable/frame.pyi b/src/dispatch/experimental/durable/frame.pyi index d20115d0..ec3e50e0 100644 --- a/src/dispatch/experimental/durable/frame.pyi +++ b/src/dispatch/experimental/durable/frame.pyi @@ -4,19 +4,25 @@ from typing import Any, AsyncGenerator, Coroutine, Generator, Tuple, Union def get_frame_ip(frame: Union[FrameType, Coroutine, Generator, AsyncGenerator]) -> int: """Get instruction pointer of a generator or coroutine.""" -def set_frame_ip(frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], ip: int): +def set_frame_ip( + frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], ip: int +): """Set instruction pointer of a generator or coroutine.""" def get_frame_sp(frame: Union[FrameType, Coroutine, Generator, AsyncGenerator]) -> int: """Get stack pointer of a generator or coroutine.""" -def set_frame_sp(frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], sp: int): +def set_frame_sp( + frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], sp: int +): """Set stack pointer of a generator or coroutine.""" def get_frame_bp(frame: Union[FrameType, Coroutine, Generator, AsyncGenerator]) -> int: """Get block pointer of a generator or coroutine.""" -def set_frame_bp(frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], bp: int): +def set_frame_bp( + frame: Union[FrameType, Coroutine, Generator, AsyncGenerator], bp: int +): """Set block pointer of a generator or coroutine.""" def get_frame_stack_at( diff --git a/src/dispatch/experimental/durable/function.py b/src/dispatch/experimental/durable/function.py index 3c721540..925b0ee5 100644 --- a/src/dispatch/experimental/durable/function.py +++ b/src/dispatch/experimental/durable/function.py @@ -9,7 +9,7 @@ MethodType, TracebackType, ) -from typing import Any, Callable, Coroutine, Generator, TypeVar, Union, cast, Optional +from typing import Any, Callable, Coroutine, Generator, Optional, TypeVar, Union, cast from . import frame as ext from .registry import RegisteredFunction, lookup_function, register_function diff --git a/src/dispatch/fastapi.py b/src/dispatch/fastapi.py index 9e7c450c..8f1c8094 100644 --- a/src/dispatch/fastapi.py +++ b/src/dispatch/fastapi.py @@ -21,8 +21,8 @@ def read_root(): import logging import os from datetime import timedelta -from urllib.parse import urlparse from typing import Optional, Union +from urllib.parse import urlparse import fastapi import fastapi.responses diff --git a/src/dispatch/function.py b/src/dispatch/function.py index 5de267e2..0caeb087 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -12,14 +12,14 @@ Dict, Generic, Iterable, - TypeVar, Optional, + TypeVar, overload, ) -from typing_extensions import ParamSpec, TypeAlias from urllib.parse import urlparse import grpc +from typing_extensions import ParamSpec, TypeAlias import dispatch.coroutine import dispatch.sdk.v1.dispatch_pb2 as dispatch_pb @@ -162,7 +162,10 @@ class Registry: __slots__ = ("functions", "endpoint", "client") def __init__( - self, endpoint: str, api_key: Optional[str] = None, api_url: Optional[str] = None + self, + endpoint: str, + api_key: Optional[str] = None, + api_url: Optional[str] = None, ): """Initialize a function registry. diff --git a/src/dispatch/scheduler.py b/src/dispatch/scheduler.py index bd86a441..42915450 100644 --- a/src/dispatch/scheduler.py +++ b/src/dispatch/scheduler.py @@ -2,7 +2,8 @@ import pickle import sys from dataclasses import dataclass, field -from typing import Any, Awaitable, Callable, Protocol, Optional, Union +from typing import Any, Awaitable, Callable, Optional, Protocol, Union + from typing_extensions import TypeAlias from dispatch.coroutine import AllDirective, AnyDirective, AnyException, RaceDirective diff --git a/src/dispatch/signature/key.py b/src/dispatch/signature/key.py index 84b44e9e..5cce28e5 100644 --- a/src/dispatch/signature/key.py +++ b/src/dispatch/signature/key.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Union, Optional +from typing import Optional, Union from cryptography.hazmat.primitives.asymmetric.ed25519 import ( Ed25519PrivateKey,