Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -3470,6 +3470,13 @@ class LLVM_ABI TargetLoweringBase {
return MathUsed && (VT.isSimple() || !isOperationExpand(Opcode, VT));
}

// Return true if the target wants to optimize the mul overflow intrinsic
// for the given \p VT.
virtual bool shouldOptimizeMulOverflowWithZeroHighBits(LLVMContext &Context,
EVT VT) const {
return false;
}

// Return true if it is profitable to use a scalar input to a BUILD_VECTOR
// even if the vector itself has multiple uses.
virtual bool aggressivelyPreferBuildVectorSources(EVT VecVT) const {
Expand Down
182 changes: 182 additions & 0 deletions llvm/lib/CodeGen/CodeGenPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,8 @@ class CodeGenPrepare {
bool optimizeMemoryInst(Instruction *MemoryInst, Value *Addr, Type *AccessTy,
unsigned AddrSpace);
bool optimizeGatherScatterInst(Instruction *MemoryInst, Value *Ptr);
bool optimizeMulWithOverflow(Instruction *I, bool IsSigned,
ModifyDT &ModifiedDT);
bool optimizeInlineAsmInst(CallInst *CS);
bool optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT);
bool optimizeExt(Instruction *&I);
Expand Down Expand Up @@ -2778,6 +2780,10 @@ bool CodeGenPrepare::optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT) {
}
}
return false;
case Intrinsic::umul_with_overflow:
return optimizeMulWithOverflow(II, /*IsSigned=*/false, ModifiedDT);
case Intrinsic::smul_with_overflow:
return optimizeMulWithOverflow(II, /*IsSigned=*/true, ModifiedDT);
}

SmallVector<Value *, 2> PtrOps;
Expand Down Expand Up @@ -6389,6 +6395,182 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
return true;
}

// This is a helper for CodeGenPrepare::optimizeMulWithOverflow.
// Check the pattern we are interested in where there are maximum 2 uses
// of the intrinsic which are the extract instructions.
static bool matchOverflowPattern(Instruction *&I, ExtractValueInst *&MulExtract,
ExtractValueInst *&OverflowExtract) {
// Bail out if it's more than 2 users:
if (I->hasNUsesOrMore(3))
return false;

for (User *U : I->users()) {
auto *Extract = dyn_cast<ExtractValueInst>(U);
if (!Extract || Extract->getNumIndices() != 1)
return false;

unsigned Index = Extract->getIndices()[0];
if (Index == 0)
MulExtract = Extract;
else if (Index == 1)
OverflowExtract = Extract;
else
return false;
}
return true;
}

// Rewrite the mul_with_overflow intrinsic by checking if both of the
// operands' value ranges are within the legal type. If so, we can optimize the
// multiplication algorithm. This code is supposed to be written during the step
// of type legalization, but given that we need to reconstruct the IR which is
// not doable there, we do it here.
// The IR after the optimization will look like:
// entry:
// if signed:
// ( (lhs_lo>>BW-1) ^ lhs_hi) || ( (rhs_lo>>BW-1) ^ rhs_hi) ? overflow,
// overflow_no
// else:
// (lhs_hi != 0) || (rhs_hi != 0) ? overflow, overflow_no
// overflow_no:
// overflow:
// overflow.res:
// \returns true if optimization was applied
// TODO: This optimization can be further improved to optimize branching on
// overflow where the 'overflow_no' BB can branch directly to the false
// successor of overflow, but that would add additional complexity so we leave
// it for future work.
bool CodeGenPrepare::optimizeMulWithOverflow(Instruction *I, bool IsSigned,
ModifyDT &ModifiedDT) {
// Check if target supports this optimization.
if (!TLI->shouldOptimizeMulOverflowWithZeroHighBits(
I->getContext(),
TLI->getValueType(*DL, I->getType()->getContainedType(0))))
return false;

ExtractValueInst *MulExtract = nullptr, *OverflowExtract = nullptr;
if (!matchOverflowPattern(I, MulExtract, OverflowExtract))
return false;

// Keep track of the instruction to stop reoptimizing it again.
InsertedInsts.insert(I);

Value *LHS = I->getOperand(0);
Value *RHS = I->getOperand(1);
Type *Ty = LHS->getType();
unsigned VTHalfBitWidth = Ty->getScalarSizeInBits() / 2;
Type *LegalTy = Ty->getWithNewBitWidth(VTHalfBitWidth);

// New BBs:
BasicBlock *OverflowEntryBB =
I->getParent()->splitBasicBlock(I, "", /*Before*/ true);
OverflowEntryBB->takeName(I->getParent());
// Keep the 'br' instruction that is generated as a result of the split to be
// erased/replaced later.
Instruction *OldTerminator = OverflowEntryBB->getTerminator();
BasicBlock *NoOverflowBB =
BasicBlock::Create(I->getContext(), "overflow.no", I->getFunction());
NoOverflowBB->moveAfter(OverflowEntryBB);
BasicBlock *OverflowBB =
BasicBlock::Create(I->getContext(), "overflow", I->getFunction());
OverflowBB->moveAfter(NoOverflowBB);

// BB overflow.entry:
IRBuilder<> Builder(OverflowEntryBB);
// Extract low and high halves of LHS:
Value *LoLHS = Builder.CreateTrunc(LHS, LegalTy, "lo.lhs");
Value *HiLHS = Builder.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr");
HiLHS = Builder.CreateTrunc(HiLHS, LegalTy, "hi.lhs");

// Extract low and high halves of RHS:
Value *LoRHS = Builder.CreateTrunc(RHS, LegalTy, "lo.rhs");
Value *HiRHS = Builder.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr");
HiRHS = Builder.CreateTrunc(HiRHS, LegalTy, "hi.rhs");

Value *IsAnyBitTrue;
if (IsSigned) {
Value *SignLoLHS =
Builder.CreateAShr(LoLHS, VTHalfBitWidth - 1, "sign.lo.lhs");
Value *SignLoRHS =
Builder.CreateAShr(LoRHS, VTHalfBitWidth - 1, "sign.lo.rhs");
Value *XorLHS = Builder.CreateXor(HiLHS, SignLoLHS);
Value *XorRHS = Builder.CreateXor(HiRHS, SignLoRHS);
Value *Or = Builder.CreateOr(XorLHS, XorRHS, "or.lhs.rhs");
IsAnyBitTrue = Builder.CreateCmp(ICmpInst::ICMP_NE, Or,
ConstantInt::getNullValue(Or->getType()));
} else {
Value *CmpLHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiLHS,
ConstantInt::getNullValue(LegalTy));
Value *CmpRHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiRHS,
ConstantInt::getNullValue(LegalTy));
IsAnyBitTrue = Builder.CreateOr(CmpLHS, CmpRHS, "or.lhs.rhs");
}
Builder.CreateCondBr(IsAnyBitTrue, OverflowBB, NoOverflowBB);

// BB overflow.no:
Builder.SetInsertPoint(NoOverflowBB);
Value *ExtLoLHS, *ExtLoRHS;
if (IsSigned) {
ExtLoLHS = Builder.CreateSExt(LoLHS, Ty, "lo.lhs.ext");
ExtLoRHS = Builder.CreateSExt(LoRHS, Ty, "lo.rhs.ext");
} else {
ExtLoLHS = Builder.CreateZExt(LoLHS, Ty, "lo.lhs.ext");
ExtLoRHS = Builder.CreateZExt(LoRHS, Ty, "lo.rhs.ext");
}

Value *Mul = Builder.CreateMul(ExtLoLHS, ExtLoRHS, "mul.overflow.no");

// Create the 'overflow.res' BB to merge the results of
// the two paths:
BasicBlock *OverflowResBB = I->getParent();
OverflowResBB->setName("overflow.res");

// BB overflow.no: jump to overflow.res BB
Builder.CreateBr(OverflowResBB);
// No we don't need the old terminator in overflow.entry BB, erase it:
OldTerminator->eraseFromParent();

// BB overflow.res:
Builder.SetInsertPoint(OverflowResBB, OverflowResBB->getFirstInsertionPt());
// Create PHI nodes to merge results from no.overflow BB and overflow BB to
// replace the extract instructions.
PHINode *OverflowResPHI = Builder.CreatePHI(Ty, 2),
*OverflowFlagPHI =
Builder.CreatePHI(IntegerType::getInt1Ty(I->getContext()), 2);

// Add the incoming values from no.overflow BB and later from overflow BB.
OverflowResPHI->addIncoming(Mul, NoOverflowBB);
OverflowFlagPHI->addIncoming(ConstantInt::getFalse(I->getContext()),
NoOverflowBB);

// Replace all users of MulExtract and OverflowExtract to use the PHI nodes.
if (MulExtract) {
MulExtract->replaceAllUsesWith(OverflowResPHI);
MulExtract->eraseFromParent();
}
if (OverflowExtract) {
OverflowExtract->replaceAllUsesWith(OverflowFlagPHI);
OverflowExtract->eraseFromParent();
}

// Remove the intrinsic from parent (overflow.res BB) as it will be part of
// overflow BB
I->removeFromParent();
// BB overflow:
I->insertInto(OverflowBB, OverflowBB->end());
Builder.SetInsertPoint(OverflowBB, OverflowBB->end());
Value *MulOverflow = Builder.CreateExtractValue(I, {0}, "mul.overflow");
Value *OverflowFlag = Builder.CreateExtractValue(I, {1}, "overflow.flag");
Builder.CreateBr(OverflowResBB);

// Add The Extracted values to the PHINodes in the overflow.res BB.
OverflowResPHI->addIncoming(MulOverflow, OverflowBB);
OverflowFlagPHI->addIncoming(OverflowFlag, OverflowBB);

ModifiedDT = ModifyDT::ModifyBBDT;
return true;
}

/// If there are any memory operands, use OptimizeMemoryInst to sink their
/// address computing into the block when possible / profitable.
bool CodeGenPrepare::optimizeInlineAsmInst(CallInst *CS) {
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18542,6 +18542,15 @@ bool AArch64TargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT,
return (Index == 0 || Index == ResVT.getVectorMinNumElements());
}

bool AArch64TargetLowering::shouldOptimizeMulOverflowWithZeroHighBits(
LLVMContext &Context, EVT VT) const {
if (getTypeAction(Context, VT) != TypeExpandInteger)
return false;

EVT LegalTy = EVT::getIntegerVT(Context, VT.getSizeInBits() / 2);
return getTypeAction(Context, LegalTy) == TargetLowering::TypeLegal;
}

/// Turn vector tests of the signbit in the form of:
/// xor (sra X, elt_size(X)-1), -1
/// into:
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,11 @@ class AArch64TargetLowering : public TargetLowering {
return TargetLowering::shouldFormOverflowOp(Opcode, VT, true);
}

// Return true if the target wants to optimize the mul overflow intrinsic
// for the given \p VT.
bool shouldOptimizeMulOverflowWithZeroHighBits(LLVMContext &Context,
EVT VT) const override;

Value *emitLoadLinked(IRBuilderBase &Builder, Type *ValueTy, Value *Addr,
AtomicOrdering Ord) const override;
Value *emitStoreConditional(IRBuilderBase &Builder, Value *Val, Value *Addr,
Expand Down
Loading
Loading