1 //===- MLIRContext.cpp - MLIR Type Classes --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/MLIRContext.h"
10 #include "AffineExprDetail.h"
11 #include "AffineMapDetail.h"
12 #include "AttributeDetail.h"
13 #include "IntegerSetDetail.h"
14 #include "TypeDetail.h"
15 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/BuiltinDialect.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/IR/IntegerSet.h"
22 #include "mlir/IR/Location.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/IR/Types.h"
25 #include "mlir/Support/DebugAction.h"
26 #include "llvm/ADT/DenseMap.h"
27 #include "llvm/ADT/DenseSet.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/ADT/SmallString.h"
30 #include "llvm/ADT/StringSet.h"
31 #include "llvm/ADT/Twine.h"
32 #include "llvm/Support/Allocator.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/Mutex.h"
36 #include "llvm/Support/RWMutex.h"
37 #include "llvm/Support/ThreadPool.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include <memory>
40 
41 #define DEBUG_TYPE "mlircontext"
42 
43 using namespace mlir;
44 using namespace mlir::detail;
45 
46 using llvm::hash_combine;
47 using llvm::hash_combine_range;
48 
49 //===----------------------------------------------------------------------===//
50 // MLIRContext CommandLine Options
51 //===----------------------------------------------------------------------===//
52 
53 namespace {
54 /// This struct contains command line options that can be used to initialize
55 /// various bits of an MLIRContext. This uses a struct wrapper to avoid the need
56 /// for global command line options.
57 struct MLIRContextOptions {
58   llvm::cl::opt<bool> disableThreading{
59       "mlir-disable-threading",
60       llvm::cl::desc("Disable multi-threading within MLIR, overrides any "
61                      "further call to MLIRContext::enableMultiThreading()")};
62 
63   llvm::cl::opt<bool> printOpOnDiagnostic{
64       "mlir-print-op-on-diagnostic",
65       llvm::cl::desc("When a diagnostic is emitted on an operation, also print "
66                      "the operation as an attached note"),
67       llvm::cl::init(true)};
68 
69   llvm::cl::opt<bool> printStackTraceOnDiagnostic{
70       "mlir-print-stacktrace-on-diagnostic",
71       llvm::cl::desc("When a diagnostic is emitted, also print the stack trace "
72                      "as an attached note")};
73 };
74 } // end anonymous namespace
75 
76 static llvm::ManagedStatic<MLIRContextOptions> clOptions;
77 
78 static bool isThreadingGloballyDisabled() {
79 #if LLVM_ENABLE_THREADS != 0
80   return clOptions.isConstructed() && clOptions->disableThreading;
81 #else
82   return true;
83 #endif
84 }
85 
86 /// Register a set of useful command-line options that can be used to configure
87 /// various flags within the MLIRContext. These flags are used when constructing
88 /// an MLIR context for initialization.
89 void mlir::registerMLIRContextCLOptions() {
90   // Make sure that the options struct has been initialized.
91   *clOptions;
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // Locking Utilities
96 //===----------------------------------------------------------------------===//
97 
98 namespace {
99 /// Utility writer lock that takes a runtime flag that specifies if we really
100 /// need to lock.
101 struct ScopedWriterLock {
102   ScopedWriterLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock)
103       : mutex(shouldLock ? &mutexParam : nullptr) {
104     if (mutex)
105       mutex->lock();
106   }
107   ~ScopedWriterLock() {
108     if (mutex)
109       mutex->unlock();
110   }
111   llvm::sys::SmartRWMutex<true> *mutex;
112 };
113 } // end anonymous namespace.
114 
115 //===----------------------------------------------------------------------===//
116 // AffineMap and IntegerSet hashing
117 //===----------------------------------------------------------------------===//
118 
119 /// A utility function to safely get or create a uniqued instance within the
120 /// given set container.
121 template <typename ValueT, typename DenseInfoT, typename KeyT,
122           typename ConstructorFn>
123 static ValueT safeGetOrCreate(DenseSet<ValueT, DenseInfoT> &container,
124                               KeyT &&key, llvm::sys::SmartRWMutex<true> &mutex,
125                               bool threadingIsEnabled,
126                               ConstructorFn &&constructorFn) {
127   // Check for an existing instance in read-only mode.
128   if (threadingIsEnabled) {
129     llvm::sys::SmartScopedReader<true> instanceLock(mutex);
130     auto it = container.find_as(key);
131     if (it != container.end())
132       return *it;
133   }
134 
135   // Acquire a writer-lock so that we can safely create the new instance.
136   ScopedWriterLock instanceLock(mutex, threadingIsEnabled);
137 
138   // Check for an existing instance again here, because another writer thread
139   // may have already created one. Otherwise, construct a new instance.
140   auto existing = container.insert_as(ValueT(), key);
141   if (existing.second)
142     return *existing.first = constructorFn();
143   return *existing.first;
144 }
145 
146 namespace {
147 struct AffineMapKeyInfo : DenseMapInfo<AffineMap> {
148   // Affine maps are uniqued based on their dim/symbol counts and affine
149   // expressions.
150   using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExpr>>;
151   using DenseMapInfo<AffineMap>::isEqual;
152 
153   static unsigned getHashValue(const AffineMap &key) {
154     return getHashValue(
155         KeyTy(key.getNumDims(), key.getNumSymbols(), key.getResults()));
156   }
157 
158   static unsigned getHashValue(KeyTy key) {
159     return hash_combine(
160         std::get<0>(key), std::get<1>(key),
161         hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()));
162   }
163 
164   static bool isEqual(const KeyTy &lhs, AffineMap rhs) {
165     if (rhs == getEmptyKey() || rhs == getTombstoneKey())
166       return false;
167     return lhs == std::make_tuple(rhs.getNumDims(), rhs.getNumSymbols(),
168                                   rhs.getResults());
169   }
170 };
171 
172 struct IntegerSetKeyInfo : DenseMapInfo<IntegerSet> {
173   // Integer sets are uniqued based on their dim/symbol counts, affine
174   // expressions appearing in the LHS of constraints, and eqFlags.
175   using KeyTy =
176       std::tuple<unsigned, unsigned, ArrayRef<AffineExpr>, ArrayRef<bool>>;
177   using DenseMapInfo<IntegerSet>::isEqual;
178 
179   static unsigned getHashValue(const IntegerSet &key) {
180     return getHashValue(KeyTy(key.getNumDims(), key.getNumSymbols(),
181                               key.getConstraints(), key.getEqFlags()));
182   }
183 
184   static unsigned getHashValue(KeyTy key) {
185     return hash_combine(
186         std::get<0>(key), std::get<1>(key),
187         hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()),
188         hash_combine_range(std::get<3>(key).begin(), std::get<3>(key).end()));
189   }
190 
191   static bool isEqual(const KeyTy &lhs, IntegerSet rhs) {
192     if (rhs == getEmptyKey() || rhs == getTombstoneKey())
193       return false;
194     return lhs == std::make_tuple(rhs.getNumDims(), rhs.getNumSymbols(),
195                                   rhs.getConstraints(), rhs.getEqFlags());
196   }
197 };
198 } // end anonymous namespace.
199 
200 //===----------------------------------------------------------------------===//
201 // MLIRContextImpl
202 //===----------------------------------------------------------------------===//
203 
204 namespace mlir {
205 /// This is the implementation of the MLIRContext class, using the pImpl idiom.
206 /// This class is completely private to this file, so everything is public.
207 class MLIRContextImpl {
208 public:
209   //===--------------------------------------------------------------------===//
210   // Debugging
211   //===--------------------------------------------------------------------===//
212 
213   /// An action manager for use within the context.
214   DebugActionManager debugActionManager;
215 
216   //===--------------------------------------------------------------------===//
217   // Diagnostics
218   //===--------------------------------------------------------------------===//
219   DiagnosticEngine diagEngine;
220 
221   //===--------------------------------------------------------------------===//
222   // Options
223   //===--------------------------------------------------------------------===//
224 
225   /// In most cases, creating operation in unregistered dialect is not desired
226   /// and indicate a misconfiguration of the compiler. This option enables to
227   /// detect such use cases
228   bool allowUnregisteredDialects = false;
229 
230   /// Enable support for multi-threading within MLIR.
231   bool threadingIsEnabled = true;
232 
233   /// Track if we are currently executing in a threaded execution environment
234   /// (like the pass-manager): this is only a debugging feature to help reducing
235   /// the chances of data races one some context APIs.
236 #ifndef NDEBUG
237   std::atomic<int> multiThreadedExecutionContext{0};
238 #endif
239 
240   /// If the operation should be attached to diagnostics printed via the
241   /// Operation::emit methods.
242   bool printOpOnDiagnostic = true;
243 
244   /// If the current stack trace should be attached when emitting diagnostics.
245   bool printStackTraceOnDiagnostic = false;
246 
247   //===--------------------------------------------------------------------===//
248   // Other
249   //===--------------------------------------------------------------------===//
250 
251   /// This points to the ThreadPool used when processing MLIR tasks in parallel.
252   /// It can't be nullptr when multi-threading is enabled. Otherwise if
253   /// multi-threading is disabled, and the threadpool wasn't externally provided
254   /// using `setThreadPool`, this will be nullptr.
255   llvm::ThreadPool *threadPool = nullptr;
256 
257   /// In case where the thread pool is owned by the context, this ensures
258   /// destruction with the context.
259   std::unique_ptr<llvm::ThreadPool> ownedThreadPool;
260 
261   /// This is a list of dialects that are created referring to this context.
262   /// The MLIRContext owns the objects.
263   DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects;
264   DialectRegistry dialectsRegistry;
265 
266   /// An allocator used for AbstractAttribute and AbstractType objects.
267   llvm::BumpPtrAllocator abstractDialectSymbolAllocator;
268 
269   /// This is a mapping from operation name to the operation info describing it.
270   llvm::StringMap<OperationName::Impl> operations;
271 
272   /// A vector of operation info specifically for registered operations.
273   SmallVector<RegisteredOperationName> registeredOperations;
274 
275   /// A mutex used when accessing operation information.
276   llvm::sys::SmartRWMutex<true> operationInfoMutex;
277 
278   //===--------------------------------------------------------------------===//
279   // Affine uniquing
280   //===--------------------------------------------------------------------===//
281 
282   // Affine allocator and mutex for thread safety.
283   llvm::BumpPtrAllocator affineAllocator;
284   llvm::sys::SmartRWMutex<true> affineMutex;
285 
286   // Affine map uniquing.
287   using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>;
288   AffineMapSet affineMaps;
289 
290   // Integer set uniquing.
291   using IntegerSets = DenseSet<IntegerSet, IntegerSetKeyInfo>;
292   IntegerSets integerSets;
293 
294   // Affine expression uniquing.
295   StorageUniquer affineUniquer;
296 
297   //===--------------------------------------------------------------------===//
298   // Type uniquing
299   //===--------------------------------------------------------------------===//
300 
301   DenseMap<TypeID, AbstractType *> registeredTypes;
302   StorageUniquer typeUniquer;
303 
304   /// Cached Type Instances.
305   BFloat16Type bf16Ty;
306   Float16Type f16Ty;
307   Float32Type f32Ty;
308   Float64Type f64Ty;
309   Float80Type f80Ty;
310   Float128Type f128Ty;
311   IndexType indexTy;
312   IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
313   NoneType noneType;
314 
315   //===--------------------------------------------------------------------===//
316   // Attribute uniquing
317   //===--------------------------------------------------------------------===//
318 
319   DenseMap<TypeID, AbstractAttribute *> registeredAttributes;
320   StorageUniquer attributeUniquer;
321 
322   /// Cached Attribute Instances.
323   BoolAttr falseAttr, trueAttr;
324   UnitAttr unitAttr;
325   UnknownLoc unknownLocAttr;
326   DictionaryAttr emptyDictionaryAttr;
327   StringAttr emptyStringAttr;
328 
329   /// Map of string attributes that may reference a dialect, that are awaiting
330   /// that dialect to be loaded.
331   llvm::sys::SmartMutex<true> dialectRefStrAttrMutex;
332   DenseMap<StringRef, SmallVector<StringAttrStorage *>>
333       dialectReferencingStrAttrs;
334 
335 public:
336   MLIRContextImpl(bool threadingIsEnabled)
337       : threadingIsEnabled(threadingIsEnabled) {
338     if (threadingIsEnabled) {
339       ownedThreadPool = std::make_unique<llvm::ThreadPool>();
340       threadPool = ownedThreadPool.get();
341     }
342   }
343   ~MLIRContextImpl() {
344     for (auto typeMapping : registeredTypes)
345       typeMapping.second->~AbstractType();
346     for (auto attrMapping : registeredAttributes)
347       attrMapping.second->~AbstractAttribute();
348   }
349 };
350 } // end namespace mlir
351 
352 MLIRContext::MLIRContext(Threading setting)
353     : MLIRContext(DialectRegistry(), setting) {}
354 
355 MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
356     : impl(new MLIRContextImpl(setting == Threading::ENABLED &&
357                                !isThreadingGloballyDisabled())) {
358   // Initialize values based on the command line flags if they were provided.
359   if (clOptions.isConstructed()) {
360     printOpOnDiagnostic(clOptions->printOpOnDiagnostic);
361     printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic);
362   }
363 
364   // Pre-populate the registry.
365   registry.appendTo(impl->dialectsRegistry);
366 
367   // Ensure the builtin dialect is always pre-loaded.
368   getOrLoadDialect<BuiltinDialect>();
369 
370   // Initialize several common attributes and types to avoid the need to lock
371   // the context when accessing them.
372 
373   //// Types.
374   /// Floating-point Types.
375   impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
376   impl->f16Ty = TypeUniquer::get<Float16Type>(this);
377   impl->f32Ty = TypeUniquer::get<Float32Type>(this);
378   impl->f64Ty = TypeUniquer::get<Float64Type>(this);
379   impl->f80Ty = TypeUniquer::get<Float80Type>(this);
380   impl->f128Ty = TypeUniquer::get<Float128Type>(this);
381   /// Index Type.
382   impl->indexTy = TypeUniquer::get<IndexType>(this);
383   /// Integer Types.
384   impl->int1Ty = TypeUniquer::get<IntegerType>(this, 1, IntegerType::Signless);
385   impl->int8Ty = TypeUniquer::get<IntegerType>(this, 8, IntegerType::Signless);
386   impl->int16Ty =
387       TypeUniquer::get<IntegerType>(this, 16, IntegerType::Signless);
388   impl->int32Ty =
389       TypeUniquer::get<IntegerType>(this, 32, IntegerType::Signless);
390   impl->int64Ty =
391       TypeUniquer::get<IntegerType>(this, 64, IntegerType::Signless);
392   impl->int128Ty =
393       TypeUniquer::get<IntegerType>(this, 128, IntegerType::Signless);
394   /// None Type.
395   impl->noneType = TypeUniquer::get<NoneType>(this);
396 
397   //// Attributes.
398   //// Note: These must be registered after the types as they may generate one
399   //// of the above types internally.
400   /// Unknown Location Attribute.
401   impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(this);
402   /// Bool Attributes.
403   impl->falseAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, false);
404   impl->trueAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, true);
405   /// Unit Attribute.
406   impl->unitAttr = AttributeUniquer::get<UnitAttr>(this);
407   /// The empty dictionary attribute.
408   impl->emptyDictionaryAttr = DictionaryAttr::getEmptyUnchecked(this);
409   /// The empty string attribute.
410   impl->emptyStringAttr = StringAttr::getEmptyStringAttrUnchecked(this);
411 
412   // Register the affine storage objects with the uniquer.
413   impl->affineUniquer
414       .registerParametricStorageType<AffineBinaryOpExprStorage>();
415   impl->affineUniquer
416       .registerParametricStorageType<AffineConstantExprStorage>();
417   impl->affineUniquer.registerParametricStorageType<AffineDimExprStorage>();
418 }
419 
420 MLIRContext::~MLIRContext() {}
421 
422 /// Copy the specified array of elements into memory managed by the provided
423 /// bump pointer allocator.  This assumes the elements are all PODs.
424 template <typename T>
425 static ArrayRef<T> copyArrayRefInto(llvm::BumpPtrAllocator &allocator,
426                                     ArrayRef<T> elements) {
427   auto result = allocator.Allocate<T>(elements.size());
428   std::uninitialized_copy(elements.begin(), elements.end(), result);
429   return ArrayRef<T>(result, elements.size());
430 }
431 
432 //===----------------------------------------------------------------------===//
433 // Debugging
434 //===----------------------------------------------------------------------===//
435 
436 DebugActionManager &MLIRContext::getDebugActionManager() {
437   return getImpl().debugActionManager;
438 }
439 
440 //===----------------------------------------------------------------------===//
441 // Diagnostic Handlers
442 //===----------------------------------------------------------------------===//
443 
444 /// Returns the diagnostic engine for this context.
445 DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
446 
447 //===----------------------------------------------------------------------===//
448 // Dialect and Operation Registration
449 //===----------------------------------------------------------------------===//
450 
451 void MLIRContext::appendDialectRegistry(const DialectRegistry &registry) {
452   registry.appendTo(impl->dialectsRegistry);
453 
454   // For the already loaded dialects, register the interfaces immediately.
455   for (const auto &kvp : impl->loadedDialects)
456     registry.registerDelayedInterfaces(kvp.second.get());
457 }
458 
459 const DialectRegistry &MLIRContext::getDialectRegistry() {
460   return impl->dialectsRegistry;
461 }
462 
463 /// Return information about all registered IR dialects.
464 std::vector<Dialect *> MLIRContext::getLoadedDialects() {
465   std::vector<Dialect *> result;
466   result.reserve(impl->loadedDialects.size());
467   for (auto &dialect : impl->loadedDialects)
468     result.push_back(dialect.second.get());
469   llvm::array_pod_sort(result.begin(), result.end(),
470                        [](Dialect *const *lhs, Dialect *const *rhs) -> int {
471                          return (*lhs)->getNamespace() < (*rhs)->getNamespace();
472                        });
473   return result;
474 }
475 std::vector<StringRef> MLIRContext::getAvailableDialects() {
476   std::vector<StringRef> result;
477   for (auto dialect : impl->dialectsRegistry.getDialectNames())
478     result.push_back(dialect);
479   return result;
480 }
481 
482 /// Get a registered IR dialect with the given namespace. If none is found,
483 /// then return nullptr.
484 Dialect *MLIRContext::getLoadedDialect(StringRef name) {
485   // Dialects are sorted by name, so we can use binary search for lookup.
486   auto it = impl->loadedDialects.find(name);
487   return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr;
488 }
489 
490 Dialect *MLIRContext::getOrLoadDialect(StringRef name) {
491   Dialect *dialect = getLoadedDialect(name);
492   if (dialect)
493     return dialect;
494   DialectAllocatorFunctionRef allocator =
495       impl->dialectsRegistry.getDialectAllocator(name);
496   return allocator ? allocator(this) : nullptr;
497 }
498 
499 /// Get a dialect for the provided namespace and TypeID: abort the program if a
500 /// dialect exist for this namespace with different TypeID. Returns a pointer to
501 /// the dialect owned by the context.
502 Dialect *
503 MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
504                               function_ref<std::unique_ptr<Dialect>()> ctor) {
505   auto &impl = getImpl();
506   // Get the correct insertion position sorted by namespace.
507   std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace];
508 
509   if (!dialect) {
510     LLVM_DEBUG(llvm::dbgs()
511                << "Load new dialect in Context " << dialectNamespace << "\n");
512 #ifndef NDEBUG
513     if (impl.multiThreadedExecutionContext != 0)
514       llvm::report_fatal_error(
515           "Loading a dialect (" + dialectNamespace +
516           ") while in a multi-threaded execution context (maybe "
517           "the PassManager): this can indicate a "
518           "missing `dependentDialects` in a pass for example.");
519 #endif
520     dialect = ctor();
521     assert(dialect && "dialect ctor failed");
522 
523     // Refresh all the identifiers dialect field, this catches cases where a
524     // dialect may be loaded after identifier prefixed with this dialect name
525     // were already created.
526     auto stringAttrsIt = impl.dialectReferencingStrAttrs.find(dialectNamespace);
527     if (stringAttrsIt != impl.dialectReferencingStrAttrs.end()) {
528       for (StringAttrStorage *storage : stringAttrsIt->second)
529         storage->referencedDialect = dialect.get();
530       impl.dialectReferencingStrAttrs.erase(stringAttrsIt);
531     }
532 
533     // Actually register the interfaces with delayed registration.
534     impl.dialectsRegistry.registerDelayedInterfaces(dialect.get());
535     return dialect.get();
536   }
537 
538   // Abort if dialect with namespace has already been registered.
539   if (dialect->getTypeID() != dialectID)
540     llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
541                              "' has already been registered");
542 
543   return dialect.get();
544 }
545 
546 void MLIRContext::loadAllAvailableDialects() {
547   for (StringRef name : getAvailableDialects())
548     getOrLoadDialect(name);
549 }
550 
551 llvm::hash_code MLIRContext::getRegistryHash() {
552   llvm::hash_code hash(0);
553   // Factor in number of loaded dialects, attributes, operations, types.
554   hash = llvm::hash_combine(hash, impl->loadedDialects.size());
555   hash = llvm::hash_combine(hash, impl->registeredAttributes.size());
556   hash = llvm::hash_combine(hash, impl->registeredOperations.size());
557   hash = llvm::hash_combine(hash, impl->registeredTypes.size());
558   return hash;
559 }
560 
561 bool MLIRContext::allowsUnregisteredDialects() {
562   return impl->allowUnregisteredDialects;
563 }
564 
565 void MLIRContext::allowUnregisteredDialects(bool allowing) {
566   impl->allowUnregisteredDialects = allowing;
567 }
568 
569 /// Return true if multi-threading is enabled by the context.
570 bool MLIRContext::isMultithreadingEnabled() {
571   return impl->threadingIsEnabled && llvm::llvm_is_multithreaded();
572 }
573 
574 /// Set the flag specifying if multi-threading is disabled by the context.
575 void MLIRContext::disableMultithreading(bool disable) {
576   // This API can be overridden by the global debugging flag
577   // --mlir-disable-threading
578   if (isThreadingGloballyDisabled())
579     return;
580 
581   impl->threadingIsEnabled = !disable;
582 
583   // Update the threading mode for each of the uniquers.
584   impl->affineUniquer.disableMultithreading(disable);
585   impl->attributeUniquer.disableMultithreading(disable);
586   impl->typeUniquer.disableMultithreading(disable);
587 
588   // Destroy thread pool (stop all threads) if it is no longer needed, or create
589   // a new one if multithreading was re-enabled.
590   if (disable) {
591     // If the thread pool is owned, explicitly set it to nullptr to avoid
592     // keeping a dangling pointer around. If the thread pool is externally
593     // owned, we don't do anything.
594     if (impl->ownedThreadPool) {
595       assert(impl->threadPool);
596       impl->threadPool = nullptr;
597       impl->ownedThreadPool.reset();
598     }
599   } else if (!impl->threadPool) {
600     // The thread pool isn't externally provided.
601     assert(!impl->ownedThreadPool);
602     impl->ownedThreadPool = std::make_unique<llvm::ThreadPool>();
603     impl->threadPool = impl->ownedThreadPool.get();
604   }
605 }
606 
607 void MLIRContext::setThreadPool(llvm::ThreadPool &pool) {
608   assert(!isMultithreadingEnabled() &&
609          "expected multi-threading to be disabled when setting a ThreadPool");
610   impl->threadPool = &pool;
611   impl->ownedThreadPool.reset();
612   enableMultithreading();
613 }
614 
615 llvm::ThreadPool &MLIRContext::getThreadPool() {
616   assert(isMultithreadingEnabled() &&
617          "expected multi-threading to be enabled within the context");
618   assert(impl->threadPool &&
619          "multi-threading is enabled but threadpool not set");
620   return *impl->threadPool;
621 }
622 
623 void MLIRContext::enterMultiThreadedExecution() {
624 #ifndef NDEBUG
625   ++impl->multiThreadedExecutionContext;
626 #endif
627 }
628 void MLIRContext::exitMultiThreadedExecution() {
629 #ifndef NDEBUG
630   --impl->multiThreadedExecutionContext;
631 #endif
632 }
633 
634 /// Return true if we should attach the operation to diagnostics emitted via
635 /// Operation::emit.
636 bool MLIRContext::shouldPrintOpOnDiagnostic() {
637   return impl->printOpOnDiagnostic;
638 }
639 
640 /// Set the flag specifying if we should attach the operation to diagnostics
641 /// emitted via Operation::emit.
642 void MLIRContext::printOpOnDiagnostic(bool enable) {
643   impl->printOpOnDiagnostic = enable;
644 }
645 
646 /// Return true if we should attach the current stacktrace to diagnostics when
647 /// emitted.
648 bool MLIRContext::shouldPrintStackTraceOnDiagnostic() {
649   return impl->printStackTraceOnDiagnostic;
650 }
651 
652 /// Set the flag specifying if we should attach the current stacktrace when
653 /// emitting diagnostics.
654 void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
655   impl->printStackTraceOnDiagnostic = enable;
656 }
657 
658 /// Return information about all registered operations.  This isn't very
659 /// efficient, typically you should ask the operations about their properties
660 /// directly.
661 std::vector<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
662   // We just have the operations in a non-deterministic hash table order. Dump
663   // into a temporary array, then sort it by operation name to get a stable
664   // ordering.
665   std::vector<RegisteredOperationName> result(
666       impl->registeredOperations.begin(), impl->registeredOperations.end());
667   llvm::array_pod_sort(result.begin(), result.end(),
668                        [](const RegisteredOperationName *lhs,
669                           const RegisteredOperationName *rhs) {
670                          return lhs->getIdentifier().compare(
671                              rhs->getIdentifier());
672                        });
673 
674   return result;
675 }
676 
677 bool MLIRContext::isOperationRegistered(StringRef name) {
678   return OperationName(name, this).isRegistered();
679 }
680 
681 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
682   auto &impl = context->getImpl();
683   assert(impl.multiThreadedExecutionContext == 0 &&
684          "Registering a new type kind while in a multi-threaded execution "
685          "context");
686   auto *newInfo =
687       new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>())
688           AbstractType(std::move(typeInfo));
689   if (!impl.registeredTypes.insert({typeID, newInfo}).second)
690     llvm::report_fatal_error("Dialect Type already registered.");
691 }
692 
693 void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
694   auto &impl = context->getImpl();
695   assert(impl.multiThreadedExecutionContext == 0 &&
696          "Registering a new attribute kind while in a multi-threaded execution "
697          "context");
698   auto *newInfo =
699       new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>())
700           AbstractAttribute(std::move(attrInfo));
701   if (!impl.registeredAttributes.insert({typeID, newInfo}).second)
702     llvm::report_fatal_error("Dialect Attribute already registered.");
703 }
704 
705 //===----------------------------------------------------------------------===//
706 // AbstractAttribute
707 //===----------------------------------------------------------------------===//
708 
709 /// Get the dialect that registered the attribute with the provided typeid.
710 const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID,
711                                                    MLIRContext *context) {
712   const AbstractAttribute *abstract = lookupMutable(typeID, context);
713   if (!abstract)
714     llvm::report_fatal_error("Trying to create an Attribute that was not "
715                              "registered in this MLIRContext.");
716   return *abstract;
717 }
718 
719 AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID,
720                                                     MLIRContext *context) {
721   auto &impl = context->getImpl();
722   auto it = impl.registeredAttributes.find(typeID);
723   if (it == impl.registeredAttributes.end())
724     return nullptr;
725   return it->second;
726 }
727 
728 //===----------------------------------------------------------------------===//
729 // OperationName
730 //===----------------------------------------------------------------------===//
731 
732 OperationName::OperationName(StringRef name, MLIRContext *context) {
733   MLIRContextImpl &ctxImpl = context->getImpl();
734 
735   // Check for an existing name in read-only mode.
736   bool isMultithreadingEnabled = context->isMultithreadingEnabled();
737   if (isMultithreadingEnabled) {
738     llvm::sys::SmartScopedReader<true> contextLock(ctxImpl.operationInfoMutex);
739     auto it = ctxImpl.operations.find(name);
740     if (it != ctxImpl.operations.end()) {
741       impl = &it->second;
742       return;
743     }
744   }
745 
746   // Acquire a writer-lock so that we can safely create the new instance.
747   ScopedWriterLock lock(ctxImpl.operationInfoMutex, isMultithreadingEnabled);
748 
749   auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)});
750   if (it.second)
751     it.first->second.name = StringAttr::get(context, name);
752   impl = &it.first->second;
753 }
754 
755 StringRef OperationName::getDialectNamespace() const {
756   if (Dialect *dialect = getDialect())
757     return dialect->getNamespace();
758   return getStringRef().split('.').first;
759 }
760 
761 //===----------------------------------------------------------------------===//
762 // RegisteredOperationName
763 //===----------------------------------------------------------------------===//
764 
765 ParseResult
766 RegisteredOperationName::parseAssembly(OpAsmParser &parser,
767                                        OperationState &result) const {
768   return impl->parseAssemblyFn(parser, result);
769 }
770 
771 void RegisteredOperationName::insert(
772     StringRef name, Dialect &dialect, TypeID typeID,
773     ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
774     VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
775     GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
776     detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
777     ArrayRef<StringRef> attrNames) {
778   MLIRContext *ctx = dialect.getContext();
779   auto &ctxImpl = ctx->getImpl();
780   assert(ctxImpl.multiThreadedExecutionContext == 0 &&
781          "registering a new operation kind while in a multi-threaded execution "
782          "context");
783 
784   // Register the attribute names of this operation.
785   MutableArrayRef<StringAttr> cachedAttrNames;
786   if (!attrNames.empty()) {
787     cachedAttrNames = MutableArrayRef<StringAttr>(
788         ctxImpl.abstractDialectSymbolAllocator.Allocate<StringAttr>(
789             attrNames.size()),
790         attrNames.size());
791     for (unsigned i : llvm::seq<unsigned>(0, attrNames.size()))
792       new (&cachedAttrNames[i]) StringAttr(StringAttr::get(ctx, attrNames[i]));
793   }
794 
795   // Insert the operation info if it doesn't exist yet.
796   auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)});
797   if (it.second)
798     it.first->second.name = StringAttr::get(ctx, name);
799   OperationName::Impl &impl = it.first->second;
800 
801   if (impl.isRegistered()) {
802     llvm::errs() << "error: operation named '" << name
803                  << "' is already registered.\n";
804     abort();
805   }
806   ctxImpl.registeredOperations.push_back(RegisteredOperationName(&impl));
807 
808   // Update the registered info for this operation.
809   impl.dialect = &dialect;
810   impl.typeID = typeID;
811   impl.interfaceMap = std::move(interfaceMap);
812   impl.foldHookFn = std::move(foldHook);
813   impl.getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns);
814   impl.hasTraitFn = std::move(hasTrait);
815   impl.parseAssemblyFn = std::move(parseAssembly);
816   impl.printAssemblyFn = std::move(printAssembly);
817   impl.verifyInvariantsFn = std::move(verifyInvariants);
818   impl.attributeNames = cachedAttrNames;
819 }
820 
821 //===----------------------------------------------------------------------===//
822 // AbstractType
823 //===----------------------------------------------------------------------===//
824 
825 const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) {
826   const AbstractType *type = lookupMutable(typeID, context);
827   if (!type)
828     llvm::report_fatal_error(
829         "Trying to create a Type that was not registered in this MLIRContext.");
830   return *type;
831 }
832 
833 AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) {
834   auto &impl = context->getImpl();
835   auto it = impl.registeredTypes.find(typeID);
836   if (it == impl.registeredTypes.end())
837     return nullptr;
838   return it->second;
839 }
840 
841 //===----------------------------------------------------------------------===//
842 // Type uniquing
843 //===----------------------------------------------------------------------===//
844 
845 /// Returns the storage uniquer used for constructing type storage instances.
846 /// This should not be used directly.
847 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
848 
849 BFloat16Type BFloat16Type::get(MLIRContext *context) {
850   return context->getImpl().bf16Ty;
851 }
852 Float16Type Float16Type::get(MLIRContext *context) {
853   return context->getImpl().f16Ty;
854 }
855 Float32Type Float32Type::get(MLIRContext *context) {
856   return context->getImpl().f32Ty;
857 }
858 Float64Type Float64Type::get(MLIRContext *context) {
859   return context->getImpl().f64Ty;
860 }
861 Float80Type Float80Type::get(MLIRContext *context) {
862   return context->getImpl().f80Ty;
863 }
864 Float128Type Float128Type::get(MLIRContext *context) {
865   return context->getImpl().f128Ty;
866 }
867 
868 /// Get an instance of the IndexType.
869 IndexType IndexType::get(MLIRContext *context) {
870   return context->getImpl().indexTy;
871 }
872 
873 /// Return an existing integer type instance if one is cached within the
874 /// context.
875 static IntegerType
876 getCachedIntegerType(unsigned width,
877                      IntegerType::SignednessSemantics signedness,
878                      MLIRContext *context) {
879   if (signedness != IntegerType::Signless)
880     return IntegerType();
881 
882   switch (width) {
883   case 1:
884     return context->getImpl().int1Ty;
885   case 8:
886     return context->getImpl().int8Ty;
887   case 16:
888     return context->getImpl().int16Ty;
889   case 32:
890     return context->getImpl().int32Ty;
891   case 64:
892     return context->getImpl().int64Ty;
893   case 128:
894     return context->getImpl().int128Ty;
895   default:
896     return IntegerType();
897   }
898 }
899 
900 IntegerType IntegerType::get(MLIRContext *context, unsigned width,
901                              IntegerType::SignednessSemantics signedness) {
902   if (auto cached = getCachedIntegerType(width, signedness, context))
903     return cached;
904   return Base::get(context, width, signedness);
905 }
906 
907 IntegerType
908 IntegerType::getChecked(function_ref<InFlightDiagnostic()> emitError,
909                         MLIRContext *context, unsigned width,
910                         SignednessSemantics signedness) {
911   if (auto cached = getCachedIntegerType(width, signedness, context))
912     return cached;
913   return Base::getChecked(emitError, context, width, signedness);
914 }
915 
916 /// Get an instance of the NoneType.
917 NoneType NoneType::get(MLIRContext *context) {
918   if (NoneType cachedInst = context->getImpl().noneType)
919     return cachedInst;
920   // Note: May happen when initializing the singleton attributes of the builtin
921   // dialect.
922   return Base::get(context);
923 }
924 
925 //===----------------------------------------------------------------------===//
926 // Attribute uniquing
927 //===----------------------------------------------------------------------===//
928 
929 /// Returns the storage uniquer used for constructing attribute storage
930 /// instances. This should not be used directly.
931 StorageUniquer &MLIRContext::getAttributeUniquer() {
932   return getImpl().attributeUniquer;
933 }
934 
935 /// Initialize the given attribute storage instance.
936 void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage,
937                                                   MLIRContext *ctx,
938                                                   TypeID attrID) {
939   storage->initializeAbstractAttribute(AbstractAttribute::lookup(attrID, ctx));
940 
941   // If the attribute did not provide a type, then default to NoneType.
942   if (!storage->getType())
943     storage->setType(NoneType::get(ctx));
944 }
945 
946 BoolAttr BoolAttr::get(MLIRContext *context, bool value) {
947   return value ? context->getImpl().trueAttr : context->getImpl().falseAttr;
948 }
949 
950 UnitAttr UnitAttr::get(MLIRContext *context) {
951   return context->getImpl().unitAttr;
952 }
953 
954 UnknownLoc UnknownLoc::get(MLIRContext *context) {
955   return context->getImpl().unknownLocAttr;
956 }
957 
958 /// Return empty dictionary.
959 DictionaryAttr DictionaryAttr::getEmpty(MLIRContext *context) {
960   return context->getImpl().emptyDictionaryAttr;
961 }
962 
963 void StringAttrStorage::initialize(MLIRContext *context) {
964   // Check for a dialect namespace prefix, if there isn't one we don't need to
965   // do any additional initialization.
966   auto dialectNamePair = value.split('.');
967   if (dialectNamePair.first.empty() || dialectNamePair.second.empty())
968     return;
969 
970   // If one exists, we check to see if this dialect is loaded. If it is, we set
971   // the dialect now, if it isn't we record this storage for initialization
972   // later if the dialect ever gets loaded.
973   if ((referencedDialect = context->getLoadedDialect(dialectNamePair.first)))
974     return;
975 
976   MLIRContextImpl &impl = context->getImpl();
977   llvm::sys::SmartScopedLock<true> lock(impl.dialectRefStrAttrMutex);
978   impl.dialectReferencingStrAttrs[dialectNamePair.first].push_back(this);
979 }
980 
981 /// Return an empty string.
982 StringAttr StringAttr::get(MLIRContext *context) {
983   return context->getImpl().emptyStringAttr;
984 }
985 
986 //===----------------------------------------------------------------------===//
987 // AffineMap uniquing
988 //===----------------------------------------------------------------------===//
989 
990 StorageUniquer &MLIRContext::getAffineUniquer() {
991   return getImpl().affineUniquer;
992 }
993 
994 AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount,
995                              ArrayRef<AffineExpr> results,
996                              MLIRContext *context) {
997   auto &impl = context->getImpl();
998   auto key = std::make_tuple(dimCount, symbolCount, results);
999 
1000   // Safely get or create an AffineMap instance.
1001   return safeGetOrCreate(
1002       impl.affineMaps, key, impl.affineMutex, impl.threadingIsEnabled, [&] {
1003         auto *res = impl.affineAllocator.Allocate<detail::AffineMapStorage>();
1004 
1005         // Copy the results into the bump pointer.
1006         results = copyArrayRefInto(impl.affineAllocator, results);
1007 
1008         // Initialize the memory using placement new.
1009         new (res)
1010             detail::AffineMapStorage{dimCount, symbolCount, results, context};
1011         return AffineMap(res);
1012       });
1013 }
1014 
1015 /// Check whether the arguments passed to the AffineMap::get() are consistent.
1016 /// This method checks whether the highest index of dimensional identifier
1017 /// present in result expressions is less than `dimCount` and the highest index
1018 /// of symbolic identifier present in result expressions is less than
1019 /// `symbolCount`.
1020 LLVM_NODISCARD static bool willBeValidAffineMap(unsigned dimCount,
1021                                                 unsigned symbolCount,
1022                                                 ArrayRef<AffineExpr> results) {
1023   int64_t maxDimPosition = -1;
1024   int64_t maxSymbolPosition = -1;
1025   getMaxDimAndSymbol(ArrayRef<ArrayRef<AffineExpr>>(results), maxDimPosition,
1026                      maxSymbolPosition);
1027   if ((maxDimPosition >= dimCount) || (maxSymbolPosition >= symbolCount)) {
1028     LLVM_DEBUG(
1029         llvm::dbgs()
1030         << "maximum dimensional identifier position in result expression must "
1031            "be less than `dimCount` and maximum symbolic identifier position "
1032            "in result expression must be less than `symbolCount`\n");
1033     return false;
1034   }
1035   return true;
1036 }
1037 
1038 AffineMap AffineMap::get(MLIRContext *context) {
1039   return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context);
1040 }
1041 
1042 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
1043                          MLIRContext *context) {
1044   return getImpl(dimCount, symbolCount, /*results=*/{}, context);
1045 }
1046 
1047 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
1048                          AffineExpr result) {
1049   assert(willBeValidAffineMap(dimCount, symbolCount, {result}));
1050   return getImpl(dimCount, symbolCount, {result}, result.getContext());
1051 }
1052 
1053 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
1054                          ArrayRef<AffineExpr> results, MLIRContext *context) {
1055   assert(willBeValidAffineMap(dimCount, symbolCount, results));
1056   return getImpl(dimCount, symbolCount, results, context);
1057 }
1058 
1059 //===----------------------------------------------------------------------===//
1060 // Integer Sets: these are allocated into the bump pointer, and are immutable.
1061 // Unlike AffineMap's, these are uniqued only if they are small.
1062 //===----------------------------------------------------------------------===//
1063 
1064 IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount,
1065                            ArrayRef<AffineExpr> constraints,
1066                            ArrayRef<bool> eqFlags) {
1067   // The number of constraints can't be zero.
1068   assert(!constraints.empty());
1069   assert(constraints.size() == eqFlags.size());
1070 
1071   auto &impl = constraints[0].getContext()->getImpl();
1072 
1073   // A utility function to construct a new IntegerSetStorage instance.
1074   auto constructorFn = [&] {
1075     auto *res = impl.affineAllocator.Allocate<detail::IntegerSetStorage>();
1076 
1077     // Copy the results and equality flags into the bump pointer.
1078     constraints = copyArrayRefInto(impl.affineAllocator, constraints);
1079     eqFlags = copyArrayRefInto(impl.affineAllocator, eqFlags);
1080 
1081     // Initialize the memory using placement new.
1082     new (res)
1083         detail::IntegerSetStorage{dimCount, symbolCount, constraints, eqFlags};
1084     return IntegerSet(res);
1085   };
1086 
1087   // If this instance is uniqued, then we handle it separately so that multiple
1088   // threads may simultaneously access existing instances.
1089   if (constraints.size() < IntegerSet::kUniquingThreshold) {
1090     auto key = std::make_tuple(dimCount, symbolCount, constraints, eqFlags);
1091     return safeGetOrCreate(impl.integerSets, key, impl.affineMutex,
1092                            impl.threadingIsEnabled, constructorFn);
1093   }
1094 
1095   // Otherwise, acquire a writer-lock so that we can safely create the new
1096   // instance.
1097   ScopedWriterLock affineLock(impl.affineMutex, impl.threadingIsEnabled);
1098   return constructorFn();
1099 }
1100 
1101 //===----------------------------------------------------------------------===//
1102 // StorageUniquerSupport
1103 //===----------------------------------------------------------------------===//
1104 
1105 /// Utility method to generate a callback that can be used to generate a
1106 /// diagnostic when checking the construction invariants of a storage object.
1107 /// This is defined out-of-line to avoid the need to include Location.h.
1108 llvm::unique_function<InFlightDiagnostic()>
1109 mlir::detail::getDefaultDiagnosticEmitFn(MLIRContext *ctx) {
1110   return [ctx] { return emitError(UnknownLoc::get(ctx)); };
1111 }
1112 llvm::unique_function<InFlightDiagnostic()>
1113 mlir::detail::getDefaultDiagnosticEmitFn(const Location &loc) {
1114   return [=] { return emitError(loc); };
1115 }
1116