Skip to content

Commit 86651e3

Browse files
committed
[GR-32841] Integrate LoopConditionProfile into LoopNode.
PullRequest: graal/9441
2 parents 0b8e4e3 + 37b1461 commit 86651e3

File tree

10 files changed

+358
-74
lines changed

10 files changed

+358
-74
lines changed

compiler/src/org.graalvm.compiler.nodes/src/org/graalvm/compiler/nodes/IfNode.java

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@
5454
import org.graalvm.compiler.graph.NodeClass;
5555
import org.graalvm.compiler.graph.NodeSourcePosition;
5656
import org.graalvm.compiler.graph.iterators.NodeIterable;
57-
import org.graalvm.compiler.nodes.spi.Simplifiable;
58-
import org.graalvm.compiler.nodes.spi.SimplifierTool;
5957
import org.graalvm.compiler.nodeinfo.InputType;
6058
import org.graalvm.compiler.nodeinfo.NodeInfo;
6159
import org.graalvm.compiler.nodes.ProfileData.BranchProbabilityData;
@@ -79,6 +77,8 @@
7977
import org.graalvm.compiler.nodes.java.LoadFieldNode;
8078
import org.graalvm.compiler.nodes.spi.LIRLowerable;
8179
import org.graalvm.compiler.nodes.spi.NodeLIRBuilderTool;
80+
import org.graalvm.compiler.nodes.spi.Simplifiable;
81+
import org.graalvm.compiler.nodes.spi.SimplifierTool;
8282
import org.graalvm.compiler.nodes.spi.SwitchFoldable;
8383
import org.graalvm.compiler.nodes.util.GraphUtil;
8484

@@ -1445,9 +1445,6 @@ private boolean removeIntermediateMaterialization(SimplifierTool tool) {
14451445
assert !ends.hasNext();
14461446
assert falseEnds.size() + trueEnds.size() == xs.length;
14471447

1448-
connectEnds(falseEnds, phi, phiValues, oldFalseSuccessor, merge, tool);
1449-
connectEnds(trueEnds, phi, phiValues, oldTrueSuccessor, merge, tool);
1450-
14511448
if (this.getTrueSuccessorProbability() == 0.0) {
14521449
for (AbstractEndNode endNode : trueEnds) {
14531450
propagateZeroProbability(endNode);
@@ -1460,6 +1457,13 @@ private boolean removeIntermediateMaterialization(SimplifierTool tool) {
14601457
}
14611458
}
14621459

1460+
if (this.getProfileData().getProfileSource() == ProfileSource.INJECTED && trueEnds.size() == 1 && falseEnds.size() == 1) {
1461+
propagateInjectedProfile(this.getProfileData(), trueEnds.get(0), falseEnds.get(0));
1462+
}
1463+
1464+
connectEnds(falseEnds, phi, phiValues, oldFalseSuccessor, merge, tool);
1465+
connectEnds(trueEnds, phi, phiValues, oldTrueSuccessor, merge, tool);
1466+
14631467
/*
14641468
* Remove obsolete ends only after processing all ends, otherwise oldTrueSuccessor or
14651469
* oldFalseSuccessor might have been removed if it is a LoopExitNode.
@@ -1516,6 +1520,41 @@ private static void propagateZeroProbability(FixedNode startNode) {
15161520
}
15171521
}
15181522

1523+
private static IfNode predecessorIf(FixedNode end) {
1524+
for (FixedNode node : GraphUtil.predecessorIterable(end)) {
1525+
if (node instanceof IfNode) {
1526+
return (IfNode) node;
1527+
} else if (node instanceof AbstractMergeNode) {
1528+
return null;
1529+
}
1530+
}
1531+
return null;
1532+
}
1533+
1534+
private static void propagateInjectedProfile(BranchProbabilityData profile, EndNode trueEnd, EndNode falseEnd) {
1535+
Node prev = null;
1536+
for (FixedNode node : GraphUtil.predecessorIterable(trueEnd)) {
1537+
if (node instanceof IfNode) {
1538+
IfNode ifNode = (IfNode) node;
1539+
if (!ProfileSource.isTrusted(ifNode.getProfileData().getProfileSource())) {
1540+
if (ifNode == predecessorIf(falseEnd)) {
1541+
if (ifNode.trueSuccessor() == prev) {
1542+
ifNode.setTrueSuccessorProbability(profile);
1543+
} else if (ifNode.falseSuccessor() == prev) {
1544+
ifNode.setTrueSuccessorProbability(profile.negated());
1545+
} else {
1546+
throw new GraalError("Illegal state");
1547+
}
1548+
}
1549+
}
1550+
return;
1551+
} else if (node instanceof AbstractMergeNode) {
1552+
return;
1553+
}
1554+
prev = node;
1555+
}
1556+
}
1557+
15191558
/**
15201559
* Connects a set of ends to a given successor, inserting a merge node if there is more than one
15211560
* end. If {@code ends} is not empty, then {@code successor} is added to {@code tool}'s

compiler/src/org.graalvm.compiler.nodes/src/org/graalvm/compiler/nodes/extended/BranchProbabilityNode.java

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,13 @@
2727
import static org.graalvm.compiler.nodeinfo.NodeCycles.CYCLES_0;
2828
import static org.graalvm.compiler.nodeinfo.NodeSize.SIZE_0;
2929

30-
import jdk.vm.ci.meta.JavaKind;
3130
import org.graalvm.compiler.core.common.calc.CanonicalCondition;
3231
import org.graalvm.compiler.core.common.type.IntegerStamp;
3332
import org.graalvm.compiler.core.common.type.StampFactory;
3433
import org.graalvm.compiler.debug.GraalError;
3534
import org.graalvm.compiler.graph.Node;
3635
import org.graalvm.compiler.graph.NodeClass;
3736
import org.graalvm.compiler.graph.iterators.NodePredicates;
38-
import org.graalvm.compiler.nodes.spi.Canonicalizable;
39-
import org.graalvm.compiler.nodes.spi.CanonicalizerTool;
40-
import org.graalvm.compiler.nodes.spi.Simplifiable;
41-
import org.graalvm.compiler.nodes.spi.SimplifierTool;
4237
import org.graalvm.compiler.nodeinfo.NodeInfo;
4338
import org.graalvm.compiler.nodes.FixedGuardNode;
4439
import org.graalvm.compiler.nodes.IfNode;
@@ -51,8 +46,15 @@
5146
import org.graalvm.compiler.nodes.calc.IntegerEqualsNode;
5247
import org.graalvm.compiler.nodes.calc.NarrowNode;
5348
import org.graalvm.compiler.nodes.calc.ZeroExtendNode;
49+
import org.graalvm.compiler.nodes.spi.Canonicalizable;
50+
import org.graalvm.compiler.nodes.spi.CanonicalizerTool;
5451
import org.graalvm.compiler.nodes.spi.Lowerable;
5552
import org.graalvm.compiler.nodes.spi.LoweringTool;
53+
import org.graalvm.compiler.nodes.spi.Simplifiable;
54+
import org.graalvm.compiler.nodes.spi.SimplifierTool;
55+
import org.graalvm.compiler.nodes.util.GraphUtil;
56+
57+
import jdk.vm.ci.meta.JavaKind;
5658

5759
/**
5860
* Instances of this node class will look for a preceding if node and put the given probability into
@@ -176,7 +178,45 @@ public void simplify(SimplifierTool tool) {
176178
}
177179
replaceAndDelete(currentCondition);
178180
if (tool != null) {
179-
tool.addToWorkList(currentCondition.usages());
181+
// @formatter:off
182+
// Try to eliminate useless Conditional == Constant eagerly, e.g.:
183+
//
184+
// <condition>
185+
// | C(1) C(0)
186+
// | | /
187+
// Conditional
188+
// |
189+
// BranchProbability
190+
// | C(0|1)
191+
// | /
192+
// IntegerEquals
193+
// |
194+
// If
195+
//
196+
// Should be directly simplified to:
197+
//
198+
// <condition>
199+
// |
200+
// If
201+
//
202+
// This allows the If to be simplified immediately after injecting the profile.
203+
// @formatter:on
204+
if (currentCondition instanceof ConditionalNode &&
205+
((ConditionalNode) currentCondition).trueValue().isConstant() && ((ConditionalNode) currentCondition).falseValue().isConstant()) {
206+
for (IntegerEqualsNode eq : currentCondition.usages().filter(IntegerEqualsNode.class).snapshot()) {
207+
if (eq.getY().isConstant() || eq.getX().isConstant()) {
208+
ValueNode canonical = eq.canonical(tool);
209+
if (canonical != eq && canonical != null) {
210+
tool.addToWorkList(eq.usages());
211+
eq.replaceAtUsages(graph().maybeAddOrUnique(canonical));
212+
GraphUtil.killWithUnusedFloatingInputs(eq);
213+
}
214+
}
215+
}
216+
}
217+
if (currentCondition.hasUsages()) {
218+
tool.addToWorkList(currentCondition.usages());
219+
}
180220
}
181221
} else {
182222
if (!isSubstitutionGraph()) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
3+
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4+
*
5+
* This code is free software; you can redistribute it and/or modify it
6+
* under the terms of the GNU General Public License version 2 only, as
7+
* published by the Free Software Foundation. Oracle designates this
8+
* particular file as subject to the "Classpath" exception as provided
9+
* by Oracle in the LICENSE file that accompanied this code.
10+
*
11+
* This code is distributed in the hope that it will be useful, but WITHOUT
12+
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13+
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14+
* version 2 for more details (a copy is included in the LICENSE file that
15+
* accompanied this code).
16+
*
17+
* You should have received a copy of the GNU General Public License version
18+
* 2 along with this work; if not, write to the Free Software Foundation,
19+
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20+
*
21+
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22+
* or visit www.oracle.com if you need additional information or have any
23+
* questions.
24+
*/
25+
package org.graalvm.compiler.truffle.runtime;
26+
27+
import java.util.Objects;
28+
29+
import com.oracle.truffle.api.CompilerDirectives;
30+
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
31+
import com.oracle.truffle.api.frame.VirtualFrame;
32+
import com.oracle.truffle.api.nodes.LoopNode;
33+
import com.oracle.truffle.api.nodes.RepeatingNode;
34+
35+
abstract class AbstractOptimizedLoopNode extends LoopNode {
36+
37+
@Child protected RepeatingNode repeatingNode;
38+
39+
@CompilationFinal private long trueCount; // long for long running loops.
40+
@CompilationFinal private int falseCount;
41+
42+
protected AbstractOptimizedLoopNode(RepeatingNode repeatingNode) {
43+
this.repeatingNode = Objects.requireNonNull(repeatingNode);
44+
}
45+
46+
@Override
47+
public final RepeatingNode getRepeatingNode() {
48+
return repeatingNode;
49+
}
50+
51+
@SuppressWarnings("deprecation")
52+
@Override
53+
public final void executeLoop(VirtualFrame frame) {
54+
execute(frame);
55+
}
56+
57+
protected final void profileCounted(long iterations) {
58+
if (CompilerDirectives.inInterpreter()) {
59+
long trueCountLocal = trueCount + iterations;
60+
if (trueCountLocal >= 0) { // don't write overflow values
61+
trueCount = trueCountLocal;
62+
int falseCountLocal = falseCount;
63+
if (falseCountLocal < Integer.MAX_VALUE) {
64+
falseCount = falseCountLocal + 1;
65+
}
66+
}
67+
}
68+
}
69+
70+
protected final boolean inject(boolean condition) {
71+
if (CompilerDirectives.inCompiledCode()) {
72+
return CompilerDirectives.injectBranchProbability(calculateProbability(trueCount, falseCount), condition);
73+
} else {
74+
return condition;
75+
}
76+
}
77+
78+
private static double calculateProbability(long trueCountLocal, int falseCountLocal) {
79+
if (falseCountLocal == 0 && trueCountLocal == 0) {
80+
// Avoid division by zero and assume default probability for AOT.
81+
return 0.5;
82+
} else {
83+
return (double) trueCountLocal / (double) (trueCountLocal + falseCountLocal);
84+
}
85+
}
86+
}

compiler/src/org.graalvm.compiler.truffle.runtime/src/org/graalvm/compiler/truffle/runtime/GraalCompilerDirectives.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131
*/
3232
public class GraalCompilerDirectives {
3333
/**
34-
* Returns a boolean value indicating whether the method is executed in a compiled tier wich can
35-
* be replaced with a higher tier (e.g. a first tier compilation can be replaced with a second
36-
* tier compilation).
34+
* Returns a boolean value indicating whether the method is executed in a compiled tier which
35+
* can be replaced with a higher tier (e.g. a first tier compilation can be replaced with a
36+
* second tier compilation).
3737
*
3838
* {@link PolyglotCompilerOptions#MultiTier}
3939
*

compiler/src/org.graalvm.compiler.truffle.runtime/src/org/graalvm/compiler/truffle/runtime/OptimizedLoopNode.java

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,38 +30,26 @@
3030
import com.oracle.truffle.api.nodes.LoopNode;
3131
import com.oracle.truffle.api.nodes.RepeatingNode;
3232

33-
public final class OptimizedLoopNode extends LoopNode {
34-
35-
@Child private RepeatingNode repeatingNode;
33+
public final class OptimizedLoopNode extends AbstractOptimizedLoopNode {
3634

3735
OptimizedLoopNode(RepeatingNode repeatingNode) {
38-
this.repeatingNode = repeatingNode;
39-
}
40-
41-
@Override
42-
public RepeatingNode getRepeatingNode() {
43-
return repeatingNode;
44-
}
45-
46-
@SuppressWarnings("deprecation")
47-
@Override
48-
public void executeLoop(VirtualFrame frame) {
49-
execute(frame);
36+
super(repeatingNode);
5037
}
5138

5239
@Override
5340
public Object execute(VirtualFrame frame) {
5441
Object status;
5542
long loopCount = 0;
5643
try {
57-
while (repeatingNode.shouldContinue(status = repeatingNode.executeRepeatingWithValue(frame))) {
44+
while (inject(repeatingNode.shouldContinue(status = repeatingNode.executeRepeatingWithValue(frame)))) {
5845
if (CompilerDirectives.inInterpreter() || GraalCompilerDirectives.hasNextTier()) {
5946
loopCount++;
6047
}
6148
TruffleSafepoint.poll(this);
6249
}
6350
return status;
6451
} finally {
52+
profileCounted(loopCount);
6553
reportLoopCount(this, OptimizedOSRLoopNode.toIntOrMaxInt(loopCount));
6654
}
6755
}

0 commit comments

Comments
 (0)