@@ -2,7 +2,7 @@ import abc
22import sys
33from _typeshed import FileDescriptorOrPath , Unused
44from abc import ABC , abstractmethod
5- from collections .abc import AsyncGenerator , AsyncIterator , Awaitable , Callable , Generator , Iterator
5+ from collections .abc import AsyncGenerator , AsyncIterator , Awaitable , Callable , Generator
66from types import TracebackType
77from typing import IO , Any , Generic , Protocol , TypeVar , overload , runtime_checkable
88from typing_extensions import ParamSpec , Self , TypeAlias
@@ -36,6 +36,9 @@ _F = TypeVar("_F", bound=Callable[..., Any])
3636_G = TypeVar ("_G" , bound = Generator [Any , Any , Any ] | AsyncGenerator [Any , Any ], covariant = True )
3737_P = ParamSpec ("_P" )
3838
39+ _SendT_contra = TypeVar ("_SendT_contra" , contravariant = True , default = None )
40+ _ReturnT_co = TypeVar ("_ReturnT_co" , covariant = True , default = None )
41+
3942_ExitFunc : TypeAlias = Callable [[type [BaseException ] | None , BaseException | None , TracebackType | None ], bool | None ]
4043_CM_EF = TypeVar ("_CM_EF" , bound = AbstractContextManager [Any , Any ] | _ExitFunc )
4144
@@ -74,7 +77,9 @@ class _GeneratorContextManagerBase(Generic[_G]):
7477 kwds : dict [str , Any ]
7578
7679class _GeneratorContextManager (
77- _GeneratorContextManagerBase [Generator [_T_co , Any , Any ]], AbstractContextManager [_T_co , bool | None ], ContextDecorator
80+ _GeneratorContextManagerBase [Generator [_T_co , _SendT_contra , _ReturnT_co ]],
81+ AbstractContextManager [_T_co , bool | None ],
82+ ContextDecorator ,
7883):
7984 if sys .version_info >= (3 , 9 ):
8085 def __exit__ (
@@ -85,7 +90,9 @@ class _GeneratorContextManager(
8590 self , type : type [BaseException ] | None , value : BaseException | None , traceback : TracebackType | None
8691 ) -> bool | None : ...
8792
88- def contextmanager (func : Callable [_P , Iterator [_T_co ]]) -> Callable [_P , _GeneratorContextManager [_T_co ]]: ...
93+ def contextmanager (
94+ func : Callable [_P , Generator [_T_co , _SendT_contra , _ReturnT_co ]]
95+ ) -> Callable [_P , _GeneratorContextManager [_T_co , _SendT_contra , _ReturnT_co ]]: ...
8996
9097if sys .version_info >= (3 , 10 ):
9198 _AF = TypeVar ("_AF" , bound = Callable [..., Awaitable [Any ]])
@@ -95,7 +102,7 @@ if sys.version_info >= (3, 10):
95102 def __call__ (self , func : _AF ) -> _AF : ...
96103
97104 class _AsyncGeneratorContextManager (
98- _GeneratorContextManagerBase [AsyncGenerator [_T_co , Any ]],
105+ _GeneratorContextManagerBase [AsyncGenerator [_T_co , _SendT_contra ]],
99106 AbstractAsyncContextManager [_T_co , bool | None ],
100107 AsyncContextDecorator ,
101108 ):
@@ -105,7 +112,7 @@ if sys.version_info >= (3, 10):
105112
106113else :
107114 class _AsyncGeneratorContextManager (
108- _GeneratorContextManagerBase [AsyncGenerator [_T_co , Any ]], AbstractAsyncContextManager [_T_co , bool | None ]
115+ _GeneratorContextManagerBase [AsyncGenerator [_T_co , _SendT_contra ]], AbstractAsyncContextManager [_T_co , bool | None ]
109116 ):
110117 async def __aexit__ (
111118 self , typ : type [BaseException ] | None , value : BaseException | None , traceback : TracebackType | None
0 commit comments