18965011fSRiver Riddle //===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===//
28965011fSRiver Riddle //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68965011fSRiver Riddle //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
88965011fSRiver Riddle //
98965011fSRiver Riddle // This file contains interfaces and analyses for defining a nested callgraph.
108965011fSRiver Riddle //
118965011fSRiver Riddle //===----------------------------------------------------------------------===//
128965011fSRiver Riddle
138cb405a8SRiver Riddle #include "mlir/Analysis/CallGraph.h"
148cb405a8SRiver Riddle #include "mlir/IR/Operation.h"
158cb405a8SRiver Riddle #include "mlir/IR/SymbolTable.h"
167ce1e7abSRiver Riddle #include "mlir/Interfaces/CallInterfaces.h"
178965011fSRiver Riddle #include "llvm/ADT/PointerUnion.h"
188cb405a8SRiver Riddle #include "llvm/ADT/SCCIterator.h"
198cb405a8SRiver Riddle #include "llvm/Support/raw_ostream.h"
208965011fSRiver Riddle
218965011fSRiver Riddle using namespace mlir;
228965011fSRiver Riddle
238965011fSRiver Riddle //===----------------------------------------------------------------------===//
248cb405a8SRiver Riddle // CallGraphNode
258cb405a8SRiver Riddle //===----------------------------------------------------------------------===//
268cb405a8SRiver Riddle
27deb99610SKamlesh Kumar /// Returns true if this node refers to the indirect/external node.
isExternal() const288cb405a8SRiver Riddle bool CallGraphNode::isExternal() const { return !callableRegion; }
298cb405a8SRiver Riddle
308cb405a8SRiver Riddle /// Return the callable region this node represents. This can only be called
318cb405a8SRiver Riddle /// on non-external nodes.
getCallableRegion() const328cb405a8SRiver Riddle Region *CallGraphNode::getCallableRegion() const {
338cb405a8SRiver Riddle assert(!isExternal() && "the external node has no callable region");
348cb405a8SRiver Riddle return callableRegion;
358cb405a8SRiver Riddle }
368cb405a8SRiver Riddle
378cb405a8SRiver Riddle /// Adds an reference edge to the given node. This is only valid on the
388cb405a8SRiver Riddle /// external node.
addAbstractEdge(CallGraphNode * node)398cb405a8SRiver Riddle void CallGraphNode::addAbstractEdge(CallGraphNode *node) {
408cb405a8SRiver Riddle assert(isExternal() && "abstract edges are only valid on external nodes");
418cb405a8SRiver Riddle addEdge(node, Edge::Kind::Abstract);
428cb405a8SRiver Riddle }
438cb405a8SRiver Riddle
448cb405a8SRiver Riddle /// Add an outgoing call edge from this node.
addCallEdge(CallGraphNode * node)458cb405a8SRiver Riddle void CallGraphNode::addCallEdge(CallGraphNode *node) {
468cb405a8SRiver Riddle addEdge(node, Edge::Kind::Call);
478cb405a8SRiver Riddle }
488cb405a8SRiver Riddle
498cb405a8SRiver Riddle /// Adds a reference edge to the given child node.
addChildEdge(CallGraphNode * child)508cb405a8SRiver Riddle void CallGraphNode::addChildEdge(CallGraphNode *child) {
518cb405a8SRiver Riddle addEdge(child, Edge::Kind::Child);
528cb405a8SRiver Riddle }
538cb405a8SRiver Riddle
546b1cc3c6SRiver Riddle /// Returns true if this node has any child edges.
hasChildren() const556b1cc3c6SRiver Riddle bool CallGraphNode::hasChildren() const {
566b1cc3c6SRiver Riddle return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); });
576b1cc3c6SRiver Riddle }
586b1cc3c6SRiver Riddle
598cb405a8SRiver Riddle /// Add an edge to 'node' with the given kind.
addEdge(CallGraphNode * node,Edge::Kind kind)608cb405a8SRiver Riddle void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
618cb405a8SRiver Riddle edges.insert({node, kind});
628cb405a8SRiver Riddle }
638cb405a8SRiver Riddle
648cb405a8SRiver Riddle //===----------------------------------------------------------------------===//
658cb405a8SRiver Riddle // CallGraph
668cb405a8SRiver Riddle //===----------------------------------------------------------------------===//
678cb405a8SRiver Riddle
686b1cc3c6SRiver Riddle /// Recursively compute the callgraph edges for the given operation. Computed
696b1cc3c6SRiver Riddle /// edges are placed into the given callgraph object.
computeCallGraph(Operation * op,CallGraph & cg,SymbolTableCollection & symbolTable,CallGraphNode * parentNode,bool resolveCalls)706b1cc3c6SRiver Riddle static void computeCallGraph(Operation *op, CallGraph &cg,
71*a5ea6045SRiver Riddle SymbolTableCollection &symbolTable,
72c7748404SRiver Riddle CallGraphNode *parentNode, bool resolveCalls) {
73c7748404SRiver Riddle if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) {
748cb405a8SRiver Riddle // If there is no parent node, we ignore this operation. Even if this
75c7748404SRiver Riddle // operation was a call, there would be no callgraph node to attribute it
76c7748404SRiver Riddle // to.
775c159b91SRiver Riddle if (resolveCalls && parentNode)
78*a5ea6045SRiver Riddle parentNode->addCallEdge(cg.resolveCallable(call, symbolTable));
79c7748404SRiver Riddle return;
80c7748404SRiver Riddle }
81c7748404SRiver Riddle
82c7748404SRiver Riddle // Compute the callgraph nodes and edges for each of the nested operations.
83c7748404SRiver Riddle if (CallableOpInterface callable = dyn_cast<CallableOpInterface>(op)) {
84c7748404SRiver Riddle if (auto *callableRegion = callable.getCallableRegion())
85c7748404SRiver Riddle parentNode = cg.getOrAddNode(callableRegion, parentNode);
86c7748404SRiver Riddle else
87c7748404SRiver Riddle return;
88c7748404SRiver Riddle }
89c7748404SRiver Riddle
90c7748404SRiver Riddle for (Region ®ion : op->getRegions())
911e4faf23SRiver Riddle for (Operation &nested : region.getOps())
92*a5ea6045SRiver Riddle computeCallGraph(&nested, cg, symbolTable, parentNode, resolveCalls);
938cb405a8SRiver Riddle }
948cb405a8SRiver Riddle
CallGraph(Operation * op)958cb405a8SRiver Riddle CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) {
96c7748404SRiver Riddle // Make two passes over the graph, one to compute the callables and one to
97c7748404SRiver Riddle // resolve the calls. We split these up as we may have nested callable objects
98c7748404SRiver Riddle // that need to be reserved before the calls.
99*a5ea6045SRiver Riddle SymbolTableCollection symbolTable;
100*a5ea6045SRiver Riddle computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr,
101*a5ea6045SRiver Riddle /*resolveCalls=*/false);
102*a5ea6045SRiver Riddle computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr,
103*a5ea6045SRiver Riddle /*resolveCalls=*/true);
1048cb405a8SRiver Riddle }
1058cb405a8SRiver Riddle
1068cb405a8SRiver Riddle /// Get or add a call graph node for the given region.
getOrAddNode(Region * region,CallGraphNode * parentNode)1078cb405a8SRiver Riddle CallGraphNode *CallGraph::getOrAddNode(Region *region,
1088cb405a8SRiver Riddle CallGraphNode *parentNode) {
1098cb405a8SRiver Riddle assert(region && isa<CallableOpInterface>(region->getParentOp()) &&
1108cb405a8SRiver Riddle "expected parent operation to be callable");
1118cb405a8SRiver Riddle std::unique_ptr<CallGraphNode> &node = nodes[region];
1128cb405a8SRiver Riddle if (!node) {
1138cb405a8SRiver Riddle node.reset(new CallGraphNode(region));
1148cb405a8SRiver Riddle
1158cb405a8SRiver Riddle // Add this node to the given parent node if necessary.
116*a5ea6045SRiver Riddle if (parentNode) {
1178cb405a8SRiver Riddle parentNode->addChildEdge(node.get());
118*a5ea6045SRiver Riddle } else {
1198cb405a8SRiver Riddle // Otherwise, connect all callable nodes to the external node, this allows
1208cb405a8SRiver Riddle // for conservatively including all callable nodes within the graph.
121*a5ea6045SRiver Riddle // FIXME This isn't correct, this is only necessary for callable nodes
122*a5ea6045SRiver Riddle // that *could* be called from external sources. This requires extending
123*a5ea6045SRiver Riddle // the interface for callables to check if they may be referenced
124*a5ea6045SRiver Riddle // externally.
1258cb405a8SRiver Riddle externalNode.addAbstractEdge(node.get());
1268cb405a8SRiver Riddle }
127*a5ea6045SRiver Riddle }
1288cb405a8SRiver Riddle return node.get();
1298cb405a8SRiver Riddle }
1308cb405a8SRiver Riddle
1318cb405a8SRiver Riddle /// Lookup a call graph node for the given region, or nullptr if none is
1328cb405a8SRiver Riddle /// registered.
lookupNode(Region * region) const1338cb405a8SRiver Riddle CallGraphNode *CallGraph::lookupNode(Region *region) const {
1348cb405a8SRiver Riddle auto it = nodes.find(region);
1358cb405a8SRiver Riddle return it == nodes.end() ? nullptr : it->second.get();
1368cb405a8SRiver Riddle }
1378cb405a8SRiver Riddle
1388cb405a8SRiver Riddle /// Resolve the callable for given callee to a node in the callgraph, or the
1398cb405a8SRiver Riddle /// external node if a valid node was not resolved.
140*a5ea6045SRiver Riddle CallGraphNode *
resolveCallable(CallOpInterface call,SymbolTableCollection & symbolTable) const141*a5ea6045SRiver Riddle CallGraph::resolveCallable(CallOpInterface call,
142*a5ea6045SRiver Riddle SymbolTableCollection &symbolTable) const {
143*a5ea6045SRiver Riddle Operation *callable = call.resolveCallable(&symbolTable);
1445c159b91SRiver Riddle if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable))
145c7748404SRiver Riddle if (auto *node = lookupNode(callableOp.getCallableRegion()))
1468cb405a8SRiver Riddle return node;
1478cb405a8SRiver Riddle
1488cb405a8SRiver Riddle // If we don't have a valid direct region, this is an external call.
1498cb405a8SRiver Riddle return getExternalNode();
1508cb405a8SRiver Riddle }
1518cb405a8SRiver Riddle
1524be504a9SRiver Riddle /// Erase the given node from the callgraph.
eraseNode(CallGraphNode * node)1534be504a9SRiver Riddle void CallGraph::eraseNode(CallGraphNode *node) {
1544be504a9SRiver Riddle // Erase any children of this node first.
1554be504a9SRiver Riddle if (node->hasChildren()) {
1564be504a9SRiver Riddle for (const CallGraphNode::Edge &edge : llvm::make_early_inc_range(*node))
1574be504a9SRiver Riddle if (edge.isChild())
1584be504a9SRiver Riddle eraseNode(edge.getTarget());
1594be504a9SRiver Riddle }
1604be504a9SRiver Riddle // Erase any edges to this node from any other nodes.
1614be504a9SRiver Riddle for (auto &it : nodes) {
1624be504a9SRiver Riddle it.second->edges.remove_if([node](const CallGraphNode::Edge &edge) {
1634be504a9SRiver Riddle return edge.getTarget() == node;
1644be504a9SRiver Riddle });
1654be504a9SRiver Riddle }
1664be504a9SRiver Riddle nodes.erase(node->getCallableRegion());
1674be504a9SRiver Riddle }
1684be504a9SRiver Riddle
1696b1cc3c6SRiver Riddle //===----------------------------------------------------------------------===//
1706b1cc3c6SRiver Riddle // Printing
1716b1cc3c6SRiver Riddle
172e5026165SAlexander Belyaev /// Dump the graph in a human readable format.
dump() const1738cb405a8SRiver Riddle void CallGraph::dump() const { print(llvm::errs()); }
print(raw_ostream & os) const1748cb405a8SRiver Riddle void CallGraph::print(raw_ostream &os) const {
1758cb405a8SRiver Riddle os << "// ---- CallGraph ----\n";
1768cb405a8SRiver Riddle
1778cb405a8SRiver Riddle // Functor used to output the name for the given node.
1788cb405a8SRiver Riddle auto emitNodeName = [&](const CallGraphNode *node) {
1798cb405a8SRiver Riddle if (node->isExternal()) {
1808cb405a8SRiver Riddle os << "<External-Node>";
1818cb405a8SRiver Riddle return;
1828cb405a8SRiver Riddle }
1838cb405a8SRiver Riddle
1848cb405a8SRiver Riddle auto *callableRegion = node->getCallableRegion();
1858cb405a8SRiver Riddle auto *parentOp = callableRegion->getParentOp();
1868cb405a8SRiver Riddle os << "'" << callableRegion->getParentOp()->getName() << "' - Region #"
1878cb405a8SRiver Riddle << callableRegion->getRegionNumber();
1885eae715aSJacques Pienaar auto attrs = parentOp->getAttrDictionary();
1895eae715aSJacques Pienaar if (!attrs.empty())
1908cb405a8SRiver Riddle os << " : " << attrs;
1918cb405a8SRiver Riddle };
1928cb405a8SRiver Riddle
1938cb405a8SRiver Riddle for (auto &nodeIt : nodes) {
1948cb405a8SRiver Riddle const CallGraphNode *node = nodeIt.second.get();
1958cb405a8SRiver Riddle
1968cb405a8SRiver Riddle // Dump the header for this node.
1978cb405a8SRiver Riddle os << "// - Node : ";
1988cb405a8SRiver Riddle emitNodeName(node);
1998cb405a8SRiver Riddle os << "\n";
2008cb405a8SRiver Riddle
2018cb405a8SRiver Riddle // Emit each of the edges.
2028cb405a8SRiver Riddle for (auto &edge : *node) {
2038cb405a8SRiver Riddle os << "// -- ";
2048cb405a8SRiver Riddle if (edge.isCall())
2058cb405a8SRiver Riddle os << "Call";
2068cb405a8SRiver Riddle else if (edge.isChild())
2078cb405a8SRiver Riddle os << "Child";
2088cb405a8SRiver Riddle
2098cb405a8SRiver Riddle os << "-Edge : ";
2108cb405a8SRiver Riddle emitNodeName(edge.getTarget());
2118cb405a8SRiver Riddle os << "\n";
2128cb405a8SRiver Riddle }
2138cb405a8SRiver Riddle os << "//\n";
2148cb405a8SRiver Riddle }
2158cb405a8SRiver Riddle
2168cb405a8SRiver Riddle os << "// -- SCCs --\n";
2178cb405a8SRiver Riddle
2188cb405a8SRiver Riddle for (auto &scc : make_range(llvm::scc_begin(this), llvm::scc_end(this))) {
2198cb405a8SRiver Riddle os << "// - SCC : \n";
2208cb405a8SRiver Riddle for (auto &node : scc) {
2218cb405a8SRiver Riddle os << "// -- Node :";
2228cb405a8SRiver Riddle emitNodeName(node);
2238cb405a8SRiver Riddle os << "\n";
2248cb405a8SRiver Riddle }
2258cb405a8SRiver Riddle os << "\n";
2268cb405a8SRiver Riddle }
2278cb405a8SRiver Riddle
2288cb405a8SRiver Riddle os << "// -------------------\n";
2298cb405a8SRiver Riddle }
230