Skip to content

Commit 163f463

Browse files
author
Spenser Bauman
committed
[mlir][scf] Implement conversion from scf.forall to scf.parallel
There is currently no path to lower scf.forall to scf.parallel with the goal of targeting the OpenMP dialect. In the SCF->ControlFlow conversion, scf.forall is briefly converted to scf.parallel, but the scf.parallel is lowered directly to a sequential loop. This makes experimenting with scf.forall for CPU execution difficult. This change factors out the rewrite in the SCF->ControlFlow pass into a utility function that can then be used in the SCF->ControlFlow lowering, but also in a separate -scf-forall-to-parallel pass.
1 parent 2db190f commit 163f463

File tree

10 files changed

+292
-27
lines changed

10 files changed

+292
-27
lines changed

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,32 @@ def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",
6868
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
6969
}
7070

71+
def ForallToParallelOp : Op<Transform_Dialect, "loop.forall_to_parallel",
72+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
73+
DeclareOpInterfaceMethods<TransformOpInterface>]> {
74+
let summary = "Converts scf.forall into a nest of scf.for operations";
75+
let description = [{
76+
Converts the `scf.forall` operation pointed to by the given handle into an
77+
`scf.parallel` operation.
78+
79+
The operand handle must be associated with exactly one payload operation.
80+
81+
Loops with outputs are not supported.
82+
83+
#### Return Modes
84+
85+
Consumes the operand handle. Produces a silenceable failure if the operand
86+
is not associated with a single `scf.forall` payload operation.
87+
Returns a handle to the new `scf.parallel` operation.
88+
Produces a silenceable failure if another number of resulting handles is
89+
requested.
90+
}];
91+
let arguments = (ins TransformHandleTypeInterface:$target);
92+
let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
93+
94+
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
95+
}
96+
7197
def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
7298
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
7399
DeclareOpInterfaceMethods<TransformOpInterface>]> {

mlir/include/mlir/Dialect/SCF/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ std::unique_ptr<Pass> createForLoopRangeFoldingPass();
6262
/// Creates a pass that converts SCF forall loops to SCF for loops.
6363
std::unique_ptr<Pass> createForallToForLoopPass();
6464

65+
/// Creates a pass that converts SCF forall loops to SCF parallel loops.
66+
std::unique_ptr<Pass> createForallToParallelLoopPass();
67+
6568
// Creates a pass which lowers for loops into while loops.
6669
std::unique_ptr<Pass> createForToWhileLoopPass();
6770

mlir/include/mlir/Dialect/SCF/Transforms/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ def SCFForallToForLoop : Pass<"scf-forall-to-for"> {
125125
let constructor = "mlir::createForallToForLoopPass()";
126126
}
127127

128+
def SCFForallToParallelLoop : Pass<"scf-forall-to-parallel"> {
129+
let summary = "Convert SCF forall loops to SCF parallel loops";
130+
let constructor = "mlir::createForallToParallelLoopPass()";
131+
}
132+
128133
def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
129134
let summary = "Convert SCF for loops to SCF while loops";
130135
let constructor = "mlir::createForToWhileLoopPass()";

mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ class WhileOp;
3939
LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp,
4040
SmallVectorImpl<Operation *> *results = nullptr);
4141

42+
/// Try converting scf.forall into an scf.parallel loop.
43+
/// The conversion is only supported for forall operations with no results.
44+
LogicalResult forallToParallelLoop(RewriterBase &rewriter,
45+
ForallOp forallOp,
46+
ParallelOp *result = nullptr);
47+
4248
/// Fuses all adjacent scf.parallel operations with identical bounds and step
4349
/// into one scf.parallel operations. Uses a naive aliasing and dependency
4450
/// analysis.

mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/Arith/IR/Arith.h"
1717
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
1818
#include "mlir/Dialect/SCF/IR/SCF.h"
19+
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
1920
#include "mlir/IR/Builders.h"
2021
#include "mlir/IR/BuiltinOps.h"
2122
#include "mlir/IR/IRMapping.h"
@@ -688,33 +689,7 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
688689

689690
LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
690691
PatternRewriter &rewriter) const {
691-
Location loc = forallOp.getLoc();
692-
if (!forallOp.getOutputs().empty())
693-
return rewriter.notifyMatchFailure(
694-
forallOp,
695-
"only fully bufferized scf.forall ops can be lowered to scf.parallel");
696-
697-
// Convert mixed bounds and steps to SSA values.
698-
SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
699-
rewriter, loc, forallOp.getMixedLowerBound());
700-
SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
701-
rewriter, loc, forallOp.getMixedUpperBound());
702-
SmallVector<Value> steps =
703-
getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
704-
705-
// Create empty scf.parallel op.
706-
auto parallelOp = rewriter.create<ParallelOp>(loc, lbs, ubs, steps);
707-
rewriter.eraseBlock(&parallelOp.getRegion().front());
708-
rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
709-
parallelOp.getRegion().begin());
710-
// Replace the terminator.
711-
rewriter.setInsertionPointToEnd(&parallelOp.getRegion().front());
712-
rewriter.replaceOpWithNewOp<scf::ReduceOp>(
713-
parallelOp.getRegion().front().getTerminator());
714-
715-
// Erase the scf.forall op.
716-
rewriter.replaceOp(forallOp, parallelOp);
717-
return success();
692+
return scf::forallToParallelLoop(rewriter, forallOp);
718693
}
719694

720695
void mlir::populateSCFToControlFlowConversionPatterns(

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,50 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
9898
return DiagnosedSilenceableFailure::success();
9999
}
100100

101+
//===----------------------------------------------------------------------===//
102+
// ForallToForOp
103+
//===----------------------------------------------------------------------===//
104+
105+
DiagnosedSilenceableFailure
106+
transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter,
107+
transform::TransformResults &results,
108+
transform::TransformState &state) {
109+
auto payload = state.getPayloadOps(getTarget());
110+
if (!llvm::hasSingleElement(payload))
111+
return emitSilenceableError() << "expected a single payload op";
112+
113+
auto target = dyn_cast<scf::ForallOp>(*payload.begin());
114+
if (!target) {
115+
DiagnosedSilenceableFailure diag =
116+
emitSilenceableError() << "expected the payload to be scf.forall";
117+
diag.attachNote((*payload.begin())->getLoc()) << "payload op";
118+
return diag;
119+
}
120+
121+
if (!target.getOutputs().empty()) {
122+
return emitSilenceableError()
123+
<< "unsupported shared outputs (didn't bufferize?)";
124+
}
125+
126+
if (getNumResults() != 1) {
127+
DiagnosedSilenceableFailure diag = emitSilenceableError()
128+
<< "op expects one result, given "
129+
<< getNumResults();
130+
diag.attachNote(target.getLoc()) << "payload op";
131+
return diag;
132+
}
133+
134+
scf::ParallelOp opResult;
135+
if (failed(scf::forallToParallelLoop(rewriter, target, &opResult))) {
136+
DiagnosedSilenceableFailure diag = emitSilenceableError()
137+
<< "failed to convert forall into parallel";
138+
return diag;
139+
}
140+
141+
results.set(cast<OpResult>(getTransformed()[0]), {opResult});
142+
return DiagnosedSilenceableFailure::success();
143+
}
144+
101145
//===----------------------------------------------------------------------===//
102146
// LoopOutlineOp
103147
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
33
BufferizableOpInterfaceImpl.cpp
44
Bufferize.cpp
55
ForallToFor.cpp
6+
ForallToParallel.cpp
67
ForToWhile.cpp
78
LoopCanonicalization.cpp
89
LoopPipelining.cpp
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
//===- ForallToParallel.cpp - scf.forall to scf.parallel loop conversion --===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Transforms SCF.ForallOp's into SCF.ParallelOps's.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
14+
#include "mlir/Dialect/SCF/IR/SCF.h"
15+
#include "mlir/Dialect/SCF/Transforms/Passes.h"
16+
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
19+
namespace mlir {
20+
#define GEN_PASS_DEF_SCFFORALLTOPARALLELLOOP
21+
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
22+
} // namespace mlir
23+
24+
using namespace mlir;
25+
26+
LogicalResult mlir::scf::forallToParallelLoop(RewriterBase &rewriter,
27+
scf::ForallOp forallOp,
28+
scf::ParallelOp *result) {
29+
OpBuilder::InsertionGuard guard(rewriter);
30+
rewriter.setInsertionPoint(forallOp);
31+
32+
Location loc = forallOp.getLoc();
33+
if (!forallOp.getOutputs().empty())
34+
return rewriter.notifyMatchFailure(
35+
forallOp,
36+
"only fully bufferized scf.forall ops can be lowered to scf.parallel");
37+
38+
// Convert mixed bounds and steps to SSA values.
39+
SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
40+
rewriter, loc, forallOp.getMixedLowerBound());
41+
SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
42+
rewriter, loc, forallOp.getMixedUpperBound());
43+
SmallVector<Value> steps =
44+
getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());
45+
46+
// Create empty scf.parallel op.
47+
auto parallelOp = rewriter.create<scf::ParallelOp>(loc, lbs, ubs, steps);
48+
rewriter.eraseBlock(&parallelOp.getRegion().front());
49+
rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
50+
parallelOp.getRegion().begin());
51+
// Replace the terminator.
52+
rewriter.setInsertionPointToEnd(&parallelOp.getRegion().front());
53+
rewriter.replaceOpWithNewOp<scf::ReduceOp>(
54+
parallelOp.getRegion().front().getTerminator());
55+
56+
// Erase the scf.forall op.
57+
rewriter.replaceOp(forallOp, parallelOp);
58+
59+
if (result)
60+
*result = parallelOp;
61+
62+
return success();
63+
}
64+
65+
namespace {
66+
struct ForallToParallelLoop final
67+
: public impl::SCFForallToParallelLoopBase<ForallToParallelLoop> {
68+
void runOnOperation() override {
69+
Operation *parentOp = getOperation();
70+
IRRewriter rewriter(parentOp->getContext());
71+
72+
parentOp->walk([&](scf::ForallOp forallOp) {
73+
if (failed(scf::forallToParallelLoop(rewriter, forallOp))) {
74+
return signalPassFailure();
75+
}
76+
});
77+
}
78+
};
79+
} // namespace
80+
81+
std::unique_ptr<Pass> mlir::createForallToParallelLoopPass() {
82+
return std::make_unique<ForallToParallelLoop>();
83+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-parallel))' -split-input-file | FileCheck %s
2+
3+
func.func private @callee(%i: index, %j: index)
4+
5+
// CHECK-LABEL: @two_iters
6+
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
7+
func.func @two_iters(%ub1: index, %ub2: index) {
8+
scf.forall (%i, %j) in (%ub1, %ub2) {
9+
func.call @callee(%i, %j) : (index, index) -> ()
10+
}
11+
12+
// CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
13+
// CHECK: func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
14+
// CHECK: scf.reduce
15+
return
16+
}
17+
18+
// -----
19+
20+
func.func private @callee(%i: index, %j: index)
21+
22+
// CHECK-LABEL: @repeated
23+
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
24+
func.func @repeated(%ub1: index, %ub2: index) {
25+
scf.forall (%i, %j) in (%ub1, %ub2) {
26+
func.call @callee(%i, %j) : (index, index) -> ()
27+
}
28+
29+
// CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
30+
// CHECK: func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
31+
// CHECK: scf.reduce
32+
scf.forall (%i, %j) in (%ub1, %ub2) {
33+
func.call @callee(%i, %j) : (index, index) -> ()
34+
}
35+
36+
// CHECK: scf.parallel (%[[IV3:.+]], %[[IV4:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
37+
// CHECK: func.call @callee(%[[IV3]], %[[IV4]])
38+
// CHECK: scf.reduce
39+
return
40+
}
41+
42+
// -----
43+
44+
func.func private @callee(%i: index, %j: index, %k: index, %l: index)
45+
46+
// CHECK-LABEL: @nested
47+
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index, %[[UB3:.+]]: index, %[[UB4:.+]]: index
48+
func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) {
49+
// CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]]) step (%{{.*}}, %{{.*}}) {
50+
// CHECK: scf.parallel (%[[IV3:.+]], %[[IV4:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB3]], %[[UB4]]) step (%{{.*}}, %{{.*}}) {
51+
// CHECK: func.call @callee(%[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]])
52+
// CHECK: scf.reduce
53+
// CHECK: }
54+
// CHECK: scf.reduce
55+
// CHECK: }
56+
scf.forall (%i, %j) in (%ub1, %ub2) {
57+
scf.forall (%k, %l) in (%ub3, %ub4) {
58+
func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
59+
}
60+
}
61+
return
62+
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics | FileCheck %s
2+
3+
func.func private @callee(%i: index, %j: index)
4+
5+
// CHECK-LABEL: @two_iters
6+
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
7+
func.func @two_iters(%ub1: index, %ub2: index) {
8+
scf.forall (%i, %j) in (%ub1, %ub2) {
9+
func.call @callee(%i, %j) : (index, index) -> ()
10+
}
11+
// CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
12+
// CHECK: func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
13+
// CHECK: scf.reduce
14+
return
15+
}
16+
17+
module attributes {transform.with_named_sequence} {
18+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
19+
%0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
20+
transform.loop.forall_to_parallel %0 : (!transform.any_op) -> (!transform.any_op)
21+
transform.yield
22+
}
23+
}
24+
25+
// -----
26+
27+
func.func private @callee(%i: index, %j: index)
28+
29+
func.func @repeated(%ub1: index, %ub2: index) {
30+
scf.forall (%i, %j) in (%ub1, %ub2) {
31+
func.call @callee(%i, %j) : (index, index) -> ()
32+
}
33+
scf.forall (%i, %j) in (%ub1, %ub2) {
34+
func.call @callee(%i, %j) : (index, index) -> ()
35+
}
36+
return
37+
}
38+
39+
module attributes {transform.with_named_sequence} {
40+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
41+
%0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
42+
// expected-error @below {{expected a single payload op}}
43+
transform.loop.forall_to_parallel %0 : (!transform.any_op) -> (!transform.any_op)
44+
transform.yield
45+
}
46+
}
47+
48+
// -----
49+
50+
// expected-note @below {{payload op}}
51+
func.func private @callee(%i: index, %j: index)
52+
53+
module attributes {transform.with_named_sequence} {
54+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
55+
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
56+
// expected-error @below {{expected the payload to be scf.forall}}
57+
transform.loop.forall_to_for %0 : (!transform.any_op) -> !transform.any_op
58+
transform.yield
59+
}
60+
}

0 commit comments

Comments
 (0)