1 //===- MLIRContext.cpp - MLIR Type Classes --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/MLIRContext.h"
10 #include "AffineExprDetail.h"
11 #include "AffineMapDetail.h"
12 #include "AttributeDetail.h"
13 #include "IntegerSetDetail.h"
14 #include "TypeDetail.h"
15 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/BuiltinDialect.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/IR/IntegerSet.h"
22 #include "mlir/IR/Location.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/IR/Types.h"
25 #include "mlir/Support/DebugAction.h"
26 #include "llvm/ADT/DenseMap.h"
27 #include "llvm/ADT/DenseSet.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/ADT/SmallString.h"
30 #include "llvm/ADT/StringSet.h"
31 #include "llvm/ADT/Twine.h"
32 #include "llvm/Support/Allocator.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/Mutex.h"
36 #include "llvm/Support/RWMutex.h"
37 #include "llvm/Support/ThreadPool.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include <memory>
40 
41 #define DEBUG_TYPE "mlircontext"
42 
43 using namespace mlir;
44 using namespace mlir::detail;
45 
46 using llvm::hash_combine;
47 using llvm::hash_combine_range;
48 
49 //===----------------------------------------------------------------------===//
50 // MLIRContext CommandLine Options
51 //===----------------------------------------------------------------------===//
52 
53 namespace {
54 /// This struct contains command line options that can be used to initialize
55 /// various bits of an MLIRContext. This uses a struct wrapper to avoid the need
56 /// for global command line options.
57 struct MLIRContextOptions {
58   llvm::cl::opt<bool> disableThreading{
59       "mlir-disable-threading",
60       llvm::cl::desc("Disable multi-threading within MLIR, overrides any "
61                      "further call to MLIRContext::enableMultiThreading()")};
62 
63   llvm::cl::opt<bool> printOpOnDiagnostic{
64       "mlir-print-op-on-diagnostic",
65       llvm::cl::desc("When a diagnostic is emitted on an operation, also print "
66                      "the operation as an attached note"),
67       llvm::cl::init(true)};
68 
69   llvm::cl::opt<bool> printStackTraceOnDiagnostic{
70       "mlir-print-stacktrace-on-diagnostic",
71       llvm::cl::desc("When a diagnostic is emitted, also print the stack trace "
72                      "as an attached note")};
73 };
74 } // end anonymous namespace
75 
76 static llvm::ManagedStatic<MLIRContextOptions> clOptions;
77 
78 static bool isThreadingGloballyDisabled() {
79 #if LLVM_ENABLE_THREADS != 0
80   return clOptions.isConstructed() && clOptions->disableThreading;
81 #else
82   return true;
83 #endif
84 }
85 
86 /// Register a set of useful command-line options that can be used to configure
87 /// various flags within the MLIRContext. These flags are used when constructing
88 /// an MLIR context for initialization.
89 void mlir::registerMLIRContextCLOptions() {
90   // Make sure that the options struct has been initialized.
91   *clOptions;
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // Locking Utilities
96 //===----------------------------------------------------------------------===//
97 
98 namespace {
99 /// Utility writer lock that takes a runtime flag that specifies if we really
100 /// need to lock.
101 struct ScopedWriterLock {
102   ScopedWriterLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock)
103       : mutex(shouldLock ? &mutexParam : nullptr) {
104     if (mutex)
105       mutex->lock();
106   }
107   ~ScopedWriterLock() {
108     if (mutex)
109       mutex->unlock();
110   }
111   llvm::sys::SmartRWMutex<true> *mutex;
112 };
113 } // end anonymous namespace.
114 
115 //===----------------------------------------------------------------------===//
116 // MLIRContextImpl
117 //===----------------------------------------------------------------------===//
118 
119 namespace mlir {
120 /// This is the implementation of the MLIRContext class, using the pImpl idiom.
121 /// This class is completely private to this file, so everything is public.
122 class MLIRContextImpl {
123 public:
124   //===--------------------------------------------------------------------===//
125   // Debugging
126   //===--------------------------------------------------------------------===//
127 
128   /// An action manager for use within the context.
129   DebugActionManager debugActionManager;
130 
131   //===--------------------------------------------------------------------===//
132   // Diagnostics
133   //===--------------------------------------------------------------------===//
134   DiagnosticEngine diagEngine;
135 
136   //===--------------------------------------------------------------------===//
137   // Options
138   //===--------------------------------------------------------------------===//
139 
140   /// In most cases, creating operation in unregistered dialect is not desired
141   /// and indicate a misconfiguration of the compiler. This option enables to
142   /// detect such use cases
143   bool allowUnregisteredDialects = false;
144 
145   /// Enable support for multi-threading within MLIR.
146   bool threadingIsEnabled = true;
147 
148   /// Track if we are currently executing in a threaded execution environment
149   /// (like the pass-manager): this is only a debugging feature to help reducing
150   /// the chances of data races one some context APIs.
151 #ifndef NDEBUG
152   std::atomic<int> multiThreadedExecutionContext{0};
153 #endif
154 
155   /// If the operation should be attached to diagnostics printed via the
156   /// Operation::emit methods.
157   bool printOpOnDiagnostic = true;
158 
159   /// If the current stack trace should be attached when emitting diagnostics.
160   bool printStackTraceOnDiagnostic = false;
161 
162   //===--------------------------------------------------------------------===//
163   // Other
164   //===--------------------------------------------------------------------===//
165 
166   /// This points to the ThreadPool used when processing MLIR tasks in parallel.
167   /// It can't be nullptr when multi-threading is enabled. Otherwise if
168   /// multi-threading is disabled, and the threadpool wasn't externally provided
169   /// using `setThreadPool`, this will be nullptr.
170   llvm::ThreadPool *threadPool = nullptr;
171 
172   /// In case where the thread pool is owned by the context, this ensures
173   /// destruction with the context.
174   std::unique_ptr<llvm::ThreadPool> ownedThreadPool;
175 
176   /// This is a list of dialects that are created referring to this context.
177   /// The MLIRContext owns the objects.
178   DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects;
179   DialectRegistry dialectsRegistry;
180 
181   /// An allocator used for AbstractAttribute and AbstractType objects.
182   llvm::BumpPtrAllocator abstractDialectSymbolAllocator;
183 
184   /// This is a mapping from operation name to the operation info describing it.
185   llvm::StringMap<OperationName::Impl> operations;
186 
187   /// A vector of operation info specifically for registered operations.
188   SmallVector<RegisteredOperationName> registeredOperations;
189 
190   /// A mutex used when accessing operation information.
191   llvm::sys::SmartRWMutex<true> operationInfoMutex;
192 
193   //===--------------------------------------------------------------------===//
194   // Affine uniquing
195   //===--------------------------------------------------------------------===//
196 
197   // Affine expression, map and integer set uniquing.
198   StorageUniquer affineUniquer;
199 
200   //===--------------------------------------------------------------------===//
201   // Type uniquing
202   //===--------------------------------------------------------------------===//
203 
204   DenseMap<TypeID, AbstractType *> registeredTypes;
205   StorageUniquer typeUniquer;
206 
207   /// Cached Type Instances.
208   BFloat16Type bf16Ty;
209   Float16Type f16Ty;
210   Float32Type f32Ty;
211   Float64Type f64Ty;
212   Float80Type f80Ty;
213   Float128Type f128Ty;
214   IndexType indexTy;
215   IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
216   NoneType noneType;
217 
218   //===--------------------------------------------------------------------===//
219   // Attribute uniquing
220   //===--------------------------------------------------------------------===//
221 
222   DenseMap<TypeID, AbstractAttribute *> registeredAttributes;
223   StorageUniquer attributeUniquer;
224 
225   /// Cached Attribute Instances.
226   BoolAttr falseAttr, trueAttr;
227   UnitAttr unitAttr;
228   UnknownLoc unknownLocAttr;
229   DictionaryAttr emptyDictionaryAttr;
230   StringAttr emptyStringAttr;
231 
232   /// Map of string attributes that may reference a dialect, that are awaiting
233   /// that dialect to be loaded.
234   llvm::sys::SmartMutex<true> dialectRefStrAttrMutex;
235   DenseMap<StringRef, SmallVector<StringAttrStorage *>>
236       dialectReferencingStrAttrs;
237 
238 public:
239   MLIRContextImpl(bool threadingIsEnabled)
240       : threadingIsEnabled(threadingIsEnabled) {
241     if (threadingIsEnabled) {
242       ownedThreadPool = std::make_unique<llvm::ThreadPool>();
243       threadPool = ownedThreadPool.get();
244     }
245   }
246   ~MLIRContextImpl() {
247     for (auto typeMapping : registeredTypes)
248       typeMapping.second->~AbstractType();
249     for (auto attrMapping : registeredAttributes)
250       attrMapping.second->~AbstractAttribute();
251   }
252 };
253 } // end namespace mlir
254 
255 MLIRContext::MLIRContext(Threading setting)
256     : MLIRContext(DialectRegistry(), setting) {}
257 
258 MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
259     : impl(new MLIRContextImpl(setting == Threading::ENABLED &&
260                                !isThreadingGloballyDisabled())) {
261   // Initialize values based on the command line flags if they were provided.
262   if (clOptions.isConstructed()) {
263     printOpOnDiagnostic(clOptions->printOpOnDiagnostic);
264     printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic);
265   }
266 
267   // Pre-populate the registry.
268   registry.appendTo(impl->dialectsRegistry);
269 
270   // Ensure the builtin dialect is always pre-loaded.
271   getOrLoadDialect<BuiltinDialect>();
272 
273   // Initialize several common attributes and types to avoid the need to lock
274   // the context when accessing them.
275 
276   //// Types.
277   /// Floating-point Types.
278   impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
279   impl->f16Ty = TypeUniquer::get<Float16Type>(this);
280   impl->f32Ty = TypeUniquer::get<Float32Type>(this);
281   impl->f64Ty = TypeUniquer::get<Float64Type>(this);
282   impl->f80Ty = TypeUniquer::get<Float80Type>(this);
283   impl->f128Ty = TypeUniquer::get<Float128Type>(this);
284   /// Index Type.
285   impl->indexTy = TypeUniquer::get<IndexType>(this);
286   /// Integer Types.
287   impl->int1Ty = TypeUniquer::get<IntegerType>(this, 1, IntegerType::Signless);
288   impl->int8Ty = TypeUniquer::get<IntegerType>(this, 8, IntegerType::Signless);
289   impl->int16Ty =
290       TypeUniquer::get<IntegerType>(this, 16, IntegerType::Signless);
291   impl->int32Ty =
292       TypeUniquer::get<IntegerType>(this, 32, IntegerType::Signless);
293   impl->int64Ty =
294       TypeUniquer::get<IntegerType>(this, 64, IntegerType::Signless);
295   impl->int128Ty =
296       TypeUniquer::get<IntegerType>(this, 128, IntegerType::Signless);
297   /// None Type.
298   impl->noneType = TypeUniquer::get<NoneType>(this);
299 
300   //// Attributes.
301   //// Note: These must be registered after the types as they may generate one
302   //// of the above types internally.
303   /// Unknown Location Attribute.
304   impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(this);
305   /// Bool Attributes.
306   impl->falseAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, false);
307   impl->trueAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, true);
308   /// Unit Attribute.
309   impl->unitAttr = AttributeUniquer::get<UnitAttr>(this);
310   /// The empty dictionary attribute.
311   impl->emptyDictionaryAttr = DictionaryAttr::getEmptyUnchecked(this);
312   /// The empty string attribute.
313   impl->emptyStringAttr = StringAttr::getEmptyStringAttrUnchecked(this);
314 
315   // Register the affine storage objects with the uniquer.
316   impl->affineUniquer
317       .registerParametricStorageType<AffineBinaryOpExprStorage>();
318   impl->affineUniquer
319       .registerParametricStorageType<AffineConstantExprStorage>();
320   impl->affineUniquer.registerParametricStorageType<AffineDimExprStorage>();
321   impl->affineUniquer.registerParametricStorageType<AffineMapStorage>();
322   impl->affineUniquer.registerParametricStorageType<IntegerSetStorage>();
323 }
324 
325 MLIRContext::~MLIRContext() {}
326 
327 /// Copy the specified array of elements into memory managed by the provided
328 /// bump pointer allocator.  This assumes the elements are all PODs.
329 template <typename T>
330 static ArrayRef<T> copyArrayRefInto(llvm::BumpPtrAllocator &allocator,
331                                     ArrayRef<T> elements) {
332   auto result = allocator.Allocate<T>(elements.size());
333   std::uninitialized_copy(elements.begin(), elements.end(), result);
334   return ArrayRef<T>(result, elements.size());
335 }
336 
337 //===----------------------------------------------------------------------===//
338 // Debugging
339 //===----------------------------------------------------------------------===//
340 
341 DebugActionManager &MLIRContext::getDebugActionManager() {
342   return getImpl().debugActionManager;
343 }
344 
345 //===----------------------------------------------------------------------===//
346 // Diagnostic Handlers
347 //===----------------------------------------------------------------------===//
348 
349 /// Returns the diagnostic engine for this context.
350 DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
351 
352 //===----------------------------------------------------------------------===//
353 // Dialect and Operation Registration
354 //===----------------------------------------------------------------------===//
355 
356 void MLIRContext::appendDialectRegistry(const DialectRegistry &registry) {
357   registry.appendTo(impl->dialectsRegistry);
358 
359   // For the already loaded dialects, register the interfaces immediately.
360   for (const auto &kvp : impl->loadedDialects)
361     registry.registerDelayedInterfaces(kvp.second.get());
362 }
363 
364 const DialectRegistry &MLIRContext::getDialectRegistry() {
365   return impl->dialectsRegistry;
366 }
367 
368 /// Return information about all registered IR dialects.
369 std::vector<Dialect *> MLIRContext::getLoadedDialects() {
370   std::vector<Dialect *> result;
371   result.reserve(impl->loadedDialects.size());
372   for (auto &dialect : impl->loadedDialects)
373     result.push_back(dialect.second.get());
374   llvm::array_pod_sort(result.begin(), result.end(),
375                        [](Dialect *const *lhs, Dialect *const *rhs) -> int {
376                          return (*lhs)->getNamespace() < (*rhs)->getNamespace();
377                        });
378   return result;
379 }
380 std::vector<StringRef> MLIRContext::getAvailableDialects() {
381   std::vector<StringRef> result;
382   for (auto dialect : impl->dialectsRegistry.getDialectNames())
383     result.push_back(dialect);
384   return result;
385 }
386 
387 /// Get a registered IR dialect with the given namespace. If none is found,
388 /// then return nullptr.
389 Dialect *MLIRContext::getLoadedDialect(StringRef name) {
390   // Dialects are sorted by name, so we can use binary search for lookup.
391   auto it = impl->loadedDialects.find(name);
392   return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr;
393 }
394 
395 Dialect *MLIRContext::getOrLoadDialect(StringRef name) {
396   Dialect *dialect = getLoadedDialect(name);
397   if (dialect)
398     return dialect;
399   DialectAllocatorFunctionRef allocator =
400       impl->dialectsRegistry.getDialectAllocator(name);
401   return allocator ? allocator(this) : nullptr;
402 }
403 
404 /// Get a dialect for the provided namespace and TypeID: abort the program if a
405 /// dialect exist for this namespace with different TypeID. Returns a pointer to
406 /// the dialect owned by the context.
407 Dialect *
408 MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
409                               function_ref<std::unique_ptr<Dialect>()> ctor) {
410   auto &impl = getImpl();
411   // Get the correct insertion position sorted by namespace.
412   std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace];
413 
414   if (!dialect) {
415     LLVM_DEBUG(llvm::dbgs()
416                << "Load new dialect in Context " << dialectNamespace << "\n");
417 #ifndef NDEBUG
418     if (impl.multiThreadedExecutionContext != 0)
419       llvm::report_fatal_error(
420           "Loading a dialect (" + dialectNamespace +
421           ") while in a multi-threaded execution context (maybe "
422           "the PassManager): this can indicate a "
423           "missing `dependentDialects` in a pass for example.");
424 #endif
425     dialect = ctor();
426     assert(dialect && "dialect ctor failed");
427 
428     // Refresh all the identifiers dialect field, this catches cases where a
429     // dialect may be loaded after identifier prefixed with this dialect name
430     // were already created.
431     auto stringAttrsIt = impl.dialectReferencingStrAttrs.find(dialectNamespace);
432     if (stringAttrsIt != impl.dialectReferencingStrAttrs.end()) {
433       for (StringAttrStorage *storage : stringAttrsIt->second)
434         storage->referencedDialect = dialect.get();
435       impl.dialectReferencingStrAttrs.erase(stringAttrsIt);
436     }
437 
438     // Actually register the interfaces with delayed registration.
439     impl.dialectsRegistry.registerDelayedInterfaces(dialect.get());
440     return dialect.get();
441   }
442 
443   // Abort if dialect with namespace has already been registered.
444   if (dialect->getTypeID() != dialectID)
445     llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
446                              "' has already been registered");
447 
448   return dialect.get();
449 }
450 
451 void MLIRContext::loadAllAvailableDialects() {
452   for (StringRef name : getAvailableDialects())
453     getOrLoadDialect(name);
454 }
455 
456 llvm::hash_code MLIRContext::getRegistryHash() {
457   llvm::hash_code hash(0);
458   // Factor in number of loaded dialects, attributes, operations, types.
459   hash = llvm::hash_combine(hash, impl->loadedDialects.size());
460   hash = llvm::hash_combine(hash, impl->registeredAttributes.size());
461   hash = llvm::hash_combine(hash, impl->registeredOperations.size());
462   hash = llvm::hash_combine(hash, impl->registeredTypes.size());
463   return hash;
464 }
465 
466 bool MLIRContext::allowsUnregisteredDialects() {
467   return impl->allowUnregisteredDialects;
468 }
469 
470 void MLIRContext::allowUnregisteredDialects(bool allowing) {
471   impl->allowUnregisteredDialects = allowing;
472 }
473 
474 /// Return true if multi-threading is enabled by the context.
475 bool MLIRContext::isMultithreadingEnabled() {
476   return impl->threadingIsEnabled && llvm::llvm_is_multithreaded();
477 }
478 
479 /// Set the flag specifying if multi-threading is disabled by the context.
480 void MLIRContext::disableMultithreading(bool disable) {
481   // This API can be overridden by the global debugging flag
482   // --mlir-disable-threading
483   if (isThreadingGloballyDisabled())
484     return;
485 
486   impl->threadingIsEnabled = !disable;
487 
488   // Update the threading mode for each of the uniquers.
489   impl->affineUniquer.disableMultithreading(disable);
490   impl->attributeUniquer.disableMultithreading(disable);
491   impl->typeUniquer.disableMultithreading(disable);
492 
493   // Destroy thread pool (stop all threads) if it is no longer needed, or create
494   // a new one if multithreading was re-enabled.
495   if (disable) {
496     // If the thread pool is owned, explicitly set it to nullptr to avoid
497     // keeping a dangling pointer around. If the thread pool is externally
498     // owned, we don't do anything.
499     if (impl->ownedThreadPool) {
500       assert(impl->threadPool);
501       impl->threadPool = nullptr;
502       impl->ownedThreadPool.reset();
503     }
504   } else if (!impl->threadPool) {
505     // The thread pool isn't externally provided.
506     assert(!impl->ownedThreadPool);
507     impl->ownedThreadPool = std::make_unique<llvm::ThreadPool>();
508     impl->threadPool = impl->ownedThreadPool.get();
509   }
510 }
511 
512 void MLIRContext::setThreadPool(llvm::ThreadPool &pool) {
513   assert(!isMultithreadingEnabled() &&
514          "expected multi-threading to be disabled when setting a ThreadPool");
515   impl->threadPool = &pool;
516   impl->ownedThreadPool.reset();
517   enableMultithreading();
518 }
519 
520 llvm::ThreadPool &MLIRContext::getThreadPool() {
521   assert(isMultithreadingEnabled() &&
522          "expected multi-threading to be enabled within the context");
523   assert(impl->threadPool &&
524          "multi-threading is enabled but threadpool not set");
525   return *impl->threadPool;
526 }
527 
528 void MLIRContext::enterMultiThreadedExecution() {
529 #ifndef NDEBUG
530   ++impl->multiThreadedExecutionContext;
531 #endif
532 }
533 void MLIRContext::exitMultiThreadedExecution() {
534 #ifndef NDEBUG
535   --impl->multiThreadedExecutionContext;
536 #endif
537 }
538 
539 /// Return true if we should attach the operation to diagnostics emitted via
540 /// Operation::emit.
541 bool MLIRContext::shouldPrintOpOnDiagnostic() {
542   return impl->printOpOnDiagnostic;
543 }
544 
545 /// Set the flag specifying if we should attach the operation to diagnostics
546 /// emitted via Operation::emit.
547 void MLIRContext::printOpOnDiagnostic(bool enable) {
548   impl->printOpOnDiagnostic = enable;
549 }
550 
551 /// Return true if we should attach the current stacktrace to diagnostics when
552 /// emitted.
553 bool MLIRContext::shouldPrintStackTraceOnDiagnostic() {
554   return impl->printStackTraceOnDiagnostic;
555 }
556 
557 /// Set the flag specifying if we should attach the current stacktrace when
558 /// emitting diagnostics.
559 void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
560   impl->printStackTraceOnDiagnostic = enable;
561 }
562 
563 /// Return information about all registered operations.  This isn't very
564 /// efficient, typically you should ask the operations about their properties
565 /// directly.
566 std::vector<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
567   // We just have the operations in a non-deterministic hash table order. Dump
568   // into a temporary array, then sort it by operation name to get a stable
569   // ordering.
570   std::vector<RegisteredOperationName> result(
571       impl->registeredOperations.begin(), impl->registeredOperations.end());
572   llvm::array_pod_sort(result.begin(), result.end(),
573                        [](const RegisteredOperationName *lhs,
574                           const RegisteredOperationName *rhs) {
575                          return lhs->getIdentifier().compare(
576                              rhs->getIdentifier());
577                        });
578 
579   return result;
580 }
581 
582 bool MLIRContext::isOperationRegistered(StringRef name) {
583   return OperationName(name, this).isRegistered();
584 }
585 
586 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
587   auto &impl = context->getImpl();
588   assert(impl.multiThreadedExecutionContext == 0 &&
589          "Registering a new type kind while in a multi-threaded execution "
590          "context");
591   auto *newInfo =
592       new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>())
593           AbstractType(std::move(typeInfo));
594   if (!impl.registeredTypes.insert({typeID, newInfo}).second)
595     llvm::report_fatal_error("Dialect Type already registered.");
596 }
597 
598 void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
599   auto &impl = context->getImpl();
600   assert(impl.multiThreadedExecutionContext == 0 &&
601          "Registering a new attribute kind while in a multi-threaded execution "
602          "context");
603   auto *newInfo =
604       new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>())
605           AbstractAttribute(std::move(attrInfo));
606   if (!impl.registeredAttributes.insert({typeID, newInfo}).second)
607     llvm::report_fatal_error("Dialect Attribute already registered.");
608 }
609 
610 //===----------------------------------------------------------------------===//
611 // AbstractAttribute
612 //===----------------------------------------------------------------------===//
613 
614 /// Get the dialect that registered the attribute with the provided typeid.
615 const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID,
616                                                    MLIRContext *context) {
617   const AbstractAttribute *abstract = lookupMutable(typeID, context);
618   if (!abstract)
619     llvm::report_fatal_error("Trying to create an Attribute that was not "
620                              "registered in this MLIRContext.");
621   return *abstract;
622 }
623 
624 AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID,
625                                                     MLIRContext *context) {
626   auto &impl = context->getImpl();
627   auto it = impl.registeredAttributes.find(typeID);
628   if (it == impl.registeredAttributes.end())
629     return nullptr;
630   return it->second;
631 }
632 
633 //===----------------------------------------------------------------------===//
634 // OperationName
635 //===----------------------------------------------------------------------===//
636 
637 OperationName::OperationName(StringRef name, MLIRContext *context) {
638   MLIRContextImpl &ctxImpl = context->getImpl();
639 
640   // Check for an existing name in read-only mode.
641   bool isMultithreadingEnabled = context->isMultithreadingEnabled();
642   if (isMultithreadingEnabled) {
643     llvm::sys::SmartScopedReader<true> contextLock(ctxImpl.operationInfoMutex);
644     auto it = ctxImpl.operations.find(name);
645     if (it != ctxImpl.operations.end()) {
646       impl = &it->second;
647       return;
648     }
649   }
650 
651   // Acquire a writer-lock so that we can safely create the new instance.
652   ScopedWriterLock lock(ctxImpl.operationInfoMutex, isMultithreadingEnabled);
653 
654   auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)});
655   if (it.second)
656     it.first->second.name = StringAttr::get(context, name);
657   impl = &it.first->second;
658 }
659 
660 StringRef OperationName::getDialectNamespace() const {
661   if (Dialect *dialect = getDialect())
662     return dialect->getNamespace();
663   return getStringRef().split('.').first;
664 }
665 
666 //===----------------------------------------------------------------------===//
667 // RegisteredOperationName
668 //===----------------------------------------------------------------------===//
669 
670 ParseResult
671 RegisteredOperationName::parseAssembly(OpAsmParser &parser,
672                                        OperationState &result) const {
673   return impl->parseAssemblyFn(parser, result);
674 }
675 
676 void RegisteredOperationName::insert(
677     StringRef name, Dialect &dialect, TypeID typeID,
678     ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
679     VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
680     GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
681     detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
682     ArrayRef<StringRef> attrNames) {
683   MLIRContext *ctx = dialect.getContext();
684   auto &ctxImpl = ctx->getImpl();
685   assert(ctxImpl.multiThreadedExecutionContext == 0 &&
686          "registering a new operation kind while in a multi-threaded execution "
687          "context");
688 
689   // Register the attribute names of this operation.
690   MutableArrayRef<StringAttr> cachedAttrNames;
691   if (!attrNames.empty()) {
692     cachedAttrNames = MutableArrayRef<StringAttr>(
693         ctxImpl.abstractDialectSymbolAllocator.Allocate<StringAttr>(
694             attrNames.size()),
695         attrNames.size());
696     for (unsigned i : llvm::seq<unsigned>(0, attrNames.size()))
697       new (&cachedAttrNames[i]) StringAttr(StringAttr::get(ctx, attrNames[i]));
698   }
699 
700   // Insert the operation info if it doesn't exist yet.
701   auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)});
702   if (it.second)
703     it.first->second.name = StringAttr::get(ctx, name);
704   OperationName::Impl &impl = it.first->second;
705 
706   if (impl.isRegistered()) {
707     llvm::errs() << "error: operation named '" << name
708                  << "' is already registered.\n";
709     abort();
710   }
711   ctxImpl.registeredOperations.push_back(RegisteredOperationName(&impl));
712 
713   // Update the registered info for this operation.
714   impl.dialect = &dialect;
715   impl.typeID = typeID;
716   impl.interfaceMap = std::move(interfaceMap);
717   impl.foldHookFn = std::move(foldHook);
718   impl.getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns);
719   impl.hasTraitFn = std::move(hasTrait);
720   impl.parseAssemblyFn = std::move(parseAssembly);
721   impl.printAssemblyFn = std::move(printAssembly);
722   impl.verifyInvariantsFn = std::move(verifyInvariants);
723   impl.attributeNames = cachedAttrNames;
724 }
725 
726 //===----------------------------------------------------------------------===//
727 // AbstractType
728 //===----------------------------------------------------------------------===//
729 
730 const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) {
731   const AbstractType *type = lookupMutable(typeID, context);
732   if (!type)
733     llvm::report_fatal_error(
734         "Trying to create a Type that was not registered in this MLIRContext.");
735   return *type;
736 }
737 
738 AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) {
739   auto &impl = context->getImpl();
740   auto it = impl.registeredTypes.find(typeID);
741   if (it == impl.registeredTypes.end())
742     return nullptr;
743   return it->second;
744 }
745 
746 //===----------------------------------------------------------------------===//
747 // Type uniquing
748 //===----------------------------------------------------------------------===//
749 
750 /// Returns the storage uniquer used for constructing type storage instances.
751 /// This should not be used directly.
752 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
753 
754 BFloat16Type BFloat16Type::get(MLIRContext *context) {
755   return context->getImpl().bf16Ty;
756 }
757 Float16Type Float16Type::get(MLIRContext *context) {
758   return context->getImpl().f16Ty;
759 }
760 Float32Type Float32Type::get(MLIRContext *context) {
761   return context->getImpl().f32Ty;
762 }
763 Float64Type Float64Type::get(MLIRContext *context) {
764   return context->getImpl().f64Ty;
765 }
766 Float80Type Float80Type::get(MLIRContext *context) {
767   return context->getImpl().f80Ty;
768 }
769 Float128Type Float128Type::get(MLIRContext *context) {
770   return context->getImpl().f128Ty;
771 }
772 
773 /// Get an instance of the IndexType.
774 IndexType IndexType::get(MLIRContext *context) {
775   return context->getImpl().indexTy;
776 }
777 
778 /// Return an existing integer type instance if one is cached within the
779 /// context.
780 static IntegerType
781 getCachedIntegerType(unsigned width,
782                      IntegerType::SignednessSemantics signedness,
783                      MLIRContext *context) {
784   if (signedness != IntegerType::Signless)
785     return IntegerType();
786 
787   switch (width) {
788   case 1:
789     return context->getImpl().int1Ty;
790   case 8:
791     return context->getImpl().int8Ty;
792   case 16:
793     return context->getImpl().int16Ty;
794   case 32:
795     return context->getImpl().int32Ty;
796   case 64:
797     return context->getImpl().int64Ty;
798   case 128:
799     return context->getImpl().int128Ty;
800   default:
801     return IntegerType();
802   }
803 }
804 
805 IntegerType IntegerType::get(MLIRContext *context, unsigned width,
806                              IntegerType::SignednessSemantics signedness) {
807   if (auto cached = getCachedIntegerType(width, signedness, context))
808     return cached;
809   return Base::get(context, width, signedness);
810 }
811 
812 IntegerType
813 IntegerType::getChecked(function_ref<InFlightDiagnostic()> emitError,
814                         MLIRContext *context, unsigned width,
815                         SignednessSemantics signedness) {
816   if (auto cached = getCachedIntegerType(width, signedness, context))
817     return cached;
818   return Base::getChecked(emitError, context, width, signedness);
819 }
820 
821 /// Get an instance of the NoneType.
822 NoneType NoneType::get(MLIRContext *context) {
823   if (NoneType cachedInst = context->getImpl().noneType)
824     return cachedInst;
825   // Note: May happen when initializing the singleton attributes of the builtin
826   // dialect.
827   return Base::get(context);
828 }
829 
830 //===----------------------------------------------------------------------===//
831 // Attribute uniquing
832 //===----------------------------------------------------------------------===//
833 
834 /// Returns the storage uniquer used for constructing attribute storage
835 /// instances. This should not be used directly.
836 StorageUniquer &MLIRContext::getAttributeUniquer() {
837   return getImpl().attributeUniquer;
838 }
839 
840 /// Initialize the given attribute storage instance.
841 void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage,
842                                                   MLIRContext *ctx,
843                                                   TypeID attrID) {
844   storage->initializeAbstractAttribute(AbstractAttribute::lookup(attrID, ctx));
845 
846   // If the attribute did not provide a type, then default to NoneType.
847   if (!storage->getType())
848     storage->setType(NoneType::get(ctx));
849 }
850 
851 BoolAttr BoolAttr::get(MLIRContext *context, bool value) {
852   return value ? context->getImpl().trueAttr : context->getImpl().falseAttr;
853 }
854 
855 UnitAttr UnitAttr::get(MLIRContext *context) {
856   return context->getImpl().unitAttr;
857 }
858 
859 UnknownLoc UnknownLoc::get(MLIRContext *context) {
860   return context->getImpl().unknownLocAttr;
861 }
862 
863 /// Return empty dictionary.
864 DictionaryAttr DictionaryAttr::getEmpty(MLIRContext *context) {
865   return context->getImpl().emptyDictionaryAttr;
866 }
867 
868 void StringAttrStorage::initialize(MLIRContext *context) {
869   // Check for a dialect namespace prefix, if there isn't one we don't need to
870   // do any additional initialization.
871   auto dialectNamePair = value.split('.');
872   if (dialectNamePair.first.empty() || dialectNamePair.second.empty())
873     return;
874 
875   // If one exists, we check to see if this dialect is loaded. If it is, we set
876   // the dialect now, if it isn't we record this storage for initialization
877   // later if the dialect ever gets loaded.
878   if ((referencedDialect = context->getLoadedDialect(dialectNamePair.first)))
879     return;
880 
881   MLIRContextImpl &impl = context->getImpl();
882   llvm::sys::SmartScopedLock<true> lock(impl.dialectRefStrAttrMutex);
883   impl.dialectReferencingStrAttrs[dialectNamePair.first].push_back(this);
884 }
885 
886 /// Return an empty string.
887 StringAttr StringAttr::get(MLIRContext *context) {
888   return context->getImpl().emptyStringAttr;
889 }
890 
891 //===----------------------------------------------------------------------===//
892 // AffineMap uniquing
893 //===----------------------------------------------------------------------===//
894 
895 StorageUniquer &MLIRContext::getAffineUniquer() {
896   return getImpl().affineUniquer;
897 }
898 
899 AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount,
900                              ArrayRef<AffineExpr> results,
901                              MLIRContext *context) {
902   auto &impl = context->getImpl();
903   auto *storage = impl.affineUniquer.get<AffineMapStorage>(
904       [&](AffineMapStorage *storage) { storage->context = context; }, dimCount,
905       symbolCount, results);
906   return AffineMap(storage);
907 }
908 
909 /// Check whether the arguments passed to the AffineMap::get() are consistent.
910 /// This method checks whether the highest index of dimensional identifier
911 /// present in result expressions is less than `dimCount` and the highest index
912 /// of symbolic identifier present in result expressions is less than
913 /// `symbolCount`.
914 LLVM_ATTRIBUTE_UNUSED static bool
915 willBeValidAffineMap(unsigned dimCount, unsigned symbolCount,
916                      ArrayRef<AffineExpr> results) {
917   int64_t maxDimPosition = -1;
918   int64_t maxSymbolPosition = -1;
919   getMaxDimAndSymbol(ArrayRef<ArrayRef<AffineExpr>>(results), maxDimPosition,
920                      maxSymbolPosition);
921   if ((maxDimPosition >= dimCount) || (maxSymbolPosition >= symbolCount)) {
922     LLVM_DEBUG(
923         llvm::dbgs()
924         << "maximum dimensional identifier position in result expression must "
925            "be less than `dimCount` and maximum symbolic identifier position "
926            "in result expression must be less than `symbolCount`\n");
927     return false;
928   }
929   return true;
930 }
931 
932 AffineMap AffineMap::get(MLIRContext *context) {
933   return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context);
934 }
935 
936 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
937                          MLIRContext *context) {
938   return getImpl(dimCount, symbolCount, /*results=*/{}, context);
939 }
940 
941 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
942                          AffineExpr result) {
943   assert(willBeValidAffineMap(dimCount, symbolCount, {result}));
944   return getImpl(dimCount, symbolCount, {result}, result.getContext());
945 }
946 
947 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
948                          ArrayRef<AffineExpr> results, MLIRContext *context) {
949   assert(willBeValidAffineMap(dimCount, symbolCount, results));
950   return getImpl(dimCount, symbolCount, results, context);
951 }
952 
953 //===----------------------------------------------------------------------===//
954 // Integer Sets: these are allocated into the bump pointer, and are immutable.
955 // Unlike AffineMap's, these are uniqued only if they are small.
956 //===----------------------------------------------------------------------===//
957 
958 IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount,
959                            ArrayRef<AffineExpr> constraints,
960                            ArrayRef<bool> eqFlags) {
961   // The number of constraints can't be zero.
962   assert(!constraints.empty());
963   assert(constraints.size() == eqFlags.size());
964 
965   auto &impl = constraints[0].getContext()->getImpl();
966   auto *storage = impl.affineUniquer.get<IntegerSetStorage>(
967       [](IntegerSetStorage *) {}, dimCount, symbolCount, constraints, eqFlags);
968   return IntegerSet(storage);
969 }
970 
971 //===----------------------------------------------------------------------===//
972 // StorageUniquerSupport
973 //===----------------------------------------------------------------------===//
974 
975 /// Utility method to generate a callback that can be used to generate a
976 /// diagnostic when checking the construction invariants of a storage object.
977 /// This is defined out-of-line to avoid the need to include Location.h.
978 llvm::unique_function<InFlightDiagnostic()>
979 mlir::detail::getDefaultDiagnosticEmitFn(MLIRContext *ctx) {
980   return [ctx] { return emitError(UnknownLoc::get(ctx)); };
981 }
982 llvm::unique_function<InFlightDiagnostic()>
983 mlir::detail::getDefaultDiagnosticEmitFn(const Location &loc) {
984   return [=] { return emitError(loc); };
985 }
986