1 //===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file contains interfaces and analyses for defining a nested callgraph. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/Analysis/CallGraph.h" 23 #include "mlir/Analysis/CallInterfaces.h" 24 #include "mlir/IR/Operation.h" 25 #include "mlir/IR/SymbolTable.h" 26 #include "llvm/ADT/PointerUnion.h" 27 #include "llvm/ADT/SCCIterator.h" 28 #include "llvm/Support/raw_ostream.h" 29 30 using namespace mlir; 31 32 //===----------------------------------------------------------------------===// 33 // CallInterfaces 34 //===----------------------------------------------------------------------===// 35 36 #include "mlir/Analysis/CallInterfaces.cpp.inc" 37 38 //===----------------------------------------------------------------------===// 39 // CallGraphNode 40 //===----------------------------------------------------------------------===// 41 42 /// Returns if this node refers to the indirect/external node. 43 bool CallGraphNode::isExternal() const { return !callableRegion; } 44 45 /// Return the callable region this node represents. This can only be called 46 /// on non-external nodes. 47 Region *CallGraphNode::getCallableRegion() const { 48 assert(!isExternal() && "the external node has no callable region"); 49 return callableRegion; 50 } 51 52 /// Adds an reference edge to the given node. This is only valid on the 53 /// external node. 54 void CallGraphNode::addAbstractEdge(CallGraphNode *node) { 55 assert(isExternal() && "abstract edges are only valid on external nodes"); 56 addEdge(node, Edge::Kind::Abstract); 57 } 58 59 /// Add an outgoing call edge from this node. 60 void CallGraphNode::addCallEdge(CallGraphNode *node) { 61 addEdge(node, Edge::Kind::Call); 62 } 63 64 /// Adds a reference edge to the given child node. 65 void CallGraphNode::addChildEdge(CallGraphNode *child) { 66 addEdge(child, Edge::Kind::Child); 67 } 68 69 /// Returns true if this node has any child edges. 70 bool CallGraphNode::hasChildren() const { 71 return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); }); 72 } 73 74 /// Add an edge to 'node' with the given kind. 75 void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) { 76 edges.insert({node, kind}); 77 } 78 79 //===----------------------------------------------------------------------===// 80 // CallGraph 81 //===----------------------------------------------------------------------===// 82 83 /// Recursively compute the callgraph edges for the given operation. Computed 84 /// edges are placed into the given callgraph object. 85 static void computeCallGraph(Operation *op, CallGraph &cg, 86 CallGraphNode *parentNode); 87 88 /// Compute the set of callgraph nodes that are created by regions nested within 89 /// 'op'. 90 static void computeCallables(Operation *op, CallGraph &cg, 91 CallGraphNode *parentNode) { 92 if (op->getNumRegions() == 0) 93 return; 94 if (auto callableOp = dyn_cast<CallableOpInterface>(op)) { 95 SmallVector<Region *, 1> callables; 96 callableOp.getCallableRegions(callables); 97 for (auto *callableRegion : callables) 98 cg.getOrAddNode(callableRegion, parentNode); 99 } 100 } 101 102 /// Recursively compute the callgraph edges within the given region. Computed 103 /// edges are placed into the given callgraph object. 104 static void computeCallGraph(Region ®ion, CallGraph &cg, 105 CallGraphNode *parentNode) { 106 // Iterate over the nested operations twice: 107 /// One to fully create nodes in the for each callable region of a nested 108 /// operation; 109 for (auto &block : region) 110 for (auto &nested : block) 111 computeCallables(&nested, cg, parentNode); 112 113 /// And another to recursively compute the callgraph. 114 for (auto &block : region) 115 for (auto &nested : block) 116 computeCallGraph(&nested, cg, parentNode); 117 } 118 119 /// Recursively compute the callgraph edges for the given operation. Computed 120 /// edges are placed into the given callgraph object. 121 static void computeCallGraph(Operation *op, CallGraph &cg, 122 CallGraphNode *parentNode) { 123 // Compute the callgraph nodes and edges for each of the nested operations. 124 auto isCallable = isa<CallableOpInterface>(op); 125 for (auto ®ion : op->getRegions()) { 126 // Check to see if this region is a callable node, if so this is the parent 127 // node of the nested region. 128 CallGraphNode *nestedParentNode; 129 if (!isCallable || !(nestedParentNode = cg.lookupNode(®ion))) 130 nestedParentNode = parentNode; 131 computeCallGraph(region, cg, nestedParentNode); 132 } 133 134 // If there is no parent node, we ignore this operation. Even if this 135 // operation was a call, there would be no callgraph node to attribute it to. 136 if (!parentNode) 137 return; 138 139 // If this is a call operation, resolve the callee. 140 if (auto call = dyn_cast<CallOpInterface>(op)) 141 parentNode->addCallEdge( 142 cg.resolveCallable(call.getCallableForCallee(), op)); 143 } 144 145 CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) { 146 computeCallGraph(op, *this, /*parentNode=*/nullptr); 147 } 148 149 /// Get or add a call graph node for the given region. 150 CallGraphNode *CallGraph::getOrAddNode(Region *region, 151 CallGraphNode *parentNode) { 152 assert(region && isa<CallableOpInterface>(region->getParentOp()) && 153 "expected parent operation to be callable"); 154 std::unique_ptr<CallGraphNode> &node = nodes[region]; 155 if (!node) { 156 node.reset(new CallGraphNode(region)); 157 158 // Add this node to the given parent node if necessary. 159 if (parentNode) 160 parentNode->addChildEdge(node.get()); 161 else 162 // Otherwise, connect all callable nodes to the external node, this allows 163 // for conservatively including all callable nodes within the graph. 164 // FIXME(riverriddle) This isn't correct, this is only necessary for 165 // callable nodes that *could* be called from external sources. This 166 // requires extending the interface for callables to check if they may be 167 // referenced externally. 168 externalNode.addAbstractEdge(node.get()); 169 } 170 return node.get(); 171 } 172 173 /// Lookup a call graph node for the given region, or nullptr if none is 174 /// registered. 175 CallGraphNode *CallGraph::lookupNode(Region *region) const { 176 auto it = nodes.find(region); 177 return it == nodes.end() ? nullptr : it->second.get(); 178 } 179 180 /// Resolve the callable for given callee to a node in the callgraph, or the 181 /// external node if a valid node was not resolved. 182 CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable, 183 Operation *from) const { 184 // Get the callee operation from the callable. 185 Operation *callee; 186 if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>()) 187 // TODO(riverriddle) Support nested references. 188 callee = SymbolTable::lookupNearestSymbolFrom(from, 189 symbolRef.getRootReference()); 190 else 191 callee = callable.get<ValuePtr>()->getDefiningOp(); 192 193 // If the callee is non-null and is a valid callable object, try to get the 194 // called region from it. 195 if (callee && callee->getNumRegions()) { 196 if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callee)) { 197 if (auto *node = lookupNode(callableOp.getCallableRegion(callable))) 198 return node; 199 } 200 } 201 202 // If we don't have a valid direct region, this is an external call. 203 return getExternalNode(); 204 } 205 206 //===----------------------------------------------------------------------===// 207 // Printing 208 209 /// Dump the graph in a human readable format. 210 void CallGraph::dump() const { print(llvm::errs()); } 211 void CallGraph::print(raw_ostream &os) const { 212 os << "// ---- CallGraph ----\n"; 213 214 // Functor used to output the name for the given node. 215 auto emitNodeName = [&](const CallGraphNode *node) { 216 if (node->isExternal()) { 217 os << "<External-Node>"; 218 return; 219 } 220 221 auto *callableRegion = node->getCallableRegion(); 222 auto *parentOp = callableRegion->getParentOp(); 223 os << "'" << callableRegion->getParentOp()->getName() << "' - Region #" 224 << callableRegion->getRegionNumber(); 225 if (auto attrs = parentOp->getAttrList().getDictionary()) 226 os << " : " << attrs; 227 }; 228 229 for (auto &nodeIt : nodes) { 230 const CallGraphNode *node = nodeIt.second.get(); 231 232 // Dump the header for this node. 233 os << "// - Node : "; 234 emitNodeName(node); 235 os << "\n"; 236 237 // Emit each of the edges. 238 for (auto &edge : *node) { 239 os << "// -- "; 240 if (edge.isCall()) 241 os << "Call"; 242 else if (edge.isChild()) 243 os << "Child"; 244 245 os << "-Edge : "; 246 emitNodeName(edge.getTarget()); 247 os << "\n"; 248 } 249 os << "//\n"; 250 } 251 252 os << "// -- SCCs --\n"; 253 254 for (auto &scc : make_range(llvm::scc_begin(this), llvm::scc_end(this))) { 255 os << "// - SCC : \n"; 256 for (auto &node : scc) { 257 os << "// -- Node :"; 258 emitNodeName(node); 259 os << "\n"; 260 } 261 os << "\n"; 262 } 263 264 os << "// -------------------\n"; 265 } 266