11"""Helpers for introspecting and wrapping annotations."""
22
33import ast
4+ import builtins
45import enum
56import functools
7+ import keyword
68import sys
79import types
810
@@ -154,8 +156,19 @@ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None):
154156 globals [param_name ] = param
155157 locals .pop (param_name , None )
156158
157- code = self .__forward_code__
158- value = eval (code , globals = globals , locals = locals )
159+ arg = self .__forward_arg__
160+ if arg .isidentifier () and not keyword .iskeyword (arg ):
161+ if arg in locals :
162+ value = locals [arg ]
163+ elif arg in globals :
164+ value = globals [arg ]
165+ elif hasattr (builtins , arg ):
166+ return getattr (builtins , arg )
167+ else :
168+ raise NameError (arg )
169+ else :
170+ code = self .__forward_code__
171+ value = eval (code , globals = globals , locals = locals )
159172 self .__forward_evaluated__ = True
160173 self .__forward_value__ = value
161174 return value
@@ -254,7 +267,9 @@ class _Stringifier:
254267 __slots__ = _SLOTS
255268
256269 def __init__ (self , node , globals = None , owner = None , is_class = False , cell = None ):
257- assert isinstance (node , ast .AST )
270+ # Either an AST node or a simple str (for the common case where a ForwardRef
271+ # represent a single name).
272+ assert isinstance (node , (ast .AST , str ))
258273 self .__arg__ = None
259274 self .__forward_evaluated__ = False
260275 self .__forward_value__ = None
@@ -267,18 +282,26 @@ def __init__(self, node, globals=None, owner=None, is_class=False, cell=None):
267282 self .__cell__ = cell
268283 self .__owner__ = owner
269284
270- def __convert (self , other ):
285+ def __convert_to_ast (self , other ):
271286 if isinstance (other , _Stringifier ):
287+ if isinstance (other .__ast_node__ , str ):
288+ return ast .Name (id = other .__ast_node__ )
272289 return other .__ast_node__
273290 elif isinstance (other , slice ):
274291 return ast .Slice (
275- lower = self .__convert (other .start ) if other .start is not None else None ,
276- upper = self .__convert (other .stop ) if other .stop is not None else None ,
277- step = self .__convert (other .step ) if other .step is not None else None ,
292+ lower = self .__convert_to_ast (other .start ) if other .start is not None else None ,
293+ upper = self .__convert_to_ast (other .stop ) if other .stop is not None else None ,
294+ step = self .__convert_to_ast (other .step ) if other .step is not None else None ,
278295 )
279296 else :
280297 return ast .Constant (value = other )
281298
299+ def __get_ast (self ):
300+ node = self .__ast_node__
301+ if isinstance (node , str ):
302+ return ast .Name (id = node )
303+ return node
304+
282305 def __make_new (self , node ):
283306 return _Stringifier (
284307 node , self .__globals__ , self .__owner__ , self .__forward_is_class__
@@ -292,38 +315,37 @@ def __hash__(self):
292315 def __getitem__ (self , other ):
293316 # Special case, to avoid stringifying references to class-scoped variables
294317 # as '__classdict__["x"]'.
295- if (
296- isinstance (self .__ast_node__ , ast .Name )
297- and self .__ast_node__ .id == "__classdict__"
298- ):
318+ if self .__ast_node__ == "__classdict__" :
299319 raise KeyError
300320 if isinstance (other , tuple ):
301- elts = [self .__convert (elt ) for elt in other ]
321+ elts = [self .__convert_to_ast (elt ) for elt in other ]
302322 other = ast .Tuple (elts )
303323 else :
304- other = self .__convert (other )
324+ other = self .__convert_to_ast (other )
305325 assert isinstance (other , ast .AST ), repr (other )
306- return self .__make_new (ast .Subscript (self .__ast_node__ , other ))
326+ return self .__make_new (ast .Subscript (self .__get_ast () , other ))
307327
308328 def __getattr__ (self , attr ):
309- return self .__make_new (ast .Attribute (self .__ast_node__ , attr ))
329+ return self .__make_new (ast .Attribute (self .__get_ast () , attr ))
310330
311331 def __call__ (self , * args , ** kwargs ):
312332 return self .__make_new (
313333 ast .Call (
314- self .__ast_node__ ,
315- [self .__convert (arg ) for arg in args ],
334+ self .__get_ast () ,
335+ [self .__convert_to_ast (arg ) for arg in args ],
316336 [
317- ast .keyword (key , self .__convert (value ))
337+ ast .keyword (key , self .__convert_to_ast (value ))
318338 for key , value in kwargs .items ()
319339 ],
320340 )
321341 )
322342
323343 def __iter__ (self ):
324- yield self .__make_new (ast .Starred (self .__ast_node__ ))
344+ yield self .__make_new (ast .Starred (self .__get_ast () ))
325345
326346 def __repr__ (self ):
347+ if isinstance (self .__ast_node__ , str ):
348+ return self .__ast_node__
327349 return ast .unparse (self .__ast_node__ )
328350
329351 def __format__ (self , format_spec ):
@@ -332,7 +354,7 @@ def __format__(self, format_spec):
332354 def _make_binop (op : ast .AST ):
333355 def binop (self , other ):
334356 return self .__make_new (
335- ast .BinOp (self .__ast_node__ , op , self .__convert (other ))
357+ ast .BinOp (self .__get_ast () , op , self .__convert_to_ast (other ))
336358 )
337359
338360 return binop
@@ -356,7 +378,7 @@ def binop(self, other):
356378 def _make_rbinop (op : ast .AST ):
357379 def rbinop (self , other ):
358380 return self .__make_new (
359- ast .BinOp (self .__convert (other ), op , self .__ast_node__ )
381+ ast .BinOp (self .__convert_to_ast (other ), op , self .__get_ast () )
360382 )
361383
362384 return rbinop
@@ -381,9 +403,9 @@ def _make_compare(op):
381403 def compare (self , other ):
382404 return self .__make_new (
383405 ast .Compare (
384- left = self .__ast_node__ ,
406+ left = self .__get_ast () ,
385407 ops = [op ],
386- comparators = [self .__convert (other )],
408+ comparators = [self .__convert_to_ast (other )],
387409 )
388410 )
389411
@@ -400,7 +422,7 @@ def compare(self, other):
400422
401423 def _make_unary_op (op ):
402424 def unary_op (self ):
403- return self .__make_new (ast .UnaryOp (op , self .__ast_node__ ))
425+ return self .__make_new (ast .UnaryOp (op , self .__get_ast () ))
404426
405427 return unary_op
406428
@@ -422,7 +444,7 @@ def __init__(self, namespace, globals=None, owner=None, is_class=False):
422444
423445 def __missing__ (self , key ):
424446 fwdref = _Stringifier (
425- ast . Name ( id = key ) ,
447+ key ,
426448 globals = self .globals ,
427449 owner = self .owner ,
428450 is_class = self .is_class ,
@@ -480,7 +502,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
480502 name = freevars [i ]
481503 else :
482504 name = "__cell__"
483- fwdref = _Stringifier (ast . Name ( id = name ) )
505+ fwdref = _Stringifier (name )
484506 new_closure .append (types .CellType (fwdref ))
485507 closure = tuple (new_closure )
486508 else :
@@ -532,7 +554,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
532554 else :
533555 name = "__cell__"
534556 fwdref = _Stringifier (
535- ast . Name ( id = name ) ,
557+ name ,
536558 cell = cell ,
537559 owner = owner ,
538560 globals = annotate .__globals__ ,
@@ -555,6 +577,9 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
555577 result = func (Format .VALUE )
556578 for obj in globals .stringifiers :
557579 obj .__class__ = ForwardRef
580+ if isinstance (obj .__ast_node__ , str ):
581+ obj .__arg__ = obj .__ast_node__
582+ obj .__ast_node__ = None
558583 return result
559584 elif format == Format .VALUE :
560585 # Should be impossible because __annotate__ functions must not raise
0 commit comments