Skip to content

Commit f2b5d04

Browse files
[LLVM][InstSimplify] Add folds for SVE integer reduction intrinsics. (#167519)
[andv, eorv, orv, s/uaddv, s/umaxv, s/uminv] sve_reduce_##(none, ?) -> op's neutral value sve_reduce_##(any, neutral) -> op's neutral value [andv, orv, s/umaxv, s/uminv] sve_reduce_##(all, splat(X)) -> X [eorv] sve_reduce_##(all, splat(X)) -> 0
1 parent 6fc2bc1 commit f2b5d04

File tree

5 files changed

+1002
-0
lines changed

5 files changed

+1002
-0
lines changed

llvm/include/llvm/IR/Constant.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ class Constant : public User {
7979
/// Return true if the value is the smallest signed value.
8080
LLVM_ABI bool isMinSignedValue() const;
8181

82+
/// Return true if the value is the largest signed value.
83+
LLVM_ABI bool isMaxSignedValue() const;
84+
8285
/// Return true if this is a finite and non-zero floating-point scalar
8386
/// constant or a fixed width vector constant with all finite and non-zero
8487
/// elements.

llvm/lib/Analysis/InstructionSimplify.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include "llvm/IR/Dominators.h"
4242
#include "llvm/IR/InstrTypes.h"
4343
#include "llvm/IR/Instructions.h"
44+
#include "llvm/IR/IntrinsicsAArch64.h"
4445
#include "llvm/IR/Operator.h"
4546
#include "llvm/IR/PatternMatch.h"
4647
#include "llvm/IR/Statepoint.h"
@@ -6676,6 +6677,62 @@ static MinMaxOptResult OptimizeConstMinMax(const Constant *RHSConst,
66766677
return MinMaxOptResult::CannotOptimize;
66776678
}
66786679

6680+
static Value *simplifySVEIntReduction(Intrinsic::ID IID, Type *ReturnType,
6681+
Value *Op0, Value *Op1) {
6682+
Constant *C0 = dyn_cast<Constant>(Op0);
6683+
Constant *C1 = dyn_cast<Constant>(Op1);
6684+
unsigned Width = ReturnType->getPrimitiveSizeInBits();
6685+
6686+
// All false predicate or reduction of neutral values ==> neutral result.
6687+
switch (IID) {
6688+
case Intrinsic::aarch64_sve_eorv:
6689+
case Intrinsic::aarch64_sve_orv:
6690+
case Intrinsic::aarch64_sve_saddv:
6691+
case Intrinsic::aarch64_sve_uaddv:
6692+
case Intrinsic::aarch64_sve_umaxv:
6693+
if ((C0 && C0->isNullValue()) || (C1 && C1->isNullValue()))
6694+
return ConstantInt::get(ReturnType, 0);
6695+
break;
6696+
case Intrinsic::aarch64_sve_andv:
6697+
case Intrinsic::aarch64_sve_uminv:
6698+
if ((C0 && C0->isNullValue()) || (C1 && C1->isAllOnesValue()))
6699+
return ConstantInt::get(ReturnType, APInt::getMaxValue(Width));
6700+
break;
6701+
case Intrinsic::aarch64_sve_smaxv:
6702+
if ((C0 && C0->isNullValue()) || (C1 && C1->isMinSignedValue()))
6703+
return ConstantInt::get(ReturnType, APInt::getSignedMinValue(Width));
6704+
break;
6705+
case Intrinsic::aarch64_sve_sminv:
6706+
if ((C0 && C0->isNullValue()) || (C1 && C1->isMaxSignedValue()))
6707+
return ConstantInt::get(ReturnType, APInt::getSignedMaxValue(Width));
6708+
break;
6709+
}
6710+
6711+
switch (IID) {
6712+
case Intrinsic::aarch64_sve_andv:
6713+
case Intrinsic::aarch64_sve_orv:
6714+
case Intrinsic::aarch64_sve_smaxv:
6715+
case Intrinsic::aarch64_sve_sminv:
6716+
case Intrinsic::aarch64_sve_umaxv:
6717+
case Intrinsic::aarch64_sve_uminv:
6718+
// sve_reduce_##(all, splat(X)) ==> X
6719+
if (C0 && C0->isAllOnesValue()) {
6720+
if (Value *SplatVal = getSplatValue(Op1)) {
6721+
assert(SplatVal->getType() == ReturnType && "Unexpected result type!");
6722+
return SplatVal;
6723+
}
6724+
}
6725+
break;
6726+
case Intrinsic::aarch64_sve_eorv:
6727+
// sve_reduce_xor(all, splat(X)) ==> 0
6728+
if (C0 && C0->isAllOnesValue())
6729+
return ConstantInt::get(ReturnType, 0);
6730+
break;
6731+
}
6732+
6733+
return nullptr;
6734+
}
6735+
66796736
Value *llvm::simplifyBinaryIntrinsic(Intrinsic::ID IID, Type *ReturnType,
66806737
Value *Op0, Value *Op1,
66816738
const SimplifyQuery &Q,
@@ -7037,6 +7094,17 @@ Value *llvm::simplifyBinaryIntrinsic(Intrinsic::ID IID, Type *ReturnType,
70377094

70387095
break;
70397096
}
7097+
7098+
case Intrinsic::aarch64_sve_andv:
7099+
case Intrinsic::aarch64_sve_eorv:
7100+
case Intrinsic::aarch64_sve_orv:
7101+
case Intrinsic::aarch64_sve_saddv:
7102+
case Intrinsic::aarch64_sve_smaxv:
7103+
case Intrinsic::aarch64_sve_sminv:
7104+
case Intrinsic::aarch64_sve_uaddv:
7105+
case Intrinsic::aarch64_sve_umaxv:
7106+
case Intrinsic::aarch64_sve_uminv:
7107+
return simplifySVEIntReduction(IID, ReturnType, Op0, Op1);
70407108
default:
70417109
break;
70427110
}

llvm/lib/IR/Constants.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,23 @@ bool Constant::isMinSignedValue() const {
183183
return false;
184184
}
185185

186+
bool Constant::isMaxSignedValue() const {
187+
// Check for INT_MAX integers
188+
if (const ConstantInt *CI = dyn_cast<ConstantInt>(this))
189+
return CI->isMaxValue(/*isSigned=*/true);
190+
191+
// Check for FP which are bitcasted from INT_MAX integers
192+
if (const ConstantFP *CFP = dyn_cast<ConstantFP>(this))
193+
return CFP->getValueAPF().bitcastToAPInt().isMaxSignedValue();
194+
195+
// Check for splats of INT_MAX values.
196+
if (getType()->isVectorTy())
197+
if (const auto *SplatVal = getSplatValue())
198+
return SplatVal->isMaxSignedValue();
199+
200+
return false;
201+
}
202+
186203
bool Constant::isNotMinSignedValue() const {
187204
// Check for INT_MIN integers
188205
if (const ConstantInt *CI = dyn_cast<ConstantInt>(this))

0 commit comments

Comments
 (0)