Skip to content

Commit 806b134

Browse files
zmievsaAlexWaygood
andauthored
Add get_origin annotations (#9811)
Co-authored-by: Alex Waygood <[email protected]>
1 parent 13325d4 commit 806b134

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

stdlib/typing.pyi

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ from types import (
2020
)
2121
from typing_extensions import Never as _Never, ParamSpec as _ParamSpec, final as _final
2222

23+
if sys.version_info >= (3, 10):
24+
from types import UnionType
25+
if sys.version_info >= (3, 9):
26+
from types import GenericAlias
27+
2328
__all__ = [
2429
"AbstractSet",
2530
"Any",
@@ -745,9 +750,21 @@ else:
745750
) -> dict[str, Any]: ...
746751

747752
if sys.version_info >= (3, 8):
748-
def get_origin(tp: Any) -> Any | None: ...
749753
def get_args(tp: Any) -> tuple[Any, ...]: ...
750754

755+
if sys.version_info >= (3, 10):
756+
@overload
757+
def get_origin(tp: ParamSpecArgs | ParamSpecKwargs) -> ParamSpec: ...
758+
@overload
759+
def get_origin(tp: UnionType) -> type[UnionType]: ...
760+
if sys.version_info >= (3, 9):
761+
@overload
762+
def get_origin(tp: GenericAlias) -> type: ...
763+
@overload
764+
def get_origin(tp: Any) -> Any | None: ...
765+
else:
766+
def get_origin(tp: Any) -> Any | None: ...
767+
751768
@overload
752769
def cast(typ: Type[_T], val: Any) -> _T: ...
753770
@overload

stdlib/typing_extensions.pyi

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ from typing import ( # noqa: Y022,Y039
3232
type_check_only,
3333
)
3434

35+
if sys.version_info >= (3, 10):
36+
from types import UnionType
37+
if sys.version_info >= (3, 9):
38+
from types import GenericAlias
39+
3540
__all__ = [
3641
"Any",
3742
"ClassVar",
@@ -155,6 +160,18 @@ def get_type_hints(
155160
include_extras: bool = False,
156161
) -> dict[str, Any]: ...
157162
def get_args(tp: Any) -> tuple[Any, ...]: ...
163+
164+
if sys.version_info >= (3, 10):
165+
@overload
166+
def get_origin(tp: UnionType) -> type[UnionType]: ...
167+
168+
if sys.version_info >= (3, 9):
169+
@overload
170+
def get_origin(tp: GenericAlias) -> type: ...
171+
172+
@overload
173+
def get_origin(tp: ParamSpecArgs | ParamSpecKwargs) -> ParamSpec: ...
174+
@overload
158175
def get_origin(tp: Any) -> Any | None: ...
159176

160177
Annotated: _SpecialForm

0 commit comments

Comments
 (0)