1 //===-- LoopUnrollAndJam.cpp - Loop unrolling utilities -------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements loop unroll and jam as a routine, much like
11 // LoopUnroll.cpp implements loop unroll.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/ADT/SmallPtrSet.h"
16 #include "llvm/ADT/Statistic.h"
17 #include "llvm/Analysis/AssumptionCache.h"
18 #include "llvm/Analysis/DependenceAnalysis.h"
19 #include "llvm/Analysis/InstructionSimplify.h"
20 #include "llvm/Analysis/LoopAnalysisManager.h"
21 #include "llvm/Analysis/LoopIterator.h"
22 #include "llvm/Analysis/LoopPass.h"
23 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
24 #include "llvm/Analysis/ScalarEvolution.h"
25 #include "llvm/Analysis/ScalarEvolutionExpander.h"
26 #include "llvm/Analysis/Utils/Local.h"
27 #include "llvm/IR/BasicBlock.h"
28 #include "llvm/IR/DataLayout.h"
29 #include "llvm/IR/DebugInfoMetadata.h"
30 #include "llvm/IR/Dominators.h"
31 #include "llvm/IR/IntrinsicInst.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/raw_ostream.h"
35 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
36 #include "llvm/Transforms/Utils/Cloning.h"
37 #include "llvm/Transforms/Utils/LoopSimplify.h"
38 #include "llvm/Transforms/Utils/LoopUtils.h"
39 #include "llvm/Transforms/Utils/SimplifyIndVar.h"
40 #include "llvm/Transforms/Utils/UnrollLoop.h"
41 using namespace llvm;
42 
43 #define DEBUG_TYPE "loop-unroll-and-jam"
44 
45 STATISTIC(NumUnrolledAndJammed, "Number of loops unroll and jammed");
46 STATISTIC(NumCompletelyUnrolledAndJammed, "Number of loops unroll and jammed");
47 
48 typedef SmallPtrSet<BasicBlock *, 4> BasicBlockSet;
49 
50 // Partition blocks in an outer/inner loop pair into blocks before and after
51 // the loop
52 static bool partitionOuterLoopBlocks(Loop *L, Loop *SubLoop,
53                                      BasicBlockSet &ForeBlocks,
54                                      BasicBlockSet &SubLoopBlocks,
55                                      BasicBlockSet &AftBlocks,
56                                      DominatorTree *DT) {
57   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
58   SubLoopBlocks.insert(SubLoop->block_begin(), SubLoop->block_end());
59 
60   for (BasicBlock *BB : L->blocks()) {
61     if (!SubLoop->contains(BB)) {
62       if (DT->dominates(SubLoopLatch, BB))
63         AftBlocks.insert(BB);
64       else
65         ForeBlocks.insert(BB);
66     }
67   }
68 
69   // Check that all blocks in ForeBlocks together dominate the subloop
70   // TODO: This might ideally be done better with a dominator/postdominators.
71   BasicBlock *SubLoopPreHeader = SubLoop->getLoopPreheader();
72   for (BasicBlock *BB : ForeBlocks) {
73     if (BB == SubLoopPreHeader)
74       continue;
75     TerminatorInst *TI = BB->getTerminator();
76     for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
77       if (!ForeBlocks.count(TI->getSuccessor(i)))
78         return false;
79   }
80 
81   return true;
82 }
83 
84 // Looks at the phi nodes in Header for values coming from Latch. For these
85 // instructions and all their operands calls Visit on them, keeping going for
86 // all the operands in AftBlocks. Returns false if Visit returns false,
87 // otherwise returns true. This is used to process the instructions in the
88 // Aft blocks that need to be moved before the subloop. It is used in two
89 // places. One to check that the required set of instructions can be moved
90 // before the loop. Then to collect the instructions to actually move in
91 // moveHeaderPhiOperandsToForeBlocks.
92 template <typename T>
93 static bool processHeaderPhiOperands(BasicBlock *Header, BasicBlock *Latch,
94                                      BasicBlockSet &AftBlocks, T Visit) {
95   SmallVector<Instruction *, 8> Worklist;
96   for (auto &Phi : Header->phis()) {
97     Value *V = Phi.getIncomingValueForBlock(Latch);
98     if (Instruction *I = dyn_cast<Instruction>(V))
99       Worklist.push_back(I);
100   }
101 
102   while (!Worklist.empty()) {
103     Instruction *I = Worklist.back();
104     Worklist.pop_back();
105     if (!Visit(I))
106       return false;
107 
108     if (AftBlocks.count(I->getParent()))
109       for (auto &U : I->operands())
110         if (Instruction *II = dyn_cast<Instruction>(U))
111           Worklist.push_back(II);
112   }
113 
114   return true;
115 }
116 
117 // Move the phi operands of Header from Latch out of AftBlocks to InsertLoc.
118 static void moveHeaderPhiOperandsToForeBlocks(BasicBlock *Header,
119                                               BasicBlock *Latch,
120                                               Instruction *InsertLoc,
121                                               BasicBlockSet &AftBlocks) {
122   // We need to ensure we move the instructions in the correct order,
123   // starting with the earliest required instruction and moving forward.
124   std::vector<Instruction *> Visited;
125   processHeaderPhiOperands(Header, Latch, AftBlocks,
126                            [&Visited, &AftBlocks](Instruction *I) {
127                              if (AftBlocks.count(I->getParent()))
128                                Visited.push_back(I);
129                              return true;
130                            });
131 
132   // Move all instructions in program order to before the InsertLoc
133   BasicBlock *InsertLocBB = InsertLoc->getParent();
134   for (Instruction *I : reverse(Visited)) {
135     if (I->getParent() != InsertLocBB)
136       I->moveBefore(InsertLoc);
137   }
138 }
139 
140 /*
141   This method performs Unroll and Jam. For a simple loop like:
142   for (i = ..)
143     Fore(i)
144     for (j = ..)
145       SubLoop(i, j)
146     Aft(i)
147 
148   Instead of doing normal inner or outer unrolling, we do:
149   for (i = .., i+=2)
150     Fore(i)
151     Fore(i+1)
152     for (j = ..)
153       SubLoop(i, j)
154       SubLoop(i+1, j)
155     Aft(i)
156     Aft(i+1)
157 
158   So the outer loop is essetially unrolled and then the inner loops are fused
159   ("jammed") together into a single loop. This can increase speed when there
160   are loads in SubLoop that are invariant to i, as they become shared between
161   the now jammed inner loops.
162 
163   We do this by spliting the blocks in the loop into Fore, Subloop and Aft.
164   Fore blocks are those before the inner loop, Aft are those after. Normal
165   Unroll code is used to copy each of these sets of blocks and the results are
166   combined together into the final form above.
167 
168   isSafeToUnrollAndJam should be used prior to calling this to make sure the
169   unrolling will be valid. Checking profitablility is also advisable.
170 */
171 LoopUnrollResult
172 llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount,
173                        unsigned TripMultiple, bool UnrollRemainder,
174                        LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT,
175                        AssumptionCache *AC, OptimizationRemarkEmitter *ORE) {
176 
177   // When we enter here we should have already checked that it is safe
178   BasicBlock *Header = L->getHeader();
179   assert(L->getSubLoops().size() == 1);
180   Loop *SubLoop = *L->begin();
181 
182   // Don't enter the unroll code if there is nothing to do.
183   if (TripCount == 0 && Count < 2) {
184     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; almost nothing to do\n");
185     return LoopUnrollResult::Unmodified;
186   }
187 
188   assert(Count > 0);
189   assert(TripMultiple > 0);
190   assert(TripCount == 0 || TripCount % TripMultiple == 0);
191 
192   // Are we eliminating the loop control altogether?
193   bool CompletelyUnroll = (Count == TripCount);
194 
195   // We use the runtime remainder in cases where we don't know trip multiple
196   if (TripMultiple == 1 || TripMultiple % Count != 0) {
197     if (!UnrollRuntimeLoopRemainder(L, Count, /*AllowExpensiveTripCount*/ false,
198                                     /*UseEpilogRemainder*/ true,
199                                     UnrollRemainder, LI, SE, DT, AC, true)) {
200       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; remainder loop could not be "
201                            "generated when assuming runtime trip count\n");
202       return LoopUnrollResult::Unmodified;
203     }
204   }
205 
206   // Notify ScalarEvolution that the loop will be substantially changed,
207   // if not outright eliminated.
208   if (SE) {
209     SE->forgetLoop(L);
210     SE->forgetLoop(SubLoop);
211   }
212 
213   using namespace ore;
214   // Report the unrolling decision.
215   if (CompletelyUnroll) {
216     LLVM_DEBUG(dbgs() << "COMPLETELY UNROLL AND JAMMING loop %"
217                       << Header->getName() << " with trip count " << TripCount
218                       << "!\n");
219     ORE->emit(OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(),
220                                  L->getHeader())
221               << "completely unroll and jammed loop with "
222               << NV("UnrollCount", TripCount) << " iterations");
223   } else {
224     auto DiagBuilder = [&]() {
225       OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(),
226                               L->getHeader());
227       return Diag << "unroll and jammed loop by a factor of "
228                   << NV("UnrollCount", Count);
229     };
230 
231     LLVM_DEBUG(dbgs() << "UNROLL AND JAMMING loop %" << Header->getName()
232                       << " by " << Count);
233     if (TripMultiple != 1) {
234       LLVM_DEBUG(dbgs() << " with " << TripMultiple << " trips per branch");
235       ORE->emit([&]() {
236         return DiagBuilder() << " with " << NV("TripMultiple", TripMultiple)
237                              << " trips per branch";
238       });
239     } else {
240       LLVM_DEBUG(dbgs() << " with run-time trip count");
241       ORE->emit([&]() { return DiagBuilder() << " with run-time trip count"; });
242     }
243     LLVM_DEBUG(dbgs() << "!\n");
244   }
245 
246   BasicBlock *Preheader = L->getLoopPreheader();
247   BasicBlock *LatchBlock = L->getLoopLatch();
248   BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator());
249   assert(Preheader && LatchBlock && Header);
250   assert(BI && !BI->isUnconditional());
251   bool ContinueOnTrue = L->contains(BI->getSuccessor(0));
252   BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue);
253   bool SubLoopContinueOnTrue = SubLoop->contains(
254       SubLoop->getLoopLatch()->getTerminator()->getSuccessor(0));
255 
256   // Partition blocks in an outer/inner loop pair into blocks before and after
257   // the loop
258   BasicBlockSet SubLoopBlocks;
259   BasicBlockSet ForeBlocks;
260   BasicBlockSet AftBlocks;
261   partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks, AftBlocks,
262                            DT);
263 
264   // We keep track of the entering/first and exiting/last block of each of
265   // Fore/SubLoop/Aft in each iteration. This helps make the stapling up of
266   // blocks easier.
267   std::vector<BasicBlock *> ForeBlocksFirst;
268   std::vector<BasicBlock *> ForeBlocksLast;
269   std::vector<BasicBlock *> SubLoopBlocksFirst;
270   std::vector<BasicBlock *> SubLoopBlocksLast;
271   std::vector<BasicBlock *> AftBlocksFirst;
272   std::vector<BasicBlock *> AftBlocksLast;
273   ForeBlocksFirst.push_back(Header);
274   ForeBlocksLast.push_back(SubLoop->getLoopPreheader());
275   SubLoopBlocksFirst.push_back(SubLoop->getHeader());
276   SubLoopBlocksLast.push_back(SubLoop->getExitingBlock());
277   AftBlocksFirst.push_back(SubLoop->getExitBlock());
278   AftBlocksLast.push_back(L->getExitingBlock());
279   // Maps Blocks[0] -> Blocks[It]
280   ValueToValueMapTy LastValueMap;
281 
282   // Move any instructions from fore phi operands from AftBlocks into Fore.
283   moveHeaderPhiOperandsToForeBlocks(
284       Header, LatchBlock, SubLoop->getLoopPreheader()->getTerminator(),
285       AftBlocks);
286 
287   // The current on-the-fly SSA update requires blocks to be processed in
288   // reverse postorder so that LastValueMap contains the correct value at each
289   // exit.
290   LoopBlocksDFS DFS(L);
291   DFS.perform(LI);
292   // Stash the DFS iterators before adding blocks to the loop.
293   LoopBlocksDFS::RPOIterator BlockBegin = DFS.beginRPO();
294   LoopBlocksDFS::RPOIterator BlockEnd = DFS.endRPO();
295 
296   if (Header->getParent()->isDebugInfoForProfiling())
297     for (BasicBlock *BB : L->getBlocks())
298       for (Instruction &I : *BB)
299         if (!isa<DbgInfoIntrinsic>(&I))
300           if (const DILocation *DIL = I.getDebugLoc())
301             I.setDebugLoc(DIL->cloneWithDuplicationFactor(Count));
302 
303   // Copy all blocks
304   for (unsigned It = 1; It != Count; ++It) {
305     std::vector<BasicBlock *> NewBlocks;
306     // Maps Blocks[It] -> Blocks[It-1]
307     DenseMap<Value *, Value *> PrevItValueMap;
308 
309     for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) {
310       ValueToValueMapTy VMap;
311       BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It));
312       Header->getParent()->getBasicBlockList().push_back(New);
313 
314       if (ForeBlocks.count(*BB)) {
315         L->addBasicBlockToLoop(New, *LI);
316 
317         if (*BB == ForeBlocksFirst[0])
318           ForeBlocksFirst.push_back(New);
319         if (*BB == ForeBlocksLast[0])
320           ForeBlocksLast.push_back(New);
321       } else if (SubLoopBlocks.count(*BB)) {
322         SubLoop->addBasicBlockToLoop(New, *LI);
323 
324         if (*BB == SubLoopBlocksFirst[0])
325           SubLoopBlocksFirst.push_back(New);
326         if (*BB == SubLoopBlocksLast[0])
327           SubLoopBlocksLast.push_back(New);
328       } else if (AftBlocks.count(*BB)) {
329         L->addBasicBlockToLoop(New, *LI);
330 
331         if (*BB == AftBlocksFirst[0])
332           AftBlocksFirst.push_back(New);
333         if (*BB == AftBlocksLast[0])
334           AftBlocksLast.push_back(New);
335       } else {
336         llvm_unreachable("BB being cloned should be in Fore/Sub/Aft");
337       }
338 
339       // Update our running maps of newest clones
340       PrevItValueMap[New] = (It == 1 ? *BB : LastValueMap[*BB]);
341       LastValueMap[*BB] = New;
342       for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end();
343            VI != VE; ++VI) {
344         PrevItValueMap[VI->second] =
345             const_cast<Value *>(It == 1 ? VI->first : LastValueMap[VI->first]);
346         LastValueMap[VI->first] = VI->second;
347       }
348 
349       NewBlocks.push_back(New);
350 
351       // Update DomTree:
352       if (*BB == ForeBlocksFirst[0])
353         DT->addNewBlock(New, ForeBlocksLast[It - 1]);
354       else if (*BB == SubLoopBlocksFirst[0])
355         DT->addNewBlock(New, SubLoopBlocksLast[It - 1]);
356       else if (*BB == AftBlocksFirst[0])
357         DT->addNewBlock(New, AftBlocksLast[It - 1]);
358       else {
359         // Each set of blocks (Fore/Sub/Aft) will have the same internal domtree
360         // structure.
361         auto BBDomNode = DT->getNode(*BB);
362         auto BBIDom = BBDomNode->getIDom();
363         BasicBlock *OriginalBBIDom = BBIDom->getBlock();
364         assert(OriginalBBIDom);
365         assert(LastValueMap[cast<Value>(OriginalBBIDom)]);
366         DT->addNewBlock(
367             New, cast<BasicBlock>(LastValueMap[cast<Value>(OriginalBBIDom)]));
368       }
369     }
370 
371     // Remap all instructions in the most recent iteration
372     for (BasicBlock *NewBlock : NewBlocks) {
373       for (Instruction &I : *NewBlock) {
374         ::remapInstruction(&I, LastValueMap);
375         if (auto *II = dyn_cast<IntrinsicInst>(&I))
376           if (II->getIntrinsicID() == Intrinsic::assume)
377             AC->registerAssumption(II);
378       }
379     }
380 
381     // Alter the ForeBlocks phi's, pointing them at the latest version of the
382     // value from the previous iteration's phis
383     for (PHINode &Phi : ForeBlocksFirst[It]->phis()) {
384       Value *OldValue = Phi.getIncomingValueForBlock(AftBlocksLast[It]);
385       assert(OldValue && "should have incoming edge from Aft[It]");
386       Value *NewValue = OldValue;
387       if (Value *PrevValue = PrevItValueMap[OldValue])
388         NewValue = PrevValue;
389 
390       assert(Phi.getNumOperands() == 2);
391       Phi.setIncomingBlock(0, ForeBlocksLast[It - 1]);
392       Phi.setIncomingValue(0, NewValue);
393       Phi.removeIncomingValue(1);
394     }
395   }
396 
397   // Now that all the basic blocks for the unrolled iterations are in place,
398   // finish up connecting the blocks and phi nodes. At this point LastValueMap
399   // is the last unrolled iterations values.
400 
401   // Update Phis in BB from OldBB to point to NewBB
402   auto updatePHIBlocks = [](BasicBlock *BB, BasicBlock *OldBB,
403                             BasicBlock *NewBB) {
404     for (PHINode &Phi : BB->phis()) {
405       int I = Phi.getBasicBlockIndex(OldBB);
406       Phi.setIncomingBlock(I, NewBB);
407     }
408   };
409   // Update Phis in BB from OldBB to point to NewBB and use the latest value
410   // from LastValueMap
411   auto updatePHIBlocksAndValues = [](BasicBlock *BB, BasicBlock *OldBB,
412                                      BasicBlock *NewBB,
413                                      ValueToValueMapTy &LastValueMap) {
414     for (PHINode &Phi : BB->phis()) {
415       for (unsigned b = 0; b < Phi.getNumIncomingValues(); ++b) {
416         if (Phi.getIncomingBlock(b) == OldBB) {
417           Value *OldValue = Phi.getIncomingValue(b);
418           if (Value *LastValue = LastValueMap[OldValue])
419             Phi.setIncomingValue(b, LastValue);
420           Phi.setIncomingBlock(b, NewBB);
421           break;
422         }
423       }
424     }
425   };
426   // Move all the phis from Src into Dest
427   auto movePHIs = [](BasicBlock *Src, BasicBlock *Dest) {
428     Instruction *insertPoint = Dest->getFirstNonPHI();
429     while (PHINode *Phi = dyn_cast<PHINode>(Src->begin()))
430       Phi->moveBefore(insertPoint);
431   };
432 
433   // Update the PHI values outside the loop to point to the last block
434   updatePHIBlocksAndValues(LoopExit, AftBlocksLast[0], AftBlocksLast.back(),
435                            LastValueMap);
436 
437   // Update ForeBlocks successors and phi nodes
438   BranchInst *ForeTerm =
439       cast<BranchInst>(ForeBlocksLast.back()->getTerminator());
440   BasicBlock *Dest = SubLoopBlocksFirst[0];
441   ForeTerm->setSuccessor(0, Dest);
442 
443   if (CompletelyUnroll) {
444     while (PHINode *Phi = dyn_cast<PHINode>(ForeBlocksFirst[0]->begin())) {
445       Phi->replaceAllUsesWith(Phi->getIncomingValueForBlock(Preheader));
446       Phi->getParent()->getInstList().erase(Phi);
447     }
448   } else {
449     // Update the PHI values to point to the last aft block
450     updatePHIBlocksAndValues(ForeBlocksFirst[0], AftBlocksLast[0],
451                              AftBlocksLast.back(), LastValueMap);
452   }
453 
454   for (unsigned It = 1; It != Count; It++) {
455     // Remap ForeBlock successors from previous iteration to this
456     BranchInst *ForeTerm =
457         cast<BranchInst>(ForeBlocksLast[It - 1]->getTerminator());
458     BasicBlock *Dest = ForeBlocksFirst[It];
459     ForeTerm->setSuccessor(0, Dest);
460   }
461 
462   // Subloop successors and phis
463   BranchInst *SubTerm =
464       cast<BranchInst>(SubLoopBlocksLast.back()->getTerminator());
465   SubTerm->setSuccessor(!SubLoopContinueOnTrue, SubLoopBlocksFirst[0]);
466   SubTerm->setSuccessor(SubLoopContinueOnTrue, AftBlocksFirst[0]);
467   updatePHIBlocks(SubLoopBlocksFirst[0], ForeBlocksLast[0],
468                   ForeBlocksLast.back());
469   updatePHIBlocks(SubLoopBlocksFirst[0], SubLoopBlocksLast[0],
470                   SubLoopBlocksLast.back());
471 
472   for (unsigned It = 1; It != Count; It++) {
473     // Replace the conditional branch of the previous iteration subloop with an
474     // unconditional one to this one
475     BranchInst *SubTerm =
476         cast<BranchInst>(SubLoopBlocksLast[It - 1]->getTerminator());
477     BranchInst::Create(SubLoopBlocksFirst[It], SubTerm);
478     SubTerm->eraseFromParent();
479 
480     updatePHIBlocks(SubLoopBlocksFirst[It], ForeBlocksLast[It],
481                     ForeBlocksLast.back());
482     updatePHIBlocks(SubLoopBlocksFirst[It], SubLoopBlocksLast[It],
483                     SubLoopBlocksLast.back());
484     movePHIs(SubLoopBlocksFirst[It], SubLoopBlocksFirst[0]);
485   }
486 
487   // Aft blocks successors and phis
488   BranchInst *Term = cast<BranchInst>(AftBlocksLast.back()->getTerminator());
489   if (CompletelyUnroll) {
490     BranchInst::Create(LoopExit, Term);
491     Term->eraseFromParent();
492   } else {
493     Term->setSuccessor(!ContinueOnTrue, ForeBlocksFirst[0]);
494   }
495   updatePHIBlocks(AftBlocksFirst[0], SubLoopBlocksLast[0],
496                   SubLoopBlocksLast.back());
497 
498   for (unsigned It = 1; It != Count; It++) {
499     // Replace the conditional branch of the previous iteration subloop with an
500     // unconditional one to this one
501     BranchInst *AftTerm =
502         cast<BranchInst>(AftBlocksLast[It - 1]->getTerminator());
503     BranchInst::Create(AftBlocksFirst[It], AftTerm);
504     AftTerm->eraseFromParent();
505 
506     updatePHIBlocks(AftBlocksFirst[It], SubLoopBlocksLast[It],
507                     SubLoopBlocksLast.back());
508     movePHIs(AftBlocksFirst[It], AftBlocksFirst[0]);
509   }
510 
511   // Dominator Tree. Remove the old links between Fore, Sub and Aft, adding the
512   // new ones required.
513   if (Count != 1) {
514     SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
515     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete, ForeBlocksLast[0],
516                            SubLoopBlocksFirst[0]);
517     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete,
518                            SubLoopBlocksLast[0], AftBlocksFirst[0]);
519 
520     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
521                            ForeBlocksLast.back(), SubLoopBlocksFirst[0]);
522     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
523                            SubLoopBlocksLast.back(), AftBlocksFirst[0]);
524     DT->applyUpdates(DTUpdates);
525   }
526 
527   // Merge adjacent basic blocks, if possible.
528   SmallPtrSet<BasicBlock *, 16> MergeBlocks;
529   MergeBlocks.insert(ForeBlocksLast.begin(), ForeBlocksLast.end());
530   MergeBlocks.insert(SubLoopBlocksLast.begin(), SubLoopBlocksLast.end());
531   MergeBlocks.insert(AftBlocksLast.begin(), AftBlocksLast.end());
532   while (!MergeBlocks.empty()) {
533     BasicBlock *BB = *MergeBlocks.begin();
534     BranchInst *Term = dyn_cast<BranchInst>(BB->getTerminator());
535     if (Term && Term->isUnconditional() && L->contains(Term->getSuccessor(0))) {
536       BasicBlock *Dest = Term->getSuccessor(0);
537       if (BasicBlock *Fold = foldBlockIntoPredecessor(Dest, LI, SE, DT)) {
538         // Don't remove BB and add Fold as they are the same BB
539         assert(Fold == BB);
540         (void)Fold;
541         MergeBlocks.erase(Dest);
542       } else
543         MergeBlocks.erase(BB);
544     } else
545       MergeBlocks.erase(BB);
546   }
547 
548   // At this point, the code is well formed.  We now do a quick sweep over the
549   // inserted code, doing constant propagation and dead code elimination as we
550   // go.
551   simplifyLoopAfterUnroll(SubLoop, true, LI, SE, DT, AC);
552   simplifyLoopAfterUnroll(L, !CompletelyUnroll && Count > 1, LI, SE, DT, AC);
553 
554   NumCompletelyUnrolledAndJammed += CompletelyUnroll;
555   ++NumUnrolledAndJammed;
556 
557 #ifndef NDEBUG
558   // We shouldn't have done anything to break loop simplify form or LCSSA.
559   Loop *OuterL = L->getParentLoop();
560   Loop *OutestLoop = OuterL ? OuterL : (!CompletelyUnroll ? L : SubLoop);
561   assert(OutestLoop->isRecursivelyLCSSAForm(*DT, *LI));
562   if (!CompletelyUnroll)
563     assert(L->isLoopSimplifyForm());
564   assert(SubLoop->isLoopSimplifyForm());
565   assert(DT->verify());
566 #endif
567 
568   // Update LoopInfo if the loop is completely removed.
569   if (CompletelyUnroll)
570     LI->erase(L);
571 
572   return CompletelyUnroll ? LoopUnrollResult::FullyUnrolled
573                           : LoopUnrollResult::PartiallyUnrolled;
574 }
575 
576 static bool getLoadsAndStores(BasicBlockSet &Blocks,
577                               SmallVector<Value *, 4> &MemInstr) {
578   // Scan the BBs and collect legal loads and stores.
579   // Returns false if non-simple loads/stores are found.
580   for (BasicBlock *BB : Blocks) {
581     for (Instruction &I : *BB) {
582       if (auto *Ld = dyn_cast<LoadInst>(&I)) {
583         if (!Ld->isSimple())
584           return false;
585         MemInstr.push_back(&I);
586       } else if (auto *St = dyn_cast<StoreInst>(&I)) {
587         if (!St->isSimple())
588           return false;
589         MemInstr.push_back(&I);
590       } else if (I.mayReadOrWriteMemory()) {
591         return false;
592       }
593     }
594   }
595   return true;
596 }
597 
598 static bool checkDependencies(SmallVector<Value *, 4> &Earlier,
599                               SmallVector<Value *, 4> &Later,
600                               unsigned LoopDepth, bool InnerLoop,
601                               DependenceInfo &DI) {
602   // Use DA to check for dependencies between loads and stores that make unroll
603   // and jam invalid
604   for (Value *I : Earlier) {
605     for (Value *J : Later) {
606       Instruction *Src = cast<Instruction>(I);
607       Instruction *Dst = cast<Instruction>(J);
608       if (Src == Dst)
609         continue;
610       // Ignore Input dependencies.
611       if (isa<LoadInst>(Src) && isa<LoadInst>(Dst))
612         continue;
613 
614       // Track dependencies, and if we find them take a conservative approach
615       // by allowing only = or < (not >), altough some > would be safe
616       // (depending upon unroll width).
617       // For the inner loop, we need to disallow any (> <) dependencies
618       // FIXME: Allow > so long as distance is less than unroll width
619       if (auto D = DI.depends(Src, Dst, true)) {
620         assert(D->isOrdered() && "Expected an output, flow or anti dep.");
621 
622         if (D->isConfused()) {
623           LLVM_DEBUG(dbgs() << "  Confused dependency between:\n"
624                             << "  " << *Src << "\n"
625                             << "  " << *Dst << "\n");
626           return false;
627         }
628         if (!InnerLoop) {
629           if (D->getDirection(LoopDepth) & Dependence::DVEntry::GT) {
630             LLVM_DEBUG(dbgs() << "  > dependency between:\n"
631                               << "  " << *Src << "\n"
632                               << "  " << *Dst << "\n");
633             return false;
634           }
635         } else {
636           assert(LoopDepth + 1 <= D->getLevels());
637           if (D->getDirection(LoopDepth) & Dependence::DVEntry::GT &&
638               D->getDirection(LoopDepth + 1) & Dependence::DVEntry::LT) {
639             LLVM_DEBUG(dbgs() << "  < > dependency between:\n"
640                               << "  " << *Src << "\n"
641                               << "  " << *Dst << "\n");
642             return false;
643           }
644         }
645       }
646     }
647   }
648   return true;
649 }
650 
651 static bool checkDependencies(Loop *L, BasicBlockSet &ForeBlocks,
652                               BasicBlockSet &SubLoopBlocks,
653                               BasicBlockSet &AftBlocks, DependenceInfo &DI) {
654   // Get all loads/store pairs for each blocks
655   SmallVector<Value *, 4> ForeMemInstr;
656   SmallVector<Value *, 4> SubLoopMemInstr;
657   SmallVector<Value *, 4> AftMemInstr;
658   if (!getLoadsAndStores(ForeBlocks, ForeMemInstr) ||
659       !getLoadsAndStores(SubLoopBlocks, SubLoopMemInstr) ||
660       !getLoadsAndStores(AftBlocks, AftMemInstr))
661     return false;
662 
663   // Check for dependencies between any blocks that may change order
664   unsigned LoopDepth = L->getLoopDepth();
665   return checkDependencies(ForeMemInstr, SubLoopMemInstr, LoopDepth, false,
666                            DI) &&
667          checkDependencies(ForeMemInstr, AftMemInstr, LoopDepth, false, DI) &&
668          checkDependencies(SubLoopMemInstr, AftMemInstr, LoopDepth, false,
669                            DI) &&
670          checkDependencies(SubLoopMemInstr, SubLoopMemInstr, LoopDepth, true,
671                            DI);
672 }
673 
674 bool llvm::isSafeToUnrollAndJam(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
675                                 DependenceInfo &DI) {
676   /* We currently handle outer loops like this:
677         |
678     ForeFirst    <----\    }
679      Blocks           |    } ForeBlocks
680     ForeLast          |    }
681         |             |
682     SubLoopFirst  <\  |    }
683      Blocks        |  |    } SubLoopBlocks
684     SubLoopLast   -/  |    }
685         |             |
686     AftFirst          |    }
687      Blocks           |    } AftBlocks
688     AftLast     ------/    }
689         |
690 
691     There are (theoretically) any number of blocks in ForeBlocks, SubLoopBlocks
692     and AftBlocks, providing that there is one edge from Fores to SubLoops,
693     one edge from SubLoops to Afts and a single outer loop exit (from Afts).
694     In practice we currently limit Aft blocks to a single block, and limit
695     things further in the profitablility checks of the unroll and jam pass.
696 
697     Because of the way we rearrange basic blocks, we also require that
698     the Fore blocks on all unrolled iterations are safe to move before the
699     SubLoop blocks of all iterations. So we require that the phi node looping
700     operands of ForeHeader can be moved to at least the end of ForeEnd, so that
701     we can arrange cloned Fore Blocks before the subloop and match up Phi's
702     correctly.
703 
704     i.e. The old order of blocks used to be F1 S1_1 S1_2 A1 F2 S2_1 S2_2 A2.
705     It needs to be safe to tranform this to F1 F2 S1_1 S2_1 S1_2 S2_2 A1 A2.
706 
707     There are then a number of checks along the lines of no calls, no
708     exceptions, inner loop IV is consistent, etc. Note that for loops requiring
709     runtime unrolling, UnrollRuntimeLoopRemainder can also fail in
710     UnrollAndJamLoop if the trip count cannot be easily calculated.
711   */
712 
713   if (!L->isLoopSimplifyForm() || L->getSubLoops().size() != 1)
714     return false;
715   Loop *SubLoop = L->getSubLoops()[0];
716   if (!SubLoop->isLoopSimplifyForm())
717     return false;
718 
719   BasicBlock *Header = L->getHeader();
720   BasicBlock *Latch = L->getLoopLatch();
721   BasicBlock *Exit = L->getExitingBlock();
722   BasicBlock *SubLoopHeader = SubLoop->getHeader();
723   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
724   BasicBlock *SubLoopExit = SubLoop->getExitingBlock();
725 
726   if (Latch != Exit)
727     return false;
728   if (SubLoopLatch != SubLoopExit)
729     return false;
730 
731   if (Header->hasAddressTaken() || SubLoopHeader->hasAddressTaken()) {
732     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Address taken\n");
733     return false;
734   }
735 
736   // Split blocks into Fore/SubLoop/Aft based on dominators
737   BasicBlockSet SubLoopBlocks;
738   BasicBlockSet ForeBlocks;
739   BasicBlockSet AftBlocks;
740   if (!partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks,
741                                 AftBlocks, &DT)) {
742     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Incompatible loop layout\n");
743     return false;
744   }
745 
746   // Aft blocks may need to move instructions to fore blocks, which becomes more
747   // difficult if there are multiple (potentially conditionally executed)
748   // blocks. For now we just exclude loops with multiple aft blocks.
749   if (AftBlocks.size() != 1) {
750     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Can't currently handle "
751                          "multiple blocks after the loop\n");
752     return false;
753   }
754 
755   // Check inner loop backedge count is consistent on all iterations of the
756   // outer loop
757   if (!hasIterationCountInvariantInParent(SubLoop, SE)) {
758     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Inner loop iteration count is "
759                          "not consistent on each iteration\n");
760     return false;
761   }
762 
763   // Check the loop safety info for exceptions.
764   LoopSafetyInfo LSI;
765   LSI.computeLoopSafetyInfo(L);
766   if (LSI.anyBlockMayThrow()) {
767     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Something may throw\n");
768     return false;
769   }
770 
771   // We've ruled out the easy stuff and now need to check that there are no
772   // interdependencies which may prevent us from moving the:
773   //  ForeBlocks before Subloop and AftBlocks.
774   //  Subloop before AftBlocks.
775   //  ForeBlock phi operands before the subloop
776 
777   // Make sure we can move all instructions we need to before the subloop
778   if (!processHeaderPhiOperands(
779           Header, Latch, AftBlocks, [&AftBlocks, &SubLoop](Instruction *I) {
780             if (SubLoop->contains(I->getParent()))
781               return false;
782             if (AftBlocks.count(I->getParent())) {
783               // If we hit a phi node in afts we know we are done (probably
784               // LCSSA)
785               if (isa<PHINode>(I))
786                 return false;
787               // Can't move instructions with side effects or memory
788               // reads/writes
789               if (I->mayHaveSideEffects() || I->mayReadOrWriteMemory())
790                 return false;
791             }
792             // Keep going
793             return true;
794           })) {
795     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; can't move required "
796                          "instructions after subloop to before it\n");
797     return false;
798   }
799 
800   // Check for memory dependencies which prohibit the unrolling we are doing.
801   // Because of the way we are unrolling Fore/Sub/Aft blocks, we need to check
802   // there are no dependencies between Fore-Sub, Fore-Aft, Sub-Aft and Sub-Sub.
803   if (!checkDependencies(L, ForeBlocks, SubLoopBlocks, AftBlocks, DI)) {
804     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; failed dependency check\n");
805     return false;
806   }
807 
808   return true;
809 }
810