1 //===- MLIRContext.cpp - MLIR Type Classes --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/MLIRContext.h"
10 #include "AffineExprDetail.h"
11 #include "AffineMapDetail.h"
12 #include "AttributeDetail.h"
13 #include "IntegerSetDetail.h"
14 #include "TypeDetail.h"
15 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/BuiltinDialect.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/IR/Identifier.h"
22 #include "mlir/IR/IntegerSet.h"
23 #include "mlir/IR/Location.h"
24 #include "mlir/IR/OpImplementation.h"
25 #include "mlir/IR/Types.h"
26 #include "mlir/Support/DebugAction.h"
27 #include "mlir/Support/ThreadLocalCache.h"
28 #include "llvm/ADT/DenseMap.h"
29 #include "llvm/ADT/DenseSet.h"
30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/ADT/SmallString.h"
32 #include "llvm/ADT/StringSet.h"
33 #include "llvm/ADT/Twine.h"
34 #include "llvm/Support/Allocator.h"
35 #include "llvm/Support/CommandLine.h"
36 #include "llvm/Support/Debug.h"
37 #include "llvm/Support/RWMutex.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include <memory>
40 
41 #define DEBUG_TYPE "mlircontext"
42 
43 using namespace mlir;
44 using namespace mlir::detail;
45 
46 using llvm::hash_combine;
47 using llvm::hash_combine_range;
48 
49 //===----------------------------------------------------------------------===//
50 // MLIRContext CommandLine Options
51 //===----------------------------------------------------------------------===//
52 
53 namespace {
54 /// This struct contains command line options that can be used to initialize
55 /// various bits of an MLIRContext. This uses a struct wrapper to avoid the need
56 /// for global command line options.
57 struct MLIRContextOptions {
58   llvm::cl::opt<bool> disableThreading{
59       "mlir-disable-threading",
60       llvm::cl::desc("Disabling multi-threading within MLIR")};
61 
62   llvm::cl::opt<bool> printOpOnDiagnostic{
63       "mlir-print-op-on-diagnostic",
64       llvm::cl::desc("When a diagnostic is emitted on an operation, also print "
65                      "the operation as an attached note"),
66       llvm::cl::init(true)};
67 
68   llvm::cl::opt<bool> printStackTraceOnDiagnostic{
69       "mlir-print-stacktrace-on-diagnostic",
70       llvm::cl::desc("When a diagnostic is emitted, also print the stack trace "
71                      "as an attached note")};
72 };
73 } // end anonymous namespace
74 
75 static llvm::ManagedStatic<MLIRContextOptions> clOptions;
76 
77 /// Register a set of useful command-line options that can be used to configure
78 /// various flags within the MLIRContext. These flags are used when constructing
79 /// an MLIR context for initialization.
80 void mlir::registerMLIRContextCLOptions() {
81   // Make sure that the options struct has been initialized.
82   *clOptions;
83 }
84 
85 //===----------------------------------------------------------------------===//
86 // Locking Utilities
87 //===----------------------------------------------------------------------===//
88 
89 namespace {
90 /// Utility reader lock that takes a runtime flag that specifies if we really
91 /// need to lock.
92 struct ScopedReaderLock {
93   ScopedReaderLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock)
94       : mutex(shouldLock ? &mutexParam : nullptr) {
95     if (mutex)
96       mutex->lock_shared();
97   }
98   ~ScopedReaderLock() {
99     if (mutex)
100       mutex->unlock_shared();
101   }
102   llvm::sys::SmartRWMutex<true> *mutex;
103 };
104 /// Utility writer lock that takes a runtime flag that specifies if we really
105 /// need to lock.
106 struct ScopedWriterLock {
107   ScopedWriterLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock)
108       : mutex(shouldLock ? &mutexParam : nullptr) {
109     if (mutex)
110       mutex->lock();
111   }
112   ~ScopedWriterLock() {
113     if (mutex)
114       mutex->unlock();
115   }
116   llvm::sys::SmartRWMutex<true> *mutex;
117 };
118 } // end anonymous namespace.
119 
120 //===----------------------------------------------------------------------===//
121 // AffineMap and IntegerSet hashing
122 //===----------------------------------------------------------------------===//
123 
124 /// A utility function to safely get or create a uniqued instance within the
125 /// given set container.
126 template <typename ValueT, typename DenseInfoT, typename KeyT,
127           typename ConstructorFn>
128 static ValueT safeGetOrCreate(DenseSet<ValueT, DenseInfoT> &container,
129                               KeyT &&key, llvm::sys::SmartRWMutex<true> &mutex,
130                               bool threadingIsEnabled,
131                               ConstructorFn &&constructorFn) {
132   // Check for an existing instance in read-only mode.
133   if (threadingIsEnabled) {
134     llvm::sys::SmartScopedReader<true> instanceLock(mutex);
135     auto it = container.find_as(key);
136     if (it != container.end())
137       return *it;
138   }
139 
140   // Acquire a writer-lock so that we can safely create the new instance.
141   ScopedWriterLock instanceLock(mutex, threadingIsEnabled);
142 
143   // Check for an existing instance again here, because another writer thread
144   // may have already created one. Otherwise, construct a new instance.
145   auto existing = container.insert_as(ValueT(), key);
146   if (existing.second)
147     return *existing.first = constructorFn();
148   return *existing.first;
149 }
150 
151 namespace {
152 struct AffineMapKeyInfo : DenseMapInfo<AffineMap> {
153   // Affine maps are uniqued based on their dim/symbol counts and affine
154   // expressions.
155   using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExpr>>;
156   using DenseMapInfo<AffineMap>::isEqual;
157 
158   static unsigned getHashValue(const AffineMap &key) {
159     return getHashValue(
160         KeyTy(key.getNumDims(), key.getNumSymbols(), key.getResults()));
161   }
162 
163   static unsigned getHashValue(KeyTy key) {
164     return hash_combine(
165         std::get<0>(key), std::get<1>(key),
166         hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()));
167   }
168 
169   static bool isEqual(const KeyTy &lhs, AffineMap rhs) {
170     if (rhs == getEmptyKey() || rhs == getTombstoneKey())
171       return false;
172     return lhs == std::make_tuple(rhs.getNumDims(), rhs.getNumSymbols(),
173                                   rhs.getResults());
174   }
175 };
176 
177 struct IntegerSetKeyInfo : DenseMapInfo<IntegerSet> {
178   // Integer sets are uniqued based on their dim/symbol counts, affine
179   // expressions appearing in the LHS of constraints, and eqFlags.
180   using KeyTy =
181       std::tuple<unsigned, unsigned, ArrayRef<AffineExpr>, ArrayRef<bool>>;
182   using DenseMapInfo<IntegerSet>::isEqual;
183 
184   static unsigned getHashValue(const IntegerSet &key) {
185     return getHashValue(KeyTy(key.getNumDims(), key.getNumSymbols(),
186                               key.getConstraints(), key.getEqFlags()));
187   }
188 
189   static unsigned getHashValue(KeyTy key) {
190     return hash_combine(
191         std::get<0>(key), std::get<1>(key),
192         hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()),
193         hash_combine_range(std::get<3>(key).begin(), std::get<3>(key).end()));
194   }
195 
196   static bool isEqual(const KeyTy &lhs, IntegerSet rhs) {
197     if (rhs == getEmptyKey() || rhs == getTombstoneKey())
198       return false;
199     return lhs == std::make_tuple(rhs.getNumDims(), rhs.getNumSymbols(),
200                                   rhs.getConstraints(), rhs.getEqFlags());
201   }
202 };
203 } // end anonymous namespace.
204 
205 //===----------------------------------------------------------------------===//
206 // MLIRContextImpl
207 //===----------------------------------------------------------------------===//
208 
209 namespace mlir {
210 /// This is the implementation of the MLIRContext class, using the pImpl idiom.
211 /// This class is completely private to this file, so everything is public.
212 class MLIRContextImpl {
213 public:
214   //===--------------------------------------------------------------------===//
215   // Debugging
216   //===--------------------------------------------------------------------===//
217 
218   /// An action manager for use within the context.
219   DebugActionManager debugActionManager;
220 
221   //===--------------------------------------------------------------------===//
222   // Identifier uniquing
223   //===--------------------------------------------------------------------===//
224 
225   // Identifier allocator and mutex for thread safety.
226   llvm::BumpPtrAllocator identifierAllocator;
227   llvm::sys::SmartRWMutex<true> identifierMutex;
228 
229   //===--------------------------------------------------------------------===//
230   // Diagnostics
231   //===--------------------------------------------------------------------===//
232   DiagnosticEngine diagEngine;
233 
234   //===--------------------------------------------------------------------===//
235   // Options
236   //===--------------------------------------------------------------------===//
237 
238   /// In most cases, creating operation in unregistered dialect is not desired
239   /// and indicate a misconfiguration of the compiler. This option enables to
240   /// detect such use cases
241   bool allowUnregisteredDialects = false;
242 
243   /// Enable support for multi-threading within MLIR.
244   bool threadingIsEnabled = true;
245 
246   /// Track if we are currently executing in a threaded execution environment
247   /// (like the pass-manager): this is only a debugging feature to help reducing
248   /// the chances of data races one some context APIs.
249 #ifndef NDEBUG
250   std::atomic<int> multiThreadedExecutionContext{0};
251 #endif
252 
253   /// If the operation should be attached to diagnostics printed via the
254   /// Operation::emit methods.
255   bool printOpOnDiagnostic = true;
256 
257   /// If the current stack trace should be attached when emitting diagnostics.
258   bool printStackTraceOnDiagnostic = false;
259 
260   //===--------------------------------------------------------------------===//
261   // Other
262   //===--------------------------------------------------------------------===//
263 
264   /// This is a list of dialects that are created referring to this context.
265   /// The MLIRContext owns the objects.
266   DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects;
267   DialectRegistry dialectsRegistry;
268 
269   /// This is a mapping from operation name to AbstractOperation for registered
270   /// operations.
271   llvm::StringMap<AbstractOperation> registeredOperations;
272 
273   /// Identifiers are uniqued by string value and use the internal string set
274   /// for storage.
275   llvm::StringMap<PointerUnion<Dialect *, MLIRContext *>,
276                   llvm::BumpPtrAllocator &>
277       identifiers;
278   /// A thread local cache of identifiers to reduce lock contention.
279   ThreadLocalCache<llvm::StringMap<
280       llvm::StringMapEntry<PointerUnion<Dialect *, MLIRContext *>> *>>
281       localIdentifierCache;
282 
283   /// An allocator used for AbstractAttribute and AbstractType objects.
284   llvm::BumpPtrAllocator abstractDialectSymbolAllocator;
285 
286   //===--------------------------------------------------------------------===//
287   // Affine uniquing
288   //===--------------------------------------------------------------------===//
289 
290   // Affine allocator and mutex for thread safety.
291   llvm::BumpPtrAllocator affineAllocator;
292   llvm::sys::SmartRWMutex<true> affineMutex;
293 
294   // Affine map uniquing.
295   using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>;
296   AffineMapSet affineMaps;
297 
298   // Integer set uniquing.
299   using IntegerSets = DenseSet<IntegerSet, IntegerSetKeyInfo>;
300   IntegerSets integerSets;
301 
302   // Affine expression uniquing.
303   StorageUniquer affineUniquer;
304 
305   //===--------------------------------------------------------------------===//
306   // Type uniquing
307   //===--------------------------------------------------------------------===//
308 
309   DenseMap<TypeID, const AbstractType *> registeredTypes;
310   StorageUniquer typeUniquer;
311 
312   /// Cached Type Instances.
313   BFloat16Type bf16Ty;
314   Float16Type f16Ty;
315   Float32Type f32Ty;
316   Float64Type f64Ty;
317   Float80Type f80Ty;
318   Float128Type f128Ty;
319   IndexType indexTy;
320   IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
321   NoneType noneType;
322 
323   //===--------------------------------------------------------------------===//
324   // Attribute uniquing
325   //===--------------------------------------------------------------------===//
326 
327   DenseMap<TypeID, const AbstractAttribute *> registeredAttributes;
328   StorageUniquer attributeUniquer;
329 
330   /// Cached Attribute Instances.
331   BoolAttr falseAttr, trueAttr;
332   UnitAttr unitAttr;
333   UnknownLoc unknownLocAttr;
334   DictionaryAttr emptyDictionaryAttr;
335   StringAttr emptyStringAttr;
336 
337 public:
338   MLIRContextImpl() : identifiers(identifierAllocator) {}
339   ~MLIRContextImpl() {
340     for (auto typeMapping : registeredTypes)
341       typeMapping.second->~AbstractType();
342     for (auto attrMapping : registeredAttributes)
343       attrMapping.second->~AbstractAttribute();
344   }
345 };
346 } // end namespace mlir
347 
348 MLIRContext::MLIRContext() : MLIRContext(DialectRegistry()) {}
349 
350 MLIRContext::MLIRContext(const DialectRegistry &registry)
351     : impl(new MLIRContextImpl) {
352   // Initialize values based on the command line flags if they were provided.
353   if (clOptions.isConstructed()) {
354     disableMultithreading(clOptions->disableThreading);
355     printOpOnDiagnostic(clOptions->printOpOnDiagnostic);
356     printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic);
357   }
358 
359   // Ensure the builtin dialect is always pre-loaded.
360   getOrLoadDialect<BuiltinDialect>();
361 
362   // Pre-populate the registry.
363   registry.appendTo(impl->dialectsRegistry);
364 
365   // Initialize several common attributes and types to avoid the need to lock
366   // the context when accessing them.
367 
368   //// Types.
369   /// Floating-point Types.
370   impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
371   impl->f16Ty = TypeUniquer::get<Float16Type>(this);
372   impl->f32Ty = TypeUniquer::get<Float32Type>(this);
373   impl->f64Ty = TypeUniquer::get<Float64Type>(this);
374   impl->f80Ty = TypeUniquer::get<Float80Type>(this);
375   impl->f128Ty = TypeUniquer::get<Float128Type>(this);
376   /// Index Type.
377   impl->indexTy = TypeUniquer::get<IndexType>(this);
378   /// Integer Types.
379   impl->int1Ty = TypeUniquer::get<IntegerType>(this, 1, IntegerType::Signless);
380   impl->int8Ty = TypeUniquer::get<IntegerType>(this, 8, IntegerType::Signless);
381   impl->int16Ty =
382       TypeUniquer::get<IntegerType>(this, 16, IntegerType::Signless);
383   impl->int32Ty =
384       TypeUniquer::get<IntegerType>(this, 32, IntegerType::Signless);
385   impl->int64Ty =
386       TypeUniquer::get<IntegerType>(this, 64, IntegerType::Signless);
387   impl->int128Ty =
388       TypeUniquer::get<IntegerType>(this, 128, IntegerType::Signless);
389   /// None Type.
390   impl->noneType = TypeUniquer::get<NoneType>(this);
391 
392   //// Attributes.
393   //// Note: These must be registered after the types as they may generate one
394   //// of the above types internally.
395   /// Unknown Location Attribute.
396   impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(this);
397   /// Bool Attributes.
398   impl->falseAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, false);
399   impl->trueAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, true);
400   /// Unit Attribute.
401   impl->unitAttr = AttributeUniquer::get<UnitAttr>(this);
402   /// The empty dictionary attribute.
403   impl->emptyDictionaryAttr = DictionaryAttr::getEmptyUnchecked(this);
404   /// The empty string attribute.
405   impl->emptyStringAttr = StringAttr::getEmptyStringAttrUnchecked(this);
406 
407   // Register the affine storage objects with the uniquer.
408   impl->affineUniquer
409       .registerParametricStorageType<AffineBinaryOpExprStorage>();
410   impl->affineUniquer
411       .registerParametricStorageType<AffineConstantExprStorage>();
412   impl->affineUniquer.registerParametricStorageType<AffineDimExprStorage>();
413 }
414 
415 MLIRContext::~MLIRContext() {}
416 
417 /// Copy the specified array of elements into memory managed by the provided
418 /// bump pointer allocator.  This assumes the elements are all PODs.
419 template <typename T>
420 static ArrayRef<T> copyArrayRefInto(llvm::BumpPtrAllocator &allocator,
421                                     ArrayRef<T> elements) {
422   auto result = allocator.Allocate<T>(elements.size());
423   std::uninitialized_copy(elements.begin(), elements.end(), result);
424   return ArrayRef<T>(result, elements.size());
425 }
426 
427 //===----------------------------------------------------------------------===//
428 // Debugging
429 //===----------------------------------------------------------------------===//
430 
431 DebugActionManager &MLIRContext::getDebugActionManager() {
432   return getImpl().debugActionManager;
433 }
434 
435 //===----------------------------------------------------------------------===//
436 // Diagnostic Handlers
437 //===----------------------------------------------------------------------===//
438 
439 /// Returns the diagnostic engine for this context.
440 DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
441 
442 //===----------------------------------------------------------------------===//
443 // Dialect and Operation Registration
444 //===----------------------------------------------------------------------===//
445 
446 void MLIRContext::appendDialectRegistry(const DialectRegistry &registry) {
447   registry.appendTo(impl->dialectsRegistry);
448 
449   // For the already loaded dialects, register the interfaces immediately.
450   for (const auto &kvp : impl->loadedDialects)
451     registry.registerDelayedInterfaces(kvp.second.get());
452 }
453 
454 const DialectRegistry &MLIRContext::getDialectRegistry() {
455   return impl->dialectsRegistry;
456 }
457 
458 /// Return information about all registered IR dialects.
459 std::vector<Dialect *> MLIRContext::getLoadedDialects() {
460   std::vector<Dialect *> result;
461   result.reserve(impl->loadedDialects.size());
462   for (auto &dialect : impl->loadedDialects)
463     result.push_back(dialect.second.get());
464   llvm::array_pod_sort(result.begin(), result.end(),
465                        [](Dialect *const *lhs, Dialect *const *rhs) -> int {
466                          return (*lhs)->getNamespace() < (*rhs)->getNamespace();
467                        });
468   return result;
469 }
470 std::vector<StringRef> MLIRContext::getAvailableDialects() {
471   std::vector<StringRef> result;
472   for (auto dialect : impl->dialectsRegistry.getDialectNames())
473     result.push_back(dialect);
474   return result;
475 }
476 
477 /// Get a registered IR dialect with the given namespace. If none is found,
478 /// then return nullptr.
479 Dialect *MLIRContext::getLoadedDialect(StringRef name) {
480   // Dialects are sorted by name, so we can use binary search for lookup.
481   auto it = impl->loadedDialects.find(name);
482   return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr;
483 }
484 
485 Dialect *MLIRContext::getOrLoadDialect(StringRef name) {
486   Dialect *dialect = getLoadedDialect(name);
487   if (dialect)
488     return dialect;
489   DialectAllocatorFunctionRef allocator =
490       impl->dialectsRegistry.getDialectAllocator(name);
491   return allocator ? allocator(this) : nullptr;
492 }
493 
494 /// Get a dialect for the provided namespace and TypeID: abort the program if a
495 /// dialect exist for this namespace with different TypeID. Returns a pointer to
496 /// the dialect owned by the context.
497 Dialect *
498 MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
499                               function_ref<std::unique_ptr<Dialect>()> ctor) {
500   auto &impl = getImpl();
501   // Get the correct insertion position sorted by namespace.
502   std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace];
503 
504   if (!dialect) {
505     LLVM_DEBUG(llvm::dbgs()
506                << "Load new dialect in Context " << dialectNamespace << "\n");
507 #ifndef NDEBUG
508     if (impl.multiThreadedExecutionContext != 0)
509       llvm::report_fatal_error(
510           "Loading a dialect (" + dialectNamespace +
511           ") while in a multi-threaded execution context (maybe "
512           "the PassManager): this can indicate a "
513           "missing `dependentDialects` in a pass for example.");
514 #endif
515     dialect = ctor();
516     assert(dialect && "dialect ctor failed");
517 
518     // Refresh all the identifiers dialect field, this catches cases where a
519     // dialect may be loaded after identifier prefixed with this dialect name
520     // were already created.
521     llvm::SmallString<32> dialectPrefix(dialectNamespace);
522     dialectPrefix.push_back('.');
523     for (auto &identifierEntry : impl.identifiers)
524       if (identifierEntry.second.is<MLIRContext *>() &&
525           identifierEntry.first().startswith(dialectPrefix))
526         identifierEntry.second = dialect.get();
527 
528     // Actually register the interfaces with delayed registration.
529     impl.dialectsRegistry.registerDelayedInterfaces(dialect.get());
530     return dialect.get();
531   }
532 
533   // Abort if dialect with namespace has already been registered.
534   if (dialect->getTypeID() != dialectID)
535     llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
536                              "' has already been registered");
537 
538   return dialect.get();
539 }
540 
541 void MLIRContext::loadAllAvailableDialects() {
542   for (StringRef name : getAvailableDialects())
543     getOrLoadDialect(name);
544 }
545 
546 llvm::hash_code MLIRContext::getRegistryHash() {
547   llvm::hash_code hash(0);
548   // Factor in number of loaded dialects, attributes, operations, types.
549   hash = llvm::hash_combine(hash, impl->loadedDialects.size());
550   hash = llvm::hash_combine(hash, impl->registeredAttributes.size());
551   hash = llvm::hash_combine(hash, impl->registeredOperations.size());
552   hash = llvm::hash_combine(hash, impl->registeredTypes.size());
553   return hash;
554 }
555 
556 bool MLIRContext::allowsUnregisteredDialects() {
557   return impl->allowUnregisteredDialects;
558 }
559 
560 void MLIRContext::allowUnregisteredDialects(bool allowing) {
561   impl->allowUnregisteredDialects = allowing;
562 }
563 
564 /// Return true if multi-threading is enabled by the context.
565 bool MLIRContext::isMultithreadingEnabled() {
566   return impl->threadingIsEnabled && llvm::llvm_is_multithreaded();
567 }
568 
569 /// Set the flag specifying if multi-threading is disabled by the context.
570 void MLIRContext::disableMultithreading(bool disable) {
571   impl->threadingIsEnabled = !disable;
572 
573   // Update the threading mode for each of the uniquers.
574   impl->affineUniquer.disableMultithreading(disable);
575   impl->attributeUniquer.disableMultithreading(disable);
576   impl->typeUniquer.disableMultithreading(disable);
577 }
578 
579 void MLIRContext::enterMultiThreadedExecution() {
580 #ifndef NDEBUG
581   ++impl->multiThreadedExecutionContext;
582 #endif
583 }
584 void MLIRContext::exitMultiThreadedExecution() {
585 #ifndef NDEBUG
586   --impl->multiThreadedExecutionContext;
587 #endif
588 }
589 
590 /// Return true if we should attach the operation to diagnostics emitted via
591 /// Operation::emit.
592 bool MLIRContext::shouldPrintOpOnDiagnostic() {
593   return impl->printOpOnDiagnostic;
594 }
595 
596 /// Set the flag specifying if we should attach the operation to diagnostics
597 /// emitted via Operation::emit.
598 void MLIRContext::printOpOnDiagnostic(bool enable) {
599   impl->printOpOnDiagnostic = enable;
600 }
601 
602 /// Return true if we should attach the current stacktrace to diagnostics when
603 /// emitted.
604 bool MLIRContext::shouldPrintStackTraceOnDiagnostic() {
605   return impl->printStackTraceOnDiagnostic;
606 }
607 
608 /// Set the flag specifying if we should attach the current stacktrace when
609 /// emitting diagnostics.
610 void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
611   impl->printStackTraceOnDiagnostic = enable;
612 }
613 
614 /// Return information about all registered operations.  This isn't very
615 /// efficient, typically you should ask the operations about their properties
616 /// directly.
617 std::vector<AbstractOperation *> MLIRContext::getRegisteredOperations() {
618   // We just have the operations in a non-deterministic hash table order. Dump
619   // into a temporary array, then sort it by operation name to get a stable
620   // ordering.
621   llvm::StringMap<AbstractOperation> &registeredOps =
622       impl->registeredOperations;
623 
624   std::vector<AbstractOperation *> result;
625   result.reserve(registeredOps.size());
626   for (auto &elt : registeredOps)
627     result.push_back(&elt.second);
628   llvm::array_pod_sort(
629       result.begin(), result.end(),
630       [](AbstractOperation *const *lhs, AbstractOperation *const *rhs) {
631         return (*lhs)->name.compare((*rhs)->name);
632       });
633 
634   return result;
635 }
636 
637 bool MLIRContext::isOperationRegistered(StringRef name) {
638   return impl->registeredOperations.count(name);
639 }
640 
641 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
642   auto &impl = context->getImpl();
643   assert(impl.multiThreadedExecutionContext == 0 &&
644          "Registering a new type kind while in a multi-threaded execution "
645          "context");
646   auto *newInfo =
647       new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>())
648           AbstractType(std::move(typeInfo));
649   if (!impl.registeredTypes.insert({typeID, newInfo}).second)
650     llvm::report_fatal_error("Dialect Type already registered.");
651 }
652 
653 void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
654   auto &impl = context->getImpl();
655   assert(impl.multiThreadedExecutionContext == 0 &&
656          "Registering a new attribute kind while in a multi-threaded execution "
657          "context");
658   auto *newInfo =
659       new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>())
660           AbstractAttribute(std::move(attrInfo));
661   if (!impl.registeredAttributes.insert({typeID, newInfo}).second)
662     llvm::report_fatal_error("Dialect Attribute already registered.");
663 }
664 
665 //===----------------------------------------------------------------------===//
666 // AbstractAttribute
667 //===----------------------------------------------------------------------===//
668 
669 /// Get the dialect that registered the attribute with the provided typeid.
670 const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID,
671                                                    MLIRContext *context) {
672   auto &impl = context->getImpl();
673   auto it = impl.registeredAttributes.find(typeID);
674   if (it == impl.registeredAttributes.end())
675     llvm::report_fatal_error("Trying to create an Attribute that was not "
676                              "registered in this MLIRContext.");
677   return *it->second;
678 }
679 
680 //===----------------------------------------------------------------------===//
681 // AbstractOperation
682 //===----------------------------------------------------------------------===//
683 
684 ParseResult AbstractOperation::parseAssembly(OpAsmParser &parser,
685                                              OperationState &result) const {
686   return parseAssemblyFn(parser, result);
687 }
688 
689 /// Look up the specified operation in the operation set and return a pointer
690 /// to it if present. Otherwise, return a null pointer.
691 const AbstractOperation *AbstractOperation::lookup(StringRef opName,
692                                                    MLIRContext *context) {
693   auto &impl = context->getImpl();
694   auto it = impl.registeredOperations.find(opName);
695   if (it != impl.registeredOperations.end())
696     return &it->second;
697   return nullptr;
698 }
699 
700 void AbstractOperation::insert(
701     StringRef name, Dialect &dialect, TypeID typeID,
702     ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
703     VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
704     GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
705     detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait) {
706   AbstractOperation opInfo(name, dialect, typeID, std::move(parseAssembly),
707                            std::move(printAssembly),
708                            std::move(verifyInvariants), std::move(foldHook),
709                            std::move(getCanonicalizationPatterns),
710                            std::move(interfaceMap), std::move(hasTrait));
711 
712   auto &impl = dialect.getContext()->getImpl();
713   assert(impl.multiThreadedExecutionContext == 0 &&
714          "Registering a new operation kind while in a multi-threaded execution "
715          "context");
716   if (!impl.registeredOperations.insert({name, std::move(opInfo)}).second) {
717     llvm::errs() << "error: operation named '" << name
718                  << "' is already registered.\n";
719     abort();
720   }
721 }
722 
723 AbstractOperation::AbstractOperation(
724     StringRef name, Dialect &dialect, TypeID typeID,
725     ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
726     VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
727     GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
728     detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait)
729     : name(Identifier::get(name, dialect.getContext())), dialect(dialect),
730       typeID(typeID), interfaceMap(std::move(interfaceMap)),
731       foldHookFn(std::move(foldHook)),
732       getCanonicalizationPatternsFn(std::move(getCanonicalizationPatterns)),
733       hasTraitFn(std::move(hasTrait)),
734       parseAssemblyFn(std::move(parseAssembly)),
735       printAssemblyFn(std::move(printAssembly)),
736       verifyInvariantsFn(std::move(verifyInvariants)) {}
737 
738 //===----------------------------------------------------------------------===//
739 // AbstractType
740 //===----------------------------------------------------------------------===//
741 
742 const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) {
743   auto &impl = context->getImpl();
744   auto it = impl.registeredTypes.find(typeID);
745   if (it == impl.registeredTypes.end())
746     llvm::report_fatal_error(
747         "Trying to create a Type that was not registered in this MLIRContext.");
748   return *it->second;
749 }
750 
751 //===----------------------------------------------------------------------===//
752 // Identifier uniquing
753 //===----------------------------------------------------------------------===//
754 
755 /// Return an identifier for the specified string.
756 Identifier Identifier::get(const Twine &string, MLIRContext *context) {
757   SmallString<32> tempStr;
758   StringRef str = string.toStringRef(tempStr);
759 
760   // Check invariants after seeing if we already have something in the
761   // identifier table - if we already had it in the table, then it already
762   // passed invariant checks.
763   assert(!str.empty() && "Cannot create an empty identifier");
764   assert(str.find('\0') == StringRef::npos &&
765          "Cannot create an identifier with a nul character");
766 
767   auto getDialectOrContext = [&]() {
768     PointerUnion<Dialect *, MLIRContext *> dialectOrContext = context;
769     auto dialectNamePair = str.split('.');
770     if (!dialectNamePair.first.empty())
771       if (Dialect *dialect = context->getLoadedDialect(dialectNamePair.first))
772         dialectOrContext = dialect;
773     return dialectOrContext;
774   };
775 
776   auto &impl = context->getImpl();
777   if (!context->isMultithreadingEnabled()) {
778     auto insertedIt = impl.identifiers.insert({str, nullptr});
779     if (insertedIt.second)
780       insertedIt.first->second = getDialectOrContext();
781     return Identifier(&*insertedIt.first);
782   }
783 
784   // Check for an existing instance in the local cache.
785   auto *&localEntry = (*impl.localIdentifierCache)[str];
786   if (localEntry)
787     return Identifier(localEntry);
788 
789   // Check for an existing identifier in read-only mode.
790   {
791     llvm::sys::SmartScopedReader<true> contextLock(impl.identifierMutex);
792     auto it = impl.identifiers.find(str);
793     if (it != impl.identifiers.end()) {
794       localEntry = &*it;
795       return Identifier(localEntry);
796     }
797   }
798 
799   // Acquire a writer-lock so that we can safely create the new instance.
800   llvm::sys::SmartScopedWriter<true> contextLock(impl.identifierMutex);
801   auto it = impl.identifiers.insert({str, getDialectOrContext()}).first;
802   localEntry = &*it;
803   return Identifier(localEntry);
804 }
805 
806 Dialect *Identifier::getDialect() {
807   return entry->second.dyn_cast<Dialect *>();
808 }
809 
810 MLIRContext *Identifier::getContext() {
811   if (Dialect *dialect = getDialect())
812     return dialect->getContext();
813   return entry->second.get<MLIRContext *>();
814 }
815 
816 //===----------------------------------------------------------------------===//
817 // Type uniquing
818 //===----------------------------------------------------------------------===//
819 
820 /// Returns the storage uniquer used for constructing type storage instances.
821 /// This should not be used directly.
822 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
823 
824 BFloat16Type BFloat16Type::get(MLIRContext *context) {
825   return context->getImpl().bf16Ty;
826 }
827 Float16Type Float16Type::get(MLIRContext *context) {
828   return context->getImpl().f16Ty;
829 }
830 Float32Type Float32Type::get(MLIRContext *context) {
831   return context->getImpl().f32Ty;
832 }
833 Float64Type Float64Type::get(MLIRContext *context) {
834   return context->getImpl().f64Ty;
835 }
836 Float80Type Float80Type::get(MLIRContext *context) {
837   return context->getImpl().f80Ty;
838 }
839 Float128Type Float128Type::get(MLIRContext *context) {
840   return context->getImpl().f128Ty;
841 }
842 
843 /// Get an instance of the IndexType.
844 IndexType IndexType::get(MLIRContext *context) {
845   return context->getImpl().indexTy;
846 }
847 
848 /// Return an existing integer type instance if one is cached within the
849 /// context.
850 static IntegerType
851 getCachedIntegerType(unsigned width,
852                      IntegerType::SignednessSemantics signedness,
853                      MLIRContext *context) {
854   if (signedness != IntegerType::Signless)
855     return IntegerType();
856 
857   switch (width) {
858   case 1:
859     return context->getImpl().int1Ty;
860   case 8:
861     return context->getImpl().int8Ty;
862   case 16:
863     return context->getImpl().int16Ty;
864   case 32:
865     return context->getImpl().int32Ty;
866   case 64:
867     return context->getImpl().int64Ty;
868   case 128:
869     return context->getImpl().int128Ty;
870   default:
871     return IntegerType();
872   }
873 }
874 
875 IntegerType IntegerType::get(MLIRContext *context, unsigned width,
876                              IntegerType::SignednessSemantics signedness) {
877   if (auto cached = getCachedIntegerType(width, signedness, context))
878     return cached;
879   return Base::get(context, width, signedness);
880 }
881 
882 IntegerType
883 IntegerType::getChecked(function_ref<InFlightDiagnostic()> emitError,
884                         MLIRContext *context, unsigned width,
885                         SignednessSemantics signedness) {
886   if (auto cached = getCachedIntegerType(width, signedness, context))
887     return cached;
888   return Base::getChecked(emitError, context, width, signedness);
889 }
890 
891 /// Get an instance of the NoneType.
892 NoneType NoneType::get(MLIRContext *context) {
893   if (NoneType cachedInst = context->getImpl().noneType)
894     return cachedInst;
895   // Note: May happen when initializing the singleton attributes of the builtin
896   // dialect.
897   return Base::get(context);
898 }
899 
900 //===----------------------------------------------------------------------===//
901 // Attribute uniquing
902 //===----------------------------------------------------------------------===//
903 
904 /// Returns the storage uniquer used for constructing attribute storage
905 /// instances. This should not be used directly.
906 StorageUniquer &MLIRContext::getAttributeUniquer() {
907   return getImpl().attributeUniquer;
908 }
909 
910 /// Initialize the given attribute storage instance.
911 void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage,
912                                                   MLIRContext *ctx,
913                                                   TypeID attrID) {
914   storage->initialize(AbstractAttribute::lookup(attrID, ctx));
915 
916   // If the attribute did not provide a type, then default to NoneType.
917   if (!storage->getType())
918     storage->setType(NoneType::get(ctx));
919 }
920 
921 BoolAttr BoolAttr::get(MLIRContext *context, bool value) {
922   return value ? context->getImpl().trueAttr : context->getImpl().falseAttr;
923 }
924 
925 UnitAttr UnitAttr::get(MLIRContext *context) {
926   return context->getImpl().unitAttr;
927 }
928 
929 UnknownLoc UnknownLoc::get(MLIRContext *context) {
930   return context->getImpl().unknownLocAttr;
931 }
932 
933 /// Return empty dictionary.
934 DictionaryAttr DictionaryAttr::getEmpty(MLIRContext *context) {
935   return context->getImpl().emptyDictionaryAttr;
936 }
937 
938 /// Return an empty string.
939 StringAttr StringAttr::get(MLIRContext *context) {
940   return context->getImpl().emptyStringAttr;
941 }
942 
943 //===----------------------------------------------------------------------===//
944 // AffineMap uniquing
945 //===----------------------------------------------------------------------===//
946 
947 StorageUniquer &MLIRContext::getAffineUniquer() {
948   return getImpl().affineUniquer;
949 }
950 
951 AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount,
952                              ArrayRef<AffineExpr> results,
953                              MLIRContext *context) {
954   auto &impl = context->getImpl();
955   auto key = std::make_tuple(dimCount, symbolCount, results);
956 
957   // Safely get or create an AffineMap instance.
958   return safeGetOrCreate(
959       impl.affineMaps, key, impl.affineMutex, impl.threadingIsEnabled, [&] {
960         auto *res = impl.affineAllocator.Allocate<detail::AffineMapStorage>();
961 
962         // Copy the results into the bump pointer.
963         results = copyArrayRefInto(impl.affineAllocator, results);
964 
965         // Initialize the memory using placement new.
966         new (res)
967             detail::AffineMapStorage{dimCount, symbolCount, results, context};
968         return AffineMap(res);
969       });
970 }
971 
972 AffineMap AffineMap::get(MLIRContext *context) {
973   return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context);
974 }
975 
976 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
977                          MLIRContext *context) {
978   return getImpl(dimCount, symbolCount, /*results=*/{}, context);
979 }
980 
981 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
982                          AffineExpr result) {
983   return getImpl(dimCount, symbolCount, {result}, result.getContext());
984 }
985 
986 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
987                          ArrayRef<AffineExpr> results, MLIRContext *context) {
988   return getImpl(dimCount, symbolCount, results, context);
989 }
990 
991 //===----------------------------------------------------------------------===//
992 // Integer Sets: these are allocated into the bump pointer, and are immutable.
993 // Unlike AffineMap's, these are uniqued only if they are small.
994 //===----------------------------------------------------------------------===//
995 
996 IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount,
997                            ArrayRef<AffineExpr> constraints,
998                            ArrayRef<bool> eqFlags) {
999   // The number of constraints can't be zero.
1000   assert(!constraints.empty());
1001   assert(constraints.size() == eqFlags.size());
1002 
1003   auto &impl = constraints[0].getContext()->getImpl();
1004 
1005   // A utility function to construct a new IntegerSetStorage instance.
1006   auto constructorFn = [&] {
1007     auto *res = impl.affineAllocator.Allocate<detail::IntegerSetStorage>();
1008 
1009     // Copy the results and equality flags into the bump pointer.
1010     constraints = copyArrayRefInto(impl.affineAllocator, constraints);
1011     eqFlags = copyArrayRefInto(impl.affineAllocator, eqFlags);
1012 
1013     // Initialize the memory using placement new.
1014     new (res)
1015         detail::IntegerSetStorage{dimCount, symbolCount, constraints, eqFlags};
1016     return IntegerSet(res);
1017   };
1018 
1019   // If this instance is uniqued, then we handle it separately so that multiple
1020   // threads may simultaneously access existing instances.
1021   if (constraints.size() < IntegerSet::kUniquingThreshold) {
1022     auto key = std::make_tuple(dimCount, symbolCount, constraints, eqFlags);
1023     return safeGetOrCreate(impl.integerSets, key, impl.affineMutex,
1024                            impl.threadingIsEnabled, constructorFn);
1025   }
1026 
1027   // Otherwise, acquire a writer-lock so that we can safely create the new
1028   // instance.
1029   ScopedWriterLock affineLock(impl.affineMutex, impl.threadingIsEnabled);
1030   return constructorFn();
1031 }
1032 
1033 //===----------------------------------------------------------------------===//
1034 // StorageUniquerSupport
1035 //===----------------------------------------------------------------------===//
1036 
1037 /// Utility method to generate a callback that can be used to generate a
1038 /// diagnostic when checking the construction invariants of a storage object.
1039 /// This is defined out-of-line to avoid the need to include Location.h.
1040 llvm::unique_function<InFlightDiagnostic()>
1041 mlir::detail::getDefaultDiagnosticEmitFn(MLIRContext *ctx) {
1042   return [ctx] { return emitError(UnknownLoc::get(ctx)); };
1043 }
1044 llvm::unique_function<InFlightDiagnostic()>
1045 mlir::detail::getDefaultDiagnosticEmitFn(const Location &loc) {
1046   return [=] { return emitError(loc); };
1047 }
1048