diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h index e6b928d8ebecc..2ed96afbace81 100644 --- a/mlir/include/mlir/Transforms/RegionUtils.h +++ b/mlir/include/mlir/Transforms/RegionUtils.h @@ -11,6 +11,7 @@ #include "mlir/IR/Region.h" #include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" #include "llvm/ADT/SetVector.h" @@ -80,6 +81,16 @@ LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op, LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op, Operation *insertionPoint); +/// Move definitions of `values` before an insertion point. Current support is +/// only for movement of definitions within the same basic block. Note that this +/// is an all-or-nothing approach. Either definitions of all values are moved +/// before insertion point, or none of them are. +LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values, + Operation *insertionPoint, + DominanceInfo &dominance); +LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values, + Operation *insertionPoint); + /// Run a set of structural simplifications over the given regions. This /// includes transformations like unreachable block elimination, dead argument /// elimination, as well as some other DCE. This function returns success if any diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index da0d486f0fdcb..18e079d153161 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -1070,7 +1070,7 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter, // in different basic blocks. if (op->getBlock() != insertionPoint->getBlock()) { return rewriter.notifyMatchFailure( - op, "unsupported caes where operation and insertion point are not in " + op, "unsupported case where operation and insertion point are not in " "the same basic block"); } // If `insertionPoint` does not dominate `op`, do nothing @@ -1115,3 +1115,70 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter, DominanceInfo dominance(op); return moveOperationDependencies(rewriter, op, insertionPoint, dominance); } + +LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter, + ValueRange values, + Operation *insertionPoint, + DominanceInfo &dominance) { + // Remove the values that already dominate the insertion point. + SmallVector prunedValues; + for (auto value : values) { + if (dominance.properlyDominates(value, insertionPoint)) { + continue; + } + // Block arguments are not supported. + if (isa(value)) { + return rewriter.notifyMatchFailure( + insertionPoint, + "unsupported case of moving block argument before insertion point"); + } + // Check for currently unsupported case if the insertion point is in a + // different block. + if (value.getDefiningOp()->getBlock() != insertionPoint->getBlock()) { + return rewriter.notifyMatchFailure( + insertionPoint, + "unsupported case of moving definition of value before an insertion " + "point in a different basic block"); + } + prunedValues.push_back(value); + } + + // Find the backward slice of operation for each `Value` the operation + // depends on. Prune the slice to only include operations not already + // dominated by the `insertionPoint` + BackwardSliceOptions options; + options.inclusive = true; + options.omitUsesFromAbove = false; + // Since current support is to only move within a same basic block, + // the slices dont need to look past block arguments. + options.omitBlockArguments = true; + options.filter = [&](Operation *sliceBoundaryOp) { + return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint); + }; + llvm::SetVector slice; + for (auto value : prunedValues) { + getBackwardSlice(value, &slice, options); + } + + // If the slice contains `insertionPoint` cannot move the dependencies. + if (slice.contains(insertionPoint)) { + return rewriter.notifyMatchFailure( + insertionPoint, + "cannot move dependencies before operation in backward slice of op"); + } + + // Sort operations topologically before moving. + mlir::topologicalSort(slice); + + for (Operation *op : slice) { + rewriter.moveOpBefore(op, insertionPoint); + } + return success(); +} + +LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter, + ValueRange values, + Operation *insertionPoint) { + DominanceInfo dominance(insertionPoint); + return moveValueDefinitions(rewriter, values, insertionPoint, dominance); +} diff --git a/mlir/test/Transforms/move-operation-deps.mlir b/mlir/test/Transforms/move-operation-deps.mlir index 37637152938f6..aa7b5dc2a240a 100644 --- a/mlir/test/Transforms/move-operation-deps.mlir +++ b/mlir/test/Transforms/move-operation-deps.mlir @@ -234,3 +234,229 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// Check simple move value definitions before insertion operation. +func.func @simple_move_values() -> f32 { + %0 = "before"() : () -> (f32) + %1 = "moved_op_1"() : () -> (f32) + %2 = "moved_op_2"() : () -> (f32) + %3 = "foo"(%1, %2) : (f32, f32) -> (f32) + return %3 : f32 +} +// CHECK-LABEL: func @simple_move_values() +// CHECK: %[[MOVED1:.+]] = "moved_op_1" +// CHECK: %[[MOVED2:.+]] = "moved_op_2" +// CHECK: %[[BEFORE:.+]] = "before" +// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED1]], %[[MOVED2]]) +// CHECK: return %[[FOO]] + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { + %op1 = transform.structured.match ops{["moved_op_1"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op2 = transform.structured.match ops{["moved_op_2"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op3 = transform.structured.match ops{["before"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value + %v2 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value + transform.test.move_value_defns %v1, %v2 before %op3 + : (!transform.any_value, !transform.any_value), !transform.any_op + transform.yield + } +} + +// ----- + +// Compute slice including the implicitly captured values. +func.func @move_region_dependencies_values() -> f32 { + %0 = "before"() : () -> (f32) + %1 = "moved_op_1"() : () -> (f32) + %2 = "moved_op_2"() ({ + %3 = "inner_op"(%1) : (f32) -> (f32) + "yield"(%3) : (f32) -> () + }) : () -> (f32) + return %2 : f32 +} +// CHECK-LABEL: func @move_region_dependencies_values() +// CHECK: %[[MOVED1:.+]] = "moved_op_1" +// CHECK: %[[MOVED2:.+]] = "moved_op_2" +// CHECK: %[[BEFORE:.+]] = "before" + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { + %op1 = transform.structured.match ops{["moved_op_2"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op2 = transform.structured.match ops{["before"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value + transform.test.move_value_defns %v1 before %op2 + : (!transform.any_value), !transform.any_op + transform.yield + } +} + +// ----- + +// Move operations in toplogical sort order +func.func @move_values_in_topological_sort_order() -> f32 { + %0 = "before"() : () -> (f32) + %1 = "moved_op_1"() : () -> (f32) + %2 = "moved_op_2"() : () -> (f32) + %3 = "moved_op_3"(%1) : (f32) -> (f32) + %4 = "moved_op_4"(%1, %3) : (f32, f32) -> (f32) + %5 = "moved_op_5"(%2) : (f32) -> (f32) + %6 = "foo"(%4, %5) : (f32, f32) -> (f32) + return %6 : f32 +} +// CHECK-LABEL: func @move_values_in_topological_sort_order() +// CHECK: %[[MOVED_1:.+]] = "moved_op_1" +// CHECK-DAG: %[[MOVED_2:.+]] = "moved_op_3"(%[[MOVED_1]]) +// CHECK-DAG: %[[MOVED_3:.+]] = "moved_op_4"(%[[MOVED_1]], %[[MOVED_2]]) +// CHECK-DAG: %[[MOVED_4:.+]] = "moved_op_2" +// CHECK-DAG: %[[MOVED_5:.+]] = "moved_op_5"(%[[MOVED_4]]) +// CHECK: %[[BEFORE:.+]] = "before" +// CHECK: %[[FOO:.+]] = "foo"(%[[MOVED_3]], %[[MOVED_5]]) +// CHECK: return %[[FOO]] + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { + %op1 = transform.structured.match ops{["moved_op_4"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op2 = transform.structured.match ops{["moved_op_5"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op3 = transform.structured.match ops{["before"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value + %v2 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value + transform.test.move_value_defns %v1, %v2 before %op3 + : (!transform.any_value, !transform.any_value), !transform.any_op + transform.yield + } +} + +// ----- + +// Move only those value definitions that are not dominated by insertion point + +func.func @move_only_required_defns() -> (f32, f32, f32, f32) { + %0 = "unmoved_op"() : () -> (f32) + %1 = "dummy_op"() : () -> (f32) + %2 = "before"() : () -> (f32) + %3 = "moved_op"() : () -> (f32) + return %0, %1, %2, %3 : f32, f32, f32, f32 +} +// CHECK-LABEL: func @move_only_required_defns() +// CHECK: %[[UNMOVED:.+]] = "unmoved_op" +// CHECK: %[[DUMMY:.+]] = "dummy_op" +// CHECK: %[[MOVED:.+]] = "moved_op" +// CHECK: %[[BEFORE:.+]] = "before" + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { + %op1 = transform.structured.match ops{["unmoved_op"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op2 = transform.structured.match ops{["dummy_op"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op3 = transform.structured.match ops{["before"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op4 = transform.structured.match ops{["moved_op"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value + %v2 = transform.get_result %op4[0] : (!transform.any_op) -> !transform.any_value + transform.test.move_value_defns %v1, %v2 before %op3 + : (!transform.any_value, !transform.any_value), !transform.any_op + transform.yield + } +} + +// ----- + +// Move only those value definitions that are not dominated by insertion point + +func.func @move_only_required_defns() -> (f32, f32, f32, f32) { + %0 = "unmoved_op"() : () -> (f32) + %1 = "dummy_op"() : () -> (f32) + %2 = "before"() : () -> (f32) + %3 = "moved_op"() : () -> (f32) + return %0, %1, %2, %3 : f32, f32, f32, f32 +} +// CHECK-LABEL: func @move_only_required_defns() +// CHECK: %[[UNMOVED:.+]] = "unmoved_op" +// CHECK: %[[DUMMY:.+]] = "dummy_op" +// CHECK: %[[MOVED:.+]] = "moved_op" +// CHECK: %[[BEFORE:.+]] = "before" + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { + %op1 = transform.structured.match ops{["unmoved_op"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op2 = transform.structured.match ops{["dummy_op"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op3 = transform.structured.match ops{["before"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op4 = transform.structured.match ops{["moved_op"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %v1 = transform.get_result %op1[0] : (!transform.any_op) -> !transform.any_value + %v2 = transform.get_result %op4[0] : (!transform.any_op) -> !transform.any_value + transform.test.move_value_defns %v1, %v2 before %op3 + : (!transform.any_value, !transform.any_value), !transform.any_op + transform.yield + } +} + +// ----- + +// Check handling of block arguments +func.func @move_only_required_defns() -> (f32, f32) { + %0 = "unmoved_op"() : () -> (f32) + cf.br ^bb0(%0 : f32) + ^bb0(%arg0 : f32) : + %1 = "before"() : () -> (f32) + %2 = "moved_op"(%arg0) : (f32) -> (f32) + return %1, %2 : f32, f32 +} +// CHECK-LABEL: func @move_only_required_defns() +// CHECK: %[[MOVED:.+]] = "moved_op" +// CHECK: %[[BEFORE:.+]] = "before" + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { + %op1 = transform.structured.match ops{["before"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op2 = transform.structured.match ops{["moved_op"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value + transform.test.move_value_defns %v1 before %op1 + : (!transform.any_value), !transform.any_op + transform.yield + } +} + +// ----- + +// Do not move across basic blocks +func.func @no_move_across_basic_blocks() -> (f32, f32) { + %0 = "unmoved_op"() : () -> (f32) + %1 = "before"() : () -> (f32) + cf.br ^bb0(%0 : f32) + ^bb0(%arg0 : f32) : + %2 = "moved_op"(%arg0) : (f32) -> (f32) + return %1, %2 : f32, f32 +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) { + %op1 = transform.structured.match ops{["before"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %op2 = transform.structured.match ops{["moved_op"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value + // expected-remark@+1{{unsupported case of moving definition of value before an insertion point in a different basic block}} + transform.test.move_value_defns %v1 before %op1 + : (!transform.any_value), !transform.any_op + transform.yield + } +} diff --git a/mlir/test/lib/Transforms/TestTransformsOps.cpp b/mlir/test/lib/Transforms/TestTransformsOps.cpp index aaa566d9938a3..c05b32bed9b94 100644 --- a/mlir/test/lib/Transforms/TestTransformsOps.cpp +++ b/mlir/test/lib/Transforms/TestTransformsOps.cpp @@ -39,6 +39,23 @@ transform::TestMoveOperandDeps::apply(TransformRewriter &rewriter, return DiagnosedSilenceableFailure::success(); } +DiagnosedSilenceableFailure +transform::TestMoveValueDefns::apply(TransformRewriter &rewriter, + TransformResults &TransformResults, + TransformState &state) { + SmallVector values; + for (auto tdValue : getValues()) { + values.push_back(*state.getPayloadValues(tdValue).begin()); + } + Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin(); + if (failed(moveValueDefinitions(rewriter, values, moveBefore))) { + auto listener = cast(rewriter.getListener()); + std::string errorMsg = listener->getLatestMatchFailureMessage(); + (void)emitRemark(errorMsg); + } + return DiagnosedSilenceableFailure::success(); +} + namespace { class TestTransformsDialectExtension diff --git a/mlir/test/lib/Transforms/TestTransformsOps.td b/mlir/test/lib/Transforms/TestTransformsOps.td index f514702cef5bc..495579b452dfc 100644 --- a/mlir/test/lib/Transforms/TestTransformsOps.td +++ b/mlir/test/lib/Transforms/TestTransformsOps.td @@ -38,4 +38,26 @@ def TestMoveOperandDeps : }]; } +def TestMoveValueDefns : + Op, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Moves all dependencies of on operation before another operation. + }]; + + let arguments = + (ins Variadic:$values, + TransformHandleTypeInterface:$insertion_point); + + let results = (outs); + + let assemblyFormat = [{ + $values `before` $insertion_point attr-dict + `:` `(` type($values) `)` `` `,` type($insertion_point) + }]; +} + + #endif // TEST_TRANSFORM_OPS