@@ -7601,6 +7601,111 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
76017601 setResultRanges (getResult (), result);
76027602}
76037603
7604+ namespace {
7605+
7606+ // / Fold `vector.step -> arith.cmpi` when the step value is compared to a
7607+ // / constant large enough such that the result is the same at all indices.
7608+ // /
7609+ // / For example, rewrite the 'greater than' comparison below,
7610+ // /
7611+ // / ```mlir
7612+ // / %cst = arith.constant dense<7> : vector<3xindex>
7613+ // / %stp = vector.step : vector<3xindex>
7614+ // / %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
7615+ // / ```
7616+ // /
7617+ // / as,
7618+ // /
7619+ // / ```mlir
7620+ // / %out = arith.constant dense<false> : vector<3xi1>.
7621+ // / ```
7622+ // /
7623+ // / Above `[0, 1, 2] > [7, 7, 7]` => `[false, false, false]`. Because the result
7624+ // / is false at ALL indices we fold. If the constant was 1, then
7625+ // / `[0, 1, 2] > [1, 1, 1]` => `[false, false, true]` and we do fold,
7626+ // / conservatively preferring the 'compact' vector.step representation.
7627+ // /
7628+ // / Note: this folder only works for the case where the constant (`%cst` above)
7629+ // / is the second operand of the comparison. The arith.cmpi canonicalizer will
7630+ // / ensure that constants are always second (on the right).
7631+ struct StepCompareFolder : public OpRewritePattern <StepOp> {
7632+ using Base::Base;
7633+
7634+ LogicalResult matchAndRewrite (StepOp stepOp,
7635+ PatternRewriter &rewriter) const override {
7636+ const int64_t stepSize = stepOp.getResult ().getType ().getNumElements ();
7637+
7638+ for (OpOperand &use : stepOp.getResult ().getUses ()) {
7639+ auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner ());
7640+ if (!cmpiOp)
7641+ continue ;
7642+
7643+ // arith.cmpi canonicalizer makes constants final operands.
7644+ const unsigned stepOperandNumber = use.getOperandNumber ();
7645+ if (stepOperandNumber != 0 )
7646+ continue ;
7647+
7648+ // Check that operand 1 is a constant.
7649+ unsigned constOperandNumber = 1 ;
7650+ Value otherOperand = cmpiOp.getOperand (constOperandNumber);
7651+ std::optional<int64_t > maybeConstValue =
7652+ getConstantIntValue (otherOperand);
7653+ if (!maybeConstValue.has_value ())
7654+ continue ;
7655+
7656+ int64_t constValue = maybeConstValue.value ();
7657+ arith::CmpIPredicate pred = cmpiOp.getPredicate ();
7658+
7659+ auto maybeSplat = [&]() -> std::optional<bool > {
7660+ // Handle ult (unsigned less than) and uge (unsigned greater equal).
7661+ if ((pred == arith::CmpIPredicate::ult ||
7662+ pred == arith::CmpIPredicate::uge) &&
7663+ stepSize <= constValue)
7664+ return pred == arith::CmpIPredicate::ult;
7665+
7666+ // Handle ule and ugt.
7667+ if ((pred == arith::CmpIPredicate::ule ||
7668+ pred == arith::CmpIPredicate::ugt) &&
7669+ stepSize - 1 <= constValue) {
7670+ return pred == arith::CmpIPredicate::ule;
7671+ }
7672+
7673+ // Handle eq and ne.
7674+ if ((pred == arith::CmpIPredicate::eq ||
7675+ pred == arith::CmpIPredicate::ne) &&
7676+ stepSize <= constValue)
7677+ return pred == arith::CmpIPredicate::ne;
7678+
7679+ return std::nullopt ;
7680+ }();
7681+
7682+ if (!maybeSplat.has_value ())
7683+ continue ;
7684+
7685+ rewriter.setInsertionPointAfter (cmpiOp);
7686+
7687+ auto type = dyn_cast<VectorType>(cmpiOp.getResult ().getType ());
7688+ if (!type)
7689+ continue ;
7690+
7691+ auto boolAttr = DenseElementsAttr::get (type, maybeSplat.value ());
7692+ Value splat = mlir::arith::ConstantOp::create (rewriter, cmpiOp.getLoc (),
7693+ type, boolAttr);
7694+
7695+ rewriter.replaceOp (cmpiOp, splat);
7696+ return success ();
7697+ }
7698+
7699+ return failure ();
7700+ }
7701+ };
7702+ } // namespace
7703+
7704+ void StepOp::getCanonicalizationPatterns (RewritePatternSet &results,
7705+ MLIRContext *context) {
7706+ results.add <StepCompareFolder>(context);
7707+ }
7708+
76047709// ===----------------------------------------------------------------------===//
76057710// Vector Masking Utilities
76067711// ===----------------------------------------------------------------------===//
0 commit comments