diff --git a/README.md b/README.md index e85b7cfc..3c76203a 100644 --- a/README.md +++ b/README.md @@ -22,8 +22,9 @@ Python package to develop applications with the Dispatch platform. - [Usage](#usage) - [Configuration](#configuration) - [Integration with FastAPI](#integration-with-fastapi) - - [Local testing with ngrok](#local-testing-with-ngrok) - - [Distributed coroutines for Python](#distributed-coroutines-for-python) + - [Local Testing](#local-testing) + - [Distributed Coroutines for Python](#distributed-coroutines-for-python) + - [Serialization](#serialization) - [Examples](#examples) - [Contributing](#contributing) @@ -123,10 +124,46 @@ program, driven by the Dispatch SDK. The instantiation of the `Dispatch` object on the `FastAPI` application automatically installs the HTTP route needed for Dispatch to invoke functions. -### Local testing with ngrok +### Local Testing -To enable local testing, a common approach consists of using [ngrok][ngrok] to -setup a public endpoint that forwards to the server running on localhost. +#### Mock Dispatch + +The SDK ships with a mock Dispatch server. It can be used to quickly test your +local functions, without requiring internet access. + +Note that the mock Dispatch server has very limited scheduling capabilities. + +```console +python -m dispatch.test $DISPATCH_ENDPOINT_URL +``` + +The command will start a mock Dispatch server and print the configuration +for the SDK. + +For example, if your functions were exposed through a local endpoint +listening on `http://127.0.0.1:8000`, you could run: + +```console +$ python -m dispatch.test http://127.0.0.1:8000 +Spawned a mock Dispatch server on 127.0.0.1:4450 + +Dispatching function calls to the endpoint at http://127.0.0.1:8000 + +The Dispatch SDK can be configured with: + + export DISPATCH_API_URL="http://127.0.0.1:4450" + export DISPATCH_API_KEY="test" + export DISPATCH_ENDPOINT_URL="http://127.0.0.1:8000" + export DISPATCH_VERIFICATION_KEY="Z+nTe2VRcw8t8Ihx++D+nXtbO28nwjWIOTLRgzrelYs=" +``` + +#### Real Dispatch + +To test local functions with the production instance of Dispatch, it needs +to be able to access your local endpoint. + +A common approach consists of using [ngrok][ngrok] to setup a public endpoint +that forwards to the server running on localhost. For example, assuming the server is running on port 8000 (which is the default with FastAPI), the command to create a ngrok tunnel is: diff --git a/pyproject.toml b/pyproject.toml index a9d36b6f..527a1928 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,9 @@ dev = [ "coverage >= 7.4.1", "requests >= 2.31.0", "types-requests >= 2.31.0.20240125", + "docopt >= 0.6.2", + "types-docopt >= 0.6.11.4", + "uvicorn >= 0.28.0" ] docs = [ diff --git a/src/dispatch/scheduler.py b/src/dispatch/scheduler.py index 320b2193..7b78c1a7 100644 --- a/src/dispatch/scheduler.py +++ b/src/dispatch/scheduler.py @@ -329,8 +329,8 @@ def _run(self, input: Input) -> Output: coroutine_id=coroutine.id, value=e.value ) except Exception as e: - logger.exception( - f"@dispatch.function: '{coroutine}' raised an exception" + logger.debug( + f"@dispatch.function: '{coroutine}' raised an exception", exc_info=e ) coroutine_result = CoroutineResult(coroutine_id=coroutine.id, error=e) diff --git a/src/dispatch/status.py b/src/dispatch/status.py index a82f315a..5a413802 100644 --- a/src/dispatch/status.py +++ b/src/dispatch/status.py @@ -30,6 +30,12 @@ class Status(int, enum.Enum): _proto: status_pb.Status + def __repr__(self): + return self.name + + def __str__(self): + return self.name + # Maybe we should find a better way to define that enum. It's that way to please # Mypy and provide documentation for the enum values. diff --git a/src/dispatch/test/__main__.py b/src/dispatch/test/__main__.py new file mode 100644 index 00000000..c0e45439 --- /dev/null +++ b/src/dispatch/test/__main__.py @@ -0,0 +1,92 @@ +"""Mock Dispatch server for use in test environments. + +Usage: + dispatch.test [--api-key=] [--hostname=] [--port=] [-v | --verbose] + dispatch.test -h | --help + +Options: + --api-key= API key to require when clients connect to the server [default: test]. + + --hostname= Hostname to listen on [default: 127.0.0.1]. + --port= Port to listen on [default: 4450]. + + -v --verbose Show verbose details in the log. + -h --help Show this help information. +""" + +import base64 +import logging +import os +import sys + +from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey +from docopt import docopt + +from dispatch.test import DispatchServer, DispatchService, EndpointClient + + +def main(): + args = docopt(__doc__) + + if args["--help"]: + print(__doc__) + exit(0) + + endpoint = args[""] + api_key = args["--api-key"] + hostname = args["--hostname"] + port_str = args["--port"] + + try: + port = int(port_str) + except ValueError: + print(f"error: invalid port: {port_str}", file=sys.stderr) + exit(1) + + if not os.getenv("NO_COLOR"): + logging.addLevelName(logging.WARNING, f"\033[1;33mWARN\033[1;0m") + logging.addLevelName(logging.ERROR, "\033[1;31mERROR\033[1;0m") + + logger = logging.getLogger() + if args["--verbose"]: + logger.setLevel(logging.DEBUG) + fmt = "%(asctime)s [%(levelname)s] %(name)s - %(message)s" + else: + logger.setLevel(logging.INFO) + fmt = "%(asctime)s [%(levelname)s] %(message)s" + logging.getLogger("httpx").disabled = True + + log_formatter = logging.Formatter(fmt=fmt, datefmt="%Y-%m-%d %H:%M:%S") + log_handler = logging.StreamHandler(sys.stderr) + log_handler.setFormatter(log_formatter) + logger.addHandler(log_handler) + + # This private key was generated randomly. + signing_key = Ed25519PrivateKey.from_private_bytes( + b"\x0e\xca\xfb\xc9\xa9Gc'fR\xe4\x97y\xf0\xae\x90\x01\xe8\xd9\x94\xa6\xd4@\xf6\xa7!\x90b\\!z!" + ) + verification_key = base64.b64encode( + signing_key.public_key().public_bytes_raw() + ).decode() + + endpoint_client = EndpointClient.from_url(endpoint, signing_key=signing_key) + + with DispatchService(endpoint_client, api_key=api_key) as service: + with DispatchServer(service, hostname=hostname, port=port) as server: + print(f"Spawned a mock Dispatch server on {hostname}:{port}") + print() + print(f"Dispatching function calls to the endpoint at {endpoint}") + print() + print("The Dispatch SDK can be configured with:") + print() + print(f' export DISPATCH_API_URL="http://{hostname}:{port}"') + print(f' export DISPATCH_API_KEY="{api_key}"') + print(f' export DISPATCH_ENDPOINT_URL="{endpoint}"') + print(f' export DISPATCH_VERIFICATION_KEY="{verification_key}"') + print() + + server.wait() + + +if __name__ == "__main__": + main() diff --git a/src/dispatch/test/server.py b/src/dispatch/test/server.py index 775523d1..b6d9c11a 100644 --- a/src/dispatch/test/server.py +++ b/src/dispatch/test/server.py @@ -39,6 +39,10 @@ def start(self): """Start the server.""" self._server.start() + def wait(self): + """Block until the server terminates.""" + self._server.wait_for_termination() + def stop(self): """Stop the server.""" self._server.stop(0) diff --git a/src/dispatch/test/service.py b/src/dispatch/test/service.py index 731b8680..28fe5901 100644 --- a/src/dispatch/test/service.py +++ b/src/dispatch/test/service.py @@ -1,11 +1,14 @@ +import enum import logging import os import threading +import time from collections import OrderedDict from dataclasses import dataclass from typing import TypeAlias import grpc +import httpx import dispatch.sdk.v1.call_pb2 as call_pb import dispatch.sdk.v1.dispatch_pb2 as dispatch_pb @@ -34,6 +37,14 @@ """A request to a Dispatch endpoint, and the response that was received.""" +class CallType(enum.Enum): + """Type of function call.""" + + CALL = 0 + RESUME = 1 + RETRY = 2 + + class DispatchService(dispatch_grpc.DispatchServiceServicer): """Test instance of Dispatch that provides the bare minimum functionality required to test functions locally.""" @@ -70,7 +81,7 @@ def __init__( self._next_dispatch_id = 1 - self.queue: list[tuple[DispatchID, function_pb.RunRequest]] = [] + self.queue: list[tuple[DispatchID, function_pb.RunRequest, CallType]] = [] self.pollers: dict[DispatchID, Poller] = {} self.parents: dict[DispatchID, Poller] = {} @@ -93,15 +104,13 @@ def Dispatch(self, request: dispatch_pb.DispatchRequest, context): with self._work_signal: for call in request.calls: dispatch_id = self._make_dispatch_id() - logger.debug( - "enqueueing call to function %s as %s", call.function, dispatch_id - ) + logger.debug("enqueueing call to function: %s", call.function) resp.dispatch_ids.append(dispatch_id) run_request = function_pb.RunRequest( function=call.function, input=call.input, ) - self.queue.append((dispatch_id, run_request)) + self.queue.append((dispatch_id, run_request, CallType.CALL)) self._work_signal.notify() @@ -113,6 +122,9 @@ def _validate_authentication(self, context: grpc.ServicerContext): if key == "authorization": if value == expected: return + logger.warning( + "a client attempted to dispatch a function call with an incorrect API key. Is the client's DISPATCH_API_KEY correct?" + ) context.abort( grpc.StatusCode.UNAUTHENTICATED, f"Invalid authorization header. Expected '{expected}', got {value!r}", @@ -129,13 +141,23 @@ def dispatch_calls(self): configured endpoint.""" _next_queue = [] while self.queue: - dispatch_id, request = self.queue.pop(0) - - logger.debug( - "dispatching call to function %s (%s)", request.function, dispatch_id - ) - - response = self.endpoint_client.run(request) + dispatch_id, request, call_type = self.queue.pop(0) + + match call_type: + case CallType.CALL: + logger.info("calling function %s", request.function) + case CallType.RESUME: + logger.info("resuming function %s", request.function) + case CallType.RETRY: + logger.info("retrying function %s", request.function) + + try: + response = self.endpoint_client.run(request) + except: + logger.warning("call to function %s failed", request.function) + self.queue.extend(_next_queue) + self.queue.append((dispatch_id, request, CallType.RETRY)) + raise if self.roundtrips is not None: try: @@ -146,16 +168,26 @@ def dispatch_calls(self): roundtrips.append((request, response)) self.roundtrips[dispatch_id] = roundtrips - if Status(response.status) in self.retry_on_status: - logger.debug( - "retrying call to function %s (%s)", request.function, dispatch_id + status = Status(response.status) + if status == Status.OK: + logger.info("call to function %s succeeded", request.function) + else: + logger.warning( + "call to function %s failed (%s)", + request.function, + status, ) - _next_queue.append((dispatch_id, request)) + + if status in self.retry_on_status: + _next_queue.append((dispatch_id, request, CallType.RETRY)) elif response.HasField("poll"): assert not response.HasField("exit") + logger.info("suspending function %s", request.function) + logger.debug("registering poller %s", dispatch_id) + assert dispatch_id not in self.pollers poller = Poller( id=dispatch_id, @@ -172,7 +204,10 @@ def dispatch_calls(self): function=call.function, input=call.input, ) - _next_queue.append((child_dispatch_id, child_request)) + + _next_queue.append( + (child_dispatch_id, child_request, CallType.CALL) + ) self.parents[child_dispatch_id] = poller poller.waiting[child_dispatch_id] = call @@ -182,15 +217,14 @@ def dispatch_calls(self): if response.exit.HasField("tail_call"): tail_call = response.exit.tail_call logger.debug( - "enqueueing tail call to %s (%s)", + "enqueueing tail call for %s", tail_call.function, - dispatch_id, ) tail_call_request = function_pb.RunRequest( function=tail_call.function, input=tail_call.input, ) - _next_queue.append((dispatch_id, tail_call_request)) + _next_queue.append((dispatch_id, tail_call_request, CallType.CALL)) elif dispatch_id in self.parents: result = response.exit.result @@ -226,7 +260,9 @@ def dispatch_calls(self): ), ) del self.pollers[poller.id] - _next_queue.append((poller.id, poll_results_request)) + _next_queue.append( + (poller.id, poll_results_request, CallType.RESUME) + ) self.queue = _next_queue @@ -259,7 +295,32 @@ def _dispatch_continuously(self): if self._stop_event.is_set(): break - self.dispatch_calls() + try: + self.dispatch_calls() + except httpx.HTTPStatusError as e: + if e.response.status_code == 403: + logger.error( + "error dispatching function call to endpoint (403). Is the endpoint's DISPATCH_VERIFICATION_KEY correct?" + ) + else: + logger.exception(e) + except httpx.ConnectError as e: + logger.error( + "error connecting to the endpoint. Is it running and accessible from DISPATCH_ENDPOINT_URL?" + ) + except Exception as e: + logger.exception(e) + + # Introduce an artificial delay before continuing with + # follow-up work (retries, dispatching nested calls). + # This serves two purposes. Firstly, this is just a mock + # Dispatch server providing the bare minimum of functionality. + # Since there's no adaptive concurrency control, and no backoff + # between call attempts, the mock server may busy-loop without + # some sort of delay. Secondly, a bit of latency mimics the + # latency you would see in a production system and makes the + # log output easier to parse. + time.sleep(0.15) def __enter__(self): self.start() diff --git a/tests/test_client.py b/tests/test_client.py index 54d4bed8..29a2ddec 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -67,7 +67,7 @@ def test_call_pickle(self): pending_calls = self.dispatch_service.queue self.assertEqual(len(pending_calls), 1) - dispatch_id, call = pending_calls[0] + dispatch_id, call, _ = pending_calls[0] self.assertEqual(dispatch_id, dispatch_ids[0]) self.assertEqual(call.function, "my-function") self.assertEqual(any_unpickle(call.input), 42)