1 //===- CPPGen.cpp ---------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This files contains a PDLL generator that outputs C++ code that defines PDLL 10 // patterns as individual C++ PDLPatternModules for direct use in native code, 11 // and also defines any native constraints whose bodies were defined in PDLL. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Tools/PDLL/CodeGen/CPPGen.h" 16 #include "mlir/Dialect/PDL/IR/PDL.h" 17 #include "mlir/Dialect/PDL/IR/PDLOps.h" 18 #include "mlir/IR/BuiltinOps.h" 19 #include "mlir/Tools/PDLL/AST/Nodes.h" 20 #include "mlir/Tools/PDLL/ODS/Operation.h" 21 #include "llvm/ADT/SmallString.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/ADT/StringSet.h" 24 #include "llvm/ADT/TypeSwitch.h" 25 #include "llvm/Support/ErrorHandling.h" 26 #include "llvm/Support/FormatVariadic.h" 27 28 using namespace mlir; 29 using namespace mlir::pdll; 30 31 //===----------------------------------------------------------------------===// 32 // CodeGen 33 //===----------------------------------------------------------------------===// 34 35 namespace { 36 class CodeGen { 37 public: 38 CodeGen(raw_ostream &os) : os(os) {} 39 40 /// Generate C++ code for the given PDL pattern module. 41 void generate(const ast::Module &astModule, ModuleOp module); 42 43 private: 44 void generate(pdl::PatternOp pattern, StringRef patternName, 45 StringSet<> &nativeFunctions); 46 47 /// Generate C++ code for all user defined constraints and rewrites with 48 /// native code. 49 void generateConstraintAndRewrites(const ast::Module &astModule, 50 ModuleOp module, 51 StringSet<> &nativeFunctions); 52 void generate(const ast::UserConstraintDecl *decl, 53 StringSet<> &nativeFunctions); 54 void generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions); 55 void generateConstraintOrRewrite(const ast::CallableDecl *decl, 56 bool isConstraint, 57 StringSet<> &nativeFunctions); 58 59 /// Return the native name for the type of the given type. 60 StringRef getNativeTypeName(ast::Type type); 61 62 /// Return the native name for the type of the given variable decl. 63 StringRef getNativeTypeName(ast::VariableDecl *decl); 64 65 /// The stream to output to. 66 raw_ostream &os; 67 }; 68 } // namespace 69 70 void CodeGen::generate(const ast::Module &astModule, ModuleOp module) { 71 SetVector<std::string, SmallVector<std::string>, StringSet<>> patternNames; 72 StringSet<> nativeFunctions; 73 74 // Generate code for any native functions within the module. 75 generateConstraintAndRewrites(astModule, module, nativeFunctions); 76 77 os << "namespace {\n"; 78 std::string basePatternName = "GeneratedPDLLPattern"; 79 int patternIndex = 0; 80 for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) { 81 // If the pattern has a name, use that. Otherwise, generate a unique name. 82 if (Optional<StringRef> patternName = pattern.sym_name()) { 83 patternNames.insert(patternName->str()); 84 } else { 85 std::string name; 86 do { 87 name = (basePatternName + Twine(patternIndex++)).str(); 88 } while (!patternNames.insert(name)); 89 } 90 91 generate(pattern, patternNames.back(), nativeFunctions); 92 } 93 os << "} // end namespace\n\n"; 94 95 // Emit function to add the generated matchers to the pattern list. 96 os << "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(" 97 "::mlir::RewritePatternSet &patterns) {\n"; 98 for (const auto &name : patternNames) 99 os << " patterns.add<" << name << ">(patterns.getContext());\n"; 100 os << "}\n"; 101 } 102 103 void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName, 104 StringSet<> &nativeFunctions) { 105 const char *patternClassStartStr = R"( 106 struct {0} : ::mlir::PDLPatternModule {{ 107 {0}(::mlir::MLIRContext *context) 108 : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>( 109 )"; 110 os << llvm::formatv(patternClassStartStr, patternName); 111 112 os << "R\"mlir("; 113 pattern->print(os, OpPrintingFlags().enableDebugInfo()); 114 os << "\n )mlir\", context)) {\n"; 115 116 // Register any native functions used within the pattern. 117 StringSet<> registeredNativeFunctions; 118 auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) { 119 if (!nativeFunctions.count(fnName) || 120 !registeredNativeFunctions.insert(fnName).second) 121 return; 122 os << " register" << fnType << "Function(\"" << fnName << "\", " 123 << fnName << "PDLFn);\n"; 124 }; 125 pattern.walk([&](Operation *op) { 126 if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op)) 127 checkRegisterNativeFn(constraintOp.name(), "Constraint"); 128 else if (auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op)) 129 checkRegisterNativeFn(rewriteOp.name(), "Rewrite"); 130 }); 131 os << " }\n};\n\n"; 132 } 133 134 void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule, 135 ModuleOp module, 136 StringSet<> &nativeFunctions) { 137 // First check to see which constraints and rewrites are actually referenced 138 // in the module. 139 StringSet<> usedFns; 140 module.walk([&](Operation *op) { 141 TypeSwitch<Operation *>(op) 142 .Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>( 143 [&](auto op) { usedFns.insert(op.name()); }); 144 }); 145 146 for (const ast::Decl *decl : astModule.getChildren()) { 147 TypeSwitch<const ast::Decl *>(decl) 148 .Case<ast::UserConstraintDecl, ast::UserRewriteDecl>( 149 [&](const auto *decl) { 150 // We only generate code for inline native decls that have been 151 // referenced. 152 if (decl->getCodeBlock() && 153 usedFns.contains(decl->getName().getName())) 154 this->generate(decl, nativeFunctions); 155 }); 156 } 157 } 158 159 void CodeGen::generate(const ast::UserConstraintDecl *decl, 160 StringSet<> &nativeFunctions) { 161 return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl), 162 /*isConstraint=*/true, nativeFunctions); 163 } 164 165 void CodeGen::generate(const ast::UserRewriteDecl *decl, 166 StringSet<> &nativeFunctions) { 167 return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl), 168 /*isConstraint=*/false, nativeFunctions); 169 } 170 171 StringRef CodeGen::getNativeTypeName(ast::Type type) { 172 return llvm::TypeSwitch<ast::Type, StringRef>(type) 173 .Case([&](ast::AttributeType) { return "::mlir::Attribute"; }) 174 .Case([&](ast::OperationType opType) -> StringRef { 175 // Use the derived Op class when available. 176 if (const auto *odsOp = opType.getODSOperation()) 177 return odsOp->getNativeClassName(); 178 return "::mlir::Operation *"; 179 }) 180 .Case([&](ast::TypeType) { return "::mlir::Type"; }) 181 .Case([&](ast::ValueType) { return "::mlir::Value"; }) 182 .Case([&](ast::TypeRangeType) { return "::mlir::TypeRange"; }) 183 .Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; }); 184 } 185 186 StringRef CodeGen::getNativeTypeName(ast::VariableDecl *decl) { 187 // Try to extract a type name from the variable's constraints. 188 for (ast::ConstraintRef &cst : decl->getConstraints()) { 189 if (auto *userCst = dyn_cast<ast::UserConstraintDecl>(cst.constraint)) { 190 if (Optional<StringRef> name = userCst->getNativeInputType(0)) 191 return *name; 192 return getNativeTypeName(userCst->getInputs()[0]); 193 } 194 } 195 196 // Otherwise, use the type of the variable. 197 return getNativeTypeName(decl->getType()); 198 } 199 200 void CodeGen::generateConstraintOrRewrite(const ast::CallableDecl *decl, 201 bool isConstraint, 202 StringSet<> &nativeFunctions) { 203 StringRef name = decl->getName()->getName(); 204 nativeFunctions.insert(name); 205 206 os << "static "; 207 208 // TODO: Work out a proper modeling for "optionality". 209 210 // Emit the result type. 211 // If this is a constraint, we always return a LogicalResult. 212 // TODO: This will need to change if we allow Constraints to return values as 213 // well. 214 if (isConstraint) { 215 os << "::mlir::LogicalResult"; 216 } else { 217 // Otherwise, generate a type based on the results of the callable. 218 // If the callable has explicit results, use those to build the result. 219 // Otherwise, use the type of the callable. 220 ArrayRef<ast::VariableDecl *> results = decl->getResults(); 221 if (results.empty()) { 222 os << "void"; 223 } else if (results.size() == 1) { 224 os << getNativeTypeName(results[0]); 225 } else { 226 os << "std::tuple<"; 227 llvm::interleaveComma(results, os, [&](ast::VariableDecl *result) { 228 os << getNativeTypeName(result); 229 }); 230 os << ">"; 231 } 232 } 233 234 os << " " << name << "PDLFn(::mlir::PatternRewriter &rewriter"; 235 if (!decl->getInputs().empty()) { 236 os << ", "; 237 llvm::interleaveComma(decl->getInputs(), os, [&](ast::VariableDecl *input) { 238 os << getNativeTypeName(input) << " " << input->getName().getName(); 239 }); 240 } 241 os << ") {\n"; 242 os << " " << decl->getCodeBlock()->trim() << "\n}\n\n"; 243 } 244 245 //===----------------------------------------------------------------------===// 246 // CPPGen 247 //===----------------------------------------------------------------------===// 248 249 void mlir::pdll::codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module, 250 raw_ostream &os) { 251 CodeGen codegen(os); 252 codegen.generate(astModule, module); 253 } 254