1 //===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains interfaces and analyses for defining a nested callgraph. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/CallGraph.h" 14 #include "mlir/Analysis/CallInterfaces.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/IR/SymbolTable.h" 17 #include "llvm/ADT/PointerUnion.h" 18 #include "llvm/ADT/SCCIterator.h" 19 #include "llvm/Support/raw_ostream.h" 20 21 using namespace mlir; 22 23 //===----------------------------------------------------------------------===// 24 // CallInterfaces 25 //===----------------------------------------------------------------------===// 26 27 #include "mlir/Analysis/CallInterfaces.cpp.inc" 28 29 //===----------------------------------------------------------------------===// 30 // CallGraphNode 31 //===----------------------------------------------------------------------===// 32 33 /// Returns if this node refers to the indirect/external node. 34 bool CallGraphNode::isExternal() const { return !callableRegion; } 35 36 /// Return the callable region this node represents. This can only be called 37 /// on non-external nodes. 38 Region *CallGraphNode::getCallableRegion() const { 39 assert(!isExternal() && "the external node has no callable region"); 40 return callableRegion; 41 } 42 43 /// Adds an reference edge to the given node. This is only valid on the 44 /// external node. 45 void CallGraphNode::addAbstractEdge(CallGraphNode *node) { 46 assert(isExternal() && "abstract edges are only valid on external nodes"); 47 addEdge(node, Edge::Kind::Abstract); 48 } 49 50 /// Add an outgoing call edge from this node. 51 void CallGraphNode::addCallEdge(CallGraphNode *node) { 52 addEdge(node, Edge::Kind::Call); 53 } 54 55 /// Adds a reference edge to the given child node. 56 void CallGraphNode::addChildEdge(CallGraphNode *child) { 57 addEdge(child, Edge::Kind::Child); 58 } 59 60 /// Returns true if this node has any child edges. 61 bool CallGraphNode::hasChildren() const { 62 return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); }); 63 } 64 65 /// Add an edge to 'node' with the given kind. 66 void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) { 67 edges.insert({node, kind}); 68 } 69 70 //===----------------------------------------------------------------------===// 71 // CallGraph 72 //===----------------------------------------------------------------------===// 73 74 /// Recursively compute the callgraph edges for the given operation. Computed 75 /// edges are placed into the given callgraph object. 76 static void computeCallGraph(Operation *op, CallGraph &cg, 77 CallGraphNode *parentNode, bool resolveCalls) { 78 if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) { 79 // If there is no parent node, we ignore this operation. Even if this 80 // operation was a call, there would be no callgraph node to attribute it 81 // to. 82 if (!resolveCalls || !parentNode) 83 return; 84 parentNode->addCallEdge( 85 cg.resolveCallable(call.getCallableForCallee(), op)); 86 return; 87 } 88 89 // Compute the callgraph nodes and edges for each of the nested operations. 90 if (CallableOpInterface callable = dyn_cast<CallableOpInterface>(op)) { 91 if (auto *callableRegion = callable.getCallableRegion()) 92 parentNode = cg.getOrAddNode(callableRegion, parentNode); 93 else 94 return; 95 } 96 97 for (Region ®ion : op->getRegions()) 98 for (Block &block : region) 99 for (Operation &nested : block) 100 computeCallGraph(&nested, cg, parentNode, resolveCalls); 101 } 102 103 CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) { 104 // Make two passes over the graph, one to compute the callables and one to 105 // resolve the calls. We split these up as we may have nested callable objects 106 // that need to be reserved before the calls. 107 computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/false); 108 computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/true); 109 } 110 111 /// Get or add a call graph node for the given region. 112 CallGraphNode *CallGraph::getOrAddNode(Region *region, 113 CallGraphNode *parentNode) { 114 assert(region && isa<CallableOpInterface>(region->getParentOp()) && 115 "expected parent operation to be callable"); 116 std::unique_ptr<CallGraphNode> &node = nodes[region]; 117 if (!node) { 118 node.reset(new CallGraphNode(region)); 119 120 // Add this node to the given parent node if necessary. 121 if (parentNode) 122 parentNode->addChildEdge(node.get()); 123 else 124 // Otherwise, connect all callable nodes to the external node, this allows 125 // for conservatively including all callable nodes within the graph. 126 // FIXME(riverriddle) This isn't correct, this is only necessary for 127 // callable nodes that *could* be called from external sources. This 128 // requires extending the interface for callables to check if they may be 129 // referenced externally. 130 externalNode.addAbstractEdge(node.get()); 131 } 132 return node.get(); 133 } 134 135 /// Lookup a call graph node for the given region, or nullptr if none is 136 /// registered. 137 CallGraphNode *CallGraph::lookupNode(Region *region) const { 138 auto it = nodes.find(region); 139 return it == nodes.end() ? nullptr : it->second.get(); 140 } 141 142 /// Resolve the callable for given callee to a node in the callgraph, or the 143 /// external node if a valid node was not resolved. 144 CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable, 145 Operation *from) const { 146 // Get the callee operation from the callable. 147 Operation *callee; 148 if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>()) 149 callee = SymbolTable::lookupNearestSymbolFrom(from, symbolRef); 150 else 151 callee = callable.get<Value>().getDefiningOp(); 152 153 // If the callee is non-null and is a valid callable object, try to get the 154 // called region from it. 155 if (callee && callee->getNumRegions()) { 156 if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callee)) { 157 if (auto *node = lookupNode(callableOp.getCallableRegion())) 158 return node; 159 } 160 } 161 162 // If we don't have a valid direct region, this is an external call. 163 return getExternalNode(); 164 } 165 166 //===----------------------------------------------------------------------===// 167 // Printing 168 169 /// Dump the graph in a human readable format. 170 void CallGraph::dump() const { print(llvm::errs()); } 171 void CallGraph::print(raw_ostream &os) const { 172 os << "// ---- CallGraph ----\n"; 173 174 // Functor used to output the name for the given node. 175 auto emitNodeName = [&](const CallGraphNode *node) { 176 if (node->isExternal()) { 177 os << "<External-Node>"; 178 return; 179 } 180 181 auto *callableRegion = node->getCallableRegion(); 182 auto *parentOp = callableRegion->getParentOp(); 183 os << "'" << callableRegion->getParentOp()->getName() << "' - Region #" 184 << callableRegion->getRegionNumber(); 185 if (auto attrs = parentOp->getAttrList().getDictionary()) 186 os << " : " << attrs; 187 }; 188 189 for (auto &nodeIt : nodes) { 190 const CallGraphNode *node = nodeIt.second.get(); 191 192 // Dump the header for this node. 193 os << "// - Node : "; 194 emitNodeName(node); 195 os << "\n"; 196 197 // Emit each of the edges. 198 for (auto &edge : *node) { 199 os << "// -- "; 200 if (edge.isCall()) 201 os << "Call"; 202 else if (edge.isChild()) 203 os << "Child"; 204 205 os << "-Edge : "; 206 emitNodeName(edge.getTarget()); 207 os << "\n"; 208 } 209 os << "//\n"; 210 } 211 212 os << "// -- SCCs --\n"; 213 214 for (auto &scc : make_range(llvm::scc_begin(this), llvm::scc_end(this))) { 215 os << "// - SCC : \n"; 216 for (auto &node : scc) { 217 os << "// -- Node :"; 218 emitNodeName(node); 219 os << "\n"; 220 } 221 os << "\n"; 222 } 223 224 os << "// -------------------\n"; 225 } 226