1 //===- NodePrinter.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Tools/PDLL/AST/Context.h"
10 #include "mlir/Tools/PDLL/AST/Nodes.h"
11 #include "llvm/ADT/StringExtras.h"
12 #include "llvm/ADT/TypeSwitch.h"
13 #include "llvm/Support/SaveAndRestore.h"
14 #include "llvm/Support/ScopedPrinter.h"
15
16 using namespace mlir;
17 using namespace mlir::pdll::ast;
18
19 //===----------------------------------------------------------------------===//
20 // NodePrinter
21 //===----------------------------------------------------------------------===//
22
23 namespace {
24 class NodePrinter {
25 public:
NodePrinter(raw_ostream & os)26 NodePrinter(raw_ostream &os) : os(os) {}
27
28 /// Print the given type to the stream.
29 void print(Type type);
30
31 /// Print the given node to the stream.
32 void print(const Node *node);
33
34 private:
35 /// Print a range containing children of a node.
36 template <typename RangeT,
37 std::enable_if_t<!std::is_convertible<RangeT, const Node *>::value>
38 * = nullptr>
printChildren(RangeT && range)39 void printChildren(RangeT &&range) {
40 if (llvm::empty(range))
41 return;
42
43 // Print the first N-1 elements with a prefix of "|-".
44 auto it = std::begin(range);
45 for (unsigned i = 0, e = llvm::size(range) - 1; i < e; ++i, ++it)
46 print(*it);
47
48 // Print the last element.
49 elementIndentStack.back() = true;
50 print(*it);
51 }
52 template <typename RangeT, typename... OthersT,
53 std::enable_if_t<std::is_convertible<RangeT, const Node *>::value>
54 * = nullptr>
printChildren(RangeT && range,OthersT &&...others)55 void printChildren(RangeT &&range, OthersT &&...others) {
56 printChildren(ArrayRef<const Node *>({range, others...}));
57 }
58 /// Print a range containing children of a node, nesting the children under
59 /// the given label.
60 template <typename RangeT>
printChildren(StringRef label,RangeT && range)61 void printChildren(StringRef label, RangeT &&range) {
62 if (llvm::empty(range))
63 return;
64 elementIndentStack.reserve(elementIndentStack.size() + 1);
65 llvm::SaveAndRestore<bool> lastElement(elementIndentStack.back(), true);
66
67 printIndent();
68 os << label << "`\n";
69 elementIndentStack.push_back(/*isLastElt*/ false);
70 printChildren(std::forward<RangeT>(range));
71 elementIndentStack.pop_back();
72 }
73
74 /// Print the given derived node to the stream.
75 void printImpl(const CompoundStmt *stmt);
76 void printImpl(const EraseStmt *stmt);
77 void printImpl(const LetStmt *stmt);
78 void printImpl(const ReplaceStmt *stmt);
79 void printImpl(const ReturnStmt *stmt);
80 void printImpl(const RewriteStmt *stmt);
81
82 void printImpl(const AttributeExpr *expr);
83 void printImpl(const CallExpr *expr);
84 void printImpl(const DeclRefExpr *expr);
85 void printImpl(const MemberAccessExpr *expr);
86 void printImpl(const OperationExpr *expr);
87 void printImpl(const TupleExpr *expr);
88 void printImpl(const TypeExpr *expr);
89
90 void printImpl(const AttrConstraintDecl *decl);
91 void printImpl(const OpConstraintDecl *decl);
92 void printImpl(const TypeConstraintDecl *decl);
93 void printImpl(const TypeRangeConstraintDecl *decl);
94 void printImpl(const UserConstraintDecl *decl);
95 void printImpl(const ValueConstraintDecl *decl);
96 void printImpl(const ValueRangeConstraintDecl *decl);
97 void printImpl(const NamedAttributeDecl *decl);
98 void printImpl(const OpNameDecl *decl);
99 void printImpl(const PatternDecl *decl);
100 void printImpl(const UserRewriteDecl *decl);
101 void printImpl(const VariableDecl *decl);
102 void printImpl(const Module *module);
103
104 /// Print the current indent stack.
printIndent()105 void printIndent() {
106 if (elementIndentStack.empty())
107 return;
108
109 for (bool isLastElt : llvm::makeArrayRef(elementIndentStack).drop_back())
110 os << (isLastElt ? " " : " |");
111 os << (elementIndentStack.back() ? " `" : " |");
112 }
113
114 /// The raw output stream.
115 raw_ostream &os;
116
117 /// A stack of indents and a flag indicating if the current element being
118 /// printed at that indent is the last element.
119 SmallVector<bool> elementIndentStack;
120 };
121 } // namespace
122
print(Type type)123 void NodePrinter::print(Type type) {
124 // Protect against invalid inputs.
125 if (!type) {
126 os << "Type<NULL>";
127 return;
128 }
129
130 TypeSwitch<Type>(type)
131 .Case([&](AttributeType) { os << "Attr"; })
132 .Case([&](ConstraintType) { os << "Constraint"; })
133 .Case([&](OperationType type) {
134 os << "Op";
135 if (Optional<StringRef> name = type.getName())
136 os << "<" << *name << ">";
137 })
138 .Case([&](RangeType type) {
139 print(type.getElementType());
140 os << "Range";
141 })
142 .Case([&](RewriteType) { os << "Rewrite"; })
143 .Case([&](TupleType type) {
144 os << "Tuple<";
145 llvm::interleaveComma(
146 llvm::zip(type.getElementNames(), type.getElementTypes()), os,
147 [&](auto it) {
148 if (!std::get<0>(it).empty())
149 os << std::get<0>(it) << ": ";
150 this->print(std::get<1>(it));
151 });
152 os << ">";
153 })
154 .Case([&](TypeType) { os << "Type"; })
155 .Case([&](ValueType) { os << "Value"; })
156 .Default([](Type) { llvm_unreachable("unknown AST type"); });
157 }
158
print(const Node * node)159 void NodePrinter::print(const Node *node) {
160 printIndent();
161 os << "-";
162
163 elementIndentStack.push_back(/*isLastElt*/ false);
164 TypeSwitch<const Node *>(node)
165 .Case<
166 // Statements.
167 const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt,
168 const ReturnStmt, const RewriteStmt,
169
170 // Expressions.
171 const AttributeExpr, const CallExpr, const DeclRefExpr,
172 const MemberAccessExpr, const OperationExpr, const TupleExpr,
173 const TypeExpr,
174
175 // Decls.
176 const AttrConstraintDecl, const OpConstraintDecl,
177 const TypeConstraintDecl, const TypeRangeConstraintDecl,
178 const UserConstraintDecl, const ValueConstraintDecl,
179 const ValueRangeConstraintDecl, const NamedAttributeDecl,
180 const OpNameDecl, const PatternDecl, const UserRewriteDecl,
181 const VariableDecl,
182
183 const Module>([&](auto derivedNode) { this->printImpl(derivedNode); })
184 .Default([](const Node *) { llvm_unreachable("unknown AST node"); });
185 elementIndentStack.pop_back();
186 }
187
printImpl(const CompoundStmt * stmt)188 void NodePrinter::printImpl(const CompoundStmt *stmt) {
189 os << "CompoundStmt " << stmt << "\n";
190 printChildren(stmt->getChildren());
191 }
192
printImpl(const EraseStmt * stmt)193 void NodePrinter::printImpl(const EraseStmt *stmt) {
194 os << "EraseStmt " << stmt << "\n";
195 printChildren(stmt->getRootOpExpr());
196 }
197
printImpl(const LetStmt * stmt)198 void NodePrinter::printImpl(const LetStmt *stmt) {
199 os << "LetStmt " << stmt << "\n";
200 printChildren(stmt->getVarDecl());
201 }
202
printImpl(const ReplaceStmt * stmt)203 void NodePrinter::printImpl(const ReplaceStmt *stmt) {
204 os << "ReplaceStmt " << stmt << "\n";
205 printChildren(stmt->getRootOpExpr());
206 printChildren("ReplValues", stmt->getReplExprs());
207 }
208
printImpl(const ReturnStmt * stmt)209 void NodePrinter::printImpl(const ReturnStmt *stmt) {
210 os << "ReturnStmt " << stmt << "\n";
211 printChildren(stmt->getResultExpr());
212 }
213
printImpl(const RewriteStmt * stmt)214 void NodePrinter::printImpl(const RewriteStmt *stmt) {
215 os << "RewriteStmt " << stmt << "\n";
216 printChildren(stmt->getRootOpExpr(), stmt->getRewriteBody());
217 }
218
printImpl(const AttributeExpr * expr)219 void NodePrinter::printImpl(const AttributeExpr *expr) {
220 os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
221 }
222
printImpl(const CallExpr * expr)223 void NodePrinter::printImpl(const CallExpr *expr) {
224 os << "CallExpr " << expr << " Type<";
225 print(expr->getType());
226 os << ">\n";
227 printChildren(expr->getCallableExpr());
228 printChildren("Arguments", expr->getArguments());
229 }
230
printImpl(const DeclRefExpr * expr)231 void NodePrinter::printImpl(const DeclRefExpr *expr) {
232 os << "DeclRefExpr " << expr << " Type<";
233 print(expr->getType());
234 os << ">\n";
235 printChildren(expr->getDecl());
236 }
237
printImpl(const MemberAccessExpr * expr)238 void NodePrinter::printImpl(const MemberAccessExpr *expr) {
239 os << "MemberAccessExpr " << expr << " Member<" << expr->getMemberName()
240 << "> Type<";
241 print(expr->getType());
242 os << ">\n";
243 printChildren(expr->getParentExpr());
244 }
245
printImpl(const OperationExpr * expr)246 void NodePrinter::printImpl(const OperationExpr *expr) {
247 os << "OperationExpr " << expr << " Type<";
248 print(expr->getType());
249 os << ">\n";
250
251 printChildren(expr->getNameDecl());
252 printChildren("Operands", expr->getOperands());
253 printChildren("Result Types", expr->getResultTypes());
254 printChildren("Attributes", expr->getAttributes());
255 }
256
printImpl(const TupleExpr * expr)257 void NodePrinter::printImpl(const TupleExpr *expr) {
258 os << "TupleExpr " << expr << " Type<";
259 print(expr->getType());
260 os << ">\n";
261
262 printChildren(expr->getElements());
263 }
264
printImpl(const TypeExpr * expr)265 void NodePrinter::printImpl(const TypeExpr *expr) {
266 os << "TypeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
267 }
268
printImpl(const AttrConstraintDecl * decl)269 void NodePrinter::printImpl(const AttrConstraintDecl *decl) {
270 os << "AttrConstraintDecl " << decl << "\n";
271 if (const auto *typeExpr = decl->getTypeExpr())
272 printChildren(typeExpr);
273 }
274
printImpl(const OpConstraintDecl * decl)275 void NodePrinter::printImpl(const OpConstraintDecl *decl) {
276 os << "OpConstraintDecl " << decl << "\n";
277 printChildren(decl->getNameDecl());
278 }
279
printImpl(const TypeConstraintDecl * decl)280 void NodePrinter::printImpl(const TypeConstraintDecl *decl) {
281 os << "TypeConstraintDecl " << decl << "\n";
282 }
283
printImpl(const TypeRangeConstraintDecl * decl)284 void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) {
285 os << "TypeRangeConstraintDecl " << decl << "\n";
286 }
287
printImpl(const UserConstraintDecl * decl)288 void NodePrinter::printImpl(const UserConstraintDecl *decl) {
289 os << "UserConstraintDecl " << decl << " Name<" << decl->getName().getName()
290 << "> ResultType<" << decl->getResultType() << ">";
291 if (Optional<StringRef> codeBlock = decl->getCodeBlock()) {
292 os << " Code<";
293 llvm::printEscapedString(*codeBlock, os);
294 os << ">";
295 }
296 os << "\n";
297 printChildren("Inputs", decl->getInputs());
298 printChildren("Results", decl->getResults());
299 if (const CompoundStmt *body = decl->getBody())
300 printChildren(body);
301 }
302
printImpl(const ValueConstraintDecl * decl)303 void NodePrinter::printImpl(const ValueConstraintDecl *decl) {
304 os << "ValueConstraintDecl " << decl << "\n";
305 if (const auto *typeExpr = decl->getTypeExpr())
306 printChildren(typeExpr);
307 }
308
printImpl(const ValueRangeConstraintDecl * decl)309 void NodePrinter::printImpl(const ValueRangeConstraintDecl *decl) {
310 os << "ValueRangeConstraintDecl " << decl << "\n";
311 if (const auto *typeExpr = decl->getTypeExpr())
312 printChildren(typeExpr);
313 }
314
printImpl(const NamedAttributeDecl * decl)315 void NodePrinter::printImpl(const NamedAttributeDecl *decl) {
316 os << "NamedAttributeDecl " << decl << " Name<" << decl->getName().getName()
317 << ">\n";
318 printChildren(decl->getValue());
319 }
320
printImpl(const OpNameDecl * decl)321 void NodePrinter::printImpl(const OpNameDecl *decl) {
322 os << "OpNameDecl " << decl;
323 if (Optional<StringRef> name = decl->getName())
324 os << " Name<" << name << ">";
325 os << "\n";
326 }
327
printImpl(const PatternDecl * decl)328 void NodePrinter::printImpl(const PatternDecl *decl) {
329 os << "PatternDecl " << decl;
330 if (const Name *name = decl->getName())
331 os << " Name<" << name->getName() << ">";
332 if (Optional<uint16_t> benefit = decl->getBenefit())
333 os << " Benefit<" << *benefit << ">";
334 if (decl->hasBoundedRewriteRecursion())
335 os << " Recursion";
336
337 os << "\n";
338 printChildren(decl->getBody());
339 }
340
printImpl(const UserRewriteDecl * decl)341 void NodePrinter::printImpl(const UserRewriteDecl *decl) {
342 os << "UserRewriteDecl " << decl << " Name<" << decl->getName().getName()
343 << "> ResultType<" << decl->getResultType() << ">";
344 if (Optional<StringRef> codeBlock = decl->getCodeBlock()) {
345 os << " Code<";
346 llvm::printEscapedString(*codeBlock, os);
347 os << ">";
348 }
349 os << "\n";
350 printChildren("Inputs", decl->getInputs());
351 printChildren("Results", decl->getResults());
352 if (const CompoundStmt *body = decl->getBody())
353 printChildren(body);
354 }
355
printImpl(const VariableDecl * decl)356 void NodePrinter::printImpl(const VariableDecl *decl) {
357 os << "VariableDecl " << decl << " Name<" << decl->getName().getName()
358 << "> Type<";
359 print(decl->getType());
360 os << ">\n";
361 if (Expr *initExpr = decl->getInitExpr())
362 printChildren(initExpr);
363
364 auto constraints =
365 llvm::map_range(decl->getConstraints(),
366 [](const ConstraintRef &ref) { return ref.constraint; });
367 printChildren("Constraints", constraints);
368 }
369
printImpl(const Module * module)370 void NodePrinter::printImpl(const Module *module) {
371 os << "Module " << module << "\n";
372 printChildren(module->getChildren());
373 }
374
375 //===----------------------------------------------------------------------===//
376 // Entry point
377 //===----------------------------------------------------------------------===//
378
print(raw_ostream & os) const379 void Node::print(raw_ostream &os) const { NodePrinter(os).print(this); }
380
print(raw_ostream & os) const381 void Type::print(raw_ostream &os) const { NodePrinter(os).print(*this); }
382