1 //===- CPPGen.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This files contains a PDLL generator that outputs C++ code that defines PDLL
10 // patterns as individual C++ PDLPatternModules for direct use in native code,
11 // and also defines any native constraints whose bodies were defined in PDLL.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "mlir/Tools/PDLL/CodeGen/CPPGen.h"
16 #include "mlir/Dialect/PDL/IR/PDL.h"
17 #include "mlir/Dialect/PDL/IR/PDLOps.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/Tools/PDLL/AST/Nodes.h"
20 #include "mlir/Tools/PDLL/ODS/Operation.h"
21 #include "llvm/ADT/SmallString.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringSet.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/FormatVariadic.h"
27
28 using namespace mlir;
29 using namespace mlir::pdll;
30
31 //===----------------------------------------------------------------------===//
32 // CodeGen
33 //===----------------------------------------------------------------------===//
34
35 namespace {
36 class CodeGen {
37 public:
CodeGen(raw_ostream & os)38 CodeGen(raw_ostream &os) : os(os) {}
39
40 /// Generate C++ code for the given PDL pattern module.
41 void generate(const ast::Module &astModule, ModuleOp module);
42
43 private:
44 void generate(pdl::PatternOp pattern, StringRef patternName,
45 StringSet<> &nativeFunctions);
46
47 /// Generate C++ code for all user defined constraints and rewrites with
48 /// native code.
49 void generateConstraintAndRewrites(const ast::Module &astModule,
50 ModuleOp module,
51 StringSet<> &nativeFunctions);
52 void generate(const ast::UserConstraintDecl *decl,
53 StringSet<> &nativeFunctions);
54 void generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions);
55 void generateConstraintOrRewrite(const ast::CallableDecl *decl,
56 bool isConstraint,
57 StringSet<> &nativeFunctions);
58
59 /// Return the native name for the type of the given type.
60 StringRef getNativeTypeName(ast::Type type);
61
62 /// Return the native name for the type of the given variable decl.
63 StringRef getNativeTypeName(ast::VariableDecl *decl);
64
65 /// The stream to output to.
66 raw_ostream &os;
67 };
68 } // namespace
69
generate(const ast::Module & astModule,ModuleOp module)70 void CodeGen::generate(const ast::Module &astModule, ModuleOp module) {
71 SetVector<std::string, SmallVector<std::string>, StringSet<>> patternNames;
72 StringSet<> nativeFunctions;
73
74 // Generate code for any native functions within the module.
75 generateConstraintAndRewrites(astModule, module, nativeFunctions);
76
77 os << "namespace {\n";
78 std::string basePatternName = "GeneratedPDLLPattern";
79 int patternIndex = 0;
80 for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
81 // If the pattern has a name, use that. Otherwise, generate a unique name.
82 if (Optional<StringRef> patternName = pattern.sym_name()) {
83 patternNames.insert(patternName->str());
84 } else {
85 std::string name;
86 do {
87 name = (basePatternName + Twine(patternIndex++)).str();
88 } while (!patternNames.insert(name));
89 }
90
91 generate(pattern, patternNames.back(), nativeFunctions);
92 }
93 os << "} // end namespace\n\n";
94
95 // Emit function to add the generated matchers to the pattern list.
96 os << "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
97 "::mlir::RewritePatternSet &patterns) {\n";
98 for (const auto &name : patternNames)
99 os << " patterns.add<" << name << ">(patterns.getContext());\n";
100 os << "}\n";
101 }
102
generate(pdl::PatternOp pattern,StringRef patternName,StringSet<> & nativeFunctions)103 void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName,
104 StringSet<> &nativeFunctions) {
105 const char *patternClassStartStr = R"(
106 struct {0} : ::mlir::PDLPatternModule {{
107 {0}(::mlir::MLIRContext *context)
108 : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
109 )";
110 os << llvm::formatv(patternClassStartStr, patternName);
111
112 os << "R\"mlir(";
113 pattern->print(os, OpPrintingFlags().enableDebugInfo());
114 os << "\n )mlir\", context)) {\n";
115
116 // Register any native functions used within the pattern.
117 StringSet<> registeredNativeFunctions;
118 auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) {
119 if (!nativeFunctions.count(fnName) ||
120 !registeredNativeFunctions.insert(fnName).second)
121 return;
122 os << " register" << fnType << "Function(\"" << fnName << "\", "
123 << fnName << "PDLFn);\n";
124 };
125 pattern.walk([&](Operation *op) {
126 if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op))
127 checkRegisterNativeFn(constraintOp.name(), "Constraint");
128 else if (auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op))
129 checkRegisterNativeFn(rewriteOp.name(), "Rewrite");
130 });
131 os << " }\n};\n\n";
132 }
133
generateConstraintAndRewrites(const ast::Module & astModule,ModuleOp module,StringSet<> & nativeFunctions)134 void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule,
135 ModuleOp module,
136 StringSet<> &nativeFunctions) {
137 // First check to see which constraints and rewrites are actually referenced
138 // in the module.
139 StringSet<> usedFns;
140 module.walk([&](Operation *op) {
141 TypeSwitch<Operation *>(op)
142 .Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>(
143 [&](auto op) { usedFns.insert(op.name()); });
144 });
145
146 for (const ast::Decl *decl : astModule.getChildren()) {
147 TypeSwitch<const ast::Decl *>(decl)
148 .Case<ast::UserConstraintDecl, ast::UserRewriteDecl>(
149 [&](const auto *decl) {
150 // We only generate code for inline native decls that have been
151 // referenced.
152 if (decl->getCodeBlock() &&
153 usedFns.contains(decl->getName().getName()))
154 this->generate(decl, nativeFunctions);
155 });
156 }
157 }
158
generate(const ast::UserConstraintDecl * decl,StringSet<> & nativeFunctions)159 void CodeGen::generate(const ast::UserConstraintDecl *decl,
160 StringSet<> &nativeFunctions) {
161 return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
162 /*isConstraint=*/true, nativeFunctions);
163 }
164
generate(const ast::UserRewriteDecl * decl,StringSet<> & nativeFunctions)165 void CodeGen::generate(const ast::UserRewriteDecl *decl,
166 StringSet<> &nativeFunctions) {
167 return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
168 /*isConstraint=*/false, nativeFunctions);
169 }
170
getNativeTypeName(ast::Type type)171 StringRef CodeGen::getNativeTypeName(ast::Type type) {
172 return llvm::TypeSwitch<ast::Type, StringRef>(type)
173 .Case([&](ast::AttributeType) { return "::mlir::Attribute"; })
174 .Case([&](ast::OperationType opType) -> StringRef {
175 // Use the derived Op class when available.
176 if (const auto *odsOp = opType.getODSOperation())
177 return odsOp->getNativeClassName();
178 return "::mlir::Operation *";
179 })
180 .Case([&](ast::TypeType) { return "::mlir::Type"; })
181 .Case([&](ast::ValueType) { return "::mlir::Value"; })
182 .Case([&](ast::TypeRangeType) { return "::mlir::TypeRange"; })
183 .Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; });
184 }
185
getNativeTypeName(ast::VariableDecl * decl)186 StringRef CodeGen::getNativeTypeName(ast::VariableDecl *decl) {
187 // Try to extract a type name from the variable's constraints.
188 for (ast::ConstraintRef &cst : decl->getConstraints()) {
189 if (auto *userCst = dyn_cast<ast::UserConstraintDecl>(cst.constraint)) {
190 if (Optional<StringRef> name = userCst->getNativeInputType(0))
191 return *name;
192 return getNativeTypeName(userCst->getInputs()[0]);
193 }
194 }
195
196 // Otherwise, use the type of the variable.
197 return getNativeTypeName(decl->getType());
198 }
199
generateConstraintOrRewrite(const ast::CallableDecl * decl,bool isConstraint,StringSet<> & nativeFunctions)200 void CodeGen::generateConstraintOrRewrite(const ast::CallableDecl *decl,
201 bool isConstraint,
202 StringSet<> &nativeFunctions) {
203 StringRef name = decl->getName()->getName();
204 nativeFunctions.insert(name);
205
206 os << "static ";
207
208 // TODO: Work out a proper modeling for "optionality".
209
210 // Emit the result type.
211 // If this is a constraint, we always return a LogicalResult.
212 // TODO: This will need to change if we allow Constraints to return values as
213 // well.
214 if (isConstraint) {
215 os << "::mlir::LogicalResult";
216 } else {
217 // Otherwise, generate a type based on the results of the callable.
218 // If the callable has explicit results, use those to build the result.
219 // Otherwise, use the type of the callable.
220 ArrayRef<ast::VariableDecl *> results = decl->getResults();
221 if (results.empty()) {
222 os << "void";
223 } else if (results.size() == 1) {
224 os << getNativeTypeName(results[0]);
225 } else {
226 os << "std::tuple<";
227 llvm::interleaveComma(results, os, [&](ast::VariableDecl *result) {
228 os << getNativeTypeName(result);
229 });
230 os << ">";
231 }
232 }
233
234 os << " " << name << "PDLFn(::mlir::PatternRewriter &rewriter";
235 if (!decl->getInputs().empty()) {
236 os << ", ";
237 llvm::interleaveComma(decl->getInputs(), os, [&](ast::VariableDecl *input) {
238 os << getNativeTypeName(input) << " " << input->getName().getName();
239 });
240 }
241 os << ") {\n";
242 os << " " << decl->getCodeBlock()->trim() << "\n}\n\n";
243 }
244
245 //===----------------------------------------------------------------------===//
246 // CPPGen
247 //===----------------------------------------------------------------------===//
248
codegenPDLLToCPP(const ast::Module & astModule,ModuleOp module,raw_ostream & os)249 void mlir::pdll::codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module,
250 raw_ostream &os) {
251 CodeGen codegen(os);
252 codegen.generate(astModule, module);
253 }
254