1 //===- DialectRegistry.h - Dialect Registration and Extension ---*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines functionality for registring and extending dialects. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_IR_DIALECTREGISTRY_H 14 #define MLIR_IR_DIALECTREGISTRY_H 15 16 #include "mlir/IR/MLIRContext.h" 17 #include "llvm/ADT/ArrayRef.h" 18 #include "llvm/ADT/SmallVector.h" 19 #include "llvm/ADT/StringRef.h" 20 21 #include <map> 22 #include <tuple> 23 24 namespace mlir { 25 class Dialect; 26 27 using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>; 28 using DialectAllocatorFunctionRef = function_ref<Dialect *(MLIRContext *)>; 29 30 //===----------------------------------------------------------------------===// 31 // DialectExtension 32 //===----------------------------------------------------------------------===// 33 34 /// This class represents an opaque dialect extension. It contains a set of 35 /// required dialects and an application function. The required dialects control 36 /// when the extension is applied, i.e. the extension is applied when all 37 /// required dialects are loaded. The application function can be used to attach 38 /// additional functionality to attributes, dialects, operations, types, etc., 39 /// and may also load additional necessary dialects. 40 class DialectExtensionBase { 41 public: 42 virtual ~DialectExtensionBase(); 43 44 /// Return the dialects that our required by this extension to be loaded 45 /// before applying. getRequiredDialects()46 ArrayRef<StringRef> getRequiredDialects() const { return dialectNames; } 47 48 /// Apply this extension to the given context and the required dialects. 49 virtual void apply(MLIRContext *context, 50 MutableArrayRef<Dialect *> dialects) const = 0; 51 52 /// Return a copy of this extension. 53 virtual std::unique_ptr<DialectExtensionBase> clone() const = 0; 54 55 protected: 56 /// Initialize the extension with a set of required dialects. Note that there 57 /// should always be at least one affected dialect. DialectExtensionBase(ArrayRef<StringRef> dialectNames)58 DialectExtensionBase(ArrayRef<StringRef> dialectNames) 59 : dialectNames(dialectNames.begin(), dialectNames.end()) { 60 assert(!dialectNames.empty() && "expected at least one affected dialect"); 61 } 62 63 private: 64 /// The names of the dialects affected by this extension. 65 SmallVector<StringRef> dialectNames; 66 }; 67 68 /// This class represents a dialect extension anchored on the given set of 69 /// dialects. When all of the specified dialects have been loaded, the 70 /// application function of this extension will be executed. 71 template <typename DerivedT, typename... DialectsT> 72 class DialectExtension : public DialectExtensionBase { 73 public: 74 /// Applies this extension to the given context and set of required dialects. 75 virtual void apply(MLIRContext *context, DialectsT *...dialects) const = 0; 76 77 /// Return a copy of this extension. clone()78 std::unique_ptr<DialectExtensionBase> clone() const final { 79 return std::make_unique<DerivedT>(static_cast<const DerivedT &>(*this)); 80 } 81 82 protected: DialectExtension()83 DialectExtension() 84 : DialectExtensionBase( 85 ArrayRef<StringRef>({DialectsT::getDialectNamespace()...})) {} 86 87 /// Override the base apply method to allow providing the exact dialect types. apply(MLIRContext * context,MutableArrayRef<Dialect * > dialects)88 void apply(MLIRContext *context, 89 MutableArrayRef<Dialect *> dialects) const final { 90 unsigned dialectIdx = 0; 91 auto derivedDialects = std::tuple<DialectsT *...>{ 92 static_cast<DialectsT *>(dialects[dialectIdx++])...}; 93 llvm::apply_tuple( 94 [&](DialectsT *...dialect) { apply(context, dialect...); }, 95 derivedDialects); 96 } 97 }; 98 99 //===----------------------------------------------------------------------===// 100 // DialectRegistry 101 //===----------------------------------------------------------------------===// 102 103 /// The DialectRegistry maps a dialect namespace to a constructor for the 104 /// matching dialect. This allows for decoupling the list of dialects 105 /// "available" from the dialects loaded in the Context. The parser in 106 /// particular will lazily load dialects in the Context as operations are 107 /// encountered. 108 class DialectRegistry { 109 using MapTy = 110 std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>; 111 112 public: 113 explicit DialectRegistry(); 114 115 template <typename ConcreteDialect> insert()116 void insert() { 117 insert(TypeID::get<ConcreteDialect>(), 118 ConcreteDialect::getDialectNamespace(), 119 static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) { 120 // Just allocate the dialect, the context 121 // takes ownership of it. 122 return ctx->getOrLoadDialect<ConcreteDialect>(); 123 }))); 124 } 125 126 template <typename ConcreteDialect, typename OtherDialect, 127 typename... MoreDialects> insert()128 void insert() { 129 insert<ConcreteDialect>(); 130 insert<OtherDialect, MoreDialects...>(); 131 } 132 133 /// Add a new dialect constructor to the registry. The constructor must be 134 /// calling MLIRContext::getOrLoadDialect in order for the context to take 135 /// ownership of the dialect and for delayed interface registration to happen. 136 void insert(TypeID typeID, StringRef name, 137 const DialectAllocatorFunction &ctor); 138 139 /// Return an allocation function for constructing the dialect identified by 140 /// its namespace, or nullptr if the namespace is not in this registry. 141 DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const; 142 143 // Register all dialects available in the current registry with the registry 144 // in the provided context. appendTo(DialectRegistry & destination)145 void appendTo(DialectRegistry &destination) const { 146 for (const auto &nameAndRegistrationIt : registry) 147 destination.insert(nameAndRegistrationIt.second.first, 148 nameAndRegistrationIt.first, 149 nameAndRegistrationIt.second.second); 150 // Merge the extensions. 151 for (const auto &extension : extensions) 152 destination.extensions.push_back(extension->clone()); 153 } 154 155 /// Return the names of dialects known to this registry. getDialectNames()156 auto getDialectNames() const { 157 return llvm::map_range( 158 registry, 159 [](const MapTy::value_type &item) -> StringRef { return item.first; }); 160 } 161 162 /// Apply any held extensions that require the given dialect. Users are not 163 /// expected to call this directly. 164 void applyExtensions(Dialect *dialect) const; 165 166 /// Apply any applicable extensions to the given context. Users are not 167 /// expected to call this directly. 168 void applyExtensions(MLIRContext *ctx) const; 169 170 /// Add the given extension to the registry. addExtension(std::unique_ptr<DialectExtensionBase> extension)171 void addExtension(std::unique_ptr<DialectExtensionBase> extension) { 172 extensions.push_back(std::move(extension)); 173 } 174 175 /// Add the given extensions to the registry. 176 template <typename... ExtensionsT> addExtensions()177 void addExtensions() { 178 (void)std::initializer_list<int>{ 179 (addExtension(std::make_unique<ExtensionsT>()), 0)...}; 180 } 181 182 /// Add an extension function that requires the given dialects. 183 /// Note: This bare functor overload is provided in addition to the 184 /// std::function variant to enable dialect type deduction, e.g.: 185 /// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) { ... }) 186 /// 187 /// is equivalent to: 188 /// registry.addExtension<MyDialect>( 189 /// [](MLIRContext *ctx, MyDialect *dialect){ ... } 190 /// ) 191 template <typename... DialectsT> addExtension(void (* extensionFn)(MLIRContext *,DialectsT * ...))192 void addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) { 193 addExtension<DialectsT...>( 194 std::function<void(MLIRContext *, DialectsT * ...)>(extensionFn)); 195 } 196 template <typename... DialectsT> 197 void addExtension(std::function<void (MLIRContext *,DialectsT * ...)> extensionFn)198 addExtension(std::function<void(MLIRContext *, DialectsT *...)> extensionFn) { 199 using ExtensionFnT = std::function<void(MLIRContext *, DialectsT * ...)>; 200 201 struct Extension : public DialectExtension<Extension, DialectsT...> { 202 Extension(const Extension &) = default; 203 Extension(ExtensionFnT extensionFn) 204 : extensionFn(std::move(extensionFn)) {} 205 ~Extension() override = default; 206 207 void apply(MLIRContext *context, DialectsT *...dialects) const final { 208 extensionFn(context, dialects...); 209 } 210 ExtensionFnT extensionFn; 211 }; 212 addExtension(std::make_unique<Extension>(std::move(extensionFn))); 213 } 214 215 /// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs' 216 /// contains all of the components of this registry. 217 bool isSubsetOf(const DialectRegistry &rhs) const; 218 219 private: 220 MapTy registry; 221 std::vector<std::unique_ptr<DialectExtensionBase>> extensions; 222 }; 223 224 } // namespace mlir 225 226 #endif // MLIR_IR_DIALECTREGISTRY_H 227