1 //===- MLIRContext.cpp - MLIR Type Classes --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/MLIRContext.h"
10 #include "AffineExprDetail.h"
11 #include "AffineMapDetail.h"
12 #include "AttributeDetail.h"
13 #include "IntegerSetDetail.h"
14 #include "TypeDetail.h"
15 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/BuiltinDialect.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/IR/IntegerSet.h"
22 #include "mlir/IR/Location.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/IR/Types.h"
25 #include "mlir/Support/DebugAction.h"
26 #include "llvm/ADT/DenseMap.h"
27 #include "llvm/ADT/DenseSet.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/ADT/SmallString.h"
30 #include "llvm/ADT/StringSet.h"
31 #include "llvm/ADT/Twine.h"
32 #include "llvm/Support/Allocator.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/Mutex.h"
36 #include "llvm/Support/RWMutex.h"
37 #include "llvm/Support/ThreadPool.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include <memory>
40 
41 #define DEBUG_TYPE "mlircontext"
42 
43 using namespace mlir;
44 using namespace mlir::detail;
45 
46 using llvm::hash_combine;
47 using llvm::hash_combine_range;
48 
49 //===----------------------------------------------------------------------===//
50 // MLIRContext CommandLine Options
51 //===----------------------------------------------------------------------===//
52 
53 namespace {
54 /// This struct contains command line options that can be used to initialize
55 /// various bits of an MLIRContext. This uses a struct wrapper to avoid the need
56 /// for global command line options.
57 struct MLIRContextOptions {
58   llvm::cl::opt<bool> disableThreading{
59       "mlir-disable-threading",
60       llvm::cl::desc("Disable multi-threading within MLIR, overrides any "
61                      "further call to MLIRContext::enableMultiThreading()")};
62 
63   llvm::cl::opt<bool> printOpOnDiagnostic{
64       "mlir-print-op-on-diagnostic",
65       llvm::cl::desc("When a diagnostic is emitted on an operation, also print "
66                      "the operation as an attached note"),
67       llvm::cl::init(true)};
68 
69   llvm::cl::opt<bool> printStackTraceOnDiagnostic{
70       "mlir-print-stacktrace-on-diagnostic",
71       llvm::cl::desc("When a diagnostic is emitted, also print the stack trace "
72                      "as an attached note")};
73 };
74 } // end anonymous namespace
75 
76 static llvm::ManagedStatic<MLIRContextOptions> clOptions;
77 
78 static bool isThreadingGloballyDisabled() {
79 #if LLVM_ENABLE_THREADS != 0
80   return clOptions.isConstructed() && clOptions->disableThreading;
81 #else
82   return true;
83 #endif
84 }
85 
86 /// Register a set of useful command-line options that can be used to configure
87 /// various flags within the MLIRContext. These flags are used when constructing
88 /// an MLIR context for initialization.
89 void mlir::registerMLIRContextCLOptions() {
90   // Make sure that the options struct has been initialized.
91   *clOptions;
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // Locking Utilities
96 //===----------------------------------------------------------------------===//
97 
98 namespace {
99 /// Utility reader lock that takes a runtime flag that specifies if we really
100 /// need to lock.
101 struct ScopedReaderLock {
102   ScopedReaderLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock)
103       : mutex(shouldLock ? &mutexParam : nullptr) {
104     if (mutex)
105       mutex->lock_shared();
106   }
107   ~ScopedReaderLock() {
108     if (mutex)
109       mutex->unlock_shared();
110   }
111   llvm::sys::SmartRWMutex<true> *mutex;
112 };
113 /// Utility writer lock that takes a runtime flag that specifies if we really
114 /// need to lock.
115 struct ScopedWriterLock {
116   ScopedWriterLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock)
117       : mutex(shouldLock ? &mutexParam : nullptr) {
118     if (mutex)
119       mutex->lock();
120   }
121   ~ScopedWriterLock() {
122     if (mutex)
123       mutex->unlock();
124   }
125   llvm::sys::SmartRWMutex<true> *mutex;
126 };
127 } // end anonymous namespace.
128 
129 //===----------------------------------------------------------------------===//
130 // AffineMap and IntegerSet hashing
131 //===----------------------------------------------------------------------===//
132 
133 /// A utility function to safely get or create a uniqued instance within the
134 /// given set container.
135 template <typename ValueT, typename DenseInfoT, typename KeyT,
136           typename ConstructorFn>
137 static ValueT safeGetOrCreate(DenseSet<ValueT, DenseInfoT> &container,
138                               KeyT &&key, llvm::sys::SmartRWMutex<true> &mutex,
139                               bool threadingIsEnabled,
140                               ConstructorFn &&constructorFn) {
141   // Check for an existing instance in read-only mode.
142   if (threadingIsEnabled) {
143     llvm::sys::SmartScopedReader<true> instanceLock(mutex);
144     auto it = container.find_as(key);
145     if (it != container.end())
146       return *it;
147   }
148 
149   // Acquire a writer-lock so that we can safely create the new instance.
150   ScopedWriterLock instanceLock(mutex, threadingIsEnabled);
151 
152   // Check for an existing instance again here, because another writer thread
153   // may have already created one. Otherwise, construct a new instance.
154   auto existing = container.insert_as(ValueT(), key);
155   if (existing.second)
156     return *existing.first = constructorFn();
157   return *existing.first;
158 }
159 
160 namespace {
161 struct AffineMapKeyInfo : DenseMapInfo<AffineMap> {
162   // Affine maps are uniqued based on their dim/symbol counts and affine
163   // expressions.
164   using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExpr>>;
165   using DenseMapInfo<AffineMap>::isEqual;
166 
167   static unsigned getHashValue(const AffineMap &key) {
168     return getHashValue(
169         KeyTy(key.getNumDims(), key.getNumSymbols(), key.getResults()));
170   }
171 
172   static unsigned getHashValue(KeyTy key) {
173     return hash_combine(
174         std::get<0>(key), std::get<1>(key),
175         hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()));
176   }
177 
178   static bool isEqual(const KeyTy &lhs, AffineMap rhs) {
179     if (rhs == getEmptyKey() || rhs == getTombstoneKey())
180       return false;
181     return lhs == std::make_tuple(rhs.getNumDims(), rhs.getNumSymbols(),
182                                   rhs.getResults());
183   }
184 };
185 
186 struct IntegerSetKeyInfo : DenseMapInfo<IntegerSet> {
187   // Integer sets are uniqued based on their dim/symbol counts, affine
188   // expressions appearing in the LHS of constraints, and eqFlags.
189   using KeyTy =
190       std::tuple<unsigned, unsigned, ArrayRef<AffineExpr>, ArrayRef<bool>>;
191   using DenseMapInfo<IntegerSet>::isEqual;
192 
193   static unsigned getHashValue(const IntegerSet &key) {
194     return getHashValue(KeyTy(key.getNumDims(), key.getNumSymbols(),
195                               key.getConstraints(), key.getEqFlags()));
196   }
197 
198   static unsigned getHashValue(KeyTy key) {
199     return hash_combine(
200         std::get<0>(key), std::get<1>(key),
201         hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()),
202         hash_combine_range(std::get<3>(key).begin(), std::get<3>(key).end()));
203   }
204 
205   static bool isEqual(const KeyTy &lhs, IntegerSet rhs) {
206     if (rhs == getEmptyKey() || rhs == getTombstoneKey())
207       return false;
208     return lhs == std::make_tuple(rhs.getNumDims(), rhs.getNumSymbols(),
209                                   rhs.getConstraints(), rhs.getEqFlags());
210   }
211 };
212 } // end anonymous namespace.
213 
214 //===----------------------------------------------------------------------===//
215 // MLIRContextImpl
216 //===----------------------------------------------------------------------===//
217 
218 namespace mlir {
219 /// This is the implementation of the MLIRContext class, using the pImpl idiom.
220 /// This class is completely private to this file, so everything is public.
221 class MLIRContextImpl {
222 public:
223   //===--------------------------------------------------------------------===//
224   // Debugging
225   //===--------------------------------------------------------------------===//
226 
227   /// An action manager for use within the context.
228   DebugActionManager debugActionManager;
229 
230   //===--------------------------------------------------------------------===//
231   // Diagnostics
232   //===--------------------------------------------------------------------===//
233   DiagnosticEngine diagEngine;
234 
235   //===--------------------------------------------------------------------===//
236   // Options
237   //===--------------------------------------------------------------------===//
238 
239   /// In most cases, creating operation in unregistered dialect is not desired
240   /// and indicate a misconfiguration of the compiler. This option enables to
241   /// detect such use cases
242   bool allowUnregisteredDialects = false;
243 
244   /// Enable support for multi-threading within MLIR.
245   bool threadingIsEnabled = true;
246 
247   /// Track if we are currently executing in a threaded execution environment
248   /// (like the pass-manager): this is only a debugging feature to help reducing
249   /// the chances of data races one some context APIs.
250 #ifndef NDEBUG
251   std::atomic<int> multiThreadedExecutionContext{0};
252 #endif
253 
254   /// If the operation should be attached to diagnostics printed via the
255   /// Operation::emit methods.
256   bool printOpOnDiagnostic = true;
257 
258   /// If the current stack trace should be attached when emitting diagnostics.
259   bool printStackTraceOnDiagnostic = false;
260 
261   //===--------------------------------------------------------------------===//
262   // Other
263   //===--------------------------------------------------------------------===//
264 
265   /// This points to the ThreadPool used when processing MLIR tasks in parallel.
266   /// It can't be nullptr when multi-threading is enabled. Otherwise if
267   /// multi-threading is disabled, and the threadpool wasn't externally provided
268   /// using `setThreadPool`, this will be nullptr.
269   llvm::ThreadPool *threadPool = nullptr;
270 
271   /// In case where the thread pool is owned by the context, this ensures
272   /// destruction with the context.
273   std::unique_ptr<llvm::ThreadPool> ownedThreadPool;
274 
275   /// This is a list of dialects that are created referring to this context.
276   /// The MLIRContext owns the objects.
277   DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects;
278   DialectRegistry dialectsRegistry;
279 
280   /// This is a mapping from operation name to AbstractOperation for registered
281   /// operations.
282   llvm::StringMap<AbstractOperation> registeredOperations;
283 
284   /// An allocator used for AbstractAttribute and AbstractType objects.
285   llvm::BumpPtrAllocator abstractDialectSymbolAllocator;
286 
287   //===--------------------------------------------------------------------===//
288   // Affine uniquing
289   //===--------------------------------------------------------------------===//
290 
291   // Affine allocator and mutex for thread safety.
292   llvm::BumpPtrAllocator affineAllocator;
293   llvm::sys::SmartRWMutex<true> affineMutex;
294 
295   // Affine map uniquing.
296   using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>;
297   AffineMapSet affineMaps;
298 
299   // Integer set uniquing.
300   using IntegerSets = DenseSet<IntegerSet, IntegerSetKeyInfo>;
301   IntegerSets integerSets;
302 
303   // Affine expression uniquing.
304   StorageUniquer affineUniquer;
305 
306   //===--------------------------------------------------------------------===//
307   // Type uniquing
308   //===--------------------------------------------------------------------===//
309 
310   DenseMap<TypeID, AbstractType *> registeredTypes;
311   StorageUniquer typeUniquer;
312 
313   /// Cached Type Instances.
314   BFloat16Type bf16Ty;
315   Float16Type f16Ty;
316   Float32Type f32Ty;
317   Float64Type f64Ty;
318   Float80Type f80Ty;
319   Float128Type f128Ty;
320   IndexType indexTy;
321   IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
322   NoneType noneType;
323 
324   //===--------------------------------------------------------------------===//
325   // Attribute uniquing
326   //===--------------------------------------------------------------------===//
327 
328   DenseMap<TypeID, AbstractAttribute *> registeredAttributes;
329   StorageUniquer attributeUniquer;
330 
331   /// Cached Attribute Instances.
332   BoolAttr falseAttr, trueAttr;
333   UnitAttr unitAttr;
334   UnknownLoc unknownLocAttr;
335   DictionaryAttr emptyDictionaryAttr;
336   StringAttr emptyStringAttr;
337 
338   /// Map of string attributes that may reference a dialect, that are awaiting
339   /// that dialect to be loaded.
340   llvm::sys::SmartMutex<true> dialectRefStrAttrMutex;
341   DenseMap<StringRef, SmallVector<StringAttrStorage *>>
342       dialectReferencingStrAttrs;
343 
344 public:
345   MLIRContextImpl(bool threadingIsEnabled)
346       : threadingIsEnabled(threadingIsEnabled) {
347     if (threadingIsEnabled) {
348       ownedThreadPool = std::make_unique<llvm::ThreadPool>();
349       threadPool = ownedThreadPool.get();
350     }
351   }
352   ~MLIRContextImpl() {
353     for (auto typeMapping : registeredTypes)
354       typeMapping.second->~AbstractType();
355     for (auto attrMapping : registeredAttributes)
356       attrMapping.second->~AbstractAttribute();
357   }
358 };
359 } // end namespace mlir
360 
361 MLIRContext::MLIRContext(Threading setting)
362     : MLIRContext(DialectRegistry(), setting) {}
363 
364 MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
365     : impl(new MLIRContextImpl(setting == Threading::ENABLED &&
366                                !isThreadingGloballyDisabled())) {
367   // Initialize values based on the command line flags if they were provided.
368   if (clOptions.isConstructed()) {
369     printOpOnDiagnostic(clOptions->printOpOnDiagnostic);
370     printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic);
371   }
372 
373   // Pre-populate the registry.
374   registry.appendTo(impl->dialectsRegistry);
375 
376   // Ensure the builtin dialect is always pre-loaded.
377   getOrLoadDialect<BuiltinDialect>();
378 
379   // Initialize several common attributes and types to avoid the need to lock
380   // the context when accessing them.
381 
382   //// Types.
383   /// Floating-point Types.
384   impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
385   impl->f16Ty = TypeUniquer::get<Float16Type>(this);
386   impl->f32Ty = TypeUniquer::get<Float32Type>(this);
387   impl->f64Ty = TypeUniquer::get<Float64Type>(this);
388   impl->f80Ty = TypeUniquer::get<Float80Type>(this);
389   impl->f128Ty = TypeUniquer::get<Float128Type>(this);
390   /// Index Type.
391   impl->indexTy = TypeUniquer::get<IndexType>(this);
392   /// Integer Types.
393   impl->int1Ty = TypeUniquer::get<IntegerType>(this, 1, IntegerType::Signless);
394   impl->int8Ty = TypeUniquer::get<IntegerType>(this, 8, IntegerType::Signless);
395   impl->int16Ty =
396       TypeUniquer::get<IntegerType>(this, 16, IntegerType::Signless);
397   impl->int32Ty =
398       TypeUniquer::get<IntegerType>(this, 32, IntegerType::Signless);
399   impl->int64Ty =
400       TypeUniquer::get<IntegerType>(this, 64, IntegerType::Signless);
401   impl->int128Ty =
402       TypeUniquer::get<IntegerType>(this, 128, IntegerType::Signless);
403   /// None Type.
404   impl->noneType = TypeUniquer::get<NoneType>(this);
405 
406   //// Attributes.
407   //// Note: These must be registered after the types as they may generate one
408   //// of the above types internally.
409   /// Unknown Location Attribute.
410   impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(this);
411   /// Bool Attributes.
412   impl->falseAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, false);
413   impl->trueAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, true);
414   /// Unit Attribute.
415   impl->unitAttr = AttributeUniquer::get<UnitAttr>(this);
416   /// The empty dictionary attribute.
417   impl->emptyDictionaryAttr = DictionaryAttr::getEmptyUnchecked(this);
418   /// The empty string attribute.
419   impl->emptyStringAttr = StringAttr::getEmptyStringAttrUnchecked(this);
420 
421   // Register the affine storage objects with the uniquer.
422   impl->affineUniquer
423       .registerParametricStorageType<AffineBinaryOpExprStorage>();
424   impl->affineUniquer
425       .registerParametricStorageType<AffineConstantExprStorage>();
426   impl->affineUniquer.registerParametricStorageType<AffineDimExprStorage>();
427 }
428 
429 MLIRContext::~MLIRContext() {}
430 
431 /// Copy the specified array of elements into memory managed by the provided
432 /// bump pointer allocator.  This assumes the elements are all PODs.
433 template <typename T>
434 static ArrayRef<T> copyArrayRefInto(llvm::BumpPtrAllocator &allocator,
435                                     ArrayRef<T> elements) {
436   auto result = allocator.Allocate<T>(elements.size());
437   std::uninitialized_copy(elements.begin(), elements.end(), result);
438   return ArrayRef<T>(result, elements.size());
439 }
440 
441 //===----------------------------------------------------------------------===//
442 // Debugging
443 //===----------------------------------------------------------------------===//
444 
445 DebugActionManager &MLIRContext::getDebugActionManager() {
446   return getImpl().debugActionManager;
447 }
448 
449 //===----------------------------------------------------------------------===//
450 // Diagnostic Handlers
451 //===----------------------------------------------------------------------===//
452 
453 /// Returns the diagnostic engine for this context.
454 DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
455 
456 //===----------------------------------------------------------------------===//
457 // Dialect and Operation Registration
458 //===----------------------------------------------------------------------===//
459 
460 void MLIRContext::appendDialectRegistry(const DialectRegistry &registry) {
461   registry.appendTo(impl->dialectsRegistry);
462 
463   // For the already loaded dialects, register the interfaces immediately.
464   for (const auto &kvp : impl->loadedDialects)
465     registry.registerDelayedInterfaces(kvp.second.get());
466 }
467 
468 const DialectRegistry &MLIRContext::getDialectRegistry() {
469   return impl->dialectsRegistry;
470 }
471 
472 /// Return information about all registered IR dialects.
473 std::vector<Dialect *> MLIRContext::getLoadedDialects() {
474   std::vector<Dialect *> result;
475   result.reserve(impl->loadedDialects.size());
476   for (auto &dialect : impl->loadedDialects)
477     result.push_back(dialect.second.get());
478   llvm::array_pod_sort(result.begin(), result.end(),
479                        [](Dialect *const *lhs, Dialect *const *rhs) -> int {
480                          return (*lhs)->getNamespace() < (*rhs)->getNamespace();
481                        });
482   return result;
483 }
484 std::vector<StringRef> MLIRContext::getAvailableDialects() {
485   std::vector<StringRef> result;
486   for (auto dialect : impl->dialectsRegistry.getDialectNames())
487     result.push_back(dialect);
488   return result;
489 }
490 
491 /// Get a registered IR dialect with the given namespace. If none is found,
492 /// then return nullptr.
493 Dialect *MLIRContext::getLoadedDialect(StringRef name) {
494   // Dialects are sorted by name, so we can use binary search for lookup.
495   auto it = impl->loadedDialects.find(name);
496   return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr;
497 }
498 
499 Dialect *MLIRContext::getOrLoadDialect(StringRef name) {
500   Dialect *dialect = getLoadedDialect(name);
501   if (dialect)
502     return dialect;
503   DialectAllocatorFunctionRef allocator =
504       impl->dialectsRegistry.getDialectAllocator(name);
505   return allocator ? allocator(this) : nullptr;
506 }
507 
508 /// Get a dialect for the provided namespace and TypeID: abort the program if a
509 /// dialect exist for this namespace with different TypeID. Returns a pointer to
510 /// the dialect owned by the context.
511 Dialect *
512 MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
513                               function_ref<std::unique_ptr<Dialect>()> ctor) {
514   auto &impl = getImpl();
515   // Get the correct insertion position sorted by namespace.
516   std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace];
517 
518   if (!dialect) {
519     LLVM_DEBUG(llvm::dbgs()
520                << "Load new dialect in Context " << dialectNamespace << "\n");
521 #ifndef NDEBUG
522     if (impl.multiThreadedExecutionContext != 0)
523       llvm::report_fatal_error(
524           "Loading a dialect (" + dialectNamespace +
525           ") while in a multi-threaded execution context (maybe "
526           "the PassManager): this can indicate a "
527           "missing `dependentDialects` in a pass for example.");
528 #endif
529     dialect = ctor();
530     assert(dialect && "dialect ctor failed");
531 
532     // Refresh all the identifiers dialect field, this catches cases where a
533     // dialect may be loaded after identifier prefixed with this dialect name
534     // were already created.
535     auto stringAttrsIt = impl.dialectReferencingStrAttrs.find(dialectNamespace);
536     if (stringAttrsIt != impl.dialectReferencingStrAttrs.end()) {
537       for (StringAttrStorage *storage : stringAttrsIt->second)
538         storage->referencedDialect = dialect.get();
539       impl.dialectReferencingStrAttrs.erase(stringAttrsIt);
540     }
541 
542     // Actually register the interfaces with delayed registration.
543     impl.dialectsRegistry.registerDelayedInterfaces(dialect.get());
544     return dialect.get();
545   }
546 
547   // Abort if dialect with namespace has already been registered.
548   if (dialect->getTypeID() != dialectID)
549     llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
550                              "' has already been registered");
551 
552   return dialect.get();
553 }
554 
555 void MLIRContext::loadAllAvailableDialects() {
556   for (StringRef name : getAvailableDialects())
557     getOrLoadDialect(name);
558 }
559 
560 llvm::hash_code MLIRContext::getRegistryHash() {
561   llvm::hash_code hash(0);
562   // Factor in number of loaded dialects, attributes, operations, types.
563   hash = llvm::hash_combine(hash, impl->loadedDialects.size());
564   hash = llvm::hash_combine(hash, impl->registeredAttributes.size());
565   hash = llvm::hash_combine(hash, impl->registeredOperations.size());
566   hash = llvm::hash_combine(hash, impl->registeredTypes.size());
567   return hash;
568 }
569 
570 bool MLIRContext::allowsUnregisteredDialects() {
571   return impl->allowUnregisteredDialects;
572 }
573 
574 void MLIRContext::allowUnregisteredDialects(bool allowing) {
575   impl->allowUnregisteredDialects = allowing;
576 }
577 
578 /// Return true if multi-threading is enabled by the context.
579 bool MLIRContext::isMultithreadingEnabled() {
580   return impl->threadingIsEnabled && llvm::llvm_is_multithreaded();
581 }
582 
583 /// Set the flag specifying if multi-threading is disabled by the context.
584 void MLIRContext::disableMultithreading(bool disable) {
585   // This API can be overridden by the global debugging flag
586   // --mlir-disable-threading
587   if (isThreadingGloballyDisabled())
588     return;
589 
590   impl->threadingIsEnabled = !disable;
591 
592   // Update the threading mode for each of the uniquers.
593   impl->affineUniquer.disableMultithreading(disable);
594   impl->attributeUniquer.disableMultithreading(disable);
595   impl->typeUniquer.disableMultithreading(disable);
596 
597   // Destroy thread pool (stop all threads) if it is no longer needed, or create
598   // a new one if multithreading was re-enabled.
599   if (disable) {
600     // If the thread pool is owned, explicitly set it to nullptr to avoid
601     // keeping a dangling pointer around. If the thread pool is externally
602     // owned, we don't do anything.
603     if (impl->ownedThreadPool) {
604       assert(impl->threadPool);
605       impl->threadPool = nullptr;
606       impl->ownedThreadPool.reset();
607     }
608   } else if (!impl->threadPool) {
609     // The thread pool isn't externally provided.
610     assert(!impl->ownedThreadPool);
611     impl->ownedThreadPool = std::make_unique<llvm::ThreadPool>();
612     impl->threadPool = impl->ownedThreadPool.get();
613   }
614 }
615 
616 void MLIRContext::setThreadPool(llvm::ThreadPool &pool) {
617   assert(!isMultithreadingEnabled() &&
618          "expected multi-threading to be disabled when setting a ThreadPool");
619   impl->threadPool = &pool;
620   impl->ownedThreadPool.reset();
621   enableMultithreading();
622 }
623 
624 llvm::ThreadPool &MLIRContext::getThreadPool() {
625   assert(isMultithreadingEnabled() &&
626          "expected multi-threading to be enabled within the context");
627   assert(impl->threadPool &&
628          "multi-threading is enabled but threadpool not set");
629   return *impl->threadPool;
630 }
631 
632 void MLIRContext::enterMultiThreadedExecution() {
633 #ifndef NDEBUG
634   ++impl->multiThreadedExecutionContext;
635 #endif
636 }
637 void MLIRContext::exitMultiThreadedExecution() {
638 #ifndef NDEBUG
639   --impl->multiThreadedExecutionContext;
640 #endif
641 }
642 
643 /// Return true if we should attach the operation to diagnostics emitted via
644 /// Operation::emit.
645 bool MLIRContext::shouldPrintOpOnDiagnostic() {
646   return impl->printOpOnDiagnostic;
647 }
648 
649 /// Set the flag specifying if we should attach the operation to diagnostics
650 /// emitted via Operation::emit.
651 void MLIRContext::printOpOnDiagnostic(bool enable) {
652   impl->printOpOnDiagnostic = enable;
653 }
654 
655 /// Return true if we should attach the current stacktrace to diagnostics when
656 /// emitted.
657 bool MLIRContext::shouldPrintStackTraceOnDiagnostic() {
658   return impl->printStackTraceOnDiagnostic;
659 }
660 
661 /// Set the flag specifying if we should attach the current stacktrace when
662 /// emitting diagnostics.
663 void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
664   impl->printStackTraceOnDiagnostic = enable;
665 }
666 
667 /// Return information about all registered operations.  This isn't very
668 /// efficient, typically you should ask the operations about their properties
669 /// directly.
670 std::vector<AbstractOperation *> MLIRContext::getRegisteredOperations() {
671   // We just have the operations in a non-deterministic hash table order. Dump
672   // into a temporary array, then sort it by operation name to get a stable
673   // ordering.
674   llvm::StringMap<AbstractOperation> &registeredOps =
675       impl->registeredOperations;
676 
677   std::vector<AbstractOperation *> result;
678   result.reserve(registeredOps.size());
679   for (auto &elt : registeredOps)
680     result.push_back(&elt.second);
681   llvm::array_pod_sort(
682       result.begin(), result.end(),
683       [](AbstractOperation *const *lhs, AbstractOperation *const *rhs) {
684         return (*lhs)->name.compare((*rhs)->name);
685       });
686 
687   return result;
688 }
689 
690 bool MLIRContext::isOperationRegistered(StringRef name) {
691   return impl->registeredOperations.count(name);
692 }
693 
694 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
695   auto &impl = context->getImpl();
696   assert(impl.multiThreadedExecutionContext == 0 &&
697          "Registering a new type kind while in a multi-threaded execution "
698          "context");
699   auto *newInfo =
700       new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>())
701           AbstractType(std::move(typeInfo));
702   if (!impl.registeredTypes.insert({typeID, newInfo}).second)
703     llvm::report_fatal_error("Dialect Type already registered.");
704 }
705 
706 void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
707   auto &impl = context->getImpl();
708   assert(impl.multiThreadedExecutionContext == 0 &&
709          "Registering a new attribute kind while in a multi-threaded execution "
710          "context");
711   auto *newInfo =
712       new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>())
713           AbstractAttribute(std::move(attrInfo));
714   if (!impl.registeredAttributes.insert({typeID, newInfo}).second)
715     llvm::report_fatal_error("Dialect Attribute already registered.");
716 }
717 
718 //===----------------------------------------------------------------------===//
719 // AbstractAttribute
720 //===----------------------------------------------------------------------===//
721 
722 /// Get the dialect that registered the attribute with the provided typeid.
723 const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID,
724                                                    MLIRContext *context) {
725   const AbstractAttribute *abstract = lookupMutable(typeID, context);
726   if (!abstract)
727     llvm::report_fatal_error("Trying to create an Attribute that was not "
728                              "registered in this MLIRContext.");
729   return *abstract;
730 }
731 
732 AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID,
733                                                     MLIRContext *context) {
734   auto &impl = context->getImpl();
735   auto it = impl.registeredAttributes.find(typeID);
736   if (it == impl.registeredAttributes.end())
737     return nullptr;
738   return it->second;
739 }
740 
741 //===----------------------------------------------------------------------===//
742 // AbstractOperation
743 //===----------------------------------------------------------------------===//
744 
745 ParseResult AbstractOperation::parseAssembly(OpAsmParser &parser,
746                                              OperationState &result) const {
747   return parseAssemblyFn(parser, result);
748 }
749 
750 /// Look up the specified operation in the operation set and return a pointer
751 /// to it if present. Otherwise, return a null pointer.
752 AbstractOperation *AbstractOperation::lookupMutable(StringRef opName,
753                                                     MLIRContext *context) {
754   auto &impl = context->getImpl();
755   auto it = impl.registeredOperations.find(opName);
756   if (it != impl.registeredOperations.end())
757     return &it->second;
758   return nullptr;
759 }
760 
761 void AbstractOperation::insert(
762     StringRef name, Dialect &dialect, TypeID typeID,
763     ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
764     VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
765     GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
766     detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
767     ArrayRef<StringRef> attrNames) {
768   MLIRContext *ctx = dialect.getContext();
769   auto &impl = ctx->getImpl();
770   assert(impl.multiThreadedExecutionContext == 0 &&
771          "Registering a new operation kind while in a multi-threaded execution "
772          "context");
773 
774   // Register the attribute names of this operation.
775   MutableArrayRef<StringAttr> cachedAttrNames;
776   if (!attrNames.empty()) {
777     cachedAttrNames = MutableArrayRef<StringAttr>(
778         impl.abstractDialectSymbolAllocator.Allocate<StringAttr>(
779             attrNames.size()),
780         attrNames.size());
781     for (unsigned i : llvm::seq<unsigned>(0, attrNames.size()))
782       new (&cachedAttrNames[i]) StringAttr(StringAttr::get(ctx, attrNames[i]));
783   }
784 
785   // Register the information for this operation.
786   AbstractOperation opInfo(
787       name, dialect, typeID, std::move(parseAssembly), std::move(printAssembly),
788       std::move(verifyInvariants), std::move(foldHook),
789       std::move(getCanonicalizationPatterns), std::move(interfaceMap),
790       std::move(hasTrait), cachedAttrNames);
791   if (!impl.registeredOperations.insert({name, std::move(opInfo)}).second) {
792     llvm::errs() << "error: operation named '" << name
793                  << "' is already registered.\n";
794     abort();
795   }
796 }
797 
798 AbstractOperation::AbstractOperation(
799     StringRef name, Dialect &dialect, TypeID typeID,
800     ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
801     VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
802     GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
803     detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
804     ArrayRef<StringAttr> attrNames)
805     : name(StringAttr::get(dialect.getContext(), name)), dialect(dialect),
806       typeID(typeID), interfaceMap(std::move(interfaceMap)),
807       foldHookFn(std::move(foldHook)),
808       getCanonicalizationPatternsFn(std::move(getCanonicalizationPatterns)),
809       hasTraitFn(std::move(hasTrait)),
810       parseAssemblyFn(std::move(parseAssembly)),
811       printAssemblyFn(std::move(printAssembly)),
812       verifyInvariantsFn(std::move(verifyInvariants)),
813       attributeNames(attrNames) {}
814 
815 //===----------------------------------------------------------------------===//
816 // AbstractType
817 //===----------------------------------------------------------------------===//
818 
819 const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) {
820   const AbstractType *type = lookupMutable(typeID, context);
821   if (!type)
822     llvm::report_fatal_error(
823         "Trying to create a Type that was not registered in this MLIRContext.");
824   return *type;
825 }
826 
827 AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) {
828   auto &impl = context->getImpl();
829   auto it = impl.registeredTypes.find(typeID);
830   if (it == impl.registeredTypes.end())
831     return nullptr;
832   return it->second;
833 }
834 
835 //===----------------------------------------------------------------------===//
836 // Type uniquing
837 //===----------------------------------------------------------------------===//
838 
839 /// Returns the storage uniquer used for constructing type storage instances.
840 /// This should not be used directly.
841 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
842 
843 BFloat16Type BFloat16Type::get(MLIRContext *context) {
844   return context->getImpl().bf16Ty;
845 }
846 Float16Type Float16Type::get(MLIRContext *context) {
847   return context->getImpl().f16Ty;
848 }
849 Float32Type Float32Type::get(MLIRContext *context) {
850   return context->getImpl().f32Ty;
851 }
852 Float64Type Float64Type::get(MLIRContext *context) {
853   return context->getImpl().f64Ty;
854 }
855 Float80Type Float80Type::get(MLIRContext *context) {
856   return context->getImpl().f80Ty;
857 }
858 Float128Type Float128Type::get(MLIRContext *context) {
859   return context->getImpl().f128Ty;
860 }
861 
862 /// Get an instance of the IndexType.
863 IndexType IndexType::get(MLIRContext *context) {
864   return context->getImpl().indexTy;
865 }
866 
867 /// Return an existing integer type instance if one is cached within the
868 /// context.
869 static IntegerType
870 getCachedIntegerType(unsigned width,
871                      IntegerType::SignednessSemantics signedness,
872                      MLIRContext *context) {
873   if (signedness != IntegerType::Signless)
874     return IntegerType();
875 
876   switch (width) {
877   case 1:
878     return context->getImpl().int1Ty;
879   case 8:
880     return context->getImpl().int8Ty;
881   case 16:
882     return context->getImpl().int16Ty;
883   case 32:
884     return context->getImpl().int32Ty;
885   case 64:
886     return context->getImpl().int64Ty;
887   case 128:
888     return context->getImpl().int128Ty;
889   default:
890     return IntegerType();
891   }
892 }
893 
894 IntegerType IntegerType::get(MLIRContext *context, unsigned width,
895                              IntegerType::SignednessSemantics signedness) {
896   if (auto cached = getCachedIntegerType(width, signedness, context))
897     return cached;
898   return Base::get(context, width, signedness);
899 }
900 
901 IntegerType
902 IntegerType::getChecked(function_ref<InFlightDiagnostic()> emitError,
903                         MLIRContext *context, unsigned width,
904                         SignednessSemantics signedness) {
905   if (auto cached = getCachedIntegerType(width, signedness, context))
906     return cached;
907   return Base::getChecked(emitError, context, width, signedness);
908 }
909 
910 /// Get an instance of the NoneType.
911 NoneType NoneType::get(MLIRContext *context) {
912   if (NoneType cachedInst = context->getImpl().noneType)
913     return cachedInst;
914   // Note: May happen when initializing the singleton attributes of the builtin
915   // dialect.
916   return Base::get(context);
917 }
918 
919 //===----------------------------------------------------------------------===//
920 // Attribute uniquing
921 //===----------------------------------------------------------------------===//
922 
923 /// Returns the storage uniquer used for constructing attribute storage
924 /// instances. This should not be used directly.
925 StorageUniquer &MLIRContext::getAttributeUniquer() {
926   return getImpl().attributeUniquer;
927 }
928 
929 /// Initialize the given attribute storage instance.
930 void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage,
931                                                   MLIRContext *ctx,
932                                                   TypeID attrID) {
933   storage->initializeAbstractAttribute(AbstractAttribute::lookup(attrID, ctx));
934 
935   // If the attribute did not provide a type, then default to NoneType.
936   if (!storage->getType())
937     storage->setType(NoneType::get(ctx));
938 }
939 
940 BoolAttr BoolAttr::get(MLIRContext *context, bool value) {
941   return value ? context->getImpl().trueAttr : context->getImpl().falseAttr;
942 }
943 
944 UnitAttr UnitAttr::get(MLIRContext *context) {
945   return context->getImpl().unitAttr;
946 }
947 
948 UnknownLoc UnknownLoc::get(MLIRContext *context) {
949   return context->getImpl().unknownLocAttr;
950 }
951 
952 /// Return empty dictionary.
953 DictionaryAttr DictionaryAttr::getEmpty(MLIRContext *context) {
954   return context->getImpl().emptyDictionaryAttr;
955 }
956 
957 void StringAttrStorage::initialize(MLIRContext *context) {
958   // Check for a dialect namespace prefix, if there isn't one we don't need to
959   // do any additional initialization.
960   auto dialectNamePair = value.split('.');
961   if (dialectNamePair.first.empty() || dialectNamePair.second.empty())
962     return;
963 
964   // If one exists, we check to see if this dialect is loaded. If it is, we set
965   // the dialect now, if it isn't we record this storage for initialization
966   // later if the dialect ever gets loaded.
967   if ((referencedDialect = context->getLoadedDialect(dialectNamePair.first)))
968     return;
969 
970   MLIRContextImpl &impl = context->getImpl();
971   llvm::sys::SmartScopedLock<true> lock(impl.dialectRefStrAttrMutex);
972   impl.dialectReferencingStrAttrs[dialectNamePair.first].push_back(this);
973 }
974 
975 /// Return an empty string.
976 StringAttr StringAttr::get(MLIRContext *context) {
977   return context->getImpl().emptyStringAttr;
978 }
979 
980 //===----------------------------------------------------------------------===//
981 // AffineMap uniquing
982 //===----------------------------------------------------------------------===//
983 
984 StorageUniquer &MLIRContext::getAffineUniquer() {
985   return getImpl().affineUniquer;
986 }
987 
988 AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount,
989                              ArrayRef<AffineExpr> results,
990                              MLIRContext *context) {
991   auto &impl = context->getImpl();
992   auto key = std::make_tuple(dimCount, symbolCount, results);
993 
994   // Safely get or create an AffineMap instance.
995   return safeGetOrCreate(
996       impl.affineMaps, key, impl.affineMutex, impl.threadingIsEnabled, [&] {
997         auto *res = impl.affineAllocator.Allocate<detail::AffineMapStorage>();
998 
999         // Copy the results into the bump pointer.
1000         results = copyArrayRefInto(impl.affineAllocator, results);
1001 
1002         // Initialize the memory using placement new.
1003         new (res)
1004             detail::AffineMapStorage{dimCount, symbolCount, results, context};
1005         return AffineMap(res);
1006       });
1007 }
1008 
1009 AffineMap AffineMap::get(MLIRContext *context) {
1010   return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context);
1011 }
1012 
1013 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
1014                          MLIRContext *context) {
1015   return getImpl(dimCount, symbolCount, /*results=*/{}, context);
1016 }
1017 
1018 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
1019                          AffineExpr result) {
1020   return getImpl(dimCount, symbolCount, {result}, result.getContext());
1021 }
1022 
1023 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
1024                          ArrayRef<AffineExpr> results, MLIRContext *context) {
1025   return getImpl(dimCount, symbolCount, results, context);
1026 }
1027 
1028 //===----------------------------------------------------------------------===//
1029 // Integer Sets: these are allocated into the bump pointer, and are immutable.
1030 // Unlike AffineMap's, these are uniqued only if they are small.
1031 //===----------------------------------------------------------------------===//
1032 
1033 IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount,
1034                            ArrayRef<AffineExpr> constraints,
1035                            ArrayRef<bool> eqFlags) {
1036   // The number of constraints can't be zero.
1037   assert(!constraints.empty());
1038   assert(constraints.size() == eqFlags.size());
1039 
1040   auto &impl = constraints[0].getContext()->getImpl();
1041 
1042   // A utility function to construct a new IntegerSetStorage instance.
1043   auto constructorFn = [&] {
1044     auto *res = impl.affineAllocator.Allocate<detail::IntegerSetStorage>();
1045 
1046     // Copy the results and equality flags into the bump pointer.
1047     constraints = copyArrayRefInto(impl.affineAllocator, constraints);
1048     eqFlags = copyArrayRefInto(impl.affineAllocator, eqFlags);
1049 
1050     // Initialize the memory using placement new.
1051     new (res)
1052         detail::IntegerSetStorage{dimCount, symbolCount, constraints, eqFlags};
1053     return IntegerSet(res);
1054   };
1055 
1056   // If this instance is uniqued, then we handle it separately so that multiple
1057   // threads may simultaneously access existing instances.
1058   if (constraints.size() < IntegerSet::kUniquingThreshold) {
1059     auto key = std::make_tuple(dimCount, symbolCount, constraints, eqFlags);
1060     return safeGetOrCreate(impl.integerSets, key, impl.affineMutex,
1061                            impl.threadingIsEnabled, constructorFn);
1062   }
1063 
1064   // Otherwise, acquire a writer-lock so that we can safely create the new
1065   // instance.
1066   ScopedWriterLock affineLock(impl.affineMutex, impl.threadingIsEnabled);
1067   return constructorFn();
1068 }
1069 
1070 //===----------------------------------------------------------------------===//
1071 // StorageUniquerSupport
1072 //===----------------------------------------------------------------------===//
1073 
1074 /// Utility method to generate a callback that can be used to generate a
1075 /// diagnostic when checking the construction invariants of a storage object.
1076 /// This is defined out-of-line to avoid the need to include Location.h.
1077 llvm::unique_function<InFlightDiagnostic()>
1078 mlir::detail::getDefaultDiagnosticEmitFn(MLIRContext *ctx) {
1079   return [ctx] { return emitError(UnknownLoc::get(ctx)); };
1080 }
1081 llvm::unique_function<InFlightDiagnostic()>
1082 mlir::detail::getDefaultDiagnosticEmitFn(const Location &loc) {
1083   return [=] { return emitError(loc); };
1084 }
1085