diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td index 700a29139a35b..14df7e23a430f 100644 --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -68,30 +68,6 @@ def ForallToForOp : Op]> { - let summary = "Gets a handle to the parent 'for' loop of the given operation"; - let description = [{ - Produces a handle to the n-th (default 1) parent `scf.for` or `affine.for` - (when the affine flag is true) loop for each Payload IR operation - associated with the operand. Fails if such a loop cannot be found. The list - of operations associated with the handle contains parent operations in the - same order as the list associated with the operand, except for operations - that are parents to more than one input which are only present once. - }]; - - let arguments = - (ins TransformHandleTypeInterface:$target, - DefaultValuedAttr, - "1">:$num_loops, - DefaultValuedAttr:$affine); - let results = (outs TransformHandleTypeInterface : $parent); - - let assemblyFormat = - "$target attr-dict `:` functional-type(operands, results)"; -} - def LoopOutlineOp : Op]> { diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index 2fd0e80db96fe..307257f4a582b 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -620,10 +620,11 @@ def GetParentOp : TransformDialectOp<"get_parent_op", that case for each target op, the closest parent op that fulfills all requirements, is returned. - `isolated_from_above`: the parent op must be isolated from above - - `allow_empty_results`: get_parent_op is allowed to return an empty list and - still succeeds. In such a case, if get_parent_op fails for any operation - in the list, the entire transform returns an empty handle. + - `allow_empty_results`: get_parent_op is allowed to return an empty list + and still succeeds. In such a case, if get_parent_op fails for any + operation in the list, the entire transform returns an empty handle. - `op_name`: the parent op must have the specified name + - `nth_parent`: get the n-th parent of that satisfies the above requirements If `deduplicate` is set, the result handle does not contain any duplicate ops. For example, given the list @@ -641,7 +642,9 @@ def GetParentOp : TransformDialectOp<"get_parent_op", UnitAttr:$isolated_from_above, UnitAttr:$allow_empty_results, OptionalAttr:$op_name, - UnitAttr:$deduplicate); + UnitAttr:$deduplicate, + DefaultValuedAttr, + "1">:$nth_parent); let results = (outs TransformHandleTypeInterface:$parent); let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)"; diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 65d503d7c4ad8..62370604142cd 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -49,39 +49,6 @@ void transform::ApplySCFStructuralConversionPatternsOp:: conversionTarget); } -//===----------------------------------------------------------------------===// -// GetParentForOp -//===----------------------------------------------------------------------===// - -DiagnosedSilenceableFailure -transform::GetParentForOp::apply(transform::TransformRewriter &rewriter, - transform::TransformResults &results, - transform::TransformState &state) { - SetVector parents; - for (Operation *target : state.getPayloadOps(getTarget())) { - Operation *loop, *current = target; - for (unsigned i = 0, e = getNumLoops(); i < e; ++i) { - loop = getAffine() - ? current->getParentOfType().getOperation() - : current->getParentOfType().getOperation(); - if (!loop) { - DiagnosedSilenceableFailure diag = - emitSilenceableError() - << "could not find an '" - << (getAffine() ? AffineForOp::getOperationName() - : scf::ForOp::getOperationName()) - << "' parent"; - diag.attachNote(target->getLoc()) << "target op"; - return diag; - } - current = loop; - } - parents.insert(loop); - } - results.set(cast(getResult()), parents.getArrayRef()); - return DiagnosedSilenceableFailure::success(); -} - //===----------------------------------------------------------------------===// // ForallToForOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 514a75b5d5904..7136e423470a2 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -1232,27 +1232,30 @@ transform::GetParentOp::apply(transform::TransformRewriter &rewriter, SmallVector parents; DenseSet resultSet; for (Operation *target : state.getPayloadOps(getTarget())) { - Operation *parent = target->getParentOp(); - while (parent) { - bool checkIsolatedFromAbove = - !getIsolatedFromAbove() || - parent->hasTrait(); - bool checkOpName = !getOpName().has_value() || - parent->getName().getStringRef() == *getOpName(); - if (checkIsolatedFromAbove && checkOpName) - break; + Operation *parent = target; + for (int64_t i = 0, e = getNthParent(); i < e; ++i) { parent = parent->getParentOp(); - } - if (!parent) { - if (getAllowEmptyResults()) { - results.set(llvm::cast(getResult()), parents); - return DiagnosedSilenceableFailure::success(); + while (parent) { + bool checkIsolatedFromAbove = + !getIsolatedFromAbove() || + parent->hasTrait(); + bool checkOpName = !getOpName().has_value() || + parent->getName().getStringRef() == *getOpName(); + if (checkIsolatedFromAbove && checkOpName) + break; + parent = parent->getParentOp(); + } + if (!parent) { + if (getAllowEmptyResults()) { + results.set(llvm::cast(getResult()), parents); + return DiagnosedSilenceableFailure::success(); + } + DiagnosedSilenceableFailure diag = + emitSilenceableError() + << "could not find a parent op that matches all requirements"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; } - DiagnosedSilenceableFailure diag = - emitSilenceableError() - << "could not find a parent op that matches all requirements"; - diag.attachNote(target->getLoc()) << "target op"; - return diag; } if (getDeduplicate()) { if (!resultSet.contains(parent)) { diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index f7a2026e800ae..166c5c5ca4ec3 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -52,26 +52,28 @@ def patterns(self) -> Block: @_ods_cext.register_operation(_Dialect, replace=True) class GetParentOp(GetParentOp): - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - isolated_from_above: bool = False, - op_name: Optional[str] = None, - deduplicate: bool = False, - loc=None, - ip=None, - ): - super().__init__( - result_type, - _get_op_result_or_value(target), - isolated_from_above=isolated_from_above, - op_name=op_name, - deduplicate=deduplicate, - loc=loc, - ip=ip, - ) + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + isolated_from_above: bool = False, + op_name: Optional[str] = None, + deduplicate: bool = False, + nth_parent: int = 1, + loc=None, + ip=None, + ): + super().__init__( + result_type, + _get_op_result_or_value(target), + isolated_from_above=isolated_from_above, + op_name=op_name, + deduplicate=deduplicate, + nth_parent=nth_parent, + loc=loc, + ip=ip, + ) @_ods_cext.register_operation(_Dialect, replace=True) diff --git a/mlir/python/mlir/dialects/transform/loop.py b/mlir/python/mlir/dialects/transform/loop.py index 6c89025f41383..3bdd9ca3b22f0 100644 --- a/mlir/python/mlir/dialects/transform/loop.py +++ b/mlir/python/mlir/dialects/transform/loop.py @@ -17,30 +17,6 @@ from typing import Optional, Union -@_ods_cext.register_operation(_Dialect, replace=True) -class GetParentForOp(GetParentForOp): - """Extension for GetParentForOp.""" - - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - num_loops: Optional[int] = None, - ip=None, - loc=None, - ): - if num_loops is None: - num_loops = 1 - super().__init__( - result_type, - _get_op_result_or_value(target), - num_loops=num_loops, - ip=ip, - loc=loc, - ) - - @_ods_cext.register_operation(_Dialect, replace=True) class LoopOutlineOp(LoopOutlineOp): """Extension for LoopOutlineOp.""" diff --git a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir index 96c57d4716d37..59b824d4ca262 100644 --- a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir +++ b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir @@ -32,7 +32,7 @@ func.func @test_loops_do_not_get_unrolled() { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.loop.get_parent_for %0 { affine = true } : (!transform.any_op) -> !transform.op<"affine.for"> + %1 = transform.get_parent_op %0 {op_name = "affine.for"} : (!transform.any_op) -> !transform.op<"affine.for"> // expected-error @below {{failed to unroll}} transform.loop.unroll %1 { factor = 8 } : !transform.op<"affine.for"> transform.yield @@ -81,7 +81,7 @@ func.func @test_loops_do_not_get_peeled() { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for"> + %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for"> // expected-error @below {{failed to peel}} transform.loop.peel %1 : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op) transform.yield diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir index 6d1ba48d3b935..74601cf5b34a1 100644 --- a/mlir/test/Dialect/SCF/transform-ops.mlir +++ b/mlir/test/Dialect/SCF/transform-ops.mlir @@ -1,53 +1,5 @@ // RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s -// CHECK-LABEL: @get_parent_for_op -func.func @get_parent_for_op(%arg0: index, %arg1: index, %arg2: index) { - // expected-remark @below {{first loop}} - scf.for %i = %arg0 to %arg1 step %arg2 { - // expected-remark @below {{second loop}} - scf.for %j = %arg0 to %arg1 step %arg2 { - // expected-remark @below {{third loop}} - scf.for %k = %arg0 to %arg1 step %arg2 { - arith.addi %i, %j : index - } - } - } - return -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op - // CHECK: = transform.loop.get_parent_for - %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for"> - %2 = transform.loop.get_parent_for %0 { num_loops = 2 } : (!transform.any_op) -> !transform.op<"scf.for"> - %3 = transform.loop.get_parent_for %0 { num_loops = 3 } : (!transform.any_op) -> !transform.op<"scf.for"> - transform.test_print_remark_at_operand %1, "third loop" : !transform.op<"scf.for"> - transform.test_print_remark_at_operand %2, "second loop" : !transform.op<"scf.for"> - transform.test_print_remark_at_operand %3, "first loop" : !transform.op<"scf.for"> - transform.yield - } -} - -// ----- - -func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) { - // expected-note @below {{target op}} - arith.addi %arg0, %arg1 : index - return -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op - // expected-error @below {{could not find an 'scf.for' parent}} - %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for"> - transform.yield - } -} - -// ----- - // Outlined functions: // // CHECK: func @foo(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}) @@ -81,7 +33,7 @@ func.func @loop_outline_op(%arg0: index, %arg1: index, %arg2: index) { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for"> + %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for"> // CHECK: = transform.loop.outline %{{.*}} transform.loop.outline %1 {func_name = "foo"} : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op) transform.yield @@ -114,7 +66,7 @@ func.func @loop_peel_op() { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for"> + %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for"> %main_loop, %remainder = transform.loop.peel %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">, !transform.op<"scf.for">) // Make sure transform.test_print_remark_at_operand %main_loop, "main loop" : !transform.op<"scf.for"> @@ -152,7 +104,7 @@ func.func @loop_pipeline_op(%A: memref, %result: memref) { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["arith.addf"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for"> + %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for"> %2 = transform.loop.pipeline %1 : (!transform.op<"scf.for">) -> !transform.any_op // Verify that the returned handle is usable. transform.test_print_remark_at_operand %2, "transformed" : !transform.any_op @@ -178,7 +130,7 @@ func.func @loop_unroll_op() { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.op<"scf.for"> + %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for"> transform.loop.unroll %1 { factor = 4 } : !transform.op<"scf.for"> transform.yield } @@ -186,54 +138,6 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-LABEL: @get_parent_for_op -func.func @get_parent_for_op(%arg0: index, %arg1: index, %arg2: index) { - // expected-remark @below {{first loop}} - affine.for %i = %arg0 to %arg1 { - // expected-remark @below {{second loop}} - affine.for %j = %arg0 to %arg1 { - // expected-remark @below {{third loop}} - affine.for %k = %arg0 to %arg1 { - arith.addi %i, %j : index - } - } - } - return -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op - // CHECK: = transform.loop.get_parent_for - %1 = transform.loop.get_parent_for %0 { affine = true } : (!transform.any_op) -> !transform.op<"affine.for"> - %2 = transform.loop.get_parent_for %0 { num_loops = 2, affine = true } : (!transform.any_op) -> !transform.op<"affine.for"> - %3 = transform.loop.get_parent_for %0 { num_loops = 3, affine = true } : (!transform.any_op) -> !transform.op<"affine.for"> - transform.test_print_remark_at_operand %1, "third loop" : !transform.op<"affine.for"> - transform.test_print_remark_at_operand %2, "second loop" : !transform.op<"affine.for"> - transform.test_print_remark_at_operand %3, "first loop" : !transform.op<"affine.for"> - transform.yield - } -} - -// ----- - -func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) { - // expected-note @below {{target op}} - arith.addi %arg0, %arg1 : index - return -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op - // expected-error @below {{could not find an 'affine.for' parent}} - %1 = transform.loop.get_parent_for %0 { affine = true } : (!transform.any_op) -> !transform.op<"affine.for"> - transform.yield - } -} - -// ----- - func.func @loop_unroll_op() { %c0 = arith.constant 0 : index %c42 = arith.constant 42 : index @@ -250,7 +154,7 @@ func.func @loop_unroll_op() { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.loop.get_parent_for %0 { affine = true } : (!transform.any_op) -> !transform.op<"affine.for"> + %1 = transform.get_parent_op %0 {op_name = "affine.for"} : (!transform.any_op) -> !transform.op<"affine.for"> transform.test_print_remark_at_operand %1, "affine for loop" : !transform.op<"affine.for"> transform.loop.unroll %1 { factor = 4, affine = true } : !transform.op<"affine.for"> transform.yield @@ -277,7 +181,7 @@ func.func @test_mixed_loops() { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.loop.get_parent_for %0 { num_loops = 1, affine = true } : (!transform.any_op) -> !transform.op<"affine.for"> + %1 = transform.get_parent_op %0 {op_name = "affine.for"} : (!transform.any_op) -> !transform.op<"affine.for"> transform.test_print_remark_at_operand %1, "affine for loop" : !transform.op<"affine.for"> transform.loop.unroll %1 { factor = 4 } : !transform.op<"affine.for"> transform.yield diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir index 3891c16b41155..d9a11994eb9d9 100644 --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -116,6 +116,32 @@ transform.with_pdl_patterns { // ----- +func.func @test_get_nth_parent() { + "test.foo"() ({ + // expected-remark @below{{2nd parent}} + "test.foo"() ({ + "test.qux"() ({ + // expected-remark @below{{1st parent}} + "test.foo"() ({ + "test.bar"() : () -> () + }) : () -> () + }) : () -> () + }) : () -> () + }) : () -> () +} + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + %f = transform.structured.match ops{["test.bar"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %parent = get_parent_op %f {nth_parent = 1, op_name = "test.foo"} : (!transform.any_op) -> !transform.any_op + test_print_remark_at_operand %parent, "1st parent" : !transform.any_op + %parent2 = get_parent_op %f {nth_parent = 2, op_name = "test.foo"} : (!transform.any_op) -> !transform.any_op + test_print_remark_at_operand %parent2, "2nd parent" : !transform.any_op + transform.yield +} + +// ----- + func.func @foo() { %0 = arith.constant 0 : i32 return @@ -355,7 +381,7 @@ transform.with_pdl_patterns { sequence %arg0 : !transform.any_op failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.pdl_match @match_const in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.loop.get_parent_for %0 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.any_op // expected-error @below {{only isolated-from-above ops can be alternative scopes}} alternatives %1 : !transform.any_op { ^bb2(%arg2: !transform.any_op): diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py index 481d774572010..d778172a607a3 100644 --- a/mlir/test/python/dialects/transform.py +++ b/mlir/test/python/dialects/transform.py @@ -162,13 +162,16 @@ def testGetParentOp(): ) with InsertionPoint(sequence.body): transform.GetParentOp( - transform.AnyOpType.get(), sequence.bodyTarget, isolated_from_above=True + transform.AnyOpType.get(), + sequence.bodyTarget, + isolated_from_above=True, + nth_parent=2, ) transform.YieldOp() # CHECK-LABEL: TEST: testGetParentOp # CHECK: transform.sequence # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op): - # CHECK: = get_parent_op %[[ARG1]] {isolated_from_above} + # CHECK: = get_parent_op %[[ARG1]] {isolated_from_above, nth_parent = 2 : i64} @run diff --git a/mlir/test/python/dialects/transform_loop_ext.py b/mlir/test/python/dialects/transform_loop_ext.py index daec6707d6743..840e7a46e7ce0 100644 --- a/mlir/test/python/dialects/transform_loop_ext.py +++ b/mlir/test/python/dialects/transform_loop_ext.py @@ -16,21 +16,6 @@ def run(f): return f -@run -def getParentLoop(): - sequence = transform.SequenceOp( - transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get() - ) - with InsertionPoint(sequence.body): - loop.GetParentForOp( - transform.OperationType.get("scf.for"), sequence.bodyTarget, num_loops=2 - ) - transform.YieldOp() - # CHECK-LABEL: TEST: getParentLoop - # CHECK: = transform.loop.get_parent_for % - # CHECK: num_loops = 2 - - @run def loopOutline(): sequence = transform.SequenceOp(