1 //===-- LoopUnrollAndJam.cpp - Loop unrolling utilities -------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements loop unroll and jam as a routine, much like
11 // LoopUnroll.cpp implements loop unroll.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/ADT/SmallPtrSet.h"
16 #include "llvm/ADT/Statistic.h"
17 #include "llvm/Analysis/AssumptionCache.h"
18 #include "llvm/Analysis/DependenceAnalysis.h"
19 #include "llvm/Analysis/InstructionSimplify.h"
20 #include "llvm/Analysis/LoopAnalysisManager.h"
21 #include "llvm/Analysis/LoopIterator.h"
22 #include "llvm/Analysis/LoopPass.h"
23 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
24 #include "llvm/Analysis/ScalarEvolution.h"
25 #include "llvm/Analysis/ScalarEvolutionExpander.h"
26 #include "llvm/Analysis/Utils/Local.h"
27 #include "llvm/IR/BasicBlock.h"
28 #include "llvm/IR/DataLayout.h"
29 #include "llvm/IR/DebugInfoMetadata.h"
30 #include "llvm/IR/Dominators.h"
31 #include "llvm/IR/IntrinsicInst.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/raw_ostream.h"
35 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
36 #include "llvm/Transforms/Utils/Cloning.h"
37 #include "llvm/Transforms/Utils/LoopSimplify.h"
38 #include "llvm/Transforms/Utils/LoopUtils.h"
39 #include "llvm/Transforms/Utils/SimplifyIndVar.h"
40 #include "llvm/Transforms/Utils/UnrollLoop.h"
41 using namespace llvm;
42 
43 #define DEBUG_TYPE "loop-unroll-and-jam"
44 
45 STATISTIC(NumUnrolledAndJammed, "Number of loops unroll and jammed");
46 STATISTIC(NumCompletelyUnrolledAndJammed, "Number of loops unroll and jammed");
47 
48 typedef SmallPtrSet<BasicBlock *, 4> BasicBlockSet;
49 
50 // Partition blocks in an outer/inner loop pair into blocks before and after
51 // the loop
52 static bool partitionOuterLoopBlocks(Loop *L, Loop *SubLoop,
53                                      BasicBlockSet &ForeBlocks,
54                                      BasicBlockSet &SubLoopBlocks,
55                                      BasicBlockSet &AftBlocks,
56                                      DominatorTree *DT) {
57   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
58   SubLoopBlocks.insert(SubLoop->block_begin(), SubLoop->block_end());
59 
60   for (BasicBlock *BB : L->blocks()) {
61     if (!SubLoop->contains(BB)) {
62       if (DT->dominates(SubLoopLatch, BB))
63         AftBlocks.insert(BB);
64       else
65         ForeBlocks.insert(BB);
66     }
67   }
68 
69   // Check that all blocks in ForeBlocks together dominate the subloop
70   // TODO: This might ideally be done better with a dominator/postdominators.
71   BasicBlock *SubLoopPreHeader = SubLoop->getLoopPreheader();
72   for (BasicBlock *BB : ForeBlocks) {
73     if (BB == SubLoopPreHeader)
74       continue;
75     Instruction *TI = BB->getTerminator();
76     for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
77       if (!ForeBlocks.count(TI->getSuccessor(i)))
78         return false;
79   }
80 
81   return true;
82 }
83 
84 // Looks at the phi nodes in Header for values coming from Latch. For these
85 // instructions and all their operands calls Visit on them, keeping going for
86 // all the operands in AftBlocks. Returns false if Visit returns false,
87 // otherwise returns true. This is used to process the instructions in the
88 // Aft blocks that need to be moved before the subloop. It is used in two
89 // places. One to check that the required set of instructions can be moved
90 // before the loop. Then to collect the instructions to actually move in
91 // moveHeaderPhiOperandsToForeBlocks.
92 template <typename T>
93 static bool processHeaderPhiOperands(BasicBlock *Header, BasicBlock *Latch,
94                                      BasicBlockSet &AftBlocks, T Visit) {
95   SmallVector<Instruction *, 8> Worklist;
96   for (auto &Phi : Header->phis()) {
97     Value *V = Phi.getIncomingValueForBlock(Latch);
98     if (Instruction *I = dyn_cast<Instruction>(V))
99       Worklist.push_back(I);
100   }
101 
102   while (!Worklist.empty()) {
103     Instruction *I = Worklist.back();
104     Worklist.pop_back();
105     if (!Visit(I))
106       return false;
107 
108     if (AftBlocks.count(I->getParent()))
109       for (auto &U : I->operands())
110         if (Instruction *II = dyn_cast<Instruction>(U))
111           Worklist.push_back(II);
112   }
113 
114   return true;
115 }
116 
117 // Move the phi operands of Header from Latch out of AftBlocks to InsertLoc.
118 static void moveHeaderPhiOperandsToForeBlocks(BasicBlock *Header,
119                                               BasicBlock *Latch,
120                                               Instruction *InsertLoc,
121                                               BasicBlockSet &AftBlocks) {
122   // We need to ensure we move the instructions in the correct order,
123   // starting with the earliest required instruction and moving forward.
124   std::vector<Instruction *> Visited;
125   processHeaderPhiOperands(Header, Latch, AftBlocks,
126                            [&Visited, &AftBlocks](Instruction *I) {
127                              if (AftBlocks.count(I->getParent()))
128                                Visited.push_back(I);
129                              return true;
130                            });
131 
132   // Move all instructions in program order to before the InsertLoc
133   BasicBlock *InsertLocBB = InsertLoc->getParent();
134   for (Instruction *I : reverse(Visited)) {
135     if (I->getParent() != InsertLocBB)
136       I->moveBefore(InsertLoc);
137   }
138 }
139 
140 /*
141   This method performs Unroll and Jam. For a simple loop like:
142   for (i = ..)
143     Fore(i)
144     for (j = ..)
145       SubLoop(i, j)
146     Aft(i)
147 
148   Instead of doing normal inner or outer unrolling, we do:
149   for (i = .., i+=2)
150     Fore(i)
151     Fore(i+1)
152     for (j = ..)
153       SubLoop(i, j)
154       SubLoop(i+1, j)
155     Aft(i)
156     Aft(i+1)
157 
158   So the outer loop is essetially unrolled and then the inner loops are fused
159   ("jammed") together into a single loop. This can increase speed when there
160   are loads in SubLoop that are invariant to i, as they become shared between
161   the now jammed inner loops.
162 
163   We do this by spliting the blocks in the loop into Fore, Subloop and Aft.
164   Fore blocks are those before the inner loop, Aft are those after. Normal
165   Unroll code is used to copy each of these sets of blocks and the results are
166   combined together into the final form above.
167 
168   isSafeToUnrollAndJam should be used prior to calling this to make sure the
169   unrolling will be valid. Checking profitablility is also advisable.
170 
171   If EpilogueLoop is non-null, it receives the epilogue loop (if it was
172   necessary to create one and not fully unrolled).
173 */
174 LoopUnrollResult llvm::UnrollAndJamLoop(
175     Loop *L, unsigned Count, unsigned TripCount, unsigned TripMultiple,
176     bool UnrollRemainder, LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT,
177     AssumptionCache *AC, OptimizationRemarkEmitter *ORE, Loop **EpilogueLoop) {
178 
179   // When we enter here we should have already checked that it is safe
180   BasicBlock *Header = L->getHeader();
181   assert(L->getSubLoops().size() == 1);
182   Loop *SubLoop = *L->begin();
183 
184   // Don't enter the unroll code if there is nothing to do.
185   if (TripCount == 0 && Count < 2) {
186     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; almost nothing to do\n");
187     return LoopUnrollResult::Unmodified;
188   }
189 
190   assert(Count > 0);
191   assert(TripMultiple > 0);
192   assert(TripCount == 0 || TripCount % TripMultiple == 0);
193 
194   // Are we eliminating the loop control altogether?
195   bool CompletelyUnroll = (Count == TripCount);
196 
197   // We use the runtime remainder in cases where we don't know trip multiple
198   if (TripMultiple == 1 || TripMultiple % Count != 0) {
199     if (!UnrollRuntimeLoopRemainder(L, Count, /*AllowExpensiveTripCount*/ false,
200                                     /*UseEpilogRemainder*/ true,
201                                     UnrollRemainder, LI, SE, DT, AC, true,
202                                     EpilogueLoop)) {
203       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; remainder loop could not be "
204                            "generated when assuming runtime trip count\n");
205       return LoopUnrollResult::Unmodified;
206     }
207   }
208 
209   // Notify ScalarEvolution that the loop will be substantially changed,
210   // if not outright eliminated.
211   if (SE) {
212     SE->forgetLoop(L);
213     SE->forgetLoop(SubLoop);
214   }
215 
216   using namespace ore;
217   // Report the unrolling decision.
218   if (CompletelyUnroll) {
219     LLVM_DEBUG(dbgs() << "COMPLETELY UNROLL AND JAMMING loop %"
220                       << Header->getName() << " with trip count " << TripCount
221                       << "!\n");
222     ORE->emit(OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(),
223                                  L->getHeader())
224               << "completely unroll and jammed loop with "
225               << NV("UnrollCount", TripCount) << " iterations");
226   } else {
227     auto DiagBuilder = [&]() {
228       OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(),
229                               L->getHeader());
230       return Diag << "unroll and jammed loop by a factor of "
231                   << NV("UnrollCount", Count);
232     };
233 
234     LLVM_DEBUG(dbgs() << "UNROLL AND JAMMING loop %" << Header->getName()
235                       << " by " << Count);
236     if (TripMultiple != 1) {
237       LLVM_DEBUG(dbgs() << " with " << TripMultiple << " trips per branch");
238       ORE->emit([&]() {
239         return DiagBuilder() << " with " << NV("TripMultiple", TripMultiple)
240                              << " trips per branch";
241       });
242     } else {
243       LLVM_DEBUG(dbgs() << " with run-time trip count");
244       ORE->emit([&]() { return DiagBuilder() << " with run-time trip count"; });
245     }
246     LLVM_DEBUG(dbgs() << "!\n");
247   }
248 
249   BasicBlock *Preheader = L->getLoopPreheader();
250   BasicBlock *LatchBlock = L->getLoopLatch();
251   BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator());
252   assert(Preheader && LatchBlock && Header);
253   assert(BI && !BI->isUnconditional());
254   bool ContinueOnTrue = L->contains(BI->getSuccessor(0));
255   BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue);
256   bool SubLoopContinueOnTrue = SubLoop->contains(
257       SubLoop->getLoopLatch()->getTerminator()->getSuccessor(0));
258 
259   // Partition blocks in an outer/inner loop pair into blocks before and after
260   // the loop
261   BasicBlockSet SubLoopBlocks;
262   BasicBlockSet ForeBlocks;
263   BasicBlockSet AftBlocks;
264   partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks, AftBlocks,
265                            DT);
266 
267   // We keep track of the entering/first and exiting/last block of each of
268   // Fore/SubLoop/Aft in each iteration. This helps make the stapling up of
269   // blocks easier.
270   std::vector<BasicBlock *> ForeBlocksFirst;
271   std::vector<BasicBlock *> ForeBlocksLast;
272   std::vector<BasicBlock *> SubLoopBlocksFirst;
273   std::vector<BasicBlock *> SubLoopBlocksLast;
274   std::vector<BasicBlock *> AftBlocksFirst;
275   std::vector<BasicBlock *> AftBlocksLast;
276   ForeBlocksFirst.push_back(Header);
277   ForeBlocksLast.push_back(SubLoop->getLoopPreheader());
278   SubLoopBlocksFirst.push_back(SubLoop->getHeader());
279   SubLoopBlocksLast.push_back(SubLoop->getExitingBlock());
280   AftBlocksFirst.push_back(SubLoop->getExitBlock());
281   AftBlocksLast.push_back(L->getExitingBlock());
282   // Maps Blocks[0] -> Blocks[It]
283   ValueToValueMapTy LastValueMap;
284 
285   // Move any instructions from fore phi operands from AftBlocks into Fore.
286   moveHeaderPhiOperandsToForeBlocks(
287       Header, LatchBlock, SubLoop->getLoopPreheader()->getTerminator(),
288       AftBlocks);
289 
290   // The current on-the-fly SSA update requires blocks to be processed in
291   // reverse postorder so that LastValueMap contains the correct value at each
292   // exit.
293   LoopBlocksDFS DFS(L);
294   DFS.perform(LI);
295   // Stash the DFS iterators before adding blocks to the loop.
296   LoopBlocksDFS::RPOIterator BlockBegin = DFS.beginRPO();
297   LoopBlocksDFS::RPOIterator BlockEnd = DFS.endRPO();
298 
299   if (Header->getParent()->isDebugInfoForProfiling())
300     for (BasicBlock *BB : L->getBlocks())
301       for (Instruction &I : *BB)
302         if (!isa<DbgInfoIntrinsic>(&I))
303           if (const DILocation *DIL = I.getDebugLoc())
304             I.setDebugLoc(DIL->cloneWithDuplicationFactor(Count));
305 
306   // Copy all blocks
307   for (unsigned It = 1; It != Count; ++It) {
308     std::vector<BasicBlock *> NewBlocks;
309     // Maps Blocks[It] -> Blocks[It-1]
310     DenseMap<Value *, Value *> PrevItValueMap;
311 
312     for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) {
313       ValueToValueMapTy VMap;
314       BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It));
315       Header->getParent()->getBasicBlockList().push_back(New);
316 
317       if (ForeBlocks.count(*BB)) {
318         L->addBasicBlockToLoop(New, *LI);
319 
320         if (*BB == ForeBlocksFirst[0])
321           ForeBlocksFirst.push_back(New);
322         if (*BB == ForeBlocksLast[0])
323           ForeBlocksLast.push_back(New);
324       } else if (SubLoopBlocks.count(*BB)) {
325         SubLoop->addBasicBlockToLoop(New, *LI);
326 
327         if (*BB == SubLoopBlocksFirst[0])
328           SubLoopBlocksFirst.push_back(New);
329         if (*BB == SubLoopBlocksLast[0])
330           SubLoopBlocksLast.push_back(New);
331       } else if (AftBlocks.count(*BB)) {
332         L->addBasicBlockToLoop(New, *LI);
333 
334         if (*BB == AftBlocksFirst[0])
335           AftBlocksFirst.push_back(New);
336         if (*BB == AftBlocksLast[0])
337           AftBlocksLast.push_back(New);
338       } else {
339         llvm_unreachable("BB being cloned should be in Fore/Sub/Aft");
340       }
341 
342       // Update our running maps of newest clones
343       PrevItValueMap[New] = (It == 1 ? *BB : LastValueMap[*BB]);
344       LastValueMap[*BB] = New;
345       for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end();
346            VI != VE; ++VI) {
347         PrevItValueMap[VI->second] =
348             const_cast<Value *>(It == 1 ? VI->first : LastValueMap[VI->first]);
349         LastValueMap[VI->first] = VI->second;
350       }
351 
352       NewBlocks.push_back(New);
353 
354       // Update DomTree:
355       if (*BB == ForeBlocksFirst[0])
356         DT->addNewBlock(New, ForeBlocksLast[It - 1]);
357       else if (*BB == SubLoopBlocksFirst[0])
358         DT->addNewBlock(New, SubLoopBlocksLast[It - 1]);
359       else if (*BB == AftBlocksFirst[0])
360         DT->addNewBlock(New, AftBlocksLast[It - 1]);
361       else {
362         // Each set of blocks (Fore/Sub/Aft) will have the same internal domtree
363         // structure.
364         auto BBDomNode = DT->getNode(*BB);
365         auto BBIDom = BBDomNode->getIDom();
366         BasicBlock *OriginalBBIDom = BBIDom->getBlock();
367         assert(OriginalBBIDom);
368         assert(LastValueMap[cast<Value>(OriginalBBIDom)]);
369         DT->addNewBlock(
370             New, cast<BasicBlock>(LastValueMap[cast<Value>(OriginalBBIDom)]));
371       }
372     }
373 
374     // Remap all instructions in the most recent iteration
375     for (BasicBlock *NewBlock : NewBlocks) {
376       for (Instruction &I : *NewBlock) {
377         ::remapInstruction(&I, LastValueMap);
378         if (auto *II = dyn_cast<IntrinsicInst>(&I))
379           if (II->getIntrinsicID() == Intrinsic::assume)
380             AC->registerAssumption(II);
381       }
382     }
383 
384     // Alter the ForeBlocks phi's, pointing them at the latest version of the
385     // value from the previous iteration's phis
386     for (PHINode &Phi : ForeBlocksFirst[It]->phis()) {
387       Value *OldValue = Phi.getIncomingValueForBlock(AftBlocksLast[It]);
388       assert(OldValue && "should have incoming edge from Aft[It]");
389       Value *NewValue = OldValue;
390       if (Value *PrevValue = PrevItValueMap[OldValue])
391         NewValue = PrevValue;
392 
393       assert(Phi.getNumOperands() == 2);
394       Phi.setIncomingBlock(0, ForeBlocksLast[It - 1]);
395       Phi.setIncomingValue(0, NewValue);
396       Phi.removeIncomingValue(1);
397     }
398   }
399 
400   // Now that all the basic blocks for the unrolled iterations are in place,
401   // finish up connecting the blocks and phi nodes. At this point LastValueMap
402   // is the last unrolled iterations values.
403 
404   // Update Phis in BB from OldBB to point to NewBB
405   auto updatePHIBlocks = [](BasicBlock *BB, BasicBlock *OldBB,
406                             BasicBlock *NewBB) {
407     for (PHINode &Phi : BB->phis()) {
408       int I = Phi.getBasicBlockIndex(OldBB);
409       Phi.setIncomingBlock(I, NewBB);
410     }
411   };
412   // Update Phis in BB from OldBB to point to NewBB and use the latest value
413   // from LastValueMap
414   auto updatePHIBlocksAndValues = [](BasicBlock *BB, BasicBlock *OldBB,
415                                      BasicBlock *NewBB,
416                                      ValueToValueMapTy &LastValueMap) {
417     for (PHINode &Phi : BB->phis()) {
418       for (unsigned b = 0; b < Phi.getNumIncomingValues(); ++b) {
419         if (Phi.getIncomingBlock(b) == OldBB) {
420           Value *OldValue = Phi.getIncomingValue(b);
421           if (Value *LastValue = LastValueMap[OldValue])
422             Phi.setIncomingValue(b, LastValue);
423           Phi.setIncomingBlock(b, NewBB);
424           break;
425         }
426       }
427     }
428   };
429   // Move all the phis from Src into Dest
430   auto movePHIs = [](BasicBlock *Src, BasicBlock *Dest) {
431     Instruction *insertPoint = Dest->getFirstNonPHI();
432     while (PHINode *Phi = dyn_cast<PHINode>(Src->begin()))
433       Phi->moveBefore(insertPoint);
434   };
435 
436   // Update the PHI values outside the loop to point to the last block
437   updatePHIBlocksAndValues(LoopExit, AftBlocksLast[0], AftBlocksLast.back(),
438                            LastValueMap);
439 
440   // Update ForeBlocks successors and phi nodes
441   BranchInst *ForeTerm =
442       cast<BranchInst>(ForeBlocksLast.back()->getTerminator());
443   BasicBlock *Dest = SubLoopBlocksFirst[0];
444   ForeTerm->setSuccessor(0, Dest);
445 
446   if (CompletelyUnroll) {
447     while (PHINode *Phi = dyn_cast<PHINode>(ForeBlocksFirst[0]->begin())) {
448       Phi->replaceAllUsesWith(Phi->getIncomingValueForBlock(Preheader));
449       Phi->getParent()->getInstList().erase(Phi);
450     }
451   } else {
452     // Update the PHI values to point to the last aft block
453     updatePHIBlocksAndValues(ForeBlocksFirst[0], AftBlocksLast[0],
454                              AftBlocksLast.back(), LastValueMap);
455   }
456 
457   for (unsigned It = 1; It != Count; It++) {
458     // Remap ForeBlock successors from previous iteration to this
459     BranchInst *ForeTerm =
460         cast<BranchInst>(ForeBlocksLast[It - 1]->getTerminator());
461     BasicBlock *Dest = ForeBlocksFirst[It];
462     ForeTerm->setSuccessor(0, Dest);
463   }
464 
465   // Subloop successors and phis
466   BranchInst *SubTerm =
467       cast<BranchInst>(SubLoopBlocksLast.back()->getTerminator());
468   SubTerm->setSuccessor(!SubLoopContinueOnTrue, SubLoopBlocksFirst[0]);
469   SubTerm->setSuccessor(SubLoopContinueOnTrue, AftBlocksFirst[0]);
470   updatePHIBlocks(SubLoopBlocksFirst[0], ForeBlocksLast[0],
471                   ForeBlocksLast.back());
472   updatePHIBlocks(SubLoopBlocksFirst[0], SubLoopBlocksLast[0],
473                   SubLoopBlocksLast.back());
474 
475   for (unsigned It = 1; It != Count; It++) {
476     // Replace the conditional branch of the previous iteration subloop with an
477     // unconditional one to this one
478     BranchInst *SubTerm =
479         cast<BranchInst>(SubLoopBlocksLast[It - 1]->getTerminator());
480     BranchInst::Create(SubLoopBlocksFirst[It], SubTerm);
481     SubTerm->eraseFromParent();
482 
483     updatePHIBlocks(SubLoopBlocksFirst[It], ForeBlocksLast[It],
484                     ForeBlocksLast.back());
485     updatePHIBlocks(SubLoopBlocksFirst[It], SubLoopBlocksLast[It],
486                     SubLoopBlocksLast.back());
487     movePHIs(SubLoopBlocksFirst[It], SubLoopBlocksFirst[0]);
488   }
489 
490   // Aft blocks successors and phis
491   BranchInst *Term = cast<BranchInst>(AftBlocksLast.back()->getTerminator());
492   if (CompletelyUnroll) {
493     BranchInst::Create(LoopExit, Term);
494     Term->eraseFromParent();
495   } else {
496     Term->setSuccessor(!ContinueOnTrue, ForeBlocksFirst[0]);
497   }
498   updatePHIBlocks(AftBlocksFirst[0], SubLoopBlocksLast[0],
499                   SubLoopBlocksLast.back());
500 
501   for (unsigned It = 1; It != Count; It++) {
502     // Replace the conditional branch of the previous iteration subloop with an
503     // unconditional one to this one
504     BranchInst *AftTerm =
505         cast<BranchInst>(AftBlocksLast[It - 1]->getTerminator());
506     BranchInst::Create(AftBlocksFirst[It], AftTerm);
507     AftTerm->eraseFromParent();
508 
509     updatePHIBlocks(AftBlocksFirst[It], SubLoopBlocksLast[It],
510                     SubLoopBlocksLast.back());
511     movePHIs(AftBlocksFirst[It], AftBlocksFirst[0]);
512   }
513 
514   // Dominator Tree. Remove the old links between Fore, Sub and Aft, adding the
515   // new ones required.
516   if (Count != 1) {
517     SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
518     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete, ForeBlocksLast[0],
519                            SubLoopBlocksFirst[0]);
520     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete,
521                            SubLoopBlocksLast[0], AftBlocksFirst[0]);
522 
523     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
524                            ForeBlocksLast.back(), SubLoopBlocksFirst[0]);
525     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
526                            SubLoopBlocksLast.back(), AftBlocksFirst[0]);
527     DT->applyUpdates(DTUpdates);
528   }
529 
530   // Merge adjacent basic blocks, if possible.
531   SmallPtrSet<BasicBlock *, 16> MergeBlocks;
532   MergeBlocks.insert(ForeBlocksLast.begin(), ForeBlocksLast.end());
533   MergeBlocks.insert(SubLoopBlocksLast.begin(), SubLoopBlocksLast.end());
534   MergeBlocks.insert(AftBlocksLast.begin(), AftBlocksLast.end());
535   while (!MergeBlocks.empty()) {
536     BasicBlock *BB = *MergeBlocks.begin();
537     BranchInst *Term = dyn_cast<BranchInst>(BB->getTerminator());
538     if (Term && Term->isUnconditional() && L->contains(Term->getSuccessor(0))) {
539       BasicBlock *Dest = Term->getSuccessor(0);
540       if (BasicBlock *Fold = foldBlockIntoPredecessor(Dest, LI, SE, DT)) {
541         // Don't remove BB and add Fold as they are the same BB
542         assert(Fold == BB);
543         (void)Fold;
544         MergeBlocks.erase(Dest);
545       } else
546         MergeBlocks.erase(BB);
547     } else
548       MergeBlocks.erase(BB);
549   }
550 
551   // At this point, the code is well formed.  We now do a quick sweep over the
552   // inserted code, doing constant propagation and dead code elimination as we
553   // go.
554   simplifyLoopAfterUnroll(SubLoop, true, LI, SE, DT, AC);
555   simplifyLoopAfterUnroll(L, !CompletelyUnroll && Count > 1, LI, SE, DT, AC);
556 
557   NumCompletelyUnrolledAndJammed += CompletelyUnroll;
558   ++NumUnrolledAndJammed;
559 
560 #ifndef NDEBUG
561   // We shouldn't have done anything to break loop simplify form or LCSSA.
562   Loop *OuterL = L->getParentLoop();
563   Loop *OutestLoop = OuterL ? OuterL : (!CompletelyUnroll ? L : SubLoop);
564   assert(OutestLoop->isRecursivelyLCSSAForm(*DT, *LI));
565   if (!CompletelyUnroll)
566     assert(L->isLoopSimplifyForm());
567   assert(SubLoop->isLoopSimplifyForm());
568   assert(DT->verify());
569 #endif
570 
571   // Update LoopInfo if the loop is completely removed.
572   if (CompletelyUnroll)
573     LI->erase(L);
574 
575   return CompletelyUnroll ? LoopUnrollResult::FullyUnrolled
576                           : LoopUnrollResult::PartiallyUnrolled;
577 }
578 
579 static bool getLoadsAndStores(BasicBlockSet &Blocks,
580                               SmallVector<Value *, 4> &MemInstr) {
581   // Scan the BBs and collect legal loads and stores.
582   // Returns false if non-simple loads/stores are found.
583   for (BasicBlock *BB : Blocks) {
584     for (Instruction &I : *BB) {
585       if (auto *Ld = dyn_cast<LoadInst>(&I)) {
586         if (!Ld->isSimple())
587           return false;
588         MemInstr.push_back(&I);
589       } else if (auto *St = dyn_cast<StoreInst>(&I)) {
590         if (!St->isSimple())
591           return false;
592         MemInstr.push_back(&I);
593       } else if (I.mayReadOrWriteMemory()) {
594         return false;
595       }
596     }
597   }
598   return true;
599 }
600 
601 static bool checkDependencies(SmallVector<Value *, 4> &Earlier,
602                               SmallVector<Value *, 4> &Later,
603                               unsigned LoopDepth, bool InnerLoop,
604                               DependenceInfo &DI) {
605   // Use DA to check for dependencies between loads and stores that make unroll
606   // and jam invalid
607   for (Value *I : Earlier) {
608     for (Value *J : Later) {
609       Instruction *Src = cast<Instruction>(I);
610       Instruction *Dst = cast<Instruction>(J);
611       if (Src == Dst)
612         continue;
613       // Ignore Input dependencies.
614       if (isa<LoadInst>(Src) && isa<LoadInst>(Dst))
615         continue;
616 
617       // Track dependencies, and if we find them take a conservative approach
618       // by allowing only = or < (not >), altough some > would be safe
619       // (depending upon unroll width).
620       // For the inner loop, we need to disallow any (> <) dependencies
621       // FIXME: Allow > so long as distance is less than unroll width
622       if (auto D = DI.depends(Src, Dst, true)) {
623         assert(D->isOrdered() && "Expected an output, flow or anti dep.");
624 
625         if (D->isConfused()) {
626           LLVM_DEBUG(dbgs() << "  Confused dependency between:\n"
627                             << "  " << *Src << "\n"
628                             << "  " << *Dst << "\n");
629           return false;
630         }
631         if (!InnerLoop) {
632           if (D->getDirection(LoopDepth) & Dependence::DVEntry::GT) {
633             LLVM_DEBUG(dbgs() << "  > dependency between:\n"
634                               << "  " << *Src << "\n"
635                               << "  " << *Dst << "\n");
636             return false;
637           }
638         } else {
639           assert(LoopDepth + 1 <= D->getLevels());
640           if (D->getDirection(LoopDepth) & Dependence::DVEntry::GT &&
641               D->getDirection(LoopDepth + 1) & Dependence::DVEntry::LT) {
642             LLVM_DEBUG(dbgs() << "  < > dependency between:\n"
643                               << "  " << *Src << "\n"
644                               << "  " << *Dst << "\n");
645             return false;
646           }
647         }
648       }
649     }
650   }
651   return true;
652 }
653 
654 static bool checkDependencies(Loop *L, BasicBlockSet &ForeBlocks,
655                               BasicBlockSet &SubLoopBlocks,
656                               BasicBlockSet &AftBlocks, DependenceInfo &DI) {
657   // Get all loads/store pairs for each blocks
658   SmallVector<Value *, 4> ForeMemInstr;
659   SmallVector<Value *, 4> SubLoopMemInstr;
660   SmallVector<Value *, 4> AftMemInstr;
661   if (!getLoadsAndStores(ForeBlocks, ForeMemInstr) ||
662       !getLoadsAndStores(SubLoopBlocks, SubLoopMemInstr) ||
663       !getLoadsAndStores(AftBlocks, AftMemInstr))
664     return false;
665 
666   // Check for dependencies between any blocks that may change order
667   unsigned LoopDepth = L->getLoopDepth();
668   return checkDependencies(ForeMemInstr, SubLoopMemInstr, LoopDepth, false,
669                            DI) &&
670          checkDependencies(ForeMemInstr, AftMemInstr, LoopDepth, false, DI) &&
671          checkDependencies(SubLoopMemInstr, AftMemInstr, LoopDepth, false,
672                            DI) &&
673          checkDependencies(SubLoopMemInstr, SubLoopMemInstr, LoopDepth, true,
674                            DI);
675 }
676 
677 bool llvm::isSafeToUnrollAndJam(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
678                                 DependenceInfo &DI) {
679   /* We currently handle outer loops like this:
680         |
681     ForeFirst    <----\    }
682      Blocks           |    } ForeBlocks
683     ForeLast          |    }
684         |             |
685     SubLoopFirst  <\  |    }
686      Blocks        |  |    } SubLoopBlocks
687     SubLoopLast   -/  |    }
688         |             |
689     AftFirst          |    }
690      Blocks           |    } AftBlocks
691     AftLast     ------/    }
692         |
693 
694     There are (theoretically) any number of blocks in ForeBlocks, SubLoopBlocks
695     and AftBlocks, providing that there is one edge from Fores to SubLoops,
696     one edge from SubLoops to Afts and a single outer loop exit (from Afts).
697     In practice we currently limit Aft blocks to a single block, and limit
698     things further in the profitablility checks of the unroll and jam pass.
699 
700     Because of the way we rearrange basic blocks, we also require that
701     the Fore blocks on all unrolled iterations are safe to move before the
702     SubLoop blocks of all iterations. So we require that the phi node looping
703     operands of ForeHeader can be moved to at least the end of ForeEnd, so that
704     we can arrange cloned Fore Blocks before the subloop and match up Phi's
705     correctly.
706 
707     i.e. The old order of blocks used to be F1 S1_1 S1_2 A1 F2 S2_1 S2_2 A2.
708     It needs to be safe to tranform this to F1 F2 S1_1 S2_1 S1_2 S2_2 A1 A2.
709 
710     There are then a number of checks along the lines of no calls, no
711     exceptions, inner loop IV is consistent, etc. Note that for loops requiring
712     runtime unrolling, UnrollRuntimeLoopRemainder can also fail in
713     UnrollAndJamLoop if the trip count cannot be easily calculated.
714   */
715 
716   if (!L->isLoopSimplifyForm() || L->getSubLoops().size() != 1)
717     return false;
718   Loop *SubLoop = L->getSubLoops()[0];
719   if (!SubLoop->isLoopSimplifyForm())
720     return false;
721 
722   BasicBlock *Header = L->getHeader();
723   BasicBlock *Latch = L->getLoopLatch();
724   BasicBlock *Exit = L->getExitingBlock();
725   BasicBlock *SubLoopHeader = SubLoop->getHeader();
726   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
727   BasicBlock *SubLoopExit = SubLoop->getExitingBlock();
728 
729   if (Latch != Exit)
730     return false;
731   if (SubLoopLatch != SubLoopExit)
732     return false;
733 
734   if (Header->hasAddressTaken() || SubLoopHeader->hasAddressTaken()) {
735     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Address taken\n");
736     return false;
737   }
738 
739   // Split blocks into Fore/SubLoop/Aft based on dominators
740   BasicBlockSet SubLoopBlocks;
741   BasicBlockSet ForeBlocks;
742   BasicBlockSet AftBlocks;
743   if (!partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks,
744                                 AftBlocks, &DT)) {
745     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Incompatible loop layout\n");
746     return false;
747   }
748 
749   // Aft blocks may need to move instructions to fore blocks, which becomes more
750   // difficult if there are multiple (potentially conditionally executed)
751   // blocks. For now we just exclude loops with multiple aft blocks.
752   if (AftBlocks.size() != 1) {
753     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Can't currently handle "
754                          "multiple blocks after the loop\n");
755     return false;
756   }
757 
758   // Check inner loop backedge count is consistent on all iterations of the
759   // outer loop
760   if (!hasIterationCountInvariantInParent(SubLoop, SE)) {
761     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Inner loop iteration count is "
762                          "not consistent on each iteration\n");
763     return false;
764   }
765 
766   // Check the loop safety info for exceptions.
767   SimpleLoopSafetyInfo LSI;
768   LSI.computeLoopSafetyInfo(L);
769   if (LSI.anyBlockMayThrow()) {
770     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Something may throw\n");
771     return false;
772   }
773 
774   // We've ruled out the easy stuff and now need to check that there are no
775   // interdependencies which may prevent us from moving the:
776   //  ForeBlocks before Subloop and AftBlocks.
777   //  Subloop before AftBlocks.
778   //  ForeBlock phi operands before the subloop
779 
780   // Make sure we can move all instructions we need to before the subloop
781   if (!processHeaderPhiOperands(
782           Header, Latch, AftBlocks, [&AftBlocks, &SubLoop](Instruction *I) {
783             if (SubLoop->contains(I->getParent()))
784               return false;
785             if (AftBlocks.count(I->getParent())) {
786               // If we hit a phi node in afts we know we are done (probably
787               // LCSSA)
788               if (isa<PHINode>(I))
789                 return false;
790               // Can't move instructions with side effects or memory
791               // reads/writes
792               if (I->mayHaveSideEffects() || I->mayReadOrWriteMemory())
793                 return false;
794             }
795             // Keep going
796             return true;
797           })) {
798     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; can't move required "
799                          "instructions after subloop to before it\n");
800     return false;
801   }
802 
803   // Check for memory dependencies which prohibit the unrolling we are doing.
804   // Because of the way we are unrolling Fore/Sub/Aft blocks, we need to check
805   // there are no dependencies between Fore-Sub, Fore-Aft, Sub-Aft and Sub-Sub.
806   if (!checkDependencies(L, ForeBlocks, SubLoopBlocks, AftBlocks, DI)) {
807     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; failed dependency check\n");
808     return false;
809   }
810 
811   return true;
812 }
813