1 //===- CodeExtractor.cpp - Pull code region into a new function -----------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements the interface to tear out a code region, such as an
11 // individual loop or a parallel section, into a new function, replacing it with
12 // a call to the new function.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/Transforms/Utils/CodeExtractor.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/Analysis/BlockFrequencyInfo.h"
21 #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
22 #include "llvm/Analysis/BranchProbabilityInfo.h"
23 #include "llvm/Analysis/LoopInfo.h"
24 #include "llvm/Analysis/RegionInfo.h"
25 #include "llvm/Analysis/RegionIterator.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/DerivedTypes.h"
28 #include "llvm/IR/Dominators.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/IntrinsicInst.h"
31 #include "llvm/IR/Intrinsics.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/IR/MDBuilder.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/IR/Verifier.h"
36 #include "llvm/Pass.h"
37 #include "llvm/Support/BlockFrequency.h"
38 #include "llvm/Support/CommandLine.h"
39 #include "llvm/Support/Debug.h"
40 #include "llvm/Support/ErrorHandling.h"
41 #include "llvm/Support/raw_ostream.h"
42 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
43 #include <algorithm>
44 #include <set>
45 using namespace llvm;
46 
47 #define DEBUG_TYPE "code-extractor"
48 
49 // Provide a command-line option to aggregate function arguments into a struct
50 // for functions produced by the code extractor. This is useful when converting
51 // extracted functions to pthread-based code, as only one argument (void*) can
52 // be passed in to pthread_create().
53 static cl::opt<bool>
54 AggregateArgsOpt("aggregate-extracted-args", cl::Hidden,
55                  cl::desc("Aggregate arguments to code-extracted functions"));
56 
57 /// \brief Test whether a block is valid for extraction.
58 bool CodeExtractor::isBlockValidForExtraction(const BasicBlock &BB) {
59   // Landing pads must be in the function where they were inserted for cleanup.
60   if (BB.isEHPad())
61     return false;
62   // taking the address of a basic block moved to another function is illegal
63   if (BB.hasAddressTaken())
64     return false;
65 
66   // don't hoist code that uses another basicblock address, as it's likely to
67   // lead to unexpected behavior, like cross-function jumps
68   SmallPtrSet<User const *, 16> Visited;
69   SmallVector<User const *, 16> ToVisit;
70 
71   for (Instruction const &Inst : BB)
72     ToVisit.push_back(&Inst);
73 
74   while (!ToVisit.empty()) {
75     User const *Curr = ToVisit.pop_back_val();
76     if (!Visited.insert(Curr).second)
77       continue;
78     if (isa<BlockAddress const>(Curr))
79       return false; // even a reference to self is likely to be not compatible
80 
81     if (isa<Instruction>(Curr) && cast<Instruction>(Curr)->getParent() != &BB)
82       continue;
83 
84     for (auto const &U : Curr->operands()) {
85       if (auto *UU = dyn_cast<User>(U))
86         ToVisit.push_back(UU);
87     }
88   }
89 
90   // Don't hoist code containing allocas, invokes, or vastarts.
91   for (BasicBlock::const_iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
92     if (isa<AllocaInst>(I) || isa<InvokeInst>(I))
93       return false;
94     if (const CallInst *CI = dyn_cast<CallInst>(I))
95       if (const Function *F = CI->getCalledFunction())
96         if (F->getIntrinsicID() == Intrinsic::vastart)
97           return false;
98   }
99 
100   return true;
101 }
102 
103 /// \brief Build a set of blocks to extract if the input blocks are viable.
104 static SetVector<BasicBlock *>
105 buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT) {
106   assert(!BBs.empty() && "The set of blocks to extract must be non-empty");
107   SetVector<BasicBlock *> Result;
108 
109   // Loop over the blocks, adding them to our set-vector, and aborting with an
110   // empty set if we encounter invalid blocks.
111   for (BasicBlock *BB : BBs) {
112 
113     // If this block is dead, don't process it.
114     if (DT && !DT->isReachableFromEntry(BB))
115       continue;
116 
117     if (!Result.insert(BB))
118       llvm_unreachable("Repeated basic blocks in extraction input");
119     if (!CodeExtractor::isBlockValidForExtraction(*BB)) {
120       Result.clear();
121       return Result;
122     }
123   }
124 
125 #ifndef NDEBUG
126   for (SetVector<BasicBlock *>::iterator I = std::next(Result.begin()),
127                                          E = Result.end();
128        I != E; ++I)
129     for (pred_iterator PI = pred_begin(*I), PE = pred_end(*I);
130          PI != PE; ++PI)
131       assert(Result.count(*PI) &&
132              "No blocks in this region may have entries from outside the region"
133              " except for the first block!");
134 #endif
135 
136   return Result;
137 }
138 
139 CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
140                              bool AggregateArgs, BlockFrequencyInfo *BFI,
141                              BranchProbabilityInfo *BPI)
142     : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
143       BPI(BPI), Blocks(buildExtractionBlockSet(BBs, DT)), NumExitBlocks(~0U) {}
144 
145 CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
146                              BlockFrequencyInfo *BFI,
147                              BranchProbabilityInfo *BPI)
148     : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
149       BPI(BPI), Blocks(buildExtractionBlockSet(L.getBlocks(), &DT)),
150       NumExitBlocks(~0U) {}
151 
152 /// definedInRegion - Return true if the specified value is defined in the
153 /// extracted region.
154 static bool definedInRegion(const SetVector<BasicBlock *> &Blocks, Value *V) {
155   if (Instruction *I = dyn_cast<Instruction>(V))
156     if (Blocks.count(I->getParent()))
157       return true;
158   return false;
159 }
160 
161 /// definedInCaller - Return true if the specified value is defined in the
162 /// function being code extracted, but not in the region being extracted.
163 /// These values must be passed in as live-ins to the function.
164 static bool definedInCaller(const SetVector<BasicBlock *> &Blocks, Value *V) {
165   if (isa<Argument>(V)) return true;
166   if (Instruction *I = dyn_cast<Instruction>(V))
167     if (!Blocks.count(I->getParent()))
168       return true;
169   return false;
170 }
171 
172 static BasicBlock *getCommonExitBlock(const SetVector<BasicBlock *> &Blocks) {
173   BasicBlock *CommonExitBlock = nullptr;
174   auto hasNonCommonExitSucc = [&](BasicBlock *Block) {
175     for (auto *Succ : successors(Block)) {
176       // Internal edges, ok.
177       if (Blocks.count(Succ))
178         continue;
179       if (!CommonExitBlock) {
180         CommonExitBlock = Succ;
181         continue;
182       }
183       if (CommonExitBlock == Succ)
184         continue;
185 
186       return true;
187     }
188     return false;
189   };
190 
191   if (any_of(Blocks, hasNonCommonExitSucc))
192     return nullptr;
193 
194   return CommonExitBlock;
195 }
196 
197 bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers(
198     Instruction *Addr) const {
199   AllocaInst *AI = cast<AllocaInst>(Addr->stripInBoundsConstantOffsets());
200   Function *Func = (*Blocks.begin())->getParent();
201   for (BasicBlock &BB : *Func) {
202     if (Blocks.count(&BB))
203       continue;
204     for (Instruction &II : BB) {
205 
206       if (isa<DbgInfoIntrinsic>(II))
207         continue;
208 
209       unsigned Opcode = II.getOpcode();
210       Value *MemAddr = nullptr;
211       switch (Opcode) {
212       case Instruction::Store:
213       case Instruction::Load: {
214         if (Opcode == Instruction::Store) {
215           StoreInst *SI = cast<StoreInst>(&II);
216           MemAddr = SI->getPointerOperand();
217         } else {
218           LoadInst *LI = cast<LoadInst>(&II);
219           MemAddr = LI->getPointerOperand();
220         }
221         // Global variable can not be aliased with locals.
222         if (dyn_cast<Constant>(MemAddr))
223           break;
224         Value *Base = MemAddr->stripInBoundsConstantOffsets();
225         if (!dyn_cast<AllocaInst>(Base) || Base == AI)
226           return false;
227         break;
228       }
229       default: {
230         IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(&II);
231         if (IntrInst) {
232           if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start ||
233               IntrInst->getIntrinsicID() == Intrinsic::lifetime_end)
234             break;
235           return false;
236         }
237         // Treat all the other cases conservatively if it has side effects.
238         if (II.mayHaveSideEffects())
239           return false;
240       }
241       }
242     }
243   }
244 
245   return true;
246 }
247 
248 BasicBlock *
249 CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) {
250   BasicBlock *SinglePredFromOutlineRegion = nullptr;
251   assert(!Blocks.count(CommonExitBlock) &&
252          "Expect a block outside the region!");
253   for (auto *Pred : predecessors(CommonExitBlock)) {
254     if (!Blocks.count(Pred))
255       continue;
256     if (!SinglePredFromOutlineRegion) {
257       SinglePredFromOutlineRegion = Pred;
258     } else if (SinglePredFromOutlineRegion != Pred) {
259       SinglePredFromOutlineRegion = nullptr;
260       break;
261     }
262   }
263 
264   if (SinglePredFromOutlineRegion)
265     return SinglePredFromOutlineRegion;
266 
267 #ifndef NDEBUG
268   auto getFirstPHI = [](BasicBlock *BB) {
269     BasicBlock::iterator I = BB->begin();
270     PHINode *FirstPhi = nullptr;
271     while (I != BB->end()) {
272       PHINode *Phi = dyn_cast<PHINode>(I);
273       if (!Phi)
274         break;
275       if (!FirstPhi) {
276         FirstPhi = Phi;
277         break;
278       }
279     }
280     return FirstPhi;
281   };
282   // If there are any phi nodes, the single pred either exists or has already
283   // be created before code extraction.
284   assert(!getFirstPHI(CommonExitBlock) && "Phi not expected");
285 #endif
286 
287   BasicBlock *NewExitBlock = CommonExitBlock->splitBasicBlock(
288       CommonExitBlock->getFirstNonPHI()->getIterator());
289 
290   for (auto *Pred : predecessors(CommonExitBlock)) {
291     if (Blocks.count(Pred))
292       continue;
293     Pred->getTerminator()->replaceUsesOfWith(CommonExitBlock, NewExitBlock);
294   }
295   // Now add the old exit block to the outline region.
296   Blocks.insert(CommonExitBlock);
297   return CommonExitBlock;
298 }
299 
300 void CodeExtractor::findAllocas(ValueSet &SinkCands, ValueSet &HoistCands,
301                                 BasicBlock *&ExitBlock) const {
302   Function *Func = (*Blocks.begin())->getParent();
303   ExitBlock = getCommonExitBlock(Blocks);
304 
305   for (BasicBlock &BB : *Func) {
306     if (Blocks.count(&BB))
307       continue;
308     for (Instruction &II : BB) {
309       auto *AI = dyn_cast<AllocaInst>(&II);
310       if (!AI)
311         continue;
312 
313       // Find the pair of life time markers for address 'Addr' that are either
314       // defined inside the outline region or can legally be shrinkwrapped into
315       // the outline region. If there are not other untracked uses of the
316       // address, return the pair of markers if found; otherwise return a pair
317       // of nullptr.
318       auto GetLifeTimeMarkers =
319           [&](Instruction *Addr, bool &SinkLifeStart,
320               bool &HoistLifeEnd) -> std::pair<Instruction *, Instruction *> {
321         Instruction *LifeStart = nullptr, *LifeEnd = nullptr;
322 
323         for (User *U : Addr->users()) {
324           IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(U);
325           if (IntrInst) {
326             if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start) {
327               // Do not handle the case where AI has multiple start markers.
328               if (LifeStart)
329                 return std::make_pair<Instruction *>(nullptr, nullptr);
330               LifeStart = IntrInst;
331             }
332             if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) {
333               if (LifeEnd)
334                 return std::make_pair<Instruction *>(nullptr, nullptr);
335               LifeEnd = IntrInst;
336             }
337             continue;
338           }
339           // Find untracked uses of the address, bail.
340           if (!definedInRegion(Blocks, U))
341             return std::make_pair<Instruction *>(nullptr, nullptr);
342         }
343 
344         if (!LifeStart || !LifeEnd)
345           return std::make_pair<Instruction *>(nullptr, nullptr);
346 
347         SinkLifeStart = !definedInRegion(Blocks, LifeStart);
348         HoistLifeEnd = !definedInRegion(Blocks, LifeEnd);
349         // Do legality Check.
350         if ((SinkLifeStart || HoistLifeEnd) &&
351             !isLegalToShrinkwrapLifetimeMarkers(Addr))
352           return std::make_pair<Instruction *>(nullptr, nullptr);
353 
354         // Check to see if we have a place to do hoisting, if not, bail.
355         if (HoistLifeEnd && !ExitBlock)
356           return std::make_pair<Instruction *>(nullptr, nullptr);
357 
358         return std::make_pair(LifeStart, LifeEnd);
359       };
360 
361       bool SinkLifeStart = false, HoistLifeEnd = false;
362       auto Markers = GetLifeTimeMarkers(AI, SinkLifeStart, HoistLifeEnd);
363 
364       if (Markers.first) {
365         if (SinkLifeStart)
366           SinkCands.insert(Markers.first);
367         SinkCands.insert(AI);
368         if (HoistLifeEnd)
369           HoistCands.insert(Markers.second);
370         continue;
371       }
372 
373       // Follow the bitcast.
374       Instruction *MarkerAddr = nullptr;
375       for (User *U : AI->users()) {
376 
377         if (U->stripInBoundsConstantOffsets() == AI) {
378           SinkLifeStart = false;
379           HoistLifeEnd = false;
380           Instruction *Bitcast = cast<Instruction>(U);
381           Markers = GetLifeTimeMarkers(Bitcast, SinkLifeStart, HoistLifeEnd);
382           if (Markers.first) {
383             MarkerAddr = Bitcast;
384             continue;
385           }
386         }
387 
388         // Found unknown use of AI.
389         if (!definedInRegion(Blocks, U)) {
390           MarkerAddr = nullptr;
391           break;
392         }
393       }
394 
395       if (MarkerAddr) {
396         if (SinkLifeStart)
397           SinkCands.insert(Markers.first);
398         if (!definedInRegion(Blocks, MarkerAddr))
399           SinkCands.insert(MarkerAddr);
400         SinkCands.insert(AI);
401         if (HoistLifeEnd)
402           HoistCands.insert(Markers.second);
403       }
404     }
405   }
406 }
407 
408 void CodeExtractor::findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs,
409                                       const ValueSet &SinkCands) const {
410 
411   for (BasicBlock *BB : Blocks) {
412     // If a used value is defined outside the region, it's an input.  If an
413     // instruction is used outside the region, it's an output.
414     for (Instruction &II : *BB) {
415       for (User::op_iterator OI = II.op_begin(), OE = II.op_end(); OI != OE;
416            ++OI) {
417         Value *V = *OI;
418         if (!SinkCands.count(V) && definedInCaller(Blocks, V))
419           Inputs.insert(V);
420       }
421 
422       for (User *U : II.users())
423         if (!definedInRegion(Blocks, U)) {
424           Outputs.insert(&II);
425           break;
426         }
427     }
428   }
429 }
430 
431 /// severSplitPHINodes - If a PHI node has multiple inputs from outside of the
432 /// region, we need to split the entry block of the region so that the PHI node
433 /// is easier to deal with.
434 void CodeExtractor::severSplitPHINodes(BasicBlock *&Header) {
435   unsigned NumPredsFromRegion = 0;
436   unsigned NumPredsOutsideRegion = 0;
437 
438   if (Header != &Header->getParent()->getEntryBlock()) {
439     PHINode *PN = dyn_cast<PHINode>(Header->begin());
440     if (!PN) return;  // No PHI nodes.
441 
442     // If the header node contains any PHI nodes, check to see if there is more
443     // than one entry from outside the region.  If so, we need to sever the
444     // header block into two.
445     for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
446       if (Blocks.count(PN->getIncomingBlock(i)))
447         ++NumPredsFromRegion;
448       else
449         ++NumPredsOutsideRegion;
450 
451     // If there is one (or fewer) predecessor from outside the region, we don't
452     // need to do anything special.
453     if (NumPredsOutsideRegion <= 1) return;
454   }
455 
456   // Otherwise, we need to split the header block into two pieces: one
457   // containing PHI nodes merging values from outside of the region, and a
458   // second that contains all of the code for the block and merges back any
459   // incoming values from inside of the region.
460   BasicBlock *NewBB = llvm::SplitBlock(Header, Header->getFirstNonPHI(), DT);
461 
462   // We only want to code extract the second block now, and it becomes the new
463   // header of the region.
464   BasicBlock *OldPred = Header;
465   Blocks.remove(OldPred);
466   Blocks.insert(NewBB);
467   Header = NewBB;
468 
469   // Okay, now we need to adjust the PHI nodes and any branches from within the
470   // region to go to the new header block instead of the old header block.
471   if (NumPredsFromRegion) {
472     PHINode *PN = cast<PHINode>(OldPred->begin());
473     // Loop over all of the predecessors of OldPred that are in the region,
474     // changing them to branch to NewBB instead.
475     for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
476       if (Blocks.count(PN->getIncomingBlock(i))) {
477         TerminatorInst *TI = PN->getIncomingBlock(i)->getTerminator();
478         TI->replaceUsesOfWith(OldPred, NewBB);
479       }
480 
481     // Okay, everything within the region is now branching to the right block, we
482     // just have to update the PHI nodes now, inserting PHI nodes into NewBB.
483     BasicBlock::iterator AfterPHIs;
484     for (AfterPHIs = OldPred->begin(); isa<PHINode>(AfterPHIs); ++AfterPHIs) {
485       PHINode *PN = cast<PHINode>(AfterPHIs);
486       // Create a new PHI node in the new region, which has an incoming value
487       // from OldPred of PN.
488       PHINode *NewPN = PHINode::Create(PN->getType(), 1 + NumPredsFromRegion,
489                                        PN->getName() + ".ce", &NewBB->front());
490       PN->replaceAllUsesWith(NewPN);
491       NewPN->addIncoming(PN, OldPred);
492 
493       // Loop over all of the incoming value in PN, moving them to NewPN if they
494       // are from the extracted region.
495       for (unsigned i = 0; i != PN->getNumIncomingValues(); ++i) {
496         if (Blocks.count(PN->getIncomingBlock(i))) {
497           NewPN->addIncoming(PN->getIncomingValue(i), PN->getIncomingBlock(i));
498           PN->removeIncomingValue(i);
499           --i;
500         }
501       }
502     }
503   }
504 }
505 
506 void CodeExtractor::splitReturnBlocks() {
507   for (BasicBlock *Block : Blocks)
508     if (ReturnInst *RI = dyn_cast<ReturnInst>(Block->getTerminator())) {
509       BasicBlock *New =
510           Block->splitBasicBlock(RI->getIterator(), Block->getName() + ".ret");
511       if (DT) {
512         // Old dominates New. New node dominates all other nodes dominated
513         // by Old.
514         DomTreeNode *OldNode = DT->getNode(Block);
515         SmallVector<DomTreeNode *, 8> Children(OldNode->begin(),
516                                                OldNode->end());
517 
518         DomTreeNode *NewNode = DT->addNewBlock(New, Block);
519 
520         for (DomTreeNode *I : Children)
521           DT->changeImmediateDominator(I, NewNode);
522       }
523     }
524 }
525 
526 /// constructFunction - make a function based on inputs and outputs, as follows:
527 /// f(in0, ..., inN, out0, ..., outN)
528 ///
529 Function *CodeExtractor::constructFunction(const ValueSet &inputs,
530                                            const ValueSet &outputs,
531                                            BasicBlock *header,
532                                            BasicBlock *newRootNode,
533                                            BasicBlock *newHeader,
534                                            Function *oldFunction,
535                                            Module *M) {
536   DEBUG(dbgs() << "inputs: " << inputs.size() << "\n");
537   DEBUG(dbgs() << "outputs: " << outputs.size() << "\n");
538 
539   // This function returns unsigned, outputs will go back by reference.
540   switch (NumExitBlocks) {
541   case 0:
542   case 1: RetTy = Type::getVoidTy(header->getContext()); break;
543   case 2: RetTy = Type::getInt1Ty(header->getContext()); break;
544   default: RetTy = Type::getInt16Ty(header->getContext()); break;
545   }
546 
547   std::vector<Type*> paramTy;
548 
549   // Add the types of the input values to the function's argument list
550   for (Value *value : inputs) {
551     DEBUG(dbgs() << "value used in func: " << *value << "\n");
552     paramTy.push_back(value->getType());
553   }
554 
555   // Add the types of the output values to the function's argument list.
556   for (Value *output : outputs) {
557     DEBUG(dbgs() << "instr used in func: " << *output << "\n");
558     if (AggregateArgs)
559       paramTy.push_back(output->getType());
560     else
561       paramTy.push_back(PointerType::getUnqual(output->getType()));
562   }
563 
564   DEBUG({
565     dbgs() << "Function type: " << *RetTy << " f(";
566     for (Type *i : paramTy)
567       dbgs() << *i << ", ";
568     dbgs() << ")\n";
569   });
570 
571   StructType *StructTy;
572   if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
573     StructTy = StructType::get(M->getContext(), paramTy);
574     paramTy.clear();
575     paramTy.push_back(PointerType::getUnqual(StructTy));
576   }
577   FunctionType *funcType =
578                   FunctionType::get(RetTy, paramTy, false);
579 
580   // Create the new function
581   Function *newFunction = Function::Create(funcType,
582                                            GlobalValue::InternalLinkage,
583                                            oldFunction->getName() + "_" +
584                                            header->getName(), M);
585   // If the old function is no-throw, so is the new one.
586   if (oldFunction->doesNotThrow())
587     newFunction->setDoesNotThrow();
588 
589   // Inherit the uwtable attribute if we need to.
590   if (oldFunction->hasUWTable())
591     newFunction->setHasUWTable();
592 
593   // Inherit all of the target dependent attributes.
594   //  (e.g. If the extracted region contains a call to an x86.sse
595   //  instruction we need to make sure that the extracted region has the
596   //  "target-features" attribute allowing it to be lowered.
597   // FIXME: This should be changed to check to see if a specific
598   //           attribute can not be inherited.
599   AttrBuilder AB(oldFunction->getAttributes().getFnAttributes());
600   for (const auto &Attr : AB.td_attrs())
601     newFunction->addFnAttr(Attr.first, Attr.second);
602 
603   newFunction->getBasicBlockList().push_back(newRootNode);
604 
605   // Create an iterator to name all of the arguments we inserted.
606   Function::arg_iterator AI = newFunction->arg_begin();
607 
608   // Rewrite all users of the inputs in the extracted region to use the
609   // arguments (or appropriate addressing into struct) instead.
610   for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
611     Value *RewriteVal;
612     if (AggregateArgs) {
613       Value *Idx[2];
614       Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext()));
615       Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), i);
616       TerminatorInst *TI = newFunction->begin()->getTerminator();
617       GetElementPtrInst *GEP = GetElementPtrInst::Create(
618           StructTy, &*AI, Idx, "gep_" + inputs[i]->getName(), TI);
619       RewriteVal = new LoadInst(GEP, "loadgep_" + inputs[i]->getName(), TI);
620     } else
621       RewriteVal = &*AI++;
622 
623     std::vector<User*> Users(inputs[i]->user_begin(), inputs[i]->user_end());
624     for (User *use : Users)
625       if (Instruction *inst = dyn_cast<Instruction>(use))
626         if (Blocks.count(inst->getParent()))
627           inst->replaceUsesOfWith(inputs[i], RewriteVal);
628   }
629 
630   // Set names for input and output arguments.
631   if (!AggregateArgs) {
632     AI = newFunction->arg_begin();
633     for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++AI)
634       AI->setName(inputs[i]->getName());
635     for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++AI)
636       AI->setName(outputs[i]->getName()+".out");
637   }
638 
639   // Rewrite branches to basic blocks outside of the loop to new dummy blocks
640   // within the new function. This must be done before we lose track of which
641   // blocks were originally in the code region.
642   std::vector<User*> Users(header->user_begin(), header->user_end());
643   for (unsigned i = 0, e = Users.size(); i != e; ++i)
644     // The BasicBlock which contains the branch is not in the region
645     // modify the branch target to a new block
646     if (TerminatorInst *TI = dyn_cast<TerminatorInst>(Users[i]))
647       if (!Blocks.count(TI->getParent()) &&
648           TI->getParent()->getParent() == oldFunction)
649         TI->replaceUsesOfWith(header, newHeader);
650 
651   return newFunction;
652 }
653 
654 /// FindPhiPredForUseInBlock - Given a value and a basic block, find a PHI
655 /// that uses the value within the basic block, and return the predecessor
656 /// block associated with that use, or return 0 if none is found.
657 static BasicBlock* FindPhiPredForUseInBlock(Value* Used, BasicBlock* BB) {
658   for (Use &U : Used->uses()) {
659      PHINode *P = dyn_cast<PHINode>(U.getUser());
660      if (P && P->getParent() == BB)
661        return P->getIncomingBlock(U);
662   }
663 
664   return nullptr;
665 }
666 
667 /// emitCallAndSwitchStatement - This method sets up the caller side by adding
668 /// the call instruction, splitting any PHI nodes in the header block as
669 /// necessary.
670 void CodeExtractor::
671 emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer,
672                            ValueSet &inputs, ValueSet &outputs) {
673   // Emit a call to the new function, passing in: *pointer to struct (if
674   // aggregating parameters), or plan inputs and allocated memory for outputs
675   std::vector<Value*> params, StructValues, ReloadOutputs, Reloads;
676 
677   Module *M = newFunction->getParent();
678   LLVMContext &Context = M->getContext();
679   const DataLayout &DL = M->getDataLayout();
680 
681   // Add inputs as params, or to be filled into the struct
682   for (Value *input : inputs)
683     if (AggregateArgs)
684       StructValues.push_back(input);
685     else
686       params.push_back(input);
687 
688   // Create allocas for the outputs
689   for (Value *output : outputs) {
690     if (AggregateArgs) {
691       StructValues.push_back(output);
692     } else {
693       AllocaInst *alloca =
694         new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
695                        nullptr, output->getName() + ".loc",
696                        &codeReplacer->getParent()->front().front());
697       ReloadOutputs.push_back(alloca);
698       params.push_back(alloca);
699     }
700   }
701 
702   StructType *StructArgTy = nullptr;
703   AllocaInst *Struct = nullptr;
704   if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
705     std::vector<Type*> ArgTypes;
706     for (ValueSet::iterator v = StructValues.begin(),
707            ve = StructValues.end(); v != ve; ++v)
708       ArgTypes.push_back((*v)->getType());
709 
710     // Allocate a struct at the beginning of this function
711     StructArgTy = StructType::get(newFunction->getContext(), ArgTypes);
712     Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
713                             "structArg",
714                             &codeReplacer->getParent()->front().front());
715     params.push_back(Struct);
716 
717     for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
718       Value *Idx[2];
719       Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
720       Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i);
721       GetElementPtrInst *GEP = GetElementPtrInst::Create(
722           StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName());
723       codeReplacer->getInstList().push_back(GEP);
724       StoreInst *SI = new StoreInst(StructValues[i], GEP);
725       codeReplacer->getInstList().push_back(SI);
726     }
727   }
728 
729   // Emit the call to the function
730   CallInst *call = CallInst::Create(newFunction, params,
731                                     NumExitBlocks > 1 ? "targetBlock" : "");
732   codeReplacer->getInstList().push_back(call);
733 
734   Function::arg_iterator OutputArgBegin = newFunction->arg_begin();
735   unsigned FirstOut = inputs.size();
736   if (!AggregateArgs)
737     std::advance(OutputArgBegin, inputs.size());
738 
739   // Reload the outputs passed in by reference
740   for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
741     Value *Output = nullptr;
742     if (AggregateArgs) {
743       Value *Idx[2];
744       Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
745       Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
746       GetElementPtrInst *GEP = GetElementPtrInst::Create(
747           StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName());
748       codeReplacer->getInstList().push_back(GEP);
749       Output = GEP;
750     } else {
751       Output = ReloadOutputs[i];
752     }
753     LoadInst *load = new LoadInst(Output, outputs[i]->getName()+".reload");
754     Reloads.push_back(load);
755     codeReplacer->getInstList().push_back(load);
756     std::vector<User*> Users(outputs[i]->user_begin(), outputs[i]->user_end());
757     for (unsigned u = 0, e = Users.size(); u != e; ++u) {
758       Instruction *inst = cast<Instruction>(Users[u]);
759       if (!Blocks.count(inst->getParent()))
760         inst->replaceUsesOfWith(outputs[i], load);
761     }
762   }
763 
764   // Now we can emit a switch statement using the call as a value.
765   SwitchInst *TheSwitch =
766       SwitchInst::Create(Constant::getNullValue(Type::getInt16Ty(Context)),
767                          codeReplacer, 0, codeReplacer);
768 
769   // Since there may be multiple exits from the original region, make the new
770   // function return an unsigned, switch on that number.  This loop iterates
771   // over all of the blocks in the extracted region, updating any terminator
772   // instructions in the to-be-extracted region that branch to blocks that are
773   // not in the region to be extracted.
774   std::map<BasicBlock*, BasicBlock*> ExitBlockMap;
775 
776   unsigned switchVal = 0;
777   for (BasicBlock *Block : Blocks) {
778     TerminatorInst *TI = Block->getTerminator();
779     for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
780       if (!Blocks.count(TI->getSuccessor(i))) {
781         BasicBlock *OldTarget = TI->getSuccessor(i);
782         // add a new basic block which returns the appropriate value
783         BasicBlock *&NewTarget = ExitBlockMap[OldTarget];
784         if (!NewTarget) {
785           // If we don't already have an exit stub for this non-extracted
786           // destination, create one now!
787           NewTarget = BasicBlock::Create(Context,
788                                          OldTarget->getName() + ".exitStub",
789                                          newFunction);
790           unsigned SuccNum = switchVal++;
791 
792           Value *brVal = nullptr;
793           switch (NumExitBlocks) {
794           case 0:
795           case 1: break;  // No value needed.
796           case 2:         // Conditional branch, return a bool
797             brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum);
798             break;
799           default:
800             brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum);
801             break;
802           }
803 
804           ReturnInst *NTRet = ReturnInst::Create(Context, brVal, NewTarget);
805 
806           // Update the switch instruction.
807           TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context),
808                                               SuccNum),
809                              OldTarget);
810 
811           // Restore values just before we exit
812           Function::arg_iterator OAI = OutputArgBegin;
813           for (unsigned out = 0, e = outputs.size(); out != e; ++out) {
814             // For an invoke, the normal destination is the only one that is
815             // dominated by the result of the invocation
816             BasicBlock *DefBlock = cast<Instruction>(outputs[out])->getParent();
817 
818             bool DominatesDef = true;
819 
820             BasicBlock *NormalDest = nullptr;
821             if (auto *Invoke = dyn_cast<InvokeInst>(outputs[out]))
822               NormalDest = Invoke->getNormalDest();
823 
824             if (NormalDest) {
825               DefBlock = NormalDest;
826 
827               // Make sure we are looking at the original successor block, not
828               // at a newly inserted exit block, which won't be in the dominator
829               // info.
830               for (const auto &I : ExitBlockMap)
831                 if (DefBlock == I.second) {
832                   DefBlock = I.first;
833                   break;
834                 }
835 
836               // In the extract block case, if the block we are extracting ends
837               // with an invoke instruction, make sure that we don't emit a
838               // store of the invoke value for the unwind block.
839               if (!DT && DefBlock != OldTarget)
840                 DominatesDef = false;
841             }
842 
843             if (DT) {
844               DominatesDef = DT->dominates(DefBlock, OldTarget);
845 
846               // If the output value is used by a phi in the target block,
847               // then we need to test for dominance of the phi's predecessor
848               // instead.  Unfortunately, this a little complicated since we
849               // have already rewritten uses of the value to uses of the reload.
850               BasicBlock* pred = FindPhiPredForUseInBlock(Reloads[out],
851                                                           OldTarget);
852               if (pred && DT && DT->dominates(DefBlock, pred))
853                 DominatesDef = true;
854             }
855 
856             if (DominatesDef) {
857               if (AggregateArgs) {
858                 Value *Idx[2];
859                 Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
860                 Idx[1] = ConstantInt::get(Type::getInt32Ty(Context),
861                                           FirstOut+out);
862                 GetElementPtrInst *GEP = GetElementPtrInst::Create(
863                     StructArgTy, &*OAI, Idx, "gep_" + outputs[out]->getName(),
864                     NTRet);
865                 new StoreInst(outputs[out], GEP, NTRet);
866               } else {
867                 new StoreInst(outputs[out], &*OAI, NTRet);
868               }
869             }
870             // Advance output iterator even if we don't emit a store
871             if (!AggregateArgs) ++OAI;
872           }
873         }
874 
875         // rewrite the original branch instruction with this new target
876         TI->setSuccessor(i, NewTarget);
877       }
878   }
879 
880   // Now that we've done the deed, simplify the switch instruction.
881   Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType();
882   switch (NumExitBlocks) {
883   case 0:
884     // There are no successors (the block containing the switch itself), which
885     // means that previously this was the last part of the function, and hence
886     // this should be rewritten as a `ret'
887 
888     // Check if the function should return a value
889     if (OldFnRetTy->isVoidTy()) {
890       ReturnInst::Create(Context, nullptr, TheSwitch);  // Return void
891     } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) {
892       // return what we have
893       ReturnInst::Create(Context, TheSwitch->getCondition(), TheSwitch);
894     } else {
895       // Otherwise we must have code extracted an unwind or something, just
896       // return whatever we want.
897       ReturnInst::Create(Context,
898                          Constant::getNullValue(OldFnRetTy), TheSwitch);
899     }
900 
901     TheSwitch->eraseFromParent();
902     break;
903   case 1:
904     // Only a single destination, change the switch into an unconditional
905     // branch.
906     BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch);
907     TheSwitch->eraseFromParent();
908     break;
909   case 2:
910     BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2),
911                        call, TheSwitch);
912     TheSwitch->eraseFromParent();
913     break;
914   default:
915     // Otherwise, make the default destination of the switch instruction be one
916     // of the other successors.
917     TheSwitch->setCondition(call);
918     TheSwitch->setDefaultDest(TheSwitch->getSuccessor(NumExitBlocks));
919     // Remove redundant case
920     TheSwitch->removeCase(SwitchInst::CaseIt(TheSwitch, NumExitBlocks-1));
921     break;
922   }
923 }
924 
925 void CodeExtractor::moveCodeToFunction(Function *newFunction) {
926   Function *oldFunc = (*Blocks.begin())->getParent();
927   Function::BasicBlockListType &oldBlocks = oldFunc->getBasicBlockList();
928   Function::BasicBlockListType &newBlocks = newFunction->getBasicBlockList();
929 
930   for (BasicBlock *Block : Blocks) {
931     // Delete the basic block from the old function, and the list of blocks
932     oldBlocks.remove(Block);
933 
934     // Insert this basic block into the new function
935     newBlocks.push_back(Block);
936   }
937 }
938 
939 void CodeExtractor::calculateNewCallTerminatorWeights(
940     BasicBlock *CodeReplacer,
941     DenseMap<BasicBlock *, BlockFrequency> &ExitWeights,
942     BranchProbabilityInfo *BPI) {
943   typedef BlockFrequencyInfoImplBase::Distribution Distribution;
944   typedef BlockFrequencyInfoImplBase::BlockNode BlockNode;
945 
946   // Update the branch weights for the exit block.
947   TerminatorInst *TI = CodeReplacer->getTerminator();
948   SmallVector<unsigned, 8> BranchWeights(TI->getNumSuccessors(), 0);
949 
950   // Block Frequency distribution with dummy node.
951   Distribution BranchDist;
952 
953   // Add each of the frequencies of the successors.
954   for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) {
955     BlockNode ExitNode(i);
956     uint64_t ExitFreq = ExitWeights[TI->getSuccessor(i)].getFrequency();
957     if (ExitFreq != 0)
958       BranchDist.addExit(ExitNode, ExitFreq);
959     else
960       BPI->setEdgeProbability(CodeReplacer, i, BranchProbability::getZero());
961   }
962 
963   // Check for no total weight.
964   if (BranchDist.Total == 0)
965     return;
966 
967   // Normalize the distribution so that they can fit in unsigned.
968   BranchDist.normalize();
969 
970   // Create normalized branch weights and set the metadata.
971   for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) {
972     const auto &Weight = BranchDist.Weights[I];
973 
974     // Get the weight and update the current BFI.
975     BranchWeights[Weight.TargetNode.Index] = Weight.Amount;
976     BranchProbability BP(Weight.Amount, BranchDist.Total);
977     BPI->setEdgeProbability(CodeReplacer, Weight.TargetNode.Index, BP);
978   }
979   TI->setMetadata(
980       LLVMContext::MD_prof,
981       MDBuilder(TI->getContext()).createBranchWeights(BranchWeights));
982 }
983 
984 Function *CodeExtractor::extractCodeRegion() {
985   if (!isEligible())
986     return nullptr;
987 
988   ValueSet inputs, outputs, SinkingCands, HoistingCands;
989   BasicBlock *CommonExit = nullptr;
990 
991   // Assumption: this is a single-entry code region, and the header is the first
992   // block in the region.
993   BasicBlock *header = *Blocks.begin();
994 
995   // Calculate the entry frequency of the new function before we change the root
996   //   block.
997   BlockFrequency EntryFreq;
998   if (BFI) {
999     assert(BPI && "Both BPI and BFI are required to preserve profile info");
1000     for (BasicBlock *Pred : predecessors(header)) {
1001       if (Blocks.count(Pred))
1002         continue;
1003       EntryFreq +=
1004           BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header);
1005     }
1006   }
1007 
1008   // If we have to split PHI nodes or the entry block, do so now.
1009   severSplitPHINodes(header);
1010 
1011   // If we have any return instructions in the region, split those blocks so
1012   // that the return is not in the region.
1013   splitReturnBlocks();
1014 
1015   Function *oldFunction = header->getParent();
1016 
1017   // This takes place of the original loop
1018   BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(),
1019                                                 "codeRepl", oldFunction,
1020                                                 header);
1021 
1022   // The new function needs a root node because other nodes can branch to the
1023   // head of the region, but the entry node of a function cannot have preds.
1024   BasicBlock *newFuncRoot = BasicBlock::Create(header->getContext(),
1025                                                "newFuncRoot");
1026   newFuncRoot->getInstList().push_back(BranchInst::Create(header));
1027 
1028   findAllocas(SinkingCands, HoistingCands, CommonExit);
1029   assert(HoistingCands.empty() || CommonExit);
1030 
1031   // Find inputs to, outputs from the code region.
1032   findInputsOutputs(inputs, outputs, SinkingCands);
1033 
1034   // Now sink all instructions which only have non-phi uses inside the region
1035   for (auto *II : SinkingCands)
1036     cast<Instruction>(II)->moveBefore(*newFuncRoot,
1037                                       newFuncRoot->getFirstInsertionPt());
1038 
1039   if (!HoistingCands.empty()) {
1040     auto *HoistToBlock = findOrCreateBlockForHoisting(CommonExit);
1041     Instruction *TI = HoistToBlock->getTerminator();
1042     for (auto *II : HoistingCands)
1043       cast<Instruction>(II)->moveBefore(TI);
1044   }
1045 
1046   // Calculate the exit blocks for the extracted region and the total exit
1047   //  weights for each of those blocks.
1048   DenseMap<BasicBlock *, BlockFrequency> ExitWeights;
1049   SmallPtrSet<BasicBlock *, 1> ExitBlocks;
1050   for (BasicBlock *Block : Blocks) {
1051     for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE;
1052          ++SI) {
1053       if (!Blocks.count(*SI)) {
1054         // Update the branch weight for this successor.
1055         if (BFI) {
1056           BlockFrequency &BF = ExitWeights[*SI];
1057           BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI);
1058         }
1059         ExitBlocks.insert(*SI);
1060       }
1061     }
1062   }
1063   NumExitBlocks = ExitBlocks.size();
1064 
1065   // Construct new function based on inputs/outputs & add allocas for all defs.
1066   Function *newFunction = constructFunction(inputs, outputs, header,
1067                                             newFuncRoot,
1068                                             codeReplacer, oldFunction,
1069                                             oldFunction->getParent());
1070 
1071   // Update the entry count of the function.
1072   if (BFI) {
1073     Optional<uint64_t> EntryCount =
1074         BFI->getProfileCountFromFreq(EntryFreq.getFrequency());
1075     if (EntryCount.hasValue())
1076       newFunction->setEntryCount(EntryCount.getValue());
1077     BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency());
1078   }
1079 
1080   emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs);
1081 
1082   moveCodeToFunction(newFunction);
1083 
1084   // Update the branch weights for the exit block.
1085   if (BFI && NumExitBlocks > 1)
1086     calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI);
1087 
1088   // Loop over all of the PHI nodes in the header block, and change any
1089   // references to the old incoming edge to be the new incoming edge.
1090   for (BasicBlock::iterator I = header->begin(); isa<PHINode>(I); ++I) {
1091     PHINode *PN = cast<PHINode>(I);
1092     for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
1093       if (!Blocks.count(PN->getIncomingBlock(i)))
1094         PN->setIncomingBlock(i, newFuncRoot);
1095   }
1096 
1097   // Look at all successors of the codeReplacer block.  If any of these blocks
1098   // had PHI nodes in them, we need to update the "from" block to be the code
1099   // replacer, not the original block in the extracted region.
1100   std::vector<BasicBlock*> Succs(succ_begin(codeReplacer),
1101                                  succ_end(codeReplacer));
1102   for (unsigned i = 0, e = Succs.size(); i != e; ++i)
1103     for (BasicBlock::iterator I = Succs[i]->begin(); isa<PHINode>(I); ++I) {
1104       PHINode *PN = cast<PHINode>(I);
1105       std::set<BasicBlock*> ProcessedPreds;
1106       for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
1107         if (Blocks.count(PN->getIncomingBlock(i))) {
1108           if (ProcessedPreds.insert(PN->getIncomingBlock(i)).second)
1109             PN->setIncomingBlock(i, codeReplacer);
1110           else {
1111             // There were multiple entries in the PHI for this block, now there
1112             // is only one, so remove the duplicated entries.
1113             PN->removeIncomingValue(i, false);
1114             --i; --e;
1115           }
1116         }
1117     }
1118 
1119   DEBUG(if (verifyFunction(*newFunction))
1120         report_fatal_error("verifyFunction failed!"));
1121   return newFunction;
1122 }
1123