12660623aSJacques Pienaar //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===// 22660623aSJacques Pienaar // 330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information. 556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 62660623aSJacques Pienaar // 756222a06SMehdi Amini //===----------------------------------------------------------------------===// 82660623aSJacques Pienaar 92660623aSJacques Pienaar #include "mlir/Transforms/ViewOpGraph.h" 101834ad4aSRiver Riddle #include "PassDetail.h" 112660623aSJacques Pienaar #include "mlir/IR/Block.h" 12*36d3efeaSRiver Riddle #include "mlir/IR/BuiltinTypes.h" 132660623aSJacques Pienaar #include "mlir/IR/Operation.h" 148d15b7dcSMatthias Springer #include "mlir/Support/IndentedOstream.h" 158d15b7dcSMatthias Springer #include "llvm/Support/Format.h" 169102a16bSMatthias Springer #include "llvm/Support/GraphWriter.h" 17*36d3efeaSRiver Riddle #include <utility> 182660623aSJacques Pienaar 194562e389SRiver Riddle using namespace mlir; 204562e389SRiver Riddle 219102a16bSMatthias Springer static const StringRef kLineStyleControlFlow = "dashed"; 228d15b7dcSMatthias Springer static const StringRef kLineStyleDataFlow = "solid"; 238d15b7dcSMatthias Springer static const StringRef kShapeNode = "ellipse"; 248d15b7dcSMatthias Springer static const StringRef kShapeNone = "plain"; 258d15b7dcSMatthias Springer 26400ad6f9SRiver Riddle /// Return the size limits for eliding large attributes. 27400ad6f9SRiver Riddle static int64_t getLargeAttributeSizeLimit() { 28400ad6f9SRiver Riddle // Use the default from the printer flags if possible. 29400ad6f9SRiver Riddle if (Optional<int64_t> limit = OpPrintingFlags().getLargeElementsAttrLimit()) 30400ad6f9SRiver Riddle return *limit; 31400ad6f9SRiver Riddle return 16; 32400ad6f9SRiver Riddle } 33400ad6f9SRiver Riddle 348d15b7dcSMatthias Springer /// Return all values printed onto a stream as a string. 358d15b7dcSMatthias Springer static std::string strFromOs(function_ref<void(raw_ostream &)> func) { 368d15b7dcSMatthias Springer std::string buf; 378d15b7dcSMatthias Springer llvm::raw_string_ostream os(buf); 388d15b7dcSMatthias Springer func(os); 398d15b7dcSMatthias Springer return os.str(); 402660623aSJacques Pienaar } 418d15b7dcSMatthias Springer 428d15b7dcSMatthias Springer /// Escape special characters such as '\n' and quotation marks. 438d15b7dcSMatthias Springer static std::string escapeString(std::string str) { 448d15b7dcSMatthias Springer return strFromOs([&](raw_ostream &os) { os.write_escaped(str); }); 452660623aSJacques Pienaar } 468d15b7dcSMatthias Springer 478d15b7dcSMatthias Springer /// Put quotation marks around a given string. 481fc096afSMehdi Amini static std::string quoteString(const std::string &str) { 491fc096afSMehdi Amini return "\"" + str + "\""; 501fc096afSMehdi Amini } 518d15b7dcSMatthias Springer 528d15b7dcSMatthias Springer using AttributeMap = llvm::StringMap<std::string>; 538d15b7dcSMatthias Springer 548d15b7dcSMatthias Springer namespace { 558d15b7dcSMatthias Springer 568d15b7dcSMatthias Springer /// This struct represents a node in the DOT language. Each node has an 578d15b7dcSMatthias Springer /// identifier and an optional identifier for the cluster (subgraph) that 588d15b7dcSMatthias Springer /// contains the node. 598d15b7dcSMatthias Springer /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but 608d15b7dcSMatthias Springer /// not between clusters. However, edges can be clipped to the boundary of a 618d15b7dcSMatthias Springer /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new 628d15b7dcSMatthias Springer /// cluster, an invisible "anchor" node is created. 638d15b7dcSMatthias Springer struct Node { 648d15b7dcSMatthias Springer public: 658d15b7dcSMatthias Springer Node(int id = 0, Optional<int> clusterId = llvm::None) 668d15b7dcSMatthias Springer : id(id), clusterId(clusterId) {} 678d15b7dcSMatthias Springer 688d15b7dcSMatthias Springer int id; 698d15b7dcSMatthias Springer Optional<int> clusterId; 702660623aSJacques Pienaar }; 712660623aSJacques Pienaar 728d15b7dcSMatthias Springer /// This pass generates a Graphviz dataflow visualization of an MLIR operation. 738d15b7dcSMatthias Springer /// Note: See https://www.graphviz.org/doc/info/lang.html for more information 748d15b7dcSMatthias Springer /// about the Graphviz DOT language. 758bd08a9fSUday Bondhugula class PrintOpPass : public ViewOpGraphBase<PrintOpPass> { 768d15b7dcSMatthias Springer public: 778d15b7dcSMatthias Springer PrintOpPass(raw_ostream &os) : os(os) {} 7866b1e629SMatthias Springer PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {} 792660623aSJacques Pienaar 808d15b7dcSMatthias Springer void runOnOperation() override { 818d15b7dcSMatthias Springer emitGraph([&]() { 828d15b7dcSMatthias Springer processOperation(getOperation()); 838d15b7dcSMatthias Springer emitAllEdgeStmts(); 848d15b7dcSMatthias Springer }); 85736ad206SJing Pu } 86736ad206SJing Pu 879102a16bSMatthias Springer /// Create a CFG graph for a region. Used in `Region::viewGraph`. 889102a16bSMatthias Springer void emitRegionCFG(Region ®ion) { 899102a16bSMatthias Springer printControlFlowEdges = true; 909102a16bSMatthias Springer printDataFlowEdges = false; 919102a16bSMatthias Springer emitGraph([&]() { processRegion(region); }); 929102a16bSMatthias Springer } 939102a16bSMatthias Springer 948d15b7dcSMatthias Springer private: 958d15b7dcSMatthias Springer /// Emit all edges. This function should be called after all nodes have been 968d15b7dcSMatthias Springer /// emitted. 978d15b7dcSMatthias Springer void emitAllEdgeStmts() { 988d15b7dcSMatthias Springer for (const std::string &edge : edges) 998d15b7dcSMatthias Springer os << edge << ";\n"; 1008d15b7dcSMatthias Springer edges.clear(); 1018d15b7dcSMatthias Springer } 10217606a10SJing Pu 1038d15b7dcSMatthias Springer /// Emit a cluster (subgraph). The specified builder generates the body of the 1048d15b7dcSMatthias Springer /// cluster. Return the anchor node of the cluster. 1058d15b7dcSMatthias Springer Node emitClusterStmt(function_ref<void()> builder, std::string label = "") { 1068d15b7dcSMatthias Springer int clusterId = ++counter; 1078d15b7dcSMatthias Springer os << "subgraph cluster_" << clusterId << " {\n"; 1088d15b7dcSMatthias Springer os.indent(); 1098d15b7dcSMatthias Springer // Emit invisible anchor node from/to which arrows can be drawn. 1108d15b7dcSMatthias Springer Node anchorNode = emitNodeStmt(" ", kShapeNone); 1111fc096afSMehdi Amini os << attrStmt("label", quoteString(escapeString(std::move(label)))) 1121fc096afSMehdi Amini << ";\n"; 1138d15b7dcSMatthias Springer builder(); 1148d15b7dcSMatthias Springer os.unindent(); 1158d15b7dcSMatthias Springer os << "}\n"; 1168d15b7dcSMatthias Springer return Node(anchorNode.id, clusterId); 1178d15b7dcSMatthias Springer } 1188d15b7dcSMatthias Springer 1198d15b7dcSMatthias Springer /// Generate an attribute statement. 1208d15b7dcSMatthias Springer std::string attrStmt(const Twine &key, const Twine &value) { 1218d15b7dcSMatthias Springer return (key + " = " + value).str(); 1228d15b7dcSMatthias Springer } 1238d15b7dcSMatthias Springer 1248d15b7dcSMatthias Springer /// Emit an attribute list. 1258d15b7dcSMatthias Springer void emitAttrList(raw_ostream &os, const AttributeMap &map) { 1268d15b7dcSMatthias Springer os << "["; 1278d15b7dcSMatthias Springer interleaveComma(map, os, [&](const auto &it) { 128438f700bSMatthias Springer os << this->attrStmt(it.getKey(), it.getValue()); 1298d15b7dcSMatthias Springer }); 1308d15b7dcSMatthias Springer os << "]"; 1318d15b7dcSMatthias Springer } 1328d15b7dcSMatthias Springer 1338d15b7dcSMatthias Springer // Print an MLIR attribute to `os`. Large attributes are truncated. 1348d15b7dcSMatthias Springer void emitMlirAttr(raw_ostream &os, Attribute attr) { 135400ad6f9SRiver Riddle // A value used to elide large container attribute. 136400ad6f9SRiver Riddle int64_t largeAttrLimit = getLargeAttributeSizeLimit(); 1378d15b7dcSMatthias Springer 1382660623aSJacques Pienaar // Always emit splat attributes. 1398d15b7dcSMatthias Springer if (attr.isa<SplatElementsAttr>()) { 1408d15b7dcSMatthias Springer attr.print(os); 1418d15b7dcSMatthias Springer return; 1422660623aSJacques Pienaar } 1432660623aSJacques Pienaar 1442660623aSJacques Pienaar // Elide "big" elements attributes. 1458d15b7dcSMatthias Springer auto elements = attr.dyn_cast<ElementsAttr>(); 146400ad6f9SRiver Riddle if (elements && elements.getNumElements() > largeAttrLimit) { 1472b86e27dSJacques Pienaar os << std::string(elements.getType().getRank(), '[') << "..." 1482b86e27dSJacques Pienaar << std::string(elements.getType().getRank(), ']') << " : " 1492b86e27dSJacques Pienaar << elements.getType(); 1508d15b7dcSMatthias Springer return; 1512660623aSJacques Pienaar } 1522660623aSJacques Pienaar 1538d15b7dcSMatthias Springer auto array = attr.dyn_cast<ArrayAttr>(); 154400ad6f9SRiver Riddle if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) { 155563b5910SJing Pu os << "[...]"; 1568d15b7dcSMatthias Springer return; 157563b5910SJing Pu } 158563b5910SJing Pu 1592660623aSJacques Pienaar // Print all other attributes. 160a87be1c1SMatthias Springer std::string buf; 161a87be1c1SMatthias Springer llvm::raw_string_ostream ss(buf); 162a87be1c1SMatthias Springer attr.print(ss); 163a87be1c1SMatthias Springer os << truncateString(ss.str()); 1642660623aSJacques Pienaar } 1652660623aSJacques Pienaar 1668d15b7dcSMatthias Springer /// Append an edge to the list of edges. 1678d15b7dcSMatthias Springer /// Note: Edges are written to the output stream via `emitAllEdgeStmts`. 1689102a16bSMatthias Springer void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) { 1698d15b7dcSMatthias Springer AttributeMap attrs; 1708d15b7dcSMatthias Springer attrs["style"] = style.str(); 1718d15b7dcSMatthias Springer // Do not label edges that start/end at a cluster boundary. Such edges are 1728d15b7dcSMatthias Springer // clipped at the boundary, but labels are not. This can lead to labels 1738d15b7dcSMatthias Springer // floating around without any edge next to them. 1748d15b7dcSMatthias Springer if (!n1.clusterId && !n2.clusterId) 1751fc096afSMehdi Amini attrs["label"] = quoteString(escapeString(std::move(label))); 1768d15b7dcSMatthias Springer // Use `ltail` and `lhead` to draw edges between clusters. 1778d15b7dcSMatthias Springer if (n1.clusterId) 1788d15b7dcSMatthias Springer attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId); 1798d15b7dcSMatthias Springer if (n2.clusterId) 1808d15b7dcSMatthias Springer attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId); 1812660623aSJacques Pienaar 1828d15b7dcSMatthias Springer edges.push_back(strFromOs([&](raw_ostream &os) { 1838d15b7dcSMatthias Springer os << llvm::format("v%i -> v%i ", n1.id, n2.id); 1848d15b7dcSMatthias Springer emitAttrList(os, attrs); 1858d15b7dcSMatthias Springer })); 18602d7b260SJacques Pienaar } 1872660623aSJacques Pienaar 1888d15b7dcSMatthias Springer /// Emit a graph. The specified builder generates the body of the graph. 1898d15b7dcSMatthias Springer void emitGraph(function_ref<void()> builder) { 1908d15b7dcSMatthias Springer os << "digraph G {\n"; 1918d15b7dcSMatthias Springer os.indent(); 1928d15b7dcSMatthias Springer // Edges between clusters are allowed only in compound mode. 1938d15b7dcSMatthias Springer os << attrStmt("compound", "true") << ";\n"; 1948d15b7dcSMatthias Springer builder(); 1958d15b7dcSMatthias Springer os.unindent(); 1968d15b7dcSMatthias Springer os << "}\n"; 1972660623aSJacques Pienaar } 1982660623aSJacques Pienaar 1998d15b7dcSMatthias Springer /// Emit a node statement. 2008d15b7dcSMatthias Springer Node emitNodeStmt(std::string label, StringRef shape = kShapeNode) { 2018d15b7dcSMatthias Springer int nodeId = ++counter; 2028d15b7dcSMatthias Springer AttributeMap attrs; 2031fc096afSMehdi Amini attrs["label"] = quoteString(escapeString(std::move(label))); 2048d15b7dcSMatthias Springer attrs["shape"] = shape.str(); 2058d15b7dcSMatthias Springer os << llvm::format("v%i ", nodeId); 2068d15b7dcSMatthias Springer emitAttrList(os, attrs); 2078d15b7dcSMatthias Springer os << ";\n"; 2088d15b7dcSMatthias Springer return Node(nodeId); 2092660623aSJacques Pienaar } 2102660623aSJacques Pienaar 2118d15b7dcSMatthias Springer /// Generate a label for an operation. 2128d15b7dcSMatthias Springer std::string getLabel(Operation *op) { 2138d15b7dcSMatthias Springer return strFromOs([&](raw_ostream &os) { 2148d15b7dcSMatthias Springer // Print operation name and type. 215a87be1c1SMatthias Springer os << op->getName(); 216a87be1c1SMatthias Springer if (printResultTypes) { 217a87be1c1SMatthias Springer os << " : ("; 218a87be1c1SMatthias Springer std::string buf; 219a87be1c1SMatthias Springer llvm::raw_string_ostream ss(buf); 220a87be1c1SMatthias Springer interleaveComma(op->getResultTypes(), ss); 221a87be1c1SMatthias Springer os << truncateString(ss.str()) << ")"; 222a87be1c1SMatthias Springer os << ")"; 223a87be1c1SMatthias Springer } 2242660623aSJacques Pienaar 2258d15b7dcSMatthias Springer // Print attributes. 226a87be1c1SMatthias Springer if (printAttrs) { 227a87be1c1SMatthias Springer os << "\n"; 2288d15b7dcSMatthias Springer for (const NamedAttribute &attr : op->getAttrs()) { 2290c7890c8SRiver Riddle os << '\n' << attr.getName().getValue() << ": "; 2300c7890c8SRiver Riddle emitMlirAttr(os, attr.getValue()); 2318d15b7dcSMatthias Springer } 232a87be1c1SMatthias Springer } 2338d15b7dcSMatthias Springer }); 2348d15b7dcSMatthias Springer } 2358d15b7dcSMatthias Springer 2368d15b7dcSMatthias Springer /// Generate a label for a block argument. 2378d15b7dcSMatthias Springer std::string getLabel(BlockArgument arg) { 2388d15b7dcSMatthias Springer return "arg" + std::to_string(arg.getArgNumber()); 2398d15b7dcSMatthias Springer } 2408d15b7dcSMatthias Springer 2418d15b7dcSMatthias Springer /// Process a block. Emit a cluster and one node per block argument and 2428d15b7dcSMatthias Springer /// operation inside the cluster. 2438d15b7dcSMatthias Springer void processBlock(Block &block) { 2448d15b7dcSMatthias Springer emitClusterStmt([&]() { 2458d15b7dcSMatthias Springer for (BlockArgument &blockArg : block.getArguments()) 2468d15b7dcSMatthias Springer valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg)); 2478d15b7dcSMatthias Springer 2488d15b7dcSMatthias Springer // Emit a node for each operation. 2499102a16bSMatthias Springer Optional<Node> prevNode; 2509102a16bSMatthias Springer for (Operation &op : block) { 2519102a16bSMatthias Springer Node nextNode = processOperation(&op); 2529102a16bSMatthias Springer if (printControlFlowEdges && prevNode) 2539102a16bSMatthias Springer emitEdgeStmt(*prevNode, nextNode, /*label=*/"", 2549102a16bSMatthias Springer kLineStyleControlFlow); 2559102a16bSMatthias Springer prevNode = nextNode; 2569102a16bSMatthias Springer } 2578d15b7dcSMatthias Springer }); 2588d15b7dcSMatthias Springer } 2598d15b7dcSMatthias Springer 2608d15b7dcSMatthias Springer /// Process an operation. If the operation has regions, emit a cluster. 2618d15b7dcSMatthias Springer /// Otherwise, emit a node. 2629102a16bSMatthias Springer Node processOperation(Operation *op) { 2638d15b7dcSMatthias Springer Node node; 2648d15b7dcSMatthias Springer if (op->getNumRegions() > 0) { 2658d15b7dcSMatthias Springer // Emit cluster for op with regions. 2668d15b7dcSMatthias Springer node = emitClusterStmt( 2678d15b7dcSMatthias Springer [&]() { 2688d15b7dcSMatthias Springer for (Region ®ion : op->getRegions()) 2698d15b7dcSMatthias Springer processRegion(region); 2708d15b7dcSMatthias Springer }, 2718d15b7dcSMatthias Springer getLabel(op)); 2728d15b7dcSMatthias Springer } else { 2738d15b7dcSMatthias Springer node = emitNodeStmt(getLabel(op)); 2748d15b7dcSMatthias Springer } 2758d15b7dcSMatthias Springer 2769102a16bSMatthias Springer // Insert data flow edges originating from each operand. 2779102a16bSMatthias Springer if (printDataFlowEdges) { 2788d15b7dcSMatthias Springer unsigned numOperands = op->getNumOperands(); 2798d15b7dcSMatthias Springer for (unsigned i = 0; i < numOperands; i++) 2808d15b7dcSMatthias Springer emitEdgeStmt(valueToNode[op->getOperand(i)], node, 2819102a16bSMatthias Springer /*label=*/numOperands == 1 ? "" : std::to_string(i), 2829102a16bSMatthias Springer kLineStyleDataFlow); 2839102a16bSMatthias Springer } 2848d15b7dcSMatthias Springer 2858d15b7dcSMatthias Springer for (Value result : op->getResults()) 2868d15b7dcSMatthias Springer valueToNode[result] = node; 2879102a16bSMatthias Springer 2889102a16bSMatthias Springer return node; 2898d15b7dcSMatthias Springer } 2908d15b7dcSMatthias Springer 2918d15b7dcSMatthias Springer /// Process a region. 2928d15b7dcSMatthias Springer void processRegion(Region ®ion) { 2938d15b7dcSMatthias Springer for (Block &block : region.getBlocks()) 2948d15b7dcSMatthias Springer processBlock(block); 2958d15b7dcSMatthias Springer } 2968d15b7dcSMatthias Springer 297a87be1c1SMatthias Springer /// Truncate long strings. 298a87be1c1SMatthias Springer std::string truncateString(std::string str) { 299a87be1c1SMatthias Springer if (str.length() <= maxLabelLen) 300a87be1c1SMatthias Springer return str; 301a87be1c1SMatthias Springer return str.substr(0, maxLabelLen) + "..."; 302a87be1c1SMatthias Springer } 303a87be1c1SMatthias Springer 3048d15b7dcSMatthias Springer /// Output stream to write DOT file to. 3058d15b7dcSMatthias Springer raw_indented_ostream os; 3068d15b7dcSMatthias Springer /// A list of edges. For simplicity, should be emitted after all nodes were 3078d15b7dcSMatthias Springer /// emitted. 3088d15b7dcSMatthias Springer std::vector<std::string> edges; 3098d15b7dcSMatthias Springer /// Mapping of SSA values to Graphviz nodes/clusters. 3108d15b7dcSMatthias Springer DenseMap<Value, Node> valueToNode; 3118d15b7dcSMatthias Springer /// Counter for generating unique node/subgraph identifiers. 3128d15b7dcSMatthias Springer int counter = 0; 3132660623aSJacques Pienaar }; 3148d15b7dcSMatthias Springer 3152660623aSJacques Pienaar } // namespace 3162660623aSJacques Pienaar 3178d15b7dcSMatthias Springer std::unique_ptr<Pass> 3188d15b7dcSMatthias Springer mlir::createPrintOpGraphPass(raw_ostream &os) { 3198d15b7dcSMatthias Springer return std::make_unique<PrintOpPass>(os); 3202660623aSJacques Pienaar } 3219102a16bSMatthias Springer 3229102a16bSMatthias Springer /// Generate a CFG for a region and show it in a window. 3239102a16bSMatthias Springer static void llvmViewGraph(Region ®ion, const Twine &name) { 3249102a16bSMatthias Springer int fd; 3259102a16bSMatthias Springer std::string filename = llvm::createGraphFilename(name.str(), fd); 3269102a16bSMatthias Springer { 3279102a16bSMatthias Springer llvm::raw_fd_ostream os(fd, /*shouldClose=*/true); 3289102a16bSMatthias Springer if (fd == -1) { 3299102a16bSMatthias Springer llvm::errs() << "error opening file '" << filename << "' for writing\n"; 3309102a16bSMatthias Springer return; 3319102a16bSMatthias Springer } 3329102a16bSMatthias Springer PrintOpPass pass(os); 3339102a16bSMatthias Springer pass.emitRegionCFG(region); 3349102a16bSMatthias Springer } 3359102a16bSMatthias Springer llvm::DisplayGraph(filename, /*wait=*/false, llvm::GraphProgram::DOT); 3369102a16bSMatthias Springer } 3379102a16bSMatthias Springer 3389102a16bSMatthias Springer void mlir::Region::viewGraph(const Twine ®ionName) { 3399102a16bSMatthias Springer llvmViewGraph(*this, regionName); 3409102a16bSMatthias Springer } 3419102a16bSMatthias Springer 3429102a16bSMatthias Springer void mlir::Region::viewGraph() { viewGraph("region"); } 343