1 //===- MLIRContext.cpp - MLIR Type Classes --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/MLIRContext.h" 10 #include "AffineExprDetail.h" 11 #include "AffineMapDetail.h" 12 #include "AttributeDetail.h" 13 #include "IntegerSetDetail.h" 14 #include "TypeDetail.h" 15 #include "mlir/IR/AffineExpr.h" 16 #include "mlir/IR/AffineMap.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/BuiltinDialect.h" 19 #include "mlir/IR/Diagnostics.h" 20 #include "mlir/IR/Dialect.h" 21 #include "mlir/IR/IntegerSet.h" 22 #include "mlir/IR/Location.h" 23 #include "mlir/IR/OpImplementation.h" 24 #include "mlir/IR/Types.h" 25 #include "mlir/Support/DebugAction.h" 26 #include "llvm/ADT/DenseMap.h" 27 #include "llvm/ADT/DenseSet.h" 28 #include "llvm/ADT/SetVector.h" 29 #include "llvm/ADT/SmallString.h" 30 #include "llvm/ADT/StringSet.h" 31 #include "llvm/ADT/Twine.h" 32 #include "llvm/Support/Allocator.h" 33 #include "llvm/Support/CommandLine.h" 34 #include "llvm/Support/Debug.h" 35 #include "llvm/Support/Mutex.h" 36 #include "llvm/Support/RWMutex.h" 37 #include "llvm/Support/ThreadPool.h" 38 #include "llvm/Support/raw_ostream.h" 39 #include <memory> 40 41 #define DEBUG_TYPE "mlircontext" 42 43 using namespace mlir; 44 using namespace mlir::detail; 45 46 using llvm::hash_combine; 47 using llvm::hash_combine_range; 48 49 //===----------------------------------------------------------------------===// 50 // MLIRContext CommandLine Options 51 //===----------------------------------------------------------------------===// 52 53 namespace { 54 /// This struct contains command line options that can be used to initialize 55 /// various bits of an MLIRContext. This uses a struct wrapper to avoid the need 56 /// for global command line options. 57 struct MLIRContextOptions { 58 llvm::cl::opt<bool> disableThreading{ 59 "mlir-disable-threading", 60 llvm::cl::desc("Disable multi-threading within MLIR, overrides any " 61 "further call to MLIRContext::enableMultiThreading()")}; 62 63 llvm::cl::opt<bool> printOpOnDiagnostic{ 64 "mlir-print-op-on-diagnostic", 65 llvm::cl::desc("When a diagnostic is emitted on an operation, also print " 66 "the operation as an attached note"), 67 llvm::cl::init(true)}; 68 69 llvm::cl::opt<bool> printStackTraceOnDiagnostic{ 70 "mlir-print-stacktrace-on-diagnostic", 71 llvm::cl::desc("When a diagnostic is emitted, also print the stack trace " 72 "as an attached note")}; 73 }; 74 } // end anonymous namespace 75 76 static llvm::ManagedStatic<MLIRContextOptions> clOptions; 77 78 static bool isThreadingGloballyDisabled() { 79 #if LLVM_ENABLE_THREADS != 0 80 return clOptions.isConstructed() && clOptions->disableThreading; 81 #else 82 return true; 83 #endif 84 } 85 86 /// Register a set of useful command-line options that can be used to configure 87 /// various flags within the MLIRContext. These flags are used when constructing 88 /// an MLIR context for initialization. 89 void mlir::registerMLIRContextCLOptions() { 90 // Make sure that the options struct has been initialized. 91 *clOptions; 92 } 93 94 //===----------------------------------------------------------------------===// 95 // Locking Utilities 96 //===----------------------------------------------------------------------===// 97 98 namespace { 99 /// Utility writer lock that takes a runtime flag that specifies if we really 100 /// need to lock. 101 struct ScopedWriterLock { 102 ScopedWriterLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock) 103 : mutex(shouldLock ? &mutexParam : nullptr) { 104 if (mutex) 105 mutex->lock(); 106 } 107 ~ScopedWriterLock() { 108 if (mutex) 109 mutex->unlock(); 110 } 111 llvm::sys::SmartRWMutex<true> *mutex; 112 }; 113 } // end anonymous namespace. 114 115 //===----------------------------------------------------------------------===// 116 // MLIRContextImpl 117 //===----------------------------------------------------------------------===// 118 119 namespace mlir { 120 /// This is the implementation of the MLIRContext class, using the pImpl idiom. 121 /// This class is completely private to this file, so everything is public. 122 class MLIRContextImpl { 123 public: 124 //===--------------------------------------------------------------------===// 125 // Debugging 126 //===--------------------------------------------------------------------===// 127 128 /// An action manager for use within the context. 129 DebugActionManager debugActionManager; 130 131 //===--------------------------------------------------------------------===// 132 // Diagnostics 133 //===--------------------------------------------------------------------===// 134 DiagnosticEngine diagEngine; 135 136 //===--------------------------------------------------------------------===// 137 // Options 138 //===--------------------------------------------------------------------===// 139 140 /// In most cases, creating operation in unregistered dialect is not desired 141 /// and indicate a misconfiguration of the compiler. This option enables to 142 /// detect such use cases 143 bool allowUnregisteredDialects = false; 144 145 /// Enable support for multi-threading within MLIR. 146 bool threadingIsEnabled = true; 147 148 /// Track if we are currently executing in a threaded execution environment 149 /// (like the pass-manager): this is only a debugging feature to help reducing 150 /// the chances of data races one some context APIs. 151 #ifndef NDEBUG 152 std::atomic<int> multiThreadedExecutionContext{0}; 153 #endif 154 155 /// If the operation should be attached to diagnostics printed via the 156 /// Operation::emit methods. 157 bool printOpOnDiagnostic = true; 158 159 /// If the current stack trace should be attached when emitting diagnostics. 160 bool printStackTraceOnDiagnostic = false; 161 162 //===--------------------------------------------------------------------===// 163 // Other 164 //===--------------------------------------------------------------------===// 165 166 /// This points to the ThreadPool used when processing MLIR tasks in parallel. 167 /// It can't be nullptr when multi-threading is enabled. Otherwise if 168 /// multi-threading is disabled, and the threadpool wasn't externally provided 169 /// using `setThreadPool`, this will be nullptr. 170 llvm::ThreadPool *threadPool = nullptr; 171 172 /// In case where the thread pool is owned by the context, this ensures 173 /// destruction with the context. 174 std::unique_ptr<llvm::ThreadPool> ownedThreadPool; 175 176 /// This is a list of dialects that are created referring to this context. 177 /// The MLIRContext owns the objects. 178 DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects; 179 DialectRegistry dialectsRegistry; 180 181 /// An allocator used for AbstractAttribute and AbstractType objects. 182 llvm::BumpPtrAllocator abstractDialectSymbolAllocator; 183 184 /// This is a mapping from operation name to the operation info describing it. 185 llvm::StringMap<OperationName::Impl> operations; 186 187 /// A vector of operation info specifically for registered operations. 188 SmallVector<RegisteredOperationName> registeredOperations; 189 190 /// A mutex used when accessing operation information. 191 llvm::sys::SmartRWMutex<true> operationInfoMutex; 192 193 //===--------------------------------------------------------------------===// 194 // Affine uniquing 195 //===--------------------------------------------------------------------===// 196 197 // Affine expression, map and integer set uniquing. 198 StorageUniquer affineUniquer; 199 200 //===--------------------------------------------------------------------===// 201 // Type uniquing 202 //===--------------------------------------------------------------------===// 203 204 DenseMap<TypeID, AbstractType *> registeredTypes; 205 StorageUniquer typeUniquer; 206 207 /// Cached Type Instances. 208 BFloat16Type bf16Ty; 209 Float16Type f16Ty; 210 Float32Type f32Ty; 211 Float64Type f64Ty; 212 Float80Type f80Ty; 213 Float128Type f128Ty; 214 IndexType indexTy; 215 IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty; 216 NoneType noneType; 217 218 //===--------------------------------------------------------------------===// 219 // Attribute uniquing 220 //===--------------------------------------------------------------------===// 221 222 DenseMap<TypeID, AbstractAttribute *> registeredAttributes; 223 StorageUniquer attributeUniquer; 224 225 /// Cached Attribute Instances. 226 BoolAttr falseAttr, trueAttr; 227 UnitAttr unitAttr; 228 UnknownLoc unknownLocAttr; 229 DictionaryAttr emptyDictionaryAttr; 230 StringAttr emptyStringAttr; 231 232 /// Map of string attributes that may reference a dialect, that are awaiting 233 /// that dialect to be loaded. 234 llvm::sys::SmartMutex<true> dialectRefStrAttrMutex; 235 DenseMap<StringRef, SmallVector<StringAttrStorage *>> 236 dialectReferencingStrAttrs; 237 238 public: 239 MLIRContextImpl(bool threadingIsEnabled) 240 : threadingIsEnabled(threadingIsEnabled) { 241 if (threadingIsEnabled) { 242 ownedThreadPool = std::make_unique<llvm::ThreadPool>(); 243 threadPool = ownedThreadPool.get(); 244 } 245 } 246 ~MLIRContextImpl() { 247 for (auto typeMapping : registeredTypes) 248 typeMapping.second->~AbstractType(); 249 for (auto attrMapping : registeredAttributes) 250 attrMapping.second->~AbstractAttribute(); 251 } 252 }; 253 } // end namespace mlir 254 255 MLIRContext::MLIRContext(Threading setting) 256 : MLIRContext(DialectRegistry(), setting) {} 257 258 MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting) 259 : impl(new MLIRContextImpl(setting == Threading::ENABLED && 260 !isThreadingGloballyDisabled())) { 261 // Initialize values based on the command line flags if they were provided. 262 if (clOptions.isConstructed()) { 263 printOpOnDiagnostic(clOptions->printOpOnDiagnostic); 264 printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic); 265 } 266 267 // Pre-populate the registry. 268 registry.appendTo(impl->dialectsRegistry); 269 270 // Ensure the builtin dialect is always pre-loaded. 271 getOrLoadDialect<BuiltinDialect>(); 272 273 // Initialize several common attributes and types to avoid the need to lock 274 // the context when accessing them. 275 276 //// Types. 277 /// Floating-point Types. 278 impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this); 279 impl->f16Ty = TypeUniquer::get<Float16Type>(this); 280 impl->f32Ty = TypeUniquer::get<Float32Type>(this); 281 impl->f64Ty = TypeUniquer::get<Float64Type>(this); 282 impl->f80Ty = TypeUniquer::get<Float80Type>(this); 283 impl->f128Ty = TypeUniquer::get<Float128Type>(this); 284 /// Index Type. 285 impl->indexTy = TypeUniquer::get<IndexType>(this); 286 /// Integer Types. 287 impl->int1Ty = TypeUniquer::get<IntegerType>(this, 1, IntegerType::Signless); 288 impl->int8Ty = TypeUniquer::get<IntegerType>(this, 8, IntegerType::Signless); 289 impl->int16Ty = 290 TypeUniquer::get<IntegerType>(this, 16, IntegerType::Signless); 291 impl->int32Ty = 292 TypeUniquer::get<IntegerType>(this, 32, IntegerType::Signless); 293 impl->int64Ty = 294 TypeUniquer::get<IntegerType>(this, 64, IntegerType::Signless); 295 impl->int128Ty = 296 TypeUniquer::get<IntegerType>(this, 128, IntegerType::Signless); 297 /// None Type. 298 impl->noneType = TypeUniquer::get<NoneType>(this); 299 300 //// Attributes. 301 //// Note: These must be registered after the types as they may generate one 302 //// of the above types internally. 303 /// Unknown Location Attribute. 304 impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(this); 305 /// Bool Attributes. 306 impl->falseAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, false); 307 impl->trueAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, true); 308 /// Unit Attribute. 309 impl->unitAttr = AttributeUniquer::get<UnitAttr>(this); 310 /// The empty dictionary attribute. 311 impl->emptyDictionaryAttr = DictionaryAttr::getEmptyUnchecked(this); 312 /// The empty string attribute. 313 impl->emptyStringAttr = StringAttr::getEmptyStringAttrUnchecked(this); 314 315 // Register the affine storage objects with the uniquer. 316 impl->affineUniquer 317 .registerParametricStorageType<AffineBinaryOpExprStorage>(); 318 impl->affineUniquer 319 .registerParametricStorageType<AffineConstantExprStorage>(); 320 impl->affineUniquer.registerParametricStorageType<AffineDimExprStorage>(); 321 impl->affineUniquer.registerParametricStorageType<AffineMapStorage>(); 322 impl->affineUniquer.registerParametricStorageType<IntegerSetStorage>(); 323 } 324 325 MLIRContext::~MLIRContext() {} 326 327 /// Copy the specified array of elements into memory managed by the provided 328 /// bump pointer allocator. This assumes the elements are all PODs. 329 template <typename T> 330 static ArrayRef<T> copyArrayRefInto(llvm::BumpPtrAllocator &allocator, 331 ArrayRef<T> elements) { 332 auto result = allocator.Allocate<T>(elements.size()); 333 std::uninitialized_copy(elements.begin(), elements.end(), result); 334 return ArrayRef<T>(result, elements.size()); 335 } 336 337 //===----------------------------------------------------------------------===// 338 // Debugging 339 //===----------------------------------------------------------------------===// 340 341 DebugActionManager &MLIRContext::getDebugActionManager() { 342 return getImpl().debugActionManager; 343 } 344 345 //===----------------------------------------------------------------------===// 346 // Diagnostic Handlers 347 //===----------------------------------------------------------------------===// 348 349 /// Returns the diagnostic engine for this context. 350 DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; } 351 352 //===----------------------------------------------------------------------===// 353 // Dialect and Operation Registration 354 //===----------------------------------------------------------------------===// 355 356 void MLIRContext::appendDialectRegistry(const DialectRegistry ®istry) { 357 registry.appendTo(impl->dialectsRegistry); 358 359 // For the already loaded dialects, register the interfaces immediately. 360 for (const auto &kvp : impl->loadedDialects) 361 registry.registerDelayedInterfaces(kvp.second.get()); 362 } 363 364 const DialectRegistry &MLIRContext::getDialectRegistry() { 365 return impl->dialectsRegistry; 366 } 367 368 /// Return information about all registered IR dialects. 369 std::vector<Dialect *> MLIRContext::getLoadedDialects() { 370 std::vector<Dialect *> result; 371 result.reserve(impl->loadedDialects.size()); 372 for (auto &dialect : impl->loadedDialects) 373 result.push_back(dialect.second.get()); 374 llvm::array_pod_sort(result.begin(), result.end(), 375 [](Dialect *const *lhs, Dialect *const *rhs) -> int { 376 return (*lhs)->getNamespace() < (*rhs)->getNamespace(); 377 }); 378 return result; 379 } 380 std::vector<StringRef> MLIRContext::getAvailableDialects() { 381 std::vector<StringRef> result; 382 for (auto dialect : impl->dialectsRegistry.getDialectNames()) 383 result.push_back(dialect); 384 return result; 385 } 386 387 /// Get a registered IR dialect with the given namespace. If none is found, 388 /// then return nullptr. 389 Dialect *MLIRContext::getLoadedDialect(StringRef name) { 390 // Dialects are sorted by name, so we can use binary search for lookup. 391 auto it = impl->loadedDialects.find(name); 392 return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr; 393 } 394 395 Dialect *MLIRContext::getOrLoadDialect(StringRef name) { 396 Dialect *dialect = getLoadedDialect(name); 397 if (dialect) 398 return dialect; 399 DialectAllocatorFunctionRef allocator = 400 impl->dialectsRegistry.getDialectAllocator(name); 401 return allocator ? allocator(this) : nullptr; 402 } 403 404 /// Get a dialect for the provided namespace and TypeID: abort the program if a 405 /// dialect exist for this namespace with different TypeID. Returns a pointer to 406 /// the dialect owned by the context. 407 Dialect * 408 MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID, 409 function_ref<std::unique_ptr<Dialect>()> ctor) { 410 auto &impl = getImpl(); 411 // Get the correct insertion position sorted by namespace. 412 std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace]; 413 414 if (!dialect) { 415 LLVM_DEBUG(llvm::dbgs() 416 << "Load new dialect in Context " << dialectNamespace << "\n"); 417 #ifndef NDEBUG 418 if (impl.multiThreadedExecutionContext != 0) 419 llvm::report_fatal_error( 420 "Loading a dialect (" + dialectNamespace + 421 ") while in a multi-threaded execution context (maybe " 422 "the PassManager): this can indicate a " 423 "missing `dependentDialects` in a pass for example."); 424 #endif 425 dialect = ctor(); 426 assert(dialect && "dialect ctor failed"); 427 428 // Refresh all the identifiers dialect field, this catches cases where a 429 // dialect may be loaded after identifier prefixed with this dialect name 430 // were already created. 431 auto stringAttrsIt = impl.dialectReferencingStrAttrs.find(dialectNamespace); 432 if (stringAttrsIt != impl.dialectReferencingStrAttrs.end()) { 433 for (StringAttrStorage *storage : stringAttrsIt->second) 434 storage->referencedDialect = dialect.get(); 435 impl.dialectReferencingStrAttrs.erase(stringAttrsIt); 436 } 437 438 // Actually register the interfaces with delayed registration. 439 impl.dialectsRegistry.registerDelayedInterfaces(dialect.get()); 440 return dialect.get(); 441 } 442 443 // Abort if dialect with namespace has already been registered. 444 if (dialect->getTypeID() != dialectID) 445 llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace + 446 "' has already been registered"); 447 448 return dialect.get(); 449 } 450 451 void MLIRContext::loadAllAvailableDialects() { 452 for (StringRef name : getAvailableDialects()) 453 getOrLoadDialect(name); 454 } 455 456 llvm::hash_code MLIRContext::getRegistryHash() { 457 llvm::hash_code hash(0); 458 // Factor in number of loaded dialects, attributes, operations, types. 459 hash = llvm::hash_combine(hash, impl->loadedDialects.size()); 460 hash = llvm::hash_combine(hash, impl->registeredAttributes.size()); 461 hash = llvm::hash_combine(hash, impl->registeredOperations.size()); 462 hash = llvm::hash_combine(hash, impl->registeredTypes.size()); 463 return hash; 464 } 465 466 bool MLIRContext::allowsUnregisteredDialects() { 467 return impl->allowUnregisteredDialects; 468 } 469 470 void MLIRContext::allowUnregisteredDialects(bool allowing) { 471 impl->allowUnregisteredDialects = allowing; 472 } 473 474 /// Return true if multi-threading is enabled by the context. 475 bool MLIRContext::isMultithreadingEnabled() { 476 return impl->threadingIsEnabled && llvm::llvm_is_multithreaded(); 477 } 478 479 /// Set the flag specifying if multi-threading is disabled by the context. 480 void MLIRContext::disableMultithreading(bool disable) { 481 // This API can be overridden by the global debugging flag 482 // --mlir-disable-threading 483 if (isThreadingGloballyDisabled()) 484 return; 485 486 impl->threadingIsEnabled = !disable; 487 488 // Update the threading mode for each of the uniquers. 489 impl->affineUniquer.disableMultithreading(disable); 490 impl->attributeUniquer.disableMultithreading(disable); 491 impl->typeUniquer.disableMultithreading(disable); 492 493 // Destroy thread pool (stop all threads) if it is no longer needed, or create 494 // a new one if multithreading was re-enabled. 495 if (disable) { 496 // If the thread pool is owned, explicitly set it to nullptr to avoid 497 // keeping a dangling pointer around. If the thread pool is externally 498 // owned, we don't do anything. 499 if (impl->ownedThreadPool) { 500 assert(impl->threadPool); 501 impl->threadPool = nullptr; 502 impl->ownedThreadPool.reset(); 503 } 504 } else if (!impl->threadPool) { 505 // The thread pool isn't externally provided. 506 assert(!impl->ownedThreadPool); 507 impl->ownedThreadPool = std::make_unique<llvm::ThreadPool>(); 508 impl->threadPool = impl->ownedThreadPool.get(); 509 } 510 } 511 512 void MLIRContext::setThreadPool(llvm::ThreadPool &pool) { 513 assert(!isMultithreadingEnabled() && 514 "expected multi-threading to be disabled when setting a ThreadPool"); 515 impl->threadPool = &pool; 516 impl->ownedThreadPool.reset(); 517 enableMultithreading(); 518 } 519 520 llvm::ThreadPool &MLIRContext::getThreadPool() { 521 assert(isMultithreadingEnabled() && 522 "expected multi-threading to be enabled within the context"); 523 assert(impl->threadPool && 524 "multi-threading is enabled but threadpool not set"); 525 return *impl->threadPool; 526 } 527 528 void MLIRContext::enterMultiThreadedExecution() { 529 #ifndef NDEBUG 530 ++impl->multiThreadedExecutionContext; 531 #endif 532 } 533 void MLIRContext::exitMultiThreadedExecution() { 534 #ifndef NDEBUG 535 --impl->multiThreadedExecutionContext; 536 #endif 537 } 538 539 /// Return true if we should attach the operation to diagnostics emitted via 540 /// Operation::emit. 541 bool MLIRContext::shouldPrintOpOnDiagnostic() { 542 return impl->printOpOnDiagnostic; 543 } 544 545 /// Set the flag specifying if we should attach the operation to diagnostics 546 /// emitted via Operation::emit. 547 void MLIRContext::printOpOnDiagnostic(bool enable) { 548 impl->printOpOnDiagnostic = enable; 549 } 550 551 /// Return true if we should attach the current stacktrace to diagnostics when 552 /// emitted. 553 bool MLIRContext::shouldPrintStackTraceOnDiagnostic() { 554 return impl->printStackTraceOnDiagnostic; 555 } 556 557 /// Set the flag specifying if we should attach the current stacktrace when 558 /// emitting diagnostics. 559 void MLIRContext::printStackTraceOnDiagnostic(bool enable) { 560 impl->printStackTraceOnDiagnostic = enable; 561 } 562 563 /// Return information about all registered operations. This isn't very 564 /// efficient, typically you should ask the operations about their properties 565 /// directly. 566 std::vector<RegisteredOperationName> MLIRContext::getRegisteredOperations() { 567 // We just have the operations in a non-deterministic hash table order. Dump 568 // into a temporary array, then sort it by operation name to get a stable 569 // ordering. 570 std::vector<RegisteredOperationName> result( 571 impl->registeredOperations.begin(), impl->registeredOperations.end()); 572 llvm::array_pod_sort(result.begin(), result.end(), 573 [](const RegisteredOperationName *lhs, 574 const RegisteredOperationName *rhs) { 575 return lhs->getIdentifier().compare( 576 rhs->getIdentifier()); 577 }); 578 579 return result; 580 } 581 582 bool MLIRContext::isOperationRegistered(StringRef name) { 583 return OperationName(name, this).isRegistered(); 584 } 585 586 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) { 587 auto &impl = context->getImpl(); 588 assert(impl.multiThreadedExecutionContext == 0 && 589 "Registering a new type kind while in a multi-threaded execution " 590 "context"); 591 auto *newInfo = 592 new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>()) 593 AbstractType(std::move(typeInfo)); 594 if (!impl.registeredTypes.insert({typeID, newInfo}).second) 595 llvm::report_fatal_error("Dialect Type already registered."); 596 } 597 598 void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) { 599 auto &impl = context->getImpl(); 600 assert(impl.multiThreadedExecutionContext == 0 && 601 "Registering a new attribute kind while in a multi-threaded execution " 602 "context"); 603 auto *newInfo = 604 new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>()) 605 AbstractAttribute(std::move(attrInfo)); 606 if (!impl.registeredAttributes.insert({typeID, newInfo}).second) 607 llvm::report_fatal_error("Dialect Attribute already registered."); 608 } 609 610 //===----------------------------------------------------------------------===// 611 // AbstractAttribute 612 //===----------------------------------------------------------------------===// 613 614 /// Get the dialect that registered the attribute with the provided typeid. 615 const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID, 616 MLIRContext *context) { 617 const AbstractAttribute *abstract = lookupMutable(typeID, context); 618 if (!abstract) 619 llvm::report_fatal_error("Trying to create an Attribute that was not " 620 "registered in this MLIRContext."); 621 return *abstract; 622 } 623 624 AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID, 625 MLIRContext *context) { 626 auto &impl = context->getImpl(); 627 auto it = impl.registeredAttributes.find(typeID); 628 if (it == impl.registeredAttributes.end()) 629 return nullptr; 630 return it->second; 631 } 632 633 //===----------------------------------------------------------------------===// 634 // OperationName 635 //===----------------------------------------------------------------------===// 636 637 OperationName::OperationName(StringRef name, MLIRContext *context) { 638 MLIRContextImpl &ctxImpl = context->getImpl(); 639 640 // Check for an existing name in read-only mode. 641 bool isMultithreadingEnabled = context->isMultithreadingEnabled(); 642 if (isMultithreadingEnabled) { 643 llvm::sys::SmartScopedReader<true> contextLock(ctxImpl.operationInfoMutex); 644 auto it = ctxImpl.operations.find(name); 645 if (it != ctxImpl.operations.end()) { 646 impl = &it->second; 647 return; 648 } 649 } 650 651 // Acquire a writer-lock so that we can safely create the new instance. 652 ScopedWriterLock lock(ctxImpl.operationInfoMutex, isMultithreadingEnabled); 653 654 auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)}); 655 if (it.second) 656 it.first->second.name = StringAttr::get(context, name); 657 impl = &it.first->second; 658 } 659 660 StringRef OperationName::getDialectNamespace() const { 661 if (Dialect *dialect = getDialect()) 662 return dialect->getNamespace(); 663 return getStringRef().split('.').first; 664 } 665 666 //===----------------------------------------------------------------------===// 667 // RegisteredOperationName 668 //===----------------------------------------------------------------------===// 669 670 ParseResult 671 RegisteredOperationName::parseAssembly(OpAsmParser &parser, 672 OperationState &result) const { 673 return impl->parseAssemblyFn(parser, result); 674 } 675 676 void RegisteredOperationName::insert( 677 StringRef name, Dialect &dialect, TypeID typeID, 678 ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly, 679 VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook, 680 GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, 681 detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait, 682 ArrayRef<StringRef> attrNames) { 683 MLIRContext *ctx = dialect.getContext(); 684 auto &ctxImpl = ctx->getImpl(); 685 assert(ctxImpl.multiThreadedExecutionContext == 0 && 686 "registering a new operation kind while in a multi-threaded execution " 687 "context"); 688 689 // Register the attribute names of this operation. 690 MutableArrayRef<StringAttr> cachedAttrNames; 691 if (!attrNames.empty()) { 692 cachedAttrNames = MutableArrayRef<StringAttr>( 693 ctxImpl.abstractDialectSymbolAllocator.Allocate<StringAttr>( 694 attrNames.size()), 695 attrNames.size()); 696 for (unsigned i : llvm::seq<unsigned>(0, attrNames.size())) 697 new (&cachedAttrNames[i]) StringAttr(StringAttr::get(ctx, attrNames[i])); 698 } 699 700 // Insert the operation info if it doesn't exist yet. 701 auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)}); 702 if (it.second) 703 it.first->second.name = StringAttr::get(ctx, name); 704 OperationName::Impl &impl = it.first->second; 705 706 if (impl.isRegistered()) { 707 llvm::errs() << "error: operation named '" << name 708 << "' is already registered.\n"; 709 abort(); 710 } 711 ctxImpl.registeredOperations.push_back(RegisteredOperationName(&impl)); 712 713 // Update the registered info for this operation. 714 impl.dialect = &dialect; 715 impl.typeID = typeID; 716 impl.interfaceMap = std::move(interfaceMap); 717 impl.foldHookFn = std::move(foldHook); 718 impl.getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns); 719 impl.hasTraitFn = std::move(hasTrait); 720 impl.parseAssemblyFn = std::move(parseAssembly); 721 impl.printAssemblyFn = std::move(printAssembly); 722 impl.verifyInvariantsFn = std::move(verifyInvariants); 723 impl.attributeNames = cachedAttrNames; 724 } 725 726 //===----------------------------------------------------------------------===// 727 // AbstractType 728 //===----------------------------------------------------------------------===// 729 730 const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) { 731 const AbstractType *type = lookupMutable(typeID, context); 732 if (!type) 733 llvm::report_fatal_error( 734 "Trying to create a Type that was not registered in this MLIRContext."); 735 return *type; 736 } 737 738 AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) { 739 auto &impl = context->getImpl(); 740 auto it = impl.registeredTypes.find(typeID); 741 if (it == impl.registeredTypes.end()) 742 return nullptr; 743 return it->second; 744 } 745 746 //===----------------------------------------------------------------------===// 747 // Type uniquing 748 //===----------------------------------------------------------------------===// 749 750 /// Returns the storage uniquer used for constructing type storage instances. 751 /// This should not be used directly. 752 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; } 753 754 BFloat16Type BFloat16Type::get(MLIRContext *context) { 755 return context->getImpl().bf16Ty; 756 } 757 Float16Type Float16Type::get(MLIRContext *context) { 758 return context->getImpl().f16Ty; 759 } 760 Float32Type Float32Type::get(MLIRContext *context) { 761 return context->getImpl().f32Ty; 762 } 763 Float64Type Float64Type::get(MLIRContext *context) { 764 return context->getImpl().f64Ty; 765 } 766 Float80Type Float80Type::get(MLIRContext *context) { 767 return context->getImpl().f80Ty; 768 } 769 Float128Type Float128Type::get(MLIRContext *context) { 770 return context->getImpl().f128Ty; 771 } 772 773 /// Get an instance of the IndexType. 774 IndexType IndexType::get(MLIRContext *context) { 775 return context->getImpl().indexTy; 776 } 777 778 /// Return an existing integer type instance if one is cached within the 779 /// context. 780 static IntegerType 781 getCachedIntegerType(unsigned width, 782 IntegerType::SignednessSemantics signedness, 783 MLIRContext *context) { 784 if (signedness != IntegerType::Signless) 785 return IntegerType(); 786 787 switch (width) { 788 case 1: 789 return context->getImpl().int1Ty; 790 case 8: 791 return context->getImpl().int8Ty; 792 case 16: 793 return context->getImpl().int16Ty; 794 case 32: 795 return context->getImpl().int32Ty; 796 case 64: 797 return context->getImpl().int64Ty; 798 case 128: 799 return context->getImpl().int128Ty; 800 default: 801 return IntegerType(); 802 } 803 } 804 805 IntegerType IntegerType::get(MLIRContext *context, unsigned width, 806 IntegerType::SignednessSemantics signedness) { 807 if (auto cached = getCachedIntegerType(width, signedness, context)) 808 return cached; 809 return Base::get(context, width, signedness); 810 } 811 812 IntegerType 813 IntegerType::getChecked(function_ref<InFlightDiagnostic()> emitError, 814 MLIRContext *context, unsigned width, 815 SignednessSemantics signedness) { 816 if (auto cached = getCachedIntegerType(width, signedness, context)) 817 return cached; 818 return Base::getChecked(emitError, context, width, signedness); 819 } 820 821 /// Get an instance of the NoneType. 822 NoneType NoneType::get(MLIRContext *context) { 823 if (NoneType cachedInst = context->getImpl().noneType) 824 return cachedInst; 825 // Note: May happen when initializing the singleton attributes of the builtin 826 // dialect. 827 return Base::get(context); 828 } 829 830 //===----------------------------------------------------------------------===// 831 // Attribute uniquing 832 //===----------------------------------------------------------------------===// 833 834 /// Returns the storage uniquer used for constructing attribute storage 835 /// instances. This should not be used directly. 836 StorageUniquer &MLIRContext::getAttributeUniquer() { 837 return getImpl().attributeUniquer; 838 } 839 840 /// Initialize the given attribute storage instance. 841 void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage, 842 MLIRContext *ctx, 843 TypeID attrID) { 844 storage->initializeAbstractAttribute(AbstractAttribute::lookup(attrID, ctx)); 845 846 // If the attribute did not provide a type, then default to NoneType. 847 if (!storage->getType()) 848 storage->setType(NoneType::get(ctx)); 849 } 850 851 BoolAttr BoolAttr::get(MLIRContext *context, bool value) { 852 return value ? context->getImpl().trueAttr : context->getImpl().falseAttr; 853 } 854 855 UnitAttr UnitAttr::get(MLIRContext *context) { 856 return context->getImpl().unitAttr; 857 } 858 859 UnknownLoc UnknownLoc::get(MLIRContext *context) { 860 return context->getImpl().unknownLocAttr; 861 } 862 863 /// Return empty dictionary. 864 DictionaryAttr DictionaryAttr::getEmpty(MLIRContext *context) { 865 return context->getImpl().emptyDictionaryAttr; 866 } 867 868 void StringAttrStorage::initialize(MLIRContext *context) { 869 // Check for a dialect namespace prefix, if there isn't one we don't need to 870 // do any additional initialization. 871 auto dialectNamePair = value.split('.'); 872 if (dialectNamePair.first.empty() || dialectNamePair.second.empty()) 873 return; 874 875 // If one exists, we check to see if this dialect is loaded. If it is, we set 876 // the dialect now, if it isn't we record this storage for initialization 877 // later if the dialect ever gets loaded. 878 if ((referencedDialect = context->getLoadedDialect(dialectNamePair.first))) 879 return; 880 881 MLIRContextImpl &impl = context->getImpl(); 882 llvm::sys::SmartScopedLock<true> lock(impl.dialectRefStrAttrMutex); 883 impl.dialectReferencingStrAttrs[dialectNamePair.first].push_back(this); 884 } 885 886 /// Return an empty string. 887 StringAttr StringAttr::get(MLIRContext *context) { 888 return context->getImpl().emptyStringAttr; 889 } 890 891 //===----------------------------------------------------------------------===// 892 // AffineMap uniquing 893 //===----------------------------------------------------------------------===// 894 895 StorageUniquer &MLIRContext::getAffineUniquer() { 896 return getImpl().affineUniquer; 897 } 898 899 AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount, 900 ArrayRef<AffineExpr> results, 901 MLIRContext *context) { 902 auto &impl = context->getImpl(); 903 auto *storage = impl.affineUniquer.get<AffineMapStorage>( 904 [&](AffineMapStorage *storage) { storage->context = context; }, dimCount, 905 symbolCount, results); 906 return AffineMap(storage); 907 } 908 909 /// Check whether the arguments passed to the AffineMap::get() are consistent. 910 /// This method checks whether the highest index of dimensional identifier 911 /// present in result expressions is less than `dimCount` and the highest index 912 /// of symbolic identifier present in result expressions is less than 913 /// `symbolCount`. 914 LLVM_ATTRIBUTE_UNUSED static bool 915 willBeValidAffineMap(unsigned dimCount, unsigned symbolCount, 916 ArrayRef<AffineExpr> results) { 917 int64_t maxDimPosition = -1; 918 int64_t maxSymbolPosition = -1; 919 getMaxDimAndSymbol(ArrayRef<ArrayRef<AffineExpr>>(results), maxDimPosition, 920 maxSymbolPosition); 921 if ((maxDimPosition >= dimCount) || (maxSymbolPosition >= symbolCount)) { 922 LLVM_DEBUG( 923 llvm::dbgs() 924 << "maximum dimensional identifier position in result expression must " 925 "be less than `dimCount` and maximum symbolic identifier position " 926 "in result expression must be less than `symbolCount`\n"); 927 return false; 928 } 929 return true; 930 } 931 932 AffineMap AffineMap::get(MLIRContext *context) { 933 return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context); 934 } 935 936 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, 937 MLIRContext *context) { 938 return getImpl(dimCount, symbolCount, /*results=*/{}, context); 939 } 940 941 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, 942 AffineExpr result) { 943 assert(willBeValidAffineMap(dimCount, symbolCount, {result})); 944 return getImpl(dimCount, symbolCount, {result}, result.getContext()); 945 } 946 947 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, 948 ArrayRef<AffineExpr> results, MLIRContext *context) { 949 assert(willBeValidAffineMap(dimCount, symbolCount, results)); 950 return getImpl(dimCount, symbolCount, results, context); 951 } 952 953 //===----------------------------------------------------------------------===// 954 // Integer Sets: these are allocated into the bump pointer, and are immutable. 955 // Unlike AffineMap's, these are uniqued only if they are small. 956 //===----------------------------------------------------------------------===// 957 958 IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount, 959 ArrayRef<AffineExpr> constraints, 960 ArrayRef<bool> eqFlags) { 961 // The number of constraints can't be zero. 962 assert(!constraints.empty()); 963 assert(constraints.size() == eqFlags.size()); 964 965 auto &impl = constraints[0].getContext()->getImpl(); 966 auto *storage = impl.affineUniquer.get<IntegerSetStorage>( 967 [](IntegerSetStorage *) {}, dimCount, symbolCount, constraints, eqFlags); 968 return IntegerSet(storage); 969 } 970 971 //===----------------------------------------------------------------------===// 972 // StorageUniquerSupport 973 //===----------------------------------------------------------------------===// 974 975 /// Utility method to generate a callback that can be used to generate a 976 /// diagnostic when checking the construction invariants of a storage object. 977 /// This is defined out-of-line to avoid the need to include Location.h. 978 llvm::unique_function<InFlightDiagnostic()> 979 mlir::detail::getDefaultDiagnosticEmitFn(MLIRContext *ctx) { 980 return [ctx] { return emitError(UnknownLoc::get(ctx)); }; 981 } 982 llvm::unique_function<InFlightDiagnostic()> 983 mlir::detail::getDefaultDiagnosticEmitFn(const Location &loc) { 984 return [=] { return emitError(loc); }; 985 } 986