diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h index 8fed18140433c..6a481aad9141e 100644 --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -413,16 +413,17 @@ class DistinctAttributeAllocator { /// Allocates a distinct attribute storage using a thread local bump pointer /// allocator to enable synchronization free parallel allocations. DistinctAttrStorage *allocate(Attribute referencedAttr) { + std::unique_lock lock(allocatorMutex); if (!useThreadLocalAllocator && threadingIsEnabled) { - std::scoped_lock lock(allocatorMutex); - return allocateImpl(referencedAttr); + return allocateImpl(referencedAttr, lock); } - return allocateImpl(referencedAttr); + return allocateImpl(referencedAttr, lock); } /// Sets a flag that stores if multithreading is enabled. The flag is used to /// decide if locking is needed when using a non thread-safe allocator. void disableMultiThreading(bool disable = true) { + std::scoped_lock lock(allocatorMutex); threadingIsEnabled = !disable; } @@ -431,12 +432,15 @@ class DistinctAttributeAllocator { /// beyond the lifetime of a child thread calling this function while ensuring /// thread-safe allocation. void disableThreadLocalStorage(bool disable = true) { + std::scoped_lock lock(allocatorMutex); useThreadLocalAllocator = !disable; } private: - DistinctAttrStorage *allocateImpl(Attribute referencedAttr) { - return new (getAllocatorInUse().Allocate()) + DistinctAttrStorage *allocateImpl(Attribute referencedAttr, + const std::unique_lock &lock) { + assert(lock.owns_lock()); + return new (getAllocatorInUse(lock).Allocate()) DistinctAttrStorage(referencedAttr); } @@ -444,7 +448,9 @@ class DistinctAttributeAllocator { /// thread-local, non-thread safe bump pointer allocator is used instead to /// prevent use-after-free errors whenever attribute storage created on a /// crash recover thread is accessed after the thread joins. - llvm::BumpPtrAllocator &getAllocatorInUse() { + llvm::BumpPtrAllocator & + getAllocatorInUse(const std::unique_lock &lock) { + assert(lock.owns_lock()); if (useThreadLocalAllocator) return allocatorCache.get(); return allocator;