1 //===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains interfaces and analyses for defining a nested callgraph.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/CallGraph.h"
14 #include "mlir/Analysis/CallInterfaces.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/SymbolTable.h"
17 #include "llvm/ADT/PointerUnion.h"
18 #include "llvm/ADT/SCCIterator.h"
19 #include "llvm/Support/raw_ostream.h"
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // CallInterfaces
25 //===----------------------------------------------------------------------===//
26 
27 #include "mlir/Analysis/CallInterfaces.cpp.inc"
28 
29 //===----------------------------------------------------------------------===//
30 // CallGraphNode
31 //===----------------------------------------------------------------------===//
32 
33 /// Returns if this node refers to the indirect/external node.
34 bool CallGraphNode::isExternal() const { return !callableRegion; }
35 
36 /// Return the callable region this node represents. This can only be called
37 /// on non-external nodes.
38 Region *CallGraphNode::getCallableRegion() const {
39   assert(!isExternal() && "the external node has no callable region");
40   return callableRegion;
41 }
42 
43 /// Adds an reference edge to the given node. This is only valid on the
44 /// external node.
45 void CallGraphNode::addAbstractEdge(CallGraphNode *node) {
46   assert(isExternal() && "abstract edges are only valid on external nodes");
47   addEdge(node, Edge::Kind::Abstract);
48 }
49 
50 /// Add an outgoing call edge from this node.
51 void CallGraphNode::addCallEdge(CallGraphNode *node) {
52   addEdge(node, Edge::Kind::Call);
53 }
54 
55 /// Adds a reference edge to the given child node.
56 void CallGraphNode::addChildEdge(CallGraphNode *child) {
57   addEdge(child, Edge::Kind::Child);
58 }
59 
60 /// Returns true if this node has any child edges.
61 bool CallGraphNode::hasChildren() const {
62   return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); });
63 }
64 
65 /// Add an edge to 'node' with the given kind.
66 void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
67   edges.insert({node, kind});
68 }
69 
70 //===----------------------------------------------------------------------===//
71 // CallGraph
72 //===----------------------------------------------------------------------===//
73 
74 /// Recursively compute the callgraph edges for the given operation. Computed
75 /// edges are placed into the given callgraph object.
76 static void computeCallGraph(Operation *op, CallGraph &cg,
77                              CallGraphNode *parentNode, bool resolveCalls) {
78   if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) {
79     // If there is no parent node, we ignore this operation. Even if this
80     // operation was a call, there would be no callgraph node to attribute it
81     // to.
82     if (!resolveCalls || !parentNode)
83       return;
84     parentNode->addCallEdge(
85         cg.resolveCallable(call.getCallableForCallee(), op));
86     return;
87   }
88 
89   // Compute the callgraph nodes and edges for each of the nested operations.
90   if (CallableOpInterface callable = dyn_cast<CallableOpInterface>(op)) {
91     if (auto *callableRegion = callable.getCallableRegion())
92       parentNode = cg.getOrAddNode(callableRegion, parentNode);
93     else
94       return;
95   }
96 
97   for (Region &region : op->getRegions())
98     for (Block &block : region)
99       for (Operation &nested : block)
100         computeCallGraph(&nested, cg, parentNode, resolveCalls);
101 }
102 
103 CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) {
104   // Make two passes over the graph, one to compute the callables and one to
105   // resolve the calls. We split these up as we may have nested callable objects
106   // that need to be reserved before the calls.
107   computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/false);
108   computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/true);
109 }
110 
111 /// Get or add a call graph node for the given region.
112 CallGraphNode *CallGraph::getOrAddNode(Region *region,
113                                        CallGraphNode *parentNode) {
114   assert(region && isa<CallableOpInterface>(region->getParentOp()) &&
115          "expected parent operation to be callable");
116   std::unique_ptr<CallGraphNode> &node = nodes[region];
117   if (!node) {
118     node.reset(new CallGraphNode(region));
119 
120     // Add this node to the given parent node if necessary.
121     if (parentNode)
122       parentNode->addChildEdge(node.get());
123     else
124       // Otherwise, connect all callable nodes to the external node, this allows
125       // for conservatively including all callable nodes within the graph.
126       // FIXME(riverriddle) This isn't correct, this is only necessary for
127       // callable nodes that *could* be called from external sources. This
128       // requires extending the interface for callables to check if they may be
129       // referenced externally.
130       externalNode.addAbstractEdge(node.get());
131   }
132   return node.get();
133 }
134 
135 /// Lookup a call graph node for the given region, or nullptr if none is
136 /// registered.
137 CallGraphNode *CallGraph::lookupNode(Region *region) const {
138   auto it = nodes.find(region);
139   return it == nodes.end() ? nullptr : it->second.get();
140 }
141 
142 /// Resolve the callable for given callee to a node in the callgraph, or the
143 /// external node if a valid node was not resolved.
144 CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable,
145                                           Operation *from) const {
146   // Get the callee operation from the callable.
147   Operation *callee;
148   if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>())
149     callee = SymbolTable::lookupNearestSymbolFrom(from, symbolRef);
150   else
151     callee = callable.get<Value>().getDefiningOp();
152 
153   // If the callee is non-null and is a valid callable object, try to get the
154   // called region from it.
155   if (callee && callee->getNumRegions()) {
156     if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callee)) {
157       if (auto *node = lookupNode(callableOp.getCallableRegion()))
158         return node;
159     }
160   }
161 
162   // If we don't have a valid direct region, this is an external call.
163   return getExternalNode();
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // Printing
168 
169 /// Dump the graph in a human readable format.
170 void CallGraph::dump() const { print(llvm::errs()); }
171 void CallGraph::print(raw_ostream &os) const {
172   os << "// ---- CallGraph ----\n";
173 
174   // Functor used to output the name for the given node.
175   auto emitNodeName = [&](const CallGraphNode *node) {
176     if (node->isExternal()) {
177       os << "<External-Node>";
178       return;
179     }
180 
181     auto *callableRegion = node->getCallableRegion();
182     auto *parentOp = callableRegion->getParentOp();
183     os << "'" << callableRegion->getParentOp()->getName() << "' - Region #"
184        << callableRegion->getRegionNumber();
185     if (auto attrs = parentOp->getAttrList().getDictionary())
186       os << " : " << attrs;
187   };
188 
189   for (auto &nodeIt : nodes) {
190     const CallGraphNode *node = nodeIt.second.get();
191 
192     // Dump the header for this node.
193     os << "// - Node : ";
194     emitNodeName(node);
195     os << "\n";
196 
197     // Emit each of the edges.
198     for (auto &edge : *node) {
199       os << "// -- ";
200       if (edge.isCall())
201         os << "Call";
202       else if (edge.isChild())
203         os << "Child";
204 
205       os << "-Edge : ";
206       emitNodeName(edge.getTarget());
207       os << "\n";
208     }
209     os << "//\n";
210   }
211 
212   os << "// -- SCCs --\n";
213 
214   for (auto &scc : make_range(llvm::scc_begin(this), llvm::scc_end(this))) {
215     os << "// - SCC : \n";
216     for (auto &node : scc) {
217       os << "// -- Node :";
218       emitNodeName(node);
219       os << "\n";
220     }
221     os << "\n";
222   }
223 
224   os << "// -------------------\n";
225 }
226