Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
NoneType,
Overloaded,
Parameters,
ParamSpecFlavor,
ParamSpecType,
PartialType,
ProperType,
Expand All @@ -36,7 +37,6 @@
UninhabitedType,
UnionType,
UnpackType,
expand_param_spec,
flatten_nested_unions,
get_proper_type,
)
Expand Down Expand Up @@ -247,7 +247,33 @@ def visit_param_spec(self, t: ParamSpecType) -> Type:
# TODO: why does this case even happen? Instances aren't plural.
return repl
elif isinstance(repl, (ParamSpecType, Parameters, CallableType)):
return expand_param_spec(t, repl)
if isinstance(repl, ParamSpecType):
return repl.copy_modified(
flavor=t.flavor,
prefix=t.prefix.copy_modified(
arg_types=t.prefix.arg_types + repl.prefix.arg_types,
arg_kinds=t.prefix.arg_kinds + repl.prefix.arg_kinds,
arg_names=t.prefix.arg_names + repl.prefix.arg_names,
),
)
else:
# if the paramspec is *P.args or **P.kwargs:
if t.flavor != ParamSpecFlavor.BARE:
assert isinstance(repl, CallableType), "Should not be able to get here."
# Is this always the right thing to do?
param_spec = repl.param_spec()
if param_spec:
return param_spec.with_flavor(t.flavor)
else:
return repl
else:
return Parameters(
t.prefix.arg_types + repl.arg_types,
t.prefix.arg_kinds + repl.arg_kinds,
t.prefix.arg_names + repl.arg_names,
variables=[*t.prefix.variables, *repl.variables],
)

else:
# TODO: should this branch be removed? better not to fail silently
return repl
Expand Down
116 changes: 24 additions & 92 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,14 @@ def _expand_once(self) -> Type:
# as their target.
assert isinstance(self.alias.target, Instance) # type: ignore[misc]
return self.alias.target.copy_modified(args=self.args)
return replace_alias_tvars(
self.alias.target, self.alias.alias_tvars, self.args, self.line, self.column
replacer = InstantiateAliasVisitor(
{v.id: s for (v, s) in zip(self.alias.alias_tvars, self.args)}
)
new_tp = self.alias.target.accept(replacer)
new_tp.accept(LocationSetter(self.line, self.column))
new_tp.line = self.line
new_tp.column = self.column
return new_tp

def _partial_expansion(self, nothing_args: bool = False) -> tuple[ProperType, bool]:
# Private method mostly for debugging and testing.
Expand Down Expand Up @@ -3243,49 +3248,6 @@ def is_named_instance(t: Type, fullnames: str | tuple[str, ...]) -> TypeGuard[In
return isinstance(t, Instance) and t.type.fullname in fullnames


class InstantiateAliasVisitor(TrivialSyntheticTypeTranslator):
def __init__(self, vars: list[TypeVarLikeType], subs: list[Type]) -> None:
self.replacements = {v.id: s for (v, s) in zip(vars, subs)}

def visit_type_alias_type(self, typ: TypeAliasType) -> Type:
return typ.copy_modified(args=[t.accept(self) for t in typ.args])

def visit_type_var(self, typ: TypeVarType) -> Type:
if typ.id in self.replacements:
return self.replacements[typ.id]
return typ

def visit_callable_type(self, t: CallableType) -> Type:
param_spec = t.param_spec()
if param_spec is not None:
# TODO: this branch duplicates the one in expand_type(), find a way to reuse it
# without import cycle types <-> typeanal <-> expandtype.
repl = get_proper_type(self.replacements.get(param_spec.id))
if isinstance(repl, (CallableType, Parameters)):
prefix = param_spec.prefix
t = t.expand_param_spec(repl, no_prefix=True)
return t.copy_modified(
arg_types=[t.accept(self) for t in prefix.arg_types] + t.arg_types,
arg_kinds=prefix.arg_kinds + t.arg_kinds,
arg_names=prefix.arg_names + t.arg_names,
ret_type=t.ret_type.accept(self),
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
)
return super().visit_callable_type(t)

def visit_param_spec(self, typ: ParamSpecType) -> Type:
if typ.id in self.replacements:
repl = get_proper_type(self.replacements[typ.id])
# TODO: all the TODOs from same logic in expand_type() apply here.
if isinstance(repl, Instance):
return repl
elif isinstance(repl, (ParamSpecType, Parameters, CallableType)):
return expand_param_spec(typ, repl)
else:
return repl
return typ


class LocationSetter(TypeTraverserVisitor):
# TODO: Should we update locations of other Type subclasses?
def __init__(self, line: int, column: int) -> None:
Expand All @@ -3298,20 +3260,6 @@ def visit_instance(self, typ: Instance) -> None:
super().visit_instance(typ)


def replace_alias_tvars(
tp: Type, vars: list[TypeVarLikeType], subs: list[Type], newline: int, newcolumn: int
) -> Type:
"""Replace type variables in a generic type alias tp with substitutions subs
resetting context. Length of subs should be already checked.
"""
replacer = InstantiateAliasVisitor(vars, subs)
new_tp = tp.accept(replacer)
new_tp.accept(LocationSetter(newline, newcolumn))
new_tp.line = newline
new_tp.column = newcolumn
return new_tp


class HasTypeVars(BoolTypeQuery):
def __init__(self) -> None:
super().__init__(ANY_STRATEGY)
Expand Down Expand Up @@ -3408,36 +3356,20 @@ def callable_with_ellipsis(any_type: AnyType, ret_type: Type, fallback: Instance
)


def expand_param_spec(
t: ParamSpecType, repl: ParamSpecType | Parameters | CallableType
) -> ProperType:
"""This is shared part of the logic w.r.t. ParamSpec instantiation.

It is shared between type aliases and proper types, that currently use somewhat different
logic for instantiation."""
if isinstance(repl, ParamSpecType):
return repl.copy_modified(
flavor=t.flavor,
prefix=t.prefix.copy_modified(
arg_types=t.prefix.arg_types + repl.prefix.arg_types,
arg_kinds=t.prefix.arg_kinds + repl.prefix.arg_kinds,
arg_names=t.prefix.arg_names + repl.prefix.arg_names,
),
)
else:
# if the paramspec is *P.args or **P.kwargs:
if t.flavor != ParamSpecFlavor.BARE:
assert isinstance(repl, CallableType), "Should not be able to get here."
# Is this always the right thing to do?
param_spec = repl.param_spec()
if param_spec:
return param_spec.with_flavor(t.flavor)
else:
return repl
else:
return Parameters(
t.prefix.arg_types + repl.arg_types,
t.prefix.arg_kinds + repl.arg_kinds,
t.prefix.arg_names + repl.arg_names,
variables=[*t.prefix.variables, *repl.variables],
)
# This cyclic import is unfortunate, but to avoid it we would need to move away all uses
# of get_proper_type() from types.py. Majority of them have been removed, but few remaining
# are quite tricky to get rid of, but ultimately we want to do it at some point.
from mypy.expandtype import ExpandTypeVisitor


class InstantiateAliasVisitor(ExpandTypeVisitor):
def visit_union_type(self, t: UnionType) -> Type:
# Unlike regular expand_type(), we don't do any simplification for unions,
# not even removing strict duplicates. There are three reasons for this:
# * get_proper_type() is a very hot function, even slightest slow down will
# cause a perf regression
# * We want to preserve this historical behaviour, to avoid possible
# regressions
# * Simplifying unions may (indirectly) call get_proper_type(), causing
# infinite recursion.
return TypeTranslator.visit_union_type(self, t)
15 changes: 13 additions & 2 deletions test-data/unit/check-typeguard.test
Original file line number Diff line number Diff line change
Expand Up @@ -604,11 +604,11 @@ from typing_extensions import TypeGuard
class Z:
def typeguard1(self, *, x: object) -> TypeGuard[int]: # line 4
...

@staticmethod
def typeguard2(x: object) -> TypeGuard[int]:
...

@staticmethod # line 11
def typeguard3(*, x: object) -> TypeGuard[int]:
...
Expand Down Expand Up @@ -688,3 +688,14 @@ if typeguard(x=x, y="42"):
if typeguard(y="42", x=x):
reveal_type(x) # N: Revealed type is "builtins.str"
[builtins fixtures/tuple.pyi]

[case testGenericAliasWithTypeGuard]
from typing import Callable, List, TypeVar
from typing_extensions import TypeGuard, TypeAlias

A = Callable[[object], TypeGuard[List[T]]]
def foo(x: object) -> TypeGuard[List[str]]: ...

def test(f: A[T]) -> T: ...
reveal_type(test(foo)) # N: Revealed type is "builtins.str"
[builtins fixtures/list.pyi]