1 //===- Inliner.cpp - Pass to inline function calls ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a basic inlining algorithm that operates bottom up over 10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more 11 // incremental propagation of inlining decisions from the leafs to the roots of 12 // the callgraph. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "mlir/Analysis/CallGraph.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/InliningUtils.h" 21 #include "mlir/Transforms/Passes.h" 22 #include "llvm/ADT/SCCIterator.h" 23 #include "llvm/Support/Debug.h" 24 #include "llvm/Support/Parallel.h" 25 26 #define DEBUG_TYPE "inlining" 27 28 using namespace mlir; 29 30 static llvm::cl::opt<bool> disableCanonicalization( 31 "mlir-disable-inline-simplify", 32 llvm::cl::desc("Disable running simplifications during inlining"), 33 llvm::cl::ReallyHidden, llvm::cl::init(false)); 34 35 static llvm::cl::opt<unsigned> maxInliningIterations( 36 "mlir-max-inline-iterations", 37 llvm::cl::desc("Maximum number of iterations when inlining within an SCC"), 38 llvm::cl::ReallyHidden, llvm::cl::init(4)); 39 40 //===----------------------------------------------------------------------===// 41 // CallGraph traversal 42 //===----------------------------------------------------------------------===// 43 44 /// Run a given transformation over the SCCs of the callgraph in a bottom up 45 /// traversal. 46 static void runTransformOnCGSCCs( 47 const CallGraph &cg, 48 function_ref<void(ArrayRef<CallGraphNode *>)> sccTransformer) { 49 std::vector<CallGraphNode *> currentSCCVec; 50 auto cgi = llvm::scc_begin(&cg); 51 while (!cgi.isAtEnd()) { 52 // Copy the current SCC and increment so that the transformer can modify the 53 // SCC without invalidating our iterator. 54 currentSCCVec = *cgi; 55 ++cgi; 56 sccTransformer(currentSCCVec); 57 } 58 } 59 60 namespace { 61 /// This struct represents a resolved call to a given callgraph node. Given that 62 /// the call does not actually contain a direct reference to the 63 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them 64 /// explicitly. 65 struct ResolvedCall { 66 ResolvedCall(CallOpInterface call, CallGraphNode *targetNode) 67 : call(call), targetNode(targetNode) {} 68 CallOpInterface call; 69 CallGraphNode *targetNode; 70 }; 71 } // end anonymous namespace 72 73 /// Collect all of the callable operations within the given range of blocks. If 74 /// `traverseNestedCGNodes` is true, this will also collect call operations 75 /// inside of nested callgraph nodes. 76 static void collectCallOps(iterator_range<Region::iterator> blocks, 77 CallGraph &cg, SmallVectorImpl<ResolvedCall> &calls, 78 bool traverseNestedCGNodes) { 79 SmallVector<Block *, 8> worklist; 80 auto addToWorklist = [&](iterator_range<Region::iterator> blocks) { 81 for (Block &block : blocks) 82 worklist.push_back(&block); 83 }; 84 85 addToWorklist(blocks); 86 while (!worklist.empty()) { 87 for (Operation &op : *worklist.pop_back_val()) { 88 if (auto call = dyn_cast<CallOpInterface>(op)) { 89 // TODO(riverriddle) Support inlining nested call references. 90 CallInterfaceCallable callable = call.getCallableForCallee(); 91 if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) { 92 if (!symRef.isa<FlatSymbolRefAttr>()) 93 continue; 94 } 95 96 CallGraphNode *node = cg.resolveCallable(call); 97 if (!node->isExternal()) 98 calls.emplace_back(call, node); 99 continue; 100 } 101 102 // If this is not a call, traverse the nested regions. If 103 // `traverseNestedCGNodes` is false, then don't traverse nested call graph 104 // regions. 105 for (auto &nestedRegion : op.getRegions()) 106 if (traverseNestedCGNodes || !cg.lookupNode(&nestedRegion)) 107 addToWorklist(nestedRegion); 108 } 109 } 110 } 111 112 //===----------------------------------------------------------------------===// 113 // Inliner 114 //===----------------------------------------------------------------------===// 115 namespace { 116 /// This class provides a specialization of the main inlining interface. 117 struct Inliner : public InlinerInterface { 118 Inliner(MLIRContext *context, CallGraph &cg) 119 : InlinerInterface(context), cg(cg) {} 120 121 /// Process a set of blocks that have been inlined. This callback is invoked 122 /// *before* inlined terminator operations have been processed. 123 void 124 processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final { 125 collectCallOps(inlinedBlocks, cg, calls, /*traverseNestedCGNodes=*/true); 126 } 127 128 /// The current set of call instructions to consider for inlining. 129 SmallVector<ResolvedCall, 8> calls; 130 131 /// The callgraph being operated on. 132 CallGraph &cg; 133 }; 134 } // namespace 135 136 /// Returns true if the given call should be inlined. 137 static bool shouldInline(ResolvedCall &resolvedCall) { 138 // Don't allow inlining terminator calls. We currently don't support this 139 // case. 140 if (resolvedCall.call.getOperation()->isKnownTerminator()) 141 return false; 142 143 // Don't allow inlining if the target is an ancestor of the call. This 144 // prevents inlining recursively. 145 if (resolvedCall.targetNode->getCallableRegion()->isAncestor( 146 resolvedCall.call.getParentRegion())) 147 return false; 148 149 // Otherwise, inline. 150 return true; 151 } 152 153 /// Attempt to inline calls within the given scc. This function returns 154 /// success if any calls were inlined, failure otherwise. 155 static LogicalResult inlineCallsInSCC(Inliner &inliner, 156 ArrayRef<CallGraphNode *> currentSCC) { 157 CallGraph &cg = inliner.cg; 158 auto &calls = inliner.calls; 159 160 // Collect all of the direct calls within the nodes of the current SCC. We 161 // don't traverse nested callgraph nodes, because they are handled separately 162 // likely within a different SCC. 163 for (auto *node : currentSCC) { 164 if (!node->isExternal()) 165 collectCallOps(*node->getCallableRegion(), cg, calls, 166 /*traverseNestedCGNodes=*/false); 167 } 168 if (calls.empty()) 169 return failure(); 170 171 // Try to inline each of the call operations. Don't cache the end iterator 172 // here as more calls may be added during inlining. 173 bool inlinedAnyCalls = false; 174 for (unsigned i = 0; i != calls.size(); ++i) { 175 ResolvedCall &it = calls[i]; 176 LLVM_DEBUG({ 177 llvm::dbgs() << "* Considering inlining call: "; 178 it.call.dump(); 179 }); 180 if (!shouldInline(it)) 181 continue; 182 183 CallOpInterface call = it.call; 184 Region *targetRegion = it.targetNode->getCallableRegion(); 185 LogicalResult inlineResult = inlineCall( 186 inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()), 187 targetRegion); 188 if (failed(inlineResult)) 189 continue; 190 191 // If the inlining was successful, then erase the call. 192 call.erase(); 193 inlinedAnyCalls = true; 194 } 195 calls.clear(); 196 return success(inlinedAnyCalls); 197 } 198 199 /// Canonicalize the nodes within the given SCC with the given set of 200 /// canonicalization patterns. 201 static void canonicalizeSCC(CallGraph &cg, ArrayRef<CallGraphNode *> currentSCC, 202 MLIRContext *context, 203 const OwningRewritePatternList &canonPatterns) { 204 // Collect the sets of nodes to canonicalize. 205 SmallVector<CallGraphNode *, 4> nodesToCanonicalize; 206 for (auto *node : currentSCC) { 207 // Don't canonicalize the external node, it has no valid callable region. 208 if (node->isExternal()) 209 continue; 210 211 // Don't canonicalize nodes with children. Nodes with children 212 // require special handling as we may remove the node during 213 // canonicalization. In the future, we should be able to handle this 214 // case with proper node deletion tracking. 215 if (node->hasChildren()) 216 continue; 217 218 // We also won't apply canonicalizations for nodes that are not 219 // isolated. This avoids potentially mutating the regions of nodes defined 220 // above, this is also a stipulation of the 'applyPatternsGreedily' driver. 221 auto *region = node->getCallableRegion(); 222 if (!region->getParentOp()->isKnownIsolatedFromAbove()) 223 continue; 224 nodesToCanonicalize.push_back(node); 225 } 226 if (nodesToCanonicalize.empty()) 227 return; 228 229 // Canonicalize each of the nodes within the SCC in parallel. 230 // NOTE: This is simple now, because we don't enable canonicalizing nodes 231 // within children. When we remove this restriction, this logic will need to 232 // be reworked. 233 ParallelDiagnosticHandler canonicalizationHandler(context); 234 llvm::parallel::for_each_n( 235 llvm::parallel::par, /*Begin=*/size_t(0), 236 /*End=*/nodesToCanonicalize.size(), [&](size_t index) { 237 // Set the order for this thread so that diagnostics will be properly 238 // ordered. 239 canonicalizationHandler.setOrderIDForThread(index); 240 241 // Apply the canonicalization patterns to this region. 242 auto *node = nodesToCanonicalize[index]; 243 applyPatternsGreedily(*node->getCallableRegion(), canonPatterns); 244 245 // Make sure to reset the order ID for the diagnostic handler, as this 246 // thread may be used in a different context. 247 canonicalizationHandler.eraseOrderIDForThread(); 248 }); 249 } 250 251 /// Attempt to inline calls within the given scc, and run canonicalizations with 252 /// the given patterns, until a fixed point is reached. This allows for the 253 /// inlining of newly devirtualized calls. 254 static void inlineSCC(Inliner &inliner, ArrayRef<CallGraphNode *> currentSCC, 255 MLIRContext *context, 256 const OwningRewritePatternList &canonPatterns) { 257 // If we successfully inlined any calls, run some simplifications on the 258 // nodes of the scc. Continue attempting to inline until we reach a fixed 259 // point, or a maximum iteration count. We canonicalize here as it may 260 // devirtualize new calls, as well as give us a better cost model. 261 unsigned iterationCount = 0; 262 while (succeeded(inlineCallsInSCC(inliner, currentSCC))) { 263 // If we aren't allowing simplifications or the max iteration count was 264 // reached, then bail out early. 265 if (disableCanonicalization || ++iterationCount >= maxInliningIterations) 266 break; 267 canonicalizeSCC(inliner.cg, currentSCC, context, canonPatterns); 268 } 269 } 270 271 //===----------------------------------------------------------------------===// 272 // InlinerPass 273 //===----------------------------------------------------------------------===// 274 275 // TODO(riverriddle) This pass should currently only be used for basic testing 276 // of inlining functionality. 277 namespace { 278 struct InlinerPass : public OperationPass<InlinerPass> { 279 void runOnOperation() override { 280 CallGraph &cg = getAnalysis<CallGraph>(); 281 auto *context = &getContext(); 282 283 // The inliner should only be run on operations that define a symbol table, 284 // as the callgraph will need to resolve references. 285 Operation *op = getOperation(); 286 if (!op->hasTrait<OpTrait::SymbolTable>()) { 287 op->emitOpError() << " was scheduled to run under the inliner, but does " 288 "not define a symbol table"; 289 return signalPassFailure(); 290 } 291 292 // Collect a set of canonicalization patterns to use when simplifying 293 // callable regions within an SCC. 294 OwningRewritePatternList canonPatterns; 295 for (auto *op : context->getRegisteredOperations()) 296 op->getCanonicalizationPatterns(canonPatterns, context); 297 298 // Run the inline transform in post-order over the SCCs in the callgraph. 299 Inliner inliner(context, cg); 300 runTransformOnCGSCCs(cg, [&](ArrayRef<CallGraphNode *> scc) { 301 inlineSCC(inliner, scc, context, canonPatterns); 302 }); 303 } 304 }; 305 } // end anonymous namespace 306 307 std::unique_ptr<Pass> mlir::createInlinerPass() { 308 return std::make_unique<InlinerPass>(); 309 } 310 311 static PassRegistration<InlinerPass> pass("inline", "Inline function calls"); 312