1 //===- MLIRContext.cpp - MLIR Type Classes --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/MLIRContext.h"
10 #include "AffineExprDetail.h"
11 #include "AffineMapDetail.h"
12 #include "AttributeDetail.h"
13 #include "IntegerSetDetail.h"
14 #include "TypeDetail.h"
15 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/BuiltinDialect.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/IR/IntegerSet.h"
22 #include "mlir/IR/Location.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/IR/Types.h"
25 #include "mlir/Support/DebugAction.h"
26 #include "llvm/ADT/DenseMap.h"
27 #include "llvm/ADT/DenseSet.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/ADT/SmallString.h"
30 #include "llvm/ADT/StringSet.h"
31 #include "llvm/ADT/Twine.h"
32 #include "llvm/Support/Allocator.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/Mutex.h"
36 #include "llvm/Support/RWMutex.h"
37 #include "llvm/Support/ThreadPool.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include <memory>
40 
41 #define DEBUG_TYPE "mlircontext"
42 
43 using namespace mlir;
44 using namespace mlir::detail;
45 
46 //===----------------------------------------------------------------------===//
47 // MLIRContext CommandLine Options
48 //===----------------------------------------------------------------------===//
49 
50 namespace {
51 /// This struct contains command line options that can be used to initialize
52 /// various bits of an MLIRContext. This uses a struct wrapper to avoid the need
53 /// for global command line options.
54 struct MLIRContextOptions {
55   llvm::cl::opt<bool> disableThreading{
56       "mlir-disable-threading",
57       llvm::cl::desc("Disable multi-threading within MLIR, overrides any "
58                      "further call to MLIRContext::enableMultiThreading()")};
59 
60   llvm::cl::opt<bool> printOpOnDiagnostic{
61       "mlir-print-op-on-diagnostic",
62       llvm::cl::desc("When a diagnostic is emitted on an operation, also print "
63                      "the operation as an attached note"),
64       llvm::cl::init(true)};
65 
66   llvm::cl::opt<bool> printStackTraceOnDiagnostic{
67       "mlir-print-stacktrace-on-diagnostic",
68       llvm::cl::desc("When a diagnostic is emitted, also print the stack trace "
69                      "as an attached note")};
70 };
71 } // namespace
72 
73 static llvm::ManagedStatic<MLIRContextOptions> clOptions;
74 
isThreadingGloballyDisabled()75 static bool isThreadingGloballyDisabled() {
76 #if LLVM_ENABLE_THREADS != 0
77   return clOptions.isConstructed() && clOptions->disableThreading;
78 #else
79   return true;
80 #endif
81 }
82 
83 /// Register a set of useful command-line options that can be used to configure
84 /// various flags within the MLIRContext. These flags are used when constructing
85 /// an MLIR context for initialization.
registerMLIRContextCLOptions()86 void mlir::registerMLIRContextCLOptions() {
87   // Make sure that the options struct has been initialized.
88   *clOptions;
89 }
90 
91 //===----------------------------------------------------------------------===//
92 // Locking Utilities
93 //===----------------------------------------------------------------------===//
94 
95 namespace {
96 /// Utility writer lock that takes a runtime flag that specifies if we really
97 /// need to lock.
98 struct ScopedWriterLock {
ScopedWriterLock__anonc2df5d7c0211::ScopedWriterLock99   ScopedWriterLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock)
100       : mutex(shouldLock ? &mutexParam : nullptr) {
101     if (mutex)
102       mutex->lock();
103   }
~ScopedWriterLock__anonc2df5d7c0211::ScopedWriterLock104   ~ScopedWriterLock() {
105     if (mutex)
106       mutex->unlock();
107   }
108   llvm::sys::SmartRWMutex<true> *mutex;
109 };
110 } // namespace
111 
112 //===----------------------------------------------------------------------===//
113 // MLIRContextImpl
114 //===----------------------------------------------------------------------===//
115 
116 namespace mlir {
117 /// This is the implementation of the MLIRContext class, using the pImpl idiom.
118 /// This class is completely private to this file, so everything is public.
119 class MLIRContextImpl {
120 public:
121   //===--------------------------------------------------------------------===//
122   // Debugging
123   //===--------------------------------------------------------------------===//
124 
125   /// An action manager for use within the context.
126   DebugActionManager debugActionManager;
127 
128   //===--------------------------------------------------------------------===//
129   // Diagnostics
130   //===--------------------------------------------------------------------===//
131   DiagnosticEngine diagEngine;
132 
133   //===--------------------------------------------------------------------===//
134   // Options
135   //===--------------------------------------------------------------------===//
136 
137   /// In most cases, creating operation in unregistered dialect is not desired
138   /// and indicate a misconfiguration of the compiler. This option enables to
139   /// detect such use cases
140   bool allowUnregisteredDialects = false;
141 
142   /// Enable support for multi-threading within MLIR.
143   bool threadingIsEnabled = true;
144 
145   /// Track if we are currently executing in a threaded execution environment
146   /// (like the pass-manager): this is only a debugging feature to help reducing
147   /// the chances of data races one some context APIs.
148 #ifndef NDEBUG
149   std::atomic<int> multiThreadedExecutionContext{0};
150 #endif
151 
152   /// If the operation should be attached to diagnostics printed via the
153   /// Operation::emit methods.
154   bool printOpOnDiagnostic = true;
155 
156   /// If the current stack trace should be attached when emitting diagnostics.
157   bool printStackTraceOnDiagnostic = false;
158 
159   //===--------------------------------------------------------------------===//
160   // Other
161   //===--------------------------------------------------------------------===//
162 
163   /// This points to the ThreadPool used when processing MLIR tasks in parallel.
164   /// It can't be nullptr when multi-threading is enabled. Otherwise if
165   /// multi-threading is disabled, and the threadpool wasn't externally provided
166   /// using `setThreadPool`, this will be nullptr.
167   llvm::ThreadPool *threadPool = nullptr;
168 
169   /// In case where the thread pool is owned by the context, this ensures
170   /// destruction with the context.
171   std::unique_ptr<llvm::ThreadPool> ownedThreadPool;
172 
173   /// This is a list of dialects that are created referring to this context.
174   /// The MLIRContext owns the objects.
175   DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects;
176   DialectRegistry dialectsRegistry;
177 
178   /// An allocator used for AbstractAttribute and AbstractType objects.
179   llvm::BumpPtrAllocator abstractDialectSymbolAllocator;
180 
181   /// This is a mapping from operation name to the operation info describing it.
182   llvm::StringMap<OperationName::Impl> operations;
183 
184   /// A vector of operation info specifically for registered operations.
185   llvm::StringMap<RegisteredOperationName> registeredOperations;
186 
187   /// This is a sorted container of registered operations for a deterministic
188   /// and efficient `getRegisteredOperations` implementation.
189   SmallVector<RegisteredOperationName, 0> sortedRegisteredOperations;
190 
191   /// A mutex used when accessing operation information.
192   llvm::sys::SmartRWMutex<true> operationInfoMutex;
193 
194   //===--------------------------------------------------------------------===//
195   // Affine uniquing
196   //===--------------------------------------------------------------------===//
197 
198   // Affine expression, map and integer set uniquing.
199   StorageUniquer affineUniquer;
200 
201   //===--------------------------------------------------------------------===//
202   // Type uniquing
203   //===--------------------------------------------------------------------===//
204 
205   DenseMap<TypeID, AbstractType *> registeredTypes;
206   StorageUniquer typeUniquer;
207 
208   /// Cached Type Instances.
209   BFloat16Type bf16Ty;
210   Float16Type f16Ty;
211   Float32Type f32Ty;
212   Float64Type f64Ty;
213   Float80Type f80Ty;
214   Float128Type f128Ty;
215   IndexType indexTy;
216   IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
217   NoneType noneType;
218 
219   //===--------------------------------------------------------------------===//
220   // Attribute uniquing
221   //===--------------------------------------------------------------------===//
222 
223   DenseMap<TypeID, AbstractAttribute *> registeredAttributes;
224   StorageUniquer attributeUniquer;
225 
226   /// Cached Attribute Instances.
227   BoolAttr falseAttr, trueAttr;
228   UnitAttr unitAttr;
229   UnknownLoc unknownLocAttr;
230   DictionaryAttr emptyDictionaryAttr;
231   StringAttr emptyStringAttr;
232 
233   /// Map of string attributes that may reference a dialect, that are awaiting
234   /// that dialect to be loaded.
235   llvm::sys::SmartMutex<true> dialectRefStrAttrMutex;
236   DenseMap<StringRef, SmallVector<StringAttrStorage *>>
237       dialectReferencingStrAttrs;
238 
239 public:
MLIRContextImpl(bool threadingIsEnabled)240   MLIRContextImpl(bool threadingIsEnabled)
241       : threadingIsEnabled(threadingIsEnabled) {
242     if (threadingIsEnabled) {
243       ownedThreadPool = std::make_unique<llvm::ThreadPool>();
244       threadPool = ownedThreadPool.get();
245     }
246   }
~MLIRContextImpl()247   ~MLIRContextImpl() {
248     for (auto typeMapping : registeredTypes)
249       typeMapping.second->~AbstractType();
250     for (auto attrMapping : registeredAttributes)
251       attrMapping.second->~AbstractAttribute();
252   }
253 };
254 } // namespace mlir
255 
MLIRContext(Threading setting)256 MLIRContext::MLIRContext(Threading setting)
257     : MLIRContext(DialectRegistry(), setting) {}
258 
MLIRContext(const DialectRegistry & registry,Threading setting)259 MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
260     : impl(new MLIRContextImpl(setting == Threading::ENABLED &&
261                                !isThreadingGloballyDisabled())) {
262   // Initialize values based on the command line flags if they were provided.
263   if (clOptions.isConstructed()) {
264     printOpOnDiagnostic(clOptions->printOpOnDiagnostic);
265     printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic);
266   }
267 
268   // Pre-populate the registry.
269   registry.appendTo(impl->dialectsRegistry);
270 
271   // Ensure the builtin dialect is always pre-loaded.
272   getOrLoadDialect<BuiltinDialect>();
273 
274   // Initialize several common attributes and types to avoid the need to lock
275   // the context when accessing them.
276 
277   //// Types.
278   /// Floating-point Types.
279   impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
280   impl->f16Ty = TypeUniquer::get<Float16Type>(this);
281   impl->f32Ty = TypeUniquer::get<Float32Type>(this);
282   impl->f64Ty = TypeUniquer::get<Float64Type>(this);
283   impl->f80Ty = TypeUniquer::get<Float80Type>(this);
284   impl->f128Ty = TypeUniquer::get<Float128Type>(this);
285   /// Index Type.
286   impl->indexTy = TypeUniquer::get<IndexType>(this);
287   /// Integer Types.
288   impl->int1Ty = TypeUniquer::get<IntegerType>(this, 1, IntegerType::Signless);
289   impl->int8Ty = TypeUniquer::get<IntegerType>(this, 8, IntegerType::Signless);
290   impl->int16Ty =
291       TypeUniquer::get<IntegerType>(this, 16, IntegerType::Signless);
292   impl->int32Ty =
293       TypeUniquer::get<IntegerType>(this, 32, IntegerType::Signless);
294   impl->int64Ty =
295       TypeUniquer::get<IntegerType>(this, 64, IntegerType::Signless);
296   impl->int128Ty =
297       TypeUniquer::get<IntegerType>(this, 128, IntegerType::Signless);
298   /// None Type.
299   impl->noneType = TypeUniquer::get<NoneType>(this);
300 
301   //// Attributes.
302   //// Note: These must be registered after the types as they may generate one
303   //// of the above types internally.
304   /// Unknown Location Attribute.
305   impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(this);
306   /// Bool Attributes.
307   impl->falseAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, false);
308   impl->trueAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, true);
309   /// Unit Attribute.
310   impl->unitAttr = AttributeUniquer::get<UnitAttr>(this);
311   /// The empty dictionary attribute.
312   impl->emptyDictionaryAttr = DictionaryAttr::getEmptyUnchecked(this);
313   /// The empty string attribute.
314   impl->emptyStringAttr = StringAttr::getEmptyStringAttrUnchecked(this);
315 
316   // Register the affine storage objects with the uniquer.
317   impl->affineUniquer
318       .registerParametricStorageType<AffineBinaryOpExprStorage>();
319   impl->affineUniquer
320       .registerParametricStorageType<AffineConstantExprStorage>();
321   impl->affineUniquer.registerParametricStorageType<AffineDimExprStorage>();
322   impl->affineUniquer.registerParametricStorageType<AffineMapStorage>();
323   impl->affineUniquer.registerParametricStorageType<IntegerSetStorage>();
324 }
325 
326 MLIRContext::~MLIRContext() = default;
327 
328 /// Copy the specified array of elements into memory managed by the provided
329 /// bump pointer allocator.  This assumes the elements are all PODs.
330 template <typename T>
copyArrayRefInto(llvm::BumpPtrAllocator & allocator,ArrayRef<T> elements)331 static ArrayRef<T> copyArrayRefInto(llvm::BumpPtrAllocator &allocator,
332                                     ArrayRef<T> elements) {
333   auto result = allocator.Allocate<T>(elements.size());
334   std::uninitialized_copy(elements.begin(), elements.end(), result);
335   return ArrayRef<T>(result, elements.size());
336 }
337 
338 //===----------------------------------------------------------------------===//
339 // Debugging
340 //===----------------------------------------------------------------------===//
341 
getDebugActionManager()342 DebugActionManager &MLIRContext::getDebugActionManager() {
343   return getImpl().debugActionManager;
344 }
345 
346 //===----------------------------------------------------------------------===//
347 // Diagnostic Handlers
348 //===----------------------------------------------------------------------===//
349 
350 /// Returns the diagnostic engine for this context.
getDiagEngine()351 DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
352 
353 //===----------------------------------------------------------------------===//
354 // Dialect and Operation Registration
355 //===----------------------------------------------------------------------===//
356 
appendDialectRegistry(const DialectRegistry & registry)357 void MLIRContext::appendDialectRegistry(const DialectRegistry &registry) {
358   if (registry.isSubsetOf(impl->dialectsRegistry))
359     return;
360 
361   assert(impl->multiThreadedExecutionContext == 0 &&
362          "appending to the MLIRContext dialect registry while in a "
363          "multi-threaded execution context");
364   registry.appendTo(impl->dialectsRegistry);
365 
366   // For the already loaded dialects, apply any possible extensions immediately.
367   registry.applyExtensions(this);
368 }
369 
getDialectRegistry()370 const DialectRegistry &MLIRContext::getDialectRegistry() {
371   return impl->dialectsRegistry;
372 }
373 
374 /// Return information about all registered IR dialects.
getLoadedDialects()375 std::vector<Dialect *> MLIRContext::getLoadedDialects() {
376   std::vector<Dialect *> result;
377   result.reserve(impl->loadedDialects.size());
378   for (auto &dialect : impl->loadedDialects)
379     result.push_back(dialect.second.get());
380   llvm::array_pod_sort(result.begin(), result.end(),
381                        [](Dialect *const *lhs, Dialect *const *rhs) -> int {
382                          return (*lhs)->getNamespace() < (*rhs)->getNamespace();
383                        });
384   return result;
385 }
getAvailableDialects()386 std::vector<StringRef> MLIRContext::getAvailableDialects() {
387   std::vector<StringRef> result;
388   for (auto dialect : impl->dialectsRegistry.getDialectNames())
389     result.push_back(dialect);
390   return result;
391 }
392 
393 /// Get a registered IR dialect with the given namespace. If none is found,
394 /// then return nullptr.
getLoadedDialect(StringRef name)395 Dialect *MLIRContext::getLoadedDialect(StringRef name) {
396   // Dialects are sorted by name, so we can use binary search for lookup.
397   auto it = impl->loadedDialects.find(name);
398   return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr;
399 }
400 
getOrLoadDialect(StringRef name)401 Dialect *MLIRContext::getOrLoadDialect(StringRef name) {
402   Dialect *dialect = getLoadedDialect(name);
403   if (dialect)
404     return dialect;
405   DialectAllocatorFunctionRef allocator =
406       impl->dialectsRegistry.getDialectAllocator(name);
407   return allocator ? allocator(this) : nullptr;
408 }
409 
410 /// Get a dialect for the provided namespace and TypeID: abort the program if a
411 /// dialect exist for this namespace with different TypeID. Returns a pointer to
412 /// the dialect owned by the context.
413 Dialect *
getOrLoadDialect(StringRef dialectNamespace,TypeID dialectID,function_ref<std::unique_ptr<Dialect> ()> ctor)414 MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
415                               function_ref<std::unique_ptr<Dialect>()> ctor) {
416   auto &impl = getImpl();
417   // Get the correct insertion position sorted by namespace.
418   auto dialectIt = impl.loadedDialects.find(dialectNamespace);
419 
420   if (dialectIt == impl.loadedDialects.end()) {
421     LLVM_DEBUG(llvm::dbgs()
422                << "Load new dialect in Context " << dialectNamespace << "\n");
423 #ifndef NDEBUG
424     if (impl.multiThreadedExecutionContext != 0)
425       llvm::report_fatal_error(
426           "Loading a dialect (" + dialectNamespace +
427           ") while in a multi-threaded execution context (maybe "
428           "the PassManager): this can indicate a "
429           "missing `dependentDialects` in a pass for example.");
430 #endif
431     std::unique_ptr<Dialect> &dialect =
432         impl.loadedDialects.insert({dialectNamespace, ctor()}).first->second;
433     assert(dialect && "dialect ctor failed");
434 
435     // Refresh all the identifiers dialect field, this catches cases where a
436     // dialect may be loaded after identifier prefixed with this dialect name
437     // were already created.
438     auto stringAttrsIt = impl.dialectReferencingStrAttrs.find(dialectNamespace);
439     if (stringAttrsIt != impl.dialectReferencingStrAttrs.end()) {
440       for (StringAttrStorage *storage : stringAttrsIt->second)
441         storage->referencedDialect = dialect.get();
442       impl.dialectReferencingStrAttrs.erase(stringAttrsIt);
443     }
444 
445     // Apply any extensions to this newly loaded dialect.
446     impl.dialectsRegistry.applyExtensions(dialect.get());
447     return dialect.get();
448   }
449 
450   // Abort if dialect with namespace has already been registered.
451   std::unique_ptr<Dialect> &dialect = dialectIt->second;
452   if (dialect->getTypeID() != dialectID)
453     llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
454                              "' has already been registered");
455 
456   return dialect.get();
457 }
458 
loadAllAvailableDialects()459 void MLIRContext::loadAllAvailableDialects() {
460   for (StringRef name : getAvailableDialects())
461     getOrLoadDialect(name);
462 }
463 
getRegistryHash()464 llvm::hash_code MLIRContext::getRegistryHash() {
465   llvm::hash_code hash(0);
466   // Factor in number of loaded dialects, attributes, operations, types.
467   hash = llvm::hash_combine(hash, impl->loadedDialects.size());
468   hash = llvm::hash_combine(hash, impl->registeredAttributes.size());
469   hash = llvm::hash_combine(hash, impl->registeredOperations.size());
470   hash = llvm::hash_combine(hash, impl->registeredTypes.size());
471   return hash;
472 }
473 
allowsUnregisteredDialects()474 bool MLIRContext::allowsUnregisteredDialects() {
475   return impl->allowUnregisteredDialects;
476 }
477 
allowUnregisteredDialects(bool allowing)478 void MLIRContext::allowUnregisteredDialects(bool allowing) {
479   assert(impl->multiThreadedExecutionContext == 0 &&
480          "changing MLIRContext `allow-unregistered-dialects` configuration "
481          "while in a multi-threaded execution context");
482   impl->allowUnregisteredDialects = allowing;
483 }
484 
485 /// Return true if multi-threading is enabled by the context.
isMultithreadingEnabled()486 bool MLIRContext::isMultithreadingEnabled() {
487   return impl->threadingIsEnabled && llvm::llvm_is_multithreaded();
488 }
489 
490 /// Set the flag specifying if multi-threading is disabled by the context.
disableMultithreading(bool disable)491 void MLIRContext::disableMultithreading(bool disable) {
492   // This API can be overridden by the global debugging flag
493   // --mlir-disable-threading
494   if (isThreadingGloballyDisabled())
495     return;
496   assert(impl->multiThreadedExecutionContext == 0 &&
497          "changing MLIRContext `disable-threading` configuration while "
498          "in a multi-threaded execution context");
499 
500   impl->threadingIsEnabled = !disable;
501 
502   // Update the threading mode for each of the uniquers.
503   impl->affineUniquer.disableMultithreading(disable);
504   impl->attributeUniquer.disableMultithreading(disable);
505   impl->typeUniquer.disableMultithreading(disable);
506 
507   // Destroy thread pool (stop all threads) if it is no longer needed, or create
508   // a new one if multithreading was re-enabled.
509   if (disable) {
510     // If the thread pool is owned, explicitly set it to nullptr to avoid
511     // keeping a dangling pointer around. If the thread pool is externally
512     // owned, we don't do anything.
513     if (impl->ownedThreadPool) {
514       assert(impl->threadPool);
515       impl->threadPool = nullptr;
516       impl->ownedThreadPool.reset();
517     }
518   } else if (!impl->threadPool) {
519     // The thread pool isn't externally provided.
520     assert(!impl->ownedThreadPool);
521     impl->ownedThreadPool = std::make_unique<llvm::ThreadPool>();
522     impl->threadPool = impl->ownedThreadPool.get();
523   }
524 }
525 
setThreadPool(llvm::ThreadPool & pool)526 void MLIRContext::setThreadPool(llvm::ThreadPool &pool) {
527   assert(!isMultithreadingEnabled() &&
528          "expected multi-threading to be disabled when setting a ThreadPool");
529   impl->threadPool = &pool;
530   impl->ownedThreadPool.reset();
531   enableMultithreading();
532 }
533 
getNumThreads()534 unsigned MLIRContext::getNumThreads() {
535   if (isMultithreadingEnabled()) {
536     assert(impl->threadPool &&
537            "multi-threading is enabled but threadpool not set");
538     return impl->threadPool->getThreadCount();
539   }
540   // No multithreading or active thread pool. Return 1 thread.
541   return 1;
542 }
543 
getThreadPool()544 llvm::ThreadPool &MLIRContext::getThreadPool() {
545   assert(isMultithreadingEnabled() &&
546          "expected multi-threading to be enabled within the context");
547   assert(impl->threadPool &&
548          "multi-threading is enabled but threadpool not set");
549   return *impl->threadPool;
550 }
551 
enterMultiThreadedExecution()552 void MLIRContext::enterMultiThreadedExecution() {
553 #ifndef NDEBUG
554   ++impl->multiThreadedExecutionContext;
555 #endif
556 }
exitMultiThreadedExecution()557 void MLIRContext::exitMultiThreadedExecution() {
558 #ifndef NDEBUG
559   --impl->multiThreadedExecutionContext;
560 #endif
561 }
562 
563 /// Return true if we should attach the operation to diagnostics emitted via
564 /// Operation::emit.
shouldPrintOpOnDiagnostic()565 bool MLIRContext::shouldPrintOpOnDiagnostic() {
566   return impl->printOpOnDiagnostic;
567 }
568 
569 /// Set the flag specifying if we should attach the operation to diagnostics
570 /// emitted via Operation::emit.
printOpOnDiagnostic(bool enable)571 void MLIRContext::printOpOnDiagnostic(bool enable) {
572   assert(impl->multiThreadedExecutionContext == 0 &&
573          "changing MLIRContext `print-op-on-diagnostic` configuration while in "
574          "a multi-threaded execution context");
575   impl->printOpOnDiagnostic = enable;
576 }
577 
578 /// Return true if we should attach the current stacktrace to diagnostics when
579 /// emitted.
shouldPrintStackTraceOnDiagnostic()580 bool MLIRContext::shouldPrintStackTraceOnDiagnostic() {
581   return impl->printStackTraceOnDiagnostic;
582 }
583 
584 /// Set the flag specifying if we should attach the current stacktrace when
585 /// emitting diagnostics.
printStackTraceOnDiagnostic(bool enable)586 void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
587   assert(impl->multiThreadedExecutionContext == 0 &&
588          "changing MLIRContext `print-stacktrace-on-diagnostic` configuration "
589          "while in a multi-threaded execution context");
590   impl->printStackTraceOnDiagnostic = enable;
591 }
592 
593 /// Return information about all registered operations.
getRegisteredOperations()594 ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
595   return impl->sortedRegisteredOperations;
596 }
597 
isOperationRegistered(StringRef name)598 bool MLIRContext::isOperationRegistered(StringRef name) {
599   return RegisteredOperationName::lookup(name, this).has_value();
600 }
601 
addType(TypeID typeID,AbstractType && typeInfo)602 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
603   auto &impl = context->getImpl();
604   assert(impl.multiThreadedExecutionContext == 0 &&
605          "Registering a new type kind while in a multi-threaded execution "
606          "context");
607   auto *newInfo =
608       new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>())
609           AbstractType(std::move(typeInfo));
610   if (!impl.registeredTypes.insert({typeID, newInfo}).second)
611     llvm::report_fatal_error("Dialect Type already registered.");
612 }
613 
addAttribute(TypeID typeID,AbstractAttribute && attrInfo)614 void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
615   auto &impl = context->getImpl();
616   assert(impl.multiThreadedExecutionContext == 0 &&
617          "Registering a new attribute kind while in a multi-threaded execution "
618          "context");
619   auto *newInfo =
620       new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>())
621           AbstractAttribute(std::move(attrInfo));
622   if (!impl.registeredAttributes.insert({typeID, newInfo}).second)
623     llvm::report_fatal_error("Dialect Attribute already registered.");
624 }
625 
626 //===----------------------------------------------------------------------===//
627 // AbstractAttribute
628 //===----------------------------------------------------------------------===//
629 
630 /// Get the dialect that registered the attribute with the provided typeid.
lookup(TypeID typeID,MLIRContext * context)631 const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID,
632                                                    MLIRContext *context) {
633   const AbstractAttribute *abstract = lookupMutable(typeID, context);
634   if (!abstract)
635     llvm::report_fatal_error("Trying to create an Attribute that was not "
636                              "registered in this MLIRContext.");
637   return *abstract;
638 }
639 
lookupMutable(TypeID typeID,MLIRContext * context)640 AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID,
641                                                     MLIRContext *context) {
642   auto &impl = context->getImpl();
643   auto it = impl.registeredAttributes.find(typeID);
644   if (it == impl.registeredAttributes.end())
645     return nullptr;
646   return it->second;
647 }
648 
649 //===----------------------------------------------------------------------===//
650 // OperationName
651 //===----------------------------------------------------------------------===//
652 
OperationName(StringRef name,MLIRContext * context)653 OperationName::OperationName(StringRef name, MLIRContext *context) {
654   MLIRContextImpl &ctxImpl = context->getImpl();
655 
656   // Check for an existing name in read-only mode.
657   bool isMultithreadingEnabled = context->isMultithreadingEnabled();
658   if (isMultithreadingEnabled) {
659     // Check the registered info map first. In the overwhelmingly common case,
660     // the entry will be in here and it also removes the need to acquire any
661     // locks.
662     auto registeredIt = ctxImpl.registeredOperations.find(name);
663     if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperations.end())) {
664       impl = registeredIt->second.impl;
665       return;
666     }
667 
668     llvm::sys::SmartScopedReader<true> contextLock(ctxImpl.operationInfoMutex);
669     auto it = ctxImpl.operations.find(name);
670     if (it != ctxImpl.operations.end()) {
671       impl = &it->second;
672       return;
673     }
674   }
675 
676   // Acquire a writer-lock so that we can safely create the new instance.
677   ScopedWriterLock lock(ctxImpl.operationInfoMutex, isMultithreadingEnabled);
678 
679   auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)});
680   if (it.second)
681     it.first->second.name = StringAttr::get(context, name);
682   impl = &it.first->second;
683 }
684 
getDialectNamespace() const685 StringRef OperationName::getDialectNamespace() const {
686   if (Dialect *dialect = getDialect())
687     return dialect->getNamespace();
688   return getStringRef().split('.').first;
689 }
690 
691 //===----------------------------------------------------------------------===//
692 // RegisteredOperationName
693 //===----------------------------------------------------------------------===//
694 
695 Optional<RegisteredOperationName>
lookup(StringRef name,MLIRContext * ctx)696 RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) {
697   auto &impl = ctx->getImpl();
698   auto it = impl.registeredOperations.find(name);
699   if (it != impl.registeredOperations.end())
700     return it->getValue();
701   return llvm::None;
702 }
703 
704 ParseResult
parseAssembly(OpAsmParser & parser,OperationState & result) const705 RegisteredOperationName::parseAssembly(OpAsmParser &parser,
706                                        OperationState &result) const {
707   return impl->parseAssemblyFn(parser, result);
708 }
709 
populateDefaultAttrs(NamedAttrList & attrs) const710 void RegisteredOperationName::populateDefaultAttrs(NamedAttrList &attrs) const {
711   impl->populateDefaultAttrsFn(*this, attrs);
712 }
713 
insert(StringRef name,Dialect & dialect,TypeID typeID,ParseAssemblyFn && parseAssembly,PrintAssemblyFn && printAssembly,VerifyInvariantsFn && verifyInvariants,VerifyRegionInvariantsFn && verifyRegionInvariants,FoldHookFn && foldHook,GetCanonicalizationPatternsFn && getCanonicalizationPatterns,detail::InterfaceMap && interfaceMap,HasTraitFn && hasTrait,ArrayRef<StringRef> attrNames,PopulateDefaultAttrsFn && populateDefaultAttrs)714 void RegisteredOperationName::insert(
715     StringRef name, Dialect &dialect, TypeID typeID,
716     ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
717     VerifyInvariantsFn &&verifyInvariants,
718     VerifyRegionInvariantsFn &&verifyRegionInvariants, FoldHookFn &&foldHook,
719     GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
720     detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
721     ArrayRef<StringRef> attrNames,
722     PopulateDefaultAttrsFn &&populateDefaultAttrs) {
723   MLIRContext *ctx = dialect.getContext();
724   auto &ctxImpl = ctx->getImpl();
725   assert(ctxImpl.multiThreadedExecutionContext == 0 &&
726          "registering a new operation kind while in a multi-threaded execution "
727          "context");
728 
729   // Register the attribute names of this operation.
730   MutableArrayRef<StringAttr> cachedAttrNames;
731   if (!attrNames.empty()) {
732     cachedAttrNames = MutableArrayRef<StringAttr>(
733         ctxImpl.abstractDialectSymbolAllocator.Allocate<StringAttr>(
734             attrNames.size()),
735         attrNames.size());
736     for (unsigned i : llvm::seq<unsigned>(0, attrNames.size()))
737       new (&cachedAttrNames[i]) StringAttr(StringAttr::get(ctx, attrNames[i]));
738   }
739 
740   // Insert the operation info if it doesn't exist yet.
741   auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)});
742   if (it.second)
743     it.first->second.name = StringAttr::get(ctx, name);
744   OperationName::Impl &impl = it.first->second;
745 
746   if (impl.isRegistered()) {
747     llvm::errs() << "error: operation named '" << name
748                  << "' is already registered.\n";
749     abort();
750   }
751   auto emplaced = ctxImpl.registeredOperations.try_emplace(
752       name, RegisteredOperationName(&impl));
753   assert(emplaced.second && "operation name registration must be successful");
754 
755   // Add emplaced operation name to the sorted operations container.
756   RegisteredOperationName &value = emplaced.first->getValue();
757   ctxImpl.sortedRegisteredOperations.insert(
758       llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value,
759                         [](auto &lhs, auto &rhs) {
760                           return lhs.getIdentifier().compare(
761                               rhs.getIdentifier());
762                         }),
763       value);
764 
765   // Update the registered info for this operation.
766   impl.dialect = &dialect;
767   impl.typeID = typeID;
768   impl.interfaceMap = std::move(interfaceMap);
769   impl.foldHookFn = std::move(foldHook);
770   impl.getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns);
771   impl.hasTraitFn = std::move(hasTrait);
772   impl.parseAssemblyFn = std::move(parseAssembly);
773   impl.printAssemblyFn = std::move(printAssembly);
774   impl.verifyInvariantsFn = std::move(verifyInvariants);
775   impl.verifyRegionInvariantsFn = std::move(verifyRegionInvariants);
776   impl.attributeNames = cachedAttrNames;
777   impl.populateDefaultAttrsFn = std::move(populateDefaultAttrs);
778 }
779 
780 //===----------------------------------------------------------------------===//
781 // AbstractType
782 //===----------------------------------------------------------------------===//
783 
lookup(TypeID typeID,MLIRContext * context)784 const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) {
785   const AbstractType *type = lookupMutable(typeID, context);
786   if (!type)
787     llvm::report_fatal_error(
788         "Trying to create a Type that was not registered in this MLIRContext.");
789   return *type;
790 }
791 
lookupMutable(TypeID typeID,MLIRContext * context)792 AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) {
793   auto &impl = context->getImpl();
794   auto it = impl.registeredTypes.find(typeID);
795   if (it == impl.registeredTypes.end())
796     return nullptr;
797   return it->second;
798 }
799 
800 //===----------------------------------------------------------------------===//
801 // Type uniquing
802 //===----------------------------------------------------------------------===//
803 
804 /// Returns the storage uniquer used for constructing type storage instances.
805 /// This should not be used directly.
getTypeUniquer()806 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
807 
get(MLIRContext * context)808 BFloat16Type BFloat16Type::get(MLIRContext *context) {
809   return context->getImpl().bf16Ty;
810 }
get(MLIRContext * context)811 Float16Type Float16Type::get(MLIRContext *context) {
812   return context->getImpl().f16Ty;
813 }
get(MLIRContext * context)814 Float32Type Float32Type::get(MLIRContext *context) {
815   return context->getImpl().f32Ty;
816 }
get(MLIRContext * context)817 Float64Type Float64Type::get(MLIRContext *context) {
818   return context->getImpl().f64Ty;
819 }
get(MLIRContext * context)820 Float80Type Float80Type::get(MLIRContext *context) {
821   return context->getImpl().f80Ty;
822 }
get(MLIRContext * context)823 Float128Type Float128Type::get(MLIRContext *context) {
824   return context->getImpl().f128Ty;
825 }
826 
827 /// Get an instance of the IndexType.
get(MLIRContext * context)828 IndexType IndexType::get(MLIRContext *context) {
829   return context->getImpl().indexTy;
830 }
831 
832 /// Return an existing integer type instance if one is cached within the
833 /// context.
834 static IntegerType
getCachedIntegerType(unsigned width,IntegerType::SignednessSemantics signedness,MLIRContext * context)835 getCachedIntegerType(unsigned width,
836                      IntegerType::SignednessSemantics signedness,
837                      MLIRContext *context) {
838   if (signedness != IntegerType::Signless)
839     return IntegerType();
840 
841   switch (width) {
842   case 1:
843     return context->getImpl().int1Ty;
844   case 8:
845     return context->getImpl().int8Ty;
846   case 16:
847     return context->getImpl().int16Ty;
848   case 32:
849     return context->getImpl().int32Ty;
850   case 64:
851     return context->getImpl().int64Ty;
852   case 128:
853     return context->getImpl().int128Ty;
854   default:
855     return IntegerType();
856   }
857 }
858 
get(MLIRContext * context,unsigned width,IntegerType::SignednessSemantics signedness)859 IntegerType IntegerType::get(MLIRContext *context, unsigned width,
860                              IntegerType::SignednessSemantics signedness) {
861   if (auto cached = getCachedIntegerType(width, signedness, context))
862     return cached;
863   return Base::get(context, width, signedness);
864 }
865 
866 IntegerType
getChecked(function_ref<InFlightDiagnostic ()> emitError,MLIRContext * context,unsigned width,SignednessSemantics signedness)867 IntegerType::getChecked(function_ref<InFlightDiagnostic()> emitError,
868                         MLIRContext *context, unsigned width,
869                         SignednessSemantics signedness) {
870   if (auto cached = getCachedIntegerType(width, signedness, context))
871     return cached;
872   return Base::getChecked(emitError, context, width, signedness);
873 }
874 
875 /// Get an instance of the NoneType.
get(MLIRContext * context)876 NoneType NoneType::get(MLIRContext *context) {
877   if (NoneType cachedInst = context->getImpl().noneType)
878     return cachedInst;
879   // Note: May happen when initializing the singleton attributes of the builtin
880   // dialect.
881   return Base::get(context);
882 }
883 
884 //===----------------------------------------------------------------------===//
885 // Attribute uniquing
886 //===----------------------------------------------------------------------===//
887 
888 /// Returns the storage uniquer used for constructing attribute storage
889 /// instances. This should not be used directly.
getAttributeUniquer()890 StorageUniquer &MLIRContext::getAttributeUniquer() {
891   return getImpl().attributeUniquer;
892 }
893 
894 /// Initialize the given attribute storage instance.
initializeAttributeStorage(AttributeStorage * storage,MLIRContext * ctx,TypeID attrID)895 void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage,
896                                                   MLIRContext *ctx,
897                                                   TypeID attrID) {
898   storage->initializeAbstractAttribute(AbstractAttribute::lookup(attrID, ctx));
899 
900   // If the attribute did not provide a type, then default to NoneType.
901   if (!storage->getType())
902     storage->setType(NoneType::get(ctx));
903 }
904 
get(MLIRContext * context,bool value)905 BoolAttr BoolAttr::get(MLIRContext *context, bool value) {
906   return value ? context->getImpl().trueAttr : context->getImpl().falseAttr;
907 }
908 
get(MLIRContext * context)909 UnitAttr UnitAttr::get(MLIRContext *context) {
910   return context->getImpl().unitAttr;
911 }
912 
get(MLIRContext * context)913 UnknownLoc UnknownLoc::get(MLIRContext *context) {
914   return context->getImpl().unknownLocAttr;
915 }
916 
917 /// Return empty dictionary.
getEmpty(MLIRContext * context)918 DictionaryAttr DictionaryAttr::getEmpty(MLIRContext *context) {
919   return context->getImpl().emptyDictionaryAttr;
920 }
921 
initialize(MLIRContext * context)922 void StringAttrStorage::initialize(MLIRContext *context) {
923   // Check for a dialect namespace prefix, if there isn't one we don't need to
924   // do any additional initialization.
925   auto dialectNamePair = value.split('.');
926   if (dialectNamePair.first.empty() || dialectNamePair.second.empty())
927     return;
928 
929   // If one exists, we check to see if this dialect is loaded. If it is, we set
930   // the dialect now, if it isn't we record this storage for initialization
931   // later if the dialect ever gets loaded.
932   if ((referencedDialect = context->getLoadedDialect(dialectNamePair.first)))
933     return;
934 
935   MLIRContextImpl &impl = context->getImpl();
936   llvm::sys::SmartScopedLock<true> lock(impl.dialectRefStrAttrMutex);
937   impl.dialectReferencingStrAttrs[dialectNamePair.first].push_back(this);
938 }
939 
940 /// Return an empty string.
get(MLIRContext * context)941 StringAttr StringAttr::get(MLIRContext *context) {
942   return context->getImpl().emptyStringAttr;
943 }
944 
945 //===----------------------------------------------------------------------===//
946 // AffineMap uniquing
947 //===----------------------------------------------------------------------===//
948 
getAffineUniquer()949 StorageUniquer &MLIRContext::getAffineUniquer() {
950   return getImpl().affineUniquer;
951 }
952 
getImpl(unsigned dimCount,unsigned symbolCount,ArrayRef<AffineExpr> results,MLIRContext * context)953 AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount,
954                              ArrayRef<AffineExpr> results,
955                              MLIRContext *context) {
956   auto &impl = context->getImpl();
957   auto *storage = impl.affineUniquer.get<AffineMapStorage>(
958       [&](AffineMapStorage *storage) { storage->context = context; }, dimCount,
959       symbolCount, results);
960   return AffineMap(storage);
961 }
962 
963 /// Check whether the arguments passed to the AffineMap::get() are consistent.
964 /// This method checks whether the highest index of dimensional identifier
965 /// present in result expressions is less than `dimCount` and the highest index
966 /// of symbolic identifier present in result expressions is less than
967 /// `symbolCount`.
968 LLVM_ATTRIBUTE_UNUSED static bool
willBeValidAffineMap(unsigned dimCount,unsigned symbolCount,ArrayRef<AffineExpr> results)969 willBeValidAffineMap(unsigned dimCount, unsigned symbolCount,
970                      ArrayRef<AffineExpr> results) {
971   int64_t maxDimPosition = -1;
972   int64_t maxSymbolPosition = -1;
973   getMaxDimAndSymbol(ArrayRef<ArrayRef<AffineExpr>>(results), maxDimPosition,
974                      maxSymbolPosition);
975   if ((maxDimPosition >= dimCount) || (maxSymbolPosition >= symbolCount)) {
976     LLVM_DEBUG(
977         llvm::dbgs()
978         << "maximum dimensional identifier position in result expression must "
979            "be less than `dimCount` and maximum symbolic identifier position "
980            "in result expression must be less than `symbolCount`\n");
981     return false;
982   }
983   return true;
984 }
985 
get(MLIRContext * context)986 AffineMap AffineMap::get(MLIRContext *context) {
987   return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context);
988 }
989 
get(unsigned dimCount,unsigned symbolCount,MLIRContext * context)990 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
991                          MLIRContext *context) {
992   return getImpl(dimCount, symbolCount, /*results=*/{}, context);
993 }
994 
get(unsigned dimCount,unsigned symbolCount,AffineExpr result)995 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
996                          AffineExpr result) {
997   assert(willBeValidAffineMap(dimCount, symbolCount, {result}));
998   return getImpl(dimCount, symbolCount, {result}, result.getContext());
999 }
1000 
get(unsigned dimCount,unsigned symbolCount,ArrayRef<AffineExpr> results,MLIRContext * context)1001 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
1002                          ArrayRef<AffineExpr> results, MLIRContext *context) {
1003   assert(willBeValidAffineMap(dimCount, symbolCount, results));
1004   return getImpl(dimCount, symbolCount, results, context);
1005 }
1006 
1007 //===----------------------------------------------------------------------===//
1008 // Integer Sets: these are allocated into the bump pointer, and are immutable.
1009 // Unlike AffineMap's, these are uniqued only if they are small.
1010 //===----------------------------------------------------------------------===//
1011 
get(unsigned dimCount,unsigned symbolCount,ArrayRef<AffineExpr> constraints,ArrayRef<bool> eqFlags)1012 IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount,
1013                            ArrayRef<AffineExpr> constraints,
1014                            ArrayRef<bool> eqFlags) {
1015   // The number of constraints can't be zero.
1016   assert(!constraints.empty());
1017   assert(constraints.size() == eqFlags.size());
1018 
1019   auto &impl = constraints[0].getContext()->getImpl();
1020   auto *storage = impl.affineUniquer.get<IntegerSetStorage>(
1021       [](IntegerSetStorage *) {}, dimCount, symbolCount, constraints, eqFlags);
1022   return IntegerSet(storage);
1023 }
1024 
1025 //===----------------------------------------------------------------------===//
1026 // StorageUniquerSupport
1027 //===----------------------------------------------------------------------===//
1028 
1029 /// Utility method to generate a callback that can be used to generate a
1030 /// diagnostic when checking the construction invariants of a storage object.
1031 /// This is defined out-of-line to avoid the need to include Location.h.
1032 llvm::unique_function<InFlightDiagnostic()>
getDefaultDiagnosticEmitFn(MLIRContext * ctx)1033 mlir::detail::getDefaultDiagnosticEmitFn(MLIRContext *ctx) {
1034   return [ctx] { return emitError(UnknownLoc::get(ctx)); };
1035 }
1036 llvm::unique_function<InFlightDiagnostic()>
getDefaultDiagnosticEmitFn(const Location & loc)1037 mlir::detail::getDefaultDiagnosticEmitFn(const Location &loc) {
1038   return [=] { return emitError(loc); };
1039 }
1040