1 //===- MLIRContext.cpp - MLIR Type Classes --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/MLIRContext.h" 10 #include "AffineExprDetail.h" 11 #include "AffineMapDetail.h" 12 #include "AttributeDetail.h" 13 #include "IntegerSetDetail.h" 14 #include "TypeDetail.h" 15 #include "mlir/IR/AffineExpr.h" 16 #include "mlir/IR/AffineMap.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/BuiltinDialect.h" 19 #include "mlir/IR/Diagnostics.h" 20 #include "mlir/IR/Dialect.h" 21 #include "mlir/IR/IntegerSet.h" 22 #include "mlir/IR/Location.h" 23 #include "mlir/IR/OpImplementation.h" 24 #include "mlir/IR/Types.h" 25 #include "mlir/Support/DebugAction.h" 26 #include "llvm/ADT/DenseMap.h" 27 #include "llvm/ADT/DenseSet.h" 28 #include "llvm/ADT/SetVector.h" 29 #include "llvm/ADT/SmallString.h" 30 #include "llvm/ADT/StringSet.h" 31 #include "llvm/ADT/Twine.h" 32 #include "llvm/Support/Allocator.h" 33 #include "llvm/Support/CommandLine.h" 34 #include "llvm/Support/Debug.h" 35 #include "llvm/Support/Mutex.h" 36 #include "llvm/Support/RWMutex.h" 37 #include "llvm/Support/ThreadPool.h" 38 #include "llvm/Support/raw_ostream.h" 39 #include <memory> 40 41 #define DEBUG_TYPE "mlircontext" 42 43 using namespace mlir; 44 using namespace mlir::detail; 45 46 //===----------------------------------------------------------------------===// 47 // MLIRContext CommandLine Options 48 //===----------------------------------------------------------------------===// 49 50 namespace { 51 /// This struct contains command line options that can be used to initialize 52 /// various bits of an MLIRContext. This uses a struct wrapper to avoid the need 53 /// for global command line options. 54 struct MLIRContextOptions { 55 llvm::cl::opt<bool> disableThreading{ 56 "mlir-disable-threading", 57 llvm::cl::desc("Disable multi-threading within MLIR, overrides any " 58 "further call to MLIRContext::enableMultiThreading()")}; 59 60 llvm::cl::opt<bool> printOpOnDiagnostic{ 61 "mlir-print-op-on-diagnostic", 62 llvm::cl::desc("When a diagnostic is emitted on an operation, also print " 63 "the operation as an attached note"), 64 llvm::cl::init(true)}; 65 66 llvm::cl::opt<bool> printStackTraceOnDiagnostic{ 67 "mlir-print-stacktrace-on-diagnostic", 68 llvm::cl::desc("When a diagnostic is emitted, also print the stack trace " 69 "as an attached note")}; 70 }; 71 } // namespace 72 73 static llvm::ManagedStatic<MLIRContextOptions> clOptions; 74 75 static bool isThreadingGloballyDisabled() { 76 #if LLVM_ENABLE_THREADS != 0 77 return clOptions.isConstructed() && clOptions->disableThreading; 78 #else 79 return true; 80 #endif 81 } 82 83 /// Register a set of useful command-line options that can be used to configure 84 /// various flags within the MLIRContext. These flags are used when constructing 85 /// an MLIR context for initialization. 86 void mlir::registerMLIRContextCLOptions() { 87 // Make sure that the options struct has been initialized. 88 *clOptions; 89 } 90 91 //===----------------------------------------------------------------------===// 92 // Locking Utilities 93 //===----------------------------------------------------------------------===// 94 95 namespace { 96 /// Utility writer lock that takes a runtime flag that specifies if we really 97 /// need to lock. 98 struct ScopedWriterLock { 99 ScopedWriterLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock) 100 : mutex(shouldLock ? &mutexParam : nullptr) { 101 if (mutex) 102 mutex->lock(); 103 } 104 ~ScopedWriterLock() { 105 if (mutex) 106 mutex->unlock(); 107 } 108 llvm::sys::SmartRWMutex<true> *mutex; 109 }; 110 } // namespace 111 112 //===----------------------------------------------------------------------===// 113 // MLIRContextImpl 114 //===----------------------------------------------------------------------===// 115 116 namespace mlir { 117 /// This is the implementation of the MLIRContext class, using the pImpl idiom. 118 /// This class is completely private to this file, so everything is public. 119 class MLIRContextImpl { 120 public: 121 //===--------------------------------------------------------------------===// 122 // Debugging 123 //===--------------------------------------------------------------------===// 124 125 /// An action manager for use within the context. 126 DebugActionManager debugActionManager; 127 128 //===--------------------------------------------------------------------===// 129 // Diagnostics 130 //===--------------------------------------------------------------------===// 131 DiagnosticEngine diagEngine; 132 133 //===--------------------------------------------------------------------===// 134 // Options 135 //===--------------------------------------------------------------------===// 136 137 /// In most cases, creating operation in unregistered dialect is not desired 138 /// and indicate a misconfiguration of the compiler. This option enables to 139 /// detect such use cases 140 bool allowUnregisteredDialects = false; 141 142 /// Enable support for multi-threading within MLIR. 143 bool threadingIsEnabled = true; 144 145 /// Track if we are currently executing in a threaded execution environment 146 /// (like the pass-manager): this is only a debugging feature to help reducing 147 /// the chances of data races one some context APIs. 148 #ifndef NDEBUG 149 std::atomic<int> multiThreadedExecutionContext{0}; 150 #endif 151 152 /// If the operation should be attached to diagnostics printed via the 153 /// Operation::emit methods. 154 bool printOpOnDiagnostic = true; 155 156 /// If the current stack trace should be attached when emitting diagnostics. 157 bool printStackTraceOnDiagnostic = false; 158 159 //===--------------------------------------------------------------------===// 160 // Other 161 //===--------------------------------------------------------------------===// 162 163 /// This points to the ThreadPool used when processing MLIR tasks in parallel. 164 /// It can't be nullptr when multi-threading is enabled. Otherwise if 165 /// multi-threading is disabled, and the threadpool wasn't externally provided 166 /// using `setThreadPool`, this will be nullptr. 167 llvm::ThreadPool *threadPool = nullptr; 168 169 /// In case where the thread pool is owned by the context, this ensures 170 /// destruction with the context. 171 std::unique_ptr<llvm::ThreadPool> ownedThreadPool; 172 173 /// This is a list of dialects that are created referring to this context. 174 /// The MLIRContext owns the objects. 175 DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects; 176 DialectRegistry dialectsRegistry; 177 178 /// An allocator used for AbstractAttribute and AbstractType objects. 179 llvm::BumpPtrAllocator abstractDialectSymbolAllocator; 180 181 /// This is a mapping from operation name to the operation info describing it. 182 llvm::StringMap<OperationName::Impl> operations; 183 184 /// A vector of operation info specifically for registered operations. 185 llvm::StringMap<RegisteredOperationName> registeredOperations; 186 187 /// A mutex used when accessing operation information. 188 llvm::sys::SmartRWMutex<true> operationInfoMutex; 189 190 //===--------------------------------------------------------------------===// 191 // Affine uniquing 192 //===--------------------------------------------------------------------===// 193 194 // Affine expression, map and integer set uniquing. 195 StorageUniquer affineUniquer; 196 197 //===--------------------------------------------------------------------===// 198 // Type uniquing 199 //===--------------------------------------------------------------------===// 200 201 DenseMap<TypeID, AbstractType *> registeredTypes; 202 StorageUniquer typeUniquer; 203 204 /// Cached Type Instances. 205 BFloat16Type bf16Ty; 206 Float16Type f16Ty; 207 Float32Type f32Ty; 208 Float64Type f64Ty; 209 Float80Type f80Ty; 210 Float128Type f128Ty; 211 IndexType indexTy; 212 IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty; 213 NoneType noneType; 214 215 //===--------------------------------------------------------------------===// 216 // Attribute uniquing 217 //===--------------------------------------------------------------------===// 218 219 DenseMap<TypeID, AbstractAttribute *> registeredAttributes; 220 StorageUniquer attributeUniquer; 221 222 /// Cached Attribute Instances. 223 BoolAttr falseAttr, trueAttr; 224 UnitAttr unitAttr; 225 UnknownLoc unknownLocAttr; 226 DictionaryAttr emptyDictionaryAttr; 227 StringAttr emptyStringAttr; 228 229 /// Map of string attributes that may reference a dialect, that are awaiting 230 /// that dialect to be loaded. 231 llvm::sys::SmartMutex<true> dialectRefStrAttrMutex; 232 DenseMap<StringRef, SmallVector<StringAttrStorage *>> 233 dialectReferencingStrAttrs; 234 235 public: 236 MLIRContextImpl(bool threadingIsEnabled) 237 : threadingIsEnabled(threadingIsEnabled) { 238 if (threadingIsEnabled) { 239 ownedThreadPool = std::make_unique<llvm::ThreadPool>(); 240 threadPool = ownedThreadPool.get(); 241 } 242 } 243 ~MLIRContextImpl() { 244 for (auto typeMapping : registeredTypes) 245 typeMapping.second->~AbstractType(); 246 for (auto attrMapping : registeredAttributes) 247 attrMapping.second->~AbstractAttribute(); 248 } 249 }; 250 } // namespace mlir 251 252 MLIRContext::MLIRContext(Threading setting) 253 : MLIRContext(DialectRegistry(), setting) {} 254 255 MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting) 256 : impl(new MLIRContextImpl(setting == Threading::ENABLED && 257 !isThreadingGloballyDisabled())) { 258 // Initialize values based on the command line flags if they were provided. 259 if (clOptions.isConstructed()) { 260 printOpOnDiagnostic(clOptions->printOpOnDiagnostic); 261 printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic); 262 } 263 264 // Pre-populate the registry. 265 registry.appendTo(impl->dialectsRegistry); 266 267 // Ensure the builtin dialect is always pre-loaded. 268 getOrLoadDialect<BuiltinDialect>(); 269 270 // Initialize several common attributes and types to avoid the need to lock 271 // the context when accessing them. 272 273 //// Types. 274 /// Floating-point Types. 275 impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this); 276 impl->f16Ty = TypeUniquer::get<Float16Type>(this); 277 impl->f32Ty = TypeUniquer::get<Float32Type>(this); 278 impl->f64Ty = TypeUniquer::get<Float64Type>(this); 279 impl->f80Ty = TypeUniquer::get<Float80Type>(this); 280 impl->f128Ty = TypeUniquer::get<Float128Type>(this); 281 /// Index Type. 282 impl->indexTy = TypeUniquer::get<IndexType>(this); 283 /// Integer Types. 284 impl->int1Ty = TypeUniquer::get<IntegerType>(this, 1, IntegerType::Signless); 285 impl->int8Ty = TypeUniquer::get<IntegerType>(this, 8, IntegerType::Signless); 286 impl->int16Ty = 287 TypeUniquer::get<IntegerType>(this, 16, IntegerType::Signless); 288 impl->int32Ty = 289 TypeUniquer::get<IntegerType>(this, 32, IntegerType::Signless); 290 impl->int64Ty = 291 TypeUniquer::get<IntegerType>(this, 64, IntegerType::Signless); 292 impl->int128Ty = 293 TypeUniquer::get<IntegerType>(this, 128, IntegerType::Signless); 294 /// None Type. 295 impl->noneType = TypeUniquer::get<NoneType>(this); 296 297 //// Attributes. 298 //// Note: These must be registered after the types as they may generate one 299 //// of the above types internally. 300 /// Unknown Location Attribute. 301 impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(this); 302 /// Bool Attributes. 303 impl->falseAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, false); 304 impl->trueAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, true); 305 /// Unit Attribute. 306 impl->unitAttr = AttributeUniquer::get<UnitAttr>(this); 307 /// The empty dictionary attribute. 308 impl->emptyDictionaryAttr = DictionaryAttr::getEmptyUnchecked(this); 309 /// The empty string attribute. 310 impl->emptyStringAttr = StringAttr::getEmptyStringAttrUnchecked(this); 311 312 // Register the affine storage objects with the uniquer. 313 impl->affineUniquer 314 .registerParametricStorageType<AffineBinaryOpExprStorage>(); 315 impl->affineUniquer 316 .registerParametricStorageType<AffineConstantExprStorage>(); 317 impl->affineUniquer.registerParametricStorageType<AffineDimExprStorage>(); 318 impl->affineUniquer.registerParametricStorageType<AffineMapStorage>(); 319 impl->affineUniquer.registerParametricStorageType<IntegerSetStorage>(); 320 } 321 322 MLIRContext::~MLIRContext() = default; 323 324 /// Copy the specified array of elements into memory managed by the provided 325 /// bump pointer allocator. This assumes the elements are all PODs. 326 template <typename T> 327 static ArrayRef<T> copyArrayRefInto(llvm::BumpPtrAllocator &allocator, 328 ArrayRef<T> elements) { 329 auto result = allocator.Allocate<T>(elements.size()); 330 std::uninitialized_copy(elements.begin(), elements.end(), result); 331 return ArrayRef<T>(result, elements.size()); 332 } 333 334 //===----------------------------------------------------------------------===// 335 // Debugging 336 //===----------------------------------------------------------------------===// 337 338 DebugActionManager &MLIRContext::getDebugActionManager() { 339 return getImpl().debugActionManager; 340 } 341 342 //===----------------------------------------------------------------------===// 343 // Diagnostic Handlers 344 //===----------------------------------------------------------------------===// 345 346 /// Returns the diagnostic engine for this context. 347 DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; } 348 349 //===----------------------------------------------------------------------===// 350 // Dialect and Operation Registration 351 //===----------------------------------------------------------------------===// 352 353 void MLIRContext::appendDialectRegistry(const DialectRegistry ®istry) { 354 registry.appendTo(impl->dialectsRegistry); 355 356 // For the already loaded dialects, register the interfaces immediately. 357 for (const auto &kvp : impl->loadedDialects) 358 registry.registerDelayedInterfaces(kvp.second.get()); 359 } 360 361 const DialectRegistry &MLIRContext::getDialectRegistry() { 362 return impl->dialectsRegistry; 363 } 364 365 /// Return information about all registered IR dialects. 366 std::vector<Dialect *> MLIRContext::getLoadedDialects() { 367 std::vector<Dialect *> result; 368 result.reserve(impl->loadedDialects.size()); 369 for (auto &dialect : impl->loadedDialects) 370 result.push_back(dialect.second.get()); 371 llvm::array_pod_sort(result.begin(), result.end(), 372 [](Dialect *const *lhs, Dialect *const *rhs) -> int { 373 return (*lhs)->getNamespace() < (*rhs)->getNamespace(); 374 }); 375 return result; 376 } 377 std::vector<StringRef> MLIRContext::getAvailableDialects() { 378 std::vector<StringRef> result; 379 for (auto dialect : impl->dialectsRegistry.getDialectNames()) 380 result.push_back(dialect); 381 return result; 382 } 383 384 /// Get a registered IR dialect with the given namespace. If none is found, 385 /// then return nullptr. 386 Dialect *MLIRContext::getLoadedDialect(StringRef name) { 387 // Dialects are sorted by name, so we can use binary search for lookup. 388 auto it = impl->loadedDialects.find(name); 389 return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr; 390 } 391 392 Dialect *MLIRContext::getOrLoadDialect(StringRef name) { 393 Dialect *dialect = getLoadedDialect(name); 394 if (dialect) 395 return dialect; 396 DialectAllocatorFunctionRef allocator = 397 impl->dialectsRegistry.getDialectAllocator(name); 398 return allocator ? allocator(this) : nullptr; 399 } 400 401 /// Get a dialect for the provided namespace and TypeID: abort the program if a 402 /// dialect exist for this namespace with different TypeID. Returns a pointer to 403 /// the dialect owned by the context. 404 Dialect * 405 MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID, 406 function_ref<std::unique_ptr<Dialect>()> ctor) { 407 auto &impl = getImpl(); 408 // Get the correct insertion position sorted by namespace. 409 auto dialectIt = impl.loadedDialects.find(dialectNamespace); 410 411 if (dialectIt == impl.loadedDialects.end()) { 412 LLVM_DEBUG(llvm::dbgs() 413 << "Load new dialect in Context " << dialectNamespace << "\n"); 414 #ifndef NDEBUG 415 if (impl.multiThreadedExecutionContext != 0) 416 llvm::report_fatal_error( 417 "Loading a dialect (" + dialectNamespace + 418 ") while in a multi-threaded execution context (maybe " 419 "the PassManager): this can indicate a " 420 "missing `dependentDialects` in a pass for example."); 421 #endif 422 std::unique_ptr<Dialect> &dialect = 423 impl.loadedDialects.insert({dialectNamespace, ctor()}).first->second; 424 assert(dialect && "dialect ctor failed"); 425 426 // Refresh all the identifiers dialect field, this catches cases where a 427 // dialect may be loaded after identifier prefixed with this dialect name 428 // were already created. 429 auto stringAttrsIt = impl.dialectReferencingStrAttrs.find(dialectNamespace); 430 if (stringAttrsIt != impl.dialectReferencingStrAttrs.end()) { 431 for (StringAttrStorage *storage : stringAttrsIt->second) 432 storage->referencedDialect = dialect.get(); 433 impl.dialectReferencingStrAttrs.erase(stringAttrsIt); 434 } 435 436 // Actually register the interfaces with delayed registration. 437 impl.dialectsRegistry.registerDelayedInterfaces(dialect.get()); 438 return dialect.get(); 439 } 440 441 // Abort if dialect with namespace has already been registered. 442 std::unique_ptr<Dialect> &dialect = dialectIt->second; 443 if (dialect->getTypeID() != dialectID) 444 llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace + 445 "' has already been registered"); 446 447 return dialect.get(); 448 } 449 450 void MLIRContext::loadAllAvailableDialects() { 451 for (StringRef name : getAvailableDialects()) 452 getOrLoadDialect(name); 453 } 454 455 llvm::hash_code MLIRContext::getRegistryHash() { 456 llvm::hash_code hash(0); 457 // Factor in number of loaded dialects, attributes, operations, types. 458 hash = llvm::hash_combine(hash, impl->loadedDialects.size()); 459 hash = llvm::hash_combine(hash, impl->registeredAttributes.size()); 460 hash = llvm::hash_combine(hash, impl->registeredOperations.size()); 461 hash = llvm::hash_combine(hash, impl->registeredTypes.size()); 462 return hash; 463 } 464 465 bool MLIRContext::allowsUnregisteredDialects() { 466 return impl->allowUnregisteredDialects; 467 } 468 469 void MLIRContext::allowUnregisteredDialects(bool allowing) { 470 impl->allowUnregisteredDialects = allowing; 471 } 472 473 /// Return true if multi-threading is enabled by the context. 474 bool MLIRContext::isMultithreadingEnabled() { 475 return impl->threadingIsEnabled && llvm::llvm_is_multithreaded(); 476 } 477 478 /// Set the flag specifying if multi-threading is disabled by the context. 479 void MLIRContext::disableMultithreading(bool disable) { 480 // This API can be overridden by the global debugging flag 481 // --mlir-disable-threading 482 if (isThreadingGloballyDisabled()) 483 return; 484 485 impl->threadingIsEnabled = !disable; 486 487 // Update the threading mode for each of the uniquers. 488 impl->affineUniquer.disableMultithreading(disable); 489 impl->attributeUniquer.disableMultithreading(disable); 490 impl->typeUniquer.disableMultithreading(disable); 491 492 // Destroy thread pool (stop all threads) if it is no longer needed, or create 493 // a new one if multithreading was re-enabled. 494 if (disable) { 495 // If the thread pool is owned, explicitly set it to nullptr to avoid 496 // keeping a dangling pointer around. If the thread pool is externally 497 // owned, we don't do anything. 498 if (impl->ownedThreadPool) { 499 assert(impl->threadPool); 500 impl->threadPool = nullptr; 501 impl->ownedThreadPool.reset(); 502 } 503 } else if (!impl->threadPool) { 504 // The thread pool isn't externally provided. 505 assert(!impl->ownedThreadPool); 506 impl->ownedThreadPool = std::make_unique<llvm::ThreadPool>(); 507 impl->threadPool = impl->ownedThreadPool.get(); 508 } 509 } 510 511 void MLIRContext::setThreadPool(llvm::ThreadPool &pool) { 512 assert(!isMultithreadingEnabled() && 513 "expected multi-threading to be disabled when setting a ThreadPool"); 514 impl->threadPool = &pool; 515 impl->ownedThreadPool.reset(); 516 enableMultithreading(); 517 } 518 519 unsigned MLIRContext::getNumThreads() { 520 if (isMultithreadingEnabled()) { 521 assert(impl->threadPool && 522 "multi-threading is enabled but threadpool not set"); 523 return impl->threadPool->getThreadCount(); 524 } 525 // No multithreading or active thread pool. Return 1 thread. 526 return 1; 527 } 528 529 llvm::ThreadPool &MLIRContext::getThreadPool() { 530 assert(isMultithreadingEnabled() && 531 "expected multi-threading to be enabled within the context"); 532 assert(impl->threadPool && 533 "multi-threading is enabled but threadpool not set"); 534 return *impl->threadPool; 535 } 536 537 void MLIRContext::enterMultiThreadedExecution() { 538 #ifndef NDEBUG 539 ++impl->multiThreadedExecutionContext; 540 #endif 541 } 542 void MLIRContext::exitMultiThreadedExecution() { 543 #ifndef NDEBUG 544 --impl->multiThreadedExecutionContext; 545 #endif 546 } 547 548 /// Return true if we should attach the operation to diagnostics emitted via 549 /// Operation::emit. 550 bool MLIRContext::shouldPrintOpOnDiagnostic() { 551 return impl->printOpOnDiagnostic; 552 } 553 554 /// Set the flag specifying if we should attach the operation to diagnostics 555 /// emitted via Operation::emit. 556 void MLIRContext::printOpOnDiagnostic(bool enable) { 557 impl->printOpOnDiagnostic = enable; 558 } 559 560 /// Return true if we should attach the current stacktrace to diagnostics when 561 /// emitted. 562 bool MLIRContext::shouldPrintStackTraceOnDiagnostic() { 563 return impl->printStackTraceOnDiagnostic; 564 } 565 566 /// Set the flag specifying if we should attach the current stacktrace when 567 /// emitting diagnostics. 568 void MLIRContext::printStackTraceOnDiagnostic(bool enable) { 569 impl->printStackTraceOnDiagnostic = enable; 570 } 571 572 /// Return information about all registered operations. This isn't very 573 /// efficient, typically you should ask the operations about their properties 574 /// directly. 575 std::vector<RegisteredOperationName> MLIRContext::getRegisteredOperations() { 576 // We just have the operations in a non-deterministic hash table order. Dump 577 // into a temporary array, then sort it by operation name to get a stable 578 // ordering. 579 auto unwrappedNames = llvm::make_second_range(impl->registeredOperations); 580 std::vector<RegisteredOperationName> result(unwrappedNames.begin(), 581 unwrappedNames.end()); 582 llvm::array_pod_sort(result.begin(), result.end(), 583 [](const RegisteredOperationName *lhs, 584 const RegisteredOperationName *rhs) { 585 return lhs->getIdentifier().compare( 586 rhs->getIdentifier()); 587 }); 588 589 return result; 590 } 591 592 bool MLIRContext::isOperationRegistered(StringRef name) { 593 return RegisteredOperationName::lookup(name, this).hasValue(); 594 } 595 596 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) { 597 auto &impl = context->getImpl(); 598 assert(impl.multiThreadedExecutionContext == 0 && 599 "Registering a new type kind while in a multi-threaded execution " 600 "context"); 601 auto *newInfo = 602 new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>()) 603 AbstractType(std::move(typeInfo)); 604 if (!impl.registeredTypes.insert({typeID, newInfo}).second) 605 llvm::report_fatal_error("Dialect Type already registered."); 606 } 607 608 void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) { 609 auto &impl = context->getImpl(); 610 assert(impl.multiThreadedExecutionContext == 0 && 611 "Registering a new attribute kind while in a multi-threaded execution " 612 "context"); 613 auto *newInfo = 614 new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>()) 615 AbstractAttribute(std::move(attrInfo)); 616 if (!impl.registeredAttributes.insert({typeID, newInfo}).second) 617 llvm::report_fatal_error("Dialect Attribute already registered."); 618 } 619 620 //===----------------------------------------------------------------------===// 621 // AbstractAttribute 622 //===----------------------------------------------------------------------===// 623 624 /// Get the dialect that registered the attribute with the provided typeid. 625 const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID, 626 MLIRContext *context) { 627 const AbstractAttribute *abstract = lookupMutable(typeID, context); 628 if (!abstract) 629 llvm::report_fatal_error("Trying to create an Attribute that was not " 630 "registered in this MLIRContext."); 631 return *abstract; 632 } 633 634 AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID, 635 MLIRContext *context) { 636 auto &impl = context->getImpl(); 637 auto it = impl.registeredAttributes.find(typeID); 638 if (it == impl.registeredAttributes.end()) 639 return nullptr; 640 return it->second; 641 } 642 643 //===----------------------------------------------------------------------===// 644 // OperationName 645 //===----------------------------------------------------------------------===// 646 647 OperationName::OperationName(StringRef name, MLIRContext *context) { 648 MLIRContextImpl &ctxImpl = context->getImpl(); 649 650 // Check for an existing name in read-only mode. 651 bool isMultithreadingEnabled = context->isMultithreadingEnabled(); 652 if (isMultithreadingEnabled) { 653 // Check the registered info map first. In the overwhelmingly common case, 654 // the entry will be in here and it also removes the need to acquire any 655 // locks. 656 auto registeredIt = ctxImpl.registeredOperations.find(name); 657 if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperations.end())) { 658 impl = registeredIt->second.impl; 659 return; 660 } 661 662 llvm::sys::SmartScopedReader<true> contextLock(ctxImpl.operationInfoMutex); 663 auto it = ctxImpl.operations.find(name); 664 if (it != ctxImpl.operations.end()) { 665 impl = &it->second; 666 return; 667 } 668 } 669 670 // Acquire a writer-lock so that we can safely create the new instance. 671 ScopedWriterLock lock(ctxImpl.operationInfoMutex, isMultithreadingEnabled); 672 673 auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)}); 674 if (it.second) 675 it.first->second.name = StringAttr::get(context, name); 676 impl = &it.first->second; 677 } 678 679 StringRef OperationName::getDialectNamespace() const { 680 if (Dialect *dialect = getDialect()) 681 return dialect->getNamespace(); 682 return getStringRef().split('.').first; 683 } 684 685 //===----------------------------------------------------------------------===// 686 // RegisteredOperationName 687 //===----------------------------------------------------------------------===// 688 689 Optional<RegisteredOperationName> 690 RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) { 691 auto &impl = ctx->getImpl(); 692 auto it = impl.registeredOperations.find(name); 693 if (it != impl.registeredOperations.end()) 694 return it->getValue(); 695 return llvm::None; 696 } 697 698 ParseResult 699 RegisteredOperationName::parseAssembly(OpAsmParser &parser, 700 OperationState &result) const { 701 return impl->parseAssemblyFn(parser, result); 702 } 703 704 void RegisteredOperationName::insert( 705 StringRef name, Dialect &dialect, TypeID typeID, 706 ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly, 707 VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook, 708 GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, 709 detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait, 710 ArrayRef<StringRef> attrNames) { 711 MLIRContext *ctx = dialect.getContext(); 712 auto &ctxImpl = ctx->getImpl(); 713 assert(ctxImpl.multiThreadedExecutionContext == 0 && 714 "registering a new operation kind while in a multi-threaded execution " 715 "context"); 716 717 // Register the attribute names of this operation. 718 MutableArrayRef<StringAttr> cachedAttrNames; 719 if (!attrNames.empty()) { 720 cachedAttrNames = MutableArrayRef<StringAttr>( 721 ctxImpl.abstractDialectSymbolAllocator.Allocate<StringAttr>( 722 attrNames.size()), 723 attrNames.size()); 724 for (unsigned i : llvm::seq<unsigned>(0, attrNames.size())) 725 new (&cachedAttrNames[i]) StringAttr(StringAttr::get(ctx, attrNames[i])); 726 } 727 728 // Insert the operation info if it doesn't exist yet. 729 auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)}); 730 if (it.second) 731 it.first->second.name = StringAttr::get(ctx, name); 732 OperationName::Impl &impl = it.first->second; 733 734 if (impl.isRegistered()) { 735 llvm::errs() << "error: operation named '" << name 736 << "' is already registered.\n"; 737 abort(); 738 } 739 ctxImpl.registeredOperations.try_emplace(name, 740 RegisteredOperationName(&impl)); 741 742 // Update the registered info for this operation. 743 impl.dialect = &dialect; 744 impl.typeID = typeID; 745 impl.interfaceMap = std::move(interfaceMap); 746 impl.foldHookFn = std::move(foldHook); 747 impl.getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns); 748 impl.hasTraitFn = std::move(hasTrait); 749 impl.parseAssemblyFn = std::move(parseAssembly); 750 impl.printAssemblyFn = std::move(printAssembly); 751 impl.verifyInvariantsFn = std::move(verifyInvariants); 752 impl.attributeNames = cachedAttrNames; 753 } 754 755 //===----------------------------------------------------------------------===// 756 // AbstractType 757 //===----------------------------------------------------------------------===// 758 759 const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) { 760 const AbstractType *type = lookupMutable(typeID, context); 761 if (!type) 762 llvm::report_fatal_error( 763 "Trying to create a Type that was not registered in this MLIRContext."); 764 return *type; 765 } 766 767 AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) { 768 auto &impl = context->getImpl(); 769 auto it = impl.registeredTypes.find(typeID); 770 if (it == impl.registeredTypes.end()) 771 return nullptr; 772 return it->second; 773 } 774 775 //===----------------------------------------------------------------------===// 776 // Type uniquing 777 //===----------------------------------------------------------------------===// 778 779 /// Returns the storage uniquer used for constructing type storage instances. 780 /// This should not be used directly. 781 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; } 782 783 BFloat16Type BFloat16Type::get(MLIRContext *context) { 784 return context->getImpl().bf16Ty; 785 } 786 Float16Type Float16Type::get(MLIRContext *context) { 787 return context->getImpl().f16Ty; 788 } 789 Float32Type Float32Type::get(MLIRContext *context) { 790 return context->getImpl().f32Ty; 791 } 792 Float64Type Float64Type::get(MLIRContext *context) { 793 return context->getImpl().f64Ty; 794 } 795 Float80Type Float80Type::get(MLIRContext *context) { 796 return context->getImpl().f80Ty; 797 } 798 Float128Type Float128Type::get(MLIRContext *context) { 799 return context->getImpl().f128Ty; 800 } 801 802 /// Get an instance of the IndexType. 803 IndexType IndexType::get(MLIRContext *context) { 804 return context->getImpl().indexTy; 805 } 806 807 /// Return an existing integer type instance if one is cached within the 808 /// context. 809 static IntegerType 810 getCachedIntegerType(unsigned width, 811 IntegerType::SignednessSemantics signedness, 812 MLIRContext *context) { 813 if (signedness != IntegerType::Signless) 814 return IntegerType(); 815 816 switch (width) { 817 case 1: 818 return context->getImpl().int1Ty; 819 case 8: 820 return context->getImpl().int8Ty; 821 case 16: 822 return context->getImpl().int16Ty; 823 case 32: 824 return context->getImpl().int32Ty; 825 case 64: 826 return context->getImpl().int64Ty; 827 case 128: 828 return context->getImpl().int128Ty; 829 default: 830 return IntegerType(); 831 } 832 } 833 834 IntegerType IntegerType::get(MLIRContext *context, unsigned width, 835 IntegerType::SignednessSemantics signedness) { 836 if (auto cached = getCachedIntegerType(width, signedness, context)) 837 return cached; 838 return Base::get(context, width, signedness); 839 } 840 841 IntegerType 842 IntegerType::getChecked(function_ref<InFlightDiagnostic()> emitError, 843 MLIRContext *context, unsigned width, 844 SignednessSemantics signedness) { 845 if (auto cached = getCachedIntegerType(width, signedness, context)) 846 return cached; 847 return Base::getChecked(emitError, context, width, signedness); 848 } 849 850 /// Get an instance of the NoneType. 851 NoneType NoneType::get(MLIRContext *context) { 852 if (NoneType cachedInst = context->getImpl().noneType) 853 return cachedInst; 854 // Note: May happen when initializing the singleton attributes of the builtin 855 // dialect. 856 return Base::get(context); 857 } 858 859 //===----------------------------------------------------------------------===// 860 // Attribute uniquing 861 //===----------------------------------------------------------------------===// 862 863 /// Returns the storage uniquer used for constructing attribute storage 864 /// instances. This should not be used directly. 865 StorageUniquer &MLIRContext::getAttributeUniquer() { 866 return getImpl().attributeUniquer; 867 } 868 869 /// Initialize the given attribute storage instance. 870 void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage, 871 MLIRContext *ctx, 872 TypeID attrID) { 873 storage->initializeAbstractAttribute(AbstractAttribute::lookup(attrID, ctx)); 874 875 // If the attribute did not provide a type, then default to NoneType. 876 if (!storage->getType()) 877 storage->setType(NoneType::get(ctx)); 878 } 879 880 BoolAttr BoolAttr::get(MLIRContext *context, bool value) { 881 return value ? context->getImpl().trueAttr : context->getImpl().falseAttr; 882 } 883 884 UnitAttr UnitAttr::get(MLIRContext *context) { 885 return context->getImpl().unitAttr; 886 } 887 888 UnknownLoc UnknownLoc::get(MLIRContext *context) { 889 return context->getImpl().unknownLocAttr; 890 } 891 892 /// Return empty dictionary. 893 DictionaryAttr DictionaryAttr::getEmpty(MLIRContext *context) { 894 return context->getImpl().emptyDictionaryAttr; 895 } 896 897 void StringAttrStorage::initialize(MLIRContext *context) { 898 // Check for a dialect namespace prefix, if there isn't one we don't need to 899 // do any additional initialization. 900 auto dialectNamePair = value.split('.'); 901 if (dialectNamePair.first.empty() || dialectNamePair.second.empty()) 902 return; 903 904 // If one exists, we check to see if this dialect is loaded. If it is, we set 905 // the dialect now, if it isn't we record this storage for initialization 906 // later if the dialect ever gets loaded. 907 if ((referencedDialect = context->getLoadedDialect(dialectNamePair.first))) 908 return; 909 910 MLIRContextImpl &impl = context->getImpl(); 911 llvm::sys::SmartScopedLock<true> lock(impl.dialectRefStrAttrMutex); 912 impl.dialectReferencingStrAttrs[dialectNamePair.first].push_back(this); 913 } 914 915 /// Return an empty string. 916 StringAttr StringAttr::get(MLIRContext *context) { 917 return context->getImpl().emptyStringAttr; 918 } 919 920 //===----------------------------------------------------------------------===// 921 // AffineMap uniquing 922 //===----------------------------------------------------------------------===// 923 924 StorageUniquer &MLIRContext::getAffineUniquer() { 925 return getImpl().affineUniquer; 926 } 927 928 AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount, 929 ArrayRef<AffineExpr> results, 930 MLIRContext *context) { 931 auto &impl = context->getImpl(); 932 auto *storage = impl.affineUniquer.get<AffineMapStorage>( 933 [&](AffineMapStorage *storage) { storage->context = context; }, dimCount, 934 symbolCount, results); 935 return AffineMap(storage); 936 } 937 938 /// Check whether the arguments passed to the AffineMap::get() are consistent. 939 /// This method checks whether the highest index of dimensional identifier 940 /// present in result expressions is less than `dimCount` and the highest index 941 /// of symbolic identifier present in result expressions is less than 942 /// `symbolCount`. 943 LLVM_ATTRIBUTE_UNUSED static bool 944 willBeValidAffineMap(unsigned dimCount, unsigned symbolCount, 945 ArrayRef<AffineExpr> results) { 946 int64_t maxDimPosition = -1; 947 int64_t maxSymbolPosition = -1; 948 getMaxDimAndSymbol(ArrayRef<ArrayRef<AffineExpr>>(results), maxDimPosition, 949 maxSymbolPosition); 950 if ((maxDimPosition >= dimCount) || (maxSymbolPosition >= symbolCount)) { 951 LLVM_DEBUG( 952 llvm::dbgs() 953 << "maximum dimensional identifier position in result expression must " 954 "be less than `dimCount` and maximum symbolic identifier position " 955 "in result expression must be less than `symbolCount`\n"); 956 return false; 957 } 958 return true; 959 } 960 961 AffineMap AffineMap::get(MLIRContext *context) { 962 return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context); 963 } 964 965 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, 966 MLIRContext *context) { 967 return getImpl(dimCount, symbolCount, /*results=*/{}, context); 968 } 969 970 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, 971 AffineExpr result) { 972 assert(willBeValidAffineMap(dimCount, symbolCount, {result})); 973 return getImpl(dimCount, symbolCount, {result}, result.getContext()); 974 } 975 976 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, 977 ArrayRef<AffineExpr> results, MLIRContext *context) { 978 assert(willBeValidAffineMap(dimCount, symbolCount, results)); 979 return getImpl(dimCount, symbolCount, results, context); 980 } 981 982 //===----------------------------------------------------------------------===// 983 // Integer Sets: these are allocated into the bump pointer, and are immutable. 984 // Unlike AffineMap's, these are uniqued only if they are small. 985 //===----------------------------------------------------------------------===// 986 987 IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount, 988 ArrayRef<AffineExpr> constraints, 989 ArrayRef<bool> eqFlags) { 990 // The number of constraints can't be zero. 991 assert(!constraints.empty()); 992 assert(constraints.size() == eqFlags.size()); 993 994 auto &impl = constraints[0].getContext()->getImpl(); 995 auto *storage = impl.affineUniquer.get<IntegerSetStorage>( 996 [](IntegerSetStorage *) {}, dimCount, symbolCount, constraints, eqFlags); 997 return IntegerSet(storage); 998 } 999 1000 //===----------------------------------------------------------------------===// 1001 // StorageUniquerSupport 1002 //===----------------------------------------------------------------------===// 1003 1004 /// Utility method to generate a callback that can be used to generate a 1005 /// diagnostic when checking the construction invariants of a storage object. 1006 /// This is defined out-of-line to avoid the need to include Location.h. 1007 llvm::unique_function<InFlightDiagnostic()> 1008 mlir::detail::getDefaultDiagnosticEmitFn(MLIRContext *ctx) { 1009 return [ctx] { return emitError(UnknownLoc::get(ctx)); }; 1010 } 1011 llvm::unique_function<InFlightDiagnostic()> 1012 mlir::detail::getDefaultDiagnosticEmitFn(const Location &loc) { 1013 return [=] { return emitError(loc); }; 1014 } 1015