1 //===- Inliner.cpp - Pass to inline function calls ------------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file implements a basic inlining algorithm that operates bottom up over
19 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more
20 // incremental propagation of inlining decisions from the leafs to the roots of
21 // the callgraph.
22 //
23 //===----------------------------------------------------------------------===//
24 
25 #include "mlir/Analysis/CallGraph.h"
26 #include "mlir/IR/Builders.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/Pass/Pass.h"
29 #include "mlir/Transforms/InliningUtils.h"
30 #include "mlir/Transforms/Passes.h"
31 #include "llvm/ADT/SCCIterator.h"
32 #include "llvm/Support/Parallel.h"
33 
34 using namespace mlir;
35 
36 static llvm::cl::opt<bool> disableCanonicalization(
37     "mlir-disable-inline-simplify",
38     llvm::cl::desc("Disable running simplifications during inlining"),
39     llvm::cl::ReallyHidden, llvm::cl::init(false));
40 
41 static llvm::cl::opt<unsigned> maxInliningIterations(
42     "mlir-max-inline-iterations",
43     llvm::cl::desc("Maximum number of iterations when inlining within an SCC"),
44     llvm::cl::ReallyHidden, llvm::cl::init(4));
45 
46 //===----------------------------------------------------------------------===//
47 // CallGraph traversal
48 //===----------------------------------------------------------------------===//
49 
50 /// Run a given transformation over the SCCs of the callgraph in a bottom up
51 /// traversal.
52 static void runTransformOnCGSCCs(
53     const CallGraph &cg,
54     function_ref<void(ArrayRef<CallGraphNode *>)> sccTransformer) {
55   std::vector<CallGraphNode *> currentSCCVec;
56   auto cgi = llvm::scc_begin(&cg);
57   while (!cgi.isAtEnd()) {
58     // Copy the current SCC and increment so that the transformer can modify the
59     // SCC without invalidating our iterator.
60     currentSCCVec = *cgi;
61     ++cgi;
62     sccTransformer(currentSCCVec);
63   }
64 }
65 
66 namespace {
67 /// This struct represents a resolved call to a given callgraph node. Given that
68 /// the call does not actually contain a direct reference to the
69 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them
70 /// explicitly.
71 struct ResolvedCall {
72   ResolvedCall(CallOpInterface call, CallGraphNode *targetNode)
73       : call(call), targetNode(targetNode) {}
74   CallOpInterface call;
75   CallGraphNode *targetNode;
76 };
77 } // end anonymous namespace
78 
79 /// Collect all of the callable operations within the given range of blocks. If
80 /// `traverseNestedCGNodes` is true, this will also collect call operations
81 /// inside of nested callgraph nodes.
82 static void collectCallOps(llvm::iterator_range<Region::iterator> blocks,
83                            CallGraph &cg, SmallVectorImpl<ResolvedCall> &calls,
84                            bool traverseNestedCGNodes) {
85   SmallVector<Block *, 8> worklist;
86   auto addToWorklist = [&](llvm::iterator_range<Region::iterator> blocks) {
87     for (Block &block : blocks)
88       worklist.push_back(&block);
89   };
90 
91   addToWorklist(blocks);
92   while (!worklist.empty()) {
93     for (Operation &op : *worklist.pop_back_val()) {
94       if (auto call = dyn_cast<CallOpInterface>(op)) {
95         CallGraphNode *node =
96             cg.resolveCallable(call.getCallableForCallee(), &op);
97         if (!node->isExternal())
98           calls.emplace_back(call, node);
99         continue;
100       }
101 
102       // If this is not a call, traverse the nested regions. If
103       // `traverseNestedCGNodes` is false, then don't traverse nested call graph
104       // regions.
105       for (auto &nestedRegion : op.getRegions())
106         if (traverseNestedCGNodes || !cg.lookupNode(&nestedRegion))
107           addToWorklist(nestedRegion);
108     }
109   }
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // Inliner
114 //===----------------------------------------------------------------------===//
115 namespace {
116 /// This class provides a specialization of the main inlining interface.
117 struct Inliner : public InlinerInterface {
118   Inliner(MLIRContext *context, CallGraph &cg)
119       : InlinerInterface(context), cg(cg) {}
120 
121   /// Process a set of blocks that have been inlined. This callback is invoked
122   /// *before* inlined terminator operations have been processed.
123   void processInlinedBlocks(
124       llvm::iterator_range<Region::iterator> inlinedBlocks) final {
125     collectCallOps(inlinedBlocks, cg, calls, /*traverseNestedCGNodes=*/true);
126   }
127 
128   /// The current set of call instructions to consider for inlining.
129   SmallVector<ResolvedCall, 8> calls;
130 
131   /// The callgraph being operated on.
132   CallGraph &cg;
133 };
134 } // namespace
135 
136 /// Returns true if the given call should be inlined.
137 static bool shouldInline(ResolvedCall &resolvedCall) {
138   // Don't allow inlining terminator calls. We currently don't support this
139   // case.
140   if (resolvedCall.call.getOperation()->isKnownTerminator())
141     return false;
142 
143   // Don't allow inlining if the target is an ancestor of the call. This
144   // prevents inlining recursively.
145   if (resolvedCall.targetNode->getCallableRegion()->isAncestor(
146           resolvedCall.call.getParentRegion()))
147     return false;
148 
149   // Otherwise, inline.
150   return true;
151 }
152 
153 /// Attempt to inline calls within the given scc. This function returns
154 /// success if any calls were inlined, failure otherwise.
155 static LogicalResult inlineCallsInSCC(Inliner &inliner,
156                                       ArrayRef<CallGraphNode *> currentSCC) {
157   CallGraph &cg = inliner.cg;
158   auto &calls = inliner.calls;
159 
160   // Collect all of the direct calls within the nodes of the current SCC. We
161   // don't traverse nested callgraph nodes, because they are handled separately
162   // likely within a different SCC.
163   for (auto *node : currentSCC) {
164     if (!node->isExternal())
165       collectCallOps(*node->getCallableRegion(), cg, calls,
166                      /*traverseNestedCGNodes=*/false);
167   }
168   if (calls.empty())
169     return failure();
170 
171   // Try to inline each of the call operations. Don't cache the end iterator
172   // here as more calls may be added during inlining.
173   bool inlinedAnyCalls = false;
174   for (unsigned i = 0; i != calls.size(); ++i) {
175     ResolvedCall &it = calls[i];
176     if (!shouldInline(it))
177       continue;
178 
179     CallOpInterface call = it.call;
180     Region *targetRegion = it.targetNode->getCallableRegion();
181     LogicalResult inlineResult = inlineCall(
182         inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()),
183         targetRegion);
184     if (failed(inlineResult))
185       continue;
186 
187     // If the inlining was successful, then erase the call.
188     call.erase();
189     inlinedAnyCalls = true;
190   }
191   calls.clear();
192   return success(inlinedAnyCalls);
193 }
194 
195 /// Canonicalize the nodes within the given SCC with the given set of
196 /// canonicalization patterns.
197 static void canonicalizeSCC(CallGraph &cg, ArrayRef<CallGraphNode *> currentSCC,
198                             MLIRContext *context,
199                             const OwningRewritePatternList &canonPatterns) {
200   // Collect the sets of nodes to canonicalize.
201   SmallVector<CallGraphNode *, 4> nodesToCanonicalize;
202   for (auto *node : currentSCC) {
203     // Don't canonicalize the external node, it has no valid callable region.
204     if (node->isExternal())
205       continue;
206 
207     // Don't canonicalize nodes with children. Nodes with children
208     // require special handling as we may remove the node during
209     // canonicalization. In the future, we should be able to handle this
210     // case with proper node deletion tracking.
211     if (node->hasChildren())
212       continue;
213 
214     // We also won't apply canonicalizations for nodes that are not
215     // isolated. This avoids potentially mutating the regions of nodes defined
216     // above, this is also a stipulation of the 'applyPatternsGreedily' driver.
217     auto *region = node->getCallableRegion();
218     if (!region->getParentOp()->isKnownIsolatedFromAbove())
219       continue;
220     nodesToCanonicalize.push_back(node);
221   }
222   if (nodesToCanonicalize.empty())
223     return;
224 
225   // Canonicalize each of the nodes within the SCC in parallel.
226   // NOTE: This is simple now, because we don't enable canonicalizing nodes
227   // within children. When we remove this restriction, this logic will need to
228   // be reworked.
229   ParallelDiagnosticHandler canonicalizationHandler(context);
230   llvm::parallel::for_each_n(
231       llvm::parallel::par, /*Begin=*/size_t(0),
232       /*End=*/nodesToCanonicalize.size(), [&](size_t index) {
233         // Set the order for this thread so that diagnostics will be properly
234         // ordered.
235         canonicalizationHandler.setOrderIDForThread(index);
236 
237         // Apply the canonicalization patterns to this region.
238         auto *node = nodesToCanonicalize[index];
239         applyPatternsGreedily(*node->getCallableRegion(), canonPatterns);
240 
241         // Make sure to reset the order ID for the diagnostic handler, as this
242         // thread may be used in a different context.
243         canonicalizationHandler.eraseOrderIDForThread();
244       });
245 }
246 
247 /// Attempt to inline calls within the given scc, and run canonicalizations with
248 /// the given patterns, until a fixed point is reached. This allows for the
249 /// inlining of newly devirtualized calls.
250 static void inlineSCC(Inliner &inliner, ArrayRef<CallGraphNode *> currentSCC,
251                       MLIRContext *context,
252                       const OwningRewritePatternList &canonPatterns) {
253   // If we successfully inlined any calls, run some simplifications on the
254   // nodes of the scc. Continue attempting to inline until we reach a fixed
255   // point, or a maximum iteration count. We canonicalize here as it may
256   // devirtualize new calls, as well as give us a better cost model.
257   unsigned iterationCount = 0;
258   while (succeeded(inlineCallsInSCC(inliner, currentSCC))) {
259     // If we aren't allowing simplifications or the max iteration count was
260     // reached, then bail out early.
261     if (disableCanonicalization || ++iterationCount >= maxInliningIterations)
262       break;
263     canonicalizeSCC(inliner.cg, currentSCC, context, canonPatterns);
264   }
265 }
266 
267 //===----------------------------------------------------------------------===//
268 // InlinerPass
269 //===----------------------------------------------------------------------===//
270 
271 // TODO(riverriddle) This pass should currently only be used for basic testing
272 // of inlining functionality.
273 namespace {
274 struct InlinerPass : public OperationPass<InlinerPass> {
275   void runOnOperation() override {
276     CallGraph &cg = getAnalysis<CallGraph>();
277     auto *context = &getContext();
278 
279     // Collect a set of canonicalization patterns to use when simplifying
280     // callable regions within an SCC.
281     OwningRewritePatternList canonPatterns;
282     for (auto *op : context->getRegisteredOperations())
283       op->getCanonicalizationPatterns(canonPatterns, context);
284 
285     // Run the inline transform in post-order over the SCCs in the callgraph.
286     Inliner inliner(context, cg);
287     runTransformOnCGSCCs(cg, [&](ArrayRef<CallGraphNode *> scc) {
288       inlineSCC(inliner, scc, context, canonPatterns);
289     });
290   }
291 };
292 } // end anonymous namespace
293 
294 std::unique_ptr<Pass> mlir::createInlinerPass() {
295   return std::make_unique<InlinerPass>();
296 }
297 
298 static PassRegistration<InlinerPass> pass("inline", "Inline function calls");
299