1 //===- NodePrinter.cpp ----------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/AST/Context.h" 10 #include "mlir/Tools/PDLL/AST/Nodes.h" 11 #include "llvm/ADT/StringExtras.h" 12 #include "llvm/ADT/TypeSwitch.h" 13 #include "llvm/Support/SaveAndRestore.h" 14 #include "llvm/Support/ScopedPrinter.h" 15 16 using namespace mlir; 17 using namespace mlir::pdll::ast; 18 19 //===----------------------------------------------------------------------===// 20 // NodePrinter 21 //===----------------------------------------------------------------------===// 22 23 namespace { 24 class NodePrinter { 25 public: 26 NodePrinter(raw_ostream &os) : os(os) {} 27 28 /// Print the given type to the stream. 29 void print(Type type); 30 31 /// Print the given node to the stream. 32 void print(const Node *node); 33 34 private: 35 /// Print a range containing children of a node. 36 template <typename RangeT, 37 std::enable_if_t<!std::is_convertible<RangeT, const Node *>::value> 38 * = nullptr> 39 void printChildren(RangeT &&range) { 40 if (llvm::empty(range)) 41 return; 42 43 // Print the first N-1 elements with a prefix of "|-". 44 auto it = std::begin(range); 45 for (unsigned i = 0, e = llvm::size(range) - 1; i < e; ++i, ++it) 46 print(*it); 47 48 // Print the last element. 49 elementIndentStack.back() = true; 50 print(*it); 51 } 52 template <typename RangeT, typename... OthersT, 53 std::enable_if_t<std::is_convertible<RangeT, const Node *>::value> 54 * = nullptr> 55 void printChildren(RangeT &&range, OthersT &&...others) { 56 printChildren(ArrayRef<const Node *>({range, others...})); 57 } 58 /// Print a range containing children of a node, nesting the children under 59 /// the given label. 60 template <typename RangeT> 61 void printChildren(StringRef label, RangeT &&range) { 62 if (llvm::empty(range)) 63 return; 64 elementIndentStack.reserve(elementIndentStack.size() + 1); 65 llvm::SaveAndRestore<bool> lastElement(elementIndentStack.back(), true); 66 67 printIndent(); 68 os << label << "`\n"; 69 elementIndentStack.push_back(/*isLastElt*/ false); 70 printChildren(std::forward<RangeT>(range)); 71 elementIndentStack.pop_back(); 72 } 73 74 /// Print the given derived node to the stream. 75 void printImpl(const CompoundStmt *stmt); 76 void printImpl(const EraseStmt *stmt); 77 void printImpl(const LetStmt *stmt); 78 void printImpl(const ReplaceStmt *stmt); 79 void printImpl(const RewriteStmt *stmt); 80 81 void printImpl(const AttributeExpr *expr); 82 void printImpl(const DeclRefExpr *expr); 83 void printImpl(const MemberAccessExpr *expr); 84 void printImpl(const OperationExpr *expr); 85 void printImpl(const TupleExpr *expr); 86 void printImpl(const TypeExpr *expr); 87 88 void printImpl(const AttrConstraintDecl *decl); 89 void printImpl(const OpConstraintDecl *decl); 90 void printImpl(const TypeConstraintDecl *decl); 91 void printImpl(const TypeRangeConstraintDecl *decl); 92 void printImpl(const ValueConstraintDecl *decl); 93 void printImpl(const ValueRangeConstraintDecl *decl); 94 void printImpl(const NamedAttributeDecl *decl); 95 void printImpl(const OpNameDecl *decl); 96 void printImpl(const PatternDecl *decl); 97 void printImpl(const VariableDecl *decl); 98 void printImpl(const Module *module); 99 100 /// Print the current indent stack. 101 void printIndent() { 102 if (elementIndentStack.empty()) 103 return; 104 105 for (bool isLastElt : llvm::makeArrayRef(elementIndentStack).drop_back()) 106 os << (isLastElt ? " " : " |"); 107 os << (elementIndentStack.back() ? " `" : " |"); 108 } 109 110 /// The raw output stream. 111 raw_ostream &os; 112 113 /// A stack of indents and a flag indicating if the current element being 114 /// printed at that indent is the last element. 115 SmallVector<bool> elementIndentStack; 116 }; 117 } // namespace 118 119 void NodePrinter::print(Type type) { 120 // Protect against invalid inputs. 121 if (!type) { 122 os << "Type<NULL>"; 123 return; 124 } 125 126 TypeSwitch<Type>(type) 127 .Case([&](AttributeType) { os << "Attr"; }) 128 .Case([&](ConstraintType) { os << "Constraint"; }) 129 .Case([&](OperationType type) { 130 os << "Op"; 131 if (Optional<StringRef> name = type.getName()) 132 os << "<" << *name << ">"; 133 }) 134 .Case([&](RangeType type) { 135 print(type.getElementType()); 136 os << "Range"; 137 }) 138 .Case([&](TupleType type) { 139 os << "Tuple<"; 140 llvm::interleaveComma( 141 llvm::zip(type.getElementNames(), type.getElementTypes()), os, 142 [&](auto it) { 143 if (!std::get<0>(it).empty()) 144 os << std::get<0>(it) << ": "; 145 this->print(std::get<1>(it)); 146 }); 147 os << ">"; 148 }) 149 .Case([&](TypeType) { os << "Type"; }) 150 .Case([&](ValueType) { os << "Value"; }) 151 .Default([](Type) { llvm_unreachable("unknown AST type"); }); 152 } 153 154 void NodePrinter::print(const Node *node) { 155 printIndent(); 156 os << "-"; 157 158 elementIndentStack.push_back(/*isLastElt*/ false); 159 TypeSwitch<const Node *>(node) 160 .Case< 161 // Statements. 162 const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt, 163 const RewriteStmt, 164 165 // Expressions. 166 const AttributeExpr, const DeclRefExpr, const MemberAccessExpr, 167 const OperationExpr, const TupleExpr, const TypeExpr, 168 169 // Decls. 170 const AttrConstraintDecl, const OpConstraintDecl, 171 const TypeConstraintDecl, const TypeRangeConstraintDecl, 172 const ValueConstraintDecl, const ValueRangeConstraintDecl, 173 const NamedAttributeDecl, const OpNameDecl, const PatternDecl, 174 const VariableDecl, 175 176 const Module>([&](auto derivedNode) { this->printImpl(derivedNode); }) 177 .Default([](const Node *) { llvm_unreachable("unknown AST node"); }); 178 elementIndentStack.pop_back(); 179 } 180 181 void NodePrinter::printImpl(const CompoundStmt *stmt) { 182 os << "CompoundStmt " << stmt << "\n"; 183 printChildren(stmt->getChildren()); 184 } 185 186 void NodePrinter::printImpl(const EraseStmt *stmt) { 187 os << "EraseStmt " << stmt << "\n"; 188 printChildren(stmt->getRootOpExpr()); 189 } 190 191 void NodePrinter::printImpl(const LetStmt *stmt) { 192 os << "LetStmt " << stmt << "\n"; 193 printChildren(stmt->getVarDecl()); 194 } 195 196 void NodePrinter::printImpl(const ReplaceStmt *stmt) { 197 os << "ReplaceStmt " << stmt << "\n"; 198 printChildren(stmt->getRootOpExpr()); 199 printChildren("ReplValues", stmt->getReplExprs()); 200 } 201 202 void NodePrinter::printImpl(const RewriteStmt *stmt) { 203 os << "RewriteStmt " << stmt << "\n"; 204 printChildren(stmt->getRootOpExpr(), stmt->getRewriteBody()); 205 } 206 207 void NodePrinter::printImpl(const AttributeExpr *expr) { 208 os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n"; 209 } 210 211 void NodePrinter::printImpl(const DeclRefExpr *expr) { 212 os << "DeclRefExpr " << expr << " Type<"; 213 print(expr->getType()); 214 os << ">\n"; 215 printChildren(expr->getDecl()); 216 } 217 218 void NodePrinter::printImpl(const MemberAccessExpr *expr) { 219 os << "MemberAccessExpr " << expr << " Member<" << expr->getMemberName() 220 << "> Type<"; 221 print(expr->getType()); 222 os << ">\n"; 223 printChildren(expr->getParentExpr()); 224 } 225 226 void NodePrinter::printImpl(const OperationExpr *expr) { 227 os << "OperationExpr " << expr << " Type<"; 228 print(expr->getType()); 229 os << ">\n"; 230 231 printChildren(expr->getNameDecl()); 232 printChildren("Operands", expr->getOperands()); 233 printChildren("Result Types", expr->getResultTypes()); 234 printChildren("Attributes", expr->getAttributes()); 235 } 236 237 void NodePrinter::printImpl(const TupleExpr *expr) { 238 os << "TupleExpr " << expr << " Type<"; 239 print(expr->getType()); 240 os << ">\n"; 241 242 printChildren(expr->getElements()); 243 } 244 245 void NodePrinter::printImpl(const TypeExpr *expr) { 246 os << "TypeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n"; 247 } 248 249 void NodePrinter::printImpl(const AttrConstraintDecl *decl) { 250 os << "AttrConstraintDecl " << decl << "\n"; 251 if (const auto *typeExpr = decl->getTypeExpr()) 252 printChildren(typeExpr); 253 } 254 255 void NodePrinter::printImpl(const OpConstraintDecl *decl) { 256 os << "OpConstraintDecl " << decl << "\n"; 257 printChildren(decl->getNameDecl()); 258 } 259 260 void NodePrinter::printImpl(const TypeConstraintDecl *decl) { 261 os << "TypeConstraintDecl " << decl << "\n"; 262 } 263 264 void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) { 265 os << "TypeRangeConstraintDecl " << decl << "\n"; 266 } 267 268 void NodePrinter::printImpl(const ValueConstraintDecl *decl) { 269 os << "ValueConstraintDecl " << decl << "\n"; 270 if (const auto *typeExpr = decl->getTypeExpr()) 271 printChildren(typeExpr); 272 } 273 274 void NodePrinter::printImpl(const ValueRangeConstraintDecl *decl) { 275 os << "ValueRangeConstraintDecl " << decl << "\n"; 276 if (const auto *typeExpr = decl->getTypeExpr()) 277 printChildren(typeExpr); 278 } 279 280 void NodePrinter::printImpl(const NamedAttributeDecl *decl) { 281 os << "NamedAttributeDecl " << decl << " Name<" << decl->getName().getName() 282 << ">\n"; 283 printChildren(decl->getValue()); 284 } 285 286 void NodePrinter::printImpl(const OpNameDecl *decl) { 287 os << "OpNameDecl " << decl; 288 if (Optional<StringRef> name = decl->getName()) 289 os << " Name<" << name << ">"; 290 os << "\n"; 291 } 292 293 void NodePrinter::printImpl(const PatternDecl *decl) { 294 os << "PatternDecl " << decl; 295 if (const Name *name = decl->getName()) 296 os << " Name<" << name->getName() << ">"; 297 if (Optional<uint16_t> benefit = decl->getBenefit()) 298 os << " Benefit<" << *benefit << ">"; 299 if (decl->hasBoundedRewriteRecursion()) 300 os << " Recursion"; 301 302 os << "\n"; 303 printChildren(decl->getBody()); 304 } 305 306 void NodePrinter::printImpl(const VariableDecl *decl) { 307 os << "VariableDecl " << decl << " Name<" << decl->getName().getName() 308 << "> Type<"; 309 print(decl->getType()); 310 os << ">\n"; 311 if (Expr *initExpr = decl->getInitExpr()) 312 printChildren(initExpr); 313 314 auto constraints = 315 llvm::map_range(decl->getConstraints(), 316 [](const ConstraintRef &ref) { return ref.constraint; }); 317 printChildren("Constraints", constraints); 318 } 319 320 void NodePrinter::printImpl(const Module *module) { 321 os << "Module " << module << "\n"; 322 printChildren(module->getChildren()); 323 } 324 325 //===----------------------------------------------------------------------===// 326 // Entry point 327 //===----------------------------------------------------------------------===// 328 329 void Node::print(raw_ostream &os) const { NodePrinter(os).print(this); } 330 331 void Type::print(raw_ostream &os) const { NodePrinter(os).print(*this); } 332