1 //===- LoopExtractor.cpp - Extract each loop into a new function ----------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // A pass wrapper around the ExtractLoop() scalar transformation to extract each 11 // top-level loop into its own new function. If the loop is the ONLY loop in a 12 // given function, it is not touched. This is a pass most useful for debugging 13 // via bugpoint. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #define DEBUG_TYPE "loop-extract" 18 #include "llvm/Transforms/IPO.h" 19 #include "llvm/Instructions.h" 20 #include "llvm/Module.h" 21 #include "llvm/Pass.h" 22 #include "llvm/Analysis/Dominators.h" 23 #include "llvm/Analysis/LoopInfo.h" 24 #include "llvm/Support/CommandLine.h" 25 #include "llvm/Support/Compiler.h" 26 #include "llvm/Transforms/Scalar.h" 27 #include "llvm/Transforms/Utils/FunctionUtils.h" 28 #include "llvm/ADT/Statistic.h" 29 #include <fstream> 30 #include <set> 31 using namespace llvm; 32 33 STATISTIC(NumExtracted, "Number of loops extracted"); 34 35 namespace { 36 // FIXME: This is not a function pass, but the PassManager doesn't allow 37 // Module passes to require FunctionPasses, so we can't get loop info if we're 38 // not a function pass. 39 struct VISIBILITY_HIDDEN LoopExtractor : public FunctionPass { 40 static char ID; // Pass identification, replacement for typeid 41 unsigned NumLoops; 42 43 explicit LoopExtractor(unsigned numLoops = ~0) 44 : FunctionPass(&ID), NumLoops(numLoops) {} 45 46 virtual bool runOnFunction(Function &F); 47 48 virtual void getAnalysisUsage(AnalysisUsage &AU) const { 49 AU.addRequiredID(BreakCriticalEdgesID); 50 AU.addRequiredID(LoopSimplifyID); 51 AU.addRequired<DominatorTree>(); 52 AU.addRequired<LoopInfo>(); 53 } 54 }; 55 } 56 57 char LoopExtractor::ID = 0; 58 static RegisterPass<LoopExtractor> 59 X("loop-extract", "Extract loops into new functions"); 60 61 namespace { 62 /// SingleLoopExtractor - For bugpoint. 63 struct SingleLoopExtractor : public LoopExtractor { 64 static char ID; // Pass identification, replacement for typeid 65 SingleLoopExtractor() : LoopExtractor(1) {} 66 }; 67 } // End anonymous namespace 68 69 char SingleLoopExtractor::ID = 0; 70 static RegisterPass<SingleLoopExtractor> 71 Y("loop-extract-single", "Extract at most one loop into a new function"); 72 73 // createLoopExtractorPass - This pass extracts all natural loops from the 74 // program into a function if it can. 75 // 76 FunctionPass *llvm::createLoopExtractorPass() { return new LoopExtractor(); } 77 78 bool LoopExtractor::runOnFunction(Function &F) { 79 LoopInfo &LI = getAnalysis<LoopInfo>(); 80 81 // If this function has no loops, there is nothing to do. 82 if (LI.empty()) 83 return false; 84 85 DominatorTree &DT = getAnalysis<DominatorTree>(); 86 87 // If there is more than one top-level loop in this function, extract all of 88 // the loops. 89 bool Changed = false; 90 if (LI.end()-LI.begin() > 1) { 91 for (LoopInfo::iterator i = LI.begin(), e = LI.end(); i != e; ++i) { 92 if (NumLoops == 0) return Changed; 93 --NumLoops; 94 Changed |= ExtractLoop(DT, *i) != 0; 95 ++NumExtracted; 96 } 97 } else { 98 // Otherwise there is exactly one top-level loop. If this function is more 99 // than a minimal wrapper around the loop, extract the loop. 100 Loop *TLL = *LI.begin(); 101 bool ShouldExtractLoop = false; 102 103 // Extract the loop if the entry block doesn't branch to the loop header. 104 TerminatorInst *EntryTI = F.getEntryBlock().getTerminator(); 105 if (!isa<BranchInst>(EntryTI) || 106 !cast<BranchInst>(EntryTI)->isUnconditional() || 107 EntryTI->getSuccessor(0) != TLL->getHeader()) 108 ShouldExtractLoop = true; 109 else { 110 // Check to see if any exits from the loop are more than just return 111 // blocks. 112 SmallVector<BasicBlock*, 8> ExitBlocks; 113 TLL->getExitBlocks(ExitBlocks); 114 for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) 115 if (!isa<ReturnInst>(ExitBlocks[i]->getTerminator())) { 116 ShouldExtractLoop = true; 117 break; 118 } 119 } 120 121 if (ShouldExtractLoop) { 122 if (NumLoops == 0) return Changed; 123 --NumLoops; 124 Changed |= ExtractLoop(DT, TLL) != 0; 125 ++NumExtracted; 126 } else { 127 // Okay, this function is a minimal container around the specified loop. 128 // If we extract the loop, we will continue to just keep extracting it 129 // infinitely... so don't extract it. However, if the loop contains any 130 // subloops, extract them. 131 for (Loop::iterator i = TLL->begin(), e = TLL->end(); i != e; ++i) { 132 if (NumLoops == 0) return Changed; 133 --NumLoops; 134 Changed |= ExtractLoop(DT, *i) != 0; 135 ++NumExtracted; 136 } 137 } 138 } 139 140 return Changed; 141 } 142 143 // createSingleLoopExtractorPass - This pass extracts one natural loop from the 144 // program into a function if it can. This is used by bugpoint. 145 // 146 FunctionPass *llvm::createSingleLoopExtractorPass() { 147 return new SingleLoopExtractor(); 148 } 149 150 151 // BlockFile - A file which contains a list of blocks that should not be 152 // extracted. 153 static cl::opt<std::string> 154 BlockFile("extract-blocks-file", cl::value_desc("filename"), 155 cl::desc("A file containing list of basic blocks to not extract"), 156 cl::Hidden); 157 158 namespace { 159 /// BlockExtractorPass - This pass is used by bugpoint to extract all blocks 160 /// from the module into their own functions except for those specified by the 161 /// BlocksToNotExtract list. 162 class BlockExtractorPass : public ModulePass { 163 void LoadFile(const char *Filename); 164 165 std::vector<BasicBlock*> BlocksToNotExtract; 166 std::vector<std::pair<std::string, std::string> > BlocksToNotExtractByName; 167 public: 168 static char ID; // Pass identification, replacement for typeid 169 explicit BlockExtractorPass(const std::vector<BasicBlock*> &B) 170 : ModulePass(&ID), BlocksToNotExtract(B) { 171 if (!BlockFile.empty()) 172 LoadFile(BlockFile.c_str()); 173 } 174 BlockExtractorPass() : ModulePass(&ID) {} 175 176 bool runOnModule(Module &M); 177 }; 178 } 179 180 char BlockExtractorPass::ID = 0; 181 static RegisterPass<BlockExtractorPass> 182 XX("extract-blocks", "Extract Basic Blocks From Module (for bugpoint use)"); 183 184 // createBlockExtractorPass - This pass extracts all blocks (except those 185 // specified in the argument list) from the functions in the module. 186 // 187 ModulePass *llvm::createBlockExtractorPass(const std::vector<BasicBlock*> &BTNE) 188 { 189 return new BlockExtractorPass(BTNE); 190 } 191 192 void BlockExtractorPass::LoadFile(const char *Filename) { 193 // Load the BlockFile... 194 std::ifstream In(Filename); 195 if (!In.good()) { 196 cerr << "WARNING: BlockExtractor couldn't load file '" << Filename 197 << "'!\n"; 198 return; 199 } 200 while (In) { 201 std::string FunctionName, BlockName; 202 In >> FunctionName; 203 In >> BlockName; 204 if (!BlockName.empty()) 205 BlocksToNotExtractByName.push_back( 206 std::make_pair(FunctionName, BlockName)); 207 } 208 } 209 210 bool BlockExtractorPass::runOnModule(Module &M) { 211 std::set<BasicBlock*> TranslatedBlocksToNotExtract; 212 for (unsigned i = 0, e = BlocksToNotExtract.size(); i != e; ++i) { 213 BasicBlock *BB = BlocksToNotExtract[i]; 214 Function *F = BB->getParent(); 215 216 // Map the corresponding function in this module. 217 Function *MF = M.getFunction(F->getName()); 218 assert(MF->getFunctionType() == F->getFunctionType() && "Wrong function?"); 219 220 // Figure out which index the basic block is in its function. 221 Function::iterator BBI = MF->begin(); 222 std::advance(BBI, std::distance(F->begin(), Function::iterator(BB))); 223 TranslatedBlocksToNotExtract.insert(BBI); 224 } 225 226 while (!BlocksToNotExtractByName.empty()) { 227 // There's no way to find BBs by name without looking at every BB inside 228 // every Function. Fortunately, this is always empty except when used by 229 // bugpoint in which case correctness is more important than performance. 230 231 std::string &FuncName = BlocksToNotExtractByName.back().first; 232 std::string &BlockName = BlocksToNotExtractByName.back().second; 233 234 for (Module::iterator FI = M.begin(), FE = M.end(); FI != FE; ++FI) { 235 Function &F = *FI; 236 if (F.getName() != FuncName) continue; 237 238 for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) { 239 BasicBlock &BB = *BI; 240 if (BB.getName() != BlockName) continue; 241 242 TranslatedBlocksToNotExtract.insert(BI); 243 } 244 } 245 246 BlocksToNotExtractByName.pop_back(); 247 } 248 249 // Now that we know which blocks to not extract, figure out which ones we WANT 250 // to extract. 251 std::vector<BasicBlock*> BlocksToExtract; 252 for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F) 253 for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) 254 if (!TranslatedBlocksToNotExtract.count(BB)) 255 BlocksToExtract.push_back(BB); 256 257 for (unsigned i = 0, e = BlocksToExtract.size(); i != e; ++i) 258 ExtractBasicBlock(BlocksToExtract[i]); 259 260 return !BlocksToExtract.empty(); 261 } 262