Skip to content
3 changes: 2 additions & 1 deletion mlir/include/mlir/AsmParser/AsmParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ parseAsmSourceFile(const llvm::SourceMgr &sourceMgr, Block *block,
/// null terminated.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context,
Type type = {}, size_t *numRead = nullptr,
bool isKnownNullTerminated = false);
bool isKnownNullTerminated = false,
llvm::StringMap<Attribute> *attributesCache = nullptr);

/// This parses a single MLIR type to an MLIR context if it was valid. If not,
/// an error diagnostic is emitted to the context.
Expand Down
24 changes: 22 additions & 2 deletions mlir/lib/AsmParser/DialectSymbolParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,15 @@ static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
return nullptr;
}

if constexpr (std::is_same_v<Symbol, Attribute>) {
auto &cache = p.getState().symbols.attributesCache;
auto cacheIt = cache.find(symbolData);
// Skip cached attribute if it has type.
if (cacheIt != cache.end() && !p.getToken().is(Token::colon))
return cacheIt->second;

return cache[symbolData] = createSymbol(dialectName, symbolData, loc);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You cannot rely on the symbolData alone here. This needs to also include the dialects namespace.

}
return createSymbol(dialectName, symbolData, loc);
}

Expand Down Expand Up @@ -337,6 +346,7 @@ Type Parser::parseExtendedType() {
template <typename T, typename ParserFn>
static T parseSymbol(StringRef inputStr, MLIRContext *context,
size_t *numReadOut, bool isKnownNullTerminated,
llvm::StringMap<Attribute> *attributesCache,
ParserFn &&parserFn) {
// Set the buffer name to the string being parsed, so that it appears in error
// diagnostics.
Expand All @@ -348,6 +358,9 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context,
SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
SymbolState aliasState;
if (attributesCache)
aliasState.attributesCache = *attributesCache;

ParserConfig config(context);
ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr,
/*codeCompleteContext=*/nullptr);
Expand All @@ -358,6 +371,11 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context,
if (!symbol)
return T();

if constexpr (std::is_same_v<T, Attribute>) {
if (attributesCache)
*attributesCache = state.symbols.attributesCache;
}

// Provide the number of bytes that were read.
Token endTok = parser.getToken();
size_t numRead =
Expand All @@ -374,13 +392,15 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context,

Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
Type type, size_t *numRead,
bool isKnownNullTerminated) {
bool isKnownNullTerminated,
llvm::StringMap<Attribute> *attributesCache) {
return parseSymbol<Attribute>(
attrStr, context, numRead, isKnownNullTerminated,
attrStr, context, numRead, isKnownNullTerminated, attributesCache,
[type](Parser &parser) { return parser.parseAttribute(type); });
}
Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead,
bool isKnownNullTerminated) {
return parseSymbol<Type>(typeStr, context, numRead, isKnownNullTerminated,
/*attributesCache=*/nullptr,
[](Parser &parser) { return parser.parseType(); });
}
3 changes: 3 additions & 0 deletions mlir/lib/AsmParser/ParserState.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ struct SymbolState {

/// A map from unique integer identifier to DistinctAttr.
DenseMap<uint64_t, DistinctAttr> distinctAttributes;

/// A map from unique string identifier to Attribute.
llvm::StringMap<Attribute> attributesCache;
};

//===----------------------------------------------------------------------===//
Expand Down
6 changes: 5 additions & 1 deletion mlir/lib/Bytecode/Reader/BytecodeReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,10 @@ class AttrTypeReader {
SmallVector<AttrEntry> attributes;
SmallVector<TypeEntry> types;

/// The map of cached attributes, used to avoid re-parsing the same
/// attribute multiple times.
llvm::StringMap<Attribute> attributesCache;

/// A location used for error emission.
Location fileLoc;

Expand Down Expand Up @@ -1235,7 +1239,7 @@ LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true);
else
result = ::parseAttribute(asmStr, context, Type(), &numRead,
/*isKnownNullTerminated=*/true);
/*isKnownNullTerminated=*/true, &attributesCache);
if (!result)
return failure();

Expand Down
13 changes: 13 additions & 0 deletions mlir/test/IR/recursive-distinct-attr.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: mlir-opt -emit-bytecode %s | mlir-opt --mlir-print-debuginfo | FileCheck %s

// Verify that the distinct attribute which is used transitively
// through two aliases does not end up duplicated when round-tripped
// through bytecode.

// CHECK: distinct[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// CHECK: distinct[0]
// Verify that the distinct attribute which is used transitively
// through two aliases does not end up duplicated when round-tripped
// through bytecode.
// CHECK: distinct[0]

// CHECK-NOT: distinct[1]
#attr_ugly = #test<attr_ugly begin distinct[0]<> end>
#attr_ugly1 = #test<attr_ugly begin #attr_ugly end>

module attributes {test.alias = #attr_ugly, test.alias1 = #attr_ugly1} {
}