@@ -2919,75 +2919,116 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
29192919 That is, 'a < b > c == d' is check as 'a < b and b > c and c == d'
29202920 """
29212921 result : Type | None = None
2922- sub_result : Type | None = None
2922+ sub_result : Type
29232923
29242924 # Check each consecutive operand pair and their operator
29252925 for left , right , operator in zip (e .operands , e .operands [1 :], e .operators ):
29262926 left_type = self .accept (left )
29272927
2928- method_type : mypy .types .Type | None = None
2929-
29302928 if operator == "in" or operator == "not in" :
2929+ # This case covers both iterables and containers, which have different meanings.
2930+ # For a container, the in operator calls the __contains__ method.
2931+ # For an iterable, the in operator iterates over the iterable, and compares each item one-by-one.
2932+ # We allow `in` for a union of containers and iterables as long as at least one of them matches the
2933+ # type of the left operand, as the operation will simply return False if the union's container/iterator
2934+ # type doesn't match the left operand.
2935+
29312936 # If the right operand has partial type, look it up without triggering
29322937 # a "Need type annotation ..." message, as it would be noise.
29332938 right_type = self .find_partial_type_ref_fast_path (right )
29342939 if right_type is None :
29352940 right_type = self .accept (right ) # Validate the right operand
29362941
2937- # Keep track of whether we get type check errors (these won't be reported, they
2938- # are just to verify whether something is valid typing wise).
2939- with self .msg .filter_errors (save_filtered_errors = True ) as local_errors :
2940- _ , method_type = self .check_method_call_by_name (
2941- method = "__contains__" ,
2942- base_type = right_type ,
2943- args = [left ],
2944- arg_kinds = [ARG_POS ],
2945- context = e ,
2946- )
2942+ right_type = get_proper_type (right_type )
2943+ item_types : Sequence [Type ] = [right_type ]
2944+ if isinstance (right_type , UnionType ):
2945+ item_types = list (right_type .items )
29472946
29482947 sub_result = self .bool_type ()
2949- # Container item type for strict type overlap checks. Note: we need to only
2950- # check for nominal type, because a usual "Unsupported operands for in"
2951- # will be reported for types incompatible with __contains__().
2952- # See testCustomContainsCheckStrictEquality for an example.
2953- cont_type = self .chk .analyze_container_item_type (right_type )
2954- if isinstance (right_type , PartialType ):
2955- # We don't really know if this is an error or not, so just shut up.
2956- pass
2957- elif (
2958- local_errors .has_new_errors ()
2959- and
2960- # is_valid_var_arg is True for any Iterable
2961- self .is_valid_var_arg (right_type )
2962- ):
2963- _ , itertype = self .chk .analyze_iterable_item_type (right )
2964- method_type = CallableType (
2965- [left_type ],
2966- [nodes .ARG_POS ],
2967- [None ],
2968- self .bool_type (),
2969- self .named_type ("builtins.function" ),
2970- )
2971- if not is_subtype (left_type , itertype ):
2972- self .msg .unsupported_operand_types ("in" , left_type , right_type , e )
2973- # Only show dangerous overlap if there are no other errors.
2974- elif (
2975- not local_errors .has_new_errors ()
2976- and cont_type
2977- and self .dangerous_comparison (
2978- left_type , cont_type , original_container = right_type , prefer_literal = False
2979- )
2980- ):
2981- self .msg .dangerous_comparison (left_type , cont_type , "container" , e )
2982- else :
2983- self .msg .add_errors (local_errors .filtered_errors ())
2948+
2949+ container_types : list [Type ] = []
2950+ iterable_types : list [Type ] = []
2951+ failed_out = False
2952+ encountered_partial_type = False
2953+
2954+ for item_type in item_types :
2955+ # Keep track of whether we get type check errors (these won't be reported, they
2956+ # are just to verify whether something is valid typing wise).
2957+ with self .msg .filter_errors (save_filtered_errors = True ) as container_errors :
2958+ _ , method_type = self .check_method_call_by_name (
2959+ method = "__contains__" ,
2960+ base_type = item_type ,
2961+ args = [left ],
2962+ arg_kinds = [ARG_POS ],
2963+ context = e ,
2964+ original_type = right_type ,
2965+ )
2966+ # Container item type for strict type overlap checks. Note: we need to only
2967+ # check for nominal type, because a usual "Unsupported operands for in"
2968+ # will be reported for types incompatible with __contains__().
2969+ # See testCustomContainsCheckStrictEquality for an example.
2970+ cont_type = self .chk .analyze_container_item_type (item_type )
2971+
2972+ if isinstance (item_type , PartialType ):
2973+ # We don't really know if this is an error or not, so just shut up.
2974+ encountered_partial_type = True
2975+ pass
2976+ elif (
2977+ container_errors .has_new_errors ()
2978+ and
2979+ # is_valid_var_arg is True for any Iterable
2980+ self .is_valid_var_arg (item_type )
2981+ ):
2982+ # it's not a container, but it is an iterable
2983+ with self .msg .filter_errors (save_filtered_errors = True ) as iterable_errors :
2984+ _ , itertype = self .chk .analyze_iterable_item_type_without_expression (
2985+ item_type , e
2986+ )
2987+ if iterable_errors .has_new_errors ():
2988+ self .msg .add_errors (iterable_errors .filtered_errors ())
2989+ failed_out = True
2990+ else :
2991+ method_type = CallableType (
2992+ [left_type ],
2993+ [nodes .ARG_POS ],
2994+ [None ],
2995+ self .bool_type (),
2996+ self .named_type ("builtins.function" ),
2997+ )
2998+ e .method_types .append (method_type )
2999+ iterable_types .append (itertype )
3000+ elif not container_errors .has_new_errors () and cont_type :
3001+ container_types .append (cont_type )
3002+ e .method_types .append (method_type )
3003+ else :
3004+ self .msg .add_errors (container_errors .filtered_errors ())
3005+ failed_out = True
3006+
3007+ if not encountered_partial_type and not failed_out :
3008+ iterable_type = UnionType .make_union (iterable_types )
3009+ if not is_subtype (left_type , iterable_type ):
3010+ if len (container_types ) == 0 :
3011+ self .msg .unsupported_operand_types ("in" , left_type , right_type , e )
3012+ else :
3013+ container_type = UnionType .make_union (container_types )
3014+ if self .dangerous_comparison (
3015+ left_type ,
3016+ container_type ,
3017+ original_container = right_type ,
3018+ prefer_literal = False ,
3019+ ):
3020+ self .msg .dangerous_comparison (
3021+ left_type , container_type , "container" , e
3022+ )
3023+
29843024 elif operator in operators .op_methods :
29853025 method = operators .op_methods [operator ]
29863026
29873027 with ErrorWatcher (self .msg .errors ) as w :
29883028 sub_result , method_type = self .check_op (
29893029 method , left_type , right , e , allow_reverse = True
29903030 )
3031+ e .method_types .append (method_type )
29913032
29923033 # Only show dangerous overlap if there are no other errors. See
29933034 # testCustomEqCheckStrictEquality for an example.
@@ -3007,12 +3048,10 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
30073048 left_type = try_getting_literal (left_type )
30083049 right_type = try_getting_literal (right_type )
30093050 self .msg .dangerous_comparison (left_type , right_type , "identity" , e )
3010- method_type = None
3051+ e . method_types . append ( None )
30113052 else :
30123053 raise RuntimeError (f"Unknown comparison operator { operator } " )
30133054
3014- e .method_types .append (method_type )
3015-
30163055 # Determine type of boolean-and of result and sub_result
30173056 if result is None :
30183057 result = sub_result
0 commit comments