From 56d096bb46f42bc7c82503b300faa94a8d00254b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Markus=20B=C3=B6ck?= Date: Sat, 4 May 2024 19:51:04 +0100 Subject: [PATCH] [mlir][ODS] Deduplicate `ref` and `qualified` handling Both the attribute and type format generator and the op format generator independently implemented the parsing and verification of the `ref` and `qualified` directives with little to no differences. This PR moves the implementation of these into the common `FormatParser` class to deduplicate the implementations. --- .../attr-or-type-format-invalid.td | 2 +- .../tools/mlir-tblgen/AttrOrTypeFormatGen.cpp | 52 ++++--------------- mlir/tools/mlir-tblgen/FormatGen.cpp | 36 +++++++++++++ mlir/tools/mlir-tblgen/FormatGen.h | 10 +++- mlir/tools/mlir-tblgen/OpFormatGen.cpp | 40 ++------------ 5 files changed, 61 insertions(+), 79 deletions(-) diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td index d3be4d8b8022a..3a57cbca4d7bb 100644 --- a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td @@ -111,7 +111,7 @@ def InvalidTypeN : InvalidType<"InvalidTypeN", "invalid_n"> { def InvalidTypeO : InvalidType<"InvalidTypeO", "invalid_o"> { let parameters = (ins "int":$a); - // CHECK: `ref` is only allowed inside custom directives + // CHECK: 'ref' is only valid within a `custom` directive let assemblyFormat = "$a ref($a)"; } diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp index 6098808c646f7..abd1fbdaf8c64 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -940,6 +940,8 @@ class DefFormatParser : public FormatParser { ArrayRef elements, FormatElement *anchor) override; + LogicalResult markQualified(SMLoc loc, FormatElement *element) override; + /// Parse an attribute or type variable. FailureOr parseVariableImpl(SMLoc loc, StringRef name, Context ctx) override; @@ -950,12 +952,8 @@ class DefFormatParser : public FormatParser { private: /// Parse a `params` directive. FailureOr parseParamsDirective(SMLoc loc, Context ctx); - /// Parse a `qualified` directive. - FailureOr parseQualifiedDirective(SMLoc loc, Context ctx); /// Parse a `struct` directive. FailureOr parseStructDirective(SMLoc loc, Context ctx); - /// Parse a `ref` directive. - FailureOr parseRefDirective(SMLoc loc, Context ctx); /// Attribute or type tablegen def. const AttrOrTypeDef &def; @@ -1060,6 +1058,14 @@ DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc, return success(); } +LogicalResult DefFormatParser::markQualified(SMLoc loc, + FormatElement *element) { + if (!isa(element)) + return emitError(loc, "`qualified` argument list expected a variable"); + cast(element)->setShouldBeQualified(); + return success(); +} + FailureOr DefFormatParser::parse() { FailureOr> elements = FormatParser::parse(); if (failed(elements)) @@ -1107,33 +1113,11 @@ DefFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, return parseParamsDirective(loc, ctx); case FormatToken::kw_struct: return parseStructDirective(loc, ctx); - case FormatToken::kw_ref: - return parseRefDirective(loc, ctx); - case FormatToken::kw_custom: - return parseCustomDirective(loc, ctx); - default: return emitError(loc, "unsupported directive kind"); } } -FailureOr -DefFormatParser::parseQualifiedDirective(SMLoc loc, Context ctx) { - if (failed(parseToken(FormatToken::l_paren, - "expected '(' before argument list"))) - return failure(); - FailureOr var = parseElement(ctx); - if (failed(var)) - return var; - if (!isa(*var)) - return emitError(loc, "`qualified` argument list expected a variable"); - cast(*var)->setShouldBeQualified(); - if (failed( - parseToken(FormatToken::r_paren, "expected ')' after argument list"))) - return failure(); - return var; -} - FailureOr DefFormatParser::parseParamsDirective(SMLoc loc, Context ctx) { // It doesn't make sense to allow references to all parameters in a custom @@ -1201,22 +1185,6 @@ FailureOr DefFormatParser::parseStructDirective(SMLoc loc, return create(std::move(vars)); } -FailureOr DefFormatParser::parseRefDirective(SMLoc loc, - Context ctx) { - if (ctx != CustomDirectiveContext) - return emitError(loc, "`ref` is only allowed inside custom directives"); - - // Parse the child parameter element. - FailureOr child; - if (failed(parseToken(FormatToken::l_paren, "expected '('")) || - failed(child = parseElement(RefDirectiveContext)) || - failed(parseToken(FormatToken::r_paren, "expeced ')'"))) - return failure(); - - // Only parameter elements are allowed to be parsed under a `ref` directive. - return create(*child); -} - //===----------------------------------------------------------------------===// // Interface //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp index d402748b96ad5..7540e584b8fac 100644 --- a/mlir/tools/mlir-tblgen/FormatGen.cpp +++ b/mlir/tools/mlir-tblgen/FormatGen.cpp @@ -308,6 +308,10 @@ FailureOr FormatParser::parseDirective(Context ctx) { if (tok.is(FormatToken::kw_custom)) return parseCustomDirective(loc, ctx); + if (tok.is(FormatToken::kw_ref)) + return parseRefDirective(loc, ctx); + if (tok.is(FormatToken::kw_qualified)) + return parseQualifiedDirective(loc, ctx); return parseDirectiveImpl(loc, tok.getKind(), ctx); } @@ -430,6 +434,38 @@ FailureOr FormatParser::parseCustomDirective(SMLoc loc, return create(nameTok->getSpelling(), std::move(arguments)); } +FailureOr FormatParser::parseRefDirective(SMLoc loc, + Context context) { + if (context != CustomDirectiveContext) + return emitError(loc, "'ref' is only valid within a `custom` directive"); + + FailureOr arg; + if (failed(parseToken(FormatToken::l_paren, + "expected '(' before argument list")) || + failed(arg = parseElement(RefDirectiveContext)) || + failed( + parseToken(FormatToken::r_paren, "expected ')' after argument list"))) + return failure(); + + return create(*arg); +} + +FailureOr FormatParser::parseQualifiedDirective(SMLoc loc, + Context ctx) { + if (failed(parseToken(FormatToken::l_paren, + "expected '(' before argument list"))) + return failure(); + FailureOr var = parseElement(ctx); + if (failed(var)) + return var; + if (failed(markQualified(loc, *var))) + return failure(); + if (failed( + parseToken(FormatToken::r_paren, "expected ')' after argument list"))) + return failure(); + return var; +} + //===----------------------------------------------------------------------===// // Utility Functions //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h index 18a410277fc10..b061d4d8ea7f0 100644 --- a/mlir/tools/mlir-tblgen/FormatGen.h +++ b/mlir/tools/mlir-tblgen/FormatGen.h @@ -495,9 +495,12 @@ class FormatParser { FailureOr parseDirective(Context ctx); /// Parse an optional group. FailureOr parseOptionalGroup(Context ctx); - /// Parse a custom directive. FailureOr parseCustomDirective(llvm::SMLoc loc, Context ctx); + /// Parse a ref directive. + FailureOr parseRefDirective(SMLoc loc, Context context); + /// Parse a qualified directive. + FailureOr parseQualifiedDirective(SMLoc loc, Context ctx); /// Parse a format-specific variable kind. virtual FailureOr @@ -522,6 +525,11 @@ class FormatParser { ArrayRef elements, FormatElement *anchor) = 0; + /// Mark 'element' as qualified. If 'element' cannot be qualified an error + /// should be emitted and failure returned. + virtual LogicalResult markQualified(llvm::SMLoc loc, + FormatElement *element) = 0; + //===--------------------------------------------------------------------===// // Lexer Utilities diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index 806991035e668..f7cc0a292b8c5 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -2547,6 +2547,8 @@ class OpFormatParser : public FormatParser { LogicalResult verifyOptionalGroupElement(SMLoc loc, FormatElement *element, bool isAnchor); + LogicalResult markQualified(SMLoc loc, FormatElement *element) override; + /// Parse an operation variable. FailureOr parseVariableImpl(SMLoc loc, StringRef name, Context ctx) override; @@ -2622,10 +2624,6 @@ class OpFormatParser : public FormatParser { FailureOr parseOIListDirective(SMLoc loc, Context context); LogicalResult verifyOIListParsingElement(FormatElement *element, SMLoc loc); FailureOr parseOperandsDirective(SMLoc loc, Context context); - FailureOr parseQualifiedDirective(SMLoc loc, - Context context); - FailureOr parseReferenceDirective(SMLoc loc, - Context context); FailureOr parseRegionsDirective(SMLoc loc, Context context); FailureOr parseResultsDirective(SMLoc loc, Context context); FailureOr parseSuccessorsDirective(SMLoc loc, @@ -3224,16 +3222,12 @@ OpFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, return parseFunctionalTypeDirective(loc, ctx); case FormatToken::kw_operands: return parseOperandsDirective(loc, ctx); - case FormatToken::kw_qualified: - return parseQualifiedDirective(loc, ctx); case FormatToken::kw_regions: return parseRegionsDirective(loc, ctx); case FormatToken::kw_results: return parseResultsDirective(loc, ctx); case FormatToken::kw_successors: return parseSuccessorsDirective(loc, ctx); - case FormatToken::kw_ref: - return parseReferenceDirective(loc, ctx); case FormatToken::kw_type: return parseTypeDirective(loc, ctx); case FormatToken::kw_oilist: @@ -3338,22 +3332,6 @@ OpFormatParser::parseOperandsDirective(SMLoc loc, Context context) { return create(); } -FailureOr -OpFormatParser::parseReferenceDirective(SMLoc loc, Context context) { - if (context != CustomDirectiveContext) - return emitError(loc, "'ref' is only valid within a `custom` directive"); - - FailureOr arg; - if (failed(parseToken(FormatToken::l_paren, - "expected '(' before argument list")) || - failed(arg = parseElement(RefDirectiveContext)) || - failed( - parseToken(FormatToken::r_paren, "expected ')' after argument list"))) - return failure(); - - return create(*arg); -} - FailureOr OpFormatParser::parseRegionsDirective(SMLoc loc, Context context) { if (context == TypeDirectiveContext) @@ -3495,19 +3473,11 @@ FailureOr OpFormatParser::parseTypeDirective(SMLoc loc, return create(*operand); } -FailureOr -OpFormatParser::parseQualifiedDirective(SMLoc loc, Context context) { - FailureOr element; - if (failed(parseToken(FormatToken::l_paren, - "expected '(' before argument list")) || - failed(element = parseElement(context)) || - failed( - parseToken(FormatToken::r_paren, "expected ')' after argument list"))) - return failure(); - return TypeSwitch>(*element) +LogicalResult OpFormatParser::markQualified(SMLoc loc, FormatElement *element) { + return TypeSwitch(element) .Case([](auto *element) { element->setShouldBeQualified(); - return element; + return success(); }) .Default([&](auto *element) { return this->emitError(