@@ -626,7 +626,6 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
626626 }];
627627}
628628
629-
630629def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_using_scf",
631630 [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
632631 TransformEachOpTrait, TransformOpInterface]> {
@@ -714,6 +713,89 @@ def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_u
714713 }];
715714}
716715
716+ def TileReductionUsingForeachThreadOp :
717+ Op<Transform_Dialect, "structured.tile_reduction_using_foreach_thread",
718+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
719+ TransformEachOpTrait, TransformOpInterface]> {
720+ let description = [{
721+ Tile a PartialReductionOpInterface op to a tiled `scf.foreach_thread` doing
722+ partial reduction.
723+
724+ This transformation tiles the `target` along the reduction dimensions. It
725+ creates a tensor initialized with the identity value. Then it creates a
726+ `scf.foreach_thread` loops with the number threads given by `num_threads`.
727+ The op is tiled op with a size equal to `floordiv(size, num_threads)`.
728+ All the partial reduction value is are parallel inserted to create a new
729+ tensor. After the loop a merge operation is created to do a final reduction
730+ with the partial reductions tensor.
731+
732+ #### Return modes
733+
734+ This 3 returned handles point to:
735+ - the fill op used to initialize the neutral element,
736+ - the parallel tiled op and
737+ - the result-combining op.
738+
739+ #### Example:
740+
741+ ```
742+ %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
743+ affine_map<(d0, d1) -> (d0)>],
744+ iterator_types = ["parallel", "reduction"]}
745+ ins(%arg0 : tensor<?x?xf32>)
746+ outs(%out : tensor<?xf32>) {
747+ ^bb0(%arg7: f32, %arg9: f32):
748+ %1 = arith.addf %arg7, %arg9 : f32
749+ linalg.yield %1 : f32
750+ } -> tensor<?xf32>
751+ return %red : tensor<?xf32>
752+ ```
753+
754+ is transformed into:
755+
756+ ```
757+ %0 = tensor.empty(%dim_1) : tensor<?x5xf32>
758+ %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x5xf32>) -> tensor<?x5xf32>
759+ %2 = scf.foreach_thread (%arg2) in (%c5) shared_outs(%arg3 = %1) -> (tensor<?x5xf32>) {
760+ %4 = affine.min #map(%arg2)[%dim_0]
761+ %5 = affine.max #map1(%4)
762+ %extracted_slice = tensor.extract_slice %arg3[0, %arg2] [%dim, 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
763+ %6 = affine.apply #map2(%arg2)[%dim_0]
764+ %extracted_slice_2 = tensor.extract_slice %arg0[0, %6] [%dim, %5] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
765+ %extracted_slice_3 = tensor.extract_slice %extracted_slice[0] [%dim] [1] : tensor<?xf32> to tensor<?xf32>
766+ %7 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice_2 : tensor<?x?xf32>) outs(%extracted_slice_3 : tensor<?xf32>) {
767+ ^bb0(%in: f32, %out: f32):
768+ %9 = arith.addf %in, %out : f32
769+ linalg.yield %9 : f32
770+ } -> tensor<?xf32>
771+ scf.foreach_thread.perform_concurrently {
772+ tensor.parallel_insert_slice %7 into %arg3[0, %arg2] [%dim, 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
773+ }
774+ } {thread_dim_mapping = []}
775+ %3 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "reduction"]} ins(%2 : tensor<?x5xf32>) outs(%arg1 : tensor<?xf32>) {
776+ ^bb0(%in: f32, %out: f32):
777+ %4 = arith.addf %in, %out : f32
778+ linalg.yield %4 : f32
779+ } -> tensor<?xf32>
780+ ```
781+ }];
782+
783+ let arguments = (ins PDL_Operation:$target,
784+ DefaultValuedAttr<I64ArrayAttr, "{}">:$num_threads);
785+ let results = (outs PDL_Operation:$fill_op,
786+ PDL_Operation:$split_linalg_op,
787+ PDL_Operation:$combining_linalg_op);
788+
789+ let assemblyFormat = "$target attr-dict";
790+
791+ let extraClassDeclaration = [{
792+ ::mlir::DiagnosedSilenceableFailure applyToOne(
793+ ::mlir::linalg::LinalgOp target,
794+ ::llvm::SmallVectorImpl<::mlir::Operation *> &results,
795+ ::mlir::transform::TransformState &state);
796+ }];
797+ }
798+
717799def TileOp : Op<Transform_Dialect, "structured.tile",
718800 [DeclareOpInterfaceMethods<TransformOpInterface>,
719801 DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
0 commit comments