-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[VectorCombine] Generalize foldBitOpOfBitcasts to support more cast operations #148350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,9 +52,9 @@ STATISTIC(NumScalarOps, "Number of scalar unary + binary ops formed"); | |
STATISTIC(NumScalarCmp, "Number of scalar compares formed"); | ||
STATISTIC(NumScalarIntrinsic, "Number of scalar intrinsic calls formed"); | ||
|
||
static cl::opt<bool> DisableVectorCombine( | ||
"disable-vector-combine", cl::init(false), cl::Hidden, | ||
cl::desc("Disable all vector combine transforms")); | ||
static cl::opt<bool> | ||
DisableVectorCombine("disable-vector-combine", cl::init(false), cl::Hidden, | ||
cl::desc("Disable all vector combine transforms")); | ||
|
||
static cl::opt<bool> DisableBinopExtractShuffle( | ||
"disable-binop-extract-shuffle", cl::init(false), cl::Hidden, | ||
|
@@ -115,7 +115,7 @@ class VectorCombine { | |
bool foldInsExtFNeg(Instruction &I); | ||
bool foldInsExtBinop(Instruction &I); | ||
bool foldInsExtVectorToShuffle(Instruction &I); | ||
bool foldBitOpOfBitcasts(Instruction &I); | ||
bool foldBitOpOfCastops(Instruction &I); | ||
bool foldBitcastShuffle(Instruction &I); | ||
bool scalarizeOpOrCmp(Instruction &I); | ||
bool scalarizeVPIntrinsic(Instruction &I); | ||
|
@@ -808,46 +808,105 @@ bool VectorCombine::foldInsExtBinop(Instruction &I) { | |
return true; | ||
} | ||
|
||
bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) { | ||
// Match: bitop(bitcast(x), bitcast(y)) -> bitcast(bitop(x, y)) | ||
Value *LHSSrc, *RHSSrc; | ||
if (!match(&I, m_BitwiseLogic(m_BitCast(m_Value(LHSSrc)), | ||
m_BitCast(m_Value(RHSSrc))))) | ||
bool VectorCombine::foldBitOpOfCastops(Instruction &I) { | ||
// Match: bitop(castop(x), castop(y)) -> castop(bitop(x, y)) | ||
// Supports: bitcast, trunc, sext, zext | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (style) Move the description comment up a few lines to outside the function def |
||
|
||
// Check if this is a bitwise logic operation | ||
auto *BinOp = dyn_cast<BinaryOperator>(&I); | ||
if (!BinOp || !BinOp->isBitwiseLogicOp()) | ||
return false; | ||
|
||
LLVM_DEBUG(dbgs() << "Found bitwise logic op: " << I << "\n"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. probably drop this? |
||
|
||
// Get the cast instructions | ||
auto *LHSCast = dyn_cast<CastInst>(BinOp->getOperand(0)); | ||
auto *RHSCast = dyn_cast<CastInst>(BinOp->getOperand(1)); | ||
if (!LHSCast || !RHSCast) { | ||
LLVM_DEBUG(dbgs() << " One or both operands are not cast instructions\n"); | ||
return false; | ||
} | ||
|
||
LLVM_DEBUG(dbgs() << " LHS cast: " << *LHSCast << "\n"); | ||
LLVM_DEBUG(dbgs() << " RHS cast: " << *RHSCast << "\n"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. probably drop this? |
||
|
||
// Both casts must be the same type | ||
Instruction::CastOps CastOpcode = LHSCast->getOpcode(); | ||
if (CastOpcode != RHSCast->getOpcode()) | ||
return false; | ||
|
||
// Only handle supported cast operations | ||
switch (CastOpcode) { | ||
case Instruction::BitCast: | ||
case Instruction::Trunc: | ||
case Instruction::SExt: | ||
case Instruction::ZExt: | ||
break; | ||
default: | ||
return false; | ||
} | ||
|
||
Value *LHSSrc = LHSCast->getOperand(0); | ||
Value *RHSSrc = RHSCast->getOperand(0); | ||
|
||
// Source types must match | ||
if (LHSSrc->getType() != RHSSrc->getType()) | ||
return false; | ||
if (!LHSSrc->getType()->getScalarType()->isIntegerTy()) | ||
return false; | ||
|
||
// Only handle vector types | ||
// Only handle vector types with integer elements | ||
auto *SrcVecTy = dyn_cast<FixedVectorType>(LHSSrc->getType()); | ||
auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType()); | ||
if (!SrcVecTy || !DstVecTy) | ||
return false; | ||
|
||
// Same total bit width | ||
assert(SrcVecTy->getPrimitiveSizeInBits() == | ||
DstVecTy->getPrimitiveSizeInBits() && | ||
"Bitcast should preserve total bit width"); | ||
if (!SrcVecTy->getScalarType()->isIntegerTy() || | ||
!DstVecTy->getScalarType()->isIntegerTy()) | ||
return false; | ||
|
||
// Validate cast operation constraints | ||
switch (CastOpcode) { | ||
case Instruction::BitCast: | ||
// Total bit width must be preserved | ||
if (SrcVecTy->getPrimitiveSizeInBits() != | ||
DstVecTy->getPrimitiveSizeInBits()) | ||
return false; | ||
break; | ||
case Instruction::Trunc: | ||
// Source elements must be wider | ||
if (SrcVecTy->getScalarSizeInBits() <= DstVecTy->getScalarSizeInBits()) | ||
return false; | ||
break; | ||
case Instruction::SExt: | ||
case Instruction::ZExt: | ||
// Source elements must be narrower | ||
if (SrcVecTy->getScalarSizeInBits() >= DstVecTy->getScalarSizeInBits()) | ||
return false; | ||
break; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure if any of these checks are necessary - even the old assertion didn't contribute much as Builder.CreateCast should assert the cast is valid. for us if we get to that stage. |
||
|
||
// Cost Check : | ||
// OldCost = bitlogic + 2*bitcasts | ||
// NewCost = bitlogic + bitcast | ||
auto *BinOp = cast<BinaryOperator>(&I); | ||
// OldCost = bitlogic + 2*casts | ||
// NewCost = bitlogic + cast | ||
InstructionCost OldCost = | ||
TTI.getArithmeticInstrCost(BinOp->getOpcode(), DstVecTy) + | ||
TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, LHSSrc->getType(), | ||
TTI::CastContextHint::None) + | ||
TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, RHSSrc->getType(), | ||
TTI::CastContextHint::None); | ||
TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy, | ||
TTI::CastContextHint::None) * | ||
2; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we hoist the the separate getCastInstrCost calls here to avoid calling it again for the !hasOneUse cases below, Add the Instruction* args as well to help improve costs - we can't do it for new cost calc but its still useful for old costs. We're missing the CostKind as well |
||
|
||
InstructionCost NewCost = | ||
TTI.getArithmeticInstrCost(BinOp->getOpcode(), SrcVecTy) + | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing CostKind. |
||
TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, SrcVecTy, | ||
TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy, | ||
TTI::CastContextHint::None); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing CostKind. |
||
|
||
LLVM_DEBUG(dbgs() << "Found a bitwise logic op of bitcasted values: " << I | ||
// Account for multi-use casts | ||
if (!LHSCast->hasOneUse()) | ||
NewCost += TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy, | ||
TTI::CastContextHint::None); | ||
if (!RHSCast->hasOneUse()) | ||
NewCost += TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy, | ||
TTI::CastContextHint::None); | ||
|
||
LLVM_DEBUG(dbgs() << "Found bitwise logic op of cast ops: " << I | ||
<< "\n OldCost: " << OldCost << " vs NewCost: " << NewCost | ||
<< "\n"); | ||
|
||
|
@@ -862,8 +921,15 @@ bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) { | |
|
||
Worklist.pushValue(NewOp); | ||
|
||
// Bitcast the result back | ||
Value *Result = Builder.CreateBitCast(NewOp, I.getType()); | ||
// Create the cast operation | ||
Value *Result = Builder.CreateCast(CastOpcode, NewOp, I.getType()); | ||
|
||
// Preserve cast instruction flags | ||
if (auto *NewCast = dyn_cast<CastInst>(Result)) { | ||
NewCast->copyIRFlags(LHSCast); | ||
NewCast->andIRFlags(RHSCast); | ||
} | ||
|
||
replaceValue(I, *Result); | ||
return true; | ||
} | ||
|
@@ -1020,8 +1086,7 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) { | |
InstructionCost OldCost = 2 * SplatCost + VectorOpCost; | ||
|
||
// Determine scalar opcode | ||
std::optional<unsigned> FunctionalOpcode = | ||
VPI.getFunctionalOpcode(); | ||
std::optional<unsigned> FunctionalOpcode = VPI.getFunctionalOpcode(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (style) don't clang-format lines unrelated to a patch |
||
std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt; | ||
if (!FunctionalOpcode) { | ||
ScalarIntrID = VPI.getFunctionalIntrinsicID(); | ||
|
@@ -1044,8 +1109,7 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) { | |
(SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse()); | ||
InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats; | ||
|
||
LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI | ||
<< "\n"); | ||
LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI << "\n"); | ||
LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost | ||
<< ", Cost of scalarizing:" << NewCost << "\n"); | ||
|
||
|
@@ -2015,10 +2079,12 @@ bool VectorCombine::foldPermuteOfBinops(Instruction &I) { | |
} | ||
|
||
unsigned NumOpElts = Op0Ty->getNumElements(); | ||
bool IsIdentity0 = ShuffleDstTy == Op0Ty && | ||
bool IsIdentity0 = | ||
ShuffleDstTy == Op0Ty && | ||
all_of(NewMask0, [NumOpElts](int M) { return M < (int)NumOpElts; }) && | ||
ShuffleVectorInst::isIdentityMask(NewMask0, NumOpElts); | ||
bool IsIdentity1 = ShuffleDstTy == Op1Ty && | ||
bool IsIdentity1 = | ||
ShuffleDstTy == Op1Ty && | ||
all_of(NewMask1, [NumOpElts](int M) { return M < (int)NumOpElts; }) && | ||
ShuffleVectorInst::isIdentityMask(NewMask1, NumOpElts); | ||
|
||
|
@@ -3773,7 +3839,7 @@ bool VectorCombine::run() { | |
case Instruction::And: | ||
case Instruction::Or: | ||
case Instruction::Xor: | ||
MadeChange |= foldBitOpOfBitcasts(I); | ||
MadeChange |= foldBitOpOfCastops(I); | ||
break; | ||
default: | ||
MadeChange |= shrinkType(I); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(style) don't clang-format lines unrelated to a patch