11from __future__ import annotations
22
3- from typing import NamedTuple
4-
3+ from mypy import checker
54from mypy .messages import MessageBuilder
65from mypy .nodes import (
6+ AssertStmt ,
77 AssignmentStmt ,
8+ BreakStmt ,
9+ ContinueStmt ,
10+ Expression ,
11+ ExpressionStmt ,
812 ForStmt ,
913 FuncDef ,
1014 FuncItem ,
15+ GeneratorExpr ,
1116 IfStmt ,
1217 ListExpr ,
1318 Lvalue ,
1419 NameExpr ,
20+ RaiseStmt ,
21+ ReturnStmt ,
1522 TupleExpr ,
1623 WhileStmt ,
1724)
18- from mypy .traverser import TraverserVisitor
25+ from mypy .traverser import ExtendedTraverserVisitor
26+ from mypy .types import Type , UninhabitedType
1927
2028
21- class DefinedVars ( NamedTuple ) :
22- """DefinedVars contains information about variable definition at the end of a branching statement.
29+ class BranchState :
30+ """BranchState contains information about variable definition at the end of a branching statement.
2331 `if` and `match` are examples of branching statements.
2432
2533 `may_be_defined` contains variables that were defined in only some branches.
2634 `must_be_defined` contains variables that were defined in all branches.
2735 """
2836
29- may_be_defined : set [str ]
30- must_be_defined : set [str ]
37+ def __init__ (
38+ self ,
39+ must_be_defined : set [str ] | None = None ,
40+ may_be_defined : set [str ] | None = None ,
41+ skipped : bool = False ,
42+ ) -> None :
43+ if may_be_defined is None :
44+ may_be_defined = set ()
45+ if must_be_defined is None :
46+ must_be_defined = set ()
47+
48+ self .may_be_defined = set (may_be_defined )
49+ self .must_be_defined = set (must_be_defined )
50+ self .skipped = skipped
3151
3252
3353class BranchStatement :
34- def __init__ (self , already_defined : DefinedVars ) -> None :
35- self .already_defined = already_defined
36- self .defined_by_branch : list [DefinedVars ] = [
37- DefinedVars ( may_be_defined = set (), must_be_defined = set ( already_defined . must_be_defined ) )
54+ def __init__ (self , initial_state : BranchState ) -> None :
55+ self .initial_state = initial_state
56+ self .branches : list [BranchState ] = [
57+ BranchState ( must_be_defined = self . initial_state . must_be_defined )
3858 ]
3959
4060 def next_branch (self ) -> None :
41- self .defined_by_branch .append (
42- DefinedVars (
43- may_be_defined = set (), must_be_defined = set (self .already_defined .must_be_defined )
44- )
45- )
61+ self .branches .append (BranchState (must_be_defined = self .initial_state .must_be_defined ))
4662
4763 def record_definition (self , name : str ) -> None :
48- assert len (self .defined_by_branch ) > 0
49- self .defined_by_branch [- 1 ].must_be_defined .add (name )
50- self .defined_by_branch [- 1 ].may_be_defined .discard (name )
51-
52- def record_nested_branch (self , vars : DefinedVars ) -> None :
53- assert len (self .defined_by_branch ) > 0
54- current_branch = self .defined_by_branch [- 1 ]
55- current_branch .must_be_defined .update (vars .must_be_defined )
56- current_branch .may_be_defined .update (vars .may_be_defined )
64+ assert len (self .branches ) > 0
65+ self .branches [- 1 ].must_be_defined .add (name )
66+ self .branches [- 1 ].may_be_defined .discard (name )
67+
68+ def record_nested_branch (self , state : BranchState ) -> None :
69+ assert len (self .branches ) > 0
70+ current_branch = self .branches [- 1 ]
71+ if state .skipped :
72+ current_branch .skipped = True
73+ return
74+ current_branch .must_be_defined .update (state .must_be_defined )
75+ current_branch .may_be_defined .update (state .may_be_defined )
5776 current_branch .may_be_defined .difference_update (current_branch .must_be_defined )
5877
78+ def skip_branch (self ) -> None :
79+ assert len (self .branches ) > 0
80+ self .branches [- 1 ].skipped = True
81+
5982 def is_possibly_undefined (self , name : str ) -> bool :
60- assert len (self .defined_by_branch ) > 0
61- return name in self .defined_by_branch [- 1 ].may_be_defined
83+ assert len (self .branches ) > 0
84+ return name in self .branches [- 1 ].may_be_defined
6285
63- def done (self ) -> DefinedVars :
64- assert len ( self .defined_by_branch ) > 0
65- if len (self . defined_by_branch ) == 1 :
66- # If there's only one branch, then we just return current.
67- # Note that this case is a different case when an empty branch is omitted (e.g. `if` without `else`).
68- return self . defined_by_branch [0 ]
86+ def done (self ) -> BranchState :
87+ branches = [ b for b in self .branches if not b . skipped ]
88+ if len (branches ) == 0 :
89+ return BranchState ( skipped = True )
90+ if len ( branches ) == 1 :
91+ return branches [0 ]
6992
7093 # must_be_defined is a union of must_be_defined of all branches.
71- must_be_defined = set (self . defined_by_branch [0 ].must_be_defined )
72- for branch_vars in self . defined_by_branch [1 :]:
73- must_be_defined .intersection_update (branch_vars .must_be_defined )
94+ must_be_defined = set (branches [0 ].must_be_defined )
95+ for b in branches [1 :]:
96+ must_be_defined .intersection_update (b .must_be_defined )
7497 # may_be_defined are all variables that are not must be defined.
7598 all_vars = set ()
76- for branch_vars in self . defined_by_branch :
77- all_vars .update (branch_vars .may_be_defined )
78- all_vars .update (branch_vars .must_be_defined )
99+ for b in branches :
100+ all_vars .update (b .may_be_defined )
101+ all_vars .update (b .must_be_defined )
79102 may_be_defined = all_vars .difference (must_be_defined )
80- return DefinedVars (may_be_defined = may_be_defined , must_be_defined = must_be_defined )
103+ return BranchState (may_be_defined = may_be_defined , must_be_defined = must_be_defined )
81104
82105
83106class DefinedVariableTracker :
84107 """DefinedVariableTracker manages the state and scope for the UndefinedVariablesVisitor."""
85108
86109 def __init__ (self ) -> None :
87110 # There's always at least one scope. Within each scope, there's at least one "global" BranchingStatement.
88- self .scopes : list [list [BranchStatement ]] = [
89- [BranchStatement (DefinedVars (may_be_defined = set (), must_be_defined = set ()))]
90- ]
111+ self .scopes : list [list [BranchStatement ]] = [[BranchStatement (BranchState ())]]
91112
92113 def _scope (self ) -> list [BranchStatement ]:
93114 assert len (self .scopes ) > 0
94115 return self .scopes [- 1 ]
95116
96117 def enter_scope (self ) -> None :
97118 assert len (self ._scope ()) > 0
98- self .scopes .append ([BranchStatement (self ._scope ()[- 1 ].defined_by_branch [- 1 ])])
119+ self .scopes .append ([BranchStatement (self ._scope ()[- 1 ].branches [- 1 ])])
99120
100121 def exit_scope (self ) -> None :
101122 self .scopes .pop ()
102123
103124 def start_branch_statement (self ) -> None :
104125 assert len (self ._scope ()) > 0
105- self ._scope ().append (BranchStatement (self ._scope ()[- 1 ].defined_by_branch [- 1 ]))
126+ self ._scope ().append (BranchStatement (self ._scope ()[- 1 ].branches [- 1 ]))
106127
107128 def next_branch (self ) -> None :
108129 assert len (self ._scope ()) > 1
@@ -113,6 +134,11 @@ def end_branch_statement(self) -> None:
113134 result = self ._scope ().pop ().done ()
114135 self ._scope ()[- 1 ].record_nested_branch (result )
115136
137+ def skip_branch (self ) -> None :
138+ # Only skip branch if we're outside of "root" branch statement.
139+ if len (self ._scope ()) > 1 :
140+ self ._scope ()[- 1 ].skip_branch ()
141+
116142 def record_declaration (self , name : str ) -> None :
117143 assert len (self .scopes ) > 0
118144 assert len (self .scopes [- 1 ]) > 0
@@ -125,7 +151,7 @@ def is_possibly_undefined(self, name: str) -> bool:
125151 return self ._scope ()[- 1 ].is_possibly_undefined (name )
126152
127153
128- class PartiallyDefinedVariableVisitor (TraverserVisitor ):
154+ class PartiallyDefinedVariableVisitor (ExtendedTraverserVisitor ):
129155 """Detect variables that are defined only part of the time.
130156
131157 This visitor detects the following case:
@@ -137,8 +163,9 @@ class PartiallyDefinedVariableVisitor(TraverserVisitor):
137163 handled by the semantic analyzer.
138164 """
139165
140- def __init__ (self , msg : MessageBuilder ) -> None :
166+ def __init__ (self , msg : MessageBuilder , type_map : dict [ Expression , Type ] ) -> None :
141167 self .msg = msg
168+ self .type_map = type_map
142169 self .tracker = DefinedVariableTracker ()
143170
144171 def process_lvalue (self , lvalue : Lvalue ) -> None :
@@ -175,6 +202,13 @@ def visit_func(self, o: FuncItem) -> None:
175202 self .tracker .record_declaration (arg .variable .name )
176203 super ().visit_func (o )
177204
205+ def visit_generator_expr (self , o : GeneratorExpr ) -> None :
206+ self .tracker .enter_scope ()
207+ for idx in o .indices :
208+ self .process_lvalue (idx )
209+ super ().visit_generator_expr (o )
210+ self .tracker .exit_scope ()
211+
178212 def visit_for_stmt (self , o : ForStmt ) -> None :
179213 o .expr .accept (self )
180214 self .process_lvalue (o .index )
@@ -186,13 +220,40 @@ def visit_for_stmt(self, o: ForStmt) -> None:
186220 o .else_body .accept (self )
187221 self .tracker .end_branch_statement ()
188222
223+ def visit_return_stmt (self , o : ReturnStmt ) -> None :
224+ super ().visit_return_stmt (o )
225+ self .tracker .skip_branch ()
226+
227+ def visit_assert_stmt (self , o : AssertStmt ) -> None :
228+ super ().visit_assert_stmt (o )
229+ if checker .is_false_literal (o .expr ):
230+ self .tracker .skip_branch ()
231+
232+ def visit_raise_stmt (self , o : RaiseStmt ) -> None :
233+ super ().visit_raise_stmt (o )
234+ self .tracker .skip_branch ()
235+
236+ def visit_continue_stmt (self , o : ContinueStmt ) -> None :
237+ super ().visit_continue_stmt (o )
238+ self .tracker .skip_branch ()
239+
240+ def visit_break_stmt (self , o : BreakStmt ) -> None :
241+ super ().visit_break_stmt (o )
242+ self .tracker .skip_branch ()
243+
244+ def visit_expression_stmt (self , o : ExpressionStmt ) -> None :
245+ if isinstance (self .type_map .get (o .expr , None ), UninhabitedType ):
246+ self .tracker .skip_branch ()
247+ super ().visit_expression_stmt (o )
248+
189249 def visit_while_stmt (self , o : WhileStmt ) -> None :
190250 o .expr .accept (self )
191251 self .tracker .start_branch_statement ()
192252 o .body .accept (self )
193- self .tracker .next_branch ()
194- if o .else_body :
195- o .else_body .accept (self )
253+ if not checker .is_true_literal (o .expr ):
254+ self .tracker .next_branch ()
255+ if o .else_body :
256+ o .else_body .accept (self )
196257 self .tracker .end_branch_statement ()
197258
198259 def visit_name_expr (self , o : NameExpr ) -> None :
0 commit comments