@@ -64,6 +64,7 @@ def __init__(self, name: str) -> None: ...
6464
6565class Coroutine(Generic[_T_co, _S, _R]): ...
6666class Iterable(Generic[_T_co]): ...
67+ class Iterator(Iterable[_T_co]): ...
6768class Mapping(Generic[_K, _V]): ...
6869class Match(Generic[AnyStr]): ...
6970class Sequence(Iterable[_T_co]): ...
@@ -86,7 +87,9 @@ def __init__(self) -> None: pass
8687 def __repr__(self) -> str: pass
8788class type: ...
8889
89- class tuple(Sequence[T_co], Generic[T_co]): ...
90+ class tuple(Sequence[T_co], Generic[T_co]):
91+ def __ge__(self, __other: tuple[T_co, ...]) -> bool: pass
92+
9093class dict(Mapping[KT, VT]): ...
9194
9295class function: pass
@@ -105,6 +108,39 @@ def classmethod(f: T) -> T: ...
105108def staticmethod(f: T) -> T: ...
106109"""
107110
111+ stubtest_enum_stub = """
112+ import sys
113+ from typing import Any, TypeVar, Iterator
114+
115+ _T = TypeVar('_T')
116+
117+ class EnumMeta(type):
118+ def __len__(self) -> int: pass
119+ def __iter__(self: type[_T]) -> Iterator[_T]: pass
120+ def __reversed__(self: type[_T]) -> Iterator[_T]: pass
121+ def __getitem__(self: type[_T], name: str) -> _T: pass
122+
123+ class Enum(metaclass=EnumMeta):
124+ def __new__(cls: type[_T], value: object) -> _T: pass
125+ def __repr__(self) -> str: pass
126+ def __str__(self) -> str: pass
127+ def __format__(self, format_spec: str) -> str: pass
128+ def __hash__(self) -> Any: pass
129+ def __reduce_ex__(self, proto: Any) -> Any: pass
130+ name: str
131+ value: Any
132+
133+ class Flag(Enum):
134+ def __or__(self: _T, other: _T) -> _T: pass
135+ def __and__(self: _T, other: _T) -> _T: pass
136+ def __xor__(self: _T, other: _T) -> _T: pass
137+ def __invert__(self: _T) -> _T: pass
138+ if sys.version_info >= (3, 11):
139+ __ror__ = __or__
140+ __rand__ = __and__
141+ __rxor__ = __xor__
142+ """
143+
108144
109145def run_stubtest (
110146 stub : str , runtime : str , options : list [str ], config_file : str | None = None
@@ -114,6 +150,8 @@ def run_stubtest(
114150 f .write (stubtest_builtins_stub )
115151 with open ("typing.pyi" , "w" ) as f :
116152 f .write (stubtest_typing_stub )
153+ with open ("enum.pyi" , "w" ) as f :
154+ f .write (stubtest_enum_stub )
117155 with open (f"{ TEST_MODULE_NAME } .pyi" , "w" ) as f :
118156 f .write (stub )
119157 with open (f"{ TEST_MODULE_NAME } .py" , "w" ) as f :
@@ -954,23 +992,82 @@ def fizz(self): pass
954992
955993 @collect_cases
956994 def test_enum (self ) -> Iterator [Case ]:
995+ yield Case (stub = "import enum" , runtime = "import enum" , error = None )
957996 yield Case (
958997 stub = """
959- import enum
960998 class X(enum.Enum):
961999 a: int
9621000 b: str
9631001 c: str
9641002 """ ,
9651003 runtime = """
966- import enum
9671004 class X(enum.Enum):
9681005 a = 1
9691006 b = "asdf"
9701007 c = 2
9711008 """ ,
9721009 error = "X.c" ,
9731010 )
1011+ yield Case (
1012+ stub = """
1013+ class Flags1(enum.Flag):
1014+ a: int
1015+ b: int
1016+ def foo(x: Flags1 = ...) -> None: ...
1017+ """ ,
1018+ runtime = """
1019+ class Flags1(enum.Flag):
1020+ a = 1
1021+ b = 2
1022+ def foo(x=Flags1.a|Flags1.b): pass
1023+ """ ,
1024+ error = None ,
1025+ )
1026+ yield Case (
1027+ stub = """
1028+ class Flags2(enum.Flag):
1029+ a: int
1030+ b: int
1031+ def bar(x: Flags2 | None = None) -> None: ...
1032+ """ ,
1033+ runtime = """
1034+ class Flags2(enum.Flag):
1035+ a = 1
1036+ b = 2
1037+ def bar(x=Flags2.a|Flags2.b): pass
1038+ """ ,
1039+ error = "bar" ,
1040+ )
1041+ yield Case (
1042+ stub = """
1043+ class Flags3(enum.Flag):
1044+ a: int
1045+ b: int
1046+ def baz(x: Flags3 | None = ...) -> None: ...
1047+ """ ,
1048+ runtime = """
1049+ class Flags3(enum.Flag):
1050+ a = 1
1051+ b = 2
1052+ def baz(x=Flags3(0)): pass
1053+ """ ,
1054+ error = None ,
1055+ )
1056+ yield Case (
1057+ stub = """
1058+ class Flags4(enum.Flag):
1059+ a: int
1060+ b: int
1061+ def spam(x: Flags4 | None = None) -> None: ...
1062+ """ ,
1063+ runtime = """
1064+ class Flags4(enum.Flag):
1065+ a = 1
1066+ b = 2
1067+ def spam(x=Flags4(0)): pass
1068+ """ ,
1069+ error = "spam" ,
1070+ )
9741071
9751072 @collect_cases
9761073 def test_decorator (self ) -> Iterator [Case ]:
0 commit comments