Skip to content

Commit 0fbef95

Browse files
committed
[GR-61244] Implement direct jump table fast path
PullRequest: graal/19687
2 parents 70eb85b + 6c0a1ab commit 0fbef95

File tree

8 files changed

+348
-94
lines changed

8 files changed

+348
-94
lines changed

compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/core/aarch64/AArch64LIRGenerator.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ private static ConditionFlag toFloatConditionFlag(Condition cond, boolean unorde
423423
/**
424424
* Takes a Condition and returns the correct AArch64 specific ConditionFlag.
425425
*/
426-
private static ConditionFlag toIntConditionFlag(Condition cond) {
426+
public static ConditionFlag toIntConditionFlag(Condition cond) {
427427
switch (cond) {
428428
case EQ:
429429
return ConditionFlag.EQ;
@@ -566,8 +566,8 @@ protected StrategySwitchOp createStrategySwitchOp(SwitchStrategy strategy, Label
566566
}
567567

568568
@Override
569-
protected void emitRangeTableSwitch(int lowKey, LabelRef defaultTarget, LabelRef[] targets, AllocatableValue key) {
570-
append(new RangeTableSwitchOp(lowKey, defaultTarget, targets, key));
569+
protected void emitRangeTableSwitch(int lowKey, LabelRef defaultTarget, LabelRef[] targets, SwitchStrategy remainingStrategy, LabelRef[] remainingTargets, AllocatableValue key) {
570+
append(new RangeTableSwitchOp(lowKey, defaultTarget, targets, remainingStrategy, remainingTargets, key));
571571
}
572572

573573
@Override

compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/core/amd64/AMD64LIRGenerator.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,10 +1116,8 @@ public void emitStrategySwitch(SwitchStrategy strategy, AllocatableValue key, La
11161116
}
11171117

11181118
@Override
1119-
protected void emitRangeTableSwitch(int lowKey, LabelRef defaultTarget, LabelRef[] targets, AllocatableValue key) {
1120-
Variable scratch = newVariable(LIRKind.value(target().arch.getWordKind()));
1121-
Variable idxScratch = newVariable(key.getValueKind());
1122-
append(new RangeTableSwitchOp(lowKey, defaultTarget, targets, key, scratch, idxScratch));
1119+
protected void emitRangeTableSwitch(int lowKey, LabelRef defaultTarget, LabelRef[] targets, SwitchStrategy remainingStrategy, LabelRef[] remainingTargets, AllocatableValue key) {
1120+
append(new RangeTableSwitchOp(this, lowKey, defaultTarget, targets, remainingStrategy, remainingTargets, key));
11231121
}
11241122

11251123
@Override

compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/lir/SwitchStrategy.java

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2013, 2018, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2013, 2025, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* This code is free software; you can redistribute it and/or modify it
@@ -31,6 +31,7 @@
3131
import jdk.graal.compiler.asm.Label;
3232
import jdk.graal.compiler.core.common.calc.Condition;
3333
import jdk.graal.compiler.debug.Assertions;
34+
import jdk.graal.compiler.debug.GraalError;
3435
import jdk.graal.compiler.lir.asm.CompilationResultBuilder;
3536

3637
import jdk.vm.ci.meta.Constant;
@@ -212,7 +213,6 @@ public double getAverageEffort() {
212213
private EffortClosure effortClosure;
213214

214215
public SwitchStrategy(double[] keyProbabilities) {
215-
assert keyProbabilities.length >= 2 : Assertions.errorMessage(keyProbabilities);
216216
this.keyProbabilities = keyProbabilities;
217217
}
218218

@@ -334,8 +334,9 @@ public static class RangesStrategy extends PrimitiveStrategy {
334334

335335
public RangesStrategy(final double[] keyProbabilities, JavaConstant[] keyConstants) {
336336
super(keyProbabilities, keyConstants);
337-
338337
int keyCount = keyConstants.length;
338+
GraalError.guarantee(keyCount > 1, "%s", Arrays.toString(keyConstants));
339+
339340
indexes = new Integer[keyCount];
340341
for (int i = 0; i < keyCount; i++) {
341342
indexes[i] = i;
@@ -398,6 +399,7 @@ public static class BinaryStrategy extends PrimitiveStrategy {
398399

399400
public BinaryStrategy(double[] keyProbabilities, JavaConstant[] keyConstants) {
400401
super(keyProbabilities, keyConstants);
402+
GraalError.guarantee(keyProbabilities.length > 1, "%s", Arrays.toString(keyConstants));
401403
probabilitySums = new double[keyProbabilities.length + 1];
402404
double sum = 0;
403405
for (int i = 0; i < keyConstants.length; i++) {
@@ -501,8 +503,14 @@ private void recurseBinarySwitch(SwitchClosure closure, int left, int right, int
501503
public abstract void run(SwitchClosure closure);
502504

503505
private static SwitchStrategy[] getStrategies(double[] keyProbabilities, JavaConstant[] keyConstants, LabelRef[] keyTargets) {
504-
SwitchStrategy[] strategies = new SwitchStrategy[]{new SequentialStrategy(keyProbabilities, keyConstants), new RangesStrategy(keyProbabilities, keyConstants),
505-
new BinaryStrategy(keyProbabilities, keyConstants)};
506+
SwitchStrategy[] strategies;
507+
if (keyProbabilities.length == 1) {
508+
strategies = new SwitchStrategy[]{new SequentialStrategy(keyProbabilities, keyConstants)};
509+
} else {
510+
strategies = new SwitchStrategy[]{new SequentialStrategy(keyProbabilities, keyConstants), new RangesStrategy(keyProbabilities, keyConstants),
511+
new BinaryStrategy(keyProbabilities, keyConstants)};
512+
}
513+
506514
for (SwitchStrategy strategy : strategies) {
507515
strategy.effortClosure = strategy.new EffortClosure(keyTargets);
508516
strategy.run(strategy.effortClosure);

compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/lir/aarch64/AArch64ControlFlow.java

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2013, 2024, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2013, 2025, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* This code is free software; you can redistribute it and/or modify it
@@ -44,6 +44,7 @@
4444
import jdk.graal.compiler.asm.aarch64.AArch64MacroAssembler.ScratchRegister;
4545
import jdk.graal.compiler.code.CompilationResult.JumpTable;
4646
import jdk.graal.compiler.code.CompilationResult.JumpTable.EntryFormat;
47+
import jdk.graal.compiler.core.aarch64.AArch64LIRGenerator;
4748
import jdk.graal.compiler.core.common.NumUtil;
4849
import jdk.graal.compiler.core.common.calc.Condition;
4950
import jdk.graal.compiler.debug.Assertions;
@@ -385,7 +386,8 @@ private static void emitCompareHelper(CompilationResultBuilder crb, AArch64Macro
385386

386387
/**
387388
* This operation jumps to the appropriate destination as specified within a JumpTable, or to
388-
* the default condition if there is no match within the JumpTable.
389+
* another destination according to {@code remainingStrategy} if there is no match within the
390+
* JumpTable.
389391
*
390392
* <p>
391393
* The JumpTable contains a series of target offsets, relative to the start of the jump table,
@@ -395,8 +397,12 @@ private static void emitCompareHelper(CompilationResultBuilder crb, AArch64Macro
395397
* <ol>
396398
* <li>Determine whether the index is within the JumpTable. This is accomplished by first
397399
* normalizing the index (normalizedIdx == index - lowKey), and then checking whether
398-
* <code>(unsigned(normalizedIdx) &lt;= highKey - lowKey</code>). If not, then one must jump to
399-
* the defaultTarget.</li>
400+
* <code>(unsigned(normalizedIdx) &lt;= highKey - lowKey</code>).
401+
*
402+
* <li>If normalizedIdx is not in the JumpTable, the destination is decided by {@code
403+
* remainingStrategy}. If {@code remainingStrategy == null}, then the destination is {@code
404+
* defaultTarget}. Otherwise, the destination is {@code defaultTarget} or one of {@code
405+
* remainingTargets} based on the value of {@code key}.</li>
400406
*
401407
* <li>If normalizedIdx is within the JumpTable, then jump to JumpTableStart +
402408
* JumpTable[normalizedIdx].</li>
@@ -407,53 +413,80 @@ public static final class RangeTableSwitchOp extends AArch64BlockEndOp {
407413
private final int lowKey;
408414
private final LabelRef defaultTarget;
409415
private final LabelRef[] targets;
410-
@Use({REG}) protected AllocatableValue index;
416+
private final SwitchStrategy remainingStrategy;
417+
private final LabelRef[] remainingTargets;
418+
@Alive({REG}) protected AllocatableValue key;
411419

412-
public RangeTableSwitchOp(final int lowKey, final LabelRef defaultTarget, final LabelRef[] targets, AllocatableValue index) {
420+
public RangeTableSwitchOp(int lowKey, LabelRef defaultTarget, LabelRef[] targets, SwitchStrategy remainingStrategy, LabelRef[] remainingTargets, AllocatableValue key) {
413421
super(TYPE);
414422
this.lowKey = lowKey;
415423
assert defaultTarget != null;
416424
this.defaultTarget = defaultTarget;
417425
this.targets = targets;
418-
this.index = index;
426+
this.remainingStrategy = remainingStrategy;
427+
this.remainingTargets = remainingTargets;
428+
this.key = key;
419429
}
420430

421431
@Override
422432
public void emitCode(CompilationResultBuilder crb, AArch64MacroAssembler masm) {
423433
try (ScratchRegister sc1 = masm.getScratchRegister(); ScratchRegister sc2 = masm.getScratchRegister()) {
434+
Register keyReg = asRegister(key);
424435
Register scratch1 = sc1.getRegister();
425436
Register scratch2 = sc2.getRegister();
437+
GraalError.guarantee(!keyReg.equals(scratch1) && !keyReg.equals(scratch2), "must not alias");
426438
/* Compare index against jump table bounds */
427439
int highKey = lowKey + targets.length - 1;
428-
masm.sub(32, scratch2, asRegister(index), lowKey);
429-
int keyDiff = highKey - lowKey; // equivalent to targets.length - 1
430-
if (AArch64MacroAssembler.isComparisonImmediate(keyDiff)) {
431-
masm.compare(32, scratch2, keyDiff);
440+
Register keyOffsetReg = keyReg;
441+
if (lowKey != 0) {
442+
masm.sub(32, scratch2, keyReg, lowKey);
443+
keyOffsetReg = scratch2;
444+
}
445+
446+
int interval = highKey - lowKey;
447+
if (AArch64MacroAssembler.isComparisonImmediate(interval)) {
448+
masm.compare(32, keyOffsetReg, interval);
432449
} else {
433-
masm.mov(scratch1, keyDiff);
434-
masm.cmp(32, scratch2, scratch1);
450+
masm.mov(scratch1, interval);
451+
masm.cmp(32, keyOffsetReg, scratch1);
435452
}
436453

437-
// Jump to default target if index is not within the jump table
438-
masm.branchConditionally(ConditionFlag.HI, defaultTarget.label());
454+
Label outOfRangeLabel = defaultTarget.label();
455+
if (remainingStrategy != null) {
456+
Label remainingLabel = new Label();
457+
outOfRangeLabel = remainingLabel;
458+
459+
crb.getLIR().addSlowPath(this, () -> {
460+
masm.bind(remainingLabel);
461+
new StrategySwitchOp(remainingStrategy, remainingTargets, defaultTarget, key, AArch64LIRGenerator::toIntConditionFlag).emitCode(crb, masm);
462+
});
463+
}
464+
// Jump to outOfRangeLabel if index is not within the jump table
465+
masm.branchConditionally(ConditionFlag.HI, outOfRangeLabel);
439466

440-
emitJumpTable(crb, masm, scratch1, scratch2, lowKey, highKey, Arrays.stream(targets).map(LabelRef::label));
467+
emitJumpTable(crb, masm, keyOffsetReg, scratch1, scratch2, lowKey, highKey, Arrays.stream(targets).map(LabelRef::label));
441468
}
442469
}
443470

444-
public static void emitJumpTable(CompilationResultBuilder crb, AArch64MacroAssembler masm, Register scratch, Register index, int lowKey, int highKey, Stream<Label> targets) {
471+
public static void emitJumpTable(CompilationResultBuilder crb, AArch64MacroAssembler masm, Register scratch, Register keyScratch, int lowKey, int highKey, Stream<Label> targets) {
472+
emitJumpTable(crb, masm, keyScratch, scratch, keyScratch, lowKey, highKey, targets);
473+
}
474+
475+
private static void emitJumpTable(CompilationResultBuilder crb, AArch64MacroAssembler masm, Register key, Register scratch, Register idxScratch, int lowKey, int highKey,
476+
Stream<Label> targets) {
477+
GraalError.guarantee(!key.equals(scratch), "must not alias");
445478
Label jumpTable = new Label();
446479
// load start of jump table
447480
masm.adr(scratch, jumpTable);
448481
/*
449482
* Note scratch holds the start of the jump table and index stores the normalized index.
450483
* Because each jumpTable index is 4 bytes large, index should be scaled.
451484
*/
452-
AArch64Address jumpTableEntryAddr = AArch64Address.createExtendedRegisterOffsetAddress(32, scratch, index, true, ExtendType.UXTW);
485+
AArch64Address jumpTableEntryAddr = AArch64Address.createExtendedRegisterOffsetAddress(32, scratch, key, true, ExtendType.UXTW);
453486
// load relative target offset
454-
masm.ldrs(64, 32, index, jumpTableEntryAddr);
487+
masm.ldrs(64, 32, idxScratch, jumpTableEntryAddr);
455488
// compute target address (jumpTableStart + target offset)
456-
masm.add(64, scratch, scratch, index);
489+
masm.add(64, scratch, scratch, idxScratch);
457490
// jump to target
458491
masm.jmp(scratch);
459492

compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/lir/amd64/AMD64ControlFlow.java

Lines changed: 61 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2011, 2024, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2011, 2025, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* This code is free software; you can redistribute it and/or modify it
@@ -43,6 +43,7 @@
4343
import jdk.graal.compiler.asm.amd64.AMD64MacroAssembler;
4444
import jdk.graal.compiler.code.CompilationResult.JumpTable;
4545
import jdk.graal.compiler.code.CompilationResult.JumpTable.EntryFormat;
46+
import jdk.graal.compiler.core.common.LIRKind;
4647
import jdk.graal.compiler.core.common.NumUtil;
4748
import jdk.graal.compiler.core.common.Stride;
4849
import jdk.graal.compiler.core.common.calc.Condition;
@@ -57,6 +58,7 @@
5758
import jdk.graal.compiler.lir.SwitchStrategy;
5859
import jdk.graal.compiler.lir.Variable;
5960
import jdk.graal.compiler.lir.asm.CompilationResultBuilder;
61+
import jdk.graal.compiler.lir.gen.LIRGenerator;
6062
import jdk.vm.ci.amd64.AMD64;
6163
import jdk.vm.ci.amd64.AMD64Kind;
6264
import jdk.vm.ci.code.Register;
@@ -655,59 +657,98 @@ protected void conditionalJump(int index, Condition condition, Label target) {
655657
}
656658
}
657659

660+
/**
661+
* See {@code LIRGenerator::emitRangeTableSwitch}.
662+
*/
658663
public static final class RangeTableSwitchOp extends AMD64BlockEndOp {
659664
public static final LIRInstructionClass<RangeTableSwitchOp> TYPE = LIRInstructionClass.create(RangeTableSwitchOp.class);
660665
private final int lowKey;
661666
private final LabelRef defaultTarget;
662667
private final LabelRef[] targets;
663-
@LIRInstruction.Use protected Value index;
664-
@LIRInstruction.Temp({LIRInstruction.OperandFlag.REG, LIRInstruction.OperandFlag.HINT}) protected Value idxScratch;
665-
@LIRInstruction.Temp protected Value scratch;
668+
private final SwitchStrategy remainingStrategy;
669+
private final LabelRef[] remainingTargets;
670+
@LIRInstruction.Use(OperandFlag.REG) protected AllocatableValue key;
671+
@LIRInstruction.Temp(OperandFlag.REG) protected AllocatableValue scratch1;
672+
@LIRInstruction.Temp(OperandFlag.REG) protected AllocatableValue scratch2;
666673

667-
public RangeTableSwitchOp(final int lowKey, final LabelRef defaultTarget, final LabelRef[] targets, Value index, Variable scratch, Variable idxScratch) {
674+
public RangeTableSwitchOp(LIRGenerator gen, int lowKey, LabelRef defaultTarget, LabelRef[] targets, SwitchStrategy remainingStrategy, LabelRef[] remainingTargets, AllocatableValue key) {
668675
super(TYPE);
669676
this.lowKey = lowKey;
670677
assert defaultTarget != null;
671678
this.defaultTarget = defaultTarget;
672679
this.targets = targets;
673-
this.index = index;
674-
this.scratch = scratch;
675-
this.idxScratch = idxScratch;
680+
this.remainingStrategy = remainingStrategy;
681+
this.remainingTargets = remainingTargets;
682+
this.key = key;
683+
this.scratch1 = gen.newVariable(LIRKind.value(AMD64Kind.DWORD));
684+
this.scratch2 = gen.newVariable(LIRKind.value(AMD64Kind.DWORD));
676685
}
677686

678687
@Override
679688
public void emitCode(CompilationResultBuilder crb, AMD64MacroAssembler masm) {
680-
Register indexReg = asRegister(index, AMD64Kind.DWORD);
681-
Register idxScratchReg = asRegister(idxScratch, AMD64Kind.DWORD);
682-
Register scratchReg = asRegister(scratch, AMD64Kind.QWORD);
683-
684-
if (!indexReg.equals(idxScratchReg)) {
685-
masm.movl(idxScratchReg, indexReg);
689+
Register keyReg = asRegister(key);
690+
AllocatableValue scratch;
691+
AllocatableValue idxScratch;
692+
if (asRegister(scratch1).equals(keyReg)) {
693+
// keyReg cannot alias with scratchReg but may alias with idxScratchReg
694+
scratch = scratch2;
695+
idxScratch = key;
696+
} else {
697+
scratch = scratch1;
698+
idxScratch = scratch2;
686699
}
700+
Register scratchReg = asRegister(scratch);
701+
Register idxScratchReg = asRegister(idxScratch);
687702

688703
// Compare index against jump table bounds
689704
int highKey = lowKey + targets.length - 1;
705+
Register keyOffsetReg;
690706
if (lowKey != 0) {
691707
// subtract the low value from the switch value
692-
masm.subl(idxScratchReg, lowKey);
708+
if (keyReg.equals(idxScratchReg)) {
709+
masm.addl(idxScratchReg, -lowKey);
710+
} else {
711+
masm.lead(idxScratchReg, new AMD64Address(keyReg, -lowKey));
712+
}
693713
masm.cmpl(idxScratchReg, highKey - lowKey);
714+
keyOffsetReg = idxScratchReg;
694715
} else {
695-
masm.cmpl(idxScratchReg, highKey);
716+
masm.cmpl(keyReg, highKey);
717+
keyOffsetReg = keyReg;
696718
}
697719

698-
// Jump to default target if index is not within the jump table
699-
masm.jcc(ConditionFlag.Above, defaultTarget.label());
720+
Label outOfRangeLabel = defaultTarget.label();
721+
if (remainingStrategy != null) {
722+
Label remainingLabel = new Label();
723+
outOfRangeLabel = remainingLabel;
724+
boolean needsKeyRecover = lowKey != 0 && keyReg.equals(idxScratchReg);
700725

701-
emitJumpTable(crb, masm, scratchReg, idxScratchReg, lowKey, highKey, Arrays.stream(targets).map(LabelRef::label));
726+
crb.getLIR().addSlowPath(this, () -> {
727+
masm.bind(remainingLabel);
728+
if (needsKeyRecover) {
729+
masm.addl(keyReg, lowKey);
730+
}
731+
new StrategySwitchOp(remainingStrategy, remainingTargets, defaultTarget, key, scratch).emitCode(crb, masm);
732+
});
733+
}
734+
masm.jcc(ConditionFlag.Above, outOfRangeLabel);
735+
736+
emitJumpTable(crb, masm, keyOffsetReg, scratchReg, idxScratchReg, lowKey, highKey, Arrays.stream(targets).map(LabelRef::label));
702737
}
703738

704739
public static void emitJumpTable(CompilationResultBuilder crb, AMD64MacroAssembler masm, Register scratchReg, Register idxScratchReg, int lowKey, int highKey, Stream<Label> targets) {
740+
emitJumpTable(crb, masm, idxScratchReg, scratchReg, idxScratchReg, lowKey, highKey, targets);
741+
}
742+
743+
private static void emitJumpTable(CompilationResultBuilder crb, AMD64MacroAssembler masm, Register keyReg, Register scratchReg, Register idxScratchReg, int lowKey, int highKey,
744+
Stream<Label> targets) {
745+
GraalError.guarantee(!keyReg.equals(scratchReg), "must not alias");
705746
// Set scratch to address of jump table
706747
masm.leaq(scratchReg, new AMD64Address(AMD64.rip, 0));
707748
final int afterLea = masm.position();
708749

709750
// Load jump table entry into scratch and jump to it
710-
masm.movslq(idxScratchReg, new AMD64Address(scratchReg, idxScratchReg, Stride.S4, 0));
751+
masm.movslq(idxScratchReg, new AMD64Address(scratchReg, keyReg, Stride.S4, 0));
711752
masm.addq(scratchReg, idxScratchReg);
712753
masm.jmp(scratchReg);
713754

0 commit comments

Comments
 (0)