diff --git a/mlir/include/mlir/AsmParser/AsmParser.h b/mlir/include/mlir/AsmParser/AsmParser.h index 33daf7ca26f49..f39b3bd853a2a 100644 --- a/mlir/include/mlir/AsmParser/AsmParser.h +++ b/mlir/include/mlir/AsmParser/AsmParser.h @@ -53,7 +53,8 @@ parseAsmSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, /// null terminated. Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type = {}, size_t *numRead = nullptr, - bool isKnownNullTerminated = false); + bool isKnownNullTerminated = false, + llvm::StringMap *attributesCache = nullptr); /// This parses a single MLIR type to an MLIR context if it was valid. If not, /// an error diagnostic is emitted to the context. diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp index 8b14e71118c3a..416d8eb5f40e0 100644 --- a/mlir/lib/AsmParser/DialectSymbolParser.cpp +++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp @@ -245,6 +245,15 @@ static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState, return nullptr; } + if constexpr (std::is_same_v) { + auto &cache = p.getState().symbols.attributesCache; + auto cacheIt = cache.find(symbolData); + // Skip cached attribute if it has type. + if (cacheIt != cache.end() && !p.getToken().is(Token::colon)) + return cacheIt->second; + + return cache[symbolData] = createSymbol(dialectName, symbolData, loc); + } return createSymbol(dialectName, symbolData, loc); } @@ -337,6 +346,7 @@ Type Parser::parseExtendedType() { template static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t *numReadOut, bool isKnownNullTerminated, + llvm::StringMap *attributesCache, ParserFn &&parserFn) { // Set the buffer name to the string being parsed, so that it appears in error // diagnostics. @@ -348,6 +358,9 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context, SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); SymbolState aliasState; + if (attributesCache) + aliasState.attributesCache = *attributesCache; + ParserConfig config(context); ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr, /*codeCompleteContext=*/nullptr); @@ -358,6 +371,11 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context, if (!symbol) return T(); + if constexpr (std::is_same_v) { + if (attributesCache) + *attributesCache = state.symbols.attributesCache; + } + // Provide the number of bytes that were read. Token endTok = parser.getToken(); size_t numRead = @@ -374,13 +392,15 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context, Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context, Type type, size_t *numRead, - bool isKnownNullTerminated) { + bool isKnownNullTerminated, + llvm::StringMap *attributesCache) { return parseSymbol( - attrStr, context, numRead, isKnownNullTerminated, + attrStr, context, numRead, isKnownNullTerminated, attributesCache, [type](Parser &parser) { return parser.parseAttribute(type); }); } Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead, bool isKnownNullTerminated) { return parseSymbol(typeStr, context, numRead, isKnownNullTerminated, + /*attributesCache=*/nullptr, [](Parser &parser) { return parser.parseType(); }); } diff --git a/mlir/lib/AsmParser/ParserState.h b/mlir/lib/AsmParser/ParserState.h index 159058a18fa4e..aa53032107cbf 100644 --- a/mlir/lib/AsmParser/ParserState.h +++ b/mlir/lib/AsmParser/ParserState.h @@ -40,6 +40,9 @@ struct SymbolState { /// A map from unique integer identifier to DistinctAttr. DenseMap distinctAttributes; + + /// A map from unique string identifier to Attribute. + llvm::StringMap attributesCache; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index 44458d010c6c8..0f97443433774 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -895,6 +895,10 @@ class AttrTypeReader { SmallVector attributes; SmallVector types; + /// The map of cached attributes, used to avoid re-parsing the same + /// attribute multiple times. + llvm::StringMap attributesCache; + /// A location used for error emission. Location fileLoc; @@ -1235,7 +1239,7 @@ LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader, ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true); else result = ::parseAttribute(asmStr, context, Type(), &numRead, - /*isKnownNullTerminated=*/true); + /*isKnownNullTerminated=*/true, &attributesCache); if (!result) return failure(); diff --git a/mlir/test/IR/recursive-distinct-attr.mlir b/mlir/test/IR/recursive-distinct-attr.mlir new file mode 100644 index 0000000000000..5afb5c59e0fcf --- /dev/null +++ b/mlir/test/IR/recursive-distinct-attr.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt -emit-bytecode %s | mlir-opt --mlir-print-debuginfo | FileCheck %s + +// Verify that the distinct attribute which is used transitively +// through two aliases does not end up duplicated when round-tripped +// through bytecode. + +// CHECK: distinct[0] +// CHECK-NOT: distinct[1] +#attr_ugly = #test end> +#attr_ugly1 = #test + +module attributes {test.alias = #attr_ugly, test.alias1 = #attr_ugly1} { +} \ No newline at end of file