@@ -595,15 +595,11 @@ def visit_parameters(self, template: Parameters) -> list[Constraint]:
595595 return self .infer_against_any (template .arg_types , self .actual )
596596 if type_state .infer_polymorphic and isinstance (self .actual , Parameters ):
597597 # For polymorphic inference we need to be able to infer secondary constraints
598- # in situations like [x: T] <: P <: [x: int].
599- res = []
600- if len (template .arg_types ) == len (self .actual .arg_types ):
601- for tt , at in zip (template .arg_types , self .actual .arg_types ):
602- # This avoids bogus constraints like T <: P.args
603- if isinstance (at , ParamSpecType ):
604- continue
605- res .extend (infer_constraints (tt , at , self .direction ))
606- return res
598+ # in situations like [x: T] <: P <: [x: int]. Note we invert direction, since
599+ # this function expects direction between callables.
600+ return infer_callable_arguments_constraints (
601+ template , self .actual , neg_op (self .direction )
602+ )
607603 raise RuntimeError ("Parameters cannot be constrained to" )
608604
609605 # Non-leaf types
@@ -722,7 +718,8 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
722718 prefix = mapped_arg .prefix
723719 if isinstance (instance_arg , Parameters ):
724720 # No such thing as variance for ParamSpecs, consider them invariant
725- # TODO: constraints between prefixes
721+ # TODO: constraints between prefixes using
722+ # infer_callable_arguments_constraints()
726723 suffix : Type = instance_arg .copy_modified (
727724 instance_arg .arg_types [len (prefix .arg_types ) :],
728725 instance_arg .arg_kinds [len (prefix .arg_kinds ) :],
@@ -793,7 +790,8 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
793790 prefix = template_arg .prefix
794791 if isinstance (mapped_arg , Parameters ):
795792 # No such thing as variance for ParamSpecs, consider them invariant
796- # TODO: constraints between prefixes
793+ # TODO: constraints between prefixes using
794+ # infer_callable_arguments_constraints()
797795 suffix = mapped_arg .copy_modified (
798796 mapped_arg .arg_types [len (prefix .arg_types ) :],
799797 mapped_arg .arg_kinds [len (prefix .arg_kinds ) :],
@@ -962,24 +960,12 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
962960 unpack_constraints = build_constraints_for_simple_unpack (
963961 template_types , actual_types , neg_op (self .direction )
964962 )
965- template_args = []
966- cactual_args = []
967963 res .extend (unpack_constraints )
968964 else :
969- template_args = template .arg_types
970- cactual_args = cactual .arg_types
971- # TODO: use some more principled "formal to actual" logic
972- # instead of this lock-step loop over argument types. This identical
973- # logic should be used in 5 places: in Parameters vs Parameters
974- # inference, in Instance vs Instance inference for prefixes (two
975- # branches), and in Callable vs Callable inference (two branches).
976- for t , a in zip (template_args , cactual_args ):
977- # This avoids bogus constraints like T <: P.args
978- if isinstance (a , (ParamSpecType , UnpackType )):
979- # TODO: can we infer something useful for *T vs P?
980- continue
981965 # Negate direction due to function argument type contravariance.
982- res .extend (infer_constraints (t , a , neg_op (self .direction )))
966+ res .extend (
967+ infer_callable_arguments_constraints (template , cactual , self .direction )
968+ )
983969 else :
984970 prefix = param_spec .prefix
985971 prefix_len = len (prefix .arg_types )
@@ -1028,11 +1014,9 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
10281014 arg_kinds = cactual .arg_kinds [:prefix_len ],
10291015 arg_names = cactual .arg_names [:prefix_len ],
10301016 )
1031-
1032- for t , a in zip (prefix .arg_types , cactual_prefix .arg_types ):
1033- if isinstance (a , ParamSpecType ):
1034- continue
1035- res .extend (infer_constraints (t , a , neg_op (self .direction )))
1017+ res .extend (
1018+ infer_callable_arguments_constraints (prefix , cactual_prefix , self .direction )
1019+ )
10361020
10371021 template_ret_type , cactual_ret_type = template .ret_type , cactual .ret_type
10381022 if template .type_guard is not None :
@@ -1435,3 +1419,89 @@ def build_constraints_for_unpack(
14351419 for template_arg , item in zip (template_unpack .items , mapped_middle ):
14361420 res .extend (infer_constraints (template_arg , item , direction ))
14371421 return res , mapped_prefix + mapped_suffix , template_prefix + template_suffix
1422+
1423+
1424+ def infer_directed_arg_constraints (left : Type , right : Type , direction : int ) -> list [Constraint ]:
1425+ """Infer constraints between two arguments using direction between original callables."""
1426+ if isinstance (left , (ParamSpecType , UnpackType )) or isinstance (
1427+ right , (ParamSpecType , UnpackType )
1428+ ):
1429+ # This avoids bogus constraints like T <: P.args
1430+ # TODO: can we infer something useful for *T vs P?
1431+ return []
1432+ if direction == SUBTYPE_OF :
1433+ # We invert direction to account for argument contravariance.
1434+ return infer_constraints (left , right , neg_op (direction ))
1435+ else :
1436+ return infer_constraints (right , left , neg_op (direction ))
1437+
1438+
1439+ def infer_callable_arguments_constraints (
1440+ template : CallableType | Parameters , actual : CallableType | Parameters , direction : int
1441+ ) -> list [Constraint ]:
1442+ """Infer constraints between argument types of two callables.
1443+
1444+ This function essentially extracts four steps from are_parameters_compatible() in
1445+ subtypes.py that involve subtype checks between argument types. We keep the argument
1446+ matching logic, but ignore various strictness flags present there, and checks that
1447+ do not involve subtyping. Then in place of every subtype check we put an infer_constraints()
1448+ call for the same types.
1449+ """
1450+ res = []
1451+ if direction == SUBTYPE_OF :
1452+ left , right = template , actual
1453+ else :
1454+ left , right = actual , template
1455+ left_star = left .var_arg ()
1456+ left_star2 = left .kw_arg ()
1457+ right_star = right .var_arg ()
1458+ right_star2 = right .kw_arg ()
1459+
1460+ # Numbering of steps below matches the one in are_parameters_compatible() for convenience.
1461+ # Phase 1a: compare star vs star arguments.
1462+ if left_star is not None and right_star is not None :
1463+ res .extend (infer_directed_arg_constraints (left_star .typ , right_star .typ , direction ))
1464+ if left_star2 is not None and right_star2 is not None :
1465+ res .extend (infer_directed_arg_constraints (left_star2 .typ , right_star2 .typ , direction ))
1466+
1467+ # Phase 1b: compare left args with corresponding non-star right arguments.
1468+ for right_arg in right .formal_arguments ():
1469+ left_arg = mypy .typeops .callable_corresponding_argument (left , right_arg )
1470+ if left_arg is None :
1471+ continue
1472+ res .extend (infer_directed_arg_constraints (left_arg .typ , right_arg .typ , direction ))
1473+
1474+ # Phase 1c: compare left args with right *args.
1475+ if right_star is not None :
1476+ right_by_position = right .try_synthesizing_arg_from_vararg (None )
1477+ assert right_by_position is not None
1478+ i = right_star .pos
1479+ assert i is not None
1480+ while i < len (left .arg_kinds ) and left .arg_kinds [i ].is_positional ():
1481+ left_by_position = left .argument_by_position (i )
1482+ assert left_by_position is not None
1483+ res .extend (
1484+ infer_directed_arg_constraints (
1485+ left_by_position .typ , right_by_position .typ , direction
1486+ )
1487+ )
1488+ i += 1
1489+
1490+ # Phase 1d: compare left args with right **kwargs.
1491+ if right_star2 is not None :
1492+ right_names = {name for name in right .arg_names if name is not None }
1493+ left_only_names = set ()
1494+ for name , kind in zip (left .arg_names , left .arg_kinds ):
1495+ if name is None or kind .is_star () or name in right_names :
1496+ continue
1497+ left_only_names .add (name )
1498+
1499+ right_by_name = right .try_synthesizing_arg_from_kwarg (None )
1500+ assert right_by_name is not None
1501+ for name in left_only_names :
1502+ left_by_name = left .argument_by_name (name )
1503+ assert left_by_name is not None
1504+ res .extend (
1505+ infer_directed_arg_constraints (left_by_name .typ , right_by_name .typ , direction )
1506+ )
1507+ return res
0 commit comments