@@ -1043,72 +1043,102 @@ def check_overlapping_op_methods(self,
10431043 """Check for overlapping method and reverse method signatures.
10441044
10451045 Assume reverse method has valid argument count and kinds.
1046+
1047+ Precondition:
1048+ If the reverse operator method accepts some argument of type
1049+ X, the forward operator method must belong to class X.
1050+
1051+ For example, if we have the reverse operator `A.__radd__(B)`, then the
1052+ corresponding forward operator must have the type `B.__add__(...)`.
10461053 """
10471054
1048- # Reverse operator method that overlaps unsafely with the
1049- # forward operator method can result in type unsafety. This is
1050- # similar to overlapping overload variants.
1055+ # Note: Suppose we have two operator methods "A.__rOP__(B) -> R1" and
1056+ # "B.__OP__(C) -> R2". We check if these two methods are unsafely overlapping
1057+ # by using the following algorithm:
1058+ #
1059+ # 1. Rewrite "B.__OP__(C) -> R1" to "temp1(B, C) -> R1"
1060+ #
1061+ # 2. Rewrite "A.__rOP__(B) -> R2" to "temp2(B, A) -> R2"
1062+ #
1063+ # 3. Treat temp1 and temp2 as if they were both variants in the same
1064+ # overloaded function. (This mirrors how the Python runtime calls
1065+ # operator methods: we first try __OP__, then __rOP__.)
10511066 #
1052- # This example illustrates the issue:
1067+ # If the first signature is unsafely overlapping with the second,
1068+ # report an error.
10531069 #
1054- # class X: pass
1055- # class A:
1056- # def __add__(self, x: X) -> int:
1057- # if isinstance(x, X):
1058- # return 1
1059- # return NotImplemented
1060- # class B:
1061- # def __radd__(self, x: A) -> str: return 'x'
1062- # class C(X, B): pass
1063- # def f(b: B) -> None:
1064- # A() + b # Result is 1, even though static type seems to be str!
1065- # f(C())
1070+ # 4. However, if temp1 shadows temp2 (e.g. the __rOP__ method can never
1071+ # be called), do NOT report an error.
10661072 #
1067- # The reason for the problem is that B and X are overlapping
1068- # types, and the return types are different. Also, if the type
1069- # of x in __radd__ would not be A, the methods could be
1070- # non-overlapping.
1073+ # This behavior deviates from how we handle overloads -- many of the
1074+ # modules in typeshed seem to define __OP__ methods that shadow the
1075+ # corresponding __rOP__ method.
1076+ #
1077+ # Note: we do not attempt to handle unsafe overlaps related to multiple
1078+ # inheritance.
10711079
10721080 for forward_item in union_items (forward_type ):
10731081 if isinstance (forward_item , CallableType ):
1074- # TODO check argument kinds
1075- if len (forward_item .arg_types ) < 1 :
1076- # Not a valid operator method -- can't succeed anyway.
1077- return
1078-
1079- # Construct normalized function signatures corresponding to the
1080- # operator methods. The first argument is the left operand and the
1081- # second operand is the right argument -- we switch the order of
1082- # the arguments of the reverse method.
1083- forward_tweaked = CallableType (
1084- [forward_base , forward_item .arg_types [0 ]],
1085- [nodes .ARG_POS ] * 2 ,
1086- [None ] * 2 ,
1087- forward_item .ret_type ,
1088- forward_item .fallback ,
1089- name = forward_item .name )
1090- reverse_args = reverse_type .arg_types
1091- reverse_tweaked = CallableType (
1092- [reverse_args [1 ], reverse_args [0 ]],
1093- [nodes .ARG_POS ] * 2 ,
1094- [None ] * 2 ,
1095- reverse_type .ret_type ,
1096- fallback = self .named_type ('builtins.function' ),
1097- name = reverse_type .name )
1098-
1099- if is_unsafe_overlapping_operator_signatures (
1100- forward_tweaked , reverse_tweaked ):
1082+ if self .is_unsafe_overlapping_op (forward_item , forward_base , reverse_type ):
11011083 self .msg .operator_method_signatures_overlap (
11021084 reverse_class , reverse_name ,
11031085 forward_base , forward_name , context )
11041086 elif isinstance (forward_item , Overloaded ):
11051087 for item in forward_item .items ():
1106- self .check_overlapping_op_methods (
1107- reverse_type , reverse_name , reverse_class ,
1108- item , forward_name , forward_base , context )
1088+ if self .is_unsafe_overlapping_op (item , forward_base , reverse_type ):
1089+ self .msg .operator_method_signatures_overlap (
1090+ reverse_class , reverse_name ,
1091+ forward_base , forward_name ,
1092+ context )
11091093 elif not isinstance (forward_item , AnyType ):
11101094 self .msg .forward_operator_not_callable (forward_name , context )
11111095
1096+ def is_unsafe_overlapping_op (self ,
1097+ forward_item : CallableType ,
1098+ forward_base : Type ,
1099+ reverse_type : CallableType ) -> bool :
1100+ # TODO check argument kinds
1101+ if len (forward_item .arg_types ) < 1 :
1102+ # Not a valid operator method -- can't succeed anyway.
1103+ return False
1104+
1105+ # Erase the type if necessary to make sure we don't have a dangling
1106+ # TypeVar in forward_tweaked
1107+ forward_base_erased = forward_base
1108+ if isinstance (forward_base , TypeVarType ):
1109+ forward_base_erased = erase_to_bound (forward_base )
1110+
1111+ # Construct normalized function signatures corresponding to the
1112+ # operator methods. The first argument is the left operand and the
1113+ # second operand is the right argument -- we switch the order of
1114+ # the arguments of the reverse method.
1115+
1116+ forward_tweaked = forward_item .copy_modified (
1117+ arg_types = [forward_base_erased , forward_item .arg_types [0 ]],
1118+ arg_kinds = [nodes .ARG_POS ] * 2 ,
1119+ arg_names = [None ] * 2 ,
1120+ )
1121+ reverse_tweaked = reverse_type .copy_modified (
1122+ arg_types = [reverse_type .arg_types [1 ], reverse_type .arg_types [0 ]],
1123+ arg_kinds = [nodes .ARG_POS ] * 2 ,
1124+ arg_names = [None ] * 2 ,
1125+ )
1126+
1127+ reverse_base_erased = reverse_type .arg_types [0 ]
1128+ if isinstance (reverse_base_erased , TypeVarType ):
1129+ reverse_base_erased = erase_to_bound (reverse_base_erased )
1130+
1131+ if is_same_type (reverse_base_erased , forward_base_erased ):
1132+ return False
1133+ elif is_subtype (reverse_base_erased , forward_base_erased ):
1134+ first = reverse_tweaked
1135+ second = forward_tweaked
1136+ else :
1137+ first = forward_tweaked
1138+ second = reverse_tweaked
1139+
1140+ return is_unsafe_overlapping_overload_signatures (first , second )
1141+
11121142 def check_inplace_operator_method (self , defn : FuncBase ) -> None :
11131143 """Check an inplace operator method such as __iadd__.
11141144
0 commit comments