@@ -4089,36 +4089,57 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
40894089 if isinstance (subject_type , DeletedType ):
40904090 self .msg .deleted_as_rvalue (subject_type , s )
40914091
4092+ # We infer types of patterns twice. The first pass is used
4093+ # to infer the types of capture variables. The type of a
4094+ # capture variable may depend on multiple patterns (it
4095+ # will be a union of all capture types). This pass ignores
4096+ # guard expressions.
40924097 pattern_types = [self .pattern_checker .accept (p , subject_type ) for p in s .patterns ]
4093-
40944098 type_maps : List [TypeMap ] = [t .captures for t in pattern_types ]
4095- self .infer_variable_types_from_type_maps (type_maps )
4099+ inferred_types = self .infer_variable_types_from_type_maps (type_maps )
40964100
4097- for pattern_type , g , b in zip (pattern_types , s .guards , s .bodies ):
4101+ # The second pass narrows down the types and type checks bodies.
4102+ for p , g , b in zip (s .patterns , s .guards , s .bodies ):
4103+ current_subject_type = self .expr_checker .narrow_type_from_binder (s .subject ,
4104+ subject_type )
4105+ pattern_type = self .pattern_checker .accept (p , current_subject_type )
40984106 with self .binder .frame_context (can_skip = True , fall_through = 2 ):
40994107 if b .is_unreachable or isinstance (get_proper_type (pattern_type .type ),
41004108 UninhabitedType ):
41014109 self .push_type_map (None )
4110+ else_map : TypeMap = {}
41024111 else :
4103- self .binder .put (s .subject , pattern_type .type )
4112+ pattern_map , else_map = conditional_types_to_typemaps (
4113+ s .subject ,
4114+ pattern_type .type ,
4115+ pattern_type .rest_type
4116+ )
4117+ self .remove_capture_conflicts (pattern_type .captures ,
4118+ inferred_types )
4119+ self .push_type_map (pattern_map )
41044120 self .push_type_map (pattern_type .captures )
41054121 if g is not None :
4106- gt = get_proper_type (self .expr_checker .accept (g ))
4122+ with self .binder .frame_context (can_skip = True , fall_through = 3 ):
4123+ gt = get_proper_type (self .expr_checker .accept (g ))
41074124
4108- if isinstance (gt , DeletedType ):
4109- self .msg .deleted_as_rvalue (gt , s )
4125+ if isinstance (gt , DeletedType ):
4126+ self .msg .deleted_as_rvalue (gt , s )
41104127
4111- if_map , _ = self .find_isinstance_check (g )
4128+ guard_map , guard_else_map = self .find_isinstance_check (g )
4129+ else_map = or_conditional_maps (else_map , guard_else_map )
41124130
4113- self .push_type_map (if_map )
4114- self .accept (b )
4131+ self .push_type_map (guard_map )
4132+ self .accept (b )
4133+ else :
4134+ self .accept (b )
4135+ self .push_type_map (else_map )
41154136
41164137 # This is needed due to a quirk in frame_context. Without it types will stay narrowed
41174138 # after the match.
41184139 with self .binder .frame_context (can_skip = False , fall_through = 2 ):
41194140 pass
41204141
4121- def infer_variable_types_from_type_maps (self , type_maps : List [TypeMap ]) -> None :
4142+ def infer_variable_types_from_type_maps (self , type_maps : List [TypeMap ]) -> Dict [ Var , Type ] :
41224143 all_captures : Dict [Var , List [Tuple [NameExpr , Type ]]] = defaultdict (list )
41234144 for tm in type_maps :
41244145 if tm is not None :
@@ -4128,28 +4149,38 @@ def infer_variable_types_from_type_maps(self, type_maps: List[TypeMap]) -> None:
41284149 assert isinstance (node , Var )
41294150 all_captures [node ].append ((expr , typ ))
41304151
4152+ inferred_types : Dict [Var , Type ] = {}
41314153 for var , captures in all_captures .items ():
4132- conflict = False
4154+ already_exists = False
41334155 types : List [Type ] = []
41344156 for expr , typ in captures :
41354157 types .append (typ )
41364158
4137- previous_type , _ , inferred = self .check_lvalue (expr )
4159+ previous_type , _ , _ = self .check_lvalue (expr )
41384160 if previous_type is not None :
4139- conflict = True
4140- self .check_subtype (typ , previous_type , expr ,
4141- msg = message_registry .INCOMPATIBLE_TYPES_IN_CAPTURE ,
4142- subtype_label = "pattern captures type" ,
4143- supertype_label = "variable has type" )
4144- for type_map in type_maps :
4145- if type_map is not None and expr in type_map :
4146- del type_map [expr ]
4147-
4148- if not conflict :
4161+ already_exists = True
4162+ if self .check_subtype (typ , previous_type , expr ,
4163+ msg = message_registry .INCOMPATIBLE_TYPES_IN_CAPTURE ,
4164+ subtype_label = "pattern captures type" ,
4165+ supertype_label = "variable has type" ):
4166+ inferred_types [var ] = previous_type
4167+
4168+ if not already_exists :
41494169 new_type = UnionType .make_union (types )
41504170 # Infer the union type at the first occurrence
41514171 first_occurrence , _ = captures [0 ]
4172+ inferred_types [var ] = new_type
41524173 self .infer_variable_type (var , first_occurrence , new_type , first_occurrence )
4174+ return inferred_types
4175+
4176+ def remove_capture_conflicts (self , type_map : TypeMap , inferred_types : Dict [Var , Type ]) -> None :
4177+ if type_map :
4178+ for expr , typ in list (type_map .items ()):
4179+ if isinstance (expr , NameExpr ):
4180+ node = expr .node
4181+ assert isinstance (node , Var )
4182+ if node not in inferred_types or not is_subtype (typ , inferred_types [node ]):
4183+ del type_map [expr ]
41534184
41544185 def make_fake_typeinfo (self ,
41554186 curr_module_fullname : str ,
@@ -5637,6 +5668,14 @@ def conditional_types(current_type: Type,
56375668 None means no new information can be inferred. If default is set it is returned
56385669 instead."""
56395670 if proposed_type_ranges :
5671+ if len (proposed_type_ranges ) == 1 :
5672+ target = proposed_type_ranges [0 ].item
5673+ target = get_proper_type (target )
5674+ if isinstance (target , LiteralType ) and (target .is_enum_literal ()
5675+ or isinstance (target .value , bool )):
5676+ enum_name = target .fallback .type .fullname
5677+ current_type = try_expanding_sum_type_to_union (current_type ,
5678+ enum_name )
56405679 proposed_items = [type_range .item for type_range in proposed_type_ranges ]
56415680 proposed_type = make_simplified_union (proposed_items )
56425681 if isinstance (proposed_type , AnyType ):
0 commit comments