1 //===- MLIRContext.cpp - MLIR Type Classes --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/MLIRContext.h" 10 #include "AffineExprDetail.h" 11 #include "AffineMapDetail.h" 12 #include "AttributeDetail.h" 13 #include "IntegerSetDetail.h" 14 #include "TypeDetail.h" 15 #include "mlir/IR/AffineExpr.h" 16 #include "mlir/IR/AffineMap.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/BuiltinDialect.h" 19 #include "mlir/IR/Diagnostics.h" 20 #include "mlir/IR/Dialect.h" 21 #include "mlir/IR/IntegerSet.h" 22 #include "mlir/IR/Location.h" 23 #include "mlir/IR/OpImplementation.h" 24 #include "mlir/IR/Types.h" 25 #include "mlir/Support/DebugAction.h" 26 #include "llvm/ADT/DenseMap.h" 27 #include "llvm/ADT/DenseSet.h" 28 #include "llvm/ADT/SetVector.h" 29 #include "llvm/ADT/SmallString.h" 30 #include "llvm/ADT/StringSet.h" 31 #include "llvm/ADT/Twine.h" 32 #include "llvm/Support/Allocator.h" 33 #include "llvm/Support/CommandLine.h" 34 #include "llvm/Support/Debug.h" 35 #include "llvm/Support/Mutex.h" 36 #include "llvm/Support/RWMutex.h" 37 #include "llvm/Support/ThreadPool.h" 38 #include "llvm/Support/raw_ostream.h" 39 #include <memory> 40 41 #define DEBUG_TYPE "mlircontext" 42 43 using namespace mlir; 44 using namespace mlir::detail; 45 46 using llvm::hash_combine; 47 using llvm::hash_combine_range; 48 49 //===----------------------------------------------------------------------===// 50 // MLIRContext CommandLine Options 51 //===----------------------------------------------------------------------===// 52 53 namespace { 54 /// This struct contains command line options that can be used to initialize 55 /// various bits of an MLIRContext. This uses a struct wrapper to avoid the need 56 /// for global command line options. 57 struct MLIRContextOptions { 58 llvm::cl::opt<bool> disableThreading{ 59 "mlir-disable-threading", 60 llvm::cl::desc("Disable multi-threading within MLIR, overrides any " 61 "further call to MLIRContext::enableMultiThreading()")}; 62 63 llvm::cl::opt<bool> printOpOnDiagnostic{ 64 "mlir-print-op-on-diagnostic", 65 llvm::cl::desc("When a diagnostic is emitted on an operation, also print " 66 "the operation as an attached note"), 67 llvm::cl::init(true)}; 68 69 llvm::cl::opt<bool> printStackTraceOnDiagnostic{ 70 "mlir-print-stacktrace-on-diagnostic", 71 llvm::cl::desc("When a diagnostic is emitted, also print the stack trace " 72 "as an attached note")}; 73 }; 74 } // end anonymous namespace 75 76 static llvm::ManagedStatic<MLIRContextOptions> clOptions; 77 78 static bool isThreadingGloballyDisabled() { 79 #if LLVM_ENABLE_THREADS != 0 80 return clOptions.isConstructed() && clOptions->disableThreading; 81 #else 82 return true; 83 #endif 84 } 85 86 /// Register a set of useful command-line options that can be used to configure 87 /// various flags within the MLIRContext. These flags are used when constructing 88 /// an MLIR context for initialization. 89 void mlir::registerMLIRContextCLOptions() { 90 // Make sure that the options struct has been initialized. 91 *clOptions; 92 } 93 94 //===----------------------------------------------------------------------===// 95 // Locking Utilities 96 //===----------------------------------------------------------------------===// 97 98 namespace { 99 /// Utility reader lock that takes a runtime flag that specifies if we really 100 /// need to lock. 101 struct ScopedReaderLock { 102 ScopedReaderLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock) 103 : mutex(shouldLock ? &mutexParam : nullptr) { 104 if (mutex) 105 mutex->lock_shared(); 106 } 107 ~ScopedReaderLock() { 108 if (mutex) 109 mutex->unlock_shared(); 110 } 111 llvm::sys::SmartRWMutex<true> *mutex; 112 }; 113 /// Utility writer lock that takes a runtime flag that specifies if we really 114 /// need to lock. 115 struct ScopedWriterLock { 116 ScopedWriterLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock) 117 : mutex(shouldLock ? &mutexParam : nullptr) { 118 if (mutex) 119 mutex->lock(); 120 } 121 ~ScopedWriterLock() { 122 if (mutex) 123 mutex->unlock(); 124 } 125 llvm::sys::SmartRWMutex<true> *mutex; 126 }; 127 } // end anonymous namespace. 128 129 //===----------------------------------------------------------------------===// 130 // AffineMap and IntegerSet hashing 131 //===----------------------------------------------------------------------===// 132 133 /// A utility function to safely get or create a uniqued instance within the 134 /// given set container. 135 template <typename ValueT, typename DenseInfoT, typename KeyT, 136 typename ConstructorFn> 137 static ValueT safeGetOrCreate(DenseSet<ValueT, DenseInfoT> &container, 138 KeyT &&key, llvm::sys::SmartRWMutex<true> &mutex, 139 bool threadingIsEnabled, 140 ConstructorFn &&constructorFn) { 141 // Check for an existing instance in read-only mode. 142 if (threadingIsEnabled) { 143 llvm::sys::SmartScopedReader<true> instanceLock(mutex); 144 auto it = container.find_as(key); 145 if (it != container.end()) 146 return *it; 147 } 148 149 // Acquire a writer-lock so that we can safely create the new instance. 150 ScopedWriterLock instanceLock(mutex, threadingIsEnabled); 151 152 // Check for an existing instance again here, because another writer thread 153 // may have already created one. Otherwise, construct a new instance. 154 auto existing = container.insert_as(ValueT(), key); 155 if (existing.second) 156 return *existing.first = constructorFn(); 157 return *existing.first; 158 } 159 160 namespace { 161 struct AffineMapKeyInfo : DenseMapInfo<AffineMap> { 162 // Affine maps are uniqued based on their dim/symbol counts and affine 163 // expressions. 164 using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExpr>>; 165 using DenseMapInfo<AffineMap>::isEqual; 166 167 static unsigned getHashValue(const AffineMap &key) { 168 return getHashValue( 169 KeyTy(key.getNumDims(), key.getNumSymbols(), key.getResults())); 170 } 171 172 static unsigned getHashValue(KeyTy key) { 173 return hash_combine( 174 std::get<0>(key), std::get<1>(key), 175 hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end())); 176 } 177 178 static bool isEqual(const KeyTy &lhs, AffineMap rhs) { 179 if (rhs == getEmptyKey() || rhs == getTombstoneKey()) 180 return false; 181 return lhs == std::make_tuple(rhs.getNumDims(), rhs.getNumSymbols(), 182 rhs.getResults()); 183 } 184 }; 185 186 struct IntegerSetKeyInfo : DenseMapInfo<IntegerSet> { 187 // Integer sets are uniqued based on their dim/symbol counts, affine 188 // expressions appearing in the LHS of constraints, and eqFlags. 189 using KeyTy = 190 std::tuple<unsigned, unsigned, ArrayRef<AffineExpr>, ArrayRef<bool>>; 191 using DenseMapInfo<IntegerSet>::isEqual; 192 193 static unsigned getHashValue(const IntegerSet &key) { 194 return getHashValue(KeyTy(key.getNumDims(), key.getNumSymbols(), 195 key.getConstraints(), key.getEqFlags())); 196 } 197 198 static unsigned getHashValue(KeyTy key) { 199 return hash_combine( 200 std::get<0>(key), std::get<1>(key), 201 hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()), 202 hash_combine_range(std::get<3>(key).begin(), std::get<3>(key).end())); 203 } 204 205 static bool isEqual(const KeyTy &lhs, IntegerSet rhs) { 206 if (rhs == getEmptyKey() || rhs == getTombstoneKey()) 207 return false; 208 return lhs == std::make_tuple(rhs.getNumDims(), rhs.getNumSymbols(), 209 rhs.getConstraints(), rhs.getEqFlags()); 210 } 211 }; 212 } // end anonymous namespace. 213 214 //===----------------------------------------------------------------------===// 215 // MLIRContextImpl 216 //===----------------------------------------------------------------------===// 217 218 namespace mlir { 219 /// This is the implementation of the MLIRContext class, using the pImpl idiom. 220 /// This class is completely private to this file, so everything is public. 221 class MLIRContextImpl { 222 public: 223 //===--------------------------------------------------------------------===// 224 // Debugging 225 //===--------------------------------------------------------------------===// 226 227 /// An action manager for use within the context. 228 DebugActionManager debugActionManager; 229 230 //===--------------------------------------------------------------------===// 231 // Diagnostics 232 //===--------------------------------------------------------------------===// 233 DiagnosticEngine diagEngine; 234 235 //===--------------------------------------------------------------------===// 236 // Options 237 //===--------------------------------------------------------------------===// 238 239 /// In most cases, creating operation in unregistered dialect is not desired 240 /// and indicate a misconfiguration of the compiler. This option enables to 241 /// detect such use cases 242 bool allowUnregisteredDialects = false; 243 244 /// Enable support for multi-threading within MLIR. 245 bool threadingIsEnabled = true; 246 247 /// Track if we are currently executing in a threaded execution environment 248 /// (like the pass-manager): this is only a debugging feature to help reducing 249 /// the chances of data races one some context APIs. 250 #ifndef NDEBUG 251 std::atomic<int> multiThreadedExecutionContext{0}; 252 #endif 253 254 /// If the operation should be attached to diagnostics printed via the 255 /// Operation::emit methods. 256 bool printOpOnDiagnostic = true; 257 258 /// If the current stack trace should be attached when emitting diagnostics. 259 bool printStackTraceOnDiagnostic = false; 260 261 //===--------------------------------------------------------------------===// 262 // Other 263 //===--------------------------------------------------------------------===// 264 265 /// This points to the ThreadPool used when processing MLIR tasks in parallel. 266 /// It can't be nullptr when multi-threading is enabled. Otherwise if 267 /// multi-threading is disabled, and the threadpool wasn't externally provided 268 /// using `setThreadPool`, this will be nullptr. 269 llvm::ThreadPool *threadPool = nullptr; 270 271 /// In case where the thread pool is owned by the context, this ensures 272 /// destruction with the context. 273 std::unique_ptr<llvm::ThreadPool> ownedThreadPool; 274 275 /// This is a list of dialects that are created referring to this context. 276 /// The MLIRContext owns the objects. 277 DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects; 278 DialectRegistry dialectsRegistry; 279 280 /// This is a mapping from operation name to AbstractOperation for registered 281 /// operations. 282 llvm::StringMap<AbstractOperation> registeredOperations; 283 284 /// An allocator used for AbstractAttribute and AbstractType objects. 285 llvm::BumpPtrAllocator abstractDialectSymbolAllocator; 286 287 //===--------------------------------------------------------------------===// 288 // Affine uniquing 289 //===--------------------------------------------------------------------===// 290 291 // Affine allocator and mutex for thread safety. 292 llvm::BumpPtrAllocator affineAllocator; 293 llvm::sys::SmartRWMutex<true> affineMutex; 294 295 // Affine map uniquing. 296 using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>; 297 AffineMapSet affineMaps; 298 299 // Integer set uniquing. 300 using IntegerSets = DenseSet<IntegerSet, IntegerSetKeyInfo>; 301 IntegerSets integerSets; 302 303 // Affine expression uniquing. 304 StorageUniquer affineUniquer; 305 306 //===--------------------------------------------------------------------===// 307 // Type uniquing 308 //===--------------------------------------------------------------------===// 309 310 DenseMap<TypeID, AbstractType *> registeredTypes; 311 StorageUniquer typeUniquer; 312 313 /// Cached Type Instances. 314 BFloat16Type bf16Ty; 315 Float16Type f16Ty; 316 Float32Type f32Ty; 317 Float64Type f64Ty; 318 Float80Type f80Ty; 319 Float128Type f128Ty; 320 IndexType indexTy; 321 IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty; 322 NoneType noneType; 323 324 //===--------------------------------------------------------------------===// 325 // Attribute uniquing 326 //===--------------------------------------------------------------------===// 327 328 DenseMap<TypeID, AbstractAttribute *> registeredAttributes; 329 StorageUniquer attributeUniquer; 330 331 /// Cached Attribute Instances. 332 BoolAttr falseAttr, trueAttr; 333 UnitAttr unitAttr; 334 UnknownLoc unknownLocAttr; 335 DictionaryAttr emptyDictionaryAttr; 336 StringAttr emptyStringAttr; 337 338 /// Map of string attributes that may reference a dialect, that are awaiting 339 /// that dialect to be loaded. 340 llvm::sys::SmartMutex<true> dialectRefStrAttrMutex; 341 DenseMap<StringRef, SmallVector<StringAttrStorage *>> 342 dialectReferencingStrAttrs; 343 344 public: 345 MLIRContextImpl(bool threadingIsEnabled) 346 : threadingIsEnabled(threadingIsEnabled) { 347 if (threadingIsEnabled) { 348 ownedThreadPool = std::make_unique<llvm::ThreadPool>(); 349 threadPool = ownedThreadPool.get(); 350 } 351 } 352 ~MLIRContextImpl() { 353 for (auto typeMapping : registeredTypes) 354 typeMapping.second->~AbstractType(); 355 for (auto attrMapping : registeredAttributes) 356 attrMapping.second->~AbstractAttribute(); 357 } 358 }; 359 } // end namespace mlir 360 361 MLIRContext::MLIRContext(Threading setting) 362 : MLIRContext(DialectRegistry(), setting) {} 363 364 MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting) 365 : impl(new MLIRContextImpl(setting == Threading::ENABLED && 366 !isThreadingGloballyDisabled())) { 367 // Initialize values based on the command line flags if they were provided. 368 if (clOptions.isConstructed()) { 369 printOpOnDiagnostic(clOptions->printOpOnDiagnostic); 370 printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic); 371 } 372 373 // Pre-populate the registry. 374 registry.appendTo(impl->dialectsRegistry); 375 376 // Ensure the builtin dialect is always pre-loaded. 377 getOrLoadDialect<BuiltinDialect>(); 378 379 // Initialize several common attributes and types to avoid the need to lock 380 // the context when accessing them. 381 382 //// Types. 383 /// Floating-point Types. 384 impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this); 385 impl->f16Ty = TypeUniquer::get<Float16Type>(this); 386 impl->f32Ty = TypeUniquer::get<Float32Type>(this); 387 impl->f64Ty = TypeUniquer::get<Float64Type>(this); 388 impl->f80Ty = TypeUniquer::get<Float80Type>(this); 389 impl->f128Ty = TypeUniquer::get<Float128Type>(this); 390 /// Index Type. 391 impl->indexTy = TypeUniquer::get<IndexType>(this); 392 /// Integer Types. 393 impl->int1Ty = TypeUniquer::get<IntegerType>(this, 1, IntegerType::Signless); 394 impl->int8Ty = TypeUniquer::get<IntegerType>(this, 8, IntegerType::Signless); 395 impl->int16Ty = 396 TypeUniquer::get<IntegerType>(this, 16, IntegerType::Signless); 397 impl->int32Ty = 398 TypeUniquer::get<IntegerType>(this, 32, IntegerType::Signless); 399 impl->int64Ty = 400 TypeUniquer::get<IntegerType>(this, 64, IntegerType::Signless); 401 impl->int128Ty = 402 TypeUniquer::get<IntegerType>(this, 128, IntegerType::Signless); 403 /// None Type. 404 impl->noneType = TypeUniquer::get<NoneType>(this); 405 406 //// Attributes. 407 //// Note: These must be registered after the types as they may generate one 408 //// of the above types internally. 409 /// Unknown Location Attribute. 410 impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(this); 411 /// Bool Attributes. 412 impl->falseAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, false); 413 impl->trueAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, true); 414 /// Unit Attribute. 415 impl->unitAttr = AttributeUniquer::get<UnitAttr>(this); 416 /// The empty dictionary attribute. 417 impl->emptyDictionaryAttr = DictionaryAttr::getEmptyUnchecked(this); 418 /// The empty string attribute. 419 impl->emptyStringAttr = StringAttr::getEmptyStringAttrUnchecked(this); 420 421 // Register the affine storage objects with the uniquer. 422 impl->affineUniquer 423 .registerParametricStorageType<AffineBinaryOpExprStorage>(); 424 impl->affineUniquer 425 .registerParametricStorageType<AffineConstantExprStorage>(); 426 impl->affineUniquer.registerParametricStorageType<AffineDimExprStorage>(); 427 } 428 429 MLIRContext::~MLIRContext() {} 430 431 /// Copy the specified array of elements into memory managed by the provided 432 /// bump pointer allocator. This assumes the elements are all PODs. 433 template <typename T> 434 static ArrayRef<T> copyArrayRefInto(llvm::BumpPtrAllocator &allocator, 435 ArrayRef<T> elements) { 436 auto result = allocator.Allocate<T>(elements.size()); 437 std::uninitialized_copy(elements.begin(), elements.end(), result); 438 return ArrayRef<T>(result, elements.size()); 439 } 440 441 //===----------------------------------------------------------------------===// 442 // Debugging 443 //===----------------------------------------------------------------------===// 444 445 DebugActionManager &MLIRContext::getDebugActionManager() { 446 return getImpl().debugActionManager; 447 } 448 449 //===----------------------------------------------------------------------===// 450 // Diagnostic Handlers 451 //===----------------------------------------------------------------------===// 452 453 /// Returns the diagnostic engine for this context. 454 DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; } 455 456 //===----------------------------------------------------------------------===// 457 // Dialect and Operation Registration 458 //===----------------------------------------------------------------------===// 459 460 void MLIRContext::appendDialectRegistry(const DialectRegistry ®istry) { 461 registry.appendTo(impl->dialectsRegistry); 462 463 // For the already loaded dialects, register the interfaces immediately. 464 for (const auto &kvp : impl->loadedDialects) 465 registry.registerDelayedInterfaces(kvp.second.get()); 466 } 467 468 const DialectRegistry &MLIRContext::getDialectRegistry() { 469 return impl->dialectsRegistry; 470 } 471 472 /// Return information about all registered IR dialects. 473 std::vector<Dialect *> MLIRContext::getLoadedDialects() { 474 std::vector<Dialect *> result; 475 result.reserve(impl->loadedDialects.size()); 476 for (auto &dialect : impl->loadedDialects) 477 result.push_back(dialect.second.get()); 478 llvm::array_pod_sort(result.begin(), result.end(), 479 [](Dialect *const *lhs, Dialect *const *rhs) -> int { 480 return (*lhs)->getNamespace() < (*rhs)->getNamespace(); 481 }); 482 return result; 483 } 484 std::vector<StringRef> MLIRContext::getAvailableDialects() { 485 std::vector<StringRef> result; 486 for (auto dialect : impl->dialectsRegistry.getDialectNames()) 487 result.push_back(dialect); 488 return result; 489 } 490 491 /// Get a registered IR dialect with the given namespace. If none is found, 492 /// then return nullptr. 493 Dialect *MLIRContext::getLoadedDialect(StringRef name) { 494 // Dialects are sorted by name, so we can use binary search for lookup. 495 auto it = impl->loadedDialects.find(name); 496 return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr; 497 } 498 499 Dialect *MLIRContext::getOrLoadDialect(StringRef name) { 500 Dialect *dialect = getLoadedDialect(name); 501 if (dialect) 502 return dialect; 503 DialectAllocatorFunctionRef allocator = 504 impl->dialectsRegistry.getDialectAllocator(name); 505 return allocator ? allocator(this) : nullptr; 506 } 507 508 /// Get a dialect for the provided namespace and TypeID: abort the program if a 509 /// dialect exist for this namespace with different TypeID. Returns a pointer to 510 /// the dialect owned by the context. 511 Dialect * 512 MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID, 513 function_ref<std::unique_ptr<Dialect>()> ctor) { 514 auto &impl = getImpl(); 515 // Get the correct insertion position sorted by namespace. 516 std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace]; 517 518 if (!dialect) { 519 LLVM_DEBUG(llvm::dbgs() 520 << "Load new dialect in Context " << dialectNamespace << "\n"); 521 #ifndef NDEBUG 522 if (impl.multiThreadedExecutionContext != 0) 523 llvm::report_fatal_error( 524 "Loading a dialect (" + dialectNamespace + 525 ") while in a multi-threaded execution context (maybe " 526 "the PassManager): this can indicate a " 527 "missing `dependentDialects` in a pass for example."); 528 #endif 529 dialect = ctor(); 530 assert(dialect && "dialect ctor failed"); 531 532 // Refresh all the identifiers dialect field, this catches cases where a 533 // dialect may be loaded after identifier prefixed with this dialect name 534 // were already created. 535 auto stringAttrsIt = impl.dialectReferencingStrAttrs.find(dialectNamespace); 536 if (stringAttrsIt != impl.dialectReferencingStrAttrs.end()) { 537 for (StringAttrStorage *storage : stringAttrsIt->second) 538 storage->referencedDialect = dialect.get(); 539 impl.dialectReferencingStrAttrs.erase(stringAttrsIt); 540 } 541 542 // Actually register the interfaces with delayed registration. 543 impl.dialectsRegistry.registerDelayedInterfaces(dialect.get()); 544 return dialect.get(); 545 } 546 547 // Abort if dialect with namespace has already been registered. 548 if (dialect->getTypeID() != dialectID) 549 llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace + 550 "' has already been registered"); 551 552 return dialect.get(); 553 } 554 555 void MLIRContext::loadAllAvailableDialects() { 556 for (StringRef name : getAvailableDialects()) 557 getOrLoadDialect(name); 558 } 559 560 llvm::hash_code MLIRContext::getRegistryHash() { 561 llvm::hash_code hash(0); 562 // Factor in number of loaded dialects, attributes, operations, types. 563 hash = llvm::hash_combine(hash, impl->loadedDialects.size()); 564 hash = llvm::hash_combine(hash, impl->registeredAttributes.size()); 565 hash = llvm::hash_combine(hash, impl->registeredOperations.size()); 566 hash = llvm::hash_combine(hash, impl->registeredTypes.size()); 567 return hash; 568 } 569 570 bool MLIRContext::allowsUnregisteredDialects() { 571 return impl->allowUnregisteredDialects; 572 } 573 574 void MLIRContext::allowUnregisteredDialects(bool allowing) { 575 impl->allowUnregisteredDialects = allowing; 576 } 577 578 /// Return true if multi-threading is enabled by the context. 579 bool MLIRContext::isMultithreadingEnabled() { 580 return impl->threadingIsEnabled && llvm::llvm_is_multithreaded(); 581 } 582 583 /// Set the flag specifying if multi-threading is disabled by the context. 584 void MLIRContext::disableMultithreading(bool disable) { 585 // This API can be overridden by the global debugging flag 586 // --mlir-disable-threading 587 if (isThreadingGloballyDisabled()) 588 return; 589 590 impl->threadingIsEnabled = !disable; 591 592 // Update the threading mode for each of the uniquers. 593 impl->affineUniquer.disableMultithreading(disable); 594 impl->attributeUniquer.disableMultithreading(disable); 595 impl->typeUniquer.disableMultithreading(disable); 596 597 // Destroy thread pool (stop all threads) if it is no longer needed, or create 598 // a new one if multithreading was re-enabled. 599 if (disable) { 600 // If the thread pool is owned, explicitly set it to nullptr to avoid 601 // keeping a dangling pointer around. If the thread pool is externally 602 // owned, we don't do anything. 603 if (impl->ownedThreadPool) { 604 assert(impl->threadPool); 605 impl->threadPool = nullptr; 606 impl->ownedThreadPool.reset(); 607 } 608 } else if (!impl->threadPool) { 609 // The thread pool isn't externally provided. 610 assert(!impl->ownedThreadPool); 611 impl->ownedThreadPool = std::make_unique<llvm::ThreadPool>(); 612 impl->threadPool = impl->ownedThreadPool.get(); 613 } 614 } 615 616 void MLIRContext::setThreadPool(llvm::ThreadPool &pool) { 617 assert(!isMultithreadingEnabled() && 618 "expected multi-threading to be disabled when setting a ThreadPool"); 619 impl->threadPool = &pool; 620 impl->ownedThreadPool.reset(); 621 enableMultithreading(); 622 } 623 624 llvm::ThreadPool &MLIRContext::getThreadPool() { 625 assert(isMultithreadingEnabled() && 626 "expected multi-threading to be enabled within the context"); 627 assert(impl->threadPool && 628 "multi-threading is enabled but threadpool not set"); 629 return *impl->threadPool; 630 } 631 632 void MLIRContext::enterMultiThreadedExecution() { 633 #ifndef NDEBUG 634 ++impl->multiThreadedExecutionContext; 635 #endif 636 } 637 void MLIRContext::exitMultiThreadedExecution() { 638 #ifndef NDEBUG 639 --impl->multiThreadedExecutionContext; 640 #endif 641 } 642 643 /// Return true if we should attach the operation to diagnostics emitted via 644 /// Operation::emit. 645 bool MLIRContext::shouldPrintOpOnDiagnostic() { 646 return impl->printOpOnDiagnostic; 647 } 648 649 /// Set the flag specifying if we should attach the operation to diagnostics 650 /// emitted via Operation::emit. 651 void MLIRContext::printOpOnDiagnostic(bool enable) { 652 impl->printOpOnDiagnostic = enable; 653 } 654 655 /// Return true if we should attach the current stacktrace to diagnostics when 656 /// emitted. 657 bool MLIRContext::shouldPrintStackTraceOnDiagnostic() { 658 return impl->printStackTraceOnDiagnostic; 659 } 660 661 /// Set the flag specifying if we should attach the current stacktrace when 662 /// emitting diagnostics. 663 void MLIRContext::printStackTraceOnDiagnostic(bool enable) { 664 impl->printStackTraceOnDiagnostic = enable; 665 } 666 667 /// Return information about all registered operations. This isn't very 668 /// efficient, typically you should ask the operations about their properties 669 /// directly. 670 std::vector<AbstractOperation *> MLIRContext::getRegisteredOperations() { 671 // We just have the operations in a non-deterministic hash table order. Dump 672 // into a temporary array, then sort it by operation name to get a stable 673 // ordering. 674 llvm::StringMap<AbstractOperation> ®isteredOps = 675 impl->registeredOperations; 676 677 std::vector<AbstractOperation *> result; 678 result.reserve(registeredOps.size()); 679 for (auto &elt : registeredOps) 680 result.push_back(&elt.second); 681 llvm::array_pod_sort( 682 result.begin(), result.end(), 683 [](AbstractOperation *const *lhs, AbstractOperation *const *rhs) { 684 return (*lhs)->name.compare((*rhs)->name); 685 }); 686 687 return result; 688 } 689 690 bool MLIRContext::isOperationRegistered(StringRef name) { 691 return impl->registeredOperations.count(name); 692 } 693 694 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) { 695 auto &impl = context->getImpl(); 696 assert(impl.multiThreadedExecutionContext == 0 && 697 "Registering a new type kind while in a multi-threaded execution " 698 "context"); 699 auto *newInfo = 700 new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>()) 701 AbstractType(std::move(typeInfo)); 702 if (!impl.registeredTypes.insert({typeID, newInfo}).second) 703 llvm::report_fatal_error("Dialect Type already registered."); 704 } 705 706 void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) { 707 auto &impl = context->getImpl(); 708 assert(impl.multiThreadedExecutionContext == 0 && 709 "Registering a new attribute kind while in a multi-threaded execution " 710 "context"); 711 auto *newInfo = 712 new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>()) 713 AbstractAttribute(std::move(attrInfo)); 714 if (!impl.registeredAttributes.insert({typeID, newInfo}).second) 715 llvm::report_fatal_error("Dialect Attribute already registered."); 716 } 717 718 //===----------------------------------------------------------------------===// 719 // AbstractAttribute 720 //===----------------------------------------------------------------------===// 721 722 /// Get the dialect that registered the attribute with the provided typeid. 723 const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID, 724 MLIRContext *context) { 725 const AbstractAttribute *abstract = lookupMutable(typeID, context); 726 if (!abstract) 727 llvm::report_fatal_error("Trying to create an Attribute that was not " 728 "registered in this MLIRContext."); 729 return *abstract; 730 } 731 732 AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID, 733 MLIRContext *context) { 734 auto &impl = context->getImpl(); 735 auto it = impl.registeredAttributes.find(typeID); 736 if (it == impl.registeredAttributes.end()) 737 return nullptr; 738 return it->second; 739 } 740 741 //===----------------------------------------------------------------------===// 742 // AbstractOperation 743 //===----------------------------------------------------------------------===// 744 745 ParseResult AbstractOperation::parseAssembly(OpAsmParser &parser, 746 OperationState &result) const { 747 return parseAssemblyFn(parser, result); 748 } 749 750 /// Look up the specified operation in the operation set and return a pointer 751 /// to it if present. Otherwise, return a null pointer. 752 AbstractOperation *AbstractOperation::lookupMutable(StringRef opName, 753 MLIRContext *context) { 754 auto &impl = context->getImpl(); 755 auto it = impl.registeredOperations.find(opName); 756 if (it != impl.registeredOperations.end()) 757 return &it->second; 758 return nullptr; 759 } 760 761 void AbstractOperation::insert( 762 StringRef name, Dialect &dialect, TypeID typeID, 763 ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly, 764 VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook, 765 GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, 766 detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait, 767 ArrayRef<StringRef> attrNames) { 768 MLIRContext *ctx = dialect.getContext(); 769 auto &impl = ctx->getImpl(); 770 assert(impl.multiThreadedExecutionContext == 0 && 771 "Registering a new operation kind while in a multi-threaded execution " 772 "context"); 773 774 // Register the attribute names of this operation. 775 MutableArrayRef<StringAttr> cachedAttrNames; 776 if (!attrNames.empty()) { 777 cachedAttrNames = MutableArrayRef<StringAttr>( 778 impl.abstractDialectSymbolAllocator.Allocate<StringAttr>( 779 attrNames.size()), 780 attrNames.size()); 781 for (unsigned i : llvm::seq<unsigned>(0, attrNames.size())) 782 new (&cachedAttrNames[i]) StringAttr(StringAttr::get(ctx, attrNames[i])); 783 } 784 785 // Register the information for this operation. 786 AbstractOperation opInfo( 787 name, dialect, typeID, std::move(parseAssembly), std::move(printAssembly), 788 std::move(verifyInvariants), std::move(foldHook), 789 std::move(getCanonicalizationPatterns), std::move(interfaceMap), 790 std::move(hasTrait), cachedAttrNames); 791 if (!impl.registeredOperations.insert({name, std::move(opInfo)}).second) { 792 llvm::errs() << "error: operation named '" << name 793 << "' is already registered.\n"; 794 abort(); 795 } 796 } 797 798 AbstractOperation::AbstractOperation( 799 StringRef name, Dialect &dialect, TypeID typeID, 800 ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly, 801 VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook, 802 GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, 803 detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait, 804 ArrayRef<StringAttr> attrNames) 805 : name(StringAttr::get(dialect.getContext(), name)), dialect(dialect), 806 typeID(typeID), interfaceMap(std::move(interfaceMap)), 807 foldHookFn(std::move(foldHook)), 808 getCanonicalizationPatternsFn(std::move(getCanonicalizationPatterns)), 809 hasTraitFn(std::move(hasTrait)), 810 parseAssemblyFn(std::move(parseAssembly)), 811 printAssemblyFn(std::move(printAssembly)), 812 verifyInvariantsFn(std::move(verifyInvariants)), 813 attributeNames(attrNames) {} 814 815 //===----------------------------------------------------------------------===// 816 // AbstractType 817 //===----------------------------------------------------------------------===// 818 819 const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) { 820 const AbstractType *type = lookupMutable(typeID, context); 821 if (!type) 822 llvm::report_fatal_error( 823 "Trying to create a Type that was not registered in this MLIRContext."); 824 return *type; 825 } 826 827 AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) { 828 auto &impl = context->getImpl(); 829 auto it = impl.registeredTypes.find(typeID); 830 if (it == impl.registeredTypes.end()) 831 return nullptr; 832 return it->second; 833 } 834 835 //===----------------------------------------------------------------------===// 836 // Type uniquing 837 //===----------------------------------------------------------------------===// 838 839 /// Returns the storage uniquer used for constructing type storage instances. 840 /// This should not be used directly. 841 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; } 842 843 BFloat16Type BFloat16Type::get(MLIRContext *context) { 844 return context->getImpl().bf16Ty; 845 } 846 Float16Type Float16Type::get(MLIRContext *context) { 847 return context->getImpl().f16Ty; 848 } 849 Float32Type Float32Type::get(MLIRContext *context) { 850 return context->getImpl().f32Ty; 851 } 852 Float64Type Float64Type::get(MLIRContext *context) { 853 return context->getImpl().f64Ty; 854 } 855 Float80Type Float80Type::get(MLIRContext *context) { 856 return context->getImpl().f80Ty; 857 } 858 Float128Type Float128Type::get(MLIRContext *context) { 859 return context->getImpl().f128Ty; 860 } 861 862 /// Get an instance of the IndexType. 863 IndexType IndexType::get(MLIRContext *context) { 864 return context->getImpl().indexTy; 865 } 866 867 /// Return an existing integer type instance if one is cached within the 868 /// context. 869 static IntegerType 870 getCachedIntegerType(unsigned width, 871 IntegerType::SignednessSemantics signedness, 872 MLIRContext *context) { 873 if (signedness != IntegerType::Signless) 874 return IntegerType(); 875 876 switch (width) { 877 case 1: 878 return context->getImpl().int1Ty; 879 case 8: 880 return context->getImpl().int8Ty; 881 case 16: 882 return context->getImpl().int16Ty; 883 case 32: 884 return context->getImpl().int32Ty; 885 case 64: 886 return context->getImpl().int64Ty; 887 case 128: 888 return context->getImpl().int128Ty; 889 default: 890 return IntegerType(); 891 } 892 } 893 894 IntegerType IntegerType::get(MLIRContext *context, unsigned width, 895 IntegerType::SignednessSemantics signedness) { 896 if (auto cached = getCachedIntegerType(width, signedness, context)) 897 return cached; 898 return Base::get(context, width, signedness); 899 } 900 901 IntegerType 902 IntegerType::getChecked(function_ref<InFlightDiagnostic()> emitError, 903 MLIRContext *context, unsigned width, 904 SignednessSemantics signedness) { 905 if (auto cached = getCachedIntegerType(width, signedness, context)) 906 return cached; 907 return Base::getChecked(emitError, context, width, signedness); 908 } 909 910 /// Get an instance of the NoneType. 911 NoneType NoneType::get(MLIRContext *context) { 912 if (NoneType cachedInst = context->getImpl().noneType) 913 return cachedInst; 914 // Note: May happen when initializing the singleton attributes of the builtin 915 // dialect. 916 return Base::get(context); 917 } 918 919 //===----------------------------------------------------------------------===// 920 // Attribute uniquing 921 //===----------------------------------------------------------------------===// 922 923 /// Returns the storage uniquer used for constructing attribute storage 924 /// instances. This should not be used directly. 925 StorageUniquer &MLIRContext::getAttributeUniquer() { 926 return getImpl().attributeUniquer; 927 } 928 929 /// Initialize the given attribute storage instance. 930 void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage, 931 MLIRContext *ctx, 932 TypeID attrID) { 933 storage->initializeAbstractAttribute(AbstractAttribute::lookup(attrID, ctx)); 934 935 // If the attribute did not provide a type, then default to NoneType. 936 if (!storage->getType()) 937 storage->setType(NoneType::get(ctx)); 938 } 939 940 BoolAttr BoolAttr::get(MLIRContext *context, bool value) { 941 return value ? context->getImpl().trueAttr : context->getImpl().falseAttr; 942 } 943 944 UnitAttr UnitAttr::get(MLIRContext *context) { 945 return context->getImpl().unitAttr; 946 } 947 948 UnknownLoc UnknownLoc::get(MLIRContext *context) { 949 return context->getImpl().unknownLocAttr; 950 } 951 952 /// Return empty dictionary. 953 DictionaryAttr DictionaryAttr::getEmpty(MLIRContext *context) { 954 return context->getImpl().emptyDictionaryAttr; 955 } 956 957 void StringAttrStorage::initialize(MLIRContext *context) { 958 // Check for a dialect namespace prefix, if there isn't one we don't need to 959 // do any additional initialization. 960 auto dialectNamePair = value.split('.'); 961 if (dialectNamePair.first.empty() || dialectNamePair.second.empty()) 962 return; 963 964 // If one exists, we check to see if this dialect is loaded. If it is, we set 965 // the dialect now, if it isn't we record this storage for initialization 966 // later if the dialect ever gets loaded. 967 if ((referencedDialect = context->getLoadedDialect(dialectNamePair.first))) 968 return; 969 970 MLIRContextImpl &impl = context->getImpl(); 971 llvm::sys::SmartScopedLock<true> lock(impl.dialectRefStrAttrMutex); 972 impl.dialectReferencingStrAttrs[dialectNamePair.first].push_back(this); 973 } 974 975 /// Return an empty string. 976 StringAttr StringAttr::get(MLIRContext *context) { 977 return context->getImpl().emptyStringAttr; 978 } 979 980 //===----------------------------------------------------------------------===// 981 // AffineMap uniquing 982 //===----------------------------------------------------------------------===// 983 984 StorageUniquer &MLIRContext::getAffineUniquer() { 985 return getImpl().affineUniquer; 986 } 987 988 AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount, 989 ArrayRef<AffineExpr> results, 990 MLIRContext *context) { 991 auto &impl = context->getImpl(); 992 auto key = std::make_tuple(dimCount, symbolCount, results); 993 994 // Safely get or create an AffineMap instance. 995 return safeGetOrCreate( 996 impl.affineMaps, key, impl.affineMutex, impl.threadingIsEnabled, [&] { 997 auto *res = impl.affineAllocator.Allocate<detail::AffineMapStorage>(); 998 999 // Copy the results into the bump pointer. 1000 results = copyArrayRefInto(impl.affineAllocator, results); 1001 1002 // Initialize the memory using placement new. 1003 new (res) 1004 detail::AffineMapStorage{dimCount, symbolCount, results, context}; 1005 return AffineMap(res); 1006 }); 1007 } 1008 1009 AffineMap AffineMap::get(MLIRContext *context) { 1010 return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context); 1011 } 1012 1013 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, 1014 MLIRContext *context) { 1015 return getImpl(dimCount, symbolCount, /*results=*/{}, context); 1016 } 1017 1018 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, 1019 AffineExpr result) { 1020 return getImpl(dimCount, symbolCount, {result}, result.getContext()); 1021 } 1022 1023 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, 1024 ArrayRef<AffineExpr> results, MLIRContext *context) { 1025 return getImpl(dimCount, symbolCount, results, context); 1026 } 1027 1028 //===----------------------------------------------------------------------===// 1029 // Integer Sets: these are allocated into the bump pointer, and are immutable. 1030 // Unlike AffineMap's, these are uniqued only if they are small. 1031 //===----------------------------------------------------------------------===// 1032 1033 IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount, 1034 ArrayRef<AffineExpr> constraints, 1035 ArrayRef<bool> eqFlags) { 1036 // The number of constraints can't be zero. 1037 assert(!constraints.empty()); 1038 assert(constraints.size() == eqFlags.size()); 1039 1040 auto &impl = constraints[0].getContext()->getImpl(); 1041 1042 // A utility function to construct a new IntegerSetStorage instance. 1043 auto constructorFn = [&] { 1044 auto *res = impl.affineAllocator.Allocate<detail::IntegerSetStorage>(); 1045 1046 // Copy the results and equality flags into the bump pointer. 1047 constraints = copyArrayRefInto(impl.affineAllocator, constraints); 1048 eqFlags = copyArrayRefInto(impl.affineAllocator, eqFlags); 1049 1050 // Initialize the memory using placement new. 1051 new (res) 1052 detail::IntegerSetStorage{dimCount, symbolCount, constraints, eqFlags}; 1053 return IntegerSet(res); 1054 }; 1055 1056 // If this instance is uniqued, then we handle it separately so that multiple 1057 // threads may simultaneously access existing instances. 1058 if (constraints.size() < IntegerSet::kUniquingThreshold) { 1059 auto key = std::make_tuple(dimCount, symbolCount, constraints, eqFlags); 1060 return safeGetOrCreate(impl.integerSets, key, impl.affineMutex, 1061 impl.threadingIsEnabled, constructorFn); 1062 } 1063 1064 // Otherwise, acquire a writer-lock so that we can safely create the new 1065 // instance. 1066 ScopedWriterLock affineLock(impl.affineMutex, impl.threadingIsEnabled); 1067 return constructorFn(); 1068 } 1069 1070 //===----------------------------------------------------------------------===// 1071 // StorageUniquerSupport 1072 //===----------------------------------------------------------------------===// 1073 1074 /// Utility method to generate a callback that can be used to generate a 1075 /// diagnostic when checking the construction invariants of a storage object. 1076 /// This is defined out-of-line to avoid the need to include Location.h. 1077 llvm::unique_function<InFlightDiagnostic()> 1078 mlir::detail::getDefaultDiagnosticEmitFn(MLIRContext *ctx) { 1079 return [ctx] { return emitError(UnknownLoc::get(ctx)); }; 1080 } 1081 llvm::unique_function<InFlightDiagnostic()> 1082 mlir::detail::getDefaultDiagnosticEmitFn(const Location &loc) { 1083 return [=] { return emitError(loc); }; 1084 } 1085