Skip to content

Commit 30325cb

Browse files
author
git apple-llvm automerger
committed
Merge commit 'bcc10817d556' from llvm.org/main into next
2 parents efb3355 + bcc1081 commit 30325cb

File tree

1 file changed

+38
-27
lines changed

1 file changed

+38
-27
lines changed

mlir/include/mlir/Support/ThreadLocalCache.h

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,40 @@ namespace mlir {
2525
/// cache has very large lock contention.
2626
template <typename ValueT>
2727
class ThreadLocalCache {
28+
// Keep a separate shared_ptr protected state that can be acquired atomically
29+
// instead of using shared_ptr's for each value. This avoids a problem
30+
// where the instance shared_ptr is locked() successfully, and then the
31+
// ThreadLocalCache gets destroyed before remove() can be called successfully.
32+
struct PerInstanceState {
33+
/// Remove the given value entry. This is generally called when a thread
34+
/// local cache is destructing.
35+
void remove(ValueT *value) {
36+
// Erase the found value directly, because it is guaranteed to be in the
37+
// list.
38+
llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
39+
auto it =
40+
llvm::find_if(instances, [&](std::unique_ptr<ValueT> &instance) {
41+
return instance.get() == value;
42+
});
43+
assert(it != instances.end() && "expected value to exist in cache");
44+
instances.erase(it);
45+
}
46+
47+
/// Owning pointers to all of the values that have been constructed for this
48+
/// object in the static cache.
49+
SmallVector<std::unique_ptr<ValueT>, 1> instances;
50+
51+
/// A mutex used when a new thread instance has been added to the cache for
52+
/// this object.
53+
llvm::sys::SmartMutex<true> instanceMutex;
54+
};
55+
2856
/// The type used for the static thread_local cache. This is a map between an
2957
/// instance of the non-static cache and a weak reference to an instance of
3058
/// ValueT. We use a weak reference here so that the object can be destroyed
3159
/// without needing to lock access to the cache itself.
32-
struct CacheType : public llvm::SmallDenseMap<ThreadLocalCache<ValueT> *,
33-
std::weak_ptr<ValueT>> {
60+
struct CacheType
61+
: public llvm::SmallDenseMap<PerInstanceState *, std::weak_ptr<ValueT>> {
3462
~CacheType() {
3563
// Remove the values of this cache that haven't already expired.
3664
for (auto &it : *this)
@@ -60,15 +88,16 @@ class ThreadLocalCache {
6088
ValueT &get() {
6189
// Check for an already existing instance for this thread.
6290
CacheType &staticCache = getStaticCache();
63-
std::weak_ptr<ValueT> &threadInstance = staticCache[this];
91+
std::weak_ptr<ValueT> &threadInstance = staticCache[perInstanceState.get()];
6492
if (std::shared_ptr<ValueT> value = threadInstance.lock())
6593
return *value;
6694

6795
// Otherwise, create a new instance for this thread.
68-
llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
69-
instances.push_back(std::make_shared<ValueT>());
70-
std::shared_ptr<ValueT> &instance = instances.back();
71-
threadInstance = instance;
96+
llvm::sys::SmartScopedLock<true> threadInstanceLock(
97+
perInstanceState->instanceMutex);
98+
perInstanceState->instances.push_back(std::make_unique<ValueT>());
99+
ValueT *instance = perInstanceState->instances.back().get();
100+
threadInstance = std::shared_ptr<ValueT>(perInstanceState, instance);
72101

73102
// Before returning the new instance, take the chance to clear out any used
74103
// entries in the static map. The cache is only cleared within the same
@@ -90,26 +119,8 @@ class ThreadLocalCache {
90119
return cache;
91120
}
92121

93-
/// Remove the given value entry. This is generally called when a thread local
94-
/// cache is destructing.
95-
void remove(ValueT *value) {
96-
// Erase the found value directly, because it is guaranteed to be in the
97-
// list.
98-
llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
99-
auto it = llvm::find_if(instances, [&](std::shared_ptr<ValueT> &instance) {
100-
return instance.get() == value;
101-
});
102-
assert(it != instances.end() && "expected value to exist in cache");
103-
instances.erase(it);
104-
}
105-
106-
/// Owning pointers to all of the values that have been constructed for this
107-
/// object in the static cache.
108-
SmallVector<std::shared_ptr<ValueT>, 1> instances;
109-
110-
/// A mutex used when a new thread instance has been added to the cache for
111-
/// this object.
112-
llvm::sys::SmartMutex<true> instanceMutex;
122+
std::shared_ptr<PerInstanceState> perInstanceState =
123+
std::make_shared<PerInstanceState>();
113124
};
114125
} // namespace mlir
115126

0 commit comments

Comments
 (0)