1818import sys
1919
2020from typing import Tuple , Union , TypeVar , Callable , Sequence , Optional , Any , cast , List
21- from mypy .sharedparse import special_function_elide_names , argument_elide_name
21+ from mypy .sharedparse import (
22+ special_function_elide_names , argument_elide_name , is_overload_part ,
23+ )
2224from mypy .nodes import (
2325 MypyFile , Node , ImportBase , Import , ImportAll , ImportFrom , FuncDef , OverloadedFuncDef ,
2426 ClassDef , Decorator , Block , Var , OperatorAssignmentStmt ,
@@ -209,19 +211,27 @@ def as_block(self, stmts: List[ast27.stmt], lineno: int) -> Block:
209211
210212 def fix_function_overloads (self , stmts : List [Statement ]) -> List [Statement ]:
211213 ret = [] # type: List[Statement]
212- current_overload = []
214+ current_overload = [] # type: List[Decorator]
213215 current_overload_name = None
214216 # mypy doesn't actually check that the decorator is literally @overload
215217 for stmt in stmts :
216- if isinstance (stmt , Decorator ) and stmt .name () == current_overload_name :
218+ if (isinstance (stmt , Decorator )
219+ and is_overload_part (stmt )
220+ and stmt .name () == current_overload_name ):
217221 current_overload .append (stmt )
222+ elif (isinstance (stmt , FuncDef )
223+ and stmt .name () == current_overload_name
224+ and stmt .name () is not None ):
225+ ret .append (OverloadedFuncDef (current_overload , stmt ))
226+ current_overload = []
227+ current_overload_name = None
218228 else :
219229 if len (current_overload ) == 1 :
220230 ret .append (current_overload [0 ])
221231 elif len (current_overload ) > 1 :
222- ret .append (OverloadedFuncDef (current_overload ))
232+ ret .append (OverloadedFuncDef (current_overload , None ))
223233
224- if isinstance (stmt , Decorator ):
234+ if isinstance (stmt , Decorator ) and is_overload_part ( stmt ) :
225235 current_overload = [stmt ]
226236 current_overload_name = stmt .name ()
227237 else :
@@ -232,7 +242,7 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
232242 if len (current_overload ) == 1 :
233243 ret .append (current_overload [0 ])
234244 elif len (current_overload ) > 1 :
235- ret .append (OverloadedFuncDef (current_overload ))
245+ ret .append (OverloadedFuncDef (current_overload , None ))
236246 return ret
237247
238248 def in_class (self ) -> bool :
0 commit comments