Skip to content

Commit d57d60a

Browse files
committed
Allow passing in multiple separate option strings and/or params
1 parent 8977306 commit d57d60a

File tree

3 files changed

+187
-61
lines changed

3 files changed

+187
-61
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -406,8 +406,9 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
406406
This transform applies the specified pass or pass pipeline to the targeted
407407
ops. The name of the pass/pipeline is specified as a string attribute, as
408408
set during pass/pipeline registration. Optionally, pass options may be
409-
specified as a string attribute with the option to pass the attribute as a
410-
param. The pass options syntax is identical to the one used with "mlir-opt".
409+
specified as (space-separated) string attributes with the option to pass
410+
these attributes via params. The pass options syntax is identical to the one
411+
used with "mlir-opt".
411412

412413
This op first looks for a pass pipeline with the specified name. If no such
413414
pipeline exists, it looks for a pass with the specified name. If no such
@@ -420,16 +421,17 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
420421
of targeted ops.
421422
}];
422423

423-
let arguments = (ins Optional<TransformParamTypeInterface>:$dynamic_options,
424-
TransformHandleTypeInterface:$target,
425-
StrAttr:$pass_name,
426-
DefaultValuedAttr<StrAttr, "\"\"">:$static_options);
424+
let arguments = (ins StrAttr:$pass_name,
425+
DefaultValuedAttr<ArrayAttr, "{}">:$options,
426+
Variadic<TransformParamTypeInterface>:$dynamic_options,
427+
TransformHandleTypeInterface:$target);
427428
let results = (outs TransformHandleTypeInterface:$result);
428429
let assemblyFormat = [{
429430
$pass_name (`with` `options` `=`
430-
custom<ApplyRegisteredPassOptions>($dynamic_options, $static_options)^)?
431+
custom<ApplyRegisteredPassOptions>($options, $dynamic_options)^)?
431432
`to` $target attr-dict `:` functional-type(operands, results)
432433
}];
434+
let hasVerifier = 1;
433435
}
434436

435437
def CastOp : TransformDialectOp<"cast",

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 106 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,11 @@
5454
using namespace mlir;
5555

5656
static ParseResult parseApplyRegisteredPassOptions(
57-
OpAsmParser &parser,
58-
std::optional<OpAsmParser::UnresolvedOperand> &dynamicOptions,
59-
StringAttr &staticOptions);
57+
OpAsmParser &parser, ArrayAttr &options,
58+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions);
6059
static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
61-
Operation *op, Value dynamicOptions,
62-
StringAttr staticOptions);
60+
Operation *op, ArrayAttr options,
61+
ValueRange dynamicOptions);
6362
static ParseResult parseSequenceOpOperands(
6463
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
6564
Type &rootType,
@@ -785,25 +784,40 @@ DiagnosedSilenceableFailure
785784
transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
786785
transform::TransformResults &results,
787786
transform::TransformState &state) {
788-
// Check whether pass options are specified, either as a dynamic param or
789-
// a static attribute. In either case, options are passed as a single string.
790-
StringRef options;
791-
if (auto dynamicOptions = getDynamicOptions()) {
792-
ArrayRef<Attribute> dynamicOptionsParam = state.getParams(dynamicOptions);
793-
if (dynamicOptionsParam.size() != 1) {
794-
return emitSilenceableError()
795-
<< "options passed as a param must be a single value, got "
796-
<< dynamicOptionsParam.size();
797-
}
798-
if (auto optionsStrAttr = dyn_cast<StringAttr>(dynamicOptionsParam[0])) {
799-
options = optionsStrAttr.getValue();
787+
// Obtain a single options-string from options passed statically as
788+
// string attributes as well as "dynamically" through params.
789+
std::string options;
790+
OperandRange dynamicOptions = getDynamicOptions();
791+
size_t dynamicOptionsIdx = 0;
792+
for (auto [idx, optionAttr] : llvm::enumerate(getOptions())) {
793+
if (idx > 0)
794+
options += " "; // Interleave options seperator.
795+
796+
if (auto strAttr = dyn_cast<StringAttr>(optionAttr)) {
797+
options += strAttr.getValue();
798+
} else if (isa<UnitAttr>(optionAttr)) {
799+
assert(dynamicOptionsIdx < dynamicOptions.size() &&
800+
"number of dynamic option markers (UnitAttr) in options ArrayAttr "
801+
"should be the same as the number of options passed as params");
802+
ArrayRef<Attribute> dynamicOption =
803+
state.getParams(dynamicOptions[dynamicOptionsIdx++]);
804+
if (dynamicOption.size() != 1)
805+
return emitSilenceableError() << "options passed as a param must have "
806+
"a single value associated, param "
807+
<< dynamicOptionsIdx - 1 << " associates "
808+
<< dynamicOption.size();
809+
810+
if (auto dynamicOptionStr = dyn_cast<StringAttr>(dynamicOption[0])) {
811+
options += dynamicOptionStr.getValue();
812+
} else {
813+
return emitSilenceableError()
814+
<< "options passed as a param must be a string, got "
815+
<< dynamicOption[0];
816+
}
800817
} else {
801-
return emitSilenceableError()
802-
<< "options passed as a param must be a string, got "
803-
<< dynamicOptionsParam[0];
818+
assert(false &&
819+
"expected options element to be either StringAttr or UnitAttr");
804820
}
805-
} else {
806-
options = getStaticOptions();
807821
}
808822

809823
// Get pass or pass pipeline from registry.
@@ -850,43 +864,88 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
850864
}
851865

852866
static ParseResult parseApplyRegisteredPassOptions(
853-
OpAsmParser &parser,
854-
std::optional<OpAsmParser::UnresolvedOperand> &dynamicOptions,
855-
StringAttr &staticOptions) {
856-
dynamicOptions = std::nullopt;
857-
OpAsmParser::UnresolvedOperand dynamicOptionsOperand;
858-
OptionalParseResult hasDynamicOptions =
859-
parser.parseOptionalOperand(dynamicOptionsOperand);
860-
861-
if (hasDynamicOptions.has_value()) {
862-
if (failed(hasDynamicOptions.value()))
863-
return failure();
867+
OpAsmParser &parser, ArrayAttr &options,
868+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
869+
auto dynamicOptionMarker = UnitAttr::get(parser.getContext());
870+
SmallVector<Attribute> optionsArray;
871+
872+
auto parseOperandOrString = [&]() -> OptionalParseResult {
873+
OpAsmParser::UnresolvedOperand operand;
874+
OptionalParseResult parsedOperand = parser.parseOptionalOperand(operand);
875+
if (parsedOperand.has_value()) {
876+
if (failed(parsedOperand.value()))
877+
return failure();
864878

865-
dynamicOptions = dynamicOptionsOperand;
866-
return success();
867-
}
879+
dynamicOptions.push_back(operand);
880+
optionsArray.push_back(
881+
dynamicOptionMarker); // Placeholder for knowing where to
882+
// inject the dynamic option-as-param.
883+
return success();
884+
}
868885

869-
OptionalParseResult hasStaticOptions =
870-
parser.parseOptionalAttribute(staticOptions);
871-
if (hasStaticOptions.has_value()) {
872-
if (failed(hasStaticOptions.value()))
886+
StringAttr stringAttr;
887+
OptionalParseResult parsedStringAttr =
888+
parser.parseOptionalAttribute(stringAttr);
889+
if (parsedStringAttr.has_value()) {
890+
if (failed(parsedStringAttr.value()))
891+
return failure();
892+
optionsArray.push_back(stringAttr);
893+
return success();
894+
}
895+
896+
return std::nullopt;
897+
};
898+
899+
OptionalParseResult parsedOptionsElement = parseOperandOrString();
900+
while (parsedOptionsElement.has_value()) {
901+
if (failed(parsedOptionsElement.value()))
873902
return failure();
874-
return success();
903+
parsedOptionsElement = parseOperandOrString();
875904
}
876905

906+
if (optionsArray.empty()) {
907+
return parser.emitError(parser.getCurrentLocation())
908+
<< "expected at least one option (either a string or a param)";
909+
}
910+
options = parser.getBuilder().getArrayAttr(optionsArray);
877911
return success();
878912
}
879913

880914
static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
881-
Operation *op, Value dynamicOptions,
882-
StringAttr staticOptions) {
883-
if (dynamicOptions) {
884-
printer.printOperand(dynamicOptions);
885-
} else if (!staticOptions.getValue().empty()) {
886-
printer.printAttribute(staticOptions);
915+
Operation *op, ArrayAttr options,
916+
ValueRange dynamicOptions) {
917+
size_t currentDynamicOptionIdx = 0;
918+
for (Attribute optionAttr : options) {
919+
if (currentDynamicOptionIdx > 0)
920+
printer << " "; // Interleave options separator.
921+
922+
if (isa<UnitAttr>(optionAttr))
923+
printer.printOperand(dynamicOptions[currentDynamicOptionIdx++]);
924+
else if (auto strAttr = dyn_cast<StringAttr>(optionAttr))
925+
printer.printAttribute(strAttr);
926+
else
927+
assert(false && "each option should be either a StringAttr or UnitAttr");
887928
}
888929
}
889930

931+
LogicalResult transform::ApplyRegisteredPassOp::verify() {
932+
size_t numUnitsInOptions = 0;
933+
for (Attribute optionsElement : getOptions()) {
934+
if (isa<UnitAttr>(optionsElement))
935+
numUnitsInOptions++;
936+
else if (!isa<StringAttr>(optionsElement))
937+
return emitOpError() << "expected each option to be either a StringAttr "
938+
<< "or a UnitAttr, got " << optionsElement;
939+
}
940+
941+
if (getDynamicOptions().size() != numUnitsInOptions)
942+
return emitOpError()
943+
<< "expected the same number of options passed as params as "
944+
<< "UnitAttr elements in options ArrayAttr";
945+
946+
return success();
947+
}
948+
890949
//===----------------------------------------------------------------------===//
891950
// CastOp
892951
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Transform/test-pass-application.mlir

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,19 +101,84 @@ module attributes {transform.with_named_sequence} {
101101

102102
// -----
103103

104-
// CHECK-LABEL: func @valid_dynamic_pass_option()
105-
func.func @valid_dynamic_pass_option() {
104+
// CHECK-LABEL: func @valid_pass_options()
105+
func.func @valid_pass_options() {
106106
return
107107
}
108108

109109
module attributes {transform.with_named_sequence} {
110110
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
111111
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
112-
%pass_options = transform.param.constant "top-down=false" -> !transform.any_param
113-
transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op
112+
//transform.apply_registered_pass "canonicalize" with options = "top-down=false,max-iterations=10" to %1 : (!transform.any_op) -> !transform.any_op
113+
transform.apply_registered_pass "canonicalize" with options = "top-down=false test-convergence=true" to %1 : (!transform.any_op) -> !transform.any_op
114+
transform.yield
115+
}
116+
}
117+
118+
// -----
119+
120+
// CHECK-LABEL: func @valid_pass_options_as_list()
121+
func.func @valid_pass_options_as_list() {
122+
return
123+
}
124+
125+
module attributes {transform.with_named_sequence} {
126+
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
127+
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
128+
transform.apply_registered_pass "canonicalize" with options = "top-down=false" "max-iterations=0" to %1 : (!transform.any_op) -> !transform.any_op
114129
transform.yield
115130
}
116131
}
132+
133+
// -----
134+
135+
// CHECK-LABEL: func @valid_dynamic_pass_options()
136+
func.func @valid_dynamic_pass_options() {
137+
return
138+
}
139+
140+
module attributes {transform.with_named_sequence} {
141+
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
142+
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
143+
%max_iter = transform.param.constant "max-iterations=10" -> !transform.any_param
144+
%max_rewrites = transform.param.constant "max-num-rewrites=1" -> !transform.any_param
145+
%2 = transform.apply_registered_pass "canonicalize" with options = "top-down=false" %max_iter "test-convergence=true" %max_rewrites to %1 : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
146+
transform.yield
147+
}
148+
}
149+
150+
// -----
151+
152+
func.func @invalid_dynamic_options_as_array() {
153+
return
154+
}
155+
156+
module attributes {transform.with_named_sequence} {
157+
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
158+
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
159+
%max_iter = transform.param.constant "max-iterations=10" -> !transform.any_param
160+
// expected-error @below {{expected at least one option (either a string or a param)}}
161+
%2 = transform.apply_registered_pass "canonicalize" with options = ["top-down=false" %max_iter] to %1 : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
162+
transform.yield
163+
}
164+
}
165+
166+
// -----
167+
168+
func.func @invalid_options_as_pairs() {
169+
return
170+
}
171+
172+
module attributes {transform.with_named_sequence} {
173+
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
174+
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
175+
%max_iter = transform.param.constant "max-iterations=10" -> !transform.any_param
176+
// expected-error @below {{expected 'to'}}
177+
%2 = transform.apply_registered_pass "canonicalize" with options = "top-down=" false to %1 : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
178+
transform.yield
179+
}
180+
}
181+
117182
// -----
118183

119184
func.func @invalid_pass_option_param() {
@@ -126,7 +191,6 @@ module attributes {transform.with_named_sequence} {
126191
%pass_options = transform.param.constant 42 -> !transform.any_param
127192
// expected-error @below {{options passed as a param must be a string, got 42}}
128193
transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op
129-
transform.apply_registered_pass "canonicalize" with options = "invalid-option=1" to %1 : (!transform.any_op) -> !transform.any_op
130194
transform.yield
131195
}
132196
}
@@ -141,8 +205,9 @@ module attributes {transform.with_named_sequence} {
141205
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
142206
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
143207
%x = transform.param.constant "x" -> !transform.any_param
144-
%pass_options = transform.merge_handles %x, %x : !transform.any_param
145-
// expected-error @below {{options passed as a param must be a single value, got 2}}
208+
%y = transform.param.constant "y" -> !transform.any_param
209+
%pass_options = transform.merge_handles %x, %y : !transform.any_param
210+
// expected-error @below {{options passed as a param must have a single value associated, param 0 associates 2}}
146211
transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op
147212
transform.yield
148213
}

0 commit comments

Comments
 (0)