Skip to content

[VectorCombine] Generalize foldBitOpOfBitcasts to support more cast operations #148350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 100 additions & 34 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ STATISTIC(NumScalarOps, "Number of scalar unary + binary ops formed");
STATISTIC(NumScalarCmp, "Number of scalar compares formed");
STATISTIC(NumScalarIntrinsic, "Number of scalar intrinsic calls formed");

static cl::opt<bool> DisableVectorCombine(
"disable-vector-combine", cl::init(false), cl::Hidden,
cl::desc("Disable all vector combine transforms"));
static cl::opt<bool>
DisableVectorCombine("disable-vector-combine", cl::init(false), cl::Hidden,
cl::desc("Disable all vector combine transforms"));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(style) don't clang-format lines unrelated to a patch


static cl::opt<bool> DisableBinopExtractShuffle(
"disable-binop-extract-shuffle", cl::init(false), cl::Hidden,
Expand Down Expand Up @@ -115,7 +115,7 @@ class VectorCombine {
bool foldInsExtFNeg(Instruction &I);
bool foldInsExtBinop(Instruction &I);
bool foldInsExtVectorToShuffle(Instruction &I);
bool foldBitOpOfBitcasts(Instruction &I);
bool foldBitOpOfCastops(Instruction &I);
bool foldBitcastShuffle(Instruction &I);
bool scalarizeOpOrCmp(Instruction &I);
bool scalarizeVPIntrinsic(Instruction &I);
Expand Down Expand Up @@ -808,46 +808,105 @@ bool VectorCombine::foldInsExtBinop(Instruction &I) {
return true;
}

bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) {
// Match: bitop(bitcast(x), bitcast(y)) -> bitcast(bitop(x, y))
Value *LHSSrc, *RHSSrc;
if (!match(&I, m_BitwiseLogic(m_BitCast(m_Value(LHSSrc)),
m_BitCast(m_Value(RHSSrc)))))
bool VectorCombine::foldBitOpOfCastops(Instruction &I) {
// Match: bitop(castop(x), castop(y)) -> castop(bitop(x, y))
// Supports: bitcast, trunc, sext, zext
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(style) Move the description comment up a few lines to outside the function def


// Check if this is a bitwise logic operation
auto *BinOp = dyn_cast<BinaryOperator>(&I);
if (!BinOp || !BinOp->isBitwiseLogicOp())
return false;

LLVM_DEBUG(dbgs() << "Found bitwise logic op: " << I << "\n");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably drop this?


// Get the cast instructions
auto *LHSCast = dyn_cast<CastInst>(BinOp->getOperand(0));
auto *RHSCast = dyn_cast<CastInst>(BinOp->getOperand(1));
if (!LHSCast || !RHSCast) {
LLVM_DEBUG(dbgs() << " One or both operands are not cast instructions\n");
return false;
}

LLVM_DEBUG(dbgs() << " LHS cast: " << *LHSCast << "\n");
LLVM_DEBUG(dbgs() << " RHS cast: " << *RHSCast << "\n");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably drop this?


// Both casts must be the same type
Instruction::CastOps CastOpcode = LHSCast->getOpcode();
if (CastOpcode != RHSCast->getOpcode())
return false;

// Only handle supported cast operations
switch (CastOpcode) {
case Instruction::BitCast:
case Instruction::Trunc:
case Instruction::SExt:
case Instruction::ZExt:
break;
default:
return false;
}

Value *LHSSrc = LHSCast->getOperand(0);
Value *RHSSrc = RHSCast->getOperand(0);

// Source types must match
if (LHSSrc->getType() != RHSSrc->getType())
return false;
if (!LHSSrc->getType()->getScalarType()->isIntegerTy())
return false;

// Only handle vector types
// Only handle vector types with integer elements
auto *SrcVecTy = dyn_cast<FixedVectorType>(LHSSrc->getType());
auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
if (!SrcVecTy || !DstVecTy)
return false;

// Same total bit width
assert(SrcVecTy->getPrimitiveSizeInBits() ==
DstVecTy->getPrimitiveSizeInBits() &&
"Bitcast should preserve total bit width");
if (!SrcVecTy->getScalarType()->isIntegerTy() ||
!DstVecTy->getScalarType()->isIntegerTy())
return false;

// Validate cast operation constraints
switch (CastOpcode) {
case Instruction::BitCast:
// Total bit width must be preserved
if (SrcVecTy->getPrimitiveSizeInBits() !=
DstVecTy->getPrimitiveSizeInBits())
return false;
break;
case Instruction::Trunc:
// Source elements must be wider
if (SrcVecTy->getScalarSizeInBits() <= DstVecTy->getScalarSizeInBits())
return false;
break;
case Instruction::SExt:
case Instruction::ZExt:
// Source elements must be narrower
if (SrcVecTy->getScalarSizeInBits() >= DstVecTy->getScalarSizeInBits())
return false;
break;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if any of these checks are necessary - even the old assertion didn't contribute much as Builder.CreateCast should assert the cast is valid. for us if we get to that stage.


// Cost Check :
// OldCost = bitlogic + 2*bitcasts
// NewCost = bitlogic + bitcast
auto *BinOp = cast<BinaryOperator>(&I);
// OldCost = bitlogic + 2*casts
// NewCost = bitlogic + cast
InstructionCost OldCost =
TTI.getArithmeticInstrCost(BinOp->getOpcode(), DstVecTy) +
TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, LHSSrc->getType(),
TTI::CastContextHint::None) +
TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, RHSSrc->getType(),
TTI::CastContextHint::None);
TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
TTI::CastContextHint::None) *
2;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we hoist the the separate getCastInstrCost calls here to avoid calling it again for the !hasOneUse cases below, Add the Instruction* args as well to help improve costs - we can't do it for new cost calc but its still useful for old costs. We're missing the CostKind as well


InstructionCost NewCost =
TTI.getArithmeticInstrCost(BinOp->getOpcode(), SrcVecTy) +
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing CostKind.

TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, SrcVecTy,
TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
TTI::CastContextHint::None);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing CostKind.


LLVM_DEBUG(dbgs() << "Found a bitwise logic op of bitcasted values: " << I
// Account for multi-use casts
if (!LHSCast->hasOneUse())
NewCost += TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
TTI::CastContextHint::None);
if (!RHSCast->hasOneUse())
NewCost += TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
TTI::CastContextHint::None);

LLVM_DEBUG(dbgs() << "Found bitwise logic op of cast ops: " << I
<< "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
<< "\n");

Expand All @@ -862,8 +921,15 @@ bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) {

Worklist.pushValue(NewOp);

// Bitcast the result back
Value *Result = Builder.CreateBitCast(NewOp, I.getType());
// Create the cast operation
Value *Result = Builder.CreateCast(CastOpcode, NewOp, I.getType());

// Preserve cast instruction flags
if (auto *NewCast = dyn_cast<CastInst>(Result)) {
NewCast->copyIRFlags(LHSCast);
NewCast->andIRFlags(RHSCast);
}

replaceValue(I, *Result);
return true;
}
Expand Down Expand Up @@ -1020,8 +1086,7 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
InstructionCost OldCost = 2 * SplatCost + VectorOpCost;

// Determine scalar opcode
std::optional<unsigned> FunctionalOpcode =
VPI.getFunctionalOpcode();
std::optional<unsigned> FunctionalOpcode = VPI.getFunctionalOpcode();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(style) don't clang-format lines unrelated to a patch

std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt;
if (!FunctionalOpcode) {
ScalarIntrID = VPI.getFunctionalIntrinsicID();
Expand All @@ -1044,8 +1109,7 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
(SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse());
InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats;

LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI
<< "\n");
LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI << "\n");
LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost
<< ", Cost of scalarizing:" << NewCost << "\n");

Expand Down Expand Up @@ -2015,10 +2079,12 @@ bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
}

unsigned NumOpElts = Op0Ty->getNumElements();
bool IsIdentity0 = ShuffleDstTy == Op0Ty &&
bool IsIdentity0 =
ShuffleDstTy == Op0Ty &&
all_of(NewMask0, [NumOpElts](int M) { return M < (int)NumOpElts; }) &&
ShuffleVectorInst::isIdentityMask(NewMask0, NumOpElts);
bool IsIdentity1 = ShuffleDstTy == Op1Ty &&
bool IsIdentity1 =
ShuffleDstTy == Op1Ty &&
all_of(NewMask1, [NumOpElts](int M) { return M < (int)NumOpElts; }) &&
ShuffleVectorInst::isIdentityMask(NewMask1, NumOpElts);

Expand Down Expand Up @@ -3773,7 +3839,7 @@ bool VectorCombine::run() {
case Instruction::And:
case Instruction::Or:
case Instruction::Xor:
MadeChange |= foldBitOpOfBitcasts(I);
MadeChange |= foldBitOpOfCastops(I);
break;
default:
MadeChange |= shrinkType(I);
Expand Down
Loading
Loading