12660623aSJacques Pienaar //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
22660623aSJacques Pienaar //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62660623aSJacques Pienaar //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
82660623aSJacques Pienaar
92660623aSJacques Pienaar #include "mlir/Transforms/ViewOpGraph.h"
101834ad4aSRiver Riddle #include "PassDetail.h"
112660623aSJacques Pienaar #include "mlir/IR/Block.h"
1236d3efeaSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
132660623aSJacques Pienaar #include "mlir/IR/Operation.h"
148d15b7dcSMatthias Springer #include "mlir/Support/IndentedOstream.h"
158d15b7dcSMatthias Springer #include "llvm/Support/Format.h"
169102a16bSMatthias Springer #include "llvm/Support/GraphWriter.h"
1736d3efeaSRiver Riddle #include <utility>
182660623aSJacques Pienaar
194562e389SRiver Riddle using namespace mlir;
204562e389SRiver Riddle
219102a16bSMatthias Springer static const StringRef kLineStyleControlFlow = "dashed";
228d15b7dcSMatthias Springer static const StringRef kLineStyleDataFlow = "solid";
238d15b7dcSMatthias Springer static const StringRef kShapeNode = "ellipse";
248d15b7dcSMatthias Springer static const StringRef kShapeNone = "plain";
258d15b7dcSMatthias Springer
26400ad6f9SRiver Riddle /// Return the size limits for eliding large attributes.
getLargeAttributeSizeLimit()27400ad6f9SRiver Riddle static int64_t getLargeAttributeSizeLimit() {
28400ad6f9SRiver Riddle // Use the default from the printer flags if possible.
29400ad6f9SRiver Riddle if (Optional<int64_t> limit = OpPrintingFlags().getLargeElementsAttrLimit())
30400ad6f9SRiver Riddle return *limit;
31400ad6f9SRiver Riddle return 16;
32400ad6f9SRiver Riddle }
33400ad6f9SRiver Riddle
348d15b7dcSMatthias Springer /// Return all values printed onto a stream as a string.
strFromOs(function_ref<void (raw_ostream &)> func)358d15b7dcSMatthias Springer static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
368d15b7dcSMatthias Springer std::string buf;
378d15b7dcSMatthias Springer llvm::raw_string_ostream os(buf);
388d15b7dcSMatthias Springer func(os);
398d15b7dcSMatthias Springer return os.str();
402660623aSJacques Pienaar }
418d15b7dcSMatthias Springer
428d15b7dcSMatthias Springer /// Escape special characters such as '\n' and quotation marks.
escapeString(std::string str)438d15b7dcSMatthias Springer static std::string escapeString(std::string str) {
448d15b7dcSMatthias Springer return strFromOs([&](raw_ostream &os) { os.write_escaped(str); });
452660623aSJacques Pienaar }
468d15b7dcSMatthias Springer
478d15b7dcSMatthias Springer /// Put quotation marks around a given string.
quoteString(const std::string & str)481fc096afSMehdi Amini static std::string quoteString(const std::string &str) {
491fc096afSMehdi Amini return "\"" + str + "\"";
501fc096afSMehdi Amini }
518d15b7dcSMatthias Springer
528d15b7dcSMatthias Springer using AttributeMap = llvm::StringMap<std::string>;
538d15b7dcSMatthias Springer
548d15b7dcSMatthias Springer namespace {
558d15b7dcSMatthias Springer
568d15b7dcSMatthias Springer /// This struct represents a node in the DOT language. Each node has an
578d15b7dcSMatthias Springer /// identifier and an optional identifier for the cluster (subgraph) that
588d15b7dcSMatthias Springer /// contains the node.
598d15b7dcSMatthias Springer /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but
608d15b7dcSMatthias Springer /// not between clusters. However, edges can be clipped to the boundary of a
618d15b7dcSMatthias Springer /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new
628d15b7dcSMatthias Springer /// cluster, an invisible "anchor" node is created.
638d15b7dcSMatthias Springer struct Node {
648d15b7dcSMatthias Springer public:
Node__anon2ace0bc30211::Node658d15b7dcSMatthias Springer Node(int id = 0, Optional<int> clusterId = llvm::None)
668d15b7dcSMatthias Springer : id(id), clusterId(clusterId) {}
678d15b7dcSMatthias Springer
688d15b7dcSMatthias Springer int id;
698d15b7dcSMatthias Springer Optional<int> clusterId;
702660623aSJacques Pienaar };
712660623aSJacques Pienaar
728d15b7dcSMatthias Springer /// This pass generates a Graphviz dataflow visualization of an MLIR operation.
738d15b7dcSMatthias Springer /// Note: See https://www.graphviz.org/doc/info/lang.html for more information
748d15b7dcSMatthias Springer /// about the Graphviz DOT language.
758bd08a9fSUday Bondhugula class PrintOpPass : public ViewOpGraphBase<PrintOpPass> {
768d15b7dcSMatthias Springer public:
PrintOpPass(raw_ostream & os)778d15b7dcSMatthias Springer PrintOpPass(raw_ostream &os) : os(os) {}
PrintOpPass(const PrintOpPass & o)7866b1e629SMatthias Springer PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
792660623aSJacques Pienaar
runOnOperation()808d15b7dcSMatthias Springer void runOnOperation() override {
818d15b7dcSMatthias Springer emitGraph([&]() {
828d15b7dcSMatthias Springer processOperation(getOperation());
838d15b7dcSMatthias Springer emitAllEdgeStmts();
848d15b7dcSMatthias Springer });
85736ad206SJing Pu }
86736ad206SJing Pu
879102a16bSMatthias Springer /// Create a CFG graph for a region. Used in `Region::viewGraph`.
emitRegionCFG(Region & region)889102a16bSMatthias Springer void emitRegionCFG(Region ®ion) {
899102a16bSMatthias Springer printControlFlowEdges = true;
909102a16bSMatthias Springer printDataFlowEdges = false;
919102a16bSMatthias Springer emitGraph([&]() { processRegion(region); });
929102a16bSMatthias Springer }
939102a16bSMatthias Springer
948d15b7dcSMatthias Springer private:
958d15b7dcSMatthias Springer /// Emit all edges. This function should be called after all nodes have been
968d15b7dcSMatthias Springer /// emitted.
emitAllEdgeStmts()978d15b7dcSMatthias Springer void emitAllEdgeStmts() {
988d15b7dcSMatthias Springer for (const std::string &edge : edges)
998d15b7dcSMatthias Springer os << edge << ";\n";
1008d15b7dcSMatthias Springer edges.clear();
1018d15b7dcSMatthias Springer }
10217606a10SJing Pu
1038d15b7dcSMatthias Springer /// Emit a cluster (subgraph). The specified builder generates the body of the
1048d15b7dcSMatthias Springer /// cluster. Return the anchor node of the cluster.
emitClusterStmt(function_ref<void ()> builder,std::string label="")1058d15b7dcSMatthias Springer Node emitClusterStmt(function_ref<void()> builder, std::string label = "") {
1068d15b7dcSMatthias Springer int clusterId = ++counter;
1078d15b7dcSMatthias Springer os << "subgraph cluster_" << clusterId << " {\n";
1088d15b7dcSMatthias Springer os.indent();
1098d15b7dcSMatthias Springer // Emit invisible anchor node from/to which arrows can be drawn.
1108d15b7dcSMatthias Springer Node anchorNode = emitNodeStmt(" ", kShapeNone);
1111fc096afSMehdi Amini os << attrStmt("label", quoteString(escapeString(std::move(label))))
1121fc096afSMehdi Amini << ";\n";
1138d15b7dcSMatthias Springer builder();
1148d15b7dcSMatthias Springer os.unindent();
1158d15b7dcSMatthias Springer os << "}\n";
1168d15b7dcSMatthias Springer return Node(anchorNode.id, clusterId);
1178d15b7dcSMatthias Springer }
1188d15b7dcSMatthias Springer
1198d15b7dcSMatthias Springer /// Generate an attribute statement.
attrStmt(const Twine & key,const Twine & value)1208d15b7dcSMatthias Springer std::string attrStmt(const Twine &key, const Twine &value) {
1218d15b7dcSMatthias Springer return (key + " = " + value).str();
1228d15b7dcSMatthias Springer }
1238d15b7dcSMatthias Springer
1248d15b7dcSMatthias Springer /// Emit an attribute list.
emitAttrList(raw_ostream & os,const AttributeMap & map)1258d15b7dcSMatthias Springer void emitAttrList(raw_ostream &os, const AttributeMap &map) {
1268d15b7dcSMatthias Springer os << "[";
1278d15b7dcSMatthias Springer interleaveComma(map, os, [&](const auto &it) {
128438f700bSMatthias Springer os << this->attrStmt(it.getKey(), it.getValue());
1298d15b7dcSMatthias Springer });
1308d15b7dcSMatthias Springer os << "]";
1318d15b7dcSMatthias Springer }
1328d15b7dcSMatthias Springer
1338d15b7dcSMatthias Springer // Print an MLIR attribute to `os`. Large attributes are truncated.
emitMlirAttr(raw_ostream & os,Attribute attr)1348d15b7dcSMatthias Springer void emitMlirAttr(raw_ostream &os, Attribute attr) {
135400ad6f9SRiver Riddle // A value used to elide large container attribute.
136400ad6f9SRiver Riddle int64_t largeAttrLimit = getLargeAttributeSizeLimit();
1378d15b7dcSMatthias Springer
1382660623aSJacques Pienaar // Always emit splat attributes.
1398d15b7dcSMatthias Springer if (attr.isa<SplatElementsAttr>()) {
1408d15b7dcSMatthias Springer attr.print(os);
1418d15b7dcSMatthias Springer return;
1422660623aSJacques Pienaar }
1432660623aSJacques Pienaar
1442660623aSJacques Pienaar // Elide "big" elements attributes.
1458d15b7dcSMatthias Springer auto elements = attr.dyn_cast<ElementsAttr>();
146400ad6f9SRiver Riddle if (elements && elements.getNumElements() > largeAttrLimit) {
1472b86e27dSJacques Pienaar os << std::string(elements.getType().getRank(), '[') << "..."
1482b86e27dSJacques Pienaar << std::string(elements.getType().getRank(), ']') << " : "
1492b86e27dSJacques Pienaar << elements.getType();
1508d15b7dcSMatthias Springer return;
1512660623aSJacques Pienaar }
1522660623aSJacques Pienaar
1538d15b7dcSMatthias Springer auto array = attr.dyn_cast<ArrayAttr>();
154400ad6f9SRiver Riddle if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) {
155563b5910SJing Pu os << "[...]";
1568d15b7dcSMatthias Springer return;
157563b5910SJing Pu }
158563b5910SJing Pu
1592660623aSJacques Pienaar // Print all other attributes.
160a87be1c1SMatthias Springer std::string buf;
161a87be1c1SMatthias Springer llvm::raw_string_ostream ss(buf);
162a87be1c1SMatthias Springer attr.print(ss);
163a87be1c1SMatthias Springer os << truncateString(ss.str());
1642660623aSJacques Pienaar }
1652660623aSJacques Pienaar
1668d15b7dcSMatthias Springer /// Append an edge to the list of edges.
1678d15b7dcSMatthias Springer /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
emitEdgeStmt(Node n1,Node n2,std::string label,StringRef style)1689102a16bSMatthias Springer void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) {
1698d15b7dcSMatthias Springer AttributeMap attrs;
1708d15b7dcSMatthias Springer attrs["style"] = style.str();
1718d15b7dcSMatthias Springer // Do not label edges that start/end at a cluster boundary. Such edges are
1728d15b7dcSMatthias Springer // clipped at the boundary, but labels are not. This can lead to labels
1738d15b7dcSMatthias Springer // floating around without any edge next to them.
1748d15b7dcSMatthias Springer if (!n1.clusterId && !n2.clusterId)
1751fc096afSMehdi Amini attrs["label"] = quoteString(escapeString(std::move(label)));
1768d15b7dcSMatthias Springer // Use `ltail` and `lhead` to draw edges between clusters.
1778d15b7dcSMatthias Springer if (n1.clusterId)
1788d15b7dcSMatthias Springer attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId);
1798d15b7dcSMatthias Springer if (n2.clusterId)
1808d15b7dcSMatthias Springer attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId);
1812660623aSJacques Pienaar
1828d15b7dcSMatthias Springer edges.push_back(strFromOs([&](raw_ostream &os) {
1838d15b7dcSMatthias Springer os << llvm::format("v%i -> v%i ", n1.id, n2.id);
1848d15b7dcSMatthias Springer emitAttrList(os, attrs);
1858d15b7dcSMatthias Springer }));
18602d7b260SJacques Pienaar }
1872660623aSJacques Pienaar
1888d15b7dcSMatthias Springer /// Emit a graph. The specified builder generates the body of the graph.
emitGraph(function_ref<void ()> builder)1898d15b7dcSMatthias Springer void emitGraph(function_ref<void()> builder) {
1908d15b7dcSMatthias Springer os << "digraph G {\n";
1918d15b7dcSMatthias Springer os.indent();
1928d15b7dcSMatthias Springer // Edges between clusters are allowed only in compound mode.
1938d15b7dcSMatthias Springer os << attrStmt("compound", "true") << ";\n";
1948d15b7dcSMatthias Springer builder();
1958d15b7dcSMatthias Springer os.unindent();
1968d15b7dcSMatthias Springer os << "}\n";
1972660623aSJacques Pienaar }
1982660623aSJacques Pienaar
1998d15b7dcSMatthias Springer /// Emit a node statement.
emitNodeStmt(std::string label,StringRef shape=kShapeNode)2008d15b7dcSMatthias Springer Node emitNodeStmt(std::string label, StringRef shape = kShapeNode) {
2018d15b7dcSMatthias Springer int nodeId = ++counter;
2028d15b7dcSMatthias Springer AttributeMap attrs;
2031fc096afSMehdi Amini attrs["label"] = quoteString(escapeString(std::move(label)));
2048d15b7dcSMatthias Springer attrs["shape"] = shape.str();
2058d15b7dcSMatthias Springer os << llvm::format("v%i ", nodeId);
2068d15b7dcSMatthias Springer emitAttrList(os, attrs);
2078d15b7dcSMatthias Springer os << ";\n";
2088d15b7dcSMatthias Springer return Node(nodeId);
2092660623aSJacques Pienaar }
2102660623aSJacques Pienaar
2118d15b7dcSMatthias Springer /// Generate a label for an operation.
getLabel(Operation * op)2128d15b7dcSMatthias Springer std::string getLabel(Operation *op) {
2138d15b7dcSMatthias Springer return strFromOs([&](raw_ostream &os) {
2148d15b7dcSMatthias Springer // Print operation name and type.
215a87be1c1SMatthias Springer os << op->getName();
216a87be1c1SMatthias Springer if (printResultTypes) {
217a87be1c1SMatthias Springer os << " : (";
218a87be1c1SMatthias Springer std::string buf;
219a87be1c1SMatthias Springer llvm::raw_string_ostream ss(buf);
220a87be1c1SMatthias Springer interleaveComma(op->getResultTypes(), ss);
221a87be1c1SMatthias Springer os << truncateString(ss.str()) << ")";
222a87be1c1SMatthias Springer os << ")";
223a87be1c1SMatthias Springer }
2242660623aSJacques Pienaar
2258d15b7dcSMatthias Springer // Print attributes.
226a87be1c1SMatthias Springer if (printAttrs) {
227a87be1c1SMatthias Springer os << "\n";
2288d15b7dcSMatthias Springer for (const NamedAttribute &attr : op->getAttrs()) {
2290c7890c8SRiver Riddle os << '\n' << attr.getName().getValue() << ": ";
2300c7890c8SRiver Riddle emitMlirAttr(os, attr.getValue());
2318d15b7dcSMatthias Springer }
232a87be1c1SMatthias Springer }
2338d15b7dcSMatthias Springer });
2348d15b7dcSMatthias Springer }
2358d15b7dcSMatthias Springer
2368d15b7dcSMatthias Springer /// Generate a label for a block argument.
getLabel(BlockArgument arg)2378d15b7dcSMatthias Springer std::string getLabel(BlockArgument arg) {
2388d15b7dcSMatthias Springer return "arg" + std::to_string(arg.getArgNumber());
2398d15b7dcSMatthias Springer }
2408d15b7dcSMatthias Springer
2418d15b7dcSMatthias Springer /// Process a block. Emit a cluster and one node per block argument and
2428d15b7dcSMatthias Springer /// operation inside the cluster.
processBlock(Block & block)2438d15b7dcSMatthias Springer void processBlock(Block &block) {
2448d15b7dcSMatthias Springer emitClusterStmt([&]() {
2458d15b7dcSMatthias Springer for (BlockArgument &blockArg : block.getArguments())
2468d15b7dcSMatthias Springer valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
2478d15b7dcSMatthias Springer
2488d15b7dcSMatthias Springer // Emit a node for each operation.
2499102a16bSMatthias Springer Optional<Node> prevNode;
2509102a16bSMatthias Springer for (Operation &op : block) {
2519102a16bSMatthias Springer Node nextNode = processOperation(&op);
2529102a16bSMatthias Springer if (printControlFlowEdges && prevNode)
2539102a16bSMatthias Springer emitEdgeStmt(*prevNode, nextNode, /*label=*/"",
2549102a16bSMatthias Springer kLineStyleControlFlow);
2559102a16bSMatthias Springer prevNode = nextNode;
2569102a16bSMatthias Springer }
2578d15b7dcSMatthias Springer });
2588d15b7dcSMatthias Springer }
2598d15b7dcSMatthias Springer
2608d15b7dcSMatthias Springer /// Process an operation. If the operation has regions, emit a cluster.
2618d15b7dcSMatthias Springer /// Otherwise, emit a node.
processOperation(Operation * op)2629102a16bSMatthias Springer Node processOperation(Operation *op) {
2638d15b7dcSMatthias Springer Node node;
2648d15b7dcSMatthias Springer if (op->getNumRegions() > 0) {
2658d15b7dcSMatthias Springer // Emit cluster for op with regions.
2668d15b7dcSMatthias Springer node = emitClusterStmt(
2678d15b7dcSMatthias Springer [&]() {
2688d15b7dcSMatthias Springer for (Region ®ion : op->getRegions())
2698d15b7dcSMatthias Springer processRegion(region);
2708d15b7dcSMatthias Springer },
2718d15b7dcSMatthias Springer getLabel(op));
2728d15b7dcSMatthias Springer } else {
2738d15b7dcSMatthias Springer node = emitNodeStmt(getLabel(op));
2748d15b7dcSMatthias Springer }
2758d15b7dcSMatthias Springer
2769102a16bSMatthias Springer // Insert data flow edges originating from each operand.
2779102a16bSMatthias Springer if (printDataFlowEdges) {
2788d15b7dcSMatthias Springer unsigned numOperands = op->getNumOperands();
2798d15b7dcSMatthias Springer for (unsigned i = 0; i < numOperands; i++)
2808d15b7dcSMatthias Springer emitEdgeStmt(valueToNode[op->getOperand(i)], node,
2819102a16bSMatthias Springer /*label=*/numOperands == 1 ? "" : std::to_string(i),
2829102a16bSMatthias Springer kLineStyleDataFlow);
2839102a16bSMatthias Springer }
2848d15b7dcSMatthias Springer
2858d15b7dcSMatthias Springer for (Value result : op->getResults())
2868d15b7dcSMatthias Springer valueToNode[result] = node;
2879102a16bSMatthias Springer
2889102a16bSMatthias Springer return node;
2898d15b7dcSMatthias Springer }
2908d15b7dcSMatthias Springer
2918d15b7dcSMatthias Springer /// Process a region.
processRegion(Region & region)2928d15b7dcSMatthias Springer void processRegion(Region ®ion) {
2938d15b7dcSMatthias Springer for (Block &block : region.getBlocks())
2948d15b7dcSMatthias Springer processBlock(block);
2958d15b7dcSMatthias Springer }
2968d15b7dcSMatthias Springer
297a87be1c1SMatthias Springer /// Truncate long strings.
truncateString(std::string str)298a87be1c1SMatthias Springer std::string truncateString(std::string str) {
299a87be1c1SMatthias Springer if (str.length() <= maxLabelLen)
300a87be1c1SMatthias Springer return str;
301a87be1c1SMatthias Springer return str.substr(0, maxLabelLen) + "...";
302a87be1c1SMatthias Springer }
303a87be1c1SMatthias Springer
3048d15b7dcSMatthias Springer /// Output stream to write DOT file to.
3058d15b7dcSMatthias Springer raw_indented_ostream os;
3068d15b7dcSMatthias Springer /// A list of edges. For simplicity, should be emitted after all nodes were
3078d15b7dcSMatthias Springer /// emitted.
3088d15b7dcSMatthias Springer std::vector<std::string> edges;
3098d15b7dcSMatthias Springer /// Mapping of SSA values to Graphviz nodes/clusters.
3108d15b7dcSMatthias Springer DenseMap<Value, Node> valueToNode;
3118d15b7dcSMatthias Springer /// Counter for generating unique node/subgraph identifiers.
3128d15b7dcSMatthias Springer int counter = 0;
3132660623aSJacques Pienaar };
3148d15b7dcSMatthias Springer
3152660623aSJacques Pienaar } // namespace
3162660623aSJacques Pienaar
createPrintOpGraphPass(raw_ostream & os)317*b7f93c28SJeff Niu std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) {
3188d15b7dcSMatthias Springer return std::make_unique<PrintOpPass>(os);
3192660623aSJacques Pienaar }
3209102a16bSMatthias Springer
3219102a16bSMatthias Springer /// Generate a CFG for a region and show it in a window.
llvmViewGraph(Region & region,const Twine & name)3229102a16bSMatthias Springer static void llvmViewGraph(Region ®ion, const Twine &name) {
3239102a16bSMatthias Springer int fd;
3249102a16bSMatthias Springer std::string filename = llvm::createGraphFilename(name.str(), fd);
3259102a16bSMatthias Springer {
3269102a16bSMatthias Springer llvm::raw_fd_ostream os(fd, /*shouldClose=*/true);
3279102a16bSMatthias Springer if (fd == -1) {
3289102a16bSMatthias Springer llvm::errs() << "error opening file '" << filename << "' for writing\n";
3299102a16bSMatthias Springer return;
3309102a16bSMatthias Springer }
3319102a16bSMatthias Springer PrintOpPass pass(os);
3329102a16bSMatthias Springer pass.emitRegionCFG(region);
3339102a16bSMatthias Springer }
3349102a16bSMatthias Springer llvm::DisplayGraph(filename, /*wait=*/false, llvm::GraphProgram::DOT);
3359102a16bSMatthias Springer }
3369102a16bSMatthias Springer
viewGraph(const Twine & regionName)3379102a16bSMatthias Springer void mlir::Region::viewGraph(const Twine ®ionName) {
3389102a16bSMatthias Springer llvmViewGraph(*this, regionName);
3399102a16bSMatthias Springer }
3409102a16bSMatthias Springer
viewGraph()3419102a16bSMatthias Springer void mlir::Region::viewGraph() { viewGraph("region"); }
342