1 //===- NodePrinter.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/AST/Context.h"
10 #include "mlir/Tools/PDLL/AST/Nodes.h"
11 #include "llvm/ADT/StringExtras.h"
12 #include "llvm/ADT/TypeSwitch.h"
13 #include "llvm/Support/SaveAndRestore.h"
14 #include "llvm/Support/ScopedPrinter.h"
15 
16 using namespace mlir;
17 using namespace mlir::pdll::ast;
18 
19 //===----------------------------------------------------------------------===//
20 // NodePrinter
21 //===----------------------------------------------------------------------===//
22 
23 namespace {
24 class NodePrinter {
25 public:
NodePrinter(raw_ostream & os)26   NodePrinter(raw_ostream &os) : os(os) {}
27 
28   /// Print the given type to the stream.
29   void print(Type type);
30 
31   /// Print the given node to the stream.
32   void print(const Node *node);
33 
34 private:
35   /// Print a range containing children of a node.
36   template <typename RangeT,
37             std::enable_if_t<!std::is_convertible<RangeT, const Node *>::value>
38                 * = nullptr>
printChildren(RangeT && range)39   void printChildren(RangeT &&range) {
40     if (llvm::empty(range))
41       return;
42 
43     // Print the first N-1 elements with a prefix of "|-".
44     auto it = std::begin(range);
45     for (unsigned i = 0, e = llvm::size(range) - 1; i < e; ++i, ++it)
46       print(*it);
47 
48     // Print the last element.
49     elementIndentStack.back() = true;
50     print(*it);
51   }
52   template <typename RangeT, typename... OthersT,
53             std::enable_if_t<std::is_convertible<RangeT, const Node *>::value>
54                 * = nullptr>
printChildren(RangeT && range,OthersT &&...others)55   void printChildren(RangeT &&range, OthersT &&...others) {
56     printChildren(ArrayRef<const Node *>({range, others...}));
57   }
58   /// Print a range containing children of a node, nesting the children under
59   /// the given label.
60   template <typename RangeT>
printChildren(StringRef label,RangeT && range)61   void printChildren(StringRef label, RangeT &&range) {
62     if (llvm::empty(range))
63       return;
64     elementIndentStack.reserve(elementIndentStack.size() + 1);
65     llvm::SaveAndRestore<bool> lastElement(elementIndentStack.back(), true);
66 
67     printIndent();
68     os << label << "`\n";
69     elementIndentStack.push_back(/*isLastElt*/ false);
70     printChildren(std::forward<RangeT>(range));
71     elementIndentStack.pop_back();
72   }
73 
74   /// Print the given derived node to the stream.
75   void printImpl(const CompoundStmt *stmt);
76   void printImpl(const EraseStmt *stmt);
77   void printImpl(const LetStmt *stmt);
78   void printImpl(const ReplaceStmt *stmt);
79   void printImpl(const ReturnStmt *stmt);
80   void printImpl(const RewriteStmt *stmt);
81 
82   void printImpl(const AttributeExpr *expr);
83   void printImpl(const CallExpr *expr);
84   void printImpl(const DeclRefExpr *expr);
85   void printImpl(const MemberAccessExpr *expr);
86   void printImpl(const OperationExpr *expr);
87   void printImpl(const TupleExpr *expr);
88   void printImpl(const TypeExpr *expr);
89 
90   void printImpl(const AttrConstraintDecl *decl);
91   void printImpl(const OpConstraintDecl *decl);
92   void printImpl(const TypeConstraintDecl *decl);
93   void printImpl(const TypeRangeConstraintDecl *decl);
94   void printImpl(const UserConstraintDecl *decl);
95   void printImpl(const ValueConstraintDecl *decl);
96   void printImpl(const ValueRangeConstraintDecl *decl);
97   void printImpl(const NamedAttributeDecl *decl);
98   void printImpl(const OpNameDecl *decl);
99   void printImpl(const PatternDecl *decl);
100   void printImpl(const UserRewriteDecl *decl);
101   void printImpl(const VariableDecl *decl);
102   void printImpl(const Module *module);
103 
104   /// Print the current indent stack.
printIndent()105   void printIndent() {
106     if (elementIndentStack.empty())
107       return;
108 
109     for (bool isLastElt : llvm::makeArrayRef(elementIndentStack).drop_back())
110       os << (isLastElt ? "  " : " |");
111     os << (elementIndentStack.back() ? " `" : " |");
112   }
113 
114   /// The raw output stream.
115   raw_ostream &os;
116 
117   /// A stack of indents and a flag indicating if the current element being
118   /// printed at that indent is the last element.
119   SmallVector<bool> elementIndentStack;
120 };
121 } // namespace
122 
print(Type type)123 void NodePrinter::print(Type type) {
124   // Protect against invalid inputs.
125   if (!type) {
126     os << "Type<NULL>";
127     return;
128   }
129 
130   TypeSwitch<Type>(type)
131       .Case([&](AttributeType) { os << "Attr"; })
132       .Case([&](ConstraintType) { os << "Constraint"; })
133       .Case([&](OperationType type) {
134         os << "Op";
135         if (Optional<StringRef> name = type.getName())
136           os << "<" << *name << ">";
137       })
138       .Case([&](RangeType type) {
139         print(type.getElementType());
140         os << "Range";
141       })
142       .Case([&](RewriteType) { os << "Rewrite"; })
143       .Case([&](TupleType type) {
144         os << "Tuple<";
145         llvm::interleaveComma(
146             llvm::zip(type.getElementNames(), type.getElementTypes()), os,
147             [&](auto it) {
148               if (!std::get<0>(it).empty())
149                 os << std::get<0>(it) << ": ";
150               this->print(std::get<1>(it));
151             });
152         os << ">";
153       })
154       .Case([&](TypeType) { os << "Type"; })
155       .Case([&](ValueType) { os << "Value"; })
156       .Default([](Type) { llvm_unreachable("unknown AST type"); });
157 }
158 
print(const Node * node)159 void NodePrinter::print(const Node *node) {
160   printIndent();
161   os << "-";
162 
163   elementIndentStack.push_back(/*isLastElt*/ false);
164   TypeSwitch<const Node *>(node)
165       .Case<
166           // Statements.
167           const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt,
168           const ReturnStmt, const RewriteStmt,
169 
170           // Expressions.
171           const AttributeExpr, const CallExpr, const DeclRefExpr,
172           const MemberAccessExpr, const OperationExpr, const TupleExpr,
173           const TypeExpr,
174 
175           // Decls.
176           const AttrConstraintDecl, const OpConstraintDecl,
177           const TypeConstraintDecl, const TypeRangeConstraintDecl,
178           const UserConstraintDecl, const ValueConstraintDecl,
179           const ValueRangeConstraintDecl, const NamedAttributeDecl,
180           const OpNameDecl, const PatternDecl, const UserRewriteDecl,
181           const VariableDecl,
182 
183           const Module>([&](auto derivedNode) { this->printImpl(derivedNode); })
184       .Default([](const Node *) { llvm_unreachable("unknown AST node"); });
185   elementIndentStack.pop_back();
186 }
187 
printImpl(const CompoundStmt * stmt)188 void NodePrinter::printImpl(const CompoundStmt *stmt) {
189   os << "CompoundStmt " << stmt << "\n";
190   printChildren(stmt->getChildren());
191 }
192 
printImpl(const EraseStmt * stmt)193 void NodePrinter::printImpl(const EraseStmt *stmt) {
194   os << "EraseStmt " << stmt << "\n";
195   printChildren(stmt->getRootOpExpr());
196 }
197 
printImpl(const LetStmt * stmt)198 void NodePrinter::printImpl(const LetStmt *stmt) {
199   os << "LetStmt " << stmt << "\n";
200   printChildren(stmt->getVarDecl());
201 }
202 
printImpl(const ReplaceStmt * stmt)203 void NodePrinter::printImpl(const ReplaceStmt *stmt) {
204   os << "ReplaceStmt " << stmt << "\n";
205   printChildren(stmt->getRootOpExpr());
206   printChildren("ReplValues", stmt->getReplExprs());
207 }
208 
printImpl(const ReturnStmt * stmt)209 void NodePrinter::printImpl(const ReturnStmt *stmt) {
210   os << "ReturnStmt " << stmt << "\n";
211   printChildren(stmt->getResultExpr());
212 }
213 
printImpl(const RewriteStmt * stmt)214 void NodePrinter::printImpl(const RewriteStmt *stmt) {
215   os << "RewriteStmt " << stmt << "\n";
216   printChildren(stmt->getRootOpExpr(), stmt->getRewriteBody());
217 }
218 
printImpl(const AttributeExpr * expr)219 void NodePrinter::printImpl(const AttributeExpr *expr) {
220   os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
221 }
222 
printImpl(const CallExpr * expr)223 void NodePrinter::printImpl(const CallExpr *expr) {
224   os << "CallExpr " << expr << " Type<";
225   print(expr->getType());
226   os << ">\n";
227   printChildren(expr->getCallableExpr());
228   printChildren("Arguments", expr->getArguments());
229 }
230 
printImpl(const DeclRefExpr * expr)231 void NodePrinter::printImpl(const DeclRefExpr *expr) {
232   os << "DeclRefExpr " << expr << " Type<";
233   print(expr->getType());
234   os << ">\n";
235   printChildren(expr->getDecl());
236 }
237 
printImpl(const MemberAccessExpr * expr)238 void NodePrinter::printImpl(const MemberAccessExpr *expr) {
239   os << "MemberAccessExpr " << expr << " Member<" << expr->getMemberName()
240      << "> Type<";
241   print(expr->getType());
242   os << ">\n";
243   printChildren(expr->getParentExpr());
244 }
245 
printImpl(const OperationExpr * expr)246 void NodePrinter::printImpl(const OperationExpr *expr) {
247   os << "OperationExpr " << expr << " Type<";
248   print(expr->getType());
249   os << ">\n";
250 
251   printChildren(expr->getNameDecl());
252   printChildren("Operands", expr->getOperands());
253   printChildren("Result Types", expr->getResultTypes());
254   printChildren("Attributes", expr->getAttributes());
255 }
256 
printImpl(const TupleExpr * expr)257 void NodePrinter::printImpl(const TupleExpr *expr) {
258   os << "TupleExpr " << expr << " Type<";
259   print(expr->getType());
260   os << ">\n";
261 
262   printChildren(expr->getElements());
263 }
264 
printImpl(const TypeExpr * expr)265 void NodePrinter::printImpl(const TypeExpr *expr) {
266   os << "TypeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
267 }
268 
printImpl(const AttrConstraintDecl * decl)269 void NodePrinter::printImpl(const AttrConstraintDecl *decl) {
270   os << "AttrConstraintDecl " << decl << "\n";
271   if (const auto *typeExpr = decl->getTypeExpr())
272     printChildren(typeExpr);
273 }
274 
printImpl(const OpConstraintDecl * decl)275 void NodePrinter::printImpl(const OpConstraintDecl *decl) {
276   os << "OpConstraintDecl " << decl << "\n";
277   printChildren(decl->getNameDecl());
278 }
279 
printImpl(const TypeConstraintDecl * decl)280 void NodePrinter::printImpl(const TypeConstraintDecl *decl) {
281   os << "TypeConstraintDecl " << decl << "\n";
282 }
283 
printImpl(const TypeRangeConstraintDecl * decl)284 void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) {
285   os << "TypeRangeConstraintDecl " << decl << "\n";
286 }
287 
printImpl(const UserConstraintDecl * decl)288 void NodePrinter::printImpl(const UserConstraintDecl *decl) {
289   os << "UserConstraintDecl " << decl << " Name<" << decl->getName().getName()
290      << "> ResultType<" << decl->getResultType() << ">";
291   if (Optional<StringRef> codeBlock = decl->getCodeBlock()) {
292     os << " Code<";
293     llvm::printEscapedString(*codeBlock, os);
294     os << ">";
295   }
296   os << "\n";
297   printChildren("Inputs", decl->getInputs());
298   printChildren("Results", decl->getResults());
299   if (const CompoundStmt *body = decl->getBody())
300     printChildren(body);
301 }
302 
printImpl(const ValueConstraintDecl * decl)303 void NodePrinter::printImpl(const ValueConstraintDecl *decl) {
304   os << "ValueConstraintDecl " << decl << "\n";
305   if (const auto *typeExpr = decl->getTypeExpr())
306     printChildren(typeExpr);
307 }
308 
printImpl(const ValueRangeConstraintDecl * decl)309 void NodePrinter::printImpl(const ValueRangeConstraintDecl *decl) {
310   os << "ValueRangeConstraintDecl " << decl << "\n";
311   if (const auto *typeExpr = decl->getTypeExpr())
312     printChildren(typeExpr);
313 }
314 
printImpl(const NamedAttributeDecl * decl)315 void NodePrinter::printImpl(const NamedAttributeDecl *decl) {
316   os << "NamedAttributeDecl " << decl << " Name<" << decl->getName().getName()
317      << ">\n";
318   printChildren(decl->getValue());
319 }
320 
printImpl(const OpNameDecl * decl)321 void NodePrinter::printImpl(const OpNameDecl *decl) {
322   os << "OpNameDecl " << decl;
323   if (Optional<StringRef> name = decl->getName())
324     os << " Name<" << name << ">";
325   os << "\n";
326 }
327 
printImpl(const PatternDecl * decl)328 void NodePrinter::printImpl(const PatternDecl *decl) {
329   os << "PatternDecl " << decl;
330   if (const Name *name = decl->getName())
331     os << " Name<" << name->getName() << ">";
332   if (Optional<uint16_t> benefit = decl->getBenefit())
333     os << " Benefit<" << *benefit << ">";
334   if (decl->hasBoundedRewriteRecursion())
335     os << " Recursion";
336 
337   os << "\n";
338   printChildren(decl->getBody());
339 }
340 
printImpl(const UserRewriteDecl * decl)341 void NodePrinter::printImpl(const UserRewriteDecl *decl) {
342   os << "UserRewriteDecl " << decl << " Name<" << decl->getName().getName()
343      << "> ResultType<" << decl->getResultType() << ">";
344   if (Optional<StringRef> codeBlock = decl->getCodeBlock()) {
345     os << " Code<";
346     llvm::printEscapedString(*codeBlock, os);
347     os << ">";
348   }
349   os << "\n";
350   printChildren("Inputs", decl->getInputs());
351   printChildren("Results", decl->getResults());
352   if (const CompoundStmt *body = decl->getBody())
353     printChildren(body);
354 }
355 
printImpl(const VariableDecl * decl)356 void NodePrinter::printImpl(const VariableDecl *decl) {
357   os << "VariableDecl " << decl << " Name<" << decl->getName().getName()
358      << "> Type<";
359   print(decl->getType());
360   os << ">\n";
361   if (Expr *initExpr = decl->getInitExpr())
362     printChildren(initExpr);
363 
364   auto constraints =
365       llvm::map_range(decl->getConstraints(),
366                       [](const ConstraintRef &ref) { return ref.constraint; });
367   printChildren("Constraints", constraints);
368 }
369 
printImpl(const Module * module)370 void NodePrinter::printImpl(const Module *module) {
371   os << "Module " << module << "\n";
372   printChildren(module->getChildren());
373 }
374 
375 //===----------------------------------------------------------------------===//
376 // Entry point
377 //===----------------------------------------------------------------------===//
378 
print(raw_ostream & os) const379 void Node::print(raw_ostream &os) const { NodePrinter(os).print(this); }
380 
print(raw_ostream & os) const381 void Type::print(raw_ostream &os) const { NodePrinter(os).print(*this); }
382