1 //===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains interfaces and analyses for defining a nested callgraph. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/CallGraph.h" 14 #include "mlir/IR/Operation.h" 15 #include "mlir/IR/SymbolTable.h" 16 #include "mlir/Interfaces/CallInterfaces.h" 17 #include "llvm/ADT/PointerUnion.h" 18 #include "llvm/ADT/SCCIterator.h" 19 #include "llvm/Support/raw_ostream.h" 20 21 using namespace mlir; 22 23 //===----------------------------------------------------------------------===// 24 // CallGraphNode 25 //===----------------------------------------------------------------------===// 26 27 /// Returns true if this node refers to the indirect/external node. 28 bool CallGraphNode::isExternal() const { return !callableRegion; } 29 30 /// Return the callable region this node represents. This can only be called 31 /// on non-external nodes. 32 Region *CallGraphNode::getCallableRegion() const { 33 assert(!isExternal() && "the external node has no callable region"); 34 return callableRegion; 35 } 36 37 /// Adds an reference edge to the given node. This is only valid on the 38 /// external node. 39 void CallGraphNode::addAbstractEdge(CallGraphNode *node) { 40 assert(isExternal() && "abstract edges are only valid on external nodes"); 41 addEdge(node, Edge::Kind::Abstract); 42 } 43 44 /// Add an outgoing call edge from this node. 45 void CallGraphNode::addCallEdge(CallGraphNode *node) { 46 addEdge(node, Edge::Kind::Call); 47 } 48 49 /// Adds a reference edge to the given child node. 50 void CallGraphNode::addChildEdge(CallGraphNode *child) { 51 addEdge(child, Edge::Kind::Child); 52 } 53 54 /// Returns true if this node has any child edges. 55 bool CallGraphNode::hasChildren() const { 56 return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); }); 57 } 58 59 /// Add an edge to 'node' with the given kind. 60 void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) { 61 edges.insert({node, kind}); 62 } 63 64 //===----------------------------------------------------------------------===// 65 // CallGraph 66 //===----------------------------------------------------------------------===// 67 68 /// Recursively compute the callgraph edges for the given operation. Computed 69 /// edges are placed into the given callgraph object. 70 static void computeCallGraph(Operation *op, CallGraph &cg, 71 SymbolTableCollection &symbolTable, 72 CallGraphNode *parentNode, bool resolveCalls) { 73 if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) { 74 // If there is no parent node, we ignore this operation. Even if this 75 // operation was a call, there would be no callgraph node to attribute it 76 // to. 77 if (resolveCalls && parentNode) 78 parentNode->addCallEdge(cg.resolveCallable(call, symbolTable)); 79 return; 80 } 81 82 // Compute the callgraph nodes and edges for each of the nested operations. 83 if (CallableOpInterface callable = dyn_cast<CallableOpInterface>(op)) { 84 if (auto *callableRegion = callable.getCallableRegion()) 85 parentNode = cg.getOrAddNode(callableRegion, parentNode); 86 else 87 return; 88 } 89 90 for (Region ®ion : op->getRegions()) 91 for (Operation &nested : region.getOps()) 92 computeCallGraph(&nested, cg, symbolTable, parentNode, resolveCalls); 93 } 94 95 CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) { 96 // Make two passes over the graph, one to compute the callables and one to 97 // resolve the calls. We split these up as we may have nested callable objects 98 // that need to be reserved before the calls. 99 SymbolTableCollection symbolTable; 100 computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr, 101 /*resolveCalls=*/false); 102 computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr, 103 /*resolveCalls=*/true); 104 } 105 106 /// Get or add a call graph node for the given region. 107 CallGraphNode *CallGraph::getOrAddNode(Region *region, 108 CallGraphNode *parentNode) { 109 assert(region && isa<CallableOpInterface>(region->getParentOp()) && 110 "expected parent operation to be callable"); 111 std::unique_ptr<CallGraphNode> &node = nodes[region]; 112 if (!node) { 113 node.reset(new CallGraphNode(region)); 114 115 // Add this node to the given parent node if necessary. 116 if (parentNode) { 117 parentNode->addChildEdge(node.get()); 118 } else { 119 // Otherwise, connect all callable nodes to the external node, this allows 120 // for conservatively including all callable nodes within the graph. 121 // FIXME This isn't correct, this is only necessary for callable nodes 122 // that *could* be called from external sources. This requires extending 123 // the interface for callables to check if they may be referenced 124 // externally. 125 externalNode.addAbstractEdge(node.get()); 126 } 127 } 128 return node.get(); 129 } 130 131 /// Lookup a call graph node for the given region, or nullptr if none is 132 /// registered. 133 CallGraphNode *CallGraph::lookupNode(Region *region) const { 134 auto it = nodes.find(region); 135 return it == nodes.end() ? nullptr : it->second.get(); 136 } 137 138 /// Resolve the callable for given callee to a node in the callgraph, or the 139 /// external node if a valid node was not resolved. 140 CallGraphNode * 141 CallGraph::resolveCallable(CallOpInterface call, 142 SymbolTableCollection &symbolTable) const { 143 Operation *callable = call.resolveCallable(&symbolTable); 144 if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable)) 145 if (auto *node = lookupNode(callableOp.getCallableRegion())) 146 return node; 147 148 // If we don't have a valid direct region, this is an external call. 149 return getExternalNode(); 150 } 151 152 /// Erase the given node from the callgraph. 153 void CallGraph::eraseNode(CallGraphNode *node) { 154 // Erase any children of this node first. 155 if (node->hasChildren()) { 156 for (const CallGraphNode::Edge &edge : llvm::make_early_inc_range(*node)) 157 if (edge.isChild()) 158 eraseNode(edge.getTarget()); 159 } 160 // Erase any edges to this node from any other nodes. 161 for (auto &it : nodes) { 162 it.second->edges.remove_if([node](const CallGraphNode::Edge &edge) { 163 return edge.getTarget() == node; 164 }); 165 } 166 nodes.erase(node->getCallableRegion()); 167 } 168 169 //===----------------------------------------------------------------------===// 170 // Printing 171 172 /// Dump the graph in a human readable format. 173 void CallGraph::dump() const { print(llvm::errs()); } 174 void CallGraph::print(raw_ostream &os) const { 175 os << "// ---- CallGraph ----\n"; 176 177 // Functor used to output the name for the given node. 178 auto emitNodeName = [&](const CallGraphNode *node) { 179 if (node->isExternal()) { 180 os << "<External-Node>"; 181 return; 182 } 183 184 auto *callableRegion = node->getCallableRegion(); 185 auto *parentOp = callableRegion->getParentOp(); 186 os << "'" << callableRegion->getParentOp()->getName() << "' - Region #" 187 << callableRegion->getRegionNumber(); 188 auto attrs = parentOp->getAttrDictionary(); 189 if (!attrs.empty()) 190 os << " : " << attrs; 191 }; 192 193 for (auto &nodeIt : nodes) { 194 const CallGraphNode *node = nodeIt.second.get(); 195 196 // Dump the header for this node. 197 os << "// - Node : "; 198 emitNodeName(node); 199 os << "\n"; 200 201 // Emit each of the edges. 202 for (auto &edge : *node) { 203 os << "// -- "; 204 if (edge.isCall()) 205 os << "Call"; 206 else if (edge.isChild()) 207 os << "Child"; 208 209 os << "-Edge : "; 210 emitNodeName(edge.getTarget()); 211 os << "\n"; 212 } 213 os << "//\n"; 214 } 215 216 os << "// -- SCCs --\n"; 217 218 for (auto &scc : make_range(llvm::scc_begin(this), llvm::scc_end(this))) { 219 os << "// - SCC : \n"; 220 for (auto &node : scc) { 221 os << "// -- Node :"; 222 emitNodeName(node); 223 os << "\n"; 224 } 225 os << "\n"; 226 } 227 228 os << "// -------------------\n"; 229 } 230