From 36e43678644709ccb72007182cb4bb045fbc4abf Mon Sep 17 00:00:00 2001 From: Julien Fabre Date: Tue, 26 Mar 2024 23:11:48 -0600 Subject: [PATCH 1/2] Add experimentation Lambda integration --- Makefile | 9 +- src/dispatch/experimental/lambda_handler.py | 122 ++++++++++++++++++++ src/dispatch/fastapi.py | 4 +- src/dispatch/function.py | 11 ++ src/dispatch/proto.py | 5 +- src/dispatch/sdk/v1/call_pb2.py | 10 +- src/dispatch/sdk/v1/dispatch_pb2.py | 14 +-- 7 files changed, 159 insertions(+), 16 deletions(-) create mode 100644 src/dispatch/experimental/lambda_handler.py diff --git a/Makefile b/Makefile index c025bb0d..9a28a2e9 100644 --- a/Makefile +++ b/Makefile @@ -2,6 +2,13 @@ PYTHON := python +ifdef PROTO_VERSION +PROTO_TARGET := buf.build/stealthrocket/dispatch-proto:$(PROTO_VERSION) +else +PROTO_TARGET := buf.build/stealthrocket/dispatch-proto +endif + + all: test install: @@ -39,7 +46,7 @@ test: typecheck unittest mkdir -p $@ .proto/dispatch-proto: .proto - buf export buf.build/stealthrocket/dispatch-proto --output=.proto/dispatch-proto + buf export $(PROTO_TARGET) --output=.proto/dispatch-proto update-proto: $(MAKE) clean diff --git a/src/dispatch/experimental/lambda_handler.py b/src/dispatch/experimental/lambda_handler.py new file mode 100644 index 00000000..35c687e1 --- /dev/null +++ b/src/dispatch/experimental/lambda_handler.py @@ -0,0 +1,122 @@ +"""Integration of Dispatch programmable endpoints for FastAPI. + +Example: + + from dispatch.experimental.lambda_handler import Dispatch + + dispatch = Dispatch(api_key="test-key") + + @dispatch.function + def my_function(): + return "Hello World!" + + @dispatch.entrypoint + def entrypoint(): + my_function() + + def handler(event, context): + dispatch.handle(event, context) + """ + +import base64 +import logging +import json + +from dispatch.function import Registry +from dispatch.proto import Input +from dispatch.sdk.v1 import function_pb2 as function_pb +from dispatch.status import Status + +logger = logging.getLogger(__name__) + + +class Dispatch(Registry): + def __init__( + self, + api_key: str | None = None, + api_url: str | None = None, + ): + """Initializes a Dispatch Lambda handler. + + Args: + api_key: Dispatch API key to use for authentication. Uses the value of + the DISPATCH_API_KEY environment variable by default. + + api_url: The URL of the Dispatch API to use. Uses the value of the + DISPATCH_API_URL environment variable if set, otherwise + defaults to the public Dispatch API (DEFAULT_API_URL). + + """ + + # The endpoint (AWS Lambda ARN) is set when handling the first request. + super().__init__(endpoint=None, api_key=api_key, api_url=api_url) + + def handle(self, event, context): + # Use the context to determine the ARN of the Lambda function. + self.endpoint = context.invoked_function_arn + + logger.debug("Dispatch handler invoked for %s with event: %s", self.endpoint, event) + + if not event: + raise ValueError("event is required") + + try: + raw = base64.b64decode(event) + except Exception as e: + raise ValueError(f"event is not base64 encoded: {e}") + + req = function_pb.RunRequest.FromString(raw) + print(req) + if not req.function: + req.function = "entrypoint" + # FIXME raise ValueError("function is required") + + try: + func = self.functions[req.function] + except KeyError: + raise ValueError(f"function {req.function} not found") + + input = Input(req) + try: + output = func._primitive_call(input) + except Exception: + logger.error("function '%s' fatal error", req.function, exc_info=True) + raise # FIXME + else: + response = output._message + status = Status(response.status) + + if response.HasField("poll"): + logger.debug( + "function '%s' polling with %d call(s)", + req.function, + len(response.poll.calls), + ) + elif response.HasField("exit"): + exit = response.exit + if not exit.HasField("result"): + logger.debug("function '%s' exiting with no result", req.function) + else: + result = exit.result + if result.HasField("output"): + logger.debug( + "function '%s' exiting with output value", req.function + ) + elif result.HasField("error"): + err = result.error + logger.debug( + "function '%s' exiting with error: %s (%s)", + req.function, + err.message, + err.type, + ) + if exit.HasField("tail_call"): + logger.debug( + "function '%s' tail calling function '%s'", + exit.tail_call.function, + ) + + logger.debug("finished handling run request with status %s", status.name) + resp = response.SerializeToString() + resp = base64.b64encode(resp).decode("utf-8") + return bytes(json.dumps(resp), "utf-8") diff --git a/src/dispatch/fastapi.py b/src/dispatch/fastapi.py index 1f8de120..abe4e825 100644 --- a/src/dispatch/fastapi.py +++ b/src/dispatch/fastapi.py @@ -8,7 +8,7 @@ app = fastapi.FastAPI() dispatch = Dispatch(app, api_key="test-key") - @dispatch.function() + @dispatch.function def my_function(): return "Hello World!" @@ -29,7 +29,7 @@ def read_root(): import fastapi.responses from http_message_signatures import InvalidSignature -from dispatch.function import Batch, Client, Registry +from dispatch.function import Batch, Registry from dispatch.proto import Input from dispatch.sdk.v1 import function_pb2 as function_pb from dispatch.signature import ( diff --git a/src/dispatch/function.py b/src/dispatch/function.py index c2fbfcb8..ab19a883 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -210,6 +210,17 @@ def function(self, func): logger.info("registering coroutine: %s", name) return self._register_coroutine(name, func) + # TODO: ensure we have only 1 entrypoint per app. + def entrypoint(self, func): + """Decorated that registers the program entrypoint.""" + name = "entrypoint" + if not inspect.iscoroutinefunction(func): + logger.info("registering entrypoint function: %s", name) + return self._register_function(name, func) + + logger.info("registering entrypoint coroutine: %s", name) + return self._register_coroutine(name, func) + def _register_function(self, name: str, func: Callable[P, T]) -> Function[P, T]: func = durable(func) diff --git a/src/dispatch/proto.py b/src/dispatch/proto.py index d297e64a..02a37e02 100644 --- a/src/dispatch/proto.py +++ b/src/dispatch/proto.py @@ -64,7 +64,10 @@ def __init__(self, req: function_pb.RunRequest): input_pb = google.protobuf.wrappers_pb2.BytesValue() req.input.Unpack(input_pb) input_bytes = input_pb.value - self._input = pickle.loads(input_bytes) + try: + self._input = pickle.loads(input_bytes) + except EOFError: + self._input = Arguments(args=(req.input,), kwargs={}) else: state_bytes = req.poll_result.coroutine_state if len(state_bytes) > 0: diff --git a/src/dispatch/sdk/v1/call_pb2.py b/src/dispatch/sdk/v1/call_pb2.py index f478791d..4404f87d 100644 --- a/src/dispatch/sdk/v1/call_pb2.py +++ b/src/dispatch/sdk/v1/call_pb2.py @@ -20,7 +20,7 @@ from dispatch.sdk.v1 import error_pb2 as dispatch_dot_sdk_dot_v1_dot_error__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1a\x64ispatch/sdk/v1/call.proto\x12\x0f\x64ispatch.sdk.v1\x1a\x1b\x62uf/validate/validate.proto\x1a\x1b\x64ispatch/sdk/v1/error.proto\x1a\x19google/protobuf/any.proto\x1a\x1egoogle/protobuf/duration.proto"\x84\x02\n\x04\x43\x61ll\x12%\n\x0e\x63orrelation_id\x18\x01 \x01(\x04R\rcorrelationId\x12$\n\x08\x65ndpoint\x18\x02 \x01(\tB\x08\xbaH\x05r\x03\x88\x01\x01R\x08\x65ndpoint\x12>\n\x08\x66unction\x18\x03 \x01(\tB"\xbaH\x1fr\x1a\x32\x18^[a-zA-Z_][a-zA-Z0-9_]*$\xc8\x01\x01R\x08\x66unction\x12*\n\x05input\x18\x04 \x01(\x0b\x32\x14.google.protobuf.AnyR\x05input\x12\x43\n\nexpiration\x18\x05 \x01(\x0b\x32\x19.google.protobuf.DurationB\x08\xbaH\x05\xaa\x01\x02\x32\x00R\nexpiration"\x8f\x01\n\nCallResult\x12%\n\x0e\x63orrelation_id\x18\x01 \x01(\x04R\rcorrelationId\x12,\n\x06output\x18\x02 \x01(\x0b\x32\x14.google.protobuf.AnyR\x06output\x12,\n\x05\x65rror\x18\x03 \x01(\x0b\x32\x16.dispatch.sdk.v1.ErrorR\x05\x65rrorB~\n\x13\x63om.dispatch.sdk.v1B\tCallProtoP\x01\xa2\x02\x03\x44SX\xaa\x02\x0f\x44ispatch.Sdk.V1\xca\x02\x0f\x44ispatch\\Sdk\\V1\xe2\x02\x1b\x44ispatch\\Sdk\\V1\\GPBMetadata\xea\x02\x11\x44ispatch::Sdk::V1b\x06proto3' + b'\n\x1a\x64ispatch/sdk/v1/call.proto\x12\x0f\x64ispatch.sdk.v1\x1a\x1b\x62uf/validate/validate.proto\x1a\x1b\x64ispatch/sdk/v1/error.proto\x1a\x19google/protobuf/any.proto\x1a\x1egoogle/protobuf/duration.proto"\x87\x02\n\x04\x43\x61ll\x12%\n\x0e\x63orrelation_id\x18\x01 \x01(\x04R\rcorrelationId\x12$\n\x08\x65ndpoint\x18\x02 \x01(\tB\x08\xbaH\x05r\x03\x88\x01\x01R\x08\x65ndpoint\x12\x41\n\x08\x66unction\x18\x03 \x01(\tB%\xbaH"r\x1d\x32\x1b^[a-zA-Z_][a-zA-Z0-9_<>.]*$\xc8\x01\x01R\x08\x66unction\x12*\n\x05input\x18\x04 \x01(\x0b\x32\x14.google.protobuf.AnyR\x05input\x12\x43\n\nexpiration\x18\x05 \x01(\x0b\x32\x19.google.protobuf.DurationB\x08\xbaH\x05\xaa\x01\x02\x32\x00R\nexpiration"\x8f\x01\n\nCallResult\x12%\n\x0e\x63orrelation_id\x18\x01 \x01(\x04R\rcorrelationId\x12,\n\x06output\x18\x02 \x01(\x0b\x32\x14.google.protobuf.AnyR\x06output\x12,\n\x05\x65rror\x18\x03 \x01(\x0b\x32\x16.dispatch.sdk.v1.ErrorR\x05\x65rrorB~\n\x13\x63om.dispatch.sdk.v1B\tCallProtoP\x01\xa2\x02\x03\x44SX\xaa\x02\x0f\x44ispatch.Sdk.V1\xca\x02\x0f\x44ispatch\\Sdk\\V1\xe2\x02\x1b\x44ispatch\\Sdk\\V1\\GPBMetadata\xea\x02\x11\x44ispatch::Sdk::V1b\x06proto3' ) _globals = globals() @@ -40,13 +40,13 @@ _globals["_CALL"].fields_by_name["function"]._options = None _globals["_CALL"].fields_by_name[ "function" - ]._serialized_options = b"\272H\037r\0322\030^[a-zA-Z_][a-zA-Z0-9_]*$\310\001\001" + ]._serialized_options = b'\272H"r\0352\033^[a-zA-Z_][a-zA-Z0-9_<>.]*$\310\001\001' _globals["_CALL"].fields_by_name["expiration"]._options = None _globals["_CALL"].fields_by_name[ "expiration" ]._serialized_options = b"\272H\005\252\001\0022\000" _globals["_CALL"]._serialized_start = 165 - _globals["_CALL"]._serialized_end = 425 - _globals["_CALLRESULT"]._serialized_start = 428 - _globals["_CALLRESULT"]._serialized_end = 571 + _globals["_CALL"]._serialized_end = 428 + _globals["_CALLRESULT"]._serialized_start = 431 + _globals["_CALLRESULT"]._serialized_end = 574 # @@protoc_insertion_point(module_scope) diff --git a/src/dispatch/sdk/v1/dispatch_pb2.py b/src/dispatch/sdk/v1/dispatch_pb2.py index 250b42db..3d084eb9 100644 --- a/src/dispatch/sdk/v1/dispatch_pb2.py +++ b/src/dispatch/sdk/v1/dispatch_pb2.py @@ -17,7 +17,7 @@ from dispatch.sdk.v1 import call_pb2 as dispatch_dot_sdk_dot_v1_dot_call__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b"\n\x1e\x64ispatch/sdk/v1/dispatch.proto\x12\x0f\x64ispatch.sdk.v1\x1a\x1b\x62uf/validate/validate.proto\x1a\x1a\x64ispatch/sdk/v1/call.proto\"\xf3\x02\n\x0f\x44ispatchRequest\x12+\n\x05\x63\x61lls\x18\x01 \x03(\x0b\x32\x15.dispatch.sdk.v1.CallR\x05\x63\x61lls:\xb2\x02\xbaH\xae\x02\x1as\n(dispatch.request.calls.endpoint.nonempty\x12\x1d\x43\x61ll endpoint cannot be empty\x1a(this.calls.all(call, has(call.endpoint))\x1a\xb6\x01\n&dispatch.request.calls.endpoint.scheme\x12)Call endpoint must be a http or https URL\x1a\x61this.calls.all(call, call.endpoint.startsWith('http://') || call.endpoint.startsWith('https://'))\"5\n\x10\x44ispatchResponse\x12!\n\x0c\x64ispatch_ids\x18\x01 \x03(\tR\x0b\x64ispatchIds2d\n\x0f\x44ispatchService\x12Q\n\x08\x44ispatch\x12 .dispatch.sdk.v1.DispatchRequest\x1a!.dispatch.sdk.v1.DispatchResponse\"\x00\x42\x82\x01\n\x13\x63om.dispatch.sdk.v1B\rDispatchProtoP\x01\xa2\x02\x03\x44SX\xaa\x02\x0f\x44ispatch.Sdk.V1\xca\x02\x0f\x44ispatch\\Sdk\\V1\xe2\x02\x1b\x44ispatch\\Sdk\\V1\\GPBMetadata\xea\x02\x11\x44ispatch::Sdk::V1b\x06proto3" + b"\n\x1e\x64ispatch/sdk/v1/dispatch.proto\x12\x0f\x64ispatch.sdk.v1\x1a\x1b\x62uf/validate/validate.proto\x1a\x1a\x64ispatch/sdk/v1/call.proto\"\xb7\x03\n\x0f\x44ispatchRequest\x12+\n\x05\x63\x61lls\x18\x01 \x03(\x0b\x32\x15.dispatch.sdk.v1.CallR\x05\x63\x61lls:\xf6\x02\xbaH\xf2\x02\x1as\n(dispatch.request.calls.endpoint.nonempty\x12\x1d\x43\x61ll endpoint cannot be empty\x1a(this.calls.all(call, has(call.endpoint))\x1a\xfa\x01\n&dispatch.request.calls.endpoint.scheme\x12>Call endpoint must be a http or https URL or an AWS Lambda ARN\x1a\x8f\x01this.calls.all(call, call.endpoint.startsWith('http://') || call.endpoint.startsWith('https://') || call.endpoint.startsWith('arn:aws:lambda'))\"5\n\x10\x44ispatchResponse\x12!\n\x0c\x64ispatch_ids\x18\x01 \x03(\tR\x0b\x64ispatchIds2d\n\x0f\x44ispatchService\x12Q\n\x08\x44ispatch\x12 .dispatch.sdk.v1.DispatchRequest\x1a!.dispatch.sdk.v1.DispatchResponse\"\x00\x42\x82\x01\n\x13\x63om.dispatch.sdk.v1B\rDispatchProtoP\x01\xa2\x02\x03\x44SX\xaa\x02\x0f\x44ispatch.Sdk.V1\xca\x02\x0f\x44ispatch\\Sdk\\V1\xe2\x02\x1b\x44ispatch\\Sdk\\V1\\GPBMetadata\xea\x02\x11\x44ispatch::Sdk::V1b\x06proto3" ) _globals = globals() @@ -32,12 +32,12 @@ ) _globals["_DISPATCHREQUEST"]._options = None _globals["_DISPATCHREQUEST"]._serialized_options = ( - b"\272H\256\002\032s\n(dispatch.request.calls.endpoint.nonempty\022\035Call endpoint cannot be empty\032(this.calls.all(call, has(call.endpoint))\032\266\001\n&dispatch.request.calls.endpoint.scheme\022)Call endpoint must be a http or https URL\032athis.calls.all(call, call.endpoint.startsWith('http://') || call.endpoint.startsWith('https://'))" + b"\272H\362\002\032s\n(dispatch.request.calls.endpoint.nonempty\022\035Call endpoint cannot be empty\032(this.calls.all(call, has(call.endpoint))\032\372\001\n&dispatch.request.calls.endpoint.scheme\022>Call endpoint must be a http or https URL or an AWS Lambda ARN\032\217\001this.calls.all(call, call.endpoint.startsWith('http://') || call.endpoint.startsWith('https://') || call.endpoint.startsWith('arn:aws:lambda'))" ) _globals["_DISPATCHREQUEST"]._serialized_start = 109 - _globals["_DISPATCHREQUEST"]._serialized_end = 480 - _globals["_DISPATCHRESPONSE"]._serialized_start = 482 - _globals["_DISPATCHRESPONSE"]._serialized_end = 535 - _globals["_DISPATCHSERVICE"]._serialized_start = 537 - _globals["_DISPATCHSERVICE"]._serialized_end = 637 + _globals["_DISPATCHREQUEST"]._serialized_end = 548 + _globals["_DISPATCHRESPONSE"]._serialized_start = 550 + _globals["_DISPATCHRESPONSE"]._serialized_end = 603 + _globals["_DISPATCHSERVICE"]._serialized_start = 605 + _globals["_DISPATCHSERVICE"]._serialized_end = 705 # @@protoc_insertion_point(module_scope) From 7fd42a1c643b2619d902be3de6e53c21ec7650fe Mon Sep 17 00:00:00 2001 From: Julien Fabre Date: Tue, 26 Mar 2024 23:20:28 -0600 Subject: [PATCH 2/2] Review feedback, remove the entrypoint decorator, fix tests and linter --- .github/workflows/test.yml | 56 ++++++++++----------- Makefile | 3 ++ examples/auto_retry/test_app.py | 1 - examples/getting_started/test_app.py | 1 - examples/github_stats/test_app.py | 1 - pyproject.toml | 4 +- src/dispatch/experimental/lambda_handler.py | 56 +++++++++++++-------- src/dispatch/function.py | 19 +++---- src/dispatch/proto.py | 27 +++++++--- src/dispatch/scheduler.py | 1 - tests/dispatch/test_error.py | 1 - tests/dispatch/test_status.py | 2 - tests/test_fastapi.py | 1 - 13 files changed, 95 insertions(+), 78 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7b944c15..afdbffb7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,7 +1,7 @@ name: Test on: -- push + - push concurrency: group: ${{ github.workflow }}-${{ github.event.number || github.ref }} @@ -15,43 +15,43 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python: ['3.8', '3.9', '3.10', '3.11', '3.12'] + python: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python }} - cache: pip - - run: make dev - - run: make test + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + cache: pip + - run: make dev lambda + - run: make test examples: runs-on: ubuntu-latest strategy: matrix: - python: ['3.12'] + python: ["3.12"] steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python }} - cache: pip - - run: make dev - - run: make exampletest + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + cache: pip + - run: make dev + - run: make exampletest format: runs-on: ubuntu-latest strategy: matrix: - python: ['3.12'] + python: ["3.12"] steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python }} - cache: pip - - run: make dev - - run: make fmt-check + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + cache: pip + - run: make dev lambda + - run: make fmt-check diff --git a/Makefile b/Makefile index 9a28a2e9..3b70667c 100644 --- a/Makefile +++ b/Makefile @@ -17,6 +17,9 @@ install: dev: $(PYTHON) -m pip install -e .[dev] +lambda: + $(PYTHON) -m pip install -e .[lambda] + fmt: $(PYTHON) -m isort . $(PYTHON) -m black . diff --git a/examples/auto_retry/test_app.py b/examples/auto_retry/test_app.py index 894ae20b..81555369 100644 --- a/examples/auto_retry/test_app.py +++ b/examples/auto_retry/test_app.py @@ -28,7 +28,6 @@ def test_app(self): endpoint_client = EndpointClient.from_app(app) dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True) with DispatchServer(dispatch_service) as dispatch_server: - # Use it when dispatching function calls. dispatch.set_client(Client(api_url=dispatch_server.url)) diff --git a/examples/getting_started/test_app.py b/examples/getting_started/test_app.py index 108886b9..3f420e73 100644 --- a/examples/getting_started/test_app.py +++ b/examples/getting_started/test_app.py @@ -27,7 +27,6 @@ def test_app(self): endpoint_client = EndpointClient.from_app(app) dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True) with DispatchServer(dispatch_service) as dispatch_server: - # Use it when dispatching function calls. dispatch.set_client(Client(api_url=dispatch_server.url)) diff --git a/examples/github_stats/test_app.py b/examples/github_stats/test_app.py index 56c69ccd..dd68ddbe 100644 --- a/examples/github_stats/test_app.py +++ b/examples/github_stats/test_app.py @@ -27,7 +27,6 @@ def test_app(self): endpoint_client = EndpointClient.from_app(app) dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True) with DispatchServer(dispatch_service) as dispatch_server: - # Use it when dispatching function calls. dispatch.set_client(Client(api_url=dispatch_server.url)) diff --git a/pyproject.toml b/pyproject.toml index d2bef0cc..93e98636 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ [project.optional-dependencies] fastapi = ["fastapi", "httpx"] +lambda = ["awslambdaric"] dev = [ "black >= 24.1.0", @@ -33,7 +34,8 @@ dev = [ "coverage >= 7.4.1", "requests >= 2.31.0", "types-requests >= 2.31.0.20240125", - "uvicorn >= 0.28.0" + "uvicorn >= 0.28.0", + "awslambdaric-stubs" ] docs = [ diff --git a/src/dispatch/experimental/lambda_handler.py b/src/dispatch/experimental/lambda_handler.py index 35c687e1..5dd25398 100644 --- a/src/dispatch/experimental/lambda_handler.py +++ b/src/dispatch/experimental/lambda_handler.py @@ -10,17 +10,20 @@ def my_function(): return "Hello World!" - @dispatch.entrypoint + @dispatch.function def entrypoint(): my_function() def handler(event, context): - dispatch.handle(event, context) + dispatch.handle(event, context, entrypoint="entrypoint") """ import base64 -import logging import json +import logging +from typing import Optional + +from awslambdaric.lambda_context import LambdaContext from dispatch.function import Registry from dispatch.proto import Input @@ -33,14 +36,14 @@ def handler(event, context): class Dispatch(Registry): def __init__( self, - api_key: str | None = None, - api_url: str | None = None, + api_key: Optional[str] = None, + api_url: Optional[str] = None, ): """Initializes a Dispatch Lambda handler. Args: - api_key: Dispatch API key to use for authentication. Uses the value of - the DISPATCH_API_KEY environment variable by default. + api_key: Dispatch API key to use for authentication. Uses the value + of the DISPATCH_API_KEY environment variable by default. api_url: The URL of the Dispatch API to use. Uses the value of the DISPATCH_API_URL environment variable if set, otherwise @@ -48,14 +51,16 @@ def __init__( """ - # The endpoint (AWS Lambda ARN) is set when handling the first request. - super().__init__(endpoint=None, api_key=api_key, api_url=api_url) + super().__init__(endpoint="not configured", api_key=api_key, api_url=api_url) - def handle(self, event, context): - # Use the context to determine the ARN of the Lambda function. - self.endpoint = context.invoked_function_arn - - logger.debug("Dispatch handler invoked for %s with event: %s", self.endpoint, event) + def handle( + self, event: str, context: LambdaContext, entrypoint: Optional[str] = None + ): + # The ARN is not none until the first invocation of the Lambda function. + # We override the endpoint of all registered functions before any execution. + if context.invoked_function_arn: + self.endpoint = context.invoked_function_arn + self.override_endpoint(self.endpoint) if not event: raise ValueError("event is required") @@ -63,13 +68,20 @@ def handle(self, event, context): try: raw = base64.b64decode(event) except Exception as e: - raise ValueError(f"event is not base64 encoded: {e}") + raise ValueError("event is not base64 encoded") from e req = function_pb.RunRequest.FromString(raw) - print(req) - if not req.function: - req.function = "entrypoint" - # FIXME raise ValueError("function is required") + + function: Optional[str] = req.function if req.function else entrypoint + if not function: + raise ValueError("function is required") + + logger.debug( + "Dispatch handler invoked for %s function %s with runRequest: %s", + self.endpoint, + function, + req, + ) try: func = self.functions[req.function] @@ -117,6 +129,6 @@ def handle(self, event, context): ) logger.debug("finished handling run request with status %s", status.name) - resp = response.SerializeToString() - resp = base64.b64encode(resp).decode("utf-8") - return bytes(json.dumps(resp), "utf-8") + respBytes = response.SerializeToString() + respStr = base64.b64encode(respBytes).decode("utf-8") + return bytes(json.dumps(respStr), "utf-8") diff --git a/src/dispatch/function.py b/src/dispatch/function.py index ab19a883..a711880d 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -62,6 +62,10 @@ def __init__( def endpoint(self) -> str: return self._endpoint + @endpoint.setter + def endpoint(self, value: str): + self._endpoint = value + @property def name(self) -> str: return self._name @@ -210,17 +214,6 @@ def function(self, func): logger.info("registering coroutine: %s", name) return self._register_coroutine(name, func) - # TODO: ensure we have only 1 entrypoint per app. - def entrypoint(self, func): - """Decorated that registers the program entrypoint.""" - name = "entrypoint" - if not inspect.iscoroutinefunction(func): - logger.info("registering entrypoint function: %s", name) - return self._register_function(name, func) - - logger.info("registering entrypoint coroutine: %s", name) - return self._register_coroutine(name, func) - def _register_function(self, name: str, func: Callable[P, T]) -> Function[P, T]: func = durable(func) @@ -278,6 +271,10 @@ def batch(self) -> Batch: a set of calls to dispatch.""" return self.client.batch() + def override_endpoint(self, endpoint: str): + for fn in self.functions.values(): + fn.endpoint = endpoint + class Client: """Client for the Dispatch API.""" diff --git a/src/dispatch/proto.py b/src/dispatch/proto.py index 02a37e02..c3fe286f 100644 --- a/src/dispatch/proto.py +++ b/src/dispatch/proto.py @@ -10,7 +10,7 @@ import google.protobuf.message import google.protobuf.wrappers_pb2 import tblib # type: ignore[import-untyped] -from google.protobuf import duration_pb2 +from google.protobuf import descriptor_pool, duration_pb2, message_factory from dispatch.sdk.v1 import call_pb2 as call_pb from dispatch.sdk.v1 import error_pb2 as error_pb @@ -61,13 +61,16 @@ class Input: def __init__(self, req: function_pb.RunRequest): self._has_input = req.HasField("input") if self._has_input: - input_pb = google.protobuf.wrappers_pb2.BytesValue() - req.input.Unpack(input_pb) - input_bytes = input_pb.value - try: - self._input = pickle.loads(input_bytes) - except EOFError: - self._input = Arguments(args=(req.input,), kwargs={}) + if req.input.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR): + input_pb = google.protobuf.wrappers_pb2.BytesValue() + req.input.Unpack(input_pb) + input_bytes = input_pb.value + try: + self._input = pickle.loads(input_bytes) + except Exception as e: + self._input = input_bytes + else: + self._input = _pb_any_unpack(req.input) else: state_bytes = req.poll_result.coroutine_state if len(state_bytes) > 0: @@ -436,3 +439,11 @@ def _pb_any_pickle(x: Any) -> google.protobuf.any_pb2.Any: pb_any = google.protobuf.any_pb2.Any() pb_any.Pack(pb_bytes) return pb_any + + +def _pb_any_unpack(x: google.protobuf.any_pb2.Any) -> Any: + pool = descriptor_pool.Default() + msg_descriptor = pool.FindMessageTypeByName(x.TypeName()) + proto = message_factory.GetMessageClass(msg_descriptor)() + x.Unpack(proto) + return proto diff --git a/src/dispatch/scheduler.py b/src/dispatch/scheduler.py index ada9b2a3..90ee5b68 100644 --- a/src/dispatch/scheduler.py +++ b/src/dispatch/scheduler.py @@ -374,7 +374,6 @@ def _rebuild_state(self, input: Input): raise IncompatibleStateError from e def _run(self, input: Input) -> Output: - if input.is_first_call: state = self._init_state(input) else: diff --git a/tests/dispatch/test_error.py b/tests/dispatch/test_error.py index df78436b..a3a0b7d0 100644 --- a/tests/dispatch/test_error.py +++ b/tests/dispatch/test_error.py @@ -5,7 +5,6 @@ class TestError(unittest.TestCase): - def test_conversion_between_exception_and_error(self): try: raise ValueError("test") diff --git a/tests/dispatch/test_status.py b/tests/dispatch/test_status.py index 78689b81..1b445d81 100644 --- a/tests/dispatch/test_status.py +++ b/tests/dispatch/test_status.py @@ -5,7 +5,6 @@ class TestErrorStatus(unittest.TestCase): - def test_status_for_Exception(self): assert status_for_error(Exception()) is Status.PERMANENT_ERROR @@ -71,7 +70,6 @@ class CustomError(TimeoutError): class TestHTTPStatusCodes(unittest.TestCase): - def test_http_response_code_status_400(self): assert http_response_code_status(400) is Status.INVALID_ARGUMENT diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 33ee7160..39eef0d0 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -39,7 +39,6 @@ def create_dispatch_instance(app, endpoint): class TestFastAPI(unittest.TestCase): - def test_Dispatch(self): app = fastapi.FastAPI() create_dispatch_instance(app, "https://127.0.0.1:9999")