1 //===- CodeExtractor.cpp - Pull code region into a new function -----------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements the interface to tear out a code region, such as an
11 // individual loop or a parallel section, into a new function, replacing it with
12 // a call to the new function.
13 //
14 //===----------------------------------------------------------------------===//
15
16 #include "llvm/Transforms/Utils/CodeExtractor.h"
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/Optional.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/ADT/SmallPtrSet.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/Analysis/BlockFrequencyInfo.h"
25 #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
26 #include "llvm/Analysis/BranchProbabilityInfo.h"
27 #include "llvm/Analysis/LoopInfo.h"
28 #include "llvm/IR/Argument.h"
29 #include "llvm/IR/Attributes.h"
30 #include "llvm/IR/BasicBlock.h"
31 #include "llvm/IR/CFG.h"
32 #include "llvm/IR/Constant.h"
33 #include "llvm/IR/Constants.h"
34 #include "llvm/IR/DataLayout.h"
35 #include "llvm/IR/DerivedTypes.h"
36 #include "llvm/IR/Dominators.h"
37 #include "llvm/IR/Function.h"
38 #include "llvm/IR/GlobalValue.h"
39 #include "llvm/IR/InstrTypes.h"
40 #include "llvm/IR/Instruction.h"
41 #include "llvm/IR/Instructions.h"
42 #include "llvm/IR/IntrinsicInst.h"
43 #include "llvm/IR/Intrinsics.h"
44 #include "llvm/IR/LLVMContext.h"
45 #include "llvm/IR/MDBuilder.h"
46 #include "llvm/IR/Module.h"
47 #include "llvm/IR/Type.h"
48 #include "llvm/IR/User.h"
49 #include "llvm/IR/Value.h"
50 #include "llvm/IR/Verifier.h"
51 #include "llvm/Pass.h"
52 #include "llvm/Support/BlockFrequency.h"
53 #include "llvm/Support/BranchProbability.h"
54 #include "llvm/Support/Casting.h"
55 #include "llvm/Support/CommandLine.h"
56 #include "llvm/Support/Debug.h"
57 #include "llvm/Support/ErrorHandling.h"
58 #include "llvm/Support/raw_ostream.h"
59 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
60 #include "llvm/Transforms/Utils/Local.h"
61 #include <cassert>
62 #include <cstdint>
63 #include <iterator>
64 #include <map>
65 #include <set>
66 #include <utility>
67 #include <vector>
68
69 using namespace llvm;
70 using ProfileCount = Function::ProfileCount;
71
72 #define DEBUG_TYPE "code-extractor"
73
74 // Provide a command-line option to aggregate function arguments into a struct
75 // for functions produced by the code extractor. This is useful when converting
76 // extracted functions to pthread-based code, as only one argument (void*) can
77 // be passed in to pthread_create().
78 static cl::opt<bool>
79 AggregateArgsOpt("aggregate-extracted-args", cl::Hidden,
80 cl::desc("Aggregate arguments to code-extracted functions"));
81
82 /// Test whether a block is valid for extraction.
isBlockValidForExtraction(const BasicBlock & BB,const SetVector<BasicBlock * > & Result,bool AllowVarArgs,bool AllowAlloca)83 static bool isBlockValidForExtraction(const BasicBlock &BB,
84 const SetVector<BasicBlock *> &Result,
85 bool AllowVarArgs, bool AllowAlloca) {
86 // taking the address of a basic block moved to another function is illegal
87 if (BB.hasAddressTaken())
88 return false;
89
90 // don't hoist code that uses another basicblock address, as it's likely to
91 // lead to unexpected behavior, like cross-function jumps
92 SmallPtrSet<User const *, 16> Visited;
93 SmallVector<User const *, 16> ToVisit;
94
95 for (Instruction const &Inst : BB)
96 ToVisit.push_back(&Inst);
97
98 while (!ToVisit.empty()) {
99 User const *Curr = ToVisit.pop_back_val();
100 if (!Visited.insert(Curr).second)
101 continue;
102 if (isa<BlockAddress const>(Curr))
103 return false; // even a reference to self is likely to be not compatible
104
105 if (isa<Instruction>(Curr) && cast<Instruction>(Curr)->getParent() != &BB)
106 continue;
107
108 for (auto const &U : Curr->operands()) {
109 if (auto *UU = dyn_cast<User>(U))
110 ToVisit.push_back(UU);
111 }
112 }
113
114 // If explicitly requested, allow vastart and alloca. For invoke instructions
115 // verify that extraction is valid.
116 for (BasicBlock::const_iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
117 if (isa<AllocaInst>(I)) {
118 if (!AllowAlloca)
119 return false;
120 continue;
121 }
122
123 if (const auto *II = dyn_cast<InvokeInst>(I)) {
124 // Unwind destination (either a landingpad, catchswitch, or cleanuppad)
125 // must be a part of the subgraph which is being extracted.
126 if (auto *UBB = II->getUnwindDest())
127 if (!Result.count(UBB))
128 return false;
129 continue;
130 }
131
132 // All catch handlers of a catchswitch instruction as well as the unwind
133 // destination must be in the subgraph.
134 if (const auto *CSI = dyn_cast<CatchSwitchInst>(I)) {
135 if (auto *UBB = CSI->getUnwindDest())
136 if (!Result.count(UBB))
137 return false;
138 for (auto *HBB : CSI->handlers())
139 if (!Result.count(const_cast<BasicBlock*>(HBB)))
140 return false;
141 continue;
142 }
143
144 // Make sure that entire catch handler is within subgraph. It is sufficient
145 // to check that catch return's block is in the list.
146 if (const auto *CPI = dyn_cast<CatchPadInst>(I)) {
147 for (const auto *U : CPI->users())
148 if (const auto *CRI = dyn_cast<CatchReturnInst>(U))
149 if (!Result.count(const_cast<BasicBlock*>(CRI->getParent())))
150 return false;
151 continue;
152 }
153
154 // And do similar checks for cleanup handler - the entire handler must be
155 // in subgraph which is going to be extracted. For cleanup return should
156 // additionally check that the unwind destination is also in the subgraph.
157 if (const auto *CPI = dyn_cast<CleanupPadInst>(I)) {
158 for (const auto *U : CPI->users())
159 if (const auto *CRI = dyn_cast<CleanupReturnInst>(U))
160 if (!Result.count(const_cast<BasicBlock*>(CRI->getParent())))
161 return false;
162 continue;
163 }
164 if (const auto *CRI = dyn_cast<CleanupReturnInst>(I)) {
165 if (auto *UBB = CRI->getUnwindDest())
166 if (!Result.count(UBB))
167 return false;
168 continue;
169 }
170
171 if (const CallInst *CI = dyn_cast<CallInst>(I)) {
172 if (const Function *F = CI->getCalledFunction()) {
173 auto IID = F->getIntrinsicID();
174 if (IID == Intrinsic::vastart) {
175 if (AllowVarArgs)
176 continue;
177 else
178 return false;
179 }
180
181 // Currently, we miscompile outlined copies of eh_typid_for. There are
182 // proposals for fixing this in llvm.org/PR39545.
183 if (IID == Intrinsic::eh_typeid_for)
184 return false;
185 }
186 }
187 }
188
189 return true;
190 }
191
192 /// Build a set of blocks to extract if the input blocks are viable.
193 static SetVector<BasicBlock *>
buildExtractionBlockSet(ArrayRef<BasicBlock * > BBs,DominatorTree * DT,bool AllowVarArgs,bool AllowAlloca)194 buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
195 bool AllowVarArgs, bool AllowAlloca) {
196 assert(!BBs.empty() && "The set of blocks to extract must be non-empty");
197 SetVector<BasicBlock *> Result;
198
199 // Loop over the blocks, adding them to our set-vector, and aborting with an
200 // empty set if we encounter invalid blocks.
201 for (BasicBlock *BB : BBs) {
202 // If this block is dead, don't process it.
203 if (DT && !DT->isReachableFromEntry(BB))
204 continue;
205
206 if (!Result.insert(BB))
207 llvm_unreachable("Repeated basic blocks in extraction input");
208 }
209
210 for (auto *BB : Result) {
211 if (!isBlockValidForExtraction(*BB, Result, AllowVarArgs, AllowAlloca))
212 return {};
213
214 // Make sure that the first block is not a landing pad.
215 if (BB == Result.front()) {
216 if (BB->isEHPad()) {
217 LLVM_DEBUG(dbgs() << "The first block cannot be an unwind block\n");
218 return {};
219 }
220 continue;
221 }
222
223 // All blocks other than the first must not have predecessors outside of
224 // the subgraph which is being extracted.
225 for (auto *PBB : predecessors(BB))
226 if (!Result.count(PBB)) {
227 LLVM_DEBUG(
228 dbgs() << "No blocks in this region may have entries from "
229 "outside the region except for the first block!\n");
230 return {};
231 }
232 }
233
234 return Result;
235 }
236
CodeExtractor(ArrayRef<BasicBlock * > BBs,DominatorTree * DT,bool AggregateArgs,BlockFrequencyInfo * BFI,BranchProbabilityInfo * BPI,bool AllowVarArgs,bool AllowAlloca,std::string Suffix)237 CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
238 bool AggregateArgs, BlockFrequencyInfo *BFI,
239 BranchProbabilityInfo *BPI, bool AllowVarArgs,
240 bool AllowAlloca, std::string Suffix)
241 : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
242 BPI(BPI), AllowVarArgs(AllowVarArgs),
243 Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
244 Suffix(Suffix) {}
245
CodeExtractor(DominatorTree & DT,Loop & L,bool AggregateArgs,BlockFrequencyInfo * BFI,BranchProbabilityInfo * BPI,std::string Suffix)246 CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
247 BlockFrequencyInfo *BFI,
248 BranchProbabilityInfo *BPI, std::string Suffix)
249 : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
250 BPI(BPI), AllowVarArgs(false),
251 Blocks(buildExtractionBlockSet(L.getBlocks(), &DT,
252 /* AllowVarArgs */ false,
253 /* AllowAlloca */ false)),
254 Suffix(Suffix) {}
255
256 /// definedInRegion - Return true if the specified value is defined in the
257 /// extracted region.
definedInRegion(const SetVector<BasicBlock * > & Blocks,Value * V)258 static bool definedInRegion(const SetVector<BasicBlock *> &Blocks, Value *V) {
259 if (Instruction *I = dyn_cast<Instruction>(V))
260 if (Blocks.count(I->getParent()))
261 return true;
262 return false;
263 }
264
265 /// definedInCaller - Return true if the specified value is defined in the
266 /// function being code extracted, but not in the region being extracted.
267 /// These values must be passed in as live-ins to the function.
definedInCaller(const SetVector<BasicBlock * > & Blocks,Value * V)268 static bool definedInCaller(const SetVector<BasicBlock *> &Blocks, Value *V) {
269 if (isa<Argument>(V)) return true;
270 if (Instruction *I = dyn_cast<Instruction>(V))
271 if (!Blocks.count(I->getParent()))
272 return true;
273 return false;
274 }
275
getCommonExitBlock(const SetVector<BasicBlock * > & Blocks)276 static BasicBlock *getCommonExitBlock(const SetVector<BasicBlock *> &Blocks) {
277 BasicBlock *CommonExitBlock = nullptr;
278 auto hasNonCommonExitSucc = [&](BasicBlock *Block) {
279 for (auto *Succ : successors(Block)) {
280 // Internal edges, ok.
281 if (Blocks.count(Succ))
282 continue;
283 if (!CommonExitBlock) {
284 CommonExitBlock = Succ;
285 continue;
286 }
287 if (CommonExitBlock == Succ)
288 continue;
289
290 return true;
291 }
292 return false;
293 };
294
295 if (any_of(Blocks, hasNonCommonExitSucc))
296 return nullptr;
297
298 return CommonExitBlock;
299 }
300
isLegalToShrinkwrapLifetimeMarkers(Instruction * Addr) const301 bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers(
302 Instruction *Addr) const {
303 AllocaInst *AI = cast<AllocaInst>(Addr->stripInBoundsConstantOffsets());
304 Function *Func = (*Blocks.begin())->getParent();
305 for (BasicBlock &BB : *Func) {
306 if (Blocks.count(&BB))
307 continue;
308 for (Instruction &II : BB) {
309 if (isa<DbgInfoIntrinsic>(II))
310 continue;
311
312 unsigned Opcode = II.getOpcode();
313 Value *MemAddr = nullptr;
314 switch (Opcode) {
315 case Instruction::Store:
316 case Instruction::Load: {
317 if (Opcode == Instruction::Store) {
318 StoreInst *SI = cast<StoreInst>(&II);
319 MemAddr = SI->getPointerOperand();
320 } else {
321 LoadInst *LI = cast<LoadInst>(&II);
322 MemAddr = LI->getPointerOperand();
323 }
324 // Global variable can not be aliased with locals.
325 if (dyn_cast<Constant>(MemAddr))
326 break;
327 Value *Base = MemAddr->stripInBoundsConstantOffsets();
328 if (!dyn_cast<AllocaInst>(Base) || Base == AI)
329 return false;
330 break;
331 }
332 default: {
333 IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(&II);
334 if (IntrInst) {
335 if (IntrInst->isLifetimeStartOrEnd())
336 break;
337 return false;
338 }
339 // Treat all the other cases conservatively if it has side effects.
340 if (II.mayHaveSideEffects())
341 return false;
342 }
343 }
344 }
345 }
346
347 return true;
348 }
349
350 BasicBlock *
findOrCreateBlockForHoisting(BasicBlock * CommonExitBlock)351 CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) {
352 BasicBlock *SinglePredFromOutlineRegion = nullptr;
353 assert(!Blocks.count(CommonExitBlock) &&
354 "Expect a block outside the region!");
355 for (auto *Pred : predecessors(CommonExitBlock)) {
356 if (!Blocks.count(Pred))
357 continue;
358 if (!SinglePredFromOutlineRegion) {
359 SinglePredFromOutlineRegion = Pred;
360 } else if (SinglePredFromOutlineRegion != Pred) {
361 SinglePredFromOutlineRegion = nullptr;
362 break;
363 }
364 }
365
366 if (SinglePredFromOutlineRegion)
367 return SinglePredFromOutlineRegion;
368
369 #ifndef NDEBUG
370 auto getFirstPHI = [](BasicBlock *BB) {
371 BasicBlock::iterator I = BB->begin();
372 PHINode *FirstPhi = nullptr;
373 while (I != BB->end()) {
374 PHINode *Phi = dyn_cast<PHINode>(I);
375 if (!Phi)
376 break;
377 if (!FirstPhi) {
378 FirstPhi = Phi;
379 break;
380 }
381 }
382 return FirstPhi;
383 };
384 // If there are any phi nodes, the single pred either exists or has already
385 // be created before code extraction.
386 assert(!getFirstPHI(CommonExitBlock) && "Phi not expected");
387 #endif
388
389 BasicBlock *NewExitBlock = CommonExitBlock->splitBasicBlock(
390 CommonExitBlock->getFirstNonPHI()->getIterator());
391
392 for (auto PI = pred_begin(CommonExitBlock), PE = pred_end(CommonExitBlock);
393 PI != PE;) {
394 BasicBlock *Pred = *PI++;
395 if (Blocks.count(Pred))
396 continue;
397 Pred->getTerminator()->replaceUsesOfWith(CommonExitBlock, NewExitBlock);
398 }
399 // Now add the old exit block to the outline region.
400 Blocks.insert(CommonExitBlock);
401 return CommonExitBlock;
402 }
403
findAllocas(ValueSet & SinkCands,ValueSet & HoistCands,BasicBlock * & ExitBlock) const404 void CodeExtractor::findAllocas(ValueSet &SinkCands, ValueSet &HoistCands,
405 BasicBlock *&ExitBlock) const {
406 Function *Func = (*Blocks.begin())->getParent();
407 ExitBlock = getCommonExitBlock(Blocks);
408
409 for (BasicBlock &BB : *Func) {
410 if (Blocks.count(&BB))
411 continue;
412 for (Instruction &II : BB) {
413 auto *AI = dyn_cast<AllocaInst>(&II);
414 if (!AI)
415 continue;
416
417 // Find the pair of life time markers for address 'Addr' that are either
418 // defined inside the outline region or can legally be shrinkwrapped into
419 // the outline region. If there are not other untracked uses of the
420 // address, return the pair of markers if found; otherwise return a pair
421 // of nullptr.
422 auto GetLifeTimeMarkers =
423 [&](Instruction *Addr, bool &SinkLifeStart,
424 bool &HoistLifeEnd) -> std::pair<Instruction *, Instruction *> {
425 Instruction *LifeStart = nullptr, *LifeEnd = nullptr;
426
427 for (User *U : Addr->users()) {
428 IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(U);
429 if (IntrInst) {
430 if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start) {
431 // Do not handle the case where AI has multiple start markers.
432 if (LifeStart)
433 return std::make_pair<Instruction *>(nullptr, nullptr);
434 LifeStart = IntrInst;
435 }
436 if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) {
437 if (LifeEnd)
438 return std::make_pair<Instruction *>(nullptr, nullptr);
439 LifeEnd = IntrInst;
440 }
441 continue;
442 }
443 // Find untracked uses of the address, bail.
444 if (!definedInRegion(Blocks, U))
445 return std::make_pair<Instruction *>(nullptr, nullptr);
446 }
447
448 if (!LifeStart || !LifeEnd)
449 return std::make_pair<Instruction *>(nullptr, nullptr);
450
451 SinkLifeStart = !definedInRegion(Blocks, LifeStart);
452 HoistLifeEnd = !definedInRegion(Blocks, LifeEnd);
453 // Do legality Check.
454 if ((SinkLifeStart || HoistLifeEnd) &&
455 !isLegalToShrinkwrapLifetimeMarkers(Addr))
456 return std::make_pair<Instruction *>(nullptr, nullptr);
457
458 // Check to see if we have a place to do hoisting, if not, bail.
459 if (HoistLifeEnd && !ExitBlock)
460 return std::make_pair<Instruction *>(nullptr, nullptr);
461
462 return std::make_pair(LifeStart, LifeEnd);
463 };
464
465 bool SinkLifeStart = false, HoistLifeEnd = false;
466 auto Markers = GetLifeTimeMarkers(AI, SinkLifeStart, HoistLifeEnd);
467
468 if (Markers.first) {
469 if (SinkLifeStart)
470 SinkCands.insert(Markers.first);
471 SinkCands.insert(AI);
472 if (HoistLifeEnd)
473 HoistCands.insert(Markers.second);
474 continue;
475 }
476
477 // Follow the bitcast.
478 Instruction *MarkerAddr = nullptr;
479 for (User *U : AI->users()) {
480 if (U->stripInBoundsConstantOffsets() == AI) {
481 SinkLifeStart = false;
482 HoistLifeEnd = false;
483 Instruction *Bitcast = cast<Instruction>(U);
484 Markers = GetLifeTimeMarkers(Bitcast, SinkLifeStart, HoistLifeEnd);
485 if (Markers.first) {
486 MarkerAddr = Bitcast;
487 continue;
488 }
489 }
490
491 // Found unknown use of AI.
492 if (!definedInRegion(Blocks, U)) {
493 MarkerAddr = nullptr;
494 break;
495 }
496 }
497
498 if (MarkerAddr) {
499 if (SinkLifeStart)
500 SinkCands.insert(Markers.first);
501 if (!definedInRegion(Blocks, MarkerAddr))
502 SinkCands.insert(MarkerAddr);
503 SinkCands.insert(AI);
504 if (HoistLifeEnd)
505 HoistCands.insert(Markers.second);
506 }
507 }
508 }
509 }
510
findInputsOutputs(ValueSet & Inputs,ValueSet & Outputs,const ValueSet & SinkCands) const511 void CodeExtractor::findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs,
512 const ValueSet &SinkCands) const {
513 for (BasicBlock *BB : Blocks) {
514 // If a used value is defined outside the region, it's an input. If an
515 // instruction is used outside the region, it's an output.
516 for (Instruction &II : *BB) {
517 for (User::op_iterator OI = II.op_begin(), OE = II.op_end(); OI != OE;
518 ++OI) {
519 Value *V = *OI;
520 if (!SinkCands.count(V) && definedInCaller(Blocks, V))
521 Inputs.insert(V);
522 }
523
524 for (User *U : II.users())
525 if (!definedInRegion(Blocks, U)) {
526 Outputs.insert(&II);
527 break;
528 }
529 }
530 }
531 }
532
533 /// severSplitPHINodesOfEntry - If a PHI node has multiple inputs from outside
534 /// of the region, we need to split the entry block of the region so that the
535 /// PHI node is easier to deal with.
severSplitPHINodesOfEntry(BasicBlock * & Header)536 void CodeExtractor::severSplitPHINodesOfEntry(BasicBlock *&Header) {
537 unsigned NumPredsFromRegion = 0;
538 unsigned NumPredsOutsideRegion = 0;
539
540 if (Header != &Header->getParent()->getEntryBlock()) {
541 PHINode *PN = dyn_cast<PHINode>(Header->begin());
542 if (!PN) return; // No PHI nodes.
543
544 // If the header node contains any PHI nodes, check to see if there is more
545 // than one entry from outside the region. If so, we need to sever the
546 // header block into two.
547 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
548 if (Blocks.count(PN->getIncomingBlock(i)))
549 ++NumPredsFromRegion;
550 else
551 ++NumPredsOutsideRegion;
552
553 // If there is one (or fewer) predecessor from outside the region, we don't
554 // need to do anything special.
555 if (NumPredsOutsideRegion <= 1) return;
556 }
557
558 // Otherwise, we need to split the header block into two pieces: one
559 // containing PHI nodes merging values from outside of the region, and a
560 // second that contains all of the code for the block and merges back any
561 // incoming values from inside of the region.
562 BasicBlock *NewBB = SplitBlock(Header, Header->getFirstNonPHI(), DT);
563
564 // We only want to code extract the second block now, and it becomes the new
565 // header of the region.
566 BasicBlock *OldPred = Header;
567 Blocks.remove(OldPred);
568 Blocks.insert(NewBB);
569 Header = NewBB;
570
571 // Okay, now we need to adjust the PHI nodes and any branches from within the
572 // region to go to the new header block instead of the old header block.
573 if (NumPredsFromRegion) {
574 PHINode *PN = cast<PHINode>(OldPred->begin());
575 // Loop over all of the predecessors of OldPred that are in the region,
576 // changing them to branch to NewBB instead.
577 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
578 if (Blocks.count(PN->getIncomingBlock(i))) {
579 Instruction *TI = PN->getIncomingBlock(i)->getTerminator();
580 TI->replaceUsesOfWith(OldPred, NewBB);
581 }
582
583 // Okay, everything within the region is now branching to the right block, we
584 // just have to update the PHI nodes now, inserting PHI nodes into NewBB.
585 BasicBlock::iterator AfterPHIs;
586 for (AfterPHIs = OldPred->begin(); isa<PHINode>(AfterPHIs); ++AfterPHIs) {
587 PHINode *PN = cast<PHINode>(AfterPHIs);
588 // Create a new PHI node in the new region, which has an incoming value
589 // from OldPred of PN.
590 PHINode *NewPN = PHINode::Create(PN->getType(), 1 + NumPredsFromRegion,
591 PN->getName() + ".ce", &NewBB->front());
592 PN->replaceAllUsesWith(NewPN);
593 NewPN->addIncoming(PN, OldPred);
594
595 // Loop over all of the incoming value in PN, moving them to NewPN if they
596 // are from the extracted region.
597 for (unsigned i = 0; i != PN->getNumIncomingValues(); ++i) {
598 if (Blocks.count(PN->getIncomingBlock(i))) {
599 NewPN->addIncoming(PN->getIncomingValue(i), PN->getIncomingBlock(i));
600 PN->removeIncomingValue(i);
601 --i;
602 }
603 }
604 }
605 }
606 }
607
608 /// severSplitPHINodesOfExits - if PHI nodes in exit blocks have inputs from
609 /// outlined region, we split these PHIs on two: one with inputs from region
610 /// and other with remaining incoming blocks; then first PHIs are placed in
611 /// outlined region.
severSplitPHINodesOfExits(const SmallPtrSetImpl<BasicBlock * > & Exits)612 void CodeExtractor::severSplitPHINodesOfExits(
613 const SmallPtrSetImpl<BasicBlock *> &Exits) {
614 for (BasicBlock *ExitBB : Exits) {
615 BasicBlock *NewBB = nullptr;
616
617 for (PHINode &PN : ExitBB->phis()) {
618 // Find all incoming values from the outlining region.
619 SmallVector<unsigned, 2> IncomingVals;
620 for (unsigned i = 0; i < PN.getNumIncomingValues(); ++i)
621 if (Blocks.count(PN.getIncomingBlock(i)))
622 IncomingVals.push_back(i);
623
624 // Do not process PHI if there is one (or fewer) predecessor from region.
625 // If PHI has exactly one predecessor from region, only this one incoming
626 // will be replaced on codeRepl block, so it should be safe to skip PHI.
627 if (IncomingVals.size() <= 1)
628 continue;
629
630 // Create block for new PHIs and add it to the list of outlined if it
631 // wasn't done before.
632 if (!NewBB) {
633 NewBB = BasicBlock::Create(ExitBB->getContext(),
634 ExitBB->getName() + ".split",
635 ExitBB->getParent(), ExitBB);
636 SmallVector<BasicBlock *, 4> Preds(pred_begin(ExitBB),
637 pred_end(ExitBB));
638 for (BasicBlock *PredBB : Preds)
639 if (Blocks.count(PredBB))
640 PredBB->getTerminator()->replaceUsesOfWith(ExitBB, NewBB);
641 BranchInst::Create(ExitBB, NewBB);
642 Blocks.insert(NewBB);
643 }
644
645 // Split this PHI.
646 PHINode *NewPN =
647 PHINode::Create(PN.getType(), IncomingVals.size(),
648 PN.getName() + ".ce", NewBB->getFirstNonPHI());
649 for (unsigned i : IncomingVals)
650 NewPN->addIncoming(PN.getIncomingValue(i), PN.getIncomingBlock(i));
651 for (unsigned i : reverse(IncomingVals))
652 PN.removeIncomingValue(i, false);
653 PN.addIncoming(NewPN, NewBB);
654 }
655 }
656 }
657
splitReturnBlocks()658 void CodeExtractor::splitReturnBlocks() {
659 for (BasicBlock *Block : Blocks)
660 if (ReturnInst *RI = dyn_cast<ReturnInst>(Block->getTerminator())) {
661 BasicBlock *New =
662 Block->splitBasicBlock(RI->getIterator(), Block->getName() + ".ret");
663 if (DT) {
664 // Old dominates New. New node dominates all other nodes dominated
665 // by Old.
666 DomTreeNode *OldNode = DT->getNode(Block);
667 SmallVector<DomTreeNode *, 8> Children(OldNode->begin(),
668 OldNode->end());
669
670 DomTreeNode *NewNode = DT->addNewBlock(New, Block);
671
672 for (DomTreeNode *I : Children)
673 DT->changeImmediateDominator(I, NewNode);
674 }
675 }
676 }
677
678 /// constructFunction - make a function based on inputs and outputs, as follows:
679 /// f(in0, ..., inN, out0, ..., outN)
constructFunction(const ValueSet & inputs,const ValueSet & outputs,BasicBlock * header,BasicBlock * newRootNode,BasicBlock * newHeader,Function * oldFunction,Module * M)680 Function *CodeExtractor::constructFunction(const ValueSet &inputs,
681 const ValueSet &outputs,
682 BasicBlock *header,
683 BasicBlock *newRootNode,
684 BasicBlock *newHeader,
685 Function *oldFunction,
686 Module *M) {
687 LLVM_DEBUG(dbgs() << "inputs: " << inputs.size() << "\n");
688 LLVM_DEBUG(dbgs() << "outputs: " << outputs.size() << "\n");
689
690 // This function returns unsigned, outputs will go back by reference.
691 switch (NumExitBlocks) {
692 case 0:
693 case 1: RetTy = Type::getVoidTy(header->getContext()); break;
694 case 2: RetTy = Type::getInt1Ty(header->getContext()); break;
695 default: RetTy = Type::getInt16Ty(header->getContext()); break;
696 }
697
698 std::vector<Type *> paramTy;
699
700 // Add the types of the input values to the function's argument list
701 for (Value *value : inputs) {
702 LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n");
703 paramTy.push_back(value->getType());
704 }
705
706 // Add the types of the output values to the function's argument list.
707 for (Value *output : outputs) {
708 LLVM_DEBUG(dbgs() << "instr used in func: " << *output << "\n");
709 if (AggregateArgs)
710 paramTy.push_back(output->getType());
711 else
712 paramTy.push_back(PointerType::getUnqual(output->getType()));
713 }
714
715 LLVM_DEBUG({
716 dbgs() << "Function type: " << *RetTy << " f(";
717 for (Type *i : paramTy)
718 dbgs() << *i << ", ";
719 dbgs() << ")\n";
720 });
721
722 StructType *StructTy;
723 if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
724 StructTy = StructType::get(M->getContext(), paramTy);
725 paramTy.clear();
726 paramTy.push_back(PointerType::getUnqual(StructTy));
727 }
728 FunctionType *funcType =
729 FunctionType::get(RetTy, paramTy,
730 AllowVarArgs && oldFunction->isVarArg());
731
732 std::string SuffixToUse =
733 Suffix.empty()
734 ? (header->getName().empty() ? "extracted" : header->getName().str())
735 : Suffix;
736 // Create the new function
737 Function *newFunction = Function::Create(
738 funcType, GlobalValue::InternalLinkage, oldFunction->getAddressSpace(),
739 oldFunction->getName() + "." + SuffixToUse, M);
740 // If the old function is no-throw, so is the new one.
741 if (oldFunction->doesNotThrow())
742 newFunction->setDoesNotThrow();
743
744 // Inherit the uwtable attribute if we need to.
745 if (oldFunction->hasUWTable())
746 newFunction->setHasUWTable();
747
748 // Inherit all of the target dependent attributes and white-listed
749 // target independent attributes.
750 // (e.g. If the extracted region contains a call to an x86.sse
751 // instruction we need to make sure that the extracted region has the
752 // "target-features" attribute allowing it to be lowered.
753 // FIXME: This should be changed to check to see if a specific
754 // attribute can not be inherited.
755 for (const auto &Attr : oldFunction->getAttributes().getFnAttributes()) {
756 if (Attr.isStringAttribute()) {
757 if (Attr.getKindAsString() == "thunk")
758 continue;
759 } else
760 switch (Attr.getKindAsEnum()) {
761 // Those attributes cannot be propagated safely. Explicitly list them
762 // here so we get a warning if new attributes are added. This list also
763 // includes non-function attributes.
764 case Attribute::Alignment:
765 case Attribute::AllocSize:
766 case Attribute::ArgMemOnly:
767 case Attribute::Builtin:
768 case Attribute::ByVal:
769 case Attribute::Convergent:
770 case Attribute::Dereferenceable:
771 case Attribute::DereferenceableOrNull:
772 case Attribute::InAlloca:
773 case Attribute::InReg:
774 case Attribute::InaccessibleMemOnly:
775 case Attribute::InaccessibleMemOrArgMemOnly:
776 case Attribute::JumpTable:
777 case Attribute::Naked:
778 case Attribute::Nest:
779 case Attribute::NoAlias:
780 case Attribute::NoBuiltin:
781 case Attribute::NoCapture:
782 case Attribute::NoReturn:
783 case Attribute::None:
784 case Attribute::NonNull:
785 case Attribute::ReadNone:
786 case Attribute::ReadOnly:
787 case Attribute::Returned:
788 case Attribute::ReturnsTwice:
789 case Attribute::SExt:
790 case Attribute::Speculatable:
791 case Attribute::StackAlignment:
792 case Attribute::StructRet:
793 case Attribute::SwiftError:
794 case Attribute::SwiftSelf:
795 case Attribute::WriteOnly:
796 case Attribute::ZExt:
797 case Attribute::EndAttrKinds:
798 continue;
799 // Those attributes should be safe to propagate to the extracted function.
800 case Attribute::AlwaysInline:
801 case Attribute::Cold:
802 case Attribute::NoRecurse:
803 case Attribute::InlineHint:
804 case Attribute::MinSize:
805 case Attribute::NoDuplicate:
806 case Attribute::NoImplicitFloat:
807 case Attribute::NoInline:
808 case Attribute::NonLazyBind:
809 case Attribute::NoRedZone:
810 case Attribute::NoUnwind:
811 case Attribute::OptForFuzzing:
812 case Attribute::OptimizeNone:
813 case Attribute::OptimizeForSize:
814 case Attribute::SafeStack:
815 case Attribute::ShadowCallStack:
816 case Attribute::SanitizeAddress:
817 case Attribute::SanitizeMemory:
818 case Attribute::SanitizeThread:
819 case Attribute::SanitizeHWAddress:
820 case Attribute::SpeculativeLoadHardening:
821 case Attribute::StackProtect:
822 case Attribute::StackProtectReq:
823 case Attribute::StackProtectStrong:
824 case Attribute::StrictFP:
825 case Attribute::UWTable:
826 case Attribute::NoCfCheck:
827 break;
828 }
829
830 newFunction->addFnAttr(Attr);
831 }
832 newFunction->getBasicBlockList().push_back(newRootNode);
833
834 // Create an iterator to name all of the arguments we inserted.
835 Function::arg_iterator AI = newFunction->arg_begin();
836
837 // Rewrite all users of the inputs in the extracted region to use the
838 // arguments (or appropriate addressing into struct) instead.
839 for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
840 Value *RewriteVal;
841 if (AggregateArgs) {
842 Value *Idx[2];
843 Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext()));
844 Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), i);
845 Instruction *TI = newFunction->begin()->getTerminator();
846 GetElementPtrInst *GEP = GetElementPtrInst::Create(
847 StructTy, &*AI, Idx, "gep_" + inputs[i]->getName(), TI);
848 RewriteVal = new LoadInst(GEP, "loadgep_" + inputs[i]->getName(), TI);
849 } else
850 RewriteVal = &*AI++;
851
852 std::vector<User *> Users(inputs[i]->user_begin(), inputs[i]->user_end());
853 for (User *use : Users)
854 if (Instruction *inst = dyn_cast<Instruction>(use))
855 if (Blocks.count(inst->getParent()))
856 inst->replaceUsesOfWith(inputs[i], RewriteVal);
857 }
858
859 // Set names for input and output arguments.
860 if (!AggregateArgs) {
861 AI = newFunction->arg_begin();
862 for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++AI)
863 AI->setName(inputs[i]->getName());
864 for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++AI)
865 AI->setName(outputs[i]->getName()+".out");
866 }
867
868 // Rewrite branches to basic blocks outside of the loop to new dummy blocks
869 // within the new function. This must be done before we lose track of which
870 // blocks were originally in the code region.
871 std::vector<User *> Users(header->user_begin(), header->user_end());
872 for (unsigned i = 0, e = Users.size(); i != e; ++i)
873 // The BasicBlock which contains the branch is not in the region
874 // modify the branch target to a new block
875 if (Instruction *I = dyn_cast<Instruction>(Users[i]))
876 if (I->isTerminator() && !Blocks.count(I->getParent()) &&
877 I->getParent()->getParent() == oldFunction)
878 I->replaceUsesOfWith(header, newHeader);
879
880 return newFunction;
881 }
882
883 /// emitCallAndSwitchStatement - This method sets up the caller side by adding
884 /// the call instruction, splitting any PHI nodes in the header block as
885 /// necessary.
emitCallAndSwitchStatement(Function * newFunction,BasicBlock * codeReplacer,ValueSet & inputs,ValueSet & outputs)886 CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
887 BasicBlock *codeReplacer,
888 ValueSet &inputs,
889 ValueSet &outputs) {
890 // Emit a call to the new function, passing in: *pointer to struct (if
891 // aggregating parameters), or plan inputs and allocated memory for outputs
892 std::vector<Value *> params, StructValues, ReloadOutputs, Reloads;
893
894 Module *M = newFunction->getParent();
895 LLVMContext &Context = M->getContext();
896 const DataLayout &DL = M->getDataLayout();
897 CallInst *call = nullptr;
898
899 // Add inputs as params, or to be filled into the struct
900 for (Value *input : inputs)
901 if (AggregateArgs)
902 StructValues.push_back(input);
903 else
904 params.push_back(input);
905
906 // Create allocas for the outputs
907 for (Value *output : outputs) {
908 if (AggregateArgs) {
909 StructValues.push_back(output);
910 } else {
911 AllocaInst *alloca =
912 new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
913 nullptr, output->getName() + ".loc",
914 &codeReplacer->getParent()->front().front());
915 ReloadOutputs.push_back(alloca);
916 params.push_back(alloca);
917 }
918 }
919
920 StructType *StructArgTy = nullptr;
921 AllocaInst *Struct = nullptr;
922 if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
923 std::vector<Type *> ArgTypes;
924 for (ValueSet::iterator v = StructValues.begin(),
925 ve = StructValues.end(); v != ve; ++v)
926 ArgTypes.push_back((*v)->getType());
927
928 // Allocate a struct at the beginning of this function
929 StructArgTy = StructType::get(newFunction->getContext(), ArgTypes);
930 Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
931 "structArg",
932 &codeReplacer->getParent()->front().front());
933 params.push_back(Struct);
934
935 for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
936 Value *Idx[2];
937 Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
938 Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i);
939 GetElementPtrInst *GEP = GetElementPtrInst::Create(
940 StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName());
941 codeReplacer->getInstList().push_back(GEP);
942 StoreInst *SI = new StoreInst(StructValues[i], GEP);
943 codeReplacer->getInstList().push_back(SI);
944 }
945 }
946
947 // Emit the call to the function
948 call = CallInst::Create(newFunction, params,
949 NumExitBlocks > 1 ? "targetBlock" : "");
950 // Add debug location to the new call, if the original function has debug
951 // info. In that case, the terminator of the entry block of the extracted
952 // function contains the first debug location of the extracted function,
953 // set in extractCodeRegion.
954 if (codeReplacer->getParent()->getSubprogram()) {
955 if (auto DL = newFunction->getEntryBlock().getTerminator()->getDebugLoc())
956 call->setDebugLoc(DL);
957 }
958 codeReplacer->getInstList().push_back(call);
959
960 Function::arg_iterator OutputArgBegin = newFunction->arg_begin();
961 unsigned FirstOut = inputs.size();
962 if (!AggregateArgs)
963 std::advance(OutputArgBegin, inputs.size());
964
965 // Reload the outputs passed in by reference.
966 Function::arg_iterator OAI = OutputArgBegin;
967 for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
968 Value *Output = nullptr;
969 if (AggregateArgs) {
970 Value *Idx[2];
971 Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
972 Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
973 GetElementPtrInst *GEP = GetElementPtrInst::Create(
974 StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName());
975 codeReplacer->getInstList().push_back(GEP);
976 Output = GEP;
977 } else {
978 Output = ReloadOutputs[i];
979 }
980 LoadInst *load = new LoadInst(Output, outputs[i]->getName()+".reload");
981 Reloads.push_back(load);
982 codeReplacer->getInstList().push_back(load);
983 std::vector<User *> Users(outputs[i]->user_begin(), outputs[i]->user_end());
984 for (unsigned u = 0, e = Users.size(); u != e; ++u) {
985 Instruction *inst = cast<Instruction>(Users[u]);
986 if (!Blocks.count(inst->getParent()))
987 inst->replaceUsesOfWith(outputs[i], load);
988 }
989
990 // Store to argument right after the definition of output value.
991 auto *OutI = dyn_cast<Instruction>(outputs[i]);
992 if (!OutI)
993 continue;
994
995 // Find proper insertion point.
996 BasicBlock::iterator InsertPt;
997 // In case OutI is an invoke, we insert the store at the beginning in the
998 // 'normal destination' BB. Otherwise we insert the store right after OutI.
999 if (auto *InvokeI = dyn_cast<InvokeInst>(OutI))
1000 InsertPt = InvokeI->getNormalDest()->getFirstInsertionPt();
1001 else if (auto *Phi = dyn_cast<PHINode>(OutI))
1002 InsertPt = Phi->getParent()->getFirstInsertionPt();
1003 else
1004 InsertPt = std::next(OutI->getIterator());
1005
1006 assert(OAI != newFunction->arg_end() &&
1007 "Number of output arguments should match "
1008 "the amount of defined values");
1009 if (AggregateArgs) {
1010 Value *Idx[2];
1011 Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
1012 Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
1013 GetElementPtrInst *GEP = GetElementPtrInst::Create(
1014 StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(), &*InsertPt);
1015 new StoreInst(outputs[i], GEP, &*InsertPt);
1016 // Since there should be only one struct argument aggregating
1017 // all the output values, we shouldn't increment OAI, which always
1018 // points to the struct argument, in this case.
1019 } else {
1020 new StoreInst(outputs[i], &*OAI, &*InsertPt);
1021 ++OAI;
1022 }
1023 }
1024
1025 // Now we can emit a switch statement using the call as a value.
1026 SwitchInst *TheSwitch =
1027 SwitchInst::Create(Constant::getNullValue(Type::getInt16Ty(Context)),
1028 codeReplacer, 0, codeReplacer);
1029
1030 // Since there may be multiple exits from the original region, make the new
1031 // function return an unsigned, switch on that number. This loop iterates
1032 // over all of the blocks in the extracted region, updating any terminator
1033 // instructions in the to-be-extracted region that branch to blocks that are
1034 // not in the region to be extracted.
1035 std::map<BasicBlock *, BasicBlock *> ExitBlockMap;
1036
1037 unsigned switchVal = 0;
1038 for (BasicBlock *Block : Blocks) {
1039 Instruction *TI = Block->getTerminator();
1040 for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
1041 if (!Blocks.count(TI->getSuccessor(i))) {
1042 BasicBlock *OldTarget = TI->getSuccessor(i);
1043 // add a new basic block which returns the appropriate value
1044 BasicBlock *&NewTarget = ExitBlockMap[OldTarget];
1045 if (!NewTarget) {
1046 // If we don't already have an exit stub for this non-extracted
1047 // destination, create one now!
1048 NewTarget = BasicBlock::Create(Context,
1049 OldTarget->getName() + ".exitStub",
1050 newFunction);
1051 unsigned SuccNum = switchVal++;
1052
1053 Value *brVal = nullptr;
1054 switch (NumExitBlocks) {
1055 case 0:
1056 case 1: break; // No value needed.
1057 case 2: // Conditional branch, return a bool
1058 brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum);
1059 break;
1060 default:
1061 brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum);
1062 break;
1063 }
1064
1065 ReturnInst::Create(Context, brVal, NewTarget);
1066
1067 // Update the switch instruction.
1068 TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context),
1069 SuccNum),
1070 OldTarget);
1071 }
1072
1073 // rewrite the original branch instruction with this new target
1074 TI->setSuccessor(i, NewTarget);
1075 }
1076 }
1077
1078 // Now that we've done the deed, simplify the switch instruction.
1079 Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType();
1080 switch (NumExitBlocks) {
1081 case 0:
1082 // There are no successors (the block containing the switch itself), which
1083 // means that previously this was the last part of the function, and hence
1084 // this should be rewritten as a `ret'
1085
1086 // Check if the function should return a value
1087 if (OldFnRetTy->isVoidTy()) {
1088 ReturnInst::Create(Context, nullptr, TheSwitch); // Return void
1089 } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) {
1090 // return what we have
1091 ReturnInst::Create(Context, TheSwitch->getCondition(), TheSwitch);
1092 } else {
1093 // Otherwise we must have code extracted an unwind or something, just
1094 // return whatever we want.
1095 ReturnInst::Create(Context,
1096 Constant::getNullValue(OldFnRetTy), TheSwitch);
1097 }
1098
1099 TheSwitch->eraseFromParent();
1100 break;
1101 case 1:
1102 // Only a single destination, change the switch into an unconditional
1103 // branch.
1104 BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch);
1105 TheSwitch->eraseFromParent();
1106 break;
1107 case 2:
1108 BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2),
1109 call, TheSwitch);
1110 TheSwitch->eraseFromParent();
1111 break;
1112 default:
1113 // Otherwise, make the default destination of the switch instruction be one
1114 // of the other successors.
1115 TheSwitch->setCondition(call);
1116 TheSwitch->setDefaultDest(TheSwitch->getSuccessor(NumExitBlocks));
1117 // Remove redundant case
1118 TheSwitch->removeCase(SwitchInst::CaseIt(TheSwitch, NumExitBlocks-1));
1119 break;
1120 }
1121
1122 return call;
1123 }
1124
moveCodeToFunction(Function * newFunction)1125 void CodeExtractor::moveCodeToFunction(Function *newFunction) {
1126 Function *oldFunc = (*Blocks.begin())->getParent();
1127 Function::BasicBlockListType &oldBlocks = oldFunc->getBasicBlockList();
1128 Function::BasicBlockListType &newBlocks = newFunction->getBasicBlockList();
1129
1130 for (BasicBlock *Block : Blocks) {
1131 // Delete the basic block from the old function, and the list of blocks
1132 oldBlocks.remove(Block);
1133
1134 // Insert this basic block into the new function
1135 newBlocks.push_back(Block);
1136 }
1137 }
1138
calculateNewCallTerminatorWeights(BasicBlock * CodeReplacer,DenseMap<BasicBlock *,BlockFrequency> & ExitWeights,BranchProbabilityInfo * BPI)1139 void CodeExtractor::calculateNewCallTerminatorWeights(
1140 BasicBlock *CodeReplacer,
1141 DenseMap<BasicBlock *, BlockFrequency> &ExitWeights,
1142 BranchProbabilityInfo *BPI) {
1143 using Distribution = BlockFrequencyInfoImplBase::Distribution;
1144 using BlockNode = BlockFrequencyInfoImplBase::BlockNode;
1145
1146 // Update the branch weights for the exit block.
1147 Instruction *TI = CodeReplacer->getTerminator();
1148 SmallVector<unsigned, 8> BranchWeights(TI->getNumSuccessors(), 0);
1149
1150 // Block Frequency distribution with dummy node.
1151 Distribution BranchDist;
1152
1153 // Add each of the frequencies of the successors.
1154 for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) {
1155 BlockNode ExitNode(i);
1156 uint64_t ExitFreq = ExitWeights[TI->getSuccessor(i)].getFrequency();
1157 if (ExitFreq != 0)
1158 BranchDist.addExit(ExitNode, ExitFreq);
1159 else
1160 BPI->setEdgeProbability(CodeReplacer, i, BranchProbability::getZero());
1161 }
1162
1163 // Check for no total weight.
1164 if (BranchDist.Total == 0)
1165 return;
1166
1167 // Normalize the distribution so that they can fit in unsigned.
1168 BranchDist.normalize();
1169
1170 // Create normalized branch weights and set the metadata.
1171 for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) {
1172 const auto &Weight = BranchDist.Weights[I];
1173
1174 // Get the weight and update the current BFI.
1175 BranchWeights[Weight.TargetNode.Index] = Weight.Amount;
1176 BranchProbability BP(Weight.Amount, BranchDist.Total);
1177 BPI->setEdgeProbability(CodeReplacer, Weight.TargetNode.Index, BP);
1178 }
1179 TI->setMetadata(
1180 LLVMContext::MD_prof,
1181 MDBuilder(TI->getContext()).createBranchWeights(BranchWeights));
1182 }
1183
1184 /// Scan the extraction region for lifetime markers which reference inputs.
1185 /// Erase these markers. Return the inputs which were referenced.
1186 ///
1187 /// The extraction region is defined by a set of blocks (\p Blocks), and a set
1188 /// of allocas which will be moved from the caller function into the extracted
1189 /// function (\p SunkAllocas).
1190 static SetVector<Value *>
eraseLifetimeMarkersOnInputs(const SetVector<BasicBlock * > & Blocks,const SetVector<Value * > & SunkAllocas)1191 eraseLifetimeMarkersOnInputs(const SetVector<BasicBlock *> &Blocks,
1192 const SetVector<Value *> &SunkAllocas) {
1193 SetVector<Value *> InputObjectsWithLifetime;
1194 for (BasicBlock *BB : Blocks) {
1195 for (auto It = BB->begin(), End = BB->end(); It != End;) {
1196 auto *II = dyn_cast<IntrinsicInst>(&*It);
1197 ++It;
1198 if (!II || !II->isLifetimeStartOrEnd())
1199 continue;
1200
1201 // Get the memory operand of the lifetime marker. If the underlying
1202 // object is a sunk alloca, or is otherwise defined in the extraction
1203 // region, the lifetime marker must not be erased.
1204 Value *Mem = II->getOperand(1)->stripInBoundsOffsets();
1205 if (SunkAllocas.count(Mem) || definedInRegion(Blocks, Mem))
1206 continue;
1207
1208 InputObjectsWithLifetime.insert(Mem);
1209 II->eraseFromParent();
1210 }
1211 }
1212 return InputObjectsWithLifetime;
1213 }
1214
1215 /// Insert lifetime start/end markers surrounding the call to the new function
1216 /// for objects defined in the caller.
insertLifetimeMarkersSurroundingCall(Module * M,const SetVector<Value * > & InputObjectsWithLifetime,CallInst * TheCall)1217 static void insertLifetimeMarkersSurroundingCall(
1218 Module *M, const SetVector<Value *> &InputObjectsWithLifetime,
1219 CallInst *TheCall) {
1220 if (InputObjectsWithLifetime.empty())
1221 return;
1222
1223 LLVMContext &Ctx = M->getContext();
1224 auto Int8PtrTy = Type::getInt8PtrTy(Ctx);
1225 auto NegativeOne = ConstantInt::getSigned(Type::getInt64Ty(Ctx), -1);
1226 auto LifetimeStartFn = llvm::Intrinsic::getDeclaration(
1227 M, llvm::Intrinsic::lifetime_start, Int8PtrTy);
1228 auto LifetimeEndFn = llvm::Intrinsic::getDeclaration(
1229 M, llvm::Intrinsic::lifetime_end, Int8PtrTy);
1230 for (Value *Mem : InputObjectsWithLifetime) {
1231 assert((!isa<Instruction>(Mem) ||
1232 cast<Instruction>(Mem)->getFunction() == TheCall->getFunction()) &&
1233 "Input memory not defined in original function");
1234 Value *MemAsI8Ptr = nullptr;
1235 if (Mem->getType() == Int8PtrTy)
1236 MemAsI8Ptr = Mem;
1237 else
1238 MemAsI8Ptr =
1239 CastInst::CreatePointerCast(Mem, Int8PtrTy, "lt.cast", TheCall);
1240
1241 auto StartMarker =
1242 CallInst::Create(LifetimeStartFn, {NegativeOne, MemAsI8Ptr});
1243 StartMarker->insertBefore(TheCall);
1244 auto EndMarker = CallInst::Create(LifetimeEndFn, {NegativeOne, MemAsI8Ptr});
1245 EndMarker->insertAfter(TheCall);
1246 }
1247 }
1248
extractCodeRegion()1249 Function *CodeExtractor::extractCodeRegion() {
1250 if (!isEligible())
1251 return nullptr;
1252
1253 // Assumption: this is a single-entry code region, and the header is the first
1254 // block in the region.
1255 BasicBlock *header = *Blocks.begin();
1256 Function *oldFunction = header->getParent();
1257
1258 // For functions with varargs, check that varargs handling is only done in the
1259 // outlined function, i.e vastart and vaend are only used in outlined blocks.
1260 if (AllowVarArgs && oldFunction->getFunctionType()->isVarArg()) {
1261 auto containsVarArgIntrinsic = [](Instruction &I) {
1262 if (const CallInst *CI = dyn_cast<CallInst>(&I))
1263 if (const Function *F = CI->getCalledFunction())
1264 return F->getIntrinsicID() == Intrinsic::vastart ||
1265 F->getIntrinsicID() == Intrinsic::vaend;
1266 return false;
1267 };
1268
1269 for (auto &BB : *oldFunction) {
1270 if (Blocks.count(&BB))
1271 continue;
1272 if (llvm::any_of(BB, containsVarArgIntrinsic))
1273 return nullptr;
1274 }
1275 }
1276 ValueSet inputs, outputs, SinkingCands, HoistingCands;
1277 BasicBlock *CommonExit = nullptr;
1278
1279 // Calculate the entry frequency of the new function before we change the root
1280 // block.
1281 BlockFrequency EntryFreq;
1282 if (BFI) {
1283 assert(BPI && "Both BPI and BFI are required to preserve profile info");
1284 for (BasicBlock *Pred : predecessors(header)) {
1285 if (Blocks.count(Pred))
1286 continue;
1287 EntryFreq +=
1288 BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header);
1289 }
1290 }
1291
1292 // If we have any return instructions in the region, split those blocks so
1293 // that the return is not in the region.
1294 splitReturnBlocks();
1295
1296 // Calculate the exit blocks for the extracted region and the total exit
1297 // weights for each of those blocks.
1298 DenseMap<BasicBlock *, BlockFrequency> ExitWeights;
1299 SmallPtrSet<BasicBlock *, 1> ExitBlocks;
1300 for (BasicBlock *Block : Blocks) {
1301 for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE;
1302 ++SI) {
1303 if (!Blocks.count(*SI)) {
1304 // Update the branch weight for this successor.
1305 if (BFI) {
1306 BlockFrequency &BF = ExitWeights[*SI];
1307 BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI);
1308 }
1309 ExitBlocks.insert(*SI);
1310 }
1311 }
1312 }
1313 NumExitBlocks = ExitBlocks.size();
1314
1315 // If we have to split PHI nodes of the entry or exit blocks, do so now.
1316 severSplitPHINodesOfEntry(header);
1317 severSplitPHINodesOfExits(ExitBlocks);
1318
1319 // This takes place of the original loop
1320 BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(),
1321 "codeRepl", oldFunction,
1322 header);
1323
1324 // The new function needs a root node because other nodes can branch to the
1325 // head of the region, but the entry node of a function cannot have preds.
1326 BasicBlock *newFuncRoot = BasicBlock::Create(header->getContext(),
1327 "newFuncRoot");
1328 auto *BranchI = BranchInst::Create(header);
1329 // If the original function has debug info, we have to add a debug location
1330 // to the new branch instruction from the artificial entry block.
1331 // We use the debug location of the first instruction in the extracted
1332 // blocks, as there is no other equivalent line in the source code.
1333 if (oldFunction->getSubprogram()) {
1334 any_of(Blocks, [&BranchI](const BasicBlock *BB) {
1335 return any_of(*BB, [&BranchI](const Instruction &I) {
1336 if (!I.getDebugLoc())
1337 return false;
1338 BranchI->setDebugLoc(I.getDebugLoc());
1339 return true;
1340 });
1341 });
1342 }
1343 newFuncRoot->getInstList().push_back(BranchI);
1344
1345 findAllocas(SinkingCands, HoistingCands, CommonExit);
1346 assert(HoistingCands.empty() || CommonExit);
1347
1348 // Find inputs to, outputs from the code region.
1349 findInputsOutputs(inputs, outputs, SinkingCands);
1350
1351 // Now sink all instructions which only have non-phi uses inside the region
1352 for (auto *II : SinkingCands)
1353 cast<Instruction>(II)->moveBefore(*newFuncRoot,
1354 newFuncRoot->getFirstInsertionPt());
1355
1356 if (!HoistingCands.empty()) {
1357 auto *HoistToBlock = findOrCreateBlockForHoisting(CommonExit);
1358 Instruction *TI = HoistToBlock->getTerminator();
1359 for (auto *II : HoistingCands)
1360 cast<Instruction>(II)->moveBefore(TI);
1361 }
1362
1363 // Collect objects which are inputs to the extraction region and also
1364 // referenced by lifetime start/end markers within it. The effects of these
1365 // markers must be replicated in the calling function to prevent the stack
1366 // coloring pass from merging slots which store input objects.
1367 ValueSet InputObjectsWithLifetime =
1368 eraseLifetimeMarkersOnInputs(Blocks, SinkingCands);
1369
1370 // Construct new function based on inputs/outputs & add allocas for all defs.
1371 Function *newFunction =
1372 constructFunction(inputs, outputs, header, newFuncRoot, codeReplacer,
1373 oldFunction, oldFunction->getParent());
1374
1375 // Update the entry count of the function.
1376 if (BFI) {
1377 auto Count = BFI->getProfileCountFromFreq(EntryFreq.getFrequency());
1378 if (Count.hasValue())
1379 newFunction->setEntryCount(
1380 ProfileCount(Count.getValue(), Function::PCT_Real)); // FIXME
1381 BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency());
1382 }
1383
1384 CallInst *TheCall =
1385 emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs);
1386
1387 moveCodeToFunction(newFunction);
1388
1389 // Replicate the effects of any lifetime start/end markers which referenced
1390 // input objects in the extraction region by placing markers around the call.
1391 insertLifetimeMarkersSurroundingCall(oldFunction->getParent(),
1392 InputObjectsWithLifetime, TheCall);
1393
1394 // Propagate personality info to the new function if there is one.
1395 if (oldFunction->hasPersonalityFn())
1396 newFunction->setPersonalityFn(oldFunction->getPersonalityFn());
1397
1398 // Update the branch weights for the exit block.
1399 if (BFI && NumExitBlocks > 1)
1400 calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI);
1401
1402 // Loop over all of the PHI nodes in the header and exit blocks, and change
1403 // any references to the old incoming edge to be the new incoming edge.
1404 for (BasicBlock::iterator I = header->begin(); isa<PHINode>(I); ++I) {
1405 PHINode *PN = cast<PHINode>(I);
1406 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
1407 if (!Blocks.count(PN->getIncomingBlock(i)))
1408 PN->setIncomingBlock(i, newFuncRoot);
1409 }
1410
1411 for (BasicBlock *ExitBB : ExitBlocks)
1412 for (PHINode &PN : ExitBB->phis()) {
1413 Value *IncomingCodeReplacerVal = nullptr;
1414 for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) {
1415 // Ignore incoming values from outside of the extracted region.
1416 if (!Blocks.count(PN.getIncomingBlock(i)))
1417 continue;
1418
1419 // Ensure that there is only one incoming value from codeReplacer.
1420 if (!IncomingCodeReplacerVal) {
1421 PN.setIncomingBlock(i, codeReplacer);
1422 IncomingCodeReplacerVal = PN.getIncomingValue(i);
1423 } else
1424 assert(IncomingCodeReplacerVal == PN.getIncomingValue(i) &&
1425 "PHI has two incompatbile incoming values from codeRepl");
1426 }
1427 }
1428
1429 // Erase debug info intrinsics. Variable updates within the new function are
1430 // invisible to debuggers. This could be improved by defining a DISubprogram
1431 // for the new function.
1432 for (BasicBlock &BB : *newFunction) {
1433 auto BlockIt = BB.begin();
1434 // Remove debug info intrinsics from the new function.
1435 while (BlockIt != BB.end()) {
1436 Instruction *Inst = &*BlockIt;
1437 ++BlockIt;
1438 if (isa<DbgInfoIntrinsic>(Inst))
1439 Inst->eraseFromParent();
1440 }
1441 // Remove debug info intrinsics which refer to values in the new function
1442 // from the old function.
1443 SmallVector<DbgVariableIntrinsic *, 4> DbgUsers;
1444 for (Instruction &I : BB)
1445 findDbgUsers(DbgUsers, &I);
1446 for (DbgVariableIntrinsic *DVI : DbgUsers)
1447 DVI->eraseFromParent();
1448 }
1449
1450 // Mark the new function `noreturn` if applicable. Terminators which resume
1451 // exception propagation are treated as returning instructions. This is to
1452 // avoid inserting traps after calls to outlined functions which unwind.
1453 bool doesNotReturn = none_of(*newFunction, [](const BasicBlock &BB) {
1454 const Instruction *Term = BB.getTerminator();
1455 return isa<ReturnInst>(Term) || isa<ResumeInst>(Term);
1456 });
1457 if (doesNotReturn)
1458 newFunction->setDoesNotReturn();
1459
1460 LLVM_DEBUG(if (verifyFunction(*newFunction, &errs())) {
1461 newFunction->dump();
1462 report_fatal_error("verification of newFunction failed!");
1463 });
1464 LLVM_DEBUG(if (verifyFunction(*oldFunction))
1465 report_fatal_error("verification of oldFunction failed!"));
1466 return newFunction;
1467 }
1468