1 //===- Inliner.cpp - Pass to inline function calls ------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements a basic inlining algorithm that operates bottom up over 19 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more 20 // incremental propagation of inlining decisions from the leafs to the roots of 21 // the callgraph. 22 // 23 //===----------------------------------------------------------------------===// 24 25 #include "mlir/Analysis/CallGraph.h" 26 #include "mlir/IR/Builders.h" 27 #include "mlir/IR/Module.h" 28 #include "mlir/Pass/Pass.h" 29 #include "mlir/Transforms/InliningUtils.h" 30 #include "mlir/Transforms/Passes.h" 31 #include "llvm/ADT/SCCIterator.h" 32 33 using namespace mlir; 34 35 //===----------------------------------------------------------------------===// 36 // CallGraph traversal 37 //===----------------------------------------------------------------------===// 38 39 /// Run a given transformation over the SCCs of the callgraph in a bottom up 40 /// traversal. 41 static void runTransformOnCGSCCs( 42 const CallGraph &cg, 43 function_ref<void(ArrayRef<CallGraphNode *>)> sccTransformer) { 44 for (auto cgi = llvm::scc_begin(&cg); !cgi.isAtEnd(); ++cgi) 45 sccTransformer(*cgi); 46 } 47 48 namespace { 49 /// This struct represents a resolved call to a given callgraph node. Given that 50 /// the call does not actually contain a direct reference to the 51 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them 52 /// explicitly. 53 struct ResolvedCall { 54 ResolvedCall(CallOpInterface call, CallGraphNode *targetNode) 55 : call(call), targetNode(targetNode) {} 56 CallOpInterface call; 57 CallGraphNode *targetNode; 58 }; 59 } // end anonymous namespace 60 61 /// Collect all of the callable operations within the given range of blocks. If 62 /// `traverseNestedCGNodes` is true, this will also collect call operations 63 /// inside of nested callgraph nodes. 64 static void collectCallOps(llvm::iterator_range<Region::iterator> blocks, 65 CallGraph &cg, SmallVectorImpl<ResolvedCall> &calls, 66 bool traverseNestedCGNodes) { 67 SmallVector<Block *, 8> worklist; 68 auto addToWorklist = [&](llvm::iterator_range<Region::iterator> blocks) { 69 for (Block &block : blocks) 70 worklist.push_back(&block); 71 }; 72 73 addToWorklist(blocks); 74 while (!worklist.empty()) { 75 for (Operation &op : *worklist.pop_back_val()) { 76 if (auto call = dyn_cast<CallOpInterface>(op)) { 77 CallGraphNode *node = 78 cg.resolveCallable(call.getCallableForCallee(), &op); 79 if (!node->isExternal()) 80 calls.emplace_back(call, node); 81 continue; 82 } 83 84 // If this is not a call, traverse the nested regions. If 85 // `traverseNestedCGNodes` is false, then don't traverse nested call graph 86 // regions. 87 for (auto &nestedRegion : op.getRegions()) 88 if (traverseNestedCGNodes || !cg.lookupNode(&nestedRegion)) 89 addToWorklist(nestedRegion); 90 } 91 } 92 } 93 94 //===----------------------------------------------------------------------===// 95 // Inliner 96 //===----------------------------------------------------------------------===// 97 namespace { 98 /// This class provides a specialization of the main inlining interface. 99 struct Inliner : public InlinerInterface { 100 Inliner(MLIRContext *context, CallGraph &cg) 101 : InlinerInterface(context), cg(cg) {} 102 103 /// Process a set of blocks that have been inlined. This callback is invoked 104 /// *before* inlined terminator operations have been processed. 105 void processInlinedBlocks( 106 llvm::iterator_range<Region::iterator> inlinedBlocks) final { 107 collectCallOps(inlinedBlocks, cg, calls, /*traverseNestedCGNodes=*/true); 108 } 109 110 /// The current set of call instructions to consider for inlining. 111 SmallVector<ResolvedCall, 8> calls; 112 113 /// The callgraph being operated on. 114 CallGraph &cg; 115 }; 116 } // namespace 117 118 /// Returns true if the given call should be inlined. 119 static bool shouldInline(ResolvedCall &resolvedCall) { 120 // Don't allow inlining terminator calls. We currently don't support this 121 // case. 122 if (resolvedCall.call.getOperation()->isKnownTerminator()) 123 return false; 124 125 // Don't allow inlining if the target is an ancestor of the call. This 126 // prevents inlining recursively. 127 if (resolvedCall.targetNode->getCallableRegion()->isAncestor( 128 resolvedCall.call.getParentRegion())) 129 return false; 130 131 // Otherwise, inline. 132 return true; 133 } 134 135 /// Attempt to inline calls within the given scc. 136 static void inlineCallsInSCC(Inliner &inliner, 137 ArrayRef<CallGraphNode *> currentSCC) { 138 CallGraph &cg = inliner.cg; 139 auto &calls = inliner.calls; 140 141 // Collect all of the direct calls within the nodes of the current SCC. We 142 // don't traverse nested callgraph nodes, because they are handled separately 143 // likely within a different SCC. 144 for (auto *node : currentSCC) { 145 if (!node->isExternal()) 146 collectCallOps(*node->getCallableRegion(), cg, calls, 147 /*traverseNestedCGNodes=*/false); 148 } 149 if (calls.empty()) 150 return; 151 152 // Try to inline each of the call operations. Don't cache the end iterator 153 // here as more calls may be added during inlining. 154 for (unsigned i = 0; i != calls.size(); ++i) { 155 ResolvedCall &it = calls[i]; 156 if (!shouldInline(it)) 157 continue; 158 159 CallOpInterface call = it.call; 160 Region *targetRegion = it.targetNode->getCallableRegion(); 161 LogicalResult inlineResult = inlineCall( 162 inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()), 163 targetRegion); 164 if (failed(inlineResult)) 165 continue; 166 167 // If the inlining was successful, then erase the call. 168 call.erase(); 169 } 170 calls.clear(); 171 } 172 173 //===----------------------------------------------------------------------===// 174 // InlinerPass 175 //===----------------------------------------------------------------------===// 176 177 // TODO(riverriddle) This pass should currently only be used for basic testing 178 // of inlining functionality. 179 namespace { 180 struct InlinerPass : public OperationPass<InlinerPass> { 181 void runOnOperation() override { 182 CallGraph &cg = getAnalysis<CallGraph>(); 183 Inliner inliner(&getContext(), cg); 184 185 // Run the inline transform in post-order over the SCCs in the callgraph. 186 runTransformOnCGSCCs(cg, [&](ArrayRef<CallGraphNode *> scc) { 187 inlineCallsInSCC(inliner, scc); 188 }); 189 } 190 }; 191 } // end anonymous namespace 192 193 static PassRegistration<InlinerPass> pass("inline", "Inline function calls"); 194