1 //===- Inliner.cpp - Pass to inline function calls ------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements a basic inlining algorithm that operates bottom up over 19 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more 20 // incremental propagation of inlining decisions from the leafs to the roots of 21 // the callgraph. 22 // 23 //===----------------------------------------------------------------------===// 24 25 #include "mlir/Analysis/CallGraph.h" 26 #include "mlir/IR/Builders.h" 27 #include "mlir/IR/PatternMatch.h" 28 #include "mlir/Pass/Pass.h" 29 #include "mlir/Transforms/InliningUtils.h" 30 #include "mlir/Transforms/Passes.h" 31 #include "llvm/ADT/SCCIterator.h" 32 #include "llvm/Support/Parallel.h" 33 34 using namespace mlir; 35 36 static llvm::cl::opt<bool> disableCanonicalization( 37 "mlir-disable-inline-simplify", 38 llvm::cl::desc("Disable running simplifications during inlining"), 39 llvm::cl::ReallyHidden, llvm::cl::init(false)); 40 41 static llvm::cl::opt<unsigned> maxInliningIterations( 42 "mlir-max-inline-iterations", 43 llvm::cl::desc("Maximum number of iterations when inlining within an SCC"), 44 llvm::cl::ReallyHidden, llvm::cl::init(4)); 45 46 //===----------------------------------------------------------------------===// 47 // CallGraph traversal 48 //===----------------------------------------------------------------------===// 49 50 /// Run a given transformation over the SCCs of the callgraph in a bottom up 51 /// traversal. 52 static void runTransformOnCGSCCs( 53 const CallGraph &cg, 54 function_ref<void(ArrayRef<CallGraphNode *>)> sccTransformer) { 55 std::vector<CallGraphNode *> currentSCCVec; 56 auto cgi = llvm::scc_begin(&cg); 57 while (!cgi.isAtEnd()) { 58 // Copy the current SCC and increment so that the transformer can modify the 59 // SCC without invalidating our iterator. 60 currentSCCVec = *cgi; 61 ++cgi; 62 sccTransformer(currentSCCVec); 63 } 64 } 65 66 namespace { 67 /// This struct represents a resolved call to a given callgraph node. Given that 68 /// the call does not actually contain a direct reference to the 69 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them 70 /// explicitly. 71 struct ResolvedCall { 72 ResolvedCall(CallOpInterface call, CallGraphNode *targetNode) 73 : call(call), targetNode(targetNode) {} 74 CallOpInterface call; 75 CallGraphNode *targetNode; 76 }; 77 } // end anonymous namespace 78 79 /// Collect all of the callable operations within the given range of blocks. If 80 /// `traverseNestedCGNodes` is true, this will also collect call operations 81 /// inside of nested callgraph nodes. 82 static void collectCallOps(llvm::iterator_range<Region::iterator> blocks, 83 CallGraph &cg, SmallVectorImpl<ResolvedCall> &calls, 84 bool traverseNestedCGNodes) { 85 SmallVector<Block *, 8> worklist; 86 auto addToWorklist = [&](llvm::iterator_range<Region::iterator> blocks) { 87 for (Block &block : blocks) 88 worklist.push_back(&block); 89 }; 90 91 addToWorklist(blocks); 92 while (!worklist.empty()) { 93 for (Operation &op : *worklist.pop_back_val()) { 94 if (auto call = dyn_cast<CallOpInterface>(op)) { 95 CallGraphNode *node = 96 cg.resolveCallable(call.getCallableForCallee(), &op); 97 if (!node->isExternal()) 98 calls.emplace_back(call, node); 99 continue; 100 } 101 102 // If this is not a call, traverse the nested regions. If 103 // `traverseNestedCGNodes` is false, then don't traverse nested call graph 104 // regions. 105 for (auto &nestedRegion : op.getRegions()) 106 if (traverseNestedCGNodes || !cg.lookupNode(&nestedRegion)) 107 addToWorklist(nestedRegion); 108 } 109 } 110 } 111 112 //===----------------------------------------------------------------------===// 113 // Inliner 114 //===----------------------------------------------------------------------===// 115 namespace { 116 /// This class provides a specialization of the main inlining interface. 117 struct Inliner : public InlinerInterface { 118 Inliner(MLIRContext *context, CallGraph &cg) 119 : InlinerInterface(context), cg(cg) {} 120 121 /// Process a set of blocks that have been inlined. This callback is invoked 122 /// *before* inlined terminator operations have been processed. 123 void processInlinedBlocks( 124 llvm::iterator_range<Region::iterator> inlinedBlocks) final { 125 collectCallOps(inlinedBlocks, cg, calls, /*traverseNestedCGNodes=*/true); 126 } 127 128 /// The current set of call instructions to consider for inlining. 129 SmallVector<ResolvedCall, 8> calls; 130 131 /// The callgraph being operated on. 132 CallGraph &cg; 133 }; 134 } // namespace 135 136 /// Returns true if the given call should be inlined. 137 static bool shouldInline(ResolvedCall &resolvedCall) { 138 // Don't allow inlining terminator calls. We currently don't support this 139 // case. 140 if (resolvedCall.call.getOperation()->isKnownTerminator()) 141 return false; 142 143 // Don't allow inlining if the target is an ancestor of the call. This 144 // prevents inlining recursively. 145 if (resolvedCall.targetNode->getCallableRegion()->isAncestor( 146 resolvedCall.call.getParentRegion())) 147 return false; 148 149 // Otherwise, inline. 150 return true; 151 } 152 153 /// Attempt to inline calls within the given scc. This function returns 154 /// success if any calls were inlined, failure otherwise. 155 static LogicalResult inlineCallsInSCC(Inliner &inliner, 156 ArrayRef<CallGraphNode *> currentSCC) { 157 CallGraph &cg = inliner.cg; 158 auto &calls = inliner.calls; 159 160 // Collect all of the direct calls within the nodes of the current SCC. We 161 // don't traverse nested callgraph nodes, because they are handled separately 162 // likely within a different SCC. 163 for (auto *node : currentSCC) { 164 if (!node->isExternal()) 165 collectCallOps(*node->getCallableRegion(), cg, calls, 166 /*traverseNestedCGNodes=*/false); 167 } 168 if (calls.empty()) 169 return failure(); 170 171 // Try to inline each of the call operations. Don't cache the end iterator 172 // here as more calls may be added during inlining. 173 bool inlinedAnyCalls = false; 174 for (unsigned i = 0; i != calls.size(); ++i) { 175 ResolvedCall &it = calls[i]; 176 if (!shouldInline(it)) 177 continue; 178 179 CallOpInterface call = it.call; 180 Region *targetRegion = it.targetNode->getCallableRegion(); 181 LogicalResult inlineResult = inlineCall( 182 inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()), 183 targetRegion); 184 if (failed(inlineResult)) 185 continue; 186 187 // If the inlining was successful, then erase the call. 188 call.erase(); 189 inlinedAnyCalls = true; 190 } 191 calls.clear(); 192 return success(inlinedAnyCalls); 193 } 194 195 /// Canonicalize the nodes within the given SCC with the given set of 196 /// canonicalization patterns. 197 static void canonicalizeSCC(CallGraph &cg, ArrayRef<CallGraphNode *> currentSCC, 198 MLIRContext *context, 199 const OwningRewritePatternList &canonPatterns) { 200 // Collect the sets of nodes to canonicalize. 201 SmallVector<CallGraphNode *, 4> nodesToCanonicalize; 202 for (auto *node : currentSCC) { 203 // Don't canonicalize the external node, it has no valid callable region. 204 if (node->isExternal()) 205 continue; 206 207 // Don't canonicalize nodes with children. Nodes with children 208 // require special handling as we may remove the node during 209 // canonicalization. In the future, we should be able to handle this 210 // case with proper node deletion tracking. 211 if (node->hasChildren()) 212 continue; 213 214 // We also won't apply canonicalizations for nodes that are not 215 // isolated. This avoids potentially mutating the regions of nodes defined 216 // above, this is also a stipulation of the 'applyPatternsGreedily' driver. 217 auto *region = node->getCallableRegion(); 218 if (!region->getParentOp()->isKnownIsolatedFromAbove()) 219 continue; 220 nodesToCanonicalize.push_back(node); 221 } 222 if (nodesToCanonicalize.empty()) 223 return; 224 225 // Canonicalize each of the nodes within the SCC in parallel. 226 // NOTE: This is simple now, because we don't enable canonicalizing nodes 227 // within children. When we remove this restriction, this logic will need to 228 // be reworked. 229 ParallelDiagnosticHandler canonicalizationHandler(context); 230 llvm::parallel::for_each_n( 231 llvm::parallel::par, /*Begin=*/size_t(0), 232 /*End=*/nodesToCanonicalize.size(), [&](size_t index) { 233 // Set the order for this thread so that diagnostics will be properly 234 // ordered. 235 canonicalizationHandler.setOrderIDForThread(index); 236 237 // Apply the canonicalization patterns to this region. 238 auto *node = nodesToCanonicalize[index]; 239 applyPatternsGreedily(*node->getCallableRegion(), canonPatterns); 240 241 // Make sure to reset the order ID for the diagnostic handler, as this 242 // thread may be used in a different context. 243 canonicalizationHandler.eraseOrderIDForThread(); 244 }); 245 } 246 247 /// Attempt to inline calls within the given scc, and run canonicalizations with 248 /// the given patterns, until a fixed point is reached. This allows for the 249 /// inlining of newly devirtualized calls. 250 static void inlineSCC(Inliner &inliner, ArrayRef<CallGraphNode *> currentSCC, 251 MLIRContext *context, 252 const OwningRewritePatternList &canonPatterns) { 253 // If we successfully inlined any calls, run some simplifications on the 254 // nodes of the scc. Continue attempting to inline until we reach a fixed 255 // point, or a maximum iteration count. We canonicalize here as it may 256 // devirtualize new calls, as well as give us a better cost model. 257 unsigned iterationCount = 0; 258 while (succeeded(inlineCallsInSCC(inliner, currentSCC))) { 259 // If we aren't allowing simplifications or the max iteration count was 260 // reached, then bail out early. 261 if (disableCanonicalization || ++iterationCount >= maxInliningIterations) 262 break; 263 canonicalizeSCC(inliner.cg, currentSCC, context, canonPatterns); 264 } 265 } 266 267 //===----------------------------------------------------------------------===// 268 // InlinerPass 269 //===----------------------------------------------------------------------===// 270 271 // TODO(riverriddle) This pass should currently only be used for basic testing 272 // of inlining functionality. 273 namespace { 274 struct InlinerPass : public OperationPass<InlinerPass> { 275 void runOnOperation() override { 276 CallGraph &cg = getAnalysis<CallGraph>(); 277 auto *context = &getContext(); 278 279 // Collect a set of canonicalization patterns to use when simplifying 280 // callable regions within an SCC. 281 OwningRewritePatternList canonPatterns; 282 for (auto *op : context->getRegisteredOperations()) 283 op->getCanonicalizationPatterns(canonPatterns, context); 284 285 // Run the inline transform in post-order over the SCCs in the callgraph. 286 Inliner inliner(context, cg); 287 runTransformOnCGSCCs(cg, [&](ArrayRef<CallGraphNode *> scc) { 288 inlineSCC(inliner, scc, context, canonPatterns); 289 }); 290 } 291 }; 292 } // end anonymous namespace 293 294 std::unique_ptr<Pass> mlir::createInlinerPass() { 295 return std::make_unique<InlinerPass>(); 296 } 297 298 static PassRegistration<InlinerPass> pass("inline", "Inline function calls"); 299