1 //===- CPPGen.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This files contains a PDLL generator that outputs C++ code that defines PDLL
10 // patterns as individual C++ PDLPatternModules for direct use in native code,
11 // and also defines any native constraints whose bodies were defined in PDLL.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Tools/PDLL/CodeGen/CPPGen.h"
16 #include "mlir/Dialect/PDL/IR/PDL.h"
17 #include "mlir/Dialect/PDL/IR/PDLOps.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/Tools/PDLL/AST/Nodes.h"
20 #include "mlir/Tools/PDLL/ODS/Operation.h"
21 #include "llvm/ADT/SmallString.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringSet.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/FormatVariadic.h"
27 
28 using namespace mlir;
29 using namespace mlir::pdll;
30 
31 //===----------------------------------------------------------------------===//
32 // CodeGen
33 //===----------------------------------------------------------------------===//
34 
35 namespace {
36 class CodeGen {
37 public:
CodeGen(raw_ostream & os)38   CodeGen(raw_ostream &os) : os(os) {}
39 
40   /// Generate C++ code for the given PDL pattern module.
41   void generate(const ast::Module &astModule, ModuleOp module);
42 
43 private:
44   void generate(pdl::PatternOp pattern, StringRef patternName,
45                 StringSet<> &nativeFunctions);
46 
47   /// Generate C++ code for all user defined constraints and rewrites with
48   /// native code.
49   void generateConstraintAndRewrites(const ast::Module &astModule,
50                                      ModuleOp module,
51                                      StringSet<> &nativeFunctions);
52   void generate(const ast::UserConstraintDecl *decl,
53                 StringSet<> &nativeFunctions);
54   void generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions);
55   void generateConstraintOrRewrite(const ast::CallableDecl *decl,
56                                    bool isConstraint,
57                                    StringSet<> &nativeFunctions);
58 
59   /// Return the native name for the type of the given type.
60   StringRef getNativeTypeName(ast::Type type);
61 
62   /// Return the native name for the type of the given variable decl.
63   StringRef getNativeTypeName(ast::VariableDecl *decl);
64 
65   /// The stream to output to.
66   raw_ostream &os;
67 };
68 } // namespace
69 
generate(const ast::Module & astModule,ModuleOp module)70 void CodeGen::generate(const ast::Module &astModule, ModuleOp module) {
71   SetVector<std::string, SmallVector<std::string>, StringSet<>> patternNames;
72   StringSet<> nativeFunctions;
73 
74   // Generate code for any native functions within the module.
75   generateConstraintAndRewrites(astModule, module, nativeFunctions);
76 
77   os << "namespace {\n";
78   std::string basePatternName = "GeneratedPDLLPattern";
79   int patternIndex = 0;
80   for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
81     // If the pattern has a name, use that. Otherwise, generate a unique name.
82     if (Optional<StringRef> patternName = pattern.sym_name()) {
83       patternNames.insert(patternName->str());
84     } else {
85       std::string name;
86       do {
87         name = (basePatternName + Twine(patternIndex++)).str();
88       } while (!patternNames.insert(name));
89     }
90 
91     generate(pattern, patternNames.back(), nativeFunctions);
92   }
93   os << "} // end namespace\n\n";
94 
95   // Emit function to add the generated matchers to the pattern list.
96   os << "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
97         "::mlir::RewritePatternSet &patterns) {\n";
98   for (const auto &name : patternNames)
99     os << "  patterns.add<" << name << ">(patterns.getContext());\n";
100   os << "}\n";
101 }
102 
generate(pdl::PatternOp pattern,StringRef patternName,StringSet<> & nativeFunctions)103 void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName,
104                        StringSet<> &nativeFunctions) {
105   const char *patternClassStartStr = R"(
106 struct {0} : ::mlir::PDLPatternModule {{
107   {0}(::mlir::MLIRContext *context)
108     : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
109 )";
110   os << llvm::formatv(patternClassStartStr, patternName);
111 
112   os << "R\"mlir(";
113   pattern->print(os, OpPrintingFlags().enableDebugInfo());
114   os << "\n    )mlir\", context)) {\n";
115 
116   // Register any native functions used within the pattern.
117   StringSet<> registeredNativeFunctions;
118   auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) {
119     if (!nativeFunctions.count(fnName) ||
120         !registeredNativeFunctions.insert(fnName).second)
121       return;
122     os << "    register" << fnType << "Function(\"" << fnName << "\", "
123        << fnName << "PDLFn);\n";
124   };
125   pattern.walk([&](Operation *op) {
126     if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op))
127       checkRegisterNativeFn(constraintOp.name(), "Constraint");
128     else if (auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op))
129       checkRegisterNativeFn(rewriteOp.name(), "Rewrite");
130   });
131   os << "  }\n};\n\n";
132 }
133 
generateConstraintAndRewrites(const ast::Module & astModule,ModuleOp module,StringSet<> & nativeFunctions)134 void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule,
135                                             ModuleOp module,
136                                             StringSet<> &nativeFunctions) {
137   // First check to see which constraints and rewrites are actually referenced
138   // in the module.
139   StringSet<> usedFns;
140   module.walk([&](Operation *op) {
141     TypeSwitch<Operation *>(op)
142         .Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>(
143             [&](auto op) { usedFns.insert(op.name()); });
144   });
145 
146   for (const ast::Decl *decl : astModule.getChildren()) {
147     TypeSwitch<const ast::Decl *>(decl)
148         .Case<ast::UserConstraintDecl, ast::UserRewriteDecl>(
149             [&](const auto *decl) {
150               // We only generate code for inline native decls that have been
151               // referenced.
152               if (decl->getCodeBlock() &&
153                   usedFns.contains(decl->getName().getName()))
154                 this->generate(decl, nativeFunctions);
155             });
156   }
157 }
158 
generate(const ast::UserConstraintDecl * decl,StringSet<> & nativeFunctions)159 void CodeGen::generate(const ast::UserConstraintDecl *decl,
160                        StringSet<> &nativeFunctions) {
161   return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
162                                      /*isConstraint=*/true, nativeFunctions);
163 }
164 
generate(const ast::UserRewriteDecl * decl,StringSet<> & nativeFunctions)165 void CodeGen::generate(const ast::UserRewriteDecl *decl,
166                        StringSet<> &nativeFunctions) {
167   return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
168                                      /*isConstraint=*/false, nativeFunctions);
169 }
170 
getNativeTypeName(ast::Type type)171 StringRef CodeGen::getNativeTypeName(ast::Type type) {
172   return llvm::TypeSwitch<ast::Type, StringRef>(type)
173       .Case([&](ast::AttributeType) { return "::mlir::Attribute"; })
174       .Case([&](ast::OperationType opType) -> StringRef {
175         // Use the derived Op class when available.
176         if (const auto *odsOp = opType.getODSOperation())
177           return odsOp->getNativeClassName();
178         return "::mlir::Operation *";
179       })
180       .Case([&](ast::TypeType) { return "::mlir::Type"; })
181       .Case([&](ast::ValueType) { return "::mlir::Value"; })
182       .Case([&](ast::TypeRangeType) { return "::mlir::TypeRange"; })
183       .Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; });
184 }
185 
getNativeTypeName(ast::VariableDecl * decl)186 StringRef CodeGen::getNativeTypeName(ast::VariableDecl *decl) {
187   // Try to extract a type name from the variable's constraints.
188   for (ast::ConstraintRef &cst : decl->getConstraints()) {
189     if (auto *userCst = dyn_cast<ast::UserConstraintDecl>(cst.constraint)) {
190       if (Optional<StringRef> name = userCst->getNativeInputType(0))
191         return *name;
192       return getNativeTypeName(userCst->getInputs()[0]);
193     }
194   }
195 
196   // Otherwise, use the type of the variable.
197   return getNativeTypeName(decl->getType());
198 }
199 
generateConstraintOrRewrite(const ast::CallableDecl * decl,bool isConstraint,StringSet<> & nativeFunctions)200 void CodeGen::generateConstraintOrRewrite(const ast::CallableDecl *decl,
201                                           bool isConstraint,
202                                           StringSet<> &nativeFunctions) {
203   StringRef name = decl->getName()->getName();
204   nativeFunctions.insert(name);
205 
206   os << "static ";
207 
208   // TODO: Work out a proper modeling for "optionality".
209 
210   // Emit the result type.
211   // If this is a constraint, we always return a LogicalResult.
212   // TODO: This will need to change if we allow Constraints to return values as
213   // well.
214   if (isConstraint) {
215     os << "::mlir::LogicalResult";
216   } else {
217     // Otherwise, generate a type based on the results of the callable.
218     // If the callable has explicit results, use those to build the result.
219     // Otherwise, use the type of the callable.
220     ArrayRef<ast::VariableDecl *> results = decl->getResults();
221     if (results.empty()) {
222       os << "void";
223     } else if (results.size() == 1) {
224       os << getNativeTypeName(results[0]);
225     } else {
226       os << "std::tuple<";
227       llvm::interleaveComma(results, os, [&](ast::VariableDecl *result) {
228         os << getNativeTypeName(result);
229       });
230       os << ">";
231     }
232   }
233 
234   os << " " << name << "PDLFn(::mlir::PatternRewriter &rewriter";
235   if (!decl->getInputs().empty()) {
236     os << ", ";
237     llvm::interleaveComma(decl->getInputs(), os, [&](ast::VariableDecl *input) {
238       os << getNativeTypeName(input) << " " << input->getName().getName();
239     });
240   }
241   os << ") {\n";
242   os << "  " << decl->getCodeBlock()->trim() << "\n}\n\n";
243 }
244 
245 //===----------------------------------------------------------------------===//
246 // CPPGen
247 //===----------------------------------------------------------------------===//
248 
codegenPDLLToCPP(const ast::Module & astModule,ModuleOp module,raw_ostream & os)249 void mlir::pdll::codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module,
250                                   raw_ostream &os) {
251   CodeGen codegen(os);
252   codegen.generate(astModule, module);
253 }
254