66from typing import Iterable , Sequence
77from typing_extensions import TypeAlias as _TypeAlias
88
9- from mypy .constraints import SUBTYPE_OF , SUPERTYPE_OF , Constraint , infer_constraints
9+ from mypy .constraints import SUBTYPE_OF , SUPERTYPE_OF , Constraint , infer_constraints , neg_op
1010from mypy .expandtype import expand_type
1111from mypy .graph_utils import prepare_sccs , strongly_connected_components , topsort
1212from mypy .join import join_types
@@ -69,6 +69,10 @@ def solve_constraints(
6969 extra_vars .extend ([v .id for v in c .extra_tvars if v .id not in vars + extra_vars ])
7070 originals .update ({v .id : v for v in c .extra_tvars if v .id not in originals })
7171
72+ if allow_polymorphic :
73+ # Constraints inferred from unions require special handling in polymorphic inference.
74+ constraints = skip_reverse_union_constraints (constraints )
75+
7276 # Collect a list of constraints for each type variable.
7377 cmap : dict [TypeVarId , list [Constraint ]] = {tv : [] for tv in vars + extra_vars }
7478 for con in constraints :
@@ -431,19 +435,15 @@ def transitive_closure(
431435 uppers [l ] |= uppers [upper ]
432436 for lt in lowers [lower ]:
433437 for ut in uppers [upper ]:
434- # TODO: what if secondary constraints result in inference
435- # against polymorphic actual (also in below branches)?
436- remaining |= set (infer_constraints (lt , ut , SUBTYPE_OF ))
437- remaining |= set (infer_constraints (ut , lt , SUPERTYPE_OF ))
438+ add_secondary_constraints (remaining , lt , ut )
438439 elif c .op == SUBTYPE_OF :
439440 if c .target in uppers [c .type_var ]:
440441 continue
441442 for l in tvars :
442443 if (l , c .type_var ) in graph :
443444 uppers [l ].add (c .target )
444445 for lt in lowers [c .type_var ]:
445- remaining |= set (infer_constraints (lt , c .target , SUBTYPE_OF ))
446- remaining |= set (infer_constraints (c .target , lt , SUPERTYPE_OF ))
446+ add_secondary_constraints (remaining , lt , c .target )
447447 else :
448448 assert c .op == SUPERTYPE_OF
449449 if c .target in lowers [c .type_var ]:
@@ -452,11 +452,24 @@ def transitive_closure(
452452 if (c .type_var , u ) in graph :
453453 lowers [u ].add (c .target )
454454 for ut in uppers [c .type_var ]:
455- remaining |= set (infer_constraints (ut , c .target , SUPERTYPE_OF ))
456- remaining |= set (infer_constraints (c .target , ut , SUBTYPE_OF ))
455+ add_secondary_constraints (remaining , c .target , ut )
457456 return graph , lowers , uppers
458457
459458
459+ def add_secondary_constraints (cs : set [Constraint ], lower : Type , upper : Type ) -> None :
460+ """Add secondary constraints inferred between lower and upper (in place)."""
461+ if isinstance (get_proper_type (upper ), UnionType ) and isinstance (
462+ get_proper_type (lower ), UnionType
463+ ):
464+ # When both types are unions, this can lead to inferring spurious constraints,
465+ # for example Union[T, int] <: S <: Union[T, int] may infer T <: int.
466+ # To avoid this, just skip them for now.
467+ return
468+ # TODO: what if secondary constraints result in inference against polymorphic actual?
469+ cs .update (set (infer_constraints (lower , upper , SUBTYPE_OF )))
470+ cs .update (set (infer_constraints (upper , lower , SUPERTYPE_OF )))
471+
472+
460473def compute_dependencies (
461474 tvars : list [TypeVarId ], graph : Graph , lowers : Bounds , uppers : Bounds
462475) -> dict [TypeVarId , list [TypeVarId ]]:
@@ -494,6 +507,28 @@ def check_linear(scc: set[TypeVarId], lowers: Bounds, uppers: Bounds) -> bool:
494507 return True
495508
496509
510+ def skip_reverse_union_constraints (cs : list [Constraint ]) -> list [Constraint ]:
511+ """Avoid ambiguities for constraints inferred from unions during polymorphic inference.
512+
513+ Polymorphic inference implicitly relies on assumption that a reverse of a linear constraint
514+ is a linear constraint. This is however not true in presence of union types, for example
515+ T :> Union[S, int] vs S <: T. Trying to solve such constraints would be detected ambiguous
516+ as (T, S) form a non-linear SCC. However, simply removing the linear part results in a valid
517+ solution T = Union[S, int], S = <free>.
518+
519+ TODO: a cleaner solution may be to avoid inferring such constraints in first place, but
520+ this would require passing around a flag through all infer_constraints() calls.
521+ """
522+ reverse_union_cs = set ()
523+ for c in cs :
524+ p_target = get_proper_type (c .target )
525+ if isinstance (p_target , UnionType ):
526+ for item in p_target .items :
527+ if isinstance (item , TypeVarType ):
528+ reverse_union_cs .add (Constraint (item , neg_op (c .op ), c .origin_type_var ))
529+ return [c for c in cs if c not in reverse_union_cs ]
530+
531+
497532def get_vars (target : Type , vars : list [TypeVarId ]) -> set [TypeVarId ]:
498533 """Find type variables for which we are solving in a target type."""
499534 return {tv .id for tv in get_all_type_vars (target )} & set (vars )
0 commit comments