1 //===-- LoopUnrollAndJam.cpp - Loop unrolling utilities -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop unroll and jam as a routine, much like
10 // LoopUnroll.cpp implements loop unroll.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/SmallPtrSet.h"
15 #include "llvm/ADT/Statistic.h"
16 #include "llvm/Analysis/AssumptionCache.h"
17 #include "llvm/Analysis/DependenceAnalysis.h"
18 #include "llvm/Analysis/InstructionSimplify.h"
19 #include "llvm/Analysis/LoopAnalysisManager.h"
20 #include "llvm/Analysis/LoopIterator.h"
21 #include "llvm/Analysis/LoopPass.h"
22 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
23 #include "llvm/Analysis/ScalarEvolution.h"
24 #include "llvm/Analysis/Utils/Local.h"
25 #include "llvm/IR/BasicBlock.h"
26 #include "llvm/IR/DataLayout.h"
27 #include "llvm/IR/DebugInfoMetadata.h"
28 #include "llvm/IR/Dominators.h"
29 #include "llvm/IR/IntrinsicInst.h"
30 #include "llvm/IR/LLVMContext.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
34 #include "llvm/Transforms/Utils/Cloning.h"
35 #include "llvm/Transforms/Utils/LoopSimplify.h"
36 #include "llvm/Transforms/Utils/LoopUtils.h"
37 #include "llvm/Transforms/Utils/SimplifyIndVar.h"
38 #include "llvm/Transforms/Utils/UnrollLoop.h"
39 using namespace llvm;
40 
41 #define DEBUG_TYPE "loop-unroll-and-jam"
42 
43 STATISTIC(NumUnrolledAndJammed, "Number of loops unroll and jammed");
44 STATISTIC(NumCompletelyUnrolledAndJammed, "Number of loops unroll and jammed");
45 
46 typedef SmallPtrSet<BasicBlock *, 4> BasicBlockSet;
47 
48 // Partition blocks in an outer/inner loop pair into blocks before and after
49 // the loop
50 static bool partitionOuterLoopBlocks(Loop *L, Loop *SubLoop,
51                                      BasicBlockSet &ForeBlocks,
52                                      BasicBlockSet &SubLoopBlocks,
53                                      BasicBlockSet &AftBlocks,
54                                      DominatorTree *DT) {
55   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
56   SubLoopBlocks.insert(SubLoop->block_begin(), SubLoop->block_end());
57 
58   for (BasicBlock *BB : L->blocks()) {
59     if (!SubLoop->contains(BB)) {
60       if (DT->dominates(SubLoopLatch, BB))
61         AftBlocks.insert(BB);
62       else
63         ForeBlocks.insert(BB);
64     }
65   }
66 
67   // Check that all blocks in ForeBlocks together dominate the subloop
68   // TODO: This might ideally be done better with a dominator/postdominators.
69   BasicBlock *SubLoopPreHeader = SubLoop->getLoopPreheader();
70   for (BasicBlock *BB : ForeBlocks) {
71     if (BB == SubLoopPreHeader)
72       continue;
73     Instruction *TI = BB->getTerminator();
74     for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
75       if (!ForeBlocks.count(TI->getSuccessor(i)))
76         return false;
77   }
78 
79   return true;
80 }
81 
82 // Looks at the phi nodes in Header for values coming from Latch. For these
83 // instructions and all their operands calls Visit on them, keeping going for
84 // all the operands in AftBlocks. Returns false if Visit returns false,
85 // otherwise returns true. This is used to process the instructions in the
86 // Aft blocks that need to be moved before the subloop. It is used in two
87 // places. One to check that the required set of instructions can be moved
88 // before the loop. Then to collect the instructions to actually move in
89 // moveHeaderPhiOperandsToForeBlocks.
90 template <typename T>
91 static bool processHeaderPhiOperands(BasicBlock *Header, BasicBlock *Latch,
92                                      BasicBlockSet &AftBlocks, T Visit) {
93   SmallVector<Instruction *, 8> Worklist;
94   for (auto &Phi : Header->phis()) {
95     Value *V = Phi.getIncomingValueForBlock(Latch);
96     if (Instruction *I = dyn_cast<Instruction>(V))
97       Worklist.push_back(I);
98   }
99 
100   while (!Worklist.empty()) {
101     Instruction *I = Worklist.back();
102     Worklist.pop_back();
103     if (!Visit(I))
104       return false;
105 
106     if (AftBlocks.count(I->getParent()))
107       for (auto &U : I->operands())
108         if (Instruction *II = dyn_cast<Instruction>(U))
109           Worklist.push_back(II);
110   }
111 
112   return true;
113 }
114 
115 // Move the phi operands of Header from Latch out of AftBlocks to InsertLoc.
116 static void moveHeaderPhiOperandsToForeBlocks(BasicBlock *Header,
117                                               BasicBlock *Latch,
118                                               Instruction *InsertLoc,
119                                               BasicBlockSet &AftBlocks) {
120   // We need to ensure we move the instructions in the correct order,
121   // starting with the earliest required instruction and moving forward.
122   std::vector<Instruction *> Visited;
123   processHeaderPhiOperands(Header, Latch, AftBlocks,
124                            [&Visited, &AftBlocks](Instruction *I) {
125                              if (AftBlocks.count(I->getParent()))
126                                Visited.push_back(I);
127                              return true;
128                            });
129 
130   // Move all instructions in program order to before the InsertLoc
131   BasicBlock *InsertLocBB = InsertLoc->getParent();
132   for (Instruction *I : reverse(Visited)) {
133     if (I->getParent() != InsertLocBB)
134       I->moveBefore(InsertLoc);
135   }
136 }
137 
138 /*
139   This method performs Unroll and Jam. For a simple loop like:
140   for (i = ..)
141     Fore(i)
142     for (j = ..)
143       SubLoop(i, j)
144     Aft(i)
145 
146   Instead of doing normal inner or outer unrolling, we do:
147   for (i = .., i+=2)
148     Fore(i)
149     Fore(i+1)
150     for (j = ..)
151       SubLoop(i, j)
152       SubLoop(i+1, j)
153     Aft(i)
154     Aft(i+1)
155 
156   So the outer loop is essetially unrolled and then the inner loops are fused
157   ("jammed") together into a single loop. This can increase speed when there
158   are loads in SubLoop that are invariant to i, as they become shared between
159   the now jammed inner loops.
160 
161   We do this by spliting the blocks in the loop into Fore, Subloop and Aft.
162   Fore blocks are those before the inner loop, Aft are those after. Normal
163   Unroll code is used to copy each of these sets of blocks and the results are
164   combined together into the final form above.
165 
166   isSafeToUnrollAndJam should be used prior to calling this to make sure the
167   unrolling will be valid. Checking profitablility is also advisable.
168 
169   If EpilogueLoop is non-null, it receives the epilogue loop (if it was
170   necessary to create one and not fully unrolled).
171 */
172 LoopUnrollResult llvm::UnrollAndJamLoop(
173     Loop *L, unsigned Count, unsigned TripCount, unsigned TripMultiple,
174     bool UnrollRemainder, LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT,
175     AssumptionCache *AC, OptimizationRemarkEmitter *ORE, Loop **EpilogueLoop) {
176 
177   // When we enter here we should have already checked that it is safe
178   BasicBlock *Header = L->getHeader();
179   assert(Header && "No header.");
180   assert(L->getSubLoops().size() == 1);
181   Loop *SubLoop = *L->begin();
182 
183   // Don't enter the unroll code if there is nothing to do.
184   if (TripCount == 0 && Count < 2) {
185     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; almost nothing to do\n");
186     return LoopUnrollResult::Unmodified;
187   }
188 
189   assert(Count > 0);
190   assert(TripMultiple > 0);
191   assert(TripCount == 0 || TripCount % TripMultiple == 0);
192 
193   // Are we eliminating the loop control altogether?
194   bool CompletelyUnroll = (Count == TripCount);
195 
196   // We use the runtime remainder in cases where we don't know trip multiple
197   if (TripMultiple == 1 || TripMultiple % Count != 0) {
198     if (!UnrollRuntimeLoopRemainder(L, Count, /*AllowExpensiveTripCount*/ false,
199                                     /*UseEpilogRemainder*/ true,
200                                     UnrollRemainder, /*ForgetAllSCEV*/ false,
201                                     LI, SE, DT, AC, true, EpilogueLoop)) {
202       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; remainder loop could not be "
203                            "generated when assuming runtime trip count\n");
204       return LoopUnrollResult::Unmodified;
205     }
206   }
207 
208   // Notify ScalarEvolution that the loop will be substantially changed,
209   // if not outright eliminated.
210   if (SE) {
211     SE->forgetLoop(L);
212     SE->forgetLoop(SubLoop);
213   }
214 
215   using namespace ore;
216   // Report the unrolling decision.
217   if (CompletelyUnroll) {
218     LLVM_DEBUG(dbgs() << "COMPLETELY UNROLL AND JAMMING loop %"
219                       << Header->getName() << " with trip count " << TripCount
220                       << "!\n");
221     ORE->emit(OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(),
222                                  L->getHeader())
223               << "completely unroll and jammed loop with "
224               << NV("UnrollCount", TripCount) << " iterations");
225   } else {
226     auto DiagBuilder = [&]() {
227       OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(),
228                               L->getHeader());
229       return Diag << "unroll and jammed loop by a factor of "
230                   << NV("UnrollCount", Count);
231     };
232 
233     LLVM_DEBUG(dbgs() << "UNROLL AND JAMMING loop %" << Header->getName()
234                       << " by " << Count);
235     if (TripMultiple != 1) {
236       LLVM_DEBUG(dbgs() << " with " << TripMultiple << " trips per branch");
237       ORE->emit([&]() {
238         return DiagBuilder() << " with " << NV("TripMultiple", TripMultiple)
239                              << " trips per branch";
240       });
241     } else {
242       LLVM_DEBUG(dbgs() << " with run-time trip count");
243       ORE->emit([&]() { return DiagBuilder() << " with run-time trip count"; });
244     }
245     LLVM_DEBUG(dbgs() << "!\n");
246   }
247 
248   BasicBlock *Preheader = L->getLoopPreheader();
249   BasicBlock *LatchBlock = L->getLoopLatch();
250   assert(Preheader && "No preheader");
251   assert(LatchBlock && "No latch block");
252   BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator());
253   assert(BI && !BI->isUnconditional());
254   bool ContinueOnTrue = L->contains(BI->getSuccessor(0));
255   BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue);
256   bool SubLoopContinueOnTrue = SubLoop->contains(
257       SubLoop->getLoopLatch()->getTerminator()->getSuccessor(0));
258 
259   // Partition blocks in an outer/inner loop pair into blocks before and after
260   // the loop
261   BasicBlockSet SubLoopBlocks;
262   BasicBlockSet ForeBlocks;
263   BasicBlockSet AftBlocks;
264   partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks, AftBlocks,
265                            DT);
266 
267   // We keep track of the entering/first and exiting/last block of each of
268   // Fore/SubLoop/Aft in each iteration. This helps make the stapling up of
269   // blocks easier.
270   std::vector<BasicBlock *> ForeBlocksFirst;
271   std::vector<BasicBlock *> ForeBlocksLast;
272   std::vector<BasicBlock *> SubLoopBlocksFirst;
273   std::vector<BasicBlock *> SubLoopBlocksLast;
274   std::vector<BasicBlock *> AftBlocksFirst;
275   std::vector<BasicBlock *> AftBlocksLast;
276   ForeBlocksFirst.push_back(Header);
277   ForeBlocksLast.push_back(SubLoop->getLoopPreheader());
278   SubLoopBlocksFirst.push_back(SubLoop->getHeader());
279   SubLoopBlocksLast.push_back(SubLoop->getExitingBlock());
280   AftBlocksFirst.push_back(SubLoop->getExitBlock());
281   AftBlocksLast.push_back(L->getExitingBlock());
282   // Maps Blocks[0] -> Blocks[It]
283   ValueToValueMapTy LastValueMap;
284 
285   // Move any instructions from fore phi operands from AftBlocks into Fore.
286   moveHeaderPhiOperandsToForeBlocks(
287       Header, LatchBlock, ForeBlocksLast[0]->getTerminator(), AftBlocks);
288 
289   // The current on-the-fly SSA update requires blocks to be processed in
290   // reverse postorder so that LastValueMap contains the correct value at each
291   // exit.
292   LoopBlocksDFS DFS(L);
293   DFS.perform(LI);
294   // Stash the DFS iterators before adding blocks to the loop.
295   LoopBlocksDFS::RPOIterator BlockBegin = DFS.beginRPO();
296   LoopBlocksDFS::RPOIterator BlockEnd = DFS.endRPO();
297 
298   if (Header->getParent()->isDebugInfoForProfiling())
299     for (BasicBlock *BB : L->getBlocks())
300       for (Instruction &I : *BB)
301         if (!isa<DbgInfoIntrinsic>(&I))
302           if (const DILocation *DIL = I.getDebugLoc()) {
303             auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(Count);
304             if (NewDIL)
305               I.setDebugLoc(NewDIL.getValue());
306             else
307               LLVM_DEBUG(dbgs()
308                          << "Failed to create new discriminator: "
309                          << DIL->getFilename() << " Line: " << DIL->getLine());
310           }
311 
312   // Copy all blocks
313   for (unsigned It = 1; It != Count; ++It) {
314     SmallVector<BasicBlock *, 8> NewBlocks;
315     // Maps Blocks[It] -> Blocks[It-1]
316     DenseMap<Value *, Value *> PrevItValueMap;
317 
318     for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) {
319       ValueToValueMapTy VMap;
320       BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It));
321       Header->getParent()->getBasicBlockList().push_back(New);
322 
323       if (ForeBlocks.count(*BB)) {
324         L->addBasicBlockToLoop(New, *LI);
325 
326         if (*BB == ForeBlocksFirst[0])
327           ForeBlocksFirst.push_back(New);
328         if (*BB == ForeBlocksLast[0])
329           ForeBlocksLast.push_back(New);
330       } else if (SubLoopBlocks.count(*BB)) {
331         SubLoop->addBasicBlockToLoop(New, *LI);
332 
333         if (*BB == SubLoopBlocksFirst[0])
334           SubLoopBlocksFirst.push_back(New);
335         if (*BB == SubLoopBlocksLast[0])
336           SubLoopBlocksLast.push_back(New);
337       } else if (AftBlocks.count(*BB)) {
338         L->addBasicBlockToLoop(New, *LI);
339 
340         if (*BB == AftBlocksFirst[0])
341           AftBlocksFirst.push_back(New);
342         if (*BB == AftBlocksLast[0])
343           AftBlocksLast.push_back(New);
344       } else {
345         llvm_unreachable("BB being cloned should be in Fore/Sub/Aft");
346       }
347 
348       // Update our running maps of newest clones
349       PrevItValueMap[New] = (It == 1 ? *BB : LastValueMap[*BB]);
350       LastValueMap[*BB] = New;
351       for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end();
352            VI != VE; ++VI) {
353         PrevItValueMap[VI->second] =
354             const_cast<Value *>(It == 1 ? VI->first : LastValueMap[VI->first]);
355         LastValueMap[VI->first] = VI->second;
356       }
357 
358       NewBlocks.push_back(New);
359 
360       // Update DomTree:
361       if (*BB == ForeBlocksFirst[0])
362         DT->addNewBlock(New, ForeBlocksLast[It - 1]);
363       else if (*BB == SubLoopBlocksFirst[0])
364         DT->addNewBlock(New, SubLoopBlocksLast[It - 1]);
365       else if (*BB == AftBlocksFirst[0])
366         DT->addNewBlock(New, AftBlocksLast[It - 1]);
367       else {
368         // Each set of blocks (Fore/Sub/Aft) will have the same internal domtree
369         // structure.
370         auto BBDomNode = DT->getNode(*BB);
371         auto BBIDom = BBDomNode->getIDom();
372         BasicBlock *OriginalBBIDom = BBIDom->getBlock();
373         assert(OriginalBBIDom);
374         assert(LastValueMap[cast<Value>(OriginalBBIDom)]);
375         DT->addNewBlock(
376             New, cast<BasicBlock>(LastValueMap[cast<Value>(OriginalBBIDom)]));
377       }
378     }
379 
380     // Remap all instructions in the most recent iteration
381     remapInstructionsInBlocks(NewBlocks, LastValueMap);
382     for (BasicBlock *NewBlock : NewBlocks) {
383       for (Instruction &I : *NewBlock) {
384         if (auto *II = dyn_cast<IntrinsicInst>(&I))
385           if (II->getIntrinsicID() == Intrinsic::assume)
386             AC->registerAssumption(II);
387       }
388     }
389 
390     // Alter the ForeBlocks phi's, pointing them at the latest version of the
391     // value from the previous iteration's phis
392     for (PHINode &Phi : ForeBlocksFirst[It]->phis()) {
393       Value *OldValue = Phi.getIncomingValueForBlock(AftBlocksLast[It]);
394       assert(OldValue && "should have incoming edge from Aft[It]");
395       Value *NewValue = OldValue;
396       if (Value *PrevValue = PrevItValueMap[OldValue])
397         NewValue = PrevValue;
398 
399       assert(Phi.getNumOperands() == 2);
400       Phi.setIncomingBlock(0, ForeBlocksLast[It - 1]);
401       Phi.setIncomingValue(0, NewValue);
402       Phi.removeIncomingValue(1);
403     }
404   }
405 
406   // Now that all the basic blocks for the unrolled iterations are in place,
407   // finish up connecting the blocks and phi nodes. At this point LastValueMap
408   // is the last unrolled iterations values.
409 
410   // Update Phis in BB from OldBB to point to NewBB
411   auto updatePHIBlocks = [](BasicBlock *BB, BasicBlock *OldBB,
412                             BasicBlock *NewBB) {
413     for (PHINode &Phi : BB->phis()) {
414       int I = Phi.getBasicBlockIndex(OldBB);
415       Phi.setIncomingBlock(I, NewBB);
416     }
417   };
418   // Update Phis in BB from OldBB to point to NewBB and use the latest value
419   // from LastValueMap
420   auto updatePHIBlocksAndValues = [](BasicBlock *BB, BasicBlock *OldBB,
421                                      BasicBlock *NewBB,
422                                      ValueToValueMapTy &LastValueMap) {
423     for (PHINode &Phi : BB->phis()) {
424       for (unsigned b = 0; b < Phi.getNumIncomingValues(); ++b) {
425         if (Phi.getIncomingBlock(b) == OldBB) {
426           Value *OldValue = Phi.getIncomingValue(b);
427           if (Value *LastValue = LastValueMap[OldValue])
428             Phi.setIncomingValue(b, LastValue);
429           Phi.setIncomingBlock(b, NewBB);
430           break;
431         }
432       }
433     }
434   };
435   // Move all the phis from Src into Dest
436   auto movePHIs = [](BasicBlock *Src, BasicBlock *Dest) {
437     Instruction *insertPoint = Dest->getFirstNonPHI();
438     while (PHINode *Phi = dyn_cast<PHINode>(Src->begin()))
439       Phi->moveBefore(insertPoint);
440   };
441 
442   // Update the PHI values outside the loop to point to the last block
443   updatePHIBlocksAndValues(LoopExit, AftBlocksLast[0], AftBlocksLast.back(),
444                            LastValueMap);
445 
446   // Update ForeBlocks successors and phi nodes
447   BranchInst *ForeTerm =
448       cast<BranchInst>(ForeBlocksLast.back()->getTerminator());
449   assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
450   ForeTerm->setSuccessor(0, SubLoopBlocksFirst[0]);
451 
452   if (CompletelyUnroll) {
453     while (PHINode *Phi = dyn_cast<PHINode>(ForeBlocksFirst[0]->begin())) {
454       Phi->replaceAllUsesWith(Phi->getIncomingValueForBlock(Preheader));
455       Phi->getParent()->getInstList().erase(Phi);
456     }
457   } else {
458     // Update the PHI values to point to the last aft block
459     updatePHIBlocksAndValues(ForeBlocksFirst[0], AftBlocksLast[0],
460                              AftBlocksLast.back(), LastValueMap);
461   }
462 
463   for (unsigned It = 1; It != Count; It++) {
464     // Remap ForeBlock successors from previous iteration to this
465     BranchInst *ForeTerm =
466         cast<BranchInst>(ForeBlocksLast[It - 1]->getTerminator());
467     assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
468     ForeTerm->setSuccessor(0, ForeBlocksFirst[It]);
469   }
470 
471   // Subloop successors and phis
472   BranchInst *SubTerm =
473       cast<BranchInst>(SubLoopBlocksLast.back()->getTerminator());
474   SubTerm->setSuccessor(!SubLoopContinueOnTrue, SubLoopBlocksFirst[0]);
475   SubTerm->setSuccessor(SubLoopContinueOnTrue, AftBlocksFirst[0]);
476   updatePHIBlocks(SubLoopBlocksFirst[0], ForeBlocksLast[0],
477                   ForeBlocksLast.back());
478   updatePHIBlocks(SubLoopBlocksFirst[0], SubLoopBlocksLast[0],
479                   SubLoopBlocksLast.back());
480 
481   for (unsigned It = 1; It != Count; It++) {
482     // Replace the conditional branch of the previous iteration subloop with an
483     // unconditional one to this one
484     BranchInst *SubTerm =
485         cast<BranchInst>(SubLoopBlocksLast[It - 1]->getTerminator());
486     BranchInst::Create(SubLoopBlocksFirst[It], SubTerm);
487     SubTerm->eraseFromParent();
488 
489     updatePHIBlocks(SubLoopBlocksFirst[It], ForeBlocksLast[It],
490                     ForeBlocksLast.back());
491     updatePHIBlocks(SubLoopBlocksFirst[It], SubLoopBlocksLast[It],
492                     SubLoopBlocksLast.back());
493     movePHIs(SubLoopBlocksFirst[It], SubLoopBlocksFirst[0]);
494   }
495 
496   // Aft blocks successors and phis
497   BranchInst *AftTerm = cast<BranchInst>(AftBlocksLast.back()->getTerminator());
498   if (CompletelyUnroll) {
499     BranchInst::Create(LoopExit, AftTerm);
500     AftTerm->eraseFromParent();
501   } else {
502     AftTerm->setSuccessor(!ContinueOnTrue, ForeBlocksFirst[0]);
503     assert(AftTerm->getSuccessor(ContinueOnTrue) == LoopExit &&
504            "Expecting the ContinueOnTrue successor of AftTerm to be LoopExit");
505   }
506   updatePHIBlocks(AftBlocksFirst[0], SubLoopBlocksLast[0],
507                   SubLoopBlocksLast.back());
508 
509   for (unsigned It = 1; It != Count; It++) {
510     // Replace the conditional branch of the previous iteration subloop with an
511     // unconditional one to this one
512     BranchInst *AftTerm =
513         cast<BranchInst>(AftBlocksLast[It - 1]->getTerminator());
514     BranchInst::Create(AftBlocksFirst[It], AftTerm);
515     AftTerm->eraseFromParent();
516 
517     updatePHIBlocks(AftBlocksFirst[It], SubLoopBlocksLast[It],
518                     SubLoopBlocksLast.back());
519     movePHIs(AftBlocksFirst[It], AftBlocksFirst[0]);
520   }
521 
522   DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
523   // Dominator Tree. Remove the old links between Fore, Sub and Aft, adding the
524   // new ones required.
525   if (Count != 1) {
526     SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
527     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete, ForeBlocksLast[0],
528                            SubLoopBlocksFirst[0]);
529     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete,
530                            SubLoopBlocksLast[0], AftBlocksFirst[0]);
531 
532     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
533                            ForeBlocksLast.back(), SubLoopBlocksFirst[0]);
534     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
535                            SubLoopBlocksLast.back(), AftBlocksFirst[0]);
536     DTU.applyUpdatesPermissive(DTUpdates);
537   }
538 
539   // Merge adjacent basic blocks, if possible.
540   SmallPtrSet<BasicBlock *, 16> MergeBlocks;
541   MergeBlocks.insert(ForeBlocksLast.begin(), ForeBlocksLast.end());
542   MergeBlocks.insert(SubLoopBlocksLast.begin(), SubLoopBlocksLast.end());
543   MergeBlocks.insert(AftBlocksLast.begin(), AftBlocksLast.end());
544   while (!MergeBlocks.empty()) {
545     BasicBlock *BB = *MergeBlocks.begin();
546     BranchInst *Term = dyn_cast<BranchInst>(BB->getTerminator());
547     if (Term && Term->isUnconditional() && L->contains(Term->getSuccessor(0))) {
548       BasicBlock *Dest = Term->getSuccessor(0);
549       BasicBlock *Fold = Dest->getUniquePredecessor();
550       if (MergeBlockIntoPredecessor(Dest, &DTU, LI)) {
551         // Don't remove BB and add Fold as they are the same BB
552         assert(Fold == BB);
553         (void)Fold;
554         MergeBlocks.erase(Dest);
555       } else
556         MergeBlocks.erase(BB);
557     } else
558       MergeBlocks.erase(BB);
559   }
560   // Apply updates to the DomTree.
561   DT = &DTU.getDomTree();
562 
563   // At this point, the code is well formed.  We now do a quick sweep over the
564   // inserted code, doing constant propagation and dead code elimination as we
565   // go.
566   simplifyLoopAfterUnroll(SubLoop, true, LI, SE, DT, AC);
567   simplifyLoopAfterUnroll(L, !CompletelyUnroll && Count > 1, LI, SE, DT, AC);
568 
569   NumCompletelyUnrolledAndJammed += CompletelyUnroll;
570   ++NumUnrolledAndJammed;
571 
572   // Update LoopInfo if the loop is completely removed.
573   if (CompletelyUnroll)
574     LI->erase(L);
575 
576 #ifndef NDEBUG
577   // We shouldn't have done anything to break loop simplify form or LCSSA.
578   Loop *OutestLoop = SubLoop->getParentLoop()
579                          ? SubLoop->getParentLoop()->getParentLoop()
580                                ? SubLoop->getParentLoop()->getParentLoop()
581                                : SubLoop->getParentLoop()
582                          : SubLoop;
583   assert(DT->verify());
584   LI->verify(*DT);
585   assert(OutestLoop->isRecursivelyLCSSAForm(*DT, *LI));
586   if (!CompletelyUnroll)
587     assert(L->isLoopSimplifyForm());
588   assert(SubLoop->isLoopSimplifyForm());
589   SE->verify();
590 #endif
591 
592   return CompletelyUnroll ? LoopUnrollResult::FullyUnrolled
593                           : LoopUnrollResult::PartiallyUnrolled;
594 }
595 
596 static bool getLoadsAndStores(BasicBlockSet &Blocks,
597                               SmallVector<Value *, 4> &MemInstr) {
598   // Scan the BBs and collect legal loads and stores.
599   // Returns false if non-simple loads/stores are found.
600   for (BasicBlock *BB : Blocks) {
601     for (Instruction &I : *BB) {
602       if (auto *Ld = dyn_cast<LoadInst>(&I)) {
603         if (!Ld->isSimple())
604           return false;
605         MemInstr.push_back(&I);
606       } else if (auto *St = dyn_cast<StoreInst>(&I)) {
607         if (!St->isSimple())
608           return false;
609         MemInstr.push_back(&I);
610       } else if (I.mayReadOrWriteMemory()) {
611         return false;
612       }
613     }
614   }
615   return true;
616 }
617 
618 static bool checkDependencies(SmallVector<Value *, 4> &Earlier,
619                               SmallVector<Value *, 4> &Later,
620                               unsigned LoopDepth, bool InnerLoop,
621                               DependenceInfo &DI) {
622   // Use DA to check for dependencies between loads and stores that make unroll
623   // and jam invalid
624   for (Value *I : Earlier) {
625     for (Value *J : Later) {
626       Instruction *Src = cast<Instruction>(I);
627       Instruction *Dst = cast<Instruction>(J);
628       if (Src == Dst)
629         continue;
630       // Ignore Input dependencies.
631       if (isa<LoadInst>(Src) && isa<LoadInst>(Dst))
632         continue;
633 
634       // Track dependencies, and if we find them take a conservative approach
635       // by allowing only = or < (not >), altough some > would be safe
636       // (depending upon unroll width).
637       // For the inner loop, we need to disallow any (> <) dependencies
638       // FIXME: Allow > so long as distance is less than unroll width
639       if (auto D = DI.depends(Src, Dst, true)) {
640         assert(D->isOrdered() && "Expected an output, flow or anti dep.");
641 
642         if (D->isConfused()) {
643           LLVM_DEBUG(dbgs() << "  Confused dependency between:\n"
644                             << "  " << *Src << "\n"
645                             << "  " << *Dst << "\n");
646           return false;
647         }
648         if (!InnerLoop) {
649           if (D->getDirection(LoopDepth) & Dependence::DVEntry::GT) {
650             LLVM_DEBUG(dbgs() << "  > dependency between:\n"
651                               << "  " << *Src << "\n"
652                               << "  " << *Dst << "\n");
653             return false;
654           }
655         } else {
656           assert(LoopDepth + 1 <= D->getLevels());
657           if (D->getDirection(LoopDepth) & Dependence::DVEntry::GT &&
658               D->getDirection(LoopDepth + 1) & Dependence::DVEntry::LT) {
659             LLVM_DEBUG(dbgs() << "  < > dependency between:\n"
660                               << "  " << *Src << "\n"
661                               << "  " << *Dst << "\n");
662             return false;
663           }
664         }
665       }
666     }
667   }
668   return true;
669 }
670 
671 static bool checkDependencies(Loop *L, BasicBlockSet &ForeBlocks,
672                               BasicBlockSet &SubLoopBlocks,
673                               BasicBlockSet &AftBlocks, DependenceInfo &DI) {
674   // Get all loads/store pairs for each blocks
675   SmallVector<Value *, 4> ForeMemInstr;
676   SmallVector<Value *, 4> SubLoopMemInstr;
677   SmallVector<Value *, 4> AftMemInstr;
678   if (!getLoadsAndStores(ForeBlocks, ForeMemInstr) ||
679       !getLoadsAndStores(SubLoopBlocks, SubLoopMemInstr) ||
680       !getLoadsAndStores(AftBlocks, AftMemInstr))
681     return false;
682 
683   // Check for dependencies between any blocks that may change order
684   unsigned LoopDepth = L->getLoopDepth();
685   return checkDependencies(ForeMemInstr, SubLoopMemInstr, LoopDepth, false,
686                            DI) &&
687          checkDependencies(ForeMemInstr, AftMemInstr, LoopDepth, false, DI) &&
688          checkDependencies(SubLoopMemInstr, AftMemInstr, LoopDepth, false,
689                            DI) &&
690          checkDependencies(SubLoopMemInstr, SubLoopMemInstr, LoopDepth, true,
691                            DI);
692 }
693 
694 bool llvm::isSafeToUnrollAndJam(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
695                                 DependenceInfo &DI) {
696   /* We currently handle outer loops like this:
697         |
698     ForeFirst    <----\    }
699      Blocks           |    } ForeBlocks
700     ForeLast          |    }
701         |             |
702     SubLoopFirst  <\  |    }
703      Blocks        |  |    } SubLoopBlocks
704     SubLoopLast   -/  |    }
705         |             |
706     AftFirst          |    }
707      Blocks           |    } AftBlocks
708     AftLast     ------/    }
709         |
710 
711     There are (theoretically) any number of blocks in ForeBlocks, SubLoopBlocks
712     and AftBlocks, providing that there is one edge from Fores to SubLoops,
713     one edge from SubLoops to Afts and a single outer loop exit (from Afts).
714     In practice we currently limit Aft blocks to a single block, and limit
715     things further in the profitablility checks of the unroll and jam pass.
716 
717     Because of the way we rearrange basic blocks, we also require that
718     the Fore blocks on all unrolled iterations are safe to move before the
719     SubLoop blocks of all iterations. So we require that the phi node looping
720     operands of ForeHeader can be moved to at least the end of ForeEnd, so that
721     we can arrange cloned Fore Blocks before the subloop and match up Phi's
722     correctly.
723 
724     i.e. The old order of blocks used to be F1 S1_1 S1_2 A1 F2 S2_1 S2_2 A2.
725     It needs to be safe to tranform this to F1 F2 S1_1 S2_1 S1_2 S2_2 A1 A2.
726 
727     There are then a number of checks along the lines of no calls, no
728     exceptions, inner loop IV is consistent, etc. Note that for loops requiring
729     runtime unrolling, UnrollRuntimeLoopRemainder can also fail in
730     UnrollAndJamLoop if the trip count cannot be easily calculated.
731   */
732 
733   if (!L->isLoopSimplifyForm() || L->getSubLoops().size() != 1)
734     return false;
735   Loop *SubLoop = L->getSubLoops()[0];
736   if (!SubLoop->isLoopSimplifyForm())
737     return false;
738 
739   BasicBlock *Header = L->getHeader();
740   BasicBlock *Latch = L->getLoopLatch();
741   BasicBlock *Exit = L->getExitingBlock();
742   BasicBlock *SubLoopHeader = SubLoop->getHeader();
743   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
744   BasicBlock *SubLoopExit = SubLoop->getExitingBlock();
745 
746   if (Latch != Exit)
747     return false;
748   if (SubLoopLatch != SubLoopExit)
749     return false;
750 
751   if (Header->hasAddressTaken() || SubLoopHeader->hasAddressTaken()) {
752     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Address taken\n");
753     return false;
754   }
755 
756   // Split blocks into Fore/SubLoop/Aft based on dominators
757   BasicBlockSet SubLoopBlocks;
758   BasicBlockSet ForeBlocks;
759   BasicBlockSet AftBlocks;
760   if (!partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks,
761                                 AftBlocks, &DT)) {
762     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Incompatible loop layout\n");
763     return false;
764   }
765 
766   // Aft blocks may need to move instructions to fore blocks, which becomes more
767   // difficult if there are multiple (potentially conditionally executed)
768   // blocks. For now we just exclude loops with multiple aft blocks.
769   if (AftBlocks.size() != 1) {
770     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Can't currently handle "
771                          "multiple blocks after the loop\n");
772     return false;
773   }
774 
775   // Check inner loop backedge count is consistent on all iterations of the
776   // outer loop
777   if (!hasIterationCountInvariantInParent(SubLoop, SE)) {
778     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Inner loop iteration count is "
779                          "not consistent on each iteration\n");
780     return false;
781   }
782 
783   // Check the loop safety info for exceptions.
784   SimpleLoopSafetyInfo LSI;
785   LSI.computeLoopSafetyInfo(L);
786   if (LSI.anyBlockMayThrow()) {
787     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Something may throw\n");
788     return false;
789   }
790 
791   // We've ruled out the easy stuff and now need to check that there are no
792   // interdependencies which may prevent us from moving the:
793   //  ForeBlocks before Subloop and AftBlocks.
794   //  Subloop before AftBlocks.
795   //  ForeBlock phi operands before the subloop
796 
797   // Make sure we can move all instructions we need to before the subloop
798   if (!processHeaderPhiOperands(
799           Header, Latch, AftBlocks, [&AftBlocks, &SubLoop](Instruction *I) {
800             if (SubLoop->contains(I->getParent()))
801               return false;
802             if (AftBlocks.count(I->getParent())) {
803               // If we hit a phi node in afts we know we are done (probably
804               // LCSSA)
805               if (isa<PHINode>(I))
806                 return false;
807               // Can't move instructions with side effects or memory
808               // reads/writes
809               if (I->mayHaveSideEffects() || I->mayReadOrWriteMemory())
810                 return false;
811             }
812             // Keep going
813             return true;
814           })) {
815     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; can't move required "
816                          "instructions after subloop to before it\n");
817     return false;
818   }
819 
820   // Check for memory dependencies which prohibit the unrolling we are doing.
821   // Because of the way we are unrolling Fore/Sub/Aft blocks, we need to check
822   // there are no dependencies between Fore-Sub, Fore-Aft, Sub-Aft and Sub-Sub.
823   if (!checkDependencies(L, ForeBlocks, SubLoopBlocks, AftBlocks, DI)) {
824     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; failed dependency check\n");
825     return false;
826   }
827 
828   return true;
829 }
830