//===- CPPGen.cpp ---------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This files contains a PDLL generator that outputs C++ code that defines PDLL // patterns as individual C++ PDLPatternModules for direct use in native code, // and also defines any native constraints whose bodies were defined in PDLL. // //===----------------------------------------------------------------------===// #include "mlir/Tools/PDLL/CodeGen/CPPGen.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Tools/PDLL/AST/Nodes.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FormatVariadic.h" using namespace mlir; using namespace mlir::pdll; //===----------------------------------------------------------------------===// // CodeGen //===----------------------------------------------------------------------===// namespace { class CodeGen { public: CodeGen(raw_ostream &os) : os(os) {} /// Generate C++ code for the given PDL pattern module. void generate(const ast::Module &astModule, ModuleOp module); private: void generate(pdl::PatternOp pattern, StringRef patternName, StringSet<> &nativeFunctions); /// Generate C++ code for all user defined constraints and rewrites with /// native code. void generateConstraintAndRewrites(const ast::Module &astModule, ModuleOp module, StringSet<> &nativeFunctions); void generate(const ast::UserConstraintDecl *decl, StringSet<> &nativeFunctions); void generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions); void generateConstraintOrRewrite(StringRef name, bool isConstraint, ArrayRef inputs, StringRef codeBlock, StringSet<> &nativeFunctions); /// The stream to output to. raw_ostream &os; }; } // namespace void CodeGen::generate(const ast::Module &astModule, ModuleOp module) { SetVector, StringSet<>> patternNames; StringSet<> nativeFunctions; // Generate code for any native functions within the module. generateConstraintAndRewrites(astModule, module, nativeFunctions); os << "namespace {\n"; std::string basePatternName = "GeneratedPDLLPattern"; int patternIndex = 0; for (pdl::PatternOp pattern : module.getOps()) { // If the pattern has a name, use that. Otherwise, generate a unique name. if (Optional patternName = pattern.sym_name()) { patternNames.insert(patternName->str()); } else { std::string name; do { name = (basePatternName + Twine(patternIndex++)).str(); } while (!patternNames.insert(name)); } generate(pattern, patternNames.back(), nativeFunctions); } os << "} // end namespace\n\n"; // Emit function to add the generated matchers to the pattern list. os << "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(" "::mlir::RewritePatternSet &patterns) {\n"; for (const auto &name : patternNames) os << " patterns.add<" << name << ">(patterns.getContext());\n"; os << "}\n"; } void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName, StringSet<> &nativeFunctions) { const char *patternClassStartStr = R"( struct {0} : ::mlir::PDLPatternModule {{ {0}(::mlir::MLIRContext *context) : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>( )"; os << llvm::formatv(patternClassStartStr, patternName); os << "R\"mlir("; pattern->print(os, OpPrintingFlags().enableDebugInfo()); os << "\n )mlir\", context)) {\n"; // Register any native functions used within the pattern. StringSet<> registeredNativeFunctions; auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) { if (!nativeFunctions.count(fnName) || !registeredNativeFunctions.insert(fnName).second) return; os << " register" << fnType << "Function(\"" << fnName << "\", " << fnName << "PDLFn);\n"; }; pattern.walk([&](Operation *op) { if (auto constraintOp = dyn_cast(op)) checkRegisterNativeFn(constraintOp.name(), "Constraint"); else if (auto rewriteOp = dyn_cast(op)) checkRegisterNativeFn(rewriteOp.name(), "Rewrite"); }); os << " }\n};\n\n"; } void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule, ModuleOp module, StringSet<> &nativeFunctions) { // First check to see which constraints and rewrites are actually referenced // in the module. StringSet<> usedFns; module.walk([&](Operation *op) { TypeSwitch(op) .Case( [&](auto op) { usedFns.insert(op.name()); }); }); for (const ast::Decl *decl : astModule.getChildren()) { TypeSwitch(decl) .Case( [&](const auto *decl) { // We only generate code for inline native decls that have been // referenced. if (decl->getCodeBlock() && usedFns.contains(decl->getName().getName())) this->generate(decl, nativeFunctions); }); } } void CodeGen::generate(const ast::UserConstraintDecl *decl, StringSet<> &nativeFunctions) { return generateConstraintOrRewrite(decl->getName().getName(), /*isConstraint=*/true, decl->getInputs(), *decl->getCodeBlock(), nativeFunctions); } void CodeGen::generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions) { return generateConstraintOrRewrite(decl->getName().getName(), /*isConstraint=*/false, decl->getInputs(), *decl->getCodeBlock(), nativeFunctions); } void CodeGen::generateConstraintOrRewrite(StringRef name, bool isConstraint, ArrayRef inputs, StringRef codeBlock, StringSet<> &nativeFunctions) { nativeFunctions.insert(name); // TODO: Should there be something explicit for handling optionality? auto getCppType = [&](ast::Type type) -> StringRef { return llvm::TypeSwitch(type) .Case([&](ast::AttributeType) { return "::mlir::Attribute"; }) .Case([&](ast::OperationType) { // TODO: Allow using the derived Op class when possible. return "::mlir::Operation *"; }) .Case([&](ast::TypeType) { return "::mlir::Type"; }) .Case([&](ast::ValueType) { return "::mlir::Value"; }) .Case([&](ast::TypeRangeType) { return "::mlir::TypeRange"; }) .Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; }); }; os << "static " << (isConstraint ? "::mlir::LogicalResult " : "void ") << name << "PDLFn(::mlir::PatternRewriter &rewriter, " << (isConstraint ? "" : "::mlir::PDLResultList &results, ") << "::llvm::ArrayRef<::mlir::PDLValue> values) {\n"; const char *argumentInitStr = R"( {0} {1} = {{}; if (values[{2}]) {1} = values[{2}].cast<{0}>(); (void){1}; )"; for (const auto &it : llvm::enumerate(inputs)) { const ast::VariableDecl *input = it.value(); os << llvm::formatv(argumentInitStr, getCppType(input->getType()), input->getName().getName(), it.index()); } os << " " << codeBlock.trim() << "\n}\n"; } //===----------------------------------------------------------------------===// // CPPGen //===----------------------------------------------------------------------===// void mlir::pdll::codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module, raw_ostream &os) { CodeGen codegen(os); codegen.generate(astModule, module); }