1 //===- NodePrinter.cpp ----------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/AST/Context.h" 10 #include "mlir/Tools/PDLL/AST/Nodes.h" 11 #include "llvm/ADT/StringExtras.h" 12 #include "llvm/ADT/TypeSwitch.h" 13 #include "llvm/Support/SaveAndRestore.h" 14 #include "llvm/Support/ScopedPrinter.h" 15 16 using namespace mlir; 17 using namespace mlir::pdll::ast; 18 19 //===----------------------------------------------------------------------===// 20 // NodePrinter 21 //===----------------------------------------------------------------------===// 22 23 namespace { 24 class NodePrinter { 25 public: 26 NodePrinter(raw_ostream &os) : os(os) {} 27 28 /// Print the given type to the stream. 29 void print(Type type); 30 31 /// Print the given node to the stream. 32 void print(const Node *node); 33 34 private: 35 /// Print a range containing children of a node. 36 template <typename RangeT, 37 std::enable_if_t<!std::is_convertible<RangeT, const Node *>::value> 38 * = nullptr> 39 void printChildren(RangeT &&range) { 40 if (llvm::empty(range)) 41 return; 42 43 // Print the first N-1 elements with a prefix of "|-". 44 auto it = std::begin(range); 45 for (unsigned i = 0, e = llvm::size(range) - 1; i < e; ++i, ++it) 46 print(*it); 47 48 // Print the last element. 49 elementIndentStack.back() = true; 50 print(*it); 51 } 52 template <typename RangeT, typename... OthersT, 53 std::enable_if_t<std::is_convertible<RangeT, const Node *>::value> 54 * = nullptr> 55 void printChildren(RangeT &&range, OthersT &&...others) { 56 printChildren(ArrayRef<const Node *>({range, others...})); 57 } 58 /// Print a range containing children of a node, nesting the children under 59 /// the given label. 60 template <typename RangeT> 61 void printChildren(StringRef label, RangeT &&range) { 62 if (llvm::empty(range)) 63 return; 64 elementIndentStack.reserve(elementIndentStack.size() + 1); 65 llvm::SaveAndRestore<bool> lastElement(elementIndentStack.back(), true); 66 67 printIndent(); 68 os << label << "`\n"; 69 elementIndentStack.push_back(/*isLastElt*/ false); 70 printChildren(std::forward<RangeT>(range)); 71 elementIndentStack.pop_back(); 72 } 73 74 /// Print the given derived node to the stream. 75 void printImpl(const CompoundStmt *stmt); 76 void printImpl(const EraseStmt *stmt); 77 void printImpl(const LetStmt *stmt); 78 void printImpl(const ReplaceStmt *stmt); 79 80 void printImpl(const AttributeExpr *expr); 81 void printImpl(const DeclRefExpr *expr); 82 void printImpl(const MemberAccessExpr *expr); 83 void printImpl(const OperationExpr *expr); 84 void printImpl(const TupleExpr *expr); 85 void printImpl(const TypeExpr *expr); 86 87 void printImpl(const AttrConstraintDecl *decl); 88 void printImpl(const OpConstraintDecl *decl); 89 void printImpl(const TypeConstraintDecl *decl); 90 void printImpl(const TypeRangeConstraintDecl *decl); 91 void printImpl(const ValueConstraintDecl *decl); 92 void printImpl(const ValueRangeConstraintDecl *decl); 93 void printImpl(const NamedAttributeDecl *decl); 94 void printImpl(const OpNameDecl *decl); 95 void printImpl(const PatternDecl *decl); 96 void printImpl(const VariableDecl *decl); 97 void printImpl(const Module *module); 98 99 /// Print the current indent stack. 100 void printIndent() { 101 if (elementIndentStack.empty()) 102 return; 103 104 for (bool isLastElt : llvm::makeArrayRef(elementIndentStack).drop_back()) 105 os << (isLastElt ? " " : " |"); 106 os << (elementIndentStack.back() ? " `" : " |"); 107 } 108 109 /// The raw output stream. 110 raw_ostream &os; 111 112 /// A stack of indents and a flag indicating if the current element being 113 /// printed at that indent is the last element. 114 SmallVector<bool> elementIndentStack; 115 }; 116 } // namespace 117 118 void NodePrinter::print(Type type) { 119 // Protect against invalid inputs. 120 if (!type) { 121 os << "Type<NULL>"; 122 return; 123 } 124 125 TypeSwitch<Type>(type) 126 .Case([&](AttributeType) { os << "Attr"; }) 127 .Case([&](ConstraintType) { os << "Constraint"; }) 128 .Case([&](OperationType type) { 129 os << "Op"; 130 if (Optional<StringRef> name = type.getName()) 131 os << "<" << *name << ">"; 132 }) 133 .Case([&](RangeType type) { 134 print(type.getElementType()); 135 os << "Range"; 136 }) 137 .Case([&](TupleType type) { 138 os << "Tuple<"; 139 llvm::interleaveComma( 140 llvm::zip(type.getElementNames(), type.getElementTypes()), os, 141 [&](auto it) { 142 if (!std::get<0>(it).empty()) 143 os << std::get<0>(it) << ": "; 144 this->print(std::get<1>(it)); 145 }); 146 os << ">"; 147 }) 148 .Case([&](TypeType) { os << "Type"; }) 149 .Case([&](ValueType) { os << "Value"; }) 150 .Default([](Type) { llvm_unreachable("unknown AST type"); }); 151 } 152 153 void NodePrinter::print(const Node *node) { 154 printIndent(); 155 os << "-"; 156 157 elementIndentStack.push_back(/*isLastElt*/ false); 158 TypeSwitch<const Node *>(node) 159 .Case< 160 // Statements. 161 const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt, 162 163 // Expressions. 164 const AttributeExpr, const DeclRefExpr, const MemberAccessExpr, 165 const OperationExpr, const TupleExpr, const TypeExpr, 166 167 // Decls. 168 const AttrConstraintDecl, const OpConstraintDecl, 169 const TypeConstraintDecl, const TypeRangeConstraintDecl, 170 const ValueConstraintDecl, const ValueRangeConstraintDecl, 171 const NamedAttributeDecl, const OpNameDecl, const PatternDecl, 172 const VariableDecl, 173 174 const Module>([&](auto derivedNode) { this->printImpl(derivedNode); }) 175 .Default([](const Node *) { llvm_unreachable("unknown AST node"); }); 176 elementIndentStack.pop_back(); 177 } 178 179 void NodePrinter::printImpl(const CompoundStmt *stmt) { 180 os << "CompoundStmt " << stmt << "\n"; 181 printChildren(stmt->getChildren()); 182 } 183 184 void NodePrinter::printImpl(const EraseStmt *stmt) { 185 os << "EraseStmt " << stmt << "\n"; 186 printChildren(stmt->getRootOpExpr()); 187 } 188 189 void NodePrinter::printImpl(const LetStmt *stmt) { 190 os << "LetStmt " << stmt << "\n"; 191 printChildren(stmt->getVarDecl()); 192 } 193 194 void NodePrinter::printImpl(const ReplaceStmt *stmt) { 195 os << "ReplaceStmt " << stmt << "\n"; 196 printChildren(stmt->getRootOpExpr()); 197 printChildren("ReplValues", stmt->getReplExprs()); 198 } 199 200 void NodePrinter::printImpl(const AttributeExpr *expr) { 201 os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n"; 202 } 203 204 void NodePrinter::printImpl(const DeclRefExpr *expr) { 205 os << "DeclRefExpr " << expr << " Type<"; 206 print(expr->getType()); 207 os << ">\n"; 208 printChildren(expr->getDecl()); 209 } 210 211 void NodePrinter::printImpl(const MemberAccessExpr *expr) { 212 os << "MemberAccessExpr " << expr << " Member<" << expr->getMemberName() 213 << "> Type<"; 214 print(expr->getType()); 215 os << ">\n"; 216 printChildren(expr->getParentExpr()); 217 } 218 219 void NodePrinter::printImpl(const OperationExpr *expr) { 220 os << "OperationExpr " << expr << " Type<"; 221 print(expr->getType()); 222 os << ">\n"; 223 224 printChildren(expr->getNameDecl()); 225 printChildren("Operands", expr->getOperands()); 226 printChildren("Result Types", expr->getResultTypes()); 227 printChildren("Attributes", expr->getAttributes()); 228 } 229 230 void NodePrinter::printImpl(const TupleExpr *expr) { 231 os << "TupleExpr " << expr << " Type<"; 232 print(expr->getType()); 233 os << ">\n"; 234 235 printChildren(expr->getElements()); 236 } 237 238 void NodePrinter::printImpl(const TypeExpr *expr) { 239 os << "TypeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n"; 240 } 241 242 void NodePrinter::printImpl(const AttrConstraintDecl *decl) { 243 os << "AttrConstraintDecl " << decl << "\n"; 244 if (const auto *typeExpr = decl->getTypeExpr()) 245 printChildren(typeExpr); 246 } 247 248 void NodePrinter::printImpl(const OpConstraintDecl *decl) { 249 os << "OpConstraintDecl " << decl << "\n"; 250 printChildren(decl->getNameDecl()); 251 } 252 253 void NodePrinter::printImpl(const TypeConstraintDecl *decl) { 254 os << "TypeConstraintDecl " << decl << "\n"; 255 } 256 257 void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) { 258 os << "TypeRangeConstraintDecl " << decl << "\n"; 259 } 260 261 void NodePrinter::printImpl(const ValueConstraintDecl *decl) { 262 os << "ValueConstraintDecl " << decl << "\n"; 263 if (const auto *typeExpr = decl->getTypeExpr()) 264 printChildren(typeExpr); 265 } 266 267 void NodePrinter::printImpl(const ValueRangeConstraintDecl *decl) { 268 os << "ValueRangeConstraintDecl " << decl << "\n"; 269 if (const auto *typeExpr = decl->getTypeExpr()) 270 printChildren(typeExpr); 271 } 272 273 void NodePrinter::printImpl(const NamedAttributeDecl *decl) { 274 os << "NamedAttributeDecl " << decl << " Name<" << decl->getName().getName() 275 << ">\n"; 276 printChildren(decl->getValue()); 277 } 278 279 void NodePrinter::printImpl(const OpNameDecl *decl) { 280 os << "OpNameDecl " << decl; 281 if (Optional<StringRef> name = decl->getName()) 282 os << " Name<" << name << ">"; 283 os << "\n"; 284 } 285 286 void NodePrinter::printImpl(const PatternDecl *decl) { 287 os << "PatternDecl " << decl; 288 if (const Name *name = decl->getName()) 289 os << " Name<" << name->getName() << ">"; 290 if (Optional<uint16_t> benefit = decl->getBenefit()) 291 os << " Benefit<" << *benefit << ">"; 292 if (decl->hasBoundedRewriteRecursion()) 293 os << " Recursion"; 294 295 os << "\n"; 296 printChildren(decl->getBody()); 297 } 298 299 void NodePrinter::printImpl(const VariableDecl *decl) { 300 os << "VariableDecl " << decl << " Name<" << decl->getName().getName() 301 << "> Type<"; 302 print(decl->getType()); 303 os << ">\n"; 304 if (Expr *initExpr = decl->getInitExpr()) 305 printChildren(initExpr); 306 307 auto constraints = 308 llvm::map_range(decl->getConstraints(), 309 [](const ConstraintRef &ref) { return ref.constraint; }); 310 printChildren("Constraints", constraints); 311 } 312 313 void NodePrinter::printImpl(const Module *module) { 314 os << "Module " << module << "\n"; 315 printChildren(module->getChildren()); 316 } 317 318 //===----------------------------------------------------------------------===// 319 // Entry point 320 //===----------------------------------------------------------------------===// 321 322 void Node::print(raw_ostream &os) const { NodePrinter(os).print(this); } 323 324 void Type::print(raw_ostream &os) const { NodePrinter(os).print(*this); } 325