1 //===- PartialInlining.cpp - Inline parts of functions --------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass performs partial inlining, typically by inlining an if statement
11 // that surrounds the body of the function.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/IPO.h"
16 #include "llvm/ADT/Statistic.h"
17 #include "llvm/IR/CFG.h"
18 #include "llvm/IR/Dominators.h"
19 #include "llvm/IR/Instructions.h"
20 #include "llvm/IR/Module.h"
21 #include "llvm/Pass.h"
22 #include "llvm/Transforms/Utils/Cloning.h"
23 #include "llvm/Transforms/Utils/CodeExtractor.h"
24 using namespace llvm;
25 
26 #define DEBUG_TYPE "partialinlining"
27 
28 STATISTIC(NumPartialInlined, "Number of functions partially inlined");
29 
30 namespace {
31   struct PartialInliner : public ModulePass {
32     void getAnalysisUsage(AnalysisUsage &AU) const override { }
33     static char ID; // Pass identification, replacement for typeid
34     PartialInliner() : ModulePass(ID) {
35       initializePartialInlinerPass(*PassRegistry::getPassRegistry());
36     }
37 
38     bool runOnModule(Module& M) override;
39 
40   private:
41     Function* unswitchFunction(Function* F);
42   };
43 }
44 
45 char PartialInliner::ID = 0;
46 INITIALIZE_PASS(PartialInliner, "partial-inliner",
47                 "Partial Inliner", false, false)
48 
49 ModulePass* llvm::createPartialInliningPass() { return new PartialInliner(); }
50 
51 Function* PartialInliner::unswitchFunction(Function* F) {
52   // First, verify that this function is an unswitching candidate...
53   BasicBlock *entryBlock = &F->front();
54   BranchInst *BR = dyn_cast<BranchInst>(entryBlock->getTerminator());
55   if (!BR || BR->isUnconditional())
56     return nullptr;
57 
58   BasicBlock* returnBlock = nullptr;
59   BasicBlock* nonReturnBlock = nullptr;
60   unsigned returnCount = 0;
61   for (BasicBlock *BB : successors(entryBlock)) {
62     if (isa<ReturnInst>(BB->getTerminator())) {
63       returnBlock = BB;
64       returnCount++;
65     } else
66       nonReturnBlock = BB;
67   }
68 
69   if (returnCount != 1)
70     return nullptr;
71 
72   // Clone the function, so that we can hack away on it.
73   ValueToValueMapTy VMap;
74   Function* duplicateFunction = CloneFunction(F, VMap);
75   duplicateFunction->setLinkage(GlobalValue::InternalLinkage);
76   BasicBlock* newEntryBlock = cast<BasicBlock>(VMap[entryBlock]);
77   BasicBlock* newReturnBlock = cast<BasicBlock>(VMap[returnBlock]);
78   BasicBlock* newNonReturnBlock = cast<BasicBlock>(VMap[nonReturnBlock]);
79 
80   // Go ahead and update all uses to the duplicate, so that we can just
81   // use the inliner functionality when we're done hacking.
82   F->replaceAllUsesWith(duplicateFunction);
83 
84   // Special hackery is needed with PHI nodes that have inputs from more than
85   // one extracted block.  For simplicity, just split the PHIs into a two-level
86   // sequence of PHIs, some of which will go in the extracted region, and some
87   // of which will go outside.
88   BasicBlock* preReturn = newReturnBlock;
89   newReturnBlock = newReturnBlock->splitBasicBlock(
90       newReturnBlock->getFirstNonPHI()->getIterator());
91   BasicBlock::iterator I = preReturn->begin();
92   Instruction *Ins = &newReturnBlock->front();
93   while (I != preReturn->end()) {
94     PHINode* OldPhi = dyn_cast<PHINode>(I);
95     if (!OldPhi) break;
96 
97     PHINode *retPhi = PHINode::Create(OldPhi->getType(), 2, "", Ins);
98     OldPhi->replaceAllUsesWith(retPhi);
99     Ins = newReturnBlock->getFirstNonPHI();
100 
101     retPhi->addIncoming(&*I, preReturn);
102     retPhi->addIncoming(OldPhi->getIncomingValueForBlock(newEntryBlock),
103                         newEntryBlock);
104     OldPhi->removeIncomingValue(newEntryBlock);
105 
106     ++I;
107   }
108   newEntryBlock->getTerminator()->replaceUsesOfWith(preReturn, newReturnBlock);
109 
110   // Gather up the blocks that we're going to extract.
111   std::vector<BasicBlock*> toExtract;
112   toExtract.push_back(newNonReturnBlock);
113   for (BasicBlock &BB : *duplicateFunction)
114     if (&BB != newEntryBlock && &BB != newReturnBlock &&
115         &BB != newNonReturnBlock)
116       toExtract.push_back(&BB);
117 
118   // The CodeExtractor needs a dominator tree.
119   DominatorTree DT;
120   DT.recalculate(*duplicateFunction);
121 
122   // Extract the body of the if.
123   Function* extractedFunction
124     = CodeExtractor(toExtract, &DT).extractCodeRegion();
125 
126   InlineFunctionInfo IFI;
127 
128   // Inline the top-level if test into all callers.
129   std::vector<User *> Users(duplicateFunction->user_begin(),
130                             duplicateFunction->user_end());
131   for (User *User : Users)
132     if (CallInst *CI = dyn_cast<CallInst>(User))
133       InlineFunction(CI, IFI);
134     else if (InvokeInst *II = dyn_cast<InvokeInst>(User))
135       InlineFunction(II, IFI);
136 
137   // Ditch the duplicate, since we're done with it, and rewrite all remaining
138   // users (function pointers, etc.) back to the original function.
139   duplicateFunction->replaceAllUsesWith(F);
140   duplicateFunction->eraseFromParent();
141 
142   ++NumPartialInlined;
143 
144   return extractedFunction;
145 }
146 
147 bool PartialInliner::runOnModule(Module& M) {
148   if (skipModule(M))
149     return false;
150 
151   std::vector<Function*> worklist;
152   worklist.reserve(M.size());
153   for (Function &F : M)
154     if (!F.use_empty() && !F.isDeclaration())
155       worklist.push_back(&F);
156 
157   bool changed = false;
158   while (!worklist.empty()) {
159     Function* currFunc = worklist.back();
160     worklist.pop_back();
161 
162     if (currFunc->use_empty()) continue;
163 
164     bool recursive = false;
165     for (User *U : currFunc->users())
166       if (Instruction* I = dyn_cast<Instruction>(U))
167         if (I->getParent()->getParent() == currFunc) {
168           recursive = true;
169           break;
170         }
171     if (recursive) continue;
172 
173 
174     if (Function* newFunc = unswitchFunction(currFunc)) {
175       worklist.push_back(newFunc);
176       changed = true;
177     }
178 
179   }
180 
181   return changed;
182 }
183