1 //===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file contains interfaces and analyses for defining a nested callgraph.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Analysis/CallGraph.h"
23 #include "mlir/Analysis/CallInterfaces.h"
24 #include "mlir/IR/Operation.h"
25 #include "mlir/IR/SymbolTable.h"
26 #include "llvm/ADT/PointerUnion.h"
27 #include "llvm/ADT/SCCIterator.h"
28 #include "llvm/Support/raw_ostream.h"
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // CallInterfaces
34 //===----------------------------------------------------------------------===//
35 
36 #include "mlir/Analysis/CallInterfaces.cpp.inc"
37 
38 //===----------------------------------------------------------------------===//
39 // CallGraphNode
40 //===----------------------------------------------------------------------===//
41 
42 /// Returns if this node refers to the indirect/external node.
43 bool CallGraphNode::isExternal() const { return !callableRegion; }
44 
45 /// Return the callable region this node represents. This can only be called
46 /// on non-external nodes.
47 Region *CallGraphNode::getCallableRegion() const {
48   assert(!isExternal() && "the external node has no callable region");
49   return callableRegion;
50 }
51 
52 /// Adds an reference edge to the given node. This is only valid on the
53 /// external node.
54 void CallGraphNode::addAbstractEdge(CallGraphNode *node) {
55   assert(isExternal() && "abstract edges are only valid on external nodes");
56   addEdge(node, Edge::Kind::Abstract);
57 }
58 
59 /// Add an outgoing call edge from this node.
60 void CallGraphNode::addCallEdge(CallGraphNode *node) {
61   addEdge(node, Edge::Kind::Call);
62 }
63 
64 /// Adds a reference edge to the given child node.
65 void CallGraphNode::addChildEdge(CallGraphNode *child) {
66   addEdge(child, Edge::Kind::Child);
67 }
68 
69 /// Add an edge to 'node' with the given kind.
70 void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
71   edges.insert({node, kind});
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // CallGraph
76 //===----------------------------------------------------------------------===//
77 
78 /// Compute the set of callgraph nodes that are created by regions nested within
79 /// 'op'.
80 static void computeCallables(Operation *op, CallGraph &cg,
81                              CallGraphNode *parentNode) {
82   if (op->getNumRegions() == 0)
83     return;
84   if (auto callableOp = dyn_cast<CallableOpInterface>(op)) {
85     SmallVector<Region *, 1> callables;
86     callableOp.getCallableRegions(callables);
87     for (auto *callableRegion : callables)
88       cg.getOrAddNode(callableRegion, parentNode);
89   }
90 }
91 
92 /// Recursively compute the callgraph edges for the given operation. Computed
93 /// edges are placed into the given callgraph object.
94 static void computeCallGraph(Operation *op, CallGraph &cg,
95                              CallGraphNode *parentNode) {
96   // Compute the callgraph nodes and edges for each of the nested operations.
97   auto isCallable = isa<CallableOpInterface>(op);
98   for (auto &region : op->getRegions()) {
99     // Check to see if this region is a callable node, if so this is the parent
100     // node of the nested region.
101     CallGraphNode *nestedParentNode;
102     if (!isCallable || !(nestedParentNode = cg.lookupNode(&region)))
103       nestedParentNode = parentNode;
104 
105     // Iterate over the nested operations twice:
106     /// One to fully create nodes in the for each callable region of a nested
107     /// operation;
108     for (auto &block : region)
109       for (auto &nested : block)
110         computeCallables(&nested, cg, nestedParentNode);
111     /// And another to recursively compute the callgraph.
112     for (auto &block : region)
113       for (auto &nested : block)
114         computeCallGraph(&nested, cg, nestedParentNode);
115   }
116 
117   // If there is no parent node, we ignore this operation. Even if this
118   // operation was a call, there would be no callgraph node to attribute it to.
119   if (!parentNode)
120     return;
121 
122   // If this is a call operation, resolve the callee.
123   if (auto call = dyn_cast<CallOpInterface>(op))
124     parentNode->addCallEdge(
125         cg.resolveCallable(call.getCallableForCallee(), op));
126 }
127 
128 CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) {
129   computeCallGraph(op, *this, /*parentNode=*/nullptr);
130 }
131 
132 /// Get or add a call graph node for the given region.
133 CallGraphNode *CallGraph::getOrAddNode(Region *region,
134                                        CallGraphNode *parentNode) {
135   assert(region && isa<CallableOpInterface>(region->getParentOp()) &&
136          "expected parent operation to be callable");
137   std::unique_ptr<CallGraphNode> &node = nodes[region];
138   if (!node) {
139     node.reset(new CallGraphNode(region));
140 
141     // Add this node to the given parent node if necessary.
142     if (parentNode)
143       parentNode->addChildEdge(node.get());
144     else
145       // Otherwise, connect all callable nodes to the external node, this allows
146       // for conservatively including all callable nodes within the graph.
147       // FIXME(riverriddle) This isn't correct, this is only necessary for
148       // callable nodes that *could* be called from external sources. This
149       // requires extending the interface for callables to check if they may be
150       // referenced externally.
151       externalNode.addAbstractEdge(node.get());
152   }
153   return node.get();
154 }
155 
156 /// Lookup a call graph node for the given region, or nullptr if none is
157 /// registered.
158 CallGraphNode *CallGraph::lookupNode(Region *region) const {
159   auto it = nodes.find(region);
160   return it == nodes.end() ? nullptr : it->second.get();
161 }
162 
163 /// Resolve the callable for given callee to a node in the callgraph, or the
164 /// external node if a valid node was not resolved.
165 CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable,
166                                           Operation *from) const {
167   // Get the callee operation from the callable.
168   Operation *callee;
169   if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>())
170     callee = SymbolTable::lookupNearestSymbolFrom(from, symbolRef.getValue());
171   else
172     callee = callable.get<Value *>()->getDefiningOp();
173 
174   // If the callee is non-null and is a valid callable object, try to get the
175   // called region from it.
176   if (callee && callee->getNumRegions()) {
177     if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callee)) {
178       if (auto *node = lookupNode(callableOp.getCallableRegion(callable)))
179         return node;
180     }
181   }
182 
183   // If we don't have a valid direct region, this is an external call.
184   return getExternalNode();
185 }
186 
187 /// Dump the the graph in a human readable format.
188 void CallGraph::dump() const { print(llvm::errs()); }
189 void CallGraph::print(raw_ostream &os) const {
190   os << "// ---- CallGraph ----\n";
191 
192   // Functor used to output the name for the given node.
193   auto emitNodeName = [&](const CallGraphNode *node) {
194     if (node->isExternal()) {
195       os << "<External-Node>";
196       return;
197     }
198 
199     auto *callableRegion = node->getCallableRegion();
200     auto *parentOp = callableRegion->getParentOp();
201     os << "'" << callableRegion->getParentOp()->getName() << "' - Region #"
202        << callableRegion->getRegionNumber();
203     if (auto attrs = parentOp->getAttrList().getDictionary())
204       os << " : " << attrs;
205   };
206 
207   for (auto &nodeIt : nodes) {
208     const CallGraphNode *node = nodeIt.second.get();
209 
210     // Dump the header for this node.
211     os << "// - Node : ";
212     emitNodeName(node);
213     os << "\n";
214 
215     // Emit each of the edges.
216     for (auto &edge : *node) {
217       os << "// -- ";
218       if (edge.isCall())
219         os << "Call";
220       else if (edge.isChild())
221         os << "Child";
222 
223       os << "-Edge : ";
224       emitNodeName(edge.getTarget());
225       os << "\n";
226     }
227     os << "//\n";
228   }
229 
230   os << "// -- SCCs --\n";
231 
232   for (auto &scc : make_range(llvm::scc_begin(this), llvm::scc_end(this))) {
233     os << "// - SCC : \n";
234     for (auto &node : scc) {
235       os << "// -- Node :";
236       emitNodeName(node);
237       os << "\n";
238     }
239     os << "\n";
240   }
241 
242   os << "// -------------------\n";
243 }
244