1 //===-- LoopUnrollAndJam.cpp - Loop unrolling utilities -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop unroll and jam as a routine, much like
10 // LoopUnroll.cpp implements loop unroll.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/SmallPtrSet.h"
15 #include "llvm/ADT/Statistic.h"
16 #include "llvm/Analysis/AssumptionCache.h"
17 #include "llvm/Analysis/DependenceAnalysis.h"
18 #include "llvm/Analysis/InstructionSimplify.h"
19 #include "llvm/Analysis/LoopAnalysisManager.h"
20 #include "llvm/Analysis/LoopIterator.h"
21 #include "llvm/Analysis/LoopPass.h"
22 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
23 #include "llvm/Analysis/ScalarEvolution.h"
24 #include "llvm/Analysis/ScalarEvolutionExpander.h"
25 #include "llvm/Analysis/Utils/Local.h"
26 #include "llvm/IR/BasicBlock.h"
27 #include "llvm/IR/DataLayout.h"
28 #include "llvm/IR/DebugInfoMetadata.h"
29 #include "llvm/IR/Dominators.h"
30 #include "llvm/IR/IntrinsicInst.h"
31 #include "llvm/IR/LLVMContext.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/Support/raw_ostream.h"
34 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
35 #include "llvm/Transforms/Utils/Cloning.h"
36 #include "llvm/Transforms/Utils/LoopSimplify.h"
37 #include "llvm/Transforms/Utils/LoopUtils.h"
38 #include "llvm/Transforms/Utils/SimplifyIndVar.h"
39 #include "llvm/Transforms/Utils/UnrollLoop.h"
40 using namespace llvm;
41 
42 #define DEBUG_TYPE "loop-unroll-and-jam"
43 
44 STATISTIC(NumUnrolledAndJammed, "Number of loops unroll and jammed");
45 STATISTIC(NumCompletelyUnrolledAndJammed, "Number of loops unroll and jammed");
46 
47 typedef SmallPtrSet<BasicBlock *, 4> BasicBlockSet;
48 
49 // Partition blocks in an outer/inner loop pair into blocks before and after
50 // the loop
51 static bool partitionOuterLoopBlocks(Loop *L, Loop *SubLoop,
52                                      BasicBlockSet &ForeBlocks,
53                                      BasicBlockSet &SubLoopBlocks,
54                                      BasicBlockSet &AftBlocks,
55                                      DominatorTree *DT) {
56   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
57   SubLoopBlocks.insert(SubLoop->block_begin(), SubLoop->block_end());
58 
59   for (BasicBlock *BB : L->blocks()) {
60     if (!SubLoop->contains(BB)) {
61       if (DT->dominates(SubLoopLatch, BB))
62         AftBlocks.insert(BB);
63       else
64         ForeBlocks.insert(BB);
65     }
66   }
67 
68   // Check that all blocks in ForeBlocks together dominate the subloop
69   // TODO: This might ideally be done better with a dominator/postdominators.
70   BasicBlock *SubLoopPreHeader = SubLoop->getLoopPreheader();
71   for (BasicBlock *BB : ForeBlocks) {
72     if (BB == SubLoopPreHeader)
73       continue;
74     Instruction *TI = BB->getTerminator();
75     for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
76       if (!ForeBlocks.count(TI->getSuccessor(i)))
77         return false;
78   }
79 
80   return true;
81 }
82 
83 // Looks at the phi nodes in Header for values coming from Latch. For these
84 // instructions and all their operands calls Visit on them, keeping going for
85 // all the operands in AftBlocks. Returns false if Visit returns false,
86 // otherwise returns true. This is used to process the instructions in the
87 // Aft blocks that need to be moved before the subloop. It is used in two
88 // places. One to check that the required set of instructions can be moved
89 // before the loop. Then to collect the instructions to actually move in
90 // moveHeaderPhiOperandsToForeBlocks.
91 template <typename T>
92 static bool processHeaderPhiOperands(BasicBlock *Header, BasicBlock *Latch,
93                                      BasicBlockSet &AftBlocks, T Visit) {
94   SmallVector<Instruction *, 8> Worklist;
95   for (auto &Phi : Header->phis()) {
96     Value *V = Phi.getIncomingValueForBlock(Latch);
97     if (Instruction *I = dyn_cast<Instruction>(V))
98       Worklist.push_back(I);
99   }
100 
101   while (!Worklist.empty()) {
102     Instruction *I = Worklist.back();
103     Worklist.pop_back();
104     if (!Visit(I))
105       return false;
106 
107     if (AftBlocks.count(I->getParent()))
108       for (auto &U : I->operands())
109         if (Instruction *II = dyn_cast<Instruction>(U))
110           Worklist.push_back(II);
111   }
112 
113   return true;
114 }
115 
116 // Move the phi operands of Header from Latch out of AftBlocks to InsertLoc.
117 static void moveHeaderPhiOperandsToForeBlocks(BasicBlock *Header,
118                                               BasicBlock *Latch,
119                                               Instruction *InsertLoc,
120                                               BasicBlockSet &AftBlocks) {
121   // We need to ensure we move the instructions in the correct order,
122   // starting with the earliest required instruction and moving forward.
123   std::vector<Instruction *> Visited;
124   processHeaderPhiOperands(Header, Latch, AftBlocks,
125                            [&Visited, &AftBlocks](Instruction *I) {
126                              if (AftBlocks.count(I->getParent()))
127                                Visited.push_back(I);
128                              return true;
129                            });
130 
131   // Move all instructions in program order to before the InsertLoc
132   BasicBlock *InsertLocBB = InsertLoc->getParent();
133   for (Instruction *I : reverse(Visited)) {
134     if (I->getParent() != InsertLocBB)
135       I->moveBefore(InsertLoc);
136   }
137 }
138 
139 /*
140   This method performs Unroll and Jam. For a simple loop like:
141   for (i = ..)
142     Fore(i)
143     for (j = ..)
144       SubLoop(i, j)
145     Aft(i)
146 
147   Instead of doing normal inner or outer unrolling, we do:
148   for (i = .., i+=2)
149     Fore(i)
150     Fore(i+1)
151     for (j = ..)
152       SubLoop(i, j)
153       SubLoop(i+1, j)
154     Aft(i)
155     Aft(i+1)
156 
157   So the outer loop is essetially unrolled and then the inner loops are fused
158   ("jammed") together into a single loop. This can increase speed when there
159   are loads in SubLoop that are invariant to i, as they become shared between
160   the now jammed inner loops.
161 
162   We do this by spliting the blocks in the loop into Fore, Subloop and Aft.
163   Fore blocks are those before the inner loop, Aft are those after. Normal
164   Unroll code is used to copy each of these sets of blocks and the results are
165   combined together into the final form above.
166 
167   isSafeToUnrollAndJam should be used prior to calling this to make sure the
168   unrolling will be valid. Checking profitablility is also advisable.
169 
170   If EpilogueLoop is non-null, it receives the epilogue loop (if it was
171   necessary to create one and not fully unrolled).
172 */
173 LoopUnrollResult llvm::UnrollAndJamLoop(
174     Loop *L, unsigned Count, unsigned TripCount, unsigned TripMultiple,
175     bool UnrollRemainder, LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT,
176     AssumptionCache *AC, OptimizationRemarkEmitter *ORE, Loop **EpilogueLoop) {
177 
178   // When we enter here we should have already checked that it is safe
179   BasicBlock *Header = L->getHeader();
180   assert(Header && "No header.");
181   assert(L->getSubLoops().size() == 1);
182   Loop *SubLoop = *L->begin();
183 
184   // Don't enter the unroll code if there is nothing to do.
185   if (TripCount == 0 && Count < 2) {
186     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; almost nothing to do\n");
187     return LoopUnrollResult::Unmodified;
188   }
189 
190   assert(Count > 0);
191   assert(TripMultiple > 0);
192   assert(TripCount == 0 || TripCount % TripMultiple == 0);
193 
194   // Are we eliminating the loop control altogether?
195   bool CompletelyUnroll = (Count == TripCount);
196 
197   // We use the runtime remainder in cases where we don't know trip multiple
198   if (TripMultiple == 1 || TripMultiple % Count != 0) {
199     if (!UnrollRuntimeLoopRemainder(L, Count, /*AllowExpensiveTripCount*/ false,
200                                     /*UseEpilogRemainder*/ true,
201                                     UnrollRemainder, /*ForgetAllSCEV*/ false,
202                                     LI, SE, DT, AC, true, EpilogueLoop)) {
203       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; remainder loop could not be "
204                            "generated when assuming runtime trip count\n");
205       return LoopUnrollResult::Unmodified;
206     }
207   }
208 
209   // Notify ScalarEvolution that the loop will be substantially changed,
210   // if not outright eliminated.
211   if (SE) {
212     SE->forgetLoop(L);
213     SE->forgetLoop(SubLoop);
214   }
215 
216   using namespace ore;
217   // Report the unrolling decision.
218   if (CompletelyUnroll) {
219     LLVM_DEBUG(dbgs() << "COMPLETELY UNROLL AND JAMMING loop %"
220                       << Header->getName() << " with trip count " << TripCount
221                       << "!\n");
222     ORE->emit(OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(),
223                                  L->getHeader())
224               << "completely unroll and jammed loop with "
225               << NV("UnrollCount", TripCount) << " iterations");
226   } else {
227     auto DiagBuilder = [&]() {
228       OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(),
229                               L->getHeader());
230       return Diag << "unroll and jammed loop by a factor of "
231                   << NV("UnrollCount", Count);
232     };
233 
234     LLVM_DEBUG(dbgs() << "UNROLL AND JAMMING loop %" << Header->getName()
235                       << " by " << Count);
236     if (TripMultiple != 1) {
237       LLVM_DEBUG(dbgs() << " with " << TripMultiple << " trips per branch");
238       ORE->emit([&]() {
239         return DiagBuilder() << " with " << NV("TripMultiple", TripMultiple)
240                              << " trips per branch";
241       });
242     } else {
243       LLVM_DEBUG(dbgs() << " with run-time trip count");
244       ORE->emit([&]() { return DiagBuilder() << " with run-time trip count"; });
245     }
246     LLVM_DEBUG(dbgs() << "!\n");
247   }
248 
249   BasicBlock *Preheader = L->getLoopPreheader();
250   BasicBlock *LatchBlock = L->getLoopLatch();
251   assert(Preheader && "No preheader");
252   assert(LatchBlock && "No latch block");
253   BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator());
254   assert(BI && !BI->isUnconditional());
255   bool ContinueOnTrue = L->contains(BI->getSuccessor(0));
256   BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue);
257   bool SubLoopContinueOnTrue = SubLoop->contains(
258       SubLoop->getLoopLatch()->getTerminator()->getSuccessor(0));
259 
260   // Partition blocks in an outer/inner loop pair into blocks before and after
261   // the loop
262   BasicBlockSet SubLoopBlocks;
263   BasicBlockSet ForeBlocks;
264   BasicBlockSet AftBlocks;
265   partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks, AftBlocks,
266                            DT);
267 
268   // We keep track of the entering/first and exiting/last block of each of
269   // Fore/SubLoop/Aft in each iteration. This helps make the stapling up of
270   // blocks easier.
271   std::vector<BasicBlock *> ForeBlocksFirst;
272   std::vector<BasicBlock *> ForeBlocksLast;
273   std::vector<BasicBlock *> SubLoopBlocksFirst;
274   std::vector<BasicBlock *> SubLoopBlocksLast;
275   std::vector<BasicBlock *> AftBlocksFirst;
276   std::vector<BasicBlock *> AftBlocksLast;
277   ForeBlocksFirst.push_back(Header);
278   ForeBlocksLast.push_back(SubLoop->getLoopPreheader());
279   SubLoopBlocksFirst.push_back(SubLoop->getHeader());
280   SubLoopBlocksLast.push_back(SubLoop->getExitingBlock());
281   AftBlocksFirst.push_back(SubLoop->getExitBlock());
282   AftBlocksLast.push_back(L->getExitingBlock());
283   // Maps Blocks[0] -> Blocks[It]
284   ValueToValueMapTy LastValueMap;
285 
286   // Move any instructions from fore phi operands from AftBlocks into Fore.
287   moveHeaderPhiOperandsToForeBlocks(
288       Header, LatchBlock, SubLoop->getLoopPreheader()->getTerminator(),
289       AftBlocks);
290 
291   // The current on-the-fly SSA update requires blocks to be processed in
292   // reverse postorder so that LastValueMap contains the correct value at each
293   // exit.
294   LoopBlocksDFS DFS(L);
295   DFS.perform(LI);
296   // Stash the DFS iterators before adding blocks to the loop.
297   LoopBlocksDFS::RPOIterator BlockBegin = DFS.beginRPO();
298   LoopBlocksDFS::RPOIterator BlockEnd = DFS.endRPO();
299 
300   if (Header->getParent()->isDebugInfoForProfiling())
301     for (BasicBlock *BB : L->getBlocks())
302       for (Instruction &I : *BB)
303         if (!isa<DbgInfoIntrinsic>(&I))
304           if (const DILocation *DIL = I.getDebugLoc()) {
305             auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(Count);
306             if (NewDIL)
307               I.setDebugLoc(NewDIL.getValue());
308             else
309               LLVM_DEBUG(dbgs()
310                          << "Failed to create new discriminator: "
311                          << DIL->getFilename() << " Line: " << DIL->getLine());
312           }
313 
314   // Copy all blocks
315   for (unsigned It = 1; It != Count; ++It) {
316     std::vector<BasicBlock *> NewBlocks;
317     // Maps Blocks[It] -> Blocks[It-1]
318     DenseMap<Value *, Value *> PrevItValueMap;
319 
320     for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) {
321       ValueToValueMapTy VMap;
322       BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It));
323       Header->getParent()->getBasicBlockList().push_back(New);
324 
325       if (ForeBlocks.count(*BB)) {
326         L->addBasicBlockToLoop(New, *LI);
327 
328         if (*BB == ForeBlocksFirst[0])
329           ForeBlocksFirst.push_back(New);
330         if (*BB == ForeBlocksLast[0])
331           ForeBlocksLast.push_back(New);
332       } else if (SubLoopBlocks.count(*BB)) {
333         SubLoop->addBasicBlockToLoop(New, *LI);
334 
335         if (*BB == SubLoopBlocksFirst[0])
336           SubLoopBlocksFirst.push_back(New);
337         if (*BB == SubLoopBlocksLast[0])
338           SubLoopBlocksLast.push_back(New);
339       } else if (AftBlocks.count(*BB)) {
340         L->addBasicBlockToLoop(New, *LI);
341 
342         if (*BB == AftBlocksFirst[0])
343           AftBlocksFirst.push_back(New);
344         if (*BB == AftBlocksLast[0])
345           AftBlocksLast.push_back(New);
346       } else {
347         llvm_unreachable("BB being cloned should be in Fore/Sub/Aft");
348       }
349 
350       // Update our running maps of newest clones
351       PrevItValueMap[New] = (It == 1 ? *BB : LastValueMap[*BB]);
352       LastValueMap[*BB] = New;
353       for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end();
354            VI != VE; ++VI) {
355         PrevItValueMap[VI->second] =
356             const_cast<Value *>(It == 1 ? VI->first : LastValueMap[VI->first]);
357         LastValueMap[VI->first] = VI->second;
358       }
359 
360       NewBlocks.push_back(New);
361 
362       // Update DomTree:
363       if (*BB == ForeBlocksFirst[0])
364         DT->addNewBlock(New, ForeBlocksLast[It - 1]);
365       else if (*BB == SubLoopBlocksFirst[0])
366         DT->addNewBlock(New, SubLoopBlocksLast[It - 1]);
367       else if (*BB == AftBlocksFirst[0])
368         DT->addNewBlock(New, AftBlocksLast[It - 1]);
369       else {
370         // Each set of blocks (Fore/Sub/Aft) will have the same internal domtree
371         // structure.
372         auto BBDomNode = DT->getNode(*BB);
373         auto BBIDom = BBDomNode->getIDom();
374         BasicBlock *OriginalBBIDom = BBIDom->getBlock();
375         assert(OriginalBBIDom);
376         assert(LastValueMap[cast<Value>(OriginalBBIDom)]);
377         DT->addNewBlock(
378             New, cast<BasicBlock>(LastValueMap[cast<Value>(OriginalBBIDom)]));
379       }
380     }
381 
382     // Remap all instructions in the most recent iteration
383     for (BasicBlock *NewBlock : NewBlocks) {
384       for (Instruction &I : *NewBlock) {
385         ::remapInstruction(&I, LastValueMap);
386         if (auto *II = dyn_cast<IntrinsicInst>(&I))
387           if (II->getIntrinsicID() == Intrinsic::assume)
388             AC->registerAssumption(II);
389       }
390     }
391 
392     // Alter the ForeBlocks phi's, pointing them at the latest version of the
393     // value from the previous iteration's phis
394     for (PHINode &Phi : ForeBlocksFirst[It]->phis()) {
395       Value *OldValue = Phi.getIncomingValueForBlock(AftBlocksLast[It]);
396       assert(OldValue && "should have incoming edge from Aft[It]");
397       Value *NewValue = OldValue;
398       if (Value *PrevValue = PrevItValueMap[OldValue])
399         NewValue = PrevValue;
400 
401       assert(Phi.getNumOperands() == 2);
402       Phi.setIncomingBlock(0, ForeBlocksLast[It - 1]);
403       Phi.setIncomingValue(0, NewValue);
404       Phi.removeIncomingValue(1);
405     }
406   }
407 
408   // Now that all the basic blocks for the unrolled iterations are in place,
409   // finish up connecting the blocks and phi nodes. At this point LastValueMap
410   // is the last unrolled iterations values.
411 
412   // Update Phis in BB from OldBB to point to NewBB
413   auto updatePHIBlocks = [](BasicBlock *BB, BasicBlock *OldBB,
414                             BasicBlock *NewBB) {
415     for (PHINode &Phi : BB->phis()) {
416       int I = Phi.getBasicBlockIndex(OldBB);
417       Phi.setIncomingBlock(I, NewBB);
418     }
419   };
420   // Update Phis in BB from OldBB to point to NewBB and use the latest value
421   // from LastValueMap
422   auto updatePHIBlocksAndValues = [](BasicBlock *BB, BasicBlock *OldBB,
423                                      BasicBlock *NewBB,
424                                      ValueToValueMapTy &LastValueMap) {
425     for (PHINode &Phi : BB->phis()) {
426       for (unsigned b = 0; b < Phi.getNumIncomingValues(); ++b) {
427         if (Phi.getIncomingBlock(b) == OldBB) {
428           Value *OldValue = Phi.getIncomingValue(b);
429           if (Value *LastValue = LastValueMap[OldValue])
430             Phi.setIncomingValue(b, LastValue);
431           Phi.setIncomingBlock(b, NewBB);
432           break;
433         }
434       }
435     }
436   };
437   // Move all the phis from Src into Dest
438   auto movePHIs = [](BasicBlock *Src, BasicBlock *Dest) {
439     Instruction *insertPoint = Dest->getFirstNonPHI();
440     while (PHINode *Phi = dyn_cast<PHINode>(Src->begin()))
441       Phi->moveBefore(insertPoint);
442   };
443 
444   // Update the PHI values outside the loop to point to the last block
445   updatePHIBlocksAndValues(LoopExit, AftBlocksLast[0], AftBlocksLast.back(),
446                            LastValueMap);
447 
448   // Update ForeBlocks successors and phi nodes
449   BranchInst *ForeTerm =
450       cast<BranchInst>(ForeBlocksLast.back()->getTerminator());
451   BasicBlock *Dest = SubLoopBlocksFirst[0];
452   ForeTerm->setSuccessor(0, Dest);
453 
454   if (CompletelyUnroll) {
455     while (PHINode *Phi = dyn_cast<PHINode>(ForeBlocksFirst[0]->begin())) {
456       Phi->replaceAllUsesWith(Phi->getIncomingValueForBlock(Preheader));
457       Phi->getParent()->getInstList().erase(Phi);
458     }
459   } else {
460     // Update the PHI values to point to the last aft block
461     updatePHIBlocksAndValues(ForeBlocksFirst[0], AftBlocksLast[0],
462                              AftBlocksLast.back(), LastValueMap);
463   }
464 
465   for (unsigned It = 1; It != Count; It++) {
466     // Remap ForeBlock successors from previous iteration to this
467     BranchInst *ForeTerm =
468         cast<BranchInst>(ForeBlocksLast[It - 1]->getTerminator());
469     BasicBlock *Dest = ForeBlocksFirst[It];
470     ForeTerm->setSuccessor(0, Dest);
471   }
472 
473   // Subloop successors and phis
474   BranchInst *SubTerm =
475       cast<BranchInst>(SubLoopBlocksLast.back()->getTerminator());
476   SubTerm->setSuccessor(!SubLoopContinueOnTrue, SubLoopBlocksFirst[0]);
477   SubTerm->setSuccessor(SubLoopContinueOnTrue, AftBlocksFirst[0]);
478   updatePHIBlocks(SubLoopBlocksFirst[0], ForeBlocksLast[0],
479                   ForeBlocksLast.back());
480   updatePHIBlocks(SubLoopBlocksFirst[0], SubLoopBlocksLast[0],
481                   SubLoopBlocksLast.back());
482 
483   for (unsigned It = 1; It != Count; It++) {
484     // Replace the conditional branch of the previous iteration subloop with an
485     // unconditional one to this one
486     BranchInst *SubTerm =
487         cast<BranchInst>(SubLoopBlocksLast[It - 1]->getTerminator());
488     BranchInst::Create(SubLoopBlocksFirst[It], SubTerm);
489     SubTerm->eraseFromParent();
490 
491     updatePHIBlocks(SubLoopBlocksFirst[It], ForeBlocksLast[It],
492                     ForeBlocksLast.back());
493     updatePHIBlocks(SubLoopBlocksFirst[It], SubLoopBlocksLast[It],
494                     SubLoopBlocksLast.back());
495     movePHIs(SubLoopBlocksFirst[It], SubLoopBlocksFirst[0]);
496   }
497 
498   // Aft blocks successors and phis
499   BranchInst *Term = cast<BranchInst>(AftBlocksLast.back()->getTerminator());
500   if (CompletelyUnroll) {
501     BranchInst::Create(LoopExit, Term);
502     Term->eraseFromParent();
503   } else {
504     Term->setSuccessor(!ContinueOnTrue, ForeBlocksFirst[0]);
505   }
506   updatePHIBlocks(AftBlocksFirst[0], SubLoopBlocksLast[0],
507                   SubLoopBlocksLast.back());
508 
509   for (unsigned It = 1; It != Count; It++) {
510     // Replace the conditional branch of the previous iteration subloop with an
511     // unconditional one to this one
512     BranchInst *AftTerm =
513         cast<BranchInst>(AftBlocksLast[It - 1]->getTerminator());
514     BranchInst::Create(AftBlocksFirst[It], AftTerm);
515     AftTerm->eraseFromParent();
516 
517     updatePHIBlocks(AftBlocksFirst[It], SubLoopBlocksLast[It],
518                     SubLoopBlocksLast.back());
519     movePHIs(AftBlocksFirst[It], AftBlocksFirst[0]);
520   }
521 
522   DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
523   // Dominator Tree. Remove the old links between Fore, Sub and Aft, adding the
524   // new ones required.
525   if (Count != 1) {
526     SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
527     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete, ForeBlocksLast[0],
528                            SubLoopBlocksFirst[0]);
529     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete,
530                            SubLoopBlocksLast[0], AftBlocksFirst[0]);
531 
532     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
533                            ForeBlocksLast.back(), SubLoopBlocksFirst[0]);
534     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
535                            SubLoopBlocksLast.back(), AftBlocksFirst[0]);
536     DTU.applyUpdatesPermissive(DTUpdates);
537   }
538 
539   // Merge adjacent basic blocks, if possible.
540   SmallPtrSet<BasicBlock *, 16> MergeBlocks;
541   MergeBlocks.insert(ForeBlocksLast.begin(), ForeBlocksLast.end());
542   MergeBlocks.insert(SubLoopBlocksLast.begin(), SubLoopBlocksLast.end());
543   MergeBlocks.insert(AftBlocksLast.begin(), AftBlocksLast.end());
544   while (!MergeBlocks.empty()) {
545     BasicBlock *BB = *MergeBlocks.begin();
546     BranchInst *Term = dyn_cast<BranchInst>(BB->getTerminator());
547     if (Term && Term->isUnconditional() && L->contains(Term->getSuccessor(0))) {
548       BasicBlock *Dest = Term->getSuccessor(0);
549       BasicBlock *Fold = Dest->getUniquePredecessor();
550       if (MergeBlockIntoPredecessor(Dest, &DTU, LI)) {
551         // Don't remove BB and add Fold as they are the same BB
552         assert(Fold == BB);
553         (void)Fold;
554         MergeBlocks.erase(Dest);
555       } else
556         MergeBlocks.erase(BB);
557     } else
558       MergeBlocks.erase(BB);
559   }
560   // Apply updates to the DomTree.
561   DT = &DTU.getDomTree();
562 
563   // At this point, the code is well formed.  We now do a quick sweep over the
564   // inserted code, doing constant propagation and dead code elimination as we
565   // go.
566   simplifyLoopAfterUnroll(SubLoop, true, LI, SE, DT, AC);
567   simplifyLoopAfterUnroll(L, !CompletelyUnroll && Count > 1, LI, SE, DT, AC);
568 
569   NumCompletelyUnrolledAndJammed += CompletelyUnroll;
570   ++NumUnrolledAndJammed;
571 
572 #ifndef NDEBUG
573   // We shouldn't have done anything to break loop simplify form or LCSSA.
574   Loop *OuterL = L->getParentLoop();
575   Loop *OutestLoop = OuterL ? OuterL : (!CompletelyUnroll ? L : SubLoop);
576   assert(OutestLoop->isRecursivelyLCSSAForm(*DT, *LI));
577   if (!CompletelyUnroll)
578     assert(L->isLoopSimplifyForm());
579   assert(SubLoop->isLoopSimplifyForm());
580   assert(DT->verify());
581 #endif
582 
583   // Update LoopInfo if the loop is completely removed.
584   if (CompletelyUnroll)
585     LI->erase(L);
586 
587   return CompletelyUnroll ? LoopUnrollResult::FullyUnrolled
588                           : LoopUnrollResult::PartiallyUnrolled;
589 }
590 
591 static bool getLoadsAndStores(BasicBlockSet &Blocks,
592                               SmallVector<Value *, 4> &MemInstr) {
593   // Scan the BBs and collect legal loads and stores.
594   // Returns false if non-simple loads/stores are found.
595   for (BasicBlock *BB : Blocks) {
596     for (Instruction &I : *BB) {
597       if (auto *Ld = dyn_cast<LoadInst>(&I)) {
598         if (!Ld->isSimple())
599           return false;
600         MemInstr.push_back(&I);
601       } else if (auto *St = dyn_cast<StoreInst>(&I)) {
602         if (!St->isSimple())
603           return false;
604         MemInstr.push_back(&I);
605       } else if (I.mayReadOrWriteMemory()) {
606         return false;
607       }
608     }
609   }
610   return true;
611 }
612 
613 static bool checkDependencies(SmallVector<Value *, 4> &Earlier,
614                               SmallVector<Value *, 4> &Later,
615                               unsigned LoopDepth, bool InnerLoop,
616                               DependenceInfo &DI) {
617   // Use DA to check for dependencies between loads and stores that make unroll
618   // and jam invalid
619   for (Value *I : Earlier) {
620     for (Value *J : Later) {
621       Instruction *Src = cast<Instruction>(I);
622       Instruction *Dst = cast<Instruction>(J);
623       if (Src == Dst)
624         continue;
625       // Ignore Input dependencies.
626       if (isa<LoadInst>(Src) && isa<LoadInst>(Dst))
627         continue;
628 
629       // Track dependencies, and if we find them take a conservative approach
630       // by allowing only = or < (not >), altough some > would be safe
631       // (depending upon unroll width).
632       // For the inner loop, we need to disallow any (> <) dependencies
633       // FIXME: Allow > so long as distance is less than unroll width
634       if (auto D = DI.depends(Src, Dst, true)) {
635         assert(D->isOrdered() && "Expected an output, flow or anti dep.");
636 
637         if (D->isConfused()) {
638           LLVM_DEBUG(dbgs() << "  Confused dependency between:\n"
639                             << "  " << *Src << "\n"
640                             << "  " << *Dst << "\n");
641           return false;
642         }
643         if (!InnerLoop) {
644           if (D->getDirection(LoopDepth) & Dependence::DVEntry::GT) {
645             LLVM_DEBUG(dbgs() << "  > dependency between:\n"
646                               << "  " << *Src << "\n"
647                               << "  " << *Dst << "\n");
648             return false;
649           }
650         } else {
651           assert(LoopDepth + 1 <= D->getLevels());
652           if (D->getDirection(LoopDepth) & Dependence::DVEntry::GT &&
653               D->getDirection(LoopDepth + 1) & Dependence::DVEntry::LT) {
654             LLVM_DEBUG(dbgs() << "  < > dependency between:\n"
655                               << "  " << *Src << "\n"
656                               << "  " << *Dst << "\n");
657             return false;
658           }
659         }
660       }
661     }
662   }
663   return true;
664 }
665 
666 static bool checkDependencies(Loop *L, BasicBlockSet &ForeBlocks,
667                               BasicBlockSet &SubLoopBlocks,
668                               BasicBlockSet &AftBlocks, DependenceInfo &DI) {
669   // Get all loads/store pairs for each blocks
670   SmallVector<Value *, 4> ForeMemInstr;
671   SmallVector<Value *, 4> SubLoopMemInstr;
672   SmallVector<Value *, 4> AftMemInstr;
673   if (!getLoadsAndStores(ForeBlocks, ForeMemInstr) ||
674       !getLoadsAndStores(SubLoopBlocks, SubLoopMemInstr) ||
675       !getLoadsAndStores(AftBlocks, AftMemInstr))
676     return false;
677 
678   // Check for dependencies between any blocks that may change order
679   unsigned LoopDepth = L->getLoopDepth();
680   return checkDependencies(ForeMemInstr, SubLoopMemInstr, LoopDepth, false,
681                            DI) &&
682          checkDependencies(ForeMemInstr, AftMemInstr, LoopDepth, false, DI) &&
683          checkDependencies(SubLoopMemInstr, AftMemInstr, LoopDepth, false,
684                            DI) &&
685          checkDependencies(SubLoopMemInstr, SubLoopMemInstr, LoopDepth, true,
686                            DI);
687 }
688 
689 bool llvm::isSafeToUnrollAndJam(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
690                                 DependenceInfo &DI) {
691   /* We currently handle outer loops like this:
692         |
693     ForeFirst    <----\    }
694      Blocks           |    } ForeBlocks
695     ForeLast          |    }
696         |             |
697     SubLoopFirst  <\  |    }
698      Blocks        |  |    } SubLoopBlocks
699     SubLoopLast   -/  |    }
700         |             |
701     AftFirst          |    }
702      Blocks           |    } AftBlocks
703     AftLast     ------/    }
704         |
705 
706     There are (theoretically) any number of blocks in ForeBlocks, SubLoopBlocks
707     and AftBlocks, providing that there is one edge from Fores to SubLoops,
708     one edge from SubLoops to Afts and a single outer loop exit (from Afts).
709     In practice we currently limit Aft blocks to a single block, and limit
710     things further in the profitablility checks of the unroll and jam pass.
711 
712     Because of the way we rearrange basic blocks, we also require that
713     the Fore blocks on all unrolled iterations are safe to move before the
714     SubLoop blocks of all iterations. So we require that the phi node looping
715     operands of ForeHeader can be moved to at least the end of ForeEnd, so that
716     we can arrange cloned Fore Blocks before the subloop and match up Phi's
717     correctly.
718 
719     i.e. The old order of blocks used to be F1 S1_1 S1_2 A1 F2 S2_1 S2_2 A2.
720     It needs to be safe to tranform this to F1 F2 S1_1 S2_1 S1_2 S2_2 A1 A2.
721 
722     There are then a number of checks along the lines of no calls, no
723     exceptions, inner loop IV is consistent, etc. Note that for loops requiring
724     runtime unrolling, UnrollRuntimeLoopRemainder can also fail in
725     UnrollAndJamLoop if the trip count cannot be easily calculated.
726   */
727 
728   if (!L->isLoopSimplifyForm() || L->getSubLoops().size() != 1)
729     return false;
730   Loop *SubLoop = L->getSubLoops()[0];
731   if (!SubLoop->isLoopSimplifyForm())
732     return false;
733 
734   BasicBlock *Header = L->getHeader();
735   BasicBlock *Latch = L->getLoopLatch();
736   BasicBlock *Exit = L->getExitingBlock();
737   BasicBlock *SubLoopHeader = SubLoop->getHeader();
738   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
739   BasicBlock *SubLoopExit = SubLoop->getExitingBlock();
740 
741   if (Latch != Exit)
742     return false;
743   if (SubLoopLatch != SubLoopExit)
744     return false;
745 
746   if (Header->hasAddressTaken() || SubLoopHeader->hasAddressTaken()) {
747     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Address taken\n");
748     return false;
749   }
750 
751   // Split blocks into Fore/SubLoop/Aft based on dominators
752   BasicBlockSet SubLoopBlocks;
753   BasicBlockSet ForeBlocks;
754   BasicBlockSet AftBlocks;
755   if (!partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks,
756                                 AftBlocks, &DT)) {
757     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Incompatible loop layout\n");
758     return false;
759   }
760 
761   // Aft blocks may need to move instructions to fore blocks, which becomes more
762   // difficult if there are multiple (potentially conditionally executed)
763   // blocks. For now we just exclude loops with multiple aft blocks.
764   if (AftBlocks.size() != 1) {
765     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Can't currently handle "
766                          "multiple blocks after the loop\n");
767     return false;
768   }
769 
770   // Check inner loop backedge count is consistent on all iterations of the
771   // outer loop
772   if (!hasIterationCountInvariantInParent(SubLoop, SE)) {
773     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Inner loop iteration count is "
774                          "not consistent on each iteration\n");
775     return false;
776   }
777 
778   // Check the loop safety info for exceptions.
779   SimpleLoopSafetyInfo LSI;
780   LSI.computeLoopSafetyInfo(L);
781   if (LSI.anyBlockMayThrow()) {
782     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Something may throw\n");
783     return false;
784   }
785 
786   // We've ruled out the easy stuff and now need to check that there are no
787   // interdependencies which may prevent us from moving the:
788   //  ForeBlocks before Subloop and AftBlocks.
789   //  Subloop before AftBlocks.
790   //  ForeBlock phi operands before the subloop
791 
792   // Make sure we can move all instructions we need to before the subloop
793   if (!processHeaderPhiOperands(
794           Header, Latch, AftBlocks, [&AftBlocks, &SubLoop](Instruction *I) {
795             if (SubLoop->contains(I->getParent()))
796               return false;
797             if (AftBlocks.count(I->getParent())) {
798               // If we hit a phi node in afts we know we are done (probably
799               // LCSSA)
800               if (isa<PHINode>(I))
801                 return false;
802               // Can't move instructions with side effects or memory
803               // reads/writes
804               if (I->mayHaveSideEffects() || I->mayReadOrWriteMemory())
805                 return false;
806             }
807             // Keep going
808             return true;
809           })) {
810     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; can't move required "
811                          "instructions after subloop to before it\n");
812     return false;
813   }
814 
815   // Check for memory dependencies which prohibit the unrolling we are doing.
816   // Because of the way we are unrolling Fore/Sub/Aft blocks, we need to check
817   // there are no dependencies between Fore-Sub, Fore-Aft, Sub-Aft and Sub-Sub.
818   if (!checkDependencies(L, ForeBlocks, SubLoopBlocks, AftBlocks, DI)) {
819     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; failed dependency check\n");
820     return false;
821   }
822 
823   return true;
824 }
825