1 //===- NodePrinter.cpp ----------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/AST/Context.h" 10 #include "mlir/Tools/PDLL/AST/Nodes.h" 11 #include "llvm/ADT/StringExtras.h" 12 #include "llvm/ADT/TypeSwitch.h" 13 #include "llvm/Support/SaveAndRestore.h" 14 #include "llvm/Support/ScopedPrinter.h" 15 16 using namespace mlir; 17 using namespace mlir::pdll::ast; 18 19 //===----------------------------------------------------------------------===// 20 // NodePrinter 21 //===----------------------------------------------------------------------===// 22 23 namespace { 24 class NodePrinter { 25 public: 26 NodePrinter(raw_ostream &os) : os(os) {} 27 28 /// Print the given type to the stream. 29 void print(Type type); 30 31 /// Print the given node to the stream. 32 void print(const Node *node); 33 34 private: 35 /// Print a range containing children of a node. 36 template <typename RangeT, 37 std::enable_if_t<!std::is_convertible<RangeT, const Node *>::value> 38 * = nullptr> 39 void printChildren(RangeT &&range) { 40 if (llvm::empty(range)) 41 return; 42 43 // Print the first N-1 elements with a prefix of "|-". 44 auto it = std::begin(range); 45 for (unsigned i = 0, e = llvm::size(range) - 1; i < e; ++i, ++it) 46 print(*it); 47 48 // Print the last element. 49 elementIndentStack.back() = true; 50 print(*it); 51 } 52 template <typename RangeT, typename... OthersT, 53 std::enable_if_t<std::is_convertible<RangeT, const Node *>::value> 54 * = nullptr> 55 void printChildren(RangeT &&range, OthersT &&...others) { 56 printChildren(ArrayRef<const Node *>({range, others...})); 57 } 58 /// Print a range containing children of a node, nesting the children under 59 /// the given label. 60 template <typename RangeT> 61 void printChildren(StringRef label, RangeT &&range) { 62 if (llvm::empty(range)) 63 return; 64 elementIndentStack.reserve(elementIndentStack.size() + 1); 65 llvm::SaveAndRestore<bool> lastElement(elementIndentStack.back(), true); 66 67 printIndent(); 68 os << label << "`\n"; 69 elementIndentStack.push_back(/*isLastElt*/ false); 70 printChildren(std::forward<RangeT>(range)); 71 elementIndentStack.pop_back(); 72 } 73 74 /// Print the given derived node to the stream. 75 void printImpl(const CompoundStmt *stmt); 76 void printImpl(const EraseStmt *stmt); 77 void printImpl(const LetStmt *stmt); 78 void printImpl(const ReplaceStmt *stmt); 79 void printImpl(const ReturnStmt *stmt); 80 void printImpl(const RewriteStmt *stmt); 81 82 void printImpl(const AttributeExpr *expr); 83 void printImpl(const CallExpr *expr); 84 void printImpl(const DeclRefExpr *expr); 85 void printImpl(const MemberAccessExpr *expr); 86 void printImpl(const OperationExpr *expr); 87 void printImpl(const TupleExpr *expr); 88 void printImpl(const TypeExpr *expr); 89 90 void printImpl(const AttrConstraintDecl *decl); 91 void printImpl(const OpConstraintDecl *decl); 92 void printImpl(const TypeConstraintDecl *decl); 93 void printImpl(const TypeRangeConstraintDecl *decl); 94 void printImpl(const UserConstraintDecl *decl); 95 void printImpl(const ValueConstraintDecl *decl); 96 void printImpl(const ValueRangeConstraintDecl *decl); 97 void printImpl(const NamedAttributeDecl *decl); 98 void printImpl(const OpNameDecl *decl); 99 void printImpl(const PatternDecl *decl); 100 void printImpl(const UserRewriteDecl *decl); 101 void printImpl(const VariableDecl *decl); 102 void printImpl(const Module *module); 103 104 /// Print the current indent stack. 105 void printIndent() { 106 if (elementIndentStack.empty()) 107 return; 108 109 for (bool isLastElt : llvm::makeArrayRef(elementIndentStack).drop_back()) 110 os << (isLastElt ? " " : " |"); 111 os << (elementIndentStack.back() ? " `" : " |"); 112 } 113 114 /// The raw output stream. 115 raw_ostream &os; 116 117 /// A stack of indents and a flag indicating if the current element being 118 /// printed at that indent is the last element. 119 SmallVector<bool> elementIndentStack; 120 }; 121 } // namespace 122 123 void NodePrinter::print(Type type) { 124 // Protect against invalid inputs. 125 if (!type) { 126 os << "Type<NULL>"; 127 return; 128 } 129 130 TypeSwitch<Type>(type) 131 .Case([&](AttributeType) { os << "Attr"; }) 132 .Case([&](ConstraintType) { os << "Constraint"; }) 133 .Case([&](OperationType type) { 134 os << "Op"; 135 if (Optional<StringRef> name = type.getName()) 136 os << "<" << *name << ">"; 137 }) 138 .Case([&](RangeType type) { 139 print(type.getElementType()); 140 os << "Range"; 141 }) 142 .Case([&](RewriteType) { os << "Rewrite"; }) 143 .Case([&](TupleType type) { 144 os << "Tuple<"; 145 llvm::interleaveComma( 146 llvm::zip(type.getElementNames(), type.getElementTypes()), os, 147 [&](auto it) { 148 if (!std::get<0>(it).empty()) 149 os << std::get<0>(it) << ": "; 150 this->print(std::get<1>(it)); 151 }); 152 os << ">"; 153 }) 154 .Case([&](TypeType) { os << "Type"; }) 155 .Case([&](ValueType) { os << "Value"; }) 156 .Default([](Type) { llvm_unreachable("unknown AST type"); }); 157 } 158 159 void NodePrinter::print(const Node *node) { 160 printIndent(); 161 os << "-"; 162 163 elementIndentStack.push_back(/*isLastElt*/ false); 164 TypeSwitch<const Node *>(node) 165 .Case< 166 // Statements. 167 const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt, 168 const ReturnStmt, const RewriteStmt, 169 170 // Expressions. 171 const AttributeExpr, const CallExpr, const DeclRefExpr, 172 const MemberAccessExpr, const OperationExpr, const TupleExpr, 173 const TypeExpr, 174 175 // Decls. 176 const AttrConstraintDecl, const OpConstraintDecl, 177 const TypeConstraintDecl, const TypeRangeConstraintDecl, 178 const UserConstraintDecl, const ValueConstraintDecl, 179 const ValueRangeConstraintDecl, const NamedAttributeDecl, 180 const OpNameDecl, const PatternDecl, const UserRewriteDecl, 181 const VariableDecl, 182 183 const Module>([&](auto derivedNode) { this->printImpl(derivedNode); }) 184 .Default([](const Node *) { llvm_unreachable("unknown AST node"); }); 185 elementIndentStack.pop_back(); 186 } 187 188 void NodePrinter::printImpl(const CompoundStmt *stmt) { 189 os << "CompoundStmt " << stmt << "\n"; 190 printChildren(stmt->getChildren()); 191 } 192 193 void NodePrinter::printImpl(const EraseStmt *stmt) { 194 os << "EraseStmt " << stmt << "\n"; 195 printChildren(stmt->getRootOpExpr()); 196 } 197 198 void NodePrinter::printImpl(const LetStmt *stmt) { 199 os << "LetStmt " << stmt << "\n"; 200 printChildren(stmt->getVarDecl()); 201 } 202 203 void NodePrinter::printImpl(const ReplaceStmt *stmt) { 204 os << "ReplaceStmt " << stmt << "\n"; 205 printChildren(stmt->getRootOpExpr()); 206 printChildren("ReplValues", stmt->getReplExprs()); 207 } 208 209 void NodePrinter::printImpl(const ReturnStmt *stmt) { 210 os << "ReturnStmt " << stmt << "\n"; 211 printChildren(stmt->getResultExpr()); 212 } 213 214 void NodePrinter::printImpl(const RewriteStmt *stmt) { 215 os << "RewriteStmt " << stmt << "\n"; 216 printChildren(stmt->getRootOpExpr(), stmt->getRewriteBody()); 217 } 218 219 void NodePrinter::printImpl(const AttributeExpr *expr) { 220 os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n"; 221 } 222 223 void NodePrinter::printImpl(const CallExpr *expr) { 224 os << "CallExpr " << expr << " Type<"; 225 print(expr->getType()); 226 os << ">\n"; 227 printChildren(expr->getCallableExpr()); 228 printChildren("Arguments", expr->getArguments()); 229 } 230 231 void NodePrinter::printImpl(const DeclRefExpr *expr) { 232 os << "DeclRefExpr " << expr << " Type<"; 233 print(expr->getType()); 234 os << ">\n"; 235 printChildren(expr->getDecl()); 236 } 237 238 void NodePrinter::printImpl(const MemberAccessExpr *expr) { 239 os << "MemberAccessExpr " << expr << " Member<" << expr->getMemberName() 240 << "> Type<"; 241 print(expr->getType()); 242 os << ">\n"; 243 printChildren(expr->getParentExpr()); 244 } 245 246 void NodePrinter::printImpl(const OperationExpr *expr) { 247 os << "OperationExpr " << expr << " Type<"; 248 print(expr->getType()); 249 os << ">\n"; 250 251 printChildren(expr->getNameDecl()); 252 printChildren("Operands", expr->getOperands()); 253 printChildren("Result Types", expr->getResultTypes()); 254 printChildren("Attributes", expr->getAttributes()); 255 } 256 257 void NodePrinter::printImpl(const TupleExpr *expr) { 258 os << "TupleExpr " << expr << " Type<"; 259 print(expr->getType()); 260 os << ">\n"; 261 262 printChildren(expr->getElements()); 263 } 264 265 void NodePrinter::printImpl(const TypeExpr *expr) { 266 os << "TypeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n"; 267 } 268 269 void NodePrinter::printImpl(const AttrConstraintDecl *decl) { 270 os << "AttrConstraintDecl " << decl << "\n"; 271 if (const auto *typeExpr = decl->getTypeExpr()) 272 printChildren(typeExpr); 273 } 274 275 void NodePrinter::printImpl(const OpConstraintDecl *decl) { 276 os << "OpConstraintDecl " << decl << "\n"; 277 printChildren(decl->getNameDecl()); 278 } 279 280 void NodePrinter::printImpl(const TypeConstraintDecl *decl) { 281 os << "TypeConstraintDecl " << decl << "\n"; 282 } 283 284 void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) { 285 os << "TypeRangeConstraintDecl " << decl << "\n"; 286 } 287 288 void NodePrinter::printImpl(const UserConstraintDecl *decl) { 289 os << "UserConstraintDecl " << decl << " Name<" << decl->getName().getName() 290 << "> ResultType<" << decl->getResultType() << ">"; 291 if (Optional<StringRef> codeBlock = decl->getCodeBlock()) { 292 os << " Code<"; 293 llvm::printEscapedString(*codeBlock, os); 294 os << ">"; 295 } 296 os << "\n"; 297 printChildren("Inputs", decl->getInputs()); 298 printChildren("Results", decl->getResults()); 299 if (const CompoundStmt *body = decl->getBody()) 300 printChildren(body); 301 } 302 303 void NodePrinter::printImpl(const ValueConstraintDecl *decl) { 304 os << "ValueConstraintDecl " << decl << "\n"; 305 if (const auto *typeExpr = decl->getTypeExpr()) 306 printChildren(typeExpr); 307 } 308 309 void NodePrinter::printImpl(const ValueRangeConstraintDecl *decl) { 310 os << "ValueRangeConstraintDecl " << decl << "\n"; 311 if (const auto *typeExpr = decl->getTypeExpr()) 312 printChildren(typeExpr); 313 } 314 315 void NodePrinter::printImpl(const NamedAttributeDecl *decl) { 316 os << "NamedAttributeDecl " << decl << " Name<" << decl->getName().getName() 317 << ">\n"; 318 printChildren(decl->getValue()); 319 } 320 321 void NodePrinter::printImpl(const OpNameDecl *decl) { 322 os << "OpNameDecl " << decl; 323 if (Optional<StringRef> name = decl->getName()) 324 os << " Name<" << name << ">"; 325 os << "\n"; 326 } 327 328 void NodePrinter::printImpl(const PatternDecl *decl) { 329 os << "PatternDecl " << decl; 330 if (const Name *name = decl->getName()) 331 os << " Name<" << name->getName() << ">"; 332 if (Optional<uint16_t> benefit = decl->getBenefit()) 333 os << " Benefit<" << *benefit << ">"; 334 if (decl->hasBoundedRewriteRecursion()) 335 os << " Recursion"; 336 337 os << "\n"; 338 printChildren(decl->getBody()); 339 } 340 341 void NodePrinter::printImpl(const UserRewriteDecl *decl) { 342 os << "UserRewriteDecl " << decl << " Name<" << decl->getName().getName() 343 << "> ResultType<" << decl->getResultType() << ">"; 344 if (Optional<StringRef> codeBlock = decl->getCodeBlock()) { 345 os << " Code<"; 346 llvm::printEscapedString(*codeBlock, os); 347 os << ">"; 348 } 349 os << "\n"; 350 printChildren("Inputs", decl->getInputs()); 351 printChildren("Results", decl->getResults()); 352 if (const CompoundStmt *body = decl->getBody()) 353 printChildren(body); 354 } 355 356 void NodePrinter::printImpl(const VariableDecl *decl) { 357 os << "VariableDecl " << decl << " Name<" << decl->getName().getName() 358 << "> Type<"; 359 print(decl->getType()); 360 os << ">\n"; 361 if (Expr *initExpr = decl->getInitExpr()) 362 printChildren(initExpr); 363 364 auto constraints = 365 llvm::map_range(decl->getConstraints(), 366 [](const ConstraintRef &ref) { return ref.constraint; }); 367 printChildren("Constraints", constraints); 368 } 369 370 void NodePrinter::printImpl(const Module *module) { 371 os << "Module " << module << "\n"; 372 printChildren(module->getChildren()); 373 } 374 375 //===----------------------------------------------------------------------===// 376 // Entry point 377 //===----------------------------------------------------------------------===// 378 379 void Node::print(raw_ostream &os) const { NodePrinter(os).print(this); } 380 381 void Type::print(raw_ostream &os) const { NodePrinter(os).print(*this); } 382