4
4
5
5
from asyncio import (
6
6
CancelledError ,
7
+ TimeoutError , # only needed for Python < 3.11 # noqa: A004
7
8
ensure_future ,
8
9
)
9
10
from contextlib import suppress
29
30
cast ,
30
31
)
31
32
32
- try :
33
- from typing import TypeAlias
34
- except ImportError : # Python < 3.10
35
- from typing_extensions import TypeAlias
36
- try : # only needed for Python < 3.11
37
- from asyncio .exceptions import TimeoutError # noqa: A004
38
- except ImportError : # Python < 3.7
39
- from concurrent .futures import TimeoutError # noqa: A004
40
-
41
33
from ..error import GraphQLError , located_error
42
34
from ..language import (
43
35
DocumentNode ,
119
111
if TYPE_CHECKING :
120
112
from graphql .pyutils .undefined import UndefinedType
121
113
114
+ try :
115
+ from typing import TypeAlias , TypeGuard
116
+ except ImportError : # Python < 3.10
117
+ from typing_extensions import TypeAlias , TypeGuard
118
+
122
119
try : # pragma: no cover
123
120
anext # noqa: B018 # pyright: ignore
124
121
except NameError : # pragma: no cover (Python < 3.10)
@@ -216,7 +213,9 @@ class ExecutionContext(IncrementalPublisherContext):
216
213
cancellable_streams : set [CancellableStreamRecord ] | None
217
214
middleware_manager : MiddlewareManager | None
218
215
219
- is_awaitable : Callable [[Any ], bool ] = staticmethod (default_is_awaitable )
216
+ is_awaitable : Callable [[Any ], TypeGuard [Awaitable ]] = staticmethod (
217
+ default_is_awaitable # type: ignore
218
+ )
220
219
221
220
def __init__ (
222
221
self ,
@@ -230,7 +229,7 @@ def __init__(
230
229
type_resolver : GraphQLTypeResolver ,
231
230
subscribe_field_resolver : GraphQLFieldResolver ,
232
231
middleware_manager : MiddlewareManager | None ,
233
- is_awaitable : Callable [[Any ], bool ] | None ,
232
+ is_awaitable : Callable [[Any ], TypeGuard [ Awaitable ]] | None = None ,
234
233
) -> None :
235
234
self .schema = schema
236
235
self .fragments = fragments
@@ -242,8 +241,7 @@ def __init__(
242
241
self .type_resolver = type_resolver
243
242
self .subscribe_field_resolver = subscribe_field_resolver
244
243
self .middleware_manager = middleware_manager
245
- if is_awaitable :
246
- self .is_awaitable = is_awaitable
244
+ self .is_awaitable = is_awaitable or default_is_awaitable
247
245
self .errors = None
248
246
self .cancellable_streams = None
249
247
self ._canceled_iterators : set [AsyncIterator ] = set ()
@@ -264,7 +262,7 @@ def build(
264
262
type_resolver : GraphQLTypeResolver | None = None ,
265
263
subscribe_field_resolver : GraphQLFieldResolver | None = None ,
266
264
middleware : Middleware | None = None ,
267
- is_awaitable : Callable [[Any ], bool ] | None = None ,
265
+ is_awaitable : Callable [[Any ], TypeGuard [ Awaitable ] ] | None = None ,
268
266
** custom_args : Any ,
269
267
) -> list [GraphQLError ] | ExecutionContext :
270
268
"""Build an execution context
@@ -422,7 +420,7 @@ async def await_result() -> (
422
420
ExecutionResult | ExperimentalIncrementalExecutionResults
423
421
):
424
422
try :
425
- resolved = await graphql_wrapped_result # type: ignore
423
+ resolved = await graphql_wrapped_result
426
424
except GraphQLError as error :
427
425
return ExecutionResult (None , with_error (self .errors , error ))
428
426
return self .build_data_response (
@@ -496,7 +494,7 @@ def reducer(
496
494
if is_awaitable (result ):
497
495
498
496
async def set_result () -> GraphQLWrappedResult [dict [str , Any ]]:
499
- resolved = await result # type: ignore
497
+ resolved = await result
500
498
graphql_wrapped_result .result [response_name ] = resolved .result
501
499
graphql_wrapped_result .add_increments (resolved .increments )
502
500
return graphql_wrapped_result
@@ -553,11 +551,12 @@ async def resolve(
553
551
add_increments (resolved .increments )
554
552
return resolved .result
555
553
556
- results [response_name ] = resolve (result ) # type: ignore
554
+ results [response_name ] = resolve (result )
557
555
append_awaitable (response_name )
558
556
else :
559
- results [response_name ] = result .result # type: ignore
560
- add_increments (result .increments ) # type: ignore
557
+ result = cast ("GraphQLWrappedResult[dict[str, Any]]" , result )
558
+ results [response_name ] = result .result
559
+ add_increments (result .increments )
561
560
562
561
# If there are no coroutines, we can just return the object.
563
562
if not awaitable_fields :
@@ -651,7 +650,7 @@ def execute_field(
651
650
# noinspection PyShadowingNames
652
651
async def await_completed () -> Any :
653
652
try :
654
- return await completed # type: ignore
653
+ return await completed
655
654
except Exception as raw_error :
656
655
# Before Python 3.8 CancelledError inherits Exception and
657
656
# so gets caught here.
@@ -864,7 +863,7 @@ async def complete_awaitable_value(
864
863
defer_map ,
865
864
)
866
865
if self .is_awaitable (completed ):
867
- completed = await completed # type: ignore
866
+ completed = await completed
868
867
except Exception as raw_error :
869
868
# Before Python 3.8 CancelledError inherits Exception and
870
869
# so gets caught here.
@@ -1276,8 +1275,9 @@ async def complete_awaitable_list_item_value(
1276
1275
defer_map ,
1277
1276
)
1278
1277
if self .is_awaitable (completed ):
1279
- completed = await completed # type: ignore
1280
- parent .add_increments (completed .increments ) # type: ignore
1278
+ completed = await completed
1279
+ completed = cast ("GraphQLWrappedResult[list[Any]]" , completed )
1280
+ parent .add_increments (completed .increments )
1281
1281
except Exception as raw_error :
1282
1282
self .handle_field_error (
1283
1283
raw_error ,
@@ -1288,7 +1288,7 @@ async def complete_awaitable_list_item_value(
1288
1288
)
1289
1289
return None
1290
1290
else :
1291
- return completed .result # type: ignore
1291
+ return completed .result
1292
1292
1293
1293
@staticmethod
1294
1294
def complete_leaf_value (return_type : GraphQLLeafType , result : Any ) -> Any :
@@ -1326,7 +1326,6 @@ def complete_abstract_value(
1326
1326
runtime_type = resolve_type_fn (result , info , return_type )
1327
1327
1328
1328
if self .is_awaitable (runtime_type ):
1329
- runtime_type = cast ("Awaitable" , runtime_type )
1330
1329
1331
1330
async def await_complete_object_value () -> Any :
1332
1331
value = self .complete_object_value (
@@ -1345,7 +1344,7 @@ async def await_complete_object_value() -> Any:
1345
1344
defer_map ,
1346
1345
)
1347
1346
if self .is_awaitable (value ):
1348
- return await value # type: ignore
1347
+ return await value
1349
1348
return value # pragma: no cover
1350
1349
1351
1350
return await_complete_object_value ()
@@ -1447,7 +1446,7 @@ def complete_object_value(
1447
1446
async def execute_subfields_async () -> GraphQLWrappedResult [
1448
1447
dict [str , Any ]
1449
1448
]:
1450
- if not await is_type_of : # type: ignore
1449
+ if not await is_type_of :
1451
1450
raise invalid_return_type_error (
1452
1451
return_type , result , field_group
1453
1452
)
@@ -1460,8 +1459,10 @@ async def execute_subfields_async() -> GraphQLWrappedResult[
1460
1459
defer_map ,
1461
1460
)
1462
1461
if self .is_awaitable (graphql_wrapped_result ): # pragma: no cover
1463
- return await graphql_wrapped_result # type: ignore
1464
- return graphql_wrapped_result # type: ignore
1462
+ return await graphql_wrapped_result
1463
+ return cast (
1464
+ "GraphQLWrappedResult[dict[str, Any]]" , graphql_wrapped_result
1465
+ )
1465
1466
1466
1467
return execute_subfields_async ()
1467
1468
@@ -1644,8 +1645,8 @@ async def executor(
1644
1645
defer_map ,
1645
1646
)
1646
1647
if self .is_awaitable (result ):
1647
- return await result # type: ignore
1648
- return result # type: ignore
1648
+ return await result
1649
+ return cast ( "DeferredGroupedFieldSetResult" , result )
1649
1650
1650
1651
deferred_grouped_field_set_record = DeferredGroupedFieldSetRecord (
1651
1652
deferred_fragment_records ,
@@ -1702,7 +1703,7 @@ def execute_deferred_grouped_field_set(
1702
1703
1703
1704
async def await_result () -> DeferredGroupedFieldSetResult :
1704
1705
try :
1705
- awaited_result = await result # type: ignore
1706
+ awaited_result = await result
1706
1707
except GraphQLError as error :
1707
1708
return NonReconcilableDeferredGroupedFieldSetResult (
1708
1709
deferred_fragment_records ,
@@ -1792,8 +1793,8 @@ async def await_result() -> StreamItemsResult:
1792
1793
1793
1794
result = first_stream_items .result
1794
1795
if is_awaitable (result ):
1795
- return await result # type: ignore
1796
- return result # type: ignore
1796
+ return await result
1797
+ return cast ( "StreamItemsResult" , result )
1797
1798
1798
1799
return StreamItemsRecord (stream_record , await_result ())
1799
1800
@@ -1864,8 +1865,8 @@ async def get_next_async_stream_items_result(
1864
1865
result = self .prepend_next_stream_items (result , next_stream_items_record )
1865
1866
1866
1867
if self .is_awaitable (result ):
1867
- return await result # type: ignore
1868
- return result # type: ignore
1868
+ return await result
1869
+ return cast ( "StreamItemsResult" , result )
1869
1870
1870
1871
def complete_stream_items (
1871
1872
self ,
@@ -1932,7 +1933,7 @@ async def await_item() -> StreamItemsResult:
1932
1933
async def await_item () -> StreamItemsResult :
1933
1934
try :
1934
1935
try :
1935
- awaited_item = await result # type: ignore
1936
+ awaited_item = await result
1936
1937
except Exception as raw_error :
1937
1938
self .handle_field_error (
1938
1939
raw_error ,
@@ -1967,13 +1968,13 @@ def prepend_next_stream_items(
1967
1968
if self .is_awaitable (result ):
1968
1969
1969
1970
async def await_result () -> StreamItemsResult :
1970
- resolved = await result # type: ignore
1971
+ resolved = await result
1971
1972
return prepend_next_resolved_stream_items (resolved , next_stream_items )
1972
1973
1973
1974
return await_result ()
1974
1975
1975
1976
return prepend_next_resolved_stream_items (
1976
- result , # type: ignore
1977
+ cast ( "StreamItemsResult" , result ),
1977
1978
next_stream_items ,
1978
1979
)
1979
1980
@@ -1986,7 +1987,7 @@ def with_new_deferred_grouped_field_sets(
1986
1987
if self .is_awaitable (result ):
1987
1988
1988
1989
async def await_result () -> GraphQLWrappedResult [dict [str , Any ]]:
1989
- resolved = await result # type: ignore
1990
+ resolved = await result
1990
1991
resolved .add_increments (new_deferred_grouped_field_set_records )
1991
1992
return resolved
1992
1993
@@ -2091,7 +2092,7 @@ def execute(
2091
2092
subscribe_field_resolver : GraphQLFieldResolver | None = None ,
2092
2093
middleware : Middleware | None = None ,
2093
2094
execution_context_class : type [ExecutionContext ] | None = None ,
2094
- is_awaitable : Callable [[Any ], bool ] | None = None ,
2095
+ is_awaitable : Callable [[Any ], TypeGuard [ Awaitable ] ] | None = None ,
2095
2096
** custom_context_args : Any ,
2096
2097
) -> AwaitableOrValue [ExecutionResult ]:
2097
2098
"""Execute a GraphQL operation.
@@ -2153,7 +2154,7 @@ def experimental_execute_incrementally(
2153
2154
subscribe_field_resolver : GraphQLFieldResolver | None = None ,
2154
2155
middleware : Middleware | None = None ,
2155
2156
execution_context_class : type [ExecutionContext ] | None = None ,
2156
- is_awaitable : Callable [[Any ], bool ] | None = None ,
2157
+ is_awaitable : Callable [[Any ], TypeGuard [ Awaitable ] ] | None = None ,
2157
2158
** custom_context_args : Any ,
2158
2159
) -> AwaitableOrValue [ExecutionResult | ExperimentalIncrementalExecutionResults ]:
2159
2160
"""Execute GraphQL operation incrementally (internal implementation).
@@ -2193,7 +2194,7 @@ def experimental_execute_incrementally(
2193
2194
return context .execute_operation ()
2194
2195
2195
2196
2196
- def assume_not_awaitable (_value : Any ) -> bool :
2197
+ def assume_not_awaitable (_value : Any ) -> TypeGuard [ Awaitable ] :
2197
2198
"""Replacement for is_awaitable if everything is assumed to be synchronous."""
2198
2199
return False
2199
2200
@@ -2221,7 +2222,7 @@ def execute_sync(
2221
2222
Set check_sync to True to still run checks that no awaitable values are returned.
2222
2223
"""
2223
2224
is_awaitable = (
2224
- check_sync
2225
+ cast ( "Callable[[Any], TypeGuard[Awaitable]]" , check_sync )
2225
2226
if callable (check_sync )
2226
2227
else (None if check_sync else assume_not_awaitable )
2227
2228
)
@@ -2434,7 +2435,7 @@ def default_type_resolver(
2434
2435
return type_ .name
2435
2436
2436
2437
if awaitable_is_type_of_results :
2437
- # noinspection PyShadowingNames
2438
+
2438
2439
async def get_type () -> str | None :
2439
2440
is_type_of_results = await gather_with_cancel (* awaitable_is_type_of_results )
2440
2441
for is_type_of_result , type_ in zip (is_type_of_results , awaitable_types ):
@@ -2533,9 +2534,9 @@ def subscribe(
2533
2534
result_or_stream = create_source_event_stream_impl (context )
2534
2535
2535
2536
if context .is_awaitable (result_or_stream ):
2536
- # noinspection PyShadowingNames
2537
+
2537
2538
async def await_result () -> Any :
2538
- awaited_result_or_stream = await result_or_stream # type: ignore
2539
+ awaited_result_or_stream = await result_or_stream
2539
2540
if isinstance (awaited_result_or_stream , ExecutionResult ):
2540
2541
return awaited_result_or_stream
2541
2542
return context .map_source_to_response (awaited_result_or_stream )
@@ -2616,12 +2617,10 @@ def create_source_event_stream_impl(
2616
2617
return ExecutionResult (None , errors = [error ])
2617
2618
2618
2619
if context .is_awaitable (event_stream ):
2619
- awaitable_event_stream = cast ("Awaitable" , event_stream )
2620
2620
2621
- # noinspection PyShadowingNames
2622
2621
async def await_event_stream () -> AsyncIterable [Any ] | ExecutionResult :
2623
2622
try :
2624
- return await awaitable_event_stream
2623
+ return await event_stream
2625
2624
except GraphQLError as error :
2626
2625
return ExecutionResult (None , errors = [error ])
2627
2626
0 commit comments