1 //===- DialectGen.cpp - MLIR dialect definitions generator ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // DialectGen uses the description of dialects to generate C++ definitions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "DialectGenUtilities.h"
14 #include "mlir/TableGen/Class.h"
15 #include "mlir/TableGen/CodeGenHelpers.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/GenInfo.h"
18 #include "mlir/TableGen/Interfaces.h"
19 #include "mlir/TableGen/Operator.h"
20 #include "mlir/TableGen/Trait.h"
21 #include "llvm/ADT/Optional.h"
22 #include "llvm/ADT/Sequence.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/Signals.h"
26 #include "llvm/TableGen/Error.h"
27 #include "llvm/TableGen/Record.h"
28 #include "llvm/TableGen/TableGenBackend.h"
29 
30 #define DEBUG_TYPE "mlir-tblgen-opdefgen"
31 
32 using namespace mlir;
33 using namespace mlir::tblgen;
34 
35 static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*");
36 llvm::cl::opt<std::string>
37     selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
38                     llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
39 
40 /// Utility iterator used for filtering records for a specific dialect.
41 namespace {
42 using DialectFilterIterator =
43     llvm::filter_iterator<ArrayRef<llvm::Record *>::iterator,
44                           std::function<bool(const llvm::Record *)>>;
45 } // namespace
46 
47 /// Given a set of records for a T, filter the ones that correspond to
48 /// the given dialect.
49 template <typename T>
50 static iterator_range<DialectFilterIterator>
filterForDialect(ArrayRef<llvm::Record * > records,Dialect & dialect)51 filterForDialect(ArrayRef<llvm::Record *> records, Dialect &dialect) {
52   auto filterFn = [&](const llvm::Record *record) {
53     return T(record).getDialect() == dialect;
54   };
55   return {DialectFilterIterator(records.begin(), records.end(), filterFn),
56           DialectFilterIterator(records.end(), records.end(), filterFn)};
57 }
58 
findDialectToGenerate(ArrayRef<Dialect> dialects)59 Optional<Dialect> tblgen::findDialectToGenerate(ArrayRef<Dialect> dialects) {
60   if (dialects.empty()) {
61     llvm::errs() << "no dialect was found\n";
62     return llvm::None;
63   }
64 
65   // Select the dialect to gen for.
66   if (dialects.size() == 1 && selectedDialect.getNumOccurrences() == 0)
67     return dialects.front();
68 
69   if (selectedDialect.getNumOccurrences() == 0) {
70     llvm::errs() << "when more than 1 dialect is present, one must be selected "
71                     "via '-dialect'\n";
72     return llvm::None;
73   }
74 
75   const auto *dialectIt = llvm::find_if(dialects, [](const Dialect &dialect) {
76     return dialect.getName() == selectedDialect;
77   });
78   if (dialectIt == dialects.end()) {
79     llvm::errs() << "selected dialect with '-dialect' does not exist\n";
80     return llvm::None;
81   }
82   return *dialectIt;
83 }
84 
85 //===----------------------------------------------------------------------===//
86 // GEN: Dialect declarations
87 //===----------------------------------------------------------------------===//
88 
89 /// The code block for the start of a dialect class declaration.
90 ///
91 /// {0}: The name of the dialect class.
92 /// {1}: The dialect namespace.
93 /// {2}: The dialect parent class.
94 static const char *const dialectDeclBeginStr = R"(
95 class {0} : public ::mlir::{2} {
96   explicit {0}(::mlir::MLIRContext *context);
97 
98   void initialize();
99   friend class ::mlir::MLIRContext;
100 public:
101   ~{0}() override;
102   static constexpr ::llvm::StringLiteral getDialectNamespace() {
103     return ::llvm::StringLiteral("{1}");
104   }
105 )";
106 
107 /// Registration for a single dependent dialect: to be inserted in the ctor
108 /// above for each dependent dialect.
109 const char *const dialectRegistrationTemplate = R"(
110     getContext()->getOrLoadDialect<{0}>();
111 )";
112 
113 /// The code block for the attribute parser/printer hooks.
114 static const char *const attrParserDecl = R"(
115   /// Parse an attribute registered to this dialect.
116   ::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser,
117                                    ::mlir::Type type) const override;
118 
119   /// Print an attribute registered to this dialect.
120   void printAttribute(::mlir::Attribute attr,
121                       ::mlir::DialectAsmPrinter &os) const override;
122 )";
123 
124 /// The code block for the type parser/printer hooks.
125 static const char *const typeParserDecl = R"(
126   /// Parse a type registered to this dialect.
127   ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;
128 
129   /// Print a type registered to this dialect.
130   void printType(::mlir::Type type,
131                  ::mlir::DialectAsmPrinter &os) const override;
132 )";
133 
134 /// The code block for the canonicalization pattern registration hook.
135 static const char *const canonicalizerDecl = R"(
136   /// Register canonicalization patterns.
137   void getCanonicalizationPatterns(
138       ::mlir::RewritePatternSet &results) const override;
139 )";
140 
141 /// The code block for the constant materializer hook.
142 static const char *const constantMaterializerDecl = R"(
143   /// Materialize a single constant operation from a given attribute value with
144   /// the desired resultant type.
145   ::mlir::Operation *materializeConstant(::mlir::OpBuilder &builder,
146                                          ::mlir::Attribute value,
147                                          ::mlir::Type type,
148                                          ::mlir::Location loc) override;
149 )";
150 
151 /// The code block for the operation attribute verifier hook.
152 static const char *const opAttrVerifierDecl = R"(
153     /// Provides a hook for verifying dialect attributes attached to the given
154     /// op.
155     ::mlir::LogicalResult verifyOperationAttribute(
156         ::mlir::Operation *op, ::mlir::NamedAttribute attribute) override;
157 )";
158 
159 /// The code block for the region argument attribute verifier hook.
160 static const char *const regionArgAttrVerifierDecl = R"(
161     /// Provides a hook for verifying dialect attributes attached to the given
162     /// op's region argument.
163     ::mlir::LogicalResult verifyRegionArgAttribute(
164         ::mlir::Operation *op, unsigned regionIndex, unsigned argIndex,
165         ::mlir::NamedAttribute attribute) override;
166 )";
167 
168 /// The code block for the region result attribute verifier hook.
169 static const char *const regionResultAttrVerifierDecl = R"(
170     /// Provides a hook for verifying dialect attributes attached to the given
171     /// op's region result.
172     ::mlir::LogicalResult verifyRegionResultAttribute(
173         ::mlir::Operation *op, unsigned regionIndex, unsigned resultIndex,
174         ::mlir::NamedAttribute attribute) override;
175 )";
176 
177 /// The code block for the op interface fallback hook.
178 static const char *const operationInterfaceFallbackDecl = R"(
179     /// Provides a hook for op interface.
180     void *getRegisteredInterfaceForOp(mlir::TypeID interfaceID,
181                                       mlir::OperationName opName) override;
182 )";
183 
184 /// Generate the declaration for the given dialect class.
emitDialectDecl(Dialect & dialect,raw_ostream & os)185 static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
186   // Emit all nested namespaces.
187   {
188     NamespaceEmitter nsEmitter(os, dialect);
189 
190     // Emit the start of the decl.
191     std::string cppName = dialect.getCppClassName();
192     StringRef superClassName =
193         dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
194     os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(),
195                         superClassName);
196 
197     // If the dialect requested the default attribute printer and parser, emit
198     // the declarations for the hooks.
199     if (dialect.useDefaultAttributePrinterParser())
200       os << attrParserDecl;
201     // If the dialect requested the default type printer and parser, emit the
202     // delcarations for the hooks.
203     if (dialect.useDefaultTypePrinterParser())
204       os << typeParserDecl;
205 
206     // Add the decls for the various features of the dialect.
207     if (dialect.hasCanonicalizer())
208       os << canonicalizerDecl;
209     if (dialect.hasConstantMaterializer())
210       os << constantMaterializerDecl;
211     if (dialect.hasOperationAttrVerify())
212       os << opAttrVerifierDecl;
213     if (dialect.hasRegionArgAttrVerify())
214       os << regionArgAttrVerifierDecl;
215     if (dialect.hasRegionResultAttrVerify())
216       os << regionResultAttrVerifierDecl;
217     if (dialect.hasOperationInterfaceFallback())
218       os << operationInterfaceFallbackDecl;
219     if (llvm::Optional<StringRef> extraDecl =
220             dialect.getExtraClassDeclaration())
221       os << *extraDecl;
222 
223     // End the dialect decl.
224     os << "};\n";
225   }
226   if (!dialect.getCppNamespace().empty())
227     os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace()
228        << "::" << dialect.getCppClassName() << ")\n";
229 }
230 
emitDialectDecls(const llvm::RecordKeeper & recordKeeper,raw_ostream & os)231 static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
232                              raw_ostream &os) {
233   emitSourceFileHeader("Dialect Declarations", os);
234 
235   auto dialectDefs = recordKeeper.getAllDerivedDefinitions("Dialect");
236   if (dialectDefs.empty())
237     return false;
238 
239   SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
240   Optional<Dialect> dialect = findDialectToGenerate(dialects);
241   if (!dialect)
242     return true;
243   emitDialectDecl(*dialect, os);
244   return false;
245 }
246 
247 //===----------------------------------------------------------------------===//
248 // GEN: Dialect definitions
249 //===----------------------------------------------------------------------===//
250 
251 /// The code block to generate a dialect constructor definition.
252 ///
253 /// {0}: The name of the dialect class.
254 /// {1}: initialization code that is emitted in the ctor body before calling
255 ///      initialize().
256 /// {2}: The dialect parent class.
257 static const char *const dialectConstructorStr = R"(
258 {0}::{0}(::mlir::MLIRContext *context)
259     : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{
260   {1}
261   initialize();
262 }
263 )";
264 
265 /// The code block to generate a default desturctor definition.
266 ///
267 /// {0}: The name of the dialect class.
268 static const char *const dialectDestructorStr = R"(
269 {0}::~{0}() = default;
270 
271 )";
272 
emitDialectDef(Dialect & dialect,raw_ostream & os)273 static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
274   std::string cppClassName = dialect.getCppClassName();
275 
276   // Emit the TypeID explicit specializations to have a single symbol def.
277   if (!dialect.getCppNamespace().empty())
278     os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace()
279        << "::" << cppClassName << ")\n";
280 
281   // Emit all nested namespaces.
282   NamespaceEmitter nsEmitter(os, dialect);
283 
284   /// Build the list of dependent dialects.
285   std::string dependentDialectRegistrations;
286   {
287     llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
288     for (StringRef dependentDialect : dialect.getDependentDialects())
289       dialectsOs << llvm::formatv(dialectRegistrationTemplate,
290                                   dependentDialect);
291   }
292 
293   // Emit the constructor and destructor.
294   StringRef superClassName =
295       dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
296   os << llvm::formatv(dialectConstructorStr, cppClassName,
297                       dependentDialectRegistrations, superClassName);
298   if (!dialect.hasNonDefaultDestructor())
299     os << llvm::formatv(dialectDestructorStr, cppClassName);
300 }
301 
emitDialectDefs(const llvm::RecordKeeper & recordKeeper,raw_ostream & os)302 static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
303                             raw_ostream &os) {
304   emitSourceFileHeader("Dialect Definitions", os);
305 
306   auto dialectDefs = recordKeeper.getAllDerivedDefinitions("Dialect");
307   if (dialectDefs.empty())
308     return false;
309 
310   SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
311   Optional<Dialect> dialect = findDialectToGenerate(dialects);
312   if (!dialect)
313     return true;
314   emitDialectDef(*dialect, os);
315   return false;
316 }
317 
318 //===----------------------------------------------------------------------===//
319 // GEN: Dialect registration hooks
320 //===----------------------------------------------------------------------===//
321 
322 static mlir::GenRegistration
323     genDialectDecls("gen-dialect-decls", "Generate dialect declarations",
__anon58492d6f0402(const llvm::RecordKeeper &records, raw_ostream &os) 324                     [](const llvm::RecordKeeper &records, raw_ostream &os) {
325                       return emitDialectDecls(records, os);
326                     });
327 
328 static mlir::GenRegistration
329     genDialectDefs("gen-dialect-defs", "Generate dialect definitions",
__anon58492d6f0502(const llvm::RecordKeeper &records, raw_ostream &os) 330                    [](const llvm::RecordKeeper &records, raw_ostream &os) {
331                      return emitDialectDefs(records, os);
332                    });
333