-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[CGP]: Optimize mul.overflow. #148343
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CGP]: Optimize mul.overflow. #148343
Conversation
hassnaaHamdi
commented
Jul 12, 2025
- Detect cases where LHS & RHS values will not cause overflow (when the Hi parts are zero).
- Detect cases where either of LHS or RHS values could not cause overflow (when one of the Hi parts is zero).
|
@llvm/pr-subscribers-backend-powerpc @llvm/pr-subscribers-backend-arm Author: Hassnaa Hamdi (hassnaaHamdi) Changes
Patch is 676.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148343.diff 20 Files Affected:
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 9bbb89e37865d..d9859ed246604 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -431,6 +431,8 @@ class CodeGenPrepare {
bool optimizeMemoryInst(Instruction *MemoryInst, Value *Addr, Type *AccessTy,
unsigned AddrSpace);
bool optimizeGatherScatterInst(Instruction *MemoryInst, Value *Ptr);
+ bool optimizeUMulWithOverflow(Instruction *I);
+ bool optimizeSMulWithOverflow(Instruction *I);
bool optimizeInlineAsmInst(CallInst *CS);
bool optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT);
bool optimizeExt(Instruction *&I);
@@ -2769,6 +2771,10 @@ bool CodeGenPrepare::optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT) {
return optimizeGatherScatterInst(II, II->getArgOperand(0));
case Intrinsic::masked_scatter:
return optimizeGatherScatterInst(II, II->getArgOperand(1));
+ case Intrinsic::umul_with_overflow:
+ return optimizeUMulWithOverflow(II);
+ case Intrinsic::smul_with_overflow:
+ return optimizeSMulWithOverflow(II);
}
SmallVector<Value *, 2> PtrOps;
@@ -6386,6 +6392,573 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
return true;
}
+// Rewrite the umul_with_overflow intrinsic by checking if any/both of the
+// operands' value range is within the legal type. If so, we can optimize the
+// multiplication algorithm. This code is supposed to be written during the step
+// of type legalization, but given that we need to reconstruct the IR which is
+// not doable there, we do it here.
+bool CodeGenPrepare::optimizeUMulWithOverflow(Instruction *I) {
+ if (TLI->getTypeAction(
+ I->getContext(),
+ TLI->getValueType(*DL, I->getType()->getContainedType(0))) !=
+ TargetLowering::TypeExpandInteger)
+ return false;
+ Value *LHS = I->getOperand(0);
+ Value *RHS = I->getOperand(1);
+ auto *Ty = LHS->getType();
+ unsigned VTBitWidth = Ty->getScalarSizeInBits();
+ unsigned VTHalfBitWidth = VTBitWidth / 2;
+ auto *LegalTy = IntegerType::getIntNTy(I->getContext(), VTHalfBitWidth);
+
+ assert(
+ (TLI->getTypeAction(I->getContext(), TLI->getValueType(*DL, LegalTy)) ==
+ TargetLowering::TypeLegal) &&
+ "Expected the type to be legal for the target lowering");
+
+ I->getParent()->setName("overflow.res");
+ auto *OverflowResBB = I->getParent();
+ auto *OverflowoEntryBB =
+ I->getParent()->splitBasicBlock(I, "overflow.entry", /*Before*/ true);
+ BasicBlock *OverflowLHSBB = BasicBlock::Create(
+ I->getContext(), "overflow.lhs", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowLHSBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.lhs", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowRHSonlyBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.rhs.only", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowLHSonlyBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.lhs.only", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowBB = BasicBlock::Create(
+ I->getContext(), "overflow.no", I->getFunction(), OverflowResBB);
+ BasicBlock *OverflowBB = BasicBlock::Create(I->getContext(), "overflow",
+ I->getFunction(), OverflowResBB);
+ // new blocks should be:
+ // entry:
+ // lhs_lo ne lhs_hi ? overflow_yes_lhs, overflow_no_lhs
+
+ // overflow_yes_lhs:
+ // rhs_lo ne rhs_hi ? overflow : overflow_no_rhs_only
+
+ // overflow_no_lhs:
+ // rhs_lo ne rhs_hi ? overflow_no_lhs_only : overflow_no
+
+ // overflow_no_rhs_only:
+ // overflow_no_lhs_only:
+ // overflow_no:
+ // overflow:
+ // overflow.res:
+
+ IRBuilder<> BuilderEntryBB(OverflowoEntryBB->getTerminator());
+ IRBuilder<> BuilderOverflowLHSBB(OverflowLHSBB);
+ IRBuilder<> BuilderNoOverflowLHSBB(NoOverflowLHSBB);
+ IRBuilder<> BuilderNoOverflowRHSonlyBB(NoOverflowRHSonlyBB);
+ IRBuilder<> BuilderNoOverflowLHSonlyBB(NoOverflowLHSonlyBB);
+ IRBuilder<> BuilderNoOverflowBB(NoOverflowBB);
+ IRBuilder<> BuilderOverflowResBB(OverflowResBB,
+ OverflowResBB->getFirstInsertionPt());
+
+ //------------------------------------------------------------------------------
+ // BB overflow.entry:
+ // get Lo and Hi of RHS & LHS:
+
+ auto *LoRHS = BuilderEntryBB.CreateTrunc(RHS, LegalTy, "lo.rhs.trunc");
+ auto *ShrHiRHS = BuilderEntryBB.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr");
+ auto *HiRHS = BuilderEntryBB.CreateTrunc(ShrHiRHS, LegalTy, "hi.rhs.trunc");
+
+ auto *LoLHS = BuilderEntryBB.CreateTrunc(LHS, LegalTy, "lo.lhs.trunc");
+ auto *ShrHiLHS = BuilderEntryBB.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr");
+ auto *HiLHS = BuilderEntryBB.CreateTrunc(ShrHiLHS, LegalTy, "hi.lhs.trunc");
+
+ auto *Cmp = BuilderEntryBB.CreateCmp(ICmpInst::ICMP_NE, HiLHS,
+ ConstantInt::getNullValue(LegalTy));
+ BuilderEntryBB.CreateCondBr(Cmp, OverflowLHSBB, NoOverflowLHSBB);
+ OverflowoEntryBB->getTerminator()->eraseFromParent();
+
+ //------------------------------------------------------------------------------
+ // BB overflow_yes_lhs:
+ Cmp = BuilderOverflowLHSBB.CreateCmp(ICmpInst::ICMP_NE, HiRHS,
+ ConstantInt::getNullValue(LegalTy));
+ BuilderOverflowLHSBB.CreateCondBr(Cmp, OverflowBB, NoOverflowRHSonlyBB);
+
+ //------------------------------------------------------------------------------
+ // BB overflow_no_lhs:
+ Cmp = BuilderNoOverflowLHSBB.CreateCmp(ICmpInst::ICMP_NE, HiRHS,
+ ConstantInt::getNullValue(LegalTy));
+ BuilderNoOverflowLHSBB.CreateCondBr(Cmp, NoOverflowLHSonlyBB, NoOverflowBB);
+
+ //------------------------------------------------------------------------------
+ // BB overflow_no_rhs_only:
+ // RHS is 64 value range, LHS is 128
+ // P0 = RHS * LoLHS
+ // P1 = RHS * HiLHS
+
+ LoLHS = BuilderNoOverflowRHSonlyBB.CreateZExt(LoLHS, Ty, "lo.lhs");
+
+ // P0 = (RHS * LoLHS)
+ auto *P0 = BuilderNoOverflowRHSonlyBB.CreateMul(RHS, LoLHS,
+ "mul.no.overflow.rhs.lolhs");
+ auto *P0Lo = BuilderNoOverflowRHSonlyBB.CreateTrunc(P0, LegalTy, "p0.lo.rhs");
+ auto *P0Hi =
+ BuilderNoOverflowRHSonlyBB.CreateLShr(P0, VTHalfBitWidth, "p0.rhs.lsr");
+ P0Hi = BuilderNoOverflowRHSonlyBB.CreateTrunc(P0Hi, LegalTy, "p0.hi.rhs");
+
+ // P1 = (RHS * HiLHS)
+ auto *P1 = BuilderNoOverflowRHSonlyBB.CreateMul(RHS, ShrHiLHS,
+ "mul.no.overflow.rhs.hilhs");
+ auto *P1Lo = BuilderNoOverflowRHSonlyBB.CreateTrunc(P1, LegalTy, "p1.lo.rhs");
+ auto *P1Hi =
+ BuilderNoOverflowRHSonlyBB.CreateLShr(P1, VTHalfBitWidth, "p1.rhs.lsr");
+ P1Hi = BuilderNoOverflowRHSonlyBB.CreateTrunc(P1Hi, LegalTy, "p1.hi.rhs");
+
+ auto *AddOverflow = BuilderNoOverflowRHSonlyBB.CreateIntrinsic(
+ Intrinsic::uadd_with_overflow, LegalTy, {P0Hi, P1Lo});
+ auto *AddOResMid = BuilderNoOverflowRHSonlyBB.CreateExtractValue(
+ AddOverflow, 0, "rhs.p0.p1.res");
+ auto *Carry = BuilderNoOverflowRHSonlyBB.CreateExtractValue(
+ AddOverflow, 1, "rhs.p0.p1.carry");
+ Carry =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(Carry, LegalTy, "rhs.carry.zext");
+ auto *ResHi =
+ BuilderNoOverflowRHSonlyBB.CreateAdd(P1Hi, Carry, "rhs.p1.carry");
+
+ auto *ResLoEx =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(P0Lo, Ty, "rhs.res_lo.zext");
+ auto *ResMid =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(AddOResMid, Ty, "rhs.res_mid.zext");
+ auto *ResMidShl = BuilderNoOverflowRHSonlyBB.CreateShl(ResMid, VTHalfBitWidth,
+ "rhs.res_mid.shl");
+ auto *FinalRes = BuilderNoOverflowRHSonlyBB.CreateOr(ResLoEx, ResMidShl,
+ "rhs.res_lo.or.mid");
+ auto *IsOverflow = BuilderNoOverflowRHSonlyBB.CreateICmp(
+ ICmpInst::ICMP_NE, ResHi, Constant::getNullValue(LegalTy),
+ "rhs.check.overflow");
+
+ StructType *STy = StructType::get(
+ I->getContext(), {Ty, IntegerType::getInt1Ty(I->getContext())});
+ Value *StructValNoOverflowRHS = PoisonValue::get(STy);
+ StructValNoOverflowRHS = BuilderNoOverflowRHSonlyBB.CreateInsertValue(
+ StructValNoOverflowRHS, FinalRes, {0});
+ StructValNoOverflowRHS = BuilderNoOverflowRHSonlyBB.CreateInsertValue(
+ StructValNoOverflowRHS, IsOverflow, {1});
+ BuilderNoOverflowRHSonlyBB.CreateBr(OverflowResBB);
+ //------------------------------------------------------------------------------
+
+ // BB overflow_no_lhs_only:
+
+ LoRHS = BuilderNoOverflowLHSonlyBB.CreateZExt(LoRHS, Ty, "lo.rhs");
+
+ // P0 = (LHS * LoRHS)
+ P0 = BuilderNoOverflowLHSonlyBB.CreateMul(LHS, LoRHS,
+ "mul.no.overflow.lhs.lorhs");
+ P0Lo = BuilderNoOverflowLHSonlyBB.CreateTrunc(P0, LegalTy, "p0.lo.lhs");
+ P0Hi =
+ BuilderNoOverflowLHSonlyBB.CreateLShr(P0, VTHalfBitWidth, "p0.lsr.lhs");
+ P0Hi = BuilderNoOverflowLHSonlyBB.CreateTrunc(P0Hi, LegalTy, "p0.hi.lhs");
+
+ // P1 = (LHS * HiRHS)
+ P1 = BuilderNoOverflowLHSonlyBB.CreateMul(LHS, ShrHiRHS,
+ "mul.no.overflow.lhs.hirhs");
+ P1Lo = BuilderNoOverflowLHSonlyBB.CreateTrunc(P1, LegalTy, "p1.lo.lhs");
+ P1Hi =
+ BuilderNoOverflowLHSonlyBB.CreateLShr(P1, VTHalfBitWidth, "p1.lhs.lsr");
+ P1Hi = BuilderNoOverflowLHSonlyBB.CreateTrunc(P1Hi, LegalTy, "p1.hi.lhs");
+
+ AddOverflow = BuilderNoOverflowLHSonlyBB.CreateIntrinsic(
+ Intrinsic::uadd_with_overflow, LegalTy, {P0Hi, P1Lo});
+ AddOResMid = BuilderNoOverflowLHSonlyBB.CreateExtractValue(AddOverflow, 0,
+ "lhs.p0.p1.res");
+ Carry = BuilderNoOverflowLHSonlyBB.CreateExtractValue(AddOverflow, 1,
+ "lhs.p0.p1.carry");
+ Carry =
+ BuilderNoOverflowLHSonlyBB.CreateZExt(Carry, LegalTy, "lhs.carry.zext");
+ ResHi = BuilderNoOverflowLHSonlyBB.CreateAdd(P1Hi, Carry, "lhs.p1.carry");
+
+ ResLoEx = BuilderNoOverflowLHSonlyBB.CreateZExt(P0Lo, Ty, "lhs.res_lo.zext");
+ ResMid =
+ BuilderNoOverflowLHSonlyBB.CreateZExt(AddOResMid, Ty, "lhs.res_mid.zext");
+ ResMidShl = BuilderNoOverflowLHSonlyBB.CreateShl(ResMid, VTHalfBitWidth,
+ "lhs.res_mid.shl");
+ FinalRes = BuilderNoOverflowLHSonlyBB.CreateOr(ResLoEx, ResMidShl,
+ "lhs.res_lo.or.mid");
+ IsOverflow = BuilderNoOverflowLHSonlyBB.CreateICmp(
+ ICmpInst::ICMP_NE, ResHi, Constant::getNullValue(LegalTy),
+ "lhs.check.overflow");
+
+ STy = StructType::get(I->getContext(),
+ {Ty, IntegerType::getInt1Ty(I->getContext())});
+ Value *StructValNoOverflowLHS = PoisonValue::get(STy);
+ StructValNoOverflowLHS = BuilderNoOverflowLHSonlyBB.CreateInsertValue(
+ StructValNoOverflowLHS, FinalRes, {0});
+ StructValNoOverflowLHS = BuilderNoOverflowLHSonlyBB.CreateInsertValue(
+ StructValNoOverflowLHS, IsOverflow, {1});
+
+ BuilderNoOverflowLHSonlyBB.CreateBr(OverflowResBB);
+ //------------------------------------------------------------------------------
+
+ // BB overflow.no:
+ auto *Mul = BuilderNoOverflowBB.CreateMul(LHS, RHS, "mul.no.overflow");
+ STy = StructType::get(I->getContext(),
+ {Ty, IntegerType::getInt1Ty(I->getContext())});
+ Value *StructValNoOverflow = PoisonValue::get(STy);
+ StructValNoOverflow =
+ BuilderNoOverflowBB.CreateInsertValue(StructValNoOverflow, Mul, {0});
+ StructValNoOverflow = BuilderNoOverflowBB.CreateInsertValue(
+ StructValNoOverflow, ConstantInt::getFalse(I->getContext()), {1});
+ BuilderNoOverflowBB.CreateBr(OverflowResBB);
+
+ // BB overflow.res:
+ auto *PHINode = BuilderOverflowResBB.CreatePHI(STy, 2);
+ PHINode->addIncoming(StructValNoOverflow, NoOverflowBB);
+ PHINode->addIncoming(StructValNoOverflowLHS, NoOverflowLHSonlyBB);
+ PHINode->addIncoming(StructValNoOverflowRHS, NoOverflowRHSonlyBB);
+
+ // Before moving the mul.overflow intrinsic to the overflowBB, replace all its
+ // uses by PHINode.
+ I->replaceAllUsesWith(PHINode);
+
+ // BB overflow:
+ PHINode->addIncoming(I, OverflowBB);
+ I->removeFromParent();
+ I->insertInto(OverflowBB, OverflowBB->end());
+ IRBuilder<>(OverflowBB, OverflowBB->end()).CreateBr(OverflowResBB);
+
+ // return false to stop reprocessing the function.
+ return false;
+}
+
+// Rewrite the smul_with_overflow intrinsic by checking if any/both of the
+// operands' value range is within the legal type. If so, we can optimize the
+// multiplication algorithm. This code is supposed to be written during the step
+// of type legalization, but given that we need to reconstruct the IR which is
+// not doable there, we do it here.
+bool CodeGenPrepare::optimizeSMulWithOverflow(Instruction *I) {
+ if (TLI->getTypeAction(
+ I->getContext(),
+ TLI->getValueType(*DL, I->getType()->getContainedType(0))) !=
+ TargetLowering::TypeExpandInteger)
+ return false;
+ Value *LHS = I->getOperand(0);
+ Value *RHS = I->getOperand(1);
+ auto *Ty = LHS->getType();
+ unsigned VTBitWidth = Ty->getScalarSizeInBits();
+ unsigned VTHalfBitWidth = VTBitWidth / 2;
+ auto *LegalTy = IntegerType::getIntNTy(I->getContext(), VTHalfBitWidth);
+
+ assert(
+ (TLI->getTypeAction(I->getContext(), TLI->getValueType(*DL, LegalTy)) ==
+ TargetLowering::TypeLegal) &&
+ "Expected the type to be legal for the target lowering");
+
+ I->getParent()->setName("overflow.res");
+ auto *OverflowResBB = I->getParent();
+ auto *OverflowoEntryBB =
+ I->getParent()->splitBasicBlock(I, "overflow.entry", /*Before*/ true);
+ BasicBlock *OverflowLHSBB = BasicBlock::Create(
+ I->getContext(), "overflow.lhs", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowLHSBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.lhs", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowRHSonlyBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.rhs.only", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowLHSonlyBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.lhs.only", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowBB = BasicBlock::Create(
+ I->getContext(), "overflow.no", I->getFunction(), OverflowResBB);
+ BasicBlock *OverflowBB = BasicBlock::Create(I->getContext(), "overflow",
+ I->getFunction(), OverflowResBB);
+ // new blocks should be:
+ // entry:
+ // lhs_lo ne lhs_hi ? overflow_yes_lhs, overflow_no_lhs
+
+ // overflow_yes_lhs:
+ // rhs_lo ne rhs_hi ? overflow : overflow_no_rhs_only
+
+ // overflow_no_lhs:
+ // rhs_lo ne rhs_hi ? overflow_no_lhs_only : overflow_no
+
+ // overflow_no_rhs_only:
+ // overflow_no_lhs_only:
+ // overflow_no:
+ // overflow:
+ // overflow.res:
+
+ IRBuilder<> BuilderEntryBB(OverflowoEntryBB->getTerminator());
+ IRBuilder<> BuilderOverflowLHSBB(OverflowLHSBB);
+ IRBuilder<> BuilderNoOverflowLHSBB(NoOverflowLHSBB);
+ IRBuilder<> BuilderNoOverflowRHSonlyBB(NoOverflowRHSonlyBB);
+ IRBuilder<> BuilderNoOverflowLHSonlyBB(NoOverflowLHSonlyBB);
+ IRBuilder<> BuilderNoOverflowBB(NoOverflowBB);
+ IRBuilder<> BuilderOverflowResBB(OverflowResBB,
+ OverflowResBB->getFirstInsertionPt());
+
+ //------------------------------------------------------------------------------
+ // BB overflow.entry:
+ // get Lo and Hi of RHS & LHS:
+
+ auto *LoRHS = BuilderEntryBB.CreateTrunc(RHS, LegalTy, "lo.rhs");
+ auto *SignLoRHS =
+ BuilderEntryBB.CreateAShr(LoRHS, VTHalfBitWidth - 1, "sign.lo.rhs");
+ auto *HiRHS = BuilderEntryBB.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr");
+ HiRHS = BuilderEntryBB.CreateTrunc(HiRHS, LegalTy, "hi.rhs");
+
+ auto *LoLHS = BuilderEntryBB.CreateTrunc(LHS, LegalTy, "lo.lhs");
+ auto *SignLoLHS =
+ BuilderEntryBB.CreateAShr(LoLHS, VTHalfBitWidth - 1, "sign.lo.lhs");
+ auto *HiLHS = BuilderEntryBB.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr");
+ HiLHS = BuilderEntryBB.CreateTrunc(HiLHS, LegalTy, "hi.lhs");
+
+ auto *Cmp = BuilderEntryBB.CreateCmp(ICmpInst::ICMP_NE, HiLHS, SignLoLHS);
+ BuilderEntryBB.CreateCondBr(Cmp, OverflowLHSBB, NoOverflowLHSBB);
+ OverflowoEntryBB->getTerminator()->eraseFromParent();
+
+ //------------------------------------------------------------------------------
+ // BB overflow_yes_lhs:
+ Cmp = BuilderOverflowLHSBB.CreateCmp(ICmpInst::ICMP_NE, HiRHS, SignLoRHS);
+ BuilderOverflowLHSBB.CreateCondBr(Cmp, OverflowBB, NoOverflowRHSonlyBB);
+
+ //------------------------------------------------------------------------------
+ // BB overflow_no_lhs:
+ Cmp = BuilderNoOverflowLHSBB.CreateCmp(ICmpInst::ICMP_NE, HiRHS, SignLoRHS);
+ BuilderNoOverflowLHSBB.CreateCondBr(Cmp, NoOverflowLHSonlyBB, NoOverflowBB);
+
+ //------------------------------------------------------------------------------
+ // BB overflow_no_rhs_only:
+ // RHS is within 64 value range, LHS is 128
+ // P0 = RHS * LoLHS
+ // P1 = RHS * HiLHS
+
+ // check sign of RHS:
+ auto *IsNegRHS = BuilderNoOverflowRHSonlyBB.CreateIsNeg(RHS, "rhs.isneg");
+ auto *AbsRHSIntr = BuilderNoOverflowRHSonlyBB.CreateBinaryIntrinsic(
+ Intrinsic::abs, RHS, ConstantInt::getFalse(I->getContext()), {},
+ "abs.rhs");
+ auto *AbsRHS = BuilderNoOverflowRHSonlyBB.CreateSelect(
+ IsNegRHS, AbsRHSIntr, RHS, "lo.abs.rhs.select");
+
+ // check sign of LHS:
+ auto *IsNegLHS = BuilderNoOverflowRHSonlyBB.CreateIsNeg(LHS, "lhs.isneg");
+ auto *AbsLHSIntr = BuilderNoOverflowRHSonlyBB.CreateBinaryIntrinsic(
+ Intrinsic::abs, LHS, ConstantInt::getFalse(I->getContext()), {},
+ "abs.lhs");
+ auto *AbsLHS = BuilderNoOverflowRHSonlyBB.CreateSelect(IsNegLHS, AbsLHSIntr,
+ LHS, "abs.lhs.select");
+ LoLHS = BuilderNoOverflowRHSonlyBB.CreateAnd(
+ AbsLHS,
+ ConstantInt::get(Ty, APInt::getLowBitsSet(VTBitWidth, VTHalfBitWidth)),
+ "lo.abs.lhs");
+ HiLHS = BuilderNoOverflowRHSonlyBB.CreateLShr(AbsLHS, VTHalfBitWidth,
+ "hi.abs.lhs");
+
+ // P0 = (RHS * LoLHS)
+ auto *P0 = BuilderNoOverflowRHSonlyBB.CreateMul(AbsRHS, LoLHS,
+ "mul.no.overflow.rhs.lolhs");
+ auto *P0Lo = BuilderNoOverflowRHSonlyBB.CreateTrunc(P0, LegalTy, "p0.lo.rhs");
+ auto *P0Hi =
+ BuilderNoOverflowRHSonlyBB.CreateLShr(P0, VTHalfBitWidth, "p0.rhs.lsr");
+ P0Hi = BuilderNoOverflowRHSonlyBB.CreateTrunc(P0Hi, LegalTy, "p0.hi.rhs");
+
+ // P1 = (RHS * HiLHS)
+ auto *P1 = BuilderNoOverflowRHSonlyBB.CreateMul(AbsRHS, HiLHS,
+ "mul.no.overflow.rhs.hilhs");
+ auto *P1Lo = BuilderNoOverflowRHSonlyBB.CreateTrunc(P1, LegalTy, "p1.lo.rhs");
+ auto *P1Hi =
+ BuilderNoOverflowRHSonlyBB.CreateLShr(P1, VTHalfBitWidth, "p1.rhs.lsr");
+ P1Hi = BuilderNoOverflowRHSonlyBB.CreateTrunc(P1Hi, LegalTy, "p1.hi.rhs");
+
+ auto *AddOverflow = BuilderNoOverflowRHSonlyBB.CreateIntrinsic(
+ Intrinsic::uadd_with_overflow, LegalTy, {P0Hi, P1Lo});
+ auto *AddOResMid = BuilderNoOverflowRHSonlyBB.CreateExtractValue(
+ AddOverflow, 0, "rhs.p0.p1.res");
+ auto *Carry = BuilderNoOverflowRHSonlyBB.CreateExtractValue(
+ AddOverflow, 1, "rhs.p0.p1.carry");
+ Carry =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(Carry, LegalTy, "rhs.carry.zext");
+ auto *ResHi =
+ BuilderNoOverflowRHSonlyBB.CreateAdd(P1Hi, Carry, "rhs.p1.carry");
+
+ // sign handling:
+ auto *IsNeg = BuilderNoOverflowRHSonlyBB.CreateXor(IsNegRHS, IsNegLHS); // i1
+ auto *Mask =
+ BuilderNoOverflowRHSonlyBB.CreateSExt(IsNeg, LegalTy, "rhs.sign.mask");
+ auto *Add_1 =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(IsNeg, LegalTy, "rhs.add.1");
+ auto *ResLo =
+ BuilderNoOverflowRHSonlyBB.CreateXor(P0Lo, Mask, "rhs.res_lo.xor.mask");
+ ResLo =
+ BuilderNoOverflowRHSonlyBB.CreateAdd(ResLo, Add_1, "rhs.res_lo.add.1");
+
+ Carry = BuilderNoOverflowRHSonlyBB.CreateCmp(ICmpInst::ICMP_ULT, ResLo, Add_1,
+ "rhs.check.res_lo.c...
[truncated]
|
|
@llvm/pr-subscribers-backend-aarch64 Author: Hassnaa Hamdi (hassnaaHamdi) Changes
Patch is 676.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148343.diff 20 Files Affected:
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 9bbb89e37865d..d9859ed246604 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -431,6 +431,8 @@ class CodeGenPrepare {
bool optimizeMemoryInst(Instruction *MemoryInst, Value *Addr, Type *AccessTy,
unsigned AddrSpace);
bool optimizeGatherScatterInst(Instruction *MemoryInst, Value *Ptr);
+ bool optimizeUMulWithOverflow(Instruction *I);
+ bool optimizeSMulWithOverflow(Instruction *I);
bool optimizeInlineAsmInst(CallInst *CS);
bool optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT);
bool optimizeExt(Instruction *&I);
@@ -2769,6 +2771,10 @@ bool CodeGenPrepare::optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT) {
return optimizeGatherScatterInst(II, II->getArgOperand(0));
case Intrinsic::masked_scatter:
return optimizeGatherScatterInst(II, II->getArgOperand(1));
+ case Intrinsic::umul_with_overflow:
+ return optimizeUMulWithOverflow(II);
+ case Intrinsic::smul_with_overflow:
+ return optimizeSMulWithOverflow(II);
}
SmallVector<Value *, 2> PtrOps;
@@ -6386,6 +6392,573 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
return true;
}
+// Rewrite the umul_with_overflow intrinsic by checking if any/both of the
+// operands' value range is within the legal type. If so, we can optimize the
+// multiplication algorithm. This code is supposed to be written during the step
+// of type legalization, but given that we need to reconstruct the IR which is
+// not doable there, we do it here.
+bool CodeGenPrepare::optimizeUMulWithOverflow(Instruction *I) {
+ if (TLI->getTypeAction(
+ I->getContext(),
+ TLI->getValueType(*DL, I->getType()->getContainedType(0))) !=
+ TargetLowering::TypeExpandInteger)
+ return false;
+ Value *LHS = I->getOperand(0);
+ Value *RHS = I->getOperand(1);
+ auto *Ty = LHS->getType();
+ unsigned VTBitWidth = Ty->getScalarSizeInBits();
+ unsigned VTHalfBitWidth = VTBitWidth / 2;
+ auto *LegalTy = IntegerType::getIntNTy(I->getContext(), VTHalfBitWidth);
+
+ assert(
+ (TLI->getTypeAction(I->getContext(), TLI->getValueType(*DL, LegalTy)) ==
+ TargetLowering::TypeLegal) &&
+ "Expected the type to be legal for the target lowering");
+
+ I->getParent()->setName("overflow.res");
+ auto *OverflowResBB = I->getParent();
+ auto *OverflowoEntryBB =
+ I->getParent()->splitBasicBlock(I, "overflow.entry", /*Before*/ true);
+ BasicBlock *OverflowLHSBB = BasicBlock::Create(
+ I->getContext(), "overflow.lhs", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowLHSBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.lhs", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowRHSonlyBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.rhs.only", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowLHSonlyBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.lhs.only", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowBB = BasicBlock::Create(
+ I->getContext(), "overflow.no", I->getFunction(), OverflowResBB);
+ BasicBlock *OverflowBB = BasicBlock::Create(I->getContext(), "overflow",
+ I->getFunction(), OverflowResBB);
+ // new blocks should be:
+ // entry:
+ // lhs_lo ne lhs_hi ? overflow_yes_lhs, overflow_no_lhs
+
+ // overflow_yes_lhs:
+ // rhs_lo ne rhs_hi ? overflow : overflow_no_rhs_only
+
+ // overflow_no_lhs:
+ // rhs_lo ne rhs_hi ? overflow_no_lhs_only : overflow_no
+
+ // overflow_no_rhs_only:
+ // overflow_no_lhs_only:
+ // overflow_no:
+ // overflow:
+ // overflow.res:
+
+ IRBuilder<> BuilderEntryBB(OverflowoEntryBB->getTerminator());
+ IRBuilder<> BuilderOverflowLHSBB(OverflowLHSBB);
+ IRBuilder<> BuilderNoOverflowLHSBB(NoOverflowLHSBB);
+ IRBuilder<> BuilderNoOverflowRHSonlyBB(NoOverflowRHSonlyBB);
+ IRBuilder<> BuilderNoOverflowLHSonlyBB(NoOverflowLHSonlyBB);
+ IRBuilder<> BuilderNoOverflowBB(NoOverflowBB);
+ IRBuilder<> BuilderOverflowResBB(OverflowResBB,
+ OverflowResBB->getFirstInsertionPt());
+
+ //------------------------------------------------------------------------------
+ // BB overflow.entry:
+ // get Lo and Hi of RHS & LHS:
+
+ auto *LoRHS = BuilderEntryBB.CreateTrunc(RHS, LegalTy, "lo.rhs.trunc");
+ auto *ShrHiRHS = BuilderEntryBB.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr");
+ auto *HiRHS = BuilderEntryBB.CreateTrunc(ShrHiRHS, LegalTy, "hi.rhs.trunc");
+
+ auto *LoLHS = BuilderEntryBB.CreateTrunc(LHS, LegalTy, "lo.lhs.trunc");
+ auto *ShrHiLHS = BuilderEntryBB.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr");
+ auto *HiLHS = BuilderEntryBB.CreateTrunc(ShrHiLHS, LegalTy, "hi.lhs.trunc");
+
+ auto *Cmp = BuilderEntryBB.CreateCmp(ICmpInst::ICMP_NE, HiLHS,
+ ConstantInt::getNullValue(LegalTy));
+ BuilderEntryBB.CreateCondBr(Cmp, OverflowLHSBB, NoOverflowLHSBB);
+ OverflowoEntryBB->getTerminator()->eraseFromParent();
+
+ //------------------------------------------------------------------------------
+ // BB overflow_yes_lhs:
+ Cmp = BuilderOverflowLHSBB.CreateCmp(ICmpInst::ICMP_NE, HiRHS,
+ ConstantInt::getNullValue(LegalTy));
+ BuilderOverflowLHSBB.CreateCondBr(Cmp, OverflowBB, NoOverflowRHSonlyBB);
+
+ //------------------------------------------------------------------------------
+ // BB overflow_no_lhs:
+ Cmp = BuilderNoOverflowLHSBB.CreateCmp(ICmpInst::ICMP_NE, HiRHS,
+ ConstantInt::getNullValue(LegalTy));
+ BuilderNoOverflowLHSBB.CreateCondBr(Cmp, NoOverflowLHSonlyBB, NoOverflowBB);
+
+ //------------------------------------------------------------------------------
+ // BB overflow_no_rhs_only:
+ // RHS is 64 value range, LHS is 128
+ // P0 = RHS * LoLHS
+ // P1 = RHS * HiLHS
+
+ LoLHS = BuilderNoOverflowRHSonlyBB.CreateZExt(LoLHS, Ty, "lo.lhs");
+
+ // P0 = (RHS * LoLHS)
+ auto *P0 = BuilderNoOverflowRHSonlyBB.CreateMul(RHS, LoLHS,
+ "mul.no.overflow.rhs.lolhs");
+ auto *P0Lo = BuilderNoOverflowRHSonlyBB.CreateTrunc(P0, LegalTy, "p0.lo.rhs");
+ auto *P0Hi =
+ BuilderNoOverflowRHSonlyBB.CreateLShr(P0, VTHalfBitWidth, "p0.rhs.lsr");
+ P0Hi = BuilderNoOverflowRHSonlyBB.CreateTrunc(P0Hi, LegalTy, "p0.hi.rhs");
+
+ // P1 = (RHS * HiLHS)
+ auto *P1 = BuilderNoOverflowRHSonlyBB.CreateMul(RHS, ShrHiLHS,
+ "mul.no.overflow.rhs.hilhs");
+ auto *P1Lo = BuilderNoOverflowRHSonlyBB.CreateTrunc(P1, LegalTy, "p1.lo.rhs");
+ auto *P1Hi =
+ BuilderNoOverflowRHSonlyBB.CreateLShr(P1, VTHalfBitWidth, "p1.rhs.lsr");
+ P1Hi = BuilderNoOverflowRHSonlyBB.CreateTrunc(P1Hi, LegalTy, "p1.hi.rhs");
+
+ auto *AddOverflow = BuilderNoOverflowRHSonlyBB.CreateIntrinsic(
+ Intrinsic::uadd_with_overflow, LegalTy, {P0Hi, P1Lo});
+ auto *AddOResMid = BuilderNoOverflowRHSonlyBB.CreateExtractValue(
+ AddOverflow, 0, "rhs.p0.p1.res");
+ auto *Carry = BuilderNoOverflowRHSonlyBB.CreateExtractValue(
+ AddOverflow, 1, "rhs.p0.p1.carry");
+ Carry =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(Carry, LegalTy, "rhs.carry.zext");
+ auto *ResHi =
+ BuilderNoOverflowRHSonlyBB.CreateAdd(P1Hi, Carry, "rhs.p1.carry");
+
+ auto *ResLoEx =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(P0Lo, Ty, "rhs.res_lo.zext");
+ auto *ResMid =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(AddOResMid, Ty, "rhs.res_mid.zext");
+ auto *ResMidShl = BuilderNoOverflowRHSonlyBB.CreateShl(ResMid, VTHalfBitWidth,
+ "rhs.res_mid.shl");
+ auto *FinalRes = BuilderNoOverflowRHSonlyBB.CreateOr(ResLoEx, ResMidShl,
+ "rhs.res_lo.or.mid");
+ auto *IsOverflow = BuilderNoOverflowRHSonlyBB.CreateICmp(
+ ICmpInst::ICMP_NE, ResHi, Constant::getNullValue(LegalTy),
+ "rhs.check.overflow");
+
+ StructType *STy = StructType::get(
+ I->getContext(), {Ty, IntegerType::getInt1Ty(I->getContext())});
+ Value *StructValNoOverflowRHS = PoisonValue::get(STy);
+ StructValNoOverflowRHS = BuilderNoOverflowRHSonlyBB.CreateInsertValue(
+ StructValNoOverflowRHS, FinalRes, {0});
+ StructValNoOverflowRHS = BuilderNoOverflowRHSonlyBB.CreateInsertValue(
+ StructValNoOverflowRHS, IsOverflow, {1});
+ BuilderNoOverflowRHSonlyBB.CreateBr(OverflowResBB);
+ //------------------------------------------------------------------------------
+
+ // BB overflow_no_lhs_only:
+
+ LoRHS = BuilderNoOverflowLHSonlyBB.CreateZExt(LoRHS, Ty, "lo.rhs");
+
+ // P0 = (LHS * LoRHS)
+ P0 = BuilderNoOverflowLHSonlyBB.CreateMul(LHS, LoRHS,
+ "mul.no.overflow.lhs.lorhs");
+ P0Lo = BuilderNoOverflowLHSonlyBB.CreateTrunc(P0, LegalTy, "p0.lo.lhs");
+ P0Hi =
+ BuilderNoOverflowLHSonlyBB.CreateLShr(P0, VTHalfBitWidth, "p0.lsr.lhs");
+ P0Hi = BuilderNoOverflowLHSonlyBB.CreateTrunc(P0Hi, LegalTy, "p0.hi.lhs");
+
+ // P1 = (LHS * HiRHS)
+ P1 = BuilderNoOverflowLHSonlyBB.CreateMul(LHS, ShrHiRHS,
+ "mul.no.overflow.lhs.hirhs");
+ P1Lo = BuilderNoOverflowLHSonlyBB.CreateTrunc(P1, LegalTy, "p1.lo.lhs");
+ P1Hi =
+ BuilderNoOverflowLHSonlyBB.CreateLShr(P1, VTHalfBitWidth, "p1.lhs.lsr");
+ P1Hi = BuilderNoOverflowLHSonlyBB.CreateTrunc(P1Hi, LegalTy, "p1.hi.lhs");
+
+ AddOverflow = BuilderNoOverflowLHSonlyBB.CreateIntrinsic(
+ Intrinsic::uadd_with_overflow, LegalTy, {P0Hi, P1Lo});
+ AddOResMid = BuilderNoOverflowLHSonlyBB.CreateExtractValue(AddOverflow, 0,
+ "lhs.p0.p1.res");
+ Carry = BuilderNoOverflowLHSonlyBB.CreateExtractValue(AddOverflow, 1,
+ "lhs.p0.p1.carry");
+ Carry =
+ BuilderNoOverflowLHSonlyBB.CreateZExt(Carry, LegalTy, "lhs.carry.zext");
+ ResHi = BuilderNoOverflowLHSonlyBB.CreateAdd(P1Hi, Carry, "lhs.p1.carry");
+
+ ResLoEx = BuilderNoOverflowLHSonlyBB.CreateZExt(P0Lo, Ty, "lhs.res_lo.zext");
+ ResMid =
+ BuilderNoOverflowLHSonlyBB.CreateZExt(AddOResMid, Ty, "lhs.res_mid.zext");
+ ResMidShl = BuilderNoOverflowLHSonlyBB.CreateShl(ResMid, VTHalfBitWidth,
+ "lhs.res_mid.shl");
+ FinalRes = BuilderNoOverflowLHSonlyBB.CreateOr(ResLoEx, ResMidShl,
+ "lhs.res_lo.or.mid");
+ IsOverflow = BuilderNoOverflowLHSonlyBB.CreateICmp(
+ ICmpInst::ICMP_NE, ResHi, Constant::getNullValue(LegalTy),
+ "lhs.check.overflow");
+
+ STy = StructType::get(I->getContext(),
+ {Ty, IntegerType::getInt1Ty(I->getContext())});
+ Value *StructValNoOverflowLHS = PoisonValue::get(STy);
+ StructValNoOverflowLHS = BuilderNoOverflowLHSonlyBB.CreateInsertValue(
+ StructValNoOverflowLHS, FinalRes, {0});
+ StructValNoOverflowLHS = BuilderNoOverflowLHSonlyBB.CreateInsertValue(
+ StructValNoOverflowLHS, IsOverflow, {1});
+
+ BuilderNoOverflowLHSonlyBB.CreateBr(OverflowResBB);
+ //------------------------------------------------------------------------------
+
+ // BB overflow.no:
+ auto *Mul = BuilderNoOverflowBB.CreateMul(LHS, RHS, "mul.no.overflow");
+ STy = StructType::get(I->getContext(),
+ {Ty, IntegerType::getInt1Ty(I->getContext())});
+ Value *StructValNoOverflow = PoisonValue::get(STy);
+ StructValNoOverflow =
+ BuilderNoOverflowBB.CreateInsertValue(StructValNoOverflow, Mul, {0});
+ StructValNoOverflow = BuilderNoOverflowBB.CreateInsertValue(
+ StructValNoOverflow, ConstantInt::getFalse(I->getContext()), {1});
+ BuilderNoOverflowBB.CreateBr(OverflowResBB);
+
+ // BB overflow.res:
+ auto *PHINode = BuilderOverflowResBB.CreatePHI(STy, 2);
+ PHINode->addIncoming(StructValNoOverflow, NoOverflowBB);
+ PHINode->addIncoming(StructValNoOverflowLHS, NoOverflowLHSonlyBB);
+ PHINode->addIncoming(StructValNoOverflowRHS, NoOverflowRHSonlyBB);
+
+ // Before moving the mul.overflow intrinsic to the overflowBB, replace all its
+ // uses by PHINode.
+ I->replaceAllUsesWith(PHINode);
+
+ // BB overflow:
+ PHINode->addIncoming(I, OverflowBB);
+ I->removeFromParent();
+ I->insertInto(OverflowBB, OverflowBB->end());
+ IRBuilder<>(OverflowBB, OverflowBB->end()).CreateBr(OverflowResBB);
+
+ // return false to stop reprocessing the function.
+ return false;
+}
+
+// Rewrite the smul_with_overflow intrinsic by checking if any/both of the
+// operands' value range is within the legal type. If so, we can optimize the
+// multiplication algorithm. This code is supposed to be written during the step
+// of type legalization, but given that we need to reconstruct the IR which is
+// not doable there, we do it here.
+bool CodeGenPrepare::optimizeSMulWithOverflow(Instruction *I) {
+ if (TLI->getTypeAction(
+ I->getContext(),
+ TLI->getValueType(*DL, I->getType()->getContainedType(0))) !=
+ TargetLowering::TypeExpandInteger)
+ return false;
+ Value *LHS = I->getOperand(0);
+ Value *RHS = I->getOperand(1);
+ auto *Ty = LHS->getType();
+ unsigned VTBitWidth = Ty->getScalarSizeInBits();
+ unsigned VTHalfBitWidth = VTBitWidth / 2;
+ auto *LegalTy = IntegerType::getIntNTy(I->getContext(), VTHalfBitWidth);
+
+ assert(
+ (TLI->getTypeAction(I->getContext(), TLI->getValueType(*DL, LegalTy)) ==
+ TargetLowering::TypeLegal) &&
+ "Expected the type to be legal for the target lowering");
+
+ I->getParent()->setName("overflow.res");
+ auto *OverflowResBB = I->getParent();
+ auto *OverflowoEntryBB =
+ I->getParent()->splitBasicBlock(I, "overflow.entry", /*Before*/ true);
+ BasicBlock *OverflowLHSBB = BasicBlock::Create(
+ I->getContext(), "overflow.lhs", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowLHSBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.lhs", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowRHSonlyBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.rhs.only", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowLHSonlyBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.lhs.only", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowBB = BasicBlock::Create(
+ I->getContext(), "overflow.no", I->getFunction(), OverflowResBB);
+ BasicBlock *OverflowBB = BasicBlock::Create(I->getContext(), "overflow",
+ I->getFunction(), OverflowResBB);
+ // new blocks should be:
+ // entry:
+ // lhs_lo ne lhs_hi ? overflow_yes_lhs, overflow_no_lhs
+
+ // overflow_yes_lhs:
+ // rhs_lo ne rhs_hi ? overflow : overflow_no_rhs_only
+
+ // overflow_no_lhs:
+ // rhs_lo ne rhs_hi ? overflow_no_lhs_only : overflow_no
+
+ // overflow_no_rhs_only:
+ // overflow_no_lhs_only:
+ // overflow_no:
+ // overflow:
+ // overflow.res:
+
+ IRBuilder<> BuilderEntryBB(OverflowoEntryBB->getTerminator());
+ IRBuilder<> BuilderOverflowLHSBB(OverflowLHSBB);
+ IRBuilder<> BuilderNoOverflowLHSBB(NoOverflowLHSBB);
+ IRBuilder<> BuilderNoOverflowRHSonlyBB(NoOverflowRHSonlyBB);
+ IRBuilder<> BuilderNoOverflowLHSonlyBB(NoOverflowLHSonlyBB);
+ IRBuilder<> BuilderNoOverflowBB(NoOverflowBB);
+ IRBuilder<> BuilderOverflowResBB(OverflowResBB,
+ OverflowResBB->getFirstInsertionPt());
+
+ //------------------------------------------------------------------------------
+ // BB overflow.entry:
+ // get Lo and Hi of RHS & LHS:
+
+ auto *LoRHS = BuilderEntryBB.CreateTrunc(RHS, LegalTy, "lo.rhs");
+ auto *SignLoRHS =
+ BuilderEntryBB.CreateAShr(LoRHS, VTHalfBitWidth - 1, "sign.lo.rhs");
+ auto *HiRHS = BuilderEntryBB.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr");
+ HiRHS = BuilderEntryBB.CreateTrunc(HiRHS, LegalTy, "hi.rhs");
+
+ auto *LoLHS = BuilderEntryBB.CreateTrunc(LHS, LegalTy, "lo.lhs");
+ auto *SignLoLHS =
+ BuilderEntryBB.CreateAShr(LoLHS, VTHalfBitWidth - 1, "sign.lo.lhs");
+ auto *HiLHS = BuilderEntryBB.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr");
+ HiLHS = BuilderEntryBB.CreateTrunc(HiLHS, LegalTy, "hi.lhs");
+
+ auto *Cmp = BuilderEntryBB.CreateCmp(ICmpInst::ICMP_NE, HiLHS, SignLoLHS);
+ BuilderEntryBB.CreateCondBr(Cmp, OverflowLHSBB, NoOverflowLHSBB);
+ OverflowoEntryBB->getTerminator()->eraseFromParent();
+
+ //------------------------------------------------------------------------------
+ // BB overflow_yes_lhs:
+ Cmp = BuilderOverflowLHSBB.CreateCmp(ICmpInst::ICMP_NE, HiRHS, SignLoRHS);
+ BuilderOverflowLHSBB.CreateCondBr(Cmp, OverflowBB, NoOverflowRHSonlyBB);
+
+ //------------------------------------------------------------------------------
+ // BB overflow_no_lhs:
+ Cmp = BuilderNoOverflowLHSBB.CreateCmp(ICmpInst::ICMP_NE, HiRHS, SignLoRHS);
+ BuilderNoOverflowLHSBB.CreateCondBr(Cmp, NoOverflowLHSonlyBB, NoOverflowBB);
+
+ //------------------------------------------------------------------------------
+ // BB overflow_no_rhs_only:
+ // RHS is within 64 value range, LHS is 128
+ // P0 = RHS * LoLHS
+ // P1 = RHS * HiLHS
+
+ // check sign of RHS:
+ auto *IsNegRHS = BuilderNoOverflowRHSonlyBB.CreateIsNeg(RHS, "rhs.isneg");
+ auto *AbsRHSIntr = BuilderNoOverflowRHSonlyBB.CreateBinaryIntrinsic(
+ Intrinsic::abs, RHS, ConstantInt::getFalse(I->getContext()), {},
+ "abs.rhs");
+ auto *AbsRHS = BuilderNoOverflowRHSonlyBB.CreateSelect(
+ IsNegRHS, AbsRHSIntr, RHS, "lo.abs.rhs.select");
+
+ // check sign of LHS:
+ auto *IsNegLHS = BuilderNoOverflowRHSonlyBB.CreateIsNeg(LHS, "lhs.isneg");
+ auto *AbsLHSIntr = BuilderNoOverflowRHSonlyBB.CreateBinaryIntrinsic(
+ Intrinsic::abs, LHS, ConstantInt::getFalse(I->getContext()), {},
+ "abs.lhs");
+ auto *AbsLHS = BuilderNoOverflowRHSonlyBB.CreateSelect(IsNegLHS, AbsLHSIntr,
+ LHS, "abs.lhs.select");
+ LoLHS = BuilderNoOverflowRHSonlyBB.CreateAnd(
+ AbsLHS,
+ ConstantInt::get(Ty, APInt::getLowBitsSet(VTBitWidth, VTHalfBitWidth)),
+ "lo.abs.lhs");
+ HiLHS = BuilderNoOverflowRHSonlyBB.CreateLShr(AbsLHS, VTHalfBitWidth,
+ "hi.abs.lhs");
+
+ // P0 = (RHS * LoLHS)
+ auto *P0 = BuilderNoOverflowRHSonlyBB.CreateMul(AbsRHS, LoLHS,
+ "mul.no.overflow.rhs.lolhs");
+ auto *P0Lo = BuilderNoOverflowRHSonlyBB.CreateTrunc(P0, LegalTy, "p0.lo.rhs");
+ auto *P0Hi =
+ BuilderNoOverflowRHSonlyBB.CreateLShr(P0, VTHalfBitWidth, "p0.rhs.lsr");
+ P0Hi = BuilderNoOverflowRHSonlyBB.CreateTrunc(P0Hi, LegalTy, "p0.hi.rhs");
+
+ // P1 = (RHS * HiLHS)
+ auto *P1 = BuilderNoOverflowRHSonlyBB.CreateMul(AbsRHS, HiLHS,
+ "mul.no.overflow.rhs.hilhs");
+ auto *P1Lo = BuilderNoOverflowRHSonlyBB.CreateTrunc(P1, LegalTy, "p1.lo.rhs");
+ auto *P1Hi =
+ BuilderNoOverflowRHSonlyBB.CreateLShr(P1, VTHalfBitWidth, "p1.rhs.lsr");
+ P1Hi = BuilderNoOverflowRHSonlyBB.CreateTrunc(P1Hi, LegalTy, "p1.hi.rhs");
+
+ auto *AddOverflow = BuilderNoOverflowRHSonlyBB.CreateIntrinsic(
+ Intrinsic::uadd_with_overflow, LegalTy, {P0Hi, P1Lo});
+ auto *AddOResMid = BuilderNoOverflowRHSonlyBB.CreateExtractValue(
+ AddOverflow, 0, "rhs.p0.p1.res");
+ auto *Carry = BuilderNoOverflowRHSonlyBB.CreateExtractValue(
+ AddOverflow, 1, "rhs.p0.p1.carry");
+ Carry =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(Carry, LegalTy, "rhs.carry.zext");
+ auto *ResHi =
+ BuilderNoOverflowRHSonlyBB.CreateAdd(P1Hi, Carry, "rhs.p1.carry");
+
+ // sign handling:
+ auto *IsNeg = BuilderNoOverflowRHSonlyBB.CreateXor(IsNegRHS, IsNegLHS); // i1
+ auto *Mask =
+ BuilderNoOverflowRHSonlyBB.CreateSExt(IsNeg, LegalTy, "rhs.sign.mask");
+ auto *Add_1 =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(IsNeg, LegalTy, "rhs.add.1");
+ auto *ResLo =
+ BuilderNoOverflowRHSonlyBB.CreateXor(P0Lo, Mask, "rhs.res_lo.xor.mask");
+ ResLo =
+ BuilderNoOverflowRHSonlyBB.CreateAdd(ResLo, Add_1, "rhs.res_lo.add.1");
+
+ Carry = BuilderNoOverflowRHSonlyBB.CreateCmp(ICmpInst::ICMP_ULT, ResLo, Add_1,
+ "rhs.check.res_lo.c...
[truncated]
|
|
@llvm/pr-subscribers-backend-risc-v Author: Hassnaa Hamdi (hassnaaHamdi) Changes
Patch is 676.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148343.diff 20 Files Affected:
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 9bbb89e37865d..d9859ed246604 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -431,6 +431,8 @@ class CodeGenPrepare {
bool optimizeMemoryInst(Instruction *MemoryInst, Value *Addr, Type *AccessTy,
unsigned AddrSpace);
bool optimizeGatherScatterInst(Instruction *MemoryInst, Value *Ptr);
+ bool optimizeUMulWithOverflow(Instruction *I);
+ bool optimizeSMulWithOverflow(Instruction *I);
bool optimizeInlineAsmInst(CallInst *CS);
bool optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT);
bool optimizeExt(Instruction *&I);
@@ -2769,6 +2771,10 @@ bool CodeGenPrepare::optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT) {
return optimizeGatherScatterInst(II, II->getArgOperand(0));
case Intrinsic::masked_scatter:
return optimizeGatherScatterInst(II, II->getArgOperand(1));
+ case Intrinsic::umul_with_overflow:
+ return optimizeUMulWithOverflow(II);
+ case Intrinsic::smul_with_overflow:
+ return optimizeSMulWithOverflow(II);
}
SmallVector<Value *, 2> PtrOps;
@@ -6386,6 +6392,573 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
return true;
}
+// Rewrite the umul_with_overflow intrinsic by checking if any/both of the
+// operands' value range is within the legal type. If so, we can optimize the
+// multiplication algorithm. This code is supposed to be written during the step
+// of type legalization, but given that we need to reconstruct the IR which is
+// not doable there, we do it here.
+bool CodeGenPrepare::optimizeUMulWithOverflow(Instruction *I) {
+ if (TLI->getTypeAction(
+ I->getContext(),
+ TLI->getValueType(*DL, I->getType()->getContainedType(0))) !=
+ TargetLowering::TypeExpandInteger)
+ return false;
+ Value *LHS = I->getOperand(0);
+ Value *RHS = I->getOperand(1);
+ auto *Ty = LHS->getType();
+ unsigned VTBitWidth = Ty->getScalarSizeInBits();
+ unsigned VTHalfBitWidth = VTBitWidth / 2;
+ auto *LegalTy = IntegerType::getIntNTy(I->getContext(), VTHalfBitWidth);
+
+ assert(
+ (TLI->getTypeAction(I->getContext(), TLI->getValueType(*DL, LegalTy)) ==
+ TargetLowering::TypeLegal) &&
+ "Expected the type to be legal for the target lowering");
+
+ I->getParent()->setName("overflow.res");
+ auto *OverflowResBB = I->getParent();
+ auto *OverflowoEntryBB =
+ I->getParent()->splitBasicBlock(I, "overflow.entry", /*Before*/ true);
+ BasicBlock *OverflowLHSBB = BasicBlock::Create(
+ I->getContext(), "overflow.lhs", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowLHSBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.lhs", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowRHSonlyBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.rhs.only", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowLHSonlyBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.lhs.only", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowBB = BasicBlock::Create(
+ I->getContext(), "overflow.no", I->getFunction(), OverflowResBB);
+ BasicBlock *OverflowBB = BasicBlock::Create(I->getContext(), "overflow",
+ I->getFunction(), OverflowResBB);
+ // new blocks should be:
+ // entry:
+ // lhs_lo ne lhs_hi ? overflow_yes_lhs, overflow_no_lhs
+
+ // overflow_yes_lhs:
+ // rhs_lo ne rhs_hi ? overflow : overflow_no_rhs_only
+
+ // overflow_no_lhs:
+ // rhs_lo ne rhs_hi ? overflow_no_lhs_only : overflow_no
+
+ // overflow_no_rhs_only:
+ // overflow_no_lhs_only:
+ // overflow_no:
+ // overflow:
+ // overflow.res:
+
+ IRBuilder<> BuilderEntryBB(OverflowoEntryBB->getTerminator());
+ IRBuilder<> BuilderOverflowLHSBB(OverflowLHSBB);
+ IRBuilder<> BuilderNoOverflowLHSBB(NoOverflowLHSBB);
+ IRBuilder<> BuilderNoOverflowRHSonlyBB(NoOverflowRHSonlyBB);
+ IRBuilder<> BuilderNoOverflowLHSonlyBB(NoOverflowLHSonlyBB);
+ IRBuilder<> BuilderNoOverflowBB(NoOverflowBB);
+ IRBuilder<> BuilderOverflowResBB(OverflowResBB,
+ OverflowResBB->getFirstInsertionPt());
+
+ //------------------------------------------------------------------------------
+ // BB overflow.entry:
+ // get Lo and Hi of RHS & LHS:
+
+ auto *LoRHS = BuilderEntryBB.CreateTrunc(RHS, LegalTy, "lo.rhs.trunc");
+ auto *ShrHiRHS = BuilderEntryBB.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr");
+ auto *HiRHS = BuilderEntryBB.CreateTrunc(ShrHiRHS, LegalTy, "hi.rhs.trunc");
+
+ auto *LoLHS = BuilderEntryBB.CreateTrunc(LHS, LegalTy, "lo.lhs.trunc");
+ auto *ShrHiLHS = BuilderEntryBB.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr");
+ auto *HiLHS = BuilderEntryBB.CreateTrunc(ShrHiLHS, LegalTy, "hi.lhs.trunc");
+
+ auto *Cmp = BuilderEntryBB.CreateCmp(ICmpInst::ICMP_NE, HiLHS,
+ ConstantInt::getNullValue(LegalTy));
+ BuilderEntryBB.CreateCondBr(Cmp, OverflowLHSBB, NoOverflowLHSBB);
+ OverflowoEntryBB->getTerminator()->eraseFromParent();
+
+ //------------------------------------------------------------------------------
+ // BB overflow_yes_lhs:
+ Cmp = BuilderOverflowLHSBB.CreateCmp(ICmpInst::ICMP_NE, HiRHS,
+ ConstantInt::getNullValue(LegalTy));
+ BuilderOverflowLHSBB.CreateCondBr(Cmp, OverflowBB, NoOverflowRHSonlyBB);
+
+ //------------------------------------------------------------------------------
+ // BB overflow_no_lhs:
+ Cmp = BuilderNoOverflowLHSBB.CreateCmp(ICmpInst::ICMP_NE, HiRHS,
+ ConstantInt::getNullValue(LegalTy));
+ BuilderNoOverflowLHSBB.CreateCondBr(Cmp, NoOverflowLHSonlyBB, NoOverflowBB);
+
+ //------------------------------------------------------------------------------
+ // BB overflow_no_rhs_only:
+ // RHS is 64 value range, LHS is 128
+ // P0 = RHS * LoLHS
+ // P1 = RHS * HiLHS
+
+ LoLHS = BuilderNoOverflowRHSonlyBB.CreateZExt(LoLHS, Ty, "lo.lhs");
+
+ // P0 = (RHS * LoLHS)
+ auto *P0 = BuilderNoOverflowRHSonlyBB.CreateMul(RHS, LoLHS,
+ "mul.no.overflow.rhs.lolhs");
+ auto *P0Lo = BuilderNoOverflowRHSonlyBB.CreateTrunc(P0, LegalTy, "p0.lo.rhs");
+ auto *P0Hi =
+ BuilderNoOverflowRHSonlyBB.CreateLShr(P0, VTHalfBitWidth, "p0.rhs.lsr");
+ P0Hi = BuilderNoOverflowRHSonlyBB.CreateTrunc(P0Hi, LegalTy, "p0.hi.rhs");
+
+ // P1 = (RHS * HiLHS)
+ auto *P1 = BuilderNoOverflowRHSonlyBB.CreateMul(RHS, ShrHiLHS,
+ "mul.no.overflow.rhs.hilhs");
+ auto *P1Lo = BuilderNoOverflowRHSonlyBB.CreateTrunc(P1, LegalTy, "p1.lo.rhs");
+ auto *P1Hi =
+ BuilderNoOverflowRHSonlyBB.CreateLShr(P1, VTHalfBitWidth, "p1.rhs.lsr");
+ P1Hi = BuilderNoOverflowRHSonlyBB.CreateTrunc(P1Hi, LegalTy, "p1.hi.rhs");
+
+ auto *AddOverflow = BuilderNoOverflowRHSonlyBB.CreateIntrinsic(
+ Intrinsic::uadd_with_overflow, LegalTy, {P0Hi, P1Lo});
+ auto *AddOResMid = BuilderNoOverflowRHSonlyBB.CreateExtractValue(
+ AddOverflow, 0, "rhs.p0.p1.res");
+ auto *Carry = BuilderNoOverflowRHSonlyBB.CreateExtractValue(
+ AddOverflow, 1, "rhs.p0.p1.carry");
+ Carry =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(Carry, LegalTy, "rhs.carry.zext");
+ auto *ResHi =
+ BuilderNoOverflowRHSonlyBB.CreateAdd(P1Hi, Carry, "rhs.p1.carry");
+
+ auto *ResLoEx =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(P0Lo, Ty, "rhs.res_lo.zext");
+ auto *ResMid =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(AddOResMid, Ty, "rhs.res_mid.zext");
+ auto *ResMidShl = BuilderNoOverflowRHSonlyBB.CreateShl(ResMid, VTHalfBitWidth,
+ "rhs.res_mid.shl");
+ auto *FinalRes = BuilderNoOverflowRHSonlyBB.CreateOr(ResLoEx, ResMidShl,
+ "rhs.res_lo.or.mid");
+ auto *IsOverflow = BuilderNoOverflowRHSonlyBB.CreateICmp(
+ ICmpInst::ICMP_NE, ResHi, Constant::getNullValue(LegalTy),
+ "rhs.check.overflow");
+
+ StructType *STy = StructType::get(
+ I->getContext(), {Ty, IntegerType::getInt1Ty(I->getContext())});
+ Value *StructValNoOverflowRHS = PoisonValue::get(STy);
+ StructValNoOverflowRHS = BuilderNoOverflowRHSonlyBB.CreateInsertValue(
+ StructValNoOverflowRHS, FinalRes, {0});
+ StructValNoOverflowRHS = BuilderNoOverflowRHSonlyBB.CreateInsertValue(
+ StructValNoOverflowRHS, IsOverflow, {1});
+ BuilderNoOverflowRHSonlyBB.CreateBr(OverflowResBB);
+ //------------------------------------------------------------------------------
+
+ // BB overflow_no_lhs_only:
+
+ LoRHS = BuilderNoOverflowLHSonlyBB.CreateZExt(LoRHS, Ty, "lo.rhs");
+
+ // P0 = (LHS * LoRHS)
+ P0 = BuilderNoOverflowLHSonlyBB.CreateMul(LHS, LoRHS,
+ "mul.no.overflow.lhs.lorhs");
+ P0Lo = BuilderNoOverflowLHSonlyBB.CreateTrunc(P0, LegalTy, "p0.lo.lhs");
+ P0Hi =
+ BuilderNoOverflowLHSonlyBB.CreateLShr(P0, VTHalfBitWidth, "p0.lsr.lhs");
+ P0Hi = BuilderNoOverflowLHSonlyBB.CreateTrunc(P0Hi, LegalTy, "p0.hi.lhs");
+
+ // P1 = (LHS * HiRHS)
+ P1 = BuilderNoOverflowLHSonlyBB.CreateMul(LHS, ShrHiRHS,
+ "mul.no.overflow.lhs.hirhs");
+ P1Lo = BuilderNoOverflowLHSonlyBB.CreateTrunc(P1, LegalTy, "p1.lo.lhs");
+ P1Hi =
+ BuilderNoOverflowLHSonlyBB.CreateLShr(P1, VTHalfBitWidth, "p1.lhs.lsr");
+ P1Hi = BuilderNoOverflowLHSonlyBB.CreateTrunc(P1Hi, LegalTy, "p1.hi.lhs");
+
+ AddOverflow = BuilderNoOverflowLHSonlyBB.CreateIntrinsic(
+ Intrinsic::uadd_with_overflow, LegalTy, {P0Hi, P1Lo});
+ AddOResMid = BuilderNoOverflowLHSonlyBB.CreateExtractValue(AddOverflow, 0,
+ "lhs.p0.p1.res");
+ Carry = BuilderNoOverflowLHSonlyBB.CreateExtractValue(AddOverflow, 1,
+ "lhs.p0.p1.carry");
+ Carry =
+ BuilderNoOverflowLHSonlyBB.CreateZExt(Carry, LegalTy, "lhs.carry.zext");
+ ResHi = BuilderNoOverflowLHSonlyBB.CreateAdd(P1Hi, Carry, "lhs.p1.carry");
+
+ ResLoEx = BuilderNoOverflowLHSonlyBB.CreateZExt(P0Lo, Ty, "lhs.res_lo.zext");
+ ResMid =
+ BuilderNoOverflowLHSonlyBB.CreateZExt(AddOResMid, Ty, "lhs.res_mid.zext");
+ ResMidShl = BuilderNoOverflowLHSonlyBB.CreateShl(ResMid, VTHalfBitWidth,
+ "lhs.res_mid.shl");
+ FinalRes = BuilderNoOverflowLHSonlyBB.CreateOr(ResLoEx, ResMidShl,
+ "lhs.res_lo.or.mid");
+ IsOverflow = BuilderNoOverflowLHSonlyBB.CreateICmp(
+ ICmpInst::ICMP_NE, ResHi, Constant::getNullValue(LegalTy),
+ "lhs.check.overflow");
+
+ STy = StructType::get(I->getContext(),
+ {Ty, IntegerType::getInt1Ty(I->getContext())});
+ Value *StructValNoOverflowLHS = PoisonValue::get(STy);
+ StructValNoOverflowLHS = BuilderNoOverflowLHSonlyBB.CreateInsertValue(
+ StructValNoOverflowLHS, FinalRes, {0});
+ StructValNoOverflowLHS = BuilderNoOverflowLHSonlyBB.CreateInsertValue(
+ StructValNoOverflowLHS, IsOverflow, {1});
+
+ BuilderNoOverflowLHSonlyBB.CreateBr(OverflowResBB);
+ //------------------------------------------------------------------------------
+
+ // BB overflow.no:
+ auto *Mul = BuilderNoOverflowBB.CreateMul(LHS, RHS, "mul.no.overflow");
+ STy = StructType::get(I->getContext(),
+ {Ty, IntegerType::getInt1Ty(I->getContext())});
+ Value *StructValNoOverflow = PoisonValue::get(STy);
+ StructValNoOverflow =
+ BuilderNoOverflowBB.CreateInsertValue(StructValNoOverflow, Mul, {0});
+ StructValNoOverflow = BuilderNoOverflowBB.CreateInsertValue(
+ StructValNoOverflow, ConstantInt::getFalse(I->getContext()), {1});
+ BuilderNoOverflowBB.CreateBr(OverflowResBB);
+
+ // BB overflow.res:
+ auto *PHINode = BuilderOverflowResBB.CreatePHI(STy, 2);
+ PHINode->addIncoming(StructValNoOverflow, NoOverflowBB);
+ PHINode->addIncoming(StructValNoOverflowLHS, NoOverflowLHSonlyBB);
+ PHINode->addIncoming(StructValNoOverflowRHS, NoOverflowRHSonlyBB);
+
+ // Before moving the mul.overflow intrinsic to the overflowBB, replace all its
+ // uses by PHINode.
+ I->replaceAllUsesWith(PHINode);
+
+ // BB overflow:
+ PHINode->addIncoming(I, OverflowBB);
+ I->removeFromParent();
+ I->insertInto(OverflowBB, OverflowBB->end());
+ IRBuilder<>(OverflowBB, OverflowBB->end()).CreateBr(OverflowResBB);
+
+ // return false to stop reprocessing the function.
+ return false;
+}
+
+// Rewrite the smul_with_overflow intrinsic by checking if any/both of the
+// operands' value range is within the legal type. If so, we can optimize the
+// multiplication algorithm. This code is supposed to be written during the step
+// of type legalization, but given that we need to reconstruct the IR which is
+// not doable there, we do it here.
+bool CodeGenPrepare::optimizeSMulWithOverflow(Instruction *I) {
+ if (TLI->getTypeAction(
+ I->getContext(),
+ TLI->getValueType(*DL, I->getType()->getContainedType(0))) !=
+ TargetLowering::TypeExpandInteger)
+ return false;
+ Value *LHS = I->getOperand(0);
+ Value *RHS = I->getOperand(1);
+ auto *Ty = LHS->getType();
+ unsigned VTBitWidth = Ty->getScalarSizeInBits();
+ unsigned VTHalfBitWidth = VTBitWidth / 2;
+ auto *LegalTy = IntegerType::getIntNTy(I->getContext(), VTHalfBitWidth);
+
+ assert(
+ (TLI->getTypeAction(I->getContext(), TLI->getValueType(*DL, LegalTy)) ==
+ TargetLowering::TypeLegal) &&
+ "Expected the type to be legal for the target lowering");
+
+ I->getParent()->setName("overflow.res");
+ auto *OverflowResBB = I->getParent();
+ auto *OverflowoEntryBB =
+ I->getParent()->splitBasicBlock(I, "overflow.entry", /*Before*/ true);
+ BasicBlock *OverflowLHSBB = BasicBlock::Create(
+ I->getContext(), "overflow.lhs", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowLHSBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.lhs", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowRHSonlyBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.rhs.only", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowLHSonlyBB = BasicBlock::Create(
+ I->getContext(), "overflow.no.lhs.only", I->getFunction(), OverflowResBB);
+ BasicBlock *NoOverflowBB = BasicBlock::Create(
+ I->getContext(), "overflow.no", I->getFunction(), OverflowResBB);
+ BasicBlock *OverflowBB = BasicBlock::Create(I->getContext(), "overflow",
+ I->getFunction(), OverflowResBB);
+ // new blocks should be:
+ // entry:
+ // lhs_lo ne lhs_hi ? overflow_yes_lhs, overflow_no_lhs
+
+ // overflow_yes_lhs:
+ // rhs_lo ne rhs_hi ? overflow : overflow_no_rhs_only
+
+ // overflow_no_lhs:
+ // rhs_lo ne rhs_hi ? overflow_no_lhs_only : overflow_no
+
+ // overflow_no_rhs_only:
+ // overflow_no_lhs_only:
+ // overflow_no:
+ // overflow:
+ // overflow.res:
+
+ IRBuilder<> BuilderEntryBB(OverflowoEntryBB->getTerminator());
+ IRBuilder<> BuilderOverflowLHSBB(OverflowLHSBB);
+ IRBuilder<> BuilderNoOverflowLHSBB(NoOverflowLHSBB);
+ IRBuilder<> BuilderNoOverflowRHSonlyBB(NoOverflowRHSonlyBB);
+ IRBuilder<> BuilderNoOverflowLHSonlyBB(NoOverflowLHSonlyBB);
+ IRBuilder<> BuilderNoOverflowBB(NoOverflowBB);
+ IRBuilder<> BuilderOverflowResBB(OverflowResBB,
+ OverflowResBB->getFirstInsertionPt());
+
+ //------------------------------------------------------------------------------
+ // BB overflow.entry:
+ // get Lo and Hi of RHS & LHS:
+
+ auto *LoRHS = BuilderEntryBB.CreateTrunc(RHS, LegalTy, "lo.rhs");
+ auto *SignLoRHS =
+ BuilderEntryBB.CreateAShr(LoRHS, VTHalfBitWidth - 1, "sign.lo.rhs");
+ auto *HiRHS = BuilderEntryBB.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr");
+ HiRHS = BuilderEntryBB.CreateTrunc(HiRHS, LegalTy, "hi.rhs");
+
+ auto *LoLHS = BuilderEntryBB.CreateTrunc(LHS, LegalTy, "lo.lhs");
+ auto *SignLoLHS =
+ BuilderEntryBB.CreateAShr(LoLHS, VTHalfBitWidth - 1, "sign.lo.lhs");
+ auto *HiLHS = BuilderEntryBB.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr");
+ HiLHS = BuilderEntryBB.CreateTrunc(HiLHS, LegalTy, "hi.lhs");
+
+ auto *Cmp = BuilderEntryBB.CreateCmp(ICmpInst::ICMP_NE, HiLHS, SignLoLHS);
+ BuilderEntryBB.CreateCondBr(Cmp, OverflowLHSBB, NoOverflowLHSBB);
+ OverflowoEntryBB->getTerminator()->eraseFromParent();
+
+ //------------------------------------------------------------------------------
+ // BB overflow_yes_lhs:
+ Cmp = BuilderOverflowLHSBB.CreateCmp(ICmpInst::ICMP_NE, HiRHS, SignLoRHS);
+ BuilderOverflowLHSBB.CreateCondBr(Cmp, OverflowBB, NoOverflowRHSonlyBB);
+
+ //------------------------------------------------------------------------------
+ // BB overflow_no_lhs:
+ Cmp = BuilderNoOverflowLHSBB.CreateCmp(ICmpInst::ICMP_NE, HiRHS, SignLoRHS);
+ BuilderNoOverflowLHSBB.CreateCondBr(Cmp, NoOverflowLHSonlyBB, NoOverflowBB);
+
+ //------------------------------------------------------------------------------
+ // BB overflow_no_rhs_only:
+ // RHS is within 64 value range, LHS is 128
+ // P0 = RHS * LoLHS
+ // P1 = RHS * HiLHS
+
+ // check sign of RHS:
+ auto *IsNegRHS = BuilderNoOverflowRHSonlyBB.CreateIsNeg(RHS, "rhs.isneg");
+ auto *AbsRHSIntr = BuilderNoOverflowRHSonlyBB.CreateBinaryIntrinsic(
+ Intrinsic::abs, RHS, ConstantInt::getFalse(I->getContext()), {},
+ "abs.rhs");
+ auto *AbsRHS = BuilderNoOverflowRHSonlyBB.CreateSelect(
+ IsNegRHS, AbsRHSIntr, RHS, "lo.abs.rhs.select");
+
+ // check sign of LHS:
+ auto *IsNegLHS = BuilderNoOverflowRHSonlyBB.CreateIsNeg(LHS, "lhs.isneg");
+ auto *AbsLHSIntr = BuilderNoOverflowRHSonlyBB.CreateBinaryIntrinsic(
+ Intrinsic::abs, LHS, ConstantInt::getFalse(I->getContext()), {},
+ "abs.lhs");
+ auto *AbsLHS = BuilderNoOverflowRHSonlyBB.CreateSelect(IsNegLHS, AbsLHSIntr,
+ LHS, "abs.lhs.select");
+ LoLHS = BuilderNoOverflowRHSonlyBB.CreateAnd(
+ AbsLHS,
+ ConstantInt::get(Ty, APInt::getLowBitsSet(VTBitWidth, VTHalfBitWidth)),
+ "lo.abs.lhs");
+ HiLHS = BuilderNoOverflowRHSonlyBB.CreateLShr(AbsLHS, VTHalfBitWidth,
+ "hi.abs.lhs");
+
+ // P0 = (RHS * LoLHS)
+ auto *P0 = BuilderNoOverflowRHSonlyBB.CreateMul(AbsRHS, LoLHS,
+ "mul.no.overflow.rhs.lolhs");
+ auto *P0Lo = BuilderNoOverflowRHSonlyBB.CreateTrunc(P0, LegalTy, "p0.lo.rhs");
+ auto *P0Hi =
+ BuilderNoOverflowRHSonlyBB.CreateLShr(P0, VTHalfBitWidth, "p0.rhs.lsr");
+ P0Hi = BuilderNoOverflowRHSonlyBB.CreateTrunc(P0Hi, LegalTy, "p0.hi.rhs");
+
+ // P1 = (RHS * HiLHS)
+ auto *P1 = BuilderNoOverflowRHSonlyBB.CreateMul(AbsRHS, HiLHS,
+ "mul.no.overflow.rhs.hilhs");
+ auto *P1Lo = BuilderNoOverflowRHSonlyBB.CreateTrunc(P1, LegalTy, "p1.lo.rhs");
+ auto *P1Hi =
+ BuilderNoOverflowRHSonlyBB.CreateLShr(P1, VTHalfBitWidth, "p1.rhs.lsr");
+ P1Hi = BuilderNoOverflowRHSonlyBB.CreateTrunc(P1Hi, LegalTy, "p1.hi.rhs");
+
+ auto *AddOverflow = BuilderNoOverflowRHSonlyBB.CreateIntrinsic(
+ Intrinsic::uadd_with_overflow, LegalTy, {P0Hi, P1Lo});
+ auto *AddOResMid = BuilderNoOverflowRHSonlyBB.CreateExtractValue(
+ AddOverflow, 0, "rhs.p0.p1.res");
+ auto *Carry = BuilderNoOverflowRHSonlyBB.CreateExtractValue(
+ AddOverflow, 1, "rhs.p0.p1.carry");
+ Carry =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(Carry, LegalTy, "rhs.carry.zext");
+ auto *ResHi =
+ BuilderNoOverflowRHSonlyBB.CreateAdd(P1Hi, Carry, "rhs.p1.carry");
+
+ // sign handling:
+ auto *IsNeg = BuilderNoOverflowRHSonlyBB.CreateXor(IsNegRHS, IsNegLHS); // i1
+ auto *Mask =
+ BuilderNoOverflowRHSonlyBB.CreateSExt(IsNeg, LegalTy, "rhs.sign.mask");
+ auto *Add_1 =
+ BuilderNoOverflowRHSonlyBB.CreateZExt(IsNeg, LegalTy, "rhs.add.1");
+ auto *ResLo =
+ BuilderNoOverflowRHSonlyBB.CreateXor(P0Lo, Mask, "rhs.res_lo.xor.mask");
+ ResLo =
+ BuilderNoOverflowRHSonlyBB.CreateAdd(ResLo, Add_1, "rhs.res_lo.add.1");
+
+ Carry = BuilderNoOverflowRHSonlyBB.CreateCmp(ICmpInst::ICMP_ULT, ResLo, Add_1,
+ "rhs.check.res_lo.c...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the 2 instruction of madd, there should be extra patch to fold them away as x3 and x1 are expected to be 0.
nikic
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I question the value of this optimization, especially as implemented right now. This might (and I'm really not convinced) make more sense if you only specialize for the case of both high halves being zero without covering mixed cases, which blows up code size a lot.
|
I don't think that's a good idea as this is currently implemented and I'd like to see some benchmark results. In our database system, we make heavy use of 128-bit mul-with-overflow. Having spent some time optimizing this specific case:
For x86-64, the full out-of-line function is hand-optimized assembly code, which performs slightly better than LLVM's expansion (e.g., uses one less register, exploits some x86-specific flags tricks, and has optimized scheduling for some recent uArches) (also note: GCC's expansion is (was?) horrible with lots of data-dependent branches). For AArch64, we use the compiler-rt function, we haven't felt the need to look closer into this so far (most of our benchmarks target x86-64 :-) ). |
|
FWIW, the reason why LLVM expands this inline and does not use the compiler-rt builtin is that we, unfortunately, have to cater to the lowest common denominator, which is libgcc, which does not support the __muloti4 builtin :( |
davemgreen
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran an experiment a little while ago, about 6 months ago now. It was to try and measure what the average cost of a divide should be, given the distributions found in real code. It used dynamorio to interrupt the program whenever a divide was found and print the numerator and denominator. In doing that that more that 20%, across all of the llvm-test-suite + 3 x spec + some other benchmarks, about 22% of all divides were 1 divided by 1.
So I can fully imagine that i128 smulo are often biased towards low values and it is quite a bit more efficient to check for the upper halfs being zero and jump straight to the "no overflow". GCC apparently considered it useful enough to implement and we see cases where it is performing better from using the expanded form. It sounds like we should maybe make it opt-in by the target for each type as it might be better/worse depending on the relative performance (like whether there is a mulo instruction that sets flags or requires a libcall anyway). The asymmetric case that GCC runs can also probably be removed to cut down on the codesize, and it looks like the signed case is more beneficial than the unsigned case. It might be worth focussing there to begin with.
llvm/lib/CodeGen/CodeGenPrepare.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can just have one Builder and set the insertion point with SetInsertPoint.
|
Hi, |
aengelke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The generated code could be better. I still would like to see some actual (real) benchmark results as well as some data on the performance impact of only-slow-case/mixed-with-many-branch-misses.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be here, x1 and x3 are known to be zero at this point. The bic below also shouldn't be executed on the fast path. The fast path really should only have:
umulh x1, x0, x2
mul x0, x0, x2
mov w2, wzr // or mov w2, 0
Likewise on many other occasions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
About the bic, it's in the overflow.res BB, not the fast path, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overflow.res is executed in both paths and therefore also on the fast path -- for which the instruction is unnecessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to fold these into a single branch, maybe by reducing the number of covered cases a bit further (e.g., only positive numbers)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? is it really worth reducing the cases ? I think this should be postponed until I test its performance? And also the overall performance as you asked earlier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Consider the case where the first operand is sometimes small and sometimes large but the second operand is always large. With this implementation, the first branch is likely mispredicted, although in the end it's always the slow path that gets executed.
- Many out-of-order CPUs (e.g., recent Apple CPUs) can't decode across branches in the same cycle. Having just two instructions before a conditional branch reduces the throughput of the front-end.
- I'm generally worried about introducing data-dependent branches that were not present in the original code, these can hurt performance when they are mispredicted frequently. We already make it hard for users to get branchless code if they want to (e.g., if they know that the condition is unpredictable) and we shouldn't add branches unless there's an extremely good reason.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I see. Thanks for the clarification :))
What about the new proposed solution instead of reducing the covered cases ?
64c9e33 to
1a234ec
Compare
llvm/lib/CodeGen/CodeGenPrepare.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't create PHIs of aggregate types, create two separate PHIs instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
llvm/lib/CodeGen/CodeGenPrepare.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a target hook where the target can opt-in, per Type. Something like shouldFormOverflowOp.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? There is already a check testing if the target requires expanding the type, and if the new type is legal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't have something in CodeGenPrepare that says "if (isAArch64())". It doesn't scale or create clean interfaces when there are many targets.
llvm/lib/CodeGen/CodeGenPrepare.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One of the advantages of this method is that when we hit the fast-path we know that is will never overflow and we can jump straight to the "true" block without the extra conditional instruction. Can we implement something like that here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not guarantee that there always will be an overflow check, that's why here I just replace all uses with the new instruction.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there is no overflow check then we can use a simple multiply. The overflow could be used by something other than a branch, but if it is it should be useful to jump-thread them directly as it drops a conditional branch from the fast path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
llvm/lib/CodeGen/CodeGenPrepare.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the commented out code you have is better. It should be something like ahi != alo >> 64 || bhi != blo >> 63, and ideally turn into some ccmps:
asr x8, x0, #63
cmp x3, x2, asr #63
ccmp x1, x8, #0, eq
b.eq ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
During SDAG, the Br instruction gets lowered to sequence of branches when the conditions are Or'd or And'd.
So I think the current code is good ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it producing this at the moment?
; CHECK-NEXT: eor x8, x3, x2, asr #63
; CHECK-NEXT: eor x9, x1, x0, asr #63
; CHECK-NEXT: orr x8, x9, x8
; CHECK-NEXT: cmp x8, #1
; CHECK-NEXT: b.ne .LBB5_2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now it's producing:
; CHECK-NEXT: eor x8, x3, x2, asr #63
; CHECK-NEXT: eor x9, x1, x0, asr #63
; CHECK-NEXT: orr x8, x9, x8
; CHECK-NEXT: cbz x8, .LBB5_3
|
Hi @aengelke |
llvm/lib/CodeGen/CodeGenPrepare.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't have something in CodeGenPrepare that says "if (isAArch64())". It doesn't scale or create clean interfaces when there are many targets.
llvm/lib/CodeGen/CodeGenPrepare.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we try and change this to detect if the instruction is already in a block guarded by whether the top bits are zero? That would help catch other cases too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be an unconditional branch, w8 is known to be zero here. Maybe CGP is too late and we want another InstCombine/SimplifyCFG after this transformation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
|
Re benchmarks: I can look whether we have something usable, but if not, micro-benchmark numbers showing the regression/improvement for different distributions (and predictability) of small/large numbers should be ok. |
- Detect cases where LHS & RHS values will not cause overflow (when the Hi parts are zero). - Detect cases where either of LHS or RHS values could not cause overflow (when one of the Hi parts is zero).
- Enable optimization for AArch64 only. - Optimize only when both LHS, RHS value range are within legal type. - Use a single Builder Change-Id: I11d674440364594e4bca839495036975cd403aa5
Change-Id: Ib0619bde982a8d2a5eba889e12c9412705afebee
use multiple PHIs of scalar types instead of aggregate type. Change-Id: Ie6bc78eda41f454e9edeea7b3bf2c21da1a89693
For the simple case where IR just checks the overflow, skip the check when we're sure that there is no overflow.
Change-Id: I4afe203a6cedb0134812143e7211ca9e80ce6687
64046b1 to
0610edf
Compare
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
stop reprocessing, and instead use the current list of 'InsertedInsts' to keep track of processed instructions. The work of detecting the pattern needs to detect 4 patterns for cases when one/both of the parameters are constant. So the new solution is simpler and more secure.
llvm/lib/CodeGen/CodeGenPrepare.cpp
Outdated
| // not doable there, we do it here. | ||
| bool CodeGenPrepare::optimizeMulWithOverflow(Instruction *I, bool IsSigned, | ||
| ModifyDT &ModifiedDT) { | ||
| if (!TLI->shouldOptimizeMulOverflowIntrinsic()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would pass the type into the shouldOptimizeMulOverflowIntrinsic, and have the target return whether it should expand for the given type. Some of the legality checks bellow could then be removed from here.
llvm/lib/CodeGen/CodeGenPrepare.cpp
Outdated
| return false; | ||
|
|
||
| // Make sure that the I->getType() is a struct type with two elements. | ||
| if (!I->getType()->isStructTy() || I->getType()->getStructNumElements() != 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should always be true?
llvm/lib/CodeGen/CodeGenPrepare.cpp
Outdated
| // should be: | ||
| // entry: | ||
| // if signed: | ||
| // (lhs_lo ^ lhs_hi) || (rhs_lo ^ rhs_hi) ? overflow, overflow_no |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does there needs to be a >> BW-1 on the lhs_lo/rhs_lo?
llvm/lib/CodeGen/CodeGenPrepare.cpp
Outdated
| // Look for the pattern in the users of I, and make sure that all the users | ||
| // are either part of the pattern or NOT in the same BB as I. | ||
| for (User *U : I->users()) { | ||
| if (auto *Instr = dyn_cast<Instruction>(U); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be worth grabbing the 2 users of the MulWithOverflow (they should be the two extracts, if there are more bail out). Then we can replace the first with the result of the mul (through the new phi), and the overflow with the second part. That way we do not need to create the InsertValue.
The case where the overflow bit is used in a branch needs to make sure there are no other instruction between the mul and the branch (other than debug instructions). Otherwise they might need to be duplicated.
A User of an instruction will always be an Instruction.
llvm/lib/CodeGen/CodeGenPrepare.cpp
Outdated
| ExtUser->getIndices()[0] == 1) { | ||
| if (auto *Br = dyn_cast<BranchInst>(*ExtUser->user_begin())) { | ||
| DetectNoOverflowBrBB = Br->getSuccessor(1) /*if.end*/; | ||
| continue; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
continue->break, if we found what we are interested in?
llvm/lib/CodeGen/CodeGenPrepare.cpp
Outdated
| // If we come here, it means that either the pattern doesn't exist or | ||
| // there are multiple users in the same BB | ||
| DetectNoOverflowBrBB = nullptr; | ||
| break; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is needed, and we can look through all the users. There should only be 2 but we might need to look at the second.
| // for the given \p VT. | ||
| bool shouldOptimizeMulOverflowIntrinsic(LLVMContext &Context, | ||
| EVT VT) const override { | ||
| return getTypeAction(Context, VT) == TypeExpandInteger; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this just mean VT == MVT::i64? Does it apply to vector types? It might be simpler to be explicit.
|
|
||
| // Return true if the target wants to optimize the mul overflow intrinsic | ||
| // for the given \p VT. | ||
| virtual bool shouldOptimizeMulOverflowIntrinsic(LLVMContext &Context, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we need a little more descriptive of a name - something like shouldOptimizeMulOverflowIntrinsicWithHighHalf maybe.
llvm/lib/CodeGen/CodeGenPrepare.cpp
Outdated
|
|
||
| // Skip the optimization if the type with HalfBitWidth is not legal for the | ||
| // target. | ||
| if (TLI->getTypeAction(I->getContext(), TLI->getValueType(*DL, LegalTy)) != |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be removed now?
llvm/lib/CodeGen/CodeGenPrepare.cpp
Outdated
| return false; | ||
|
|
||
| // Check the pattern we are interested in where there are maximum 2 uses | ||
| // of the intrinsic which are the extracts instructions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
extract instructions
llvm/lib/CodeGen/CodeGenPrepare.cpp
Outdated
|
|
||
| // Keep track of the instruction to stop reoptimizing it again. | ||
| InsertedInsts.insert(I); | ||
| // ---------------------------- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this ---- line be removed?
llvm/lib/CodeGen/CodeGenPrepare.cpp
Outdated
| } | ||
| } | ||
| if (NoOverflowBrBB) { | ||
| // Duplicate instructions from I's BB to the NoOverflowBB: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a limit on the number of instructions we should duplicate here? It could be quite a few. There are certain cannotDuplicate instructions to watch out for too.
llvm/lib/CodeGen/CodeGenPrepare.cpp
Outdated
| } | ||
| } | ||
| } | ||
| if (!PN) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to check that PN was the MulExtract or something? We could be checking multiple unrelated phis above.
And could OverflowExtract have other uses?
|
If you wanted to try and keep the first patch simpler and add the instruction duplication in a follow-on, that might help with the first patch. Up to you what you think. |
…tion as it's common for using the intrinsic Change-Id: I84ad52b4fd326499ce498e02d6ac0b326662ecd9
for future work as it's getting complex. There could be fast path optimization where we can jump directly to the NoOverflow branch if exists, but this has a lot of implications. It would be better to move forward with the basic optimization first, and in the future we improve it.
davemgreen
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks - I agree it sounds good to keep this first version simple. The other path was a lot more complex than I had expected.
LGTM
llvm/lib/CodeGen/CodeGenPrepare.cpp
Outdated
| // TODO: This optimization can be further improved but it will get more complex, | ||
| // so we leave it for future work. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps "TODO: This optimization can be further improved to optimize branching on overflow." It helps explain what would change.
llvm/lib/CodeGen/CodeGenPrepare.cpp
Outdated
| Type *Ty = LHS->getType(); | ||
| unsigned VTHalfBitWidth = Ty->getScalarSizeInBits() / 2; | ||
| IntegerType *LegalTy = | ||
| IntegerType::getIntNTy(I->getContext(), VTHalfBitWidth); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps Type *LegalTy = Ty->getWithNewBitWidth(VTHalfBitWidth);
aengelke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with nits
llvm/lib/CodeGen/CodeGenPrepare.cpp
Outdated
| // of the intrinsic which are the extract instructions. | ||
| static bool matchOverflowPattern(Instruction *&I, ExtractValueInst *&MulExtract, | ||
| ExtractValueInst *&OverflowExtract) { | ||
| if (I->getNumUses() > 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hasNUsesOrMore
llvm/lib/CodeGen/CodeGenPrepare.cpp
Outdated
| Type *LegalTy = Ty->getWithNewBitWidth(VTHalfBitWidth); | ||
|
|
||
| // New BBs: | ||
| std::string OriginalBlockName = I->getParent()->getName().str(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to duplicate the string if it's only used in for OverflowEntryBB, created right below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have to cache the name because if I renamed the new block by the original name, it will be numbered -entry1-, because the name is already used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
takeName?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't know this. I use it, thanks.