1 //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/ViewOpGraph.h" 10 #include "PassDetail.h" 11 #include "mlir/IR/Block.h" 12 #include "mlir/IR/Operation.h" 13 #include "mlir/Support/IndentedOstream.h" 14 #include "llvm/Support/Format.h" 15 16 using namespace mlir; 17 18 static const StringRef kLineStyleDataFlow = "solid"; 19 static const StringRef kShapeNode = "ellipse"; 20 static const StringRef kShapeNone = "plain"; 21 22 /// Return the size limits for eliding large attributes. 23 static int64_t getLargeAttributeSizeLimit() { 24 // Use the default from the printer flags if possible. 25 if (Optional<int64_t> limit = OpPrintingFlags().getLargeElementsAttrLimit()) 26 return *limit; 27 return 16; 28 } 29 30 /// Return all values printed onto a stream as a string. 31 static std::string strFromOs(function_ref<void(raw_ostream &)> func) { 32 std::string buf; 33 llvm::raw_string_ostream os(buf); 34 func(os); 35 return os.str(); 36 } 37 38 /// Escape special characters such as '\n' and quotation marks. 39 static std::string escapeString(std::string str) { 40 return strFromOs([&](raw_ostream &os) { os.write_escaped(str); }); 41 } 42 43 /// Put quotation marks around a given string. 44 static std::string quoteString(std::string str) { return "\"" + str + "\""; } 45 46 using AttributeMap = llvm::StringMap<std::string>; 47 48 namespace { 49 50 /// This struct represents a node in the DOT language. Each node has an 51 /// identifier and an optional identifier for the cluster (subgraph) that 52 /// contains the node. 53 /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but 54 /// not between clusters. However, edges can be clipped to the boundary of a 55 /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new 56 /// cluster, an invisible "anchor" node is created. 57 struct Node { 58 public: 59 Node(int id = 0, Optional<int> clusterId = llvm::None) 60 : id(id), clusterId(clusterId) {} 61 62 int id; 63 Optional<int> clusterId; 64 }; 65 66 /// This pass generates a Graphviz dataflow visualization of an MLIR operation. 67 /// Note: See https://www.graphviz.org/doc/info/lang.html for more information 68 /// about the Graphviz DOT language. 69 class PrintOpPass : public ViewOpGraphPassBase<PrintOpPass> { 70 public: 71 PrintOpPass(raw_ostream &os) : os(os) {} 72 PrintOpPass(const PrintOpPass &o) : os(o.os.getOStream()) {} 73 74 void runOnOperation() override { 75 emitGraph([&]() { 76 processOperation(getOperation()); 77 emitAllEdgeStmts(); 78 }); 79 } 80 81 private: 82 /// Emit all edges. This function should be called after all nodes have been 83 /// emitted. 84 void emitAllEdgeStmts() { 85 for (const std::string &edge : edges) 86 os << edge << ";\n"; 87 edges.clear(); 88 } 89 90 /// Emit a cluster (subgraph). The specified builder generates the body of the 91 /// cluster. Return the anchor node of the cluster. 92 Node emitClusterStmt(function_ref<void()> builder, std::string label = "") { 93 int clusterId = ++counter; 94 os << "subgraph cluster_" << clusterId << " {\n"; 95 os.indent(); 96 // Emit invisible anchor node from/to which arrows can be drawn. 97 Node anchorNode = emitNodeStmt(" ", kShapeNone); 98 os << attrStmt("label", quoteString(escapeString(label))) << ";\n"; 99 builder(); 100 os.unindent(); 101 os << "}\n"; 102 return Node(anchorNode.id, clusterId); 103 } 104 105 /// Generate an attribute statement. 106 std::string attrStmt(const Twine &key, const Twine &value) { 107 return (key + " = " + value).str(); 108 } 109 110 /// Emit an attribute list. 111 void emitAttrList(raw_ostream &os, const AttributeMap &map) { 112 os << "["; 113 interleaveComma(map, os, [&](const auto &it) { 114 os << attrStmt(it.getKey(), it.getValue()); 115 }); 116 os << "]"; 117 } 118 119 // Print an MLIR attribute to `os`. Large attributes are truncated. 120 void emitMlirAttr(raw_ostream &os, Attribute attr) { 121 // A value used to elide large container attribute. 122 int64_t largeAttrLimit = getLargeAttributeSizeLimit(); 123 124 // Always emit splat attributes. 125 if (attr.isa<SplatElementsAttr>()) { 126 attr.print(os); 127 return; 128 } 129 130 // Elide "big" elements attributes. 131 auto elements = attr.dyn_cast<ElementsAttr>(); 132 if (elements && elements.getNumElements() > largeAttrLimit) { 133 os << std::string(elements.getType().getRank(), '[') << "..." 134 << std::string(elements.getType().getRank(), ']') << " : " 135 << elements.getType(); 136 return; 137 } 138 139 auto array = attr.dyn_cast<ArrayAttr>(); 140 if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) { 141 os << "[...]"; 142 return; 143 } 144 145 // Print all other attributes. 146 std::string buf; 147 llvm::raw_string_ostream ss(buf); 148 attr.print(ss); 149 os << truncateString(ss.str()); 150 } 151 152 /// Append an edge to the list of edges. 153 /// Note: Edges are written to the output stream via `emitAllEdgeStmts`. 154 void emitEdgeStmt(Node n1, Node n2, std::string label, 155 StringRef style = kLineStyleDataFlow) { 156 AttributeMap attrs; 157 attrs["style"] = style.str(); 158 // Do not label edges that start/end at a cluster boundary. Such edges are 159 // clipped at the boundary, but labels are not. This can lead to labels 160 // floating around without any edge next to them. 161 if (!n1.clusterId && !n2.clusterId) 162 attrs["label"] = quoteString(escapeString(label)); 163 // Use `ltail` and `lhead` to draw edges between clusters. 164 if (n1.clusterId) 165 attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId); 166 if (n2.clusterId) 167 attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId); 168 169 edges.push_back(strFromOs([&](raw_ostream &os) { 170 os << llvm::format("v%i -> v%i ", n1.id, n2.id); 171 emitAttrList(os, attrs); 172 })); 173 } 174 175 /// Emit a graph. The specified builder generates the body of the graph. 176 void emitGraph(function_ref<void()> builder) { 177 os << "digraph G {\n"; 178 os.indent(); 179 // Edges between clusters are allowed only in compound mode. 180 os << attrStmt("compound", "true") << ";\n"; 181 builder(); 182 os.unindent(); 183 os << "}\n"; 184 } 185 186 /// Emit a node statement. 187 Node emitNodeStmt(std::string label, StringRef shape = kShapeNode) { 188 int nodeId = ++counter; 189 AttributeMap attrs; 190 attrs["label"] = quoteString(escapeString(label)); 191 attrs["shape"] = shape.str(); 192 os << llvm::format("v%i ", nodeId); 193 emitAttrList(os, attrs); 194 os << ";\n"; 195 return Node(nodeId); 196 } 197 198 /// Generate a label for an operation. 199 std::string getLabel(Operation *op) { 200 return strFromOs([&](raw_ostream &os) { 201 // Print operation name and type. 202 os << op->getName(); 203 if (printResultTypes) { 204 os << " : ("; 205 std::string buf; 206 llvm::raw_string_ostream ss(buf); 207 interleaveComma(op->getResultTypes(), ss); 208 os << truncateString(ss.str()) << ")"; 209 os << ")"; 210 } 211 212 // Print attributes. 213 if (printAttrs) { 214 os << "\n"; 215 for (const NamedAttribute &attr : op->getAttrs()) { 216 os << '\n' << attr.first << ": "; 217 emitMlirAttr(os, attr.second); 218 } 219 } 220 }); 221 } 222 223 /// Generate a label for a block argument. 224 std::string getLabel(BlockArgument arg) { 225 return "arg" + std::to_string(arg.getArgNumber()); 226 } 227 228 /// Process a block. Emit a cluster and one node per block argument and 229 /// operation inside the cluster. 230 void processBlock(Block &block) { 231 emitClusterStmt([&]() { 232 for (BlockArgument &blockArg : block.getArguments()) 233 valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg)); 234 235 // Emit a node for each operation. 236 for (Operation &op : block) 237 processOperation(&op); 238 }); 239 } 240 241 /// Process an operation. If the operation has regions, emit a cluster. 242 /// Otherwise, emit a node. 243 void processOperation(Operation *op) { 244 Node node; 245 if (op->getNumRegions() > 0) { 246 // Emit cluster for op with regions. 247 node = emitClusterStmt( 248 [&]() { 249 for (Region ®ion : op->getRegions()) 250 processRegion(region); 251 }, 252 getLabel(op)); 253 } else { 254 node = emitNodeStmt(getLabel(op)); 255 } 256 257 // Insert edges originating from each operand. 258 unsigned numOperands = op->getNumOperands(); 259 for (unsigned i = 0; i < numOperands; i++) 260 emitEdgeStmt(valueToNode[op->getOperand(i)], node, 261 /*label=*/numOperands == 1 ? "" : std::to_string(i)); 262 263 for (Value result : op->getResults()) 264 valueToNode[result] = node; 265 } 266 267 /// Process a region. 268 void processRegion(Region ®ion) { 269 for (Block &block : region.getBlocks()) 270 processBlock(block); 271 } 272 273 /// Truncate long strings. 274 std::string truncateString(std::string str) { 275 if (str.length() <= maxLabelLen) 276 return str; 277 return str.substr(0, maxLabelLen) + "..."; 278 } 279 280 /// Output stream to write DOT file to. 281 raw_indented_ostream os; 282 /// A list of edges. For simplicity, should be emitted after all nodes were 283 /// emitted. 284 std::vector<std::string> edges; 285 /// Mapping of SSA values to Graphviz nodes/clusters. 286 DenseMap<Value, Node> valueToNode; 287 /// Counter for generating unique node/subgraph identifiers. 288 int counter = 0; 289 }; 290 291 } // namespace 292 293 std::unique_ptr<Pass> 294 mlir::createPrintOpGraphPass(raw_ostream &os) { 295 return std::make_unique<PrintOpPass>(os); 296 } 297