19ad64a5cSRiver Riddle //===- CPPGen.cpp ---------------------------------------------------------===//
29ad64a5cSRiver Riddle //
39ad64a5cSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
49ad64a5cSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
59ad64a5cSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69ad64a5cSRiver Riddle //
79ad64a5cSRiver Riddle //===----------------------------------------------------------------------===//
89ad64a5cSRiver Riddle //
99ad64a5cSRiver Riddle // This files contains a PDLL generator that outputs C++ code that defines PDLL
109ad64a5cSRiver Riddle // patterns as individual C++ PDLPatternModules for direct use in native code,
119ad64a5cSRiver Riddle // and also defines any native constraints whose bodies were defined in PDLL.
129ad64a5cSRiver Riddle //
139ad64a5cSRiver Riddle //===----------------------------------------------------------------------===//
149ad64a5cSRiver Riddle
159ad64a5cSRiver Riddle #include "mlir/Tools/PDLL/CodeGen/CPPGen.h"
169ad64a5cSRiver Riddle #include "mlir/Dialect/PDL/IR/PDL.h"
179ad64a5cSRiver Riddle #include "mlir/Dialect/PDL/IR/PDLOps.h"
189ad64a5cSRiver Riddle #include "mlir/IR/BuiltinOps.h"
199ad64a5cSRiver Riddle #include "mlir/Tools/PDLL/AST/Nodes.h"
20*1c2edb02SRiver Riddle #include "mlir/Tools/PDLL/ODS/Operation.h"
219ad64a5cSRiver Riddle #include "llvm/ADT/SmallString.h"
229ad64a5cSRiver Riddle #include "llvm/ADT/StringExtras.h"
239ad64a5cSRiver Riddle #include "llvm/ADT/StringSet.h"
249ad64a5cSRiver Riddle #include "llvm/ADT/TypeSwitch.h"
259ad64a5cSRiver Riddle #include "llvm/Support/ErrorHandling.h"
269ad64a5cSRiver Riddle #include "llvm/Support/FormatVariadic.h"
279ad64a5cSRiver Riddle
289ad64a5cSRiver Riddle using namespace mlir;
299ad64a5cSRiver Riddle using namespace mlir::pdll;
309ad64a5cSRiver Riddle
319ad64a5cSRiver Riddle //===----------------------------------------------------------------------===//
329ad64a5cSRiver Riddle // CodeGen
339ad64a5cSRiver Riddle //===----------------------------------------------------------------------===//
349ad64a5cSRiver Riddle
359ad64a5cSRiver Riddle namespace {
369ad64a5cSRiver Riddle class CodeGen {
379ad64a5cSRiver Riddle public:
CodeGen(raw_ostream & os)389ad64a5cSRiver Riddle CodeGen(raw_ostream &os) : os(os) {}
399ad64a5cSRiver Riddle
409ad64a5cSRiver Riddle /// Generate C++ code for the given PDL pattern module.
419ad64a5cSRiver Riddle void generate(const ast::Module &astModule, ModuleOp module);
429ad64a5cSRiver Riddle
439ad64a5cSRiver Riddle private:
449ad64a5cSRiver Riddle void generate(pdl::PatternOp pattern, StringRef patternName,
459ad64a5cSRiver Riddle StringSet<> &nativeFunctions);
469ad64a5cSRiver Riddle
479ad64a5cSRiver Riddle /// Generate C++ code for all user defined constraints and rewrites with
489ad64a5cSRiver Riddle /// native code.
499ad64a5cSRiver Riddle void generateConstraintAndRewrites(const ast::Module &astModule,
509ad64a5cSRiver Riddle ModuleOp module,
519ad64a5cSRiver Riddle StringSet<> &nativeFunctions);
529ad64a5cSRiver Riddle void generate(const ast::UserConstraintDecl *decl,
539ad64a5cSRiver Riddle StringSet<> &nativeFunctions);
549ad64a5cSRiver Riddle void generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions);
55*1c2edb02SRiver Riddle void generateConstraintOrRewrite(const ast::CallableDecl *decl,
56*1c2edb02SRiver Riddle bool isConstraint,
579ad64a5cSRiver Riddle StringSet<> &nativeFunctions);
589ad64a5cSRiver Riddle
59*1c2edb02SRiver Riddle /// Return the native name for the type of the given type.
60*1c2edb02SRiver Riddle StringRef getNativeTypeName(ast::Type type);
61*1c2edb02SRiver Riddle
62*1c2edb02SRiver Riddle /// Return the native name for the type of the given variable decl.
63*1c2edb02SRiver Riddle StringRef getNativeTypeName(ast::VariableDecl *decl);
64*1c2edb02SRiver Riddle
659ad64a5cSRiver Riddle /// The stream to output to.
669ad64a5cSRiver Riddle raw_ostream &os;
679ad64a5cSRiver Riddle };
689ad64a5cSRiver Riddle } // namespace
699ad64a5cSRiver Riddle
generate(const ast::Module & astModule,ModuleOp module)709ad64a5cSRiver Riddle void CodeGen::generate(const ast::Module &astModule, ModuleOp module) {
719ad64a5cSRiver Riddle SetVector<std::string, SmallVector<std::string>, StringSet<>> patternNames;
729ad64a5cSRiver Riddle StringSet<> nativeFunctions;
739ad64a5cSRiver Riddle
749ad64a5cSRiver Riddle // Generate code for any native functions within the module.
759ad64a5cSRiver Riddle generateConstraintAndRewrites(astModule, module, nativeFunctions);
769ad64a5cSRiver Riddle
779ad64a5cSRiver Riddle os << "namespace {\n";
789ad64a5cSRiver Riddle std::string basePatternName = "GeneratedPDLLPattern";
799ad64a5cSRiver Riddle int patternIndex = 0;
809ad64a5cSRiver Riddle for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
819ad64a5cSRiver Riddle // If the pattern has a name, use that. Otherwise, generate a unique name.
829ad64a5cSRiver Riddle if (Optional<StringRef> patternName = pattern.sym_name()) {
839ad64a5cSRiver Riddle patternNames.insert(patternName->str());
849ad64a5cSRiver Riddle } else {
859ad64a5cSRiver Riddle std::string name;
869ad64a5cSRiver Riddle do {
879ad64a5cSRiver Riddle name = (basePatternName + Twine(patternIndex++)).str();
889ad64a5cSRiver Riddle } while (!patternNames.insert(name));
899ad64a5cSRiver Riddle }
909ad64a5cSRiver Riddle
919ad64a5cSRiver Riddle generate(pattern, patternNames.back(), nativeFunctions);
929ad64a5cSRiver Riddle }
939ad64a5cSRiver Riddle os << "} // end namespace\n\n";
949ad64a5cSRiver Riddle
959ad64a5cSRiver Riddle // Emit function to add the generated matchers to the pattern list.
969ad64a5cSRiver Riddle os << "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
979ad64a5cSRiver Riddle "::mlir::RewritePatternSet &patterns) {\n";
989ad64a5cSRiver Riddle for (const auto &name : patternNames)
999ad64a5cSRiver Riddle os << " patterns.add<" << name << ">(patterns.getContext());\n";
1009ad64a5cSRiver Riddle os << "}\n";
1019ad64a5cSRiver Riddle }
1029ad64a5cSRiver Riddle
generate(pdl::PatternOp pattern,StringRef patternName,StringSet<> & nativeFunctions)1039ad64a5cSRiver Riddle void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName,
1049ad64a5cSRiver Riddle StringSet<> &nativeFunctions) {
1059ad64a5cSRiver Riddle const char *patternClassStartStr = R"(
1069ad64a5cSRiver Riddle struct {0} : ::mlir::PDLPatternModule {{
1079ad64a5cSRiver Riddle {0}(::mlir::MLIRContext *context)
1089ad64a5cSRiver Riddle : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
1099ad64a5cSRiver Riddle )";
1109ad64a5cSRiver Riddle os << llvm::formatv(patternClassStartStr, patternName);
1119ad64a5cSRiver Riddle
1129ad64a5cSRiver Riddle os << "R\"mlir(";
1139ad64a5cSRiver Riddle pattern->print(os, OpPrintingFlags().enableDebugInfo());
1149ad64a5cSRiver Riddle os << "\n )mlir\", context)) {\n";
1159ad64a5cSRiver Riddle
1169ad64a5cSRiver Riddle // Register any native functions used within the pattern.
1179ad64a5cSRiver Riddle StringSet<> registeredNativeFunctions;
1189ad64a5cSRiver Riddle auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) {
1199ad64a5cSRiver Riddle if (!nativeFunctions.count(fnName) ||
1209ad64a5cSRiver Riddle !registeredNativeFunctions.insert(fnName).second)
1219ad64a5cSRiver Riddle return;
1229ad64a5cSRiver Riddle os << " register" << fnType << "Function(\"" << fnName << "\", "
1239ad64a5cSRiver Riddle << fnName << "PDLFn);\n";
1249ad64a5cSRiver Riddle };
1259ad64a5cSRiver Riddle pattern.walk([&](Operation *op) {
1269ad64a5cSRiver Riddle if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op))
1279ad64a5cSRiver Riddle checkRegisterNativeFn(constraintOp.name(), "Constraint");
1289ad64a5cSRiver Riddle else if (auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op))
1299ad64a5cSRiver Riddle checkRegisterNativeFn(rewriteOp.name(), "Rewrite");
1309ad64a5cSRiver Riddle });
1319ad64a5cSRiver Riddle os << " }\n};\n\n";
1329ad64a5cSRiver Riddle }
1339ad64a5cSRiver Riddle
generateConstraintAndRewrites(const ast::Module & astModule,ModuleOp module,StringSet<> & nativeFunctions)1349ad64a5cSRiver Riddle void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule,
1359ad64a5cSRiver Riddle ModuleOp module,
1369ad64a5cSRiver Riddle StringSet<> &nativeFunctions) {
1379ad64a5cSRiver Riddle // First check to see which constraints and rewrites are actually referenced
1389ad64a5cSRiver Riddle // in the module.
1399ad64a5cSRiver Riddle StringSet<> usedFns;
1409ad64a5cSRiver Riddle module.walk([&](Operation *op) {
1419ad64a5cSRiver Riddle TypeSwitch<Operation *>(op)
1429ad64a5cSRiver Riddle .Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>(
1439ad64a5cSRiver Riddle [&](auto op) { usedFns.insert(op.name()); });
1449ad64a5cSRiver Riddle });
1459ad64a5cSRiver Riddle
1469ad64a5cSRiver Riddle for (const ast::Decl *decl : astModule.getChildren()) {
1479ad64a5cSRiver Riddle TypeSwitch<const ast::Decl *>(decl)
1489ad64a5cSRiver Riddle .Case<ast::UserConstraintDecl, ast::UserRewriteDecl>(
1499ad64a5cSRiver Riddle [&](const auto *decl) {
1509ad64a5cSRiver Riddle // We only generate code for inline native decls that have been
1519ad64a5cSRiver Riddle // referenced.
1529ad64a5cSRiver Riddle if (decl->getCodeBlock() &&
1539ad64a5cSRiver Riddle usedFns.contains(decl->getName().getName()))
1549ad64a5cSRiver Riddle this->generate(decl, nativeFunctions);
1559ad64a5cSRiver Riddle });
1569ad64a5cSRiver Riddle }
1579ad64a5cSRiver Riddle }
1589ad64a5cSRiver Riddle
generate(const ast::UserConstraintDecl * decl,StringSet<> & nativeFunctions)1599ad64a5cSRiver Riddle void CodeGen::generate(const ast::UserConstraintDecl *decl,
1609ad64a5cSRiver Riddle StringSet<> &nativeFunctions) {
161*1c2edb02SRiver Riddle return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
162*1c2edb02SRiver Riddle /*isConstraint=*/true, nativeFunctions);
1639ad64a5cSRiver Riddle }
1649ad64a5cSRiver Riddle
generate(const ast::UserRewriteDecl * decl,StringSet<> & nativeFunctions)1659ad64a5cSRiver Riddle void CodeGen::generate(const ast::UserRewriteDecl *decl,
1669ad64a5cSRiver Riddle StringSet<> &nativeFunctions) {
167*1c2edb02SRiver Riddle return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
168*1c2edb02SRiver Riddle /*isConstraint=*/false, nativeFunctions);
1699ad64a5cSRiver Riddle }
1709ad64a5cSRiver Riddle
getNativeTypeName(ast::Type type)171*1c2edb02SRiver Riddle StringRef CodeGen::getNativeTypeName(ast::Type type) {
1729ad64a5cSRiver Riddle return llvm::TypeSwitch<ast::Type, StringRef>(type)
1739ad64a5cSRiver Riddle .Case([&](ast::AttributeType) { return "::mlir::Attribute"; })
174*1c2edb02SRiver Riddle .Case([&](ast::OperationType opType) -> StringRef {
175*1c2edb02SRiver Riddle // Use the derived Op class when available.
176*1c2edb02SRiver Riddle if (const auto *odsOp = opType.getODSOperation())
177*1c2edb02SRiver Riddle return odsOp->getNativeClassName();
1789ad64a5cSRiver Riddle return "::mlir::Operation *";
1799ad64a5cSRiver Riddle })
1809ad64a5cSRiver Riddle .Case([&](ast::TypeType) { return "::mlir::Type"; })
1819ad64a5cSRiver Riddle .Case([&](ast::ValueType) { return "::mlir::Value"; })
1829ad64a5cSRiver Riddle .Case([&](ast::TypeRangeType) { return "::mlir::TypeRange"; })
1839ad64a5cSRiver Riddle .Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; });
1849ad64a5cSRiver Riddle }
1859ad64a5cSRiver Riddle
getNativeTypeName(ast::VariableDecl * decl)186*1c2edb02SRiver Riddle StringRef CodeGen::getNativeTypeName(ast::VariableDecl *decl) {
187*1c2edb02SRiver Riddle // Try to extract a type name from the variable's constraints.
188*1c2edb02SRiver Riddle for (ast::ConstraintRef &cst : decl->getConstraints()) {
189*1c2edb02SRiver Riddle if (auto *userCst = dyn_cast<ast::UserConstraintDecl>(cst.constraint)) {
190*1c2edb02SRiver Riddle if (Optional<StringRef> name = userCst->getNativeInputType(0))
191*1c2edb02SRiver Riddle return *name;
192*1c2edb02SRiver Riddle return getNativeTypeName(userCst->getInputs()[0]);
193*1c2edb02SRiver Riddle }
194*1c2edb02SRiver Riddle }
195*1c2edb02SRiver Riddle
196*1c2edb02SRiver Riddle // Otherwise, use the type of the variable.
197*1c2edb02SRiver Riddle return getNativeTypeName(decl->getType());
198*1c2edb02SRiver Riddle }
199*1c2edb02SRiver Riddle
generateConstraintOrRewrite(const ast::CallableDecl * decl,bool isConstraint,StringSet<> & nativeFunctions)200*1c2edb02SRiver Riddle void CodeGen::generateConstraintOrRewrite(const ast::CallableDecl *decl,
201*1c2edb02SRiver Riddle bool isConstraint,
202*1c2edb02SRiver Riddle StringSet<> &nativeFunctions) {
203*1c2edb02SRiver Riddle StringRef name = decl->getName()->getName();
204*1c2edb02SRiver Riddle nativeFunctions.insert(name);
205*1c2edb02SRiver Riddle
206*1c2edb02SRiver Riddle os << "static ";
207*1c2edb02SRiver Riddle
208*1c2edb02SRiver Riddle // TODO: Work out a proper modeling for "optionality".
209*1c2edb02SRiver Riddle
210*1c2edb02SRiver Riddle // Emit the result type.
211*1c2edb02SRiver Riddle // If this is a constraint, we always return a LogicalResult.
212*1c2edb02SRiver Riddle // TODO: This will need to change if we allow Constraints to return values as
213*1c2edb02SRiver Riddle // well.
214*1c2edb02SRiver Riddle if (isConstraint) {
215*1c2edb02SRiver Riddle os << "::mlir::LogicalResult";
216*1c2edb02SRiver Riddle } else {
217*1c2edb02SRiver Riddle // Otherwise, generate a type based on the results of the callable.
218*1c2edb02SRiver Riddle // If the callable has explicit results, use those to build the result.
219*1c2edb02SRiver Riddle // Otherwise, use the type of the callable.
220*1c2edb02SRiver Riddle ArrayRef<ast::VariableDecl *> results = decl->getResults();
221*1c2edb02SRiver Riddle if (results.empty()) {
222*1c2edb02SRiver Riddle os << "void";
223*1c2edb02SRiver Riddle } else if (results.size() == 1) {
224*1c2edb02SRiver Riddle os << getNativeTypeName(results[0]);
225*1c2edb02SRiver Riddle } else {
226*1c2edb02SRiver Riddle os << "std::tuple<";
227*1c2edb02SRiver Riddle llvm::interleaveComma(results, os, [&](ast::VariableDecl *result) {
228*1c2edb02SRiver Riddle os << getNativeTypeName(result);
229*1c2edb02SRiver Riddle });
230*1c2edb02SRiver Riddle os << ">";
231*1c2edb02SRiver Riddle }
232*1c2edb02SRiver Riddle }
233*1c2edb02SRiver Riddle
234*1c2edb02SRiver Riddle os << " " << name << "PDLFn(::mlir::PatternRewriter &rewriter";
235*1c2edb02SRiver Riddle if (!decl->getInputs().empty()) {
236*1c2edb02SRiver Riddle os << ", ";
237*1c2edb02SRiver Riddle llvm::interleaveComma(decl->getInputs(), os, [&](ast::VariableDecl *input) {
238*1c2edb02SRiver Riddle os << getNativeTypeName(input) << " " << input->getName().getName();
239*1c2edb02SRiver Riddle });
240*1c2edb02SRiver Riddle }
241*1c2edb02SRiver Riddle os << ") {\n";
242*1c2edb02SRiver Riddle os << " " << decl->getCodeBlock()->trim() << "\n}\n\n";
2439ad64a5cSRiver Riddle }
2449ad64a5cSRiver Riddle
2459ad64a5cSRiver Riddle //===----------------------------------------------------------------------===//
2469ad64a5cSRiver Riddle // CPPGen
2479ad64a5cSRiver Riddle //===----------------------------------------------------------------------===//
2489ad64a5cSRiver Riddle
codegenPDLLToCPP(const ast::Module & astModule,ModuleOp module,raw_ostream & os)2499ad64a5cSRiver Riddle void mlir::pdll::codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module,
2509ad64a5cSRiver Riddle raw_ostream &os) {
2519ad64a5cSRiver Riddle CodeGen codegen(os);
2529ad64a5cSRiver Riddle codegen.generate(astModule, module);
2539ad64a5cSRiver Riddle }
254