-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[VectorCombine] Generalize foldBitOpOfBitcasts to support more cast operations #148350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…perations This patch generalizes the foldBitOpOfBitcasts function (renamed to foldBitOpOfCastops) to handle additional cast operations beyond just bitcast. The optimization now supports: - trunc (truncate) - sext (sign extend) - zext (zero extend) - bitcast (original functionality) The optimization transforms: bitop(cast(x), cast(y)) -> cast(bitop(x, y)) This reduces the number of cast instructions from 2 to 1, which can improve performance on targets where cast operations are expensive or where performing bitwise operations on narrower types is beneficial. Changes: - Renamed foldBitOpOfBitcasts to foldBitOpOfCastops - Extended pattern matching to handle any CastInst - Added validation for each cast type's constraints - Updated cost model to use actual cast opcode - Added comprehensive tests for all supported cast types Fixes: llvm#146037
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-vectorizers @llvm/pr-subscribers-llvm-transforms Author: Rahul Yadav (rhyadav) ChangesThis patch generalizes the existing foldBitOpOfBitcasts optimization in the VectorCombine pass to handle Fixes: #146037 Summary The optimization now supports folding bitwise operations (AND/OR/XOR) with the following cast operations:
The transformation pattern is: This reduces the number of cast instructions from 2 to 1, improving performance on targets where cast operations Implementation Details
Testing
Full diff: https://github.com/llvm/llvm-project/pull/148350.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index fe8d74c43dfdc..58aa53694b22e 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -52,9 +52,9 @@ STATISTIC(NumScalarOps, "Number of scalar unary + binary ops formed");
STATISTIC(NumScalarCmp, "Number of scalar compares formed");
STATISTIC(NumScalarIntrinsic, "Number of scalar intrinsic calls formed");
-static cl::opt<bool> DisableVectorCombine(
- "disable-vector-combine", cl::init(false), cl::Hidden,
- cl::desc("Disable all vector combine transforms"));
+static cl::opt<bool>
+ DisableVectorCombine("disable-vector-combine", cl::init(false), cl::Hidden,
+ cl::desc("Disable all vector combine transforms"));
static cl::opt<bool> DisableBinopExtractShuffle(
"disable-binop-extract-shuffle", cl::init(false), cl::Hidden,
@@ -115,7 +115,7 @@ class VectorCombine {
bool foldInsExtFNeg(Instruction &I);
bool foldInsExtBinop(Instruction &I);
bool foldInsExtVectorToShuffle(Instruction &I);
- bool foldBitOpOfBitcasts(Instruction &I);
+ bool foldBitOpOfCastops(Instruction &I);
bool foldBitcastShuffle(Instruction &I);
bool scalarizeOpOrCmp(Instruction &I);
bool scalarizeVPIntrinsic(Instruction &I);
@@ -808,46 +808,105 @@ bool VectorCombine::foldInsExtBinop(Instruction &I) {
return true;
}
-bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) {
- // Match: bitop(bitcast(x), bitcast(y)) -> bitcast(bitop(x, y))
- Value *LHSSrc, *RHSSrc;
- if (!match(&I, m_BitwiseLogic(m_BitCast(m_Value(LHSSrc)),
- m_BitCast(m_Value(RHSSrc)))))
+bool VectorCombine::foldBitOpOfCastops(Instruction &I) {
+ // Match: bitop(castop(x), castop(y)) -> castop(bitop(x, y))
+ // Supports: bitcast, trunc, sext, zext
+
+ // Check if this is a bitwise logic operation
+ auto *BinOp = dyn_cast<BinaryOperator>(&I);
+ if (!BinOp || !BinOp->isBitwiseLogicOp())
+ return false;
+
+ LLVM_DEBUG(dbgs() << "Found bitwise logic op: " << I << "\n");
+
+ // Get the cast instructions
+ auto *LHSCast = dyn_cast<CastInst>(BinOp->getOperand(0));
+ auto *RHSCast = dyn_cast<CastInst>(BinOp->getOperand(1));
+ if (!LHSCast || !RHSCast) {
+ LLVM_DEBUG(dbgs() << " One or both operands are not cast instructions\n");
+ return false;
+ }
+
+ LLVM_DEBUG(dbgs() << " LHS cast: " << *LHSCast << "\n");
+ LLVM_DEBUG(dbgs() << " RHS cast: " << *RHSCast << "\n");
+
+ // Both casts must be the same type
+ Instruction::CastOps CastOpcode = LHSCast->getOpcode();
+ if (CastOpcode != RHSCast->getOpcode())
return false;
+ // Only handle supported cast operations
+ switch (CastOpcode) {
+ case Instruction::BitCast:
+ case Instruction::Trunc:
+ case Instruction::SExt:
+ case Instruction::ZExt:
+ break;
+ default:
+ return false;
+ }
+
+ Value *LHSSrc = LHSCast->getOperand(0);
+ Value *RHSSrc = RHSCast->getOperand(0);
+
// Source types must match
if (LHSSrc->getType() != RHSSrc->getType())
return false;
- if (!LHSSrc->getType()->getScalarType()->isIntegerTy())
- return false;
- // Only handle vector types
+ // Only handle vector types with integer elements
auto *SrcVecTy = dyn_cast<FixedVectorType>(LHSSrc->getType());
auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
if (!SrcVecTy || !DstVecTy)
return false;
- // Same total bit width
- assert(SrcVecTy->getPrimitiveSizeInBits() ==
- DstVecTy->getPrimitiveSizeInBits() &&
- "Bitcast should preserve total bit width");
+ if (!SrcVecTy->getScalarType()->isIntegerTy() ||
+ !DstVecTy->getScalarType()->isIntegerTy())
+ return false;
+
+ // Validate cast operation constraints
+ switch (CastOpcode) {
+ case Instruction::BitCast:
+ // Total bit width must be preserved
+ if (SrcVecTy->getPrimitiveSizeInBits() !=
+ DstVecTy->getPrimitiveSizeInBits())
+ return false;
+ break;
+ case Instruction::Trunc:
+ // Source elements must be wider
+ if (SrcVecTy->getScalarSizeInBits() <= DstVecTy->getScalarSizeInBits())
+ return false;
+ break;
+ case Instruction::SExt:
+ case Instruction::ZExt:
+ // Source elements must be narrower
+ if (SrcVecTy->getScalarSizeInBits() >= DstVecTy->getScalarSizeInBits())
+ return false;
+ break;
+ }
// Cost Check :
- // OldCost = bitlogic + 2*bitcasts
- // NewCost = bitlogic + bitcast
- auto *BinOp = cast<BinaryOperator>(&I);
+ // OldCost = bitlogic + 2*casts
+ // NewCost = bitlogic + cast
InstructionCost OldCost =
TTI.getArithmeticInstrCost(BinOp->getOpcode(), DstVecTy) +
- TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, LHSSrc->getType(),
- TTI::CastContextHint::None) +
- TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, RHSSrc->getType(),
- TTI::CastContextHint::None);
+ TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
+ TTI::CastContextHint::None) *
+ 2;
+
InstructionCost NewCost =
TTI.getArithmeticInstrCost(BinOp->getOpcode(), SrcVecTy) +
- TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, SrcVecTy,
+ TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
TTI::CastContextHint::None);
- LLVM_DEBUG(dbgs() << "Found a bitwise logic op of bitcasted values: " << I
+ // Account for multi-use casts
+ if (!LHSCast->hasOneUse())
+ NewCost += TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
+ TTI::CastContextHint::None);
+ if (!RHSCast->hasOneUse())
+ NewCost += TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
+ TTI::CastContextHint::None);
+
+ LLVM_DEBUG(dbgs() << "Found bitwise logic op of cast ops: " << I
<< "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
<< "\n");
@@ -862,8 +921,15 @@ bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) {
Worklist.pushValue(NewOp);
- // Bitcast the result back
- Value *Result = Builder.CreateBitCast(NewOp, I.getType());
+ // Create the cast operation
+ Value *Result = Builder.CreateCast(CastOpcode, NewOp, I.getType());
+
+ // Preserve cast instruction flags
+ if (auto *NewCast = dyn_cast<CastInst>(Result)) {
+ NewCast->copyIRFlags(LHSCast);
+ NewCast->andIRFlags(RHSCast);
+ }
+
replaceValue(I, *Result);
return true;
}
@@ -1020,8 +1086,7 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
InstructionCost OldCost = 2 * SplatCost + VectorOpCost;
// Determine scalar opcode
- std::optional<unsigned> FunctionalOpcode =
- VPI.getFunctionalOpcode();
+ std::optional<unsigned> FunctionalOpcode = VPI.getFunctionalOpcode();
std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt;
if (!FunctionalOpcode) {
ScalarIntrID = VPI.getFunctionalIntrinsicID();
@@ -1044,8 +1109,7 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
(SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse());
InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats;
- LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI
- << "\n");
+ LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI << "\n");
LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost
<< ", Cost of scalarizing:" << NewCost << "\n");
@@ -2015,10 +2079,12 @@ bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
}
unsigned NumOpElts = Op0Ty->getNumElements();
- bool IsIdentity0 = ShuffleDstTy == Op0Ty &&
+ bool IsIdentity0 =
+ ShuffleDstTy == Op0Ty &&
all_of(NewMask0, [NumOpElts](int M) { return M < (int)NumOpElts; }) &&
ShuffleVectorInst::isIdentityMask(NewMask0, NumOpElts);
- bool IsIdentity1 = ShuffleDstTy == Op1Ty &&
+ bool IsIdentity1 =
+ ShuffleDstTy == Op1Ty &&
all_of(NewMask1, [NumOpElts](int M) { return M < (int)NumOpElts; }) &&
ShuffleVectorInst::isIdentityMask(NewMask1, NumOpElts);
@@ -3773,7 +3839,7 @@ bool VectorCombine::run() {
case Instruction::And:
case Instruction::Or:
case Instruction::Xor:
- MadeChange |= foldBitOpOfBitcasts(I);
+ MadeChange |= foldBitOpOfCastops(I);
break;
default:
MadeChange |= shrinkType(I);
diff --git a/llvm/test/Transforms/VectorCombine/bitop-of-castops.ll b/llvm/test/Transforms/VectorCombine/bitop-of-castops.ll
new file mode 100644
index 0000000000000..003e14bebd169
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/bitop-of-castops.ll
@@ -0,0 +1,263 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=vector-combine -S -mtriple=x86_64-- | FileCheck %s
+
+; Test bitwise operations with bitcast
+define <4 x i32> @and_bitcast_v4f32_to_v4i32(<4 x float> %a, <4 x float> %b) {
+; CHECK-LABEL: @and_bitcast_v4f32_to_v4i32(
+; CHECK-NEXT: [[BC1:%.*]] = bitcast <4 x float> [[A:%.*]] to <4 x i32>
+; CHECK-NEXT: [[BC2:%.*]] = bitcast <4 x float> [[B:%.*]] to <4 x i32>
+; CHECK-NEXT: [[AND:%.*]] = and <4 x i32> [[BC1]], [[BC2]]
+; CHECK-NEXT: ret <4 x i32> [[AND]]
+;
+ %bc1 = bitcast <4 x float> %a to <4 x i32>
+ %bc2 = bitcast <4 x float> %b to <4 x i32>
+ %and = and <4 x i32> %bc1, %bc2
+ ret <4 x i32> %and
+}
+
+define <4 x i32> @or_bitcast_v4f32_to_v4i32(<4 x float> %a, <4 x float> %b) {
+; CHECK-LABEL: @or_bitcast_v4f32_to_v4i32(
+; CHECK-NEXT: [[BC1:%.*]] = bitcast <4 x float> [[A:%.*]] to <4 x i32>
+; CHECK-NEXT: [[BC2:%.*]] = bitcast <4 x float> [[B:%.*]] to <4 x i32>
+; CHECK-NEXT: [[OR:%.*]] = or <4 x i32> [[BC1]], [[BC2]]
+; CHECK-NEXT: ret <4 x i32> [[OR]]
+;
+ %bc1 = bitcast <4 x float> %a to <4 x i32>
+ %bc2 = bitcast <4 x float> %b to <4 x i32>
+ %or = or <4 x i32> %bc1, %bc2
+ ret <4 x i32> %or
+}
+
+define <4 x i32> @xor_bitcast_v4f32_to_v4i32(<4 x float> %a, <4 x float> %b) {
+; CHECK-LABEL: @xor_bitcast_v4f32_to_v4i32(
+; CHECK-NEXT: [[BC1:%.*]] = bitcast <4 x float> [[A:%.*]] to <4 x i32>
+; CHECK-NEXT: [[BC2:%.*]] = bitcast <4 x float> [[B:%.*]] to <4 x i32>
+; CHECK-NEXT: [[XOR:%.*]] = xor <4 x i32> [[BC1]], [[BC2]]
+; CHECK-NEXT: ret <4 x i32> [[XOR]]
+;
+ %bc1 = bitcast <4 x float> %a to <4 x i32>
+ %bc2 = bitcast <4 x float> %b to <4 x i32>
+ %xor = xor <4 x i32> %bc1, %bc2
+ ret <4 x i32> %xor
+}
+
+; Test bitwise operations with truncate
+define <4 x i16> @and_trunc_v4i32_to_v4i16(<4 x i32> %a, <4 x i32> %b) {
+; CHECK-LABEL: @and_trunc_v4i32_to_v4i16(
+; CHECK-NEXT: [[AND_INNER:%.*]] = and <4 x i32> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT: [[AND:%.*]] = trunc <4 x i32> [[AND_INNER]] to <4 x i16>
+; CHECK-NEXT: ret <4 x i16> [[AND]]
+;
+ %t1 = trunc <4 x i32> %a to <4 x i16>
+ %t2 = trunc <4 x i32> %b to <4 x i16>
+ %and = and <4 x i16> %t1, %t2
+ ret <4 x i16> %and
+}
+
+define <8 x i8> @or_trunc_v8i16_to_v8i8(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: @or_trunc_v8i16_to_v8i8(
+; CHECK-NEXT: [[OR_INNER:%.*]] = or <8 x i16> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT: [[OR:%.*]] = trunc <8 x i16> [[OR_INNER]] to <8 x i8>
+; CHECK-NEXT: ret <8 x i8> [[OR]]
+;
+ %t1 = trunc <8 x i16> %a to <8 x i8>
+ %t2 = trunc <8 x i16> %b to <8 x i8>
+ %or = or <8 x i8> %t1, %t2
+ ret <8 x i8> %or
+}
+
+define <2 x i32> @xor_trunc_v2i64_to_v2i32(<2 x i64> %a, <2 x i64> %b) {
+; CHECK-LABEL: @xor_trunc_v2i64_to_v2i32(
+; CHECK-NEXT: [[XOR_INNER:%.*]] = xor <2 x i64> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT: [[XOR:%.*]] = trunc <2 x i64> [[XOR_INNER]] to <2 x i32>
+; CHECK-NEXT: ret <2 x i32> [[XOR]]
+;
+ %t1 = trunc <2 x i64> %a to <2 x i32>
+ %t2 = trunc <2 x i64> %b to <2 x i32>
+ %xor = xor <2 x i32> %t1, %t2
+ ret <2 x i32> %xor
+}
+
+; Test bitwise operations with zero extend
+define <4 x i32> @and_zext_v4i16_to_v4i32(<4 x i16> %a, <4 x i16> %b) {
+; CHECK-LABEL: @and_zext_v4i16_to_v4i32(
+; CHECK-NEXT: [[AND_INNER:%.*]] = and <4 x i16> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT: [[AND:%.*]] = zext <4 x i16> [[AND_INNER]] to <4 x i32>
+; CHECK-NEXT: ret <4 x i32> [[AND]]
+;
+ %z1 = zext <4 x i16> %a to <4 x i32>
+ %z2 = zext <4 x i16> %b to <4 x i32>
+ %and = and <4 x i32> %z1, %z2
+ ret <4 x i32> %and
+}
+
+define <8 x i16> @or_zext_v8i8_to_v8i16(<8 x i8> %a, <8 x i8> %b) {
+; CHECK-LABEL: @or_zext_v8i8_to_v8i16(
+; CHECK-NEXT: [[OR_INNER:%.*]] = or <8 x i8> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT: [[OR:%.*]] = zext <8 x i8> [[OR_INNER]] to <8 x i16>
+; CHECK-NEXT: ret <8 x i16> [[OR]]
+;
+ %z1 = zext <8 x i8> %a to <8 x i16>
+ %z2 = zext <8 x i8> %b to <8 x i16>
+ %or = or <8 x i16> %z1, %z2
+ ret <8 x i16> %or
+}
+
+define <2 x i64> @xor_zext_v2i32_to_v2i64(<2 x i32> %a, <2 x i32> %b) {
+; CHECK-LABEL: @xor_zext_v2i32_to_v2i64(
+; CHECK-NEXT: [[XOR_INNER:%.*]] = xor <2 x i32> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT: [[XOR:%.*]] = zext <2 x i32> [[XOR_INNER]] to <2 x i64>
+; CHECK-NEXT: ret <2 x i64> [[XOR]]
+;
+ %z1 = zext <2 x i32> %a to <2 x i64>
+ %z2 = zext <2 x i32> %b to <2 x i64>
+ %xor = xor <2 x i64> %z1, %z2
+ ret <2 x i64> %xor
+}
+
+; Test bitwise operations with sign extend
+define <4 x i32> @and_sext_v4i16_to_v4i32(<4 x i16> %a, <4 x i16> %b) {
+; CHECK-LABEL: @and_sext_v4i16_to_v4i32(
+; CHECK-NEXT: [[AND_INNER:%.*]] = and <4 x i16> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT: [[AND:%.*]] = sext <4 x i16> [[AND_INNER]] to <4 x i32>
+; CHECK-NEXT: ret <4 x i32> [[AND]]
+;
+ %s1 = sext <4 x i16> %a to <4 x i32>
+ %s2 = sext <4 x i16> %b to <4 x i32>
+ %and = and <4 x i32> %s1, %s2
+ ret <4 x i32> %and
+}
+
+define <8 x i16> @or_sext_v8i8_to_v8i16(<8 x i8> %a, <8 x i8> %b) {
+; CHECK-LABEL: @or_sext_v8i8_to_v8i16(
+; CHECK-NEXT: [[OR_INNER:%.*]] = or <8 x i8> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT: [[OR:%.*]] = sext <8 x i8> [[OR_INNER]] to <8 x i16>
+; CHECK-NEXT: ret <8 x i16> [[OR]]
+;
+ %s1 = sext <8 x i8> %a to <8 x i16>
+ %s2 = sext <8 x i8> %b to <8 x i16>
+ %or = or <8 x i16> %s1, %s2
+ ret <8 x i16> %or
+}
+
+define <2 x i64> @xor_sext_v2i32_to_v2i64(<2 x i32> %a, <2 x i32> %b) {
+; CHECK-LABEL: @xor_sext_v2i32_to_v2i64(
+; CHECK-NEXT: [[XOR_INNER:%.*]] = xor <2 x i32> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT: [[XOR:%.*]] = sext <2 x i32> [[XOR_INNER]] to <2 x i64>
+; CHECK-NEXT: ret <2 x i64> [[XOR]]
+;
+ %s1 = sext <2 x i32> %a to <2 x i64>
+ %s2 = sext <2 x i32> %b to <2 x i64>
+ %xor = xor <2 x i64> %s1, %s2
+ ret <2 x i64> %xor
+}
+
+; Negative test: mismatched cast types (zext and sext)
+define <4 x i32> @and_zext_sext_mismatch(<4 x i16> %a, <4 x i16> %b) {
+; CHECK-LABEL: @and_zext_sext_mismatch(
+; CHECK-NEXT: [[Z1:%.*]] = zext <4 x i16> [[A:%.*]] to <4 x i32>
+; CHECK-NEXT: [[S2:%.*]] = sext <4 x i16> [[B:%.*]] to <4 x i32>
+; CHECK-NEXT: [[AND:%.*]] = and <4 x i32> [[Z1]], [[S2]]
+; CHECK-NEXT: ret <4 x i32> [[AND]]
+;
+ %z1 = zext <4 x i16> %a to <4 x i32>
+ %s2 = sext <4 x i16> %b to <4 x i32>
+ %and = and <4 x i32> %z1, %s2
+ ret <4 x i32> %and
+}
+
+; Negative test: mismatched source types
+define <4 x i32> @or_zext_different_src_types(<4 x i16> %a, <4 x i8> %b) {
+; CHECK-LABEL: @or_zext_different_src_types(
+; CHECK-NEXT: [[Z1:%.*]] = zext <4 x i16> [[A:%.*]] to <4 x i32>
+; CHECK-NEXT: [[Z2:%.*]] = zext <4 x i8> [[B:%.*]] to <4 x i32>
+; CHECK-NEXT: [[OR:%.*]] = or <4 x i32> [[Z1]], [[Z2]]
+; CHECK-NEXT: ret <4 x i32> [[OR]]
+;
+ %z1 = zext <4 x i16> %a to <4 x i32>
+ %z2 = zext <4 x i8> %b to <4 x i32>
+ %or = or <4 x i32> %z1, %z2
+ ret <4 x i32> %or
+}
+
+; Negative test: scalar types (not vectors)
+define i32 @xor_zext_scalar(i16 %a, i16 %b) {
+; CHECK-LABEL: @xor_zext_scalar(
+; CHECK-NEXT: [[Z1:%.*]] = zext i16 [[A:%.*]] to i32
+; CHECK-NEXT: [[Z2:%.*]] = zext i16 [[B:%.*]] to i32
+; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Z1]], [[Z2]]
+; CHECK-NEXT: ret i32 [[XOR]]
+;
+ %z1 = zext i16 %a to i32
+ %z2 = zext i16 %b to i32
+ %xor = xor i32 %z1, %z2
+ ret i32 %xor
+}
+
+; Test multi-use: one cast has multiple uses
+define <4 x i32> @and_zext_multiuse(<4 x i16> %a, <4 x i16> %b) {
+; CHECK-LABEL: @and_zext_multiuse(
+; CHECK-NEXT: [[Z1:%.*]] = zext <4 x i16> [[A:%.*]] to <4 x i32>
+; CHECK-NEXT: [[AND_INNER:%.*]] = and <4 x i16> [[A]], [[B:%.*]]
+; CHECK-NEXT: [[AND:%.*]] = zext <4 x i16> [[AND_INNER]] to <4 x i32>
+; CHECK-NEXT: [[ADD:%.*]] = add <4 x i32> [[Z1]], [[AND]]
+; CHECK-NEXT: ret <4 x i32> [[ADD]]
+;
+ %z1 = zext <4 x i16> %a to <4 x i32>
+ %z2 = zext <4 x i16> %b to <4 x i32>
+ %and = and <4 x i32> %z1, %z2
+ %add = add <4 x i32> %z1, %and ; z1 has multiple uses
+ ret <4 x i32> %add
+}
+
+; Test with different vector sizes
+define <16 x i16> @or_zext_v16i8_to_v16i16(<16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: @or_zext_v16i8_to_v16i16(
+; CHECK-NEXT: [[OR_INNER:%.*]] = or <16 x i8> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT: [[OR:%.*]] = zext <16 x i8> [[OR_INNER]] to <16 x i16>
+; CHECK-NEXT: ret <16 x i16> [[OR]]
+;
+ %z1 = zext <16 x i8> %a to <16 x i16>
+ %z2 = zext <16 x i8> %b to <16 x i16>
+ %or = or <16 x i16> %z1, %z2
+ ret <16 x i16> %or
+}
+
+; Test bitcast with different element counts
+define <8 x i16> @xor_bitcast_v4i32_to_v8i16(<4 x i32> %a, <4 x i32> %b) {
+; CHECK-LABEL: @xor_bitcast_v4i32_to_v8i16(
+; CHECK-NEXT: [[XOR_INNER:%.*]] = xor <4 x i32> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT: [[XOR:%.*]] = bitcast <4 x i32> [[XOR_INNER]] to <8 x i16>
+; CHECK-NEXT: ret <8 x i16> [[XOR]]
+;
+ %bc1 = bitcast <4 x i32> %a to <8 x i16>
+ %bc2 = bitcast <4 x i32> %b to <8 x i16>
+ %xor = xor <8 x i16> %bc1, %bc2
+ ret <8 x i16> %xor
+}
+
+; Test truncate with flag preservation
+define <4 x i16> @and_trunc_nuw_nsw(<4 x i32> %a, <4 x i32> %b) {
+; CHECK-LABEL: @and_trunc_nuw_nsw(
+; CHECK-NEXT: [[AND_INNER:%.*]] = and <4 x i32> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT: [[AND:%.*]] = trunc nuw nsw <4 x i32> [[AND_INNER]] to <4 x i16>
+; CHECK-NEXT: ret <4 x i16> [[AND]]
+;
+ %t1 = trunc nuw nsw <4 x i32> %a to <4 x i16>
+ %t2 = trunc nuw nsw <4 x i32> %b to <4 x i16>
+ %and = and <4 x i16> %t1, %t2
+ ret <4 x i16> %and
+}
+
+; Test sign extend with nneg flag
+define <4 x i32> @or_zext_nneg(<4 x i16> %a, <4 x i16> %b) {
+; CHECK-LABEL: @or_zext_nneg(
+; CHECK-NEXT: [[OR_INNER:%.*]] = or <4 x i16> [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT: [[OR:%.*]] = zext nneg <4 x i16> [[OR_INNER]] to <4 x i32>
+; CHECK-NEXT: ret <4 x i32> [[OR]]
+;
+ %z1 = zext nneg <4 x i16> %a to <4 x i32>
+ %z2 = zext nneg <4 x i16> %b to <4 x i32>
+ %or = or <4 x i32> %z1, %z2
+ ret <4 x i32> %or
+}
|
@RKSimon request your review |
cl::desc("Disable all vector combine transforms")); | ||
static cl::opt<bool> | ||
DisableVectorCombine("disable-vector-combine", cl::init(false), cl::Hidden, | ||
cl::desc("Disable all vector combine transforms")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(style) don't clang-format lines unrelated to a patch
m_BitCast(m_Value(RHSSrc))))) | ||
bool VectorCombine::foldBitOpOfCastops(Instruction &I) { | ||
// Match: bitop(castop(x), castop(y)) -> castop(bitop(x, y)) | ||
// Supports: bitcast, trunc, sext, zext |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(style) Move the description comment up a few lines to outside the function def
if (!BinOp || !BinOp->isBitwiseLogicOp()) | ||
return false; | ||
|
||
LLVM_DEBUG(dbgs() << "Found bitwise logic op: " << I << "\n"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably drop this?
} | ||
|
||
LLVM_DEBUG(dbgs() << " LHS cast: " << *LHSCast << "\n"); | ||
LLVM_DEBUG(dbgs() << " RHS cast: " << *RHSCast << "\n"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably drop this?
if (SrcVecTy->getScalarSizeInBits() >= DstVecTy->getScalarSizeInBits()) | ||
return false; | ||
break; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if any of these checks are necessary - even the old assertion didn't contribute much as Builder.CreateCast should assert the cast is valid. for us if we get to that stage.
TTI::CastContextHint::None); | ||
TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy, | ||
TTI::CastContextHint::None) * | ||
2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we hoist the the separate getCastInstrCost calls here to avoid calling it again for the !hasOneUse cases below, Add the Instruction* args as well to help improve costs - we can't do it for new cost calc but its still useful for old costs. We're missing the CostKind as well
InstructionCost NewCost = | ||
TTI.getArithmeticInstrCost(BinOp->getOpcode(), SrcVecTy) + | ||
TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, SrcVecTy, | ||
TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy, | ||
TTI::CastContextHint::None); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing CostKind.
TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy, | ||
TTI::CastContextHint::None) * | ||
2; | ||
|
||
InstructionCost NewCost = | ||
TTI.getArithmeticInstrCost(BinOp->getOpcode(), SrcVecTy) + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing CostKind.
@@ -1020,8 +1086,7 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) { | |||
InstructionCost OldCost = 2 * SplatCost + VectorOpCost; | |||
|
|||
// Determine scalar opcode | |||
std::optional<unsigned> FunctionalOpcode = | |||
VPI.getFunctionalOpcode(); | |||
std::optional<unsigned> FunctionalOpcode = VPI.getFunctionalOpcode(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(style) don't clang-format lines unrelated to a patch
This patch generalizes the existing foldBitOpOfBitcasts optimization in the VectorCombine pass to handle
additional cast operations beyond just bitcast.
Fixes: #146037
Summary
The optimization now supports folding bitwise operations (AND/OR/XOR) with the following cast operations:
The transformation pattern is:
bitop(castop(x), castop(y)) -> castop(bitop(x, y))
This reduces the number of cast instructions from 2 to 1, improving performance on targets where cast operations
are expensive or where performing bitwise operations on narrower types is beneficial.
Implementation Details
Testing