1 //===- UnifyFunctionExitNodes.cpp - Make all functions have a single exit -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass is used to ensure that functions have at most one return and one 10 // unreachable instruction in them. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" 15 #include "llvm/IR/BasicBlock.h" 16 #include "llvm/IR/Function.h" 17 #include "llvm/IR/Instructions.h" 18 #include "llvm/IR/Type.h" 19 #include "llvm/InitializePasses.h" 20 #include "llvm/Transforms/Utils.h" 21 using namespace llvm; 22 23 char UnifyFunctionExitNodes::ID = 0; 24 25 UnifyFunctionExitNodes::UnifyFunctionExitNodes() : FunctionPass(ID) { 26 initializeUnifyFunctionExitNodesPass(*PassRegistry::getPassRegistry()); 27 } 28 29 INITIALIZE_PASS(UnifyFunctionExitNodes, "mergereturn", 30 "Unify function exit nodes", false, false) 31 32 Pass *llvm::createUnifyFunctionExitNodesPass() { 33 return new UnifyFunctionExitNodes(); 34 } 35 36 void UnifyFunctionExitNodes::getAnalysisUsage(AnalysisUsage &AU) const{ 37 // We preserve the non-critical-edgeness property 38 AU.addPreservedID(BreakCriticalEdgesID); 39 // This is a cluster of orthogonal Transforms 40 AU.addPreservedID(LowerSwitchID); 41 } 42 43 bool UnifyFunctionExitNodes::unifyUnreachableBlocks(Function &F) { 44 std::vector<BasicBlock*> UnreachableBlocks; 45 46 for (BasicBlock &I : F) 47 if (isa<UnreachableInst>(I.getTerminator())) 48 UnreachableBlocks.push_back(&I); 49 50 if (UnreachableBlocks.size() <= 1) 51 return false; 52 53 BasicBlock *UnreachableBlock = 54 BasicBlock::Create(F.getContext(), "UnifiedUnreachableBlock", &F); 55 new UnreachableInst(F.getContext(), UnreachableBlock); 56 57 for (BasicBlock *BB : UnreachableBlocks) { 58 BB->getInstList().pop_back(); // Remove the unreachable inst. 59 BranchInst::Create(UnreachableBlock, BB); 60 } 61 62 return true; 63 } 64 65 bool UnifyFunctionExitNodes::unifyReturnBlocks(Function &F) { 66 std::vector<BasicBlock *> ReturningBlocks; 67 68 for (BasicBlock &I : F) 69 if (isa<ReturnInst>(I.getTerminator())) 70 ReturningBlocks.push_back(&I); 71 72 if (ReturningBlocks.size() <= 1) 73 return false; 74 75 // Insert a new basic block into the function, add PHI nodes (if the function 76 // returns values), and convert all of the return instructions into 77 // unconditional branches. 78 BasicBlock *NewRetBlock = BasicBlock::Create(F.getContext(), 79 "UnifiedReturnBlock", &F); 80 81 PHINode *PN = nullptr; 82 if (F.getReturnType()->isVoidTy()) { 83 ReturnInst::Create(F.getContext(), nullptr, NewRetBlock); 84 } else { 85 // If the function doesn't return void... add a PHI node to the block... 86 PN = PHINode::Create(F.getReturnType(), ReturningBlocks.size(), 87 "UnifiedRetVal"); 88 NewRetBlock->getInstList().push_back(PN); 89 ReturnInst::Create(F.getContext(), PN, NewRetBlock); 90 } 91 92 // Loop over all of the blocks, replacing the return instruction with an 93 // unconditional branch. 94 for (BasicBlock *BB : ReturningBlocks) { 95 // Add an incoming element to the PHI node for every return instruction that 96 // is merging into this new block... 97 if (PN) 98 PN->addIncoming(BB->getTerminator()->getOperand(0), BB); 99 100 BB->getInstList().pop_back(); // Remove the return insn 101 BranchInst::Create(NewRetBlock, BB); 102 } 103 104 return true; 105 } 106 107 // Unify all exit nodes of the CFG by creating a new BasicBlock, and converting 108 // all returns to unconditional branches to this new basic block. Also, unify 109 // all unreachable blocks. 110 bool UnifyFunctionExitNodes::runOnFunction(Function &F) { 111 bool Changed = false; 112 Changed |= unifyUnreachableBlocks(F); 113 Changed |= unifyReturnBlocks(F); 114 return Changed; 115 } 116