1 //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Transforms/ViewOpGraph.h"
10 #include "PassDetail.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/Support/IndentedOstream.h"
15 #include "llvm/Support/Format.h"
16 #include "llvm/Support/GraphWriter.h"
17 #include <utility>
18
19 using namespace mlir;
20
21 static const StringRef kLineStyleControlFlow = "dashed";
22 static const StringRef kLineStyleDataFlow = "solid";
23 static const StringRef kShapeNode = "ellipse";
24 static const StringRef kShapeNone = "plain";
25
26 /// Return the size limits for eliding large attributes.
getLargeAttributeSizeLimit()27 static int64_t getLargeAttributeSizeLimit() {
28 // Use the default from the printer flags if possible.
29 if (Optional<int64_t> limit = OpPrintingFlags().getLargeElementsAttrLimit())
30 return *limit;
31 return 16;
32 }
33
34 /// Return all values printed onto a stream as a string.
strFromOs(function_ref<void (raw_ostream &)> func)35 static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
36 std::string buf;
37 llvm::raw_string_ostream os(buf);
38 func(os);
39 return os.str();
40 }
41
42 /// Escape special characters such as '\n' and quotation marks.
escapeString(std::string str)43 static std::string escapeString(std::string str) {
44 return strFromOs([&](raw_ostream &os) { os.write_escaped(str); });
45 }
46
47 /// Put quotation marks around a given string.
quoteString(const std::string & str)48 static std::string quoteString(const std::string &str) {
49 return "\"" + str + "\"";
50 }
51
52 using AttributeMap = llvm::StringMap<std::string>;
53
54 namespace {
55
56 /// This struct represents a node in the DOT language. Each node has an
57 /// identifier and an optional identifier for the cluster (subgraph) that
58 /// contains the node.
59 /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but
60 /// not between clusters. However, edges can be clipped to the boundary of a
61 /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new
62 /// cluster, an invisible "anchor" node is created.
63 struct Node {
64 public:
Node__anon2ace0bc30211::Node65 Node(int id = 0, Optional<int> clusterId = llvm::None)
66 : id(id), clusterId(clusterId) {}
67
68 int id;
69 Optional<int> clusterId;
70 };
71
72 /// This pass generates a Graphviz dataflow visualization of an MLIR operation.
73 /// Note: See https://www.graphviz.org/doc/info/lang.html for more information
74 /// about the Graphviz DOT language.
75 class PrintOpPass : public ViewOpGraphBase<PrintOpPass> {
76 public:
PrintOpPass(raw_ostream & os)77 PrintOpPass(raw_ostream &os) : os(os) {}
PrintOpPass(const PrintOpPass & o)78 PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
79
runOnOperation()80 void runOnOperation() override {
81 emitGraph([&]() {
82 processOperation(getOperation());
83 emitAllEdgeStmts();
84 });
85 }
86
87 /// Create a CFG graph for a region. Used in `Region::viewGraph`.
emitRegionCFG(Region & region)88 void emitRegionCFG(Region ®ion) {
89 printControlFlowEdges = true;
90 printDataFlowEdges = false;
91 emitGraph([&]() { processRegion(region); });
92 }
93
94 private:
95 /// Emit all edges. This function should be called after all nodes have been
96 /// emitted.
emitAllEdgeStmts()97 void emitAllEdgeStmts() {
98 for (const std::string &edge : edges)
99 os << edge << ";\n";
100 edges.clear();
101 }
102
103 /// Emit a cluster (subgraph). The specified builder generates the body of the
104 /// cluster. Return the anchor node of the cluster.
emitClusterStmt(function_ref<void ()> builder,std::string label="")105 Node emitClusterStmt(function_ref<void()> builder, std::string label = "") {
106 int clusterId = ++counter;
107 os << "subgraph cluster_" << clusterId << " {\n";
108 os.indent();
109 // Emit invisible anchor node from/to which arrows can be drawn.
110 Node anchorNode = emitNodeStmt(" ", kShapeNone);
111 os << attrStmt("label", quoteString(escapeString(std::move(label))))
112 << ";\n";
113 builder();
114 os.unindent();
115 os << "}\n";
116 return Node(anchorNode.id, clusterId);
117 }
118
119 /// Generate an attribute statement.
attrStmt(const Twine & key,const Twine & value)120 std::string attrStmt(const Twine &key, const Twine &value) {
121 return (key + " = " + value).str();
122 }
123
124 /// Emit an attribute list.
emitAttrList(raw_ostream & os,const AttributeMap & map)125 void emitAttrList(raw_ostream &os, const AttributeMap &map) {
126 os << "[";
127 interleaveComma(map, os, [&](const auto &it) {
128 os << this->attrStmt(it.getKey(), it.getValue());
129 });
130 os << "]";
131 }
132
133 // Print an MLIR attribute to `os`. Large attributes are truncated.
emitMlirAttr(raw_ostream & os,Attribute attr)134 void emitMlirAttr(raw_ostream &os, Attribute attr) {
135 // A value used to elide large container attribute.
136 int64_t largeAttrLimit = getLargeAttributeSizeLimit();
137
138 // Always emit splat attributes.
139 if (attr.isa<SplatElementsAttr>()) {
140 attr.print(os);
141 return;
142 }
143
144 // Elide "big" elements attributes.
145 auto elements = attr.dyn_cast<ElementsAttr>();
146 if (elements && elements.getNumElements() > largeAttrLimit) {
147 os << std::string(elements.getType().getRank(), '[') << "..."
148 << std::string(elements.getType().getRank(), ']') << " : "
149 << elements.getType();
150 return;
151 }
152
153 auto array = attr.dyn_cast<ArrayAttr>();
154 if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) {
155 os << "[...]";
156 return;
157 }
158
159 // Print all other attributes.
160 std::string buf;
161 llvm::raw_string_ostream ss(buf);
162 attr.print(ss);
163 os << truncateString(ss.str());
164 }
165
166 /// Append an edge to the list of edges.
167 /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
emitEdgeStmt(Node n1,Node n2,std::string label,StringRef style)168 void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) {
169 AttributeMap attrs;
170 attrs["style"] = style.str();
171 // Do not label edges that start/end at a cluster boundary. Such edges are
172 // clipped at the boundary, but labels are not. This can lead to labels
173 // floating around without any edge next to them.
174 if (!n1.clusterId && !n2.clusterId)
175 attrs["label"] = quoteString(escapeString(std::move(label)));
176 // Use `ltail` and `lhead` to draw edges between clusters.
177 if (n1.clusterId)
178 attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId);
179 if (n2.clusterId)
180 attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId);
181
182 edges.push_back(strFromOs([&](raw_ostream &os) {
183 os << llvm::format("v%i -> v%i ", n1.id, n2.id);
184 emitAttrList(os, attrs);
185 }));
186 }
187
188 /// Emit a graph. The specified builder generates the body of the graph.
emitGraph(function_ref<void ()> builder)189 void emitGraph(function_ref<void()> builder) {
190 os << "digraph G {\n";
191 os.indent();
192 // Edges between clusters are allowed only in compound mode.
193 os << attrStmt("compound", "true") << ";\n";
194 builder();
195 os.unindent();
196 os << "}\n";
197 }
198
199 /// Emit a node statement.
emitNodeStmt(std::string label,StringRef shape=kShapeNode)200 Node emitNodeStmt(std::string label, StringRef shape = kShapeNode) {
201 int nodeId = ++counter;
202 AttributeMap attrs;
203 attrs["label"] = quoteString(escapeString(std::move(label)));
204 attrs["shape"] = shape.str();
205 os << llvm::format("v%i ", nodeId);
206 emitAttrList(os, attrs);
207 os << ";\n";
208 return Node(nodeId);
209 }
210
211 /// Generate a label for an operation.
getLabel(Operation * op)212 std::string getLabel(Operation *op) {
213 return strFromOs([&](raw_ostream &os) {
214 // Print operation name and type.
215 os << op->getName();
216 if (printResultTypes) {
217 os << " : (";
218 std::string buf;
219 llvm::raw_string_ostream ss(buf);
220 interleaveComma(op->getResultTypes(), ss);
221 os << truncateString(ss.str()) << ")";
222 os << ")";
223 }
224
225 // Print attributes.
226 if (printAttrs) {
227 os << "\n";
228 for (const NamedAttribute &attr : op->getAttrs()) {
229 os << '\n' << attr.getName().getValue() << ": ";
230 emitMlirAttr(os, attr.getValue());
231 }
232 }
233 });
234 }
235
236 /// Generate a label for a block argument.
getLabel(BlockArgument arg)237 std::string getLabel(BlockArgument arg) {
238 return "arg" + std::to_string(arg.getArgNumber());
239 }
240
241 /// Process a block. Emit a cluster and one node per block argument and
242 /// operation inside the cluster.
processBlock(Block & block)243 void processBlock(Block &block) {
244 emitClusterStmt([&]() {
245 for (BlockArgument &blockArg : block.getArguments())
246 valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
247
248 // Emit a node for each operation.
249 Optional<Node> prevNode;
250 for (Operation &op : block) {
251 Node nextNode = processOperation(&op);
252 if (printControlFlowEdges && prevNode)
253 emitEdgeStmt(*prevNode, nextNode, /*label=*/"",
254 kLineStyleControlFlow);
255 prevNode = nextNode;
256 }
257 });
258 }
259
260 /// Process an operation. If the operation has regions, emit a cluster.
261 /// Otherwise, emit a node.
processOperation(Operation * op)262 Node processOperation(Operation *op) {
263 Node node;
264 if (op->getNumRegions() > 0) {
265 // Emit cluster for op with regions.
266 node = emitClusterStmt(
267 [&]() {
268 for (Region ®ion : op->getRegions())
269 processRegion(region);
270 },
271 getLabel(op));
272 } else {
273 node = emitNodeStmt(getLabel(op));
274 }
275
276 // Insert data flow edges originating from each operand.
277 if (printDataFlowEdges) {
278 unsigned numOperands = op->getNumOperands();
279 for (unsigned i = 0; i < numOperands; i++)
280 emitEdgeStmt(valueToNode[op->getOperand(i)], node,
281 /*label=*/numOperands == 1 ? "" : std::to_string(i),
282 kLineStyleDataFlow);
283 }
284
285 for (Value result : op->getResults())
286 valueToNode[result] = node;
287
288 return node;
289 }
290
291 /// Process a region.
processRegion(Region & region)292 void processRegion(Region ®ion) {
293 for (Block &block : region.getBlocks())
294 processBlock(block);
295 }
296
297 /// Truncate long strings.
truncateString(std::string str)298 std::string truncateString(std::string str) {
299 if (str.length() <= maxLabelLen)
300 return str;
301 return str.substr(0, maxLabelLen) + "...";
302 }
303
304 /// Output stream to write DOT file to.
305 raw_indented_ostream os;
306 /// A list of edges. For simplicity, should be emitted after all nodes were
307 /// emitted.
308 std::vector<std::string> edges;
309 /// Mapping of SSA values to Graphviz nodes/clusters.
310 DenseMap<Value, Node> valueToNode;
311 /// Counter for generating unique node/subgraph identifiers.
312 int counter = 0;
313 };
314
315 } // namespace
316
createPrintOpGraphPass(raw_ostream & os)317 std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) {
318 return std::make_unique<PrintOpPass>(os);
319 }
320
321 /// Generate a CFG for a region and show it in a window.
llvmViewGraph(Region & region,const Twine & name)322 static void llvmViewGraph(Region ®ion, const Twine &name) {
323 int fd;
324 std::string filename = llvm::createGraphFilename(name.str(), fd);
325 {
326 llvm::raw_fd_ostream os(fd, /*shouldClose=*/true);
327 if (fd == -1) {
328 llvm::errs() << "error opening file '" << filename << "' for writing\n";
329 return;
330 }
331 PrintOpPass pass(os);
332 pass.emitRegionCFG(region);
333 }
334 llvm::DisplayGraph(filename, /*wait=*/false, llvm::GraphProgram::DOT);
335 }
336
viewGraph(const Twine & regionName)337 void mlir::Region::viewGraph(const Twine ®ionName) {
338 llvmViewGraph(*this, regionName);
339 }
340
viewGraph()341 void mlir::Region::viewGraph() { viewGraph("region"); }
342