Skip to content

Commit 191ff29

Browse files
committed
[GR-48876] Fix Wasm OSR compilation.
PullRequest: graal/15760
2 parents 15b7312 + 0951269 commit 191ff29

File tree

6 files changed

+164
-53
lines changed

6 files changed

+164
-53
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
3+
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4+
*
5+
* The Universal Permissive License (UPL), Version 1.0
6+
*
7+
* Subject to the condition set forth below, permission is hereby granted to any
8+
* person obtaining a copy of this software, associated documentation and/or
9+
* data (collectively the "Software"), free of charge and under any and all
10+
* copyright rights in the Software, and any and all patent rights owned or
11+
* freely licensable by each licensor hereunder covering either (i) the
12+
* unmodified Software as contributed to or provided by such licensor, or (ii)
13+
* the Larger Works (as defined below), to deal in both
14+
*
15+
* (a) the Software, and
16+
*
17+
* (b) any piece of software and/or hardware listed in the lrgrwrks.txt file if
18+
* one is included with the Software each a "Larger Work" to which the Software
19+
* is contributed by such licensors),
20+
*
21+
* without restriction, including without limitation the rights to copy, create
22+
* derivative works of, display, perform, and distribute the Software and make,
23+
* use, sell, offer for sale, import, export, have made, and have sold the
24+
* Software and the Larger Work(s), and to sublicense the foregoing rights on
25+
* either these or other terms.
26+
*
27+
* This license is subject to the following condition:
28+
*
29+
* The above copyright notice and either this complete permission notice or at a
30+
* minimum a reference to the UPL must be included in all copies or substantial
31+
* portions of the Software.
32+
*
33+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
34+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
35+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
36+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
37+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
38+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
39+
* SOFTWARE.
40+
*/
41+
package org.graalvm.wasm.test.suites.bytecode;
42+
43+
import static org.graalvm.wasm.utils.WasmBinaryTools.compileWat;
44+
45+
import java.io.IOException;
46+
47+
import org.graalvm.polyglot.Context;
48+
import org.graalvm.polyglot.Engine;
49+
import org.graalvm.polyglot.Source;
50+
import org.graalvm.polyglot.Value;
51+
import org.graalvm.polyglot.io.ByteSequence;
52+
import org.graalvm.wasm.WasmLanguage;
53+
import org.junit.Assert;
54+
import org.junit.Test;
55+
56+
public class WasmOSRSuite {
57+
private static final int N_CONTEXTS = 2;
58+
59+
@Test
60+
public void testOSR() throws IOException, InterruptedException {
61+
final ByteSequence binaryMain = ByteSequence.create(compileWat("main", """
62+
(module
63+
(type (;0;) (func (result i32)))
64+
(import "wasi_snapshot_preview1" "sched_yield" (func $__wasi_sched_yield (type 0)))
65+
(memory (;0;) 4)
66+
(export "memory" (memory 0))
67+
(func (export "_main") (type 0)
68+
(local $i i32)
69+
i32.const 1000
70+
local.set $i
71+
block
72+
loop
73+
local.get $i
74+
i32.const 1
75+
i32.sub
76+
local.tee $i
77+
call $__wasi_sched_yield
78+
drop
79+
i32.eqz
80+
br_if 1
81+
br 0
82+
end
83+
end
84+
i32.const 0
85+
)
86+
)
87+
"""));
88+
final Source sourceMain = Source.newBuilder(WasmLanguage.ID, binaryMain, "main").build();
89+
var eb = Engine.newBuilder().allowExperimentalOptions(true);
90+
eb.option("wasm.Builtins", "wasi_snapshot_preview1");
91+
eb.option("engine.OSRCompilationThreshold", "100");
92+
eb.option("engine.BackgroundCompilation", "false");
93+
try (Engine engine = eb.build()) {
94+
for (int i = 0; i < N_CONTEXTS; i++) {
95+
try (Context context = Context.newBuilder(WasmLanguage.ID).engine(engine).build()) {
96+
Value mainMod = context.eval(sourceMain);
97+
Value mainFun = mainMod.getMember("_main");
98+
Assert.assertEquals(0, mainFun.execute().asInt());
99+
}
100+
}
101+
}
102+
}
103+
}

wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmContext.java

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -222,29 +222,4 @@ public WasmContextOptions getContextOptions() {
222222
public static WasmContext get(Node node) {
223223
return REFERENCE.get(node);
224224
}
225-
226-
/**
227-
* @return The current primitive multi-value stack or null if it has never been resized.
228-
*/
229-
public long[] primitiveMultiValueStack() {
230-
return language.multiValueStack().primitiveStack();
231-
}
232-
233-
/**
234-
* @return the current reference multi-value stack or null if it has never been resized.
235-
*/
236-
public Object[] referenceMultiValueStack() {
237-
return language.multiValueStack().referenceStack();
238-
}
239-
240-
/**
241-
* Updates the size of the multi-value stack if needed. In case of a resize, the values are not
242-
* copied. Therefore, resizing should occur before any call to a function that uses the
243-
* multi-value stack.
244-
*
245-
* @param expectedSize The minimum expected size.
246-
*/
247-
public void resizeMultiValueStack(int expectedSize) {
248-
language.multiValueStack().resize(expectedSize);
249-
}
250225
}

wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmFunctionInstance.java

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -143,28 +143,29 @@ Object execute(Object[] arguments,
143143
// At this point the multi-value stack has already been populated, therefore, we don't
144144
// have to check the size of the multi-value stack.
145145
if (result == WasmConstant.MULTI_VALUE) {
146-
return multiValueStackAsArray();
146+
return multiValueStackAsArray(WasmLanguage.get(self));
147147
}
148148
return result;
149149
} finally {
150150
c.leave(self, prev);
151151
}
152152
}
153153

154-
private Object multiValueStackAsArray() {
155-
final long[] multiValueStack = context().primitiveMultiValueStack();
156-
final Object[] referenceMultiValueStack = context().referenceMultiValueStack();
154+
private Object multiValueStackAsArray(WasmLanguage language) {
155+
final var multiValueStack = language.multiValueStack();
156+
final long[] primitiveMultiValueStack = multiValueStack.primitiveStack();
157+
final Object[] referenceMultiValueStack = multiValueStack.referenceStack();
157158
final int resultCount = function.resultCount();
158-
assert multiValueStack.length >= resultCount;
159+
assert primitiveMultiValueStack.length >= resultCount;
159160
assert referenceMultiValueStack.length >= resultCount;
160161
final Object[] values = new Object[resultCount];
161162
for (int i = 0; i < resultCount; i++) {
162163
byte resultType = function.resultTypeAt(i);
163164
values[i] = switch (resultType) {
164-
case WasmType.I32_TYPE -> (int) multiValueStack[i];
165-
case WasmType.I64_TYPE -> multiValueStack[i];
166-
case WasmType.F32_TYPE -> Float.intBitsToFloat((int) multiValueStack[i]);
167-
case WasmType.F64_TYPE -> Double.longBitsToDouble(multiValueStack[i]);
165+
case WasmType.I32_TYPE -> (int) primitiveMultiValueStack[i];
166+
case WasmType.I64_TYPE -> primitiveMultiValueStack[i];
167+
case WasmType.F32_TYPE -> Float.intBitsToFloat((int) primitiveMultiValueStack[i]);
168+
case WasmType.F64_TYPE -> Double.longBitsToDouble(primitiveMultiValueStack[i]);
168169
case WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE -> referenceMultiValueStack[i];
169170
default -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL);
170171
};

wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmLanguage.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,20 +211,33 @@ public MultiValueStack multiValueStack() {
211211
return multiValueStackThreadLocal.get();
212212
}
213213

214-
static final class MultiValueStack {
214+
public static final class MultiValueStack {
215215
private long[] primitiveStack;
216216
private Object[] referenceStack;
217217
// Initialize size to 1, so we only create the stack for more than 1 result value.
218218
private int size = 1;
219219

220+
/**
221+
* @return The current primitive multi-value stack or null if it has never been resized.
222+
*/
220223
public long[] primitiveStack() {
221224
return primitiveStack;
222225
}
223226

227+
/**
228+
* @return the current reference multi-value stack or null if it has never been resized.
229+
*/
224230
public Object[] referenceStack() {
225231
return referenceStack;
226232
}
227233

234+
/**
235+
* Updates the size of the multi-value stack if needed. In case of a resize, the values are
236+
* not copied. Therefore, resizing should occur before any call to a function that uses the
237+
* multi-value stack.
238+
*
239+
* @param expectedSize The minimum expected size.
240+
*/
228241
public void resize(int expectedSize) {
229242
if (expectedSize > size) {
230243
primitiveStack = new long[expectedSize];

wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
import org.graalvm.wasm.WasmFunction;
7171
import org.graalvm.wasm.WasmFunctionInstance;
7272
import org.graalvm.wasm.WasmInstance;
73+
import org.graalvm.wasm.WasmLanguage;
7374
import org.graalvm.wasm.WasmMath;
7475
import org.graalvm.wasm.WasmModule;
7576
import org.graalvm.wasm.WasmTable;
@@ -88,6 +89,7 @@
8889
import com.oracle.truffle.api.ExactMath;
8990
import com.oracle.truffle.api.HostCompilerDirectives.BytecodeInterpreterSwitch;
9091
import com.oracle.truffle.api.TruffleContext;
92+
import com.oracle.truffle.api.frame.Frame;
9193
import com.oracle.truffle.api.frame.VirtualFrame;
9294
import com.oracle.truffle.api.interop.InteropLibrary;
9395
import com.oracle.truffle.api.interop.InvalidArrayIndexException;
@@ -213,6 +215,21 @@ public void setOSRMetadata(Object osrMetadata) {
213215
this.osrMetadata = osrMetadata;
214216
}
215217

218+
/** Preserve the first argument, i.e. the {@link WasmInstance}. */
219+
@Override
220+
public Object[] storeParentFrameInArguments(VirtualFrame parentFrame) {
221+
CompilerAsserts.neverPartOfCompilation();
222+
WasmInstance instance = ((WasmRootNode) getRootNode()).instance(parentFrame);
223+
Object[] osrFrameArgs = new Object[]{instance, parentFrame};
224+
assert WasmArguments.isValid(osrFrameArgs);
225+
return osrFrameArgs;
226+
}
227+
228+
@Override
229+
public Frame restoreParentFrameFromArguments(Object[] arguments) {
230+
return (Frame) arguments[1];
231+
}
232+
216233
// endregion OSR support
217234

218235
/**
@@ -530,7 +547,7 @@ public Object executeBodyFromOffset(WasmContext context, WasmInstance instance,
530547
}
531548
break;
532549
} else {
533-
extractMultiValueResult(context, frame, stackPointer, result, resultCount, function.typeIndex());
550+
extractMultiValueResult(frame, stackPointer, result, resultCount, function.typeIndex());
534551
stackPointer += resultCount;
535552
break;
536553
}
@@ -675,7 +692,7 @@ public Object executeBodyFromOffset(WasmContext context, WasmInstance instance,
675692
}
676693
break;
677694
} else {
678-
extractMultiValueResult(context, frame, stackPointer, result, resultCount, expectedFunctionTypeIndex);
695+
extractMultiValueResult(frame, stackPointer, result, resultCount, expectedFunctionTypeIndex);
679696
stackPointer += resultCount;
680697
break;
681698
}
@@ -4069,26 +4086,27 @@ private static boolean profileBranchTable(byte[] data, final int counterOffset,
40694086
* @param functionTypeIndex The function type index of the called function.
40704087
*/
40714088
@ExplodeLoop
4072-
private void extractMultiValueResult(WasmContext context, VirtualFrame frame, int stackPointer, Object result, int resultCount, int functionTypeIndex) {
4089+
private void extractMultiValueResult(VirtualFrame frame, int stackPointer, Object result, int resultCount, int functionTypeIndex) {
40734090
CompilerAsserts.partialEvaluationConstant(resultCount);
40744091
if (result == WasmConstant.MULTI_VALUE) {
4075-
final long[] multiValueStack = context.primitiveMultiValueStack();
4076-
final Object[] referenceMultiValueStack = context.referenceMultiValueStack();
4092+
final var multiValueStack = WasmLanguage.get(this).multiValueStack();
4093+
final long[] primitiveMultiValueStack = multiValueStack.primitiveStack();
4094+
final Object[] referenceMultiValueStack = multiValueStack.referenceStack();
40774095
for (int i = 0; i < resultCount; i++) {
40784096
final byte resultType = module.symbolTable().functionTypeResultTypeAt(functionTypeIndex, i);
40794097
CompilerAsserts.partialEvaluationConstant(resultType);
40804098
switch (resultType) {
40814099
case WasmType.I32_TYPE:
4082-
pushInt(frame, stackPointer + i, (int) multiValueStack[i]);
4100+
pushInt(frame, stackPointer + i, (int) primitiveMultiValueStack[i]);
40834101
break;
40844102
case WasmType.I64_TYPE:
4085-
pushLong(frame, stackPointer + i, multiValueStack[i]);
4103+
pushLong(frame, stackPointer + i, primitiveMultiValueStack[i]);
40864104
break;
40874105
case WasmType.F32_TYPE:
4088-
pushFloat(frame, stackPointer + i, Float.intBitsToFloat((int) multiValueStack[i]));
4106+
pushFloat(frame, stackPointer + i, Float.intBitsToFloat((int) primitiveMultiValueStack[i]));
40894107
break;
40904108
case WasmType.F64_TYPE:
4091-
pushDouble(frame, stackPointer + i, Double.longBitsToDouble(multiValueStack[i]));
4109+
pushDouble(frame, stackPointer + i, Double.longBitsToDouble(primitiveMultiValueStack[i]));
40924110
break;
40934111
case WasmType.FUNCREF_TYPE:
40944112
case WasmType.EXTERNREF_TYPE:

wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmRootNode.java

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ public Object executeWithContext(VirtualFrame frame, WasmContext context, WasmIn
167167
final int resultCount = functionNode.resultCount();
168168
CompilerAsserts.partialEvaluationConstant(resultCount);
169169
if (resultCount > 1) {
170-
context.resizeMultiValueStack(resultCount);
170+
WasmLanguage.get(this).multiValueStack().resize(resultCount);
171171
}
172172

173173
try {
@@ -199,31 +199,32 @@ public Object executeWithContext(VirtualFrame frame, WasmContext context, WasmIn
199199
throw WasmException.format(Failure.UNSPECIFIED_INTERNAL, this, "Unknown result type: %d", resultType);
200200
}
201201
} else {
202-
moveResultValuesToMultiValueStack(frame, context, resultCount, localCount);
202+
moveResultValuesToMultiValueStack(frame, resultCount, localCount);
203203
return WasmConstant.MULTI_VALUE;
204204
}
205205
}
206206

207207
@ExplodeLoop
208-
private void moveResultValuesToMultiValueStack(VirtualFrame frame, WasmContext context, int resultCount, int localCount) {
208+
private void moveResultValuesToMultiValueStack(VirtualFrame frame, int resultCount, int localCount) {
209209
CompilerAsserts.partialEvaluationConstant(resultCount);
210-
final long[] multiValueStack = context.primitiveMultiValueStack();
211-
final Object[] referenceMultiValueStack = context.referenceMultiValueStack();
210+
final var multiValueStack = WasmLanguage.get(this).multiValueStack();
211+
final long[] primitiveMultiValueStack = multiValueStack.primitiveStack();
212+
final Object[] referenceMultiValueStack = multiValueStack.referenceStack();
212213
for (int i = 0; i < resultCount; i++) {
213214
final int resultType = functionNode.resultType(i);
214215
CompilerAsserts.partialEvaluationConstant(resultType);
215216
switch (resultType) {
216217
case WasmType.I32_TYPE:
217-
multiValueStack[i] = popInt(frame, localCount + i);
218+
primitiveMultiValueStack[i] = popInt(frame, localCount + i);
218219
break;
219220
case WasmType.I64_TYPE:
220-
multiValueStack[i] = popLong(frame, localCount + i);
221+
primitiveMultiValueStack[i] = popLong(frame, localCount + i);
221222
break;
222223
case WasmType.F32_TYPE:
223-
multiValueStack[i] = Float.floatToRawIntBits(popFloat(frame, localCount + i));
224+
primitiveMultiValueStack[i] = Float.floatToRawIntBits(popFloat(frame, localCount + i));
224225
break;
225226
case WasmType.F64_TYPE:
226-
multiValueStack[i] = Double.doubleToRawLongBits(popDouble(frame, localCount + i));
227+
primitiveMultiValueStack[i] = Double.doubleToRawLongBits(popDouble(frame, localCount + i));
227228
break;
228229
case WasmType.FUNCREF_TYPE:
229230
case WasmType.EXTERNREF_TYPE:

0 commit comments

Comments
 (0)