1 //===- LoopExtractor.cpp - Extract each loop into a new function ----------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file was developed by the LLVM research group and is distributed under 6 // the University of Illinois Open Source License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // A pass wrapper around the ExtractLoop() scalar transformation to extract each 11 // top-level loop into its own new function. If the loop is the ONLY loop in a 12 // given function, it is not touched. This is a pass most useful for debugging 13 // via bugpoint. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "llvm/Transforms/IPO.h" 18 #include "llvm/Instructions.h" 19 #include "llvm/Module.h" 20 #include "llvm/Pass.h" 21 #include "llvm/Analysis/Dominators.h" 22 #include "llvm/Analysis/LoopInfo.h" 23 #include "llvm/Transforms/Scalar.h" 24 #include "llvm/Transforms/Utils/FunctionUtils.h" 25 #include "llvm/ADT/Statistic.h" 26 using namespace llvm; 27 28 namespace { 29 Statistic<> NumExtracted("loop-extract", "Number of loops extracted"); 30 31 // FIXME: This is not a function pass, but the PassManager doesn't allow 32 // Module passes to require FunctionPasses, so we can't get loop info if we're 33 // not a function pass. 34 struct LoopExtractor : public FunctionPass { 35 unsigned NumLoops; 36 37 LoopExtractor(unsigned numLoops = ~0) : NumLoops(numLoops) {} 38 39 virtual bool runOnFunction(Function &F); 40 41 virtual void getAnalysisUsage(AnalysisUsage &AU) const { 42 AU.addRequiredID(BreakCriticalEdgesID); 43 AU.addRequiredID(LoopSimplifyID); 44 AU.addRequired<DominatorSet>(); 45 AU.addRequired<LoopInfo>(); 46 } 47 }; 48 49 RegisterOpt<LoopExtractor> 50 X("loop-extract", "Extract loops into new functions"); 51 52 /// SingleLoopExtractor - For bugpoint. 53 struct SingleLoopExtractor : public LoopExtractor { 54 SingleLoopExtractor() : LoopExtractor(1) {} 55 }; 56 57 RegisterOpt<SingleLoopExtractor> 58 Y("loop-extract-single", "Extract at most one loop into a new function"); 59 } // End anonymous namespace 60 61 // createLoopExtractorPass - This pass extracts all natural loops from the 62 // program into a function if it can. 63 // 64 FunctionPass *llvm::createLoopExtractorPass() { return new LoopExtractor(); } 65 66 bool LoopExtractor::runOnFunction(Function &F) { 67 LoopInfo &LI = getAnalysis<LoopInfo>(); 68 69 // If this function has no loops, there is nothing to do. 70 if (LI.begin() == LI.end()) 71 return false; 72 73 DominatorSet &DS = getAnalysis<DominatorSet>(); 74 75 // If there is more than one top-level loop in this function, extract all of 76 // the loops. 77 bool Changed = false; 78 if (LI.end()-LI.begin() > 1) { 79 for (LoopInfo::iterator i = LI.begin(), e = LI.end(); i != e; ++i) { 80 if (NumLoops == 0) return Changed; 81 --NumLoops; 82 Changed |= ExtractLoop(DS, *i) != 0; 83 ++NumExtracted; 84 } 85 } else { 86 // Otherwise there is exactly one top-level loop. If this function is more 87 // than a minimal wrapper around the loop, extract the loop. 88 Loop *TLL = *LI.begin(); 89 bool ShouldExtractLoop = false; 90 91 // Extract the loop if the entry block doesn't branch to the loop header. 92 TerminatorInst *EntryTI = F.getEntryBlock().getTerminator(); 93 if (!isa<BranchInst>(EntryTI) || 94 !cast<BranchInst>(EntryTI)->isUnconditional() || 95 EntryTI->getSuccessor(0) != TLL->getHeader()) 96 ShouldExtractLoop = true; 97 else { 98 // Check to see if any exits from the loop are more than just return 99 // blocks. 100 std::vector<BasicBlock*> ExitBlocks; 101 TLL->getExitBlocks(ExitBlocks); 102 for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) 103 if (!isa<ReturnInst>(ExitBlocks[i]->getTerminator())) { 104 ShouldExtractLoop = true; 105 break; 106 } 107 } 108 109 if (ShouldExtractLoop) { 110 if (NumLoops == 0) return Changed; 111 --NumLoops; 112 Changed |= ExtractLoop(DS, TLL) != 0; 113 ++NumExtracted; 114 } else { 115 // Okay, this function is a minimal container around the specified loop. 116 // If we extract the loop, we will continue to just keep extracting it 117 // infinitely... so don't extract it. However, if the loop contains any 118 // subloops, extract them. 119 for (Loop::iterator i = TLL->begin(), e = TLL->end(); i != e; ++i) { 120 if (NumLoops == 0) return Changed; 121 --NumLoops; 122 Changed |= ExtractLoop(DS, *i) != 0; 123 ++NumExtracted; 124 } 125 } 126 } 127 128 return Changed; 129 } 130 131 // createSingleLoopExtractorPass - This pass extracts one natural loop from the 132 // program into a function if it can. This is used by bugpoint. 133 // 134 FunctionPass *llvm::createSingleLoopExtractorPass() { 135 return new SingleLoopExtractor(); 136 } 137 138 139 namespace { 140 /// BlockExtractorPass - This pass is used by bugpoint to extract all blocks 141 /// from the module into their own functions except for those specified by the 142 /// BlocksToNotExtract list. 143 class BlockExtractorPass : public ModulePass { 144 std::vector<BasicBlock*> BlocksToNotExtract; 145 public: 146 BlockExtractorPass(std::vector<BasicBlock*> &B) : BlocksToNotExtract(B) {} 147 BlockExtractorPass() {} 148 149 bool runOnModule(Module &M); 150 }; 151 RegisterOpt<BlockExtractorPass> 152 XX("extract-blocks", "Extract Basic Blocks From Module (for bugpoint use)"); 153 } 154 155 // createBlockExtractorPass - This pass extracts all blocks (except those 156 // specified in the argument list) from the functions in the module. 157 // 158 ModulePass *llvm::createBlockExtractorPass(std::vector<BasicBlock*> &BTNE) { 159 return new BlockExtractorPass(BTNE); 160 } 161 162 bool BlockExtractorPass::runOnModule(Module &M) { 163 std::set<BasicBlock*> TranslatedBlocksToNotExtract; 164 for (unsigned i = 0, e = BlocksToNotExtract.size(); i != e; ++i) { 165 BasicBlock *BB = BlocksToNotExtract[i]; 166 Function *F = BB->getParent(); 167 168 // Map the corresponding function in this module. 169 Function *MF = M.getFunction(F->getName(), F->getFunctionType()); 170 171 // Figure out which index the basic block is in its function. 172 Function::iterator BBI = MF->begin(); 173 std::advance(BBI, std::distance(F->begin(), Function::iterator(BB))); 174 TranslatedBlocksToNotExtract.insert(BBI); 175 } 176 177 // Now that we know which blocks to not extract, figure out which ones we WANT 178 // to extract. 179 std::vector<BasicBlock*> BlocksToExtract; 180 for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F) 181 for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) 182 if (!TranslatedBlocksToNotExtract.count(BB)) 183 BlocksToExtract.push_back(BB); 184 185 for (unsigned i = 0, e = BlocksToExtract.size(); i != e; ++i) 186 ExtractBasicBlock(BlocksToExtract[i]); 187 188 return !BlocksToExtract.empty(); 189 } 190