1 //===- DialectGen.cpp - MLIR dialect definitions generator ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // DialectGen uses the description of dialects to generate C++ definitions.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "DialectGenUtilities.h"
14 #include "mlir/TableGen/Class.h"
15 #include "mlir/TableGen/CodeGenHelpers.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/GenInfo.h"
18 #include "mlir/TableGen/Interfaces.h"
19 #include "mlir/TableGen/Operator.h"
20 #include "mlir/TableGen/Trait.h"
21 #include "llvm/ADT/Optional.h"
22 #include "llvm/ADT/Sequence.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/Signals.h"
26 #include "llvm/TableGen/Error.h"
27 #include "llvm/TableGen/Record.h"
28 #include "llvm/TableGen/TableGenBackend.h"
29
30 #define DEBUG_TYPE "mlir-tblgen-opdefgen"
31
32 using namespace mlir;
33 using namespace mlir::tblgen;
34
35 static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*");
36 llvm::cl::opt<std::string>
37 selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
38 llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
39
40 /// Utility iterator used for filtering records for a specific dialect.
41 namespace {
42 using DialectFilterIterator =
43 llvm::filter_iterator<ArrayRef<llvm::Record *>::iterator,
44 std::function<bool(const llvm::Record *)>>;
45 } // namespace
46
47 /// Given a set of records for a T, filter the ones that correspond to
48 /// the given dialect.
49 template <typename T>
50 static iterator_range<DialectFilterIterator>
filterForDialect(ArrayRef<llvm::Record * > records,Dialect & dialect)51 filterForDialect(ArrayRef<llvm::Record *> records, Dialect &dialect) {
52 auto filterFn = [&](const llvm::Record *record) {
53 return T(record).getDialect() == dialect;
54 };
55 return {DialectFilterIterator(records.begin(), records.end(), filterFn),
56 DialectFilterIterator(records.end(), records.end(), filterFn)};
57 }
58
findDialectToGenerate(ArrayRef<Dialect> dialects)59 Optional<Dialect> tblgen::findDialectToGenerate(ArrayRef<Dialect> dialects) {
60 if (dialects.empty()) {
61 llvm::errs() << "no dialect was found\n";
62 return llvm::None;
63 }
64
65 // Select the dialect to gen for.
66 if (dialects.size() == 1 && selectedDialect.getNumOccurrences() == 0)
67 return dialects.front();
68
69 if (selectedDialect.getNumOccurrences() == 0) {
70 llvm::errs() << "when more than 1 dialect is present, one must be selected "
71 "via '-dialect'\n";
72 return llvm::None;
73 }
74
75 const auto *dialectIt = llvm::find_if(dialects, [](const Dialect &dialect) {
76 return dialect.getName() == selectedDialect;
77 });
78 if (dialectIt == dialects.end()) {
79 llvm::errs() << "selected dialect with '-dialect' does not exist\n";
80 return llvm::None;
81 }
82 return *dialectIt;
83 }
84
85 //===----------------------------------------------------------------------===//
86 // GEN: Dialect declarations
87 //===----------------------------------------------------------------------===//
88
89 /// The code block for the start of a dialect class declaration.
90 ///
91 /// {0}: The name of the dialect class.
92 /// {1}: The dialect namespace.
93 /// {2}: The dialect parent class.
94 static const char *const dialectDeclBeginStr = R"(
95 class {0} : public ::mlir::{2} {
96 explicit {0}(::mlir::MLIRContext *context);
97
98 void initialize();
99 friend class ::mlir::MLIRContext;
100 public:
101 ~{0}() override;
102 static constexpr ::llvm::StringLiteral getDialectNamespace() {
103 return ::llvm::StringLiteral("{1}");
104 }
105 )";
106
107 /// Registration for a single dependent dialect: to be inserted in the ctor
108 /// above for each dependent dialect.
109 const char *const dialectRegistrationTemplate = R"(
110 getContext()->getOrLoadDialect<{0}>();
111 )";
112
113 /// The code block for the attribute parser/printer hooks.
114 static const char *const attrParserDecl = R"(
115 /// Parse an attribute registered to this dialect.
116 ::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser,
117 ::mlir::Type type) const override;
118
119 /// Print an attribute registered to this dialect.
120 void printAttribute(::mlir::Attribute attr,
121 ::mlir::DialectAsmPrinter &os) const override;
122 )";
123
124 /// The code block for the type parser/printer hooks.
125 static const char *const typeParserDecl = R"(
126 /// Parse a type registered to this dialect.
127 ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;
128
129 /// Print a type registered to this dialect.
130 void printType(::mlir::Type type,
131 ::mlir::DialectAsmPrinter &os) const override;
132 )";
133
134 /// The code block for the canonicalization pattern registration hook.
135 static const char *const canonicalizerDecl = R"(
136 /// Register canonicalization patterns.
137 void getCanonicalizationPatterns(
138 ::mlir::RewritePatternSet &results) const override;
139 )";
140
141 /// The code block for the constant materializer hook.
142 static const char *const constantMaterializerDecl = R"(
143 /// Materialize a single constant operation from a given attribute value with
144 /// the desired resultant type.
145 ::mlir::Operation *materializeConstant(::mlir::OpBuilder &builder,
146 ::mlir::Attribute value,
147 ::mlir::Type type,
148 ::mlir::Location loc) override;
149 )";
150
151 /// The code block for the operation attribute verifier hook.
152 static const char *const opAttrVerifierDecl = R"(
153 /// Provides a hook for verifying dialect attributes attached to the given
154 /// op.
155 ::mlir::LogicalResult verifyOperationAttribute(
156 ::mlir::Operation *op, ::mlir::NamedAttribute attribute) override;
157 )";
158
159 /// The code block for the region argument attribute verifier hook.
160 static const char *const regionArgAttrVerifierDecl = R"(
161 /// Provides a hook for verifying dialect attributes attached to the given
162 /// op's region argument.
163 ::mlir::LogicalResult verifyRegionArgAttribute(
164 ::mlir::Operation *op, unsigned regionIndex, unsigned argIndex,
165 ::mlir::NamedAttribute attribute) override;
166 )";
167
168 /// The code block for the region result attribute verifier hook.
169 static const char *const regionResultAttrVerifierDecl = R"(
170 /// Provides a hook for verifying dialect attributes attached to the given
171 /// op's region result.
172 ::mlir::LogicalResult verifyRegionResultAttribute(
173 ::mlir::Operation *op, unsigned regionIndex, unsigned resultIndex,
174 ::mlir::NamedAttribute attribute) override;
175 )";
176
177 /// The code block for the op interface fallback hook.
178 static const char *const operationInterfaceFallbackDecl = R"(
179 /// Provides a hook for op interface.
180 void *getRegisteredInterfaceForOp(mlir::TypeID interfaceID,
181 mlir::OperationName opName) override;
182 )";
183
184 /// Generate the declaration for the given dialect class.
emitDialectDecl(Dialect & dialect,raw_ostream & os)185 static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
186 // Emit all nested namespaces.
187 {
188 NamespaceEmitter nsEmitter(os, dialect);
189
190 // Emit the start of the decl.
191 std::string cppName = dialect.getCppClassName();
192 StringRef superClassName =
193 dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
194 os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(),
195 superClassName);
196
197 // If the dialect requested the default attribute printer and parser, emit
198 // the declarations for the hooks.
199 if (dialect.useDefaultAttributePrinterParser())
200 os << attrParserDecl;
201 // If the dialect requested the default type printer and parser, emit the
202 // delcarations for the hooks.
203 if (dialect.useDefaultTypePrinterParser())
204 os << typeParserDecl;
205
206 // Add the decls for the various features of the dialect.
207 if (dialect.hasCanonicalizer())
208 os << canonicalizerDecl;
209 if (dialect.hasConstantMaterializer())
210 os << constantMaterializerDecl;
211 if (dialect.hasOperationAttrVerify())
212 os << opAttrVerifierDecl;
213 if (dialect.hasRegionArgAttrVerify())
214 os << regionArgAttrVerifierDecl;
215 if (dialect.hasRegionResultAttrVerify())
216 os << regionResultAttrVerifierDecl;
217 if (dialect.hasOperationInterfaceFallback())
218 os << operationInterfaceFallbackDecl;
219 if (llvm::Optional<StringRef> extraDecl =
220 dialect.getExtraClassDeclaration())
221 os << *extraDecl;
222
223 // End the dialect decl.
224 os << "};\n";
225 }
226 if (!dialect.getCppNamespace().empty())
227 os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace()
228 << "::" << dialect.getCppClassName() << ")\n";
229 }
230
emitDialectDecls(const llvm::RecordKeeper & recordKeeper,raw_ostream & os)231 static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
232 raw_ostream &os) {
233 emitSourceFileHeader("Dialect Declarations", os);
234
235 auto dialectDefs = recordKeeper.getAllDerivedDefinitions("Dialect");
236 if (dialectDefs.empty())
237 return false;
238
239 SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
240 Optional<Dialect> dialect = findDialectToGenerate(dialects);
241 if (!dialect)
242 return true;
243 emitDialectDecl(*dialect, os);
244 return false;
245 }
246
247 //===----------------------------------------------------------------------===//
248 // GEN: Dialect definitions
249 //===----------------------------------------------------------------------===//
250
251 /// The code block to generate a dialect constructor definition.
252 ///
253 /// {0}: The name of the dialect class.
254 /// {1}: initialization code that is emitted in the ctor body before calling
255 /// initialize().
256 /// {2}: The dialect parent class.
257 static const char *const dialectConstructorStr = R"(
258 {0}::{0}(::mlir::MLIRContext *context)
259 : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{
260 {1}
261 initialize();
262 }
263 )";
264
265 /// The code block to generate a default desturctor definition.
266 ///
267 /// {0}: The name of the dialect class.
268 static const char *const dialectDestructorStr = R"(
269 {0}::~{0}() = default;
270
271 )";
272
emitDialectDef(Dialect & dialect,raw_ostream & os)273 static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
274 std::string cppClassName = dialect.getCppClassName();
275
276 // Emit the TypeID explicit specializations to have a single symbol def.
277 if (!dialect.getCppNamespace().empty())
278 os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace()
279 << "::" << cppClassName << ")\n";
280
281 // Emit all nested namespaces.
282 NamespaceEmitter nsEmitter(os, dialect);
283
284 /// Build the list of dependent dialects.
285 std::string dependentDialectRegistrations;
286 {
287 llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
288 for (StringRef dependentDialect : dialect.getDependentDialects())
289 dialectsOs << llvm::formatv(dialectRegistrationTemplate,
290 dependentDialect);
291 }
292
293 // Emit the constructor and destructor.
294 StringRef superClassName =
295 dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
296 os << llvm::formatv(dialectConstructorStr, cppClassName,
297 dependentDialectRegistrations, superClassName);
298 if (!dialect.hasNonDefaultDestructor())
299 os << llvm::formatv(dialectDestructorStr, cppClassName);
300 }
301
emitDialectDefs(const llvm::RecordKeeper & recordKeeper,raw_ostream & os)302 static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
303 raw_ostream &os) {
304 emitSourceFileHeader("Dialect Definitions", os);
305
306 auto dialectDefs = recordKeeper.getAllDerivedDefinitions("Dialect");
307 if (dialectDefs.empty())
308 return false;
309
310 SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
311 Optional<Dialect> dialect = findDialectToGenerate(dialects);
312 if (!dialect)
313 return true;
314 emitDialectDef(*dialect, os);
315 return false;
316 }
317
318 //===----------------------------------------------------------------------===//
319 // GEN: Dialect registration hooks
320 //===----------------------------------------------------------------------===//
321
322 static mlir::GenRegistration
323 genDialectDecls("gen-dialect-decls", "Generate dialect declarations",
__anon58492d6f0402(const llvm::RecordKeeper &records, raw_ostream &os) 324 [](const llvm::RecordKeeper &records, raw_ostream &os) {
325 return emitDialectDecls(records, os);
326 });
327
328 static mlir::GenRegistration
329 genDialectDefs("gen-dialect-defs", "Generate dialect definitions",
__anon58492d6f0502(const llvm::RecordKeeper &records, raw_ostream &os) 330 [](const llvm::RecordKeeper &records, raw_ostream &os) {
331 return emitDialectDefs(records, os);
332 });
333