Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import dispatch.integrations
from dispatch.coroutine import call, gather
from dispatch.coroutine import all, any, call, gather, race
from dispatch.function import DEFAULT_API_URL, Client
from dispatch.id import DispatchID
from dispatch.proto import Call, Error, Input, Output
Expand All @@ -20,4 +20,7 @@
"Status",
"call",
"gather",
"all",
"any",
"race",
]
58 changes: 53 additions & 5 deletions src/dispatch/coroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,60 @@ def call(call: Call) -> Any:
@coroutine
@durable
def gather(*awaitables: Awaitable[Any]) -> list[Any]: # type: ignore[misc]
"""Concurrently run a set of coroutines and block until all
results are available. If any coroutine fails with an uncaught
exception, it will be re-raised when awaiting a result here."""
return (yield Gather(awaitables))
"""Alias for all."""
return all(*awaitables)


@coroutine
@durable
def all(*awaitables: Awaitable[Any]) -> list[Any]: # type: ignore[misc]
"""Concurrently run a set of coroutines, blocking until all coroutines
return or any coroutine raises an error. If any coroutine fails with an
uncaught exception, the exception will be re-raised here."""
return (yield AllDirective(awaitables))


@coroutine
@durable
def any(*awaitables: Awaitable[Any]) -> list[Any]: # type: ignore[misc]
"""Concurrently run a set of coroutines, blocking until any coroutine
returns or all coroutines raises an error. If all coroutines fail with
uncaught exceptions, the exception(s) will be re-raised here."""
return (yield AnyDirective(awaitables))


@coroutine
@durable
def race(*awaitables: Awaitable[Any]) -> list[Any]: # type: ignore[misc]
"""Concurrently run a set of coroutines, blocking until any coroutine
returns or raises an error. If any coroutine fails with an uncaught
exception, the exception will be re-raised here."""
return (yield RaceDirective(awaitables))


@dataclass(slots=True)
class Gather:
class AllDirective:
awaitables: tuple[Awaitable[Any], ...]


@dataclass(slots=True)
class AnyDirective:
awaitables: tuple[Awaitable[Any], ...]


@dataclass(slots=True)
class RaceDirective:
awaitables: tuple[Awaitable[Any], ...]


class AnyException(RuntimeError):
"""Error indicating that all coroutines passed to any() failed
with an exception."""

__slots__ = ("exceptions",)

def __init__(self, exceptions: list[Exception]):
self.exceptions = exceptions

def __str__(self):
return f"{len(self.exceptions)} coroutine(s) failed with an exception"
191 changes: 159 additions & 32 deletions src/dispatch/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging
import pickle
import sys
from dataclasses import dataclass
from typing import Any, Callable, Protocol, TypeAlias
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable, Protocol, TypeAlias

from dispatch.coroutine import Gather
from dispatch.coroutine import AllDirective, AnyDirective, AnyException, RaceDirective
from dispatch.error import IncompatibleStateError
from dispatch.experimental.durable.function import DurableCoroutine, DurableGenerator
from dispatch.proto import Call, Error, Input, Output
Expand Down Expand Up @@ -73,17 +73,18 @@ def error(self) -> Exception | None:
return self.first_error

def value(self) -> Any:
assert self.first_error is None
assert self.result is not None
return self.result.value


@dataclass(slots=True)
class GatherFuture:
"""A future result of a dispatch.coroutine.gather() operation."""
class AllFuture:
"""A future result of a dispatch.coroutine.all() operation."""

order: list[CoroutineID]
waiting: set[CoroutineID]
results: dict[CoroutineID, CoroutineResult]
order: list[CoroutineID] = field(default_factory=list)
waiting: set[CoroutineID] = field(default_factory=set)
results: dict[CoroutineID, CoroutineResult] = field(default_factory=dict)
first_error: Exception | None = None

def add_result(self, result: CallResult | CoroutineResult):
Expand All @@ -94,13 +95,15 @@ def add_result(self, result: CallResult | CoroutineResult):
except KeyError:
return

if result.error is not None and self.first_error is None:
self.first_error = result.error
if result.error is not None:
if self.first_error is None:
self.first_error = result.error
return

self.results[result.coroutine_id] = result

def add_error(self, error: Exception):
if self.first_error is not None:
if self.first_error is None:
self.first_error = error

def ready(self) -> bool:
Expand All @@ -113,9 +116,108 @@ def error(self) -> Exception | None:
def value(self) -> list[Any]:
assert self.ready()
assert len(self.waiting) == 0
assert self.first_error is None
return [self.results[id].value for id in self.order]


@dataclass(slots=True)
class AnyFuture:
"""A future result of a dispatch.coroutine.any() operation."""

order: list[CoroutineID] = field(default_factory=list)
waiting: set[CoroutineID] = field(default_factory=set)
first_result: CoroutineResult | None = None
errors: dict[CoroutineID, Exception] = field(default_factory=dict)
generic_error: Exception | None = None

def add_result(self, result: CallResult | CoroutineResult):
assert isinstance(result, CoroutineResult)

try:
self.waiting.remove(result.coroutine_id)
except KeyError:
return

if result.error is None:
if self.first_result is None:
self.first_result = result
return

self.errors[result.coroutine_id] = result.error

def add_error(self, error: Exception):
if self.generic_error is None:
self.generic_error = error

def ready(self) -> bool:
return (
self.generic_error is not None
or self.first_result is not None
or len(self.waiting) == 0
)

def error(self) -> Exception | None:
assert self.ready()
if self.generic_error is not None:
return self.generic_error
if self.first_result is not None or len(self.errors) == 0:
return None
match len(self.errors):
case 0:
return None
case 1:
return self.errors[self.order[0]]
case _:
return AnyException([self.errors[id] for id in self.order])

def value(self) -> Any:
assert self.ready()
if len(self.order) == 0:
return None
assert self.first_result is not None
return self.first_result.value


@dataclass(slots=True)
class RaceFuture:
"""A future result of a dispatch.coroutine.race() operation."""

waiting: set[CoroutineID] = field(default_factory=set)
first_result: CoroutineResult | None = None
first_error: Exception | None = None

def add_result(self, result: CallResult | CoroutineResult):
assert isinstance(result, CoroutineResult)

if result.error is not None:
if self.first_error is None:
self.first_error = result.error
else:
if self.first_result is None:
self.first_result = result

self.waiting.remove(result.coroutine_id)

def add_error(self, error: Exception):
if self.first_error is None:
self.first_error = error

def ready(self) -> bool:
return (
self.first_error is not None
or self.first_result is not None
or len(self.waiting) == 0
)

def error(self) -> Exception | None:
assert self.ready()
return self.first_error

def value(self) -> Any:
assert self.first_error is None
return self.first_result.value if self.first_result else None


@dataclass(slots=True)
class Coroutine:
"""An in-flight coroutine."""
Expand Down Expand Up @@ -386,30 +488,35 @@ def _run(self, input: Input) -> Output:
state.prev_callers.append(coroutine)
state.outstanding_calls += 1

case Gather():
gather = coroutine_yield

children = []
for awaitable in gather.awaitables:
g = awaitable.__await__()
if not isinstance(g, DurableGenerator):
raise ValueError(
"gather awaitable is not a @dispatch.function"
)
child_id = state.next_coroutine_id
state.next_coroutine_id += 1
child = Coroutine(
id=child_id, parent_id=coroutine.id, coroutine=g
)
logger.debug("enqueuing %s for %s", child, coroutine)
children.append(child)
case AllDirective():
children = spawn_children(
state, coroutine, coroutine_yield.awaitables
)

# Prepend children to get a depth-first traversal of coroutines.
state.ready = children + state.ready
child_ids = [child.id for child in children]
coroutine.result = AllFuture(
order=child_ids, waiting=set(child_ids)
)
state.suspended[coroutine.id] = coroutine

case AnyDirective():
children = spawn_children(
state, coroutine, coroutine_yield.awaitables
)

child_ids = [child.id for child in children]
coroutine.result = GatherFuture(
order=child_ids, waiting=set(child_ids), results={}
coroutine.result = AnyFuture(
order=child_ids, waiting=set(child_ids)
)
state.suspended[coroutine.id] = coroutine

case RaceDirective():
children = spawn_children(
state, coroutine, coroutine_yield.awaitables
)

coroutine.result = RaceFuture(
waiting={child.id for child in children}
)
state.suspended[coroutine.id] = coroutine

Expand Down Expand Up @@ -446,6 +553,26 @@ def _run(self, input: Input) -> Output:
)


def spawn_children(
state: State, coroutine: Coroutine, awaitables: tuple[Awaitable[Any], ...]
) -> list[Coroutine]:
children = []
for awaitable in awaitables:
g = awaitable.__await__()
if not isinstance(g, DurableGenerator):
raise TypeError("awaitable is not a @dispatch.function")
child_id = state.next_coroutine_id
state.next_coroutine_id += 1
child = Coroutine(id=child_id, parent_id=coroutine.id, coroutine=g)
logger.debug("enqueuing %s for %s", child, coroutine)
children.append(child)

# Prepend children to get a depth-first traversal of coroutines.
state.ready = children + state.ready

return children


def correlation_id(coroutine_id: CoroutineID, call_id: CallID) -> CorrelationID:
return coroutine_id << 32 | call_id

Expand Down
Loading