Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/auto_retry/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_app(self):
from .app import app, dispatch

# Setup a fake Dispatch server.
endpoint_client = EndpointClient.from_app(app)
endpoint_client = EndpointClient(TestClient(app))
dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True)
with DispatchServer(dispatch_service) as dispatch_server:
# Use it when dispatching function calls.
Expand Down
2 changes: 1 addition & 1 deletion examples/getting_started/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_app(self):
from .app import app, dispatch

# Setup a fake Dispatch server.
endpoint_client = EndpointClient.from_app(app)
endpoint_client = EndpointClient(TestClient(app))
dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True)
with DispatchServer(dispatch_service) as dispatch_server:
# Use it when dispatching function calls.
Expand Down
2 changes: 1 addition & 1 deletion examples/github_stats/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_app(self):
from .app import app, dispatch

# Setup a fake Dispatch server.
endpoint_client = EndpointClient.from_app(app)
endpoint_client = EndpointClient(TestClient(app))
dispatch_service = DispatchService(endpoint_client, collect_roundtrips=True)
with DispatchServer(dispatch_service) as dispatch_server:
# Use it when dispatching function calls.
Expand Down
13 changes: 1 addition & 12 deletions src/dispatch/test/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from datetime import datetime
from typing import Optional

import fastapi
import grpc
import httpx
from fastapi.testclient import TestClient

from dispatch.sdk.v1 import function_pb2 as function_pb
from dispatch.sdk.v1 import function_pb2_grpc as function_grpc
Expand All @@ -22,7 +20,7 @@ class EndpointClient:
Note that this is different from dispatch.Client, which is a client
for the Dispatch API. The EndpointClient is a client similar to the one
that Dispatch itself would use to interact with an endpoint that provides
functions, for example a FastAPI app.
functions.
"""

def __init__(
Expand Down Expand Up @@ -54,15 +52,6 @@ def from_url(cls, url: str, signing_key: Optional[Ed25519PrivateKey] = None):
http_client = httpx.Client(base_url=url)
return EndpointClient(http_client, signing_key)

@classmethod
def from_app(
cls, app: fastapi.FastAPI, signing_key: Optional[Ed25519PrivateKey] = None
):
"""Returns an EndpointClient for a Dispatch endpoint bound to a
FastAPI app instance."""
http_client = TestClient(app)
return EndpointClient(http_client, signing_key)


class _HttpxGrpcChannel(grpc.Channel):
def __init__(
Expand Down
124 changes: 112 additions & 12 deletions tests/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,36 @@
import pickle
import struct
import unittest
from typing import Any
from typing import Any, Optional
from unittest import mock

import fastapi
import google.protobuf.any_pb2
import google.protobuf.wrappers_pb2
import httpx
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey
from cryptography.hazmat.primitives.asymmetric.ed25519 import (
Ed25519PrivateKey,
Ed25519PublicKey,
)
from fastapi.testclient import TestClient

import dispatch
from dispatch.experimental.durable.registry import clear_functions
from dispatch.fastapi import Dispatch
from dispatch.function import Arguments, Error, Function, Input, Output
from dispatch.proto import _any_unpickle as any_unpickle
from dispatch.sdk.v1 import call_pb2 as call_pb
from dispatch.sdk.v1 import function_pb2 as function_pb
from dispatch.signature import parse_verification_key, public_key_from_pem
from dispatch.signature import (
parse_verification_key,
private_key_from_pem,
public_key_from_pem,
)
from dispatch.status import Status
from dispatch.test import EndpointClient
from dispatch.test import DispatchServer, DispatchService, EndpointClient


def create_dispatch_instance(app, endpoint):
def create_dispatch_instance(app: fastapi.FastAPI, endpoint: str):
return Dispatch(
app,
endpoint=endpoint,
Expand All @@ -33,6 +41,13 @@ def create_dispatch_instance(app, endpoint):
)


def create_endpoint_client(
app: fastapi.FastAPI, signing_key: Optional[Ed25519PrivateKey] = None
):
http_client = TestClient(app)
return EndpointClient(http_client, signing_key)


class TestFastAPI(unittest.TestCase):
def test_Dispatch(self):
app = fastapi.FastAPI()
Expand All @@ -54,10 +69,6 @@ def read_root():
resp = client.post("/dispatch.sdk.v1.FunctionService/Run")
self.assertEqual(resp.status_code, 400)

def test_Dispatch_no_app(self):
with self.assertRaises(ValueError):
create_dispatch_instance(None, endpoint="http://127.0.0.1:9999")

@mock.patch.dict(os.environ, {"DISPATCH_ENDPOINT_URL": ""})
def test_Dispatch_no_endpoint(self):
app = fastapi.FastAPI()
Expand All @@ -79,8 +90,7 @@ def my_function(input: Input) -> Output:
f"You told me: '{input.input}' ({len(input.input)} characters)"
)

client = EndpointClient.from_app(app)

client = create_endpoint_client(app)
pickled = pickle.dumps("Hello World!")
input_any = google.protobuf.any_pb2.Any()
input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled))
Expand All @@ -102,6 +112,96 @@ def my_function(input: Input) -> Output:
self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")


signing_key = private_key_from_pem(
"""
-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEIJ+DYvh6SEqVTm50DFtMDoQikTmiCqirVv9mWG9qfSnF
-----END PRIVATE KEY-----
"""
)

verification_key = public_key_from_pem(
"""
-----BEGIN PUBLIC KEY-----
MCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=
-----END PUBLIC KEY-----
"""
)


class TestFullFastapi(unittest.TestCase):
def setUp(self):
self.endpoint_app = fastapi.FastAPI()
endpoint_client = create_endpoint_client(self.endpoint_app, signing_key)

api_key = "0000000000000000"
self.dispatch_service = DispatchService(
endpoint_client, api_key, collect_roundtrips=True
)
self.dispatch_server = DispatchServer(self.dispatch_service)
self.dispatch_client = dispatch.Client(
api_key, api_url=self.dispatch_server.url
)

self.dispatch = Dispatch(
self.endpoint_app,
endpoint="http://function-service", # unused
verification_key=verification_key,
api_key=api_key,
api_url=self.dispatch_server.url,
)

self.dispatch_server.start()

def tearDown(self):
self.dispatch_server.stop()

def test_simple_end_to_end(self):
# The FastAPI server.
@self.dispatch.function
def my_function(name: str) -> str:
return f"Hello world: {name}"

call = my_function.build_call(52)
self.assertEqual(call.function.split(".")[-1], "my_function")

# The client.
[dispatch_id] = self.dispatch_client.dispatch([my_function.build_call(52)])

# Simulate execution for testing purposes.
self.dispatch_service.dispatch_calls()

# Validate results.
roundtrips = self.dispatch_service.roundtrips[dispatch_id]
self.assertEqual(len(roundtrips), 1)
_, response = roundtrips[0]
self.assertEqual(any_unpickle(response.exit.result.output), "Hello world: 52")

def test_simple_missing_signature(self):
@self.dispatch.function
async def my_function(name: str) -> str:
return f"Hello world: {name}"

call = my_function.build_call(52)
self.assertEqual(call.function.split(".")[-1], "my_function")

[dispatch_id] = self.dispatch_client.dispatch([call])

self.dispatch_service.endpoint_client = create_endpoint_client(
self.endpoint_app
) # no signing key
try:
self.dispatch_service.dispatch_calls()
except httpx.HTTPStatusError as e:
assert e.response.status_code == 403
assert e.response.json() == {
"code": "permission_denied",
"message": 'Expected "Signature-Input" header field to be present',
}
else:
assert False, "Expected HTTPStatusError"


def response_output(resp: function_pb.RunResponse) -> Any:
return any_unpickle(resp.exit.result.output)

Expand All @@ -120,7 +220,7 @@ def root():
self.app, endpoint="https://127.0.0.1:9999"
)
self.http_client = TestClient(self.app)
self.client = EndpointClient.from_app(self.app)
self.client = create_endpoint_client(self.app)

def execute(
self, func: Function, input=None, state=None, calls=None
Expand Down
99 changes: 0 additions & 99 deletions tests/test_full.py

This file was deleted.