diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index f75ba27e58e76..0aa750e625436 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -434,10 +434,10 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass", of targeted ops. }]; - let arguments = (ins StrAttr:$pass_name, + let arguments = (ins TransformHandleTypeInterface:$target, + StrAttr:$pass_name, DefaultValuedAttr:$options, - Variadic:$dynamic_options, - TransformHandleTypeInterface:$target); + Variadic:$dynamic_options); let results = (outs TransformHandleTypeInterface:$result); let assemblyFormat = [{ $pass_name (`with` `options` `=` diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index 10a04b0cc14e0..bfe96b1b3e5d4 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -224,13 +224,13 @@ class ApplyRegisteredPassOp(ApplyRegisteredPassOp): def __init__( self, result: Type, - pass_name: Union[str, StringAttr], target: Union[Operation, Value, OpView], + pass_name: Union[str, StringAttr], *, options: Optional[ Dict[ Union[str, StringAttr], - Union[Attribute, Value, Operation, OpView], + Union[Attribute, Value, Operation, OpView, str, int, bool], ] ] = None, loc=None, @@ -253,17 +253,21 @@ def __init__( cur_param_operand_idx += 1 elif isinstance(value, Attribute): options_dict[key] = value + # The following cases auto-convert Python values to attributes. + elif isinstance(value, bool): + options_dict[key] = BoolAttr.get(value) + elif isinstance(value, int): + default_int_type = IntegerType.get_signless(64, context) + options_dict[key] = IntegerAttr.get(default_int_type, value) elif isinstance(value, str): options_dict[key] = StringAttr.get(value) else: raise TypeError(f"Unsupported option type: {type(value)}") - if len(options_dict) > 0: - print(options_dict, cur_param_operand_idx) super().__init__( result, + _get_op_result_or_value(target), pass_name, dynamic_options, - target=_get_op_result_or_value(target), options=DictAttr.get(options_dict), loc=loc, ip=ip, @@ -272,13 +276,13 @@ def __init__( def apply_registered_pass( result: Type, - pass_name: Union[str, StringAttr], target: Union[Operation, Value, OpView], + pass_name: Union[str, StringAttr], *, options: Optional[ Dict[ Union[str, StringAttr], - Union[Attribute, Value, Operation, OpView], + Union[Attribute, Value, Operation, OpView, str, int, bool], ] ] = None, loc=None, diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir index 6e6d4eb7e249f..1d1be9eda3496 100644 --- a/mlir/test/Dialect/Transform/test-pass-application.mlir +++ b/mlir/test/Dialect/Transform/test-pass-application.mlir @@ -157,7 +157,7 @@ module attributes {transform.with_named_sequence} { "test-convergence" = true, "max-num-rewrites" = %max_rewrites } to %1 - : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op + : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op transform.yield } } @@ -171,7 +171,6 @@ func.func @invalid_options_as_str() { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op) { %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %max_iter = transform.param.constant "max-iterations=10" -> !transform.any_param // expected-error @+2 {{expected '{' in options dictionary}} %2 = transform.apply_registered_pass "canonicalize" with options = "top-down=false" to %1 : (!transform.any_op) -> !transform.any_op @@ -256,7 +255,7 @@ module attributes {transform.with_named_sequence} { // expected-error @+2 {{expected '{' in options dictionary}} transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 - : (!transform.any_param, !transform.any_op) -> !transform.any_op + : (!transform.any_op, !transform.any_param) -> !transform.any_op transform.yield } } @@ -276,7 +275,7 @@ module attributes {transform.with_named_sequence} { // expected-error @below {{options passed as a param must have a single value associated, param 0 associates 2}} transform.apply_registered_pass "canonicalize" with options = { "top-down" = %topdown_options } to %1 - : (!transform.any_param, !transform.any_op) -> !transform.any_op + : (!transform.any_op, !transform.any_param) -> !transform.any_op transform.yield } } @@ -316,12 +315,12 @@ module attributes {transform.with_named_sequence} { %0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op %1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param // expected-error @below {{dynamic option index 1 is out of bounds for the number of dynamic options: 1}} - %2 = "transform.apply_registered_pass"(%1, %0) <{ + %2 = "transform.apply_registered_pass"(%0, %1) <{ options = {"max-iterations" = #transform.param_operand, "test-convergence" = true, "top-down" = false}, pass_name = "canonicalize"}> - : (!transform.any_param, !transform.any_op) -> !transform.any_op + : (!transform.any_op, !transform.any_param) -> !transform.any_op "transform.yield"() : () -> () }) : () -> () }) {transform.with_named_sequence} : () -> () @@ -340,13 +339,13 @@ module attributes {transform.with_named_sequence} { %1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param %2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param // expected-error @below {{dynamic option index 0 is already used in options}} - %3 = "transform.apply_registered_pass"(%1, %2, %0) <{ + %3 = "transform.apply_registered_pass"(%0, %1, %2) <{ options = {"max-iterations" = #transform.param_operand, "max-num-rewrites" = #transform.param_operand, "test-convergence" = true, "top-down" = false}, pass_name = "canonicalize"}> - : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op + : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op "transform.yield"() : () -> () }) : () -> () }) {transform.with_named_sequence} : () -> () @@ -364,12 +363,12 @@ module attributes {transform.with_named_sequence} { %1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param %2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param // expected-error @below {{a param operand does not have a corresponding param_operand attr in the options dict}} - %3 = "transform.apply_registered_pass"(%1, %2, %0) <{ + %3 = "transform.apply_registered_pass"(%0, %1, %2) <{ options = {"max-iterations" = #transform.param_operand, "test-convergence" = true, "top-down" = false}, pass_name = "canonicalize"}> - : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op + : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op "transform.yield"() : () -> () }) : () -> () }) {transform.with_named_sequence} : () -> () diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py index 48bc9bad37a1e..eeb95605d7a9a 100644 --- a/mlir/test/python/dialects/transform.py +++ b/mlir/test/python/dialects/transform.py @@ -263,12 +263,12 @@ def testApplyRegisteredPassOp(module: Module): ) with InsertionPoint(sequence.body): mod = transform.ApplyRegisteredPassOp( - transform.AnyOpType.get(), "canonicalize", sequence.bodyTarget + transform.AnyOpType.get(), sequence.bodyTarget, "canonicalize" ) mod = transform.ApplyRegisteredPassOp( transform.AnyOpType.get(), - "canonicalize", mod.result, + "canonicalize", options={"top-down": BoolAttr.get(False)}, ) max_iter = transform.param_constant( @@ -281,12 +281,12 @@ def testApplyRegisteredPassOp(module: Module): ) transform.apply_registered_pass( transform.AnyOpType.get(), - "canonicalize", mod, + "canonicalize", options={ "top-down": BoolAttr.get(False), "max-iterations": max_iter, - "test-convergence": BoolAttr.get(True), + "test-convergence": True, "max-rewrites": max_rewrites, }, ) @@ -305,4 +305,4 @@ def testApplyRegisteredPassOp(module: Module): # CHECK-SAME: "max-rewrites" = %[[MAX_REWRITE]], # CHECK-SAME: "test-convergence" = true, # CHECK-SAME: "top-down" = false} - # CHECK-SAME: to %{{.*}} : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op + # CHECK-SAME: to %{{.*}} : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op