From 1086dcba0944a587f2c69f56cb666f2e6eb1dcfd Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Mon, 18 Nov 2024 22:28:34 +0100 Subject: [PATCH] Support `==`-based narrowing of Optional Closes #18135 This change implements the third approach mentioned in #18135, which is stricter than similar narrowings, as clarified by the new/modified code comments. Personally, I prefer this more stringent way but could also switch this PR to approach two if there is a consent that convenience is more important than type safety here. --- mypy/checker.py | 37 +++++++++++++++++------------ test-data/unit/check-narrowing.test | 4 ++-- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 1bee348bc252..ef3f7502d7ce 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6274,10 +6274,6 @@ def has_no_custom_eq_checks(t: Type) -> bool: coerce_only_in_literal_context, ) - # Strictly speaking, we should also skip this check if the objects in the expr - # chain have custom __eq__ or __ne__ methods. But we (maybe optimistically) - # assume nobody would actually create a custom objects that considers itself - # equal to None. if if_map == {} and else_map == {}: if_map, else_map = self.refine_away_none_in_comparison( operands, operand_types, expr_indices, narrowable_operand_index_to_hash.keys() @@ -6602,25 +6598,36 @@ def refine_away_none_in_comparison( For more details about what the different arguments mean, see the docstring of 'refine_identity_comparison_expression' up above. """ + non_optional_types = [] for i in chain_indices: typ = operand_types[i] if not is_overlapping_none(typ): non_optional_types.append(typ) - # Make sure we have a mixture of optional and non-optional types. - if len(non_optional_types) == 0 or len(non_optional_types) == len(chain_indices): - return {}, {} + if_map, else_map = {}, {} - if_map = {} - for i in narrowable_operand_indices: - expr_type = operand_types[i] - if not is_overlapping_none(expr_type): - continue - if any(is_overlapping_erased_types(expr_type, t) for t in non_optional_types): - if_map[operands[i]] = remove_optional(expr_type) + if not non_optional_types or (len(non_optional_types) != len(chain_indices)): - return if_map, {} + # Narrow e.g. `Optional[A] == "x"` or `Optional[A] is "x"` to `A` (which may be + # convenient but is strictly not type-safe): + for i in narrowable_operand_indices: + expr_type = operand_types[i] + if not is_overlapping_none(expr_type): + continue + if any(is_overlapping_erased_types(expr_type, t) for t in non_optional_types): + if_map[operands[i]] = remove_optional(expr_type) + + # Narrow e.g. `Optional[A] != None` to `A` (which is stricter than the above step and + # so type-safe but less convenient, because e.g. `Optional[A] == None` still results + # in `Optional[A]`): + if any(isinstance(get_proper_type(ot), NoneType) for ot in operand_types): + for i in narrowable_operand_indices: + expr_type = operand_types[i] + if is_overlapping_none(expr_type): + else_map[operands[i]] = remove_optional(expr_type) + + return if_map, else_map def is_len_of_tuple(self, expr: Expression) -> bool: """Is this expression a `len(x)` call where x is a tuple or union of tuples?""" diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index d740708991d0..bc763095477e 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1385,9 +1385,9 @@ val: Optional[A] if val == None: reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" else: - reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" + reveal_type(val) # N: Revealed type is "__main__.A" if val != None: - reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" + reveal_type(val) # N: Revealed type is "__main__.A" else: reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"