1 //===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains interfaces and analyses for defining a nested callgraph.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/CallGraph.h"
14 #include "mlir/IR/Operation.h"
15 #include "mlir/IR/SymbolTable.h"
16 #include "mlir/Interfaces/CallInterfaces.h"
17 #include "llvm/ADT/PointerUnion.h"
18 #include "llvm/ADT/SCCIterator.h"
19 #include "llvm/Support/raw_ostream.h"
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // CallGraphNode
25 //===----------------------------------------------------------------------===//
26 
27 /// Returns true if this node refers to the indirect/external node.
28 bool CallGraphNode::isExternal() const { return !callableRegion; }
29 
30 /// Return the callable region this node represents. This can only be called
31 /// on non-external nodes.
32 Region *CallGraphNode::getCallableRegion() const {
33   assert(!isExternal() && "the external node has no callable region");
34   return callableRegion;
35 }
36 
37 /// Adds an reference edge to the given node. This is only valid on the
38 /// external node.
39 void CallGraphNode::addAbstractEdge(CallGraphNode *node) {
40   assert(isExternal() && "abstract edges are only valid on external nodes");
41   addEdge(node, Edge::Kind::Abstract);
42 }
43 
44 /// Add an outgoing call edge from this node.
45 void CallGraphNode::addCallEdge(CallGraphNode *node) {
46   addEdge(node, Edge::Kind::Call);
47 }
48 
49 /// Adds a reference edge to the given child node.
50 void CallGraphNode::addChildEdge(CallGraphNode *child) {
51   addEdge(child, Edge::Kind::Child);
52 }
53 
54 /// Returns true if this node has any child edges.
55 bool CallGraphNode::hasChildren() const {
56   return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); });
57 }
58 
59 /// Add an edge to 'node' with the given kind.
60 void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
61   edges.insert({node, kind});
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // CallGraph
66 //===----------------------------------------------------------------------===//
67 
68 /// Recursively compute the callgraph edges for the given operation. Computed
69 /// edges are placed into the given callgraph object.
70 static void computeCallGraph(Operation *op, CallGraph &cg,
71                              CallGraphNode *parentNode, bool resolveCalls) {
72   if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) {
73     // If there is no parent node, we ignore this operation. Even if this
74     // operation was a call, there would be no callgraph node to attribute it
75     // to.
76     if (resolveCalls && parentNode)
77       parentNode->addCallEdge(cg.resolveCallable(call));
78     return;
79   }
80 
81   // Compute the callgraph nodes and edges for each of the nested operations.
82   if (CallableOpInterface callable = dyn_cast<CallableOpInterface>(op)) {
83     if (auto *callableRegion = callable.getCallableRegion())
84       parentNode = cg.getOrAddNode(callableRegion, parentNode);
85     else
86       return;
87   }
88 
89   for (Region &region : op->getRegions())
90     for (Operation &nested : region.getOps())
91       computeCallGraph(&nested, cg, parentNode, resolveCalls);
92 }
93 
94 CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) {
95   // Make two passes over the graph, one to compute the callables and one to
96   // resolve the calls. We split these up as we may have nested callable objects
97   // that need to be reserved before the calls.
98   computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/false);
99   computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/true);
100 }
101 
102 /// Get or add a call graph node for the given region.
103 CallGraphNode *CallGraph::getOrAddNode(Region *region,
104                                        CallGraphNode *parentNode) {
105   assert(region && isa<CallableOpInterface>(region->getParentOp()) &&
106          "expected parent operation to be callable");
107   std::unique_ptr<CallGraphNode> &node = nodes[region];
108   if (!node) {
109     node.reset(new CallGraphNode(region));
110 
111     // Add this node to the given parent node if necessary.
112     if (parentNode)
113       parentNode->addChildEdge(node.get());
114     else
115       // Otherwise, connect all callable nodes to the external node, this allows
116       // for conservatively including all callable nodes within the graph.
117       // FIXME(riverriddle) This isn't correct, this is only necessary for
118       // callable nodes that *could* be called from external sources. This
119       // requires extending the interface for callables to check if they may be
120       // referenced externally.
121       externalNode.addAbstractEdge(node.get());
122   }
123   return node.get();
124 }
125 
126 /// Lookup a call graph node for the given region, or nullptr if none is
127 /// registered.
128 CallGraphNode *CallGraph::lookupNode(Region *region) const {
129   auto it = nodes.find(region);
130   return it == nodes.end() ? nullptr : it->second.get();
131 }
132 
133 /// Resolve the callable for given callee to a node in the callgraph, or the
134 /// external node if a valid node was not resolved.
135 CallGraphNode *CallGraph::resolveCallable(CallOpInterface call) const {
136   Operation *callable = call.resolveCallable();
137   if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable))
138     if (auto *node = lookupNode(callableOp.getCallableRegion()))
139       return node;
140 
141   // If we don't have a valid direct region, this is an external call.
142   return getExternalNode();
143 }
144 
145 /// Erase the given node from the callgraph.
146 void CallGraph::eraseNode(CallGraphNode *node) {
147   // Erase any children of this node first.
148   if (node->hasChildren()) {
149     for (const CallGraphNode::Edge &edge : llvm::make_early_inc_range(*node))
150       if (edge.isChild())
151         eraseNode(edge.getTarget());
152   }
153   // Erase any edges to this node from any other nodes.
154   for (auto &it : nodes) {
155     it.second->edges.remove_if([node](const CallGraphNode::Edge &edge) {
156       return edge.getTarget() == node;
157     });
158   }
159   nodes.erase(node->getCallableRegion());
160 }
161 
162 //===----------------------------------------------------------------------===//
163 // Printing
164 
165 /// Dump the graph in a human readable format.
166 void CallGraph::dump() const { print(llvm::errs()); }
167 void CallGraph::print(raw_ostream &os) const {
168   os << "// ---- CallGraph ----\n";
169 
170   // Functor used to output the name for the given node.
171   auto emitNodeName = [&](const CallGraphNode *node) {
172     if (node->isExternal()) {
173       os << "<External-Node>";
174       return;
175     }
176 
177     auto *callableRegion = node->getCallableRegion();
178     auto *parentOp = callableRegion->getParentOp();
179     os << "'" << callableRegion->getParentOp()->getName() << "' - Region #"
180        << callableRegion->getRegionNumber();
181     auto attrs = parentOp->getAttrDictionary();
182     if (!attrs.empty())
183       os << " : " << attrs;
184   };
185 
186   for (auto &nodeIt : nodes) {
187     const CallGraphNode *node = nodeIt.second.get();
188 
189     // Dump the header for this node.
190     os << "// - Node : ";
191     emitNodeName(node);
192     os << "\n";
193 
194     // Emit each of the edges.
195     for (auto &edge : *node) {
196       os << "// -- ";
197       if (edge.isCall())
198         os << "Call";
199       else if (edge.isChild())
200         os << "Child";
201 
202       os << "-Edge : ";
203       emitNodeName(edge.getTarget());
204       os << "\n";
205     }
206     os << "//\n";
207   }
208 
209   os << "// -- SCCs --\n";
210 
211   for (auto &scc : make_range(llvm::scc_begin(this), llvm::scc_end(this))) {
212     os << "// - SCC : \n";
213     for (auto &node : scc) {
214       os << "// -- Node :";
215       emitNodeName(node);
216       os << "\n";
217     }
218     os << "\n";
219   }
220 
221   os << "// -------------------\n";
222 }
223