1 //===- Inliner.cpp - Pass to inline function calls ------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a basic inlining algorithm that operates bottom up over
10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more
11 // incremental propagation of inlining decisions from the leafs to the roots of
12 // the callgraph.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "mlir/Analysis/CallGraph.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/InliningUtils.h"
21 #include "mlir/Transforms/Passes.h"
22 #include "llvm/ADT/SCCIterator.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/Parallel.h"
25 
26 #define DEBUG_TYPE "inlining"
27 
28 using namespace mlir;
29 
30 static llvm::cl::opt<bool> disableCanonicalization(
31     "mlir-disable-inline-simplify",
32     llvm::cl::desc("Disable running simplifications during inlining"),
33     llvm::cl::ReallyHidden, llvm::cl::init(false));
34 
35 static llvm::cl::opt<unsigned> maxInliningIterations(
36     "mlir-max-inline-iterations",
37     llvm::cl::desc("Maximum number of iterations when inlining within an SCC"),
38     llvm::cl::ReallyHidden, llvm::cl::init(4));
39 
40 //===----------------------------------------------------------------------===//
41 // CallGraph traversal
42 //===----------------------------------------------------------------------===//
43 
44 /// Run a given transformation over the SCCs of the callgraph in a bottom up
45 /// traversal.
46 static void runTransformOnCGSCCs(
47     const CallGraph &cg,
48     function_ref<void(ArrayRef<CallGraphNode *>)> sccTransformer) {
49   std::vector<CallGraphNode *> currentSCCVec;
50   auto cgi = llvm::scc_begin(&cg);
51   while (!cgi.isAtEnd()) {
52     // Copy the current SCC and increment so that the transformer can modify the
53     // SCC without invalidating our iterator.
54     currentSCCVec = *cgi;
55     ++cgi;
56     sccTransformer(currentSCCVec);
57   }
58 }
59 
60 namespace {
61 /// This struct represents a resolved call to a given callgraph node. Given that
62 /// the call does not actually contain a direct reference to the
63 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them
64 /// explicitly.
65 struct ResolvedCall {
66   ResolvedCall(CallOpInterface call, CallGraphNode *targetNode)
67       : call(call), targetNode(targetNode) {}
68   CallOpInterface call;
69   CallGraphNode *targetNode;
70 };
71 } // end anonymous namespace
72 
73 /// Collect all of the callable operations within the given range of blocks. If
74 /// `traverseNestedCGNodes` is true, this will also collect call operations
75 /// inside of nested callgraph nodes.
76 static void collectCallOps(iterator_range<Region::iterator> blocks,
77                            CallGraph &cg, SmallVectorImpl<ResolvedCall> &calls,
78                            bool traverseNestedCGNodes) {
79   SmallVector<Block *, 8> worklist;
80   auto addToWorklist = [&](iterator_range<Region::iterator> blocks) {
81     for (Block &block : blocks)
82       worklist.push_back(&block);
83   };
84 
85   addToWorklist(blocks);
86   while (!worklist.empty()) {
87     for (Operation &op : *worklist.pop_back_val()) {
88       if (auto call = dyn_cast<CallOpInterface>(op)) {
89         CallInterfaceCallable callable = call.getCallableForCallee();
90 
91         // TODO(riverriddle) Support inlining nested call references.
92         if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) {
93           if (!symRef.isa<FlatSymbolRefAttr>())
94             continue;
95         }
96 
97         CallGraphNode *node = cg.resolveCallable(callable, &op);
98         if (!node->isExternal())
99           calls.emplace_back(call, node);
100         continue;
101       }
102 
103       // If this is not a call, traverse the nested regions. If
104       // `traverseNestedCGNodes` is false, then don't traverse nested call graph
105       // regions.
106       for (auto &nestedRegion : op.getRegions())
107         if (traverseNestedCGNodes || !cg.lookupNode(&nestedRegion))
108           addToWorklist(nestedRegion);
109     }
110   }
111 }
112 
113 //===----------------------------------------------------------------------===//
114 // Inliner
115 //===----------------------------------------------------------------------===//
116 namespace {
117 /// This class provides a specialization of the main inlining interface.
118 struct Inliner : public InlinerInterface {
119   Inliner(MLIRContext *context, CallGraph &cg)
120       : InlinerInterface(context), cg(cg) {}
121 
122   /// Process a set of blocks that have been inlined. This callback is invoked
123   /// *before* inlined terminator operations have been processed.
124   void
125   processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final {
126     collectCallOps(inlinedBlocks, cg, calls, /*traverseNestedCGNodes=*/true);
127   }
128 
129   /// The current set of call instructions to consider for inlining.
130   SmallVector<ResolvedCall, 8> calls;
131 
132   /// The callgraph being operated on.
133   CallGraph &cg;
134 };
135 } // namespace
136 
137 /// Returns true if the given call should be inlined.
138 static bool shouldInline(ResolvedCall &resolvedCall) {
139   // Don't allow inlining terminator calls. We currently don't support this
140   // case.
141   if (resolvedCall.call.getOperation()->isKnownTerminator())
142     return false;
143 
144   // Don't allow inlining if the target is an ancestor of the call. This
145   // prevents inlining recursively.
146   if (resolvedCall.targetNode->getCallableRegion()->isAncestor(
147           resolvedCall.call.getParentRegion()))
148     return false;
149 
150   // Otherwise, inline.
151   return true;
152 }
153 
154 /// Attempt to inline calls within the given scc. This function returns
155 /// success if any calls were inlined, failure otherwise.
156 static LogicalResult inlineCallsInSCC(Inliner &inliner,
157                                       ArrayRef<CallGraphNode *> currentSCC) {
158   CallGraph &cg = inliner.cg;
159   auto &calls = inliner.calls;
160 
161   // Collect all of the direct calls within the nodes of the current SCC. We
162   // don't traverse nested callgraph nodes, because they are handled separately
163   // likely within a different SCC.
164   for (auto *node : currentSCC) {
165     if (!node->isExternal())
166       collectCallOps(*node->getCallableRegion(), cg, calls,
167                      /*traverseNestedCGNodes=*/false);
168   }
169   if (calls.empty())
170     return failure();
171 
172   // Try to inline each of the call operations. Don't cache the end iterator
173   // here as more calls may be added during inlining.
174   bool inlinedAnyCalls = false;
175   for (unsigned i = 0; i != calls.size(); ++i) {
176     ResolvedCall &it = calls[i];
177     LLVM_DEBUG({
178       llvm::dbgs() << "* Considering inlining call: ";
179       it.call.dump();
180     });
181     if (!shouldInline(it))
182       continue;
183 
184     CallOpInterface call = it.call;
185     Region *targetRegion = it.targetNode->getCallableRegion();
186     LogicalResult inlineResult = inlineCall(
187         inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()),
188         targetRegion);
189     if (failed(inlineResult))
190       continue;
191 
192     // If the inlining was successful, then erase the call.
193     call.erase();
194     inlinedAnyCalls = true;
195   }
196   calls.clear();
197   return success(inlinedAnyCalls);
198 }
199 
200 /// Canonicalize the nodes within the given SCC with the given set of
201 /// canonicalization patterns.
202 static void canonicalizeSCC(CallGraph &cg, ArrayRef<CallGraphNode *> currentSCC,
203                             MLIRContext *context,
204                             const OwningRewritePatternList &canonPatterns) {
205   // Collect the sets of nodes to canonicalize.
206   SmallVector<CallGraphNode *, 4> nodesToCanonicalize;
207   for (auto *node : currentSCC) {
208     // Don't canonicalize the external node, it has no valid callable region.
209     if (node->isExternal())
210       continue;
211 
212     // Don't canonicalize nodes with children. Nodes with children
213     // require special handling as we may remove the node during
214     // canonicalization. In the future, we should be able to handle this
215     // case with proper node deletion tracking.
216     if (node->hasChildren())
217       continue;
218 
219     // We also won't apply canonicalizations for nodes that are not
220     // isolated. This avoids potentially mutating the regions of nodes defined
221     // above, this is also a stipulation of the 'applyPatternsGreedily' driver.
222     auto *region = node->getCallableRegion();
223     if (!region->getParentOp()->isKnownIsolatedFromAbove())
224       continue;
225     nodesToCanonicalize.push_back(node);
226   }
227   if (nodesToCanonicalize.empty())
228     return;
229 
230   // Canonicalize each of the nodes within the SCC in parallel.
231   // NOTE: This is simple now, because we don't enable canonicalizing nodes
232   // within children. When we remove this restriction, this logic will need to
233   // be reworked.
234   ParallelDiagnosticHandler canonicalizationHandler(context);
235   llvm::parallel::for_each_n(
236       llvm::parallel::par, /*Begin=*/size_t(0),
237       /*End=*/nodesToCanonicalize.size(), [&](size_t index) {
238         // Set the order for this thread so that diagnostics will be properly
239         // ordered.
240         canonicalizationHandler.setOrderIDForThread(index);
241 
242         // Apply the canonicalization patterns to this region.
243         auto *node = nodesToCanonicalize[index];
244         applyPatternsGreedily(*node->getCallableRegion(), canonPatterns);
245 
246         // Make sure to reset the order ID for the diagnostic handler, as this
247         // thread may be used in a different context.
248         canonicalizationHandler.eraseOrderIDForThread();
249       });
250 }
251 
252 /// Attempt to inline calls within the given scc, and run canonicalizations with
253 /// the given patterns, until a fixed point is reached. This allows for the
254 /// inlining of newly devirtualized calls.
255 static void inlineSCC(Inliner &inliner, ArrayRef<CallGraphNode *> currentSCC,
256                       MLIRContext *context,
257                       const OwningRewritePatternList &canonPatterns) {
258   // If we successfully inlined any calls, run some simplifications on the
259   // nodes of the scc. Continue attempting to inline until we reach a fixed
260   // point, or a maximum iteration count. We canonicalize here as it may
261   // devirtualize new calls, as well as give us a better cost model.
262   unsigned iterationCount = 0;
263   while (succeeded(inlineCallsInSCC(inliner, currentSCC))) {
264     // If we aren't allowing simplifications or the max iteration count was
265     // reached, then bail out early.
266     if (disableCanonicalization || ++iterationCount >= maxInliningIterations)
267       break;
268     canonicalizeSCC(inliner.cg, currentSCC, context, canonPatterns);
269   }
270 }
271 
272 //===----------------------------------------------------------------------===//
273 // InlinerPass
274 //===----------------------------------------------------------------------===//
275 
276 // TODO(riverriddle) This pass should currently only be used for basic testing
277 // of inlining functionality.
278 namespace {
279 struct InlinerPass : public OperationPass<InlinerPass> {
280   void runOnOperation() override {
281     CallGraph &cg = getAnalysis<CallGraph>();
282     auto *context = &getContext();
283 
284     // The inliner should only be run on operations that define a symbol table,
285     // as the callgraph will need to resolve references.
286     Operation *op = getOperation();
287     if (!op->hasTrait<OpTrait::SymbolTable>()) {
288       op->emitOpError() << " was scheduled to run under the inliner, but does "
289                            "not define a symbol table";
290       return signalPassFailure();
291     }
292 
293     // Collect a set of canonicalization patterns to use when simplifying
294     // callable regions within an SCC.
295     OwningRewritePatternList canonPatterns;
296     for (auto *op : context->getRegisteredOperations())
297       op->getCanonicalizationPatterns(canonPatterns, context);
298 
299     // Run the inline transform in post-order over the SCCs in the callgraph.
300     Inliner inliner(context, cg);
301     runTransformOnCGSCCs(cg, [&](ArrayRef<CallGraphNode *> scc) {
302       inlineSCC(inliner, scc, context, canonPatterns);
303     });
304   }
305 };
306 } // end anonymous namespace
307 
308 std::unique_ptr<Pass> mlir::createInlinerPass() {
309   return std::make_unique<InlinerPass>();
310 }
311 
312 static PassRegistration<InlinerPass> pass("inline", "Inline function calls");
313