1 //===- MLIRContext.cpp - MLIR Type Classes --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/MLIRContext.h" 10 #include "AffineExprDetail.h" 11 #include "AffineMapDetail.h" 12 #include "AttributeDetail.h" 13 #include "IntegerSetDetail.h" 14 #include "TypeDetail.h" 15 #include "mlir/IR/AffineExpr.h" 16 #include "mlir/IR/AffineMap.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/BuiltinDialect.h" 19 #include "mlir/IR/Diagnostics.h" 20 #include "mlir/IR/Dialect.h" 21 #include "mlir/IR/IntegerSet.h" 22 #include "mlir/IR/Location.h" 23 #include "mlir/IR/OpImplementation.h" 24 #include "mlir/IR/Types.h" 25 #include "mlir/Support/DebugAction.h" 26 #include "llvm/ADT/DenseMap.h" 27 #include "llvm/ADT/DenseSet.h" 28 #include "llvm/ADT/SetVector.h" 29 #include "llvm/ADT/SmallString.h" 30 #include "llvm/ADT/StringSet.h" 31 #include "llvm/ADT/Twine.h" 32 #include "llvm/Support/Allocator.h" 33 #include "llvm/Support/CommandLine.h" 34 #include "llvm/Support/Debug.h" 35 #include "llvm/Support/Mutex.h" 36 #include "llvm/Support/RWMutex.h" 37 #include "llvm/Support/ThreadPool.h" 38 #include "llvm/Support/raw_ostream.h" 39 #include <memory> 40 41 #define DEBUG_TYPE "mlircontext" 42 43 using namespace mlir; 44 using namespace mlir::detail; 45 46 //===----------------------------------------------------------------------===// 47 // MLIRContext CommandLine Options 48 //===----------------------------------------------------------------------===// 49 50 namespace { 51 /// This struct contains command line options that can be used to initialize 52 /// various bits of an MLIRContext. This uses a struct wrapper to avoid the need 53 /// for global command line options. 54 struct MLIRContextOptions { 55 llvm::cl::opt<bool> disableThreading{ 56 "mlir-disable-threading", 57 llvm::cl::desc("Disable multi-threading within MLIR, overrides any " 58 "further call to MLIRContext::enableMultiThreading()")}; 59 60 llvm::cl::opt<bool> printOpOnDiagnostic{ 61 "mlir-print-op-on-diagnostic", 62 llvm::cl::desc("When a diagnostic is emitted on an operation, also print " 63 "the operation as an attached note"), 64 llvm::cl::init(true)}; 65 66 llvm::cl::opt<bool> printStackTraceOnDiagnostic{ 67 "mlir-print-stacktrace-on-diagnostic", 68 llvm::cl::desc("When a diagnostic is emitted, also print the stack trace " 69 "as an attached note")}; 70 }; 71 } // namespace 72 73 static llvm::ManagedStatic<MLIRContextOptions> clOptions; 74 75 static bool isThreadingGloballyDisabled() { 76 #if LLVM_ENABLE_THREADS != 0 77 return clOptions.isConstructed() && clOptions->disableThreading; 78 #else 79 return true; 80 #endif 81 } 82 83 /// Register a set of useful command-line options that can be used to configure 84 /// various flags within the MLIRContext. These flags are used when constructing 85 /// an MLIR context for initialization. 86 void mlir::registerMLIRContextCLOptions() { 87 // Make sure that the options struct has been initialized. 88 *clOptions; 89 } 90 91 //===----------------------------------------------------------------------===// 92 // Locking Utilities 93 //===----------------------------------------------------------------------===// 94 95 namespace { 96 /// Utility writer lock that takes a runtime flag that specifies if we really 97 /// need to lock. 98 struct ScopedWriterLock { 99 ScopedWriterLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock) 100 : mutex(shouldLock ? &mutexParam : nullptr) { 101 if (mutex) 102 mutex->lock(); 103 } 104 ~ScopedWriterLock() { 105 if (mutex) 106 mutex->unlock(); 107 } 108 llvm::sys::SmartRWMutex<true> *mutex; 109 }; 110 } // namespace 111 112 //===----------------------------------------------------------------------===// 113 // MLIRContextImpl 114 //===----------------------------------------------------------------------===// 115 116 namespace mlir { 117 /// This is the implementation of the MLIRContext class, using the pImpl idiom. 118 /// This class is completely private to this file, so everything is public. 119 class MLIRContextImpl { 120 public: 121 //===--------------------------------------------------------------------===// 122 // Debugging 123 //===--------------------------------------------------------------------===// 124 125 /// An action manager for use within the context. 126 DebugActionManager debugActionManager; 127 128 //===--------------------------------------------------------------------===// 129 // Diagnostics 130 //===--------------------------------------------------------------------===// 131 DiagnosticEngine diagEngine; 132 133 //===--------------------------------------------------------------------===// 134 // Options 135 //===--------------------------------------------------------------------===// 136 137 /// In most cases, creating operation in unregistered dialect is not desired 138 /// and indicate a misconfiguration of the compiler. This option enables to 139 /// detect such use cases 140 bool allowUnregisteredDialects = false; 141 142 /// Enable support for multi-threading within MLIR. 143 bool threadingIsEnabled = true; 144 145 /// Track if we are currently executing in a threaded execution environment 146 /// (like the pass-manager): this is only a debugging feature to help reducing 147 /// the chances of data races one some context APIs. 148 #ifndef NDEBUG 149 std::atomic<int> multiThreadedExecutionContext{0}; 150 #endif 151 152 /// If the operation should be attached to diagnostics printed via the 153 /// Operation::emit methods. 154 bool printOpOnDiagnostic = true; 155 156 /// If the current stack trace should be attached when emitting diagnostics. 157 bool printStackTraceOnDiagnostic = false; 158 159 //===--------------------------------------------------------------------===// 160 // Other 161 //===--------------------------------------------------------------------===// 162 163 /// This points to the ThreadPool used when processing MLIR tasks in parallel. 164 /// It can't be nullptr when multi-threading is enabled. Otherwise if 165 /// multi-threading is disabled, and the threadpool wasn't externally provided 166 /// using `setThreadPool`, this will be nullptr. 167 llvm::ThreadPool *threadPool = nullptr; 168 169 /// In case where the thread pool is owned by the context, this ensures 170 /// destruction with the context. 171 std::unique_ptr<llvm::ThreadPool> ownedThreadPool; 172 173 /// This is a list of dialects that are created referring to this context. 174 /// The MLIRContext owns the objects. 175 DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects; 176 DialectRegistry dialectsRegistry; 177 178 /// An allocator used for AbstractAttribute and AbstractType objects. 179 llvm::BumpPtrAllocator abstractDialectSymbolAllocator; 180 181 /// This is a mapping from operation name to the operation info describing it. 182 llvm::StringMap<OperationName::Impl> operations; 183 184 /// A vector of operation info specifically for registered operations. 185 llvm::StringMap<RegisteredOperationName> registeredOperations; 186 187 /// This is a sorted container of registered operations for a deterministic 188 /// and efficient `getRegisteredOperations` implementation. 189 SmallVector<RegisteredOperationName, 0> sortedRegisteredOperations; 190 191 /// A mutex used when accessing operation information. 192 llvm::sys::SmartRWMutex<true> operationInfoMutex; 193 194 //===--------------------------------------------------------------------===// 195 // Affine uniquing 196 //===--------------------------------------------------------------------===// 197 198 // Affine expression, map and integer set uniquing. 199 StorageUniquer affineUniquer; 200 201 //===--------------------------------------------------------------------===// 202 // Type uniquing 203 //===--------------------------------------------------------------------===// 204 205 DenseMap<TypeID, AbstractType *> registeredTypes; 206 StorageUniquer typeUniquer; 207 208 /// Cached Type Instances. 209 BFloat16Type bf16Ty; 210 Float16Type f16Ty; 211 Float32Type f32Ty; 212 Float64Type f64Ty; 213 Float80Type f80Ty; 214 Float128Type f128Ty; 215 IndexType indexTy; 216 IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty; 217 NoneType noneType; 218 219 //===--------------------------------------------------------------------===// 220 // Attribute uniquing 221 //===--------------------------------------------------------------------===// 222 223 DenseMap<TypeID, AbstractAttribute *> registeredAttributes; 224 StorageUniquer attributeUniquer; 225 226 /// Cached Attribute Instances. 227 BoolAttr falseAttr, trueAttr; 228 UnitAttr unitAttr; 229 UnknownLoc unknownLocAttr; 230 DictionaryAttr emptyDictionaryAttr; 231 StringAttr emptyStringAttr; 232 233 /// Map of string attributes that may reference a dialect, that are awaiting 234 /// that dialect to be loaded. 235 llvm::sys::SmartMutex<true> dialectRefStrAttrMutex; 236 DenseMap<StringRef, SmallVector<StringAttrStorage *>> 237 dialectReferencingStrAttrs; 238 239 public: 240 MLIRContextImpl(bool threadingIsEnabled) 241 : threadingIsEnabled(threadingIsEnabled) { 242 if (threadingIsEnabled) { 243 ownedThreadPool = std::make_unique<llvm::ThreadPool>(); 244 threadPool = ownedThreadPool.get(); 245 } 246 } 247 ~MLIRContextImpl() { 248 for (auto typeMapping : registeredTypes) 249 typeMapping.second->~AbstractType(); 250 for (auto attrMapping : registeredAttributes) 251 attrMapping.second->~AbstractAttribute(); 252 } 253 }; 254 } // namespace mlir 255 256 MLIRContext::MLIRContext(Threading setting) 257 : MLIRContext(DialectRegistry(), setting) {} 258 259 MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting) 260 : impl(new MLIRContextImpl(setting == Threading::ENABLED && 261 !isThreadingGloballyDisabled())) { 262 // Initialize values based on the command line flags if they were provided. 263 if (clOptions.isConstructed()) { 264 printOpOnDiagnostic(clOptions->printOpOnDiagnostic); 265 printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic); 266 } 267 268 // Pre-populate the registry. 269 registry.appendTo(impl->dialectsRegistry); 270 271 // Ensure the builtin dialect is always pre-loaded. 272 getOrLoadDialect<BuiltinDialect>(); 273 274 // Initialize several common attributes and types to avoid the need to lock 275 // the context when accessing them. 276 277 //// Types. 278 /// Floating-point Types. 279 impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this); 280 impl->f16Ty = TypeUniquer::get<Float16Type>(this); 281 impl->f32Ty = TypeUniquer::get<Float32Type>(this); 282 impl->f64Ty = TypeUniquer::get<Float64Type>(this); 283 impl->f80Ty = TypeUniquer::get<Float80Type>(this); 284 impl->f128Ty = TypeUniquer::get<Float128Type>(this); 285 /// Index Type. 286 impl->indexTy = TypeUniquer::get<IndexType>(this); 287 /// Integer Types. 288 impl->int1Ty = TypeUniquer::get<IntegerType>(this, 1, IntegerType::Signless); 289 impl->int8Ty = TypeUniquer::get<IntegerType>(this, 8, IntegerType::Signless); 290 impl->int16Ty = 291 TypeUniquer::get<IntegerType>(this, 16, IntegerType::Signless); 292 impl->int32Ty = 293 TypeUniquer::get<IntegerType>(this, 32, IntegerType::Signless); 294 impl->int64Ty = 295 TypeUniquer::get<IntegerType>(this, 64, IntegerType::Signless); 296 impl->int128Ty = 297 TypeUniquer::get<IntegerType>(this, 128, IntegerType::Signless); 298 /// None Type. 299 impl->noneType = TypeUniquer::get<NoneType>(this); 300 301 //// Attributes. 302 //// Note: These must be registered after the types as they may generate one 303 //// of the above types internally. 304 /// Unknown Location Attribute. 305 impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(this); 306 /// Bool Attributes. 307 impl->falseAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, false); 308 impl->trueAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, true); 309 /// Unit Attribute. 310 impl->unitAttr = AttributeUniquer::get<UnitAttr>(this); 311 /// The empty dictionary attribute. 312 impl->emptyDictionaryAttr = DictionaryAttr::getEmptyUnchecked(this); 313 /// The empty string attribute. 314 impl->emptyStringAttr = StringAttr::getEmptyStringAttrUnchecked(this); 315 316 // Register the affine storage objects with the uniquer. 317 impl->affineUniquer 318 .registerParametricStorageType<AffineBinaryOpExprStorage>(); 319 impl->affineUniquer 320 .registerParametricStorageType<AffineConstantExprStorage>(); 321 impl->affineUniquer.registerParametricStorageType<AffineDimExprStorage>(); 322 impl->affineUniquer.registerParametricStorageType<AffineMapStorage>(); 323 impl->affineUniquer.registerParametricStorageType<IntegerSetStorage>(); 324 } 325 326 MLIRContext::~MLIRContext() = default; 327 328 /// Copy the specified array of elements into memory managed by the provided 329 /// bump pointer allocator. This assumes the elements are all PODs. 330 template <typename T> 331 static ArrayRef<T> copyArrayRefInto(llvm::BumpPtrAllocator &allocator, 332 ArrayRef<T> elements) { 333 auto result = allocator.Allocate<T>(elements.size()); 334 std::uninitialized_copy(elements.begin(), elements.end(), result); 335 return ArrayRef<T>(result, elements.size()); 336 } 337 338 //===----------------------------------------------------------------------===// 339 // Debugging 340 //===----------------------------------------------------------------------===// 341 342 DebugActionManager &MLIRContext::getDebugActionManager() { 343 return getImpl().debugActionManager; 344 } 345 346 //===----------------------------------------------------------------------===// 347 // Diagnostic Handlers 348 //===----------------------------------------------------------------------===// 349 350 /// Returns the diagnostic engine for this context. 351 DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; } 352 353 //===----------------------------------------------------------------------===// 354 // Dialect and Operation Registration 355 //===----------------------------------------------------------------------===// 356 357 void MLIRContext::appendDialectRegistry(const DialectRegistry ®istry) { 358 if (registry.isSubsetOf(impl->dialectsRegistry)) 359 return; 360 361 assert(impl->multiThreadedExecutionContext == 0 && 362 "appending to the MLIRContext dialect registry while in a " 363 "multi-threaded execution context"); 364 registry.appendTo(impl->dialectsRegistry); 365 366 // For the already loaded dialects, apply any possible extensions immediately. 367 registry.applyExtensions(this); 368 } 369 370 const DialectRegistry &MLIRContext::getDialectRegistry() { 371 return impl->dialectsRegistry; 372 } 373 374 /// Return information about all registered IR dialects. 375 std::vector<Dialect *> MLIRContext::getLoadedDialects() { 376 std::vector<Dialect *> result; 377 result.reserve(impl->loadedDialects.size()); 378 for (auto &dialect : impl->loadedDialects) 379 result.push_back(dialect.second.get()); 380 llvm::array_pod_sort(result.begin(), result.end(), 381 [](Dialect *const *lhs, Dialect *const *rhs) -> int { 382 return (*lhs)->getNamespace() < (*rhs)->getNamespace(); 383 }); 384 return result; 385 } 386 std::vector<StringRef> MLIRContext::getAvailableDialects() { 387 std::vector<StringRef> result; 388 for (auto dialect : impl->dialectsRegistry.getDialectNames()) 389 result.push_back(dialect); 390 return result; 391 } 392 393 /// Get a registered IR dialect with the given namespace. If none is found, 394 /// then return nullptr. 395 Dialect *MLIRContext::getLoadedDialect(StringRef name) { 396 // Dialects are sorted by name, so we can use binary search for lookup. 397 auto it = impl->loadedDialects.find(name); 398 return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr; 399 } 400 401 Dialect *MLIRContext::getOrLoadDialect(StringRef name) { 402 Dialect *dialect = getLoadedDialect(name); 403 if (dialect) 404 return dialect; 405 DialectAllocatorFunctionRef allocator = 406 impl->dialectsRegistry.getDialectAllocator(name); 407 return allocator ? allocator(this) : nullptr; 408 } 409 410 /// Get a dialect for the provided namespace and TypeID: abort the program if a 411 /// dialect exist for this namespace with different TypeID. Returns a pointer to 412 /// the dialect owned by the context. 413 Dialect * 414 MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID, 415 function_ref<std::unique_ptr<Dialect>()> ctor) { 416 auto &impl = getImpl(); 417 // Get the correct insertion position sorted by namespace. 418 auto dialectIt = impl.loadedDialects.find(dialectNamespace); 419 420 if (dialectIt == impl.loadedDialects.end()) { 421 LLVM_DEBUG(llvm::dbgs() 422 << "Load new dialect in Context " << dialectNamespace << "\n"); 423 #ifndef NDEBUG 424 if (impl.multiThreadedExecutionContext != 0) 425 llvm::report_fatal_error( 426 "Loading a dialect (" + dialectNamespace + 427 ") while in a multi-threaded execution context (maybe " 428 "the PassManager): this can indicate a " 429 "missing `dependentDialects` in a pass for example."); 430 #endif 431 std::unique_ptr<Dialect> &dialect = 432 impl.loadedDialects.insert({dialectNamespace, ctor()}).first->second; 433 assert(dialect && "dialect ctor failed"); 434 435 // Refresh all the identifiers dialect field, this catches cases where a 436 // dialect may be loaded after identifier prefixed with this dialect name 437 // were already created. 438 auto stringAttrsIt = impl.dialectReferencingStrAttrs.find(dialectNamespace); 439 if (stringAttrsIt != impl.dialectReferencingStrAttrs.end()) { 440 for (StringAttrStorage *storage : stringAttrsIt->second) 441 storage->referencedDialect = dialect.get(); 442 impl.dialectReferencingStrAttrs.erase(stringAttrsIt); 443 } 444 445 // Apply any extensions to this newly loaded dialect. 446 impl.dialectsRegistry.applyExtensions(dialect.get()); 447 return dialect.get(); 448 } 449 450 // Abort if dialect with namespace has already been registered. 451 std::unique_ptr<Dialect> &dialect = dialectIt->second; 452 if (dialect->getTypeID() != dialectID) 453 llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace + 454 "' has already been registered"); 455 456 return dialect.get(); 457 } 458 459 void MLIRContext::loadAllAvailableDialects() { 460 for (StringRef name : getAvailableDialects()) 461 getOrLoadDialect(name); 462 } 463 464 llvm::hash_code MLIRContext::getRegistryHash() { 465 llvm::hash_code hash(0); 466 // Factor in number of loaded dialects, attributes, operations, types. 467 hash = llvm::hash_combine(hash, impl->loadedDialects.size()); 468 hash = llvm::hash_combine(hash, impl->registeredAttributes.size()); 469 hash = llvm::hash_combine(hash, impl->registeredOperations.size()); 470 hash = llvm::hash_combine(hash, impl->registeredTypes.size()); 471 return hash; 472 } 473 474 bool MLIRContext::allowsUnregisteredDialects() { 475 return impl->allowUnregisteredDialects; 476 } 477 478 void MLIRContext::allowUnregisteredDialects(bool allowing) { 479 assert(impl->multiThreadedExecutionContext == 0 && 480 "changing MLIRContext `allow-unregistered-dialects` configuration " 481 "while in a multi-threaded execution context"); 482 impl->allowUnregisteredDialects = allowing; 483 } 484 485 /// Return true if multi-threading is enabled by the context. 486 bool MLIRContext::isMultithreadingEnabled() { 487 return impl->threadingIsEnabled && llvm::llvm_is_multithreaded(); 488 } 489 490 /// Set the flag specifying if multi-threading is disabled by the context. 491 void MLIRContext::disableMultithreading(bool disable) { 492 // This API can be overridden by the global debugging flag 493 // --mlir-disable-threading 494 if (isThreadingGloballyDisabled()) 495 return; 496 assert(impl->multiThreadedExecutionContext == 0 && 497 "changing MLIRContext `disable-threading` configuration while " 498 "in a multi-threaded execution context"); 499 500 impl->threadingIsEnabled = !disable; 501 502 // Update the threading mode for each of the uniquers. 503 impl->affineUniquer.disableMultithreading(disable); 504 impl->attributeUniquer.disableMultithreading(disable); 505 impl->typeUniquer.disableMultithreading(disable); 506 507 // Destroy thread pool (stop all threads) if it is no longer needed, or create 508 // a new one if multithreading was re-enabled. 509 if (disable) { 510 // If the thread pool is owned, explicitly set it to nullptr to avoid 511 // keeping a dangling pointer around. If the thread pool is externally 512 // owned, we don't do anything. 513 if (impl->ownedThreadPool) { 514 assert(impl->threadPool); 515 impl->threadPool = nullptr; 516 impl->ownedThreadPool.reset(); 517 } 518 } else if (!impl->threadPool) { 519 // The thread pool isn't externally provided. 520 assert(!impl->ownedThreadPool); 521 impl->ownedThreadPool = std::make_unique<llvm::ThreadPool>(); 522 impl->threadPool = impl->ownedThreadPool.get(); 523 } 524 } 525 526 void MLIRContext::setThreadPool(llvm::ThreadPool &pool) { 527 assert(!isMultithreadingEnabled() && 528 "expected multi-threading to be disabled when setting a ThreadPool"); 529 impl->threadPool = &pool; 530 impl->ownedThreadPool.reset(); 531 enableMultithreading(); 532 } 533 534 unsigned MLIRContext::getNumThreads() { 535 if (isMultithreadingEnabled()) { 536 assert(impl->threadPool && 537 "multi-threading is enabled but threadpool not set"); 538 return impl->threadPool->getThreadCount(); 539 } 540 // No multithreading or active thread pool. Return 1 thread. 541 return 1; 542 } 543 544 llvm::ThreadPool &MLIRContext::getThreadPool() { 545 assert(isMultithreadingEnabled() && 546 "expected multi-threading to be enabled within the context"); 547 assert(impl->threadPool && 548 "multi-threading is enabled but threadpool not set"); 549 return *impl->threadPool; 550 } 551 552 void MLIRContext::enterMultiThreadedExecution() { 553 #ifndef NDEBUG 554 ++impl->multiThreadedExecutionContext; 555 #endif 556 } 557 void MLIRContext::exitMultiThreadedExecution() { 558 #ifndef NDEBUG 559 --impl->multiThreadedExecutionContext; 560 #endif 561 } 562 563 /// Return true if we should attach the operation to diagnostics emitted via 564 /// Operation::emit. 565 bool MLIRContext::shouldPrintOpOnDiagnostic() { 566 return impl->printOpOnDiagnostic; 567 } 568 569 /// Set the flag specifying if we should attach the operation to diagnostics 570 /// emitted via Operation::emit. 571 void MLIRContext::printOpOnDiagnostic(bool enable) { 572 assert(impl->multiThreadedExecutionContext == 0 && 573 "changing MLIRContext `print-op-on-diagnostic` configuration while in " 574 "a multi-threaded execution context"); 575 impl->printOpOnDiagnostic = enable; 576 } 577 578 /// Return true if we should attach the current stacktrace to diagnostics when 579 /// emitted. 580 bool MLIRContext::shouldPrintStackTraceOnDiagnostic() { 581 return impl->printStackTraceOnDiagnostic; 582 } 583 584 /// Set the flag specifying if we should attach the current stacktrace when 585 /// emitting diagnostics. 586 void MLIRContext::printStackTraceOnDiagnostic(bool enable) { 587 assert(impl->multiThreadedExecutionContext == 0 && 588 "changing MLIRContext `print-stacktrace-on-diagnostic` configuration " 589 "while in a multi-threaded execution context"); 590 impl->printStackTraceOnDiagnostic = enable; 591 } 592 593 /// Return information about all registered operations. 594 ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() { 595 return impl->sortedRegisteredOperations; 596 } 597 598 bool MLIRContext::isOperationRegistered(StringRef name) { 599 return RegisteredOperationName::lookup(name, this).hasValue(); 600 } 601 602 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) { 603 auto &impl = context->getImpl(); 604 assert(impl.multiThreadedExecutionContext == 0 && 605 "Registering a new type kind while in a multi-threaded execution " 606 "context"); 607 auto *newInfo = 608 new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>()) 609 AbstractType(std::move(typeInfo)); 610 if (!impl.registeredTypes.insert({typeID, newInfo}).second) 611 llvm::report_fatal_error("Dialect Type already registered."); 612 } 613 614 void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) { 615 auto &impl = context->getImpl(); 616 assert(impl.multiThreadedExecutionContext == 0 && 617 "Registering a new attribute kind while in a multi-threaded execution " 618 "context"); 619 auto *newInfo = 620 new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>()) 621 AbstractAttribute(std::move(attrInfo)); 622 if (!impl.registeredAttributes.insert({typeID, newInfo}).second) 623 llvm::report_fatal_error("Dialect Attribute already registered."); 624 } 625 626 //===----------------------------------------------------------------------===// 627 // AbstractAttribute 628 //===----------------------------------------------------------------------===// 629 630 /// Get the dialect that registered the attribute with the provided typeid. 631 const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID, 632 MLIRContext *context) { 633 const AbstractAttribute *abstract = lookupMutable(typeID, context); 634 if (!abstract) 635 llvm::report_fatal_error("Trying to create an Attribute that was not " 636 "registered in this MLIRContext."); 637 return *abstract; 638 } 639 640 AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID, 641 MLIRContext *context) { 642 auto &impl = context->getImpl(); 643 auto it = impl.registeredAttributes.find(typeID); 644 if (it == impl.registeredAttributes.end()) 645 return nullptr; 646 return it->second; 647 } 648 649 //===----------------------------------------------------------------------===// 650 // OperationName 651 //===----------------------------------------------------------------------===// 652 653 OperationName::OperationName(StringRef name, MLIRContext *context) { 654 MLIRContextImpl &ctxImpl = context->getImpl(); 655 656 // Check for an existing name in read-only mode. 657 bool isMultithreadingEnabled = context->isMultithreadingEnabled(); 658 if (isMultithreadingEnabled) { 659 // Check the registered info map first. In the overwhelmingly common case, 660 // the entry will be in here and it also removes the need to acquire any 661 // locks. 662 auto registeredIt = ctxImpl.registeredOperations.find(name); 663 if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperations.end())) { 664 impl = registeredIt->second.impl; 665 return; 666 } 667 668 llvm::sys::SmartScopedReader<true> contextLock(ctxImpl.operationInfoMutex); 669 auto it = ctxImpl.operations.find(name); 670 if (it != ctxImpl.operations.end()) { 671 impl = &it->second; 672 return; 673 } 674 } 675 676 // Acquire a writer-lock so that we can safely create the new instance. 677 ScopedWriterLock lock(ctxImpl.operationInfoMutex, isMultithreadingEnabled); 678 679 auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)}); 680 if (it.second) 681 it.first->second.name = StringAttr::get(context, name); 682 impl = &it.first->second; 683 } 684 685 StringRef OperationName::getDialectNamespace() const { 686 if (Dialect *dialect = getDialect()) 687 return dialect->getNamespace(); 688 return getStringRef().split('.').first; 689 } 690 691 //===----------------------------------------------------------------------===// 692 // RegisteredOperationName 693 //===----------------------------------------------------------------------===// 694 695 Optional<RegisteredOperationName> 696 RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) { 697 auto &impl = ctx->getImpl(); 698 auto it = impl.registeredOperations.find(name); 699 if (it != impl.registeredOperations.end()) 700 return it->getValue(); 701 return llvm::None; 702 } 703 704 ParseResult 705 RegisteredOperationName::parseAssembly(OpAsmParser &parser, 706 OperationState &result) const { 707 return impl->parseAssemblyFn(parser, result); 708 } 709 710 void RegisteredOperationName::insert( 711 StringRef name, Dialect &dialect, TypeID typeID, 712 ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly, 713 VerifyInvariantsFn &&verifyInvariants, 714 VerifyRegionInvariantsFn &&verifyRegionInvariants, FoldHookFn &&foldHook, 715 GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, 716 detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait, 717 ArrayRef<StringRef> attrNames) { 718 MLIRContext *ctx = dialect.getContext(); 719 auto &ctxImpl = ctx->getImpl(); 720 assert(ctxImpl.multiThreadedExecutionContext == 0 && 721 "registering a new operation kind while in a multi-threaded execution " 722 "context"); 723 724 // Register the attribute names of this operation. 725 MutableArrayRef<StringAttr> cachedAttrNames; 726 if (!attrNames.empty()) { 727 cachedAttrNames = MutableArrayRef<StringAttr>( 728 ctxImpl.abstractDialectSymbolAllocator.Allocate<StringAttr>( 729 attrNames.size()), 730 attrNames.size()); 731 for (unsigned i : llvm::seq<unsigned>(0, attrNames.size())) 732 new (&cachedAttrNames[i]) StringAttr(StringAttr::get(ctx, attrNames[i])); 733 } 734 735 // Insert the operation info if it doesn't exist yet. 736 auto it = ctxImpl.operations.insert({name, OperationName::Impl(nullptr)}); 737 if (it.second) 738 it.first->second.name = StringAttr::get(ctx, name); 739 OperationName::Impl &impl = it.first->second; 740 741 if (impl.isRegistered()) { 742 llvm::errs() << "error: operation named '" << name 743 << "' is already registered.\n"; 744 abort(); 745 } 746 auto emplaced = ctxImpl.registeredOperations.try_emplace( 747 name, RegisteredOperationName(&impl)); 748 assert(emplaced.second && "operation name registration must be successful"); 749 750 // Add emplaced operation name to the sorted operations container. 751 RegisteredOperationName &value = emplaced.first->getValue(); 752 ctxImpl.sortedRegisteredOperations.insert( 753 llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value, 754 [](auto &lhs, auto &rhs) { 755 return lhs.getIdentifier().compare( 756 rhs.getIdentifier()); 757 }), 758 value); 759 760 // Update the registered info for this operation. 761 impl.dialect = &dialect; 762 impl.typeID = typeID; 763 impl.interfaceMap = std::move(interfaceMap); 764 impl.foldHookFn = std::move(foldHook); 765 impl.getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns); 766 impl.hasTraitFn = std::move(hasTrait); 767 impl.parseAssemblyFn = std::move(parseAssembly); 768 impl.printAssemblyFn = std::move(printAssembly); 769 impl.verifyInvariantsFn = std::move(verifyInvariants); 770 impl.verifyRegionInvariantsFn = std::move(verifyRegionInvariants); 771 impl.attributeNames = cachedAttrNames; 772 } 773 774 //===----------------------------------------------------------------------===// 775 // AbstractType 776 //===----------------------------------------------------------------------===// 777 778 const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) { 779 const AbstractType *type = lookupMutable(typeID, context); 780 if (!type) 781 llvm::report_fatal_error( 782 "Trying to create a Type that was not registered in this MLIRContext."); 783 return *type; 784 } 785 786 AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) { 787 auto &impl = context->getImpl(); 788 auto it = impl.registeredTypes.find(typeID); 789 if (it == impl.registeredTypes.end()) 790 return nullptr; 791 return it->second; 792 } 793 794 //===----------------------------------------------------------------------===// 795 // Type uniquing 796 //===----------------------------------------------------------------------===// 797 798 /// Returns the storage uniquer used for constructing type storage instances. 799 /// This should not be used directly. 800 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; } 801 802 BFloat16Type BFloat16Type::get(MLIRContext *context) { 803 return context->getImpl().bf16Ty; 804 } 805 Float16Type Float16Type::get(MLIRContext *context) { 806 return context->getImpl().f16Ty; 807 } 808 Float32Type Float32Type::get(MLIRContext *context) { 809 return context->getImpl().f32Ty; 810 } 811 Float64Type Float64Type::get(MLIRContext *context) { 812 return context->getImpl().f64Ty; 813 } 814 Float80Type Float80Type::get(MLIRContext *context) { 815 return context->getImpl().f80Ty; 816 } 817 Float128Type Float128Type::get(MLIRContext *context) { 818 return context->getImpl().f128Ty; 819 } 820 821 /// Get an instance of the IndexType. 822 IndexType IndexType::get(MLIRContext *context) { 823 return context->getImpl().indexTy; 824 } 825 826 /// Return an existing integer type instance if one is cached within the 827 /// context. 828 static IntegerType 829 getCachedIntegerType(unsigned width, 830 IntegerType::SignednessSemantics signedness, 831 MLIRContext *context) { 832 if (signedness != IntegerType::Signless) 833 return IntegerType(); 834 835 switch (width) { 836 case 1: 837 return context->getImpl().int1Ty; 838 case 8: 839 return context->getImpl().int8Ty; 840 case 16: 841 return context->getImpl().int16Ty; 842 case 32: 843 return context->getImpl().int32Ty; 844 case 64: 845 return context->getImpl().int64Ty; 846 case 128: 847 return context->getImpl().int128Ty; 848 default: 849 return IntegerType(); 850 } 851 } 852 853 IntegerType IntegerType::get(MLIRContext *context, unsigned width, 854 IntegerType::SignednessSemantics signedness) { 855 if (auto cached = getCachedIntegerType(width, signedness, context)) 856 return cached; 857 return Base::get(context, width, signedness); 858 } 859 860 IntegerType 861 IntegerType::getChecked(function_ref<InFlightDiagnostic()> emitError, 862 MLIRContext *context, unsigned width, 863 SignednessSemantics signedness) { 864 if (auto cached = getCachedIntegerType(width, signedness, context)) 865 return cached; 866 return Base::getChecked(emitError, context, width, signedness); 867 } 868 869 /// Get an instance of the NoneType. 870 NoneType NoneType::get(MLIRContext *context) { 871 if (NoneType cachedInst = context->getImpl().noneType) 872 return cachedInst; 873 // Note: May happen when initializing the singleton attributes of the builtin 874 // dialect. 875 return Base::get(context); 876 } 877 878 //===----------------------------------------------------------------------===// 879 // Attribute uniquing 880 //===----------------------------------------------------------------------===// 881 882 /// Returns the storage uniquer used for constructing attribute storage 883 /// instances. This should not be used directly. 884 StorageUniquer &MLIRContext::getAttributeUniquer() { 885 return getImpl().attributeUniquer; 886 } 887 888 /// Initialize the given attribute storage instance. 889 void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage, 890 MLIRContext *ctx, 891 TypeID attrID) { 892 storage->initializeAbstractAttribute(AbstractAttribute::lookup(attrID, ctx)); 893 894 // If the attribute did not provide a type, then default to NoneType. 895 if (!storage->getType()) 896 storage->setType(NoneType::get(ctx)); 897 } 898 899 BoolAttr BoolAttr::get(MLIRContext *context, bool value) { 900 return value ? context->getImpl().trueAttr : context->getImpl().falseAttr; 901 } 902 903 UnitAttr UnitAttr::get(MLIRContext *context) { 904 return context->getImpl().unitAttr; 905 } 906 907 UnknownLoc UnknownLoc::get(MLIRContext *context) { 908 return context->getImpl().unknownLocAttr; 909 } 910 911 /// Return empty dictionary. 912 DictionaryAttr DictionaryAttr::getEmpty(MLIRContext *context) { 913 return context->getImpl().emptyDictionaryAttr; 914 } 915 916 void StringAttrStorage::initialize(MLIRContext *context) { 917 // Check for a dialect namespace prefix, if there isn't one we don't need to 918 // do any additional initialization. 919 auto dialectNamePair = value.split('.'); 920 if (dialectNamePair.first.empty() || dialectNamePair.second.empty()) 921 return; 922 923 // If one exists, we check to see if this dialect is loaded. If it is, we set 924 // the dialect now, if it isn't we record this storage for initialization 925 // later if the dialect ever gets loaded. 926 if ((referencedDialect = context->getLoadedDialect(dialectNamePair.first))) 927 return; 928 929 MLIRContextImpl &impl = context->getImpl(); 930 llvm::sys::SmartScopedLock<true> lock(impl.dialectRefStrAttrMutex); 931 impl.dialectReferencingStrAttrs[dialectNamePair.first].push_back(this); 932 } 933 934 /// Return an empty string. 935 StringAttr StringAttr::get(MLIRContext *context) { 936 return context->getImpl().emptyStringAttr; 937 } 938 939 //===----------------------------------------------------------------------===// 940 // AffineMap uniquing 941 //===----------------------------------------------------------------------===// 942 943 StorageUniquer &MLIRContext::getAffineUniquer() { 944 return getImpl().affineUniquer; 945 } 946 947 AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount, 948 ArrayRef<AffineExpr> results, 949 MLIRContext *context) { 950 auto &impl = context->getImpl(); 951 auto *storage = impl.affineUniquer.get<AffineMapStorage>( 952 [&](AffineMapStorage *storage) { storage->context = context; }, dimCount, 953 symbolCount, results); 954 return AffineMap(storage); 955 } 956 957 /// Check whether the arguments passed to the AffineMap::get() are consistent. 958 /// This method checks whether the highest index of dimensional identifier 959 /// present in result expressions is less than `dimCount` and the highest index 960 /// of symbolic identifier present in result expressions is less than 961 /// `symbolCount`. 962 LLVM_ATTRIBUTE_UNUSED static bool 963 willBeValidAffineMap(unsigned dimCount, unsigned symbolCount, 964 ArrayRef<AffineExpr> results) { 965 int64_t maxDimPosition = -1; 966 int64_t maxSymbolPosition = -1; 967 getMaxDimAndSymbol(ArrayRef<ArrayRef<AffineExpr>>(results), maxDimPosition, 968 maxSymbolPosition); 969 if ((maxDimPosition >= dimCount) || (maxSymbolPosition >= symbolCount)) { 970 LLVM_DEBUG( 971 llvm::dbgs() 972 << "maximum dimensional identifier position in result expression must " 973 "be less than `dimCount` and maximum symbolic identifier position " 974 "in result expression must be less than `symbolCount`\n"); 975 return false; 976 } 977 return true; 978 } 979 980 AffineMap AffineMap::get(MLIRContext *context) { 981 return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context); 982 } 983 984 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, 985 MLIRContext *context) { 986 return getImpl(dimCount, symbolCount, /*results=*/{}, context); 987 } 988 989 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, 990 AffineExpr result) { 991 assert(willBeValidAffineMap(dimCount, symbolCount, {result})); 992 return getImpl(dimCount, symbolCount, {result}, result.getContext()); 993 } 994 995 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, 996 ArrayRef<AffineExpr> results, MLIRContext *context) { 997 assert(willBeValidAffineMap(dimCount, symbolCount, results)); 998 return getImpl(dimCount, symbolCount, results, context); 999 } 1000 1001 //===----------------------------------------------------------------------===// 1002 // Integer Sets: these are allocated into the bump pointer, and are immutable. 1003 // Unlike AffineMap's, these are uniqued only if they are small. 1004 //===----------------------------------------------------------------------===// 1005 1006 IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount, 1007 ArrayRef<AffineExpr> constraints, 1008 ArrayRef<bool> eqFlags) { 1009 // The number of constraints can't be zero. 1010 assert(!constraints.empty()); 1011 assert(constraints.size() == eqFlags.size()); 1012 1013 auto &impl = constraints[0].getContext()->getImpl(); 1014 auto *storage = impl.affineUniquer.get<IntegerSetStorage>( 1015 [](IntegerSetStorage *) {}, dimCount, symbolCount, constraints, eqFlags); 1016 return IntegerSet(storage); 1017 } 1018 1019 //===----------------------------------------------------------------------===// 1020 // StorageUniquerSupport 1021 //===----------------------------------------------------------------------===// 1022 1023 /// Utility method to generate a callback that can be used to generate a 1024 /// diagnostic when checking the construction invariants of a storage object. 1025 /// This is defined out-of-line to avoid the need to include Location.h. 1026 llvm::unique_function<InFlightDiagnostic()> 1027 mlir::detail::getDefaultDiagnosticEmitFn(MLIRContext *ctx) { 1028 return [ctx] { return emitError(UnknownLoc::get(ctx)); }; 1029 } 1030 llvm::unique_function<InFlightDiagnostic()> 1031 mlir::detail::getDefaultDiagnosticEmitFn(const Location &loc) { 1032 return [=] { return emitError(loc); }; 1033 } 1034