1 //===- MLIRContext.cpp - MLIR Type Classes --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/MLIRContext.h" 10 #include "AffineExprDetail.h" 11 #include "AffineMapDetail.h" 12 #include "AttributeDetail.h" 13 #include "IntegerSetDetail.h" 14 #include "TypeDetail.h" 15 #include "mlir/IR/AffineExpr.h" 16 #include "mlir/IR/AffineMap.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/BuiltinDialect.h" 19 #include "mlir/IR/Diagnostics.h" 20 #include "mlir/IR/Dialect.h" 21 #include "mlir/IR/Identifier.h" 22 #include "mlir/IR/IntegerSet.h" 23 #include "mlir/IR/Location.h" 24 #include "mlir/IR/OpImplementation.h" 25 #include "mlir/IR/Types.h" 26 #include "mlir/Support/DebugAction.h" 27 #include "llvm/ADT/DenseMap.h" 28 #include "llvm/ADT/DenseSet.h" 29 #include "llvm/ADT/SetVector.h" 30 #include "llvm/ADT/SmallString.h" 31 #include "llvm/ADT/StringSet.h" 32 #include "llvm/ADT/Twine.h" 33 #include "llvm/Support/Allocator.h" 34 #include "llvm/Support/CommandLine.h" 35 #include "llvm/Support/Debug.h" 36 #include "llvm/Support/RWMutex.h" 37 #include "llvm/Support/ThreadPool.h" 38 #include "llvm/Support/raw_ostream.h" 39 #include <memory> 40 41 #define DEBUG_TYPE "mlircontext" 42 43 using namespace mlir; 44 using namespace mlir::detail; 45 46 using llvm::hash_combine; 47 using llvm::hash_combine_range; 48 49 //===----------------------------------------------------------------------===// 50 // MLIRContext CommandLine Options 51 //===----------------------------------------------------------------------===// 52 53 namespace { 54 /// This struct contains command line options that can be used to initialize 55 /// various bits of an MLIRContext. This uses a struct wrapper to avoid the need 56 /// for global command line options. 57 struct MLIRContextOptions { 58 llvm::cl::opt<bool> disableThreading{ 59 "mlir-disable-threading", 60 llvm::cl::desc("Disable multi-threading within MLIR, overrides any " 61 "further call to MLIRContext::enableMultiThreading()")}; 62 63 llvm::cl::opt<bool> printOpOnDiagnostic{ 64 "mlir-print-op-on-diagnostic", 65 llvm::cl::desc("When a diagnostic is emitted on an operation, also print " 66 "the operation as an attached note"), 67 llvm::cl::init(true)}; 68 69 llvm::cl::opt<bool> printStackTraceOnDiagnostic{ 70 "mlir-print-stacktrace-on-diagnostic", 71 llvm::cl::desc("When a diagnostic is emitted, also print the stack trace " 72 "as an attached note")}; 73 }; 74 } // end anonymous namespace 75 76 static llvm::ManagedStatic<MLIRContextOptions> clOptions; 77 78 static bool isThreadingGloballyDisabled() { 79 #if LLVM_ENABLE_THREADS != 0 80 return clOptions.isConstructed() && clOptions->disableThreading; 81 #else 82 return true; 83 #endif 84 } 85 86 /// Register a set of useful command-line options that can be used to configure 87 /// various flags within the MLIRContext. These flags are used when constructing 88 /// an MLIR context for initialization. 89 void mlir::registerMLIRContextCLOptions() { 90 // Make sure that the options struct has been initialized. 91 *clOptions; 92 } 93 94 //===----------------------------------------------------------------------===// 95 // Locking Utilities 96 //===----------------------------------------------------------------------===// 97 98 namespace { 99 /// Utility reader lock that takes a runtime flag that specifies if we really 100 /// need to lock. 101 struct ScopedReaderLock { 102 ScopedReaderLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock) 103 : mutex(shouldLock ? &mutexParam : nullptr) { 104 if (mutex) 105 mutex->lock_shared(); 106 } 107 ~ScopedReaderLock() { 108 if (mutex) 109 mutex->unlock_shared(); 110 } 111 llvm::sys::SmartRWMutex<true> *mutex; 112 }; 113 /// Utility writer lock that takes a runtime flag that specifies if we really 114 /// need to lock. 115 struct ScopedWriterLock { 116 ScopedWriterLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock) 117 : mutex(shouldLock ? &mutexParam : nullptr) { 118 if (mutex) 119 mutex->lock(); 120 } 121 ~ScopedWriterLock() { 122 if (mutex) 123 mutex->unlock(); 124 } 125 llvm::sys::SmartRWMutex<true> *mutex; 126 }; 127 } // end anonymous namespace. 128 129 //===----------------------------------------------------------------------===// 130 // AffineMap and IntegerSet hashing 131 //===----------------------------------------------------------------------===// 132 133 /// A utility function to safely get or create a uniqued instance within the 134 /// given set container. 135 template <typename ValueT, typename DenseInfoT, typename KeyT, 136 typename ConstructorFn> 137 static ValueT safeGetOrCreate(DenseSet<ValueT, DenseInfoT> &container, 138 KeyT &&key, llvm::sys::SmartRWMutex<true> &mutex, 139 bool threadingIsEnabled, 140 ConstructorFn &&constructorFn) { 141 // Check for an existing instance in read-only mode. 142 if (threadingIsEnabled) { 143 llvm::sys::SmartScopedReader<true> instanceLock(mutex); 144 auto it = container.find_as(key); 145 if (it != container.end()) 146 return *it; 147 } 148 149 // Acquire a writer-lock so that we can safely create the new instance. 150 ScopedWriterLock instanceLock(mutex, threadingIsEnabled); 151 152 // Check for an existing instance again here, because another writer thread 153 // may have already created one. Otherwise, construct a new instance. 154 auto existing = container.insert_as(ValueT(), key); 155 if (existing.second) 156 return *existing.first = constructorFn(); 157 return *existing.first; 158 } 159 160 namespace { 161 struct AffineMapKeyInfo : DenseMapInfo<AffineMap> { 162 // Affine maps are uniqued based on their dim/symbol counts and affine 163 // expressions. 164 using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExpr>>; 165 using DenseMapInfo<AffineMap>::isEqual; 166 167 static unsigned getHashValue(const AffineMap &key) { 168 return getHashValue( 169 KeyTy(key.getNumDims(), key.getNumSymbols(), key.getResults())); 170 } 171 172 static unsigned getHashValue(KeyTy key) { 173 return hash_combine( 174 std::get<0>(key), std::get<1>(key), 175 hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end())); 176 } 177 178 static bool isEqual(const KeyTy &lhs, AffineMap rhs) { 179 if (rhs == getEmptyKey() || rhs == getTombstoneKey()) 180 return false; 181 return lhs == std::make_tuple(rhs.getNumDims(), rhs.getNumSymbols(), 182 rhs.getResults()); 183 } 184 }; 185 186 struct IntegerSetKeyInfo : DenseMapInfo<IntegerSet> { 187 // Integer sets are uniqued based on their dim/symbol counts, affine 188 // expressions appearing in the LHS of constraints, and eqFlags. 189 using KeyTy = 190 std::tuple<unsigned, unsigned, ArrayRef<AffineExpr>, ArrayRef<bool>>; 191 using DenseMapInfo<IntegerSet>::isEqual; 192 193 static unsigned getHashValue(const IntegerSet &key) { 194 return getHashValue(KeyTy(key.getNumDims(), key.getNumSymbols(), 195 key.getConstraints(), key.getEqFlags())); 196 } 197 198 static unsigned getHashValue(KeyTy key) { 199 return hash_combine( 200 std::get<0>(key), std::get<1>(key), 201 hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()), 202 hash_combine_range(std::get<3>(key).begin(), std::get<3>(key).end())); 203 } 204 205 static bool isEqual(const KeyTy &lhs, IntegerSet rhs) { 206 if (rhs == getEmptyKey() || rhs == getTombstoneKey()) 207 return false; 208 return lhs == std::make_tuple(rhs.getNumDims(), rhs.getNumSymbols(), 209 rhs.getConstraints(), rhs.getEqFlags()); 210 } 211 }; 212 } // end anonymous namespace. 213 214 //===----------------------------------------------------------------------===// 215 // MLIRContextImpl 216 //===----------------------------------------------------------------------===// 217 218 namespace mlir { 219 /// This is the implementation of the MLIRContext class, using the pImpl idiom. 220 /// This class is completely private to this file, so everything is public. 221 class MLIRContextImpl { 222 public: 223 //===--------------------------------------------------------------------===// 224 // Debugging 225 //===--------------------------------------------------------------------===// 226 227 /// An action manager for use within the context. 228 DebugActionManager debugActionManager; 229 230 //===--------------------------------------------------------------------===// 231 // Identifier uniquing 232 //===--------------------------------------------------------------------===// 233 234 // Identifier allocator and mutex for thread safety. 235 llvm::BumpPtrAllocator identifierAllocator; 236 llvm::sys::SmartRWMutex<true> identifierMutex; 237 238 //===--------------------------------------------------------------------===// 239 // Diagnostics 240 //===--------------------------------------------------------------------===// 241 DiagnosticEngine diagEngine; 242 243 //===--------------------------------------------------------------------===// 244 // Options 245 //===--------------------------------------------------------------------===// 246 247 /// In most cases, creating operation in unregistered dialect is not desired 248 /// and indicate a misconfiguration of the compiler. This option enables to 249 /// detect such use cases 250 bool allowUnregisteredDialects = false; 251 252 /// Enable support for multi-threading within MLIR. 253 bool threadingIsEnabled = true; 254 255 /// Track if we are currently executing in a threaded execution environment 256 /// (like the pass-manager): this is only a debugging feature to help reducing 257 /// the chances of data races one some context APIs. 258 #ifndef NDEBUG 259 std::atomic<int> multiThreadedExecutionContext{0}; 260 #endif 261 262 /// If the operation should be attached to diagnostics printed via the 263 /// Operation::emit methods. 264 bool printOpOnDiagnostic = true; 265 266 /// If the current stack trace should be attached when emitting diagnostics. 267 bool printStackTraceOnDiagnostic = false; 268 269 //===--------------------------------------------------------------------===// 270 // Other 271 //===--------------------------------------------------------------------===// 272 273 /// This points to the ThreadPool used when processing MLIR tasks in parallel. 274 /// It can't be nullptr when multi-threading is enabled. Otherwise if 275 /// multi-threading is disabled, and the threadpool wasn't externally provided 276 /// using `setThreadPool`, this will be nullptr. 277 llvm::ThreadPool *threadPool = nullptr; 278 279 /// In case where the thread pool is owned by the context, this ensures 280 /// destruction with the context. 281 std::unique_ptr<llvm::ThreadPool> ownedThreadPool; 282 283 /// This is a list of dialects that are created referring to this context. 284 /// The MLIRContext owns the objects. 285 DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects; 286 DialectRegistry dialectsRegistry; 287 288 /// This is a mapping from operation name to AbstractOperation for registered 289 /// operations. 290 llvm::StringMap<AbstractOperation> registeredOperations; 291 292 /// Identifiers are uniqued by string value and use the internal string set 293 /// for storage. 294 llvm::StringMap<PointerUnion<Dialect *, MLIRContext *>, 295 llvm::BumpPtrAllocator &> 296 identifiers; 297 298 /// An allocator used for AbstractAttribute and AbstractType objects. 299 llvm::BumpPtrAllocator abstractDialectSymbolAllocator; 300 301 //===--------------------------------------------------------------------===// 302 // Affine uniquing 303 //===--------------------------------------------------------------------===// 304 305 // Affine allocator and mutex for thread safety. 306 llvm::BumpPtrAllocator affineAllocator; 307 llvm::sys::SmartRWMutex<true> affineMutex; 308 309 // Affine map uniquing. 310 using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>; 311 AffineMapSet affineMaps; 312 313 // Integer set uniquing. 314 using IntegerSets = DenseSet<IntegerSet, IntegerSetKeyInfo>; 315 IntegerSets integerSets; 316 317 // Affine expression uniquing. 318 StorageUniquer affineUniquer; 319 320 //===--------------------------------------------------------------------===// 321 // Type uniquing 322 //===--------------------------------------------------------------------===// 323 324 DenseMap<TypeID, AbstractType *> registeredTypes; 325 StorageUniquer typeUniquer; 326 327 /// Cached Type Instances. 328 BFloat16Type bf16Ty; 329 Float16Type f16Ty; 330 Float32Type f32Ty; 331 Float64Type f64Ty; 332 Float80Type f80Ty; 333 Float128Type f128Ty; 334 IndexType indexTy; 335 IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty; 336 NoneType noneType; 337 338 //===--------------------------------------------------------------------===// 339 // Attribute uniquing 340 //===--------------------------------------------------------------------===// 341 342 DenseMap<TypeID, AbstractAttribute *> registeredAttributes; 343 StorageUniquer attributeUniquer; 344 345 /// Cached Attribute Instances. 346 BoolAttr falseAttr, trueAttr; 347 UnitAttr unitAttr; 348 UnknownLoc unknownLocAttr; 349 DictionaryAttr emptyDictionaryAttr; 350 StringAttr emptyStringAttr; 351 352 public: 353 MLIRContextImpl(bool threadingIsEnabled) 354 : threadingIsEnabled(threadingIsEnabled), 355 identifiers(identifierAllocator) { 356 if (threadingIsEnabled) { 357 ownedThreadPool = std::make_unique<llvm::ThreadPool>(); 358 threadPool = ownedThreadPool.get(); 359 } 360 } 361 ~MLIRContextImpl() { 362 for (auto typeMapping : registeredTypes) 363 typeMapping.second->~AbstractType(); 364 for (auto attrMapping : registeredAttributes) 365 attrMapping.second->~AbstractAttribute(); 366 } 367 }; 368 } // end namespace mlir 369 370 MLIRContext::MLIRContext(Threading setting) 371 : MLIRContext(DialectRegistry(), setting) {} 372 373 MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting) 374 : impl(new MLIRContextImpl(setting == Threading::ENABLED && 375 !isThreadingGloballyDisabled())) { 376 // Initialize values based on the command line flags if they were provided. 377 if (clOptions.isConstructed()) { 378 printOpOnDiagnostic(clOptions->printOpOnDiagnostic); 379 printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic); 380 } 381 382 // Pre-populate the registry. 383 registry.appendTo(impl->dialectsRegistry); 384 385 // Ensure the builtin dialect is always pre-loaded. 386 getOrLoadDialect<BuiltinDialect>(); 387 388 // Initialize several common attributes and types to avoid the need to lock 389 // the context when accessing them. 390 391 //// Types. 392 /// Floating-point Types. 393 impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this); 394 impl->f16Ty = TypeUniquer::get<Float16Type>(this); 395 impl->f32Ty = TypeUniquer::get<Float32Type>(this); 396 impl->f64Ty = TypeUniquer::get<Float64Type>(this); 397 impl->f80Ty = TypeUniquer::get<Float80Type>(this); 398 impl->f128Ty = TypeUniquer::get<Float128Type>(this); 399 /// Index Type. 400 impl->indexTy = TypeUniquer::get<IndexType>(this); 401 /// Integer Types. 402 impl->int1Ty = TypeUniquer::get<IntegerType>(this, 1, IntegerType::Signless); 403 impl->int8Ty = TypeUniquer::get<IntegerType>(this, 8, IntegerType::Signless); 404 impl->int16Ty = 405 TypeUniquer::get<IntegerType>(this, 16, IntegerType::Signless); 406 impl->int32Ty = 407 TypeUniquer::get<IntegerType>(this, 32, IntegerType::Signless); 408 impl->int64Ty = 409 TypeUniquer::get<IntegerType>(this, 64, IntegerType::Signless); 410 impl->int128Ty = 411 TypeUniquer::get<IntegerType>(this, 128, IntegerType::Signless); 412 /// None Type. 413 impl->noneType = TypeUniquer::get<NoneType>(this); 414 415 //// Attributes. 416 //// Note: These must be registered after the types as they may generate one 417 //// of the above types internally. 418 /// Unknown Location Attribute. 419 impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(this); 420 /// Bool Attributes. 421 impl->falseAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, false); 422 impl->trueAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, true); 423 /// Unit Attribute. 424 impl->unitAttr = AttributeUniquer::get<UnitAttr>(this); 425 /// The empty dictionary attribute. 426 impl->emptyDictionaryAttr = DictionaryAttr::getEmptyUnchecked(this); 427 /// The empty string attribute. 428 impl->emptyStringAttr = StringAttr::getEmptyStringAttrUnchecked(this); 429 430 // Register the affine storage objects with the uniquer. 431 impl->affineUniquer 432 .registerParametricStorageType<AffineBinaryOpExprStorage>(); 433 impl->affineUniquer 434 .registerParametricStorageType<AffineConstantExprStorage>(); 435 impl->affineUniquer.registerParametricStorageType<AffineDimExprStorage>(); 436 } 437 438 MLIRContext::~MLIRContext() {} 439 440 /// Copy the specified array of elements into memory managed by the provided 441 /// bump pointer allocator. This assumes the elements are all PODs. 442 template <typename T> 443 static ArrayRef<T> copyArrayRefInto(llvm::BumpPtrAllocator &allocator, 444 ArrayRef<T> elements) { 445 auto result = allocator.Allocate<T>(elements.size()); 446 std::uninitialized_copy(elements.begin(), elements.end(), result); 447 return ArrayRef<T>(result, elements.size()); 448 } 449 450 //===----------------------------------------------------------------------===// 451 // Debugging 452 //===----------------------------------------------------------------------===// 453 454 DebugActionManager &MLIRContext::getDebugActionManager() { 455 return getImpl().debugActionManager; 456 } 457 458 //===----------------------------------------------------------------------===// 459 // Diagnostic Handlers 460 //===----------------------------------------------------------------------===// 461 462 /// Returns the diagnostic engine for this context. 463 DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; } 464 465 //===----------------------------------------------------------------------===// 466 // Dialect and Operation Registration 467 //===----------------------------------------------------------------------===// 468 469 void MLIRContext::appendDialectRegistry(const DialectRegistry ®istry) { 470 registry.appendTo(impl->dialectsRegistry); 471 472 // For the already loaded dialects, register the interfaces immediately. 473 for (const auto &kvp : impl->loadedDialects) 474 registry.registerDelayedInterfaces(kvp.second.get()); 475 } 476 477 const DialectRegistry &MLIRContext::getDialectRegistry() { 478 return impl->dialectsRegistry; 479 } 480 481 /// Return information about all registered IR dialects. 482 std::vector<Dialect *> MLIRContext::getLoadedDialects() { 483 std::vector<Dialect *> result; 484 result.reserve(impl->loadedDialects.size()); 485 for (auto &dialect : impl->loadedDialects) 486 result.push_back(dialect.second.get()); 487 llvm::array_pod_sort(result.begin(), result.end(), 488 [](Dialect *const *lhs, Dialect *const *rhs) -> int { 489 return (*lhs)->getNamespace() < (*rhs)->getNamespace(); 490 }); 491 return result; 492 } 493 std::vector<StringRef> MLIRContext::getAvailableDialects() { 494 std::vector<StringRef> result; 495 for (auto dialect : impl->dialectsRegistry.getDialectNames()) 496 result.push_back(dialect); 497 return result; 498 } 499 500 /// Get a registered IR dialect with the given namespace. If none is found, 501 /// then return nullptr. 502 Dialect *MLIRContext::getLoadedDialect(StringRef name) { 503 // Dialects are sorted by name, so we can use binary search for lookup. 504 auto it = impl->loadedDialects.find(name); 505 return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr; 506 } 507 508 Dialect *MLIRContext::getOrLoadDialect(StringRef name) { 509 Dialect *dialect = getLoadedDialect(name); 510 if (dialect) 511 return dialect; 512 DialectAllocatorFunctionRef allocator = 513 impl->dialectsRegistry.getDialectAllocator(name); 514 return allocator ? allocator(this) : nullptr; 515 } 516 517 /// Get a dialect for the provided namespace and TypeID: abort the program if a 518 /// dialect exist for this namespace with different TypeID. Returns a pointer to 519 /// the dialect owned by the context. 520 Dialect * 521 MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID, 522 function_ref<std::unique_ptr<Dialect>()> ctor) { 523 auto &impl = getImpl(); 524 // Get the correct insertion position sorted by namespace. 525 std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace]; 526 527 if (!dialect) { 528 LLVM_DEBUG(llvm::dbgs() 529 << "Load new dialect in Context " << dialectNamespace << "\n"); 530 #ifndef NDEBUG 531 if (impl.multiThreadedExecutionContext != 0) 532 llvm::report_fatal_error( 533 "Loading a dialect (" + dialectNamespace + 534 ") while in a multi-threaded execution context (maybe " 535 "the PassManager): this can indicate a " 536 "missing `dependentDialects` in a pass for example."); 537 #endif 538 dialect = ctor(); 539 assert(dialect && "dialect ctor failed"); 540 541 // Refresh all the identifiers dialect field, this catches cases where a 542 // dialect may be loaded after identifier prefixed with this dialect name 543 // were already created. 544 llvm::SmallString<32> dialectPrefix(dialectNamespace); 545 dialectPrefix.push_back('.'); 546 for (auto &identifierEntry : impl.identifiers) 547 if (identifierEntry.second.is<MLIRContext *>() && 548 identifierEntry.first().startswith(dialectPrefix)) 549 identifierEntry.second = dialect.get(); 550 551 // Actually register the interfaces with delayed registration. 552 impl.dialectsRegistry.registerDelayedInterfaces(dialect.get()); 553 return dialect.get(); 554 } 555 556 // Abort if dialect with namespace has already been registered. 557 if (dialect->getTypeID() != dialectID) 558 llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace + 559 "' has already been registered"); 560 561 return dialect.get(); 562 } 563 564 void MLIRContext::loadAllAvailableDialects() { 565 for (StringRef name : getAvailableDialects()) 566 getOrLoadDialect(name); 567 } 568 569 llvm::hash_code MLIRContext::getRegistryHash() { 570 llvm::hash_code hash(0); 571 // Factor in number of loaded dialects, attributes, operations, types. 572 hash = llvm::hash_combine(hash, impl->loadedDialects.size()); 573 hash = llvm::hash_combine(hash, impl->registeredAttributes.size()); 574 hash = llvm::hash_combine(hash, impl->registeredOperations.size()); 575 hash = llvm::hash_combine(hash, impl->registeredTypes.size()); 576 return hash; 577 } 578 579 bool MLIRContext::allowsUnregisteredDialects() { 580 return impl->allowUnregisteredDialects; 581 } 582 583 void MLIRContext::allowUnregisteredDialects(bool allowing) { 584 impl->allowUnregisteredDialects = allowing; 585 } 586 587 /// Return true if multi-threading is enabled by the context. 588 bool MLIRContext::isMultithreadingEnabled() { 589 return impl->threadingIsEnabled && llvm::llvm_is_multithreaded(); 590 } 591 592 /// Set the flag specifying if multi-threading is disabled by the context. 593 void MLIRContext::disableMultithreading(bool disable) { 594 // This API can be overridden by the global debugging flag 595 // --mlir-disable-threading 596 if (isThreadingGloballyDisabled()) 597 return; 598 599 impl->threadingIsEnabled = !disable; 600 601 // Update the threading mode for each of the uniquers. 602 impl->affineUniquer.disableMultithreading(disable); 603 impl->attributeUniquer.disableMultithreading(disable); 604 impl->typeUniquer.disableMultithreading(disable); 605 606 // Destroy thread pool (stop all threads) if it is no longer needed, or create 607 // a new one if multithreading was re-enabled. 608 if (disable) { 609 // If the thread pool is owned, explicitly set it to nullptr to avoid 610 // keeping a dangling pointer around. If the thread pool is externally 611 // owned, we don't do anything. 612 if (impl->ownedThreadPool) { 613 assert(impl->threadPool); 614 impl->threadPool = nullptr; 615 impl->ownedThreadPool.reset(); 616 } 617 } else if (!impl->threadPool) { 618 // The thread pool isn't externally provided. 619 assert(!impl->ownedThreadPool); 620 impl->ownedThreadPool = std::make_unique<llvm::ThreadPool>(); 621 impl->threadPool = impl->ownedThreadPool.get(); 622 } 623 } 624 625 void MLIRContext::setThreadPool(llvm::ThreadPool &pool) { 626 assert(!isMultithreadingEnabled() && 627 "expected multi-threading to be disabled when setting a ThreadPool"); 628 impl->threadPool = &pool; 629 impl->ownedThreadPool.reset(); 630 enableMultithreading(); 631 } 632 633 llvm::ThreadPool &MLIRContext::getThreadPool() { 634 assert(isMultithreadingEnabled() && 635 "expected multi-threading to be enabled within the context"); 636 assert(impl->threadPool && 637 "multi-threading is enabled but threadpool not set"); 638 return *impl->threadPool; 639 } 640 641 void MLIRContext::enterMultiThreadedExecution() { 642 #ifndef NDEBUG 643 ++impl->multiThreadedExecutionContext; 644 #endif 645 } 646 void MLIRContext::exitMultiThreadedExecution() { 647 #ifndef NDEBUG 648 --impl->multiThreadedExecutionContext; 649 #endif 650 } 651 652 /// Return true if we should attach the operation to diagnostics emitted via 653 /// Operation::emit. 654 bool MLIRContext::shouldPrintOpOnDiagnostic() { 655 return impl->printOpOnDiagnostic; 656 } 657 658 /// Set the flag specifying if we should attach the operation to diagnostics 659 /// emitted via Operation::emit. 660 void MLIRContext::printOpOnDiagnostic(bool enable) { 661 impl->printOpOnDiagnostic = enable; 662 } 663 664 /// Return true if we should attach the current stacktrace to diagnostics when 665 /// emitted. 666 bool MLIRContext::shouldPrintStackTraceOnDiagnostic() { 667 return impl->printStackTraceOnDiagnostic; 668 } 669 670 /// Set the flag specifying if we should attach the current stacktrace when 671 /// emitting diagnostics. 672 void MLIRContext::printStackTraceOnDiagnostic(bool enable) { 673 impl->printStackTraceOnDiagnostic = enable; 674 } 675 676 /// Return information about all registered operations. This isn't very 677 /// efficient, typically you should ask the operations about their properties 678 /// directly. 679 std::vector<AbstractOperation *> MLIRContext::getRegisteredOperations() { 680 // We just have the operations in a non-deterministic hash table order. Dump 681 // into a temporary array, then sort it by operation name to get a stable 682 // ordering. 683 llvm::StringMap<AbstractOperation> ®isteredOps = 684 impl->registeredOperations; 685 686 std::vector<AbstractOperation *> result; 687 result.reserve(registeredOps.size()); 688 for (auto &elt : registeredOps) 689 result.push_back(&elt.second); 690 llvm::array_pod_sort( 691 result.begin(), result.end(), 692 [](AbstractOperation *const *lhs, AbstractOperation *const *rhs) { 693 return (*lhs)->name.compare((*rhs)->name); 694 }); 695 696 return result; 697 } 698 699 bool MLIRContext::isOperationRegistered(StringRef name) { 700 return impl->registeredOperations.count(name); 701 } 702 703 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) { 704 auto &impl = context->getImpl(); 705 assert(impl.multiThreadedExecutionContext == 0 && 706 "Registering a new type kind while in a multi-threaded execution " 707 "context"); 708 auto *newInfo = 709 new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>()) 710 AbstractType(std::move(typeInfo)); 711 if (!impl.registeredTypes.insert({typeID, newInfo}).second) 712 llvm::report_fatal_error("Dialect Type already registered."); 713 } 714 715 void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) { 716 auto &impl = context->getImpl(); 717 assert(impl.multiThreadedExecutionContext == 0 && 718 "Registering a new attribute kind while in a multi-threaded execution " 719 "context"); 720 auto *newInfo = 721 new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>()) 722 AbstractAttribute(std::move(attrInfo)); 723 if (!impl.registeredAttributes.insert({typeID, newInfo}).second) 724 llvm::report_fatal_error("Dialect Attribute already registered."); 725 } 726 727 //===----------------------------------------------------------------------===// 728 // AbstractAttribute 729 //===----------------------------------------------------------------------===// 730 731 /// Get the dialect that registered the attribute with the provided typeid. 732 const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID, 733 MLIRContext *context) { 734 const AbstractAttribute *abstract = lookupMutable(typeID, context); 735 if (!abstract) 736 llvm::report_fatal_error("Trying to create an Attribute that was not " 737 "registered in this MLIRContext."); 738 return *abstract; 739 } 740 741 AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID, 742 MLIRContext *context) { 743 auto &impl = context->getImpl(); 744 auto it = impl.registeredAttributes.find(typeID); 745 if (it == impl.registeredAttributes.end()) 746 return nullptr; 747 return it->second; 748 } 749 750 //===----------------------------------------------------------------------===// 751 // AbstractOperation 752 //===----------------------------------------------------------------------===// 753 754 ParseResult AbstractOperation::parseAssembly(OpAsmParser &parser, 755 OperationState &result) const { 756 return parseAssemblyFn(parser, result); 757 } 758 759 /// Look up the specified operation in the operation set and return a pointer 760 /// to it if present. Otherwise, return a null pointer. 761 AbstractOperation *AbstractOperation::lookupMutable(StringRef opName, 762 MLIRContext *context) { 763 auto &impl = context->getImpl(); 764 auto it = impl.registeredOperations.find(opName); 765 if (it != impl.registeredOperations.end()) 766 return &it->second; 767 return nullptr; 768 } 769 770 void AbstractOperation::insert( 771 StringRef name, Dialect &dialect, TypeID typeID, 772 ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly, 773 VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook, 774 GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, 775 detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait, 776 ArrayRef<StringRef> attrNames) { 777 MLIRContext *ctx = dialect.getContext(); 778 auto &impl = ctx->getImpl(); 779 assert(impl.multiThreadedExecutionContext == 0 && 780 "Registering a new operation kind while in a multi-threaded execution " 781 "context"); 782 783 // Register the attribute names of this operation. 784 MutableArrayRef<Identifier> cachedAttrNames; 785 if (!attrNames.empty()) { 786 cachedAttrNames = MutableArrayRef<Identifier>( 787 impl.identifierAllocator.Allocate<Identifier>(attrNames.size()), 788 attrNames.size()); 789 for (unsigned i : llvm::seq<unsigned>(0, attrNames.size())) 790 new (&cachedAttrNames[i]) Identifier(Identifier::get(attrNames[i], ctx)); 791 } 792 793 // Register the information for this operation. 794 AbstractOperation opInfo( 795 name, dialect, typeID, std::move(parseAssembly), std::move(printAssembly), 796 std::move(verifyInvariants), std::move(foldHook), 797 std::move(getCanonicalizationPatterns), std::move(interfaceMap), 798 std::move(hasTrait), cachedAttrNames); 799 if (!impl.registeredOperations.insert({name, std::move(opInfo)}).second) { 800 llvm::errs() << "error: operation named '" << name 801 << "' is already registered.\n"; 802 abort(); 803 } 804 } 805 806 AbstractOperation::AbstractOperation( 807 StringRef name, Dialect &dialect, TypeID typeID, 808 ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly, 809 VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook, 810 GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, 811 detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait, 812 ArrayRef<Identifier> attrNames) 813 : name(Identifier::get(name, dialect.getContext())), dialect(dialect), 814 typeID(typeID), interfaceMap(std::move(interfaceMap)), 815 foldHookFn(std::move(foldHook)), 816 getCanonicalizationPatternsFn(std::move(getCanonicalizationPatterns)), 817 hasTraitFn(std::move(hasTrait)), 818 parseAssemblyFn(std::move(parseAssembly)), 819 printAssemblyFn(std::move(printAssembly)), 820 verifyInvariantsFn(std::move(verifyInvariants)), 821 attributeNames(attrNames) {} 822 823 //===----------------------------------------------------------------------===// 824 // AbstractType 825 //===----------------------------------------------------------------------===// 826 827 const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) { 828 const AbstractType *type = lookupMutable(typeID, context); 829 if (!type) 830 llvm::report_fatal_error( 831 "Trying to create a Type that was not registered in this MLIRContext."); 832 return *type; 833 } 834 835 AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) { 836 auto &impl = context->getImpl(); 837 auto it = impl.registeredTypes.find(typeID); 838 if (it == impl.registeredTypes.end()) 839 return nullptr; 840 return it->second; 841 } 842 843 //===----------------------------------------------------------------------===// 844 // Identifier uniquing 845 //===----------------------------------------------------------------------===// 846 847 /// Return an identifier for the specified string. 848 Identifier Identifier::get(const Twine &string, MLIRContext *context) { 849 SmallString<32> tempStr; 850 StringRef str = string.toStringRef(tempStr); 851 852 // Check invariants after seeing if we already have something in the 853 // identifier table - if we already had it in the table, then it already 854 // passed invariant checks. 855 assert(!str.empty() && "Cannot create an empty identifier"); 856 assert(str.find('\0') == StringRef::npos && 857 "Cannot create an identifier with a nul character"); 858 859 auto getDialectOrContext = [&]() { 860 PointerUnion<Dialect *, MLIRContext *> dialectOrContext = context; 861 auto dialectNamePair = str.split('.'); 862 if (!dialectNamePair.first.empty()) 863 if (Dialect *dialect = context->getLoadedDialect(dialectNamePair.first)) 864 dialectOrContext = dialect; 865 return dialectOrContext; 866 }; 867 868 auto &impl = context->getImpl(); 869 if (!context->isMultithreadingEnabled()) { 870 auto insertedIt = impl.identifiers.insert({str, nullptr}); 871 if (insertedIt.second) 872 insertedIt.first->second = getDialectOrContext(); 873 return Identifier(&*insertedIt.first); 874 } 875 876 // Check for an existing identifier in read-only mode. 877 { 878 llvm::sys::SmartScopedReader<true> contextLock(impl.identifierMutex); 879 auto it = impl.identifiers.find(str); 880 if (it != impl.identifiers.end()) 881 return Identifier(&*it); 882 } 883 884 // Acquire a writer-lock so that we can safely create the new instance. 885 llvm::sys::SmartScopedWriter<true> contextLock(impl.identifierMutex); 886 auto it = impl.identifiers.insert({str, getDialectOrContext()}).first; 887 return Identifier(&*it); 888 } 889 890 Dialect *Identifier::getDialect() { 891 return entry->second.dyn_cast<Dialect *>(); 892 } 893 894 MLIRContext *Identifier::getContext() { 895 if (Dialect *dialect = getDialect()) 896 return dialect->getContext(); 897 return entry->second.get<MLIRContext *>(); 898 } 899 900 //===----------------------------------------------------------------------===// 901 // Type uniquing 902 //===----------------------------------------------------------------------===// 903 904 /// Returns the storage uniquer used for constructing type storage instances. 905 /// This should not be used directly. 906 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; } 907 908 BFloat16Type BFloat16Type::get(MLIRContext *context) { 909 return context->getImpl().bf16Ty; 910 } 911 Float16Type Float16Type::get(MLIRContext *context) { 912 return context->getImpl().f16Ty; 913 } 914 Float32Type Float32Type::get(MLIRContext *context) { 915 return context->getImpl().f32Ty; 916 } 917 Float64Type Float64Type::get(MLIRContext *context) { 918 return context->getImpl().f64Ty; 919 } 920 Float80Type Float80Type::get(MLIRContext *context) { 921 return context->getImpl().f80Ty; 922 } 923 Float128Type Float128Type::get(MLIRContext *context) { 924 return context->getImpl().f128Ty; 925 } 926 927 /// Get an instance of the IndexType. 928 IndexType IndexType::get(MLIRContext *context) { 929 return context->getImpl().indexTy; 930 } 931 932 /// Return an existing integer type instance if one is cached within the 933 /// context. 934 static IntegerType 935 getCachedIntegerType(unsigned width, 936 IntegerType::SignednessSemantics signedness, 937 MLIRContext *context) { 938 if (signedness != IntegerType::Signless) 939 return IntegerType(); 940 941 switch (width) { 942 case 1: 943 return context->getImpl().int1Ty; 944 case 8: 945 return context->getImpl().int8Ty; 946 case 16: 947 return context->getImpl().int16Ty; 948 case 32: 949 return context->getImpl().int32Ty; 950 case 64: 951 return context->getImpl().int64Ty; 952 case 128: 953 return context->getImpl().int128Ty; 954 default: 955 return IntegerType(); 956 } 957 } 958 959 IntegerType IntegerType::get(MLIRContext *context, unsigned width, 960 IntegerType::SignednessSemantics signedness) { 961 if (auto cached = getCachedIntegerType(width, signedness, context)) 962 return cached; 963 return Base::get(context, width, signedness); 964 } 965 966 IntegerType 967 IntegerType::getChecked(function_ref<InFlightDiagnostic()> emitError, 968 MLIRContext *context, unsigned width, 969 SignednessSemantics signedness) { 970 if (auto cached = getCachedIntegerType(width, signedness, context)) 971 return cached; 972 return Base::getChecked(emitError, context, width, signedness); 973 } 974 975 /// Get an instance of the NoneType. 976 NoneType NoneType::get(MLIRContext *context) { 977 if (NoneType cachedInst = context->getImpl().noneType) 978 return cachedInst; 979 // Note: May happen when initializing the singleton attributes of the builtin 980 // dialect. 981 return Base::get(context); 982 } 983 984 //===----------------------------------------------------------------------===// 985 // Attribute uniquing 986 //===----------------------------------------------------------------------===// 987 988 /// Returns the storage uniquer used for constructing attribute storage 989 /// instances. This should not be used directly. 990 StorageUniquer &MLIRContext::getAttributeUniquer() { 991 return getImpl().attributeUniquer; 992 } 993 994 /// Initialize the given attribute storage instance. 995 void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage, 996 MLIRContext *ctx, 997 TypeID attrID) { 998 storage->initialize(AbstractAttribute::lookup(attrID, ctx)); 999 1000 // If the attribute did not provide a type, then default to NoneType. 1001 if (!storage->getType()) 1002 storage->setType(NoneType::get(ctx)); 1003 } 1004 1005 BoolAttr BoolAttr::get(MLIRContext *context, bool value) { 1006 return value ? context->getImpl().trueAttr : context->getImpl().falseAttr; 1007 } 1008 1009 UnitAttr UnitAttr::get(MLIRContext *context) { 1010 return context->getImpl().unitAttr; 1011 } 1012 1013 UnknownLoc UnknownLoc::get(MLIRContext *context) { 1014 return context->getImpl().unknownLocAttr; 1015 } 1016 1017 /// Return empty dictionary. 1018 DictionaryAttr DictionaryAttr::getEmpty(MLIRContext *context) { 1019 return context->getImpl().emptyDictionaryAttr; 1020 } 1021 1022 /// Return an empty string. 1023 StringAttr StringAttr::get(MLIRContext *context) { 1024 return context->getImpl().emptyStringAttr; 1025 } 1026 1027 //===----------------------------------------------------------------------===// 1028 // AffineMap uniquing 1029 //===----------------------------------------------------------------------===// 1030 1031 StorageUniquer &MLIRContext::getAffineUniquer() { 1032 return getImpl().affineUniquer; 1033 } 1034 1035 AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount, 1036 ArrayRef<AffineExpr> results, 1037 MLIRContext *context) { 1038 auto &impl = context->getImpl(); 1039 auto key = std::make_tuple(dimCount, symbolCount, results); 1040 1041 // Safely get or create an AffineMap instance. 1042 return safeGetOrCreate( 1043 impl.affineMaps, key, impl.affineMutex, impl.threadingIsEnabled, [&] { 1044 auto *res = impl.affineAllocator.Allocate<detail::AffineMapStorage>(); 1045 1046 // Copy the results into the bump pointer. 1047 results = copyArrayRefInto(impl.affineAllocator, results); 1048 1049 // Initialize the memory using placement new. 1050 new (res) 1051 detail::AffineMapStorage{dimCount, symbolCount, results, context}; 1052 return AffineMap(res); 1053 }); 1054 } 1055 1056 AffineMap AffineMap::get(MLIRContext *context) { 1057 return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context); 1058 } 1059 1060 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, 1061 MLIRContext *context) { 1062 return getImpl(dimCount, symbolCount, /*results=*/{}, context); 1063 } 1064 1065 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, 1066 AffineExpr result) { 1067 return getImpl(dimCount, symbolCount, {result}, result.getContext()); 1068 } 1069 1070 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, 1071 ArrayRef<AffineExpr> results, MLIRContext *context) { 1072 return getImpl(dimCount, symbolCount, results, context); 1073 } 1074 1075 //===----------------------------------------------------------------------===// 1076 // Integer Sets: these are allocated into the bump pointer, and are immutable. 1077 // Unlike AffineMap's, these are uniqued only if they are small. 1078 //===----------------------------------------------------------------------===// 1079 1080 IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount, 1081 ArrayRef<AffineExpr> constraints, 1082 ArrayRef<bool> eqFlags) { 1083 // The number of constraints can't be zero. 1084 assert(!constraints.empty()); 1085 assert(constraints.size() == eqFlags.size()); 1086 1087 auto &impl = constraints[0].getContext()->getImpl(); 1088 1089 // A utility function to construct a new IntegerSetStorage instance. 1090 auto constructorFn = [&] { 1091 auto *res = impl.affineAllocator.Allocate<detail::IntegerSetStorage>(); 1092 1093 // Copy the results and equality flags into the bump pointer. 1094 constraints = copyArrayRefInto(impl.affineAllocator, constraints); 1095 eqFlags = copyArrayRefInto(impl.affineAllocator, eqFlags); 1096 1097 // Initialize the memory using placement new. 1098 new (res) 1099 detail::IntegerSetStorage{dimCount, symbolCount, constraints, eqFlags}; 1100 return IntegerSet(res); 1101 }; 1102 1103 // If this instance is uniqued, then we handle it separately so that multiple 1104 // threads may simultaneously access existing instances. 1105 if (constraints.size() < IntegerSet::kUniquingThreshold) { 1106 auto key = std::make_tuple(dimCount, symbolCount, constraints, eqFlags); 1107 return safeGetOrCreate(impl.integerSets, key, impl.affineMutex, 1108 impl.threadingIsEnabled, constructorFn); 1109 } 1110 1111 // Otherwise, acquire a writer-lock so that we can safely create the new 1112 // instance. 1113 ScopedWriterLock affineLock(impl.affineMutex, impl.threadingIsEnabled); 1114 return constructorFn(); 1115 } 1116 1117 //===----------------------------------------------------------------------===// 1118 // StorageUniquerSupport 1119 //===----------------------------------------------------------------------===// 1120 1121 /// Utility method to generate a callback that can be used to generate a 1122 /// diagnostic when checking the construction invariants of a storage object. 1123 /// This is defined out-of-line to avoid the need to include Location.h. 1124 llvm::unique_function<InFlightDiagnostic()> 1125 mlir::detail::getDefaultDiagnosticEmitFn(MLIRContext *ctx) { 1126 return [ctx] { return emitError(UnknownLoc::get(ctx)); }; 1127 } 1128 llvm::unique_function<InFlightDiagnostic()> 1129 mlir::detail::getDefaultDiagnosticEmitFn(const Location &loc) { 1130 return [=] { return emitError(loc); }; 1131 } 1132