1 //===- MLIRContext.cpp - MLIR Type Classes --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/MLIRContext.h"
10 #include "AffineExprDetail.h"
11 #include "AffineMapDetail.h"
12 #include "AttributeDetail.h"
13 #include "IntegerSetDetail.h"
14 #include "TypeDetail.h"
15 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/BuiltinDialect.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/IR/IntegerSet.h"
22 #include "mlir/IR/Location.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/IR/Types.h"
25 #include "mlir/Support/DebugAction.h"
26 #include "llvm/ADT/DenseMap.h"
27 #include "llvm/ADT/DenseSet.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/ADT/SmallString.h"
30 #include "llvm/ADT/StringSet.h"
31 #include "llvm/ADT/Twine.h"
32 #include "llvm/Support/Allocator.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/Mutex.h"
36 #include "llvm/Support/RWMutex.h"
37 #include "llvm/Support/ThreadPool.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include <memory>
40
41 #define DEBUG_TYPE "mlircontext"
42
43 using namespace mlir;
44 using namespace mlir::detail;
45
46 //===----------------------------------------------------------------------===//
47 // MLIRContext CommandLine Options
48 //===----------------------------------------------------------------------===//
49
50 namespace {
51 /// This struct contains command line options that can be used to initialize
52 /// various bits of an MLIRContext. This uses a struct wrapper to avoid the need
53 /// for global command line options.
54 struct MLIRContextOptions {
55 llvm::cl::opt<bool> disableThreading{
56 "mlir-disable-threading",
57 llvm::cl::desc("Disable multi-threading within MLIR, overrides any "
58 "further call to MLIRContext::enableMultiThreading()")};
59
60 llvm::cl::opt<bool> printOpOnDiagnostic{
61 "mlir-print-op-on-diagnostic",
62 llvm::cl::desc("When a diagnostic is emitted on an operation, also print "
63 "the operation as an attached note"),
64 llvm::cl::init(true)};
65
66 llvm::cl::opt<bool> printStackTraceOnDiagnostic{
67 "mlir-print-stacktrace-on-diagnostic",
68 llvm::cl::desc("When a diagnostic is emitted, also print the stack trace "
69 "as an attached note")};
70 };
71 } // namespace
72
73 static llvm::ManagedStatic<MLIRContextOptions> clOptions;
74
isThreadingGloballyDisabled()75 static bool isThreadingGloballyDisabled() {
76 #if LLVM_ENABLE_THREADS != 0
77 return clOptions.isConstructed() && clOptions->disableThreading;
78 #else
79 return true;
80 #endif
81 }
82
83 /// Register a set of useful command-line options that can be used to configure
84 /// various flags within the MLIRContext. These flags are used when constructing
85 /// an MLIR context for initialization.
registerMLIRContextCLOptions()86 void mlir::registerMLIRContextCLOptions() {
87 // Make sure that the options struct has been initialized.
88 *clOptions;
89 }
90
91 //===----------------------------------------------------------------------===//
92 // Locking Utilities
93 //===----------------------------------------------------------------------===//
94
95 namespace {
96 /// Utility writer lock that takes a runtime flag that specifies if we really
97 /// need to lock.
98 struct ScopedWriterLock {
ScopedWriterLock__anonc2df5d7c0211::ScopedWriterLock99 ScopedWriterLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock)
100 : mutex(shouldLock ? &mutexParam : nullptr) {
101 if (mutex)
102 mutex->lock();
103 }
~ScopedWriterLock__anonc2df5d7c0211::ScopedWriterLock104 ~ScopedWriterLock() {
105 if (mutex)
106 mutex->unlock();
107 }
108 llvm::sys::SmartRWMutex<true> *mutex;
109 };
110 } // namespace
111
112 //===----------------------------------------------------------------------===//
113 // MLIRContextImpl
114 //===----------------------------------------------------------------------===//
115
116 namespace mlir {
117 /// This is the implementation of the MLIRContext class, using the pImpl idiom.
118 /// This class is completely private to this file, so everything is public.
119 class MLIRContextImpl {
120 public:
121 //===--------------------------------------------------------------------===//
122 // Debugging
123 //===--------------------------------------------------------------------===//
124
125 /// An action manager for use within the context.
126 DebugActionManager debugActionManager;
127
128 //===--------------------------------------------------------------------===//
129 // Diagnostics
130 //===--------------------------------------------------------------------===//
131 DiagnosticEngine diagEngine;
132
133 //===--------------------------------------------------------------------===//
134 // Options
135 //===--------------------------------------------------------------------===//
136
137 /// In most cases, creating operation in unregistered dialect is not desired
138 /// and indicate a misconfiguration of the compiler. This option enables to
139 /// detect such use cases
140 bool allowUnregisteredDialects = false;
141
142 /// Enable support for multi-threading within MLIR.
143 bool threadingIsEnabled = true;
144
145 /// Track if we are currently executing in a threaded execution environment
146 /// (like the pass-manager): this is only a debugging feature to help reducing
147 /// the chances of data races one some context APIs.
148 #ifndef NDEBUG
149 std::atomic<int> multiThreadedExecutionContext{0};
150 #endif
151
152 /// If the operation should be attached to diagnostics printed via the
153 /// Operation::emit methods.
154 bool printOpOnDiagnostic = true;
155
156 /// If the current stack trace should be attached when emitting diagnostics.
157 bool printStackTraceOnDiagnostic = false;
158
159 //===--------------------------------------------------------------------===//
160 // Other
161 //===--------------------------------------------------------------------===//
162
163 /// This points to the ThreadPool used when processing MLIR tasks in parallel.
164 /// It can't be nullptr when multi-threading is enabled. Otherwise if
165 /// multi-threading is disabled, and the threadpool wasn't externally provided
166 /// using `setThreadPool`, this will be nullptr.
167 llvm::ThreadPool *threadPool = nullptr;
168
169 /// In case where the thread pool is owned by the context, this ensures
170 /// destruction with the context.
171 std::unique_ptr<llvm::ThreadPool> ownedThreadPool;
172
173 /// This is a list of dialects that are created referring to this context.
174 /// The MLIRContext owns the objects.
175 DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects;
176 DialectRegistry dialectsRegistry;
177
178 /// An allocator used for AbstractAttribute and AbstractType objects.
179 llvm::BumpPtrAllocator abstractDialectSymbolAllocator;
180
181 /// This is a mapping from operation name to the operation info describing it.
182 llvm::StringMap<OperationName::Impl> operations;
183
184 /// A vector of operation info specifically for registered operations.
185 llvm::StringMap<RegisteredOperationName> registeredOperations;
186
187 /// This is a sorted container of registered operations for a deterministic
188 /// and efficient `getRegisteredOperations` implementation.
189 SmallVector<RegisteredOperationName, 0> sortedRegisteredOperations;
190
191 /// A mutex used when accessing operation information.
192 llvm::sys::SmartRWMutex<true> operationInfoMutex;
193
194 //===--------------------------------------------------------------------===//
195 // Affine uniquing
196 //===--------------------------------------------------------------------===//
197
198 // Affine expression, map and integer set uniquing.
199 StorageUniquer affineUniquer;
200
201 //===--------------------------------------------------------------------===//
202 // Type uniquing
203 //===--------------------------------------------------------------------===//
204
205 DenseMap<TypeID, AbstractType *> registeredTypes;
206 StorageUniquer typeUniquer;
207
208 /// Cached Type Instances.
209 BFloat16Type bf16Ty;
210 Float16Type f16Ty;
211 Float32Type f32Ty;
212 Float64Type f64Ty;
213 Float80Type f80Ty;
214 Float128Type f128Ty;
215 IndexType indexTy;
216 IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
217 NoneType noneType;
218
219 //===--------------------------------------------------------------------===//
220 // Attribute uniquing
221 //===--------------------------------------------------------------------===//
222
223 DenseMap<TypeID, AbstractAttribute *> registeredAttributes;
224 StorageUniquer attributeUniquer;
225
226 /// Cached Attribute Instances.
227 BoolAttr falseAttr, trueAttr;
228 UnitAttr unitAttr;
229 UnknownLoc unknownLocAttr;
230 DictionaryAttr emptyDictionaryAttr;
231 StringAttr emptyStringAttr;
232
233 /// Map of string attributes that may reference a dialect, that are awaiting
234 /// that dialect to be loaded.
235 llvm::sys::SmartMutex<true> dialectRefStrAttrMutex;
236 DenseMap<StringRef, SmallVector<StringAttrStorage *>>
237 dialectReferencingStrAttrs;
238
239 public:
MLIRContextImpl(bool threadingIsEnabled)240 MLIRContextImpl(bool threadingIsEnabled)
241 : threadingIsEnabled(threadingIsEnabled) {
242 if (threadingIsEnabled) {
243 ownedThreadPool = std::make_unique<llvm::ThreadPool>();
244 threadPool = ownedThreadPool.get();
245 }
246 }
~MLIRContextImpl()247 ~MLIRContextImpl() {
248 for (auto typeMapping : registeredTypes)
249 typeMapping.second->~AbstractType();
250 for (auto attrMapping : registeredAttributes)
251 attrMapping.second->~AbstractAttribute();
252 }
253 };
254 } // namespace mlir
255
MLIRContext(Threading setting)256 MLIRContext::MLIRContext(Threading setting)
257 : MLIRContext(DialectRegistry(), setting) {}
258
MLIRContext(const DialectRegistry & registry,Threading setting)259 MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting)
260 : impl(new MLIRContextImpl(setting == Threading::ENABLED &&
261 !isThreadingGloballyDisabled())) {
262 // Initialize values based on the command line flags if they were provided.
263 if (clOptions.isConstructed()) {
264 printOpOnDiagnostic(clOptions->printOpOnDiagnostic);
265 printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic);
266 }
267
268 // Pre-populate the registry.
269 registry.appendTo(impl->dialectsRegistry);
270
271 // Ensure the builtin dialect is always pre-loaded.
272 getOrLoadDialect<BuiltinDialect>();
273
274 // Initialize several common attributes and types to avoid the need to lock
275 // the context when accessing them.
276
277 //// Types.
278 /// Floating-point Types.
279 impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
280 impl->f16Ty = TypeUniquer::get<Float16Type>(this);
281 impl->f32Ty = TypeUniquer::get<Float32Type>(this);
282 impl->f64Ty = TypeUniquer::get<Float64Type>(this);
283 impl->f80Ty = TypeUniquer::get<Float80Type>(this);
284 impl->f128Ty = TypeUniquer::get<Float128Type>(this);
285 /// Index Type.
286 impl->indexTy = TypeUniquer::get<IndexType>(this);
287 /// Integer Types.
288 impl->int1Ty = TypeUniquer::get<IntegerType>(this, 1, IntegerType::Signless);
289 impl->int8Ty = TypeUniquer::get<IntegerType>(this, 8, IntegerType::Signless);
290 impl->int16Ty =
291 TypeUniquer::get<IntegerType>(this, 16, IntegerType::Signless);
292 impl->int32Ty =
293 TypeUniquer::get<IntegerType>(this, 32, IntegerType::Signless);
294 impl->int64Ty =
295 TypeUniquer::get<IntegerType>(this, 64, IntegerType::Signless);
296 impl->int128Ty =
297 TypeUniquer::get<IntegerType>(this, 128, IntegerType::Signless);
298 /// None Type.
299 impl->noneType = TypeUniquer::get<NoneType>(this);
300
301 //// Attributes.
302 //// Note: These must be registered after the types as they may generate one
303 //// of the above types internally.
304 /// Unknown Location Attribute.
305 impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(this);
306 /// Bool Attributes.
307 impl->falseAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, false);
308 impl->trueAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, true);
309 /// Unit Attribute.
310 impl->unitAttr = AttributeUniquer::get<UnitAttr>(this);
311 /// The empty dictionary attribute.
312 impl->emptyDictionaryAttr = DictionaryAttr::getEmptyUnchecked(this);
313 /// The empty string attribute.
314 impl->emptyStringAttr = StringAttr::getEmptyStringAttrUnchecked(this);
315
316 // Register the affine storage objects with the uniquer.
317 impl->affineUniquer
318 .registerParametricStorageType<AffineBinaryOpExprStorage>();
319 impl->affineUniquer
320 .registerParametricStorageType<AffineConstantExprStorage>();
321 impl->affineUniquer.registerParametricStorageType<AffineDimExprStorage>();
322 impl->affineUniquer.registerParametricStorageType<AffineMapStorage>();
323 impl->affineUniquer.registerParametricStorageType<IntegerSetStorage>();
324 }
325
326 MLIRContext::~MLIRContext() = default;
327
328 /// Copy the specified array of elements into memory managed by the provided
329 /// bump pointer allocator. This assumes the elements are all PODs.
330 template <typename T>
copyArrayRefInto(llvm::BumpPtrAllocator & allocator,ArrayRef<T> elements)331 static ArrayRef<T> copyArrayRefInto(llvm::BumpPtrAllocator &allocator,
332 ArrayRef<T> elements) {
333 auto result = allocator.Allocate<T>(elements.size());
334 std::uninitialized_copy(elements.begin(), elements.end(), result);
335 return ArrayRef<T>(result, elements.size());
336 }
337
338 //===----------------------------------------------------------------------===//
339 // Debugging
340 //===----------------------------------------------------------------------===//
341
getDebugActionManager()342 DebugActionManager &MLIRContext::getDebugActionManager() {
343 return getImpl().debugActionManager;
344 }
345
346 //===----------------------------------------------------------------------===//
347 // Diagnostic Handlers
348 //===----------------------------------------------------------------------===//
349
350 /// Returns the diagnostic engine for this context.
getDiagEngine()351 DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
352
353 //===----------------------------------------------------------------------===//
354 // Dialect and Operation Registration
355 //===----------------------------------------------------------------------===//
356
appendDialectRegistry(const DialectRegistry & registry)357 void MLIRContext::appendDialectRegistry(const DialectRegistry ®istry) {
358 if (registry.isSubsetOf(impl->dialectsRegistry))
359 return;
360
361 assert(impl->multiThreadedExecutionContext == 0 &&
362 "appending to the MLIRContext dialect registry while in a "
363 "multi-threaded execution context");
364 registry.appendTo(impl->dialectsRegistry);
365
366 // For the already loaded dialects, apply any possible extensions immediately.
367 registry.applyExtensions(this);
368 }
369
getDialectRegistry()370 const DialectRegistry &MLIRContext::getDialectRegistry() {
371 return impl->dialectsRegistry;
372 }
373
374 /// Return information about all registered IR dialects.
getLoadedDialects()375 std::vector<Dialect *> MLIRContext::getLoadedDialects() {
376 std::vector<Dialect *> result;
377 result.reserve(impl->loadedDialects.size());
378 for (auto &dialect : impl->loadedDialects)
379 result.push_back(dialect.second.get());
380 llvm::array_pod_sort(result.begin(), result.end(),
381 [](Dialect *const *lhs, Dialect *const *rhs) -> int {
382 return (*lhs)->getNamespace() < (*rhs)->getNamespace();
383 });
384 return result;
385 }
getAvailableDialects()386 std::vector<StringRef> MLIRContext::getAvailableDialects() {
387 std::vector<StringRef> result;
388 for (auto dialect : impl->dialectsRegistry.getDialectNames())
389 result.push_back(dialect);
390 return result;
391 }
392
393 /// Get a registered IR dialect with the given namespace. If none is found,
394 /// then return nullptr.
getLoadedDialect(StringRef name)395 Dialect *MLIRContext::getLoadedDialect(StringRef name) {
396 // Dialects are sorted by name, so we can use binary search for lookup.
397 auto it = impl->loadedDialects.find(name);
398 return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr;
399 }
400
getOrLoadDialect(StringRef name)401 Dialect *MLIRContext::getOrLoadDialect(StringRef name) {
402 Dialect *dialect = getLoadedDialect(name);
403 if (dialect)
404 return dialect;
405 DialectAllocatorFunctionRef allocator =
406 impl->dialectsRegistry.getDialectAllocator(name);
407 return allocator ? allocator(this) : nullptr;
408 }
409
410 /// Get a dialect for the provided namespace and TypeID: abort the program if a
411 /// dialect exist for this namespace with different TypeID. Returns a pointer to
412 /// the dialect owned by the context.
413 Dialect *
getOrLoadDialect(StringRef dialectNamespace,TypeID dialectID,function_ref<std::unique_ptr<Dialect> ()> ctor)414 MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
415 function_ref<std::unique_ptr<Dialect>()> ctor) {
416 auto &impl = getImpl();
417 // Get the correct insertion position sorted by namespace.
418 auto dialectIt = impl.loadedDialects.find(dialectNamespace);
419
420 if (dialectIt == impl.loadedDialects.end()) {
421 LLVM_DEBUG(llvm::dbgs()
422 << "Load new dialect in Context " << dialectNamespace << "\n");
423 #ifndef NDEBUG
424 if (impl.multiThreadedExecutionContext != 0)
425 llvm::report_fatal_error(
426 "Loading a dialect (" + dialectNamespace +
427 ") while in a multi-threaded execution context (maybe "
428 "the PassManager): this can indicate a "
429 "missing `dependentDialects` in a pass for example.");
430 #endif
431 std::unique_ptr<Dialect> &dialect =
432 impl.loadedDialects.insert({dialectNamespace, ctor()}).first->second;
433 assert(dialect && "dialect ctor failed");
434
435 // Refresh all the identifiers dialect field, this catches cases where a
436 // dialect may be loaded after identifier prefixed with this dialect name
437 // were already created.
438 auto stringAttrsIt = impl.dialectReferencingStrAttrs.find(dialectNamespace);
439 if (stringAttrsIt != impl.dialectReferencingStrAttrs.end()) {
440 for (StringAttrStorage *storage : stringAttrsIt->second)
441 storage->referencedDialect = dialect.get();
442 impl.dialectReferencingStrAttrs.erase(stringAttrsIt);
443 }
444
445 // Apply any extensions to this newly loaded dialect.
446 impl.dialectsRegistry.applyExtensions(dialect.get());
447 return dialect.get();
448 }
449
450 // Abort if dialect with namespace has already been registered.
451 std::unique_ptr<Dialect> &dialect = dialectIt->second;
452 if (dialect->getTypeID() != dialectID)
453 llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
454 "' has already been registered");
455
456 return dialect.get();
457 }
458
loadAllAvailableDialects()459 void MLIRContext::loadAllAvailableDialects() {
460 for (StringRef name : getAvailableDialects())
461 getOrLoadDialect(name);
462 }
463
getRegistryHash()464 llvm::hash_code MLIRContext::getRegistryHash() {
465 llvm::hash_code hash(0);
466 // Factor in number of loaded dialects, attributes, operations, types.
467 hash = llvm::hash_combine(hash, impl->loadedDialects.size());
468 hash = llvm::hash_combine(hash, impl->registeredAttributes.size());
469 hash = llvm::hash_combine(hash, impl->registeredOperations.size());
470 hash = llvm::hash_combine(hash, impl->registeredTypes.size());
471 return hash;
472 }
473
allowsUnregisteredDialects()474 bool MLIRContext::allowsUnregisteredDialects() {
475 return impl->allowUnregisteredDialects;
476 }
477
allowUnregisteredDialects(bool allowing)478 void MLIRContext::allowUnregisteredDialects(bool allowing) {
479 assert(impl->multiThreadedExecutionContext == 0 &&
480 "changing MLIRContext `allow-unregistered-dialects` configuration "
481 "while in a multi-threaded execution context");
482 impl->allowUnregisteredDialects = allowing;
483 }
484
485 /// Return true if multi-threading is enabled by the context.
isMultithreadingEnabled()486 bool MLIRContext::isMultithreadingEnabled() {
487 return impl->threadingIsEnabled && llvm::llvm_is_multithreaded();
488 }
489
490 /// Set the flag specifying if multi-threading is disabled by the context.
disableMultithreading(bool disable)491 void MLIRContext::disableMultithreading(bool disable) {
492 // This API can be overridden by the global debugging flag
493 // --mlir-disable-threading
494 if (isThreadingGloballyDisabled())
495 return;
496 assert(impl->multiThreadedExecutionContext == 0 &&
497 "changing MLIRContext `disable-threading` configuration while "
498 "in a multi-threaded execution context");
499
500 impl->threadingIsEnabled = !disable;
501
502 // Update the threading mode for each of the uniquers.
503 impl->affineUniquer.disableMultithreading(disable);
504 impl->attributeUniquer.disableMultithreading(disable);
505 impl->typeUniquer.disableMultithreading(disable);
506
507 // Destroy thread pool (stop all threads) if it is no longer needed, or create
508 // a new one if multithreading was re-enabled.
509 if (disable) {
510 // If the thread pool is owned, explicitly set it to nullptr to avoid
511 // keeping a dangling pointer around. If the thread pool is externally
512 // owned, we don't do anything.
513 if (impl->ownedThreadPool) {
514 assert(impl->threadPool);
515 impl->threadPool = nullptr;
516 impl->ownedThreadPool.reset();
517 }
518 } else if (!impl->threadPool) {
519 // The thread pool isn't externally provided.
520 assert(!impl->ownedThreadPool);
521 impl->ownedThreadPool = std::make_unique<llvm::ThreadPool>();
522 impl->threadPool = impl->ownedThreadPool.get();
523 }
524 }
525
setThreadPool(llvm::ThreadPool & pool)526 void MLIRContext::setThreadPool(llvm::ThreadPool &pool) {
527 assert(!isMultithreadingEnabled() &&
528 "expected multi-threading to be disabled when setting a ThreadPool");
529 impl->threadPool = &pool;
530 impl->ownedThreadPool.reset();
531 enableMultithreading();
532 }
533
getNumThreads()534 unsigned MLIRContext::getNumThreads() {
535 if (isMultithreadingEnabled()) {
536 assert(impl->threadPool &&
537 "multi-threading is enabled but threadpool not set");
538 return impl->threadPool->getThreadCount();
539 }
540 // No multithreading or active thread pool. Return 1 thread.
541 return 1;
542 }
543
getThreadPool()544 llvm::ThreadPool &MLIRContext::getThreadPool() {
545 assert(isMultithreadingEnabled() &&
546 "expected multi-threading to be enabled within the context");
547 assert(impl->threadPool &&
548 "multi-threading is enabled but threadpool not set");
549 return *impl->threadPool;
550 }
551
enterMultiThreadedExecution()552 void MLIRContext::enterMultiThreadedExecution() {
553 #ifndef NDEBUG
554 ++impl->multiThreadedExecutionContext;
555 #endif
556 }
exitMultiThreadedExecution()557 void MLIRContext::exitMultiThreadedExecution() {
558 #ifndef NDEBUG
559 --impl->multiThreadedExecutionContext;
560 #endif
561 }
562
563 /// Return true if we should attach the operation to diagnostics emitted via
564 /// Operation::emit.
shouldPrintOpOnDiagnostic()565 bool MLIRContext::shouldPrintOpOnDiagnostic() {
566 return impl->printOpOnDiagnostic;
567 }
568
569 /// Set the flag specifying if we should attach the operation to diagnostics
570 /// emitted via Operation::emit.
printOpOnDiagnostic(bool enable)571 void MLIRContext::printOpOnDiagnostic(bool enable) {
572 assert(impl->multiThreadedExecutionContext == 0 &&
573 "changing MLIRContext `print-op-on-diagnostic` configuration while in "
574 "a multi-threaded execution context");
575 impl->printOpOnDiagnostic = enable;
576 }
577
578 /// Return true if we should attach the current stacktrace to diagnostics when
579 /// emitted.
shouldPrintStackTraceOnDiagnostic()580 bool MLIRContext::shouldPrintStackTraceOnDiagnostic() {
581 return impl->printStackTraceOnDiagnostic;
582 }
583
584 /// Set the flag specifying if we should attach the current stacktrace when
585 /// emitting diagnostics.
printStackTraceOnDiagnostic(bool enable)586 void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
587 assert(impl->multiThreadedExecutionContext == 0 &&
588 "changing MLIRContext `print-stacktrace-on-diagnostic` configuration "
589 "while in a multi-threaded execution context");
590 impl->printStackTraceOnDiagnostic = enable;
591 }
592
593 /// Return information about all registered operations.
getRegisteredOperations()594 ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
595 return impl->sortedRegisteredOperations;
596 }
597
isOperationRegistered(StringRef name)598 bool MLIRContext::isOperationRegistered(StringRef name) {
599 return RegisteredOperationName::lookup(name, this).has_value();
600 }
601
addType(TypeID typeID,AbstractType && typeInfo)602 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
603 auto &impl = context->getImpl();
604 assert(impl.multiThreadedExecutionContext == 0 &&
605 "Registering a new type kind while in a multi-threaded execution "
606 "context");
607 auto *newInfo =
608 new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>())
609 AbstractType(std::move(typeInfo));
610 if (!impl.registeredTypes.insert({typeID, newInfo}).second)
611 llvm::report_fatal_error("Dialect Type already registered.");
612 }
613
addAttribute(TypeID typeID,AbstractAttribute && attrInfo)614 void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
615 auto &impl = context->getImpl();
616 assert(impl.multiThreadedExecutionContext == 0 &&
617 "Registering a new attribute kind while in a multi-threaded execution "
618 "context");
619 auto *newInfo =
620 new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>())
621 AbstractAttribute(std::move(attrInfo));
622 if (!impl.registeredAttributes.insert({typeID, newInfo}).second)
623 llvm::report_fatal_error("Dialect Attribute already registered.");
624 }
625
626 //===----------------------------------------------------------------------===//
627 // AbstractAttribute
628 //===----------------------------------------------------------------------===//
629
630 /// Get the dialect that registered the attribute with the provided typeid.
lookup(TypeID typeID,MLIRContext * context)631 const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID,
632 MLIRContext *context) {
633 const AbstractAttribute *abstract = lookupMutable(typeID, context);
634 if (!abstract)
635 llvm::report_fatal_error("Trying to create an Attribute that was not "
636 "registered in this MLIRContext.");
637 return *abstract;
638 }
639
lookupMutable(TypeID typeID,MLIRContext * context)640 AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID,
641 MLIRContext *context) {
642 auto &impl = context->getImpl();
643 auto it = impl.registeredAttributes.find(typeID);
644 if (it == impl.registeredAttributes.end())
645 return nullptr;
646 return it->second;
647 }
648
649 //===----------------------------------------------------------------------===//
650 // OperationName
651 //===----------------------------------------------------------------------===//
652
OperationName(StringRef name,MLIRContext * context)653 OperationName::OperationName(StringRef name, MLIRContext *context) {
654 MLIRContextImpl &ctxImpl = context->getImpl();
655
656 // Check for an existing name in read-only mode.
657 bool isMultithreadingEnabled = context->isMultithreadingEnabled();
658 if (isMultithreadingEnabled) {
659 // Check the registered info map first. In the overwhelmingly common case,
660 // the entry will be in here and it also removes the need to acquire any
661 // locks.
662 auto registeredIt = ctxImpl.registeredOperations.find(name);
663 if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperations.end())) {
664 impl = registeredIt->second.impl;
665 return;
666 }
667
668 llvm::sys::SmartScopedReader<true> contextLock(ctxImpl.operationInfoMutex);
669 auto it = ctxImpl.operations.find(name);
670 if (it != ctxImpl.operations.end()) {
671 impl = &it->second;
672 return;
673 }
674 }
675
676 // Acquire a writer-lock so that we can safely create the new instance.
677 ScopedWriterLock lock(ctxImpl.operationInfoMutex, isMultithreadingEnabled);
678
679 auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)});
680 if (it.second)
681 it.first->second.name = StringAttr::get(context, name);
682 impl = &it.first->second;
683 }
684
getDialectNamespace() const685 StringRef OperationName::getDialectNamespace() const {
686 if (Dialect *dialect = getDialect())
687 return dialect->getNamespace();
688 return getStringRef().split('.').first;
689 }
690
691 //===----------------------------------------------------------------------===//
692 // RegisteredOperationName
693 //===----------------------------------------------------------------------===//
694
695 Optional<RegisteredOperationName>
lookup(StringRef name,MLIRContext * ctx)696 RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) {
697 auto &impl = ctx->getImpl();
698 auto it = impl.registeredOperations.find(name);
699 if (it != impl.registeredOperations.end())
700 return it->getValue();
701 return llvm::None;
702 }
703
704 ParseResult
parseAssembly(OpAsmParser & parser,OperationState & result) const705 RegisteredOperationName::parseAssembly(OpAsmParser &parser,
706 OperationState &result) const {
707 return impl->parseAssemblyFn(parser, result);
708 }
709
populateDefaultAttrs(NamedAttrList & attrs) const710 void RegisteredOperationName::populateDefaultAttrs(NamedAttrList &attrs) const {
711 impl->populateDefaultAttrsFn(*this, attrs);
712 }
713
insert(StringRef name,Dialect & dialect,TypeID typeID,ParseAssemblyFn && parseAssembly,PrintAssemblyFn && printAssembly,VerifyInvariantsFn && verifyInvariants,VerifyRegionInvariantsFn && verifyRegionInvariants,FoldHookFn && foldHook,GetCanonicalizationPatternsFn && getCanonicalizationPatterns,detail::InterfaceMap && interfaceMap,HasTraitFn && hasTrait,ArrayRef<StringRef> attrNames,PopulateDefaultAttrsFn && populateDefaultAttrs)714 void RegisteredOperationName::insert(
715 StringRef name, Dialect &dialect, TypeID typeID,
716 ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
717 VerifyInvariantsFn &&verifyInvariants,
718 VerifyRegionInvariantsFn &&verifyRegionInvariants, FoldHookFn &&foldHook,
719 GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
720 detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
721 ArrayRef<StringRef> attrNames,
722 PopulateDefaultAttrsFn &&populateDefaultAttrs) {
723 MLIRContext *ctx = dialect.getContext();
724 auto &ctxImpl = ctx->getImpl();
725 assert(ctxImpl.multiThreadedExecutionContext == 0 &&
726 "registering a new operation kind while in a multi-threaded execution "
727 "context");
728
729 // Register the attribute names of this operation.
730 MutableArrayRef<StringAttr> cachedAttrNames;
731 if (!attrNames.empty()) {
732 cachedAttrNames = MutableArrayRef<StringAttr>(
733 ctxImpl.abstractDialectSymbolAllocator.Allocate<StringAttr>(
734 attrNames.size()),
735 attrNames.size());
736 for (unsigned i : llvm::seq<unsigned>(0, attrNames.size()))
737 new (&cachedAttrNames[i]) StringAttr(StringAttr::get(ctx, attrNames[i]));
738 }
739
740 // Insert the operation info if it doesn't exist yet.
741 auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)});
742 if (it.second)
743 it.first->second.name = StringAttr::get(ctx, name);
744 OperationName::Impl &impl = it.first->second;
745
746 if (impl.isRegistered()) {
747 llvm::errs() << "error: operation named '" << name
748 << "' is already registered.\n";
749 abort();
750 }
751 auto emplaced = ctxImpl.registeredOperations.try_emplace(
752 name, RegisteredOperationName(&impl));
753 assert(emplaced.second && "operation name registration must be successful");
754
755 // Add emplaced operation name to the sorted operations container.
756 RegisteredOperationName &value = emplaced.first->getValue();
757 ctxImpl.sortedRegisteredOperations.insert(
758 llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value,
759 [](auto &lhs, auto &rhs) {
760 return lhs.getIdentifier().compare(
761 rhs.getIdentifier());
762 }),
763 value);
764
765 // Update the registered info for this operation.
766 impl.dialect = &dialect;
767 impl.typeID = typeID;
768 impl.interfaceMap = std::move(interfaceMap);
769 impl.foldHookFn = std::move(foldHook);
770 impl.getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns);
771 impl.hasTraitFn = std::move(hasTrait);
772 impl.parseAssemblyFn = std::move(parseAssembly);
773 impl.printAssemblyFn = std::move(printAssembly);
774 impl.verifyInvariantsFn = std::move(verifyInvariants);
775 impl.verifyRegionInvariantsFn = std::move(verifyRegionInvariants);
776 impl.attributeNames = cachedAttrNames;
777 impl.populateDefaultAttrsFn = std::move(populateDefaultAttrs);
778 }
779
780 //===----------------------------------------------------------------------===//
781 // AbstractType
782 //===----------------------------------------------------------------------===//
783
lookup(TypeID typeID,MLIRContext * context)784 const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) {
785 const AbstractType *type = lookupMutable(typeID, context);
786 if (!type)
787 llvm::report_fatal_error(
788 "Trying to create a Type that was not registered in this MLIRContext.");
789 return *type;
790 }
791
lookupMutable(TypeID typeID,MLIRContext * context)792 AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) {
793 auto &impl = context->getImpl();
794 auto it = impl.registeredTypes.find(typeID);
795 if (it == impl.registeredTypes.end())
796 return nullptr;
797 return it->second;
798 }
799
800 //===----------------------------------------------------------------------===//
801 // Type uniquing
802 //===----------------------------------------------------------------------===//
803
804 /// Returns the storage uniquer used for constructing type storage instances.
805 /// This should not be used directly.
getTypeUniquer()806 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
807
get(MLIRContext * context)808 BFloat16Type BFloat16Type::get(MLIRContext *context) {
809 return context->getImpl().bf16Ty;
810 }
get(MLIRContext * context)811 Float16Type Float16Type::get(MLIRContext *context) {
812 return context->getImpl().f16Ty;
813 }
get(MLIRContext * context)814 Float32Type Float32Type::get(MLIRContext *context) {
815 return context->getImpl().f32Ty;
816 }
get(MLIRContext * context)817 Float64Type Float64Type::get(MLIRContext *context) {
818 return context->getImpl().f64Ty;
819 }
get(MLIRContext * context)820 Float80Type Float80Type::get(MLIRContext *context) {
821 return context->getImpl().f80Ty;
822 }
get(MLIRContext * context)823 Float128Type Float128Type::get(MLIRContext *context) {
824 return context->getImpl().f128Ty;
825 }
826
827 /// Get an instance of the IndexType.
get(MLIRContext * context)828 IndexType IndexType::get(MLIRContext *context) {
829 return context->getImpl().indexTy;
830 }
831
832 /// Return an existing integer type instance if one is cached within the
833 /// context.
834 static IntegerType
getCachedIntegerType(unsigned width,IntegerType::SignednessSemantics signedness,MLIRContext * context)835 getCachedIntegerType(unsigned width,
836 IntegerType::SignednessSemantics signedness,
837 MLIRContext *context) {
838 if (signedness != IntegerType::Signless)
839 return IntegerType();
840
841 switch (width) {
842 case 1:
843 return context->getImpl().int1Ty;
844 case 8:
845 return context->getImpl().int8Ty;
846 case 16:
847 return context->getImpl().int16Ty;
848 case 32:
849 return context->getImpl().int32Ty;
850 case 64:
851 return context->getImpl().int64Ty;
852 case 128:
853 return context->getImpl().int128Ty;
854 default:
855 return IntegerType();
856 }
857 }
858
get(MLIRContext * context,unsigned width,IntegerType::SignednessSemantics signedness)859 IntegerType IntegerType::get(MLIRContext *context, unsigned width,
860 IntegerType::SignednessSemantics signedness) {
861 if (auto cached = getCachedIntegerType(width, signedness, context))
862 return cached;
863 return Base::get(context, width, signedness);
864 }
865
866 IntegerType
getChecked(function_ref<InFlightDiagnostic ()> emitError,MLIRContext * context,unsigned width,SignednessSemantics signedness)867 IntegerType::getChecked(function_ref<InFlightDiagnostic()> emitError,
868 MLIRContext *context, unsigned width,
869 SignednessSemantics signedness) {
870 if (auto cached = getCachedIntegerType(width, signedness, context))
871 return cached;
872 return Base::getChecked(emitError, context, width, signedness);
873 }
874
875 /// Get an instance of the NoneType.
get(MLIRContext * context)876 NoneType NoneType::get(MLIRContext *context) {
877 if (NoneType cachedInst = context->getImpl().noneType)
878 return cachedInst;
879 // Note: May happen when initializing the singleton attributes of the builtin
880 // dialect.
881 return Base::get(context);
882 }
883
884 //===----------------------------------------------------------------------===//
885 // Attribute uniquing
886 //===----------------------------------------------------------------------===//
887
888 /// Returns the storage uniquer used for constructing attribute storage
889 /// instances. This should not be used directly.
getAttributeUniquer()890 StorageUniquer &MLIRContext::getAttributeUniquer() {
891 return getImpl().attributeUniquer;
892 }
893
894 /// Initialize the given attribute storage instance.
initializeAttributeStorage(AttributeStorage * storage,MLIRContext * ctx,TypeID attrID)895 void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage,
896 MLIRContext *ctx,
897 TypeID attrID) {
898 storage->initializeAbstractAttribute(AbstractAttribute::lookup(attrID, ctx));
899
900 // If the attribute did not provide a type, then default to NoneType.
901 if (!storage->getType())
902 storage->setType(NoneType::get(ctx));
903 }
904
get(MLIRContext * context,bool value)905 BoolAttr BoolAttr::get(MLIRContext *context, bool value) {
906 return value ? context->getImpl().trueAttr : context->getImpl().falseAttr;
907 }
908
get(MLIRContext * context)909 UnitAttr UnitAttr::get(MLIRContext *context) {
910 return context->getImpl().unitAttr;
911 }
912
get(MLIRContext * context)913 UnknownLoc UnknownLoc::get(MLIRContext *context) {
914 return context->getImpl().unknownLocAttr;
915 }
916
917 /// Return empty dictionary.
getEmpty(MLIRContext * context)918 DictionaryAttr DictionaryAttr::getEmpty(MLIRContext *context) {
919 return context->getImpl().emptyDictionaryAttr;
920 }
921
initialize(MLIRContext * context)922 void StringAttrStorage::initialize(MLIRContext *context) {
923 // Check for a dialect namespace prefix, if there isn't one we don't need to
924 // do any additional initialization.
925 auto dialectNamePair = value.split('.');
926 if (dialectNamePair.first.empty() || dialectNamePair.second.empty())
927 return;
928
929 // If one exists, we check to see if this dialect is loaded. If it is, we set
930 // the dialect now, if it isn't we record this storage for initialization
931 // later if the dialect ever gets loaded.
932 if ((referencedDialect = context->getLoadedDialect(dialectNamePair.first)))
933 return;
934
935 MLIRContextImpl &impl = context->getImpl();
936 llvm::sys::SmartScopedLock<true> lock(impl.dialectRefStrAttrMutex);
937 impl.dialectReferencingStrAttrs[dialectNamePair.first].push_back(this);
938 }
939
940 /// Return an empty string.
get(MLIRContext * context)941 StringAttr StringAttr::get(MLIRContext *context) {
942 return context->getImpl().emptyStringAttr;
943 }
944
945 //===----------------------------------------------------------------------===//
946 // AffineMap uniquing
947 //===----------------------------------------------------------------------===//
948
getAffineUniquer()949 StorageUniquer &MLIRContext::getAffineUniquer() {
950 return getImpl().affineUniquer;
951 }
952
getImpl(unsigned dimCount,unsigned symbolCount,ArrayRef<AffineExpr> results,MLIRContext * context)953 AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount,
954 ArrayRef<AffineExpr> results,
955 MLIRContext *context) {
956 auto &impl = context->getImpl();
957 auto *storage = impl.affineUniquer.get<AffineMapStorage>(
958 [&](AffineMapStorage *storage) { storage->context = context; }, dimCount,
959 symbolCount, results);
960 return AffineMap(storage);
961 }
962
963 /// Check whether the arguments passed to the AffineMap::get() are consistent.
964 /// This method checks whether the highest index of dimensional identifier
965 /// present in result expressions is less than `dimCount` and the highest index
966 /// of symbolic identifier present in result expressions is less than
967 /// `symbolCount`.
968 LLVM_ATTRIBUTE_UNUSED static bool
willBeValidAffineMap(unsigned dimCount,unsigned symbolCount,ArrayRef<AffineExpr> results)969 willBeValidAffineMap(unsigned dimCount, unsigned symbolCount,
970 ArrayRef<AffineExpr> results) {
971 int64_t maxDimPosition = -1;
972 int64_t maxSymbolPosition = -1;
973 getMaxDimAndSymbol(ArrayRef<ArrayRef<AffineExpr>>(results), maxDimPosition,
974 maxSymbolPosition);
975 if ((maxDimPosition >= dimCount) || (maxSymbolPosition >= symbolCount)) {
976 LLVM_DEBUG(
977 llvm::dbgs()
978 << "maximum dimensional identifier position in result expression must "
979 "be less than `dimCount` and maximum symbolic identifier position "
980 "in result expression must be less than `symbolCount`\n");
981 return false;
982 }
983 return true;
984 }
985
get(MLIRContext * context)986 AffineMap AffineMap::get(MLIRContext *context) {
987 return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context);
988 }
989
get(unsigned dimCount,unsigned symbolCount,MLIRContext * context)990 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
991 MLIRContext *context) {
992 return getImpl(dimCount, symbolCount, /*results=*/{}, context);
993 }
994
get(unsigned dimCount,unsigned symbolCount,AffineExpr result)995 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
996 AffineExpr result) {
997 assert(willBeValidAffineMap(dimCount, symbolCount, {result}));
998 return getImpl(dimCount, symbolCount, {result}, result.getContext());
999 }
1000
get(unsigned dimCount,unsigned symbolCount,ArrayRef<AffineExpr> results,MLIRContext * context)1001 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
1002 ArrayRef<AffineExpr> results, MLIRContext *context) {
1003 assert(willBeValidAffineMap(dimCount, symbolCount, results));
1004 return getImpl(dimCount, symbolCount, results, context);
1005 }
1006
1007 //===----------------------------------------------------------------------===//
1008 // Integer Sets: these are allocated into the bump pointer, and are immutable.
1009 // Unlike AffineMap's, these are uniqued only if they are small.
1010 //===----------------------------------------------------------------------===//
1011
get(unsigned dimCount,unsigned symbolCount,ArrayRef<AffineExpr> constraints,ArrayRef<bool> eqFlags)1012 IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount,
1013 ArrayRef<AffineExpr> constraints,
1014 ArrayRef<bool> eqFlags) {
1015 // The number of constraints can't be zero.
1016 assert(!constraints.empty());
1017 assert(constraints.size() == eqFlags.size());
1018
1019 auto &impl = constraints[0].getContext()->getImpl();
1020 auto *storage = impl.affineUniquer.get<IntegerSetStorage>(
1021 [](IntegerSetStorage *) {}, dimCount, symbolCount, constraints, eqFlags);
1022 return IntegerSet(storage);
1023 }
1024
1025 //===----------------------------------------------------------------------===//
1026 // StorageUniquerSupport
1027 //===----------------------------------------------------------------------===//
1028
1029 /// Utility method to generate a callback that can be used to generate a
1030 /// diagnostic when checking the construction invariants of a storage object.
1031 /// This is defined out-of-line to avoid the need to include Location.h.
1032 llvm::unique_function<InFlightDiagnostic()>
getDefaultDiagnosticEmitFn(MLIRContext * ctx)1033 mlir::detail::getDefaultDiagnosticEmitFn(MLIRContext *ctx) {
1034 return [ctx] { return emitError(UnknownLoc::get(ctx)); };
1035 }
1036 llvm::unique_function<InFlightDiagnostic()>
getDefaultDiagnosticEmitFn(const Location & loc)1037 mlir::detail::getDefaultDiagnosticEmitFn(const Location &loc) {
1038 return [=] { return emitError(loc); };
1039 }
1040