11"""Code generation for native function bodies."""
22
3- from typing import Union , Optional
3+ from typing import List , Union , Optional
44from typing_extensions import Final
55
66from mypyc .common import (
77 REG_PREFIX , NATIVE_PREFIX , STATIC_PREFIX , TYPE_PREFIX , MODULE_PREFIX ,
88)
99from mypyc .codegen .emit import Emitter
1010from mypyc .ir .ops import (
11- OpVisitor , Goto , Branch , Return , Assign , Integer , LoadErrorValue , GetAttr , SetAttr ,
11+ Op , OpVisitor , Goto , Branch , Return , Assign , Integer , LoadErrorValue , GetAttr , SetAttr ,
1212 LoadStatic , InitStatic , TupleGet , TupleSet , Call , IncRef , DecRef , Box , Cast , Unbox ,
1313 BasicBlock , Value , MethodCall , Unreachable , NAMESPACE_STATIC , NAMESPACE_TYPE , NAMESPACE_MODULE ,
1414 RaiseStandardError , CallC , LoadGlobal , Truncate , IntOp , LoadMem , GetElementPtr ,
@@ -88,8 +88,13 @@ def generate_native_function(fn: FuncIR,
8888 next_block = blocks [i + 1 ]
8989 body .emit_label (block )
9090 visitor .next_block = next_block
91- for op in block .ops :
92- op .accept (visitor )
91+
92+ ops = block .ops
93+ visitor .ops = ops
94+ visitor .op_index = 0
95+ while visitor .op_index < len (ops ):
96+ ops [visitor .op_index ].accept (visitor )
97+ visitor .op_index += 1
9398
9499 body .emit_line ('}' )
95100
@@ -110,7 +115,12 @@ def __init__(self,
110115 self .module_name = module_name
111116 self .literals = emitter .context .literals
112117 self .rare = False
118+ # Next basic block to be processed after the current one (if any), set by caller
113119 self .next_block : Optional [BasicBlock ] = None
120+ # Ops in the basic block currently being processed, set by caller
121+ self .ops : List [Op ] = []
122+ # Current index within ops; visit methods can increment this to skip/merge ops
123+ self .op_index = 0
114124
115125 def temp_name (self ) -> str :
116126 return self .emitter .temp_name ()
@@ -293,16 +303,44 @@ def visit_get_attr(self, op: GetAttr) -> None:
293303 attr_expr = self .get_attr_expr (obj , op , decl_cl )
294304 self .emitter .emit_line ('{} = {};' .format (dest , attr_expr ))
295305 self .emitter .emit_undefined_attr_check (
296- attr_rtype , attr_expr , '==' , unlikely = True
306+ attr_rtype , dest , '==' , unlikely = True
297307 )
298308 exc_class = 'PyExc_AttributeError'
299- self .emitter .emit_line (
300- 'PyErr_SetString({}, "attribute {} of {} undefined");' .format (
301- exc_class , repr (op .attr ), repr (cl .name )))
309+ merged_branch = None
310+ branch = self .next_branch ()
311+ if branch is not None :
312+ if (branch .value is op
313+ and branch .op == Branch .IS_ERROR
314+ and branch .traceback_entry is not None
315+ and not branch .negated ):
316+ # Generate code for the following branch here to avoid
317+ # redundant branches in the generate code.
318+ self .emit_attribute_error (branch , cl .name , op .attr )
319+ self .emit_line ('goto %s;' % self .label (branch .true ))
320+ merged_branch = branch
321+ self .emitter .emit_line ('}' )
322+ if not merged_branch :
323+ self .emitter .emit_line (
324+ 'PyErr_SetString({}, "attribute {} of {} undefined");' .format (
325+ exc_class , repr (op .attr ), repr (cl .name )))
326+
302327 if attr_rtype .is_refcounted :
303- self .emitter .emit_line ('} else {' )
304- self .emitter .emit_inc_ref (attr_expr , attr_rtype )
305- self .emitter .emit_line ('}' )
328+ if not merged_branch :
329+ self .emitter .emit_line ('} else {' )
330+ self .emitter .emit_inc_ref (dest , attr_rtype )
331+ if merged_branch :
332+ if merged_branch .false is not self .next_block :
333+ self .emit_line ('goto %s;' % self .label (merged_branch .false ))
334+ self .op_index += 1
335+ else :
336+ self .emitter .emit_line ('}' )
337+
338+ def next_branch (self ) -> Optional [Branch ]:
339+ if self .op_index + 1 < len (self .ops ):
340+ next_op = self .ops [self .op_index + 1 ]
341+ if isinstance (next_op , Branch ):
342+ return next_op
343+ return None
306344
307345 def visit_set_attr (self , op : SetAttr ) -> None :
308346 dest = self .reg (op )
@@ -603,6 +641,19 @@ def emit_traceback(self, op: Branch) -> None:
603641 if DEBUG_ERRORS :
604642 self .emit_line ('assert(PyErr_Occurred() != NULL && "failure w/o err!");' )
605643
644+ def emit_attribute_error (self , op : Branch , class_name : str , attr : str ) -> None :
645+ assert op .traceback_entry is not None
646+ globals_static = self .emitter .static_name ('globals' , self .module_name )
647+ self .emit_line ('CPy_AttributeError("%s", "%s", "%s", "%s", %d, %s);' % (
648+ self .source_path .replace ("\\ " , "\\ \\ " ),
649+ op .traceback_entry [0 ],
650+ class_name ,
651+ attr ,
652+ op .traceback_entry [1 ],
653+ globals_static ))
654+ if DEBUG_ERRORS :
655+ self .emit_line ('assert(PyErr_Occurred() != NULL && "failure w/o err!");' )
656+
606657 def emit_signed_int_cast (self , type : RType ) -> str :
607658 if is_tagged (type ):
608659 return '(Py_ssize_t)'
0 commit comments