@@ -25,12 +25,40 @@ namespace mlir {
2525// / cache has very large lock contention.
2626template <typename ValueT>
2727class ThreadLocalCache {
28+ // Keep a separate shared_ptr protected state that can be acquired atomically
29+ // instead of using shared_ptr's for each value. This avoids a problem
30+ // where the instance shared_ptr is locked() successfully, and then the
31+ // ThreadLocalCache gets destroyed before remove() can be called successfully.
32+ struct PerInstanceState {
33+ // / Remove the given value entry. This is generally called when a thread
34+ // / local cache is destructing.
35+ void remove (ValueT *value) {
36+ // Erase the found value directly, because it is guaranteed to be in the
37+ // list.
38+ llvm::sys::SmartScopedLock<true > threadInstanceLock (instanceMutex);
39+ auto it =
40+ llvm::find_if (instances, [&](std::unique_ptr<ValueT> &instance) {
41+ return instance.get () == value;
42+ });
43+ assert (it != instances.end () && " expected value to exist in cache" );
44+ instances.erase (it);
45+ }
46+
47+ // / Owning pointers to all of the values that have been constructed for this
48+ // / object in the static cache.
49+ SmallVector<std::unique_ptr<ValueT>, 1 > instances;
50+
51+ // / A mutex used when a new thread instance has been added to the cache for
52+ // / this object.
53+ llvm::sys::SmartMutex<true > instanceMutex;
54+ };
55+
2856 // / The type used for the static thread_local cache. This is a map between an
2957 // / instance of the non-static cache and a weak reference to an instance of
3058 // / ValueT. We use a weak reference here so that the object can be destroyed
3159 // / without needing to lock access to the cache itself.
32- struct CacheType : public llvm ::SmallDenseMap<ThreadLocalCache<ValueT> *,
33- std::weak_ptr<ValueT>> {
60+ struct CacheType
61+ : public llvm::SmallDenseMap<PerInstanceState *, std::weak_ptr<ValueT>> {
3462 ~CacheType () {
3563 // Remove the values of this cache that haven't already expired.
3664 for (auto &it : *this )
@@ -60,15 +88,16 @@ class ThreadLocalCache {
6088 ValueT &get () {
6189 // Check for an already existing instance for this thread.
6290 CacheType &staticCache = getStaticCache ();
63- std::weak_ptr<ValueT> &threadInstance = staticCache[this ];
91+ std::weak_ptr<ValueT> &threadInstance = staticCache[perInstanceState. get () ];
6492 if (std::shared_ptr<ValueT> value = threadInstance.lock ())
6593 return *value;
6694
6795 // Otherwise, create a new instance for this thread.
68- llvm::sys::SmartScopedLock<true > threadInstanceLock (instanceMutex);
69- instances.push_back (std::make_shared<ValueT>());
70- std::shared_ptr<ValueT> &instance = instances.back ();
71- threadInstance = instance;
96+ llvm::sys::SmartScopedLock<true > threadInstanceLock (
97+ perInstanceState->instanceMutex );
98+ perInstanceState->instances .push_back (std::make_unique<ValueT>());
99+ ValueT *instance = perInstanceState->instances .back ().get ();
100+ threadInstance = std::shared_ptr<ValueT>(perInstanceState, instance);
72101
73102 // Before returning the new instance, take the chance to clear out any used
74103 // entries in the static map. The cache is only cleared within the same
@@ -90,26 +119,8 @@ class ThreadLocalCache {
90119 return cache;
91120 }
92121
93- // / Remove the given value entry. This is generally called when a thread local
94- // / cache is destructing.
95- void remove (ValueT *value) {
96- // Erase the found value directly, because it is guaranteed to be in the
97- // list.
98- llvm::sys::SmartScopedLock<true > threadInstanceLock (instanceMutex);
99- auto it = llvm::find_if (instances, [&](std::shared_ptr<ValueT> &instance) {
100- return instance.get () == value;
101- });
102- assert (it != instances.end () && " expected value to exist in cache" );
103- instances.erase (it);
104- }
105-
106- // / Owning pointers to all of the values that have been constructed for this
107- // / object in the static cache.
108- SmallVector<std::shared_ptr<ValueT>, 1 > instances;
109-
110- // / A mutex used when a new thread instance has been added to the cache for
111- // / this object.
112- llvm::sys::SmartMutex<true > instanceMutex;
122+ std::shared_ptr<PerInstanceState> perInstanceState =
123+ std::make_shared<PerInstanceState>();
113124};
114125} // namespace mlir
115126
0 commit comments