@@ -3536,67 +3536,61 @@ def find_isinstance_check(self, node: Expression
35363536 vartype = type_map [expr ]
35373537 return self .conditional_callable_type_map (expr , vartype )
35383538 elif isinstance (node , ComparisonExpr ):
3539- operand_types = [coerce_to_literal (type_map [expr ])
3540- for expr in node .operands if expr in type_map ]
3541-
3542- is_not = node .operators == ['is not' ]
3543- if (is_not or node .operators == ['is' ]) and len (operand_types ) == len (node .operands ):
3544- if_vars = {} # type: TypeMap
3545- else_vars = {} # type: TypeMap
3546-
3547- for i , expr in enumerate (node .operands ):
3548- var_type = operand_types [i ]
3549- other_type = operand_types [1 - i ]
3550-
3551- if literal (expr ) == LITERAL_TYPE and is_singleton_type (other_type ):
3552- # This should only be true at most once: there should be
3553- # exactly two elements in node.operands and if the 'other type' is
3554- # a singleton type, it by definition does not need to be narrowed:
3555- # it already has the most precise type possible so does not need to
3556- # be narrowed/included in the output map.
3557- #
3558- # TODO: Generalize this to handle the case where 'other_type' is
3559- # a union of singleton types.
3539+ operand_types = []
3540+ for expr in node .operands :
3541+ if expr not in type_map :
3542+ return {}, {}
3543+ operand_types .append (coerce_to_literal (type_map [expr ]))
3544+
3545+ type_maps = []
3546+ for i , (operator , left_expr , right_expr ) in enumerate (node .pairwise ()):
3547+ left_type = operand_types [i ]
3548+ right_type = operand_types [i + 1 ]
3549+
3550+ if_map = {} # type: TypeMap
3551+ else_map = {} # type: TypeMap
3552+ if operator in {'in' , 'not in' }:
3553+ right_item_type = builtin_item_type (right_type )
3554+ if right_item_type is None or is_optional (right_item_type ):
3555+ continue
3556+ if (isinstance (right_item_type , Instance )
3557+ and right_item_type .type .fullname () == 'builtins.object' ):
3558+ continue
3559+
3560+ if (is_optional (left_type ) and literal (left_expr ) == LITERAL_TYPE
3561+ and not is_literal_none (left_expr ) and
3562+ is_overlapping_erased_types (left_type , right_item_type )):
3563+ if_map , else_map = {left_expr : remove_optional (left_type )}, {}
3564+ else :
3565+ continue
3566+ elif operator in {'==' , '!=' }:
3567+ if_map , else_map = self .narrow_given_equality (
3568+ left_expr , left_type , right_expr , right_type , assume_identity = False )
3569+ elif operator in {'is' , 'is not' }:
3570+ if_map , else_map = self .narrow_given_equality (
3571+ left_expr , left_type , right_expr , right_type , assume_identity = True )
3572+ else :
3573+ continue
35603574
3561- if isinstance (other_type , LiteralType ) and other_type .is_enum_literal ():
3562- fallback_name = other_type .fallback .type .fullname ()
3563- var_type = try_expanding_enum_to_union (var_type , fallback_name )
3575+ if operator in {'not in' , '!=' , 'is not' }:
3576+ if_map , else_map = else_map , if_map
35643577
3565- target_type = [TypeRange (other_type , is_upper_bound = False )]
3566- if_vars , else_vars = conditional_type_map (expr , var_type , target_type )
3567- break
3578+ type_maps .append ((if_map , else_map ))
35683579
3569- if is_not :
3570- if_vars , else_vars = else_vars , if_vars
3571- return if_vars , else_vars
3572- # Check for `x == y` where x is of type Optional[T] and y is of type T
3573- # or a type that overlaps with T (or vice versa).
3574- elif node .operators == ['==' ]:
3575- first_type = type_map [node .operands [0 ]]
3576- second_type = type_map [node .operands [1 ]]
3577- if is_optional (first_type ) != is_optional (second_type ):
3578- if is_optional (first_type ):
3579- optional_type , comp_type = first_type , second_type
3580- optional_expr = node .operands [0 ]
3581- else :
3582- optional_type , comp_type = second_type , first_type
3583- optional_expr = node .operands [1 ]
3584- if is_overlapping_erased_types (optional_type , comp_type ):
3585- return {optional_expr : remove_optional (optional_type )}, {}
3586- elif node .operators in [['in' ], ['not in' ]]:
3587- expr = node .operands [0 ]
3588- left_type = type_map [expr ]
3589- right_type = builtin_item_type (type_map [node .operands [1 ]])
3590- right_ok = right_type and (not is_optional (right_type ) and
3591- (not isinstance (right_type , Instance ) or
3592- right_type .type .fullname () != 'builtins.object' ))
3593- if (right_type and right_ok and is_optional (left_type ) and
3594- literal (expr ) == LITERAL_TYPE and not is_literal_none (expr ) and
3595- is_overlapping_erased_types (left_type , right_type )):
3596- if node .operators == ['in' ]:
3597- return {expr : remove_optional (left_type )}, {}
3598- if node .operators == ['not in' ]:
3599- return {}, {expr : remove_optional (left_type )}
3580+ if len (type_maps ) == 0 :
3581+ return {}, {}
3582+ elif len (type_maps ) == 1 :
3583+ return type_maps [0 ]
3584+ else :
3585+ # Comparisons like 'a == b == c is d' is the same thing as
3586+ # '(a == b) and (b == c) and (c is d)'. So after generating each
3587+ # individual comparison's typemaps, we "and" them together here.
3588+ # (Also see comments below where we handle the 'and' OpExpr.)
3589+ final_if_map , final_else_map = type_maps [0 ]
3590+ for if_map , else_map in type_maps [1 :]:
3591+ final_if_map = and_conditional_maps (final_if_map , if_map )
3592+ final_else_map = or_conditional_maps (final_else_map , else_map )
3593+ return final_if_map , final_else_map
36003594 elif isinstance (node , RefExpr ):
36013595 # Restrict the type of the variable to True-ish/False-ish in the if and else branches
36023596 # respectively
@@ -3630,6 +3624,78 @@ def find_isinstance_check(self, node: Expression
36303624 # Not a supported isinstance check
36313625 return {}, {}
36323626
3627+ def narrow_given_equality (self ,
3628+ left_expr : Expression ,
3629+ left_type : Type ,
3630+ right_expr : Expression ,
3631+ right_type : Type ,
3632+ assume_identity : bool ,
3633+ ) -> Tuple [TypeMap , TypeMap ]:
3634+ """Assuming that the given 'left' and 'right' exprs are equal to each other, try
3635+ producing TypeMaps refining the types of either the left or right exprs (or neither,
3636+ if we can't learn anything from the comparison).
3637+
3638+ For more details about what TypeMaps are, see the docstring in find_isinstance_check.
3639+
3640+ If 'assume_identity' is true, assume that this comparison was done using an
3641+ identity comparison (left_expr is right_expr), not just an equality comparison
3642+ (left_expr == right_expr). Identity checks are not overridable, so we can infer
3643+ more information in that case.
3644+ """
3645+
3646+ # For the sake of simplicity, we currently attempt inferring a more precise type
3647+ # for just one of the two variables.
3648+ comparisons = [
3649+ (left_expr , left_type , right_type ),
3650+ (right_expr , right_type , left_type ),
3651+ ]
3652+
3653+ for expr , expr_type , other_type in comparisons :
3654+ # The 'expr' isn't an expression that we can refine the type of. Skip
3655+ # attempting to refine this expr.
3656+ if literal (expr ) != LITERAL_TYPE :
3657+ continue
3658+
3659+ # Case 1: If the 'other_type' is a singleton (only one value has
3660+ # the specified type), attempt to narrow 'expr_type' to just that
3661+ # singleton type.
3662+ if is_singleton_type (other_type ):
3663+ if isinstance (other_type , LiteralType ) and other_type .is_enum_literal ():
3664+ if not assume_identity :
3665+ # Our checks need to be more conservative if the operand is
3666+ # '==' or '!=': all bets are off if either of the two operands
3667+ # has a custom `__eq__` or `__ne__` method.
3668+ #
3669+ # So, we permit this check to succeed only if 'other_type' does
3670+ # not define custom equality logic
3671+ if not uses_default_equality_checks (expr_type ):
3672+ continue
3673+ if not uses_default_equality_checks (other_type .fallback ):
3674+ continue
3675+ fallback_name = other_type .fallback .type .fullname ()
3676+ expr_type = try_expanding_enum_to_union (expr_type , fallback_name )
3677+
3678+ target_type = [TypeRange (other_type , is_upper_bound = False )]
3679+ return conditional_type_map (expr , expr_type , target_type )
3680+
3681+ # Case 2: Given expr_type=Union[A, None] and other_type=A, narrow to just 'A'.
3682+ #
3683+ # Note: This check is actually strictly speaking unsafe: stripping away the 'None'
3684+ # would be unsound in the case where A defines an '__eq__' method that always
3685+ # returns 'True', for example.
3686+ #
3687+ # We implement this check partly for backwards-compatibility reasons and partly
3688+ # because those kinds of degenerate '__eq__' implementations are probably rare
3689+ # enough that this is fine in practice.
3690+ #
3691+ # We could also probably generalize this block to strip away *any* singleton type,
3692+ # if we were fine with a bit more unsoundness.
3693+ if is_optional (expr_type ) and not is_optional (other_type ):
3694+ if is_overlapping_erased_types (expr_type , other_type ):
3695+ return {expr : remove_optional (expr_type )}, {}
3696+
3697+ return {}, {}
3698+
36333699 #
36343700 # Helpers
36353701 #
@@ -4505,6 +4571,32 @@ def is_private(node_name: str) -> bool:
45054571 return node_name .startswith ('__' ) and not node_name .endswith ('__' )
45064572
45074573
4574+ def uses_default_equality_checks (typ : Type ) -> bool :
4575+ """Returns 'true' if we know for certain that the given type is using
4576+ the default __eq__ and __ne__ checks defined in 'builtins.object'.
4577+ We can use this information to make more aggressive inferences when
4578+ analyzing things like equality checks.
4579+
4580+ When in doubt, this function will conservatively bias towards
4581+ returning False.
4582+ """
4583+ if isinstance (typ , UnionType ):
4584+ return all (map (uses_default_equality_checks , typ .items ))
4585+ # TODO: Generalize this so it'll handle other types with fallbacks
4586+ if isinstance (typ , LiteralType ):
4587+ typ = typ .fallback
4588+ if isinstance (typ , Instance ):
4589+ typeinfo = typ .type
4590+ eq_sym = typeinfo .get ('__eq__' )
4591+ ne_sym = typeinfo .get ('__ne__' )
4592+ if eq_sym is None or ne_sym is None :
4593+ return False
4594+ return (eq_sym .fullname == 'builtins.object.__eq__'
4595+ and ne_sym .fullname == 'builtins.object.__ne__' )
4596+ else :
4597+ return False
4598+
4599+
45084600def is_singleton_type (typ : Type ) -> bool :
45094601 """Returns 'true' if this type is a "singleton type" -- if there exists
45104602 exactly only one runtime value associated with this type.
0 commit comments