1 //===- CodeExtractor.cpp - Pull code region into a new function -----------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file implements the interface to tear out a code region, such as an 11 // individual loop or a parallel section, into a new function, replacing it with 12 // a call to the new function. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "llvm/Transforms/Utils/CodeExtractor.h" 17 #include "llvm/Constants.h" 18 #include "llvm/DerivedTypes.h" 19 #include "llvm/Instructions.h" 20 #include "llvm/Intrinsics.h" 21 #include "llvm/LLVMContext.h" 22 #include "llvm/Module.h" 23 #include "llvm/Pass.h" 24 #include "llvm/Analysis/Dominators.h" 25 #include "llvm/Analysis/LoopInfo.h" 26 #include "llvm/Analysis/Verifier.h" 27 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 28 #include "llvm/Support/CommandLine.h" 29 #include "llvm/Support/Debug.h" 30 #include "llvm/Support/ErrorHandling.h" 31 #include "llvm/Support/raw_ostream.h" 32 #include "llvm/ADT/SetVector.h" 33 #include "llvm/ADT/StringExtras.h" 34 #include <algorithm> 35 #include <set> 36 using namespace llvm; 37 38 // Provide a command-line option to aggregate function arguments into a struct 39 // for functions produced by the code extractor. This is useful when converting 40 // extracted functions to pthread-based code, as only one argument (void*) can 41 // be passed in to pthread_create(). 42 static cl::opt<bool> 43 AggregateArgsOpt("aggregate-extracted-args", cl::Hidden, 44 cl::desc("Aggregate arguments to code-extracted functions")); 45 46 /// \brief Test whether a block is valid for extraction. 47 static bool isBlockValidForExtraction(const BasicBlock &BB) { 48 // Landing pads must be in the function where they were inserted for cleanup. 49 if (BB.isLandingPad()) 50 return false; 51 52 // Don't hoist code containing allocas, invokes, or vastarts. 53 for (BasicBlock::const_iterator I = BB.begin(), E = BB.end(); I != E; ++I) { 54 if (isa<AllocaInst>(I) || isa<InvokeInst>(I)) 55 return false; 56 if (const CallInst *CI = dyn_cast<CallInst>(I)) 57 if (const Function *F = CI->getCalledFunction()) 58 if (F->getIntrinsicID() == Intrinsic::vastart) 59 return false; 60 } 61 62 return true; 63 } 64 65 /// \brief Build a set of blocks to extract if the input blocks are viable. 66 static SetVector<BasicBlock *> 67 buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs) { 68 SetVector<BasicBlock *> Result; 69 70 // Loop over the blocks, adding them to our set-vector, and aborting with an 71 // empty set if we encounter invalid blocks. 72 for (ArrayRef<BasicBlock *>::iterator I = BBs.begin(), E = BBs.end(); 73 I != E; ++I) { 74 if (!Result.insert(*I)) 75 continue; 76 77 if (!isBlockValidForExtraction(**I)) { 78 Result.clear(); 79 break; 80 } 81 } 82 83 return Result; 84 } 85 86 CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs) 87 : DT(0), AggregateArgs(AggregateArgs||AggregateArgsOpt), 88 Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {} 89 90 CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, 91 bool AggregateArgs) 92 : DT(DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), 93 Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {} 94 95 CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs) 96 : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), 97 Blocks(buildExtractionBlockSet(L.getBlocks())), NumExitBlocks(~0U) {} 98 99 100 /// definedInRegion - Return true if the specified value is defined in the 101 /// extracted region. 102 static bool definedInRegion(const SetVector<BasicBlock *> &Blocks, Value *V) { 103 if (Instruction *I = dyn_cast<Instruction>(V)) 104 if (Blocks.count(I->getParent())) 105 return true; 106 return false; 107 } 108 109 /// definedInCaller - Return true if the specified value is defined in the 110 /// function being code extracted, but not in the region being extracted. 111 /// These values must be passed in as live-ins to the function. 112 static bool definedInCaller(const SetVector<BasicBlock *> &Blocks, Value *V) { 113 if (isa<Argument>(V)) return true; 114 if (Instruction *I = dyn_cast<Instruction>(V)) 115 if (!Blocks.count(I->getParent())) 116 return true; 117 return false; 118 } 119 120 /// severSplitPHINodes - If a PHI node has multiple inputs from outside of the 121 /// region, we need to split the entry block of the region so that the PHI node 122 /// is easier to deal with. 123 void CodeExtractor::severSplitPHINodes(BasicBlock *&Header) { 124 unsigned NumPredsFromRegion = 0; 125 unsigned NumPredsOutsideRegion = 0; 126 127 if (Header != &Header->getParent()->getEntryBlock()) { 128 PHINode *PN = dyn_cast<PHINode>(Header->begin()); 129 if (!PN) return; // No PHI nodes. 130 131 // If the header node contains any PHI nodes, check to see if there is more 132 // than one entry from outside the region. If so, we need to sever the 133 // header block into two. 134 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) 135 if (Blocks.count(PN->getIncomingBlock(i))) 136 ++NumPredsFromRegion; 137 else 138 ++NumPredsOutsideRegion; 139 140 // If there is one (or fewer) predecessor from outside the region, we don't 141 // need to do anything special. 142 if (NumPredsOutsideRegion <= 1) return; 143 } 144 145 // Otherwise, we need to split the header block into two pieces: one 146 // containing PHI nodes merging values from outside of the region, and a 147 // second that contains all of the code for the block and merges back any 148 // incoming values from inside of the region. 149 BasicBlock::iterator AfterPHIs = Header->getFirstNonPHI(); 150 BasicBlock *NewBB = Header->splitBasicBlock(AfterPHIs, 151 Header->getName()+".ce"); 152 153 // We only want to code extract the second block now, and it becomes the new 154 // header of the region. 155 BasicBlock *OldPred = Header; 156 Blocks.remove(OldPred); 157 Blocks.insert(NewBB); 158 Header = NewBB; 159 160 // Okay, update dominator sets. The blocks that dominate the new one are the 161 // blocks that dominate TIBB plus the new block itself. 162 if (DT) 163 DT->splitBlock(NewBB); 164 165 // Okay, now we need to adjust the PHI nodes and any branches from within the 166 // region to go to the new header block instead of the old header block. 167 if (NumPredsFromRegion) { 168 PHINode *PN = cast<PHINode>(OldPred->begin()); 169 // Loop over all of the predecessors of OldPred that are in the region, 170 // changing them to branch to NewBB instead. 171 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) 172 if (Blocks.count(PN->getIncomingBlock(i))) { 173 TerminatorInst *TI = PN->getIncomingBlock(i)->getTerminator(); 174 TI->replaceUsesOfWith(OldPred, NewBB); 175 } 176 177 // Okay, everything within the region is now branching to the right block, we 178 // just have to update the PHI nodes now, inserting PHI nodes into NewBB. 179 for (AfterPHIs = OldPred->begin(); isa<PHINode>(AfterPHIs); ++AfterPHIs) { 180 PHINode *PN = cast<PHINode>(AfterPHIs); 181 // Create a new PHI node in the new region, which has an incoming value 182 // from OldPred of PN. 183 PHINode *NewPN = PHINode::Create(PN->getType(), 1 + NumPredsFromRegion, 184 PN->getName()+".ce", NewBB->begin()); 185 NewPN->addIncoming(PN, OldPred); 186 187 // Loop over all of the incoming value in PN, moving them to NewPN if they 188 // are from the extracted region. 189 for (unsigned i = 0; i != PN->getNumIncomingValues(); ++i) { 190 if (Blocks.count(PN->getIncomingBlock(i))) { 191 NewPN->addIncoming(PN->getIncomingValue(i), PN->getIncomingBlock(i)); 192 PN->removeIncomingValue(i); 193 --i; 194 } 195 } 196 } 197 } 198 } 199 200 void CodeExtractor::splitReturnBlocks() { 201 for (SetVector<BasicBlock *>::iterator I = Blocks.begin(), E = Blocks.end(); 202 I != E; ++I) 203 if (ReturnInst *RI = dyn_cast<ReturnInst>((*I)->getTerminator())) { 204 BasicBlock *New = (*I)->splitBasicBlock(RI, (*I)->getName()+".ret"); 205 if (DT) { 206 // Old dominates New. New node dominates all other nodes dominated 207 // by Old. 208 DomTreeNode *OldNode = DT->getNode(*I); 209 SmallVector<DomTreeNode*, 8> Children; 210 for (DomTreeNode::iterator DI = OldNode->begin(), DE = OldNode->end(); 211 DI != DE; ++DI) 212 Children.push_back(*DI); 213 214 DomTreeNode *NewNode = DT->addNewBlock(New, *I); 215 216 for (SmallVector<DomTreeNode*, 8>::iterator I = Children.begin(), 217 E = Children.end(); I != E; ++I) 218 DT->changeImmediateDominator(*I, NewNode); 219 } 220 } 221 } 222 223 // findInputsOutputs - Find inputs to, outputs from the code region. 224 // 225 void CodeExtractor::findInputsOutputs(ValueSet &inputs, ValueSet &outputs) { 226 std::set<BasicBlock*> ExitBlocks; 227 for (SetVector<BasicBlock*>::const_iterator ci = Blocks.begin(), 228 ce = Blocks.end(); ci != ce; ++ci) { 229 BasicBlock *BB = *ci; 230 231 for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { 232 // If a used value is defined outside the region, it's an input. If an 233 // instruction is used outside the region, it's an output. 234 for (User::op_iterator O = I->op_begin(), E = I->op_end(); O != E; ++O) 235 if (definedInCaller(Blocks, *O)) 236 inputs.insert(*O); 237 238 // Consider uses of this instruction (outputs). 239 for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); 240 UI != E; ++UI) 241 if (!definedInRegion(Blocks, *UI)) { 242 outputs.insert(I); 243 break; 244 } 245 } // for: insts 246 247 // Keep track of the exit blocks from the region. 248 TerminatorInst *TI = BB->getTerminator(); 249 for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) 250 if (!Blocks.count(TI->getSuccessor(i))) 251 ExitBlocks.insert(TI->getSuccessor(i)); 252 } // for: basic blocks 253 254 NumExitBlocks = ExitBlocks.size(); 255 } 256 257 /// constructFunction - make a function based on inputs and outputs, as follows: 258 /// f(in0, ..., inN, out0, ..., outN) 259 /// 260 Function *CodeExtractor::constructFunction(const ValueSet &inputs, 261 const ValueSet &outputs, 262 BasicBlock *header, 263 BasicBlock *newRootNode, 264 BasicBlock *newHeader, 265 Function *oldFunction, 266 Module *M) { 267 DEBUG(dbgs() << "inputs: " << inputs.size() << "\n"); 268 DEBUG(dbgs() << "outputs: " << outputs.size() << "\n"); 269 270 // This function returns unsigned, outputs will go back by reference. 271 switch (NumExitBlocks) { 272 case 0: 273 case 1: RetTy = Type::getVoidTy(header->getContext()); break; 274 case 2: RetTy = Type::getInt1Ty(header->getContext()); break; 275 default: RetTy = Type::getInt16Ty(header->getContext()); break; 276 } 277 278 std::vector<Type*> paramTy; 279 280 // Add the types of the input values to the function's argument list 281 for (ValueSet::const_iterator i = inputs.begin(), e = inputs.end(); 282 i != e; ++i) { 283 const Value *value = *i; 284 DEBUG(dbgs() << "value used in func: " << *value << "\n"); 285 paramTy.push_back(value->getType()); 286 } 287 288 // Add the types of the output values to the function's argument list. 289 for (ValueSet::const_iterator I = outputs.begin(), E = outputs.end(); 290 I != E; ++I) { 291 DEBUG(dbgs() << "instr used in func: " << **I << "\n"); 292 if (AggregateArgs) 293 paramTy.push_back((*I)->getType()); 294 else 295 paramTy.push_back(PointerType::getUnqual((*I)->getType())); 296 } 297 298 DEBUG(dbgs() << "Function type: " << *RetTy << " f("); 299 for (std::vector<Type*>::iterator i = paramTy.begin(), 300 e = paramTy.end(); i != e; ++i) 301 DEBUG(dbgs() << **i << ", "); 302 DEBUG(dbgs() << ")\n"); 303 304 if (AggregateArgs && (inputs.size() + outputs.size() > 0)) { 305 PointerType *StructPtr = 306 PointerType::getUnqual(StructType::get(M->getContext(), paramTy)); 307 paramTy.clear(); 308 paramTy.push_back(StructPtr); 309 } 310 FunctionType *funcType = 311 FunctionType::get(RetTy, paramTy, false); 312 313 // Create the new function 314 Function *newFunction = Function::Create(funcType, 315 GlobalValue::InternalLinkage, 316 oldFunction->getName() + "_" + 317 header->getName(), M); 318 // If the old function is no-throw, so is the new one. 319 if (oldFunction->doesNotThrow()) 320 newFunction->setDoesNotThrow(true); 321 322 newFunction->getBasicBlockList().push_back(newRootNode); 323 324 // Create an iterator to name all of the arguments we inserted. 325 Function::arg_iterator AI = newFunction->arg_begin(); 326 327 // Rewrite all users of the inputs in the extracted region to use the 328 // arguments (or appropriate addressing into struct) instead. 329 for (unsigned i = 0, e = inputs.size(); i != e; ++i) { 330 Value *RewriteVal; 331 if (AggregateArgs) { 332 Value *Idx[2]; 333 Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext())); 334 Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), i); 335 TerminatorInst *TI = newFunction->begin()->getTerminator(); 336 GetElementPtrInst *GEP = 337 GetElementPtrInst::Create(AI, Idx, "gep_" + inputs[i]->getName(), TI); 338 RewriteVal = new LoadInst(GEP, "loadgep_" + inputs[i]->getName(), TI); 339 } else 340 RewriteVal = AI++; 341 342 std::vector<User*> Users(inputs[i]->use_begin(), inputs[i]->use_end()); 343 for (std::vector<User*>::iterator use = Users.begin(), useE = Users.end(); 344 use != useE; ++use) 345 if (Instruction* inst = dyn_cast<Instruction>(*use)) 346 if (Blocks.count(inst->getParent())) 347 inst->replaceUsesOfWith(inputs[i], RewriteVal); 348 } 349 350 // Set names for input and output arguments. 351 if (!AggregateArgs) { 352 AI = newFunction->arg_begin(); 353 for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++AI) 354 AI->setName(inputs[i]->getName()); 355 for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++AI) 356 AI->setName(outputs[i]->getName()+".out"); 357 } 358 359 // Rewrite branches to basic blocks outside of the loop to new dummy blocks 360 // within the new function. This must be done before we lose track of which 361 // blocks were originally in the code region. 362 std::vector<User*> Users(header->use_begin(), header->use_end()); 363 for (unsigned i = 0, e = Users.size(); i != e; ++i) 364 // The BasicBlock which contains the branch is not in the region 365 // modify the branch target to a new block 366 if (TerminatorInst *TI = dyn_cast<TerminatorInst>(Users[i])) 367 if (!Blocks.count(TI->getParent()) && 368 TI->getParent()->getParent() == oldFunction) 369 TI->replaceUsesOfWith(header, newHeader); 370 371 return newFunction; 372 } 373 374 /// FindPhiPredForUseInBlock - Given a value and a basic block, find a PHI 375 /// that uses the value within the basic block, and return the predecessor 376 /// block associated with that use, or return 0 if none is found. 377 static BasicBlock* FindPhiPredForUseInBlock(Value* Used, BasicBlock* BB) { 378 for (Value::use_iterator UI = Used->use_begin(), 379 UE = Used->use_end(); UI != UE; ++UI) { 380 PHINode *P = dyn_cast<PHINode>(*UI); 381 if (P && P->getParent() == BB) 382 return P->getIncomingBlock(UI); 383 } 384 385 return 0; 386 } 387 388 /// emitCallAndSwitchStatement - This method sets up the caller side by adding 389 /// the call instruction, splitting any PHI nodes in the header block as 390 /// necessary. 391 void CodeExtractor:: 392 emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, 393 ValueSet &inputs, ValueSet &outputs) { 394 // Emit a call to the new function, passing in: *pointer to struct (if 395 // aggregating parameters), or plan inputs and allocated memory for outputs 396 std::vector<Value*> params, StructValues, ReloadOutputs, Reloads; 397 398 LLVMContext &Context = newFunction->getContext(); 399 400 // Add inputs as params, or to be filled into the struct 401 for (ValueSet::iterator i = inputs.begin(), e = inputs.end(); i != e; ++i) 402 if (AggregateArgs) 403 StructValues.push_back(*i); 404 else 405 params.push_back(*i); 406 407 // Create allocas for the outputs 408 for (ValueSet::iterator i = outputs.begin(), e = outputs.end(); i != e; ++i) { 409 if (AggregateArgs) { 410 StructValues.push_back(*i); 411 } else { 412 AllocaInst *alloca = 413 new AllocaInst((*i)->getType(), 0, (*i)->getName()+".loc", 414 codeReplacer->getParent()->begin()->begin()); 415 ReloadOutputs.push_back(alloca); 416 params.push_back(alloca); 417 } 418 } 419 420 AllocaInst *Struct = 0; 421 if (AggregateArgs && (inputs.size() + outputs.size() > 0)) { 422 std::vector<Type*> ArgTypes; 423 for (ValueSet::iterator v = StructValues.begin(), 424 ve = StructValues.end(); v != ve; ++v) 425 ArgTypes.push_back((*v)->getType()); 426 427 // Allocate a struct at the beginning of this function 428 Type *StructArgTy = StructType::get(newFunction->getContext(), ArgTypes); 429 Struct = 430 new AllocaInst(StructArgTy, 0, "structArg", 431 codeReplacer->getParent()->begin()->begin()); 432 params.push_back(Struct); 433 434 for (unsigned i = 0, e = inputs.size(); i != e; ++i) { 435 Value *Idx[2]; 436 Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); 437 Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i); 438 GetElementPtrInst *GEP = 439 GetElementPtrInst::Create(Struct, Idx, 440 "gep_" + StructValues[i]->getName()); 441 codeReplacer->getInstList().push_back(GEP); 442 StoreInst *SI = new StoreInst(StructValues[i], GEP); 443 codeReplacer->getInstList().push_back(SI); 444 } 445 } 446 447 // Emit the call to the function 448 CallInst *call = CallInst::Create(newFunction, params, 449 NumExitBlocks > 1 ? "targetBlock" : ""); 450 codeReplacer->getInstList().push_back(call); 451 452 Function::arg_iterator OutputArgBegin = newFunction->arg_begin(); 453 unsigned FirstOut = inputs.size(); 454 if (!AggregateArgs) 455 std::advance(OutputArgBegin, inputs.size()); 456 457 // Reload the outputs passed in by reference 458 for (unsigned i = 0, e = outputs.size(); i != e; ++i) { 459 Value *Output = 0; 460 if (AggregateArgs) { 461 Value *Idx[2]; 462 Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); 463 Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i); 464 GetElementPtrInst *GEP 465 = GetElementPtrInst::Create(Struct, Idx, 466 "gep_reload_" + outputs[i]->getName()); 467 codeReplacer->getInstList().push_back(GEP); 468 Output = GEP; 469 } else { 470 Output = ReloadOutputs[i]; 471 } 472 LoadInst *load = new LoadInst(Output, outputs[i]->getName()+".reload"); 473 Reloads.push_back(load); 474 codeReplacer->getInstList().push_back(load); 475 std::vector<User*> Users(outputs[i]->use_begin(), outputs[i]->use_end()); 476 for (unsigned u = 0, e = Users.size(); u != e; ++u) { 477 Instruction *inst = cast<Instruction>(Users[u]); 478 if (!Blocks.count(inst->getParent())) 479 inst->replaceUsesOfWith(outputs[i], load); 480 } 481 } 482 483 // Now we can emit a switch statement using the call as a value. 484 SwitchInst *TheSwitch = 485 SwitchInst::Create(Constant::getNullValue(Type::getInt16Ty(Context)), 486 codeReplacer, 0, codeReplacer); 487 488 // Since there may be multiple exits from the original region, make the new 489 // function return an unsigned, switch on that number. This loop iterates 490 // over all of the blocks in the extracted region, updating any terminator 491 // instructions in the to-be-extracted region that branch to blocks that are 492 // not in the region to be extracted. 493 std::map<BasicBlock*, BasicBlock*> ExitBlockMap; 494 495 unsigned switchVal = 0; 496 for (SetVector<BasicBlock*>::const_iterator i = Blocks.begin(), 497 e = Blocks.end(); i != e; ++i) { 498 TerminatorInst *TI = (*i)->getTerminator(); 499 for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) 500 if (!Blocks.count(TI->getSuccessor(i))) { 501 BasicBlock *OldTarget = TI->getSuccessor(i); 502 // add a new basic block which returns the appropriate value 503 BasicBlock *&NewTarget = ExitBlockMap[OldTarget]; 504 if (!NewTarget) { 505 // If we don't already have an exit stub for this non-extracted 506 // destination, create one now! 507 NewTarget = BasicBlock::Create(Context, 508 OldTarget->getName() + ".exitStub", 509 newFunction); 510 unsigned SuccNum = switchVal++; 511 512 Value *brVal = 0; 513 switch (NumExitBlocks) { 514 case 0: 515 case 1: break; // No value needed. 516 case 2: // Conditional branch, return a bool 517 brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum); 518 break; 519 default: 520 brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum); 521 break; 522 } 523 524 ReturnInst *NTRet = ReturnInst::Create(Context, brVal, NewTarget); 525 526 // Update the switch instruction. 527 TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context), 528 SuccNum), 529 OldTarget); 530 531 // Restore values just before we exit 532 Function::arg_iterator OAI = OutputArgBegin; 533 for (unsigned out = 0, e = outputs.size(); out != e; ++out) { 534 // For an invoke, the normal destination is the only one that is 535 // dominated by the result of the invocation 536 BasicBlock *DefBlock = cast<Instruction>(outputs[out])->getParent(); 537 538 bool DominatesDef = true; 539 540 if (InvokeInst *Invoke = dyn_cast<InvokeInst>(outputs[out])) { 541 DefBlock = Invoke->getNormalDest(); 542 543 // Make sure we are looking at the original successor block, not 544 // at a newly inserted exit block, which won't be in the dominator 545 // info. 546 for (std::map<BasicBlock*, BasicBlock*>::iterator I = 547 ExitBlockMap.begin(), E = ExitBlockMap.end(); I != E; ++I) 548 if (DefBlock == I->second) { 549 DefBlock = I->first; 550 break; 551 } 552 553 // In the extract block case, if the block we are extracting ends 554 // with an invoke instruction, make sure that we don't emit a 555 // store of the invoke value for the unwind block. 556 if (!DT && DefBlock != OldTarget) 557 DominatesDef = false; 558 } 559 560 if (DT) { 561 DominatesDef = DT->dominates(DefBlock, OldTarget); 562 563 // If the output value is used by a phi in the target block, 564 // then we need to test for dominance of the phi's predecessor 565 // instead. Unfortunately, this a little complicated since we 566 // have already rewritten uses of the value to uses of the reload. 567 BasicBlock* pred = FindPhiPredForUseInBlock(Reloads[out], 568 OldTarget); 569 if (pred && DT && DT->dominates(DefBlock, pred)) 570 DominatesDef = true; 571 } 572 573 if (DominatesDef) { 574 if (AggregateArgs) { 575 Value *Idx[2]; 576 Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); 577 Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), 578 FirstOut+out); 579 GetElementPtrInst *GEP = 580 GetElementPtrInst::Create(OAI, Idx, 581 "gep_" + outputs[out]->getName(), 582 NTRet); 583 new StoreInst(outputs[out], GEP, NTRet); 584 } else { 585 new StoreInst(outputs[out], OAI, NTRet); 586 } 587 } 588 // Advance output iterator even if we don't emit a store 589 if (!AggregateArgs) ++OAI; 590 } 591 } 592 593 // rewrite the original branch instruction with this new target 594 TI->setSuccessor(i, NewTarget); 595 } 596 } 597 598 // Now that we've done the deed, simplify the switch instruction. 599 Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType(); 600 switch (NumExitBlocks) { 601 case 0: 602 // There are no successors (the block containing the switch itself), which 603 // means that previously this was the last part of the function, and hence 604 // this should be rewritten as a `ret' 605 606 // Check if the function should return a value 607 if (OldFnRetTy->isVoidTy()) { 608 ReturnInst::Create(Context, 0, TheSwitch); // Return void 609 } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) { 610 // return what we have 611 ReturnInst::Create(Context, TheSwitch->getCondition(), TheSwitch); 612 } else { 613 // Otherwise we must have code extracted an unwind or something, just 614 // return whatever we want. 615 ReturnInst::Create(Context, 616 Constant::getNullValue(OldFnRetTy), TheSwitch); 617 } 618 619 TheSwitch->eraseFromParent(); 620 break; 621 case 1: 622 // Only a single destination, change the switch into an unconditional 623 // branch. 624 BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch); 625 TheSwitch->eraseFromParent(); 626 break; 627 case 2: 628 BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2), 629 call, TheSwitch); 630 TheSwitch->eraseFromParent(); 631 break; 632 default: 633 // Otherwise, make the default destination of the switch instruction be one 634 // of the other successors. 635 TheSwitch->setCondition(call); 636 TheSwitch->setDefaultDest(TheSwitch->getSuccessor(NumExitBlocks)); 637 // Remove redundant case 638 TheSwitch->removeCase(SwitchInst::CaseIt(TheSwitch, NumExitBlocks-1)); 639 break; 640 } 641 } 642 643 void CodeExtractor::moveCodeToFunction(Function *newFunction) { 644 Function *oldFunc = (*Blocks.begin())->getParent(); 645 Function::BasicBlockListType &oldBlocks = oldFunc->getBasicBlockList(); 646 Function::BasicBlockListType &newBlocks = newFunction->getBasicBlockList(); 647 648 for (SetVector<BasicBlock*>::const_iterator i = Blocks.begin(), 649 e = Blocks.end(); i != e; ++i) { 650 // Delete the basic block from the old function, and the list of blocks 651 oldBlocks.remove(*i); 652 653 // Insert this basic block into the new function 654 newBlocks.push_back(*i); 655 } 656 } 657 658 Function *CodeExtractor::extractCodeRegion() { 659 if (!isEligible()) 660 return 0; 661 662 ValueSet inputs, outputs; 663 664 // Assumption: this is a single-entry code region, and the header is the first 665 // block in the region. 666 BasicBlock *header = *Blocks.begin(); 667 668 for (SetVector<BasicBlock *>::iterator BI = llvm::next(Blocks.begin()), 669 BE = Blocks.end(); 670 BI != BE; ++BI) 671 for (pred_iterator PI = pred_begin(*BI), E = pred_end(*BI); 672 PI != E; ++PI) 673 assert(Blocks.count(*PI) && 674 "No blocks in this region may have entries from outside the region" 675 " except for the first block!"); 676 677 // If we have to split PHI nodes or the entry block, do so now. 678 severSplitPHINodes(header); 679 680 // If we have any return instructions in the region, split those blocks so 681 // that the return is not in the region. 682 splitReturnBlocks(); 683 684 Function *oldFunction = header->getParent(); 685 686 // This takes place of the original loop 687 BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(), 688 "codeRepl", oldFunction, 689 header); 690 691 // The new function needs a root node because other nodes can branch to the 692 // head of the region, but the entry node of a function cannot have preds. 693 BasicBlock *newFuncRoot = BasicBlock::Create(header->getContext(), 694 "newFuncRoot"); 695 newFuncRoot->getInstList().push_back(BranchInst::Create(header)); 696 697 // Find inputs to, outputs from the code region. 698 findInputsOutputs(inputs, outputs); 699 700 // Construct new function based on inputs/outputs & add allocas for all defs. 701 Function *newFunction = constructFunction(inputs, outputs, header, 702 newFuncRoot, 703 codeReplacer, oldFunction, 704 oldFunction->getParent()); 705 706 emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs); 707 708 moveCodeToFunction(newFunction); 709 710 // Loop over all of the PHI nodes in the header block, and change any 711 // references to the old incoming edge to be the new incoming edge. 712 for (BasicBlock::iterator I = header->begin(); isa<PHINode>(I); ++I) { 713 PHINode *PN = cast<PHINode>(I); 714 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) 715 if (!Blocks.count(PN->getIncomingBlock(i))) 716 PN->setIncomingBlock(i, newFuncRoot); 717 } 718 719 // Look at all successors of the codeReplacer block. If any of these blocks 720 // had PHI nodes in them, we need to update the "from" block to be the code 721 // replacer, not the original block in the extracted region. 722 std::vector<BasicBlock*> Succs(succ_begin(codeReplacer), 723 succ_end(codeReplacer)); 724 for (unsigned i = 0, e = Succs.size(); i != e; ++i) 725 for (BasicBlock::iterator I = Succs[i]->begin(); isa<PHINode>(I); ++I) { 726 PHINode *PN = cast<PHINode>(I); 727 std::set<BasicBlock*> ProcessedPreds; 728 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) 729 if (Blocks.count(PN->getIncomingBlock(i))) { 730 if (ProcessedPreds.insert(PN->getIncomingBlock(i)).second) 731 PN->setIncomingBlock(i, codeReplacer); 732 else { 733 // There were multiple entries in the PHI for this block, now there 734 // is only one, so remove the duplicated entries. 735 PN->removeIncomingValue(i, false); 736 --i; --e; 737 } 738 } 739 } 740 741 //cerr << "NEW FUNCTION: " << *newFunction; 742 // verifyFunction(*newFunction); 743 744 // cerr << "OLD FUNCTION: " << *oldFunction; 745 // verifyFunction(*oldFunction); 746 747 DEBUG(if (verifyFunction(*newFunction)) 748 report_fatal_error("verifyFunction failed!")); 749 return newFunction; 750 } 751