1 //===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file contains interfaces and analyses for defining a nested callgraph.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Analysis/CallGraph.h"
23 #include "mlir/Analysis/CallInterfaces.h"
24 #include "mlir/IR/Operation.h"
25 #include "mlir/IR/SymbolTable.h"
26 #include "llvm/ADT/PointerUnion.h"
27 #include "llvm/ADT/SCCIterator.h"
28 #include "llvm/Support/raw_ostream.h"
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // CallInterfaces
34 //===----------------------------------------------------------------------===//
35 
36 #include "mlir/Analysis/CallInterfaces.cpp.inc"
37 
38 //===----------------------------------------------------------------------===//
39 // CallGraphNode
40 //===----------------------------------------------------------------------===//
41 
42 /// Returns if this node refers to the indirect/external node.
43 bool CallGraphNode::isExternal() const { return !callableRegion; }
44 
45 /// Return the callable region this node represents. This can only be called
46 /// on non-external nodes.
47 Region *CallGraphNode::getCallableRegion() const {
48   assert(!isExternal() && "the external node has no callable region");
49   return callableRegion;
50 }
51 
52 /// Adds an reference edge to the given node. This is only valid on the
53 /// external node.
54 void CallGraphNode::addAbstractEdge(CallGraphNode *node) {
55   assert(isExternal() && "abstract edges are only valid on external nodes");
56   addEdge(node, Edge::Kind::Abstract);
57 }
58 
59 /// Add an outgoing call edge from this node.
60 void CallGraphNode::addCallEdge(CallGraphNode *node) {
61   addEdge(node, Edge::Kind::Call);
62 }
63 
64 /// Adds a reference edge to the given child node.
65 void CallGraphNode::addChildEdge(CallGraphNode *child) {
66   addEdge(child, Edge::Kind::Child);
67 }
68 
69 /// Returns true if this node has any child edges.
70 bool CallGraphNode::hasChildren() const {
71   return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); });
72 }
73 
74 /// Add an edge to 'node' with the given kind.
75 void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
76   edges.insert({node, kind});
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // CallGraph
81 //===----------------------------------------------------------------------===//
82 
83 /// Recursively compute the callgraph edges for the given operation. Computed
84 /// edges are placed into the given callgraph object.
85 static void computeCallGraph(Operation *op, CallGraph &cg,
86                              CallGraphNode *parentNode);
87 
88 /// Compute the set of callgraph nodes that are created by regions nested within
89 /// 'op'.
90 static void computeCallables(Operation *op, CallGraph &cg,
91                              CallGraphNode *parentNode) {
92   if (op->getNumRegions() == 0)
93     return;
94   if (auto callableOp = dyn_cast<CallableOpInterface>(op)) {
95     SmallVector<Region *, 1> callables;
96     callableOp.getCallableRegions(callables);
97     for (auto *callableRegion : callables)
98       cg.getOrAddNode(callableRegion, parentNode);
99   }
100 }
101 
102 /// Recursively compute the callgraph edges within the given region. Computed
103 /// edges are placed into the given callgraph object.
104 static void computeCallGraph(Region &region, CallGraph &cg,
105                              CallGraphNode *parentNode) {
106   // Iterate over the nested operations twice:
107   /// One to fully create nodes in the for each callable region of a nested
108   /// operation;
109   for (auto &block : region)
110     for (auto &nested : block)
111       computeCallables(&nested, cg, parentNode);
112 
113   /// And another to recursively compute the callgraph.
114   for (auto &block : region)
115     for (auto &nested : block)
116       computeCallGraph(&nested, cg, parentNode);
117 }
118 
119 /// Recursively compute the callgraph edges for the given operation. Computed
120 /// edges are placed into the given callgraph object.
121 static void computeCallGraph(Operation *op, CallGraph &cg,
122                              CallGraphNode *parentNode) {
123   // Compute the callgraph nodes and edges for each of the nested operations.
124   auto isCallable = isa<CallableOpInterface>(op);
125   for (auto &region : op->getRegions()) {
126     // Check to see if this region is a callable node, if so this is the parent
127     // node of the nested region.
128     CallGraphNode *nestedParentNode;
129     if (!isCallable || !(nestedParentNode = cg.lookupNode(&region)))
130       nestedParentNode = parentNode;
131     computeCallGraph(region, cg, nestedParentNode);
132   }
133 
134   // If there is no parent node, we ignore this operation. Even if this
135   // operation was a call, there would be no callgraph node to attribute it to.
136   if (!parentNode)
137     return;
138 
139   // If this is a call operation, resolve the callee.
140   if (auto call = dyn_cast<CallOpInterface>(op))
141     parentNode->addCallEdge(
142         cg.resolveCallable(call.getCallableForCallee(), op));
143 }
144 
145 CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) {
146   computeCallGraph(op, *this, /*parentNode=*/nullptr);
147 }
148 
149 /// Get or add a call graph node for the given region.
150 CallGraphNode *CallGraph::getOrAddNode(Region *region,
151                                        CallGraphNode *parentNode) {
152   assert(region && isa<CallableOpInterface>(region->getParentOp()) &&
153          "expected parent operation to be callable");
154   std::unique_ptr<CallGraphNode> &node = nodes[region];
155   if (!node) {
156     node.reset(new CallGraphNode(region));
157 
158     // Add this node to the given parent node if necessary.
159     if (parentNode)
160       parentNode->addChildEdge(node.get());
161     else
162       // Otherwise, connect all callable nodes to the external node, this allows
163       // for conservatively including all callable nodes within the graph.
164       // FIXME(riverriddle) This isn't correct, this is only necessary for
165       // callable nodes that *could* be called from external sources. This
166       // requires extending the interface for callables to check if they may be
167       // referenced externally.
168       externalNode.addAbstractEdge(node.get());
169   }
170   return node.get();
171 }
172 
173 /// Lookup a call graph node for the given region, or nullptr if none is
174 /// registered.
175 CallGraphNode *CallGraph::lookupNode(Region *region) const {
176   auto it = nodes.find(region);
177   return it == nodes.end() ? nullptr : it->second.get();
178 }
179 
180 /// Resolve the callable for given callee to a node in the callgraph, or the
181 /// external node if a valid node was not resolved.
182 CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable,
183                                           Operation *from) const {
184   // Get the callee operation from the callable.
185   Operation *callee;
186   if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>())
187     // TODO(riverriddle) Support nested references.
188     callee = SymbolTable::lookupNearestSymbolFrom(from,
189                                                   symbolRef.getRootReference());
190   else
191     callee = callable.get<ValuePtr>()->getDefiningOp();
192 
193   // If the callee is non-null and is a valid callable object, try to get the
194   // called region from it.
195   if (callee && callee->getNumRegions()) {
196     if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callee)) {
197       if (auto *node = lookupNode(callableOp.getCallableRegion(callable)))
198         return node;
199     }
200   }
201 
202   // If we don't have a valid direct region, this is an external call.
203   return getExternalNode();
204 }
205 
206 //===----------------------------------------------------------------------===//
207 // Printing
208 
209 /// Dump the graph in a human readable format.
210 void CallGraph::dump() const { print(llvm::errs()); }
211 void CallGraph::print(raw_ostream &os) const {
212   os << "// ---- CallGraph ----\n";
213 
214   // Functor used to output the name for the given node.
215   auto emitNodeName = [&](const CallGraphNode *node) {
216     if (node->isExternal()) {
217       os << "<External-Node>";
218       return;
219     }
220 
221     auto *callableRegion = node->getCallableRegion();
222     auto *parentOp = callableRegion->getParentOp();
223     os << "'" << callableRegion->getParentOp()->getName() << "' - Region #"
224        << callableRegion->getRegionNumber();
225     if (auto attrs = parentOp->getAttrList().getDictionary())
226       os << " : " << attrs;
227   };
228 
229   for (auto &nodeIt : nodes) {
230     const CallGraphNode *node = nodeIt.second.get();
231 
232     // Dump the header for this node.
233     os << "// - Node : ";
234     emitNodeName(node);
235     os << "\n";
236 
237     // Emit each of the edges.
238     for (auto &edge : *node) {
239       os << "// -- ";
240       if (edge.isCall())
241         os << "Call";
242       else if (edge.isChild())
243         os << "Child";
244 
245       os << "-Edge : ";
246       emitNodeName(edge.getTarget());
247       os << "\n";
248     }
249     os << "//\n";
250   }
251 
252   os << "// -- SCCs --\n";
253 
254   for (auto &scc : make_range(llvm::scc_begin(this), llvm::scc_end(this))) {
255     os << "// - SCC : \n";
256     for (auto &node : scc) {
257       os << "// -- Node :";
258       emitNodeName(node);
259       os << "\n";
260     }
261     os << "\n";
262   }
263 
264   os << "// -------------------\n";
265 }
266