1 //===- MLIRContext.cpp - MLIR Type Classes --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/MLIRContext.h" 10 #include "AffineExprDetail.h" 11 #include "AffineMapDetail.h" 12 #include "AttributeDetail.h" 13 #include "IntegerSetDetail.h" 14 #include "TypeDetail.h" 15 #include "mlir/IR/AffineExpr.h" 16 #include "mlir/IR/AffineMap.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/BuiltinDialect.h" 19 #include "mlir/IR/Diagnostics.h" 20 #include "mlir/IR/Dialect.h" 21 #include "mlir/IR/Identifier.h" 22 #include "mlir/IR/IntegerSet.h" 23 #include "mlir/IR/Location.h" 24 #include "mlir/IR/OpImplementation.h" 25 #include "mlir/IR/Types.h" 26 #include "mlir/Support/DebugAction.h" 27 #include "llvm/ADT/DenseMap.h" 28 #include "llvm/ADT/DenseSet.h" 29 #include "llvm/ADT/SetVector.h" 30 #include "llvm/ADT/SmallString.h" 31 #include "llvm/ADT/StringSet.h" 32 #include "llvm/ADT/Twine.h" 33 #include "llvm/Support/Allocator.h" 34 #include "llvm/Support/CommandLine.h" 35 #include "llvm/Support/Debug.h" 36 #include "llvm/Support/RWMutex.h" 37 #include "llvm/Support/ThreadPool.h" 38 #include "llvm/Support/raw_ostream.h" 39 #include <memory> 40 41 #define DEBUG_TYPE "mlircontext" 42 43 using namespace mlir; 44 using namespace mlir::detail; 45 46 using llvm::hash_combine; 47 using llvm::hash_combine_range; 48 49 //===----------------------------------------------------------------------===// 50 // MLIRContext CommandLine Options 51 //===----------------------------------------------------------------------===// 52 53 namespace { 54 /// This struct contains command line options that can be used to initialize 55 /// various bits of an MLIRContext. This uses a struct wrapper to avoid the need 56 /// for global command line options. 57 struct MLIRContextOptions { 58 llvm::cl::opt<bool> disableThreading{ 59 "mlir-disable-threading", 60 llvm::cl::desc("Disabling multi-threading within MLIR")}; 61 62 llvm::cl::opt<bool> printOpOnDiagnostic{ 63 "mlir-print-op-on-diagnostic", 64 llvm::cl::desc("When a diagnostic is emitted on an operation, also print " 65 "the operation as an attached note"), 66 llvm::cl::init(true)}; 67 68 llvm::cl::opt<bool> printStackTraceOnDiagnostic{ 69 "mlir-print-stacktrace-on-diagnostic", 70 llvm::cl::desc("When a diagnostic is emitted, also print the stack trace " 71 "as an attached note")}; 72 }; 73 } // end anonymous namespace 74 75 static llvm::ManagedStatic<MLIRContextOptions> clOptions; 76 77 /// Register a set of useful command-line options that can be used to configure 78 /// various flags within the MLIRContext. These flags are used when constructing 79 /// an MLIR context for initialization. 80 void mlir::registerMLIRContextCLOptions() { 81 // Make sure that the options struct has been initialized. 82 *clOptions; 83 } 84 85 //===----------------------------------------------------------------------===// 86 // Locking Utilities 87 //===----------------------------------------------------------------------===// 88 89 namespace { 90 /// Utility reader lock that takes a runtime flag that specifies if we really 91 /// need to lock. 92 struct ScopedReaderLock { 93 ScopedReaderLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock) 94 : mutex(shouldLock ? &mutexParam : nullptr) { 95 if (mutex) 96 mutex->lock_shared(); 97 } 98 ~ScopedReaderLock() { 99 if (mutex) 100 mutex->unlock_shared(); 101 } 102 llvm::sys::SmartRWMutex<true> *mutex; 103 }; 104 /// Utility writer lock that takes a runtime flag that specifies if we really 105 /// need to lock. 106 struct ScopedWriterLock { 107 ScopedWriterLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock) 108 : mutex(shouldLock ? &mutexParam : nullptr) { 109 if (mutex) 110 mutex->lock(); 111 } 112 ~ScopedWriterLock() { 113 if (mutex) 114 mutex->unlock(); 115 } 116 llvm::sys::SmartRWMutex<true> *mutex; 117 }; 118 } // end anonymous namespace. 119 120 //===----------------------------------------------------------------------===// 121 // AffineMap and IntegerSet hashing 122 //===----------------------------------------------------------------------===// 123 124 /// A utility function to safely get or create a uniqued instance within the 125 /// given set container. 126 template <typename ValueT, typename DenseInfoT, typename KeyT, 127 typename ConstructorFn> 128 static ValueT safeGetOrCreate(DenseSet<ValueT, DenseInfoT> &container, 129 KeyT &&key, llvm::sys::SmartRWMutex<true> &mutex, 130 bool threadingIsEnabled, 131 ConstructorFn &&constructorFn) { 132 // Check for an existing instance in read-only mode. 133 if (threadingIsEnabled) { 134 llvm::sys::SmartScopedReader<true> instanceLock(mutex); 135 auto it = container.find_as(key); 136 if (it != container.end()) 137 return *it; 138 } 139 140 // Acquire a writer-lock so that we can safely create the new instance. 141 ScopedWriterLock instanceLock(mutex, threadingIsEnabled); 142 143 // Check for an existing instance again here, because another writer thread 144 // may have already created one. Otherwise, construct a new instance. 145 auto existing = container.insert_as(ValueT(), key); 146 if (existing.second) 147 return *existing.first = constructorFn(); 148 return *existing.first; 149 } 150 151 namespace { 152 struct AffineMapKeyInfo : DenseMapInfo<AffineMap> { 153 // Affine maps are uniqued based on their dim/symbol counts and affine 154 // expressions. 155 using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExpr>>; 156 using DenseMapInfo<AffineMap>::isEqual; 157 158 static unsigned getHashValue(const AffineMap &key) { 159 return getHashValue( 160 KeyTy(key.getNumDims(), key.getNumSymbols(), key.getResults())); 161 } 162 163 static unsigned getHashValue(KeyTy key) { 164 return hash_combine( 165 std::get<0>(key), std::get<1>(key), 166 hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end())); 167 } 168 169 static bool isEqual(const KeyTy &lhs, AffineMap rhs) { 170 if (rhs == getEmptyKey() || rhs == getTombstoneKey()) 171 return false; 172 return lhs == std::make_tuple(rhs.getNumDims(), rhs.getNumSymbols(), 173 rhs.getResults()); 174 } 175 }; 176 177 struct IntegerSetKeyInfo : DenseMapInfo<IntegerSet> { 178 // Integer sets are uniqued based on their dim/symbol counts, affine 179 // expressions appearing in the LHS of constraints, and eqFlags. 180 using KeyTy = 181 std::tuple<unsigned, unsigned, ArrayRef<AffineExpr>, ArrayRef<bool>>; 182 using DenseMapInfo<IntegerSet>::isEqual; 183 184 static unsigned getHashValue(const IntegerSet &key) { 185 return getHashValue(KeyTy(key.getNumDims(), key.getNumSymbols(), 186 key.getConstraints(), key.getEqFlags())); 187 } 188 189 static unsigned getHashValue(KeyTy key) { 190 return hash_combine( 191 std::get<0>(key), std::get<1>(key), 192 hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()), 193 hash_combine_range(std::get<3>(key).begin(), std::get<3>(key).end())); 194 } 195 196 static bool isEqual(const KeyTy &lhs, IntegerSet rhs) { 197 if (rhs == getEmptyKey() || rhs == getTombstoneKey()) 198 return false; 199 return lhs == std::make_tuple(rhs.getNumDims(), rhs.getNumSymbols(), 200 rhs.getConstraints(), rhs.getEqFlags()); 201 } 202 }; 203 } // end anonymous namespace. 204 205 //===----------------------------------------------------------------------===// 206 // MLIRContextImpl 207 //===----------------------------------------------------------------------===// 208 209 namespace mlir { 210 /// This is the implementation of the MLIRContext class, using the pImpl idiom. 211 /// This class is completely private to this file, so everything is public. 212 class MLIRContextImpl { 213 public: 214 //===--------------------------------------------------------------------===// 215 // Debugging 216 //===--------------------------------------------------------------------===// 217 218 /// An action manager for use within the context. 219 DebugActionManager debugActionManager; 220 221 //===--------------------------------------------------------------------===// 222 // Identifier uniquing 223 //===--------------------------------------------------------------------===// 224 225 // Identifier allocator and mutex for thread safety. 226 llvm::BumpPtrAllocator identifierAllocator; 227 llvm::sys::SmartRWMutex<true> identifierMutex; 228 229 //===--------------------------------------------------------------------===// 230 // Diagnostics 231 //===--------------------------------------------------------------------===// 232 DiagnosticEngine diagEngine; 233 234 //===--------------------------------------------------------------------===// 235 // Options 236 //===--------------------------------------------------------------------===// 237 238 /// In most cases, creating operation in unregistered dialect is not desired 239 /// and indicate a misconfiguration of the compiler. This option enables to 240 /// detect such use cases 241 bool allowUnregisteredDialects = false; 242 243 /// Enable support for multi-threading within MLIR. 244 bool threadingIsEnabled = true; 245 246 /// Track if we are currently executing in a threaded execution environment 247 /// (like the pass-manager): this is only a debugging feature to help reducing 248 /// the chances of data races one some context APIs. 249 #ifndef NDEBUG 250 std::atomic<int> multiThreadedExecutionContext{0}; 251 #endif 252 253 /// If the operation should be attached to diagnostics printed via the 254 /// Operation::emit methods. 255 bool printOpOnDiagnostic = true; 256 257 /// If the current stack trace should be attached when emitting diagnostics. 258 bool printStackTraceOnDiagnostic = false; 259 260 //===--------------------------------------------------------------------===// 261 // Other 262 //===--------------------------------------------------------------------===// 263 264 /// The thread pool to use when processing MLIR tasks in parallel. 265 llvm::ThreadPool threadPool; 266 267 /// This is a list of dialects that are created referring to this context. 268 /// The MLIRContext owns the objects. 269 DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects; 270 DialectRegistry dialectsRegistry; 271 272 /// This is a mapping from operation name to AbstractOperation for registered 273 /// operations. 274 llvm::StringMap<AbstractOperation> registeredOperations; 275 276 /// Identifiers are uniqued by string value and use the internal string set 277 /// for storage. 278 llvm::StringMap<PointerUnion<Dialect *, MLIRContext *>, 279 llvm::BumpPtrAllocator &> 280 identifiers; 281 282 /// An allocator used for AbstractAttribute and AbstractType objects. 283 llvm::BumpPtrAllocator abstractDialectSymbolAllocator; 284 285 //===--------------------------------------------------------------------===// 286 // Affine uniquing 287 //===--------------------------------------------------------------------===// 288 289 // Affine allocator and mutex for thread safety. 290 llvm::BumpPtrAllocator affineAllocator; 291 llvm::sys::SmartRWMutex<true> affineMutex; 292 293 // Affine map uniquing. 294 using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>; 295 AffineMapSet affineMaps; 296 297 // Integer set uniquing. 298 using IntegerSets = DenseSet<IntegerSet, IntegerSetKeyInfo>; 299 IntegerSets integerSets; 300 301 // Affine expression uniquing. 302 StorageUniquer affineUniquer; 303 304 //===--------------------------------------------------------------------===// 305 // Type uniquing 306 //===--------------------------------------------------------------------===// 307 308 DenseMap<TypeID, AbstractType *> registeredTypes; 309 StorageUniquer typeUniquer; 310 311 /// Cached Type Instances. 312 BFloat16Type bf16Ty; 313 Float16Type f16Ty; 314 Float32Type f32Ty; 315 Float64Type f64Ty; 316 Float80Type f80Ty; 317 Float128Type f128Ty; 318 IndexType indexTy; 319 IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty; 320 NoneType noneType; 321 322 //===--------------------------------------------------------------------===// 323 // Attribute uniquing 324 //===--------------------------------------------------------------------===// 325 326 DenseMap<TypeID, AbstractAttribute *> registeredAttributes; 327 StorageUniquer attributeUniquer; 328 329 /// Cached Attribute Instances. 330 BoolAttr falseAttr, trueAttr; 331 UnitAttr unitAttr; 332 UnknownLoc unknownLocAttr; 333 DictionaryAttr emptyDictionaryAttr; 334 StringAttr emptyStringAttr; 335 336 public: 337 MLIRContextImpl() : identifiers(identifierAllocator) {} 338 ~MLIRContextImpl() { 339 for (auto typeMapping : registeredTypes) 340 typeMapping.second->~AbstractType(); 341 for (auto attrMapping : registeredAttributes) 342 attrMapping.second->~AbstractAttribute(); 343 } 344 }; 345 } // end namespace mlir 346 347 MLIRContext::MLIRContext() : MLIRContext(DialectRegistry()) {} 348 349 MLIRContext::MLIRContext(const DialectRegistry ®istry) 350 : impl(new MLIRContextImpl) { 351 // Initialize values based on the command line flags if they were provided. 352 if (clOptions.isConstructed()) { 353 disableMultithreading(clOptions->disableThreading); 354 printOpOnDiagnostic(clOptions->printOpOnDiagnostic); 355 printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic); 356 } 357 358 // Pre-populate the registry. 359 registry.appendTo(impl->dialectsRegistry); 360 361 // Ensure the builtin dialect is always pre-loaded. 362 getOrLoadDialect<BuiltinDialect>(); 363 364 // Initialize several common attributes and types to avoid the need to lock 365 // the context when accessing them. 366 367 //// Types. 368 /// Floating-point Types. 369 impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this); 370 impl->f16Ty = TypeUniquer::get<Float16Type>(this); 371 impl->f32Ty = TypeUniquer::get<Float32Type>(this); 372 impl->f64Ty = TypeUniquer::get<Float64Type>(this); 373 impl->f80Ty = TypeUniquer::get<Float80Type>(this); 374 impl->f128Ty = TypeUniquer::get<Float128Type>(this); 375 /// Index Type. 376 impl->indexTy = TypeUniquer::get<IndexType>(this); 377 /// Integer Types. 378 impl->int1Ty = TypeUniquer::get<IntegerType>(this, 1, IntegerType::Signless); 379 impl->int8Ty = TypeUniquer::get<IntegerType>(this, 8, IntegerType::Signless); 380 impl->int16Ty = 381 TypeUniquer::get<IntegerType>(this, 16, IntegerType::Signless); 382 impl->int32Ty = 383 TypeUniquer::get<IntegerType>(this, 32, IntegerType::Signless); 384 impl->int64Ty = 385 TypeUniquer::get<IntegerType>(this, 64, IntegerType::Signless); 386 impl->int128Ty = 387 TypeUniquer::get<IntegerType>(this, 128, IntegerType::Signless); 388 /// None Type. 389 impl->noneType = TypeUniquer::get<NoneType>(this); 390 391 //// Attributes. 392 //// Note: These must be registered after the types as they may generate one 393 //// of the above types internally. 394 /// Unknown Location Attribute. 395 impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(this); 396 /// Bool Attributes. 397 impl->falseAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, false); 398 impl->trueAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, true); 399 /// Unit Attribute. 400 impl->unitAttr = AttributeUniquer::get<UnitAttr>(this); 401 /// The empty dictionary attribute. 402 impl->emptyDictionaryAttr = DictionaryAttr::getEmptyUnchecked(this); 403 /// The empty string attribute. 404 impl->emptyStringAttr = StringAttr::getEmptyStringAttrUnchecked(this); 405 406 // Register the affine storage objects with the uniquer. 407 impl->affineUniquer 408 .registerParametricStorageType<AffineBinaryOpExprStorage>(); 409 impl->affineUniquer 410 .registerParametricStorageType<AffineConstantExprStorage>(); 411 impl->affineUniquer.registerParametricStorageType<AffineDimExprStorage>(); 412 } 413 414 MLIRContext::~MLIRContext() {} 415 416 /// Copy the specified array of elements into memory managed by the provided 417 /// bump pointer allocator. This assumes the elements are all PODs. 418 template <typename T> 419 static ArrayRef<T> copyArrayRefInto(llvm::BumpPtrAllocator &allocator, 420 ArrayRef<T> elements) { 421 auto result = allocator.Allocate<T>(elements.size()); 422 std::uninitialized_copy(elements.begin(), elements.end(), result); 423 return ArrayRef<T>(result, elements.size()); 424 } 425 426 //===----------------------------------------------------------------------===// 427 // Debugging 428 //===----------------------------------------------------------------------===// 429 430 DebugActionManager &MLIRContext::getDebugActionManager() { 431 return getImpl().debugActionManager; 432 } 433 434 //===----------------------------------------------------------------------===// 435 // Diagnostic Handlers 436 //===----------------------------------------------------------------------===// 437 438 /// Returns the diagnostic engine for this context. 439 DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; } 440 441 //===----------------------------------------------------------------------===// 442 // Dialect and Operation Registration 443 //===----------------------------------------------------------------------===// 444 445 void MLIRContext::appendDialectRegistry(const DialectRegistry ®istry) { 446 registry.appendTo(impl->dialectsRegistry); 447 448 // For the already loaded dialects, register the interfaces immediately. 449 for (const auto &kvp : impl->loadedDialects) 450 registry.registerDelayedInterfaces(kvp.second.get()); 451 } 452 453 const DialectRegistry &MLIRContext::getDialectRegistry() { 454 return impl->dialectsRegistry; 455 } 456 457 /// Return information about all registered IR dialects. 458 std::vector<Dialect *> MLIRContext::getLoadedDialects() { 459 std::vector<Dialect *> result; 460 result.reserve(impl->loadedDialects.size()); 461 for (auto &dialect : impl->loadedDialects) 462 result.push_back(dialect.second.get()); 463 llvm::array_pod_sort(result.begin(), result.end(), 464 [](Dialect *const *lhs, Dialect *const *rhs) -> int { 465 return (*lhs)->getNamespace() < (*rhs)->getNamespace(); 466 }); 467 return result; 468 } 469 std::vector<StringRef> MLIRContext::getAvailableDialects() { 470 std::vector<StringRef> result; 471 for (auto dialect : impl->dialectsRegistry.getDialectNames()) 472 result.push_back(dialect); 473 return result; 474 } 475 476 /// Get a registered IR dialect with the given namespace. If none is found, 477 /// then return nullptr. 478 Dialect *MLIRContext::getLoadedDialect(StringRef name) { 479 // Dialects are sorted by name, so we can use binary search for lookup. 480 auto it = impl->loadedDialects.find(name); 481 return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr; 482 } 483 484 Dialect *MLIRContext::getOrLoadDialect(StringRef name) { 485 Dialect *dialect = getLoadedDialect(name); 486 if (dialect) 487 return dialect; 488 DialectAllocatorFunctionRef allocator = 489 impl->dialectsRegistry.getDialectAllocator(name); 490 return allocator ? allocator(this) : nullptr; 491 } 492 493 /// Get a dialect for the provided namespace and TypeID: abort the program if a 494 /// dialect exist for this namespace with different TypeID. Returns a pointer to 495 /// the dialect owned by the context. 496 Dialect * 497 MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID, 498 function_ref<std::unique_ptr<Dialect>()> ctor) { 499 auto &impl = getImpl(); 500 // Get the correct insertion position sorted by namespace. 501 std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace]; 502 503 if (!dialect) { 504 LLVM_DEBUG(llvm::dbgs() 505 << "Load new dialect in Context " << dialectNamespace << "\n"); 506 #ifndef NDEBUG 507 if (impl.multiThreadedExecutionContext != 0) 508 llvm::report_fatal_error( 509 "Loading a dialect (" + dialectNamespace + 510 ") while in a multi-threaded execution context (maybe " 511 "the PassManager): this can indicate a " 512 "missing `dependentDialects` in a pass for example."); 513 #endif 514 dialect = ctor(); 515 assert(dialect && "dialect ctor failed"); 516 517 // Refresh all the identifiers dialect field, this catches cases where a 518 // dialect may be loaded after identifier prefixed with this dialect name 519 // were already created. 520 llvm::SmallString<32> dialectPrefix(dialectNamespace); 521 dialectPrefix.push_back('.'); 522 for (auto &identifierEntry : impl.identifiers) 523 if (identifierEntry.second.is<MLIRContext *>() && 524 identifierEntry.first().startswith(dialectPrefix)) 525 identifierEntry.second = dialect.get(); 526 527 // Actually register the interfaces with delayed registration. 528 impl.dialectsRegistry.registerDelayedInterfaces(dialect.get()); 529 return dialect.get(); 530 } 531 532 // Abort if dialect with namespace has already been registered. 533 if (dialect->getTypeID() != dialectID) 534 llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace + 535 "' has already been registered"); 536 537 return dialect.get(); 538 } 539 540 void MLIRContext::loadAllAvailableDialects() { 541 for (StringRef name : getAvailableDialects()) 542 getOrLoadDialect(name); 543 } 544 545 llvm::hash_code MLIRContext::getRegistryHash() { 546 llvm::hash_code hash(0); 547 // Factor in number of loaded dialects, attributes, operations, types. 548 hash = llvm::hash_combine(hash, impl->loadedDialects.size()); 549 hash = llvm::hash_combine(hash, impl->registeredAttributes.size()); 550 hash = llvm::hash_combine(hash, impl->registeredOperations.size()); 551 hash = llvm::hash_combine(hash, impl->registeredTypes.size()); 552 return hash; 553 } 554 555 bool MLIRContext::allowsUnregisteredDialects() { 556 return impl->allowUnregisteredDialects; 557 } 558 559 void MLIRContext::allowUnregisteredDialects(bool allowing) { 560 impl->allowUnregisteredDialects = allowing; 561 } 562 563 /// Return true if multi-threading is enabled by the context. 564 bool MLIRContext::isMultithreadingEnabled() { 565 return impl->threadingIsEnabled && llvm::llvm_is_multithreaded(); 566 } 567 568 /// Set the flag specifying if multi-threading is disabled by the context. 569 void MLIRContext::disableMultithreading(bool disable) { 570 impl->threadingIsEnabled = !disable; 571 572 // Update the threading mode for each of the uniquers. 573 impl->affineUniquer.disableMultithreading(disable); 574 impl->attributeUniquer.disableMultithreading(disable); 575 impl->typeUniquer.disableMultithreading(disable); 576 } 577 578 llvm::ThreadPool &MLIRContext::getThreadPool() { 579 assert(isMultithreadingEnabled() && 580 "expected multi-threading to be enabled within the context"); 581 return impl->threadPool; 582 } 583 584 void MLIRContext::enterMultiThreadedExecution() { 585 #ifndef NDEBUG 586 ++impl->multiThreadedExecutionContext; 587 #endif 588 } 589 void MLIRContext::exitMultiThreadedExecution() { 590 #ifndef NDEBUG 591 --impl->multiThreadedExecutionContext; 592 #endif 593 } 594 595 /// Return true if we should attach the operation to diagnostics emitted via 596 /// Operation::emit. 597 bool MLIRContext::shouldPrintOpOnDiagnostic() { 598 return impl->printOpOnDiagnostic; 599 } 600 601 /// Set the flag specifying if we should attach the operation to diagnostics 602 /// emitted via Operation::emit. 603 void MLIRContext::printOpOnDiagnostic(bool enable) { 604 impl->printOpOnDiagnostic = enable; 605 } 606 607 /// Return true if we should attach the current stacktrace to diagnostics when 608 /// emitted. 609 bool MLIRContext::shouldPrintStackTraceOnDiagnostic() { 610 return impl->printStackTraceOnDiagnostic; 611 } 612 613 /// Set the flag specifying if we should attach the current stacktrace when 614 /// emitting diagnostics. 615 void MLIRContext::printStackTraceOnDiagnostic(bool enable) { 616 impl->printStackTraceOnDiagnostic = enable; 617 } 618 619 /// Return information about all registered operations. This isn't very 620 /// efficient, typically you should ask the operations about their properties 621 /// directly. 622 std::vector<AbstractOperation *> MLIRContext::getRegisteredOperations() { 623 // We just have the operations in a non-deterministic hash table order. Dump 624 // into a temporary array, then sort it by operation name to get a stable 625 // ordering. 626 llvm::StringMap<AbstractOperation> ®isteredOps = 627 impl->registeredOperations; 628 629 std::vector<AbstractOperation *> result; 630 result.reserve(registeredOps.size()); 631 for (auto &elt : registeredOps) 632 result.push_back(&elt.second); 633 llvm::array_pod_sort( 634 result.begin(), result.end(), 635 [](AbstractOperation *const *lhs, AbstractOperation *const *rhs) { 636 return (*lhs)->name.compare((*rhs)->name); 637 }); 638 639 return result; 640 } 641 642 bool MLIRContext::isOperationRegistered(StringRef name) { 643 return impl->registeredOperations.count(name); 644 } 645 646 void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) { 647 auto &impl = context->getImpl(); 648 assert(impl.multiThreadedExecutionContext == 0 && 649 "Registering a new type kind while in a multi-threaded execution " 650 "context"); 651 auto *newInfo = 652 new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>()) 653 AbstractType(std::move(typeInfo)); 654 if (!impl.registeredTypes.insert({typeID, newInfo}).second) 655 llvm::report_fatal_error("Dialect Type already registered."); 656 } 657 658 void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) { 659 auto &impl = context->getImpl(); 660 assert(impl.multiThreadedExecutionContext == 0 && 661 "Registering a new attribute kind while in a multi-threaded execution " 662 "context"); 663 auto *newInfo = 664 new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>()) 665 AbstractAttribute(std::move(attrInfo)); 666 if (!impl.registeredAttributes.insert({typeID, newInfo}).second) 667 llvm::report_fatal_error("Dialect Attribute already registered."); 668 } 669 670 //===----------------------------------------------------------------------===// 671 // AbstractAttribute 672 //===----------------------------------------------------------------------===// 673 674 /// Get the dialect that registered the attribute with the provided typeid. 675 const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID, 676 MLIRContext *context) { 677 const AbstractAttribute *abstract = lookupMutable(typeID, context); 678 if (!abstract) 679 llvm::report_fatal_error("Trying to create an Attribute that was not " 680 "registered in this MLIRContext."); 681 return *abstract; 682 } 683 684 AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID, 685 MLIRContext *context) { 686 auto &impl = context->getImpl(); 687 auto it = impl.registeredAttributes.find(typeID); 688 if (it == impl.registeredAttributes.end()) 689 return nullptr; 690 return it->second; 691 } 692 693 //===----------------------------------------------------------------------===// 694 // AbstractOperation 695 //===----------------------------------------------------------------------===// 696 697 ParseResult AbstractOperation::parseAssembly(OpAsmParser &parser, 698 OperationState &result) const { 699 return parseAssemblyFn(parser, result); 700 } 701 702 /// Look up the specified operation in the operation set and return a pointer 703 /// to it if present. Otherwise, return a null pointer. 704 AbstractOperation *AbstractOperation::lookupMutable(StringRef opName, 705 MLIRContext *context) { 706 auto &impl = context->getImpl(); 707 auto it = impl.registeredOperations.find(opName); 708 if (it != impl.registeredOperations.end()) 709 return &it->second; 710 return nullptr; 711 } 712 713 void AbstractOperation::insert( 714 StringRef name, Dialect &dialect, TypeID typeID, 715 ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly, 716 VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook, 717 GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, 718 detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait, 719 ArrayRef<StringRef> attrNames) { 720 MLIRContext *ctx = dialect.getContext(); 721 auto &impl = ctx->getImpl(); 722 assert(impl.multiThreadedExecutionContext == 0 && 723 "Registering a new operation kind while in a multi-threaded execution " 724 "context"); 725 726 // Register the attribute names of this operation. 727 MutableArrayRef<Identifier> cachedAttrNames; 728 if (!attrNames.empty()) { 729 cachedAttrNames = MutableArrayRef<Identifier>( 730 impl.identifierAllocator.Allocate<Identifier>(attrNames.size()), 731 attrNames.size()); 732 for (unsigned i : llvm::seq<unsigned>(0, attrNames.size())) 733 new (&cachedAttrNames[i]) Identifier(Identifier::get(attrNames[i], ctx)); 734 } 735 736 // Register the information for this operation. 737 AbstractOperation opInfo( 738 name, dialect, typeID, std::move(parseAssembly), std::move(printAssembly), 739 std::move(verifyInvariants), std::move(foldHook), 740 std::move(getCanonicalizationPatterns), std::move(interfaceMap), 741 std::move(hasTrait), cachedAttrNames); 742 if (!impl.registeredOperations.insert({name, std::move(opInfo)}).second) { 743 llvm::errs() << "error: operation named '" << name 744 << "' is already registered.\n"; 745 abort(); 746 } 747 } 748 749 AbstractOperation::AbstractOperation( 750 StringRef name, Dialect &dialect, TypeID typeID, 751 ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly, 752 VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook, 753 GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, 754 detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait, 755 ArrayRef<Identifier> attrNames) 756 : name(Identifier::get(name, dialect.getContext())), dialect(dialect), 757 typeID(typeID), interfaceMap(std::move(interfaceMap)), 758 foldHookFn(std::move(foldHook)), 759 getCanonicalizationPatternsFn(std::move(getCanonicalizationPatterns)), 760 hasTraitFn(std::move(hasTrait)), 761 parseAssemblyFn(std::move(parseAssembly)), 762 printAssemblyFn(std::move(printAssembly)), 763 verifyInvariantsFn(std::move(verifyInvariants)), 764 attributeNames(attrNames) {} 765 766 //===----------------------------------------------------------------------===// 767 // AbstractType 768 //===----------------------------------------------------------------------===// 769 770 const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) { 771 const AbstractType *type = lookupMutable(typeID, context); 772 if (!type) 773 llvm::report_fatal_error( 774 "Trying to create a Type that was not registered in this MLIRContext."); 775 return *type; 776 } 777 778 AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) { 779 auto &impl = context->getImpl(); 780 auto it = impl.registeredTypes.find(typeID); 781 if (it == impl.registeredTypes.end()) 782 return nullptr; 783 return it->second; 784 } 785 786 //===----------------------------------------------------------------------===// 787 // Identifier uniquing 788 //===----------------------------------------------------------------------===// 789 790 /// Return an identifier for the specified string. 791 Identifier Identifier::get(const Twine &string, MLIRContext *context) { 792 SmallString<32> tempStr; 793 StringRef str = string.toStringRef(tempStr); 794 795 // Check invariants after seeing if we already have something in the 796 // identifier table - if we already had it in the table, then it already 797 // passed invariant checks. 798 assert(!str.empty() && "Cannot create an empty identifier"); 799 assert(str.find('\0') == StringRef::npos && 800 "Cannot create an identifier with a nul character"); 801 802 auto getDialectOrContext = [&]() { 803 PointerUnion<Dialect *, MLIRContext *> dialectOrContext = context; 804 auto dialectNamePair = str.split('.'); 805 if (!dialectNamePair.first.empty()) 806 if (Dialect *dialect = context->getLoadedDialect(dialectNamePair.first)) 807 dialectOrContext = dialect; 808 return dialectOrContext; 809 }; 810 811 auto &impl = context->getImpl(); 812 if (!context->isMultithreadingEnabled()) { 813 auto insertedIt = impl.identifiers.insert({str, nullptr}); 814 if (insertedIt.second) 815 insertedIt.first->second = getDialectOrContext(); 816 return Identifier(&*insertedIt.first); 817 } 818 819 // Check for an existing identifier in read-only mode. 820 { 821 llvm::sys::SmartScopedReader<true> contextLock(impl.identifierMutex); 822 auto it = impl.identifiers.find(str); 823 if (it != impl.identifiers.end()) 824 return Identifier(&*it); 825 } 826 827 // Acquire a writer-lock so that we can safely create the new instance. 828 llvm::sys::SmartScopedWriter<true> contextLock(impl.identifierMutex); 829 auto it = impl.identifiers.insert({str, getDialectOrContext()}).first; 830 return Identifier(&*it); 831 } 832 833 Dialect *Identifier::getDialect() { 834 return entry->second.dyn_cast<Dialect *>(); 835 } 836 837 MLIRContext *Identifier::getContext() { 838 if (Dialect *dialect = getDialect()) 839 return dialect->getContext(); 840 return entry->second.get<MLIRContext *>(); 841 } 842 843 //===----------------------------------------------------------------------===// 844 // Type uniquing 845 //===----------------------------------------------------------------------===// 846 847 /// Returns the storage uniquer used for constructing type storage instances. 848 /// This should not be used directly. 849 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; } 850 851 BFloat16Type BFloat16Type::get(MLIRContext *context) { 852 return context->getImpl().bf16Ty; 853 } 854 Float16Type Float16Type::get(MLIRContext *context) { 855 return context->getImpl().f16Ty; 856 } 857 Float32Type Float32Type::get(MLIRContext *context) { 858 return context->getImpl().f32Ty; 859 } 860 Float64Type Float64Type::get(MLIRContext *context) { 861 return context->getImpl().f64Ty; 862 } 863 Float80Type Float80Type::get(MLIRContext *context) { 864 return context->getImpl().f80Ty; 865 } 866 Float128Type Float128Type::get(MLIRContext *context) { 867 return context->getImpl().f128Ty; 868 } 869 870 /// Get an instance of the IndexType. 871 IndexType IndexType::get(MLIRContext *context) { 872 return context->getImpl().indexTy; 873 } 874 875 /// Return an existing integer type instance if one is cached within the 876 /// context. 877 static IntegerType 878 getCachedIntegerType(unsigned width, 879 IntegerType::SignednessSemantics signedness, 880 MLIRContext *context) { 881 if (signedness != IntegerType::Signless) 882 return IntegerType(); 883 884 switch (width) { 885 case 1: 886 return context->getImpl().int1Ty; 887 case 8: 888 return context->getImpl().int8Ty; 889 case 16: 890 return context->getImpl().int16Ty; 891 case 32: 892 return context->getImpl().int32Ty; 893 case 64: 894 return context->getImpl().int64Ty; 895 case 128: 896 return context->getImpl().int128Ty; 897 default: 898 return IntegerType(); 899 } 900 } 901 902 IntegerType IntegerType::get(MLIRContext *context, unsigned width, 903 IntegerType::SignednessSemantics signedness) { 904 if (auto cached = getCachedIntegerType(width, signedness, context)) 905 return cached; 906 return Base::get(context, width, signedness); 907 } 908 909 IntegerType 910 IntegerType::getChecked(function_ref<InFlightDiagnostic()> emitError, 911 MLIRContext *context, unsigned width, 912 SignednessSemantics signedness) { 913 if (auto cached = getCachedIntegerType(width, signedness, context)) 914 return cached; 915 return Base::getChecked(emitError, context, width, signedness); 916 } 917 918 /// Get an instance of the NoneType. 919 NoneType NoneType::get(MLIRContext *context) { 920 if (NoneType cachedInst = context->getImpl().noneType) 921 return cachedInst; 922 // Note: May happen when initializing the singleton attributes of the builtin 923 // dialect. 924 return Base::get(context); 925 } 926 927 //===----------------------------------------------------------------------===// 928 // Attribute uniquing 929 //===----------------------------------------------------------------------===// 930 931 /// Returns the storage uniquer used for constructing attribute storage 932 /// instances. This should not be used directly. 933 StorageUniquer &MLIRContext::getAttributeUniquer() { 934 return getImpl().attributeUniquer; 935 } 936 937 /// Initialize the given attribute storage instance. 938 void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage, 939 MLIRContext *ctx, 940 TypeID attrID) { 941 storage->initialize(AbstractAttribute::lookup(attrID, ctx)); 942 943 // If the attribute did not provide a type, then default to NoneType. 944 if (!storage->getType()) 945 storage->setType(NoneType::get(ctx)); 946 } 947 948 BoolAttr BoolAttr::get(MLIRContext *context, bool value) { 949 return value ? context->getImpl().trueAttr : context->getImpl().falseAttr; 950 } 951 952 UnitAttr UnitAttr::get(MLIRContext *context) { 953 return context->getImpl().unitAttr; 954 } 955 956 UnknownLoc UnknownLoc::get(MLIRContext *context) { 957 return context->getImpl().unknownLocAttr; 958 } 959 960 /// Return empty dictionary. 961 DictionaryAttr DictionaryAttr::getEmpty(MLIRContext *context) { 962 return context->getImpl().emptyDictionaryAttr; 963 } 964 965 /// Return an empty string. 966 StringAttr StringAttr::get(MLIRContext *context) { 967 return context->getImpl().emptyStringAttr; 968 } 969 970 //===----------------------------------------------------------------------===// 971 // AffineMap uniquing 972 //===----------------------------------------------------------------------===// 973 974 StorageUniquer &MLIRContext::getAffineUniquer() { 975 return getImpl().affineUniquer; 976 } 977 978 AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount, 979 ArrayRef<AffineExpr> results, 980 MLIRContext *context) { 981 auto &impl = context->getImpl(); 982 auto key = std::make_tuple(dimCount, symbolCount, results); 983 984 // Safely get or create an AffineMap instance. 985 return safeGetOrCreate( 986 impl.affineMaps, key, impl.affineMutex, impl.threadingIsEnabled, [&] { 987 auto *res = impl.affineAllocator.Allocate<detail::AffineMapStorage>(); 988 989 // Copy the results into the bump pointer. 990 results = copyArrayRefInto(impl.affineAllocator, results); 991 992 // Initialize the memory using placement new. 993 new (res) 994 detail::AffineMapStorage{dimCount, symbolCount, results, context}; 995 return AffineMap(res); 996 }); 997 } 998 999 AffineMap AffineMap::get(MLIRContext *context) { 1000 return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context); 1001 } 1002 1003 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, 1004 MLIRContext *context) { 1005 return getImpl(dimCount, symbolCount, /*results=*/{}, context); 1006 } 1007 1008 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, 1009 AffineExpr result) { 1010 return getImpl(dimCount, symbolCount, {result}, result.getContext()); 1011 } 1012 1013 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount, 1014 ArrayRef<AffineExpr> results, MLIRContext *context) { 1015 return getImpl(dimCount, symbolCount, results, context); 1016 } 1017 1018 //===----------------------------------------------------------------------===// 1019 // Integer Sets: these are allocated into the bump pointer, and are immutable. 1020 // Unlike AffineMap's, these are uniqued only if they are small. 1021 //===----------------------------------------------------------------------===// 1022 1023 IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount, 1024 ArrayRef<AffineExpr> constraints, 1025 ArrayRef<bool> eqFlags) { 1026 // The number of constraints can't be zero. 1027 assert(!constraints.empty()); 1028 assert(constraints.size() == eqFlags.size()); 1029 1030 auto &impl = constraints[0].getContext()->getImpl(); 1031 1032 // A utility function to construct a new IntegerSetStorage instance. 1033 auto constructorFn = [&] { 1034 auto *res = impl.affineAllocator.Allocate<detail::IntegerSetStorage>(); 1035 1036 // Copy the results and equality flags into the bump pointer. 1037 constraints = copyArrayRefInto(impl.affineAllocator, constraints); 1038 eqFlags = copyArrayRefInto(impl.affineAllocator, eqFlags); 1039 1040 // Initialize the memory using placement new. 1041 new (res) 1042 detail::IntegerSetStorage{dimCount, symbolCount, constraints, eqFlags}; 1043 return IntegerSet(res); 1044 }; 1045 1046 // If this instance is uniqued, then we handle it separately so that multiple 1047 // threads may simultaneously access existing instances. 1048 if (constraints.size() < IntegerSet::kUniquingThreshold) { 1049 auto key = std::make_tuple(dimCount, symbolCount, constraints, eqFlags); 1050 return safeGetOrCreate(impl.integerSets, key, impl.affineMutex, 1051 impl.threadingIsEnabled, constructorFn); 1052 } 1053 1054 // Otherwise, acquire a writer-lock so that we can safely create the new 1055 // instance. 1056 ScopedWriterLock affineLock(impl.affineMutex, impl.threadingIsEnabled); 1057 return constructorFn(); 1058 } 1059 1060 //===----------------------------------------------------------------------===// 1061 // StorageUniquerSupport 1062 //===----------------------------------------------------------------------===// 1063 1064 /// Utility method to generate a callback that can be used to generate a 1065 /// diagnostic when checking the construction invariants of a storage object. 1066 /// This is defined out-of-line to avoid the need to include Location.h. 1067 llvm::unique_function<InFlightDiagnostic()> 1068 mlir::detail::getDefaultDiagnosticEmitFn(MLIRContext *ctx) { 1069 return [ctx] { return emitError(UnknownLoc::get(ctx)); }; 1070 } 1071 llvm::unique_function<InFlightDiagnostic()> 1072 mlir::detail::getDefaultDiagnosticEmitFn(const Location &loc) { 1073 return [=] { return emitError(loc); }; 1074 } 1075