1 //===- NodePrinter.cpp ----------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/AST/Context.h" 10 #include "mlir/Tools/PDLL/AST/Nodes.h" 11 #include "llvm/ADT/StringExtras.h" 12 #include "llvm/ADT/TypeSwitch.h" 13 #include "llvm/Support/SaveAndRestore.h" 14 #include "llvm/Support/ScopedPrinter.h" 15 16 using namespace mlir; 17 using namespace mlir::pdll::ast; 18 19 //===----------------------------------------------------------------------===// 20 // NodePrinter 21 //===----------------------------------------------------------------------===// 22 23 namespace { 24 class NodePrinter { 25 public: 26 NodePrinter(raw_ostream &os) : os(os) {} 27 28 /// Print the given type to the stream. 29 void print(Type type); 30 31 /// Print the given node to the stream. 32 void print(const Node *node); 33 34 private: 35 /// Print a range containing children of a node. 36 template <typename RangeT, 37 std::enable_if_t<!std::is_convertible<RangeT, const Node *>::value> 38 * = nullptr> 39 void printChildren(RangeT &&range) { 40 if (llvm::empty(range)) 41 return; 42 43 // Print the first N-1 elements with a prefix of "|-". 44 auto it = std::begin(range); 45 for (unsigned i = 0, e = llvm::size(range) - 1; i < e; ++i, ++it) 46 print(*it); 47 48 // Print the last element. 49 elementIndentStack.back() = true; 50 print(*it); 51 } 52 template <typename RangeT, typename... OthersT, 53 std::enable_if_t<std::is_convertible<RangeT, const Node *>::value> 54 * = nullptr> 55 void printChildren(RangeT &&range, OthersT &&...others) { 56 printChildren(ArrayRef<const Node *>({range, others...})); 57 } 58 /// Print a range containing children of a node, nesting the children under 59 /// the given label. 60 template <typename RangeT> 61 void printChildren(StringRef label, RangeT &&range) { 62 if (llvm::empty(range)) 63 return; 64 elementIndentStack.reserve(elementIndentStack.size() + 1); 65 llvm::SaveAndRestore<bool> lastElement(elementIndentStack.back(), true); 66 67 printIndent(); 68 os << label << "`\n"; 69 elementIndentStack.push_back(/*isLastElt*/ false); 70 printChildren(std::forward<RangeT>(range)); 71 elementIndentStack.pop_back(); 72 } 73 74 /// Print the given derived node to the stream. 75 void printImpl(const CompoundStmt *stmt); 76 void printImpl(const EraseStmt *stmt); 77 void printImpl(const LetStmt *stmt); 78 79 void printImpl(const AttributeExpr *expr); 80 void printImpl(const DeclRefExpr *expr); 81 void printImpl(const MemberAccessExpr *expr); 82 void printImpl(const TypeExpr *expr); 83 84 void printImpl(const AttrConstraintDecl *decl); 85 void printImpl(const OpConstraintDecl *decl); 86 void printImpl(const TypeConstraintDecl *decl); 87 void printImpl(const TypeRangeConstraintDecl *decl); 88 void printImpl(const ValueConstraintDecl *decl); 89 void printImpl(const ValueRangeConstraintDecl *decl); 90 void printImpl(const OpNameDecl *decl); 91 void printImpl(const PatternDecl *decl); 92 void printImpl(const VariableDecl *decl); 93 void printImpl(const Module *module); 94 95 /// Print the current indent stack. 96 void printIndent() { 97 if (elementIndentStack.empty()) 98 return; 99 100 for (bool isLastElt : llvm::makeArrayRef(elementIndentStack).drop_back()) 101 os << (isLastElt ? " " : " |"); 102 os << (elementIndentStack.back() ? " `" : " |"); 103 } 104 105 /// The raw output stream. 106 raw_ostream &os; 107 108 /// A stack of indents and a flag indicating if the current element being 109 /// printed at that indent is the last element. 110 SmallVector<bool> elementIndentStack; 111 }; 112 } // namespace 113 114 void NodePrinter::print(Type type) { 115 // Protect against invalid inputs. 116 if (!type) { 117 os << "Type<NULL>"; 118 return; 119 } 120 121 TypeSwitch<Type>(type) 122 .Case([&](AttributeType) { os << "Attr"; }) 123 .Case([&](ConstraintType) { os << "Constraint"; }) 124 .Case([&](OperationType type) { 125 os << "Op"; 126 if (Optional<StringRef> name = type.getName()) 127 os << "<" << *name << ">"; 128 }) 129 .Case([&](RangeType type) { 130 print(type.getElementType()); 131 os << "Range"; 132 }) 133 .Case([&](TypeType) { os << "Type"; }) 134 .Case([&](ValueType) { os << "Value"; }) 135 .Default([](Type) { llvm_unreachable("unknown AST type"); }); 136 } 137 138 void NodePrinter::print(const Node *node) { 139 printIndent(); 140 os << "-"; 141 142 elementIndentStack.push_back(/*isLastElt*/ false); 143 TypeSwitch<const Node *>(node) 144 .Case< 145 // Statements. 146 const CompoundStmt, const EraseStmt, const LetStmt, 147 148 // Expressions. 149 const AttributeExpr, const DeclRefExpr, const MemberAccessExpr, 150 const TypeExpr, 151 152 // Decls. 153 const AttrConstraintDecl, const OpConstraintDecl, 154 const TypeConstraintDecl, const TypeRangeConstraintDecl, 155 const ValueConstraintDecl, const ValueRangeConstraintDecl, 156 const OpNameDecl, const PatternDecl, const VariableDecl, 157 158 const Module>([&](auto derivedNode) { this->printImpl(derivedNode); }) 159 .Default([](const Node *) { llvm_unreachable("unknown AST node"); }); 160 elementIndentStack.pop_back(); 161 } 162 163 void NodePrinter::printImpl(const CompoundStmt *stmt) { 164 os << "CompoundStmt " << stmt << "\n"; 165 printChildren(stmt->getChildren()); 166 } 167 168 void NodePrinter::printImpl(const EraseStmt *stmt) { 169 os << "EraseStmt " << stmt << "\n"; 170 printChildren(stmt->getRootOpExpr()); 171 } 172 173 void NodePrinter::printImpl(const LetStmt *stmt) { 174 os << "LetStmt " << stmt << "\n"; 175 printChildren(stmt->getVarDecl()); 176 } 177 178 void NodePrinter::printImpl(const AttributeExpr *expr) { 179 os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n"; 180 } 181 182 void NodePrinter::printImpl(const DeclRefExpr *expr) { 183 os << "DeclRefExpr " << expr << " Type<"; 184 print(expr->getType()); 185 os << ">\n"; 186 printChildren(expr->getDecl()); 187 } 188 189 void NodePrinter::printImpl(const MemberAccessExpr *expr) { 190 os << "MemberAccessExpr " << expr << " Member<" << expr->getMemberName() 191 << "> Type<"; 192 print(expr->getType()); 193 os << ">\n"; 194 printChildren(expr->getParentExpr()); 195 } 196 197 void NodePrinter::printImpl(const TypeExpr *expr) { 198 os << "TypeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n"; 199 } 200 201 void NodePrinter::printImpl(const AttrConstraintDecl *decl) { 202 os << "AttrConstraintDecl " << decl << "\n"; 203 if (const auto *typeExpr = decl->getTypeExpr()) 204 printChildren(typeExpr); 205 } 206 207 void NodePrinter::printImpl(const OpConstraintDecl *decl) { 208 os << "OpConstraintDecl " << decl << "\n"; 209 printChildren(decl->getNameDecl()); 210 } 211 212 void NodePrinter::printImpl(const TypeConstraintDecl *decl) { 213 os << "TypeConstraintDecl " << decl << "\n"; 214 } 215 216 void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) { 217 os << "TypeRangeConstraintDecl " << decl << "\n"; 218 } 219 220 void NodePrinter::printImpl(const ValueConstraintDecl *decl) { 221 os << "ValueConstraintDecl " << decl << "\n"; 222 if (const auto *typeExpr = decl->getTypeExpr()) 223 printChildren(typeExpr); 224 } 225 226 void NodePrinter::printImpl(const ValueRangeConstraintDecl *decl) { 227 os << "ValueRangeConstraintDecl " << decl << "\n"; 228 if (const auto *typeExpr = decl->getTypeExpr()) 229 printChildren(typeExpr); 230 } 231 232 void NodePrinter::printImpl(const OpNameDecl *decl) { 233 os << "OpNameDecl " << decl; 234 if (Optional<StringRef> name = decl->getName()) 235 os << " Name<" << name << ">"; 236 os << "\n"; 237 } 238 239 void NodePrinter::printImpl(const PatternDecl *decl) { 240 os << "PatternDecl " << decl; 241 if (const Name *name = decl->getName()) 242 os << " Name<" << name->getName() << ">"; 243 if (Optional<uint16_t> benefit = decl->getBenefit()) 244 os << " Benefit<" << *benefit << ">"; 245 if (decl->hasBoundedRewriteRecursion()) 246 os << " Recursion"; 247 248 os << "\n"; 249 printChildren(decl->getBody()); 250 } 251 252 void NodePrinter::printImpl(const VariableDecl *decl) { 253 os << "VariableDecl " << decl << " Name<" << decl->getName().getName() 254 << "> Type<"; 255 print(decl->getType()); 256 os << ">\n"; 257 if (Expr *initExpr = decl->getInitExpr()) 258 printChildren(initExpr); 259 260 auto constraints = 261 llvm::map_range(decl->getConstraints(), 262 [](const ConstraintRef &ref) { return ref.constraint; }); 263 printChildren("Constraints", constraints); 264 } 265 266 void NodePrinter::printImpl(const Module *module) { 267 os << "Module " << module << "\n"; 268 printChildren(module->getChildren()); 269 } 270 271 //===----------------------------------------------------------------------===// 272 // Entry point 273 //===----------------------------------------------------------------------===// 274 275 void Node::print(raw_ostream &os) const { NodePrinter(os).print(this); } 276 277 void Type::print(raw_ostream &os) const { NodePrinter(os).print(*this); } 278