1 //===- NodePrinter.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/AST/Context.h"
10 #include "mlir/Tools/PDLL/AST/Nodes.h"
11 #include "llvm/ADT/StringExtras.h"
12 #include "llvm/ADT/TypeSwitch.h"
13 #include "llvm/Support/SaveAndRestore.h"
14 #include "llvm/Support/ScopedPrinter.h"
15 
16 using namespace mlir;
17 using namespace mlir::pdll::ast;
18 
19 //===----------------------------------------------------------------------===//
20 // NodePrinter
21 //===----------------------------------------------------------------------===//
22 
23 namespace {
24 class NodePrinter {
25 public:
26   NodePrinter(raw_ostream &os) : os(os) {}
27 
28   /// Print the given type to the stream.
29   void print(Type type);
30 
31   /// Print the given node to the stream.
32   void print(const Node *node);
33 
34 private:
35   /// Print a range containing children of a node.
36   template <typename RangeT,
37             std::enable_if_t<!std::is_convertible<RangeT, const Node *>::value>
38                 * = nullptr>
39   void printChildren(RangeT &&range) {
40     if (llvm::empty(range))
41       return;
42 
43     // Print the first N-1 elements with a prefix of "|-".
44     auto it = std::begin(range);
45     for (unsigned i = 0, e = llvm::size(range) - 1; i < e; ++i, ++it)
46       print(*it);
47 
48     // Print the last element.
49     elementIndentStack.back() = true;
50     print(*it);
51   }
52   template <typename RangeT, typename... OthersT,
53             std::enable_if_t<std::is_convertible<RangeT, const Node *>::value>
54                 * = nullptr>
55   void printChildren(RangeT &&range, OthersT &&...others) {
56     printChildren(ArrayRef<const Node *>({range, others...}));
57   }
58   /// Print a range containing children of a node, nesting the children under
59   /// the given label.
60   template <typename RangeT>
61   void printChildren(StringRef label, RangeT &&range) {
62     if (llvm::empty(range))
63       return;
64     elementIndentStack.reserve(elementIndentStack.size() + 1);
65     llvm::SaveAndRestore<bool> lastElement(elementIndentStack.back(), true);
66 
67     printIndent();
68     os << label << "`\n";
69     elementIndentStack.push_back(/*isLastElt*/ false);
70     printChildren(std::forward<RangeT>(range));
71     elementIndentStack.pop_back();
72   }
73 
74   /// Print the given derived node to the stream.
75   void printImpl(const CompoundStmt *stmt);
76   void printImpl(const EraseStmt *stmt);
77   void printImpl(const LetStmt *stmt);
78 
79   void printImpl(const AttributeExpr *expr);
80   void printImpl(const DeclRefExpr *expr);
81   void printImpl(const MemberAccessExpr *expr);
82   void printImpl(const TypeExpr *expr);
83 
84   void printImpl(const AttrConstraintDecl *decl);
85   void printImpl(const OpConstraintDecl *decl);
86   void printImpl(const TypeConstraintDecl *decl);
87   void printImpl(const TypeRangeConstraintDecl *decl);
88   void printImpl(const ValueConstraintDecl *decl);
89   void printImpl(const ValueRangeConstraintDecl *decl);
90   void printImpl(const OpNameDecl *decl);
91   void printImpl(const PatternDecl *decl);
92   void printImpl(const VariableDecl *decl);
93   void printImpl(const Module *module);
94 
95   /// Print the current indent stack.
96   void printIndent() {
97     if (elementIndentStack.empty())
98       return;
99 
100     for (bool isLastElt : llvm::makeArrayRef(elementIndentStack).drop_back())
101       os << (isLastElt ? "  " : " |");
102     os << (elementIndentStack.back() ? " `" : " |");
103   }
104 
105   /// The raw output stream.
106   raw_ostream &os;
107 
108   /// A stack of indents and a flag indicating if the current element being
109   /// printed at that indent is the last element.
110   SmallVector<bool> elementIndentStack;
111 };
112 } // namespace
113 
114 void NodePrinter::print(Type type) {
115   // Protect against invalid inputs.
116   if (!type) {
117     os << "Type<NULL>";
118     return;
119   }
120 
121   TypeSwitch<Type>(type)
122       .Case([&](AttributeType) { os << "Attr"; })
123       .Case([&](ConstraintType) { os << "Constraint"; })
124       .Case([&](OperationType type) {
125         os << "Op";
126         if (Optional<StringRef> name = type.getName())
127           os << "<" << *name << ">";
128       })
129       .Case([&](RangeType type) {
130         print(type.getElementType());
131         os << "Range";
132       })
133       .Case([&](TypeType) { os << "Type"; })
134       .Case([&](ValueType) { os << "Value"; })
135       .Default([](Type) { llvm_unreachable("unknown AST type"); });
136 }
137 
138 void NodePrinter::print(const Node *node) {
139   printIndent();
140   os << "-";
141 
142   elementIndentStack.push_back(/*isLastElt*/ false);
143   TypeSwitch<const Node *>(node)
144       .Case<
145           // Statements.
146           const CompoundStmt, const EraseStmt, const LetStmt,
147 
148           // Expressions.
149           const AttributeExpr, const DeclRefExpr, const MemberAccessExpr,
150           const TypeExpr,
151 
152           // Decls.
153           const AttrConstraintDecl, const OpConstraintDecl,
154           const TypeConstraintDecl, const TypeRangeConstraintDecl,
155           const ValueConstraintDecl, const ValueRangeConstraintDecl,
156           const OpNameDecl, const PatternDecl, const VariableDecl,
157 
158           const Module>([&](auto derivedNode) { this->printImpl(derivedNode); })
159       .Default([](const Node *) { llvm_unreachable("unknown AST node"); });
160   elementIndentStack.pop_back();
161 }
162 
163 void NodePrinter::printImpl(const CompoundStmt *stmt) {
164   os << "CompoundStmt " << stmt << "\n";
165   printChildren(stmt->getChildren());
166 }
167 
168 void NodePrinter::printImpl(const EraseStmt *stmt) {
169   os << "EraseStmt " << stmt << "\n";
170   printChildren(stmt->getRootOpExpr());
171 }
172 
173 void NodePrinter::printImpl(const LetStmt *stmt) {
174   os << "LetStmt " << stmt << "\n";
175   printChildren(stmt->getVarDecl());
176 }
177 
178 void NodePrinter::printImpl(const AttributeExpr *expr) {
179   os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
180 }
181 
182 void NodePrinter::printImpl(const DeclRefExpr *expr) {
183   os << "DeclRefExpr " << expr << " Type<";
184   print(expr->getType());
185   os << ">\n";
186   printChildren(expr->getDecl());
187 }
188 
189 void NodePrinter::printImpl(const MemberAccessExpr *expr) {
190   os << "MemberAccessExpr " << expr << " Member<" << expr->getMemberName()
191      << "> Type<";
192   print(expr->getType());
193   os << ">\n";
194   printChildren(expr->getParentExpr());
195 }
196 
197 void NodePrinter::printImpl(const TypeExpr *expr) {
198   os << "TypeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
199 }
200 
201 void NodePrinter::printImpl(const AttrConstraintDecl *decl) {
202   os << "AttrConstraintDecl " << decl << "\n";
203   if (const auto *typeExpr = decl->getTypeExpr())
204     printChildren(typeExpr);
205 }
206 
207 void NodePrinter::printImpl(const OpConstraintDecl *decl) {
208   os << "OpConstraintDecl " << decl << "\n";
209   printChildren(decl->getNameDecl());
210 }
211 
212 void NodePrinter::printImpl(const TypeConstraintDecl *decl) {
213   os << "TypeConstraintDecl " << decl << "\n";
214 }
215 
216 void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) {
217   os << "TypeRangeConstraintDecl " << decl << "\n";
218 }
219 
220 void NodePrinter::printImpl(const ValueConstraintDecl *decl) {
221   os << "ValueConstraintDecl " << decl << "\n";
222   if (const auto *typeExpr = decl->getTypeExpr())
223     printChildren(typeExpr);
224 }
225 
226 void NodePrinter::printImpl(const ValueRangeConstraintDecl *decl) {
227   os << "ValueRangeConstraintDecl " << decl << "\n";
228   if (const auto *typeExpr = decl->getTypeExpr())
229     printChildren(typeExpr);
230 }
231 
232 void NodePrinter::printImpl(const OpNameDecl *decl) {
233   os << "OpNameDecl " << decl;
234   if (Optional<StringRef> name = decl->getName())
235     os << " Name<" << name << ">";
236   os << "\n";
237 }
238 
239 void NodePrinter::printImpl(const PatternDecl *decl) {
240   os << "PatternDecl " << decl;
241   if (const Name *name = decl->getName())
242     os << " Name<" << name->getName() << ">";
243   if (Optional<uint16_t> benefit = decl->getBenefit())
244     os << " Benefit<" << *benefit << ">";
245   if (decl->hasBoundedRewriteRecursion())
246     os << " Recursion";
247 
248   os << "\n";
249   printChildren(decl->getBody());
250 }
251 
252 void NodePrinter::printImpl(const VariableDecl *decl) {
253   os << "VariableDecl " << decl << " Name<" << decl->getName().getName()
254      << "> Type<";
255   print(decl->getType());
256   os << ">\n";
257   if (Expr *initExpr = decl->getInitExpr())
258     printChildren(initExpr);
259 
260   auto constraints =
261       llvm::map_range(decl->getConstraints(),
262                       [](const ConstraintRef &ref) { return ref.constraint; });
263   printChildren("Constraints", constraints);
264 }
265 
266 void NodePrinter::printImpl(const Module *module) {
267   os << "Module " << module << "\n";
268   printChildren(module->getChildren());
269 }
270 
271 //===----------------------------------------------------------------------===//
272 // Entry point
273 //===----------------------------------------------------------------------===//
274 
275 void Node::print(raw_ostream &os) const { NodePrinter(os).print(this); }
276 
277 void Type::print(raw_ostream &os) const { NodePrinter(os).print(*this); }
278