@@ -6274,10 +6274,6 @@ def has_no_custom_eq_checks(t: Type) -> bool:
62746274 coerce_only_in_literal_context ,
62756275 )
62766276
6277- # Strictly speaking, we should also skip this check if the objects in the expr
6278- # chain have custom __eq__ or __ne__ methods. But we (maybe optimistically)
6279- # assume nobody would actually create a custom objects that considers itself
6280- # equal to None.
62816277 if if_map == {} and else_map == {}:
62826278 if_map , else_map = self .refine_away_none_in_comparison (
62836279 operands , operand_types , expr_indices , narrowable_operand_index_to_hash .keys ()
@@ -6602,25 +6598,36 @@ def refine_away_none_in_comparison(
66026598 For more details about what the different arguments mean, see the
66036599 docstring of 'refine_identity_comparison_expression' up above.
66046600 """
6601+
66056602 non_optional_types = []
66066603 for i in chain_indices :
66076604 typ = operand_types [i ]
66086605 if not is_overlapping_none (typ ):
66096606 non_optional_types .append (typ )
66106607
6611- # Make sure we have a mixture of optional and non-optional types.
6612- if len (non_optional_types ) == 0 or len (non_optional_types ) == len (chain_indices ):
6613- return {}, {}
6608+ if_map , else_map = {}, {}
66146609
6615- if_map = {}
6616- for i in narrowable_operand_indices :
6617- expr_type = operand_types [i ]
6618- if not is_overlapping_none (expr_type ):
6619- continue
6620- if any (is_overlapping_erased_types (expr_type , t ) for t in non_optional_types ):
6621- if_map [operands [i ]] = remove_optional (expr_type )
6610+ if not non_optional_types or (len (non_optional_types ) != len (chain_indices )):
66226611
6623- return if_map , {}
6612+ # Narrow e.g. `Optional[A] == "x"` or `Optional[A] is "x"` to `A` (which may be
6613+ # convenient but is strictly not type-safe):
6614+ for i in narrowable_operand_indices :
6615+ expr_type = operand_types [i ]
6616+ if not is_overlapping_none (expr_type ):
6617+ continue
6618+ if any (is_overlapping_erased_types (expr_type , t ) for t in non_optional_types ):
6619+ if_map [operands [i ]] = remove_optional (expr_type )
6620+
6621+ # Narrow e.g. `Optional[A] != None` to `A` (which is stricter than the above step and
6622+ # so type-safe but less convenient, because e.g. `Optional[A] == None` still results
6623+ # in `Optional[A]`):
6624+ if any (isinstance (get_proper_type (ot ), NoneType ) for ot in operand_types ):
6625+ for i in narrowable_operand_indices :
6626+ expr_type = operand_types [i ]
6627+ if is_overlapping_none (expr_type ):
6628+ else_map [operands [i ]] = remove_optional (expr_type )
6629+
6630+ return if_map , else_map
66246631
66256632 def is_len_of_tuple (self , expr : Expression ) -> bool :
66266633 """Is this expression a `len(x)` call where x is a tuple or union of tuples?"""
0 commit comments