Skip to content

Commit e306a5e

Browse files
committed
Declare is_awaitable as a type guard for Awaitables
1 parent 018c599 commit e306a5e

File tree

7 files changed

+96
-81
lines changed

7 files changed

+96
-81
lines changed

docs/conf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,13 +151,15 @@
151151
"traceback",
152152
"TypeMap",
153153
"AwaitableOrValue",
154+
"CancellableStreamRecord",
154155
"DeferredFragmentRecord",
155156
"DeferredGroupedFieldSetRecord",
156157
"DeferredGroupedFieldSetResult",
157158
"DeferUsage",
158159
"EnterLeaveVisitor",
159160
"ExperimentalIncrementalExecutionResults",
160161
"FieldGroup",
162+
"FormattedCompletedResult",
161163
"FormattedIncrementalResult",
162164
"FormattedPendingResult",
163165
"FormattedSourceLocation",
@@ -167,6 +169,7 @@
167169
"GraphQLErrorExtensions",
168170
"GraphQLFieldResolver",
169171
"GraphQLInputType",
172+
"GraphQLLeafType",
170173
"GraphQLNullableType",
171174
"GraphQLOutputType",
172175
"GraphQLTypeResolver",

src/graphql/execution/execute.py

Lines changed: 49 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from asyncio import (
66
CancelledError,
7+
TimeoutError, # only needed for Python < 3.11 # noqa: A004
78
ensure_future,
89
)
910
from contextlib import suppress
@@ -29,15 +30,6 @@
2930
cast,
3031
)
3132

32-
try:
33-
from typing import TypeAlias
34-
except ImportError: # Python < 3.10
35-
from typing_extensions import TypeAlias
36-
try: # only needed for Python < 3.11
37-
from asyncio.exceptions import TimeoutError # noqa: A004
38-
except ImportError: # Python < 3.7
39-
from concurrent.futures import TimeoutError # noqa: A004
40-
4133
from ..error import GraphQLError, located_error
4234
from ..language import (
4335
DocumentNode,
@@ -119,6 +111,11 @@
119111
if TYPE_CHECKING:
120112
from graphql.pyutils.undefined import UndefinedType
121113

114+
try:
115+
from typing import TypeAlias, TypeGuard
116+
except ImportError: # Python < 3.10
117+
from typing_extensions import TypeAlias, TypeGuard
118+
122119
try: # pragma: no cover
123120
anext # noqa: B018 # pyright: ignore
124121
except NameError: # pragma: no cover (Python < 3.10)
@@ -216,7 +213,9 @@ class ExecutionContext(IncrementalPublisherContext):
216213
cancellable_streams: set[CancellableStreamRecord] | None
217214
middleware_manager: MiddlewareManager | None
218215

219-
is_awaitable: Callable[[Any], bool] = staticmethod(default_is_awaitable)
216+
is_awaitable: Callable[[Any], TypeGuard[Awaitable]] = staticmethod(
217+
default_is_awaitable # type: ignore
218+
)
220219

221220
def __init__(
222221
self,
@@ -230,7 +229,7 @@ def __init__(
230229
type_resolver: GraphQLTypeResolver,
231230
subscribe_field_resolver: GraphQLFieldResolver,
232231
middleware_manager: MiddlewareManager | None,
233-
is_awaitable: Callable[[Any], bool] | None,
232+
is_awaitable: Callable[[Any], TypeGuard[Awaitable]] | None = None,
234233
) -> None:
235234
self.schema = schema
236235
self.fragments = fragments
@@ -242,8 +241,7 @@ def __init__(
242241
self.type_resolver = type_resolver
243242
self.subscribe_field_resolver = subscribe_field_resolver
244243
self.middleware_manager = middleware_manager
245-
if is_awaitable:
246-
self.is_awaitable = is_awaitable
244+
self.is_awaitable = is_awaitable or default_is_awaitable
247245
self.errors = None
248246
self.cancellable_streams = None
249247
self._canceled_iterators: set[AsyncIterator] = set()
@@ -264,7 +262,7 @@ def build(
264262
type_resolver: GraphQLTypeResolver | None = None,
265263
subscribe_field_resolver: GraphQLFieldResolver | None = None,
266264
middleware: Middleware | None = None,
267-
is_awaitable: Callable[[Any], bool] | None = None,
265+
is_awaitable: Callable[[Any], TypeGuard[Awaitable]] | None = None,
268266
**custom_args: Any,
269267
) -> list[GraphQLError] | ExecutionContext:
270268
"""Build an execution context
@@ -422,7 +420,7 @@ async def await_result() -> (
422420
ExecutionResult | ExperimentalIncrementalExecutionResults
423421
):
424422
try:
425-
resolved = await graphql_wrapped_result # type: ignore
423+
resolved = await graphql_wrapped_result
426424
except GraphQLError as error:
427425
return ExecutionResult(None, with_error(self.errors, error))
428426
return self.build_data_response(
@@ -496,7 +494,7 @@ def reducer(
496494
if is_awaitable(result):
497495

498496
async def set_result() -> GraphQLWrappedResult[dict[str, Any]]:
499-
resolved = await result # type: ignore
497+
resolved = await result
500498
graphql_wrapped_result.result[response_name] = resolved.result
501499
graphql_wrapped_result.add_increments(resolved.increments)
502500
return graphql_wrapped_result
@@ -553,11 +551,12 @@ async def resolve(
553551
add_increments(resolved.increments)
554552
return resolved.result
555553

556-
results[response_name] = resolve(result) # type: ignore
554+
results[response_name] = resolve(result)
557555
append_awaitable(response_name)
558556
else:
559-
results[response_name] = result.result # type: ignore
560-
add_increments(result.increments) # type: ignore
557+
result = cast("GraphQLWrappedResult[dict[str, Any]]", result)
558+
results[response_name] = result.result
559+
add_increments(result.increments)
561560

562561
# If there are no coroutines, we can just return the object.
563562
if not awaitable_fields:
@@ -651,7 +650,7 @@ def execute_field(
651650
# noinspection PyShadowingNames
652651
async def await_completed() -> Any:
653652
try:
654-
return await completed # type: ignore
653+
return await completed
655654
except Exception as raw_error:
656655
# Before Python 3.8 CancelledError inherits Exception and
657656
# so gets caught here.
@@ -864,7 +863,7 @@ async def complete_awaitable_value(
864863
defer_map,
865864
)
866865
if self.is_awaitable(completed):
867-
completed = await completed # type: ignore
866+
completed = await completed
868867
except Exception as raw_error:
869868
# Before Python 3.8 CancelledError inherits Exception and
870869
# so gets caught here.
@@ -1276,8 +1275,9 @@ async def complete_awaitable_list_item_value(
12761275
defer_map,
12771276
)
12781277
if self.is_awaitable(completed):
1279-
completed = await completed # type: ignore
1280-
parent.add_increments(completed.increments) # type: ignore
1278+
completed = await completed
1279+
completed = cast("GraphQLWrappedResult[list[Any]]", completed)
1280+
parent.add_increments(completed.increments)
12811281
except Exception as raw_error:
12821282
self.handle_field_error(
12831283
raw_error,
@@ -1288,7 +1288,7 @@ async def complete_awaitable_list_item_value(
12881288
)
12891289
return None
12901290
else:
1291-
return completed.result # type: ignore
1291+
return completed.result
12921292

12931293
@staticmethod
12941294
def complete_leaf_value(return_type: GraphQLLeafType, result: Any) -> Any:
@@ -1326,7 +1326,6 @@ def complete_abstract_value(
13261326
runtime_type = resolve_type_fn(result, info, return_type)
13271327

13281328
if self.is_awaitable(runtime_type):
1329-
runtime_type = cast("Awaitable", runtime_type)
13301329

13311330
async def await_complete_object_value() -> Any:
13321331
value = self.complete_object_value(
@@ -1345,7 +1344,7 @@ async def await_complete_object_value() -> Any:
13451344
defer_map,
13461345
)
13471346
if self.is_awaitable(value):
1348-
return await value # type: ignore
1347+
return await value
13491348
return value # pragma: no cover
13501349

13511350
return await_complete_object_value()
@@ -1447,7 +1446,7 @@ def complete_object_value(
14471446
async def execute_subfields_async() -> GraphQLWrappedResult[
14481447
dict[str, Any]
14491448
]:
1450-
if not await is_type_of: # type: ignore
1449+
if not await is_type_of:
14511450
raise invalid_return_type_error(
14521451
return_type, result, field_group
14531452
)
@@ -1460,8 +1459,10 @@ async def execute_subfields_async() -> GraphQLWrappedResult[
14601459
defer_map,
14611460
)
14621461
if self.is_awaitable(graphql_wrapped_result): # pragma: no cover
1463-
return await graphql_wrapped_result # type: ignore
1464-
return graphql_wrapped_result # type: ignore
1462+
return await graphql_wrapped_result
1463+
return cast(
1464+
"GraphQLWrappedResult[dict[str, Any]]", graphql_wrapped_result
1465+
)
14651466

14661467
return execute_subfields_async()
14671468

@@ -1644,8 +1645,8 @@ async def executor(
16441645
defer_map,
16451646
)
16461647
if self.is_awaitable(result):
1647-
return await result # type: ignore
1648-
return result # type: ignore
1648+
return await result
1649+
return cast("DeferredGroupedFieldSetResult", result)
16491650

16501651
deferred_grouped_field_set_record = DeferredGroupedFieldSetRecord(
16511652
deferred_fragment_records,
@@ -1702,7 +1703,7 @@ def execute_deferred_grouped_field_set(
17021703

17031704
async def await_result() -> DeferredGroupedFieldSetResult:
17041705
try:
1705-
awaited_result = await result # type: ignore
1706+
awaited_result = await result
17061707
except GraphQLError as error:
17071708
return NonReconcilableDeferredGroupedFieldSetResult(
17081709
deferred_fragment_records,
@@ -1792,8 +1793,8 @@ async def await_result() -> StreamItemsResult:
17921793

17931794
result = first_stream_items.result
17941795
if is_awaitable(result):
1795-
return await result # type: ignore
1796-
return result # type: ignore
1796+
return await result
1797+
return cast("StreamItemsResult", result)
17971798

17981799
return StreamItemsRecord(stream_record, await_result())
17991800

@@ -1864,8 +1865,8 @@ async def get_next_async_stream_items_result(
18641865
result = self.prepend_next_stream_items(result, next_stream_items_record)
18651866

18661867
if self.is_awaitable(result):
1867-
return await result # type: ignore
1868-
return result # type: ignore
1868+
return await result
1869+
return cast("StreamItemsResult", result)
18691870

18701871
def complete_stream_items(
18711872
self,
@@ -1932,7 +1933,7 @@ async def await_item() -> StreamItemsResult:
19321933
async def await_item() -> StreamItemsResult:
19331934
try:
19341935
try:
1935-
awaited_item = await result # type: ignore
1936+
awaited_item = await result
19361937
except Exception as raw_error:
19371938
self.handle_field_error(
19381939
raw_error,
@@ -1967,13 +1968,13 @@ def prepend_next_stream_items(
19671968
if self.is_awaitable(result):
19681969

19691970
async def await_result() -> StreamItemsResult:
1970-
resolved = await result # type: ignore
1971+
resolved = await result
19711972
return prepend_next_resolved_stream_items(resolved, next_stream_items)
19721973

19731974
return await_result()
19741975

19751976
return prepend_next_resolved_stream_items(
1976-
result, # type: ignore
1977+
cast("StreamItemsResult", result),
19771978
next_stream_items,
19781979
)
19791980

@@ -1986,7 +1987,7 @@ def with_new_deferred_grouped_field_sets(
19861987
if self.is_awaitable(result):
19871988

19881989
async def await_result() -> GraphQLWrappedResult[dict[str, Any]]:
1989-
resolved = await result # type: ignore
1990+
resolved = await result
19901991
resolved.add_increments(new_deferred_grouped_field_set_records)
19911992
return resolved
19921993

@@ -2091,7 +2092,7 @@ def execute(
20912092
subscribe_field_resolver: GraphQLFieldResolver | None = None,
20922093
middleware: Middleware | None = None,
20932094
execution_context_class: type[ExecutionContext] | None = None,
2094-
is_awaitable: Callable[[Any], bool] | None = None,
2095+
is_awaitable: Callable[[Any], TypeGuard[Awaitable]] | None = None,
20952096
**custom_context_args: Any,
20962097
) -> AwaitableOrValue[ExecutionResult]:
20972098
"""Execute a GraphQL operation.
@@ -2153,7 +2154,7 @@ def experimental_execute_incrementally(
21532154
subscribe_field_resolver: GraphQLFieldResolver | None = None,
21542155
middleware: Middleware | None = None,
21552156
execution_context_class: type[ExecutionContext] | None = None,
2156-
is_awaitable: Callable[[Any], bool] | None = None,
2157+
is_awaitable: Callable[[Any], TypeGuard[Awaitable]] | None = None,
21572158
**custom_context_args: Any,
21582159
) -> AwaitableOrValue[ExecutionResult | ExperimentalIncrementalExecutionResults]:
21592160
"""Execute GraphQL operation incrementally (internal implementation).
@@ -2193,7 +2194,7 @@ def experimental_execute_incrementally(
21932194
return context.execute_operation()
21942195

21952196

2196-
def assume_not_awaitable(_value: Any) -> bool:
2197+
def assume_not_awaitable(_value: Any) -> TypeGuard[Awaitable]:
21972198
"""Replacement for is_awaitable if everything is assumed to be synchronous."""
21982199
return False
21992200

@@ -2221,7 +2222,7 @@ def execute_sync(
22212222
Set check_sync to True to still run checks that no awaitable values are returned.
22222223
"""
22232224
is_awaitable = (
2224-
check_sync
2225+
cast("Callable[[Any], TypeGuard[Awaitable]]", check_sync)
22252226
if callable(check_sync)
22262227
else (None if check_sync else assume_not_awaitable)
22272228
)
@@ -2434,7 +2435,7 @@ def default_type_resolver(
24342435
return type_.name
24352436

24362437
if awaitable_is_type_of_results:
2437-
# noinspection PyShadowingNames
2438+
24382439
async def get_type() -> str | None:
24392440
is_type_of_results = await gather_with_cancel(*awaitable_is_type_of_results)
24402441
for is_type_of_result, type_ in zip(is_type_of_results, awaitable_types):
@@ -2533,9 +2534,9 @@ def subscribe(
25332534
result_or_stream = create_source_event_stream_impl(context)
25342535

25352536
if context.is_awaitable(result_or_stream):
2536-
# noinspection PyShadowingNames
2537+
25372538
async def await_result() -> Any:
2538-
awaited_result_or_stream = await result_or_stream # type: ignore
2539+
awaited_result_or_stream = await result_or_stream
25392540
if isinstance(awaited_result_or_stream, ExecutionResult):
25402541
return awaited_result_or_stream
25412542
return context.map_source_to_response(awaited_result_or_stream)
@@ -2616,12 +2617,10 @@ def create_source_event_stream_impl(
26162617
return ExecutionResult(None, errors=[error])
26172618

26182619
if context.is_awaitable(event_stream):
2619-
awaitable_event_stream = cast("Awaitable", event_stream)
26202620

2621-
# noinspection PyShadowingNames
26222621
async def await_event_stream() -> AsyncIterable[Any] | ExecutionResult:
26232622
try:
2624-
return await awaitable_event_stream
2623+
return await event_stream
26252624
except GraphQLError as error:
26262625
return ExecutionResult(None, errors=[error])
26272626

0 commit comments

Comments
 (0)