1 //===- NodePrinter.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/AST/Context.h"
10 #include "mlir/Tools/PDLL/AST/Nodes.h"
11 #include "llvm/ADT/StringExtras.h"
12 #include "llvm/ADT/TypeSwitch.h"
13 #include "llvm/Support/SaveAndRestore.h"
14 #include "llvm/Support/ScopedPrinter.h"
15 
16 using namespace mlir;
17 using namespace mlir::pdll::ast;
18 
19 //===----------------------------------------------------------------------===//
20 // NodePrinter
21 //===----------------------------------------------------------------------===//
22 
23 namespace {
24 class NodePrinter {
25 public:
26   NodePrinter(raw_ostream &os) : os(os) {}
27 
28   /// Print the given type to the stream.
29   void print(Type type);
30 
31   /// Print the given node to the stream.
32   void print(const Node *node);
33 
34 private:
35   /// Print a range containing children of a node.
36   template <typename RangeT,
37             std::enable_if_t<!std::is_convertible<RangeT, const Node *>::value>
38                 * = nullptr>
39   void printChildren(RangeT &&range) {
40     if (llvm::empty(range))
41       return;
42 
43     // Print the first N-1 elements with a prefix of "|-".
44     auto it = std::begin(range);
45     for (unsigned i = 0, e = llvm::size(range) - 1; i < e; ++i, ++it)
46       print(*it);
47 
48     // Print the last element.
49     elementIndentStack.back() = true;
50     print(*it);
51   }
52   template <typename RangeT, typename... OthersT,
53             std::enable_if_t<std::is_convertible<RangeT, const Node *>::value>
54                 * = nullptr>
55   void printChildren(RangeT &&range, OthersT &&...others) {
56     printChildren(ArrayRef<const Node *>({range, others...}));
57   }
58   /// Print a range containing children of a node, nesting the children under
59   /// the given label.
60   template <typename RangeT>
61   void printChildren(StringRef label, RangeT &&range) {
62     if (llvm::empty(range))
63       return;
64     elementIndentStack.reserve(elementIndentStack.size() + 1);
65     llvm::SaveAndRestore<bool> lastElement(elementIndentStack.back(), true);
66 
67     printIndent();
68     os << label << "`\n";
69     elementIndentStack.push_back(/*isLastElt*/ false);
70     printChildren(std::forward<RangeT>(range));
71     elementIndentStack.pop_back();
72   }
73 
74   /// Print the given derived node to the stream.
75   void printImpl(const CompoundStmt *stmt);
76   void printImpl(const EraseStmt *stmt);
77   void printImpl(const LetStmt *stmt);
78   void printImpl(const ReplaceStmt *stmt);
79   void printImpl(const RewriteStmt *stmt);
80 
81   void printImpl(const AttributeExpr *expr);
82   void printImpl(const DeclRefExpr *expr);
83   void printImpl(const MemberAccessExpr *expr);
84   void printImpl(const OperationExpr *expr);
85   void printImpl(const TupleExpr *expr);
86   void printImpl(const TypeExpr *expr);
87 
88   void printImpl(const AttrConstraintDecl *decl);
89   void printImpl(const OpConstraintDecl *decl);
90   void printImpl(const TypeConstraintDecl *decl);
91   void printImpl(const TypeRangeConstraintDecl *decl);
92   void printImpl(const ValueConstraintDecl *decl);
93   void printImpl(const ValueRangeConstraintDecl *decl);
94   void printImpl(const NamedAttributeDecl *decl);
95   void printImpl(const OpNameDecl *decl);
96   void printImpl(const PatternDecl *decl);
97   void printImpl(const VariableDecl *decl);
98   void printImpl(const Module *module);
99 
100   /// Print the current indent stack.
101   void printIndent() {
102     if (elementIndentStack.empty())
103       return;
104 
105     for (bool isLastElt : llvm::makeArrayRef(elementIndentStack).drop_back())
106       os << (isLastElt ? "  " : " |");
107     os << (elementIndentStack.back() ? " `" : " |");
108   }
109 
110   /// The raw output stream.
111   raw_ostream &os;
112 
113   /// A stack of indents and a flag indicating if the current element being
114   /// printed at that indent is the last element.
115   SmallVector<bool> elementIndentStack;
116 };
117 } // namespace
118 
119 void NodePrinter::print(Type type) {
120   // Protect against invalid inputs.
121   if (!type) {
122     os << "Type<NULL>";
123     return;
124   }
125 
126   TypeSwitch<Type>(type)
127       .Case([&](AttributeType) { os << "Attr"; })
128       .Case([&](ConstraintType) { os << "Constraint"; })
129       .Case([&](OperationType type) {
130         os << "Op";
131         if (Optional<StringRef> name = type.getName())
132           os << "<" << *name << ">";
133       })
134       .Case([&](RangeType type) {
135         print(type.getElementType());
136         os << "Range";
137       })
138       .Case([&](TupleType type) {
139         os << "Tuple<";
140         llvm::interleaveComma(
141             llvm::zip(type.getElementNames(), type.getElementTypes()), os,
142             [&](auto it) {
143               if (!std::get<0>(it).empty())
144                 os << std::get<0>(it) << ": ";
145               this->print(std::get<1>(it));
146             });
147         os << ">";
148       })
149       .Case([&](TypeType) { os << "Type"; })
150       .Case([&](ValueType) { os << "Value"; })
151       .Default([](Type) { llvm_unreachable("unknown AST type"); });
152 }
153 
154 void NodePrinter::print(const Node *node) {
155   printIndent();
156   os << "-";
157 
158   elementIndentStack.push_back(/*isLastElt*/ false);
159   TypeSwitch<const Node *>(node)
160       .Case<
161           // Statements.
162           const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt,
163           const RewriteStmt,
164 
165           // Expressions.
166           const AttributeExpr, const DeclRefExpr, const MemberAccessExpr,
167           const OperationExpr, const TupleExpr, const TypeExpr,
168 
169           // Decls.
170           const AttrConstraintDecl, const OpConstraintDecl,
171           const TypeConstraintDecl, const TypeRangeConstraintDecl,
172           const ValueConstraintDecl, const ValueRangeConstraintDecl,
173           const NamedAttributeDecl, const OpNameDecl, const PatternDecl,
174           const VariableDecl,
175 
176           const Module>([&](auto derivedNode) { this->printImpl(derivedNode); })
177       .Default([](const Node *) { llvm_unreachable("unknown AST node"); });
178   elementIndentStack.pop_back();
179 }
180 
181 void NodePrinter::printImpl(const CompoundStmt *stmt) {
182   os << "CompoundStmt " << stmt << "\n";
183   printChildren(stmt->getChildren());
184 }
185 
186 void NodePrinter::printImpl(const EraseStmt *stmt) {
187   os << "EraseStmt " << stmt << "\n";
188   printChildren(stmt->getRootOpExpr());
189 }
190 
191 void NodePrinter::printImpl(const LetStmt *stmt) {
192   os << "LetStmt " << stmt << "\n";
193   printChildren(stmt->getVarDecl());
194 }
195 
196 void NodePrinter::printImpl(const ReplaceStmt *stmt) {
197   os << "ReplaceStmt " << stmt << "\n";
198   printChildren(stmt->getRootOpExpr());
199   printChildren("ReplValues", stmt->getReplExprs());
200 }
201 
202 void NodePrinter::printImpl(const RewriteStmt *stmt) {
203   os << "RewriteStmt " << stmt << "\n";
204   printChildren(stmt->getRootOpExpr(), stmt->getRewriteBody());
205 }
206 
207 void NodePrinter::printImpl(const AttributeExpr *expr) {
208   os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
209 }
210 
211 void NodePrinter::printImpl(const DeclRefExpr *expr) {
212   os << "DeclRefExpr " << expr << " Type<";
213   print(expr->getType());
214   os << ">\n";
215   printChildren(expr->getDecl());
216 }
217 
218 void NodePrinter::printImpl(const MemberAccessExpr *expr) {
219   os << "MemberAccessExpr " << expr << " Member<" << expr->getMemberName()
220      << "> Type<";
221   print(expr->getType());
222   os << ">\n";
223   printChildren(expr->getParentExpr());
224 }
225 
226 void NodePrinter::printImpl(const OperationExpr *expr) {
227   os << "OperationExpr " << expr << " Type<";
228   print(expr->getType());
229   os << ">\n";
230 
231   printChildren(expr->getNameDecl());
232   printChildren("Operands", expr->getOperands());
233   printChildren("Result Types", expr->getResultTypes());
234   printChildren("Attributes", expr->getAttributes());
235 }
236 
237 void NodePrinter::printImpl(const TupleExpr *expr) {
238   os << "TupleExpr " << expr << " Type<";
239   print(expr->getType());
240   os << ">\n";
241 
242   printChildren(expr->getElements());
243 }
244 
245 void NodePrinter::printImpl(const TypeExpr *expr) {
246   os << "TypeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
247 }
248 
249 void NodePrinter::printImpl(const AttrConstraintDecl *decl) {
250   os << "AttrConstraintDecl " << decl << "\n";
251   if (const auto *typeExpr = decl->getTypeExpr())
252     printChildren(typeExpr);
253 }
254 
255 void NodePrinter::printImpl(const OpConstraintDecl *decl) {
256   os << "OpConstraintDecl " << decl << "\n";
257   printChildren(decl->getNameDecl());
258 }
259 
260 void NodePrinter::printImpl(const TypeConstraintDecl *decl) {
261   os << "TypeConstraintDecl " << decl << "\n";
262 }
263 
264 void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) {
265   os << "TypeRangeConstraintDecl " << decl << "\n";
266 }
267 
268 void NodePrinter::printImpl(const ValueConstraintDecl *decl) {
269   os << "ValueConstraintDecl " << decl << "\n";
270   if (const auto *typeExpr = decl->getTypeExpr())
271     printChildren(typeExpr);
272 }
273 
274 void NodePrinter::printImpl(const ValueRangeConstraintDecl *decl) {
275   os << "ValueRangeConstraintDecl " << decl << "\n";
276   if (const auto *typeExpr = decl->getTypeExpr())
277     printChildren(typeExpr);
278 }
279 
280 void NodePrinter::printImpl(const NamedAttributeDecl *decl) {
281   os << "NamedAttributeDecl " << decl << " Name<" << decl->getName().getName()
282      << ">\n";
283   printChildren(decl->getValue());
284 }
285 
286 void NodePrinter::printImpl(const OpNameDecl *decl) {
287   os << "OpNameDecl " << decl;
288   if (Optional<StringRef> name = decl->getName())
289     os << " Name<" << name << ">";
290   os << "\n";
291 }
292 
293 void NodePrinter::printImpl(const PatternDecl *decl) {
294   os << "PatternDecl " << decl;
295   if (const Name *name = decl->getName())
296     os << " Name<" << name->getName() << ">";
297   if (Optional<uint16_t> benefit = decl->getBenefit())
298     os << " Benefit<" << *benefit << ">";
299   if (decl->hasBoundedRewriteRecursion())
300     os << " Recursion";
301 
302   os << "\n";
303   printChildren(decl->getBody());
304 }
305 
306 void NodePrinter::printImpl(const VariableDecl *decl) {
307   os << "VariableDecl " << decl << " Name<" << decl->getName().getName()
308      << "> Type<";
309   print(decl->getType());
310   os << ">\n";
311   if (Expr *initExpr = decl->getInitExpr())
312     printChildren(initExpr);
313 
314   auto constraints =
315       llvm::map_range(decl->getConstraints(),
316                       [](const ConstraintRef &ref) { return ref.constraint; });
317   printChildren("Constraints", constraints);
318 }
319 
320 void NodePrinter::printImpl(const Module *module) {
321   os << "Module " << module << "\n";
322   printChildren(module->getChildren());
323 }
324 
325 //===----------------------------------------------------------------------===//
326 // Entry point
327 //===----------------------------------------------------------------------===//
328 
329 void Node::print(raw_ostream &os) const { NodePrinter(os).print(this); }
330 
331 void Type::print(raw_ostream &os) const { NodePrinter(os).print(*this); }
332