1 //===- MLIRContext.cpp - MLIR Type Classes --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/MLIRContext.h" 10 #include "AffineExprDetail.h" 11 #include "AffineMapDetail.h" 12 #include "AttributeDetail.h" 13 #include "IntegerSetDetail.h" 14 #include "TypeDetail.h" 15 #include "mlir/IR/AffineExpr.h" 16 #include "mlir/IR/AffineMap.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/BuiltinDialect.h" 19 #include "mlir/IR/Diagnostics.h" 20 #include "mlir/IR/Dialect.h" 21 #include "mlir/IR/IntegerSet.h" 22 #include "mlir/IR/Location.h" 23 #include "mlir/IR/OpImplementation.h" 24 #include "mlir/IR/Types.h" 25 #include "mlir/Support/DebugAction.h" 26 #include "llvm/ADT/DenseMap.h" 27 #include "llvm/ADT/DenseSet.h" 28 #include "llvm/ADT/SetVector.h" 29 #include "llvm/ADT/SmallString.h" 30 #include "llvm/ADT/StringSet.h" 31 #include "llvm/ADT/Twine.h" 32 #include "llvm/Support/Allocator.h" 33 #include "llvm/Support/CommandLine.h" 34 #include "llvm/Support/Debug.h" 35 #include "llvm/Support/Mutex.h" 36 #include "llvm/Support/RWMutex.h" 37 #include "llvm/Support/ThreadPool.h" 38 #include "llvm/Support/raw_ostream.h" 39 #include <memory> 40 41 #define DEBUG_TYPE "mlircontext" 42 43 using namespace mlir; 44 using namespace mlir::detail; 45 46 using llvm::hash_combine; 47 using llvm::hash_combine_range; 48 49 //===----------------------------------------------------------------------===// 50 // MLIRContext CommandLine Options 51 //===----------------------------------------------------------------------===// 52 53 namespace { 54 /// This struct contains command line options that can be used to initialize 55 /// various bits of an MLIRContext. This uses a struct wrapper to avoid the need 56 /// for global command line options. 57 struct MLIRContextOptions { 58 llvm::cl::opt<bool> disableThreading{ 59 "mlir-disable-threading", 60 llvm::cl::desc("Disable multi-threading within MLIR, overrides any " 61 "further call to MLIRContext::enableMultiThreading()")}; 62 63 llvm::cl::opt<bool> printOpOnDiagnostic{ 64 "mlir-print-op-on-diagnostic", 65 llvm::cl::desc("When a diagnostic is emitted on an operation, also print " 66 "the operation as an attached note"), 67 llvm::cl::init(true)}; 68 69 llvm::cl::opt<bool> printStackTraceOnDiagnostic{ 70 "mlir-print-stacktrace-on-diagnostic", 71 llvm::cl::desc("When a diagnostic is emitted, also print the stack trace " 72 "as an attached note")}; 73 }; 74 } // end anonymous namespace 75 76 static llvm::ManagedStatic<MLIRContextOptions> clOptions; 77 78 static bool isThreadingGloballyDisabled() { 79 #if LLVM_ENABLE_THREADS != 0 80 return clOptions.isConstructed() && clOptions->disableThreading; 81 #else 82 return true; 83 #endif 84 } 85 86 /// Register a set of useful command-line options that can be used to configure 87 /// various flags within the MLIRContext. These flags are used when constructing 88 /// an MLIR context for initialization. 89 void mlir::registerMLIRContextCLOptions() { 90 // Make sure that the options struct has been initialized. 91 *clOptions; 92 } 93 94 //===----------------------------------------------------------------------===// 95 // Locking Utilities 96 //===----------------------------------------------------------------------===// 97 98 namespace { 99 /// Utility writer lock that takes a runtime flag that specifies if we really 100 /// need to lock. 101 struct ScopedWriterLock { 102 ScopedWriterLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock) 103 : mutex(shouldLock ? &mutexParam : nullptr) { 104 if (mutex) 105 mutex->lock(); 106 } 107 ~ScopedWriterLock() { 108 if (mutex) 109 mutex->unlock(); 110 } 111 llvm::sys::SmartRWMutex<true> *mutex; 112 }; 113 } // end anonymous namespace. 114 115 //===----------------------------------------------------------------------===// 116 // AffineMap and IntegerSet hashing 117 //===----------------------------------------------------------------------===// 118 119 /// A utility function to safely get or create a uniqued instance within the 120 /// given set container. 121 template <typename ValueT, typename DenseInfoT, typename KeyT, 122 typename ConstructorFn> 123 static ValueT safeGetOrCreate(DenseSet<ValueT, DenseInfoT> &container, 124 KeyT &&key, llvm::sys::SmartRWMutex<true> &mutex, 125 bool threadingIsEnabled, 126 ConstructorFn &&constructorFn) { 127 // Check for an existing instance in read-only mode. 128 if (threadingIsEnabled) { 129 llvm::sys::SmartScopedReader<true> instanceLock(mutex); 130 auto it = container.find_as(key); 131 if (it != container.end()) 132 return *it; 133 } 134 135 // Acquire a writer-lock so that we can safely create the new instance. 136 ScopedWriterLock instanceLock(mutex, threadingIsEnabled); 137 138 // Check for an existing instance again here, because another writer thread 139 // may have already created one. Otherwise, construct a new instance. 140 auto existing = container.insert_as(ValueT(), key); 141 if (existing.second) 142 return *existing.first = constructorFn(); 143 return *existing.first; 144 } 145 146 namespace { 147 struct AffineMapKeyInfo : DenseMapInfo<AffineMap> { 148 // Affine maps are uniqued based on their dim/symbol counts and affine 149 // expressions. 150 using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExpr>>; 151 using DenseMapInfo<AffineMap>::isEqual; 152 153 static unsigned getHashValue(const AffineMap &key) { 154 return getHashValue( 155 KeyTy(key.getNumDims(), key.getNumSymbols(), key.getResults())); 156 } 157 158 static unsigned getHashValue(KeyTy key) { 159 return hash_combine( 160 std::get<0>(key), std::get<1>(key), 161 hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end())); 162 } 163 164 static bool isEqual(const KeyTy &lhs, AffineMap rhs) { 165 if (rhs == getEmptyKey() || rhs == getTombstoneKey()) 166 return false; 167 return lhs == std::make_tuple(rhs.getNumDims(), rhs.getNumSymbols(), 168 rhs.getResults()); 169 } 170 }; 171 172 struct IntegerSetKeyInfo : DenseMapInfo<IntegerSet> { 173 // Integer sets are uniqued based on their dim/symbol counts, affine 174 // expressions appearing in the LHS of constraints, and eqFlags. 175 using KeyTy = 176 std::tuple<unsigned, unsigned, ArrayRef<AffineExpr>, ArrayRef<bool>>; 177 using DenseMapInfo<IntegerSet>::isEqual; 178 179 static unsigned getHashValue(const IntegerSet &key) { 180 return getHashValue(KeyTy(key.getNumDims(), key.getNumSymbols(), 181 key.getConstraints(), key.getEqFlags())); 182 } 183 184 static unsigned getHashValue(KeyTy key) { 185 return hash_combine( 186 std::get<0>(key), std::get<1>(key), 187 hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()), 188 hash_combine_range(std::get<3>(key).begin(), std::get<3>(key).end())); 189 } 190 191 static bool isEqual(const KeyTy &lhs, IntegerSet rhs) { 192 if (rhs == getEmptyKey() || rhs == getTombstoneKey()) 193 return false; 194 return lhs == std::make_tuple(rhs.getNumDims(), rhs.getNumSymbols(), 195 rhs.getConstraints(), rhs.getEqFlags()); 196 } 197 }; 198 } // end anonymous namespace. 199 200 //===----------------------------------------------------------------------===// 201 // MLIRContextImpl 202 //===----------------------------------------------------------------------===// 203 204 namespace mlir { 205 /// This is the implementation of the MLIRContext class, using the pImpl idiom. 206 /// This class is completely private to this file, so everything is public. 207 class MLIRContextImpl { 208 public: 209 //===--------------------------------------------------------------------===// 210 // Debugging 211 //===--------------------------------------------------------------------===// 212 213 /// An action manager for use within the context. 214 DebugActionManager debugActionManager; 215 216 //===--------------------------------------------------------------------===// 217 // Diagnostics 218 //===--------------------------------------------------------------------===// 219 DiagnosticEngine diagEngine; 220 221 //===--------------------------------------------------------------------===// 222 // Options 223 //===--------------------------------------------------------------------===// 224 225 /// In most cases, creating operation in unregistered dialect is not desired 226 /// and indicate a misconfiguration of the compiler. This option enables to 227 /// detect such use cases 228 bool allowUnregisteredDialects = false; 229 230 /// Enable support for multi-threading within MLIR. 231 bool threadingIsEnabled = true; 232 233 /// Track if we are currently executing in a threaded execution environment 234 /// (like the pass-manager): this is only a debugging feature to help reducing 235 /// the chances of data races one some context APIs. 236 #ifndef NDEBUG 237 std::atomic<int> multiThreadedExecutionContext{0}; 238 #endif 239 240 /// If the operation should be attached to diagnostics printed via the 241 /// Operation::emit methods. 242 bool printOpOnDiagnostic = true; 243 244 /// If the current stack trace should be attached when emitting diagnostics. 245 bool printStackTraceOnDiagnostic = false; 246 247 //===--------------------------------------------------------------------===// 248 // Other 249 //===--------------------------------------------------------------------===// 250 251 /// This points to the ThreadPool used when processing MLIR tasks in parallel. 252 /// It can't be nullptr when multi-threading is enabled. Otherwise if 253 /// multi-threading is disabled, and the threadpool wasn't externally provided 254 /// using `setThreadPool`, this will be nullptr. 255 llvm::ThreadPool *threadPool = nullptr; 256 257 /// In case where the thread pool is owned by the context, this ensures 258 /// destruction with the context. 259 std::unique_ptr<llvm::ThreadPool> ownedThreadPool; 260 261 /// This is a list of dialects that are created referring to this context. 262 /// The MLIRContext owns the objects. 263 DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects; 264 DialectRegistry dialectsRegistry; 265 266 /// An allocator used for AbstractAttribute and AbstractType objects. 267 llvm::BumpPtrAllocator abstractDialectSymbolAllocator; 268 269 /// This is a mapping from operation name to the operation info describing it. 270 llvm::StringMap<OperationName::Impl> operations; 271 272 /// A vector of operation info specifically for registered operations. 273 SmallVector<RegisteredOperationName> registeredOperations; 274 275 /// A mutex used when accessing operation information. 276 llvm::sys::SmartRWMutex<true> operationInfoMutex; 277 278 //===--------------------------------------------------------------------===// 279 // Affine uniquing 280 //===--------------------------------------------------------------------===// 281 282 // Affine allocator and mutex for thread safety. 283 llvm::BumpPtrAllocator affineAllocator; 284 llvm::sys::SmartRWMutex<true> affineMutex; 285 286 // Affine map uniquing. 287 using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>; 288 AffineMapSet affineMaps; 289 290 // Integer set uniquing. 291 using IntegerSets = DenseSet<IntegerSet, IntegerSetKeyInfo>; 292 IntegerSets integerSets; 293 294 // Affine expression uniquing. 295 StorageUniquer affineUniquer; 296 297 //===--------------------------------------------------------------------===// 298 // Type uniquing 299 //===--------------------------------------------------------------------===// 300 301 DenseMap<TypeID, AbstractType *> registeredTypes; 302 StorageUniquer typeUniquer; 303 304 /// Cached Type Instances. 305 BFloat16Type bf16Ty; 306 Float16Type f16Ty; 307 Float32Type f32Ty; 308 Float64Type f64Ty; 309 Float80Type f80Ty; 310 Float128Type f128Ty; 311 IndexType indexTy; 312 IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty; 313 NoneType noneType; 314 315 //===--------------------------------------------------------------------===// 316 // Attribute uniquing 317 //===--------------------------------------------------------------------===// 318 319 DenseMap<TypeID, AbstractAttribute *> registeredAttributes; 320 StorageUniquer attributeUniquer; 321 322 /// Cached Attribute Instances. 323 BoolAttr falseAttr, trueAttr; 324 UnitAttr unitAttr; 325 UnknownLoc unknownLocAttr; 326 DictionaryAttr emptyDictionaryAttr; 327 StringAttr emptyStringAttr; 328 329 /// Map of string attributes that may reference a dialect, that are awaiting 330 /// that dialect to be loaded. 331 llvm::sys::SmartMutex<true> dialectRefStrAttrMutex; 332 DenseMap<StringRef, SmallVector<StringAttrStorage *>> 333 dialectReferencingStrAttrs; 334 335 public: 336 MLIRContextImpl(bool threadingIsEnabled) 337 : threadingIsEnabled(threadingIsEnabled) { 338 if (threadingIsEnabled) { 339 ownedThreadPool = std::make_unique<llvm::ThreadPool>(); 340 threadPool = ownedThreadPool.get(); 341 } 342 } 343 ~MLIRContextImpl() { 344 for (auto typeMapping : registeredTypes) 345 typeMapping.second->~AbstractType(); 346 for (auto attrMapping : registeredAttributes) 347 attrMapping.second->~AbstractAttribute(); 348 } 349 }; 350 } // end namespace mlir 351 352 MLIRContext::MLIRContext(Threading setting) 353 : MLIRContext(DialectRegistry(), setting) {} 354 355 MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting) 356 : impl(new MLIRContextImpl(setting == Threading::ENABLED && 357 !isThreadingGloballyDisabled())) { 358 // Initialize values based on the command line flags if they were provided. 359 if (clOptions.isConstructed()) { 360 printOpOnDiagnostic(clOptions->printOpOnDiagnostic); 361 printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic); 362 } 363 364 // Pre-populate the registry. 365 registry.appendTo(impl->dialectsRegistry); 366 367 // Ensure the builtin dialect is always pre-loaded. 368 getOrLoadDialect<BuiltinDialect>(); 369 370 // Initialize several common attributes and types to avoid the need to lock 371 // the context when accessing them. 372 373 //// Types. 374 /// Floating-point Types. 375 impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this); 376 impl->f16Ty = TypeUniquer::get<Float16Type>(this); 377 impl->f32Ty = TypeUniquer::get<Float32Type>(this); 378 impl->f64Ty = TypeUniquer::get<Float64Type>(this); 379 impl->f80Ty = TypeUniquer::get<Float80Type>(this); 380 impl->f128Ty = TypeUniquer::get<Float128Type>(this); 381 /// Index Type. 382 impl->indexTy = TypeUniquer::get<IndexType>(this); 383 /// Integer Types. 384 impl->int1Ty = TypeUniquer::get<IntegerType>(this, 1, IntegerType::Signless); 385 impl->int8Ty = TypeUniquer::get<IntegerType>(this, 8, IntegerType::Signless); 386 impl->int16Ty = 387 TypeUniquer::get<IntegerType>(this, 16, IntegerType::Signless); 388 impl->int32Ty = 389 TypeUniquer::get<IntegerType>(this, 32, IntegerType::Signless); 390 impl->int64Ty = 391 TypeUniquer::get<IntegerType>(this, 64, IntegerType::Signless); 392 impl->int128Ty = 393 TypeUniquer::get<IntegerType>(this, 128, IntegerType::Signless); 394 /// None Type. 395 impl->noneType = TypeUniquer::get<NoneType>(this); 396 397 //// Attributes. 398 //// Note: These must be registered after the types as they may generate one 399 //// of the above types internally. 400 /// Unknown Location Attribute. 401 impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(this); 402 /// Bool Attributes. 403 impl->falseAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, false); 404 impl->trueAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, true); 405 /// Unit Attribute. 406 impl->unitAttr = AttributeUniquer::get<UnitAttr>(this); 407 /// The empty dictionary attribute. 408 impl->emptyDictionaryAttr = DictionaryAttr::getEmptyUnchecked(this); 409 /// The empty string attribute. 410 impl->emptyStringAttr = StringAttr::getEmptyStringAttrUnchecked(this); 411 412 // Register the affine storage objects with the uniquer. 413 impl->affineUniquer 414 .registerParametricStorageType<AffineBinaryOpExprStorage>(); 415 impl->affineUniquer 416 .registerParametricStorageType<AffineConstantExprStorage>(); 417 impl->affineUniquer.registerParametricStorageType<AffineDimExprStorage>(); 418 } 419 420 MLIRContext::~MLIRContext() {} 421 422 /// Copy the specified array of elements into memory managed by the provided 423 /// bump pointer allocator. This assumes the elements are all PODs. 424 template <typename T> 425 static ArrayRef<T> copyArrayRefInto(llvm::BumpPtrAllocator &allocator, 426 ArrayRef<T> elements) { 427 auto result = allocator.Allocate<T>(elements.size()); 428 std::uninitialized_copy(elements.begin(), elements.end(), result); 429 return ArrayRef<T>(result, elements.size()); 430 } 431 432 //===----------------------------------------------------------------------===// 433 // Debugging 434 //===----------------------------------------------------------------------===// 435 436 DebugActionManager &MLIRContext::getDebugActionManager() { 437 return getImpl().debugActionManager; 438 } 439 440 //===----------------------------------------------------------------------===// 441 // Diagnostic Handlers 442 //===----------------------------------------------------------------------===// 443 444 /// Returns the diagnostic engine for this context. 445 DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; } 446 447 //===----------------------------------------------------------------------===// 448 // Dialect and Operation Registration 449 //===----------------------------------------------------------------------===// 450 451 void MLIRContext::appendDialectRegistry(const DialectRegistry ®istry) { 452 registry.appendTo(impl->dialectsRegistry); 453 454 // For the already loaded dialects, register the interfaces immediately. 455 for (const auto &kvp : impl->loadedDialects) 456 registry.registerDelayedInterfaces(kvp.second.get()); 457 } 458 459 const DialectRegistry &MLIRContext::getDialectRegistry() { 460 return impl->dialectsRegistry; 461 } 462 463 /// Return information about all registered IR dialects. 464 std::vector<Dialect *> MLIRContext::getLoadedDialects() { 465 std::vector<Dialect *> result; 466 result.reserve(impl->loadedDialects.size()); 467 for (auto &dialect : impl->loadedDialects) 468 result.push_back(dialect.second.get()); 469 llvm::array_pod_sort(result.begin(), result.end(), 470 [](Dialect *const *lhs, Dialect *const *rhs) -> int { 471 return (*lhs)->getNamespace() < (*rhs)->getNamespace(); 472 }); 473 return result; 474 } 475 std::vector<StringRef> MLIRContext::getAvailableDialects() { 476 std::vector<StringRef> result; 477 for (auto dialect : impl->dialectsRegistry.getDialectNames()) 478 result.push_back(dialect); 479 return result; 480 } 481 482 /// Get a registered IR dialect with the given namespace. If none is found, 483 /// then return nullptr. 484 Dialect *MLIRContext::getLoadedDialect(StringRef name) { 485 // Dialects are sorted by name, so we can use binary search for lookup. 486 auto it = impl->loadedDialects.find(name); 487 return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr; 488 } 489 490 Dialect *MLIRContext::getOrLoadDialect(StringRef name) { 491 Dialect *dialect = getLoadedDialect(name); 492 if (dialect) 493 return dialect; 494 DialectAllocatorFunctionRef allocator = 495 impl->dialectsRegistry.getDialectAllocator(name); 496 return allocator ? allocator(this) : nullptr; 497 } 498 499 /// Get a dialect for the provided namespace and TypeID: abort the program if a 500 /// dialect exist for this namespace with different TypeID. Returns a pointer to 501 /// the dialect owned by the context. 502 Dialect * 503 MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID, 504 function_ref<std::unique_ptr<Dialect>()> ctor) { 505 auto &impl = getImpl(); 506 // Get the correct insertion position sorted by namespace. 507 std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace]; 508 509 if (!dialect) { 510 LLVM_DEBUG(llvm::dbgs() 511 << "Load new dialect in Context " << dialectNamespace << "\n"); 512 #ifndef NDEBUG 513 if (impl.multiThreadedExecutionContext != 0) 514 llvm::report_fatal_error( 515 "Loading a dialect (" + dialectNamespace + 516 ") while in a multi-threaded execution context (maybe " 517 "the PassManager): this can indicate a " 518 "missing `dependentDialects` in a pass for example."); 519 #endif 520 dialect = ctor(); 521 assert(dialect && "dialect ctor failed"); 522 523 // Refresh all the identifiers dialect field, this catches cases where a 524 // dialect may be loaded after identifier prefixed with this dialect name 525 // were already created. 526 auto stringAttrsIt = impl.dialectReferencingStrAttrs.find(dialectNamespace); 527 if (stringAttrsIt != impl.dialectReferencingStrAttrs.end()) { 528 for (StringAttrStorage *storage : stringAttrsIt->second) 529 storage->referencedDialect = dialect.get(); 530 impl.dialectReferencingStrAttrs.erase(stringAttrsIt); 531 } 532 533 // Actually register the interfaces with delayed registration. 534 impl.dialectsRegistry.registerDelayedInterfaces(dialect.get()); 535 return dialect.get(); 536 } 537 538 // Abort if dialect with namespace has already been registered. 539 if (dialect->getTypeID() != dialectID) 540 llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace + 541 "' has already been registered"); 542 543 return dialect.get(); 544 } 545 546 void MLIRContext::loadAllAvailableDialects() { 547 for (StringRef name : getAvailableDialects()) 548 getOrLoadDialect(name); 549 } 550 551 llvm::hash_code MLIRContext::getRegistryHash() { 552 llvm::hash_code hash(0); 553 // Factor in number of loaded dialects, attributes, operations, types. 554 hash = llvm::hash_combine(hash, impl->loadedDialects.size()); 555 hash = llvm::hash_combine(hash, impl->registeredAttributes.size()); 556 hash = llvm::hash_combine(hash, impl->registeredOperations.size()); 557 hash = llvm::hash_combine(hash, impl->registeredTypes.size()); 558 return hash; 559 } 560 561 bool MLIRContext::allowsUnregisteredDialects() { 562 return impl->allowUnregisteredDialects; 563 } 564 565 void MLIRContext::allowUnregisteredDialects(bool allowing) { 566 impl->allowUnregisteredDialects = allowing; 567 } 568 569 /// Return true if multi-threading is enabled by the context. 570 bool MLIRContext::isMultithreadingEnabled() { 571 return impl->threadingIsEnabled && llvm::llvm_is_multithreaded(); 572 } 573 574 /// Set the flag specifying if multi-threading is disabled by the context. 575 void MLIRContext::disableMultithreading(bool disable) { 576 // This API can be overridden by the global debugging flag 577 // --mlir-disable-threading 578 if (isThreadingGloballyDisabled()) 579 return; 580 581 impl->threadingIsEnabled = !disable; 582 583 // Update the threading mode for each of the uniquers. 584 impl->affineUniquer.disableMultithreading(disable); 585 impl->attributeUniquer.disableMultithreading(disable); 586 impl->typeUniquer.disableMultithreading(disable); 587 588 // Destroy thread pool (stop all threads) if it is no longer needed, or create 589 // a new one if multithreading was re-enabled. 590 if (disable) { 591 // If the thread pool is owned, explicitly set it to nullptr to avoid 592 // keeping a dangling pointer around. If the thread pool is externally 593 // owned, we don't do anything. 594 if (impl->ownedThreadPool) { 595 assert(impl->threadPool); 596 impl->threadPool = nullptr; 597 impl->ownedThreadPool.reset(); 598 } 599 } else if (!impl->threadPool) { 600 // The thread pool isn't externally provided. 601 assert(!impl->ownedThreadPool); 602 impl->ownedThreadPool = std::make_unique<llvm::ThreadPool>(); 603 impl->threadPool = impl->ownedThreadPool.get(); 604 } 605 } 606 607 void MLIRContext::setThreadPool(llvm::ThreadPool &pool) { 608 assert(!isMultithreadingEnabled() && 609 "expected multi-threading to be disabled when setting a ThreadPool"); 610 impl->threadPool = &pool; 611 impl->ownedThreadPool.reset(); 612 enableMultithreading(); 613 } 614 615 llvm::ThreadPool &MLIRContext::getThreadPool() { 616 assert(isMultithreadingEnabled() && 617 "expected multi-threading to be enabled within the context"); 618 assert(impl->threadPool && 619 "multi-threading is enabled but threadpool not set"); 620 return *impl->threadPool; 621 } 622 623 void MLIRContext::enterMultiThreadedExecution() { 624 #ifndef NDEBUG 625 ++impl->multiThreadedExecutionContext; 626 #endif 627 } 628 void MLIRContext::exitMultiThreadedExecution() { 629 #ifndef NDEBUG 630 --impl->multiThreadedExecutionContext; 631 #endif 632 } 633 634 /// Return true if we should attach the operation to diagnostics emitted via 635 /// Operation::emit. 636 bool MLIRContext::shouldPrintOpOnDiagnostic() { 637 return impl->printOpOnDiagnostic; 638 } 639 640 /// Set the flag specifying if we should attach the operation to diagnostics 641 /// emitted via Operation::emit. 642 void MLIRContext::printOpOnDiagnostic(bool enable) { 643 impl->printOpOnDiagnostic = enable; 644 } 645 646 /// Return true if we should attach the current stacktrace to diagnostics when 647 /// emitted. 648 bool MLIRContext::shouldPrintStackTraceOnDiagnostic() { 649 return impl->printStackTraceOnDiagnostic; 650 } 651 652 /// Set the flag specifying if we should attach the current stacktrace when 653 /// emitting diagnostics. 654 void MLIRContext::printStackTraceOnDiagnostic(bool enable) { 655 impl->printStackTraceOnDiagnostic = enable; 656 } 657 658 /// Return information about all registered operations. This isn't very 659 /// efficient, typically you should ask the operations about their properties 660 /// directly. 661 std::vector<RegisteredOperationName> MLIRContext::getRegisteredOperations() { 662 // We just have the operations in a non-deterministic hash table order. Dump 663 // into a temporary array, then sort it by operation name to get a stable 664 // ordering. 665 std::vector<RegisteredOperationName> result( 666 impl->registeredOperations.begin(), impl->registeredOperations.end()); 667 llvm::array_pod_sort(result.begin(), result.end(), 668 [](const RegisteredOperationName *lhs, 669 const RegisteredOperationName *rhs) { 670 return lhs->getIdentifier().compare( 671 rhs->getIdentifier()); 672 }); 673 674 return result; 675 } 676 677 bool MLIRContext::isOperationRegistered(StringRef name) { 678 return OperationName(name, this).isRegistered(); 679 } 680 681 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) { 682 auto &impl = context->getImpl(); 683 assert(impl.multiThreadedExecutionContext == 0 && 684 "Registering a new type kind while in a multi-threaded execution " 685 "context"); 686 auto *newInfo = 687 new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>()) 688 AbstractType(std::move(typeInfo)); 689 if (!impl.registeredTypes.insert({typeID, newInfo}).second) 690 llvm::report_fatal_error("Dialect Type already registered."); 691 } 692 693 void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) { 694 auto &impl = context->getImpl(); 695 assert(impl.multiThreadedExecutionContext == 0 && 696 "Registering a new attribute kind while in a multi-threaded execution " 697 "context"); 698 auto *newInfo = 699 new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>()) 700 AbstractAttribute(std::move(attrInfo)); 701 if (!impl.registeredAttributes.insert({typeID, newInfo}).second) 702 llvm::report_fatal_error("Dialect Attribute already registered."); 703 } 704 705 //===----------------------------------------------------------------------===// 706 // AbstractAttribute 707 //===----------------------------------------------------------------------===// 708 709 /// Get the dialect that registered the attribute with the provided typeid. 710 const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID, 711 MLIRContext *context) { 712 const AbstractAttribute *abstract = lookupMutable(typeID, context); 713 if (!abstract) 714 llvm::report_fatal_error("Trying to create an Attribute that was not " 715 "registered in this MLIRContext."); 716 return *abstract; 717 } 718 719 AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID, 720 MLIRContext *context) { 721 auto &impl = context->getImpl(); 722 auto it = impl.registeredAttributes.find(typeID); 723 if (it == impl.registeredAttributes.end()) 724 return nullptr; 725 return it->second; 726 } 727 728 //===----------------------------------------------------------------------===// 729 // OperationName 730 //===----------------------------------------------------------------------===// 731 732 OperationName::OperationName(StringRef name, MLIRContext *context) { 733 MLIRContextImpl &ctxImpl = context->getImpl(); 734 735 // Check for an existing name in read-only mode. 736 bool isMultithreadingEnabled = context->isMultithreadingEnabled(); 737 if (isMultithreadingEnabled) { 738 llvm::sys::SmartScopedReader<true> contextLock(ctxImpl.operationInfoMutex); 739 auto it = ctxImpl.operations.find(name); 740 if (it != ctxImpl.operations.end()) { 741 impl = &it->second; 742 return; 743 } 744 } 745 746 // Acquire a writer-lock so that we can safely create the new instance. 747 ScopedWriterLock lock(ctxImpl.operationInfoMutex, isMultithreadingEnabled); 748 749 auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)}); 750 if (it.second) 751 it.first->second.name = StringAttr::get(context, name); 752 impl = &it.first->second; 753 } 754 755 StringRef OperationName::getDialectNamespace() const { 756 if (Dialect *dialect = getDialect()) 757 return dialect->getNamespace(); 758 return getStringRef().split('.').first; 759 } 760 761 //===----------------------------------------------------------------------===// 762 // RegisteredOperationName 763 //===----------------------------------------------------------------------===// 764 765 ParseResult 766 RegisteredOperationName::parseAssembly(OpAsmParser &parser, 767 OperationState &result) const { 768 return impl->parseAssemblyFn(parser, result); 769 } 770 771 void RegisteredOperationName::insert( 772 StringRef name, Dialect &dialect, TypeID typeID, 773 ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly, 774 VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook, 775 GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, 776 detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait, 777 ArrayRef<StringRef> attrNames) { 778 MLIRContext *ctx = dialect.getContext(); 779 auto &ctxImpl = ctx->getImpl(); 780 assert(ctxImpl.multiThreadedExecutionContext == 0 && 781 "registering a new operation kind while in a multi-threaded execution " 782 "context"); 783 784 // Register the attribute names of this operation. 785 MutableArrayRef<StringAttr> cachedAttrNames; 786 if (!attrNames.empty()) { 787 cachedAttrNames = MutableArrayRef<StringAttr>( 788 ctxImpl.abstractDialectSymbolAllocator.Allocate<StringAttr>( 789 attrNames.size()), 790 attrNames.size()); 791 for (unsigned i : llvm::seq<unsigned>(0, attrNames.size())) 792 new (&cachedAttrNames[i]) StringAttr(StringAttr::get(ctx, attrNames[i])); 793 } 794 795 // Insert the operation info if it doesn't exist yet. 796 auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)}); 797 if (it.second) 798 it.first->second.name = StringAttr::get(ctx, name); 799 OperationName::Impl &impl = it.first->second; 800 801 if (impl.isRegistered()) { 802 llvm::errs() << "error: operation named '" << name 803 << "' is already registered.\n"; 804 abort(); 805 } 806 ctxImpl.registeredOperations.push_back(RegisteredOperationName(&impl)); 807 808 // Update the registered info for this operation. 809 impl.dialect = &dialect; 810 impl.typeID = typeID; 811 impl.interfaceMap = std::move(interfaceMap); 812 impl.foldHookFn = std::move(foldHook); 813 impl.getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns); 814 impl.hasTraitFn = std::move(hasTrait); 815 impl.parseAssemblyFn = std::move(parseAssembly); 816 impl.printAssemblyFn = std::move(printAssembly); 817 impl.verifyInvariantsFn = std::move(verifyInvariants); 818 impl.attributeNames = cachedAttrNames; 819 } 820 821 //===----------------------------------------------------------------------===// 822 // AbstractType 823 //===----------------------------------------------------------------------===// 824 825 const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) { 826 const AbstractType *type = lookupMutable(typeID, context); 827 if (!type) 828 llvm::report_fatal_error( 829 "Trying to create a Type that was not registered in this MLIRContext."); 830 return *type; 831 } 832 833 AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) { 834 auto &impl = context->getImpl(); 835 auto it = impl.registeredTypes.find(typeID); 836 if (it == impl.registeredTypes.end()) 837 return nullptr; 838 return it->second; 839 } 840 841 //===----------------------------------------------------------------------===// 842 // Type uniquing 843 //===----------------------------------------------------------------------===// 844 845 /// Returns the storage uniquer used for constructing type storage instances. 846 /// This should not be used directly. 847 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; } 848 849 BFloat16Type BFloat16Type::get(MLIRContext *context) { 850 return context->getImpl().bf16Ty; 851 } 852 Float16Type Float16Type::get(MLIRContext *context) { 853 return context->getImpl().f16Ty; 854 } 855 Float32Type Float32Type::get(MLIRContext *context) { 856 return context->getImpl().f32Ty; 857 } 858 Float64Type Float64Type::get(MLIRContext *context) { 859 return context->getImpl().f64Ty; 860 } 861 Float80Type Float80Type::get(MLIRContext *context) { 862 return context->getImpl().f80Ty; 863 } 864 Float128Type Float128Type::get(MLIRContext *context) { 865 return context->getImpl().f128Ty; 866 } 867 868 /// Get an instance of the IndexType. 869 IndexType IndexType::get(MLIRContext *context) { 870 return context->getImpl().indexTy; 871 } 872 873 /// Return an existing integer type instance if one is cached within the 874 /// context. 875 static IntegerType 876 getCachedIntegerType(unsigned width, 877 IntegerType::SignednessSemantics signedness, 878 MLIRContext *context) { 879 if (signedness != IntegerType::Signless) 880 return IntegerType(); 881 882 switch (width) { 883 case 1: 884 return context->getImpl().int1Ty; 885 case 8: 886 return context->getImpl().int8Ty; 887 case 16: 888 return context->getImpl().int16Ty; 889 case 32: 890 return context->getImpl().int32Ty; 891 case 64: 892 return context->getImpl().int64Ty; 893 case 128: 894 return context->getImpl().int128Ty; 895 default: 896 return IntegerType(); 897 } 898 } 899 900 IntegerType IntegerType::get(MLIRContext *context, unsigned width, 901 IntegerType::SignednessSemantics signedness) { 902 if (auto cached = getCachedIntegerType(width, signedness, context)) 903 return cached; 904 return Base::get(context, width, signedness); 905 } 906 907 IntegerType 908 IntegerType::getChecked(function_ref<InFlightDiagnostic()> emitError, 909 MLIRContext *context, unsigned width, 910 SignednessSemantics signedness) { 911 if (auto cached = getCachedIntegerType(width, signedness, context)) 912 return cached; 913 return Base::getChecked(emitError, context, width, signedness); 914 } 915 916 /// Get an instance of the NoneType. 917 NoneType NoneType::get(MLIRContext *context) { 918 if (NoneType cachedInst = context->getImpl().noneType) 919 return cachedInst; 920 // Note: May happen when initializing the singleton attributes of the builtin 921 // dialect. 922 return Base::get(context); 923 } 924 925 //===----------------------------------------------------------------------===// 926 // Attribute uniquing 927 //===----------------------------------------------------------------------===// 928 929 /// Returns the storage uniquer used for constructing attribute storage 930 /// instances. This should not be used directly. 931 StorageUniquer &MLIRContext::getAttributeUniquer() { 932 return getImpl().attributeUniquer; 933 } 934 935 /// Initialize the given attribute storage instance. 936 void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage, 937 MLIRContext *ctx, 938 TypeID attrID) { 939 storage->initializeAbstractAttribute(AbstractAttribute::lookup(attrID, ctx)); 940 941 // If the attribute did not provide a type, then default to NoneType. 942 if (!storage->getType()) 943 storage->setType(NoneType::get(ctx)); 944 } 945 946 BoolAttr BoolAttr::get(MLIRContext *context, bool value) { 947 return value ? context->getImpl().trueAttr : context->getImpl().falseAttr; 948 } 949 950 UnitAttr UnitAttr::get(MLIRContext *context) { 951 return context->getImpl().unitAttr; 952 } 953 954 UnknownLoc UnknownLoc::get(MLIRContext *context) { 955 return context->getImpl().unknownLocAttr; 956 } 957 958 /// Return empty dictionary. 959 DictionaryAttr DictionaryAttr::getEmpty(MLIRContext *context) { 960 return context->getImpl().emptyDictionaryAttr; 961 } 962 963 void StringAttrStorage::initialize(MLIRContext *context) { 964 // Check for a dialect namespace prefix, if there isn't one we don't need to 965 // do any additional initialization. 966 auto dialectNamePair = value.split('.'); 967 if (dialectNamePair.first.empty() || dialectNamePair.second.empty()) 968 return; 969 970 // If one exists, we check to see if this dialect is loaded. If it is, we set 971 // the dialect now, if it isn't we record this storage for initialization 972 // later if the dialect ever gets loaded. 973 if ((referencedDialect = context->getLoadedDialect(dialectNamePair.first))) 974 return; 975 976 MLIRContextImpl &impl = context->getImpl(); 977 llvm::sys::SmartScopedLock<true> lock(impl.dialectRefStrAttrMutex); 978 impl.dialectReferencingStrAttrs[dialectNamePair.first].push_back(this); 979 } 980 981 /// Return an empty string. 982 StringAttr StringAttr::get(MLIRContext *context) { 983 return context->getImpl().emptyStringAttr; 984 } 985 986 //===----------------------------------------------------------------------===// 987 // AffineMap uniquing 988 //===----------------------------------------------------------------------===// 989 990 StorageUniquer &MLIRContext::getAffineUniquer() { 991 return getImpl().affineUniquer; 992 } 993 994 AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount, 995 ArrayRef<AffineExpr> results, 996 MLIRContext *context) { 997 auto &impl = context->getImpl(); 998 auto key = std::make_tuple(dimCount, symbolCount, results); 999 1000 // Safely get or create an AffineMap instance. 1001 return safeGetOrCreate( 1002 impl.affineMaps, key, impl.affineMutex, impl.threadingIsEnabled, [&] { 1003 auto *res = impl.affineAllocator.Allocate<detail::AffineMapStorage>(); 1004 1005 // Copy the results into the bump pointer. 1006 results = copyArrayRefInto(impl.affineAllocator, results); 1007 1008 // Initialize the memory using placement new. 1009 new (res) 1010 detail::AffineMapStorage{dimCount, symbolCount, results, context}; 1011 return AffineMap(res); 1012 }); 1013 } 1014 1015 /// Check whether the arguments passed to the AffineMap::get() are consistent. 1016 /// This method checks whether the highest index of dimensional identifier 1017 /// present in result expressions is less than `dimCount` and the highest index 1018 /// of symbolic identifier present in result expressions is less than 1019 /// `symbolCount`. 1020 LLVM_NODISCARD static bool willBeValidAffineMap(unsigned dimCount, 1021 unsigned symbolCount, 1022 ArrayRef<AffineExpr> results) { 1023 int64_t maxDimPosition = -1; 1024 int64_t maxSymbolPosition = -1; 1025 getMaxDimAndSymbol(ArrayRef<ArrayRef<AffineExpr>>(results), maxDimPosition, 1026 maxSymbolPosition); 1027 if ((maxDimPosition >= dimCount) || (maxSymbolPosition >= symbolCount)) { 1028 LLVM_DEBUG( 1029 llvm::dbgs() 1030 << "maximum dimensional identifier position in result expression must " 1031 "be less than `dimCount` and maximum symbolic identifier position " 1032 "in result expression must be less than `symbolCount`\n"); 1033 return false; 1034 } 1035 return true; 1036 } 1037 1038 AffineMap AffineMap::get(MLIRContext *context) { 1039 return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context); 1040 } 1041 1042 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, 1043 MLIRContext *context) { 1044 return getImpl(dimCount, symbolCount, /*results=*/{}, context); 1045 } 1046 1047 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, 1048 AffineExpr result) { 1049 assert(willBeValidAffineMap(dimCount, symbolCount, {result})); 1050 return getImpl(dimCount, symbolCount, {result}, result.getContext()); 1051 } 1052 1053 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, 1054 ArrayRef<AffineExpr> results, MLIRContext *context) { 1055 assert(willBeValidAffineMap(dimCount, symbolCount, results)); 1056 return getImpl(dimCount, symbolCount, results, context); 1057 } 1058 1059 //===----------------------------------------------------------------------===// 1060 // Integer Sets: these are allocated into the bump pointer, and are immutable. 1061 // Unlike AffineMap's, these are uniqued only if they are small. 1062 //===----------------------------------------------------------------------===// 1063 1064 IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount, 1065 ArrayRef<AffineExpr> constraints, 1066 ArrayRef<bool> eqFlags) { 1067 // The number of constraints can't be zero. 1068 assert(!constraints.empty()); 1069 assert(constraints.size() == eqFlags.size()); 1070 1071 auto &impl = constraints[0].getContext()->getImpl(); 1072 1073 // A utility function to construct a new IntegerSetStorage instance. 1074 auto constructorFn = [&] { 1075 auto *res = impl.affineAllocator.Allocate<detail::IntegerSetStorage>(); 1076 1077 // Copy the results and equality flags into the bump pointer. 1078 constraints = copyArrayRefInto(impl.affineAllocator, constraints); 1079 eqFlags = copyArrayRefInto(impl.affineAllocator, eqFlags); 1080 1081 // Initialize the memory using placement new. 1082 new (res) 1083 detail::IntegerSetStorage{dimCount, symbolCount, constraints, eqFlags}; 1084 return IntegerSet(res); 1085 }; 1086 1087 // If this instance is uniqued, then we handle it separately so that multiple 1088 // threads may simultaneously access existing instances. 1089 if (constraints.size() < IntegerSet::kUniquingThreshold) { 1090 auto key = std::make_tuple(dimCount, symbolCount, constraints, eqFlags); 1091 return safeGetOrCreate(impl.integerSets, key, impl.affineMutex, 1092 impl.threadingIsEnabled, constructorFn); 1093 } 1094 1095 // Otherwise, acquire a writer-lock so that we can safely create the new 1096 // instance. 1097 ScopedWriterLock affineLock(impl.affineMutex, impl.threadingIsEnabled); 1098 return constructorFn(); 1099 } 1100 1101 //===----------------------------------------------------------------------===// 1102 // StorageUniquerSupport 1103 //===----------------------------------------------------------------------===// 1104 1105 /// Utility method to generate a callback that can be used to generate a 1106 /// diagnostic when checking the construction invariants of a storage object. 1107 /// This is defined out-of-line to avoid the need to include Location.h. 1108 llvm::unique_function<InFlightDiagnostic()> 1109 mlir::detail::getDefaultDiagnosticEmitFn(MLIRContext *ctx) { 1110 return [ctx] { return emitError(UnknownLoc::get(ctx)); }; 1111 } 1112 llvm::unique_function<InFlightDiagnostic()> 1113 mlir::detail::getDefaultDiagnosticEmitFn(const Location &loc) { 1114 return [=] { return emitError(loc); }; 1115 } 1116