1 //===- BlockExtractor.cpp - Extracts blocks into their own functions ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass extracts the specified basic blocks from the module into their 10 // own functions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/ADT/STLExtras.h" 15 #include "llvm/ADT/Statistic.h" 16 #include "llvm/IR/Instructions.h" 17 #include "llvm/IR/Module.h" 18 #include "llvm/Pass.h" 19 #include "llvm/Support/CommandLine.h" 20 #include "llvm/Support/Debug.h" 21 #include "llvm/Support/MemoryBuffer.h" 22 #include "llvm/Transforms/IPO.h" 23 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 24 #include "llvm/Transforms/Utils/CodeExtractor.h" 25 26 using namespace llvm; 27 28 #define DEBUG_TYPE "block-extractor" 29 30 STATISTIC(NumExtracted, "Number of basic blocks extracted"); 31 32 static cl::opt<std::string> BlockExtractorFile( 33 "extract-blocks-file", cl::value_desc("filename"), 34 cl::desc("A file containing list of basic blocks to extract"), cl::Hidden); 35 36 cl::opt<bool> BlockExtractorEraseFuncs("extract-blocks-erase-funcs", 37 cl::desc("Erase the existing functions"), 38 cl::Hidden); 39 namespace { 40 class BlockExtractor : public ModulePass { 41 SmallVector<SmallVector<BasicBlock *, 16>, 4> GroupsOfBlocks; 42 bool EraseFunctions; 43 /// Map a function name to groups of blocks. 44 SmallVector<std::pair<std::string, SmallVector<std::string, 4>>, 4> 45 BlocksByName; 46 47 public: 48 static char ID; 49 BlockExtractor(const SmallVectorImpl<BasicBlock *> &BlocksToExtract, 50 bool EraseFunctions) 51 : ModulePass(ID), EraseFunctions(EraseFunctions) { 52 // We want one group per element of the input list. 53 for (BasicBlock *BB : BlocksToExtract) { 54 SmallVector<BasicBlock *, 16> NewGroup; 55 NewGroup.push_back(BB); 56 GroupsOfBlocks.push_back(NewGroup); 57 } 58 if (!BlockExtractorFile.empty()) 59 loadFile(); 60 } 61 BlockExtractor() : BlockExtractor(SmallVector<BasicBlock *, 0>(), false) {} 62 bool runOnModule(Module &M) override; 63 64 private: 65 void loadFile(); 66 void splitLandingPadPreds(Function &F); 67 }; 68 } // end anonymous namespace 69 70 char BlockExtractor::ID = 0; 71 INITIALIZE_PASS(BlockExtractor, "extract-blocks", 72 "Extract basic blocks from module", false, false) 73 74 ModulePass *llvm::createBlockExtractorPass() { return new BlockExtractor(); } 75 ModulePass *llvm::createBlockExtractorPass( 76 const SmallVectorImpl<BasicBlock *> &BlocksToExtract, bool EraseFunctions) { 77 return new BlockExtractor(BlocksToExtract, EraseFunctions); 78 } 79 80 /// Gets all of the blocks specified in the input file. 81 void BlockExtractor::loadFile() { 82 auto ErrOrBuf = MemoryBuffer::getFile(BlockExtractorFile); 83 if (ErrOrBuf.getError()) 84 report_fatal_error("BlockExtractor couldn't load the file."); 85 // Read the file. 86 auto &Buf = *ErrOrBuf; 87 SmallVector<StringRef, 16> Lines; 88 Buf->getBuffer().split(Lines, '\n', /*MaxSplit=*/-1, 89 /*KeepEmpty=*/false); 90 for (const auto &Line : Lines) { 91 SmallVector<StringRef, 4> LineSplit; 92 Line.split(LineSplit, ' ', /*MaxSplit=*/-1, 93 /*KeepEmpty=*/false); 94 if (LineSplit.empty()) 95 continue; 96 SmallVector<StringRef, 4> BBNames; 97 LineSplit[1].split(BBNames, ';', /*MaxSplit=*/-1, 98 /*KeepEmpty=*/false); 99 if (BBNames.empty()) 100 report_fatal_error("Missing bbs name"); 101 BlocksByName.push_back({LineSplit[0], {BBNames.begin(), BBNames.end()}}); 102 } 103 } 104 105 /// Extracts the landing pads to make sure all of them have only one 106 /// predecessor. 107 void BlockExtractor::splitLandingPadPreds(Function &F) { 108 for (BasicBlock &BB : F) { 109 for (Instruction &I : BB) { 110 if (!isa<InvokeInst>(&I)) 111 continue; 112 InvokeInst *II = cast<InvokeInst>(&I); 113 BasicBlock *Parent = II->getParent(); 114 BasicBlock *LPad = II->getUnwindDest(); 115 116 // Look through the landing pad's predecessors. If one of them ends in an 117 // 'invoke', then we want to split the landing pad. 118 bool Split = false; 119 for (auto PredBB : predecessors(LPad)) { 120 if (PredBB->isLandingPad() && PredBB != Parent && 121 isa<InvokeInst>(Parent->getTerminator())) { 122 Split = true; 123 break; 124 } 125 } 126 127 if (!Split) 128 continue; 129 130 SmallVector<BasicBlock *, 2> NewBBs; 131 SplitLandingPadPredecessors(LPad, Parent, ".1", ".2", NewBBs); 132 } 133 } 134 } 135 136 bool BlockExtractor::runOnModule(Module &M) { 137 138 bool Changed = false; 139 140 // Get all the functions. 141 SmallVector<Function *, 4> Functions; 142 for (Function &F : M) { 143 splitLandingPadPreds(F); 144 Functions.push_back(&F); 145 } 146 147 // Get all the blocks specified in the input file. 148 unsigned NextGroupIdx = GroupsOfBlocks.size(); 149 GroupsOfBlocks.resize(NextGroupIdx + BlocksByName.size()); 150 for (const auto &BInfo : BlocksByName) { 151 Function *F = M.getFunction(BInfo.first); 152 if (!F) 153 report_fatal_error("Invalid function name specified in the input file"); 154 for (const auto &BBInfo : BInfo.second) { 155 auto Res = llvm::find_if(*F, [&](const BasicBlock &BB) { 156 return BB.getName().equals(BBInfo); 157 }); 158 if (Res == F->end()) 159 report_fatal_error("Invalid block name specified in the input file"); 160 GroupsOfBlocks[NextGroupIdx].push_back(&*Res); 161 } 162 ++NextGroupIdx; 163 } 164 165 // Extract each group of basic blocks. 166 for (auto &BBs : GroupsOfBlocks) { 167 SmallVector<BasicBlock *, 32> BlocksToExtractVec; 168 for (BasicBlock *BB : BBs) { 169 // Check if the module contains BB. 170 if (BB->getParent()->getParent() != &M) 171 report_fatal_error("Invalid basic block"); 172 LLVM_DEBUG(dbgs() << "BlockExtractor: Extracting " 173 << BB->getParent()->getName() << ":" << BB->getName() 174 << "\n"); 175 BlocksToExtractVec.push_back(BB); 176 if (const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator())) 177 BlocksToExtractVec.push_back(II->getUnwindDest()); 178 ++NumExtracted; 179 Changed = true; 180 } 181 Function *F = CodeExtractor(BlocksToExtractVec).extractCodeRegion(); 182 if (F) 183 LLVM_DEBUG(dbgs() << "Extracted group '" << (*BBs.begin())->getName() 184 << "' in: " << F->getName() << '\n'); 185 else 186 LLVM_DEBUG(dbgs() << "Failed to extract for group '" 187 << (*BBs.begin())->getName() << "'\n"); 188 } 189 190 // Erase the functions. 191 if (EraseFunctions || BlockExtractorEraseFuncs) { 192 for (Function *F : Functions) { 193 LLVM_DEBUG(dbgs() << "BlockExtractor: Trying to delete " << F->getName() 194 << "\n"); 195 F->deleteBody(); 196 } 197 // Set linkage as ExternalLinkage to avoid erasing unreachable functions. 198 for (Function &F : M) 199 F.setLinkage(GlobalValue::ExternalLinkage); 200 Changed = true; 201 } 202 203 return Changed; 204 } 205