diff --git a/mlir/include/mlir/IR/DialectResourceBlobManager.h b/mlir/include/mlir/IR/DialectResourceBlobManager.h index e3f32b7a9ab5..6c30efde306e 100644 --- a/mlir/include/mlir/IR/DialectResourceBlobManager.h +++ b/mlir/include/mlir/IR/DialectResourceBlobManager.h @@ -93,9 +93,14 @@ class DialectResourceBlobManager { return HandleT(&entry, dialect); } + /// Provide access to all the registered blobs via a callable. During access + /// the blob map is guaranteed to remain unchanged. + void getBlobMap(llvm::function_ref &)> + accessor) const; + private: /// A mutex to protect access to the blob map. - llvm::sys::SmartRWMutex blobMapLock; + mutable llvm::sys::SmartRWMutex blobMapLock; /// The internal map of tracked blobs. StringMap stores entries in distinct /// allocations, so we can freely take references to the data without fear of diff --git a/mlir/lib/IR/DialectResourceBlobManager.cpp b/mlir/lib/IR/DialectResourceBlobManager.cpp index b83b31e30ef1..83cc1879241d 100644 --- a/mlir/lib/IR/DialectResourceBlobManager.cpp +++ b/mlir/lib/IR/DialectResourceBlobManager.cpp @@ -63,3 +63,11 @@ auto DialectResourceBlobManager::insert(StringRef name, nameStorage.resize(name.size() + 1); } while (true); } + +void DialectResourceBlobManager::getBlobMap( + llvm::function_ref &)> accessor) + const { + llvm::sys::SmartScopedReader reader(blobMapLock); + + accessor(blobMap); +} diff --git a/mlir/unittests/IR/BlobManagerTest.cpp b/mlir/unittests/IR/BlobManagerTest.cpp new file mode 100644 index 000000000000..d82482ddb793 --- /dev/null +++ b/mlir/unittests/IR/BlobManagerTest.cpp @@ -0,0 +1,74 @@ +//===- mlir/unittest/IR/BlobManagerTest.cpp - Blob management unit tests --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "../../test/lib/Dialect/Test/TestDialect.h" +#include "mlir/IR/DialectResourceBlobManager.h" +#include "mlir/Parser/Parser.h" + +#include "gtest/gtest.h" + +using namespace mlir; + +namespace { + +StringLiteral moduleStr = R"mlir( +"test.use1"() {attr = dense_resource : tensor<1xi64> } : () -> () + +{-# + dialect_resources: { + builtin: { + blob1: "0x08000000ABCDABCDABCDABCE" + } + } +#-} +)mlir"; + +TEST(DialectResourceBlobManagerTest, Lookup) { + MLIRContext context; + context.loadDialect(); + + OwningOpRef m = parseSourceString(moduleStr, &context); + ASSERT_TRUE(m); + + const auto &dialectManager = + mlir::DenseResourceElementsHandle::getManagerInterface(&context); + ASSERT_NE(dialectManager.getBlobManager().lookup("blob1"), nullptr); +} + +TEST(DialectResourceBlobManagerTest, GetBlobMap) { + MLIRContext context; + context.loadDialect(); + + OwningOpRef m = parseSourceString(moduleStr, &context); + ASSERT_TRUE(m); + + Block *block = m->getBody(); + auto &op = block->getOperations().front(); + auto resourceAttr = op.getAttrOfType("attr"); + ASSERT_NE(resourceAttr, nullptr); + + const auto &dialectManager = + resourceAttr.getRawHandle().getManagerInterface(&context); + + bool blobsArePresent = false; + dialectManager.getBlobManager().getBlobMap( + [&](const llvm::StringMap + &blobMap) { blobsArePresent = blobMap.contains("blob1"); }); + ASSERT_TRUE(blobsArePresent); + + // remove operations that use resources - resources must still be accessible + block->clear(); + + blobsArePresent = false; + dialectManager.getBlobManager().getBlobMap( + [&](const llvm::StringMap + &blobMap) { blobsArePresent = blobMap.contains("blob1"); }); + ASSERT_TRUE(blobsArePresent); +} + +} // end anonymous namespace diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt index 05cb36e19031..6ac4dfc99306 100644 --- a/mlir/unittests/IR/CMakeLists.txt +++ b/mlir/unittests/IR/CMakeLists.txt @@ -15,6 +15,7 @@ add_mlir_unittest(MLIRIRTests TypeTest.cpp TypeAttrNamesTest.cpp OpPropertiesTest.cpp + BlobManagerTest.cpp DEPENDS MLIRTestInterfaceIncGen