1 //===- LoopExtractor.cpp - Extract each loop into a new function ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass wrapper around the ExtractLoop() scalar transformation to extract each
10 // top-level loop into its own new function. If the loop is the ONLY loop in a
11 // given function, it is not touched. This is a pass most useful for debugging
12 // via bugpoint.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/ADT/Statistic.h"
17 #include "llvm/Analysis/AssumptionCache.h"
18 #include "llvm/Analysis/LoopInfo.h"
19 #include "llvm/IR/Dominators.h"
20 #include "llvm/IR/Instructions.h"
21 #include "llvm/IR/Module.h"
22 #include "llvm/InitializePasses.h"
23 #include "llvm/Pass.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Transforms/IPO.h"
26 #include "llvm/Transforms/Scalar.h"
27 #include "llvm/Transforms/Utils.h"
28 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
29 #include "llvm/Transforms/Utils/CodeExtractor.h"
30 #include <fstream>
31 #include <set>
32 using namespace llvm;
33 
34 #define DEBUG_TYPE "loop-extract"
35 
36 STATISTIC(NumExtracted, "Number of loops extracted");
37 
38 namespace {
39   struct LoopExtractor : public ModulePass {
40     static char ID; // Pass identification, replacement for typeid
41 
42     // The number of natural loops to extract from the program into functions.
43     unsigned NumLoops;
44 
45     explicit LoopExtractor(unsigned numLoops = ~0)
46         : ModulePass(ID), NumLoops(numLoops) {
47       initializeLoopExtractorPass(*PassRegistry::getPassRegistry());
48     }
49 
50     bool runOnModule(Module &M) override;
51     bool runOnFunction(Function &F);
52 
53     bool extractLoops(Loop::iterator From, Loop::iterator To, LoopInfo &LI,
54                       DominatorTree &DT);
55     bool extractLoop(Loop *L, LoopInfo &LI, DominatorTree &DT);
56 
57     void getAnalysisUsage(AnalysisUsage &AU) const override {
58       AU.addRequiredID(BreakCriticalEdgesID);
59       AU.addRequired<DominatorTreeWrapperPass>();
60       AU.addRequired<LoopInfoWrapperPass>();
61       AU.addPreserved<LoopInfoWrapperPass>();
62       AU.addRequiredID(LoopSimplifyID);
63       AU.addUsedIfAvailable<AssumptionCacheTracker>();
64     }
65   };
66 }
67 
68 char LoopExtractor::ID = 0;
69 INITIALIZE_PASS_BEGIN(LoopExtractor, "loop-extract",
70                       "Extract loops into new functions", false, false)
71 INITIALIZE_PASS_DEPENDENCY(BreakCriticalEdges)
72 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
73 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
74 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
75 INITIALIZE_PASS_END(LoopExtractor, "loop-extract",
76                     "Extract loops into new functions", false, false)
77 
78 namespace {
79   /// SingleLoopExtractor - For bugpoint.
80   struct SingleLoopExtractor : public LoopExtractor {
81     static char ID; // Pass identification, replacement for typeid
82     SingleLoopExtractor() : LoopExtractor(1) {}
83   };
84 } // End anonymous namespace
85 
86 char SingleLoopExtractor::ID = 0;
87 INITIALIZE_PASS(SingleLoopExtractor, "loop-extract-single",
88                 "Extract at most one loop into a new function", false, false)
89 
90 // createLoopExtractorPass - This pass extracts all natural loops from the
91 // program into a function if it can.
92 //
93 Pass *llvm::createLoopExtractorPass() { return new LoopExtractor(); }
94 
95 bool LoopExtractor::runOnModule(Module &M) {
96   if (skipModule(M))
97     return false;
98 
99   if (M.empty())
100     return false;
101 
102   if (!NumLoops)
103     return false;
104 
105   bool Changed = false;
106 
107   // The end of the function list may change (new functions will be added at the
108   // end), so we run from the first to the current last.
109   auto I = M.begin(), E = --M.end();
110   while (true) {
111     Function &F = *I;
112 
113     Changed |= runOnFunction(F);
114     if (!NumLoops)
115       break;
116 
117     // If this is the last function.
118     if (I == E)
119       break;
120 
121     ++I;
122   }
123   return Changed;
124 }
125 
126 bool LoopExtractor::runOnFunction(Function &F) {
127   // Do not modify `optnone` functions.
128   if (F.hasOptNone())
129     return false;
130 
131   if (F.empty())
132     return false;
133 
134   LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>(F).getLoopInfo();
135 
136   // If there are no loops in the function.
137   if (LI.empty())
138     return false;
139 
140   DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>(F).getDomTree();
141 
142   // If there is more than one top-level loop in this function, extract all of
143   // the loops.
144   if (std::next(LI.begin()) != LI.end())
145     return extractLoops(LI.begin(), LI.end(), LI, DT);
146 
147   // Otherwise there is exactly one top-level loop.
148   Loop *TLL = *LI.begin();
149 
150   // If the loop is in LoopSimplify form, then extract it only if this function
151   // is more than a minimal wrapper around the loop.
152   if (TLL->isLoopSimplifyForm()) {
153     bool ShouldExtractLoop = false;
154 
155     // Extract the loop if the entry block doesn't branch to the loop header.
156     Instruction *EntryTI = F.getEntryBlock().getTerminator();
157     if (!isa<BranchInst>(EntryTI) ||
158         !cast<BranchInst>(EntryTI)->isUnconditional() ||
159         EntryTI->getSuccessor(0) != TLL->getHeader()) {
160       ShouldExtractLoop = true;
161     } else {
162       // Check to see if any exits from the loop are more than just return
163       // blocks.
164       SmallVector<BasicBlock *, 8> ExitBlocks;
165       TLL->getExitBlocks(ExitBlocks);
166       for (auto *ExitBlock : ExitBlocks)
167         if (!isa<ReturnInst>(ExitBlock->getTerminator())) {
168           ShouldExtractLoop = true;
169           break;
170         }
171     }
172 
173     if (ShouldExtractLoop)
174       return extractLoop(TLL, LI, DT);
175   }
176 
177   // Okay, this function is a minimal container around the specified loop.
178   // If we extract the loop, we will continue to just keep extracting it
179   // infinitely... so don't extract it. However, if the loop contains any
180   // sub-loops, extract them.
181   return extractLoops(TLL->begin(), TLL->end(), LI, DT);
182 }
183 
184 bool LoopExtractor::extractLoops(Loop::iterator From, Loop::iterator To,
185                                  LoopInfo &LI, DominatorTree &DT) {
186   bool Changed = false;
187   SmallVector<Loop *, 8> Loops;
188 
189   // Save the list of loops, as it may change.
190   Loops.assign(From, To);
191   for (Loop *L : Loops) {
192     // If LoopSimplify form is not available, stay out of trouble.
193     if (!L->isLoopSimplifyForm())
194       continue;
195 
196     Changed |= extractLoop(L, LI, DT);
197     if (!NumLoops)
198       break;
199   }
200   return Changed;
201 }
202 
203 bool LoopExtractor::extractLoop(Loop *L, LoopInfo &LI, DominatorTree &DT) {
204   assert(NumLoops != 0);
205   AssumptionCache *AC = nullptr;
206   Function &Func = *L->getHeader()->getParent();
207   if (auto *ACT = getAnalysisIfAvailable<AssumptionCacheTracker>())
208     AC = ACT->lookupAssumptionCache(Func);
209   CodeExtractorAnalysisCache CEAC(Func);
210   CodeExtractor Extractor(DT, *L, false, nullptr, nullptr, AC);
211   if (Extractor.extractCodeRegion(CEAC)) {
212     LI.erase(L);
213     --NumLoops;
214     ++NumExtracted;
215     return true;
216   }
217   return false;
218 }
219 
220 // createSingleLoopExtractorPass - This pass extracts one natural loop from the
221 // program into a function if it can.  This is used by bugpoint.
222 //
223 Pass *llvm::createSingleLoopExtractorPass() {
224   return new SingleLoopExtractor();
225 }
226