Skip to content

Commit bea77ed

Browse files
authored
[mlir][Vector] Fold vector.step compared to constant (#161615)
This PR adds a canonicalizer to vector.step that folds vector.step iff the result of the fold is a splat value. An alternative would be to always constant fold it, but that might result in some very large/cumbersome constants. I do wonder if vector.step might be better represented as some sort of attribute in the arith dialect, like %step = arith.constant iota<32> : vector<32xindex>. --------- Signed-off-by: James Newling <[email protected]>
1 parent cfe6bec commit bea77ed

File tree

3 files changed

+417
-0
lines changed

3 files changed

+417
-0
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2999,6 +2999,7 @@ def Vector_StepOp : Vector_Op<"step", [
29992999
}];
30003000
let results = (outs VectorOfRankAndType<[1], [Index]>:$result);
30013001
let assemblyFormat = "attr-dict `:` type($result)";
3002+
let hasCanonicalizer = 1;
30023003
}
30033004

30043005
def Vector_YieldOp : Vector_Op<"yield", [

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7601,6 +7601,111 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
76017601
setResultRanges(getResult(), result);
76027602
}
76037603

7604+
namespace {
7605+
7606+
/// Fold `vector.step -> arith.cmpi` when the step value is compared to a
7607+
/// constant large enough such that the result is the same at all indices.
7608+
///
7609+
/// For example, rewrite the 'greater than' comparison below,
7610+
///
7611+
/// ```mlir
7612+
/// %cst = arith.constant dense<7> : vector<3xindex>
7613+
/// %stp = vector.step : vector<3xindex>
7614+
/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
7615+
/// ```
7616+
///
7617+
/// as,
7618+
///
7619+
/// ```mlir
7620+
/// %out = arith.constant dense<false> : vector<3xi1>.
7621+
/// ```
7622+
///
7623+
/// Above `[0, 1, 2] > [7, 7, 7]` => `[false, false, false]`. Because the result
7624+
/// is false at ALL indices we fold. If the constant was 1, then
7625+
/// `[0, 1, 2] > [1, 1, 1]` => `[false, false, true]` and we do fold,
7626+
/// conservatively preferring the 'compact' vector.step representation.
7627+
///
7628+
/// Note: this folder only works for the case where the constant (`%cst` above)
7629+
/// is the second operand of the comparison. The arith.cmpi canonicalizer will
7630+
/// ensure that constants are always second (on the right).
7631+
struct StepCompareFolder : public OpRewritePattern<StepOp> {
7632+
using Base::Base;
7633+
7634+
LogicalResult matchAndRewrite(StepOp stepOp,
7635+
PatternRewriter &rewriter) const override {
7636+
const int64_t stepSize = stepOp.getResult().getType().getNumElements();
7637+
7638+
for (OpOperand &use : stepOp.getResult().getUses()) {
7639+
auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
7640+
if (!cmpiOp)
7641+
continue;
7642+
7643+
// arith.cmpi canonicalizer makes constants final operands.
7644+
const unsigned stepOperandNumber = use.getOperandNumber();
7645+
if (stepOperandNumber != 0)
7646+
continue;
7647+
7648+
// Check that operand 1 is a constant.
7649+
unsigned constOperandNumber = 1;
7650+
Value otherOperand = cmpiOp.getOperand(constOperandNumber);
7651+
std::optional<int64_t> maybeConstValue =
7652+
getConstantIntValue(otherOperand);
7653+
if (!maybeConstValue.has_value())
7654+
continue;
7655+
7656+
int64_t constValue = maybeConstValue.value();
7657+
arith::CmpIPredicate pred = cmpiOp.getPredicate();
7658+
7659+
auto maybeSplat = [&]() -> std::optional<bool> {
7660+
// Handle ult (unsigned less than) and uge (unsigned greater equal).
7661+
if ((pred == arith::CmpIPredicate::ult ||
7662+
pred == arith::CmpIPredicate::uge) &&
7663+
stepSize <= constValue)
7664+
return pred == arith::CmpIPredicate::ult;
7665+
7666+
// Handle ule and ugt.
7667+
if ((pred == arith::CmpIPredicate::ule ||
7668+
pred == arith::CmpIPredicate::ugt) &&
7669+
stepSize - 1 <= constValue) {
7670+
return pred == arith::CmpIPredicate::ule;
7671+
}
7672+
7673+
// Handle eq and ne.
7674+
if ((pred == arith::CmpIPredicate::eq ||
7675+
pred == arith::CmpIPredicate::ne) &&
7676+
stepSize <= constValue)
7677+
return pred == arith::CmpIPredicate::ne;
7678+
7679+
return std::nullopt;
7680+
}();
7681+
7682+
if (!maybeSplat.has_value())
7683+
continue;
7684+
7685+
rewriter.setInsertionPointAfter(cmpiOp);
7686+
7687+
auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
7688+
if (!type)
7689+
continue;
7690+
7691+
auto boolAttr = DenseElementsAttr::get(type, maybeSplat.value());
7692+
Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
7693+
type, boolAttr);
7694+
7695+
rewriter.replaceOp(cmpiOp, splat);
7696+
return success();
7697+
}
7698+
7699+
return failure();
7700+
}
7701+
};
7702+
} // namespace
7703+
7704+
void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
7705+
MLIRContext *context) {
7706+
results.add<StepCompareFolder>(context);
7707+
}
7708+
76047709
//===----------------------------------------------------------------------===//
76057710
// Vector Masking Utilities
76067711
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)