From 4566abde04e98e693650de1b2bc2955b64ad45e8 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Wed, 20 Dec 2023 10:47:46 +0900 Subject: [PATCH] [mlir][SCF] `scf.parallel`: Make reductions part of the terminator This commit makes reductions part of the terminator. Instead of `scf.yield`, `scf.reduce` now terminates the body of `scf.parallel` ops. `scf.reduce` may contain an arbitrary number of reductions, with one region per reduction. `scf.reduce` operations can no longer be interleaved with other ops in the body of `scf.parallel`. This simplifies the op and makes it possible to assign the `RecursiveMemoryEffects` trait to `scf.reduce`. (This was not possible before because the op was not a terminator, causing the op to be DCE'd.) --- mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 111 +++++++------ .../AffineToStandard/AffineToStandard.cpp | 27 +-- .../SCFToControlFlow/SCFToControlFlow.cpp | 24 ++- .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 115 +++++++------ .../Async/Transforms/AsyncParallelFor.cpp | 3 +- mlir/lib/Dialect/SCF/IR/SCF.cpp | 155 ++++++++---------- .../SCF/Transforms/ParallelLoopTiling.cpp | 5 + .../Transforms/SparseGPUCodegen.cpp | 3 + .../Transforms/Utils/LoopEmitter.cpp | 2 +- .../AffineToStandard/lower-affine.mlir | 24 +-- .../SCFToControlFlow/convert-to-cfg.mlir | 13 +- .../Conversion/SCFToGPU/parallel_loop.mlir | 8 +- .../Conversion/SCFToOpenMP/reductions.mlir | 19 +-- .../Conversion/SCFToSPIRV/unsupported.mlir | 8 +- mlir/test/Dialect/Linalg/parallel-loops.mlir | 2 +- .../Dialect/Linalg/transform-op-match.mlir | 2 +- .../test/Dialect/SCF/buffer-deallocation.mlir | 2 +- mlir/test/Dialect/SCF/canonicalize.mlir | 23 ++- mlir/test/Dialect/SCF/invalid.mlir | 35 ++-- mlir/test/Dialect/SCF/ops.mlir | 22 ++- .../Dialect/SCF/parallel-loop-fusion.mlir | 66 ++++---- .../SparseTensor/sparse_parallel_reduce.mlir | 5 +- .../invalid-parallel-loop-collapsing.mlir | 4 +- .../loop-invariant-code-motion.mlir | 2 +- .../Transforms/parallel-loop-collapsing.mlir | 2 +- .../single-parallel-loop-collapsing.mlir | 2 +- 26 files changed, 344 insertions(+), 340 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index 573e804b405e8..8d65d3dd820ba 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -770,7 +770,7 @@ def ParallelOp : SCF_Op<"parallel", "getSingleLowerBound", "getSingleUpperBound", "getSingleStep"]>, RecursiveMemoryEffects, DeclareOpInterfaceMethods, - SingleBlockImplicitTerminator<"scf::YieldOp">]> { + SingleBlockImplicitTerminator<"scf::ReduceOp">]> { let summary = "parallel for operation"; let description = [{ The "scf.parallel" operation represents a loop nest taking 4 groups of SSA @@ -791,27 +791,36 @@ def ParallelOp : SCF_Op<"parallel", The parallel loop operation supports reduction of values produced by individual iterations into a single result. This is modeled using the - scf.reduce operation (see scf.reduce for details). Each result of a - scf.parallel operation is associated with an initial value operand and - reduce operation that is an immediate child. Reductions are matched to - result and initial values in order of their appearance in the body. - Consequently, we require that the body region has the same number of - results and initial values as it has reduce operations. - - The body region must contain exactly one block that terminates with - "scf.yield" without operands. Parsing ParallelOp will create such a region - and insert the terminator when it is absent from the custom format. + "scf.reduce" terminator operation (see "scf.reduce" for details). The i-th + result of an "scf.parallel" operation is associated with the i-th initial + value operand, the i-th operand of the "scf.reduce" operation (the value to + be reduced) and the i-th region of the "scf.reduce" operation (the reduction + function). Consequently, we require that the number of results of an + "scf.parallel" op matches the number of initial values and the the number of + reductions in the "scf.reduce" terminator. + + The body region must contain exactly one block that terminates with a + "scf.reduce" operation. If an "scf.parallel" op has no reductions, the + terminator has no operands and no regions. The "scf.parallel" parser will + automatically insert the terminator for ops that have no reductions if it is + absent. Example: ```mlir %init = arith.constant 0.0 : f32 - scf.parallel (%iv) = (%lb) to (%ub) step (%step) init (%init) -> f32 { - %elem_to_reduce = load %buffer[%iv] : memref<100xf32> - scf.reduce(%elem_to_reduce) : f32 { + %r:2 = scf.parallel (%iv) = (%lb) to (%ub) step (%step) init (%init, %init) + -> f32, f32 { + %elem_to_reduce1 = load %buffer1[%iv] : memref<100xf32> + %elem_to_reduce2 = load %buffer2[%iv] : memref<100xf32> + scf.reduce(%elem_to_reduce1, %elem_to_reduce2 : f32, f32) { ^bb0(%lhs : f32, %rhs: f32): %res = arith.addf %lhs, %rhs : f32 scf.reduce.return %res : f32 + }, { + ^bb0(%lhs : f32, %rhs: f32): + %res = arith.mulf %lhs, %rhs : f32 + scf.reduce.return %res : f32 } } ``` @@ -853,36 +862,36 @@ def ParallelOp : SCF_Op<"parallel", // ReduceOp //===----------------------------------------------------------------------===// -def ReduceOp : SCF_Op<"reduce", [HasParent<"ParallelOp">]> { - let summary = "reduce operation for parallel for"; +def ReduceOp : SCF_Op<"reduce", [ + Terminator, HasParent<"ParallelOp">, RecursiveMemoryEffects, + DeclareOpInterfaceMethods]> { + let summary = "reduce operation for scf.parallel"; let description = [{ - "scf.reduce" is an operation occurring inside "scf.parallel" operations. - It consists of one block with two arguments which have the same type as the - operand of "scf.reduce". - - "scf.reduce" is used to model the value for reduction computations of a - "scf.parallel" operation. It has to appear as an immediate child of a - "scf.parallel" and is associated with a result value of its parent - operation. - - Association is in the order of appearance in the body where the first - result of a parallel loop operation corresponds to the first "scf.reduce" - in the operation's body region. The reduce operation takes a single - operand, which is the value to be used in the reduction. - - The reduce operation contains a region whose entry block expects two - arguments of the same type as the operand. As the iteration order of the - parallel loop and hence reduction order is unspecified, the result of - reduction may be non-deterministic unless the operation is associative and - commutative. - - The result of the reduce operation's body must have the same type as the - operands and associated result value of the parallel loop operation. + "scf.reduce" is the terminator for "scf.parallel" operations. It can model + an arbitrary number of reductions. It has one region per reduction. Each + region has one block with two arguments which have the same type as the + corresponding operand of "scf.reduce". The operands of the op are the values + that should be reduce; one value per reduction. + + The i-th reduction (i.e., the i-th region and the i-th operand) corresponds + the i-th initial value and the i-th result of the enclosing "scf.parallel" + op. + + The "scf.reduce" operation contains regions whose entry blocks expect two + arguments of the same type as the corresponding operand. As the iteration + order of the enclosing parallel loop and hence reduction order is + unspecified, the results of the reductions may be non-deterministic unless + the reductions are associative and commutative. + + The result of a reduction region ("scf.reduce.return" operand) must have the + same type as the corresponding "scf.reduce" operand and the corresponding + "scf.parallel" initial value. + Example: ```mlir %operand = arith.constant 1.0 : f32 - scf.reduce(%operand) : f32 { + scf.reduce(%operand : f32) { ^bb0(%lhs : f32, %rhs: f32): %res = arith.addf %lhs, %rhs : f32 scf.reduce.return %res : f32 @@ -892,14 +901,15 @@ def ReduceOp : SCF_Op<"reduce", [HasParent<"ParallelOp">]> { let skipDefaultBuilders = 1; let builders = [ - OpBuilder<(ins "Value":$operand, - CArg<"function_ref", - "nullptr">:$bodyBuilderFn)> + OpBuilder<(ins "ValueRange":$operands)>, + OpBuilder<(ins)> ]; - let arguments = (ins AnyType:$operand); - let hasCustomAssemblyFormat = 1; - let regions = (region SizedRegion<1>:$reductionOperator); + let arguments = (ins Variadic:$operands); + let assemblyFormat = [{ + (`(` $operands^ `:` type($operands) `)`)? $reductions attr-dict + }]; + let regions = (region VariadicRegion>:$reductions); let hasRegionVerifier = 1; } @@ -908,13 +918,14 @@ def ReduceOp : SCF_Op<"reduce", [HasParent<"ParallelOp">]> { //===----------------------------------------------------------------------===// def ReduceReturnOp : - SCF_Op<"reduce.return", [HasParent<"ReduceOp">, Pure, - Terminator]> { + SCF_Op<"reduce.return", [HasParent<"ReduceOp">, Pure, Terminator]> { let summary = "terminator for reduce operation"; let description = [{ "scf.reduce.return" is a special terminator operation for the block inside - "scf.reduce". It terminates the region. It should have the same type as - the operand of "scf.reduce". Example for the custom format: + "scf.reduce" regions. It terminates the region. It should have the same + operand type as the corresponding operand of the enclosing "scf.reduce" op. + + Example: ```mlir scf.reduce.return %res : f32 @@ -1150,7 +1161,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects, def YieldOp : SCF_Op<"yield", [Pure, ReturnLike, Terminator, ParentOneOf<["ExecuteRegionOp", "ForOp", "IfOp", "IndexSwitchOp", - "ParallelOp", "WhileOp"]>]> { + "WhileOp"]>]> { let summary = "loop yield and termination operation"; let description = [{ "scf.yield" yields an SSA value from the SCF dialect op region and diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp index 7dbbf015182f3..15ad6d8cdf629 100644 --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -137,10 +137,9 @@ class AffineYieldOpLowering : public OpRewritePattern { LogicalResult matchAndRewrite(AffineYieldOp op, PatternRewriter &rewriter) const override { if (isa(op->getParentOp())) { - // scf.parallel does not yield any values via its terminator scf.yield but - // models reductions differently using additional ops in its region. - rewriter.replaceOpWithNewOp(op); - return success(); + // Terminator is rewritten as part of the "affine.parallel" lowering + // pattern. + return failure(); } rewriter.replaceOpWithNewOp(op, op.getOperands()); return success(); @@ -203,7 +202,8 @@ class AffineParallelLowering : public OpRewritePattern { steps.push_back(rewriter.create(loc, step)); // Get the terminator op. - Operation *affineParOpTerminator = op.getBody()->getTerminator(); + auto affineParOpTerminator = + cast(op.getBody()->getTerminator()); scf::ParallelOp parOp; if (op.getResults().empty()) { // Case with no reduction operations/return values. @@ -214,6 +214,8 @@ class AffineParallelLowering : public OpRewritePattern { rewriter.inlineRegionBefore(op.getRegion(), parOp.getRegion(), parOp.getRegion().end()); rewriter.replaceOp(op, parOp.getResults()); + rewriter.setInsertionPoint(affineParOpTerminator); + rewriter.replaceOpWithNewOp(affineParOpTerminator); return success(); } // Case with affine.parallel with reduction operations/return values. @@ -243,6 +245,11 @@ class AffineParallelLowering : public OpRewritePattern { parOp.getRegion().end()); assert(reductions.size() == affineParOpTerminator->getNumOperands() && "Unequal number of reductions and operands."); + + // Emit new "scf.reduce" terminator. + rewriter.setInsertionPoint(affineParOpTerminator); + auto reduceOp = rewriter.replaceOpWithNewOp( + affineParOpTerminator, affineParOpTerminator->getOperands()); for (unsigned i = 0, end = reductions.size(); i < end; i++) { // For each of the reduction operations get the respective mlir::Value. std::optional reductionOp = @@ -251,13 +258,11 @@ class AffineParallelLowering : public OpRewritePattern { assert(reductionOp && "Reduction Operation cannot be of None Type"); arith::AtomicRMWKind reductionOpValue = *reductionOp; rewriter.setInsertionPoint(&parOp.getBody()->back()); - auto reduceOp = rewriter.create( - loc, affineParOpTerminator->getOperand(i)); - rewriter.setInsertionPointToEnd(&reduceOp.getReductionOperator().front()); + Block &reductionBody = reduceOp.getReductions()[i].front(); + rewriter.setInsertionPointToEnd(&reductionBody); Value reductionResult = arith::getReductionOp( - reductionOpValue, rewriter, loc, - reduceOp.getReductionOperator().front().getArgument(0), - reduceOp.getReductionOperator().front().getArgument(1)); + reductionOpValue, rewriter, loc, reductionBody.getArgument(0), + reductionBody.getArgument(1)); rewriter.create(loc, reductionResult); } rewriter.replaceOp(op, parOp.getResults()); diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp index c9b45fd4a7957..9eb8a289d7d65 100644 --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -471,6 +471,7 @@ LogicalResult ParallelLowering::matchAndRewrite(ParallelOp parallelOp, PatternRewriter &rewriter) const { Location loc = parallelOp.getLoc(); + auto reductionOp = cast(parallelOp.getBody()->getTerminator()); // For a parallel loop, we essentially need to create an n-dimensional loop // nest. We do this by translating to scf.for ops and have those lowered in @@ -506,23 +507,20 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp, } // First, merge reduction blocks into the main region. - SmallVector yieldOperands; + SmallVector yieldOperands; yieldOperands.reserve(parallelOp.getNumResults()); - for (auto &op : *parallelOp.getBody()) { - auto reduce = dyn_cast(op); - if (!reduce) - continue; - - Block &reduceBlock = reduce.getReductionOperator().front(); + for (int64_t i = 0, e = parallelOp.getNumResults(); i < e; ++i) { + Block &reductionBody = reductionOp.getReductions()[i].front(); Value arg = iterArgs[yieldOperands.size()]; - yieldOperands.push_back(reduceBlock.getTerminator()->getOperand(0)); - rewriter.eraseOp(reduceBlock.getTerminator()); - rewriter.inlineBlockBefore(&reduceBlock, &op, {arg, reduce.getOperand()}); - rewriter.eraseOp(reduce); + yieldOperands.push_back( + cast(reductionBody.getTerminator()).getResult()); + rewriter.eraseOp(reductionBody.getTerminator()); + rewriter.inlineBlockBefore(&reductionBody, reductionOp, + {arg, reductionOp.getOperands()[i]}); } + rewriter.eraseOp(reductionOp); // Then merge the loop body without the terminator. - rewriter.eraseOp(parallelOp.getBody()->getTerminator()); Block *newBody = rewriter.getInsertionBlock(); if (newBody->empty()) rewriter.mergeBlocks(parallelOp.getBody(), newBody, ivs); @@ -711,7 +709,7 @@ LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp, parallelOp.getRegion().begin()); // Replace the terminator. rewriter.setInsertionPointToEnd(¶llelOp.getRegion().front()); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( parallelOp.getRegion().front().getTerminator()); // Erase the scf.forall op. diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index 67033ba812946..2f8b3f7e11de1 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -181,32 +181,34 @@ static Attribute minMaxValueForUnsignedInt(Type type, bool min) { /// Creates an OpenMP reduction declaration and inserts it into the provided /// symbol table. The declaration has a constant initializer with the neutral -/// value `initValue`, and the reduction combiner carried over from `reduce`. -static omp::ReductionDeclareOp createDecl(PatternRewriter &builder, - SymbolTable &symbolTable, - scf::ReduceOp reduce, - Attribute initValue) { +/// value `initValue`, and the `reductionIndex`-th reduction combiner carried +/// over from `reduce`. +static omp::ReductionDeclareOp +createDecl(PatternRewriter &builder, SymbolTable &symbolTable, + scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue) { OpBuilder::InsertionGuard guard(builder); - auto decl = builder.create( - reduce.getLoc(), "__scf_reduction", reduce.getOperand().getType()); + Type type = reduce.getOperands()[reductionIndex].getType(); + auto decl = builder.create(reduce.getLoc(), + "__scf_reduction", type); symbolTable.insert(decl); - Type type = reduce.getOperand().getType(); builder.createBlock(&decl.getInitializerRegion(), decl.getInitializerRegion().end(), {type}, - {reduce.getOperand().getLoc()}); + {reduce.getOperands()[reductionIndex].getLoc()}); builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); Value init = builder.create(reduce.getLoc(), type, initValue); builder.create(reduce.getLoc(), init); - Operation *terminator = &reduce.getRegion().front().back(); + Operation *terminator = + &reduce.getReductions()[reductionIndex].front().back(); assert(isa(terminator) && "expected reduce op to be terminated by redure return"); builder.setInsertionPoint(terminator); builder.replaceOpWithNewOp(terminator, terminator->getOperands()); - builder.inlineRegionBefore(reduce.getRegion(), decl.getReductionRegion(), + builder.inlineRegionBefore(reduce.getReductions()[reductionIndex], + decl.getReductionRegion(), decl.getReductionRegion().end()); return decl; } @@ -216,10 +218,11 @@ static omp::ReductionDeclareOp createDecl(PatternRewriter &builder, static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder, LLVM::AtomicBinOp atomicKind, omp::ReductionDeclareOp decl, - scf::ReduceOp reduce) { + scf::ReduceOp reduce, + int64_t reductionIndex) { OpBuilder::InsertionGuard guard(builder); auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); - Location reduceOperandLoc = reduce.getOperand().getLoc(); + Location reduceOperandLoc = reduce.getOperands()[reductionIndex].getLoc(); builder.createBlock(&decl.getAtomicReductionRegion(), decl.getAtomicReductionRegion().end(), {ptrType, ptrType}, {reduceOperandLoc, reduceOperandLoc}); @@ -239,7 +242,8 @@ static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder, /// the neutral value, necessary for the OpenMP declaration. If the reduction /// cannot be recognized, returns null. static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder, - scf::ReduceOp reduce) { + scf::ReduceOp reduce, + int64_t reductionIndex) { Operation *container = SymbolTable::getNearestSymbolTable(reduce); SymbolTable symbolTable(container); @@ -251,49 +255,58 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder, OpBuilder::InsertionGuard guard(builder); builder.setInsertionPoint(insertionPoint); - assert(llvm::hasSingleElement(reduce.getRegion()) && + assert(llvm::hasSingleElement(reduce.getReductions()[reductionIndex]) && "expected reduction region to have a single element"); // Match simple binary reductions that can be expressed with atomicrmw. - Type type = reduce.getOperand().getType(); - Block &reduction = reduce.getRegion().front(); + Type type = reduce.getOperands()[reductionIndex].getType(); + Block &reduction = reduce.getReductions()[reductionIndex].front(); if (matchSimpleReduction(reduction)) { - omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, - builder.getFloatAttr(type, 0.0)); - return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce); + omp::ReductionDeclareOp decl = + createDecl(builder, symbolTable, reduce, reductionIndex, + builder.getFloatAttr(type, 0.0)); + return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce, + reductionIndex); } if (matchSimpleReduction(reduction)) { - omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, - builder.getIntegerAttr(type, 0)); - return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce); + omp::ReductionDeclareOp decl = + createDecl(builder, symbolTable, reduce, reductionIndex, + builder.getIntegerAttr(type, 0)); + return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce, + reductionIndex); } if (matchSimpleReduction(reduction)) { - omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, - builder.getIntegerAttr(type, 0)); - return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce); + omp::ReductionDeclareOp decl = + createDecl(builder, symbolTable, reduce, reductionIndex, + builder.getIntegerAttr(type, 0)); + return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce, + reductionIndex); } if (matchSimpleReduction(reduction)) { - omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, - builder.getIntegerAttr(type, 0)); - return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce); + omp::ReductionDeclareOp decl = + createDecl(builder, symbolTable, reduce, reductionIndex, + builder.getIntegerAttr(type, 0)); + return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce, + reductionIndex); } if (matchSimpleReduction(reduction)) { omp::ReductionDeclareOp decl = createDecl( - builder, symbolTable, reduce, + builder, symbolTable, reduce, reductionIndex, builder.getIntegerAttr( type, llvm::APInt::getAllOnes(type.getIntOrFloatBitWidth()))); - return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce); + return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce, + reductionIndex); } // Match simple binary reductions that cannot be expressed with atomicrmw. // TODO: add atomic region using cmpxchg (which needs atomic load to be // available as an op). if (matchSimpleReduction(reduction)) { - return createDecl(builder, symbolTable, reduce, + return createDecl(builder, symbolTable, reduce, reductionIndex, builder.getFloatAttr(type, 1.0)); } if (matchSimpleReduction(reduction)) { - return createDecl(builder, symbolTable, reduce, + return createDecl(builder, symbolTable, reduce, reductionIndex, builder.getIntegerAttr(type, 1)); } @@ -305,7 +318,7 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder, matchSelectReduction( reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole}, {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) { - return createDecl(builder, symbolTable, reduce, + return createDecl(builder, symbolTable, reduce, reductionIndex, minMaxValueForFloat(type, !isMin)); } if (matchSelectReduction( @@ -314,11 +327,12 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder, matchSelectReduction( reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle}, {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) { - omp::ReductionDeclareOp decl = createDecl( - builder, symbolTable, reduce, minMaxValueForSignedInt(type, !isMin)); + omp::ReductionDeclareOp decl = + createDecl(builder, symbolTable, reduce, reductionIndex, + minMaxValueForSignedInt(type, !isMin)); return addAtomicRMW(builder, isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max, - decl, reduce); + decl, reduce, reductionIndex); } if (matchSelectReduction( reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule}, @@ -326,11 +340,12 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder, matchSelectReduction( reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule}, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) { - omp::ReductionDeclareOp decl = createDecl( - builder, symbolTable, reduce, minMaxValueForUnsignedInt(type, !isMin)); + omp::ReductionDeclareOp decl = + createDecl(builder, symbolTable, reduce, reductionIndex, + minMaxValueForUnsignedInt(type, !isMin)); return addAtomicRMW( builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax, - decl, reduce); + decl, reduce, reductionIndex); } return nullptr; @@ -352,8 +367,9 @@ struct ParallelOpLowering : public OpRewritePattern { // TODO: consider checking it here is already a compatible reduction // declaration and use it instead of redeclaring. SmallVector reductionDeclSymbols; - for (auto reduce : parallelOp.getOps()) { - omp::ReductionDeclareOp decl = declareReduction(rewriter, reduce); + auto reduce = cast(parallelOp.getBody()->getTerminator()); + for (int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) { + omp::ReductionDeclareOp decl = declareReduction(rewriter, reduce, i); if (!decl) return failure(); reductionDeclSymbols.push_back( @@ -382,14 +398,13 @@ struct ParallelOpLowering : public OpRewritePattern { // Replace the reduction operations contained in this loop. Must be done // here rather than in a separate pattern to have access to the list of // reduction variables. - for (auto pair : - llvm::zip(parallelOp.getOps(), reductionVariables)) { + for (auto [x, y] : + llvm::zip_equal(reductionVariables, reduce.getOperands())) { OpBuilder::InsertionGuard guard(rewriter); - scf::ReduceOp reduceOp = std::get<0>(pair); - rewriter.setInsertionPoint(reduceOp); - rewriter.replaceOpWithNewOp( - reduceOp, reduceOp.getOperand(), std::get<1>(pair)); + rewriter.setInsertionPoint(reduce); + rewriter.create(reduce.getLoc(), y, x); } + rewriter.eraseOp(reduce); Value numThreadsVar; if (numThreads > 0) { @@ -432,10 +447,8 @@ struct ParallelOpLowering : public OpRewritePattern { rewriter.create(loc, ValueRange()); Block *scopeBlock = rewriter.createBlock(&scope.getBodyRegion()); rewriter.mergeBlocks(ops, scopeBlock); - auto oldYield = cast(scopeBlock->getTerminator()); rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin()); - rewriter.replaceOpWithNewOp( - oldYield, oldYield->getOperands()); + rewriter.create(loc, ValueRange()); if (!reductionVariables.empty()) { loop.setReductionsAttr( ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols)); diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp index 12a28c2e23b22..428a3c945581b 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -429,8 +429,9 @@ static ParallelComputeFunction createParallelComputeFunction( mapping.map(op.getInductionVars(), computeBlockInductionVars); mapping.map(computeFuncType.captures, captures); - for (auto &bodyOp : op.getRegion().getOps()) + for (auto &bodyOp : op.getRegion().front().without_terminator()) b.clone(bodyOp, mapping); + b.create(loc); }; }; diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 55bb5788108bd..5570c2ec688c8 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -2643,7 +2643,9 @@ void ParallelOp::build( bodyBlock->getArguments().take_front(numIVs), bodyBlock->getArguments().drop_front(numIVs)); } - ParallelOp::ensureTerminator(*bodyRegion, builder, result.location); + // Add terminator only if there are no reductions. + if (initVals.empty()) + ParallelOp::ensureTerminator(*bodyRegion, builder, result.location); } void ParallelOp::build( @@ -2693,19 +2695,15 @@ LogicalResult ParallelOp::verify() { return emitOpError( "expects arguments for the induction variable to be of index type"); - // Check that the yield has no results - auto yield = verifyAndGetTerminator( - *this, getRegion(), "expects body to terminate with 'scf.yield'"); - if (!yield) + // Check that the terminator is an scf.reduce op. + auto reduceOp = verifyAndGetTerminator( + *this, getRegion(), "expects body to terminate with 'scf.reduce'"); + if (!reduceOp) return failure(); - if (yield->getNumOperands() != 0) - return yield.emitOpError() << "not allowed to have operands inside '" - << ParallelOp::getOperationName() << "'"; - // Check that the number of results is the same as the number of ReduceOps. - SmallVector reductions(body->getOps()); + // Check that the number of results is the same as the number of reductions. auto resultsSize = getResults().size(); - auto reductionsSize = reductions.size(); + auto reductionsSize = reduceOp.getReductions().size(); auto initValsSize = getInitVals().size(); if (resultsSize != reductionsSize) return emitOpError() << "expects number of results: " << resultsSize @@ -2717,14 +2715,15 @@ LogicalResult ParallelOp::verify() { << initValsSize; // Check that the types of the results and reductions are the same. - for (auto resultAndReduce : llvm::zip(getResults(), reductions)) { - auto resultType = std::get<0>(resultAndReduce).getType(); - auto reduceOp = std::get<1>(resultAndReduce); - auto reduceType = reduceOp.getOperand().getType(); - if (resultType != reduceType) + for (int64_t i = 0; i < static_cast(reductionsSize); ++i) { + auto resultType = getOperation()->getResult(i).getType(); + auto reductionOperandType = reduceOp.getOperands()[i].getType(); + if (resultType != reductionOperandType) return reduceOp.emitOpError() - << "expects type of reduce: " << reduceType - << " to be the same as result type: " << resultType; + << "expects type of " << i + << "-th reduction operand: " << reductionOperandType + << " to be the same as the " << i + << "-th result type: " << resultType; } return success(); } @@ -2792,7 +2791,7 @@ ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); // Add a terminator if none was parsed. - ForOp::ensureTerminator(*body, builder, result.location); + ParallelOp::ensureTerminator(*body, builder, result.location); return success(); } @@ -2887,17 +2886,15 @@ struct ParallelOpSingleOrZeroIterationDimsFolder // loop body and nested ReduceOp's SmallVector results; results.reserve(op.getInitVals().size()); - for (auto &bodyOp : op.getBody()->without_terminator()) { - auto reduce = dyn_cast(bodyOp); - if (!reduce) { - rewriter.clone(bodyOp, mapping); - continue; - } - Block &reduceBlock = reduce.getReductionOperator().front(); + for (auto &bodyOp : op.getBody()->without_terminator()) + rewriter.clone(bodyOp, mapping); + auto reduceOp = cast(op.getBody()->getTerminator()); + for (int64_t i = 0, e = reduceOp.getReductions().size(); i < e; ++i) { + Block &reduceBlock = reduceOp.getReductions()[i].front(); auto initValIndex = results.size(); mapping.map(reduceBlock.getArgument(0), op.getInitVals()[initValIndex]); mapping.map(reduceBlock.getArgument(1), - mapping.lookupOrDefault(reduce.getOperand())); + mapping.lookupOrDefault(reduceOp.getOperands()[i])); for (auto &reduceBodyOp : reduceBlock.without_terminator()) rewriter.clone(reduceBodyOp, mapping); @@ -2905,6 +2902,7 @@ struct ParallelOpSingleOrZeroIterationDimsFolder cast(reduceBlock.getTerminator()).getResult()); results.push_back(result); } + rewriter.replaceOp(op, results); return success(); } @@ -3008,67 +3006,48 @@ void ParallelOp::getSuccessorRegions( // ReduceOp //===----------------------------------------------------------------------===// -void ReduceOp::build( - OpBuilder &builder, OperationState &result, Value operand, - function_ref bodyBuilderFn) { - auto type = operand.getType(); - result.addOperands(operand); +void ReduceOp::build(OpBuilder &builder, OperationState &result) {} - OpBuilder::InsertionGuard guard(builder); - Region *bodyRegion = result.addRegion(); - Block *body = builder.createBlock(bodyRegion, {}, ArrayRef{type, type}, - {result.location, result.location}); - if (bodyBuilderFn) - bodyBuilderFn(builder, result.location, body->getArgument(0), - body->getArgument(1)); +void ReduceOp::build(OpBuilder &builder, OperationState &result, + ValueRange operands) { + result.addOperands(operands); + for (Value v : operands) { + OpBuilder::InsertionGuard guard(builder); + Region *bodyRegion = result.addRegion(); + builder.createBlock(bodyRegion, {}, + ArrayRef{v.getType(), v.getType()}, + {result.location, result.location}); + } } LogicalResult ReduceOp::verifyRegions() { - // The region of a ReduceOp has two arguments of the same type as its operand. - auto type = getOperand().getType(); - Block &block = getReductionOperator().front(); - if (block.empty()) - return emitOpError("the block inside reduce should not be empty"); - if (block.getNumArguments() != 2 || - llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) { - return arg.getType() != type; - })) - return emitOpError() << "expects two arguments to reduce block of type " - << type; - - // Check that the block is terminated by a ReduceReturnOp. - if (!isa(block.getTerminator())) - return emitOpError("the block inside reduce should be terminated with a " - "'scf.reduce.return' op"); - - return success(); -} - -ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) { - // Parse an opening `(` followed by the reduced value followed by `)` - OpAsmParser::UnresolvedOperand operand; - if (parser.parseLParen() || parser.parseOperand(operand) || - parser.parseRParen()) - return failure(); - - Type resultType; - // Parse the type of the operand (and also what reduce computes on). - if (parser.parseColonType(resultType) || - parser.resolveOperand(operand, resultType, result.operands)) - return failure(); - - // Now parse the body. - Region *body = result.addRegion(); - if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) - return failure(); + // The region of a ReduceOp has two arguments of the same type as its + // corresponding operand. + for (int64_t i = 0, e = getReductions().size(); i < e; ++i) { + auto type = getOperands()[i].getType(); + Block &block = getReductions()[i].front(); + if (block.empty()) + return emitOpError() << i << "-th reduction has an empty body"; + if (block.getNumArguments() != 2 || + llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) { + return arg.getType() != type; + })) + return emitOpError() << "expected two block arguments with type " << type + << " in the " << i << "-th reduction region"; + + // Check that the block is terminated by a ReduceReturnOp. + if (!isa(block.getTerminator())) + return emitOpError("reduction bodies must be terminated with an " + "'scf.reduce.return' op"); + } return success(); } -void ReduceOp::print(OpAsmPrinter &p) { - p << "(" << getOperand() << ") "; - p << " : " << getOperand().getType() << ' '; - p.printRegion(getReductionOperator()); +MutableOperandRange +ReduceOp::getMutableSuccessorOperands(RegionBranchPoint point) { + // No operands are forwarded to the next iteration. + return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0); } //===----------------------------------------------------------------------===// @@ -3076,13 +3055,15 @@ void ReduceOp::print(OpAsmPrinter &p) { //===----------------------------------------------------------------------===// LogicalResult ReduceReturnOp::verify() { - // The type of the return value should be the same type as the type of the - // operand of the enclosing ReduceOp. - auto reduceOp = cast((*this)->getParentOp()); - Type reduceType = reduceOp.getOperand().getType(); - if (reduceType != getResult().getType()) - return emitOpError() << "needs to have type " << reduceType - << " (the type of the enclosing ReduceOp)"; + // The type of the return value should be the same type as the types of the + // block arguments of the reduction body. + Block *reductionBody = getOperation()->getBlock(); + // Should already be verified by an op trait. + assert(isa(reductionBody->getParentOp()) && "expected scf.reduce"); + Type expectedResultType = reductionBody->getArgument(0).getType(); + if (expectedResultType != getResult().getType()) + return emitOpError() << "must have type " << expectedResultType + << " (the type of the reduction inputs)"; return success(); } diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp index fdc28060917fb..ed73d81198f29 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp @@ -159,6 +159,11 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef tileSizes, /*hasElseRegion*/ false); ifInbound.getThenRegion().takeBody(op.getRegion()); Block &thenBlock = ifInbound.getThenRegion().front(); + // Replace the scf.reduce terminator with an scf.yield terminator. + Operation *reduceOp = thenBlock.getTerminator(); + b.setInsertionPointToEnd(&thenBlock); + b.create(reduceOp->getLoc()); + reduceOp->erase(); b.setInsertionPointToStart(innerLoop.getBody()); for (const auto &ivs : llvm::enumerate(llvm::zip( innerLoop.getInductionVars(), outerLoop.getInductionVars()))) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp index 69fd1eb746ffe..8af3b694c4d97 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -315,6 +315,9 @@ static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc, rewriter.eraseBlock(forOp.getBody()); rewriter.cloneRegionBefore(forallOp.getRegion(), forOp.getRegion(), forOp.getRegion().begin(), irMap); + // Replace the scf.reduce terminator. + rewriter.setInsertionPoint(forOp.getBody()->getTerminator()); + rewriter.replaceOpWithNewOp(forOp.getBody()->getTerminator()); // Done. rewriter.setInsertionPointAfter(forOp); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp index 35faf1769746d..d60b6ccd73216 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp @@ -1371,7 +1371,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc, rewriter.setInsertionPointAfter(redExp); auto redOp = rewriter.create(loc, curVal); // Attach to the reduction op. - Block *redBlock = &redOp.getRegion().getBlocks().front(); + Block *redBlock = &redOp.getReductions().front().front(); rewriter.setInsertionPointToEnd(redBlock); Operation *newRed = rewriter.clone(*redExp); // Replaces arguments of the reduction expression by using the block diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir index 6158de33e4aef..92608135d24b0 100644 --- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir +++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir @@ -763,7 +763,7 @@ func.func @affine_parallel_tiled(%o: memref<100x100xf32>, %a: memref<100x100xf32 // CHECK: %[[A3:.*]] = memref.load %[[ARG1]][%[[arg6]], %[[arg8]]] : memref<100x100xf32> // CHECK: %[[A4:.*]] = memref.load %[[ARG2]][%[[arg8]], %[[arg7]]] : memref<100x100xf32> // CHECK: arith.mulf %[[A3]], %[[A4]] : f32 -// CHECK: scf.yield +// CHECK: scf.reduce ///////////////////////////////////////////////////////////////////// @@ -789,7 +789,7 @@ func.func @affine_parallel_simple(%arg0: memref<3x3xf32>, %arg1: memref<3x3xf32> // CHECK-NEXT: %[[VAL_2:.*]] = memref.load // CHECK-NEXT: %[[PRODUCT:.*]] = arith.mulf // CHECK-NEXT: store -// CHECK-NEXT: scf.yield +// CHECK-NEXT: scf.reduce // CHECK-NEXT: } // CHECK-NEXT: return // CHECK-NEXT: } @@ -820,7 +820,7 @@ func.func @affine_parallel_simple_dynamic_bounds(%arg0: memref, %arg1: // CHECK-NEXT: %[[VAL_2:.*]] = memref.load // CHECK-NEXT: %[[PRODUCT:.*]] = arith.mulf // CHECK-NEXT: store -// CHECK-NEXT: scf.yield +// CHECK-NEXT: scf.reduce // CHECK-NEXT: } // CHECK-NEXT: return // CHECK-NEXT: } @@ -851,17 +851,15 @@ func.func @affine_parallel_with_reductions(%arg0: memref<3x3xf32>, %arg1: memref // CHECK-NEXT: %[[VAL_2:.*]] = memref.load // CHECK-NEXT: %[[PRODUCT:.*]] = arith.mulf // CHECK-NEXT: %[[SUM:.*]] = arith.addf -// CHECK-NEXT: scf.reduce(%[[PRODUCT]]) : f32 { +// CHECK-NEXT: scf.reduce(%[[PRODUCT]], %[[SUM]] : f32, f32) { // CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): // CHECK-NEXT: %[[RES:.*]] = arith.addf // CHECK-NEXT: scf.reduce.return %[[RES]] : f32 -// CHECK-NEXT: } -// CHECK-NEXT: scf.reduce(%[[SUM]]) : f32 { +// CHECK-NEXT: }, { // CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): // CHECK-NEXT: %[[RES:.*]] = arith.mulf // CHECK-NEXT: scf.reduce.return %[[RES]] : f32 // CHECK-NEXT: } -// CHECK-NEXT: scf.yield // CHECK-NEXT: } // CHECK-NEXT: return // CHECK-NEXT: } @@ -892,17 +890,15 @@ func.func @affine_parallel_with_reductions_f64(%arg0: memref<3x3xf64>, %arg1: me // CHECK: %[[VAL_2:.*]] = memref.load // CHECK: %[[PRODUCT:.*]] = arith.mulf // CHECK: %[[SUM:.*]] = arith.addf -// CHECK: scf.reduce(%[[PRODUCT]]) : f64 { +// CHECK: scf.reduce(%[[PRODUCT]], %[[SUM]] : f64, f64) { // CHECK: ^bb0(%[[LHS:.*]]: f64, %[[RHS:.*]]: f64): // CHECK: %[[RES:.*]] = arith.addf // CHECK: scf.reduce.return %[[RES]] : f64 -// CHECK: } -// CHECK: scf.reduce(%[[SUM]]) : f64 { +// CHECK: }, { // CHECK: ^bb0(%[[LHS:.*]]: f64, %[[RHS:.*]]: f64): // CHECK: %[[RES:.*]] = arith.mulf // CHECK: scf.reduce.return %[[RES]] : f64 // CHECK: } -// CHECK: scf.yield // CHECK: } ///////////////////////////////////////////////////////////////////// @@ -931,15 +927,13 @@ func.func @affine_parallel_with_reductions_i64(%arg0: memref<3x3xi64>, %arg1: me // CHECK: %[[VAL_2:.*]] = memref.load // CHECK: %[[PRODUCT:.*]] = arith.muli // CHECK: %[[SUM:.*]] = arith.addi -// CHECK: scf.reduce(%[[PRODUCT]]) : i64 { +// CHECK: scf.reduce(%[[PRODUCT]], %[[SUM]] : i64, i64) { // CHECK: ^bb0(%[[LHS:.*]]: i64, %[[RHS:.*]]: i64): // CHECK: %[[RES:.*]] = arith.addi // CHECK: scf.reduce.return %[[RES]] : i64 -// CHECK: } -// CHECK: scf.reduce(%[[SUM]]) : i64 { +// CHECK: }, { // CHECK: ^bb0(%[[LHS:.*]]: i64, %[[RHS:.*]]: i64): // CHECK: %[[RES:.*]] = arith.muli // CHECK: scf.reduce.return %[[RES]] : i64 // CHECK: } -// CHECK: scf.yield // CHECK: } diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir index 99b47ea94cc0b..caf17bc91ced2 100644 --- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir +++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir @@ -254,6 +254,7 @@ func.func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index, scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %step) { %c1 = arith.constant 1 : index + scf.reduce } return } @@ -347,7 +348,7 @@ func.func @simple_parallel_reduce_loop(%arg0: index, %arg1: index, // CHECK: return %[[ITER_ARG]] %0 = scf.parallel (%i) = (%arg0) to (%arg1) step (%arg2) init(%arg3) -> f32 { %cst = arith.constant 42.0 : f32 - scf.reduce(%cst) : f32 { + scf.reduce(%cst : f32) { ^bb0(%lhs: f32, %rhs: f32): %1 = arith.mulf %lhs, %rhs : f32 scf.reduce.return %1 : f32 @@ -383,14 +384,12 @@ func.func @parallel_reduce_loop(%arg0 : index, %arg1 : index, %arg2 : index, %0:2 = scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %step) init(%arg5, %init) -> (f32, i64) { %cf = arith.constant 42.0 : f32 - scf.reduce(%cf) : f32 { + %2 = func.call @generate() : () -> i64 + scf.reduce(%cf, %2 : f32, i64) { ^bb0(%lhs: f32, %rhs: f32): %1 = arith.addf %lhs, %rhs : f32 scf.reduce.return %1 : f32 - } - - %2 = func.call @generate() : () -> i64 - scf.reduce(%2) : i64 { + }, { ^bb0(%lhs: i64, %rhs: i64): %3 = arith.ori %lhs, %rhs : i64 scf.reduce.return %3 : i64 @@ -580,7 +579,7 @@ func.func @ifs_in_parallel(%arg1: index, %arg2: index, %arg3: index, %arg4: i1, scf.yield %2 : index } } - scf.yield + scf.reduce } // CHECK: ^[[LOOP_CONT]]: diff --git a/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir b/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir index deeaec2f81a94..59441e5ed6629 100644 --- a/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir +++ b/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir @@ -232,9 +232,9 @@ module { %19 = memref.load %16[%arg5, %arg6] : memref> %20 = arith.addf %17, %18 : f32 memref.store %20, %16[%arg5, %arg6] : memref> - scf.yield + scf.reduce } {mapping = [#gpu.loop_dim_map (d0), map = (d0) -> (d0), processor = thread_x>, #gpu.loop_dim_map (d0), map = (d0) -> (d0), processor = thread_y>]} - scf.yield + scf.reduce } {mapping = [#gpu.loop_dim_map (d0), map = (d0) -> (d0), processor = block_x>, #gpu.loop_dim_map (d0), map = (d0) -> (d0), processor = block_y>]} return } @@ -404,9 +404,9 @@ func.func @step_invariant() { %1 = memref.load %alloc_0[%arg0, %arg1] : memref<1x1xf64> %2 = arith.addf %0, %1 : f64 memref.store %2, %alloc[%arg0, %arg1] : memref<1x1xf64> - scf.yield + scf.reduce } {mapping = [#gpu.loop_dim_map (d0), bound = (d0) -> (d0)>]} - scf.yield + scf.reduce } {mapping = [#gpu.loop_dim_map (d0), bound = (d0) -> (d0)>]} memref.dealloc %alloc_1 : memref<1x1xf64> memref.dealloc %alloc_0 : memref<1x1xf64> diff --git a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir index 25b18b58a6adb..faf5ec4aba7d4 100644 --- a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir +++ b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir @@ -34,7 +34,7 @@ func.func @reduction1(%arg0 : index, %arg1 : index, %arg2 : index, // CHECK: %[[CST_INNER:.*]] = arith.constant 1.0 %one = arith.constant 1.0 : f32 // CHECK: omp.reduction %[[CST_INNER]], %[[BUF]] - scf.reduce(%one) : f32 { + scf.reduce(%one : f32) { ^bb0(%lhs : f32, %rhs: f32): %res = arith.addf %lhs, %rhs : f32 scf.reduce.return %res : f32 @@ -70,7 +70,7 @@ func.func @reduction2(%arg0 : index, %arg1 : index, %arg2 : index, scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %step) init (%zero) -> (f32) { %one = arith.constant 1.0 : f32 - scf.reduce(%one) : f32 { + scf.reduce(%one : f32) { ^bb0(%lhs : f32, %rhs: f32): %res = arith.mulf %lhs, %rhs : f32 scf.reduce.return %res : f32 @@ -107,7 +107,7 @@ func.func @reduction_muli(%arg0 : index, %arg1 : index, %arg2 : index, step (%arg4, %step) init (%one) -> (i32) { // CHECK: omp.reduction %pow2 = arith.constant 2 : i32 - scf.reduce(%pow2) : i32 { + scf.reduce(%pow2 : i32) { ^bb0(%lhs : i32, %rhs: i32): %res = arith.muli %lhs, %rhs : i32 scf.reduce.return %res : i32 @@ -141,7 +141,7 @@ func.func @reduction3(%arg0 : index, %arg1 : index, %arg2 : index, scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %step) init (%zero) -> (f32) { %one = arith.constant 1.0 : f32 - scf.reduce(%one) : f32 { + scf.reduce(%one : f32) { ^bb0(%lhs : f32, %rhs: f32): %cmp = arith.cmpf oge, %lhs, %rhs : f32 %res = arith.select %cmp, %lhs, %rhs : f32 @@ -205,17 +205,16 @@ func.func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index, %res:2 = scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %step) init (%zero, %ione) -> (f32, i64) { %one = arith.constant 1.0 : f32 + // CHECK: arith.fptosi + %1 = arith.fptosi %one : f32 to i64 // CHECK: omp.reduction %{{.*}}, %[[BUF1]] - scf.reduce(%one) : f32 { + // CHECK: omp.reduction %{{.*}}, %[[BUF2]] + scf.reduce(%one, %1 : f32, i64) { ^bb0(%lhs : f32, %rhs: f32): %cmp = arith.cmpf oge, %lhs, %rhs : f32 %res = arith.select %cmp, %lhs, %rhs : f32 scf.reduce.return %res : f32 - } - // CHECK: arith.fptosi - %1 = arith.fptosi %one : f32 to i64 - // CHECK: omp.reduction %{{.*}}, %[[BUF2]] - scf.reduce(%1) : i64 { + }, { ^bb1(%lhs: i64, %rhs: i64): %cmp = arith.cmpi slt, %lhs, %rhs : i64 %res = arith.select %cmp, %rhs, %lhs : i64 diff --git a/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir index 6f388f366f744..71bf2f3d918e8 100644 --- a/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir +++ b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir @@ -1,13 +1,13 @@ // RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s // `scf.parallel` conversion is not supported yet. -// Make sure that we do not accidentally invalidate this functio by removing -// `scf.yield`. +// Make sure that we do not accidentally invalidate this function by removing +// `scf.reduce`. // CHECK-LABEL: func.func @func // CHECK: scf.parallel // CHECK-NEXT: spirv.Constant // CHECK-NEXT: memref.store -// CHECK-NEXT: scf.yield +// CHECK-NEXT: scf.reduce // CHECK: spirv.Return func.func @func(%arg0: i64) { %0 = arith.index_cast %arg0 : i64 to index @@ -15,7 +15,7 @@ func.func @func(%arg0: i64) { scf.parallel (%arg1) = (%0) to (%0) step (%0) { %cst = arith.constant 1.000000e+00 : f32 memref.store %cst, %alloc[%arg1] : memref<16xf32> - scf.yield + scf.reduce } return } diff --git a/mlir/test/Dialect/Linalg/parallel-loops.mlir b/mlir/test/Dialect/Linalg/parallel-loops.mlir index 15bce63caabcf..c04f27608d445 100644 --- a/mlir/test/Dialect/Linalg/parallel-loops.mlir +++ b/mlir/test/Dialect/Linalg/parallel-loops.mlir @@ -25,7 +25,7 @@ func.func @linalg_generic_sum(%lhs: memref<2x2xf32>, // CHECK: %[[RHS_ELEM:.*]] = memref.load %[[RHS]][%[[I]], %[[J]]] // CHECK: %[[SUM:.*]] = arith.addf %[[LHS_ELEM]], %[[RHS_ELEM]] : f32 // CHECK: store %[[SUM]], %{{.*}}[%[[I]], %[[J]]] -// CHECK: scf.yield +// CHECK: scf.reduce // ----- diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir index fed3c007d9b6d..15942db9b5db2 100644 --- a/mlir/test/Dialect/Linalg/transform-op-match.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir @@ -153,7 +153,7 @@ func.func @foo(%lb: index, %ub: index, %step: index) { // expected-remark @below {{loop-like}} scf.parallel (%i) = (%lb) to (%ub) step (%step) { func.call @callee() : () -> () - scf.yield + scf.reduce } // expected-remark @below {{loop-like}} scf.forall (%i) in (%ub) { diff --git a/mlir/test/Dialect/SCF/buffer-deallocation.mlir b/mlir/test/Dialect/SCF/buffer-deallocation.mlir index 99cfed99c02d1..8451b1524fd2a 100644 --- a/mlir/test/Dialect/SCF/buffer-deallocation.mlir +++ b/mlir/test/Dialect/SCF/buffer-deallocation.mlir @@ -31,7 +31,7 @@ func.func @reduce(%buffer: memref<100xf32>) { %c1 = arith.constant 1 : index scf.parallel (%iv) = (%c0) to (%c1) step (%c1) init (%init) -> f32 { %elem_to_reduce = memref.load %buffer[%iv] : memref<100xf32> - scf.reduce(%elem_to_reduce) : f32 { + scf.reduce(%elem_to_reduce : f32) { ^bb0(%lhs : f32, %rhs: f32): %alloc = memref.alloc() : memref<2xf32> memref.store %lhs, %alloc [%c0] : memref<2xf32> diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 41e028028616a..52e0fdfa36d6c 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -11,7 +11,7 @@ func.func @single_iteration_some(%A: memref) { scf.parallel (%i0, %i1, %i2) = (%c0, %c3, %c7) to (%c1, %c6, %c10) step (%c1, %c2, %c3) { %c42 = arith.constant 42 : i32 memref.store %c42, %A[%i0, %i1, %i2] : memref - scf.yield + scf.reduce } return } @@ -26,7 +26,7 @@ func.func @single_iteration_some(%A: memref) { // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index // CHECK: scf.parallel ([[V0:%.*]]) = ([[C3]]) to ([[C6]]) step ([[C2]]) { // CHECK: memref.store [[C42]], [[ARG0]]{{\[}}[[C0]], [[V0]], [[C7]]] : memref -// CHECK: scf.yield +// CHECK: scf.reduce // CHECK: } // CHECK: return @@ -42,7 +42,7 @@ func.func @single_iteration_all(%A: memref) { scf.parallel (%i0, %i1, %i2) = (%c0, %c3, %c7) to (%c1, %c6, %c10) step (%c1, %c3, %c3) { %c42 = arith.constant 42 : i32 memref.store %c42, %A[%i0, %i1, %i2] : memref - scf.yield + scf.reduce } return } @@ -55,7 +55,7 @@ func.func @single_iteration_all(%A: memref) { // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index // CHECK-NOT: scf.parallel // CHECK: memref.store [[C42]], [[ARG0]]{{\[}}[[C0]], [[C3]], [[C7]]] : memref -// CHECK-NOT: scf.yield +// CHECK-NOT: scf.reduce // CHECK: return // ----- @@ -67,17 +67,15 @@ func.func @single_iteration_reduce(%A: index, %B: index) -> (index, index) { %c3 = arith.constant 3 : index %c6 = arith.constant 6 : index %0:2 = scf.parallel (%i0, %i1) = (%c1, %c3) to (%c2, %c6) step (%c1, %c3) init(%A, %B) -> (index, index) { - scf.reduce(%i0) : index { + scf.reduce(%i0, %i1 : index, index) { ^bb0(%lhs: index, %rhs: index): %1 = arith.addi %lhs, %rhs : index scf.reduce.return %1 : index - } - scf.reduce(%i1) : index { + }, { ^bb0(%lhs: index, %rhs: index): %2 = arith.muli %lhs, %rhs : index scf.reduce.return %2 : index } - scf.yield } return %0#0, %0#1 : index, index } @@ -109,11 +107,11 @@ func.func @nested_parallel(%0: memref) -> memref { scf.parallel (%arg3) = (%c0) to (%3) step (%c1) { %5 = memref.load %0[%arg1, %arg2, %arg3] : memref memref.store %5, %4[%arg1, %arg2, %arg3] : memref - scf.yield + scf.reduce } - scf.yield + scf.reduce } - scf.yield + scf.reduce } return %4 : memref } @@ -759,12 +757,11 @@ func.func @remove_empty_parallel_loop(%lb: index, %ub: index, %s: index) { // CHECK-NOT: test.transform %0 = scf.parallel (%i, %j, %k) = (%lb, %ub, %lb) to (%ub, %ub, %ub) step (%s, %s, %s) init(%init) -> f32 { %1 = "test.produce"() : () -> f32 - scf.reduce(%1) : f32 { + scf.reduce(%1 : f32) { ^bb0(%lhs: f32, %rhs: f32): %2 = "test.transform"(%lhs, %rhs) : (f32, f32) -> f32 scf.reduce.return %2 : f32 } - scf.yield } // CHECK: "test.consume"(%[[INIT]]) "test.consume"(%0) : (f32) -> () diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir index ad07a8b11327d..fac9d825568f7 100644 --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -235,7 +235,7 @@ func.func @parallel_fewer_results_than_reduces( // expected-error@+1 {{expects number of results: 0 to be the same as number of reductions: 1}} scf.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) { %c0 = arith.constant 1.0 : f32 - scf.reduce(%c0) : f32 { + scf.reduce(%c0 : f32) { ^bb0(%lhs: f32, %rhs: f32): scf.reduce.return %lhs : f32 } @@ -261,7 +261,7 @@ func.func @parallel_more_results_than_initial_values( %arg0 : index, %arg1: index, %arg2: index) { // expected-error@+1 {{'scf.parallel' 0 operands present, but expected 1}} %res = scf.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) -> f32 { - scf.reduce(%arg0) : index { + scf.reduce(%arg0 : index) { ^bb0(%lhs: index, %rhs: index): scf.reduce.return %lhs : index } @@ -275,8 +275,8 @@ func.func @parallel_different_types_of_results_and_reduces( %zero = arith.constant 0.0 : f32 %res = scf.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) init (%zero) -> f32 { - // expected-error@+1 {{expects type of reduce: 'index' to be the same as result type: 'f32'}} - scf.reduce(%arg0) : index { + // expected-error@+1 {{expects type of 0-th reduction operand: 'index' to be the same as the 0-th result type: 'f32'}} + scf.reduce(%arg0 : index) { ^bb0(%lhs: index, %rhs: index): scf.reduce.return %lhs : index } @@ -288,7 +288,7 @@ func.func @parallel_different_types_of_results_and_reduces( func.func @top_level_reduce(%arg0 : f32) { // expected-error@+1 {{expects parent op 'scf.parallel'}} - scf.reduce(%arg0) : f32 { + scf.reduce(%arg0 : f32) { ^bb0(%lhs : f32, %rhs : f32): scf.reduce.return %lhs : f32 } @@ -302,7 +302,7 @@ func.func @reduce_empty_block(%arg0 : index, %arg1 : f32) { %res = scf.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) -> f32 { // expected-error@+1 {{empty block: expect at least a terminator}} - scf.reduce(%arg1) : f32 { + scf.reduce(%arg1 : f32) { ^bb0(%lhs : f32, %rhs : f32): } } @@ -315,8 +315,8 @@ func.func @reduce_too_many_args(%arg0 : index, %arg1 : f32) { %zero = arith.constant 0.0 : f32 %res = scf.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) -> f32 { - // expected-error@+1 {{expects two arguments to reduce block of type 'f32'}} - scf.reduce(%arg1) : f32 { + // expected-error@+1 {{expected two block arguments with type 'f32' in the 0-th reduction region}} + scf.reduce(%arg1 : f32) { ^bb0(%lhs : f32, %rhs : f32, %other : f32): scf.reduce.return %lhs : f32 } @@ -330,8 +330,8 @@ func.func @reduce_wrong_args(%arg0 : index, %arg1 : f32) { %zero = arith.constant 0.0 : f32 %res = scf.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) -> f32 { - // expected-error@+1 {{expects two arguments to reduce block of type 'f32'}} - scf.reduce(%arg1) : f32 { + // expected-error@+1 {{expected two block arguments with type 'f32' in the 0-th reduction region}} + scf.reduce(%arg1 : f32) { ^bb0(%lhs : f32, %rhs : i32): scf.reduce.return %lhs : f32 } @@ -346,8 +346,8 @@ func.func @reduce_wrong_terminator(%arg0 : index, %arg1 : f32) { %zero = arith.constant 0.0 : f32 %res = scf.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) -> f32 { - // expected-error@+1 {{the block inside reduce should be terminated with a 'scf.reduce.return' op}} - scf.reduce(%arg1) : f32 { + // expected-error@+1 {{reduction bodies must be terminated with an 'scf.reduce.return' op}} + scf.reduce(%arg1 : f32) { ^bb0(%lhs : f32, %rhs : f32): "test.finish" () : () -> () } @@ -361,10 +361,10 @@ func.func @reduceReturn_wrong_type(%arg0 : index, %arg1: f32) { %zero = arith.constant 0.0 : f32 %res = scf.parallel (%i0) = (%arg0) to (%arg0) step (%arg0) init (%zero) -> f32 { - scf.reduce(%arg1) : f32 { + scf.reduce(%arg1 : f32) { ^bb0(%lhs : f32, %rhs : f32): %c0 = arith.constant 1 : index - // expected-error@+1 {{needs to have type 'f32' (the type of the enclosing ReduceOp)}} + // expected-error@+1 {{must have type 'f32' (the type of the reduction inputs)}} scf.reduce.return %c0 : index } } @@ -475,9 +475,10 @@ func.func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : ind func.func @parallel_invalid_yield( %arg0: index, %arg1: index, %arg2: index) { + // expected-error@below {{expects body to terminate with 'scf.reduce'}} scf.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) { %c0 = arith.constant 1.0 : f32 - // expected-error@+1 {{'scf.yield' op not allowed to have operands inside 'scf.parallel'}} + // expected-note@below {{terminator here}} scf.yield %c0 : f32 } return @@ -487,7 +488,7 @@ func.func @parallel_invalid_yield( func.func @yield_invalid_parent_op() { "my.op"() ({ - // expected-error@+1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.index_switch, scf.parallel, scf.while'}} + // expected-error@+1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.index_switch, scf.while'}} scf.yield }) : () -> () return @@ -749,7 +750,7 @@ func.func @switch_missing_terminator(%arg0: index, %arg1: i32) { // ----- func.func @parallel_missing_terminator(%0 : index) { - // expected-error @below {{'scf.parallel' op expects body to terminate with 'scf.yield'}} + // expected-error @below {{expects body to terminate with 'scf.reduce'}} "scf.parallel"(%0, %0, %0) ({ ^bb0(%arg1: index): // expected-note @below {{terminator here}} diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir index 46d175d6870ce..7f457ef3b6ba0 100644 --- a/mlir/test/Dialect/SCF/ops.mlir +++ b/mlir/test/Dialect/SCF/ops.mlir @@ -87,18 +87,18 @@ func.func @std_parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index, %red:2 = scf.parallel (%i2) = (%min) to (%max) step (%i1) init (%zero, %int_zero) -> (f32, i32) { %one = arith.constant 1.0 : f32 - scf.reduce(%one) : f32 { + %int_one = arith.constant 1 : i32 + scf.reduce(%one, %int_one : f32, i32) { ^bb0(%lhs : f32, %rhs: f32): %res = arith.addf %lhs, %rhs : f32 scf.reduce.return %res : f32 - } - %int_one = arith.constant 1 : i32 - scf.reduce(%int_one) : i32 { + }, { ^bb0(%lhs : i32, %rhs: i32): %res = arith.muli %lhs, %rhs : i32 scf.reduce.return %res : i32 } } + scf.reduce } return } @@ -121,25 +121,23 @@ func.func @std_parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index, // CHECK-SAME: step (%[[I1]]) // CHECK-SAME: init (%[[ZERO]], %[[INT_ZERO]]) -> (f32, i32) { // CHECK-NEXT: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK-NEXT: scf.reduce(%[[ONE]]) : f32 { +// CHECK-NEXT: %[[INT_ONE:.*]] = arith.constant 1 : i32 +// CHECK-NEXT: scf.reduce(%[[ONE]], %[[INT_ONE]] : f32, i32) { // CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): // CHECK-NEXT: %[[RES:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32 // CHECK-NEXT: scf.reduce.return %[[RES]] : f32 -// CHECK-NEXT: } -// CHECK-NEXT: %[[INT_ONE:.*]] = arith.constant 1 : i32 -// CHECK-NEXT: scf.reduce(%[[INT_ONE]]) : i32 { +// CHECK-NEXT: }, { // CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32): // CHECK-NEXT: %[[RES:.*]] = arith.muli %[[LHS]], %[[RHS]] : i32 // CHECK-NEXT: scf.reduce.return %[[RES]] : i32 // CHECK-NEXT: } -// CHECK-NEXT: scf.yield // CHECK-NEXT: } -// CHECK-NEXT: scf.yield +// CHECK-NEXT: scf.reduce func.func @parallel_explicit_yield( %arg0: index, %arg1: index, %arg2: index) { scf.parallel (%i0) = (%arg0) to (%arg1) step (%arg2) { - scf.yield + scf.reduce } return } @@ -149,7 +147,7 @@ func.func @parallel_explicit_yield( // CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: // CHECK-SAME: %[[ARG2:[A-Za-z0-9]+]]: // CHECK-NEXT: scf.parallel (%{{.*}}) = (%[[ARG0]]) to (%[[ARG1]]) step (%[[ARG2]]) -// CHECK-NEXT: scf.yield +// CHECK-NEXT: scf.reduce // CHECK-NEXT: } // CHECK-NEXT: return // CHECK-NEXT: } diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir index 8a42b3a1000ed..9fd33b4e52471 100644 --- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir +++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir @@ -5,10 +5,10 @@ func.func @fuse_empty_loops() { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { - scf.yield + scf.reduce } scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { - scf.yield + scf.reduce } return } @@ -18,7 +18,7 @@ func.func @fuse_empty_loops() { // CHECK: [[C1:%.*]] = arith.constant 1 : index // CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) // CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { -// CHECK: scf.yield +// CHECK: scf.reduce // CHECK: } // CHECK-NOT: scf.parallel @@ -35,14 +35,14 @@ func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>, %C_elem = memref.load %C[%i, %j] : memref<2x2xf32> %sum_elem = arith.addf %B_elem, %C_elem : f32 memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32> - scf.yield + scf.reduce } scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> %product_elem = arith.mulf %sum_elem, %A_elem : f32 memref.store %product_elem, %result[%i, %j] : memref<2x2xf32> - scf.yield + scf.reduce } memref.dealloc %sum : memref<2x2xf32> return @@ -64,7 +64,7 @@ func.func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>, // CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]] // CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]] // CHECK: memref.store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]] -// CHECK: scf.yield +// CHECK: scf.reduce // CHECK: } // CHECK: memref.dealloc [[SUM]] @@ -81,20 +81,20 @@ func.func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>, scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) { %rhs_elem = memref.load %rhs[%i] : memref<100xf32> memref.store %rhs_elem, %broadcast_rhs[%i, %j] : memref<100x10xf32> - scf.yield + scf.reduce } scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) { %lhs_elem = memref.load %lhs[%i, %j] : memref<100x10xf32> %broadcast_rhs_elem = memref.load %broadcast_rhs[%i, %j] : memref<100x10xf32> %diff_elem = arith.subf %lhs_elem, %broadcast_rhs_elem : f32 memref.store %diff_elem, %diff[%i, %j] : memref<100x10xf32> - scf.yield + scf.reduce } scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) { %diff_elem = memref.load %diff[%i, %j] : memref<100x10xf32> %exp_elem = math.exp %diff_elem : f32 memref.store %exp_elem, %result[%i, %j] : memref<100x10xf32> - scf.yield + scf.reduce } memref.dealloc %broadcast_rhs : memref<100x10xf32> memref.dealloc %diff : memref<100x10xf32> @@ -120,7 +120,7 @@ func.func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>, // CHECK: [[DIFF_ELEM_:%.*]] = memref.load [[DIFF]]{{\[}}[[I]], [[J]]] // CHECK: [[EXP_ELEM:%.*]] = math.exp [[DIFF_ELEM_]] // CHECK: memref.store [[EXP_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]] -// CHECK: scf.yield +// CHECK: scf.reduce // CHECK: } // CHECK: memref.dealloc [[BROADCAST_RHS]] // CHECK: memref.dealloc [[DIFF]] @@ -133,12 +133,12 @@ func.func @do_not_fuse_nested_ploop1() { %c1 = arith.constant 1 : index scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { - scf.yield + scf.reduce } - scf.yield + scf.reduce } scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { - scf.yield + scf.reduce } return } @@ -154,13 +154,13 @@ func.func @do_not_fuse_nested_ploop2() { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { - scf.yield + scf.reduce } scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { - scf.yield + scf.reduce } - scf.yield + scf.reduce } return } @@ -176,10 +176,10 @@ func.func @do_not_fuse_loops_unmatching_num_loops() { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { - scf.yield + scf.reduce } scf.parallel (%i) = (%c0) to (%c2) step (%c1) { - scf.yield + scf.reduce } return } @@ -194,11 +194,11 @@ func.func @do_not_fuse_loops_with_side_effecting_ops_in_between() { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { - scf.yield + scf.reduce } %buffer = memref.alloc() : memref<2x2xf32> scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { - scf.yield + scf.reduce } return } @@ -214,10 +214,10 @@ func.func @do_not_fuse_loops_unmatching_iteration_space() { %c2 = arith.constant 2 : index %c4 = arith.constant 4 : index scf.parallel (%i, %j) = (%c0, %c0) to (%c4, %c4) step (%c2, %c2) { - scf.yield + scf.reduce } scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { - scf.yield + scf.reduce } return } @@ -239,7 +239,7 @@ func.func @do_not_fuse_unmatching_write_read_patterns( %C_elem = memref.load %C[%i, %j] : memref<2x2xf32> %sum_elem = arith.addf %B_elem, %C_elem : f32 memref.store %sum_elem, %common_buf[%i, %j] : memref<2x2xf32> - scf.yield + scf.reduce } scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { %k = arith.addi %i, %c1 : index @@ -247,7 +247,7 @@ func.func @do_not_fuse_unmatching_write_read_patterns( %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> %product_elem = arith.mulf %sum_elem, %A_elem : f32 memref.store %product_elem, %result[%i, %j] : memref<2x2xf32> - scf.yield + scf.reduce } memref.dealloc %common_buf : memref<2x2xf32> return @@ -269,7 +269,7 @@ func.func @do_not_fuse_unmatching_read_write_patterns( %C_elem = memref.load %common_buf[%i, %j] : memref<2x2xf32> %sum_elem = arith.addf %B_elem, %C_elem : f32 memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32> - scf.yield + scf.reduce } scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { %k = arith.addi %i, %c1 : index @@ -277,7 +277,7 @@ func.func @do_not_fuse_unmatching_read_write_patterns( %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> %product_elem = arith.mulf %sum_elem, %A_elem : f32 memref.store %product_elem, %common_buf[%j, %i] : memref<2x2xf32> - scf.yield + scf.reduce } memref.dealloc %sum : memref<2x2xf32> return @@ -294,13 +294,13 @@ func.func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() { %c1 = arith.constant 1 : index %buffer = memref.alloc() : memref<2x2xf32> scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { - scf.yield + scf.reduce } scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { %A = memref.subview %buffer[%c0, %c0][%c2, %c2][%c1, %c1] : memref<2x2xf32> to memref> %A_elem = memref.load %A[%i, %j] : memref> - scf.yield + scf.reduce } return } @@ -322,14 +322,14 @@ func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>, %C_elem = memref.load %C[%i, %j] : memref<2x2xf32> %sum_elem = arith.addf %B_elem, %C_elem : f32 memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32> - scf.yield + scf.reduce } scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> %product_elem = arith.mulf %sum_elem, %A_elem : f32 memref.store %product_elem, %result[%i, %j] : memref<2x2xf32> - scf.yield + scf.reduce } } memref.dealloc %sum : memref<2x2xf32> @@ -353,7 +353,7 @@ func.func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>, // CHECK: [[A_ELEM:%.*]] = memref.load [[A]]{{\[}}[[I]], [[J]]] // CHECK: [[PRODUCT_ELEM:%.*]] = arith.mulf [[SUM_ELEM_]], [[A_ELEM]] // CHECK: memref.store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]] -// CHECK: scf.yield +// CHECK: scf.reduce // CHECK: } // CHECK: } // CHECK: memref.dealloc [[SUM]] @@ -371,14 +371,14 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>, %C_elem = memref.load %C[%i, %j] : memref<2x2xf32> %sum_elem = arith.addf %B_elem, %C_elem : f32 memref.store %sum_elem, %sum[%i, %j] : memref<2x2xf32> - scf.yield + scf.reduce } scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { %sum_elem = memref.load %sum[%i, %j] : memref<2x2xf32> %A_elem = memref.load %A[%i, %j] : memref<2x2xf32> %product_elem = arith.mulf %sum_elem, %A_elem : f32 memref.store %product_elem, %result[%i, %j] : memref<2x2xf32> - scf.yield + scf.reduce } return } diff --git a/mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir b/mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir index 7a35e0ff0c3a9..61b50bcd7d0c6 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir @@ -36,15 +36,14 @@ // CHECK: %[[TMP_12:.*]] = memref.load %[[TMP_2]][%[[TMP_arg4]]] : memref // CHECK: %[[TMP_13:.*]] = memref.load %[[TMP_3]][%[[TMP_11]]] : memref<32xf32> // CHECK: %[[TMP_14:.*]] = arith.mulf %[[TMP_12]], %[[TMP_13]] : f32 -// CHECK: scf.reduce(%[[TMP_14]]) : f32 { +// CHECK: scf.reduce(%[[TMP_14]] : f32) { // CHECK: ^bb0(%[[TMP_arg5:.*]]: f32, %[[TMP_arg6:.*]]: f32): // CHECK: %[[TMP_15:.*]] = arith.addf %[[TMP_arg5]], %[[TMP_arg6]] : f32 // CHECK: scf.reduce.return %[[TMP_15]] : f32 // CHECK: } -// CHECK: scf.yield // CHECK: } // CHECK: memref.store %[[TMP_10]], %[[TMP_4]][%[[TMP_arg3]]] : memref<16xf32> -// CHECK: scf.yield +// CHECK: scf.reduce // CHECK: } // CHECK: %[[TMP_5:.*]] = bufferization.to_tensor %[[TMP_4]] : memref<16xf32> // CHECK: return %[[TMP_5]] : tensor<16xf32> diff --git a/mlir/test/Transforms/invalid-parallel-loop-collapsing.mlir b/mlir/test/Transforms/invalid-parallel-loop-collapsing.mlir index 6f98d2c062a25..4a3e4dc35d4f1 100644 --- a/mlir/test/Transforms/invalid-parallel-loop-collapsing.mlir +++ b/mlir/test/Transforms/invalid-parallel-loop-collapsing.mlir @@ -20,7 +20,7 @@ func.func @too_few_iters(%arg0: index, %arg1: index, %arg2: index) { // expected-error @+1 {{op has 1 iter args while this limited functionality testing pass was configured only for loops with exactly 2 iter args.}} scf.parallel (%arg3) = (%arg0) to (%arg1) step (%arg2) { - scf.yield + scf.reduce } return } @@ -28,7 +28,7 @@ func.func @too_few_iters(%arg0: index, %arg1: index, %arg2: index) { func.func @too_many_iters(%arg0: index, %arg1: index, %arg2: index) { // expected-error @+1 {{op has 3 iter args while this limited functionality testing pass was configured only for loops with exactly 2 iter args.}} scf.parallel (%arg3, %arg4, %arg5) = (%arg0, %arg0, %arg0) to (%arg1, %arg1, %arg1) step (%arg2, %arg2, %arg2) { - scf.yield + scf.reduce } return } diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir index 1415583dde9da..dcc314f36ae0a 100644 --- a/mlir/test/Transforms/loop-invariant-code-motion.mlir +++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir @@ -374,7 +374,7 @@ func.func @parallel_loop_with_invariant() { // CHECK-NEXT: arith.addi // CHECK-NEXT: scf.parallel (%[[A:.*]],{{.*}}) = // CHECK-NEXT: arith.addi %[[A]] - // CHECK-NEXT: yield + // CHECK-NEXT: reduce // CHECK-NEXT: } // CHECK-NEXT: return diff --git a/mlir/test/Transforms/parallel-loop-collapsing.mlir b/mlir/test/Transforms/parallel-loop-collapsing.mlir index c606fe7588526..660d7edb2fbb3 100644 --- a/mlir/test/Transforms/parallel-loop-collapsing.mlir +++ b/mlir/test/Transforms/parallel-loop-collapsing.mlir @@ -43,4 +43,4 @@ func.func @parallel_many_dims() { // CHECK: [[V2:%.*]] = arith.muli [[V0]], [[C10]] : index // CHECK: [[I3:%.*]] = arith.addi [[V2]], [[C9]] : index // CHECK: "magic.op"([[I0]], [[C3]], [[C6]], [[I3]], [[C12]]) : (index, index, index, index, index) -> index -// CHECK: scf.yield +// CHECK: scf.reduce diff --git a/mlir/test/Transforms/single-parallel-loop-collapsing.mlir b/mlir/test/Transforms/single-parallel-loop-collapsing.mlir index 7b6883896dc10..542786b5fa5e5 100644 --- a/mlir/test/Transforms/single-parallel-loop-collapsing.mlir +++ b/mlir/test/Transforms/single-parallel-loop-collapsing.mlir @@ -29,6 +29,6 @@ func.func @collapse_to_single() { // CHECK: [[V1:%.*]] = arith.muli [[I1_COUNT]], [[C3]] : index // CHECK: [[I0:%.*]] = arith.addi [[V1]], [[C3]] : index // CHECK: "magic.op"([[I0]], [[I1]]) : (index, index) -> index -// CHECK: scf.yield +// CHECK: scf.reduce // CHECK-NEXT: } // CHECK-NEXT: return