1 //===- CPPGen.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This files contains a PDLL generator that outputs C++ code that defines PDLL
10 // patterns as individual C++ PDLPatternModules for direct use in native code,
11 // and also defines any native constraints whose bodies were defined in PDLL.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Tools/PDLL/CodeGen/CPPGen.h"
16 #include "mlir/Dialect/PDL/IR/PDL.h"
17 #include "mlir/Dialect/PDL/IR/PDLOps.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/Tools/PDLL/AST/Nodes.h"
20 #include "llvm/ADT/SmallString.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringSet.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/ErrorHandling.h"
25 #include "llvm/Support/FormatVariadic.h"
26 
27 using namespace mlir;
28 using namespace mlir::pdll;
29 
30 //===----------------------------------------------------------------------===//
31 // CodeGen
32 //===----------------------------------------------------------------------===//
33 
34 namespace {
35 class CodeGen {
36 public:
37   CodeGen(raw_ostream &os) : os(os) {}
38 
39   /// Generate C++ code for the given PDL pattern module.
40   void generate(const ast::Module &astModule, ModuleOp module);
41 
42 private:
43   void generate(pdl::PatternOp pattern, StringRef patternName,
44                 StringSet<> &nativeFunctions);
45 
46   /// Generate C++ code for all user defined constraints and rewrites with
47   /// native code.
48   void generateConstraintAndRewrites(const ast::Module &astModule,
49                                      ModuleOp module,
50                                      StringSet<> &nativeFunctions);
51   void generate(const ast::UserConstraintDecl *decl,
52                 StringSet<> &nativeFunctions);
53   void generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions);
54   void generateConstraintOrRewrite(StringRef name, bool isConstraint,
55                                    ArrayRef<ast::VariableDecl *> inputs,
56                                    StringRef codeBlock,
57                                    StringSet<> &nativeFunctions);
58 
59   /// The stream to output to.
60   raw_ostream &os;
61 };
62 } // namespace
63 
64 void CodeGen::generate(const ast::Module &astModule, ModuleOp module) {
65   SetVector<std::string, SmallVector<std::string>, StringSet<>> patternNames;
66   StringSet<> nativeFunctions;
67 
68   // Generate code for any native functions within the module.
69   generateConstraintAndRewrites(astModule, module, nativeFunctions);
70 
71   os << "namespace {\n";
72   std::string basePatternName = "GeneratedPDLLPattern";
73   int patternIndex = 0;
74   for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
75     // If the pattern has a name, use that. Otherwise, generate a unique name.
76     if (Optional<StringRef> patternName = pattern.sym_name()) {
77       patternNames.insert(patternName->str());
78     } else {
79       std::string name;
80       do {
81         name = (basePatternName + Twine(patternIndex++)).str();
82       } while (!patternNames.insert(name));
83     }
84 
85     generate(pattern, patternNames.back(), nativeFunctions);
86   }
87   os << "} // end namespace\n\n";
88 
89   // Emit function to add the generated matchers to the pattern list.
90   os << "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
91         "::mlir::RewritePatternSet &patterns) {\n";
92   for (const auto &name : patternNames)
93     os << "  patterns.add<" << name << ">(patterns.getContext());\n";
94   os << "}\n";
95 }
96 
97 void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName,
98                        StringSet<> &nativeFunctions) {
99   const char *patternClassStartStr = R"(
100 struct {0} : ::mlir::PDLPatternModule {{
101   {0}(::mlir::MLIRContext *context)
102     : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
103 )";
104   os << llvm::formatv(patternClassStartStr, patternName);
105 
106   os << "R\"mlir(";
107   pattern->print(os, OpPrintingFlags().enableDebugInfo());
108   os << "\n    )mlir\", context)) {\n";
109 
110   // Register any native functions used within the pattern.
111   StringSet<> registeredNativeFunctions;
112   auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) {
113     if (!nativeFunctions.count(fnName) ||
114         !registeredNativeFunctions.insert(fnName).second)
115       return;
116     os << "    register" << fnType << "Function(\"" << fnName << "\", "
117        << fnName << "PDLFn);\n";
118   };
119   pattern.walk([&](Operation *op) {
120     if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op))
121       checkRegisterNativeFn(constraintOp.name(), "Constraint");
122     else if (auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op))
123       checkRegisterNativeFn(rewriteOp.name(), "Rewrite");
124   });
125   os << "  }\n};\n\n";
126 }
127 
128 void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule,
129                                             ModuleOp module,
130                                             StringSet<> &nativeFunctions) {
131   // First check to see which constraints and rewrites are actually referenced
132   // in the module.
133   StringSet<> usedFns;
134   module.walk([&](Operation *op) {
135     TypeSwitch<Operation *>(op)
136         .Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>(
137             [&](auto op) { usedFns.insert(op.name()); });
138   });
139 
140   for (const ast::Decl *decl : astModule.getChildren()) {
141     TypeSwitch<const ast::Decl *>(decl)
142         .Case<ast::UserConstraintDecl, ast::UserRewriteDecl>(
143             [&](const auto *decl) {
144               // We only generate code for inline native decls that have been
145               // referenced.
146               if (decl->getCodeBlock() &&
147                   usedFns.contains(decl->getName().getName()))
148                 this->generate(decl, nativeFunctions);
149             });
150   }
151 }
152 
153 void CodeGen::generate(const ast::UserConstraintDecl *decl,
154                        StringSet<> &nativeFunctions) {
155   return generateConstraintOrRewrite(decl->getName().getName(),
156                                      /*isConstraint=*/true, decl->getInputs(),
157                                      *decl->getCodeBlock(), nativeFunctions);
158 }
159 
160 void CodeGen::generate(const ast::UserRewriteDecl *decl,
161                        StringSet<> &nativeFunctions) {
162   return generateConstraintOrRewrite(decl->getName().getName(),
163                                      /*isConstraint=*/false, decl->getInputs(),
164                                      *decl->getCodeBlock(), nativeFunctions);
165 }
166 
167 void CodeGen::generateConstraintOrRewrite(StringRef name, bool isConstraint,
168                                           ArrayRef<ast::VariableDecl *> inputs,
169                                           StringRef codeBlock,
170                                           StringSet<> &nativeFunctions) {
171   nativeFunctions.insert(name);
172 
173   // TODO: Should there be something explicit for handling optionality?
174   auto getCppType = [&](ast::Type type) -> StringRef {
175     return llvm::TypeSwitch<ast::Type, StringRef>(type)
176         .Case([&](ast::AttributeType) { return "::mlir::Attribute"; })
177         .Case([&](ast::OperationType) {
178           // TODO: Allow using the derived Op class when possible.
179           return "::mlir::Operation *";
180         })
181         .Case([&](ast::TypeType) { return "::mlir::Type"; })
182         .Case([&](ast::ValueType) { return "::mlir::Value"; })
183         .Case([&](ast::TypeRangeType) { return "::mlir::TypeRange"; })
184         .Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; });
185   };
186 
187   // FIXME: We currently do not have a modeling for the "constant params"
188   // support PDL provides. We should either figure out a modeling for this, or
189   // refactor the support within PDL to be something a bit more reasonable for
190   // what we need as a frontend.
191   os << "static " << (isConstraint ? "::mlir::LogicalResult " : "void ") << name
192      << "PDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, "
193         "::mlir::ArrayAttr constParams, ::mlir::PatternRewriter &rewriter"
194      << (isConstraint ? "" : ", ::mlir::PDLResultList &results") << ") {\n";
195 
196   const char *argumentInitStr = R"(
197   {0} {1} = {{};
198   if (values[{2}])
199     {1} = values[{2}].cast<{0}>();
200   (void){1};
201 )";
202   for (const auto &it : llvm::enumerate(inputs)) {
203     const ast::VariableDecl *input = it.value();
204     os << llvm::formatv(argumentInitStr, getCppType(input->getType()),
205                         input->getName().getName(), it.index());
206   }
207 
208   os << "  " << codeBlock.trim() << "\n}\n";
209 }
210 
211 //===----------------------------------------------------------------------===//
212 // CPPGen
213 //===----------------------------------------------------------------------===//
214 
215 void mlir::pdll::codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module,
216                                   raw_ostream &os) {
217   CodeGen codegen(os);
218   codegen.generate(astModule, module);
219 }
220