1 //===- Inliner.cpp - Pass to inline function calls ------------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file implements a basic inlining algorithm that operates bottom up over
19 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more
20 // incremental propagation of inlining decisions from the leafs to the roots of
21 // the callgraph.
22 //
23 //===----------------------------------------------------------------------===//
24 
25 #include "mlir/Analysis/CallGraph.h"
26 #include "mlir/IR/Builders.h"
27 #include "mlir/IR/Module.h"
28 #include "mlir/Pass/Pass.h"
29 #include "mlir/Transforms/InliningUtils.h"
30 #include "mlir/Transforms/Passes.h"
31 #include "llvm/ADT/SCCIterator.h"
32 
33 using namespace mlir;
34 
35 //===----------------------------------------------------------------------===//
36 // CallGraph traversal
37 //===----------------------------------------------------------------------===//
38 
39 /// Run a given transformation over the SCCs of the callgraph in a bottom up
40 /// traversal.
41 static void runTransformOnCGSCCs(
42     const CallGraph &cg,
43     function_ref<void(ArrayRef<CallGraphNode *>)> sccTransformer) {
44   for (auto cgi = llvm::scc_begin(&cg); !cgi.isAtEnd(); ++cgi)
45     sccTransformer(*cgi);
46 }
47 
48 namespace {
49 /// This struct represents a resolved call to a given callgraph node. Given that
50 /// the call does not actually contain a direct reference to the
51 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them
52 /// explicitly.
53 struct ResolvedCall {
54   ResolvedCall(CallOpInterface call, CallGraphNode *targetNode)
55       : call(call), targetNode(targetNode) {}
56   CallOpInterface call;
57   CallGraphNode *targetNode;
58 };
59 } // end anonymous namespace
60 
61 /// Collect all of the callable operations within the given range of blocks. If
62 /// `traverseNestedCGNodes` is true, this will also collect call operations
63 /// inside of nested callgraph nodes.
64 static void collectCallOps(llvm::iterator_range<Region::iterator> blocks,
65                            CallGraph &cg, SmallVectorImpl<ResolvedCall> &calls,
66                            bool traverseNestedCGNodes) {
67   SmallVector<Block *, 8> worklist;
68   auto addToWorklist = [&](llvm::iterator_range<Region::iterator> blocks) {
69     for (Block &block : blocks)
70       worklist.push_back(&block);
71   };
72 
73   addToWorklist(blocks);
74   while (!worklist.empty()) {
75     for (Operation &op : *worklist.pop_back_val()) {
76       if (auto call = dyn_cast<CallOpInterface>(op)) {
77         CallGraphNode *node =
78             cg.resolveCallable(call.getCallableForCallee(), &op);
79         if (!node->isExternal())
80           calls.emplace_back(call, node);
81         continue;
82       }
83 
84       // If this is not a call, traverse the nested regions. If
85       // `traverseNestedCGNodes` is false, then don't traverse nested call graph
86       // regions.
87       for (auto &nestedRegion : op.getRegions())
88         if (traverseNestedCGNodes || !cg.lookupNode(&nestedRegion))
89           addToWorklist(nestedRegion);
90     }
91   }
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // Inliner
96 //===----------------------------------------------------------------------===//
97 namespace {
98 /// This class provides a specialization of the main inlining interface.
99 struct Inliner : public InlinerInterface {
100   Inliner(MLIRContext *context, CallGraph &cg)
101       : InlinerInterface(context), cg(cg) {}
102 
103   /// Process a set of blocks that have been inlined. This callback is invoked
104   /// *before* inlined terminator operations have been processed.
105   void processInlinedBlocks(
106       llvm::iterator_range<Region::iterator> inlinedBlocks) final {
107     collectCallOps(inlinedBlocks, cg, calls, /*traverseNestedCGNodes=*/true);
108   }
109 
110   /// The current set of call instructions to consider for inlining.
111   SmallVector<ResolvedCall, 8> calls;
112 
113   /// The callgraph being operated on.
114   CallGraph &cg;
115 };
116 } // namespace
117 
118 /// Returns true if the given call should be inlined.
119 static bool shouldInline(ResolvedCall &resolvedCall) {
120   // Don't allow inlining terminator calls. We currently don't support this
121   // case.
122   if (resolvedCall.call.getOperation()->isKnownTerminator())
123     return false;
124 
125   // Don't allow inlining if the target is an ancestor of the call. This
126   // prevents inlining recursively.
127   if (resolvedCall.targetNode->getCallableRegion()->isAncestor(
128           resolvedCall.call.getParentRegion()))
129     return false;
130 
131   // Otherwise, inline.
132   return true;
133 }
134 
135 /// Attempt to inline calls within the given scc.
136 static void inlineCallsInSCC(Inliner &inliner,
137                              ArrayRef<CallGraphNode *> currentSCC) {
138   CallGraph &cg = inliner.cg;
139   auto &calls = inliner.calls;
140 
141   // Collect all of the direct calls within the nodes of the current SCC. We
142   // don't traverse nested callgraph nodes, because they are handled separately
143   // likely within a different SCC.
144   for (auto *node : currentSCC) {
145     if (!node->isExternal())
146       collectCallOps(*node->getCallableRegion(), cg, calls,
147                      /*traverseNestedCGNodes=*/false);
148   }
149   if (calls.empty())
150     return;
151 
152   // Try to inline each of the call operations. Don't cache the end iterator
153   // here as more calls may be added during inlining.
154   for (unsigned i = 0; i != calls.size(); ++i) {
155     ResolvedCall &it = calls[i];
156     if (!shouldInline(it))
157       continue;
158 
159     CallOpInterface call = it.call;
160     Region *targetRegion = it.targetNode->getCallableRegion();
161     LogicalResult inlineResult = inlineCall(
162         inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()),
163         targetRegion);
164     if (failed(inlineResult))
165       continue;
166 
167     // If the inlining was successful, then erase the call.
168     call.erase();
169   }
170   calls.clear();
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // InlinerPass
175 //===----------------------------------------------------------------------===//
176 
177 // TODO(riverriddle) This pass should currently only be used for basic testing
178 // of inlining functionality.
179 namespace {
180 struct InlinerPass : public OperationPass<InlinerPass> {
181   void runOnOperation() override {
182     CallGraph &cg = getAnalysis<CallGraph>();
183     Inliner inliner(&getContext(), cg);
184 
185     // Run the inline transform in post-order over the SCCs in the callgraph.
186     runTransformOnCGSCCs(cg, [&](ArrayRef<CallGraphNode *> scc) {
187       inlineCallsInSCC(inliner, scc);
188     });
189   }
190 };
191 } // end anonymous namespace
192 
193 static PassRegistration<InlinerPass> pass("inline", "Inline function calls");
194