1 //===- DialectRegistry.h - Dialect Registration and Extension ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines functionality for registring and extending dialects.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_DIALECTREGISTRY_H
14 #define MLIR_IR_DIALECTREGISTRY_H
15 
16 #include "mlir/IR/MLIRContext.h"
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringRef.h"
20 
21 #include <map>
22 #include <tuple>
23 
24 namespace mlir {
25 class Dialect;
26 
27 using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
28 using DialectAllocatorFunctionRef = function_ref<Dialect *(MLIRContext *)>;
29 
30 //===----------------------------------------------------------------------===//
31 // DialectExtension
32 //===----------------------------------------------------------------------===//
33 
34 /// This class represents an opaque dialect extension. It contains a set of
35 /// required dialects and an application function. The required dialects control
36 /// when the extension is applied, i.e. the extension is applied when all
37 /// required dialects are loaded. The application function can be used to attach
38 /// additional functionality to attributes, dialects, operations, types, etc.,
39 /// and may also load additional necessary dialects.
40 class DialectExtensionBase {
41 public:
42   virtual ~DialectExtensionBase();
43 
44   /// Return the dialects that our required by this extension to be loaded
45   /// before applying.
getRequiredDialects()46   ArrayRef<StringRef> getRequiredDialects() const { return dialectNames; }
47 
48   /// Apply this extension to the given context and the required dialects.
49   virtual void apply(MLIRContext *context,
50                      MutableArrayRef<Dialect *> dialects) const = 0;
51 
52   /// Return a copy of this extension.
53   virtual std::unique_ptr<DialectExtensionBase> clone() const = 0;
54 
55 protected:
56   /// Initialize the extension with a set of required dialects. Note that there
57   /// should always be at least one affected dialect.
DialectExtensionBase(ArrayRef<StringRef> dialectNames)58   DialectExtensionBase(ArrayRef<StringRef> dialectNames)
59       : dialectNames(dialectNames.begin(), dialectNames.end()) {
60     assert(!dialectNames.empty() && "expected at least one affected dialect");
61   }
62 
63 private:
64   /// The names of the dialects affected by this extension.
65   SmallVector<StringRef> dialectNames;
66 };
67 
68 /// This class represents a dialect extension anchored on the given set of
69 /// dialects. When all of the specified dialects have been loaded, the
70 /// application function of this extension will be executed.
71 template <typename DerivedT, typename... DialectsT>
72 class DialectExtension : public DialectExtensionBase {
73 public:
74   /// Applies this extension to the given context and set of required dialects.
75   virtual void apply(MLIRContext *context, DialectsT *...dialects) const = 0;
76 
77   /// Return a copy of this extension.
clone()78   std::unique_ptr<DialectExtensionBase> clone() const final {
79     return std::make_unique<DerivedT>(static_cast<const DerivedT &>(*this));
80   }
81 
82 protected:
DialectExtension()83   DialectExtension()
84       : DialectExtensionBase(
85             ArrayRef<StringRef>({DialectsT::getDialectNamespace()...})) {}
86 
87   /// Override the base apply method to allow providing the exact dialect types.
apply(MLIRContext * context,MutableArrayRef<Dialect * > dialects)88   void apply(MLIRContext *context,
89              MutableArrayRef<Dialect *> dialects) const final {
90     unsigned dialectIdx = 0;
91     auto derivedDialects = std::tuple<DialectsT *...>{
92         static_cast<DialectsT *>(dialects[dialectIdx++])...};
93     llvm::apply_tuple(
94         [&](DialectsT *...dialect) { apply(context, dialect...); },
95         derivedDialects);
96   }
97 };
98 
99 //===----------------------------------------------------------------------===//
100 // DialectRegistry
101 //===----------------------------------------------------------------------===//
102 
103 /// The DialectRegistry maps a dialect namespace to a constructor for the
104 /// matching dialect. This allows for decoupling the list of dialects
105 /// "available" from the dialects loaded in the Context. The parser in
106 /// particular will lazily load dialects in the Context as operations are
107 /// encountered.
108 class DialectRegistry {
109   using MapTy =
110       std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
111 
112 public:
113   explicit DialectRegistry();
114 
115   template <typename ConcreteDialect>
insert()116   void insert() {
117     insert(TypeID::get<ConcreteDialect>(),
118            ConcreteDialect::getDialectNamespace(),
119            static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
120              // Just allocate the dialect, the context
121              // takes ownership of it.
122              return ctx->getOrLoadDialect<ConcreteDialect>();
123            })));
124   }
125 
126   template <typename ConcreteDialect, typename OtherDialect,
127             typename... MoreDialects>
insert()128   void insert() {
129     insert<ConcreteDialect>();
130     insert<OtherDialect, MoreDialects...>();
131   }
132 
133   /// Add a new dialect constructor to the registry. The constructor must be
134   /// calling MLIRContext::getOrLoadDialect in order for the context to take
135   /// ownership of the dialect and for delayed interface registration to happen.
136   void insert(TypeID typeID, StringRef name,
137               const DialectAllocatorFunction &ctor);
138 
139   /// Return an allocation function for constructing the dialect identified by
140   /// its namespace, or nullptr if the namespace is not in this registry.
141   DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const;
142 
143   // Register all dialects available in the current registry with the registry
144   // in the provided context.
appendTo(DialectRegistry & destination)145   void appendTo(DialectRegistry &destination) const {
146     for (const auto &nameAndRegistrationIt : registry)
147       destination.insert(nameAndRegistrationIt.second.first,
148                          nameAndRegistrationIt.first,
149                          nameAndRegistrationIt.second.second);
150     // Merge the extensions.
151     for (const auto &extension : extensions)
152       destination.extensions.push_back(extension->clone());
153   }
154 
155   /// Return the names of dialects known to this registry.
getDialectNames()156   auto getDialectNames() const {
157     return llvm::map_range(
158         registry,
159         [](const MapTy::value_type &item) -> StringRef { return item.first; });
160   }
161 
162   /// Apply any held extensions that require the given dialect. Users are not
163   /// expected to call this directly.
164   void applyExtensions(Dialect *dialect) const;
165 
166   /// Apply any applicable extensions to the given context. Users are not
167   /// expected to call this directly.
168   void applyExtensions(MLIRContext *ctx) const;
169 
170   /// Add the given extension to the registry.
addExtension(std::unique_ptr<DialectExtensionBase> extension)171   void addExtension(std::unique_ptr<DialectExtensionBase> extension) {
172     extensions.push_back(std::move(extension));
173   }
174 
175   /// Add the given extensions to the registry.
176   template <typename... ExtensionsT>
addExtensions()177   void addExtensions() {
178     (void)std::initializer_list<int>{
179         (addExtension(std::make_unique<ExtensionsT>()), 0)...};
180   }
181 
182   /// Add an extension function that requires the given dialects.
183   /// Note: This bare functor overload is provided in addition to the
184   /// std::function variant to enable dialect type deduction, e.g.:
185   ///  registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) { ... })
186   ///
187   /// is equivalent to:
188   ///  registry.addExtension<MyDialect>(
189   ///     [](MLIRContext *ctx, MyDialect *dialect){ ... }
190   ///  )
191   template <typename... DialectsT>
addExtension(void (* extensionFn)(MLIRContext *,DialectsT * ...))192   void addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) {
193     addExtension<DialectsT...>(
194         std::function<void(MLIRContext *, DialectsT * ...)>(extensionFn));
195   }
196   template <typename... DialectsT>
197   void
addExtension(std::function<void (MLIRContext *,DialectsT * ...)> extensionFn)198   addExtension(std::function<void(MLIRContext *, DialectsT *...)> extensionFn) {
199     using ExtensionFnT = std::function<void(MLIRContext *, DialectsT * ...)>;
200 
201     struct Extension : public DialectExtension<Extension, DialectsT...> {
202       Extension(const Extension &) = default;
203       Extension(ExtensionFnT extensionFn)
204           : extensionFn(std::move(extensionFn)) {}
205       ~Extension() override = default;
206 
207       void apply(MLIRContext *context, DialectsT *...dialects) const final {
208         extensionFn(context, dialects...);
209       }
210       ExtensionFnT extensionFn;
211     };
212     addExtension(std::make_unique<Extension>(std::move(extensionFn)));
213   }
214 
215   /// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs'
216   /// contains all of the components of this registry.
217   bool isSubsetOf(const DialectRegistry &rhs) const;
218 
219 private:
220   MapTy registry;
221   std::vector<std::unique_ptr<DialectExtensionBase>> extensions;
222 };
223 
224 } // namespace mlir
225 
226 #endif // MLIR_IR_DIALECTREGISTRY_H
227