1 //===- DialectGen.cpp - MLIR dialect definitions generator ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // DialectGen uses the description of dialects to generate C++ definitions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/TableGen/Class.h"
14 #include "mlir/TableGen/CodeGenHelpers.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/GenInfo.h"
17 #include "mlir/TableGen/Interfaces.h"
18 #include "mlir/TableGen/Operator.h"
19 #include "mlir/TableGen/Trait.h"
20 #include "llvm/ADT/Optional.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/Signals.h"
25 #include "llvm/TableGen/Error.h"
26 #include "llvm/TableGen/Record.h"
27 #include "llvm/TableGen/TableGenBackend.h"
28 
29 #define DEBUG_TYPE "mlir-tblgen-opdefgen"
30 
31 using namespace mlir;
32 using namespace mlir::tblgen;
33 
34 static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*");
35 llvm::cl::opt<std::string>
36     selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
37                     llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
38 
39 /// Utility iterator used for filtering records for a specific dialect.
40 namespace {
41 using DialectFilterIterator =
42     llvm::filter_iterator<ArrayRef<llvm::Record *>::iterator,
43                           std::function<bool(const llvm::Record *)>>;
44 } // namespace
45 
46 /// Given a set of records for a T, filter the ones that correspond to
47 /// the given dialect.
48 template <typename T>
49 static iterator_range<DialectFilterIterator>
50 filterForDialect(ArrayRef<llvm::Record *> records, Dialect &dialect) {
51   auto filterFn = [&](const llvm::Record *record) {
52     return T(record).getDialect() == dialect;
53   };
54   return {DialectFilterIterator(records.begin(), records.end(), filterFn),
55           DialectFilterIterator(records.end(), records.end(), filterFn)};
56 }
57 
58 static Optional<Dialect>
59 findSelectedDialect(ArrayRef<const llvm::Record *> dialectDefs) {
60   // Select the dialect to gen for.
61   if (dialectDefs.size() == 1 && selectedDialect.getNumOccurrences() == 0) {
62     return Dialect(dialectDefs.front());
63   }
64 
65   if (selectedDialect.getNumOccurrences() == 0) {
66     llvm::errs() << "when more than 1 dialect is present, one must be selected "
67                     "via '-dialect'\n";
68     return llvm::None;
69   }
70 
71   const auto *dialectIt =
72       llvm::find_if(dialectDefs, [](const llvm::Record *def) {
73         return Dialect(def).getName() == selectedDialect;
74       });
75   if (dialectIt == dialectDefs.end()) {
76     llvm::errs() << "selected dialect with '-dialect' does not exist\n";
77     return llvm::None;
78   }
79   return Dialect(*dialectIt);
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // GEN: Dialect declarations
84 //===----------------------------------------------------------------------===//
85 
86 /// The code block for the start of a dialect class declaration.
87 ///
88 /// {0}: The name of the dialect class.
89 /// {1}: The dialect namespace.
90 /// {2}: The dialect parent class.
91 static const char *const dialectDeclBeginStr = R"(
92 class {0} : public ::mlir::{2} {
93   explicit {0}(::mlir::MLIRContext *context);
94 
95   void initialize();
96   friend class ::mlir::MLIRContext;
97 public:
98   ~{0}() override;
99   static constexpr ::llvm::StringLiteral getDialectNamespace() {
100     return ::llvm::StringLiteral("{1}");
101   }
102 )";
103 
104 /// Registration for a single dependent dialect: to be inserted in the ctor
105 /// above for each dependent dialect.
106 const char *const dialectRegistrationTemplate = R"(
107     getContext()->getOrLoadDialect<{0}>();
108 )";
109 
110 /// The code block for the attribute parser/printer hooks.
111 static const char *const attrParserDecl = R"(
112   /// Parse an attribute registered to this dialect.
113   ::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser,
114                                    ::mlir::Type type) const override;
115 
116   /// Print an attribute registered to this dialect.
117   void printAttribute(::mlir::Attribute attr,
118                       ::mlir::DialectAsmPrinter &os) const override;
119 )";
120 
121 /// The code block for the type parser/printer hooks.
122 static const char *const typeParserDecl = R"(
123   /// Parse a type registered to this dialect.
124   ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;
125 
126   /// Print a type registered to this dialect.
127   void printType(::mlir::Type type,
128                  ::mlir::DialectAsmPrinter &os) const override;
129 )";
130 
131 /// The code block for the canonicalization pattern registration hook.
132 static const char *const canonicalizerDecl = R"(
133   /// Register canonicalization patterns.
134   void getCanonicalizationPatterns(
135       ::mlir::RewritePatternSet &results) const override;
136 )";
137 
138 /// The code block for the constant materializer hook.
139 static const char *const constantMaterializerDecl = R"(
140   /// Materialize a single constant operation from a given attribute value with
141   /// the desired resultant type.
142   ::mlir::Operation *materializeConstant(::mlir::OpBuilder &builder,
143                                          ::mlir::Attribute value,
144                                          ::mlir::Type type,
145                                          ::mlir::Location loc) override;
146 )";
147 
148 /// The code block for the operation attribute verifier hook.
149 static const char *const opAttrVerifierDecl = R"(
150     /// Provides a hook for verifying dialect attributes attached to the given
151     /// op.
152     ::mlir::LogicalResult verifyOperationAttribute(
153         ::mlir::Operation *op, ::mlir::NamedAttribute attribute) override;
154 )";
155 
156 /// The code block for the region argument attribute verifier hook.
157 static const char *const regionArgAttrVerifierDecl = R"(
158     /// Provides a hook for verifying dialect attributes attached to the given
159     /// op's region argument.
160     ::mlir::LogicalResult verifyRegionArgAttribute(
161         ::mlir::Operation *op, unsigned regionIndex, unsigned argIndex,
162         ::mlir::NamedAttribute attribute) override;
163 )";
164 
165 /// The code block for the region result attribute verifier hook.
166 static const char *const regionResultAttrVerifierDecl = R"(
167     /// Provides a hook for verifying dialect attributes attached to the given
168     /// op's region result.
169     ::mlir::LogicalResult verifyRegionResultAttribute(
170         ::mlir::Operation *op, unsigned regionIndex, unsigned resultIndex,
171         ::mlir::NamedAttribute attribute) override;
172 )";
173 
174 /// The code block for the op interface fallback hook.
175 static const char *const operationInterfaceFallbackDecl = R"(
176     /// Provides a hook for op interface.
177     void *getRegisteredInterfaceForOp(mlir::TypeID interfaceID,
178                                       mlir::OperationName opName) override;
179 )";
180 
181 /// Generate the declaration for the given dialect class.
182 static void
183 emitDialectDecl(Dialect &dialect,
184                 const iterator_range<DialectFilterIterator> &dialectAttrs,
185                 const iterator_range<DialectFilterIterator> &dialectTypes,
186                 raw_ostream &os) {
187   // Emit all nested namespaces.
188   {
189     NamespaceEmitter nsEmitter(os, dialect);
190 
191     // Emit the start of the decl.
192     std::string cppName = dialect.getCppClassName();
193     StringRef superClassName =
194         dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
195     os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(),
196                         superClassName);
197 
198     // Check for any attributes/types registered to this dialect.  If there are,
199     // add the hooks for parsing/printing.
200     if (!dialectAttrs.empty() && dialect.useDefaultAttributePrinterParser())
201       os << attrParserDecl;
202     if (!dialectTypes.empty() && dialect.useDefaultTypePrinterParser())
203       os << typeParserDecl;
204 
205     // Add the decls for the various features of the dialect.
206     if (dialect.hasCanonicalizer())
207       os << canonicalizerDecl;
208     if (dialect.hasConstantMaterializer())
209       os << constantMaterializerDecl;
210     if (dialect.hasOperationAttrVerify())
211       os << opAttrVerifierDecl;
212     if (dialect.hasRegionArgAttrVerify())
213       os << regionArgAttrVerifierDecl;
214     if (dialect.hasRegionResultAttrVerify())
215       os << regionResultAttrVerifierDecl;
216     if (dialect.hasOperationInterfaceFallback())
217       os << operationInterfaceFallbackDecl;
218     if (llvm::Optional<StringRef> extraDecl =
219             dialect.getExtraClassDeclaration())
220       os << *extraDecl;
221 
222     // End the dialect decl.
223     os << "};\n";
224   }
225   if (!dialect.getCppNamespace().empty())
226     os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace()
227        << "::" << dialect.getCppClassName() << ")\n";
228 }
229 
230 static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
231                              raw_ostream &os) {
232   emitSourceFileHeader("Dialect Declarations", os);
233 
234   auto dialectDefs = recordKeeper.getAllDerivedDefinitions("Dialect");
235   if (dialectDefs.empty())
236     return false;
237 
238   Optional<Dialect> dialect = findSelectedDialect(dialectDefs);
239   if (!dialect)
240     return true;
241   auto attrDefs = recordKeeper.getAllDerivedDefinitions("DialectAttr");
242   auto typeDefs = recordKeeper.getAllDerivedDefinitions("DialectType");
243   emitDialectDecl(*dialect, filterForDialect<Attribute>(attrDefs, *dialect),
244                   filterForDialect<Type>(typeDefs, *dialect), os);
245   return false;
246 }
247 
248 //===----------------------------------------------------------------------===//
249 // GEN: Dialect definitions
250 //===----------------------------------------------------------------------===//
251 
252 /// The code block to generate a dialect constructor definition.
253 ///
254 /// {0}: The name of the dialect class.
255 /// {1}: initialization code that is emitted in the ctor body before calling
256 ///      initialize().
257 /// {2}: The dialect parent class.
258 static const char *const dialectConstructorStr = R"(
259 {0}::{0}(::mlir::MLIRContext *context)
260     : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{
261   {1}
262   initialize();
263 }
264 )";
265 
266 /// The code block to generate a default desturctor definition.
267 ///
268 /// {0}: The name of the dialect class.
269 static const char *const dialectDestructorStr = R"(
270 {0}::~{0}() = default;
271 
272 )";
273 
274 static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
275   std::string cppClassName = dialect.getCppClassName();
276 
277   // Emit the TypeID explicit specializations to have a single symbol def.
278   if (!dialect.getCppNamespace().empty())
279     os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace()
280        << "::" << cppClassName << ")\n";
281 
282   // Emit all nested namespaces.
283   NamespaceEmitter nsEmitter(os, dialect);
284 
285   /// Build the list of dependent dialects.
286   std::string dependentDialectRegistrations;
287   {
288     llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
289     for (StringRef dependentDialect : dialect.getDependentDialects())
290       dialectsOs << llvm::formatv(dialectRegistrationTemplate,
291                                   dependentDialect);
292   }
293 
294   // Emit the constructor and destructor.
295   StringRef superClassName =
296       dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
297   os << llvm::formatv(dialectConstructorStr, cppClassName,
298                       dependentDialectRegistrations, superClassName);
299   if (!dialect.hasNonDefaultDestructor())
300     os << llvm::formatv(dialectDestructorStr, cppClassName);
301 }
302 
303 static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
304                             raw_ostream &os) {
305   emitSourceFileHeader("Dialect Definitions", os);
306 
307   auto dialectDefs = recordKeeper.getAllDerivedDefinitions("Dialect");
308   if (dialectDefs.empty())
309     return false;
310 
311   Optional<Dialect> dialect = findSelectedDialect(dialectDefs);
312   if (!dialect)
313     return true;
314   emitDialectDef(*dialect, os);
315   return false;
316 }
317 
318 //===----------------------------------------------------------------------===//
319 // GEN: Dialect registration hooks
320 //===----------------------------------------------------------------------===//
321 
322 static mlir::GenRegistration
323     genDialectDecls("gen-dialect-decls", "Generate dialect declarations",
324                     [](const llvm::RecordKeeper &records, raw_ostream &os) {
325                       return emitDialectDecls(records, os);
326                     });
327 
328 static mlir::GenRegistration
329     genDialectDefs("gen-dialect-defs", "Generate dialect definitions",
330                    [](const llvm::RecordKeeper &records, raw_ostream &os) {
331                      return emitDialectDefs(records, os);
332                    });
333