1 //===- Inliner.cpp - Pass to inline function calls ------------------------===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a basic inlining algorithm that operates bottom up over 10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more 11 // incremental propagation of inlining decisions from the leafs to the roots of 12 // the callgraph. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "mlir/Analysis/CallGraph.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/InliningUtils.h" 21 #include "mlir/Transforms/Passes.h" 22 #include "llvm/ADT/SCCIterator.h" 23 #include "llvm/Support/Debug.h" 24 #include "llvm/Support/Parallel.h" 25 26 #define DEBUG_TYPE "inlining" 27 28 using namespace mlir; 29 30 static llvm::cl::opt<bool> disableCanonicalization( 31 "mlir-disable-inline-simplify", 32 llvm::cl::desc("Disable running simplifications during inlining"), 33 llvm::cl::ReallyHidden, llvm::cl::init(false)); 34 35 static llvm::cl::opt<unsigned> maxInliningIterations( 36 "mlir-max-inline-iterations", 37 llvm::cl::desc("Maximum number of iterations when inlining within an SCC"), 38 llvm::cl::ReallyHidden, llvm::cl::init(4)); 39 40 //===----------------------------------------------------------------------===// 41 // CallGraph traversal 42 //===----------------------------------------------------------------------===// 43 44 /// Run a given transformation over the SCCs of the callgraph in a bottom up 45 /// traversal. 46 static void runTransformOnCGSCCs( 47 const CallGraph &cg, 48 function_ref<void(ArrayRef<CallGraphNode *>)> sccTransformer) { 49 std::vector<CallGraphNode *> currentSCCVec; 50 auto cgi = llvm::scc_begin(&cg); 51 while (!cgi.isAtEnd()) { 52 // Copy the current SCC and increment so that the transformer can modify the 53 // SCC without invalidating our iterator. 54 currentSCCVec = *cgi; 55 ++cgi; 56 sccTransformer(currentSCCVec); 57 } 58 } 59 60 namespace { 61 /// This struct represents a resolved call to a given callgraph node. Given that 62 /// the call does not actually contain a direct reference to the 63 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them 64 /// explicitly. 65 struct ResolvedCall { 66 ResolvedCall(CallOpInterface call, CallGraphNode *targetNode) 67 : call(call), targetNode(targetNode) {} 68 CallOpInterface call; 69 CallGraphNode *targetNode; 70 }; 71 } // end anonymous namespace 72 73 /// Collect all of the callable operations within the given range of blocks. If 74 /// `traverseNestedCGNodes` is true, this will also collect call operations 75 /// inside of nested callgraph nodes. 76 static void collectCallOps(iterator_range<Region::iterator> blocks, 77 CallGraph &cg, SmallVectorImpl<ResolvedCall> &calls, 78 bool traverseNestedCGNodes) { 79 SmallVector<Block *, 8> worklist; 80 auto addToWorklist = [&](iterator_range<Region::iterator> blocks) { 81 for (Block &block : blocks) 82 worklist.push_back(&block); 83 }; 84 85 addToWorklist(blocks); 86 while (!worklist.empty()) { 87 for (Operation &op : *worklist.pop_back_val()) { 88 if (auto call = dyn_cast<CallOpInterface>(op)) { 89 CallGraphNode *node = 90 cg.resolveCallable(call.getCallableForCallee(), &op); 91 if (!node->isExternal()) 92 calls.emplace_back(call, node); 93 continue; 94 } 95 96 // If this is not a call, traverse the nested regions. If 97 // `traverseNestedCGNodes` is false, then don't traverse nested call graph 98 // regions. 99 for (auto &nestedRegion : op.getRegions()) 100 if (traverseNestedCGNodes || !cg.lookupNode(&nestedRegion)) 101 addToWorklist(nestedRegion); 102 } 103 } 104 } 105 106 //===----------------------------------------------------------------------===// 107 // Inliner 108 //===----------------------------------------------------------------------===// 109 namespace { 110 /// This class provides a specialization of the main inlining interface. 111 struct Inliner : public InlinerInterface { 112 Inliner(MLIRContext *context, CallGraph &cg) 113 : InlinerInterface(context), cg(cg) {} 114 115 /// Process a set of blocks that have been inlined. This callback is invoked 116 /// *before* inlined terminator operations have been processed. 117 void 118 processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final { 119 collectCallOps(inlinedBlocks, cg, calls, /*traverseNestedCGNodes=*/true); 120 } 121 122 /// The current set of call instructions to consider for inlining. 123 SmallVector<ResolvedCall, 8> calls; 124 125 /// The callgraph being operated on. 126 CallGraph &cg; 127 }; 128 } // namespace 129 130 /// Returns true if the given call should be inlined. 131 static bool shouldInline(ResolvedCall &resolvedCall) { 132 // Don't allow inlining terminator calls. We currently don't support this 133 // case. 134 if (resolvedCall.call.getOperation()->isKnownTerminator()) 135 return false; 136 137 // Don't allow inlining if the target is an ancestor of the call. This 138 // prevents inlining recursively. 139 if (resolvedCall.targetNode->getCallableRegion()->isAncestor( 140 resolvedCall.call.getParentRegion())) 141 return false; 142 143 // Otherwise, inline. 144 return true; 145 } 146 147 /// Attempt to inline calls within the given scc. This function returns 148 /// success if any calls were inlined, failure otherwise. 149 static LogicalResult inlineCallsInSCC(Inliner &inliner, 150 ArrayRef<CallGraphNode *> currentSCC) { 151 CallGraph &cg = inliner.cg; 152 auto &calls = inliner.calls; 153 154 // Collect all of the direct calls within the nodes of the current SCC. We 155 // don't traverse nested callgraph nodes, because they are handled separately 156 // likely within a different SCC. 157 for (auto *node : currentSCC) { 158 if (!node->isExternal()) 159 collectCallOps(*node->getCallableRegion(), cg, calls, 160 /*traverseNestedCGNodes=*/false); 161 } 162 if (calls.empty()) 163 return failure(); 164 165 // Try to inline each of the call operations. Don't cache the end iterator 166 // here as more calls may be added during inlining. 167 bool inlinedAnyCalls = false; 168 for (unsigned i = 0; i != calls.size(); ++i) { 169 ResolvedCall &it = calls[i]; 170 LLVM_DEBUG({ 171 llvm::dbgs() << "* Considering inlining call: "; 172 it.call.dump(); 173 }); 174 if (!shouldInline(it)) 175 continue; 176 177 CallOpInterface call = it.call; 178 Region *targetRegion = it.targetNode->getCallableRegion(); 179 LogicalResult inlineResult = inlineCall( 180 inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()), 181 targetRegion); 182 if (failed(inlineResult)) 183 continue; 184 185 // If the inlining was successful, then erase the call. 186 call.erase(); 187 inlinedAnyCalls = true; 188 } 189 calls.clear(); 190 return success(inlinedAnyCalls); 191 } 192 193 /// Canonicalize the nodes within the given SCC with the given set of 194 /// canonicalization patterns. 195 static void canonicalizeSCC(CallGraph &cg, ArrayRef<CallGraphNode *> currentSCC, 196 MLIRContext *context, 197 const OwningRewritePatternList &canonPatterns) { 198 // Collect the sets of nodes to canonicalize. 199 SmallVector<CallGraphNode *, 4> nodesToCanonicalize; 200 for (auto *node : currentSCC) { 201 // Don't canonicalize the external node, it has no valid callable region. 202 if (node->isExternal()) 203 continue; 204 205 // Don't canonicalize nodes with children. Nodes with children 206 // require special handling as we may remove the node during 207 // canonicalization. In the future, we should be able to handle this 208 // case with proper node deletion tracking. 209 if (node->hasChildren()) 210 continue; 211 212 // We also won't apply canonicalizations for nodes that are not 213 // isolated. This avoids potentially mutating the regions of nodes defined 214 // above, this is also a stipulation of the 'applyPatternsGreedily' driver. 215 auto *region = node->getCallableRegion(); 216 if (!region->getParentOp()->isKnownIsolatedFromAbove()) 217 continue; 218 nodesToCanonicalize.push_back(node); 219 } 220 if (nodesToCanonicalize.empty()) 221 return; 222 223 // Canonicalize each of the nodes within the SCC in parallel. 224 // NOTE: This is simple now, because we don't enable canonicalizing nodes 225 // within children. When we remove this restriction, this logic will need to 226 // be reworked. 227 ParallelDiagnosticHandler canonicalizationHandler(context); 228 llvm::parallel::for_each_n( 229 llvm::parallel::par, /*Begin=*/size_t(0), 230 /*End=*/nodesToCanonicalize.size(), [&](size_t index) { 231 // Set the order for this thread so that diagnostics will be properly 232 // ordered. 233 canonicalizationHandler.setOrderIDForThread(index); 234 235 // Apply the canonicalization patterns to this region. 236 auto *node = nodesToCanonicalize[index]; 237 applyPatternsGreedily(*node->getCallableRegion(), canonPatterns); 238 239 // Make sure to reset the order ID for the diagnostic handler, as this 240 // thread may be used in a different context. 241 canonicalizationHandler.eraseOrderIDForThread(); 242 }); 243 } 244 245 /// Attempt to inline calls within the given scc, and run canonicalizations with 246 /// the given patterns, until a fixed point is reached. This allows for the 247 /// inlining of newly devirtualized calls. 248 static void inlineSCC(Inliner &inliner, ArrayRef<CallGraphNode *> currentSCC, 249 MLIRContext *context, 250 const OwningRewritePatternList &canonPatterns) { 251 // If we successfully inlined any calls, run some simplifications on the 252 // nodes of the scc. Continue attempting to inline until we reach a fixed 253 // point, or a maximum iteration count. We canonicalize here as it may 254 // devirtualize new calls, as well as give us a better cost model. 255 unsigned iterationCount = 0; 256 while (succeeded(inlineCallsInSCC(inliner, currentSCC))) { 257 // If we aren't allowing simplifications or the max iteration count was 258 // reached, then bail out early. 259 if (disableCanonicalization || ++iterationCount >= maxInliningIterations) 260 break; 261 canonicalizeSCC(inliner.cg, currentSCC, context, canonPatterns); 262 } 263 } 264 265 //===----------------------------------------------------------------------===// 266 // InlinerPass 267 //===----------------------------------------------------------------------===// 268 269 // TODO(riverriddle) This pass should currently only be used for basic testing 270 // of inlining functionality. 271 namespace { 272 struct InlinerPass : public OperationPass<InlinerPass> { 273 void runOnOperation() override { 274 CallGraph &cg = getAnalysis<CallGraph>(); 275 auto *context = &getContext(); 276 277 // Collect a set of canonicalization patterns to use when simplifying 278 // callable regions within an SCC. 279 OwningRewritePatternList canonPatterns; 280 for (auto *op : context->getRegisteredOperations()) 281 op->getCanonicalizationPatterns(canonPatterns, context); 282 283 // Run the inline transform in post-order over the SCCs in the callgraph. 284 Inliner inliner(context, cg); 285 runTransformOnCGSCCs(cg, [&](ArrayRef<CallGraphNode *> scc) { 286 inlineSCC(inliner, scc, context, canonPatterns); 287 }); 288 } 289 }; 290 } // end anonymous namespace 291 292 std::unique_ptr<Pass> mlir::createInlinerPass() { 293 return std::make_unique<InlinerPass>(); 294 } 295 296 static PassRegistration<InlinerPass> pass("inline", "Inline function calls"); 297