1 //===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains interfaces and analyses for defining a nested callgraph.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/CallGraph.h"
14 #include "mlir/Analysis/CallInterfaces.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/SymbolTable.h"
17 #include "llvm/ADT/PointerUnion.h"
18 #include "llvm/ADT/SCCIterator.h"
19 #include "llvm/Support/raw_ostream.h"
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // CallInterfaces
25 //===----------------------------------------------------------------------===//
26 
27 #include "mlir/Analysis/CallInterfaces.cpp.inc"
28 
29 //===----------------------------------------------------------------------===//
30 // CallGraphNode
31 //===----------------------------------------------------------------------===//
32 
33 /// Returns if this node refers to the indirect/external node.
34 bool CallGraphNode::isExternal() const { return !callableRegion; }
35 
36 /// Return the callable region this node represents. This can only be called
37 /// on non-external nodes.
38 Region *CallGraphNode::getCallableRegion() const {
39   assert(!isExternal() && "the external node has no callable region");
40   return callableRegion;
41 }
42 
43 /// Adds an reference edge to the given node. This is only valid on the
44 /// external node.
45 void CallGraphNode::addAbstractEdge(CallGraphNode *node) {
46   assert(isExternal() && "abstract edges are only valid on external nodes");
47   addEdge(node, Edge::Kind::Abstract);
48 }
49 
50 /// Add an outgoing call edge from this node.
51 void CallGraphNode::addCallEdge(CallGraphNode *node) {
52   addEdge(node, Edge::Kind::Call);
53 }
54 
55 /// Adds a reference edge to the given child node.
56 void CallGraphNode::addChildEdge(CallGraphNode *child) {
57   addEdge(child, Edge::Kind::Child);
58 }
59 
60 /// Returns true if this node has any child edges.
61 bool CallGraphNode::hasChildren() const {
62   return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); });
63 }
64 
65 /// Add an edge to 'node' with the given kind.
66 void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
67   edges.insert({node, kind});
68 }
69 
70 //===----------------------------------------------------------------------===//
71 // CallGraph
72 //===----------------------------------------------------------------------===//
73 
74 /// Recursively compute the callgraph edges for the given operation. Computed
75 /// edges are placed into the given callgraph object.
76 static void computeCallGraph(Operation *op, CallGraph &cg,
77                              CallGraphNode *parentNode, bool resolveCalls) {
78   if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) {
79     // If there is no parent node, we ignore this operation. Even if this
80     // operation was a call, there would be no callgraph node to attribute it
81     // to.
82     if (resolveCalls && parentNode)
83       parentNode->addCallEdge(cg.resolveCallable(call));
84     return;
85   }
86 
87   // Compute the callgraph nodes and edges for each of the nested operations.
88   if (CallableOpInterface callable = dyn_cast<CallableOpInterface>(op)) {
89     if (auto *callableRegion = callable.getCallableRegion())
90       parentNode = cg.getOrAddNode(callableRegion, parentNode);
91     else
92       return;
93   }
94 
95   for (Region &region : op->getRegions())
96     for (Block &block : region)
97       for (Operation &nested : block)
98         computeCallGraph(&nested, cg, parentNode, resolveCalls);
99 }
100 
101 CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) {
102   // Make two passes over the graph, one to compute the callables and one to
103   // resolve the calls. We split these up as we may have nested callable objects
104   // that need to be reserved before the calls.
105   computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/false);
106   computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/true);
107 }
108 
109 /// Get or add a call graph node for the given region.
110 CallGraphNode *CallGraph::getOrAddNode(Region *region,
111                                        CallGraphNode *parentNode) {
112   assert(region && isa<CallableOpInterface>(region->getParentOp()) &&
113          "expected parent operation to be callable");
114   std::unique_ptr<CallGraphNode> &node = nodes[region];
115   if (!node) {
116     node.reset(new CallGraphNode(region));
117 
118     // Add this node to the given parent node if necessary.
119     if (parentNode)
120       parentNode->addChildEdge(node.get());
121     else
122       // Otherwise, connect all callable nodes to the external node, this allows
123       // for conservatively including all callable nodes within the graph.
124       // FIXME(riverriddle) This isn't correct, this is only necessary for
125       // callable nodes that *could* be called from external sources. This
126       // requires extending the interface for callables to check if they may be
127       // referenced externally.
128       externalNode.addAbstractEdge(node.get());
129   }
130   return node.get();
131 }
132 
133 /// Lookup a call graph node for the given region, or nullptr if none is
134 /// registered.
135 CallGraphNode *CallGraph::lookupNode(Region *region) const {
136   auto it = nodes.find(region);
137   return it == nodes.end() ? nullptr : it->second.get();
138 }
139 
140 /// Resolve the callable for given callee to a node in the callgraph, or the
141 /// external node if a valid node was not resolved.
142 CallGraphNode *CallGraph::resolveCallable(CallOpInterface call) const {
143   Operation *callable = call.resolveCallable();
144   if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable))
145     if (auto *node = lookupNode(callableOp.getCallableRegion()))
146       return node;
147 
148   // If we don't have a valid direct region, this is an external call.
149   return getExternalNode();
150 }
151 
152 //===----------------------------------------------------------------------===//
153 // Printing
154 
155 /// Dump the graph in a human readable format.
156 void CallGraph::dump() const { print(llvm::errs()); }
157 void CallGraph::print(raw_ostream &os) const {
158   os << "// ---- CallGraph ----\n";
159 
160   // Functor used to output the name for the given node.
161   auto emitNodeName = [&](const CallGraphNode *node) {
162     if (node->isExternal()) {
163       os << "<External-Node>";
164       return;
165     }
166 
167     auto *callableRegion = node->getCallableRegion();
168     auto *parentOp = callableRegion->getParentOp();
169     os << "'" << callableRegion->getParentOp()->getName() << "' - Region #"
170        << callableRegion->getRegionNumber();
171     if (auto attrs = parentOp->getAttrList().getDictionary())
172       os << " : " << attrs;
173   };
174 
175   for (auto &nodeIt : nodes) {
176     const CallGraphNode *node = nodeIt.second.get();
177 
178     // Dump the header for this node.
179     os << "// - Node : ";
180     emitNodeName(node);
181     os << "\n";
182 
183     // Emit each of the edges.
184     for (auto &edge : *node) {
185       os << "// -- ";
186       if (edge.isCall())
187         os << "Call";
188       else if (edge.isChild())
189         os << "Child";
190 
191       os << "-Edge : ";
192       emitNodeName(edge.getTarget());
193       os << "\n";
194     }
195     os << "//\n";
196   }
197 
198   os << "// -- SCCs --\n";
199 
200   for (auto &scc : make_range(llvm::scc_begin(this), llvm::scc_end(this))) {
201     os << "// - SCC : \n";
202     for (auto &node : scc) {
203       os << "// -- Node :";
204       emitNodeName(node);
205       os << "\n";
206     }
207     os << "\n";
208   }
209 
210   os << "// -------------------\n";
211 }
212