Skip to content

LV: Expand llvm.histogram intrinsic to support umax, umin, and uadd.sat operations #127399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions llvm/docs/LangRef.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20294,6 +20294,9 @@ More update operation types may be added in the future.

declare void @llvm.experimental.vector.histogram.add.v8p0.i32(<8 x ptr> %ptrs, i32 %inc, <8 x i1> %mask)
declare void @llvm.experimental.vector.histogram.add.nxv2p0.i64(<vscale x 2 x ptr> %ptrs, i64 %inc, <vscale x 2 x i1> %mask)
declare void @llvm.experimental.vector.histogram.uadd.sat.v8p0.i32(<8 x ptr> %ptrs, i32 %inc, <8 x i1> %mask)
declare void @llvm.experimental.vector.histogram.umax.v8p0.i32(<8 x ptr> %ptrs, i32 %val, <8 x i1> %mask)
declare void @llvm.experimental.vector.histogram.umin.v8p0.i32(<8 x ptr> %ptrs, i32 %val, <8 x i1> %mask)

Arguments:
""""""""""
Expand Down
18 changes: 18 additions & 0 deletions llvm/include/llvm/IR/Intrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -1947,6 +1947,24 @@ def int_experimental_vector_histogram_add : DefaultAttrsIntrinsic<[],
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>], // Mask
[ IntrArgMemOnly ]>;

def int_experimental_vector_histogram_uadd_sat : DefaultAttrsIntrinsic<[],
[ llvm_anyvector_ty, // Vector of pointers
llvm_anyint_ty, // Increment
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>], // Mask
[ IntrArgMemOnly ]>;

def int_experimental_vector_histogram_umin : DefaultAttrsIntrinsic<[],
[ llvm_anyvector_ty, // Vector of pointers
llvm_anyint_ty, // Update value
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>], // Mask
[ IntrArgMemOnly ]>;

def int_experimental_vector_histogram_umax : DefaultAttrsIntrinsic<[],
[ llvm_anyvector_ty, // Vector of pointers
llvm_anyint_ty, // Update value
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>], // Mask
[ IntrArgMemOnly ]>;

// Experimental match
def int_experimental_vector_match : DefaultAttrsIntrinsic<
[ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty> ],
Expand Down
25 changes: 23 additions & 2 deletions llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -997,8 +997,26 @@ static void scalarizeMaskedVectorHistogram(const DataLayout &DL, CallInst *CI,
Builder.SetInsertPoint(CondBlock->getTerminator());
Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
LoadInst *Load = Builder.CreateLoad(EltTy, Ptr, "Load" + Twine(Idx));
Value *Add = Builder.CreateAdd(Load, Inc);
Builder.CreateStore(Add, Ptr);
Value *UpdateOp;
switch (cast<IntrinsicInst>(CI)->getIntrinsicID()) {
case Intrinsic::experimental_vector_histogram_add:
UpdateOp = Builder.CreateAdd(Load, Inc);
break;
case Intrinsic::experimental_vector_histogram_uadd_sat:
UpdateOp =
Builder.CreateIntrinsic(Intrinsic::uadd_sat, {EltTy}, {Load, Inc});
break;
case Intrinsic::experimental_vector_histogram_umin:
UpdateOp = Builder.CreateIntrinsic(Intrinsic::umin, {EltTy}, {Load, Inc});
break;
case Intrinsic::experimental_vector_histogram_umax:
UpdateOp = Builder.CreateIntrinsic(Intrinsic::umax, {EltTy}, {Load, Inc});
break;

default:
llvm_unreachable("Unexpected histogram intrinsic");
}
Builder.CreateStore(UpdateOp, Ptr);

// Create "else" block, fill it in the next iteration
BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
Expand Down Expand Up @@ -1089,6 +1107,9 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
default:
break;
case Intrinsic::experimental_vector_histogram_add:
case Intrinsic::experimental_vector_histogram_uadd_sat:
case Intrinsic::experimental_vector_histogram_umin:
case Intrinsic::experimental_vector_histogram_umax:
if (TTI.isLegalMaskedVectorHistogram(CI->getArgOperand(0)->getType(),
CI->getArgOperand(1)->getType()))
return false;
Expand Down
33 changes: 21 additions & 12 deletions llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1072,34 +1072,43 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {

/// Find histogram operations that match high-level code in loops:
/// \code
/// buckets[indices[i]]+=step;
/// buckets[indices[i]] = UpdateOperator(buckets[indices[i]], Val);
/// \endcode
/// When updateOperator can be add, sub, add.sat, umin, umax.
///
/// It matches a pattern starting from \p HSt, which Stores to the 'buckets'
/// array the computed histogram. It uses a BinOp to sum all counts, storing
/// them using a loop-variant index Load from the 'indices' input array.
/// array the computed histogram. It uses an update instruction to update all
/// counts, storing them using a loop-variant index Load from the 'indices'
/// input array.
///
/// On successful matches it updates the STATISTIC 'HistogramsDetected',
/// regardless of hardware support. When there is support, it additionally
/// stores the BinOp/Load pairs in \p HistogramCounts, as well the pointers
/// stores the UpdateOp/Load pairs in \p HistogramCounts, as well the pointers
/// used to update histogram in \p HistogramPtrs.
static bool findHistogram(LoadInst *LI, StoreInst *HSt, Loop *TheLoop,
const PredicatedScalarEvolution &PSE,
SmallVectorImpl<HistogramInfo> &Histograms) {

// Store value must come from a Binary Operation.
Instruction *HPtrInstr = nullptr;
BinaryOperator *HBinOp = nullptr;
if (!match(HSt, m_Store(m_BinOp(HBinOp), m_Instruction(HPtrInstr))))
Instruction *HUpdateOp = nullptr;
if (!match(HSt, m_Store(m_Instruction(HUpdateOp), m_Instruction(HPtrInstr))))
return false;

// BinOp must be an Add or a Sub modifying the bucket value by a
// loop invariant amount.
// FIXME: We assume the loop invariant term is on the RHS.
// Fine for an immediate/constant, but maybe not a generic value?
Value *HIncVal = nullptr;
if (!match(HBinOp, m_Add(m_Load(m_Specific(HPtrInstr)), m_Value(HIncVal))) &&
!match(HBinOp, m_Sub(m_Load(m_Specific(HPtrInstr)), m_Value(HIncVal))))
if (!match(HUpdateOp,
m_Add(m_Load(m_Specific(HPtrInstr)), m_Value(HIncVal))) &&
!match(HUpdateOp,
m_Sub(m_Load(m_Specific(HPtrInstr)), m_Value(HIncVal))) &&
!match(HUpdateOp, m_Intrinsic<Intrinsic::uadd_sat>(
m_Load(m_Specific(HPtrInstr)), m_Value(HIncVal))) &&
!match(HUpdateOp, m_Intrinsic<Intrinsic::umax>(
m_Load(m_Specific(HPtrInstr)), m_Value(HIncVal))) &&
!match(HUpdateOp, m_Intrinsic<Intrinsic::umin>(
m_Load(m_Specific(HPtrInstr)), m_Value(HIncVal))))
return false;

// Make sure the increment value is loop invariant.
Expand Down Expand Up @@ -1141,15 +1150,15 @@ static bool findHistogram(LoadInst *LI, StoreInst *HSt, Loop *TheLoop,

// Ensure we'll have the same mask by checking that all parts of the histogram
// (gather load, update, scatter store) are in the same block.
LoadInst *IndexedLoad = cast<LoadInst>(HBinOp->getOperand(0));
LoadInst *IndexedLoad = cast<LoadInst>(HUpdateOp->getOperand(0));
BasicBlock *LdBB = IndexedLoad->getParent();
if (LdBB != HBinOp->getParent() || LdBB != HSt->getParent())
if (LdBB != HUpdateOp->getParent() || LdBB != HSt->getParent())
return false;

LLVM_DEBUG(dbgs() << "LV: Found histogram for: " << *HSt << "\n");

// Store the operations that make up the histogram.
Histograms.emplace_back(IndexedLoad, HBinOp, HSt);
Histograms.emplace_back(IndexedLoad, HUpdateOp, HSt);
return true;
}

Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8656,14 +8656,16 @@ VPRecipeBuilder::tryToWidenHistogram(const HistogramInfo *HI,
ArrayRef<VPValue *> Operands) {
// FIXME: Support other operations.
unsigned Opcode = HI->Update->getOpcode();
assert((Opcode == Instruction::Add || Opcode == Instruction::Sub) &&
"Histogram update operation must be an Add or Sub");
assert(VPHistogramRecipe::isLegalUpdateInstruction(HI->Update) &&
"Found Illegal update instruction for histogram");

SmallVector<VPValue *, 3> HGramOps;
// Bucket address.
HGramOps.push_back(Operands[1]);
// Increment value.
HGramOps.push_back(getVPValueOrAddLiveIn(HI->Update->getOperand(1)));
// Update Instruction.
HGramOps.push_back(getVPValueOrAddLiveIn(HI->Update));

// In case of predicated execution (due to tail-folding, or conditional
// execution, or both), pass the relevant mask.
Expand Down
9 changes: 8 additions & 1 deletion llvm/lib/Transforms/Vectorize/VPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -1415,9 +1415,16 @@ class VPHistogramRecipe : public VPRecipeBase {
/// Return the mask operand if one was provided, or a null pointer if all
/// lanes should be executed unconditionally.
VPValue *getMask() const {
return getNumOperands() == 3 ? getOperand(2) : nullptr;
return getNumOperands() == 4 ? getOperand(3) : nullptr;
}

/// Returns true if \p I is a legal update instruction of histogram operation.
static bool isLegalUpdateInstruction(Instruction *I);

/// Given update instruction \p I, returns the opcode of the coresponding
/// histogram instruction.
static unsigned getHistogramIntrinsicID(Instruction *I);

#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
/// Print the recipe
void print(raw_ostream &O, const Twine &Indent,
Expand Down
83 changes: 74 additions & 9 deletions llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1220,6 +1220,7 @@ void VPHistogramRecipe::execute(VPTransformState &State) {

Value *Address = State.get(getOperand(0));
Value *IncAmt = State.get(getOperand(1), /*IsScalar=*/true);
Instruction *UpdateInst = cast<Instruction>(State.get(getOperand(2)));
VectorType *VTy = cast<VectorType>(Address->getType());

// The histogram intrinsic requires a mask even if the recipe doesn't;
Expand All @@ -1236,10 +1237,10 @@ void VPHistogramRecipe::execute(VPTransformState &State) {
// add a separate intrinsic in future, but for now we'll try this.
if (Opcode == Instruction::Sub)
IncAmt = Builder.CreateNeg(IncAmt);
else
assert(Opcode == Instruction::Add && "only add or sub supported for now");
assert(isLegalUpdateInstruction(UpdateInst) &&
"Found Illegal update instruction for histogram");

State.Builder.CreateIntrinsic(Intrinsic::experimental_vector_histogram_add,
State.Builder.CreateIntrinsic(getHistogramIntrinsicID(UpdateInst),
{VTy, IncAmt->getType()},
{Address, IncAmt, Mask});
}
Expand Down Expand Up @@ -1274,23 +1275,48 @@ InstructionCost VPHistogramRecipe::computeCost(ElementCount VF,
IntrinsicCostAttributes ICA(Intrinsic::experimental_vector_histogram_add,
Type::getVoidTy(Ctx.LLVMCtx),
{PtrTy, IncTy, MaskTy});
auto *UpdateInst = getOperand(2)->getUnderlyingValue();
InstructionCost UpdateCost;
if (isa<IntrinsicInst>(UpdateInst)) {
IntrinsicCostAttributes UpdateICA(Opcode, IncTy, {IncTy, IncTy});
UpdateCost = Ctx.TTI.getIntrinsicInstrCost(UpdateICA, Ctx.CostKind);
} else
UpdateCost = Ctx.TTI.getArithmeticInstrCost(Opcode, VTy, Ctx.CostKind);

// Add the costs together with the add/sub operation.
return Ctx.TTI.getIntrinsicInstrCost(ICA, Ctx.CostKind) + MulCost +
Ctx.TTI.getArithmeticInstrCost(Opcode, VTy, Ctx.CostKind);
UpdateCost;
}

#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
void VPHistogramRecipe::print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const {
auto *UpdateInst = cast<Instruction>(getOperand(2)->getUnderlyingValue());
assert(isLegalUpdateInstruction(UpdateInst) &&
"Found Illegal update instruction for histogram");
O << Indent << "WIDEN-HISTOGRAM buckets: ";
getOperand(0)->printAsOperand(O, SlotTracker);

if (Opcode == Instruction::Sub)
O << ", dec: ";
else {
assert(Opcode == Instruction::Add);
O << ", inc: ";
if (isa<BinaryOperator>(UpdateInst)) {
if (Opcode == Instruction::Sub)
O << ", dec: ";
else {
O << ", inc: ";
}
} else {
switch (cast<IntrinsicInst>(UpdateInst)->getIntrinsicID()) {
case Intrinsic::uadd_sat:
O << ", saturated inc: ";
break;
case Intrinsic::umax:
O << ", max: ";
break;
case Intrinsic::umin:
O << ", min: ";
break;
default:
llvm_unreachable("Found Illegal update instruction for histogram");
}
}
getOperand(1)->printAsOperand(O, SlotTracker);

Expand All @@ -1300,6 +1326,45 @@ void VPHistogramRecipe::print(raw_ostream &O, const Twine &Indent,
}
}

bool VPHistogramRecipe::isLegalUpdateInstruction(Instruction *I) {
// We only support add and sub instructions and the following list of
// intrinsics: uadd.sat, umax, umin.
if (isa<BinaryOperator>(I))
return I->getOpcode() == Instruction::Add ||
I->getOpcode() == Instruction::Sub;
if (auto *II = dyn_cast<IntrinsicInst>(I)) {
switch (II->getIntrinsicID()) {
case Intrinsic::uadd_sat:
case Intrinsic::umax:
case Intrinsic::umin:
return true;
default:
return false;
}
}
return false;
}

unsigned VPHistogramRecipe::getHistogramIntrinsicID(Instruction *I) {
// We only support add and sub instructions and the following list of
// intrinsics: uadd.sat, umax, umin.
assert(isLegalUpdateInstruction(I) &&
"Found Illegal update instruction for histogram");
if (isa<BinaryOperator>(I))
return Intrinsic::experimental_vector_histogram_add;
auto *II = cast<IntrinsicInst>(I);
switch (II->getIntrinsicID()) {
case Intrinsic::uadd_sat:
return Intrinsic::experimental_vector_histogram_uadd_sat;
case Intrinsic::umax:
return Intrinsic::experimental_vector_histogram_umax;
case Intrinsic::umin:
return Intrinsic::experimental_vector_histogram_umin;
default:
llvm_unreachable("Found Illegal update instruction for histogram");
}
}

void VPWidenSelectRecipe::print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const {
O << Indent << "WIDEN-SELECT ";
Expand Down
Loading
Loading