1 //===- CPPGen.cpp ---------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This files contains a PDLL generator that outputs C++ code that defines PDLL 10 // patterns as individual C++ PDLPatternModules for direct use in native code, 11 // and also defines any native constraints whose bodies were defined in PDLL. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Tools/PDLL/CodeGen/CPPGen.h" 16 #include "mlir/Dialect/PDL/IR/PDL.h" 17 #include "mlir/Dialect/PDL/IR/PDLOps.h" 18 #include "mlir/IR/BuiltinOps.h" 19 #include "mlir/Tools/PDLL/AST/Nodes.h" 20 #include "llvm/ADT/SmallString.h" 21 #include "llvm/ADT/StringExtras.h" 22 #include "llvm/ADT/StringSet.h" 23 #include "llvm/ADT/TypeSwitch.h" 24 #include "llvm/Support/ErrorHandling.h" 25 #include "llvm/Support/FormatVariadic.h" 26 27 using namespace mlir; 28 using namespace mlir::pdll; 29 30 //===----------------------------------------------------------------------===// 31 // CodeGen 32 //===----------------------------------------------------------------------===// 33 34 namespace { 35 class CodeGen { 36 public: 37 CodeGen(raw_ostream &os) : os(os) {} 38 39 /// Generate C++ code for the given PDL pattern module. 40 void generate(const ast::Module &astModule, ModuleOp module); 41 42 private: 43 void generate(pdl::PatternOp pattern, StringRef patternName, 44 StringSet<> &nativeFunctions); 45 46 /// Generate C++ code for all user defined constraints and rewrites with 47 /// native code. 48 void generateConstraintAndRewrites(const ast::Module &astModule, 49 ModuleOp module, 50 StringSet<> &nativeFunctions); 51 void generate(const ast::UserConstraintDecl *decl, 52 StringSet<> &nativeFunctions); 53 void generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions); 54 void generateConstraintOrRewrite(StringRef name, bool isConstraint, 55 ArrayRef<ast::VariableDecl *> inputs, 56 StringRef codeBlock, 57 StringSet<> &nativeFunctions); 58 59 /// The stream to output to. 60 raw_ostream &os; 61 }; 62 } // namespace 63 64 void CodeGen::generate(const ast::Module &astModule, ModuleOp module) { 65 SetVector<std::string, SmallVector<std::string>, StringSet<>> patternNames; 66 StringSet<> nativeFunctions; 67 68 // Generate code for any native functions within the module. 69 generateConstraintAndRewrites(astModule, module, nativeFunctions); 70 71 os << "namespace {\n"; 72 std::string basePatternName = "GeneratedPDLLPattern"; 73 int patternIndex = 0; 74 for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) { 75 // If the pattern has a name, use that. Otherwise, generate a unique name. 76 if (Optional<StringRef> patternName = pattern.sym_name()) { 77 patternNames.insert(patternName->str()); 78 } else { 79 std::string name; 80 do { 81 name = (basePatternName + Twine(patternIndex++)).str(); 82 } while (!patternNames.insert(name)); 83 } 84 85 generate(pattern, patternNames.back(), nativeFunctions); 86 } 87 os << "} // end namespace\n\n"; 88 89 // Emit function to add the generated matchers to the pattern list. 90 os << "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(" 91 "::mlir::RewritePatternSet &patterns) {\n"; 92 for (const auto &name : patternNames) 93 os << " patterns.add<" << name << ">(patterns.getContext());\n"; 94 os << "}\n"; 95 } 96 97 void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName, 98 StringSet<> &nativeFunctions) { 99 const char *patternClassStartStr = R"( 100 struct {0} : ::mlir::PDLPatternModule {{ 101 {0}(::mlir::MLIRContext *context) 102 : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>( 103 )"; 104 os << llvm::formatv(patternClassStartStr, patternName); 105 106 os << "R\"mlir("; 107 pattern->print(os, OpPrintingFlags().enableDebugInfo()); 108 os << "\n )mlir\", context)) {\n"; 109 110 // Register any native functions used within the pattern. 111 StringSet<> registeredNativeFunctions; 112 auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) { 113 if (!nativeFunctions.count(fnName) || 114 !registeredNativeFunctions.insert(fnName).second) 115 return; 116 os << " register" << fnType << "Function(\"" << fnName << "\", " 117 << fnName << "PDLFn);\n"; 118 }; 119 pattern.walk([&](Operation *op) { 120 if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op)) 121 checkRegisterNativeFn(constraintOp.name(), "Constraint"); 122 else if (auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op)) 123 checkRegisterNativeFn(rewriteOp.name(), "Rewrite"); 124 }); 125 os << " }\n};\n\n"; 126 } 127 128 void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule, 129 ModuleOp module, 130 StringSet<> &nativeFunctions) { 131 // First check to see which constraints and rewrites are actually referenced 132 // in the module. 133 StringSet<> usedFns; 134 module.walk([&](Operation *op) { 135 TypeSwitch<Operation *>(op) 136 .Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>( 137 [&](auto op) { usedFns.insert(op.name()); }); 138 }); 139 140 for (const ast::Decl *decl : astModule.getChildren()) { 141 TypeSwitch<const ast::Decl *>(decl) 142 .Case<ast::UserConstraintDecl, ast::UserRewriteDecl>( 143 [&](const auto *decl) { 144 // We only generate code for inline native decls that have been 145 // referenced. 146 if (decl->getCodeBlock() && 147 usedFns.contains(decl->getName().getName())) 148 this->generate(decl, nativeFunctions); 149 }); 150 } 151 } 152 153 void CodeGen::generate(const ast::UserConstraintDecl *decl, 154 StringSet<> &nativeFunctions) { 155 return generateConstraintOrRewrite(decl->getName().getName(), 156 /*isConstraint=*/true, decl->getInputs(), 157 *decl->getCodeBlock(), nativeFunctions); 158 } 159 160 void CodeGen::generate(const ast::UserRewriteDecl *decl, 161 StringSet<> &nativeFunctions) { 162 return generateConstraintOrRewrite(decl->getName().getName(), 163 /*isConstraint=*/false, decl->getInputs(), 164 *decl->getCodeBlock(), nativeFunctions); 165 } 166 167 void CodeGen::generateConstraintOrRewrite(StringRef name, bool isConstraint, 168 ArrayRef<ast::VariableDecl *> inputs, 169 StringRef codeBlock, 170 StringSet<> &nativeFunctions) { 171 nativeFunctions.insert(name); 172 173 // TODO: Should there be something explicit for handling optionality? 174 auto getCppType = [&](ast::Type type) -> StringRef { 175 return llvm::TypeSwitch<ast::Type, StringRef>(type) 176 .Case([&](ast::AttributeType) { return "::mlir::Attribute"; }) 177 .Case([&](ast::OperationType) { 178 // TODO: Allow using the derived Op class when possible. 179 return "::mlir::Operation *"; 180 }) 181 .Case([&](ast::TypeType) { return "::mlir::Type"; }) 182 .Case([&](ast::ValueType) { return "::mlir::Value"; }) 183 .Case([&](ast::TypeRangeType) { return "::mlir::TypeRange"; }) 184 .Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; }); 185 }; 186 187 // FIXME: We currently do not have a modeling for the "constant params" 188 // support PDL provides. We should either figure out a modeling for this, or 189 // refactor the support within PDL to be something a bit more reasonable for 190 // what we need as a frontend. 191 os << "static " << (isConstraint ? "::mlir::LogicalResult " : "void ") << name 192 << "PDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, " 193 "::mlir::ArrayAttr constParams, ::mlir::PatternRewriter &rewriter" 194 << (isConstraint ? "" : ", ::mlir::PDLResultList &results") << ") {\n"; 195 196 const char *argumentInitStr = R"( 197 {0} {1} = {{}; 198 if (values[{2}]) 199 {1} = values[{2}].cast<{0}>(); 200 (void){1}; 201 )"; 202 for (const auto &it : llvm::enumerate(inputs)) { 203 const ast::VariableDecl *input = it.value(); 204 os << llvm::formatv(argumentInitStr, getCppType(input->getType()), 205 input->getName().getName(), it.index()); 206 } 207 208 os << " " << codeBlock.trim() << "\n}\n"; 209 } 210 211 //===----------------------------------------------------------------------===// 212 // CPPGen 213 //===----------------------------------------------------------------------===// 214 215 void mlir::pdll::codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module, 216 raw_ostream &os) { 217 CodeGen codegen(os); 218 codegen.generate(astModule, module); 219 } 220