1 //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/ViewOpGraph.h"
10 #include "PassDetail.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/Operation.h"
13 #include "mlir/Support/IndentedOstream.h"
14 #include "llvm/Support/Format.h"
15 
16 using namespace mlir;
17 
18 static const StringRef kLineStyleDataFlow = "solid";
19 static const StringRef kShapeNode = "ellipse";
20 static const StringRef kShapeNone = "plain";
21 
22 /// Return the size limits for eliding large attributes.
23 static int64_t getLargeAttributeSizeLimit() {
24   // Use the default from the printer flags if possible.
25   if (Optional<int64_t> limit = OpPrintingFlags().getLargeElementsAttrLimit())
26     return *limit;
27   return 16;
28 }
29 
30 /// Return all values printed onto a stream as a string.
31 static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
32   std::string buf;
33   llvm::raw_string_ostream os(buf);
34   func(os);
35   return os.str();
36 }
37 
38 /// Escape special characters such as '\n' and quotation marks.
39 static std::string escapeString(std::string str) {
40   return strFromOs([&](raw_ostream &os) { os.write_escaped(str); });
41 }
42 
43 /// Put quotation marks around a given string.
44 static std::string quoteString(std::string str) { return "\"" + str + "\""; }
45 
46 using AttributeMap = llvm::StringMap<std::string>;
47 
48 namespace {
49 
50 /// This struct represents a node in the DOT language. Each node has an
51 /// identifier and an optional identifier for the cluster (subgraph) that
52 /// contains the node.
53 /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but
54 /// not between clusters. However, edges can be clipped to the boundary of a
55 /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new
56 /// cluster, an invisible "anchor" node is created.
57 struct Node {
58 public:
59   Node(int id = 0, Optional<int> clusterId = llvm::None)
60       : id(id), clusterId(clusterId) {}
61 
62   int id;
63   Optional<int> clusterId;
64 };
65 
66 /// This pass generates a Graphviz dataflow visualization of an MLIR operation.
67 /// Note: See https://www.graphviz.org/doc/info/lang.html for more information
68 /// about the Graphviz DOT language.
69 class PrintOpPass : public ViewOpGraphPassBase<PrintOpPass> {
70 public:
71   PrintOpPass(raw_ostream &os) : os(os) {}
72   PrintOpPass(const PrintOpPass &o) : os(o.os.getOStream()) {}
73 
74   void runOnOperation() override {
75     emitGraph([&]() {
76       processOperation(getOperation());
77       emitAllEdgeStmts();
78     });
79   }
80 
81 private:
82   /// Emit all edges. This function should be called after all nodes have been
83   /// emitted.
84   void emitAllEdgeStmts() {
85     for (const std::string &edge : edges)
86       os << edge << ";\n";
87     edges.clear();
88   }
89 
90   /// Emit a cluster (subgraph). The specified builder generates the body of the
91   /// cluster. Return the anchor node of the cluster.
92   Node emitClusterStmt(function_ref<void()> builder, std::string label = "") {
93     int clusterId = ++counter;
94     os << "subgraph cluster_" << clusterId << " {\n";
95     os.indent();
96     // Emit invisible anchor node from/to which arrows can be drawn.
97     Node anchorNode = emitNodeStmt(" ", kShapeNone);
98     os << attrStmt("label", quoteString(escapeString(label))) << ";\n";
99     builder();
100     os.unindent();
101     os << "}\n";
102     return Node(anchorNode.id, clusterId);
103   }
104 
105   /// Generate an attribute statement.
106   std::string attrStmt(const Twine &key, const Twine &value) {
107     return (key + " = " + value).str();
108   }
109 
110   /// Emit an attribute list.
111   void emitAttrList(raw_ostream &os, const AttributeMap &map) {
112     os << "[";
113     interleaveComma(map, os, [&](const auto &it) {
114       os << attrStmt(it.getKey(), it.getValue());
115     });
116     os << "]";
117   }
118 
119   // Print an MLIR attribute to `os`. Large attributes are truncated.
120   void emitMlirAttr(raw_ostream &os, Attribute attr) {
121     // A value used to elide large container attribute.
122     int64_t largeAttrLimit = getLargeAttributeSizeLimit();
123 
124     // Always emit splat attributes.
125     if (attr.isa<SplatElementsAttr>()) {
126       attr.print(os);
127       return;
128     }
129 
130     // Elide "big" elements attributes.
131     auto elements = attr.dyn_cast<ElementsAttr>();
132     if (elements && elements.getNumElements() > largeAttrLimit) {
133       os << std::string(elements.getType().getRank(), '[') << "..."
134          << std::string(elements.getType().getRank(), ']') << " : "
135          << elements.getType();
136       return;
137     }
138 
139     auto array = attr.dyn_cast<ArrayAttr>();
140     if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) {
141       os << "[...]";
142       return;
143     }
144 
145     // Print all other attributes.
146     std::string buf;
147     llvm::raw_string_ostream ss(buf);
148     attr.print(ss);
149     os << truncateString(ss.str());
150   }
151 
152   /// Append an edge to the list of edges.
153   /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
154   void emitEdgeStmt(Node n1, Node n2, std::string label,
155                     StringRef style = kLineStyleDataFlow) {
156     AttributeMap attrs;
157     attrs["style"] = style.str();
158     // Do not label edges that start/end at a cluster boundary. Such edges are
159     // clipped at the boundary, but labels are not. This can lead to labels
160     // floating around without any edge next to them.
161     if (!n1.clusterId && !n2.clusterId)
162       attrs["label"] = quoteString(escapeString(label));
163     // Use `ltail` and `lhead` to draw edges between clusters.
164     if (n1.clusterId)
165       attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId);
166     if (n2.clusterId)
167       attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId);
168 
169     edges.push_back(strFromOs([&](raw_ostream &os) {
170       os << llvm::format("v%i -> v%i ", n1.id, n2.id);
171       emitAttrList(os, attrs);
172     }));
173   }
174 
175   /// Emit a graph. The specified builder generates the body of the graph.
176   void emitGraph(function_ref<void()> builder) {
177     os << "digraph G {\n";
178     os.indent();
179     // Edges between clusters are allowed only in compound mode.
180     os << attrStmt("compound", "true") << ";\n";
181     builder();
182     os.unindent();
183     os << "}\n";
184   }
185 
186   /// Emit a node statement.
187   Node emitNodeStmt(std::string label, StringRef shape = kShapeNode) {
188     int nodeId = ++counter;
189     AttributeMap attrs;
190     attrs["label"] = quoteString(escapeString(label));
191     attrs["shape"] = shape.str();
192     os << llvm::format("v%i ", nodeId);
193     emitAttrList(os, attrs);
194     os << ";\n";
195     return Node(nodeId);
196   }
197 
198   /// Generate a label for an operation.
199   std::string getLabel(Operation *op) {
200     return strFromOs([&](raw_ostream &os) {
201       // Print operation name and type.
202       os << op->getName();
203       if (printResultTypes) {
204         os << " : (";
205         std::string buf;
206         llvm::raw_string_ostream ss(buf);
207         interleaveComma(op->getResultTypes(), ss);
208         os << truncateString(ss.str()) << ")";
209         os << ")";
210       }
211 
212       // Print attributes.
213       if (printAttrs) {
214         os << "\n";
215         for (const NamedAttribute &attr : op->getAttrs()) {
216           os << '\n' << attr.first << ": ";
217           emitMlirAttr(os, attr.second);
218         }
219       }
220     });
221   }
222 
223   /// Generate a label for a block argument.
224   std::string getLabel(BlockArgument arg) {
225     return "arg" + std::to_string(arg.getArgNumber());
226   }
227 
228   /// Process a block. Emit a cluster and one node per block argument and
229   /// operation inside the cluster.
230   void processBlock(Block &block) {
231     emitClusterStmt([&]() {
232       for (BlockArgument &blockArg : block.getArguments())
233         valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
234 
235       // Emit a node for each operation.
236       for (Operation &op : block)
237         processOperation(&op);
238     });
239   }
240 
241   /// Process an operation. If the operation has regions, emit a cluster.
242   /// Otherwise, emit a node.
243   void processOperation(Operation *op) {
244     Node node;
245     if (op->getNumRegions() > 0) {
246       // Emit cluster for op with regions.
247       node = emitClusterStmt(
248           [&]() {
249             for (Region &region : op->getRegions())
250               processRegion(region);
251           },
252           getLabel(op));
253     } else {
254       node = emitNodeStmt(getLabel(op));
255     }
256 
257     // Insert edges originating from each operand.
258     unsigned numOperands = op->getNumOperands();
259     for (unsigned i = 0; i < numOperands; i++)
260       emitEdgeStmt(valueToNode[op->getOperand(i)], node,
261                    /*label=*/numOperands == 1 ? "" : std::to_string(i));
262 
263     for (Value result : op->getResults())
264       valueToNode[result] = node;
265   }
266 
267   /// Process a region.
268   void processRegion(Region &region) {
269     for (Block &block : region.getBlocks())
270       processBlock(block);
271   }
272 
273   /// Truncate long strings.
274   std::string truncateString(std::string str) {
275     if (str.length() <= maxLabelLen)
276       return str;
277     return str.substr(0, maxLabelLen) + "...";
278   }
279 
280   /// Output stream to write DOT file to.
281   raw_indented_ostream os;
282   /// A list of edges. For simplicity, should be emitted after all nodes were
283   /// emitted.
284   std::vector<std::string> edges;
285   /// Mapping of SSA values to Graphviz nodes/clusters.
286   DenseMap<Value, Node> valueToNode;
287   /// Counter for generating unique node/subgraph identifiers.
288   int counter = 0;
289 };
290 
291 } // namespace
292 
293 std::unique_ptr<Pass>
294 mlir::createPrintOpGraphPass(raw_ostream &os) {
295   return std::make_unique<PrintOpPass>(os);
296 }
297