1 //===- Inliner.cpp - Pass to inline function calls ------------------------===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a basic inlining algorithm that operates bottom up over 10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more 11 // incremental propagation of inlining decisions from the leafs to the roots of 12 // the callgraph. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "mlir/Analysis/CallGraph.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/InliningUtils.h" 21 #include "mlir/Transforms/Passes.h" 22 #include "llvm/ADT/SCCIterator.h" 23 #include "llvm/Support/Debug.h" 24 #include "llvm/Support/Parallel.h" 25 26 #define DEBUG_TYPE "inlining" 27 28 using namespace mlir; 29 30 static llvm::cl::opt<bool> disableCanonicalization( 31 "mlir-disable-inline-simplify", 32 llvm::cl::desc("Disable running simplifications during inlining"), 33 llvm::cl::ReallyHidden, llvm::cl::init(false)); 34 35 static llvm::cl::opt<unsigned> maxInliningIterations( 36 "mlir-max-inline-iterations", 37 llvm::cl::desc("Maximum number of iterations when inlining within an SCC"), 38 llvm::cl::ReallyHidden, llvm::cl::init(4)); 39 40 //===----------------------------------------------------------------------===// 41 // CallGraph traversal 42 //===----------------------------------------------------------------------===// 43 44 /// Run a given transformation over the SCCs of the callgraph in a bottom up 45 /// traversal. 46 static void runTransformOnCGSCCs( 47 const CallGraph &cg, 48 function_ref<void(ArrayRef<CallGraphNode *>)> sccTransformer) { 49 std::vector<CallGraphNode *> currentSCCVec; 50 auto cgi = llvm::scc_begin(&cg); 51 while (!cgi.isAtEnd()) { 52 // Copy the current SCC and increment so that the transformer can modify the 53 // SCC without invalidating our iterator. 54 currentSCCVec = *cgi; 55 ++cgi; 56 sccTransformer(currentSCCVec); 57 } 58 } 59 60 namespace { 61 /// This struct represents a resolved call to a given callgraph node. Given that 62 /// the call does not actually contain a direct reference to the 63 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them 64 /// explicitly. 65 struct ResolvedCall { 66 ResolvedCall(CallOpInterface call, CallGraphNode *targetNode) 67 : call(call), targetNode(targetNode) {} 68 CallOpInterface call; 69 CallGraphNode *targetNode; 70 }; 71 } // end anonymous namespace 72 73 /// Collect all of the callable operations within the given range of blocks. If 74 /// `traverseNestedCGNodes` is true, this will also collect call operations 75 /// inside of nested callgraph nodes. 76 static void collectCallOps(iterator_range<Region::iterator> blocks, 77 CallGraph &cg, SmallVectorImpl<ResolvedCall> &calls, 78 bool traverseNestedCGNodes) { 79 SmallVector<Block *, 8> worklist; 80 auto addToWorklist = [&](iterator_range<Region::iterator> blocks) { 81 for (Block &block : blocks) 82 worklist.push_back(&block); 83 }; 84 85 addToWorklist(blocks); 86 while (!worklist.empty()) { 87 for (Operation &op : *worklist.pop_back_val()) { 88 if (auto call = dyn_cast<CallOpInterface>(op)) { 89 CallInterfaceCallable callable = call.getCallableForCallee(); 90 91 // TODO(riverriddle) Support inlining nested call references. 92 if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) { 93 if (!symRef.isa<FlatSymbolRefAttr>()) 94 continue; 95 } 96 97 CallGraphNode *node = cg.resolveCallable(callable, &op); 98 if (!node->isExternal()) 99 calls.emplace_back(call, node); 100 continue; 101 } 102 103 // If this is not a call, traverse the nested regions. If 104 // `traverseNestedCGNodes` is false, then don't traverse nested call graph 105 // regions. 106 for (auto &nestedRegion : op.getRegions()) 107 if (traverseNestedCGNodes || !cg.lookupNode(&nestedRegion)) 108 addToWorklist(nestedRegion); 109 } 110 } 111 } 112 113 //===----------------------------------------------------------------------===// 114 // Inliner 115 //===----------------------------------------------------------------------===// 116 namespace { 117 /// This class provides a specialization of the main inlining interface. 118 struct Inliner : public InlinerInterface { 119 Inliner(MLIRContext *context, CallGraph &cg) 120 : InlinerInterface(context), cg(cg) {} 121 122 /// Process a set of blocks that have been inlined. This callback is invoked 123 /// *before* inlined terminator operations have been processed. 124 void 125 processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final { 126 collectCallOps(inlinedBlocks, cg, calls, /*traverseNestedCGNodes=*/true); 127 } 128 129 /// The current set of call instructions to consider for inlining. 130 SmallVector<ResolvedCall, 8> calls; 131 132 /// The callgraph being operated on. 133 CallGraph &cg; 134 }; 135 } // namespace 136 137 /// Returns true if the given call should be inlined. 138 static bool shouldInline(ResolvedCall &resolvedCall) { 139 // Don't allow inlining terminator calls. We currently don't support this 140 // case. 141 if (resolvedCall.call.getOperation()->isKnownTerminator()) 142 return false; 143 144 // Don't allow inlining if the target is an ancestor of the call. This 145 // prevents inlining recursively. 146 if (resolvedCall.targetNode->getCallableRegion()->isAncestor( 147 resolvedCall.call.getParentRegion())) 148 return false; 149 150 // Otherwise, inline. 151 return true; 152 } 153 154 /// Attempt to inline calls within the given scc. This function returns 155 /// success if any calls were inlined, failure otherwise. 156 static LogicalResult inlineCallsInSCC(Inliner &inliner, 157 ArrayRef<CallGraphNode *> currentSCC) { 158 CallGraph &cg = inliner.cg; 159 auto &calls = inliner.calls; 160 161 // Collect all of the direct calls within the nodes of the current SCC. We 162 // don't traverse nested callgraph nodes, because they are handled separately 163 // likely within a different SCC. 164 for (auto *node : currentSCC) { 165 if (!node->isExternal()) 166 collectCallOps(*node->getCallableRegion(), cg, calls, 167 /*traverseNestedCGNodes=*/false); 168 } 169 if (calls.empty()) 170 return failure(); 171 172 // Try to inline each of the call operations. Don't cache the end iterator 173 // here as more calls may be added during inlining. 174 bool inlinedAnyCalls = false; 175 for (unsigned i = 0; i != calls.size(); ++i) { 176 ResolvedCall &it = calls[i]; 177 LLVM_DEBUG({ 178 llvm::dbgs() << "* Considering inlining call: "; 179 it.call.dump(); 180 }); 181 if (!shouldInline(it)) 182 continue; 183 184 CallOpInterface call = it.call; 185 Region *targetRegion = it.targetNode->getCallableRegion(); 186 LogicalResult inlineResult = inlineCall( 187 inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()), 188 targetRegion); 189 if (failed(inlineResult)) 190 continue; 191 192 // If the inlining was successful, then erase the call. 193 call.erase(); 194 inlinedAnyCalls = true; 195 } 196 calls.clear(); 197 return success(inlinedAnyCalls); 198 } 199 200 /// Canonicalize the nodes within the given SCC with the given set of 201 /// canonicalization patterns. 202 static void canonicalizeSCC(CallGraph &cg, ArrayRef<CallGraphNode *> currentSCC, 203 MLIRContext *context, 204 const OwningRewritePatternList &canonPatterns) { 205 // Collect the sets of nodes to canonicalize. 206 SmallVector<CallGraphNode *, 4> nodesToCanonicalize; 207 for (auto *node : currentSCC) { 208 // Don't canonicalize the external node, it has no valid callable region. 209 if (node->isExternal()) 210 continue; 211 212 // Don't canonicalize nodes with children. Nodes with children 213 // require special handling as we may remove the node during 214 // canonicalization. In the future, we should be able to handle this 215 // case with proper node deletion tracking. 216 if (node->hasChildren()) 217 continue; 218 219 // We also won't apply canonicalizations for nodes that are not 220 // isolated. This avoids potentially mutating the regions of nodes defined 221 // above, this is also a stipulation of the 'applyPatternsGreedily' driver. 222 auto *region = node->getCallableRegion(); 223 if (!region->getParentOp()->isKnownIsolatedFromAbove()) 224 continue; 225 nodesToCanonicalize.push_back(node); 226 } 227 if (nodesToCanonicalize.empty()) 228 return; 229 230 // Canonicalize each of the nodes within the SCC in parallel. 231 // NOTE: This is simple now, because we don't enable canonicalizing nodes 232 // within children. When we remove this restriction, this logic will need to 233 // be reworked. 234 ParallelDiagnosticHandler canonicalizationHandler(context); 235 llvm::parallel::for_each_n( 236 llvm::parallel::par, /*Begin=*/size_t(0), 237 /*End=*/nodesToCanonicalize.size(), [&](size_t index) { 238 // Set the order for this thread so that diagnostics will be properly 239 // ordered. 240 canonicalizationHandler.setOrderIDForThread(index); 241 242 // Apply the canonicalization patterns to this region. 243 auto *node = nodesToCanonicalize[index]; 244 applyPatternsGreedily(*node->getCallableRegion(), canonPatterns); 245 246 // Make sure to reset the order ID for the diagnostic handler, as this 247 // thread may be used in a different context. 248 canonicalizationHandler.eraseOrderIDForThread(); 249 }); 250 } 251 252 /// Attempt to inline calls within the given scc, and run canonicalizations with 253 /// the given patterns, until a fixed point is reached. This allows for the 254 /// inlining of newly devirtualized calls. 255 static void inlineSCC(Inliner &inliner, ArrayRef<CallGraphNode *> currentSCC, 256 MLIRContext *context, 257 const OwningRewritePatternList &canonPatterns) { 258 // If we successfully inlined any calls, run some simplifications on the 259 // nodes of the scc. Continue attempting to inline until we reach a fixed 260 // point, or a maximum iteration count. We canonicalize here as it may 261 // devirtualize new calls, as well as give us a better cost model. 262 unsigned iterationCount = 0; 263 while (succeeded(inlineCallsInSCC(inliner, currentSCC))) { 264 // If we aren't allowing simplifications or the max iteration count was 265 // reached, then bail out early. 266 if (disableCanonicalization || ++iterationCount >= maxInliningIterations) 267 break; 268 canonicalizeSCC(inliner.cg, currentSCC, context, canonPatterns); 269 } 270 } 271 272 //===----------------------------------------------------------------------===// 273 // InlinerPass 274 //===----------------------------------------------------------------------===// 275 276 // TODO(riverriddle) This pass should currently only be used for basic testing 277 // of inlining functionality. 278 namespace { 279 struct InlinerPass : public OperationPass<InlinerPass> { 280 void runOnOperation() override { 281 CallGraph &cg = getAnalysis<CallGraph>(); 282 auto *context = &getContext(); 283 284 // The inliner should only be run on operations that define a symbol table, 285 // as the callgraph will need to resolve references. 286 Operation *op = getOperation(); 287 if (!op->hasTrait<OpTrait::SymbolTable>()) { 288 op->emitOpError() << " was scheduled to run under the inliner, but does " 289 "not define a symbol table"; 290 return signalPassFailure(); 291 } 292 293 // Collect a set of canonicalization patterns to use when simplifying 294 // callable regions within an SCC. 295 OwningRewritePatternList canonPatterns; 296 for (auto *op : context->getRegisteredOperations()) 297 op->getCanonicalizationPatterns(canonPatterns, context); 298 299 // Run the inline transform in post-order over the SCCs in the callgraph. 300 Inliner inliner(context, cg); 301 runTransformOnCGSCCs(cg, [&](ArrayRef<CallGraphNode *> scc) { 302 inlineSCC(inliner, scc, context, canonPatterns); 303 }); 304 } 305 }; 306 } // end anonymous namespace 307 308 std::unique_ptr<Pass> mlir::createInlinerPass() { 309 return std::make_unique<InlinerPass>(); 310 } 311 312 static PassRegistration<InlinerPass> pass("inline", "Inline function calls"); 313