1 //===- PartialInlining.cpp - Inline parts of functions --------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass performs partial inlining, typically by inlining an if statement
11 // that surrounds the body of the function.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/IPO/PartialInlining.h"
16 #include "llvm/ADT/Statistic.h"
17 #include "llvm/Analysis/BlockFrequencyInfo.h"
18 #include "llvm/Analysis/BranchProbabilityInfo.h"
19 #include "llvm/Analysis/LoopInfo.h"
20 #include "llvm/IR/CFG.h"
21 #include "llvm/IR/Dominators.h"
22 #include "llvm/IR/Instructions.h"
23 #include "llvm/IR/Module.h"
24 #include "llvm/Pass.h"
25 #include "llvm/Transforms/IPO.h"
26 #include "llvm/Transforms/Utils/Cloning.h"
27 #include "llvm/Transforms/Utils/CodeExtractor.h"
28 using namespace llvm;
29 
30 #define DEBUG_TYPE "partialinlining"
31 
32 STATISTIC(NumPartialInlined, "Number of functions partially inlined");
33 
34 namespace {
35 struct PartialInlinerImpl {
36   PartialInlinerImpl(InlineFunctionInfo IFI) : IFI(std::move(IFI)) {}
37   bool run(Module &M);
38   Function *unswitchFunction(Function *F);
39 
40 private:
41   InlineFunctionInfo IFI;
42 };
43 struct PartialInlinerLegacyPass : public ModulePass {
44   static char ID; // Pass identification, replacement for typeid
45   PartialInlinerLegacyPass() : ModulePass(ID) {
46     initializePartialInlinerLegacyPassPass(*PassRegistry::getPassRegistry());
47   }
48 
49   void getAnalysisUsage(AnalysisUsage &AU) const override {
50     AU.addRequired<AssumptionCacheTracker>();
51   }
52   bool runOnModule(Module &M) override {
53     if (skipModule(M))
54       return false;
55 
56     AssumptionCacheTracker *ACT = &getAnalysis<AssumptionCacheTracker>();
57     std::function<AssumptionCache &(Function &)> GetAssumptionCache =
58         [&ACT](Function &F) -> AssumptionCache & {
59       return ACT->getAssumptionCache(F);
60     };
61     InlineFunctionInfo IFI(nullptr, &GetAssumptionCache);
62     return PartialInlinerImpl(IFI).run(M);
63   }
64 };
65 }
66 
67 Function *PartialInlinerImpl::unswitchFunction(Function *F) {
68   // First, verify that this function is an unswitching candidate...
69   BasicBlock *EntryBlock = &F->front();
70   BranchInst *BR = dyn_cast<BranchInst>(EntryBlock->getTerminator());
71   if (!BR || BR->isUnconditional())
72     return nullptr;
73 
74   BasicBlock *ReturnBlock = nullptr;
75   BasicBlock *NonReturnBlock = nullptr;
76   unsigned ReturnCount = 0;
77   for (BasicBlock *BB : successors(EntryBlock)) {
78     if (isa<ReturnInst>(BB->getTerminator())) {
79       ReturnBlock = BB;
80       ReturnCount++;
81     } else
82       NonReturnBlock = BB;
83   }
84 
85   if (ReturnCount != 1)
86     return nullptr;
87 
88   auto canAllUsesBeReplaced = [](Function *F) {
89     std::vector<User *> Users(F->user_begin(), F->user_end());
90     for (User *User : Users) {
91       Function *Callee = nullptr;
92       if (CallInst *CI = dyn_cast<CallInst>(User))
93         Callee = CallSite(CI).getCalledFunction();
94       else if (InvokeInst *II = dyn_cast<InvokeInst>(User))
95         Callee = CallSite(II).getCalledFunction();
96 
97       if (Callee != F)
98         return false;
99     }
100 
101     return true;
102   };
103 
104   if (!canAllUsesBeReplaced(F))
105     return nullptr;
106 
107   // Clone the function, so that we can hack away on it.
108   ValueToValueMapTy VMap;
109   Function *DuplicateFunction = CloneFunction(F, VMap);
110   DuplicateFunction->setLinkage(GlobalValue::InternalLinkage);
111   BasicBlock *NewEntryBlock = cast<BasicBlock>(VMap[EntryBlock]);
112   BasicBlock *NewReturnBlock = cast<BasicBlock>(VMap[ReturnBlock]);
113   BasicBlock *NewNonReturnBlock = cast<BasicBlock>(VMap[NonReturnBlock]);
114 
115   // Go ahead and update all uses to the duplicate, so that we can just
116   // use the inliner functionality when we're done hacking.
117   F->replaceAllUsesWith(DuplicateFunction);
118 
119   // Special hackery is needed with PHI nodes that have inputs from more than
120   // one extracted block.  For simplicity, just split the PHIs into a two-level
121   // sequence of PHIs, some of which will go in the extracted region, and some
122   // of which will go outside.
123   BasicBlock *PreReturn = NewReturnBlock;
124   NewReturnBlock = NewReturnBlock->splitBasicBlock(
125       NewReturnBlock->getFirstNonPHI()->getIterator());
126   BasicBlock::iterator I = PreReturn->begin();
127   Instruction *Ins = &NewReturnBlock->front();
128   while (I != PreReturn->end()) {
129     PHINode *OldPhi = dyn_cast<PHINode>(I);
130     if (!OldPhi)
131       break;
132 
133     PHINode *RetPhi = PHINode::Create(OldPhi->getType(), 2, "", Ins);
134     OldPhi->replaceAllUsesWith(RetPhi);
135     Ins = NewReturnBlock->getFirstNonPHI();
136 
137     RetPhi->addIncoming(&*I, PreReturn);
138     RetPhi->addIncoming(OldPhi->getIncomingValueForBlock(NewEntryBlock),
139                         NewEntryBlock);
140     OldPhi->removeIncomingValue(NewEntryBlock);
141 
142     ++I;
143   }
144   NewEntryBlock->getTerminator()->replaceUsesOfWith(PreReturn, NewReturnBlock);
145 
146   // Gather up the blocks that we're going to extract.
147   std::vector<BasicBlock *> ToExtract;
148   ToExtract.push_back(NewNonReturnBlock);
149   for (BasicBlock &BB : *DuplicateFunction)
150     if (&BB != NewEntryBlock && &BB != NewReturnBlock &&
151         &BB != NewNonReturnBlock)
152       ToExtract.push_back(&BB);
153 
154   // The CodeExtractor needs a dominator tree.
155   DominatorTree DT;
156   DT.recalculate(*DuplicateFunction);
157 
158   // Manually calculate a BlockFrequencyInfo and BranchProbabilityInfo.
159   LoopInfo LI(DT);
160   BranchProbabilityInfo BPI(*DuplicateFunction, LI);
161   BlockFrequencyInfo BFI(*DuplicateFunction, BPI, LI);
162 
163   // Extract the body of the if.
164   Function *ExtractedFunction =
165       CodeExtractor(ToExtract, &DT, /*AggregateArgs*/ false, &BFI, &BPI)
166           .extractCodeRegion();
167 
168   // Inline the top-level if test into all callers.
169   std::vector<User *> Users(DuplicateFunction->user_begin(),
170                             DuplicateFunction->user_end());
171   for (User *User : Users)
172     if (CallInst *CI = dyn_cast<CallInst>(User))
173       InlineFunction(CI, IFI);
174     else if (InvokeInst *II = dyn_cast<InvokeInst>(User))
175       InlineFunction(II, IFI);
176 
177   // Ditch the duplicate, since we're done with it, and rewrite all remaining
178   // users (function pointers, etc.) back to the original function.
179   DuplicateFunction->replaceAllUsesWith(F);
180   DuplicateFunction->eraseFromParent();
181 
182   ++NumPartialInlined;
183 
184   return ExtractedFunction;
185 }
186 
187 bool PartialInlinerImpl::run(Module &M) {
188   std::vector<Function *> Worklist;
189   Worklist.reserve(M.size());
190   for (Function &F : M)
191     if (!F.use_empty() && !F.isDeclaration())
192       Worklist.push_back(&F);
193 
194   bool Changed = false;
195   while (!Worklist.empty()) {
196     Function *CurrFunc = Worklist.back();
197     Worklist.pop_back();
198 
199     if (CurrFunc->use_empty())
200       continue;
201 
202     bool Recursive = false;
203     for (User *U : CurrFunc->users())
204       if (Instruction *I = dyn_cast<Instruction>(U))
205         if (I->getParent()->getParent() == CurrFunc) {
206           Recursive = true;
207           break;
208         }
209     if (Recursive)
210       continue;
211 
212     if (Function *NewFunc = unswitchFunction(CurrFunc)) {
213       Worklist.push_back(NewFunc);
214       Changed = true;
215     }
216   }
217 
218   return Changed;
219 }
220 
221 char PartialInlinerLegacyPass::ID = 0;
222 INITIALIZE_PASS_BEGIN(PartialInlinerLegacyPass, "partial-inliner",
223                       "Partial Inliner", false, false)
224 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
225 INITIALIZE_PASS_END(PartialInlinerLegacyPass, "partial-inliner",
226                     "Partial Inliner", false, false)
227 
228 ModulePass *llvm::createPartialInliningPass() {
229   return new PartialInlinerLegacyPass();
230 }
231 
232 PreservedAnalyses PartialInlinerPass::run(Module &M,
233                                           ModuleAnalysisManager &AM) {
234   auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
235   std::function<AssumptionCache &(Function &)> GetAssumptionCache =
236       [&FAM](Function &F) -> AssumptionCache & {
237     return FAM.getResult<AssumptionAnalysis>(F);
238   };
239   InlineFunctionInfo IFI(nullptr, &GetAssumptionCache);
240   if (PartialInlinerImpl(IFI).run(M))
241     return PreservedAnalyses::none();
242   return PreservedAnalyses::all();
243 }
244