diff --git a/goto.py b/goto.py index 3c67459..cc6d746 100644 --- a/goto.py +++ b/goto.py @@ -86,6 +86,28 @@ def _parse_instructions(code): extended_arg_offset = None yield (dis.opname[opcode], oparg, offset) +def _get_instruction_size(opname, oparg=0): + size = 1 + + extended_arg = oparg >> _BYTECODE.argument_bits + if extended_arg != 0: + size += _get_instruction_size('EXTENDED_ARG', extended_arg) + oparg &= (1 << _BYTECODE.argument_bits) - 1 + + opcode = dis.opmap[opname] + if opcode >= _BYTECODE.have_argument: + size += _BYTECODE.argument.size + + return size + +def _get_instructions_size(ops): + size = 0 + for op in ops: + if isinstance(op, str): + size += _get_instruction_size(op) + else: + size += _get_instruction_size(*op) + return size def _write_instruction(buf, pos, opname, oparg=0): extended_arg = oparg >> _BYTECODE.argument_bits @@ -103,6 +125,13 @@ def _write_instruction(buf, pos, opname, oparg=0): return pos +def _write_instructions(buf, pos, ops): + for op in ops: + if isinstance(op, str): + pos = _write_instruction(buf, pos, op) + else: + pos = _write_instruction(buf, pos, *op) + return pos def _find_labels_and_gotos(code): labels = {} @@ -165,36 +194,29 @@ def _patch_code(code): if origin_stack[:target_depth] != target_stack: raise SyntaxError('Jump into different block') - failed = False - try: - for i in range(len(origin_stack) - target_depth): - pos = _write_instruction(buf, pos, 'POP_BLOCK') - - if target >= end: - rel_target = (target - pos) // _BYTECODE.jump_unit - oparg_bits = 0 - - while True: - rel_target -= (1 + _BYTECODE.argument.size) // _BYTECODE.jump_unit - if rel_target >> oparg_bits == 0: - pos = _write_instruction(buf, pos, 'EXTENDED_ARG', 0) - break + ops = [] + for i in range(len(origin_stack) - target_depth): + ops.append('POP_BLOCK') + ops.append(('JUMP_ABSOLUTE', target // _BYTECODE.jump_unit)) - oparg_bits += _BYTECODE.argument_bits - if rel_target >> oparg_bits == 0: - break + if pos + _get_instructions_size(ops) > end: + # not enough space, add code at buffer end and jump there + buf_end = len(buf) - pos = _write_instruction(buf, pos, 'JUMP_FORWARD', rel_target) - else: - pos = _write_instruction(buf, pos, 'JUMP_ABSOLUTE', target // _BYTECODE.jump_unit) + go_to_end_ops = [('JUMP_ABSOLUTE', buf_end // _BYTECODE.jump_unit)] - except (IndexError, struct.error): - failed = True + if pos + _get_instructions_size(go_to_end_ops) > end: + # not sure if reachable + raise SyntaxError('Goto in an incredibly huge function') - if failed or pos > end: - raise SyntaxError('Jump out of too many nested blocks') + pos = _write_instructions(buf, pos, go_to_end_ops) + _inject_nop_sled(buf, pos, end) - _inject_nop_sled(buf, pos, end) + buf.extend([0] * _get_instructions_size(ops)) + _write_instructions(buf, buf_end, ops) + else: + pos = _write_instructions(buf, pos, ops) + _inject_nop_sled(buf, pos, end) return _make_code(code, _array_to_bytes(buf)) diff --git a/test_goto.py b/test_goto.py index 7ff1dc1..6b12f0c 100644 --- a/test_goto.py +++ b/test_goto.py @@ -71,63 +71,55 @@ def func(): pytest.raises(SyntaxError, with_goto, func) -if sys.version_info >= (3, 6): - def test_jump_out_of_nested_2_loops(): - @with_goto - def func(): - x = 1 - for i in range(2): - for j in range(2): - # These are more than 256 bytes of bytecode, requiring - # a JUMP_FORWARD below on Python 3.6+, since the absolute - # address would be too large, after leaving two blocks. - x += x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x - x += x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x - x += x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x - - goto .end - label .end - return (i, j) - - assert func() == (0, 0) - - def test_jump_out_of_nested_3_loops(): - def func(): - for i in range(2): - for j in range(2): - for k in range(2): - goto .end - label .end - return (i, j, k) - - pytest.raises(SyntaxError, with_goto, func) -else: - def test_jump_out_of_nested_4_loops(): - @with_goto - def func(): - for i in range(2): - for j in range(2): - for k in range(2): - for m in range(2): - goto .end - label .end - return (i, j, k, m) - - assert func() == (0, 0, 0, 0) - - def test_jump_out_of_nested_5_loops(): - def func(): - for i in range(2): - for j in range(2): - for k in range(2): - for m in range(2): - for n in range(2): - goto .end - label .end - return (i, j, k, m, n) - - pytest.raises(SyntaxError, with_goto, func) +def test_jump_out_of_nested_2_loops(): + @with_goto + def func(): + x = 1 + for i in range(2): + for j in range(2): + # These are more than 256 bytes of bytecode + x += x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x + x += x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x + x += x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x+x + + goto .end + label .end + return (i, j) + + assert func() == (0, 0) + +def test_jump_out_of_nested_11_loops(): + @with_goto + def func(): + x = 1 + for i1 in range(2): + for i2 in range(2): + for i3 in range(2): + for i4 in range(2): + for i5 in range(2): + for i6 in range(2): + for i7 in range(2): + for i8 in range(2): + for i9 in range(2): + for i10 in range(2): + for i11 in range(2): + # These are more than + # 256 bytes of bytecode + x += (x+x+x+x+x+x+x+x+x+ + x+x+x+x+x+x+x+x+x+ + x+x+x+x+x+x+x+x+x) + x += (x+x+x+x+x+x+x+x+x+ + x+x+x+x+x+x+x+x+x+ + x+x+x+x+x+x+x+x+x) + x += (x+x+x+x+x+x+x+x+x+ + x+x+x+x+x+x+x+x+x+ + x+x+x+x+x+x+x+x+x) + + goto .end + label .end + return (i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11) + assert func() == (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) def test_jump_across_loops(): def func():