1 //===-- LoopUnrollAndJam.cpp - Loop unrolling utilities -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop unroll and jam as a routine, much like
10 // LoopUnroll.cpp implements loop unroll.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/ArrayRef.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/Optional.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/Statistic.h"
21 #include "llvm/ADT/StringRef.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/ADT/iterator_range.h"
24 #include "llvm/Analysis/AssumptionCache.h"
25 #include "llvm/Analysis/DependenceAnalysis.h"
26 #include "llvm/Analysis/DomTreeUpdater.h"
27 #include "llvm/Analysis/LoopInfo.h"
28 #include "llvm/Analysis/LoopIterator.h"
29 #include "llvm/Analysis/MustExecute.h"
30 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
31 #include "llvm/Analysis/ScalarEvolution.h"
32 #include "llvm/IR/BasicBlock.h"
33 #include "llvm/IR/DebugInfoMetadata.h"
34 #include "llvm/IR/DebugLoc.h"
35 #include "llvm/IR/DiagnosticInfo.h"
36 #include "llvm/IR/Dominators.h"
37 #include "llvm/IR/Function.h"
38 #include "llvm/IR/Instruction.h"
39 #include "llvm/IR/Instructions.h"
40 #include "llvm/IR/IntrinsicInst.h"
41 #include "llvm/IR/Use.h"
42 #include "llvm/IR/User.h"
43 #include "llvm/IR/Value.h"
44 #include "llvm/IR/ValueHandle.h"
45 #include "llvm/IR/ValueMap.h"
46 #include "llvm/Support/Casting.h"
47 #include "llvm/Support/Debug.h"
48 #include "llvm/Support/ErrorHandling.h"
49 #include "llvm/Support/GenericDomTree.h"
50 #include "llvm/Support/raw_ostream.h"
51 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
52 #include "llvm/Transforms/Utils/Cloning.h"
53 #include "llvm/Transforms/Utils/LoopUtils.h"
54 #include "llvm/Transforms/Utils/UnrollLoop.h"
55 #include "llvm/Transforms/Utils/ValueMapper.h"
56 #include <assert.h>
57 #include <memory>
58 #include <type_traits>
59 #include <vector>
60 
61 using namespace llvm;
62 
63 #define DEBUG_TYPE "loop-unroll-and-jam"
64 
65 STATISTIC(NumUnrolledAndJammed, "Number of loops unroll and jammed");
66 STATISTIC(NumCompletelyUnrolledAndJammed, "Number of loops unroll and jammed");
67 
68 typedef SmallPtrSet<BasicBlock *, 4> BasicBlockSet;
69 
70 // Partition blocks in an outer/inner loop pair into blocks before and after
71 // the loop
72 static bool partitionOuterLoopBlocks(Loop *L, Loop *SubLoop,
73                                      BasicBlockSet &ForeBlocks,
74                                      BasicBlockSet &SubLoopBlocks,
75                                      BasicBlockSet &AftBlocks,
76                                      DominatorTree *DT) {
77   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
78   SubLoopBlocks.insert(SubLoop->block_begin(), SubLoop->block_end());
79 
80   for (BasicBlock *BB : L->blocks()) {
81     if (!SubLoop->contains(BB)) {
82       if (DT->dominates(SubLoopLatch, BB))
83         AftBlocks.insert(BB);
84       else
85         ForeBlocks.insert(BB);
86     }
87   }
88 
89   // Check that all blocks in ForeBlocks together dominate the subloop
90   // TODO: This might ideally be done better with a dominator/postdominators.
91   BasicBlock *SubLoopPreHeader = SubLoop->getLoopPreheader();
92   for (BasicBlock *BB : ForeBlocks) {
93     if (BB == SubLoopPreHeader)
94       continue;
95     Instruction *TI = BB->getTerminator();
96     for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
97       if (!ForeBlocks.count(TI->getSuccessor(i)))
98         return false;
99   }
100 
101   return true;
102 }
103 
104 // Looks at the phi nodes in Header for values coming from Latch. For these
105 // instructions and all their operands calls Visit on them, keeping going for
106 // all the operands in AftBlocks. Returns false if Visit returns false,
107 // otherwise returns true. This is used to process the instructions in the
108 // Aft blocks that need to be moved before the subloop. It is used in two
109 // places. One to check that the required set of instructions can be moved
110 // before the loop. Then to collect the instructions to actually move in
111 // moveHeaderPhiOperandsToForeBlocks.
112 template <typename T>
113 static bool processHeaderPhiOperands(BasicBlock *Header, BasicBlock *Latch,
114                                      BasicBlockSet &AftBlocks, T Visit) {
115   SmallVector<Instruction *, 8> Worklist;
116   for (auto &Phi : Header->phis()) {
117     Value *V = Phi.getIncomingValueForBlock(Latch);
118     if (Instruction *I = dyn_cast<Instruction>(V))
119       Worklist.push_back(I);
120   }
121 
122   while (!Worklist.empty()) {
123     Instruction *I = Worklist.back();
124     Worklist.pop_back();
125     if (!Visit(I))
126       return false;
127 
128     if (AftBlocks.count(I->getParent()))
129       for (auto &U : I->operands())
130         if (Instruction *II = dyn_cast<Instruction>(U))
131           Worklist.push_back(II);
132   }
133 
134   return true;
135 }
136 
137 // Move the phi operands of Header from Latch out of AftBlocks to InsertLoc.
138 static void moveHeaderPhiOperandsToForeBlocks(BasicBlock *Header,
139                                               BasicBlock *Latch,
140                                               Instruction *InsertLoc,
141                                               BasicBlockSet &AftBlocks) {
142   // We need to ensure we move the instructions in the correct order,
143   // starting with the earliest required instruction and moving forward.
144   std::vector<Instruction *> Visited;
145   processHeaderPhiOperands(Header, Latch, AftBlocks,
146                            [&Visited, &AftBlocks](Instruction *I) {
147                              if (AftBlocks.count(I->getParent()))
148                                Visited.push_back(I);
149                              return true;
150                            });
151 
152   // Move all instructions in program order to before the InsertLoc
153   BasicBlock *InsertLocBB = InsertLoc->getParent();
154   for (Instruction *I : reverse(Visited)) {
155     if (I->getParent() != InsertLocBB)
156       I->moveBefore(InsertLoc);
157   }
158 }
159 
160 /*
161   This method performs Unroll and Jam. For a simple loop like:
162   for (i = ..)
163     Fore(i)
164     for (j = ..)
165       SubLoop(i, j)
166     Aft(i)
167 
168   Instead of doing normal inner or outer unrolling, we do:
169   for (i = .., i+=2)
170     Fore(i)
171     Fore(i+1)
172     for (j = ..)
173       SubLoop(i, j)
174       SubLoop(i+1, j)
175     Aft(i)
176     Aft(i+1)
177 
178   So the outer loop is essetially unrolled and then the inner loops are fused
179   ("jammed") together into a single loop. This can increase speed when there
180   are loads in SubLoop that are invariant to i, as they become shared between
181   the now jammed inner loops.
182 
183   We do this by spliting the blocks in the loop into Fore, Subloop and Aft.
184   Fore blocks are those before the inner loop, Aft are those after. Normal
185   Unroll code is used to copy each of these sets of blocks and the results are
186   combined together into the final form above.
187 
188   isSafeToUnrollAndJam should be used prior to calling this to make sure the
189   unrolling will be valid. Checking profitablility is also advisable.
190 
191   If EpilogueLoop is non-null, it receives the epilogue loop (if it was
192   necessary to create one and not fully unrolled).
193 */
194 LoopUnrollResult
195 llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount,
196                        unsigned TripMultiple, bool UnrollRemainder,
197                        LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT,
198                        AssumptionCache *AC, const TargetTransformInfo *TTI,
199                        OptimizationRemarkEmitter *ORE, Loop **EpilogueLoop) {
200 
201   // When we enter here we should have already checked that it is safe
202   BasicBlock *Header = L->getHeader();
203   assert(Header && "No header.");
204   assert(L->getSubLoops().size() == 1);
205   Loop *SubLoop = *L->begin();
206 
207   // Don't enter the unroll code if there is nothing to do.
208   if (TripCount == 0 && Count < 2) {
209     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; almost nothing to do\n");
210     return LoopUnrollResult::Unmodified;
211   }
212 
213   assert(Count > 0);
214   assert(TripMultiple > 0);
215   assert(TripCount == 0 || TripCount % TripMultiple == 0);
216 
217   // Are we eliminating the loop control altogether?
218   bool CompletelyUnroll = (Count == TripCount);
219 
220   // We use the runtime remainder in cases where we don't know trip multiple
221   if (TripMultiple == 1 || TripMultiple % Count != 0) {
222     if (!UnrollRuntimeLoopRemainder(L, Count, /*AllowExpensiveTripCount*/ false,
223                                     /*UseEpilogRemainder*/ true,
224                                     UnrollRemainder, /*ForgetAllSCEV*/ false,
225                                     LI, SE, DT, AC, TTI, true, EpilogueLoop)) {
226       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; remainder loop could not be "
227                            "generated when assuming runtime trip count\n");
228       return LoopUnrollResult::Unmodified;
229     }
230   }
231 
232   // Notify ScalarEvolution that the loop will be substantially changed,
233   // if not outright eliminated.
234   if (SE) {
235     SE->forgetLoop(L);
236     SE->forgetLoop(SubLoop);
237   }
238 
239   using namespace ore;
240   // Report the unrolling decision.
241   if (CompletelyUnroll) {
242     LLVM_DEBUG(dbgs() << "COMPLETELY UNROLL AND JAMMING loop %"
243                       << Header->getName() << " with trip count " << TripCount
244                       << "!\n");
245     ORE->emit(OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(),
246                                  L->getHeader())
247               << "completely unroll and jammed loop with "
248               << NV("UnrollCount", TripCount) << " iterations");
249   } else {
250     auto DiagBuilder = [&]() {
251       OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(),
252                               L->getHeader());
253       return Diag << "unroll and jammed loop by a factor of "
254                   << NV("UnrollCount", Count);
255     };
256 
257     LLVM_DEBUG(dbgs() << "UNROLL AND JAMMING loop %" << Header->getName()
258                       << " by " << Count);
259     if (TripMultiple != 1) {
260       LLVM_DEBUG(dbgs() << " with " << TripMultiple << " trips per branch");
261       ORE->emit([&]() {
262         return DiagBuilder() << " with " << NV("TripMultiple", TripMultiple)
263                              << " trips per branch";
264       });
265     } else {
266       LLVM_DEBUG(dbgs() << " with run-time trip count");
267       ORE->emit([&]() { return DiagBuilder() << " with run-time trip count"; });
268     }
269     LLVM_DEBUG(dbgs() << "!\n");
270   }
271 
272   BasicBlock *Preheader = L->getLoopPreheader();
273   BasicBlock *LatchBlock = L->getLoopLatch();
274   assert(Preheader && "No preheader");
275   assert(LatchBlock && "No latch block");
276   BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator());
277   assert(BI && !BI->isUnconditional());
278   bool ContinueOnTrue = L->contains(BI->getSuccessor(0));
279   BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue);
280   bool SubLoopContinueOnTrue = SubLoop->contains(
281       SubLoop->getLoopLatch()->getTerminator()->getSuccessor(0));
282 
283   // Partition blocks in an outer/inner loop pair into blocks before and after
284   // the loop
285   BasicBlockSet SubLoopBlocks;
286   BasicBlockSet ForeBlocks;
287   BasicBlockSet AftBlocks;
288   partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks, AftBlocks,
289                            DT);
290 
291   // We keep track of the entering/first and exiting/last block of each of
292   // Fore/SubLoop/Aft in each iteration. This helps make the stapling up of
293   // blocks easier.
294   std::vector<BasicBlock *> ForeBlocksFirst;
295   std::vector<BasicBlock *> ForeBlocksLast;
296   std::vector<BasicBlock *> SubLoopBlocksFirst;
297   std::vector<BasicBlock *> SubLoopBlocksLast;
298   std::vector<BasicBlock *> AftBlocksFirst;
299   std::vector<BasicBlock *> AftBlocksLast;
300   ForeBlocksFirst.push_back(Header);
301   ForeBlocksLast.push_back(SubLoop->getLoopPreheader());
302   SubLoopBlocksFirst.push_back(SubLoop->getHeader());
303   SubLoopBlocksLast.push_back(SubLoop->getExitingBlock());
304   AftBlocksFirst.push_back(SubLoop->getExitBlock());
305   AftBlocksLast.push_back(L->getExitingBlock());
306   // Maps Blocks[0] -> Blocks[It]
307   ValueToValueMapTy LastValueMap;
308 
309   // Move any instructions from fore phi operands from AftBlocks into Fore.
310   moveHeaderPhiOperandsToForeBlocks(
311       Header, LatchBlock, ForeBlocksLast[0]->getTerminator(), AftBlocks);
312 
313   // The current on-the-fly SSA update requires blocks to be processed in
314   // reverse postorder so that LastValueMap contains the correct value at each
315   // exit.
316   LoopBlocksDFS DFS(L);
317   DFS.perform(LI);
318   // Stash the DFS iterators before adding blocks to the loop.
319   LoopBlocksDFS::RPOIterator BlockBegin = DFS.beginRPO();
320   LoopBlocksDFS::RPOIterator BlockEnd = DFS.endRPO();
321 
322   if (Header->getParent()->isDebugInfoForProfiling())
323     for (BasicBlock *BB : L->getBlocks())
324       for (Instruction &I : *BB)
325         if (!isa<DbgInfoIntrinsic>(&I))
326           if (const DILocation *DIL = I.getDebugLoc()) {
327             auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(Count);
328             if (NewDIL)
329               I.setDebugLoc(NewDIL.getValue());
330             else
331               LLVM_DEBUG(dbgs()
332                          << "Failed to create new discriminator: "
333                          << DIL->getFilename() << " Line: " << DIL->getLine());
334           }
335 
336   // Copy all blocks
337   for (unsigned It = 1; It != Count; ++It) {
338     SmallVector<BasicBlock *, 8> NewBlocks;
339     // Maps Blocks[It] -> Blocks[It-1]
340     DenseMap<Value *, Value *> PrevItValueMap;
341 
342     for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) {
343       ValueToValueMapTy VMap;
344       BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It));
345       Header->getParent()->getBasicBlockList().push_back(New);
346 
347       if (ForeBlocks.count(*BB)) {
348         L->addBasicBlockToLoop(New, *LI);
349 
350         if (*BB == ForeBlocksFirst[0])
351           ForeBlocksFirst.push_back(New);
352         if (*BB == ForeBlocksLast[0])
353           ForeBlocksLast.push_back(New);
354       } else if (SubLoopBlocks.count(*BB)) {
355         SubLoop->addBasicBlockToLoop(New, *LI);
356 
357         if (*BB == SubLoopBlocksFirst[0])
358           SubLoopBlocksFirst.push_back(New);
359         if (*BB == SubLoopBlocksLast[0])
360           SubLoopBlocksLast.push_back(New);
361       } else if (AftBlocks.count(*BB)) {
362         L->addBasicBlockToLoop(New, *LI);
363 
364         if (*BB == AftBlocksFirst[0])
365           AftBlocksFirst.push_back(New);
366         if (*BB == AftBlocksLast[0])
367           AftBlocksLast.push_back(New);
368       } else {
369         llvm_unreachable("BB being cloned should be in Fore/Sub/Aft");
370       }
371 
372       // Update our running maps of newest clones
373       PrevItValueMap[New] = (It == 1 ? *BB : LastValueMap[*BB]);
374       LastValueMap[*BB] = New;
375       for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end();
376            VI != VE; ++VI) {
377         PrevItValueMap[VI->second] =
378             const_cast<Value *>(It == 1 ? VI->first : LastValueMap[VI->first]);
379         LastValueMap[VI->first] = VI->second;
380       }
381 
382       NewBlocks.push_back(New);
383 
384       // Update DomTree:
385       if (*BB == ForeBlocksFirst[0])
386         DT->addNewBlock(New, ForeBlocksLast[It - 1]);
387       else if (*BB == SubLoopBlocksFirst[0])
388         DT->addNewBlock(New, SubLoopBlocksLast[It - 1]);
389       else if (*BB == AftBlocksFirst[0])
390         DT->addNewBlock(New, AftBlocksLast[It - 1]);
391       else {
392         // Each set of blocks (Fore/Sub/Aft) will have the same internal domtree
393         // structure.
394         auto BBDomNode = DT->getNode(*BB);
395         auto BBIDom = BBDomNode->getIDom();
396         BasicBlock *OriginalBBIDom = BBIDom->getBlock();
397         assert(OriginalBBIDom);
398         assert(LastValueMap[cast<Value>(OriginalBBIDom)]);
399         DT->addNewBlock(
400             New, cast<BasicBlock>(LastValueMap[cast<Value>(OriginalBBIDom)]));
401       }
402     }
403 
404     // Remap all instructions in the most recent iteration
405     remapInstructionsInBlocks(NewBlocks, LastValueMap);
406     for (BasicBlock *NewBlock : NewBlocks) {
407       for (Instruction &I : *NewBlock) {
408         if (auto *II = dyn_cast<IntrinsicInst>(&I))
409           if (II->getIntrinsicID() == Intrinsic::assume)
410             AC->registerAssumption(II);
411       }
412     }
413 
414     // Alter the ForeBlocks phi's, pointing them at the latest version of the
415     // value from the previous iteration's phis
416     for (PHINode &Phi : ForeBlocksFirst[It]->phis()) {
417       Value *OldValue = Phi.getIncomingValueForBlock(AftBlocksLast[It]);
418       assert(OldValue && "should have incoming edge from Aft[It]");
419       Value *NewValue = OldValue;
420       if (Value *PrevValue = PrevItValueMap[OldValue])
421         NewValue = PrevValue;
422 
423       assert(Phi.getNumOperands() == 2);
424       Phi.setIncomingBlock(0, ForeBlocksLast[It - 1]);
425       Phi.setIncomingValue(0, NewValue);
426       Phi.removeIncomingValue(1);
427     }
428   }
429 
430   // Now that all the basic blocks for the unrolled iterations are in place,
431   // finish up connecting the blocks and phi nodes. At this point LastValueMap
432   // is the last unrolled iterations values.
433 
434   // Update Phis in BB from OldBB to point to NewBB
435   auto updatePHIBlocks = [](BasicBlock *BB, BasicBlock *OldBB,
436                             BasicBlock *NewBB) {
437     for (PHINode &Phi : BB->phis()) {
438       int I = Phi.getBasicBlockIndex(OldBB);
439       Phi.setIncomingBlock(I, NewBB);
440     }
441   };
442   // Update Phis in BB from OldBB to point to NewBB and use the latest value
443   // from LastValueMap
444   auto updatePHIBlocksAndValues = [](BasicBlock *BB, BasicBlock *OldBB,
445                                      BasicBlock *NewBB,
446                                      ValueToValueMapTy &LastValueMap) {
447     for (PHINode &Phi : BB->phis()) {
448       for (unsigned b = 0; b < Phi.getNumIncomingValues(); ++b) {
449         if (Phi.getIncomingBlock(b) == OldBB) {
450           Value *OldValue = Phi.getIncomingValue(b);
451           if (Value *LastValue = LastValueMap[OldValue])
452             Phi.setIncomingValue(b, LastValue);
453           Phi.setIncomingBlock(b, NewBB);
454           break;
455         }
456       }
457     }
458   };
459   // Move all the phis from Src into Dest
460   auto movePHIs = [](BasicBlock *Src, BasicBlock *Dest) {
461     Instruction *insertPoint = Dest->getFirstNonPHI();
462     while (PHINode *Phi = dyn_cast<PHINode>(Src->begin()))
463       Phi->moveBefore(insertPoint);
464   };
465 
466   // Update the PHI values outside the loop to point to the last block
467   updatePHIBlocksAndValues(LoopExit, AftBlocksLast[0], AftBlocksLast.back(),
468                            LastValueMap);
469 
470   // Update ForeBlocks successors and phi nodes
471   BranchInst *ForeTerm =
472       cast<BranchInst>(ForeBlocksLast.back()->getTerminator());
473   assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
474   ForeTerm->setSuccessor(0, SubLoopBlocksFirst[0]);
475 
476   if (CompletelyUnroll) {
477     while (PHINode *Phi = dyn_cast<PHINode>(ForeBlocksFirst[0]->begin())) {
478       Phi->replaceAllUsesWith(Phi->getIncomingValueForBlock(Preheader));
479       Phi->getParent()->getInstList().erase(Phi);
480     }
481   } else {
482     // Update the PHI values to point to the last aft block
483     updatePHIBlocksAndValues(ForeBlocksFirst[0], AftBlocksLast[0],
484                              AftBlocksLast.back(), LastValueMap);
485   }
486 
487   for (unsigned It = 1; It != Count; It++) {
488     // Remap ForeBlock successors from previous iteration to this
489     BranchInst *ForeTerm =
490         cast<BranchInst>(ForeBlocksLast[It - 1]->getTerminator());
491     assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
492     ForeTerm->setSuccessor(0, ForeBlocksFirst[It]);
493   }
494 
495   // Subloop successors and phis
496   BranchInst *SubTerm =
497       cast<BranchInst>(SubLoopBlocksLast.back()->getTerminator());
498   SubTerm->setSuccessor(!SubLoopContinueOnTrue, SubLoopBlocksFirst[0]);
499   SubTerm->setSuccessor(SubLoopContinueOnTrue, AftBlocksFirst[0]);
500   updatePHIBlocks(SubLoopBlocksFirst[0], ForeBlocksLast[0],
501                   ForeBlocksLast.back());
502   updatePHIBlocks(SubLoopBlocksFirst[0], SubLoopBlocksLast[0],
503                   SubLoopBlocksLast.back());
504 
505   for (unsigned It = 1; It != Count; It++) {
506     // Replace the conditional branch of the previous iteration subloop with an
507     // unconditional one to this one
508     BranchInst *SubTerm =
509         cast<BranchInst>(SubLoopBlocksLast[It - 1]->getTerminator());
510     BranchInst::Create(SubLoopBlocksFirst[It], SubTerm);
511     SubTerm->eraseFromParent();
512 
513     updatePHIBlocks(SubLoopBlocksFirst[It], ForeBlocksLast[It],
514                     ForeBlocksLast.back());
515     updatePHIBlocks(SubLoopBlocksFirst[It], SubLoopBlocksLast[It],
516                     SubLoopBlocksLast.back());
517     movePHIs(SubLoopBlocksFirst[It], SubLoopBlocksFirst[0]);
518   }
519 
520   // Aft blocks successors and phis
521   BranchInst *AftTerm = cast<BranchInst>(AftBlocksLast.back()->getTerminator());
522   if (CompletelyUnroll) {
523     BranchInst::Create(LoopExit, AftTerm);
524     AftTerm->eraseFromParent();
525   } else {
526     AftTerm->setSuccessor(!ContinueOnTrue, ForeBlocksFirst[0]);
527     assert(AftTerm->getSuccessor(ContinueOnTrue) == LoopExit &&
528            "Expecting the ContinueOnTrue successor of AftTerm to be LoopExit");
529   }
530   updatePHIBlocks(AftBlocksFirst[0], SubLoopBlocksLast[0],
531                   SubLoopBlocksLast.back());
532 
533   for (unsigned It = 1; It != Count; It++) {
534     // Replace the conditional branch of the previous iteration subloop with an
535     // unconditional one to this one
536     BranchInst *AftTerm =
537         cast<BranchInst>(AftBlocksLast[It - 1]->getTerminator());
538     BranchInst::Create(AftBlocksFirst[It], AftTerm);
539     AftTerm->eraseFromParent();
540 
541     updatePHIBlocks(AftBlocksFirst[It], SubLoopBlocksLast[It],
542                     SubLoopBlocksLast.back());
543     movePHIs(AftBlocksFirst[It], AftBlocksFirst[0]);
544   }
545 
546   DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
547   // Dominator Tree. Remove the old links between Fore, Sub and Aft, adding the
548   // new ones required.
549   if (Count != 1) {
550     SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
551     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete, ForeBlocksLast[0],
552                            SubLoopBlocksFirst[0]);
553     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete,
554                            SubLoopBlocksLast[0], AftBlocksFirst[0]);
555 
556     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
557                            ForeBlocksLast.back(), SubLoopBlocksFirst[0]);
558     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
559                            SubLoopBlocksLast.back(), AftBlocksFirst[0]);
560     DTU.applyUpdatesPermissive(DTUpdates);
561   }
562 
563   // Merge adjacent basic blocks, if possible.
564   SmallPtrSet<BasicBlock *, 16> MergeBlocks;
565   MergeBlocks.insert(ForeBlocksLast.begin(), ForeBlocksLast.end());
566   MergeBlocks.insert(SubLoopBlocksLast.begin(), SubLoopBlocksLast.end());
567   MergeBlocks.insert(AftBlocksLast.begin(), AftBlocksLast.end());
568   while (!MergeBlocks.empty()) {
569     BasicBlock *BB = *MergeBlocks.begin();
570     BranchInst *Term = dyn_cast<BranchInst>(BB->getTerminator());
571     if (Term && Term->isUnconditional() && L->contains(Term->getSuccessor(0))) {
572       BasicBlock *Dest = Term->getSuccessor(0);
573       BasicBlock *Fold = Dest->getUniquePredecessor();
574       if (MergeBlockIntoPredecessor(Dest, &DTU, LI)) {
575         // Don't remove BB and add Fold as they are the same BB
576         assert(Fold == BB);
577         (void)Fold;
578         MergeBlocks.erase(Dest);
579       } else
580         MergeBlocks.erase(BB);
581     } else
582       MergeBlocks.erase(BB);
583   }
584   // Apply updates to the DomTree.
585   DT = &DTU.getDomTree();
586 
587   // At this point, the code is well formed.  We now do a quick sweep over the
588   // inserted code, doing constant propagation and dead code elimination as we
589   // go.
590   simplifyLoopAfterUnroll(SubLoop, true, LI, SE, DT, AC, TTI);
591   simplifyLoopAfterUnroll(L, !CompletelyUnroll && Count > 1, LI, SE, DT, AC,
592                           TTI);
593 
594   NumCompletelyUnrolledAndJammed += CompletelyUnroll;
595   ++NumUnrolledAndJammed;
596 
597   // Update LoopInfo if the loop is completely removed.
598   if (CompletelyUnroll)
599     LI->erase(L);
600 
601 #ifndef NDEBUG
602   // We shouldn't have done anything to break loop simplify form or LCSSA.
603   Loop *OutestLoop = SubLoop->getParentLoop()
604                          ? SubLoop->getParentLoop()->getParentLoop()
605                                ? SubLoop->getParentLoop()->getParentLoop()
606                                : SubLoop->getParentLoop()
607                          : SubLoop;
608   assert(DT->verify());
609   LI->verify(*DT);
610   assert(OutestLoop->isRecursivelyLCSSAForm(*DT, *LI));
611   if (!CompletelyUnroll)
612     assert(L->isLoopSimplifyForm());
613   assert(SubLoop->isLoopSimplifyForm());
614   SE->verify();
615 #endif
616 
617   return CompletelyUnroll ? LoopUnrollResult::FullyUnrolled
618                           : LoopUnrollResult::PartiallyUnrolled;
619 }
620 
621 static bool getLoadsAndStores(BasicBlockSet &Blocks,
622                               SmallVector<Value *, 4> &MemInstr) {
623   // Scan the BBs and collect legal loads and stores.
624   // Returns false if non-simple loads/stores are found.
625   for (BasicBlock *BB : Blocks) {
626     for (Instruction &I : *BB) {
627       if (auto *Ld = dyn_cast<LoadInst>(&I)) {
628         if (!Ld->isSimple())
629           return false;
630         MemInstr.push_back(&I);
631       } else if (auto *St = dyn_cast<StoreInst>(&I)) {
632         if (!St->isSimple())
633           return false;
634         MemInstr.push_back(&I);
635       } else if (I.mayReadOrWriteMemory()) {
636         return false;
637       }
638     }
639   }
640   return true;
641 }
642 
643 static bool checkDependencies(SmallVector<Value *, 4> &Earlier,
644                               SmallVector<Value *, 4> &Later,
645                               unsigned LoopDepth, bool InnerLoop,
646                               DependenceInfo &DI) {
647   // Use DA to check for dependencies between loads and stores that make unroll
648   // and jam invalid
649   for (Value *I : Earlier) {
650     for (Value *J : Later) {
651       Instruction *Src = cast<Instruction>(I);
652       Instruction *Dst = cast<Instruction>(J);
653       if (Src == Dst)
654         continue;
655       // Ignore Input dependencies.
656       if (isa<LoadInst>(Src) && isa<LoadInst>(Dst))
657         continue;
658 
659       // Track dependencies, and if we find them take a conservative approach
660       // by allowing only = or < (not >), altough some > would be safe
661       // (depending upon unroll width).
662       // For the inner loop, we need to disallow any (> <) dependencies
663       // FIXME: Allow > so long as distance is less than unroll width
664       if (auto D = DI.depends(Src, Dst, true)) {
665         assert(D->isOrdered() && "Expected an output, flow or anti dep.");
666 
667         if (D->isConfused()) {
668           LLVM_DEBUG(dbgs() << "  Confused dependency between:\n"
669                             << "  " << *Src << "\n"
670                             << "  " << *Dst << "\n");
671           return false;
672         }
673         if (!InnerLoop) {
674           if (D->getDirection(LoopDepth) & Dependence::DVEntry::GT) {
675             LLVM_DEBUG(dbgs() << "  > dependency between:\n"
676                               << "  " << *Src << "\n"
677                               << "  " << *Dst << "\n");
678             return false;
679           }
680         } else {
681           assert(LoopDepth + 1 <= D->getLevels());
682           if (D->getDirection(LoopDepth) & Dependence::DVEntry::GT &&
683               D->getDirection(LoopDepth + 1) & Dependence::DVEntry::LT) {
684             LLVM_DEBUG(dbgs() << "  < > dependency between:\n"
685                               << "  " << *Src << "\n"
686                               << "  " << *Dst << "\n");
687             return false;
688           }
689         }
690       }
691     }
692   }
693   return true;
694 }
695 
696 static bool checkDependencies(Loop *L, BasicBlockSet &ForeBlocks,
697                               BasicBlockSet &SubLoopBlocks,
698                               BasicBlockSet &AftBlocks, DependenceInfo &DI) {
699   // Get all loads/store pairs for each blocks
700   SmallVector<Value *, 4> ForeMemInstr;
701   SmallVector<Value *, 4> SubLoopMemInstr;
702   SmallVector<Value *, 4> AftMemInstr;
703   if (!getLoadsAndStores(ForeBlocks, ForeMemInstr) ||
704       !getLoadsAndStores(SubLoopBlocks, SubLoopMemInstr) ||
705       !getLoadsAndStores(AftBlocks, AftMemInstr))
706     return false;
707 
708   // Check for dependencies between any blocks that may change order
709   unsigned LoopDepth = L->getLoopDepth();
710   return checkDependencies(ForeMemInstr, SubLoopMemInstr, LoopDepth, false,
711                            DI) &&
712          checkDependencies(ForeMemInstr, AftMemInstr, LoopDepth, false, DI) &&
713          checkDependencies(SubLoopMemInstr, AftMemInstr, LoopDepth, false,
714                            DI) &&
715          checkDependencies(SubLoopMemInstr, SubLoopMemInstr, LoopDepth, true,
716                            DI);
717 }
718 
719 bool llvm::isSafeToUnrollAndJam(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
720                                 DependenceInfo &DI) {
721   /* We currently handle outer loops like this:
722         |
723     ForeFirst    <----\    }
724      Blocks           |    } ForeBlocks
725     ForeLast          |    }
726         |             |
727     SubLoopFirst  <\  |    }
728      Blocks        |  |    } SubLoopBlocks
729     SubLoopLast   -/  |    }
730         |             |
731     AftFirst          |    }
732      Blocks           |    } AftBlocks
733     AftLast     ------/    }
734         |
735 
736     There are (theoretically) any number of blocks in ForeBlocks, SubLoopBlocks
737     and AftBlocks, providing that there is one edge from Fores to SubLoops,
738     one edge from SubLoops to Afts and a single outer loop exit (from Afts).
739     In practice we currently limit Aft blocks to a single block, and limit
740     things further in the profitablility checks of the unroll and jam pass.
741 
742     Because of the way we rearrange basic blocks, we also require that
743     the Fore blocks on all unrolled iterations are safe to move before the
744     SubLoop blocks of all iterations. So we require that the phi node looping
745     operands of ForeHeader can be moved to at least the end of ForeEnd, so that
746     we can arrange cloned Fore Blocks before the subloop and match up Phi's
747     correctly.
748 
749     i.e. The old order of blocks used to be F1 S1_1 S1_2 A1 F2 S2_1 S2_2 A2.
750     It needs to be safe to tranform this to F1 F2 S1_1 S2_1 S1_2 S2_2 A1 A2.
751 
752     There are then a number of checks along the lines of no calls, no
753     exceptions, inner loop IV is consistent, etc. Note that for loops requiring
754     runtime unrolling, UnrollRuntimeLoopRemainder can also fail in
755     UnrollAndJamLoop if the trip count cannot be easily calculated.
756   */
757 
758   if (!L->isLoopSimplifyForm() || L->getSubLoops().size() != 1)
759     return false;
760   Loop *SubLoop = L->getSubLoops()[0];
761   if (!SubLoop->isLoopSimplifyForm())
762     return false;
763 
764   BasicBlock *Header = L->getHeader();
765   BasicBlock *Latch = L->getLoopLatch();
766   BasicBlock *Exit = L->getExitingBlock();
767   BasicBlock *SubLoopHeader = SubLoop->getHeader();
768   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
769   BasicBlock *SubLoopExit = SubLoop->getExitingBlock();
770 
771   if (Latch != Exit)
772     return false;
773   if (SubLoopLatch != SubLoopExit)
774     return false;
775 
776   if (Header->hasAddressTaken() || SubLoopHeader->hasAddressTaken()) {
777     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Address taken\n");
778     return false;
779   }
780 
781   // Split blocks into Fore/SubLoop/Aft based on dominators
782   BasicBlockSet SubLoopBlocks;
783   BasicBlockSet ForeBlocks;
784   BasicBlockSet AftBlocks;
785   if (!partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks,
786                                 AftBlocks, &DT)) {
787     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Incompatible loop layout\n");
788     return false;
789   }
790 
791   // Aft blocks may need to move instructions to fore blocks, which becomes more
792   // difficult if there are multiple (potentially conditionally executed)
793   // blocks. For now we just exclude loops with multiple aft blocks.
794   if (AftBlocks.size() != 1) {
795     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Can't currently handle "
796                          "multiple blocks after the loop\n");
797     return false;
798   }
799 
800   // Check inner loop backedge count is consistent on all iterations of the
801   // outer loop
802   if (!hasIterationCountInvariantInParent(SubLoop, SE)) {
803     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Inner loop iteration count is "
804                          "not consistent on each iteration\n");
805     return false;
806   }
807 
808   // Check the loop safety info for exceptions.
809   SimpleLoopSafetyInfo LSI;
810   LSI.computeLoopSafetyInfo(L);
811   if (LSI.anyBlockMayThrow()) {
812     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Something may throw\n");
813     return false;
814   }
815 
816   // We've ruled out the easy stuff and now need to check that there are no
817   // interdependencies which may prevent us from moving the:
818   //  ForeBlocks before Subloop and AftBlocks.
819   //  Subloop before AftBlocks.
820   //  ForeBlock phi operands before the subloop
821 
822   // Make sure we can move all instructions we need to before the subloop
823   if (!processHeaderPhiOperands(
824           Header, Latch, AftBlocks, [&AftBlocks, &SubLoop](Instruction *I) {
825             if (SubLoop->contains(I->getParent()))
826               return false;
827             if (AftBlocks.count(I->getParent())) {
828               // If we hit a phi node in afts we know we are done (probably
829               // LCSSA)
830               if (isa<PHINode>(I))
831                 return false;
832               // Can't move instructions with side effects or memory
833               // reads/writes
834               if (I->mayHaveSideEffects() || I->mayReadOrWriteMemory())
835                 return false;
836             }
837             // Keep going
838             return true;
839           })) {
840     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; can't move required "
841                          "instructions after subloop to before it\n");
842     return false;
843   }
844 
845   // Check for memory dependencies which prohibit the unrolling we are doing.
846   // Because of the way we are unrolling Fore/Sub/Aft blocks, we need to check
847   // there are no dependencies between Fore-Sub, Fore-Aft, Sub-Aft and Sub-Sub.
848   if (!checkDependencies(L, ForeBlocks, SubLoopBlocks, AftBlocks, DI)) {
849     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; failed dependency check\n");
850     return false;
851   }
852 
853   return true;
854 }
855