19ad64a5cSRiver Riddle //===- CPPGen.cpp ---------------------------------------------------------===//
29ad64a5cSRiver Riddle //
39ad64a5cSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
49ad64a5cSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
59ad64a5cSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69ad64a5cSRiver Riddle //
79ad64a5cSRiver Riddle //===----------------------------------------------------------------------===//
89ad64a5cSRiver Riddle //
99ad64a5cSRiver Riddle // This files contains a PDLL generator that outputs C++ code that defines PDLL
109ad64a5cSRiver Riddle // patterns as individual C++ PDLPatternModules for direct use in native code,
119ad64a5cSRiver Riddle // and also defines any native constraints whose bodies were defined in PDLL.
129ad64a5cSRiver Riddle //
139ad64a5cSRiver Riddle //===----------------------------------------------------------------------===//
149ad64a5cSRiver Riddle 
159ad64a5cSRiver Riddle #include "mlir/Tools/PDLL/CodeGen/CPPGen.h"
169ad64a5cSRiver Riddle #include "mlir/Dialect/PDL/IR/PDL.h"
179ad64a5cSRiver Riddle #include "mlir/Dialect/PDL/IR/PDLOps.h"
189ad64a5cSRiver Riddle #include "mlir/IR/BuiltinOps.h"
199ad64a5cSRiver Riddle #include "mlir/Tools/PDLL/AST/Nodes.h"
20*1c2edb02SRiver Riddle #include "mlir/Tools/PDLL/ODS/Operation.h"
219ad64a5cSRiver Riddle #include "llvm/ADT/SmallString.h"
229ad64a5cSRiver Riddle #include "llvm/ADT/StringExtras.h"
239ad64a5cSRiver Riddle #include "llvm/ADT/StringSet.h"
249ad64a5cSRiver Riddle #include "llvm/ADT/TypeSwitch.h"
259ad64a5cSRiver Riddle #include "llvm/Support/ErrorHandling.h"
269ad64a5cSRiver Riddle #include "llvm/Support/FormatVariadic.h"
279ad64a5cSRiver Riddle 
289ad64a5cSRiver Riddle using namespace mlir;
299ad64a5cSRiver Riddle using namespace mlir::pdll;
309ad64a5cSRiver Riddle 
319ad64a5cSRiver Riddle //===----------------------------------------------------------------------===//
329ad64a5cSRiver Riddle // CodeGen
339ad64a5cSRiver Riddle //===----------------------------------------------------------------------===//
349ad64a5cSRiver Riddle 
359ad64a5cSRiver Riddle namespace {
369ad64a5cSRiver Riddle class CodeGen {
379ad64a5cSRiver Riddle public:
CodeGen(raw_ostream & os)389ad64a5cSRiver Riddle   CodeGen(raw_ostream &os) : os(os) {}
399ad64a5cSRiver Riddle 
409ad64a5cSRiver Riddle   /// Generate C++ code for the given PDL pattern module.
419ad64a5cSRiver Riddle   void generate(const ast::Module &astModule, ModuleOp module);
429ad64a5cSRiver Riddle 
439ad64a5cSRiver Riddle private:
449ad64a5cSRiver Riddle   void generate(pdl::PatternOp pattern, StringRef patternName,
459ad64a5cSRiver Riddle                 StringSet<> &nativeFunctions);
469ad64a5cSRiver Riddle 
479ad64a5cSRiver Riddle   /// Generate C++ code for all user defined constraints and rewrites with
489ad64a5cSRiver Riddle   /// native code.
499ad64a5cSRiver Riddle   void generateConstraintAndRewrites(const ast::Module &astModule,
509ad64a5cSRiver Riddle                                      ModuleOp module,
519ad64a5cSRiver Riddle                                      StringSet<> &nativeFunctions);
529ad64a5cSRiver Riddle   void generate(const ast::UserConstraintDecl *decl,
539ad64a5cSRiver Riddle                 StringSet<> &nativeFunctions);
549ad64a5cSRiver Riddle   void generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions);
55*1c2edb02SRiver Riddle   void generateConstraintOrRewrite(const ast::CallableDecl *decl,
56*1c2edb02SRiver Riddle                                    bool isConstraint,
579ad64a5cSRiver Riddle                                    StringSet<> &nativeFunctions);
589ad64a5cSRiver Riddle 
59*1c2edb02SRiver Riddle   /// Return the native name for the type of the given type.
60*1c2edb02SRiver Riddle   StringRef getNativeTypeName(ast::Type type);
61*1c2edb02SRiver Riddle 
62*1c2edb02SRiver Riddle   /// Return the native name for the type of the given variable decl.
63*1c2edb02SRiver Riddle   StringRef getNativeTypeName(ast::VariableDecl *decl);
64*1c2edb02SRiver Riddle 
659ad64a5cSRiver Riddle   /// The stream to output to.
669ad64a5cSRiver Riddle   raw_ostream &os;
679ad64a5cSRiver Riddle };
689ad64a5cSRiver Riddle } // namespace
699ad64a5cSRiver Riddle 
generate(const ast::Module & astModule,ModuleOp module)709ad64a5cSRiver Riddle void CodeGen::generate(const ast::Module &astModule, ModuleOp module) {
719ad64a5cSRiver Riddle   SetVector<std::string, SmallVector<std::string>, StringSet<>> patternNames;
729ad64a5cSRiver Riddle   StringSet<> nativeFunctions;
739ad64a5cSRiver Riddle 
749ad64a5cSRiver Riddle   // Generate code for any native functions within the module.
759ad64a5cSRiver Riddle   generateConstraintAndRewrites(astModule, module, nativeFunctions);
769ad64a5cSRiver Riddle 
779ad64a5cSRiver Riddle   os << "namespace {\n";
789ad64a5cSRiver Riddle   std::string basePatternName = "GeneratedPDLLPattern";
799ad64a5cSRiver Riddle   int patternIndex = 0;
809ad64a5cSRiver Riddle   for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
819ad64a5cSRiver Riddle     // If the pattern has a name, use that. Otherwise, generate a unique name.
829ad64a5cSRiver Riddle     if (Optional<StringRef> patternName = pattern.sym_name()) {
839ad64a5cSRiver Riddle       patternNames.insert(patternName->str());
849ad64a5cSRiver Riddle     } else {
859ad64a5cSRiver Riddle       std::string name;
869ad64a5cSRiver Riddle       do {
879ad64a5cSRiver Riddle         name = (basePatternName + Twine(patternIndex++)).str();
889ad64a5cSRiver Riddle       } while (!patternNames.insert(name));
899ad64a5cSRiver Riddle     }
909ad64a5cSRiver Riddle 
919ad64a5cSRiver Riddle     generate(pattern, patternNames.back(), nativeFunctions);
929ad64a5cSRiver Riddle   }
939ad64a5cSRiver Riddle   os << "} // end namespace\n\n";
949ad64a5cSRiver Riddle 
959ad64a5cSRiver Riddle   // Emit function to add the generated matchers to the pattern list.
969ad64a5cSRiver Riddle   os << "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
979ad64a5cSRiver Riddle         "::mlir::RewritePatternSet &patterns) {\n";
989ad64a5cSRiver Riddle   for (const auto &name : patternNames)
999ad64a5cSRiver Riddle     os << "  patterns.add<" << name << ">(patterns.getContext());\n";
1009ad64a5cSRiver Riddle   os << "}\n";
1019ad64a5cSRiver Riddle }
1029ad64a5cSRiver Riddle 
generate(pdl::PatternOp pattern,StringRef patternName,StringSet<> & nativeFunctions)1039ad64a5cSRiver Riddle void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName,
1049ad64a5cSRiver Riddle                        StringSet<> &nativeFunctions) {
1059ad64a5cSRiver Riddle   const char *patternClassStartStr = R"(
1069ad64a5cSRiver Riddle struct {0} : ::mlir::PDLPatternModule {{
1079ad64a5cSRiver Riddle   {0}(::mlir::MLIRContext *context)
1089ad64a5cSRiver Riddle     : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
1099ad64a5cSRiver Riddle )";
1109ad64a5cSRiver Riddle   os << llvm::formatv(patternClassStartStr, patternName);
1119ad64a5cSRiver Riddle 
1129ad64a5cSRiver Riddle   os << "R\"mlir(";
1139ad64a5cSRiver Riddle   pattern->print(os, OpPrintingFlags().enableDebugInfo());
1149ad64a5cSRiver Riddle   os << "\n    )mlir\", context)) {\n";
1159ad64a5cSRiver Riddle 
1169ad64a5cSRiver Riddle   // Register any native functions used within the pattern.
1179ad64a5cSRiver Riddle   StringSet<> registeredNativeFunctions;
1189ad64a5cSRiver Riddle   auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) {
1199ad64a5cSRiver Riddle     if (!nativeFunctions.count(fnName) ||
1209ad64a5cSRiver Riddle         !registeredNativeFunctions.insert(fnName).second)
1219ad64a5cSRiver Riddle       return;
1229ad64a5cSRiver Riddle     os << "    register" << fnType << "Function(\"" << fnName << "\", "
1239ad64a5cSRiver Riddle        << fnName << "PDLFn);\n";
1249ad64a5cSRiver Riddle   };
1259ad64a5cSRiver Riddle   pattern.walk([&](Operation *op) {
1269ad64a5cSRiver Riddle     if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op))
1279ad64a5cSRiver Riddle       checkRegisterNativeFn(constraintOp.name(), "Constraint");
1289ad64a5cSRiver Riddle     else if (auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op))
1299ad64a5cSRiver Riddle       checkRegisterNativeFn(rewriteOp.name(), "Rewrite");
1309ad64a5cSRiver Riddle   });
1319ad64a5cSRiver Riddle   os << "  }\n};\n\n";
1329ad64a5cSRiver Riddle }
1339ad64a5cSRiver Riddle 
generateConstraintAndRewrites(const ast::Module & astModule,ModuleOp module,StringSet<> & nativeFunctions)1349ad64a5cSRiver Riddle void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule,
1359ad64a5cSRiver Riddle                                             ModuleOp module,
1369ad64a5cSRiver Riddle                                             StringSet<> &nativeFunctions) {
1379ad64a5cSRiver Riddle   // First check to see which constraints and rewrites are actually referenced
1389ad64a5cSRiver Riddle   // in the module.
1399ad64a5cSRiver Riddle   StringSet<> usedFns;
1409ad64a5cSRiver Riddle   module.walk([&](Operation *op) {
1419ad64a5cSRiver Riddle     TypeSwitch<Operation *>(op)
1429ad64a5cSRiver Riddle         .Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>(
1439ad64a5cSRiver Riddle             [&](auto op) { usedFns.insert(op.name()); });
1449ad64a5cSRiver Riddle   });
1459ad64a5cSRiver Riddle 
1469ad64a5cSRiver Riddle   for (const ast::Decl *decl : astModule.getChildren()) {
1479ad64a5cSRiver Riddle     TypeSwitch<const ast::Decl *>(decl)
1489ad64a5cSRiver Riddle         .Case<ast::UserConstraintDecl, ast::UserRewriteDecl>(
1499ad64a5cSRiver Riddle             [&](const auto *decl) {
1509ad64a5cSRiver Riddle               // We only generate code for inline native decls that have been
1519ad64a5cSRiver Riddle               // referenced.
1529ad64a5cSRiver Riddle               if (decl->getCodeBlock() &&
1539ad64a5cSRiver Riddle                   usedFns.contains(decl->getName().getName()))
1549ad64a5cSRiver Riddle                 this->generate(decl, nativeFunctions);
1559ad64a5cSRiver Riddle             });
1569ad64a5cSRiver Riddle   }
1579ad64a5cSRiver Riddle }
1589ad64a5cSRiver Riddle 
generate(const ast::UserConstraintDecl * decl,StringSet<> & nativeFunctions)1599ad64a5cSRiver Riddle void CodeGen::generate(const ast::UserConstraintDecl *decl,
1609ad64a5cSRiver Riddle                        StringSet<> &nativeFunctions) {
161*1c2edb02SRiver Riddle   return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
162*1c2edb02SRiver Riddle                                      /*isConstraint=*/true, nativeFunctions);
1639ad64a5cSRiver Riddle }
1649ad64a5cSRiver Riddle 
generate(const ast::UserRewriteDecl * decl,StringSet<> & nativeFunctions)1659ad64a5cSRiver Riddle void CodeGen::generate(const ast::UserRewriteDecl *decl,
1669ad64a5cSRiver Riddle                        StringSet<> &nativeFunctions) {
167*1c2edb02SRiver Riddle   return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
168*1c2edb02SRiver Riddle                                      /*isConstraint=*/false, nativeFunctions);
1699ad64a5cSRiver Riddle }
1709ad64a5cSRiver Riddle 
getNativeTypeName(ast::Type type)171*1c2edb02SRiver Riddle StringRef CodeGen::getNativeTypeName(ast::Type type) {
1729ad64a5cSRiver Riddle   return llvm::TypeSwitch<ast::Type, StringRef>(type)
1739ad64a5cSRiver Riddle       .Case([&](ast::AttributeType) { return "::mlir::Attribute"; })
174*1c2edb02SRiver Riddle       .Case([&](ast::OperationType opType) -> StringRef {
175*1c2edb02SRiver Riddle         // Use the derived Op class when available.
176*1c2edb02SRiver Riddle         if (const auto *odsOp = opType.getODSOperation())
177*1c2edb02SRiver Riddle           return odsOp->getNativeClassName();
1789ad64a5cSRiver Riddle         return "::mlir::Operation *";
1799ad64a5cSRiver Riddle       })
1809ad64a5cSRiver Riddle       .Case([&](ast::TypeType) { return "::mlir::Type"; })
1819ad64a5cSRiver Riddle       .Case([&](ast::ValueType) { return "::mlir::Value"; })
1829ad64a5cSRiver Riddle       .Case([&](ast::TypeRangeType) { return "::mlir::TypeRange"; })
1839ad64a5cSRiver Riddle       .Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; });
1849ad64a5cSRiver Riddle }
1859ad64a5cSRiver Riddle 
getNativeTypeName(ast::VariableDecl * decl)186*1c2edb02SRiver Riddle StringRef CodeGen::getNativeTypeName(ast::VariableDecl *decl) {
187*1c2edb02SRiver Riddle   // Try to extract a type name from the variable's constraints.
188*1c2edb02SRiver Riddle   for (ast::ConstraintRef &cst : decl->getConstraints()) {
189*1c2edb02SRiver Riddle     if (auto *userCst = dyn_cast<ast::UserConstraintDecl>(cst.constraint)) {
190*1c2edb02SRiver Riddle       if (Optional<StringRef> name = userCst->getNativeInputType(0))
191*1c2edb02SRiver Riddle         return *name;
192*1c2edb02SRiver Riddle       return getNativeTypeName(userCst->getInputs()[0]);
193*1c2edb02SRiver Riddle     }
194*1c2edb02SRiver Riddle   }
195*1c2edb02SRiver Riddle 
196*1c2edb02SRiver Riddle   // Otherwise, use the type of the variable.
197*1c2edb02SRiver Riddle   return getNativeTypeName(decl->getType());
198*1c2edb02SRiver Riddle }
199*1c2edb02SRiver Riddle 
generateConstraintOrRewrite(const ast::CallableDecl * decl,bool isConstraint,StringSet<> & nativeFunctions)200*1c2edb02SRiver Riddle void CodeGen::generateConstraintOrRewrite(const ast::CallableDecl *decl,
201*1c2edb02SRiver Riddle                                           bool isConstraint,
202*1c2edb02SRiver Riddle                                           StringSet<> &nativeFunctions) {
203*1c2edb02SRiver Riddle   StringRef name = decl->getName()->getName();
204*1c2edb02SRiver Riddle   nativeFunctions.insert(name);
205*1c2edb02SRiver Riddle 
206*1c2edb02SRiver Riddle   os << "static ";
207*1c2edb02SRiver Riddle 
208*1c2edb02SRiver Riddle   // TODO: Work out a proper modeling for "optionality".
209*1c2edb02SRiver Riddle 
210*1c2edb02SRiver Riddle   // Emit the result type.
211*1c2edb02SRiver Riddle   // If this is a constraint, we always return a LogicalResult.
212*1c2edb02SRiver Riddle   // TODO: This will need to change if we allow Constraints to return values as
213*1c2edb02SRiver Riddle   // well.
214*1c2edb02SRiver Riddle   if (isConstraint) {
215*1c2edb02SRiver Riddle     os << "::mlir::LogicalResult";
216*1c2edb02SRiver Riddle   } else {
217*1c2edb02SRiver Riddle     // Otherwise, generate a type based on the results of the callable.
218*1c2edb02SRiver Riddle     // If the callable has explicit results, use those to build the result.
219*1c2edb02SRiver Riddle     // Otherwise, use the type of the callable.
220*1c2edb02SRiver Riddle     ArrayRef<ast::VariableDecl *> results = decl->getResults();
221*1c2edb02SRiver Riddle     if (results.empty()) {
222*1c2edb02SRiver Riddle       os << "void";
223*1c2edb02SRiver Riddle     } else if (results.size() == 1) {
224*1c2edb02SRiver Riddle       os << getNativeTypeName(results[0]);
225*1c2edb02SRiver Riddle     } else {
226*1c2edb02SRiver Riddle       os << "std::tuple<";
227*1c2edb02SRiver Riddle       llvm::interleaveComma(results, os, [&](ast::VariableDecl *result) {
228*1c2edb02SRiver Riddle         os << getNativeTypeName(result);
229*1c2edb02SRiver Riddle       });
230*1c2edb02SRiver Riddle       os << ">";
231*1c2edb02SRiver Riddle     }
232*1c2edb02SRiver Riddle   }
233*1c2edb02SRiver Riddle 
234*1c2edb02SRiver Riddle   os << " " << name << "PDLFn(::mlir::PatternRewriter &rewriter";
235*1c2edb02SRiver Riddle   if (!decl->getInputs().empty()) {
236*1c2edb02SRiver Riddle     os << ", ";
237*1c2edb02SRiver Riddle     llvm::interleaveComma(decl->getInputs(), os, [&](ast::VariableDecl *input) {
238*1c2edb02SRiver Riddle       os << getNativeTypeName(input) << " " << input->getName().getName();
239*1c2edb02SRiver Riddle     });
240*1c2edb02SRiver Riddle   }
241*1c2edb02SRiver Riddle   os << ") {\n";
242*1c2edb02SRiver Riddle   os << "  " << decl->getCodeBlock()->trim() << "\n}\n\n";
2439ad64a5cSRiver Riddle }
2449ad64a5cSRiver Riddle 
2459ad64a5cSRiver Riddle //===----------------------------------------------------------------------===//
2469ad64a5cSRiver Riddle // CPPGen
2479ad64a5cSRiver Riddle //===----------------------------------------------------------------------===//
2489ad64a5cSRiver Riddle 
codegenPDLLToCPP(const ast::Module & astModule,ModuleOp module,raw_ostream & os)2499ad64a5cSRiver Riddle void mlir::pdll::codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module,
2509ad64a5cSRiver Riddle                                   raw_ostream &os) {
2519ad64a5cSRiver Riddle   CodeGen codegen(os);
2529ad64a5cSRiver Riddle   codegen.generate(astModule, module);
2539ad64a5cSRiver Riddle }
254