diff --git a/src/dispatch/__init__.py b/src/dispatch/__init__.py index a9affbda..154257db 100644 --- a/src/dispatch/__init__.py +++ b/src/dispatch/__init__.py @@ -3,7 +3,7 @@ from __future__ import annotations import dispatch.integrations -from dispatch.coroutine import call, gather +from dispatch.coroutine import all, any, call, gather, race from dispatch.function import DEFAULT_API_URL, Client from dispatch.id import DispatchID from dispatch.proto import Call, Error, Input, Output @@ -20,4 +20,7 @@ "Status", "call", "gather", + "all", + "any", + "race", ] diff --git a/src/dispatch/coroutine.py b/src/dispatch/coroutine.py index c4ea53ad..cf9d4c93 100644 --- a/src/dispatch/coroutine.py +++ b/src/dispatch/coroutine.py @@ -17,12 +17,60 @@ def call(call: Call) -> Any: @coroutine @durable def gather(*awaitables: Awaitable[Any]) -> list[Any]: # type: ignore[misc] - """Concurrently run a set of coroutines and block until all - results are available. If any coroutine fails with an uncaught - exception, it will be re-raised when awaiting a result here.""" - return (yield Gather(awaitables)) + """Alias for all.""" + return all(*awaitables) + + +@coroutine +@durable +def all(*awaitables: Awaitable[Any]) -> list[Any]: # type: ignore[misc] + """Concurrently run a set of coroutines, blocking until all coroutines + return or any coroutine raises an error. If any coroutine fails with an + uncaught exception, the exception will be re-raised here.""" + return (yield AllDirective(awaitables)) + + +@coroutine +@durable +def any(*awaitables: Awaitable[Any]) -> list[Any]: # type: ignore[misc] + """Concurrently run a set of coroutines, blocking until any coroutine + returns or all coroutines raises an error. If all coroutines fail with + uncaught exceptions, the exception(s) will be re-raised here.""" + return (yield AnyDirective(awaitables)) + + +@coroutine +@durable +def race(*awaitables: Awaitable[Any]) -> list[Any]: # type: ignore[misc] + """Concurrently run a set of coroutines, blocking until any coroutine + returns or raises an error. If any coroutine fails with an uncaught + exception, the exception will be re-raised here.""" + return (yield RaceDirective(awaitables)) @dataclass(slots=True) -class Gather: +class AllDirective: awaitables: tuple[Awaitable[Any], ...] + + +@dataclass(slots=True) +class AnyDirective: + awaitables: tuple[Awaitable[Any], ...] + + +@dataclass(slots=True) +class RaceDirective: + awaitables: tuple[Awaitable[Any], ...] + + +class AnyException(RuntimeError): + """Error indicating that all coroutines passed to any() failed + with an exception.""" + + __slots__ = ("exceptions",) + + def __init__(self, exceptions: list[Exception]): + self.exceptions = exceptions + + def __str__(self): + return f"{len(self.exceptions)} coroutine(s) failed with an exception" diff --git a/src/dispatch/scheduler.py b/src/dispatch/scheduler.py index 7b78c1a7..89c4d7af 100644 --- a/src/dispatch/scheduler.py +++ b/src/dispatch/scheduler.py @@ -1,10 +1,10 @@ import logging import pickle import sys -from dataclasses import dataclass -from typing import Any, Callable, Protocol, TypeAlias +from dataclasses import dataclass, field +from typing import Any, Awaitable, Callable, Protocol, TypeAlias -from dispatch.coroutine import Gather +from dispatch.coroutine import AllDirective, AnyDirective, AnyException, RaceDirective from dispatch.error import IncompatibleStateError from dispatch.experimental.durable.function import DurableCoroutine, DurableGenerator from dispatch.proto import Call, Error, Input, Output @@ -73,17 +73,18 @@ def error(self) -> Exception | None: return self.first_error def value(self) -> Any: + assert self.first_error is None assert self.result is not None return self.result.value @dataclass(slots=True) -class GatherFuture: - """A future result of a dispatch.coroutine.gather() operation.""" +class AllFuture: + """A future result of a dispatch.coroutine.all() operation.""" - order: list[CoroutineID] - waiting: set[CoroutineID] - results: dict[CoroutineID, CoroutineResult] + order: list[CoroutineID] = field(default_factory=list) + waiting: set[CoroutineID] = field(default_factory=set) + results: dict[CoroutineID, CoroutineResult] = field(default_factory=dict) first_error: Exception | None = None def add_result(self, result: CallResult | CoroutineResult): @@ -94,13 +95,15 @@ def add_result(self, result: CallResult | CoroutineResult): except KeyError: return - if result.error is not None and self.first_error is None: - self.first_error = result.error + if result.error is not None: + if self.first_error is None: + self.first_error = result.error + return self.results[result.coroutine_id] = result def add_error(self, error: Exception): - if self.first_error is not None: + if self.first_error is None: self.first_error = error def ready(self) -> bool: @@ -113,9 +116,108 @@ def error(self) -> Exception | None: def value(self) -> list[Any]: assert self.ready() assert len(self.waiting) == 0 + assert self.first_error is None return [self.results[id].value for id in self.order] +@dataclass(slots=True) +class AnyFuture: + """A future result of a dispatch.coroutine.any() operation.""" + + order: list[CoroutineID] = field(default_factory=list) + waiting: set[CoroutineID] = field(default_factory=set) + first_result: CoroutineResult | None = None + errors: dict[CoroutineID, Exception] = field(default_factory=dict) + generic_error: Exception | None = None + + def add_result(self, result: CallResult | CoroutineResult): + assert isinstance(result, CoroutineResult) + + try: + self.waiting.remove(result.coroutine_id) + except KeyError: + return + + if result.error is None: + if self.first_result is None: + self.first_result = result + return + + self.errors[result.coroutine_id] = result.error + + def add_error(self, error: Exception): + if self.generic_error is None: + self.generic_error = error + + def ready(self) -> bool: + return ( + self.generic_error is not None + or self.first_result is not None + or len(self.waiting) == 0 + ) + + def error(self) -> Exception | None: + assert self.ready() + if self.generic_error is not None: + return self.generic_error + if self.first_result is not None or len(self.errors) == 0: + return None + match len(self.errors): + case 0: + return None + case 1: + return self.errors[self.order[0]] + case _: + return AnyException([self.errors[id] for id in self.order]) + + def value(self) -> Any: + assert self.ready() + if len(self.order) == 0: + return None + assert self.first_result is not None + return self.first_result.value + + +@dataclass(slots=True) +class RaceFuture: + """A future result of a dispatch.coroutine.race() operation.""" + + waiting: set[CoroutineID] = field(default_factory=set) + first_result: CoroutineResult | None = None + first_error: Exception | None = None + + def add_result(self, result: CallResult | CoroutineResult): + assert isinstance(result, CoroutineResult) + + if result.error is not None: + if self.first_error is None: + self.first_error = result.error + else: + if self.first_result is None: + self.first_result = result + + self.waiting.remove(result.coroutine_id) + + def add_error(self, error: Exception): + if self.first_error is None: + self.first_error = error + + def ready(self) -> bool: + return ( + self.first_error is not None + or self.first_result is not None + or len(self.waiting) == 0 + ) + + def error(self) -> Exception | None: + assert self.ready() + return self.first_error + + def value(self) -> Any: + assert self.first_error is None + return self.first_result.value if self.first_result else None + + @dataclass(slots=True) class Coroutine: """An in-flight coroutine.""" @@ -386,30 +488,35 @@ def _run(self, input: Input) -> Output: state.prev_callers.append(coroutine) state.outstanding_calls += 1 - case Gather(): - gather = coroutine_yield - - children = [] - for awaitable in gather.awaitables: - g = awaitable.__await__() - if not isinstance(g, DurableGenerator): - raise ValueError( - "gather awaitable is not a @dispatch.function" - ) - child_id = state.next_coroutine_id - state.next_coroutine_id += 1 - child = Coroutine( - id=child_id, parent_id=coroutine.id, coroutine=g - ) - logger.debug("enqueuing %s for %s", child, coroutine) - children.append(child) + case AllDirective(): + children = spawn_children( + state, coroutine, coroutine_yield.awaitables + ) - # Prepend children to get a depth-first traversal of coroutines. - state.ready = children + state.ready + child_ids = [child.id for child in children] + coroutine.result = AllFuture( + order=child_ids, waiting=set(child_ids) + ) + state.suspended[coroutine.id] = coroutine + + case AnyDirective(): + children = spawn_children( + state, coroutine, coroutine_yield.awaitables + ) child_ids = [child.id for child in children] - coroutine.result = GatherFuture( - order=child_ids, waiting=set(child_ids), results={} + coroutine.result = AnyFuture( + order=child_ids, waiting=set(child_ids) + ) + state.suspended[coroutine.id] = coroutine + + case RaceDirective(): + children = spawn_children( + state, coroutine, coroutine_yield.awaitables + ) + + coroutine.result = RaceFuture( + waiting={child.id for child in children} ) state.suspended[coroutine.id] = coroutine @@ -446,6 +553,26 @@ def _run(self, input: Input) -> Output: ) +def spawn_children( + state: State, coroutine: Coroutine, awaitables: tuple[Awaitable[Any], ...] +) -> list[Coroutine]: + children = [] + for awaitable in awaitables: + g = awaitable.__await__() + if not isinstance(g, DurableGenerator): + raise TypeError("awaitable is not a @dispatch.function") + child_id = state.next_coroutine_id + state.next_coroutine_id += 1 + child = Coroutine(id=child_id, parent_id=coroutine.id, coroutine=g) + logger.debug("enqueuing %s for %s", child, coroutine) + children.append(child) + + # Prepend children to get a depth-first traversal of coroutines. + state.ready = children + state.ready + + return children + + def correlation_id(coroutine_id: CoroutineID, call_id: CallID) -> CorrelationID: return coroutine_id << 32 | call_id diff --git a/tests/dispatch/test_scheduler.py b/tests/dispatch/test_scheduler.py index 75a83306..2bfc079a 100644 --- a/tests/dispatch/test_scheduler.py +++ b/tests/dispatch/test_scheduler.py @@ -1,11 +1,17 @@ import unittest from typing import Any, Callable -from dispatch.coroutine import call, gather +from dispatch.coroutine import AnyException, any, call, gather, race from dispatch.experimental.durable import durable from dispatch.proto import Call, CallResult, Error, Input, Output from dispatch.proto import _any_unpickle as any_unpickle -from dispatch.scheduler import OneShotScheduler +from dispatch.scheduler import ( + AllFuture, + AnyFuture, + CoroutineResult, + OneShotScheduler, + RaceFuture, +) from dispatch.sdk.v1 import call_pb2 as call_pb from dispatch.sdk.v1 import exit_pb2 as exit_pb from dispatch.sdk.v1 import poll_pb2 as poll_pb @@ -16,6 +22,16 @@ async def call_one(function): return await call(Call(function=function)) +@durable +async def call_any(*functions): + return await any(*[call_one(function) for function in functions]) + + +@durable +async def call_race(*functions): + return await race(*[call_one(function) for function in functions]) + + @durable async def call_concurrently(*functions): return await gather(*[call_one(function) for function in functions]) @@ -162,6 +178,71 @@ async def main(): self.assert_exit_result_value(output, 0 + 1 + 2 + 3) + def test_resume_after_any_result(self): + @durable + async def main(): + return await call_any("a", "b", "c", "d") + + output = self.start(main) + calls = self.assert_poll_call_functions(output, ["a", "b", "c", "d"]) + + output = self.resume( + main, + output, + [CallResult.from_value(23, correlation_id=calls[1].correlation_id)], + ) + self.assert_exit_result_value(output, 23) + + def test_resume_after_all_errors(self): + @durable + async def main(): + return await call_any("a", "b", "c", "d") + + output = self.start(main) + calls = self.assert_poll_call_functions(output, ["a", "b", "c", "d"]) + results = [ + CallResult.from_error( + Error.from_exception(RuntimeError(f"oops{i}")), + correlation_id=call.correlation_id, + ) + for i, call in enumerate(calls) + ] + output = self.resume(main, output, results) + self.assert_exit_result_error( + output, AnyException, "4 coroutine(s) failed with an exception" + ) + + def test_resume_after_race_result(self): + @durable + async def main(): + return await call_race("a", "b", "c", "d") + + output = self.start(main) + calls = self.assert_poll_call_functions(output, ["a", "b", "c", "d"]) + + output = self.resume( + main, + output, + [CallResult.from_value(23, correlation_id=calls[1].correlation_id)], + ) + self.assert_exit_result_value(output, 23) + + def test_resume_after_race_error(self): + @durable + async def main(): + return await call_race("a", "b", "c", "d") + + output = self.start(main) + calls = self.assert_poll_call_functions(output, ["a", "b", "c", "d"]) + + error = Error.from_exception(RuntimeError("oops")) + output = self.resume( + main, + output, + [CallResult.from_error(error, correlation_id=calls[2].correlation_id)], + ) + self.assert_exit_result_error(output, RuntimeError, "oops") + def test_dag(self): @durable async def main(): @@ -408,3 +489,236 @@ def assert_poll_call_functions( if max_results is not None: self.assertEqual(max_results, poll.max_results) return poll.calls + + +class TestAllFuture(unittest.TestCase): + def test_empty(self): + future = AllFuture() + + self.assertTrue(future.ready()) + self.assertListEqual(future.value(), []) + self.assertIsNone(future.error()) + + def test_one_result_value(self): + future = AllFuture(order=[10], waiting={10}) + + self.assertFalse(future.ready()) + future.add_result(CoroutineResult(coroutine_id=10, value="foobar")) + + self.assertTrue(future.ready()) + self.assertIsNone(future.error()) + self.assertListEqual(future.value(), ["foobar"]) + + def test_one_result_error(self): + future = AllFuture(order=[10], waiting={10}) + + self.assertFalse(future.ready()) + error = RuntimeError("oops") + future.add_result(CoroutineResult(coroutine_id=10, error=error)) + + self.assertTrue(future.ready()) + self.assertIs(future.error(), error) + + with self.assertRaises(AssertionError): + future.value() + + def test_one_generic_error(self): + future = AllFuture(order=[10], waiting={10}) + + self.assertFalse(future.ready()) + error = RuntimeError("oops") + future.add_error(error) + + self.assertTrue(future.ready()) + self.assertIs(future.error(), error) + + with self.assertRaises(AssertionError): + future.value() + + def test_two_result_values(self): + future = AllFuture(order=[10, 20], waiting={10, 20}) + + self.assertFalse(future.ready()) + future.add_result(CoroutineResult(coroutine_id=20, value="bar")) + self.assertFalse(future.ready()) + future.add_result(CoroutineResult(coroutine_id=10, value="foo")) + + self.assertTrue(future.ready()) + self.assertIsNone(future.error()) + self.assertListEqual(future.value(), ["foo", "bar"]) + + def test_two_result_errors(self): + future = AllFuture(order=[10, 20], waiting={10, 20}) + + self.assertFalse(future.ready()) + error1 = RuntimeError("oops1") + error2 = RuntimeError("oops2") + future.add_result(CoroutineResult(coroutine_id=20, error=error2)) + + self.assertTrue(future.ready()) + self.assertIs(future.error(), error2) + + future.add_result(CoroutineResult(coroutine_id=10, error=error1)) + self.assertIs(future.error(), error2) + + future.add_error(error1) + self.assertIs(future.error(), error2) + + with self.assertRaises(AssertionError): + future.value() + + +class TestAnyFuture(unittest.TestCase): + def test_empty(self): + future = AnyFuture() + + self.assertTrue(future.ready()) + self.assertIsNone(future.value()) + self.assertIsNone(future.error()) + + def test_one_result_value(self): + future = AnyFuture(order=[10], waiting={10}) + + self.assertFalse(future.ready()) + future.add_result(CoroutineResult(coroutine_id=10, value="foobar")) + + self.assertTrue(future.ready()) + self.assertIsNone(future.error()) + self.assertEqual(future.value(), "foobar") + + def test_one_result_error(self): + future = AnyFuture(order=[10], waiting={10}) + + self.assertFalse(future.ready()) + error = RuntimeError("oops") + future.add_result(CoroutineResult(coroutine_id=10, error=error)) + + self.assertTrue(future.ready()) + self.assertIs(future.error(), error) + + with self.assertRaises(AssertionError): + future.value() + + def test_one_generic_error(self): + future = AnyFuture(order=[10], waiting={10}) + + self.assertFalse(future.ready()) + error = RuntimeError("oops") + future.add_error(error) + + self.assertTrue(future.ready()) + self.assertIs(future.error(), error) + + with self.assertRaises(AssertionError): + future.value() + + def test_two_result_values(self): + future = AnyFuture(order=[10, 20], waiting={10, 20}) + + self.assertFalse(future.ready()) + + future.add_result(CoroutineResult(coroutine_id=20, value="bar")) + self.assertTrue(future.ready()) + self.assertIsNone(future.error()) + self.assertEqual(future.value(), "bar") + + future.add_result(CoroutineResult(coroutine_id=10, value="foo")) + self.assertTrue(future.ready()) + self.assertIsNone(future.error()) + self.assertEqual(future.value(), "bar") + + def test_two_result_errors(self): + future = AnyFuture(order=[10, 20], waiting={10, 20}) + + self.assertFalse(future.ready()) + error1 = RuntimeError("oops1") + error2 = RuntimeError("oops2") + future.add_result(CoroutineResult(coroutine_id=20, error=error2)) + + self.assertFalse(future.ready()) + future.add_result(CoroutineResult(coroutine_id=10, error=error1)) + self.assertTrue(future.ready()) + self.assertEqual(repr(future.error()), repr(AnyException([error1, error2]))) + + with self.assertRaises(AssertionError): + future.value() + + +class TestRaceFuture(unittest.TestCase): + def test_empty(self): + future = RaceFuture() + + self.assertTrue(future.ready()) + self.assertIsNone(future.value()) + self.assertIsNone(future.error()) + + def test_one_result_value(self): + future = RaceFuture(waiting={10}) + + self.assertFalse(future.ready()) + future.add_result(CoroutineResult(coroutine_id=10, value="foobar")) + + self.assertTrue(future.ready()) + self.assertIsNone(future.error()) + self.assertEqual(future.value(), "foobar") + + def test_one_result_error(self): + future = RaceFuture(waiting={10}) + + self.assertFalse(future.ready()) + error = RuntimeError("oops") + future.add_result(CoroutineResult(coroutine_id=10, error=error)) + + self.assertTrue(future.ready()) + self.assertIs(future.error(), error) + + with self.assertRaises(AssertionError): + future.value() + + def test_one_generic_error(self): + future = RaceFuture(waiting={10}) + + self.assertFalse(future.ready()) + error = RuntimeError("oops") + future.add_error(error) + + self.assertTrue(future.ready()) + self.assertIs(future.error(), error) + + with self.assertRaises(AssertionError): + future.value() + + def test_two_result_values(self): + future = RaceFuture(waiting={10, 20}) + + self.assertFalse(future.ready()) + + future.add_result(CoroutineResult(coroutine_id=20, value="bar")) + self.assertTrue(future.ready()) + self.assertIsNone(future.error()) + self.assertEqual(future.value(), "bar") + + future.add_result(CoroutineResult(coroutine_id=10, value="foo")) + self.assertTrue(future.ready()) + self.assertIsNone(future.error()) + self.assertEqual(future.value(), "bar") + + def test_two_result_errors(self): + future = RaceFuture(waiting={10, 20}) + + self.assertFalse(future.ready()) + error1 = RuntimeError("oops") + future.add_result(CoroutineResult(coroutine_id=10, error=error1)) + + self.assertTrue(future.ready()) + self.assertIs(future.error(), error1) + + error2 = RuntimeError("oops2") + future.add_result(CoroutineResult(coroutine_id=20, error=error2)) + self.assertIs(future.error(), error1) + + future.add_error(error2) + self.assertIs(future.error(), error1) + + with self.assertRaises(AssertionError): + future.value()