Skip to content

Commit 73661f6

Browse files
committed
add the extra type variables
1 parent 4f83419 commit 73661f6

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

stdlib/contextlib.pyi

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import abc
22
import sys
33
from _typeshed import FileDescriptorOrPath, Unused
44
from abc import ABC, abstractmethod
5-
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable, Generator, Iterator
5+
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable, Generator
66
from types import TracebackType
77
from typing import IO, Any, Generic, Protocol, TypeVar, overload, runtime_checkable
88
from typing_extensions import ParamSpec, Self, TypeAlias
@@ -36,6 +36,9 @@ _F = TypeVar("_F", bound=Callable[..., Any])
3636
_G = TypeVar("_G", bound=Generator[Any, Any, Any] | AsyncGenerator[Any, Any], covariant=True)
3737
_P = ParamSpec("_P")
3838

39+
_SendT_contra = TypeVar("_SendT_contra", contravariant=True, default=None)
40+
_ReturnT_co = TypeVar("_ReturnT_co", covariant=True, default=None)
41+
3942
_ExitFunc: TypeAlias = Callable[[type[BaseException] | None, BaseException | None, TracebackType | None], bool | None]
4043
_CM_EF = TypeVar("_CM_EF", bound=AbstractContextManager[Any, Any] | _ExitFunc)
4144

@@ -74,7 +77,9 @@ class _GeneratorContextManagerBase(Generic[_G]):
7477
kwds: dict[str, Any]
7578

7679
class _GeneratorContextManager(
77-
_GeneratorContextManagerBase[Generator[_T_co, Any, Any]], AbstractContextManager[_T_co, bool | None], ContextDecorator
80+
_GeneratorContextManagerBase[Generator[_T_co, _SendT_contra, _ReturnT_co]],
81+
AbstractContextManager[_T_co, bool | None],
82+
ContextDecorator,
7883
):
7984
if sys.version_info >= (3, 9):
8085
def __exit__(
@@ -85,7 +90,9 @@ class _GeneratorContextManager(
8590
self, type: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None
8691
) -> bool | None: ...
8792

88-
def contextmanager(func: Callable[_P, Iterator[_T_co]]) -> Callable[_P, _GeneratorContextManager[_T_co]]: ...
93+
def contextmanager(
94+
func: Callable[_P, Generator[_T_co, _SendT_contra, _ReturnT_co]]
95+
) -> Callable[_P, _GeneratorContextManager[_T_co, _SendT_contra, _ReturnT_co]]: ...
8996

9097
if sys.version_info >= (3, 10):
9198
_AF = TypeVar("_AF", bound=Callable[..., Awaitable[Any]])
@@ -95,7 +102,7 @@ if sys.version_info >= (3, 10):
95102
def __call__(self, func: _AF) -> _AF: ...
96103

97104
class _AsyncGeneratorContextManager(
98-
_GeneratorContextManagerBase[AsyncGenerator[_T_co, Any]],
105+
_GeneratorContextManagerBase[AsyncGenerator[_T_co, _SendT_contra]],
99106
AbstractAsyncContextManager[_T_co, bool | None],
100107
AsyncContextDecorator,
101108
):
@@ -105,7 +112,7 @@ if sys.version_info >= (3, 10):
105112

106113
else:
107114
class _AsyncGeneratorContextManager(
108-
_GeneratorContextManagerBase[AsyncGenerator[_T_co, Any]], AbstractAsyncContextManager[_T_co, bool | None]
115+
_GeneratorContextManagerBase[AsyncGenerator[_T_co, _SendT_contra]], AbstractAsyncContextManager[_T_co, bool | None]
109116
):
110117
async def __aexit__(
111118
self, typ: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None

0 commit comments

Comments
 (0)