12660623aSJacques Pienaar //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
22660623aSJacques Pienaar //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62660623aSJacques Pienaar //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
82660623aSJacques Pienaar 
92660623aSJacques Pienaar #include "mlir/Transforms/ViewOpGraph.h"
101834ad4aSRiver Riddle #include "PassDetail.h"
112660623aSJacques Pienaar #include "mlir/IR/Block.h"
122660623aSJacques Pienaar #include "mlir/IR/Operation.h"
138d15b7dcSMatthias Springer #include "mlir/Support/IndentedOstream.h"
148d15b7dcSMatthias Springer #include "llvm/Support/Format.h"
159102a16bSMatthias Springer #include "llvm/Support/GraphWriter.h"
162660623aSJacques Pienaar 
174562e389SRiver Riddle using namespace mlir;
184562e389SRiver Riddle 
199102a16bSMatthias Springer static const StringRef kLineStyleControlFlow = "dashed";
208d15b7dcSMatthias Springer static const StringRef kLineStyleDataFlow = "solid";
218d15b7dcSMatthias Springer static const StringRef kShapeNode = "ellipse";
228d15b7dcSMatthias Springer static const StringRef kShapeNone = "plain";
238d15b7dcSMatthias Springer 
24400ad6f9SRiver Riddle /// Return the size limits for eliding large attributes.
25400ad6f9SRiver Riddle static int64_t getLargeAttributeSizeLimit() {
26400ad6f9SRiver Riddle   // Use the default from the printer flags if possible.
27400ad6f9SRiver Riddle   if (Optional<int64_t> limit = OpPrintingFlags().getLargeElementsAttrLimit())
28400ad6f9SRiver Riddle     return *limit;
29400ad6f9SRiver Riddle   return 16;
30400ad6f9SRiver Riddle }
31400ad6f9SRiver Riddle 
328d15b7dcSMatthias Springer /// Return all values printed onto a stream as a string.
338d15b7dcSMatthias Springer static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
348d15b7dcSMatthias Springer   std::string buf;
358d15b7dcSMatthias Springer   llvm::raw_string_ostream os(buf);
368d15b7dcSMatthias Springer   func(os);
378d15b7dcSMatthias Springer   return os.str();
382660623aSJacques Pienaar }
398d15b7dcSMatthias Springer 
408d15b7dcSMatthias Springer /// Escape special characters such as '\n' and quotation marks.
418d15b7dcSMatthias Springer static std::string escapeString(std::string str) {
428d15b7dcSMatthias Springer   return strFromOs([&](raw_ostream &os) { os.write_escaped(str); });
432660623aSJacques Pienaar }
448d15b7dcSMatthias Springer 
458d15b7dcSMatthias Springer /// Put quotation marks around a given string.
468d15b7dcSMatthias Springer static std::string quoteString(std::string str) { return "\"" + str + "\""; }
478d15b7dcSMatthias Springer 
488d15b7dcSMatthias Springer using AttributeMap = llvm::StringMap<std::string>;
498d15b7dcSMatthias Springer 
508d15b7dcSMatthias Springer namespace {
518d15b7dcSMatthias Springer 
528d15b7dcSMatthias Springer /// This struct represents a node in the DOT language. Each node has an
538d15b7dcSMatthias Springer /// identifier and an optional identifier for the cluster (subgraph) that
548d15b7dcSMatthias Springer /// contains the node.
558d15b7dcSMatthias Springer /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but
568d15b7dcSMatthias Springer /// not between clusters. However, edges can be clipped to the boundary of a
578d15b7dcSMatthias Springer /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new
588d15b7dcSMatthias Springer /// cluster, an invisible "anchor" node is created.
598d15b7dcSMatthias Springer struct Node {
608d15b7dcSMatthias Springer public:
618d15b7dcSMatthias Springer   Node(int id = 0, Optional<int> clusterId = llvm::None)
628d15b7dcSMatthias Springer       : id(id), clusterId(clusterId) {}
638d15b7dcSMatthias Springer 
648d15b7dcSMatthias Springer   int id;
658d15b7dcSMatthias Springer   Optional<int> clusterId;
662660623aSJacques Pienaar };
672660623aSJacques Pienaar 
688d15b7dcSMatthias Springer /// This pass generates a Graphviz dataflow visualization of an MLIR operation.
698d15b7dcSMatthias Springer /// Note: See https://www.graphviz.org/doc/info/lang.html for more information
708d15b7dcSMatthias Springer /// about the Graphviz DOT language.
718d15b7dcSMatthias Springer class PrintOpPass : public ViewOpGraphPassBase<PrintOpPass> {
728d15b7dcSMatthias Springer public:
738d15b7dcSMatthias Springer   PrintOpPass(raw_ostream &os) : os(os) {}
748d15b7dcSMatthias Springer   PrintOpPass(const PrintOpPass &o) : os(o.os.getOStream()) {}
752660623aSJacques Pienaar 
768d15b7dcSMatthias Springer   void runOnOperation() override {
778d15b7dcSMatthias Springer     emitGraph([&]() {
788d15b7dcSMatthias Springer       processOperation(getOperation());
798d15b7dcSMatthias Springer       emitAllEdgeStmts();
808d15b7dcSMatthias Springer     });
81736ad206SJing Pu   }
82736ad206SJing Pu 
839102a16bSMatthias Springer   /// Create a CFG graph for a region. Used in `Region::viewGraph`.
849102a16bSMatthias Springer   void emitRegionCFG(Region &region) {
859102a16bSMatthias Springer     printControlFlowEdges = true;
869102a16bSMatthias Springer     printDataFlowEdges = false;
879102a16bSMatthias Springer     emitGraph([&]() { processRegion(region); });
889102a16bSMatthias Springer   }
899102a16bSMatthias Springer 
908d15b7dcSMatthias Springer private:
918d15b7dcSMatthias Springer   /// Emit all edges. This function should be called after all nodes have been
928d15b7dcSMatthias Springer   /// emitted.
938d15b7dcSMatthias Springer   void emitAllEdgeStmts() {
948d15b7dcSMatthias Springer     for (const std::string &edge : edges)
958d15b7dcSMatthias Springer       os << edge << ";\n";
968d15b7dcSMatthias Springer     edges.clear();
978d15b7dcSMatthias Springer   }
9817606a10SJing Pu 
998d15b7dcSMatthias Springer   /// Emit a cluster (subgraph). The specified builder generates the body of the
1008d15b7dcSMatthias Springer   /// cluster. Return the anchor node of the cluster.
1018d15b7dcSMatthias Springer   Node emitClusterStmt(function_ref<void()> builder, std::string label = "") {
1028d15b7dcSMatthias Springer     int clusterId = ++counter;
1038d15b7dcSMatthias Springer     os << "subgraph cluster_" << clusterId << " {\n";
1048d15b7dcSMatthias Springer     os.indent();
1058d15b7dcSMatthias Springer     // Emit invisible anchor node from/to which arrows can be drawn.
1068d15b7dcSMatthias Springer     Node anchorNode = emitNodeStmt(" ", kShapeNone);
1078d15b7dcSMatthias Springer     os << attrStmt("label", quoteString(escapeString(label))) << ";\n";
1088d15b7dcSMatthias Springer     builder();
1098d15b7dcSMatthias Springer     os.unindent();
1108d15b7dcSMatthias Springer     os << "}\n";
1118d15b7dcSMatthias Springer     return Node(anchorNode.id, clusterId);
1128d15b7dcSMatthias Springer   }
1138d15b7dcSMatthias Springer 
1148d15b7dcSMatthias Springer   /// Generate an attribute statement.
1158d15b7dcSMatthias Springer   std::string attrStmt(const Twine &key, const Twine &value) {
1168d15b7dcSMatthias Springer     return (key + " = " + value).str();
1178d15b7dcSMatthias Springer   }
1188d15b7dcSMatthias Springer 
1198d15b7dcSMatthias Springer   /// Emit an attribute list.
1208d15b7dcSMatthias Springer   void emitAttrList(raw_ostream &os, const AttributeMap &map) {
1218d15b7dcSMatthias Springer     os << "[";
1228d15b7dcSMatthias Springer     interleaveComma(map, os, [&](const auto &it) {
123*438f700bSMatthias Springer       os << this->attrStmt(it.getKey(), it.getValue());
1248d15b7dcSMatthias Springer     });
1258d15b7dcSMatthias Springer     os << "]";
1268d15b7dcSMatthias Springer   }
1278d15b7dcSMatthias Springer 
1288d15b7dcSMatthias Springer   // Print an MLIR attribute to `os`. Large attributes are truncated.
1298d15b7dcSMatthias Springer   void emitMlirAttr(raw_ostream &os, Attribute attr) {
130400ad6f9SRiver Riddle     // A value used to elide large container attribute.
131400ad6f9SRiver Riddle     int64_t largeAttrLimit = getLargeAttributeSizeLimit();
1328d15b7dcSMatthias Springer 
1332660623aSJacques Pienaar     // Always emit splat attributes.
1348d15b7dcSMatthias Springer     if (attr.isa<SplatElementsAttr>()) {
1358d15b7dcSMatthias Springer       attr.print(os);
1368d15b7dcSMatthias Springer       return;
1372660623aSJacques Pienaar     }
1382660623aSJacques Pienaar 
1392660623aSJacques Pienaar     // Elide "big" elements attributes.
1408d15b7dcSMatthias Springer     auto elements = attr.dyn_cast<ElementsAttr>();
141400ad6f9SRiver Riddle     if (elements && elements.getNumElements() > largeAttrLimit) {
1422b86e27dSJacques Pienaar       os << std::string(elements.getType().getRank(), '[') << "..."
1432b86e27dSJacques Pienaar          << std::string(elements.getType().getRank(), ']') << " : "
1442b86e27dSJacques Pienaar          << elements.getType();
1458d15b7dcSMatthias Springer       return;
1462660623aSJacques Pienaar     }
1472660623aSJacques Pienaar 
1488d15b7dcSMatthias Springer     auto array = attr.dyn_cast<ArrayAttr>();
149400ad6f9SRiver Riddle     if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) {
150563b5910SJing Pu       os << "[...]";
1518d15b7dcSMatthias Springer       return;
152563b5910SJing Pu     }
153563b5910SJing Pu 
1542660623aSJacques Pienaar     // Print all other attributes.
155a87be1c1SMatthias Springer     std::string buf;
156a87be1c1SMatthias Springer     llvm::raw_string_ostream ss(buf);
157a87be1c1SMatthias Springer     attr.print(ss);
158a87be1c1SMatthias Springer     os << truncateString(ss.str());
1592660623aSJacques Pienaar   }
1602660623aSJacques Pienaar 
1618d15b7dcSMatthias Springer   /// Append an edge to the list of edges.
1628d15b7dcSMatthias Springer   /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
1639102a16bSMatthias Springer   void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) {
1648d15b7dcSMatthias Springer     AttributeMap attrs;
1658d15b7dcSMatthias Springer     attrs["style"] = style.str();
1668d15b7dcSMatthias Springer     // Do not label edges that start/end at a cluster boundary. Such edges are
1678d15b7dcSMatthias Springer     // clipped at the boundary, but labels are not. This can lead to labels
1688d15b7dcSMatthias Springer     // floating around without any edge next to them.
1698d15b7dcSMatthias Springer     if (!n1.clusterId && !n2.clusterId)
1708d15b7dcSMatthias Springer       attrs["label"] = quoteString(escapeString(label));
1718d15b7dcSMatthias Springer     // Use `ltail` and `lhead` to draw edges between clusters.
1728d15b7dcSMatthias Springer     if (n1.clusterId)
1738d15b7dcSMatthias Springer       attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId);
1748d15b7dcSMatthias Springer     if (n2.clusterId)
1758d15b7dcSMatthias Springer       attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId);
1762660623aSJacques Pienaar 
1778d15b7dcSMatthias Springer     edges.push_back(strFromOs([&](raw_ostream &os) {
1788d15b7dcSMatthias Springer       os << llvm::format("v%i -> v%i ", n1.id, n2.id);
1798d15b7dcSMatthias Springer       emitAttrList(os, attrs);
1808d15b7dcSMatthias Springer     }));
18102d7b260SJacques Pienaar   }
1822660623aSJacques Pienaar 
1838d15b7dcSMatthias Springer   /// Emit a graph. The specified builder generates the body of the graph.
1848d15b7dcSMatthias Springer   void emitGraph(function_ref<void()> builder) {
1858d15b7dcSMatthias Springer     os << "digraph G {\n";
1868d15b7dcSMatthias Springer     os.indent();
1878d15b7dcSMatthias Springer     // Edges between clusters are allowed only in compound mode.
1888d15b7dcSMatthias Springer     os << attrStmt("compound", "true") << ";\n";
1898d15b7dcSMatthias Springer     builder();
1908d15b7dcSMatthias Springer     os.unindent();
1918d15b7dcSMatthias Springer     os << "}\n";
1922660623aSJacques Pienaar   }
1932660623aSJacques Pienaar 
1948d15b7dcSMatthias Springer   /// Emit a node statement.
1958d15b7dcSMatthias Springer   Node emitNodeStmt(std::string label, StringRef shape = kShapeNode) {
1968d15b7dcSMatthias Springer     int nodeId = ++counter;
1978d15b7dcSMatthias Springer     AttributeMap attrs;
1988d15b7dcSMatthias Springer     attrs["label"] = quoteString(escapeString(label));
1998d15b7dcSMatthias Springer     attrs["shape"] = shape.str();
2008d15b7dcSMatthias Springer     os << llvm::format("v%i ", nodeId);
2018d15b7dcSMatthias Springer     emitAttrList(os, attrs);
2028d15b7dcSMatthias Springer     os << ";\n";
2038d15b7dcSMatthias Springer     return Node(nodeId);
2042660623aSJacques Pienaar   }
2052660623aSJacques Pienaar 
2068d15b7dcSMatthias Springer   /// Generate a label for an operation.
2078d15b7dcSMatthias Springer   std::string getLabel(Operation *op) {
2088d15b7dcSMatthias Springer     return strFromOs([&](raw_ostream &os) {
2098d15b7dcSMatthias Springer       // Print operation name and type.
210a87be1c1SMatthias Springer       os << op->getName();
211a87be1c1SMatthias Springer       if (printResultTypes) {
212a87be1c1SMatthias Springer         os << " : (";
213a87be1c1SMatthias Springer         std::string buf;
214a87be1c1SMatthias Springer         llvm::raw_string_ostream ss(buf);
215a87be1c1SMatthias Springer         interleaveComma(op->getResultTypes(), ss);
216a87be1c1SMatthias Springer         os << truncateString(ss.str()) << ")";
217a87be1c1SMatthias Springer         os << ")";
218a87be1c1SMatthias Springer       }
2192660623aSJacques Pienaar 
2208d15b7dcSMatthias Springer       // Print attributes.
221a87be1c1SMatthias Springer       if (printAttrs) {
222a87be1c1SMatthias Springer         os << "\n";
2238d15b7dcSMatthias Springer         for (const NamedAttribute &attr : op->getAttrs()) {
2248d15b7dcSMatthias Springer           os << '\n' << attr.first << ": ";
2258d15b7dcSMatthias Springer           emitMlirAttr(os, attr.second);
2268d15b7dcSMatthias Springer         }
227a87be1c1SMatthias Springer       }
2288d15b7dcSMatthias Springer     });
2298d15b7dcSMatthias Springer   }
2308d15b7dcSMatthias Springer 
2318d15b7dcSMatthias Springer   /// Generate a label for a block argument.
2328d15b7dcSMatthias Springer   std::string getLabel(BlockArgument arg) {
2338d15b7dcSMatthias Springer     return "arg" + std::to_string(arg.getArgNumber());
2348d15b7dcSMatthias Springer   }
2358d15b7dcSMatthias Springer 
2368d15b7dcSMatthias Springer   /// Process a block. Emit a cluster and one node per block argument and
2378d15b7dcSMatthias Springer   /// operation inside the cluster.
2388d15b7dcSMatthias Springer   void processBlock(Block &block) {
2398d15b7dcSMatthias Springer     emitClusterStmt([&]() {
2408d15b7dcSMatthias Springer       for (BlockArgument &blockArg : block.getArguments())
2418d15b7dcSMatthias Springer         valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
2428d15b7dcSMatthias Springer 
2438d15b7dcSMatthias Springer       // Emit a node for each operation.
2449102a16bSMatthias Springer       Optional<Node> prevNode;
2459102a16bSMatthias Springer       for (Operation &op : block) {
2469102a16bSMatthias Springer         Node nextNode = processOperation(&op);
2479102a16bSMatthias Springer         if (printControlFlowEdges && prevNode)
2489102a16bSMatthias Springer           emitEdgeStmt(*prevNode, nextNode, /*label=*/"",
2499102a16bSMatthias Springer                        kLineStyleControlFlow);
2509102a16bSMatthias Springer         prevNode = nextNode;
2519102a16bSMatthias Springer       }
2528d15b7dcSMatthias Springer     });
2538d15b7dcSMatthias Springer   }
2548d15b7dcSMatthias Springer 
2558d15b7dcSMatthias Springer   /// Process an operation. If the operation has regions, emit a cluster.
2568d15b7dcSMatthias Springer   /// Otherwise, emit a node.
2579102a16bSMatthias Springer   Node processOperation(Operation *op) {
2588d15b7dcSMatthias Springer     Node node;
2598d15b7dcSMatthias Springer     if (op->getNumRegions() > 0) {
2608d15b7dcSMatthias Springer       // Emit cluster for op with regions.
2618d15b7dcSMatthias Springer       node = emitClusterStmt(
2628d15b7dcSMatthias Springer           [&]() {
2638d15b7dcSMatthias Springer             for (Region &region : op->getRegions())
2648d15b7dcSMatthias Springer               processRegion(region);
2658d15b7dcSMatthias Springer           },
2668d15b7dcSMatthias Springer           getLabel(op));
2678d15b7dcSMatthias Springer     } else {
2688d15b7dcSMatthias Springer       node = emitNodeStmt(getLabel(op));
2698d15b7dcSMatthias Springer     }
2708d15b7dcSMatthias Springer 
2719102a16bSMatthias Springer     // Insert data flow edges originating from each operand.
2729102a16bSMatthias Springer     if (printDataFlowEdges) {
2738d15b7dcSMatthias Springer       unsigned numOperands = op->getNumOperands();
2748d15b7dcSMatthias Springer       for (unsigned i = 0; i < numOperands; i++)
2758d15b7dcSMatthias Springer         emitEdgeStmt(valueToNode[op->getOperand(i)], node,
2769102a16bSMatthias Springer                      /*label=*/numOperands == 1 ? "" : std::to_string(i),
2779102a16bSMatthias Springer                      kLineStyleDataFlow);
2789102a16bSMatthias Springer     }
2798d15b7dcSMatthias Springer 
2808d15b7dcSMatthias Springer     for (Value result : op->getResults())
2818d15b7dcSMatthias Springer       valueToNode[result] = node;
2829102a16bSMatthias Springer 
2839102a16bSMatthias Springer     return node;
2848d15b7dcSMatthias Springer   }
2858d15b7dcSMatthias Springer 
2868d15b7dcSMatthias Springer   /// Process a region.
2878d15b7dcSMatthias Springer   void processRegion(Region &region) {
2888d15b7dcSMatthias Springer     for (Block &block : region.getBlocks())
2898d15b7dcSMatthias Springer       processBlock(block);
2908d15b7dcSMatthias Springer   }
2918d15b7dcSMatthias Springer 
292a87be1c1SMatthias Springer   /// Truncate long strings.
293a87be1c1SMatthias Springer   std::string truncateString(std::string str) {
294a87be1c1SMatthias Springer     if (str.length() <= maxLabelLen)
295a87be1c1SMatthias Springer       return str;
296a87be1c1SMatthias Springer     return str.substr(0, maxLabelLen) + "...";
297a87be1c1SMatthias Springer   }
298a87be1c1SMatthias Springer 
2998d15b7dcSMatthias Springer   /// Output stream to write DOT file to.
3008d15b7dcSMatthias Springer   raw_indented_ostream os;
3018d15b7dcSMatthias Springer   /// A list of edges. For simplicity, should be emitted after all nodes were
3028d15b7dcSMatthias Springer   /// emitted.
3038d15b7dcSMatthias Springer   std::vector<std::string> edges;
3048d15b7dcSMatthias Springer   /// Mapping of SSA values to Graphviz nodes/clusters.
3058d15b7dcSMatthias Springer   DenseMap<Value, Node> valueToNode;
3068d15b7dcSMatthias Springer   /// Counter for generating unique node/subgraph identifiers.
3078d15b7dcSMatthias Springer   int counter = 0;
3082660623aSJacques Pienaar };
3098d15b7dcSMatthias Springer 
3102660623aSJacques Pienaar } // namespace
3112660623aSJacques Pienaar 
3128d15b7dcSMatthias Springer std::unique_ptr<Pass>
3138d15b7dcSMatthias Springer mlir::createPrintOpGraphPass(raw_ostream &os) {
3148d15b7dcSMatthias Springer   return std::make_unique<PrintOpPass>(os);
3152660623aSJacques Pienaar }
3169102a16bSMatthias Springer 
3179102a16bSMatthias Springer /// Generate a CFG for a region and show it in a window.
3189102a16bSMatthias Springer static void llvmViewGraph(Region &region, const Twine &name) {
3199102a16bSMatthias Springer   int fd;
3209102a16bSMatthias Springer   std::string filename = llvm::createGraphFilename(name.str(), fd);
3219102a16bSMatthias Springer   {
3229102a16bSMatthias Springer     llvm::raw_fd_ostream os(fd, /*shouldClose=*/true);
3239102a16bSMatthias Springer     if (fd == -1) {
3249102a16bSMatthias Springer       llvm::errs() << "error opening file '" << filename << "' for writing\n";
3259102a16bSMatthias Springer       return;
3269102a16bSMatthias Springer     }
3279102a16bSMatthias Springer     PrintOpPass pass(os);
3289102a16bSMatthias Springer     pass.emitRegionCFG(region);
3299102a16bSMatthias Springer   }
3309102a16bSMatthias Springer   llvm::DisplayGraph(filename, /*wait=*/false, llvm::GraphProgram::DOT);
3319102a16bSMatthias Springer }
3329102a16bSMatthias Springer 
3339102a16bSMatthias Springer void mlir::Region::viewGraph(const Twine &regionName) {
3349102a16bSMatthias Springer   llvmViewGraph(*this, regionName);
3359102a16bSMatthias Springer }
3369102a16bSMatthias Springer 
3379102a16bSMatthias Springer void mlir::Region::viewGraph() { viewGraph("region"); }
338