1 //===- NodePrinter.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/AST/Context.h"
10 #include "mlir/Tools/PDLL/AST/Nodes.h"
11 #include "llvm/ADT/StringExtras.h"
12 #include "llvm/ADT/TypeSwitch.h"
13 #include "llvm/Support/SaveAndRestore.h"
14 #include "llvm/Support/ScopedPrinter.h"
15 
16 using namespace mlir;
17 using namespace mlir::pdll::ast;
18 
19 //===----------------------------------------------------------------------===//
20 // NodePrinter
21 //===----------------------------------------------------------------------===//
22 
23 namespace {
24 class NodePrinter {
25 public:
26   NodePrinter(raw_ostream &os) : os(os) {}
27 
28   /// Print the given type to the stream.
29   void print(Type type);
30 
31   /// Print the given node to the stream.
32   void print(const Node *node);
33 
34 private:
35   /// Print a range containing children of a node.
36   template <typename RangeT,
37             std::enable_if_t<!std::is_convertible<RangeT, const Node *>::value>
38                 * = nullptr>
39   void printChildren(RangeT &&range) {
40     if (llvm::empty(range))
41       return;
42 
43     // Print the first N-1 elements with a prefix of "|-".
44     auto it = std::begin(range);
45     for (unsigned i = 0, e = llvm::size(range) - 1; i < e; ++i, ++it)
46       print(*it);
47 
48     // Print the last element.
49     elementIndentStack.back() = true;
50     print(*it);
51   }
52   template <typename RangeT, typename... OthersT,
53             std::enable_if_t<std::is_convertible<RangeT, const Node *>::value>
54                 * = nullptr>
55   void printChildren(RangeT &&range, OthersT &&...others) {
56     printChildren(ArrayRef<const Node *>({range, others...}));
57   }
58   /// Print a range containing children of a node, nesting the children under
59   /// the given label.
60   template <typename RangeT>
61   void printChildren(StringRef label, RangeT &&range) {
62     if (llvm::empty(range))
63       return;
64     elementIndentStack.reserve(elementIndentStack.size() + 1);
65     llvm::SaveAndRestore<bool> lastElement(elementIndentStack.back(), true);
66 
67     printIndent();
68     os << label << "`\n";
69     elementIndentStack.push_back(/*isLastElt*/ false);
70     printChildren(std::forward<RangeT>(range));
71     elementIndentStack.pop_back();
72   }
73 
74   /// Print the given derived node to the stream.
75   void printImpl(const CompoundStmt *stmt);
76   void printImpl(const EraseStmt *stmt);
77   void printImpl(const LetStmt *stmt);
78   void printImpl(const ReplaceStmt *stmt);
79 
80   void printImpl(const AttributeExpr *expr);
81   void printImpl(const DeclRefExpr *expr);
82   void printImpl(const MemberAccessExpr *expr);
83   void printImpl(const OperationExpr *expr);
84   void printImpl(const TupleExpr *expr);
85   void printImpl(const TypeExpr *expr);
86 
87   void printImpl(const AttrConstraintDecl *decl);
88   void printImpl(const OpConstraintDecl *decl);
89   void printImpl(const TypeConstraintDecl *decl);
90   void printImpl(const TypeRangeConstraintDecl *decl);
91   void printImpl(const ValueConstraintDecl *decl);
92   void printImpl(const ValueRangeConstraintDecl *decl);
93   void printImpl(const NamedAttributeDecl *decl);
94   void printImpl(const OpNameDecl *decl);
95   void printImpl(const PatternDecl *decl);
96   void printImpl(const VariableDecl *decl);
97   void printImpl(const Module *module);
98 
99   /// Print the current indent stack.
100   void printIndent() {
101     if (elementIndentStack.empty())
102       return;
103 
104     for (bool isLastElt : llvm::makeArrayRef(elementIndentStack).drop_back())
105       os << (isLastElt ? "  " : " |");
106     os << (elementIndentStack.back() ? " `" : " |");
107   }
108 
109   /// The raw output stream.
110   raw_ostream &os;
111 
112   /// A stack of indents and a flag indicating if the current element being
113   /// printed at that indent is the last element.
114   SmallVector<bool> elementIndentStack;
115 };
116 } // namespace
117 
118 void NodePrinter::print(Type type) {
119   // Protect against invalid inputs.
120   if (!type) {
121     os << "Type<NULL>";
122     return;
123   }
124 
125   TypeSwitch<Type>(type)
126       .Case([&](AttributeType) { os << "Attr"; })
127       .Case([&](ConstraintType) { os << "Constraint"; })
128       .Case([&](OperationType type) {
129         os << "Op";
130         if (Optional<StringRef> name = type.getName())
131           os << "<" << *name << ">";
132       })
133       .Case([&](RangeType type) {
134         print(type.getElementType());
135         os << "Range";
136       })
137       .Case([&](TupleType type) {
138         os << "Tuple<";
139         llvm::interleaveComma(
140             llvm::zip(type.getElementNames(), type.getElementTypes()), os,
141             [&](auto it) {
142               if (!std::get<0>(it).empty())
143                 os << std::get<0>(it) << ": ";
144               this->print(std::get<1>(it));
145             });
146         os << ">";
147       })
148       .Case([&](TypeType) { os << "Type"; })
149       .Case([&](ValueType) { os << "Value"; })
150       .Default([](Type) { llvm_unreachable("unknown AST type"); });
151 }
152 
153 void NodePrinter::print(const Node *node) {
154   printIndent();
155   os << "-";
156 
157   elementIndentStack.push_back(/*isLastElt*/ false);
158   TypeSwitch<const Node *>(node)
159       .Case<
160           // Statements.
161           const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt,
162 
163           // Expressions.
164           const AttributeExpr, const DeclRefExpr, const MemberAccessExpr,
165           const OperationExpr, const TupleExpr, const TypeExpr,
166 
167           // Decls.
168           const AttrConstraintDecl, const OpConstraintDecl,
169           const TypeConstraintDecl, const TypeRangeConstraintDecl,
170           const ValueConstraintDecl, const ValueRangeConstraintDecl,
171           const NamedAttributeDecl, const OpNameDecl, const PatternDecl,
172           const VariableDecl,
173 
174           const Module>([&](auto derivedNode) { this->printImpl(derivedNode); })
175       .Default([](const Node *) { llvm_unreachable("unknown AST node"); });
176   elementIndentStack.pop_back();
177 }
178 
179 void NodePrinter::printImpl(const CompoundStmt *stmt) {
180   os << "CompoundStmt " << stmt << "\n";
181   printChildren(stmt->getChildren());
182 }
183 
184 void NodePrinter::printImpl(const EraseStmt *stmt) {
185   os << "EraseStmt " << stmt << "\n";
186   printChildren(stmt->getRootOpExpr());
187 }
188 
189 void NodePrinter::printImpl(const LetStmt *stmt) {
190   os << "LetStmt " << stmt << "\n";
191   printChildren(stmt->getVarDecl());
192 }
193 
194 void NodePrinter::printImpl(const ReplaceStmt *stmt) {
195   os << "ReplaceStmt " << stmt << "\n";
196   printChildren(stmt->getRootOpExpr());
197   printChildren("ReplValues", stmt->getReplExprs());
198 }
199 
200 void NodePrinter::printImpl(const AttributeExpr *expr) {
201   os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
202 }
203 
204 void NodePrinter::printImpl(const DeclRefExpr *expr) {
205   os << "DeclRefExpr " << expr << " Type<";
206   print(expr->getType());
207   os << ">\n";
208   printChildren(expr->getDecl());
209 }
210 
211 void NodePrinter::printImpl(const MemberAccessExpr *expr) {
212   os << "MemberAccessExpr " << expr << " Member<" << expr->getMemberName()
213      << "> Type<";
214   print(expr->getType());
215   os << ">\n";
216   printChildren(expr->getParentExpr());
217 }
218 
219 void NodePrinter::printImpl(const OperationExpr *expr) {
220   os << "OperationExpr " << expr << " Type<";
221   print(expr->getType());
222   os << ">\n";
223 
224   printChildren(expr->getNameDecl());
225   printChildren("Operands", expr->getOperands());
226   printChildren("Result Types", expr->getResultTypes());
227   printChildren("Attributes", expr->getAttributes());
228 }
229 
230 void NodePrinter::printImpl(const TupleExpr *expr) {
231   os << "TupleExpr " << expr << " Type<";
232   print(expr->getType());
233   os << ">\n";
234 
235   printChildren(expr->getElements());
236 }
237 
238 void NodePrinter::printImpl(const TypeExpr *expr) {
239   os << "TypeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
240 }
241 
242 void NodePrinter::printImpl(const AttrConstraintDecl *decl) {
243   os << "AttrConstraintDecl " << decl << "\n";
244   if (const auto *typeExpr = decl->getTypeExpr())
245     printChildren(typeExpr);
246 }
247 
248 void NodePrinter::printImpl(const OpConstraintDecl *decl) {
249   os << "OpConstraintDecl " << decl << "\n";
250   printChildren(decl->getNameDecl());
251 }
252 
253 void NodePrinter::printImpl(const TypeConstraintDecl *decl) {
254   os << "TypeConstraintDecl " << decl << "\n";
255 }
256 
257 void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) {
258   os << "TypeRangeConstraintDecl " << decl << "\n";
259 }
260 
261 void NodePrinter::printImpl(const ValueConstraintDecl *decl) {
262   os << "ValueConstraintDecl " << decl << "\n";
263   if (const auto *typeExpr = decl->getTypeExpr())
264     printChildren(typeExpr);
265 }
266 
267 void NodePrinter::printImpl(const ValueRangeConstraintDecl *decl) {
268   os << "ValueRangeConstraintDecl " << decl << "\n";
269   if (const auto *typeExpr = decl->getTypeExpr())
270     printChildren(typeExpr);
271 }
272 
273 void NodePrinter::printImpl(const NamedAttributeDecl *decl) {
274   os << "NamedAttributeDecl " << decl << " Name<" << decl->getName().getName()
275      << ">\n";
276   printChildren(decl->getValue());
277 }
278 
279 void NodePrinter::printImpl(const OpNameDecl *decl) {
280   os << "OpNameDecl " << decl;
281   if (Optional<StringRef> name = decl->getName())
282     os << " Name<" << name << ">";
283   os << "\n";
284 }
285 
286 void NodePrinter::printImpl(const PatternDecl *decl) {
287   os << "PatternDecl " << decl;
288   if (const Name *name = decl->getName())
289     os << " Name<" << name->getName() << ">";
290   if (Optional<uint16_t> benefit = decl->getBenefit())
291     os << " Benefit<" << *benefit << ">";
292   if (decl->hasBoundedRewriteRecursion())
293     os << " Recursion";
294 
295   os << "\n";
296   printChildren(decl->getBody());
297 }
298 
299 void NodePrinter::printImpl(const VariableDecl *decl) {
300   os << "VariableDecl " << decl << " Name<" << decl->getName().getName()
301      << "> Type<";
302   print(decl->getType());
303   os << ">\n";
304   if (Expr *initExpr = decl->getInitExpr())
305     printChildren(initExpr);
306 
307   auto constraints =
308       llvm::map_range(decl->getConstraints(),
309                       [](const ConstraintRef &ref) { return ref.constraint; });
310   printChildren("Constraints", constraints);
311 }
312 
313 void NodePrinter::printImpl(const Module *module) {
314   os << "Module " << module << "\n";
315   printChildren(module->getChildren());
316 }
317 
318 //===----------------------------------------------------------------------===//
319 // Entry point
320 //===----------------------------------------------------------------------===//
321 
322 void Node::print(raw_ostream &os) const { NodePrinter(os).print(this); }
323 
324 void Type::print(raw_ostream &os) const { NodePrinter(os).print(*this); }
325