1 //===- MLIRContext.cpp - MLIR Type Classes --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/MLIRContext.h"
10 #include "AffineExprDetail.h"
11 #include "AffineMapDetail.h"
12 #include "AttributeDetail.h"
13 #include "IntegerSetDetail.h"
14 #include "TypeDetail.h"
15 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/BuiltinDialect.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/IR/Identifier.h"
22 #include "mlir/IR/IntegerSet.h"
23 #include "mlir/IR/Location.h"
24 #include "mlir/IR/OpImplementation.h"
25 #include "mlir/IR/Types.h"
26 #include "mlir/Support/DebugAction.h"
27 #include "llvm/ADT/DenseMap.h"
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/SetVector.h"
30 #include "llvm/ADT/SmallString.h"
31 #include "llvm/ADT/StringSet.h"
32 #include "llvm/ADT/Twine.h"
33 #include "llvm/Support/Allocator.h"
34 #include "llvm/Support/CommandLine.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/RWMutex.h"
37 #include "llvm/Support/ThreadPool.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include <memory>
40 
41 #define DEBUG_TYPE "mlircontext"
42 
43 using namespace mlir;
44 using namespace mlir::detail;
45 
46 using llvm::hash_combine;
47 using llvm::hash_combine_range;
48 
49 //===----------------------------------------------------------------------===//
50 // MLIRContext CommandLine Options
51 //===----------------------------------------------------------------------===//
52 
53 namespace {
54 /// This struct contains command line options that can be used to initialize
55 /// various bits of an MLIRContext. This uses a struct wrapper to avoid the need
56 /// for global command line options.
57 struct MLIRContextOptions {
58   llvm::cl::opt<bool> disableThreading{
59       "mlir-disable-threading",
60       llvm::cl::desc("Disable multi-threading within MLIR, overrides any "
61                      "further call to MLIRContext::enableMultiThreading()")};
62 
63   llvm::cl::opt<bool> printOpOnDiagnostic{
64       "mlir-print-op-on-diagnostic",
65       llvm::cl::desc("When a diagnostic is emitted on an operation, also print "
66                      "the operation as an attached note"),
67       llvm::cl::init(true)};
68 
69   llvm::cl::opt<bool> printStackTraceOnDiagnostic{
70       "mlir-print-stacktrace-on-diagnostic",
71       llvm::cl::desc("When a diagnostic is emitted, also print the stack trace "
72                      "as an attached note")};
73 };
74 } // end anonymous namespace
75 
76 static llvm::ManagedStatic<MLIRContextOptions> clOptions;
77 
78 static bool isThreadingGloballyDisabled() {
79 #if LLVM_ENABLE_THREADS != 0
80   return clOptions.isConstructed() && clOptions->disableThreading;
81 #else
82   return true;
83 #endif
84 }
85 
86 /// Register a set of useful command-line options that can be used to configure
87 /// various flags within the MLIRContext. These flags are used when constructing
88 /// an MLIR context for initialization.
89 void mlir::registerMLIRContextCLOptions() {
90   // Make sure that the options struct has been initialized.
91   *clOptions;
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // Locking Utilities
96 //===----------------------------------------------------------------------===//
97 
98 namespace {
99 /// Utility reader lock that takes a runtime flag that specifies if we really
100 /// need to lock.
101 struct ScopedReaderLock {
102   ScopedReaderLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock)
103       : mutex(shouldLock ? &mutexParam : nullptr) {
104     if (mutex)
105       mutex->lock_shared();
106   }
107   ~ScopedReaderLock() {
108     if (mutex)
109       mutex->unlock_shared();
110   }
111   llvm::sys::SmartRWMutex<true> *mutex;
112 };
113 /// Utility writer lock that takes a runtime flag that specifies if we really
114 /// need to lock.
115 struct ScopedWriterLock {
116   ScopedWriterLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock)
117       : mutex(shouldLock ? &mutexParam : nullptr) {
118     if (mutex)
119       mutex->lock();
120   }
121   ~ScopedWriterLock() {
122     if (mutex)
123       mutex->unlock();
124   }
125   llvm::sys::SmartRWMutex<true> *mutex;
126 };
127 } // end anonymous namespace.
128 
129 //===----------------------------------------------------------------------===//
130 // AffineMap and IntegerSet hashing
131 //===----------------------------------------------------------------------===//
132 
133 /// A utility function to safely get or create a uniqued instance within the
134 /// given set container.
135 template <typename ValueT, typename DenseInfoT, typename KeyT,
136           typename ConstructorFn>
137 static ValueT safeGetOrCreate(DenseSet<ValueT, DenseInfoT> &container,
138                               KeyT &&key, llvm::sys::SmartRWMutex<true> &mutex,
139                               bool threadingIsEnabled,
140                               ConstructorFn &&constructorFn) {
141   // Check for an existing instance in read-only mode.
142   if (threadingIsEnabled) {
143     llvm::sys::SmartScopedReader<true> instanceLock(mutex);
144     auto it = container.find_as(key);
145     if (it != container.end())
146       return *it;
147   }
148 
149   // Acquire a writer-lock so that we can safely create the new instance.
150   ScopedWriterLock instanceLock(mutex, threadingIsEnabled);
151 
152   // Check for an existing instance again here, because another writer thread
153   // may have already created one. Otherwise, construct a new instance.
154   auto existing = container.insert_as(ValueT(), key);
155   if (existing.second)
156     return *existing.first = constructorFn();
157   return *existing.first;
158 }
159 
160 namespace {
161 struct AffineMapKeyInfo : DenseMapInfo<AffineMap> {
162   // Affine maps are uniqued based on their dim/symbol counts and affine
163   // expressions.
164   using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExpr>>;
165   using DenseMapInfo<AffineMap>::isEqual;
166 
167   static unsigned getHashValue(const AffineMap &key) {
168     return getHashValue(
169         KeyTy(key.getNumDims(), key.getNumSymbols(), key.getResults()));
170   }
171 
172   static unsigned getHashValue(KeyTy key) {
173     return hash_combine(
174         std::get<0>(key), std::get<1>(key),
175         hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()));
176   }
177 
178   static bool isEqual(const KeyTy &lhs, AffineMap rhs) {
179     if (rhs == getEmptyKey() || rhs == getTombstoneKey())
180       return false;
181     return lhs == std::make_tuple(rhs.getNumDims(), rhs.getNumSymbols(),
182                                   rhs.getResults());
183   }
184 };
185 
186 struct IntegerSetKeyInfo : DenseMapInfo<IntegerSet> {
187   // Integer sets are uniqued based on their dim/symbol counts, affine
188   // expressions appearing in the LHS of constraints, and eqFlags.
189   using KeyTy =
190       std::tuple<unsigned, unsigned, ArrayRef<AffineExpr>, ArrayRef<bool>>;
191   using DenseMapInfo<IntegerSet>::isEqual;
192 
193   static unsigned getHashValue(const IntegerSet &key) {
194     return getHashValue(KeyTy(key.getNumDims(), key.getNumSymbols(),
195                               key.getConstraints(), key.getEqFlags()));
196   }
197 
198   static unsigned getHashValue(KeyTy key) {
199     return hash_combine(
200         std::get<0>(key), std::get<1>(key),
201         hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()),
202         hash_combine_range(std::get<3>(key).begin(), std::get<3>(key).end()));
203   }
204 
205   static bool isEqual(const KeyTy &lhs, IntegerSet rhs) {
206     if (rhs == getEmptyKey() || rhs == getTombstoneKey())
207       return false;
208     return lhs == std::make_tuple(rhs.getNumDims(), rhs.getNumSymbols(),
209                                   rhs.getConstraints(), rhs.getEqFlags());
210   }
211 };
212 } // end anonymous namespace.
213 
214 //===----------------------------------------------------------------------===//
215 // MLIRContextImpl
216 //===----------------------------------------------------------------------===//
217 
218 namespace mlir {
219 /// This is the implementation of the MLIRContext class, using the pImpl idiom.
220 /// This class is completely private to this file, so everything is public.
221 class MLIRContextImpl {
222 public:
223   //===--------------------------------------------------------------------===//
224   // Debugging
225   //===--------------------------------------------------------------------===//
226 
227   /// An action manager for use within the context.
228   DebugActionManager debugActionManager;
229 
230   //===--------------------------------------------------------------------===//
231   // Identifier uniquing
232   //===--------------------------------------------------------------------===//
233 
234   // Identifier allocator and mutex for thread safety.
235   llvm::BumpPtrAllocator identifierAllocator;
236   llvm::sys::SmartRWMutex<true> identifierMutex;
237 
238   //===--------------------------------------------------------------------===//
239   // Diagnostics
240   //===--------------------------------------------------------------------===//
241   DiagnosticEngine diagEngine;
242 
243   //===--------------------------------------------------------------------===//
244   // Options
245   //===--------------------------------------------------------------------===//
246 
247   /// In most cases, creating operation in unregistered dialect is not desired
248   /// and indicate a misconfiguration of the compiler. This option enables to
249   /// detect such use cases
250   bool allowUnregisteredDialects = false;
251 
252   /// Enable support for multi-threading within MLIR.
253   bool threadingIsEnabled = true;
254 
255   /// Track if we are currently executing in a threaded execution environment
256   /// (like the pass-manager): this is only a debugging feature to help reducing
257   /// the chances of data races one some context APIs.
258 #ifndef NDEBUG
259   std::atomic<int> multiThreadedExecutionContext{0};
260 #endif
261 
262   /// If the operation should be attached to diagnostics printed via the
263   /// Operation::emit methods.
264   bool printOpOnDiagnostic = true;
265 
266   /// If the current stack trace should be attached when emitting diagnostics.
267   bool printStackTraceOnDiagnostic = false;
268 
269   //===--------------------------------------------------------------------===//
270   // Other
271   //===--------------------------------------------------------------------===//
272 
273   /// This points to the ThreadPool used when processing MLIR tasks in parallel.
274   /// It can't be nullptr when multi-threading is enabled. Otherwise if
275   /// multi-threading is disabled, and the threadpool wasn't externally provided
276   /// using `setThreadPool`, this will be nullptr.
277   llvm::ThreadPool *threadPool = nullptr;
278 
279   /// In case where the thread pool is owned by the context, this ensures
280   /// destruction with the context.
281   std::unique_ptr<llvm::ThreadPool> ownedThreadPool;
282 
283   /// This is a list of dialects that are created referring to this context.
284   /// The MLIRContext owns the objects.
285   DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects;
286   DialectRegistry dialectsRegistry;
287 
288   /// This is a mapping from operation name to AbstractOperation for registered
289   /// operations.
290   llvm::StringMap<AbstractOperation> registeredOperations;
291 
292   /// Identifiers are uniqued by string value and use the internal string set
293   /// for storage.
294   llvm::StringMap<PointerUnion<Dialect *, MLIRContext *>,
295                   llvm::BumpPtrAllocator &>
296       identifiers;
297 
298   /// An allocator used for AbstractAttribute and AbstractType objects.
299   llvm::BumpPtrAllocator abstractDialectSymbolAllocator;
300 
301   //===--------------------------------------------------------------------===//
302   // Affine uniquing
303   //===--------------------------------------------------------------------===//
304 
305   // Affine allocator and mutex for thread safety.
306   llvm::BumpPtrAllocator affineAllocator;
307   llvm::sys::SmartRWMutex<true> affineMutex;
308 
309   // Affine map uniquing.
310   using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>;
311   AffineMapSet affineMaps;
312 
313   // Integer set uniquing.
314   using IntegerSets = DenseSet<IntegerSet, IntegerSetKeyInfo>;
315   IntegerSets integerSets;
316 
317   // Affine expression uniquing.
318   StorageUniquer affineUniquer;
319 
320   //===--------------------------------------------------------------------===//
321   // Type uniquing
322   //===--------------------------------------------------------------------===//
323 
324   DenseMap<TypeID, AbstractType *> registeredTypes;
325   StorageUniquer typeUniquer;
326 
327   /// Cached Type Instances.
328   BFloat16Type bf16Ty;
329   Float16Type f16Ty;
330   Float32Type f32Ty;
331   Float64Type f64Ty;
332   Float80Type f80Ty;
333   Float128Type f128Ty;
334   IndexType indexTy;
335   IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
336   NoneType noneType;
337 
338   //===--------------------------------------------------------------------===//
339   // Attribute uniquing
340   //===--------------------------------------------------------------------===//
341 
342   DenseMap<TypeID, AbstractAttribute *> registeredAttributes;
343   StorageUniquer attributeUniquer;
344 
345   /// Cached Attribute Instances.
346   BoolAttr falseAttr, trueAttr;
347   UnitAttr unitAttr;
348   UnknownLoc unknownLocAttr;
349   DictionaryAttr emptyDictionaryAttr;
350   StringAttr emptyStringAttr;
351 
352 public:
353   MLIRContextImpl(bool threadingIsEnabled)
354       : threadingIsEnabled(threadingIsEnabled),
355         identifiers(identifierAllocator) {
356     if (threadingIsEnabled) {
357       ownedThreadPool = std::make_unique<llvm::ThreadPool>();
358       threadPool = ownedThreadPool.get();
359     }
360   }
361   ~MLIRContextImpl() {
362     for (auto typeMapping : registeredTypes)
363       typeMapping.second->~AbstractType();
364     for (auto attrMapping : registeredAttributes)
365       attrMapping.second->~AbstractAttribute();
366   }
367 };
368 } // end namespace mlir
369 
370 MLIRContext::MLIRContext(Threading setting)
371     : MLIRContext(DialectRegistry(), setting) {}
372 
373 MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
374     : impl(new MLIRContextImpl(setting == Threading::ENABLED &&
375                                !isThreadingGloballyDisabled())) {
376   // Initialize values based on the command line flags if they were provided.
377   if (clOptions.isConstructed()) {
378     printOpOnDiagnostic(clOptions->printOpOnDiagnostic);
379     printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic);
380   }
381 
382   // Pre-populate the registry.
383   registry.appendTo(impl->dialectsRegistry);
384 
385   // Ensure the builtin dialect is always pre-loaded.
386   getOrLoadDialect<BuiltinDialect>();
387 
388   // Initialize several common attributes and types to avoid the need to lock
389   // the context when accessing them.
390 
391   //// Types.
392   /// Floating-point Types.
393   impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
394   impl->f16Ty = TypeUniquer::get<Float16Type>(this);
395   impl->f32Ty = TypeUniquer::get<Float32Type>(this);
396   impl->f64Ty = TypeUniquer::get<Float64Type>(this);
397   impl->f80Ty = TypeUniquer::get<Float80Type>(this);
398   impl->f128Ty = TypeUniquer::get<Float128Type>(this);
399   /// Index Type.
400   impl->indexTy = TypeUniquer::get<IndexType>(this);
401   /// Integer Types.
402   impl->int1Ty = TypeUniquer::get<IntegerType>(this, 1, IntegerType::Signless);
403   impl->int8Ty = TypeUniquer::get<IntegerType>(this, 8, IntegerType::Signless);
404   impl->int16Ty =
405       TypeUniquer::get<IntegerType>(this, 16, IntegerType::Signless);
406   impl->int32Ty =
407       TypeUniquer::get<IntegerType>(this, 32, IntegerType::Signless);
408   impl->int64Ty =
409       TypeUniquer::get<IntegerType>(this, 64, IntegerType::Signless);
410   impl->int128Ty =
411       TypeUniquer::get<IntegerType>(this, 128, IntegerType::Signless);
412   /// None Type.
413   impl->noneType = TypeUniquer::get<NoneType>(this);
414 
415   //// Attributes.
416   //// Note: These must be registered after the types as they may generate one
417   //// of the above types internally.
418   /// Unknown Location Attribute.
419   impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(this);
420   /// Bool Attributes.
421   impl->falseAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, false);
422   impl->trueAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, true);
423   /// Unit Attribute.
424   impl->unitAttr = AttributeUniquer::get<UnitAttr>(this);
425   /// The empty dictionary attribute.
426   impl->emptyDictionaryAttr = DictionaryAttr::getEmptyUnchecked(this);
427   /// The empty string attribute.
428   impl->emptyStringAttr = StringAttr::getEmptyStringAttrUnchecked(this);
429 
430   // Register the affine storage objects with the uniquer.
431   impl->affineUniquer
432       .registerParametricStorageType<AffineBinaryOpExprStorage>();
433   impl->affineUniquer
434       .registerParametricStorageType<AffineConstantExprStorage>();
435   impl->affineUniquer.registerParametricStorageType<AffineDimExprStorage>();
436 }
437 
438 MLIRContext::~MLIRContext() {}
439 
440 /// Copy the specified array of elements into memory managed by the provided
441 /// bump pointer allocator.  This assumes the elements are all PODs.
442 template <typename T>
443 static ArrayRef<T> copyArrayRefInto(llvm::BumpPtrAllocator &allocator,
444                                     ArrayRef<T> elements) {
445   auto result = allocator.Allocate<T>(elements.size());
446   std::uninitialized_copy(elements.begin(), elements.end(), result);
447   return ArrayRef<T>(result, elements.size());
448 }
449 
450 //===----------------------------------------------------------------------===//
451 // Debugging
452 //===----------------------------------------------------------------------===//
453 
454 DebugActionManager &MLIRContext::getDebugActionManager() {
455   return getImpl().debugActionManager;
456 }
457 
458 //===----------------------------------------------------------------------===//
459 // Diagnostic Handlers
460 //===----------------------------------------------------------------------===//
461 
462 /// Returns the diagnostic engine for this context.
463 DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
464 
465 //===----------------------------------------------------------------------===//
466 // Dialect and Operation Registration
467 //===----------------------------------------------------------------------===//
468 
469 void MLIRContext::appendDialectRegistry(const DialectRegistry &registry) {
470   registry.appendTo(impl->dialectsRegistry);
471 
472   // For the already loaded dialects, register the interfaces immediately.
473   for (const auto &kvp : impl->loadedDialects)
474     registry.registerDelayedInterfaces(kvp.second.get());
475 }
476 
477 const DialectRegistry &MLIRContext::getDialectRegistry() {
478   return impl->dialectsRegistry;
479 }
480 
481 /// Return information about all registered IR dialects.
482 std::vector<Dialect *> MLIRContext::getLoadedDialects() {
483   std::vector<Dialect *> result;
484   result.reserve(impl->loadedDialects.size());
485   for (auto &dialect : impl->loadedDialects)
486     result.push_back(dialect.second.get());
487   llvm::array_pod_sort(result.begin(), result.end(),
488                        [](Dialect *const *lhs, Dialect *const *rhs) -> int {
489                          return (*lhs)->getNamespace() < (*rhs)->getNamespace();
490                        });
491   return result;
492 }
493 std::vector<StringRef> MLIRContext::getAvailableDialects() {
494   std::vector<StringRef> result;
495   for (auto dialect : impl->dialectsRegistry.getDialectNames())
496     result.push_back(dialect);
497   return result;
498 }
499 
500 /// Get a registered IR dialect with the given namespace. If none is found,
501 /// then return nullptr.
502 Dialect *MLIRContext::getLoadedDialect(StringRef name) {
503   // Dialects are sorted by name, so we can use binary search for lookup.
504   auto it = impl->loadedDialects.find(name);
505   return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr;
506 }
507 
508 Dialect *MLIRContext::getOrLoadDialect(StringRef name) {
509   Dialect *dialect = getLoadedDialect(name);
510   if (dialect)
511     return dialect;
512   DialectAllocatorFunctionRef allocator =
513       impl->dialectsRegistry.getDialectAllocator(name);
514   return allocator ? allocator(this) : nullptr;
515 }
516 
517 /// Get a dialect for the provided namespace and TypeID: abort the program if a
518 /// dialect exist for this namespace with different TypeID. Returns a pointer to
519 /// the dialect owned by the context.
520 Dialect *
521 MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
522                               function_ref<std::unique_ptr<Dialect>()> ctor) {
523   auto &impl = getImpl();
524   // Get the correct insertion position sorted by namespace.
525   std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace];
526 
527   if (!dialect) {
528     LLVM_DEBUG(llvm::dbgs()
529                << "Load new dialect in Context " << dialectNamespace << "\n");
530 #ifndef NDEBUG
531     if (impl.multiThreadedExecutionContext != 0)
532       llvm::report_fatal_error(
533           "Loading a dialect (" + dialectNamespace +
534           ") while in a multi-threaded execution context (maybe "
535           "the PassManager): this can indicate a "
536           "missing `dependentDialects` in a pass for example.");
537 #endif
538     dialect = ctor();
539     assert(dialect && "dialect ctor failed");
540 
541     // Refresh all the identifiers dialect field, this catches cases where a
542     // dialect may be loaded after identifier prefixed with this dialect name
543     // were already created.
544     llvm::SmallString<32> dialectPrefix(dialectNamespace);
545     dialectPrefix.push_back('.');
546     for (auto &identifierEntry : impl.identifiers)
547       if (identifierEntry.second.is<MLIRContext *>() &&
548           identifierEntry.first().startswith(dialectPrefix))
549         identifierEntry.second = dialect.get();
550 
551     // Actually register the interfaces with delayed registration.
552     impl.dialectsRegistry.registerDelayedInterfaces(dialect.get());
553     return dialect.get();
554   }
555 
556   // Abort if dialect with namespace has already been registered.
557   if (dialect->getTypeID() != dialectID)
558     llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
559                              "' has already been registered");
560 
561   return dialect.get();
562 }
563 
564 void MLIRContext::loadAllAvailableDialects() {
565   for (StringRef name : getAvailableDialects())
566     getOrLoadDialect(name);
567 }
568 
569 llvm::hash_code MLIRContext::getRegistryHash() {
570   llvm::hash_code hash(0);
571   // Factor in number of loaded dialects, attributes, operations, types.
572   hash = llvm::hash_combine(hash, impl->loadedDialects.size());
573   hash = llvm::hash_combine(hash, impl->registeredAttributes.size());
574   hash = llvm::hash_combine(hash, impl->registeredOperations.size());
575   hash = llvm::hash_combine(hash, impl->registeredTypes.size());
576   return hash;
577 }
578 
579 bool MLIRContext::allowsUnregisteredDialects() {
580   return impl->allowUnregisteredDialects;
581 }
582 
583 void MLIRContext::allowUnregisteredDialects(bool allowing) {
584   impl->allowUnregisteredDialects = allowing;
585 }
586 
587 /// Return true if multi-threading is enabled by the context.
588 bool MLIRContext::isMultithreadingEnabled() {
589   return impl->threadingIsEnabled && llvm::llvm_is_multithreaded();
590 }
591 
592 /// Set the flag specifying if multi-threading is disabled by the context.
593 void MLIRContext::disableMultithreading(bool disable) {
594   // This API can be overridden by the global debugging flag
595   // --mlir-disable-threading
596   if (isThreadingGloballyDisabled())
597     return;
598 
599   impl->threadingIsEnabled = !disable;
600 
601   // Update the threading mode for each of the uniquers.
602   impl->affineUniquer.disableMultithreading(disable);
603   impl->attributeUniquer.disableMultithreading(disable);
604   impl->typeUniquer.disableMultithreading(disable);
605 
606   // Destroy thread pool (stop all threads) if it is no longer needed, or create
607   // a new one if multithreading was re-enabled.
608   if (disable) {
609     // If the thread pool is owned, explicitly set it to nullptr to avoid
610     // keeping a dangling pointer around. If the thread pool is externally
611     // owned, we don't do anything.
612     if (impl->ownedThreadPool) {
613       assert(impl->threadPool);
614       impl->threadPool = nullptr;
615       impl->ownedThreadPool.reset();
616     }
617   } else if (!impl->threadPool) {
618     // The thread pool isn't externally provided.
619     assert(!impl->ownedThreadPool);
620     impl->ownedThreadPool = std::make_unique<llvm::ThreadPool>();
621     impl->threadPool = impl->ownedThreadPool.get();
622   }
623 }
624 
625 void MLIRContext::setThreadPool(llvm::ThreadPool &pool) {
626   assert(!isMultithreadingEnabled() &&
627          "expected multi-threading to be disabled when setting a ThreadPool");
628   impl->threadPool = &pool;
629   impl->ownedThreadPool.reset();
630   enableMultithreading();
631 }
632 
633 llvm::ThreadPool &MLIRContext::getThreadPool() {
634   assert(isMultithreadingEnabled() &&
635          "expected multi-threading to be enabled within the context");
636   assert(impl->threadPool &&
637          "multi-threading is enabled but threadpool not set");
638   return *impl->threadPool;
639 }
640 
641 void MLIRContext::enterMultiThreadedExecution() {
642 #ifndef NDEBUG
643   ++impl->multiThreadedExecutionContext;
644 #endif
645 }
646 void MLIRContext::exitMultiThreadedExecution() {
647 #ifndef NDEBUG
648   --impl->multiThreadedExecutionContext;
649 #endif
650 }
651 
652 /// Return true if we should attach the operation to diagnostics emitted via
653 /// Operation::emit.
654 bool MLIRContext::shouldPrintOpOnDiagnostic() {
655   return impl->printOpOnDiagnostic;
656 }
657 
658 /// Set the flag specifying if we should attach the operation to diagnostics
659 /// emitted via Operation::emit.
660 void MLIRContext::printOpOnDiagnostic(bool enable) {
661   impl->printOpOnDiagnostic = enable;
662 }
663 
664 /// Return true if we should attach the current stacktrace to diagnostics when
665 /// emitted.
666 bool MLIRContext::shouldPrintStackTraceOnDiagnostic() {
667   return impl->printStackTraceOnDiagnostic;
668 }
669 
670 /// Set the flag specifying if we should attach the current stacktrace when
671 /// emitting diagnostics.
672 void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
673   impl->printStackTraceOnDiagnostic = enable;
674 }
675 
676 /// Return information about all registered operations.  This isn't very
677 /// efficient, typically you should ask the operations about their properties
678 /// directly.
679 std::vector<AbstractOperation *> MLIRContext::getRegisteredOperations() {
680   // We just have the operations in a non-deterministic hash table order. Dump
681   // into a temporary array, then sort it by operation name to get a stable
682   // ordering.
683   llvm::StringMap<AbstractOperation> &registeredOps =
684       impl->registeredOperations;
685 
686   std::vector<AbstractOperation *> result;
687   result.reserve(registeredOps.size());
688   for (auto &elt : registeredOps)
689     result.push_back(&elt.second);
690   llvm::array_pod_sort(
691       result.begin(), result.end(),
692       [](AbstractOperation *const *lhs, AbstractOperation *const *rhs) {
693         return (*lhs)->name.compare((*rhs)->name);
694       });
695 
696   return result;
697 }
698 
699 bool MLIRContext::isOperationRegistered(StringRef name) {
700   return impl->registeredOperations.count(name);
701 }
702 
703 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
704   auto &impl = context->getImpl();
705   assert(impl.multiThreadedExecutionContext == 0 &&
706          "Registering a new type kind while in a multi-threaded execution "
707          "context");
708   auto *newInfo =
709       new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>())
710           AbstractType(std::move(typeInfo));
711   if (!impl.registeredTypes.insert({typeID, newInfo}).second)
712     llvm::report_fatal_error("Dialect Type already registered.");
713 }
714 
715 void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
716   auto &impl = context->getImpl();
717   assert(impl.multiThreadedExecutionContext == 0 &&
718          "Registering a new attribute kind while in a multi-threaded execution "
719          "context");
720   auto *newInfo =
721       new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>())
722           AbstractAttribute(std::move(attrInfo));
723   if (!impl.registeredAttributes.insert({typeID, newInfo}).second)
724     llvm::report_fatal_error("Dialect Attribute already registered.");
725 }
726 
727 //===----------------------------------------------------------------------===//
728 // AbstractAttribute
729 //===----------------------------------------------------------------------===//
730 
731 /// Get the dialect that registered the attribute with the provided typeid.
732 const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID,
733                                                    MLIRContext *context) {
734   const AbstractAttribute *abstract = lookupMutable(typeID, context);
735   if (!abstract)
736     llvm::report_fatal_error("Trying to create an Attribute that was not "
737                              "registered in this MLIRContext.");
738   return *abstract;
739 }
740 
741 AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID,
742                                                     MLIRContext *context) {
743   auto &impl = context->getImpl();
744   auto it = impl.registeredAttributes.find(typeID);
745   if (it == impl.registeredAttributes.end())
746     return nullptr;
747   return it->second;
748 }
749 
750 //===----------------------------------------------------------------------===//
751 // AbstractOperation
752 //===----------------------------------------------------------------------===//
753 
754 ParseResult AbstractOperation::parseAssembly(OpAsmParser &parser,
755                                              OperationState &result) const {
756   return parseAssemblyFn(parser, result);
757 }
758 
759 /// Look up the specified operation in the operation set and return a pointer
760 /// to it if present. Otherwise, return a null pointer.
761 AbstractOperation *AbstractOperation::lookupMutable(StringRef opName,
762                                                     MLIRContext *context) {
763   auto &impl = context->getImpl();
764   auto it = impl.registeredOperations.find(opName);
765   if (it != impl.registeredOperations.end())
766     return &it->second;
767   return nullptr;
768 }
769 
770 void AbstractOperation::insert(
771     StringRef name, Dialect &dialect, TypeID typeID,
772     ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
773     VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
774     GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
775     detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
776     ArrayRef<StringRef> attrNames) {
777   MLIRContext *ctx = dialect.getContext();
778   auto &impl = ctx->getImpl();
779   assert(impl.multiThreadedExecutionContext == 0 &&
780          "Registering a new operation kind while in a multi-threaded execution "
781          "context");
782 
783   // Register the attribute names of this operation.
784   MutableArrayRef<Identifier> cachedAttrNames;
785   if (!attrNames.empty()) {
786     cachedAttrNames = MutableArrayRef<Identifier>(
787         impl.identifierAllocator.Allocate<Identifier>(attrNames.size()),
788         attrNames.size());
789     for (unsigned i : llvm::seq<unsigned>(0, attrNames.size()))
790       new (&cachedAttrNames[i]) Identifier(Identifier::get(attrNames[i], ctx));
791   }
792 
793   // Register the information for this operation.
794   AbstractOperation opInfo(
795       name, dialect, typeID, std::move(parseAssembly), std::move(printAssembly),
796       std::move(verifyInvariants), std::move(foldHook),
797       std::move(getCanonicalizationPatterns), std::move(interfaceMap),
798       std::move(hasTrait), cachedAttrNames);
799   if (!impl.registeredOperations.insert({name, std::move(opInfo)}).second) {
800     llvm::errs() << "error: operation named '" << name
801                  << "' is already registered.\n";
802     abort();
803   }
804 }
805 
806 AbstractOperation::AbstractOperation(
807     StringRef name, Dialect &dialect, TypeID typeID,
808     ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
809     VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
810     GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
811     detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
812     ArrayRef<Identifier> attrNames)
813     : name(Identifier::get(name, dialect.getContext())), dialect(dialect),
814       typeID(typeID), interfaceMap(std::move(interfaceMap)),
815       foldHookFn(std::move(foldHook)),
816       getCanonicalizationPatternsFn(std::move(getCanonicalizationPatterns)),
817       hasTraitFn(std::move(hasTrait)),
818       parseAssemblyFn(std::move(parseAssembly)),
819       printAssemblyFn(std::move(printAssembly)),
820       verifyInvariantsFn(std::move(verifyInvariants)),
821       attributeNames(attrNames) {}
822 
823 //===----------------------------------------------------------------------===//
824 // AbstractType
825 //===----------------------------------------------------------------------===//
826 
827 const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) {
828   const AbstractType *type = lookupMutable(typeID, context);
829   if (!type)
830     llvm::report_fatal_error(
831         "Trying to create a Type that was not registered in this MLIRContext.");
832   return *type;
833 }
834 
835 AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) {
836   auto &impl = context->getImpl();
837   auto it = impl.registeredTypes.find(typeID);
838   if (it == impl.registeredTypes.end())
839     return nullptr;
840   return it->second;
841 }
842 
843 //===----------------------------------------------------------------------===//
844 // Identifier uniquing
845 //===----------------------------------------------------------------------===//
846 
847 /// Return an identifier for the specified string.
848 Identifier Identifier::get(const Twine &string, MLIRContext *context) {
849   SmallString<32> tempStr;
850   StringRef str = string.toStringRef(tempStr);
851 
852   // Check invariants after seeing if we already have something in the
853   // identifier table - if we already had it in the table, then it already
854   // passed invariant checks.
855   assert(!str.empty() && "Cannot create an empty identifier");
856   assert(str.find('\0') == StringRef::npos &&
857          "Cannot create an identifier with a nul character");
858 
859   auto getDialectOrContext = [&]() {
860     PointerUnion<Dialect *, MLIRContext *> dialectOrContext = context;
861     auto dialectNamePair = str.split('.');
862     if (!dialectNamePair.first.empty())
863       if (Dialect *dialect = context->getLoadedDialect(dialectNamePair.first))
864         dialectOrContext = dialect;
865     return dialectOrContext;
866   };
867 
868   auto &impl = context->getImpl();
869   if (!context->isMultithreadingEnabled()) {
870     auto insertedIt = impl.identifiers.insert({str, nullptr});
871     if (insertedIt.second)
872       insertedIt.first->second = getDialectOrContext();
873     return Identifier(&*insertedIt.first);
874   }
875 
876   // Check for an existing identifier in read-only mode.
877   {
878     llvm::sys::SmartScopedReader<true> contextLock(impl.identifierMutex);
879     auto it = impl.identifiers.find(str);
880     if (it != impl.identifiers.end())
881       return Identifier(&*it);
882   }
883 
884   // Acquire a writer-lock so that we can safely create the new instance.
885   llvm::sys::SmartScopedWriter<true> contextLock(impl.identifierMutex);
886   auto it = impl.identifiers.insert({str, getDialectOrContext()}).first;
887   return Identifier(&*it);
888 }
889 
890 Dialect *Identifier::getDialect() {
891   return entry->second.dyn_cast<Dialect *>();
892 }
893 
894 MLIRContext *Identifier::getContext() {
895   if (Dialect *dialect = getDialect())
896     return dialect->getContext();
897   return entry->second.get<MLIRContext *>();
898 }
899 
900 //===----------------------------------------------------------------------===//
901 // Type uniquing
902 //===----------------------------------------------------------------------===//
903 
904 /// Returns the storage uniquer used for constructing type storage instances.
905 /// This should not be used directly.
906 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
907 
908 BFloat16Type BFloat16Type::get(MLIRContext *context) {
909   return context->getImpl().bf16Ty;
910 }
911 Float16Type Float16Type::get(MLIRContext *context) {
912   return context->getImpl().f16Ty;
913 }
914 Float32Type Float32Type::get(MLIRContext *context) {
915   return context->getImpl().f32Ty;
916 }
917 Float64Type Float64Type::get(MLIRContext *context) {
918   return context->getImpl().f64Ty;
919 }
920 Float80Type Float80Type::get(MLIRContext *context) {
921   return context->getImpl().f80Ty;
922 }
923 Float128Type Float128Type::get(MLIRContext *context) {
924   return context->getImpl().f128Ty;
925 }
926 
927 /// Get an instance of the IndexType.
928 IndexType IndexType::get(MLIRContext *context) {
929   return context->getImpl().indexTy;
930 }
931 
932 /// Return an existing integer type instance if one is cached within the
933 /// context.
934 static IntegerType
935 getCachedIntegerType(unsigned width,
936                      IntegerType::SignednessSemantics signedness,
937                      MLIRContext *context) {
938   if (signedness != IntegerType::Signless)
939     return IntegerType();
940 
941   switch (width) {
942   case 1:
943     return context->getImpl().int1Ty;
944   case 8:
945     return context->getImpl().int8Ty;
946   case 16:
947     return context->getImpl().int16Ty;
948   case 32:
949     return context->getImpl().int32Ty;
950   case 64:
951     return context->getImpl().int64Ty;
952   case 128:
953     return context->getImpl().int128Ty;
954   default:
955     return IntegerType();
956   }
957 }
958 
959 IntegerType IntegerType::get(MLIRContext *context, unsigned width,
960                              IntegerType::SignednessSemantics signedness) {
961   if (auto cached = getCachedIntegerType(width, signedness, context))
962     return cached;
963   return Base::get(context, width, signedness);
964 }
965 
966 IntegerType
967 IntegerType::getChecked(function_ref<InFlightDiagnostic()> emitError,
968                         MLIRContext *context, unsigned width,
969                         SignednessSemantics signedness) {
970   if (auto cached = getCachedIntegerType(width, signedness, context))
971     return cached;
972   return Base::getChecked(emitError, context, width, signedness);
973 }
974 
975 /// Get an instance of the NoneType.
976 NoneType NoneType::get(MLIRContext *context) {
977   if (NoneType cachedInst = context->getImpl().noneType)
978     return cachedInst;
979   // Note: May happen when initializing the singleton attributes of the builtin
980   // dialect.
981   return Base::get(context);
982 }
983 
984 //===----------------------------------------------------------------------===//
985 // Attribute uniquing
986 //===----------------------------------------------------------------------===//
987 
988 /// Returns the storage uniquer used for constructing attribute storage
989 /// instances. This should not be used directly.
990 StorageUniquer &MLIRContext::getAttributeUniquer() {
991   return getImpl().attributeUniquer;
992 }
993 
994 /// Initialize the given attribute storage instance.
995 void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage,
996                                                   MLIRContext *ctx,
997                                                   TypeID attrID) {
998   storage->initialize(AbstractAttribute::lookup(attrID, ctx));
999 
1000   // If the attribute did not provide a type, then default to NoneType.
1001   if (!storage->getType())
1002     storage->setType(NoneType::get(ctx));
1003 }
1004 
1005 BoolAttr BoolAttr::get(MLIRContext *context, bool value) {
1006   return value ? context->getImpl().trueAttr : context->getImpl().falseAttr;
1007 }
1008 
1009 UnitAttr UnitAttr::get(MLIRContext *context) {
1010   return context->getImpl().unitAttr;
1011 }
1012 
1013 UnknownLoc UnknownLoc::get(MLIRContext *context) {
1014   return context->getImpl().unknownLocAttr;
1015 }
1016 
1017 /// Return empty dictionary.
1018 DictionaryAttr DictionaryAttr::getEmpty(MLIRContext *context) {
1019   return context->getImpl().emptyDictionaryAttr;
1020 }
1021 
1022 /// Return an empty string.
1023 StringAttr StringAttr::get(MLIRContext *context) {
1024   return context->getImpl().emptyStringAttr;
1025 }
1026 
1027 //===----------------------------------------------------------------------===//
1028 // AffineMap uniquing
1029 //===----------------------------------------------------------------------===//
1030 
1031 StorageUniquer &MLIRContext::getAffineUniquer() {
1032   return getImpl().affineUniquer;
1033 }
1034 
1035 AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount,
1036                              ArrayRef<AffineExpr> results,
1037                              MLIRContext *context) {
1038   auto &impl = context->getImpl();
1039   auto key = std::make_tuple(dimCount, symbolCount, results);
1040 
1041   // Safely get or create an AffineMap instance.
1042   return safeGetOrCreate(
1043       impl.affineMaps, key, impl.affineMutex, impl.threadingIsEnabled, [&] {
1044         auto *res = impl.affineAllocator.Allocate<detail::AffineMapStorage>();
1045 
1046         // Copy the results into the bump pointer.
1047         results = copyArrayRefInto(impl.affineAllocator, results);
1048 
1049         // Initialize the memory using placement new.
1050         new (res)
1051             detail::AffineMapStorage{dimCount, symbolCount, results, context};
1052         return AffineMap(res);
1053       });
1054 }
1055 
1056 AffineMap AffineMap::get(MLIRContext *context) {
1057   return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context);
1058 }
1059 
1060 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
1061                          MLIRContext *context) {
1062   return getImpl(dimCount, symbolCount, /*results=*/{}, context);
1063 }
1064 
1065 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
1066                          AffineExpr result) {
1067   return getImpl(dimCount, symbolCount, {result}, result.getContext());
1068 }
1069 
1070 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
1071                          ArrayRef<AffineExpr> results, MLIRContext *context) {
1072   return getImpl(dimCount, symbolCount, results, context);
1073 }
1074 
1075 //===----------------------------------------------------------------------===//
1076 // Integer Sets: these are allocated into the bump pointer, and are immutable.
1077 // Unlike AffineMap's, these are uniqued only if they are small.
1078 //===----------------------------------------------------------------------===//
1079 
1080 IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount,
1081                            ArrayRef<AffineExpr> constraints,
1082                            ArrayRef<bool> eqFlags) {
1083   // The number of constraints can't be zero.
1084   assert(!constraints.empty());
1085   assert(constraints.size() == eqFlags.size());
1086 
1087   auto &impl = constraints[0].getContext()->getImpl();
1088 
1089   // A utility function to construct a new IntegerSetStorage instance.
1090   auto constructorFn = [&] {
1091     auto *res = impl.affineAllocator.Allocate<detail::IntegerSetStorage>();
1092 
1093     // Copy the results and equality flags into the bump pointer.
1094     constraints = copyArrayRefInto(impl.affineAllocator, constraints);
1095     eqFlags = copyArrayRefInto(impl.affineAllocator, eqFlags);
1096 
1097     // Initialize the memory using placement new.
1098     new (res)
1099         detail::IntegerSetStorage{dimCount, symbolCount, constraints, eqFlags};
1100     return IntegerSet(res);
1101   };
1102 
1103   // If this instance is uniqued, then we handle it separately so that multiple
1104   // threads may simultaneously access existing instances.
1105   if (constraints.size() < IntegerSet::kUniquingThreshold) {
1106     auto key = std::make_tuple(dimCount, symbolCount, constraints, eqFlags);
1107     return safeGetOrCreate(impl.integerSets, key, impl.affineMutex,
1108                            impl.threadingIsEnabled, constructorFn);
1109   }
1110 
1111   // Otherwise, acquire a writer-lock so that we can safely create the new
1112   // instance.
1113   ScopedWriterLock affineLock(impl.affineMutex, impl.threadingIsEnabled);
1114   return constructorFn();
1115 }
1116 
1117 //===----------------------------------------------------------------------===//
1118 // StorageUniquerSupport
1119 //===----------------------------------------------------------------------===//
1120 
1121 /// Utility method to generate a callback that can be used to generate a
1122 /// diagnostic when checking the construction invariants of a storage object.
1123 /// This is defined out-of-line to avoid the need to include Location.h.
1124 llvm::unique_function<InFlightDiagnostic()>
1125 mlir::detail::getDefaultDiagnosticEmitFn(MLIRContext *ctx) {
1126   return [ctx] { return emitError(UnknownLoc::get(ctx)); };
1127 }
1128 llvm::unique_function<InFlightDiagnostic()>
1129 mlir::detail::getDefaultDiagnosticEmitFn(const Location &loc) {
1130   return [=] { return emitError(loc); };
1131 }
1132