1 //===- StorageUniquer.cpp - Common Storage Class Uniquer ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Support/StorageUniquer.h"
10 
11 #include "mlir/Support/LLVM.h"
12 #include "mlir/Support/ThreadLocalCache.h"
13 #include "mlir/Support/TypeID.h"
14 #include "llvm/Support/RWMutex.h"
15 
16 using namespace mlir;
17 using namespace mlir::detail;
18 
19 namespace {
20 /// This class represents a uniquer for storage instances of a specific type. It
21 /// contains all of the necessary data to unique storage instances in a thread
22 /// safe way. This allows for the main uniquer to bucket each of the individual
23 /// sub-types removing the need to lock the main uniquer itself.
24 struct InstSpecificUniquer {
25   using BaseStorage = StorageUniquer::BaseStorage;
26   using StorageAllocator = StorageUniquer::StorageAllocator;
27 
28   /// A lookup key for derived instances of storage objects.
29   struct LookupKey {
30     /// The known derived kind for the storage.
31     unsigned kind;
32 
33     /// The known hash value of the key.
34     unsigned hashValue;
35 
36     /// An equality function for comparing with an existing storage instance.
37     function_ref<bool(const BaseStorage *)> isEqual;
38   };
39 
40   /// A utility wrapper object representing a hashed storage object. This class
41   /// contains a storage object and an existing computed hash value.
42   struct HashedStorage {
43     HashedStorage(unsigned hashValue = 0, BaseStorage *storage = nullptr)
44         : hashValue(hashValue), storage(storage) {}
45     unsigned hashValue;
46     BaseStorage *storage;
47   };
48 
49   /// Storage info for derived TypeStorage objects.
50   struct StorageKeyInfo : DenseMapInfo<HashedStorage> {
51     static HashedStorage getEmptyKey() {
52       return HashedStorage(0, DenseMapInfo<BaseStorage *>::getEmptyKey());
53     }
54     static HashedStorage getTombstoneKey() {
55       return HashedStorage(0, DenseMapInfo<BaseStorage *>::getTombstoneKey());
56     }
57 
58     static unsigned getHashValue(const HashedStorage &key) {
59       return key.hashValue;
60     }
61     static unsigned getHashValue(LookupKey key) { return key.hashValue; }
62 
63     static bool isEqual(const HashedStorage &lhs, const HashedStorage &rhs) {
64       return lhs.storage == rhs.storage;
65     }
66     static bool isEqual(const LookupKey &lhs, const HashedStorage &rhs) {
67       if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
68         return false;
69       // If the lookup kind matches the kind of the storage, then invoke the
70       // equality function on the lookup key.
71       return lhs.kind == rhs.storage->getKind() && lhs.isEqual(rhs.storage);
72     }
73   };
74 
75   /// Unique types with specific hashing or storage constraints.
76   using StorageTypeSet = DenseSet<HashedStorage, StorageKeyInfo>;
77   StorageTypeSet complexInstances;
78 
79   /// Instances of this storage object.
80   llvm::SmallDenseMap<unsigned, BaseStorage *, 1> simpleInstances;
81 
82   /// Allocator to use when constructing derived instances.
83   StorageAllocator allocator;
84 
85   /// A mutex to keep type uniquing thread-safe.
86   llvm::sys::SmartRWMutex<true> mutex;
87 };
88 } // end anonymous namespace
89 
90 namespace mlir {
91 namespace detail {
92 /// This is the implementation of the StorageUniquer class.
93 struct StorageUniquerImpl {
94   using BaseStorage = StorageUniquer::BaseStorage;
95   using StorageAllocator = StorageUniquer::StorageAllocator;
96 
97   /// Get or create an instance of a complex derived type.
98   BaseStorage *
99   getOrCreate(TypeID id, unsigned kind, unsigned hashValue,
100               function_ref<bool(const BaseStorage *)> isEqual,
101               function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
102     assert(instUniquers.count(id) && "creating unregistered storage instance");
103     InstSpecificUniquer::LookupKey lookupKey{kind, hashValue, isEqual};
104     InstSpecificUniquer &storageUniquer = *instUniquers[id];
105     if (!threadingIsEnabled)
106       return getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn);
107 
108     // Check for a instance of this object in the local cache.
109     auto localIt = complexStorageLocalCache->insert_as(
110         InstSpecificUniquer::HashedStorage(lookupKey.hashValue), lookupKey);
111     BaseStorage *&localInst = localIt.first->storage;
112     if (localInst)
113       return localInst;
114 
115     // Check for an existing instance in read-only mode.
116     {
117       llvm::sys::SmartScopedReader<true> typeLock(storageUniquer.mutex);
118       auto it = storageUniquer.complexInstances.find_as(lookupKey);
119       if (it != storageUniquer.complexInstances.end())
120         return localInst = it->storage;
121     }
122 
123     // Acquire a writer-lock so that we can safely create the new type instance.
124     llvm::sys::SmartScopedWriter<true> typeLock(storageUniquer.mutex);
125     return localInst =
126                getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn);
127   }
128   /// Get or create an instance of a complex derived type in an thread-unsafe
129   /// fashion.
130   BaseStorage *
131   getOrCreateUnsafe(InstSpecificUniquer &storageUniquer, unsigned kind,
132                     InstSpecificUniquer::LookupKey &key,
133                     function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
134     auto existing =
135         storageUniquer.complexInstances.insert_as({key.hashValue}, key);
136     if (!existing.second)
137       return existing.first->storage;
138 
139     // Otherwise, construct and initialize the derived storage for this type
140     // instance.
141     BaseStorage *storage =
142         initializeStorage(kind, storageUniquer.allocator, ctorFn);
143     return existing.first->storage = storage;
144   }
145 
146   /// Get or create an instance of a simple derived type.
147   BaseStorage *
148   getOrCreate(TypeID id, unsigned kind,
149               function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
150     assert(instUniquers.count(id) && "creating unregistered storage instance");
151     InstSpecificUniquer &storageUniquer = *instUniquers[id];
152     if (!threadingIsEnabled)
153       return getOrCreateUnsafe(storageUniquer, kind, ctorFn);
154 
155     // Check for a instance of this object in the local cache.
156     BaseStorage *&localInst = (*simpleStorageLocalCache)[kind];
157     if (localInst)
158       return localInst;
159 
160     // Check for an existing instance in read-only mode.
161     {
162       llvm::sys::SmartScopedReader<true> typeLock(storageUniquer.mutex);
163       auto it = storageUniquer.simpleInstances.find(kind);
164       if (it != storageUniquer.simpleInstances.end())
165         return it->second;
166     }
167 
168     // Acquire a writer-lock so that we can safely create the new type instance.
169     llvm::sys::SmartScopedWriter<true> typeLock(storageUniquer.mutex);
170     return localInst = getOrCreateUnsafe(storageUniquer, kind, ctorFn);
171   }
172   /// Get or create an instance of a simple derived type in an thread-unsafe
173   /// fashion.
174   BaseStorage *
175   getOrCreateUnsafe(InstSpecificUniquer &storageUniquer, unsigned kind,
176                     function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
177     auto &result = storageUniquer.simpleInstances[kind];
178     if (result)
179       return result;
180 
181     // Otherwise, create and return a new storage instance.
182     return result = initializeStorage(kind, storageUniquer.allocator, ctorFn);
183   }
184 
185   /// Erase an instance of a complex derived type.
186   void erase(TypeID id, unsigned kind, unsigned hashValue,
187              function_ref<bool(const BaseStorage *)> isEqual,
188              function_ref<void(BaseStorage *)> cleanupFn) {
189     assert(instUniquers.count(id) && "erasing unregistered storage instance");
190     InstSpecificUniquer &storageUniquer = *instUniquers[id];
191     InstSpecificUniquer::LookupKey lookupKey{kind, hashValue, isEqual};
192 
193     // Acquire a writer-lock so that we can safely erase the type instance.
194     llvm::sys::SmartScopedWriter<true> lock(storageUniquer.mutex);
195     auto existing = storageUniquer.complexInstances.find_as(lookupKey);
196     if (existing == storageUniquer.complexInstances.end())
197       return;
198 
199     // Cleanup the storage and remove it from the map.
200     cleanupFn(existing->storage);
201     storageUniquer.complexInstances.erase(existing);
202   }
203 
204   /// Mutates an instance of a derived storage in a thread-safe way.
205   LogicalResult
206   mutate(TypeID id,
207          function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
208     assert(instUniquers.count(id) && "mutating unregistered storage instance");
209     InstSpecificUniquer &storageUniquer = *instUniquers[id];
210     if (!threadingIsEnabled)
211       return mutationFn(storageUniquer.allocator);
212 
213     llvm::sys::SmartScopedWriter<true> lock(storageUniquer.mutex);
214     return mutationFn(storageUniquer.allocator);
215   }
216 
217   //===--------------------------------------------------------------------===//
218   // Instance Storage
219   //===--------------------------------------------------------------------===//
220 
221   /// Utility to create and initialize a storage instance.
222   BaseStorage *
223   initializeStorage(unsigned kind, StorageAllocator &allocator,
224                     function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
225     BaseStorage *storage = ctorFn(allocator);
226     storage->kind = kind;
227     return storage;
228   }
229 
230   /// Map of type ids to the storage uniquer to use for registered objects.
231   DenseMap<TypeID, std::unique_ptr<InstSpecificUniquer>> instUniquers;
232 
233   /// A thread local cache for simple and complex storage objects. This helps to
234   /// reduce the lock contention when an object already existing in the cache.
235   ThreadLocalCache<DenseMap<unsigned, BaseStorage *>> simpleStorageLocalCache;
236   ThreadLocalCache<InstSpecificUniquer::StorageTypeSet>
237       complexStorageLocalCache;
238 
239   /// Flag specifying if multi-threading is enabled within the uniquer.
240   bool threadingIsEnabled = true;
241 };
242 } // end namespace detail
243 } // namespace mlir
244 
245 StorageUniquer::StorageUniquer() : impl(new StorageUniquerImpl()) {}
246 StorageUniquer::~StorageUniquer() {}
247 
248 /// Set the flag specifying if multi-threading is disabled within the uniquer.
249 void StorageUniquer::disableMultithreading(bool disable) {
250   impl->threadingIsEnabled = !disable;
251 }
252 
253 /// Register a new storage object with this uniquer using the given unique type
254 /// id.
255 void StorageUniquer::registerStorageType(TypeID id) {
256   impl->instUniquers.try_emplace(id, std::make_unique<InstSpecificUniquer>());
257 }
258 
259 /// Implementation for getting/creating an instance of a derived type with
260 /// complex storage.
261 auto StorageUniquer::getImpl(
262     const TypeID &id, unsigned kind, unsigned hashValue,
263     function_ref<bool(const BaseStorage *)> isEqual,
264     function_ref<BaseStorage *(StorageAllocator &)> ctorFn) -> BaseStorage * {
265   return impl->getOrCreate(id, kind, hashValue, isEqual, ctorFn);
266 }
267 
268 /// Implementation for getting/creating an instance of a derived type with
269 /// default storage.
270 auto StorageUniquer::getImpl(
271     const TypeID &id, unsigned kind,
272     function_ref<BaseStorage *(StorageAllocator &)> ctorFn) -> BaseStorage * {
273   return impl->getOrCreate(id, kind, ctorFn);
274 }
275 
276 /// Implementation for erasing an instance of a derived type with complex
277 /// storage.
278 void StorageUniquer::eraseImpl(const TypeID &id, unsigned kind,
279                                unsigned hashValue,
280                                function_ref<bool(const BaseStorage *)> isEqual,
281                                function_ref<void(BaseStorage *)> cleanupFn) {
282   impl->erase(id, kind, hashValue, isEqual, cleanupFn);
283 }
284 
285 /// Implementation for mutating an instance of a derived storage.
286 LogicalResult StorageUniquer::mutateImpl(
287     const TypeID &id,
288     function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
289   return impl->mutate(id, mutationFn);
290 }
291