1 //===- MLIRContext.cpp - MLIR Type Classes --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/MLIRContext.h"
10 #include "AffineExprDetail.h"
11 #include "AffineMapDetail.h"
12 #include "AttributeDetail.h"
13 #include "IntegerSetDetail.h"
14 #include "TypeDetail.h"
15 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/BuiltinDialect.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/IR/IntegerSet.h"
22 #include "mlir/IR/Location.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/IR/Types.h"
25 #include "mlir/Support/DebugAction.h"
26 #include "llvm/ADT/DenseMap.h"
27 #include "llvm/ADT/DenseSet.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/ADT/SmallString.h"
30 #include "llvm/ADT/StringSet.h"
31 #include "llvm/ADT/Twine.h"
32 #include "llvm/Support/Allocator.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/Mutex.h"
36 #include "llvm/Support/RWMutex.h"
37 #include "llvm/Support/ThreadPool.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include <memory>
40 
41 #define DEBUG_TYPE "mlircontext"
42 
43 using namespace mlir;
44 using namespace mlir::detail;
45 
46 //===----------------------------------------------------------------------===//
47 // MLIRContext CommandLine Options
48 //===----------------------------------------------------------------------===//
49 
50 namespace {
51 /// This struct contains command line options that can be used to initialize
52 /// various bits of an MLIRContext. This uses a struct wrapper to avoid the need
53 /// for global command line options.
54 struct MLIRContextOptions {
55   llvm::cl::opt<bool> disableThreading{
56       "mlir-disable-threading",
57       llvm::cl::desc("Disable multi-threading within MLIR, overrides any "
58                      "further call to MLIRContext::enableMultiThreading()")};
59 
60   llvm::cl::opt<bool> printOpOnDiagnostic{
61       "mlir-print-op-on-diagnostic",
62       llvm::cl::desc("When a diagnostic is emitted on an operation, also print "
63                      "the operation as an attached note"),
64       llvm::cl::init(true)};
65 
66   llvm::cl::opt<bool> printStackTraceOnDiagnostic{
67       "mlir-print-stacktrace-on-diagnostic",
68       llvm::cl::desc("When a diagnostic is emitted, also print the stack trace "
69                      "as an attached note")};
70 };
71 } // namespace
72 
73 static llvm::ManagedStatic<MLIRContextOptions> clOptions;
74 
75 static bool isThreadingGloballyDisabled() {
76 #if LLVM_ENABLE_THREADS != 0
77   return clOptions.isConstructed() && clOptions->disableThreading;
78 #else
79   return true;
80 #endif
81 }
82 
83 /// Register a set of useful command-line options that can be used to configure
84 /// various flags within the MLIRContext. These flags are used when constructing
85 /// an MLIR context for initialization.
86 void mlir::registerMLIRContextCLOptions() {
87   // Make sure that the options struct has been initialized.
88   *clOptions;
89 }
90 
91 //===----------------------------------------------------------------------===//
92 // Locking Utilities
93 //===----------------------------------------------------------------------===//
94 
95 namespace {
96 /// Utility writer lock that takes a runtime flag that specifies if we really
97 /// need to lock.
98 struct ScopedWriterLock {
99   ScopedWriterLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock)
100       : mutex(shouldLock ? &mutexParam : nullptr) {
101     if (mutex)
102       mutex->lock();
103   }
104   ~ScopedWriterLock() {
105     if (mutex)
106       mutex->unlock();
107   }
108   llvm::sys::SmartRWMutex<true> *mutex;
109 };
110 } // namespace
111 
112 //===----------------------------------------------------------------------===//
113 // MLIRContextImpl
114 //===----------------------------------------------------------------------===//
115 
116 namespace mlir {
117 /// This is the implementation of the MLIRContext class, using the pImpl idiom.
118 /// This class is completely private to this file, so everything is public.
119 class MLIRContextImpl {
120 public:
121   //===--------------------------------------------------------------------===//
122   // Debugging
123   //===--------------------------------------------------------------------===//
124 
125   /// An action manager for use within the context.
126   DebugActionManager debugActionManager;
127 
128   //===--------------------------------------------------------------------===//
129   // Diagnostics
130   //===--------------------------------------------------------------------===//
131   DiagnosticEngine diagEngine;
132 
133   //===--------------------------------------------------------------------===//
134   // Options
135   //===--------------------------------------------------------------------===//
136 
137   /// In most cases, creating operation in unregistered dialect is not desired
138   /// and indicate a misconfiguration of the compiler. This option enables to
139   /// detect such use cases
140   bool allowUnregisteredDialects = false;
141 
142   /// Enable support for multi-threading within MLIR.
143   bool threadingIsEnabled = true;
144 
145   /// Track if we are currently executing in a threaded execution environment
146   /// (like the pass-manager): this is only a debugging feature to help reducing
147   /// the chances of data races one some context APIs.
148 #ifndef NDEBUG
149   std::atomic<int> multiThreadedExecutionContext{0};
150 #endif
151 
152   /// If the operation should be attached to diagnostics printed via the
153   /// Operation::emit methods.
154   bool printOpOnDiagnostic = true;
155 
156   /// If the current stack trace should be attached when emitting diagnostics.
157   bool printStackTraceOnDiagnostic = false;
158 
159   //===--------------------------------------------------------------------===//
160   // Other
161   //===--------------------------------------------------------------------===//
162 
163   /// This points to the ThreadPool used when processing MLIR tasks in parallel.
164   /// It can't be nullptr when multi-threading is enabled. Otherwise if
165   /// multi-threading is disabled, and the threadpool wasn't externally provided
166   /// using `setThreadPool`, this will be nullptr.
167   llvm::ThreadPool *threadPool = nullptr;
168 
169   /// In case where the thread pool is owned by the context, this ensures
170   /// destruction with the context.
171   std::unique_ptr<llvm::ThreadPool> ownedThreadPool;
172 
173   /// This is a list of dialects that are created referring to this context.
174   /// The MLIRContext owns the objects.
175   DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects;
176   DialectRegistry dialectsRegistry;
177 
178   /// An allocator used for AbstractAttribute and AbstractType objects.
179   llvm::BumpPtrAllocator abstractDialectSymbolAllocator;
180 
181   /// This is a mapping from operation name to the operation info describing it.
182   llvm::StringMap<OperationName::Impl> operations;
183 
184   /// A vector of operation info specifically for registered operations.
185   llvm::StringMap<RegisteredOperationName> registeredOperations;
186 
187   /// This is a sorted container of registered operations for a deterministic
188   /// and efficient `getRegisteredOperations` implementation.
189   SmallVector<RegisteredOperationName, 0> sortedRegisteredOperations;
190 
191   /// A mutex used when accessing operation information.
192   llvm::sys::SmartRWMutex<true> operationInfoMutex;
193 
194   //===--------------------------------------------------------------------===//
195   // Affine uniquing
196   //===--------------------------------------------------------------------===//
197 
198   // Affine expression, map and integer set uniquing.
199   StorageUniquer affineUniquer;
200 
201   //===--------------------------------------------------------------------===//
202   // Type uniquing
203   //===--------------------------------------------------------------------===//
204 
205   DenseMap<TypeID, AbstractType *> registeredTypes;
206   StorageUniquer typeUniquer;
207 
208   /// Cached Type Instances.
209   BFloat16Type bf16Ty;
210   Float16Type f16Ty;
211   Float32Type f32Ty;
212   Float64Type f64Ty;
213   Float80Type f80Ty;
214   Float128Type f128Ty;
215   IndexType indexTy;
216   IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
217   NoneType noneType;
218 
219   //===--------------------------------------------------------------------===//
220   // Attribute uniquing
221   //===--------------------------------------------------------------------===//
222 
223   DenseMap<TypeID, AbstractAttribute *> registeredAttributes;
224   StorageUniquer attributeUniquer;
225 
226   /// Cached Attribute Instances.
227   BoolAttr falseAttr, trueAttr;
228   UnitAttr unitAttr;
229   UnknownLoc unknownLocAttr;
230   DictionaryAttr emptyDictionaryAttr;
231   StringAttr emptyStringAttr;
232 
233   /// Map of string attributes that may reference a dialect, that are awaiting
234   /// that dialect to be loaded.
235   llvm::sys::SmartMutex<true> dialectRefStrAttrMutex;
236   DenseMap<StringRef, SmallVector<StringAttrStorage *>>
237       dialectReferencingStrAttrs;
238 
239 public:
240   MLIRContextImpl(bool threadingIsEnabled)
241       : threadingIsEnabled(threadingIsEnabled) {
242     if (threadingIsEnabled) {
243       ownedThreadPool = std::make_unique<llvm::ThreadPool>();
244       threadPool = ownedThreadPool.get();
245     }
246   }
247   ~MLIRContextImpl() {
248     for (auto typeMapping : registeredTypes)
249       typeMapping.second->~AbstractType();
250     for (auto attrMapping : registeredAttributes)
251       attrMapping.second->~AbstractAttribute();
252   }
253 };
254 } // namespace mlir
255 
256 MLIRContext::MLIRContext(Threading setting)
257     : MLIRContext(DialectRegistry(), setting) {}
258 
259 MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
260     : impl(new MLIRContextImpl(setting == Threading::ENABLED &&
261                                !isThreadingGloballyDisabled())) {
262   // Initialize values based on the command line flags if they were provided.
263   if (clOptions.isConstructed()) {
264     printOpOnDiagnostic(clOptions->printOpOnDiagnostic);
265     printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic);
266   }
267 
268   // Pre-populate the registry.
269   registry.appendTo(impl->dialectsRegistry);
270 
271   // Ensure the builtin dialect is always pre-loaded.
272   getOrLoadDialect<BuiltinDialect>();
273 
274   // Initialize several common attributes and types to avoid the need to lock
275   // the context when accessing them.
276 
277   //// Types.
278   /// Floating-point Types.
279   impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
280   impl->f16Ty = TypeUniquer::get<Float16Type>(this);
281   impl->f32Ty = TypeUniquer::get<Float32Type>(this);
282   impl->f64Ty = TypeUniquer::get<Float64Type>(this);
283   impl->f80Ty = TypeUniquer::get<Float80Type>(this);
284   impl->f128Ty = TypeUniquer::get<Float128Type>(this);
285   /// Index Type.
286   impl->indexTy = TypeUniquer::get<IndexType>(this);
287   /// Integer Types.
288   impl->int1Ty = TypeUniquer::get<IntegerType>(this, 1, IntegerType::Signless);
289   impl->int8Ty = TypeUniquer::get<IntegerType>(this, 8, IntegerType::Signless);
290   impl->int16Ty =
291       TypeUniquer::get<IntegerType>(this, 16, IntegerType::Signless);
292   impl->int32Ty =
293       TypeUniquer::get<IntegerType>(this, 32, IntegerType::Signless);
294   impl->int64Ty =
295       TypeUniquer::get<IntegerType>(this, 64, IntegerType::Signless);
296   impl->int128Ty =
297       TypeUniquer::get<IntegerType>(this, 128, IntegerType::Signless);
298   /// None Type.
299   impl->noneType = TypeUniquer::get<NoneType>(this);
300 
301   //// Attributes.
302   //// Note: These must be registered after the types as they may generate one
303   //// of the above types internally.
304   /// Unknown Location Attribute.
305   impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(this);
306   /// Bool Attributes.
307   impl->falseAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, false);
308   impl->trueAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, true);
309   /// Unit Attribute.
310   impl->unitAttr = AttributeUniquer::get<UnitAttr>(this);
311   /// The empty dictionary attribute.
312   impl->emptyDictionaryAttr = DictionaryAttr::getEmptyUnchecked(this);
313   /// The empty string attribute.
314   impl->emptyStringAttr = StringAttr::getEmptyStringAttrUnchecked(this);
315 
316   // Register the affine storage objects with the uniquer.
317   impl->affineUniquer
318       .registerParametricStorageType<AffineBinaryOpExprStorage>();
319   impl->affineUniquer
320       .registerParametricStorageType<AffineConstantExprStorage>();
321   impl->affineUniquer.registerParametricStorageType<AffineDimExprStorage>();
322   impl->affineUniquer.registerParametricStorageType<AffineMapStorage>();
323   impl->affineUniquer.registerParametricStorageType<IntegerSetStorage>();
324 }
325 
326 MLIRContext::~MLIRContext() = default;
327 
328 /// Copy the specified array of elements into memory managed by the provided
329 /// bump pointer allocator.  This assumes the elements are all PODs.
330 template <typename T>
331 static ArrayRef<T> copyArrayRefInto(llvm::BumpPtrAllocator &allocator,
332                                     ArrayRef<T> elements) {
333   auto result = allocator.Allocate<T>(elements.size());
334   std::uninitialized_copy(elements.begin(), elements.end(), result);
335   return ArrayRef<T>(result, elements.size());
336 }
337 
338 //===----------------------------------------------------------------------===//
339 // Debugging
340 //===----------------------------------------------------------------------===//
341 
342 DebugActionManager &MLIRContext::getDebugActionManager() {
343   return getImpl().debugActionManager;
344 }
345 
346 //===----------------------------------------------------------------------===//
347 // Diagnostic Handlers
348 //===----------------------------------------------------------------------===//
349 
350 /// Returns the diagnostic engine for this context.
351 DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
352 
353 //===----------------------------------------------------------------------===//
354 // Dialect and Operation Registration
355 //===----------------------------------------------------------------------===//
356 
357 void MLIRContext::appendDialectRegistry(const DialectRegistry &registry) {
358   registry.appendTo(impl->dialectsRegistry);
359 
360   // For the already loaded dialects, register the interfaces immediately.
361   for (const auto &kvp : impl->loadedDialects)
362     registry.registerDelayedInterfaces(kvp.second.get());
363 }
364 
365 const DialectRegistry &MLIRContext::getDialectRegistry() {
366   return impl->dialectsRegistry;
367 }
368 
369 /// Return information about all registered IR dialects.
370 std::vector<Dialect *> MLIRContext::getLoadedDialects() {
371   std::vector<Dialect *> result;
372   result.reserve(impl->loadedDialects.size());
373   for (auto &dialect : impl->loadedDialects)
374     result.push_back(dialect.second.get());
375   llvm::array_pod_sort(result.begin(), result.end(),
376                        [](Dialect *const *lhs, Dialect *const *rhs) -> int {
377                          return (*lhs)->getNamespace() < (*rhs)->getNamespace();
378                        });
379   return result;
380 }
381 std::vector<StringRef> MLIRContext::getAvailableDialects() {
382   std::vector<StringRef> result;
383   for (auto dialect : impl->dialectsRegistry.getDialectNames())
384     result.push_back(dialect);
385   return result;
386 }
387 
388 /// Get a registered IR dialect with the given namespace. If none is found,
389 /// then return nullptr.
390 Dialect *MLIRContext::getLoadedDialect(StringRef name) {
391   // Dialects are sorted by name, so we can use binary search for lookup.
392   auto it = impl->loadedDialects.find(name);
393   return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr;
394 }
395 
396 Dialect *MLIRContext::getOrLoadDialect(StringRef name) {
397   Dialect *dialect = getLoadedDialect(name);
398   if (dialect)
399     return dialect;
400   DialectAllocatorFunctionRef allocator =
401       impl->dialectsRegistry.getDialectAllocator(name);
402   return allocator ? allocator(this) : nullptr;
403 }
404 
405 /// Get a dialect for the provided namespace and TypeID: abort the program if a
406 /// dialect exist for this namespace with different TypeID. Returns a pointer to
407 /// the dialect owned by the context.
408 Dialect *
409 MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
410                               function_ref<std::unique_ptr<Dialect>()> ctor) {
411   auto &impl = getImpl();
412   // Get the correct insertion position sorted by namespace.
413   auto dialectIt = impl.loadedDialects.find(dialectNamespace);
414 
415   if (dialectIt == impl.loadedDialects.end()) {
416     LLVM_DEBUG(llvm::dbgs()
417                << "Load new dialect in Context " << dialectNamespace << "\n");
418 #ifndef NDEBUG
419     if (impl.multiThreadedExecutionContext != 0)
420       llvm::report_fatal_error(
421           "Loading a dialect (" + dialectNamespace +
422           ") while in a multi-threaded execution context (maybe "
423           "the PassManager): this can indicate a "
424           "missing `dependentDialects` in a pass for example.");
425 #endif
426     std::unique_ptr<Dialect> &dialect =
427         impl.loadedDialects.insert({dialectNamespace, ctor()}).first->second;
428     assert(dialect && "dialect ctor failed");
429 
430     // Refresh all the identifiers dialect field, this catches cases where a
431     // dialect may be loaded after identifier prefixed with this dialect name
432     // were already created.
433     auto stringAttrsIt = impl.dialectReferencingStrAttrs.find(dialectNamespace);
434     if (stringAttrsIt != impl.dialectReferencingStrAttrs.end()) {
435       for (StringAttrStorage *storage : stringAttrsIt->second)
436         storage->referencedDialect = dialect.get();
437       impl.dialectReferencingStrAttrs.erase(stringAttrsIt);
438     }
439 
440     // Actually register the interfaces with delayed registration.
441     impl.dialectsRegistry.registerDelayedInterfaces(dialect.get());
442     return dialect.get();
443   }
444 
445   // Abort if dialect with namespace has already been registered.
446   std::unique_ptr<Dialect> &dialect = dialectIt->second;
447   if (dialect->getTypeID() != dialectID)
448     llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
449                              "' has already been registered");
450 
451   return dialect.get();
452 }
453 
454 void MLIRContext::loadAllAvailableDialects() {
455   for (StringRef name : getAvailableDialects())
456     getOrLoadDialect(name);
457 }
458 
459 llvm::hash_code MLIRContext::getRegistryHash() {
460   llvm::hash_code hash(0);
461   // Factor in number of loaded dialects, attributes, operations, types.
462   hash = llvm::hash_combine(hash, impl->loadedDialects.size());
463   hash = llvm::hash_combine(hash, impl->registeredAttributes.size());
464   hash = llvm::hash_combine(hash, impl->registeredOperations.size());
465   hash = llvm::hash_combine(hash, impl->registeredTypes.size());
466   return hash;
467 }
468 
469 bool MLIRContext::allowsUnregisteredDialects() {
470   return impl->allowUnregisteredDialects;
471 }
472 
473 void MLIRContext::allowUnregisteredDialects(bool allowing) {
474   impl->allowUnregisteredDialects = allowing;
475 }
476 
477 /// Return true if multi-threading is enabled by the context.
478 bool MLIRContext::isMultithreadingEnabled() {
479   return impl->threadingIsEnabled && llvm::llvm_is_multithreaded();
480 }
481 
482 /// Set the flag specifying if multi-threading is disabled by the context.
483 void MLIRContext::disableMultithreading(bool disable) {
484   // This API can be overridden by the global debugging flag
485   // --mlir-disable-threading
486   if (isThreadingGloballyDisabled())
487     return;
488 
489   impl->threadingIsEnabled = !disable;
490 
491   // Update the threading mode for each of the uniquers.
492   impl->affineUniquer.disableMultithreading(disable);
493   impl->attributeUniquer.disableMultithreading(disable);
494   impl->typeUniquer.disableMultithreading(disable);
495 
496   // Destroy thread pool (stop all threads) if it is no longer needed, or create
497   // a new one if multithreading was re-enabled.
498   if (disable) {
499     // If the thread pool is owned, explicitly set it to nullptr to avoid
500     // keeping a dangling pointer around. If the thread pool is externally
501     // owned, we don't do anything.
502     if (impl->ownedThreadPool) {
503       assert(impl->threadPool);
504       impl->threadPool = nullptr;
505       impl->ownedThreadPool.reset();
506     }
507   } else if (!impl->threadPool) {
508     // The thread pool isn't externally provided.
509     assert(!impl->ownedThreadPool);
510     impl->ownedThreadPool = std::make_unique<llvm::ThreadPool>();
511     impl->threadPool = impl->ownedThreadPool.get();
512   }
513 }
514 
515 void MLIRContext::setThreadPool(llvm::ThreadPool &pool) {
516   assert(!isMultithreadingEnabled() &&
517          "expected multi-threading to be disabled when setting a ThreadPool");
518   impl->threadPool = &pool;
519   impl->ownedThreadPool.reset();
520   enableMultithreading();
521 }
522 
523 unsigned MLIRContext::getNumThreads() {
524   if (isMultithreadingEnabled()) {
525     assert(impl->threadPool &&
526            "multi-threading is enabled but threadpool not set");
527     return impl->threadPool->getThreadCount();
528   }
529   // No multithreading or active thread pool. Return 1 thread.
530   return 1;
531 }
532 
533 llvm::ThreadPool &MLIRContext::getThreadPool() {
534   assert(isMultithreadingEnabled() &&
535          "expected multi-threading to be enabled within the context");
536   assert(impl->threadPool &&
537          "multi-threading is enabled but threadpool not set");
538   return *impl->threadPool;
539 }
540 
541 void MLIRContext::enterMultiThreadedExecution() {
542 #ifndef NDEBUG
543   ++impl->multiThreadedExecutionContext;
544 #endif
545 }
546 void MLIRContext::exitMultiThreadedExecution() {
547 #ifndef NDEBUG
548   --impl->multiThreadedExecutionContext;
549 #endif
550 }
551 
552 /// Return true if we should attach the operation to diagnostics emitted via
553 /// Operation::emit.
554 bool MLIRContext::shouldPrintOpOnDiagnostic() {
555   return impl->printOpOnDiagnostic;
556 }
557 
558 /// Set the flag specifying if we should attach the operation to diagnostics
559 /// emitted via Operation::emit.
560 void MLIRContext::printOpOnDiagnostic(bool enable) {
561   impl->printOpOnDiagnostic = enable;
562 }
563 
564 /// Return true if we should attach the current stacktrace to diagnostics when
565 /// emitted.
566 bool MLIRContext::shouldPrintStackTraceOnDiagnostic() {
567   return impl->printStackTraceOnDiagnostic;
568 }
569 
570 /// Set the flag specifying if we should attach the current stacktrace when
571 /// emitting diagnostics.
572 void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
573   impl->printStackTraceOnDiagnostic = enable;
574 }
575 
576 /// Return information about all registered operations.
577 ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
578   return impl->sortedRegisteredOperations;
579 }
580 
581 bool MLIRContext::isOperationRegistered(StringRef name) {
582   return RegisteredOperationName::lookup(name, this).hasValue();
583 }
584 
585 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
586   auto &impl = context->getImpl();
587   assert(impl.multiThreadedExecutionContext == 0 &&
588          "Registering a new type kind while in a multi-threaded execution "
589          "context");
590   auto *newInfo =
591       new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>())
592           AbstractType(std::move(typeInfo));
593   if (!impl.registeredTypes.insert({typeID, newInfo}).second)
594     llvm::report_fatal_error("Dialect Type already registered.");
595 }
596 
597 void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
598   auto &impl = context->getImpl();
599   assert(impl.multiThreadedExecutionContext == 0 &&
600          "Registering a new attribute kind while in a multi-threaded execution "
601          "context");
602   auto *newInfo =
603       new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>())
604           AbstractAttribute(std::move(attrInfo));
605   if (!impl.registeredAttributes.insert({typeID, newInfo}).second)
606     llvm::report_fatal_error("Dialect Attribute already registered.");
607 }
608 
609 //===----------------------------------------------------------------------===//
610 // AbstractAttribute
611 //===----------------------------------------------------------------------===//
612 
613 /// Get the dialect that registered the attribute with the provided typeid.
614 const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID,
615                                                    MLIRContext *context) {
616   const AbstractAttribute *abstract = lookupMutable(typeID, context);
617   if (!abstract)
618     llvm::report_fatal_error("Trying to create an Attribute that was not "
619                              "registered in this MLIRContext.");
620   return *abstract;
621 }
622 
623 AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID,
624                                                     MLIRContext *context) {
625   auto &impl = context->getImpl();
626   auto it = impl.registeredAttributes.find(typeID);
627   if (it == impl.registeredAttributes.end())
628     return nullptr;
629   return it->second;
630 }
631 
632 //===----------------------------------------------------------------------===//
633 // OperationName
634 //===----------------------------------------------------------------------===//
635 
636 OperationName::OperationName(StringRef name, MLIRContext *context) {
637   MLIRContextImpl &ctxImpl = context->getImpl();
638 
639   // Check for an existing name in read-only mode.
640   bool isMultithreadingEnabled = context->isMultithreadingEnabled();
641   if (isMultithreadingEnabled) {
642     // Check the registered info map first. In the overwhelmingly common case,
643     // the entry will be in here and it also removes the need to acquire any
644     // locks.
645     auto registeredIt = ctxImpl.registeredOperations.find(name);
646     if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperations.end())) {
647       impl = registeredIt->second.impl;
648       return;
649     }
650 
651     llvm::sys::SmartScopedReader<true> contextLock(ctxImpl.operationInfoMutex);
652     auto it = ctxImpl.operations.find(name);
653     if (it != ctxImpl.operations.end()) {
654       impl = &it->second;
655       return;
656     }
657   }
658 
659   // Acquire a writer-lock so that we can safely create the new instance.
660   ScopedWriterLock lock(ctxImpl.operationInfoMutex, isMultithreadingEnabled);
661 
662   auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)});
663   if (it.second)
664     it.first->second.name = StringAttr::get(context, name);
665   impl = &it.first->second;
666 }
667 
668 StringRef OperationName::getDialectNamespace() const {
669   if (Dialect *dialect = getDialect())
670     return dialect->getNamespace();
671   return getStringRef().split('.').first;
672 }
673 
674 //===----------------------------------------------------------------------===//
675 // RegisteredOperationName
676 //===----------------------------------------------------------------------===//
677 
678 Optional<RegisteredOperationName>
679 RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) {
680   auto &impl = ctx->getImpl();
681   auto it = impl.registeredOperations.find(name);
682   if (it != impl.registeredOperations.end())
683     return it->getValue();
684   return llvm::None;
685 }
686 
687 ParseResult
688 RegisteredOperationName::parseAssembly(OpAsmParser &parser,
689                                        OperationState &result) const {
690   return impl->parseAssemblyFn(parser, result);
691 }
692 
693 void RegisteredOperationName::insert(
694     StringRef name, Dialect &dialect, TypeID typeID,
695     ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
696     VerifyInvariantsFn &&verifyInvariants,
697     VerifyRegionInvariantsFn &&verifyRegionInvariants, FoldHookFn &&foldHook,
698     GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
699     detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
700     ArrayRef<StringRef> attrNames) {
701   MLIRContext *ctx = dialect.getContext();
702   auto &ctxImpl = ctx->getImpl();
703   assert(ctxImpl.multiThreadedExecutionContext == 0 &&
704          "registering a new operation kind while in a multi-threaded execution "
705          "context");
706 
707   // Register the attribute names of this operation.
708   MutableArrayRef<StringAttr> cachedAttrNames;
709   if (!attrNames.empty()) {
710     cachedAttrNames = MutableArrayRef<StringAttr>(
711         ctxImpl.abstractDialectSymbolAllocator.Allocate<StringAttr>(
712             attrNames.size()),
713         attrNames.size());
714     for (unsigned i : llvm::seq<unsigned>(0, attrNames.size()))
715       new (&cachedAttrNames[i]) StringAttr(StringAttr::get(ctx, attrNames[i]));
716   }
717 
718   // Insert the operation info if it doesn't exist yet.
719   auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)});
720   if (it.second)
721     it.first->second.name = StringAttr::get(ctx, name);
722   OperationName::Impl &impl = it.first->second;
723 
724   if (impl.isRegistered()) {
725     llvm::errs() << "error: operation named '" << name
726                  << "' is already registered.\n";
727     abort();
728   }
729   auto emplaced = ctxImpl.registeredOperations.try_emplace(
730       name, RegisteredOperationName(&impl));
731   assert(emplaced.second && "operation name registration must be successful");
732 
733   // Add emplaced operation name to the sorted operations container.
734   RegisteredOperationName &value = emplaced.first->getValue();
735   ctxImpl.sortedRegisteredOperations.insert(
736       llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value,
737                         [](auto &lhs, auto &rhs) {
738                           return lhs.getIdentifier().compare(
739                               rhs.getIdentifier());
740                         }),
741       value);
742 
743   // Update the registered info for this operation.
744   impl.dialect = &dialect;
745   impl.typeID = typeID;
746   impl.interfaceMap = std::move(interfaceMap);
747   impl.foldHookFn = std::move(foldHook);
748   impl.getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns);
749   impl.hasTraitFn = std::move(hasTrait);
750   impl.parseAssemblyFn = std::move(parseAssembly);
751   impl.printAssemblyFn = std::move(printAssembly);
752   impl.verifyInvariantsFn = std::move(verifyInvariants);
753   impl.verifyRegionInvariantsFn = std::move(verifyRegionInvariants);
754   impl.attributeNames = cachedAttrNames;
755 }
756 
757 //===----------------------------------------------------------------------===//
758 // AbstractType
759 //===----------------------------------------------------------------------===//
760 
761 const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) {
762   const AbstractType *type = lookupMutable(typeID, context);
763   if (!type)
764     llvm::report_fatal_error(
765         "Trying to create a Type that was not registered in this MLIRContext.");
766   return *type;
767 }
768 
769 AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) {
770   auto &impl = context->getImpl();
771   auto it = impl.registeredTypes.find(typeID);
772   if (it == impl.registeredTypes.end())
773     return nullptr;
774   return it->second;
775 }
776 
777 //===----------------------------------------------------------------------===//
778 // Type uniquing
779 //===----------------------------------------------------------------------===//
780 
781 /// Returns the storage uniquer used for constructing type storage instances.
782 /// This should not be used directly.
783 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
784 
785 BFloat16Type BFloat16Type::get(MLIRContext *context) {
786   return context->getImpl().bf16Ty;
787 }
788 Float16Type Float16Type::get(MLIRContext *context) {
789   return context->getImpl().f16Ty;
790 }
791 Float32Type Float32Type::get(MLIRContext *context) {
792   return context->getImpl().f32Ty;
793 }
794 Float64Type Float64Type::get(MLIRContext *context) {
795   return context->getImpl().f64Ty;
796 }
797 Float80Type Float80Type::get(MLIRContext *context) {
798   return context->getImpl().f80Ty;
799 }
800 Float128Type Float128Type::get(MLIRContext *context) {
801   return context->getImpl().f128Ty;
802 }
803 
804 /// Get an instance of the IndexType.
805 IndexType IndexType::get(MLIRContext *context) {
806   return context->getImpl().indexTy;
807 }
808 
809 /// Return an existing integer type instance if one is cached within the
810 /// context.
811 static IntegerType
812 getCachedIntegerType(unsigned width,
813                      IntegerType::SignednessSemantics signedness,
814                      MLIRContext *context) {
815   if (signedness != IntegerType::Signless)
816     return IntegerType();
817 
818   switch (width) {
819   case 1:
820     return context->getImpl().int1Ty;
821   case 8:
822     return context->getImpl().int8Ty;
823   case 16:
824     return context->getImpl().int16Ty;
825   case 32:
826     return context->getImpl().int32Ty;
827   case 64:
828     return context->getImpl().int64Ty;
829   case 128:
830     return context->getImpl().int128Ty;
831   default:
832     return IntegerType();
833   }
834 }
835 
836 IntegerType IntegerType::get(MLIRContext *context, unsigned width,
837                              IntegerType::SignednessSemantics signedness) {
838   if (auto cached = getCachedIntegerType(width, signedness, context))
839     return cached;
840   return Base::get(context, width, signedness);
841 }
842 
843 IntegerType
844 IntegerType::getChecked(function_ref<InFlightDiagnostic()> emitError,
845                         MLIRContext *context, unsigned width,
846                         SignednessSemantics signedness) {
847   if (auto cached = getCachedIntegerType(width, signedness, context))
848     return cached;
849   return Base::getChecked(emitError, context, width, signedness);
850 }
851 
852 /// Get an instance of the NoneType.
853 NoneType NoneType::get(MLIRContext *context) {
854   if (NoneType cachedInst = context->getImpl().noneType)
855     return cachedInst;
856   // Note: May happen when initializing the singleton attributes of the builtin
857   // dialect.
858   return Base::get(context);
859 }
860 
861 //===----------------------------------------------------------------------===//
862 // Attribute uniquing
863 //===----------------------------------------------------------------------===//
864 
865 /// Returns the storage uniquer used for constructing attribute storage
866 /// instances. This should not be used directly.
867 StorageUniquer &MLIRContext::getAttributeUniquer() {
868   return getImpl().attributeUniquer;
869 }
870 
871 /// Initialize the given attribute storage instance.
872 void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage,
873                                                   MLIRContext *ctx,
874                                                   TypeID attrID) {
875   storage->initializeAbstractAttribute(AbstractAttribute::lookup(attrID, ctx));
876 
877   // If the attribute did not provide a type, then default to NoneType.
878   if (!storage->getType())
879     storage->setType(NoneType::get(ctx));
880 }
881 
882 BoolAttr BoolAttr::get(MLIRContext *context, bool value) {
883   return value ? context->getImpl().trueAttr : context->getImpl().falseAttr;
884 }
885 
886 UnitAttr UnitAttr::get(MLIRContext *context) {
887   return context->getImpl().unitAttr;
888 }
889 
890 UnknownLoc UnknownLoc::get(MLIRContext *context) {
891   return context->getImpl().unknownLocAttr;
892 }
893 
894 /// Return empty dictionary.
895 DictionaryAttr DictionaryAttr::getEmpty(MLIRContext *context) {
896   return context->getImpl().emptyDictionaryAttr;
897 }
898 
899 void StringAttrStorage::initialize(MLIRContext *context) {
900   // Check for a dialect namespace prefix, if there isn't one we don't need to
901   // do any additional initialization.
902   auto dialectNamePair = value.split('.');
903   if (dialectNamePair.first.empty() || dialectNamePair.second.empty())
904     return;
905 
906   // If one exists, we check to see if this dialect is loaded. If it is, we set
907   // the dialect now, if it isn't we record this storage for initialization
908   // later if the dialect ever gets loaded.
909   if ((referencedDialect = context->getLoadedDialect(dialectNamePair.first)))
910     return;
911 
912   MLIRContextImpl &impl = context->getImpl();
913   llvm::sys::SmartScopedLock<true> lock(impl.dialectRefStrAttrMutex);
914   impl.dialectReferencingStrAttrs[dialectNamePair.first].push_back(this);
915 }
916 
917 /// Return an empty string.
918 StringAttr StringAttr::get(MLIRContext *context) {
919   return context->getImpl().emptyStringAttr;
920 }
921 
922 //===----------------------------------------------------------------------===//
923 // AffineMap uniquing
924 //===----------------------------------------------------------------------===//
925 
926 StorageUniquer &MLIRContext::getAffineUniquer() {
927   return getImpl().affineUniquer;
928 }
929 
930 AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount,
931                              ArrayRef<AffineExpr> results,
932                              MLIRContext *context) {
933   auto &impl = context->getImpl();
934   auto *storage = impl.affineUniquer.get<AffineMapStorage>(
935       [&](AffineMapStorage *storage) { storage->context = context; }, dimCount,
936       symbolCount, results);
937   return AffineMap(storage);
938 }
939 
940 /// Check whether the arguments passed to the AffineMap::get() are consistent.
941 /// This method checks whether the highest index of dimensional identifier
942 /// present in result expressions is less than `dimCount` and the highest index
943 /// of symbolic identifier present in result expressions is less than
944 /// `symbolCount`.
945 LLVM_ATTRIBUTE_UNUSED static bool
946 willBeValidAffineMap(unsigned dimCount, unsigned symbolCount,
947                      ArrayRef<AffineExpr> results) {
948   int64_t maxDimPosition = -1;
949   int64_t maxSymbolPosition = -1;
950   getMaxDimAndSymbol(ArrayRef<ArrayRef<AffineExpr>>(results), maxDimPosition,
951                      maxSymbolPosition);
952   if ((maxDimPosition >= dimCount) || (maxSymbolPosition >= symbolCount)) {
953     LLVM_DEBUG(
954         llvm::dbgs()
955         << "maximum dimensional identifier position in result expression must "
956            "be less than `dimCount` and maximum symbolic identifier position "
957            "in result expression must be less than `symbolCount`\n");
958     return false;
959   }
960   return true;
961 }
962 
963 AffineMap AffineMap::get(MLIRContext *context) {
964   return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context);
965 }
966 
967 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
968                          MLIRContext *context) {
969   return getImpl(dimCount, symbolCount, /*results=*/{}, context);
970 }
971 
972 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
973                          AffineExpr result) {
974   assert(willBeValidAffineMap(dimCount, symbolCount, {result}));
975   return getImpl(dimCount, symbolCount, {result}, result.getContext());
976 }
977 
978 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
979                          ArrayRef<AffineExpr> results, MLIRContext *context) {
980   assert(willBeValidAffineMap(dimCount, symbolCount, results));
981   return getImpl(dimCount, symbolCount, results, context);
982 }
983 
984 //===----------------------------------------------------------------------===//
985 // Integer Sets: these are allocated into the bump pointer, and are immutable.
986 // Unlike AffineMap's, these are uniqued only if they are small.
987 //===----------------------------------------------------------------------===//
988 
989 IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount,
990                            ArrayRef<AffineExpr> constraints,
991                            ArrayRef<bool> eqFlags) {
992   // The number of constraints can't be zero.
993   assert(!constraints.empty());
994   assert(constraints.size() == eqFlags.size());
995 
996   auto &impl = constraints[0].getContext()->getImpl();
997   auto *storage = impl.affineUniquer.get<IntegerSetStorage>(
998       [](IntegerSetStorage *) {}, dimCount, symbolCount, constraints, eqFlags);
999   return IntegerSet(storage);
1000 }
1001 
1002 //===----------------------------------------------------------------------===//
1003 // StorageUniquerSupport
1004 //===----------------------------------------------------------------------===//
1005 
1006 /// Utility method to generate a callback that can be used to generate a
1007 /// diagnostic when checking the construction invariants of a storage object.
1008 /// This is defined out-of-line to avoid the need to include Location.h.
1009 llvm::unique_function<InFlightDiagnostic()>
1010 mlir::detail::getDefaultDiagnosticEmitFn(MLIRContext *ctx) {
1011   return [ctx] { return emitError(UnknownLoc::get(ctx)); };
1012 }
1013 llvm::unique_function<InFlightDiagnostic()>
1014 mlir::detail::getDefaultDiagnosticEmitFn(const Location &loc) {
1015   return [=] { return emitError(loc); };
1016 }
1017