@@ -82,19 +82,15 @@ def __repr__(self) -> str:
8282 op_str = "<:"
8383 if self .op == SUPERTYPE_OF :
8484 op_str = ":>"
85- return f"{ self .origin_type_var } { op_str } { self .target } "
85+ return f"{ self .type_var } { op_str } { self .target } "
8686
8787 def __hash__ (self ) -> int :
88- return hash ((self .origin_type_var , self .op , self .target ))
88+ return hash ((self .type_var , self .op , self .target ))
8989
9090 def __eq__ (self , other : object ) -> bool :
9191 if not isinstance (other , Constraint ):
9292 return False
93- return (self .origin_type_var , self .op , self .target ) == (
94- other .origin_type_var ,
95- other .op ,
96- other .target ,
97- )
93+ return (self .type_var , self .op , self .target ) == (other .type_var , other .op , other .target )
9894
9995
10096def infer_constraints_for_callable (
@@ -702,54 +698,25 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
702698 )
703699 elif isinstance (tvar , ParamSpecType ) and isinstance (mapped_arg , ParamSpecType ):
704700 suffix = get_proper_type (instance_arg )
705- prefix = mapped_arg .prefix
706- length = len (prefix .arg_types )
707701
708702 if isinstance (suffix , CallableType ):
703+ prefix = mapped_arg .prefix
709704 from_concat = bool (prefix .arg_types ) or suffix .from_concatenate
710705 suffix = suffix .copy_modified (from_concatenate = from_concat )
711706
712707 if isinstance (suffix , (Parameters , CallableType )):
713708 # no such thing as variance for ParamSpecs
714709 # TODO: is there a case I am missing?
715- length = min (length , len (suffix .arg_types ))
716-
717- constrained_to = suffix .copy_modified (
718- suffix .arg_types [length :],
719- suffix .arg_kinds [length :],
720- suffix .arg_names [length :],
721- )
722- constrained_from = mapped_arg .copy_modified (
723- prefix = prefix .copy_modified (
724- prefix .arg_types [length :],
725- prefix .arg_kinds [length :],
726- prefix .arg_names [length :],
727- )
710+ # TODO: constraints between prefixes
711+ prefix = mapped_arg .prefix
712+ suffix = suffix .copy_modified (
713+ suffix .arg_types [len (prefix .arg_types ) :],
714+ suffix .arg_kinds [len (prefix .arg_kinds ) :],
715+ suffix .arg_names [len (prefix .arg_names ) :],
728716 )
729-
730- res .append (Constraint (constrained_from , SUPERTYPE_OF , constrained_to ))
731- res .append (Constraint (constrained_from , SUBTYPE_OF , constrained_to ))
717+ res .append (Constraint (mapped_arg , SUPERTYPE_OF , suffix ))
732718 elif isinstance (suffix , ParamSpecType ):
733- suffix_prefix = suffix .prefix
734- length = min (length , len (suffix_prefix .arg_types ))
735-
736- constrained = suffix .copy_modified (
737- prefix = suffix_prefix .copy_modified (
738- suffix_prefix .arg_types [length :],
739- suffix_prefix .arg_kinds [length :],
740- suffix_prefix .arg_names [length :],
741- )
742- )
743- constrained_from = mapped_arg .copy_modified (
744- prefix = prefix .copy_modified (
745- prefix .arg_types [length :],
746- prefix .arg_kinds [length :],
747- prefix .arg_names [length :],
748- )
749- )
750-
751- res .append (Constraint (constrained_from , SUPERTYPE_OF , constrained ))
752- res .append (Constraint (constrained_from , SUBTYPE_OF , constrained ))
719+ res .append (Constraint (mapped_arg , SUPERTYPE_OF , suffix ))
753720 else :
754721 # This case should have been handled above.
755722 assert not isinstance (tvar , TypeVarTupleType )
@@ -801,56 +768,26 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
801768 template_arg , ParamSpecType
802769 ):
803770 suffix = get_proper_type (mapped_arg )
804- prefix = template_arg .prefix
805- length = len (prefix .arg_types )
806771
807772 if isinstance (suffix , CallableType ):
808773 prefix = template_arg .prefix
809774 from_concat = bool (prefix .arg_types ) or suffix .from_concatenate
810775 suffix = suffix .copy_modified (from_concatenate = from_concat )
811776
812- # TODO: this is almost a copy-paste of code above: make this into a function
813777 if isinstance (suffix , (Parameters , CallableType )):
814778 # no such thing as variance for ParamSpecs
815779 # TODO: is there a case I am missing?
816- length = min (length , len (suffix .arg_types ))
780+ # TODO: constraints between prefixes
781+ prefix = template_arg .prefix
817782
818- constrained_to = suffix .copy_modified (
819- suffix .arg_types [length :],
820- suffix .arg_kinds [length :],
821- suffix .arg_names [length :],
783+ suffix = suffix .copy_modified (
784+ suffix .arg_types [len ( prefix . arg_types ) :],
785+ suffix .arg_kinds [len ( prefix . arg_kinds ) :],
786+ suffix .arg_names [len ( prefix . arg_names ) :],
822787 )
823- constrained_from = template_arg .copy_modified (
824- prefix = prefix .copy_modified (
825- prefix .arg_types [length :],
826- prefix .arg_kinds [length :],
827- prefix .arg_names [length :],
828- )
829- )
830-
831- res .append (Constraint (constrained_from , SUPERTYPE_OF , constrained_to ))
832- res .append (Constraint (constrained_from , SUBTYPE_OF , constrained_to ))
788+ res .append (Constraint (template_arg , SUPERTYPE_OF , suffix ))
833789 elif isinstance (suffix , ParamSpecType ):
834- suffix_prefix = suffix .prefix
835- length = min (length , len (suffix_prefix .arg_types ))
836-
837- constrained = suffix .copy_modified (
838- prefix = suffix_prefix .copy_modified (
839- suffix_prefix .arg_types [length :],
840- suffix_prefix .arg_kinds [length :],
841- suffix_prefix .arg_names [length :],
842- )
843- )
844- constrained_from = template_arg .copy_modified (
845- prefix = prefix .copy_modified (
846- prefix .arg_types [length :],
847- prefix .arg_kinds [length :],
848- prefix .arg_names [length :],
849- )
850- )
851-
852- res .append (Constraint (constrained_from , SUPERTYPE_OF , constrained ))
853- res .append (Constraint (constrained_from , SUBTYPE_OF , constrained ))
790+ res .append (Constraint (template_arg , SUPERTYPE_OF , suffix ))
854791 else :
855792 # This case should have been handled above.
856793 assert not isinstance (tvar , TypeVarTupleType )
@@ -1017,19 +954,9 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
1017954 prefix_len = len (prefix .arg_types )
1018955 cactual_ps = cactual .param_spec ()
1019956
1020- cactual_prefix : Parameters | CallableType
1021- if cactual_ps :
1022- cactual_prefix = cactual_ps .prefix
1023- else :
1024- cactual_prefix = cactual
1025-
1026- max_prefix_len = len (
1027- [k for k in cactual_prefix .arg_kinds if k in (ARG_POS , ARG_OPT )]
1028- )
1029- prefix_len = min (prefix_len , max_prefix_len )
1030-
1031- # we could check the prefixes match here, but that should be caught elsewhere.
1032957 if not cactual_ps :
958+ max_prefix_len = len ([k for k in cactual .arg_kinds if k in (ARG_POS , ARG_OPT )])
959+ prefix_len = min (prefix_len , max_prefix_len )
1033960 res .append (
1034961 Constraint (
1035962 param_spec ,
@@ -1043,17 +970,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
1043970 )
1044971 )
1045972 else :
1046- # earlier, cactual_prefix = cactual_ps.prefix. thus, this is guaranteed
1047- assert isinstance (cactual_prefix , Parameters )
1048-
1049- constrained_by = cactual_ps .copy_modified (
1050- prefix = cactual_prefix .copy_modified (
1051- cactual_prefix .arg_types [prefix_len :],
1052- cactual_prefix .arg_kinds [prefix_len :],
1053- cactual_prefix .arg_names [prefix_len :],
1054- )
1055- )
1056- res .append (Constraint (param_spec , SUBTYPE_OF , constrained_by ))
973+ res .append (Constraint (param_spec , SUBTYPE_OF , cactual_ps ))
1057974
1058975 # compare prefixes
1059976 cactual_prefix = cactual .copy_modified (
0 commit comments