Skip to content

Commit 99833cd

Browse files
committed
[mlir][linalg] Add reduction tiling using scf.foreachthread
This adds a transformation to tile reduction operations to partial reduction using scf.foreachthread. This uses PartialReductionOpInterface to create a merge operation of the partial tiles. Differential Revision: https://reviews.llvm.org/D137912
1 parent b40126b commit 99833cd

File tree

5 files changed

+467
-48
lines changed

5 files changed

+467
-48
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,6 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
626626
}];
627627
}
628628

629-
630629
def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_using_scf",
631630
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
632631
TransformEachOpTrait, TransformOpInterface]> {
@@ -714,6 +713,89 @@ def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_u
714713
}];
715714
}
716715

716+
def TileReductionUsingForeachThreadOp :
717+
Op<Transform_Dialect, "structured.tile_reduction_using_foreach_thread",
718+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
719+
TransformEachOpTrait, TransformOpInterface]> {
720+
let description = [{
721+
Tile a PartialReductionOpInterface op to a tiled `scf.foreach_thread` doing
722+
partial reduction.
723+
724+
This transformation tiles the `target` along the reduction dimensions. It
725+
creates a tensor initialized with the identity value. Then it creates a
726+
`scf.foreach_thread` loops with the number threads given by `num_threads`.
727+
The op is tiled op with a size equal to `floordiv(size, num_threads)`.
728+
All the partial reduction value is are parallel inserted to create a new
729+
tensor. After the loop a merge operation is created to do a final reduction
730+
with the partial reductions tensor.
731+
732+
#### Return modes
733+
734+
This 3 returned handles point to:
735+
- the fill op used to initialize the neutral element,
736+
- the parallel tiled op and
737+
- the result-combining op.
738+
739+
#### Example:
740+
741+
```
742+
%red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
743+
affine_map<(d0, d1) -> (d0)>],
744+
iterator_types = ["parallel", "reduction"]}
745+
ins(%arg0 : tensor<?x?xf32>)
746+
outs(%out : tensor<?xf32>) {
747+
^bb0(%arg7: f32, %arg9: f32):
748+
%1 = arith.addf %arg7, %arg9 : f32
749+
linalg.yield %1 : f32
750+
} -> tensor<?xf32>
751+
return %red : tensor<?xf32>
752+
```
753+
754+
is transformed into:
755+
756+
```
757+
%0 = tensor.empty(%dim_1) : tensor<?x5xf32>
758+
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x5xf32>) -> tensor<?x5xf32>
759+
%2 = scf.foreach_thread (%arg2) in (%c5) shared_outs(%arg3 = %1) -> (tensor<?x5xf32>) {
760+
%4 = affine.min #map(%arg2)[%dim_0]
761+
%5 = affine.max #map1(%4)
762+
%extracted_slice = tensor.extract_slice %arg3[0, %arg2] [%dim, 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
763+
%6 = affine.apply #map2(%arg2)[%dim_0]
764+
%extracted_slice_2 = tensor.extract_slice %arg0[0, %6] [%dim, %5] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
765+
%extracted_slice_3 = tensor.extract_slice %extracted_slice[0] [%dim] [1] : tensor<?xf32> to tensor<?xf32>
766+
%7 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice_2 : tensor<?x?xf32>) outs(%extracted_slice_3 : tensor<?xf32>) {
767+
^bb0(%in: f32, %out: f32):
768+
%9 = arith.addf %in, %out : f32
769+
linalg.yield %9 : f32
770+
} -> tensor<?xf32>
771+
scf.foreach_thread.perform_concurrently {
772+
tensor.parallel_insert_slice %7 into %arg3[0, %arg2] [%dim, 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
773+
}
774+
} {thread_dim_mapping = []}
775+
%3 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "reduction"]} ins(%2 : tensor<?x5xf32>) outs(%arg1 : tensor<?xf32>) {
776+
^bb0(%in: f32, %out: f32):
777+
%4 = arith.addf %in, %out : f32
778+
linalg.yield %4 : f32
779+
} -> tensor<?xf32>
780+
```
781+
}];
782+
783+
let arguments = (ins PDL_Operation:$target,
784+
DefaultValuedAttr<I64ArrayAttr, "{}">:$num_threads);
785+
let results = (outs PDL_Operation:$fill_op,
786+
PDL_Operation:$split_linalg_op,
787+
PDL_Operation:$combining_linalg_op);
788+
789+
let assemblyFormat = "$target attr-dict";
790+
791+
let extraClassDeclaration = [{
792+
::mlir::DiagnosedSilenceableFailure applyToOne(
793+
::mlir::linalg::LinalgOp target,
794+
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
795+
::mlir::transform::TransformState &state);
796+
}];
797+
}
798+
717799
def TileOp : Op<Transform_Dialect, "structured.tile",
718800
[DeclareOpInterfaceMethods<TransformOpInterface>,
719801
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,47 @@ tileToForeachThreadOpUsingTileSizes(RewriterBase &builder, TilingInterface op,
445445
ArrayRef<OpFoldResult> tileSizes,
446446
Optional<ArrayAttr> mapping);
447447

448+
/// Transformation information returned after reduction tiling.
449+
struct ForeachThreadReductionTilingResult {
450+
/// The partial reduction tiled op generated.
451+
Operation *parallelTiledOp;
452+
/// The final reduction operation merging all the partial reductions.
453+
Operation *mergeOp;
454+
/// The op initializing the tensor used for partial reductions.
455+
Operation *initialOp;
456+
/// The `scf.foreach_thread` operation that iterate over the tiles.
457+
scf::ForeachThreadOp loops;
458+
};
459+
460+
/// Method to tile a reduction to parallel iterations computing partial
461+
/// reductions. After the loop all the partial reduction are merged into a final
462+
/// reduction. For example for the following sequence
463+
///
464+
/// ```mlir
465+
/// %0 = linalg.generic %in ["parallel", "reduction"]
466+
/// : tensor<7x9xf32> -> tensor<7xf32>
467+
/// ```
468+
///
469+
/// into:
470+
///
471+
/// ```mlir
472+
/// %0 = linalg.fill ... : tensor<7x4xf32>
473+
/// %1 = scf.foreach_thread (%iv) in (%c4) shared_outs(%arg0 = %0)
474+
/// -> (tensor<7x4xf32>) {
475+
/// %2 = tensor.extract_slice %arg3 : tensor<7x4xf32> to tensor<7xf32>
476+
/// %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32>
477+
/// %4 = linalg.generic %2, %3 ["parallel", "reduction"]
478+
/// : tensor<7x?xf32> -> tensor<7xf32>
479+
/// %5 = tensor.insert_slice %3, %arg0[0, %iv] : tensor<7x4xf32>
480+
/// }
481+
/// %6 = linalg.generic %1 ["parallel", "reduction"]
482+
/// : tensor<7x4xf32> -> tensor<7xf32>
483+
/// ```
484+
FailureOr<ForeachThreadReductionTilingResult>
485+
tileReductionUsingForeachThread(RewriterBase &b, PartialReductionOpInterface op,
486+
ArrayRef<OpFoldResult> numThreads,
487+
Optional<ArrayAttr> mapping);
488+
448489
/// All indices returned by IndexOp should be invariant with respect to
449490
/// tiling. Therefore, if an operation is tiled, we have to transform the
450491
/// indices accordingly, i.e. offset them by the values of the corresponding

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,6 +1165,39 @@ DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne(
11651165
return DiagnosedSilenceableFailure(success());
11661166
}
11671167

1168+
//===----------------------------------------------------------------------===//
1169+
// TileReductionUsingForeachThreadOp
1170+
//===----------------------------------------------------------------------===//
1171+
1172+
DiagnosedSilenceableFailure
1173+
transform::TileReductionUsingForeachThreadOp::applyToOne(
1174+
linalg::LinalgOp target, SmallVectorImpl<Operation *> &results,
1175+
transform::TransformState &state) {
1176+
SimpleRewriter rewriter(getContext());
1177+
rewriter.setInsertionPoint(target);
1178+
SmallVector<int64_t> numThreads = extractFromI64ArrayAttr(getNumThreads());
1179+
SmallVector<OpFoldResult> numThreadResults;
1180+
for (int64_t num : numThreads) {
1181+
numThreadResults.push_back(rewriter.getIndexAttr(num));
1182+
}
1183+
1184+
FailureOr<linalg::ForeachThreadReductionTilingResult> result =
1185+
linalg::tileReductionUsingForeachThread(
1186+
rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
1187+
numThreadResults, /*mapping=*/llvm::None);
1188+
1189+
if (failed(result)) {
1190+
results.assign(3, nullptr);
1191+
Diagnostic diag(target->getLoc(), DiagnosticSeverity::Remark);
1192+
diag << "could not tile reduction in target.";
1193+
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
1194+
}
1195+
results.push_back(result->initialOp);
1196+
results.push_back(result->parallelTiledOp);
1197+
results.push_back(result->mergeOp);
1198+
return DiagnosedSilenceableFailure(success());
1199+
}
1200+
11681201
//===----------------------------------------------------------------------===//
11691202
// TileOp
11701203
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)