1 //===- Dialect.cpp - Dialect implementation -------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Dialect.h" 10 #include "mlir/IR/Diagnostics.h" 11 #include "mlir/IR/DialectImplementation.h" 12 #include "mlir/IR/DialectInterface.h" 13 #include "mlir/IR/MLIRContext.h" 14 #include "mlir/IR/Operation.h" 15 #include "llvm/ADT/MapVector.h" 16 #include "llvm/ADT/Twine.h" 17 #include "llvm/Support/ManagedStatic.h" 18 #include "llvm/Support/Regex.h" 19 20 using namespace mlir; 21 using namespace detail; 22 23 DialectAsmParser::~DialectAsmParser() {} 24 25 //===----------------------------------------------------------------------===// 26 // Dialect Registration (DEPRECATED) 27 //===----------------------------------------------------------------------===// 28 29 /// Registry for all dialect allocation functions. 30 static llvm::ManagedStatic<DialectRegistry> dialectRegistry; 31 DialectRegistry &mlir::getGlobalDialectRegistry() { return *dialectRegistry; } 32 33 // Note: deprecated, will be removed soon. 34 static bool isGlobalDialectRegistryEnabledFlag = false; 35 void mlir::enableGlobalDialectRegistry(bool enable) { 36 isGlobalDialectRegistryEnabledFlag = enable; 37 } 38 bool mlir::isGlobalDialectRegistryEnabled() { 39 return isGlobalDialectRegistryEnabledFlag; 40 } 41 42 void mlir::registerAllDialects(MLIRContext *context) { 43 dialectRegistry->appendTo(context->getDialectRegistry()); 44 } 45 46 Dialect *DialectRegistry::loadByName(StringRef name, MLIRContext *context) { 47 auto it = registry.find(name.str()); 48 if (it == registry.end()) 49 return nullptr; 50 return it->second.second(context); 51 } 52 53 void DialectRegistry::insert(TypeID typeID, StringRef name, 54 DialectAllocatorFunction ctor) { 55 auto inserted = registry.insert( 56 std::make_pair(std::string(name), std::make_pair(typeID, ctor))); 57 if (!inserted.second && inserted.first->second.first != typeID) { 58 llvm::report_fatal_error( 59 "Trying to register different dialects for the same namespace: " + 60 name); 61 } 62 } 63 64 //===----------------------------------------------------------------------===// 65 // Dialect 66 //===----------------------------------------------------------------------===// 67 68 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id) 69 : name(name), dialectID(id), context(context) { 70 assert(isValidNamespace(name) && "invalid dialect namespace"); 71 } 72 73 Dialect::~Dialect() {} 74 75 /// Verify an attribute from this dialect on the argument at 'argIndex' for 76 /// the region at 'regionIndex' on the given operation. Returns failure if 77 /// the verification failed, success otherwise. This hook may optionally be 78 /// invoked from any operation containing a region. 79 LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned, 80 NamedAttribute) { 81 return success(); 82 } 83 84 /// Verify an attribute from this dialect on the result at 'resultIndex' for 85 /// the region at 'regionIndex' on the given operation. Returns failure if 86 /// the verification failed, success otherwise. This hook may optionally be 87 /// invoked from any operation containing a region. 88 LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned, 89 unsigned, NamedAttribute) { 90 return success(); 91 } 92 93 /// Parse an attribute registered to this dialect. 94 Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const { 95 parser.emitError(parser.getNameLoc()) 96 << "dialect '" << getNamespace() 97 << "' provides no attribute parsing hook"; 98 return Attribute(); 99 } 100 101 /// Parse a type registered to this dialect. 102 Type Dialect::parseType(DialectAsmParser &parser) const { 103 // If this dialect allows unknown types, then represent this with OpaqueType. 104 if (allowsUnknownTypes()) { 105 auto ns = Identifier::get(getNamespace(), getContext()); 106 return OpaqueType::get(ns, parser.getFullSymbolSpec(), getContext()); 107 } 108 109 parser.emitError(parser.getNameLoc()) 110 << "dialect '" << getNamespace() << "' provides no type parsing hook"; 111 return Type(); 112 } 113 114 /// Utility function that returns if the given string is a valid dialect 115 /// namespace. 116 bool Dialect::isValidNamespace(StringRef str) { 117 if (str.empty()) 118 return true; 119 llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$"); 120 return dialectNameRegex.match(str); 121 } 122 123 /// Register a set of dialect interfaces with this dialect instance. 124 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) { 125 auto it = registeredInterfaces.try_emplace(interface->getID(), 126 std::move(interface)); 127 (void)it; 128 assert(it.second && "interface kind has already been registered"); 129 } 130 131 //===----------------------------------------------------------------------===// 132 // Dialect Interface 133 //===----------------------------------------------------------------------===// 134 135 DialectInterface::~DialectInterface() {} 136 137 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase( 138 MLIRContext *ctx, TypeID interfaceKind) { 139 for (auto *dialect : ctx->getLoadedDialects()) { 140 if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) { 141 interfaces.insert(interface); 142 orderedInterfaces.push_back(interface); 143 } 144 } 145 } 146 147 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {} 148 149 /// Get the interface for the dialect of given operation, or null if one 150 /// is not registered. 151 const DialectInterface * 152 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const { 153 return getInterfaceFor(op->getDialect()); 154 } 155