1 //===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains interfaces and analyses for defining a nested callgraph.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/CallGraph.h"
14 #include "mlir/Analysis/CallInterfaces.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/SymbolTable.h"
17 #include "llvm/ADT/PointerUnion.h"
18 #include "llvm/ADT/SCCIterator.h"
19 #include "llvm/Support/raw_ostream.h"
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // CallInterfaces
25 //===----------------------------------------------------------------------===//
26 
27 #include "mlir/Analysis/CallInterfaces.cpp.inc"
28 
29 //===----------------------------------------------------------------------===//
30 // CallGraphNode
31 //===----------------------------------------------------------------------===//
32 
33 /// Returns if this node refers to the indirect/external node.
34 bool CallGraphNode::isExternal() const { return !callableRegion; }
35 
36 /// Return the callable region this node represents. This can only be called
37 /// on non-external nodes.
38 Region *CallGraphNode::getCallableRegion() const {
39   assert(!isExternal() && "the external node has no callable region");
40   return callableRegion;
41 }
42 
43 /// Adds an reference edge to the given node. This is only valid on the
44 /// external node.
45 void CallGraphNode::addAbstractEdge(CallGraphNode *node) {
46   assert(isExternal() && "abstract edges are only valid on external nodes");
47   addEdge(node, Edge::Kind::Abstract);
48 }
49 
50 /// Add an outgoing call edge from this node.
51 void CallGraphNode::addCallEdge(CallGraphNode *node) {
52   addEdge(node, Edge::Kind::Call);
53 }
54 
55 /// Adds a reference edge to the given child node.
56 void CallGraphNode::addChildEdge(CallGraphNode *child) {
57   addEdge(child, Edge::Kind::Child);
58 }
59 
60 /// Returns true if this node has any child edges.
61 bool CallGraphNode::hasChildren() const {
62   return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); });
63 }
64 
65 /// Add an edge to 'node' with the given kind.
66 void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
67   edges.insert({node, kind});
68 }
69 
70 //===----------------------------------------------------------------------===//
71 // CallGraph
72 //===----------------------------------------------------------------------===//
73 
74 /// Recursively compute the callgraph edges for the given operation. Computed
75 /// edges are placed into the given callgraph object.
76 static void computeCallGraph(Operation *op, CallGraph &cg,
77                              CallGraphNode *parentNode);
78 
79 /// Compute the set of callgraph nodes that are created by regions nested within
80 /// 'op'.
81 static void computeCallables(Operation *op, CallGraph &cg,
82                              CallGraphNode *parentNode) {
83   if (op->getNumRegions() == 0)
84     return;
85   if (auto callableOp = dyn_cast<CallableOpInterface>(op)) {
86     SmallVector<Region *, 1> callables;
87     callableOp.getCallableRegions(callables);
88     for (auto *callableRegion : callables)
89       cg.getOrAddNode(callableRegion, parentNode);
90   }
91 }
92 
93 /// Recursively compute the callgraph edges within the given region. Computed
94 /// edges are placed into the given callgraph object.
95 static void computeCallGraph(Region &region, CallGraph &cg,
96                              CallGraphNode *parentNode) {
97   // Iterate over the nested operations twice:
98   /// One to fully create nodes in the for each callable region of a nested
99   /// operation;
100   for (auto &block : region)
101     for (auto &nested : block)
102       computeCallables(&nested, cg, parentNode);
103 
104   /// And another to recursively compute the callgraph.
105   for (auto &block : region)
106     for (auto &nested : block)
107       computeCallGraph(&nested, cg, parentNode);
108 }
109 
110 /// Recursively compute the callgraph edges for the given operation. Computed
111 /// edges are placed into the given callgraph object.
112 static void computeCallGraph(Operation *op, CallGraph &cg,
113                              CallGraphNode *parentNode) {
114   // Compute the callgraph nodes and edges for each of the nested operations.
115   auto isCallable = isa<CallableOpInterface>(op);
116   for (auto &region : op->getRegions()) {
117     // Check to see if this region is a callable node, if so this is the parent
118     // node of the nested region.
119     CallGraphNode *nestedParentNode;
120     if (!isCallable || !(nestedParentNode = cg.lookupNode(&region)))
121       nestedParentNode = parentNode;
122     computeCallGraph(region, cg, nestedParentNode);
123   }
124 
125   // If there is no parent node, we ignore this operation. Even if this
126   // operation was a call, there would be no callgraph node to attribute it to.
127   if (!parentNode)
128     return;
129 
130   // If this is a call operation, resolve the callee.
131   if (auto call = dyn_cast<CallOpInterface>(op))
132     parentNode->addCallEdge(
133         cg.resolveCallable(call.getCallableForCallee(), op));
134 }
135 
136 CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) {
137   computeCallGraph(op, *this, /*parentNode=*/nullptr);
138 }
139 
140 /// Get or add a call graph node for the given region.
141 CallGraphNode *CallGraph::getOrAddNode(Region *region,
142                                        CallGraphNode *parentNode) {
143   assert(region && isa<CallableOpInterface>(region->getParentOp()) &&
144          "expected parent operation to be callable");
145   std::unique_ptr<CallGraphNode> &node = nodes[region];
146   if (!node) {
147     node.reset(new CallGraphNode(region));
148 
149     // Add this node to the given parent node if necessary.
150     if (parentNode)
151       parentNode->addChildEdge(node.get());
152     else
153       // Otherwise, connect all callable nodes to the external node, this allows
154       // for conservatively including all callable nodes within the graph.
155       // FIXME(riverriddle) This isn't correct, this is only necessary for
156       // callable nodes that *could* be called from external sources. This
157       // requires extending the interface for callables to check if they may be
158       // referenced externally.
159       externalNode.addAbstractEdge(node.get());
160   }
161   return node.get();
162 }
163 
164 /// Lookup a call graph node for the given region, or nullptr if none is
165 /// registered.
166 CallGraphNode *CallGraph::lookupNode(Region *region) const {
167   auto it = nodes.find(region);
168   return it == nodes.end() ? nullptr : it->second.get();
169 }
170 
171 /// Resolve the callable for given callee to a node in the callgraph, or the
172 /// external node if a valid node was not resolved.
173 CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable,
174                                           Operation *from) const {
175   // Get the callee operation from the callable.
176   Operation *callee;
177   if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>())
178     // TODO(riverriddle) Support nested references.
179     callee = SymbolTable::lookupNearestSymbolFrom(from,
180                                                   symbolRef.getRootReference());
181   else
182     callee = callable.get<Value>().getDefiningOp();
183 
184   // If the callee is non-null and is a valid callable object, try to get the
185   // called region from it.
186   if (callee && callee->getNumRegions()) {
187     if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callee)) {
188       if (auto *node = lookupNode(callableOp.getCallableRegion(callable)))
189         return node;
190     }
191   }
192 
193   // If we don't have a valid direct region, this is an external call.
194   return getExternalNode();
195 }
196 
197 //===----------------------------------------------------------------------===//
198 // Printing
199 
200 /// Dump the graph in a human readable format.
201 void CallGraph::dump() const { print(llvm::errs()); }
202 void CallGraph::print(raw_ostream &os) const {
203   os << "// ---- CallGraph ----\n";
204 
205   // Functor used to output the name for the given node.
206   auto emitNodeName = [&](const CallGraphNode *node) {
207     if (node->isExternal()) {
208       os << "<External-Node>";
209       return;
210     }
211 
212     auto *callableRegion = node->getCallableRegion();
213     auto *parentOp = callableRegion->getParentOp();
214     os << "'" << callableRegion->getParentOp()->getName() << "' - Region #"
215        << callableRegion->getRegionNumber();
216     if (auto attrs = parentOp->getAttrList().getDictionary())
217       os << " : " << attrs;
218   };
219 
220   for (auto &nodeIt : nodes) {
221     const CallGraphNode *node = nodeIt.second.get();
222 
223     // Dump the header for this node.
224     os << "// - Node : ";
225     emitNodeName(node);
226     os << "\n";
227 
228     // Emit each of the edges.
229     for (auto &edge : *node) {
230       os << "// -- ";
231       if (edge.isCall())
232         os << "Call";
233       else if (edge.isChild())
234         os << "Child";
235 
236       os << "-Edge : ";
237       emitNodeName(edge.getTarget());
238       os << "\n";
239     }
240     os << "//\n";
241   }
242 
243   os << "// -- SCCs --\n";
244 
245   for (auto &scc : make_range(llvm::scc_begin(this), llvm::scc_end(this))) {
246     os << "// - SCC : \n";
247     for (auto &node : scc) {
248       os << "// -- Node :";
249       emitNodeName(node);
250       os << "\n";
251     }
252     os << "\n";
253   }
254 
255   os << "// -------------------\n";
256 }
257