Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 28 additions & 28 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: Test

on:
- push
- push

concurrency:
group: ${{ github.workflow }}-${{ github.event.number || github.ref }}
Expand All @@ -15,43 +15,43 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python: ['3.8', '3.9', '3.10', '3.11', '3.12']
python: ["3.8", "3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
cache: pip
- run: make dev
- run: make test
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
cache: pip
- run: make dev lambda
- run: make test

examples:
runs-on: ubuntu-latest
strategy:
matrix:
python: ['3.12']
python: ["3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
cache: pip
- run: make dev
- run: make exampletest
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
cache: pip
- run: make dev
- run: make exampletest

format:
runs-on: ubuntu-latest
strategy:
matrix:
python: ['3.12']
python: ["3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
cache: pip
- run: make dev
- run: make fmt-check
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
cache: pip
- run: make dev lambda
- run: make fmt-check
12 changes: 11 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

PYTHON := python

ifdef PROTO_VERSION
PROTO_TARGET := buf.build/stealthrocket/dispatch-proto:$(PROTO_VERSION)
else
PROTO_TARGET := buf.build/stealthrocket/dispatch-proto
endif


all: test

install:
Expand All @@ -10,6 +17,9 @@ install:
dev:
$(PYTHON) -m pip install -e .[dev]

lambda:
$(PYTHON) -m pip install -e .[lambda]

fmt:
$(PYTHON) -m isort .
$(PYTHON) -m black .
Expand Down Expand Up @@ -39,7 +49,7 @@ test: typecheck unittest
mkdir -p $@

.proto/dispatch-proto: .proto
buf export buf.build/stealthrocket/dispatch-proto --output=.proto/dispatch-proto
buf export $(PROTO_TARGET) --output=.proto/dispatch-proto

update-proto:
$(MAKE) clean
Expand Down
1 change: 0 additions & 1 deletion examples/auto_retry/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def test_app(self):
endpoint_client = EndpointClient.from_app(app)
dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True)
with DispatchServer(dispatch_service) as dispatch_server:

# Use it when dispatching function calls.
dispatch.set_client(Client(api_url=dispatch_server.url))

Expand Down
1 change: 0 additions & 1 deletion examples/getting_started/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def test_app(self):
endpoint_client = EndpointClient.from_app(app)
dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True)
with DispatchServer(dispatch_service) as dispatch_server:

# Use it when dispatching function calls.
dispatch.set_client(Client(api_url=dispatch_server.url))

Expand Down
1 change: 0 additions & 1 deletion examples/github_stats/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def test_app(self):
endpoint_client = EndpointClient.from_app(app)
dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True)
with DispatchServer(dispatch_service) as dispatch_server:

# Use it when dispatching function calls.
dispatch.set_client(Client(api_url=dispatch_server.url))

Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies = [

[project.optional-dependencies]
fastapi = ["fastapi", "httpx"]
lambda = ["awslambdaric"]

dev = [
"black >= 24.1.0",
Expand All @@ -33,7 +34,8 @@ dev = [
"coverage >= 7.4.1",
"requests >= 2.31.0",
"types-requests >= 2.31.0.20240125",
"uvicorn >= 0.28.0"
"uvicorn >= 0.28.0",
"awslambdaric-stubs"
]

docs = [
Expand Down
134 changes: 134 additions & 0 deletions src/dispatch/experimental/lambda_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""Integration of Dispatch programmable endpoints for FastAPI.

Example:

from dispatch.experimental.lambda_handler import Dispatch

dispatch = Dispatch(api_key="test-key")

@dispatch.function
def my_function():
return "Hello World!"

@dispatch.function
def entrypoint():
my_function()

def handler(event, context):
dispatch.handle(event, context, entrypoint="entrypoint")
"""

import base64
import json
import logging
from typing import Optional

from awslambdaric.lambda_context import LambdaContext

from dispatch.function import Registry
from dispatch.proto import Input
from dispatch.sdk.v1 import function_pb2 as function_pb
from dispatch.status import Status

logger = logging.getLogger(__name__)


class Dispatch(Registry):
def __init__(
self,
api_key: Optional[str] = None,
api_url: Optional[str] = None,
):
"""Initializes a Dispatch Lambda handler.

Args:
api_key: Dispatch API key to use for authentication. Uses the value
of the DISPATCH_API_KEY environment variable by default.

api_url: The URL of the Dispatch API to use. Uses the value of the
DISPATCH_API_URL environment variable if set, otherwise
defaults to the public Dispatch API (DEFAULT_API_URL).

"""

super().__init__(endpoint="not configured", api_key=api_key, api_url=api_url)

def handle(
self, event: str, context: LambdaContext, entrypoint: Optional[str] = None
):
# The ARN is not none until the first invocation of the Lambda function.
# We override the endpoint of all registered functions before any execution.
if context.invoked_function_arn:
self.endpoint = context.invoked_function_arn
self.override_endpoint(self.endpoint)

if not event:
raise ValueError("event is required")

try:
raw = base64.b64decode(event)
except Exception as e:
raise ValueError("event is not base64 encoded") from e

req = function_pb.RunRequest.FromString(raw)

function: Optional[str] = req.function if req.function else entrypoint
if not function:
raise ValueError("function is required")

logger.debug(
"Dispatch handler invoked for %s function %s with runRequest: %s",
self.endpoint,
function,
req,
)

try:
func = self.functions[req.function]
except KeyError:
raise ValueError(f"function {req.function} not found")

input = Input(req)
try:
output = func._primitive_call(input)
except Exception:
logger.error("function '%s' fatal error", req.function, exc_info=True)
raise # FIXME
else:
response = output._message
status = Status(response.status)

if response.HasField("poll"):
logger.debug(
"function '%s' polling with %d call(s)",
req.function,
len(response.poll.calls),
)
elif response.HasField("exit"):
exit = response.exit
if not exit.HasField("result"):
logger.debug("function '%s' exiting with no result", req.function)
else:
result = exit.result
if result.HasField("output"):
logger.debug(
"function '%s' exiting with output value", req.function
)
elif result.HasField("error"):
err = result.error
logger.debug(
"function '%s' exiting with error: %s (%s)",
req.function,
err.message,
err.type,
)
if exit.HasField("tail_call"):
logger.debug(
"function '%s' tail calling function '%s'",
exit.tail_call.function,
)

logger.debug("finished handling run request with status %s", status.name)
respBytes = response.SerializeToString()
respStr = base64.b64encode(respBytes).decode("utf-8")
return bytes(json.dumps(respStr), "utf-8")
4 changes: 2 additions & 2 deletions src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
app = fastapi.FastAPI()
dispatch = Dispatch(app, api_key="test-key")

@dispatch.function()
@dispatch.function
def my_function():
return "Hello World!"

Expand All @@ -29,7 +29,7 @@ def read_root():
import fastapi.responses
from http_message_signatures import InvalidSignature

from dispatch.function import Batch, Client, Registry
from dispatch.function import Batch, Registry
from dispatch.proto import Input
from dispatch.sdk.v1 import function_pb2 as function_pb
from dispatch.signature import (
Expand Down
8 changes: 8 additions & 0 deletions src/dispatch/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def __init__(
def endpoint(self) -> str:
return self._endpoint

@endpoint.setter
def endpoint(self, value: str):
self._endpoint = value

@property
def name(self) -> str:
return self._name
Expand Down Expand Up @@ -267,6 +271,10 @@ def batch(self) -> Batch:
a set of calls to dispatch."""
return self.client.batch()

def override_endpoint(self, endpoint: str):
for fn in self.functions.values():
fn.endpoint = endpoint


class Client:
"""Client for the Dispatch API."""
Expand Down
24 changes: 19 additions & 5 deletions src/dispatch/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import google.protobuf.message
import google.protobuf.wrappers_pb2
import tblib # type: ignore[import-untyped]
from google.protobuf import duration_pb2
from google.protobuf import descriptor_pool, duration_pb2, message_factory

from dispatch.sdk.v1 import call_pb2 as call_pb
from dispatch.sdk.v1 import error_pb2 as error_pb
Expand Down Expand Up @@ -61,10 +61,16 @@ class Input:
def __init__(self, req: function_pb.RunRequest):
self._has_input = req.HasField("input")
if self._has_input:
input_pb = google.protobuf.wrappers_pb2.BytesValue()
req.input.Unpack(input_pb)
input_bytes = input_pb.value
self._input = pickle.loads(input_bytes)
if req.input.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR):
input_pb = google.protobuf.wrappers_pb2.BytesValue()
req.input.Unpack(input_pb)
input_bytes = input_pb.value
try:
self._input = pickle.loads(input_bytes)
except Exception as e:
self._input = input_bytes
else:
self._input = _pb_any_unpack(req.input)
else:
state_bytes = req.poll_result.coroutine_state
if len(state_bytes) > 0:
Expand Down Expand Up @@ -433,3 +439,11 @@ def _pb_any_pickle(x: Any) -> google.protobuf.any_pb2.Any:
pb_any = google.protobuf.any_pb2.Any()
pb_any.Pack(pb_bytes)
return pb_any


def _pb_any_unpack(x: google.protobuf.any_pb2.Any) -> Any:
pool = descriptor_pool.Default()
msg_descriptor = pool.FindMessageTypeByName(x.TypeName())
proto = message_factory.GetMessageClass(msg_descriptor)()
x.Unpack(proto)
return proto
1 change: 0 additions & 1 deletion src/dispatch/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,6 @@ def _rebuild_state(self, input: Input):
raise IncompatibleStateError from e

def _run(self, input: Input) -> Output:

if input.is_first_call:
state = self._init_state(input)
else:
Expand Down
Loading