1 //===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains interfaces and analyses for defining a nested callgraph. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/CallGraph.h" 14 #include "mlir/Analysis/CallInterfaces.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/IR/SymbolTable.h" 17 #include "llvm/ADT/PointerUnion.h" 18 #include "llvm/ADT/SCCIterator.h" 19 #include "llvm/Support/raw_ostream.h" 20 21 using namespace mlir; 22 23 //===----------------------------------------------------------------------===// 24 // CallInterfaces 25 //===----------------------------------------------------------------------===// 26 27 #include "mlir/Analysis/CallInterfaces.cpp.inc" 28 29 //===----------------------------------------------------------------------===// 30 // CallGraphNode 31 //===----------------------------------------------------------------------===// 32 33 /// Returns if this node refers to the indirect/external node. 34 bool CallGraphNode::isExternal() const { return !callableRegion; } 35 36 /// Return the callable region this node represents. This can only be called 37 /// on non-external nodes. 38 Region *CallGraphNode::getCallableRegion() const { 39 assert(!isExternal() && "the external node has no callable region"); 40 return callableRegion; 41 } 42 43 /// Adds an reference edge to the given node. This is only valid on the 44 /// external node. 45 void CallGraphNode::addAbstractEdge(CallGraphNode *node) { 46 assert(isExternal() && "abstract edges are only valid on external nodes"); 47 addEdge(node, Edge::Kind::Abstract); 48 } 49 50 /// Add an outgoing call edge from this node. 51 void CallGraphNode::addCallEdge(CallGraphNode *node) { 52 addEdge(node, Edge::Kind::Call); 53 } 54 55 /// Adds a reference edge to the given child node. 56 void CallGraphNode::addChildEdge(CallGraphNode *child) { 57 addEdge(child, Edge::Kind::Child); 58 } 59 60 /// Returns true if this node has any child edges. 61 bool CallGraphNode::hasChildren() const { 62 return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); }); 63 } 64 65 /// Add an edge to 'node' with the given kind. 66 void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) { 67 edges.insert({node, kind}); 68 } 69 70 //===----------------------------------------------------------------------===// 71 // CallGraph 72 //===----------------------------------------------------------------------===// 73 74 /// Recursively compute the callgraph edges for the given operation. Computed 75 /// edges are placed into the given callgraph object. 76 static void computeCallGraph(Operation *op, CallGraph &cg, 77 CallGraphNode *parentNode); 78 79 /// Compute the set of callgraph nodes that are created by regions nested within 80 /// 'op'. 81 static void computeCallables(Operation *op, CallGraph &cg, 82 CallGraphNode *parentNode) { 83 if (op->getNumRegions() == 0) 84 return; 85 if (auto callableOp = dyn_cast<CallableOpInterface>(op)) { 86 SmallVector<Region *, 1> callables; 87 callableOp.getCallableRegions(callables); 88 for (auto *callableRegion : callables) 89 cg.getOrAddNode(callableRegion, parentNode); 90 } 91 } 92 93 /// Recursively compute the callgraph edges within the given region. Computed 94 /// edges are placed into the given callgraph object. 95 static void computeCallGraph(Region ®ion, CallGraph &cg, 96 CallGraphNode *parentNode) { 97 // Iterate over the nested operations twice: 98 /// One to fully create nodes in the for each callable region of a nested 99 /// operation; 100 for (auto &block : region) 101 for (auto &nested : block) 102 computeCallables(&nested, cg, parentNode); 103 104 /// And another to recursively compute the callgraph. 105 for (auto &block : region) 106 for (auto &nested : block) 107 computeCallGraph(&nested, cg, parentNode); 108 } 109 110 /// Recursively compute the callgraph edges for the given operation. Computed 111 /// edges are placed into the given callgraph object. 112 static void computeCallGraph(Operation *op, CallGraph &cg, 113 CallGraphNode *parentNode) { 114 // Compute the callgraph nodes and edges for each of the nested operations. 115 auto isCallable = isa<CallableOpInterface>(op); 116 for (auto ®ion : op->getRegions()) { 117 // Check to see if this region is a callable node, if so this is the parent 118 // node of the nested region. 119 CallGraphNode *nestedParentNode; 120 if (!isCallable || !(nestedParentNode = cg.lookupNode(®ion))) 121 nestedParentNode = parentNode; 122 computeCallGraph(region, cg, nestedParentNode); 123 } 124 125 // If there is no parent node, we ignore this operation. Even if this 126 // operation was a call, there would be no callgraph node to attribute it to. 127 if (!parentNode) 128 return; 129 130 // If this is a call operation, resolve the callee. 131 if (auto call = dyn_cast<CallOpInterface>(op)) 132 parentNode->addCallEdge( 133 cg.resolveCallable(call.getCallableForCallee(), op)); 134 } 135 136 CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) { 137 computeCallGraph(op, *this, /*parentNode=*/nullptr); 138 } 139 140 /// Get or add a call graph node for the given region. 141 CallGraphNode *CallGraph::getOrAddNode(Region *region, 142 CallGraphNode *parentNode) { 143 assert(region && isa<CallableOpInterface>(region->getParentOp()) && 144 "expected parent operation to be callable"); 145 std::unique_ptr<CallGraphNode> &node = nodes[region]; 146 if (!node) { 147 node.reset(new CallGraphNode(region)); 148 149 // Add this node to the given parent node if necessary. 150 if (parentNode) 151 parentNode->addChildEdge(node.get()); 152 else 153 // Otherwise, connect all callable nodes to the external node, this allows 154 // for conservatively including all callable nodes within the graph. 155 // FIXME(riverriddle) This isn't correct, this is only necessary for 156 // callable nodes that *could* be called from external sources. This 157 // requires extending the interface for callables to check if they may be 158 // referenced externally. 159 externalNode.addAbstractEdge(node.get()); 160 } 161 return node.get(); 162 } 163 164 /// Lookup a call graph node for the given region, or nullptr if none is 165 /// registered. 166 CallGraphNode *CallGraph::lookupNode(Region *region) const { 167 auto it = nodes.find(region); 168 return it == nodes.end() ? nullptr : it->second.get(); 169 } 170 171 /// Resolve the callable for given callee to a node in the callgraph, or the 172 /// external node if a valid node was not resolved. 173 CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable, 174 Operation *from) const { 175 // Get the callee operation from the callable. 176 Operation *callee; 177 if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>()) 178 // TODO(riverriddle) Support nested references. 179 callee = SymbolTable::lookupNearestSymbolFrom(from, 180 symbolRef.getRootReference()); 181 else 182 callee = callable.get<Value>().getDefiningOp(); 183 184 // If the callee is non-null and is a valid callable object, try to get the 185 // called region from it. 186 if (callee && callee->getNumRegions()) { 187 if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callee)) { 188 if (auto *node = lookupNode(callableOp.getCallableRegion(callable))) 189 return node; 190 } 191 } 192 193 // If we don't have a valid direct region, this is an external call. 194 return getExternalNode(); 195 } 196 197 //===----------------------------------------------------------------------===// 198 // Printing 199 200 /// Dump the graph in a human readable format. 201 void CallGraph::dump() const { print(llvm::errs()); } 202 void CallGraph::print(raw_ostream &os) const { 203 os << "// ---- CallGraph ----\n"; 204 205 // Functor used to output the name for the given node. 206 auto emitNodeName = [&](const CallGraphNode *node) { 207 if (node->isExternal()) { 208 os << "<External-Node>"; 209 return; 210 } 211 212 auto *callableRegion = node->getCallableRegion(); 213 auto *parentOp = callableRegion->getParentOp(); 214 os << "'" << callableRegion->getParentOp()->getName() << "' - Region #" 215 << callableRegion->getRegionNumber(); 216 if (auto attrs = parentOp->getAttrList().getDictionary()) 217 os << " : " << attrs; 218 }; 219 220 for (auto &nodeIt : nodes) { 221 const CallGraphNode *node = nodeIt.second.get(); 222 223 // Dump the header for this node. 224 os << "// - Node : "; 225 emitNodeName(node); 226 os << "\n"; 227 228 // Emit each of the edges. 229 for (auto &edge : *node) { 230 os << "// -- "; 231 if (edge.isCall()) 232 os << "Call"; 233 else if (edge.isChild()) 234 os << "Child"; 235 236 os << "-Edge : "; 237 emitNodeName(edge.getTarget()); 238 os << "\n"; 239 } 240 os << "//\n"; 241 } 242 243 os << "// -- SCCs --\n"; 244 245 for (auto &scc : make_range(llvm::scc_begin(this), llvm::scc_end(this))) { 246 os << "// - SCC : \n"; 247 for (auto &node : scc) { 248 os << "// -- Node :"; 249 emitNodeName(node); 250 os << "\n"; 251 } 252 os << "\n"; 253 } 254 255 os << "// -------------------\n"; 256 } 257