1 //===-- LoopUnrollAndJam.cpp - Loop unrolling utilities -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop unroll and jam as a routine, much like
10 // LoopUnroll.cpp implements loop unroll.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/ArrayRef.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/Optional.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SmallPtrSet.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/Statistic.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/ADT/Twine.h"
24 #include "llvm/ADT/iterator_range.h"
25 #include "llvm/Analysis/AssumptionCache.h"
26 #include "llvm/Analysis/DependenceAnalysis.h"
27 #include "llvm/Analysis/DomTreeUpdater.h"
28 #include "llvm/Analysis/LoopInfo.h"
29 #include "llvm/Analysis/LoopIterator.h"
30 #include "llvm/Analysis/MustExecute.h"
31 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
32 #include "llvm/Analysis/ScalarEvolution.h"
33 #include "llvm/IR/BasicBlock.h"
34 #include "llvm/IR/DebugInfoMetadata.h"
35 #include "llvm/IR/DebugLoc.h"
36 #include "llvm/IR/DiagnosticInfo.h"
37 #include "llvm/IR/Dominators.h"
38 #include "llvm/IR/Function.h"
39 #include "llvm/IR/Instruction.h"
40 #include "llvm/IR/Instructions.h"
41 #include "llvm/IR/IntrinsicInst.h"
42 #include "llvm/IR/Use.h"
43 #include "llvm/IR/User.h"
44 #include "llvm/IR/Value.h"
45 #include "llvm/IR/ValueHandle.h"
46 #include "llvm/IR/ValueMap.h"
47 #include "llvm/Support/Casting.h"
48 #include "llvm/Support/Debug.h"
49 #include "llvm/Support/ErrorHandling.h"
50 #include "llvm/Support/GenericDomTree.h"
51 #include "llvm/Support/raw_ostream.h"
52 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
53 #include "llvm/Transforms/Utils/Cloning.h"
54 #include "llvm/Transforms/Utils/LoopUtils.h"
55 #include "llvm/Transforms/Utils/UnrollLoop.h"
56 #include "llvm/Transforms/Utils/ValueMapper.h"
57 #include <assert.h>
58 #include <memory>
59 #include <type_traits>
60 #include <vector>
61 
62 using namespace llvm;
63 
64 #define DEBUG_TYPE "loop-unroll-and-jam"
65 
66 STATISTIC(NumUnrolledAndJammed, "Number of loops unroll and jammed");
67 STATISTIC(NumCompletelyUnrolledAndJammed, "Number of loops unroll and jammed");
68 
69 typedef SmallPtrSet<BasicBlock *, 4> BasicBlockSet;
70 
71 // Partition blocks in an outer/inner loop pair into blocks before and after
72 // the loop
73 static bool partitionLoopBlocks(Loop &L, BasicBlockSet &ForeBlocks,
74                                 BasicBlockSet &AftBlocks, DominatorTree &DT) {
75   Loop *SubLoop = L.getSubLoops()[0];
76   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
77 
78   for (BasicBlock *BB : L.blocks()) {
79     if (!SubLoop->contains(BB)) {
80       if (DT.dominates(SubLoopLatch, BB))
81         AftBlocks.insert(BB);
82       else
83         ForeBlocks.insert(BB);
84     }
85   }
86 
87   // Check that all blocks in ForeBlocks together dominate the subloop
88   // TODO: This might ideally be done better with a dominator/postdominators.
89   BasicBlock *SubLoopPreHeader = SubLoop->getLoopPreheader();
90   for (BasicBlock *BB : ForeBlocks) {
91     if (BB == SubLoopPreHeader)
92       continue;
93     Instruction *TI = BB->getTerminator();
94     for (BasicBlock *Succ : successors(TI))
95       if (!ForeBlocks.count(Succ))
96         return false;
97   }
98 
99   return true;
100 }
101 
102 /// Partition blocks in a loop nest into blocks before and after each inner
103 /// loop.
104 static bool partitionOuterLoopBlocks(
105     Loop &Root, Loop &JamLoop, BasicBlockSet &JamLoopBlocks,
106     DenseMap<Loop *, BasicBlockSet> &ForeBlocksMap,
107     DenseMap<Loop *, BasicBlockSet> &AftBlocksMap, DominatorTree &DT) {
108   JamLoopBlocks.insert(JamLoop.block_begin(), JamLoop.block_end());
109 
110   for (Loop *L : Root.getLoopsInPreorder()) {
111     if (L == &JamLoop)
112       break;
113 
114     if (!partitionLoopBlocks(*L, ForeBlocksMap[L], AftBlocksMap[L], DT))
115       return false;
116   }
117 
118   return true;
119 }
120 
121 // TODO Remove when UnrollAndJamLoop changed to support unroll and jamming more
122 // than 2 levels loop.
123 static bool partitionOuterLoopBlocks(Loop *L, Loop *SubLoop,
124                                      BasicBlockSet &ForeBlocks,
125                                      BasicBlockSet &SubLoopBlocks,
126                                      BasicBlockSet &AftBlocks,
127                                      DominatorTree *DT) {
128   SubLoopBlocks.insert(SubLoop->block_begin(), SubLoop->block_end());
129   return partitionLoopBlocks(*L, ForeBlocks, AftBlocks, *DT);
130 }
131 
132 // Looks at the phi nodes in Header for values coming from Latch. For these
133 // instructions and all their operands calls Visit on them, keeping going for
134 // all the operands in AftBlocks. Returns false if Visit returns false,
135 // otherwise returns true. This is used to process the instructions in the
136 // Aft blocks that need to be moved before the subloop. It is used in two
137 // places. One to check that the required set of instructions can be moved
138 // before the loop. Then to collect the instructions to actually move in
139 // moveHeaderPhiOperandsToForeBlocks.
140 template <typename T>
141 static bool processHeaderPhiOperands(BasicBlock *Header, BasicBlock *Latch,
142                                      BasicBlockSet &AftBlocks, T Visit) {
143   SmallVector<Instruction *, 8> Worklist;
144   for (auto &Phi : Header->phis()) {
145     Value *V = Phi.getIncomingValueForBlock(Latch);
146     if (Instruction *I = dyn_cast<Instruction>(V))
147       Worklist.push_back(I);
148   }
149 
150   while (!Worklist.empty()) {
151     Instruction *I = Worklist.back();
152     Worklist.pop_back();
153     if (!Visit(I))
154       return false;
155 
156     if (AftBlocks.count(I->getParent()))
157       for (auto &U : I->operands())
158         if (Instruction *II = dyn_cast<Instruction>(U))
159           Worklist.push_back(II);
160   }
161 
162   return true;
163 }
164 
165 // Move the phi operands of Header from Latch out of AftBlocks to InsertLoc.
166 static void moveHeaderPhiOperandsToForeBlocks(BasicBlock *Header,
167                                               BasicBlock *Latch,
168                                               Instruction *InsertLoc,
169                                               BasicBlockSet &AftBlocks) {
170   // We need to ensure we move the instructions in the correct order,
171   // starting with the earliest required instruction and moving forward.
172   std::vector<Instruction *> Visited;
173   processHeaderPhiOperands(Header, Latch, AftBlocks,
174                            [&Visited, &AftBlocks](Instruction *I) {
175                              if (AftBlocks.count(I->getParent()))
176                                Visited.push_back(I);
177                              return true;
178                            });
179 
180   // Move all instructions in program order to before the InsertLoc
181   BasicBlock *InsertLocBB = InsertLoc->getParent();
182   for (Instruction *I : reverse(Visited)) {
183     if (I->getParent() != InsertLocBB)
184       I->moveBefore(InsertLoc);
185   }
186 }
187 
188 /*
189   This method performs Unroll and Jam. For a simple loop like:
190   for (i = ..)
191     Fore(i)
192     for (j = ..)
193       SubLoop(i, j)
194     Aft(i)
195 
196   Instead of doing normal inner or outer unrolling, we do:
197   for (i = .., i+=2)
198     Fore(i)
199     Fore(i+1)
200     for (j = ..)
201       SubLoop(i, j)
202       SubLoop(i+1, j)
203     Aft(i)
204     Aft(i+1)
205 
206   So the outer loop is essetially unrolled and then the inner loops are fused
207   ("jammed") together into a single loop. This can increase speed when there
208   are loads in SubLoop that are invariant to i, as they become shared between
209   the now jammed inner loops.
210 
211   We do this by spliting the blocks in the loop into Fore, Subloop and Aft.
212   Fore blocks are those before the inner loop, Aft are those after. Normal
213   Unroll code is used to copy each of these sets of blocks and the results are
214   combined together into the final form above.
215 
216   isSafeToUnrollAndJam should be used prior to calling this to make sure the
217   unrolling will be valid. Checking profitablility is also advisable.
218 
219   If EpilogueLoop is non-null, it receives the epilogue loop (if it was
220   necessary to create one and not fully unrolled).
221 */
222 LoopUnrollResult
223 llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount,
224                        unsigned TripMultiple, bool UnrollRemainder,
225                        LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT,
226                        AssumptionCache *AC, const TargetTransformInfo *TTI,
227                        OptimizationRemarkEmitter *ORE, Loop **EpilogueLoop) {
228 
229   // When we enter here we should have already checked that it is safe
230   BasicBlock *Header = L->getHeader();
231   assert(Header && "No header.");
232   assert(L->getSubLoops().size() == 1);
233   Loop *SubLoop = *L->begin();
234 
235   // Don't enter the unroll code if there is nothing to do.
236   if (TripCount == 0 && Count < 2) {
237     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; almost nothing to do\n");
238     return LoopUnrollResult::Unmodified;
239   }
240 
241   assert(Count > 0);
242   assert(TripMultiple > 0);
243   assert(TripCount == 0 || TripCount % TripMultiple == 0);
244 
245   // Are we eliminating the loop control altogether?
246   bool CompletelyUnroll = (Count == TripCount);
247 
248   // We use the runtime remainder in cases where we don't know trip multiple
249   if (TripMultiple == 1 || TripMultiple % Count != 0) {
250     if (!UnrollRuntimeLoopRemainder(L, Count, /*AllowExpensiveTripCount*/ false,
251                                     /*UseEpilogRemainder*/ true,
252                                     UnrollRemainder, /*ForgetAllSCEV*/ false,
253                                     LI, SE, DT, AC, TTI, true, EpilogueLoop)) {
254       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; remainder loop could not be "
255                            "generated when assuming runtime trip count\n");
256       return LoopUnrollResult::Unmodified;
257     }
258   }
259 
260   // Notify ScalarEvolution that the loop will be substantially changed,
261   // if not outright eliminated.
262   if (SE) {
263     SE->forgetLoop(L);
264     SE->forgetLoop(SubLoop);
265   }
266 
267   using namespace ore;
268   // Report the unrolling decision.
269   if (CompletelyUnroll) {
270     LLVM_DEBUG(dbgs() << "COMPLETELY UNROLL AND JAMMING loop %"
271                       << Header->getName() << " with trip count " << TripCount
272                       << "!\n");
273     ORE->emit(OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(),
274                                  L->getHeader())
275               << "completely unroll and jammed loop with "
276               << NV("UnrollCount", TripCount) << " iterations");
277   } else {
278     auto DiagBuilder = [&]() {
279       OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(),
280                               L->getHeader());
281       return Diag << "unroll and jammed loop by a factor of "
282                   << NV("UnrollCount", Count);
283     };
284 
285     LLVM_DEBUG(dbgs() << "UNROLL AND JAMMING loop %" << Header->getName()
286                       << " by " << Count);
287     if (TripMultiple != 1) {
288       LLVM_DEBUG(dbgs() << " with " << TripMultiple << " trips per branch");
289       ORE->emit([&]() {
290         return DiagBuilder() << " with " << NV("TripMultiple", TripMultiple)
291                              << " trips per branch";
292       });
293     } else {
294       LLVM_DEBUG(dbgs() << " with run-time trip count");
295       ORE->emit([&]() { return DiagBuilder() << " with run-time trip count"; });
296     }
297     LLVM_DEBUG(dbgs() << "!\n");
298   }
299 
300   BasicBlock *Preheader = L->getLoopPreheader();
301   BasicBlock *LatchBlock = L->getLoopLatch();
302   assert(Preheader && "No preheader");
303   assert(LatchBlock && "No latch block");
304   BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator());
305   assert(BI && !BI->isUnconditional());
306   bool ContinueOnTrue = L->contains(BI->getSuccessor(0));
307   BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue);
308   bool SubLoopContinueOnTrue = SubLoop->contains(
309       SubLoop->getLoopLatch()->getTerminator()->getSuccessor(0));
310 
311   // Partition blocks in an outer/inner loop pair into blocks before and after
312   // the loop
313   BasicBlockSet SubLoopBlocks;
314   BasicBlockSet ForeBlocks;
315   BasicBlockSet AftBlocks;
316   partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks, AftBlocks,
317                            DT);
318 
319   // We keep track of the entering/first and exiting/last block of each of
320   // Fore/SubLoop/Aft in each iteration. This helps make the stapling up of
321   // blocks easier.
322   std::vector<BasicBlock *> ForeBlocksFirst;
323   std::vector<BasicBlock *> ForeBlocksLast;
324   std::vector<BasicBlock *> SubLoopBlocksFirst;
325   std::vector<BasicBlock *> SubLoopBlocksLast;
326   std::vector<BasicBlock *> AftBlocksFirst;
327   std::vector<BasicBlock *> AftBlocksLast;
328   ForeBlocksFirst.push_back(Header);
329   ForeBlocksLast.push_back(SubLoop->getLoopPreheader());
330   SubLoopBlocksFirst.push_back(SubLoop->getHeader());
331   SubLoopBlocksLast.push_back(SubLoop->getExitingBlock());
332   AftBlocksFirst.push_back(SubLoop->getExitBlock());
333   AftBlocksLast.push_back(L->getExitingBlock());
334   // Maps Blocks[0] -> Blocks[It]
335   ValueToValueMapTy LastValueMap;
336 
337   // Move any instructions from fore phi operands from AftBlocks into Fore.
338   moveHeaderPhiOperandsToForeBlocks(
339       Header, LatchBlock, ForeBlocksLast[0]->getTerminator(), AftBlocks);
340 
341   // The current on-the-fly SSA update requires blocks to be processed in
342   // reverse postorder so that LastValueMap contains the correct value at each
343   // exit.
344   LoopBlocksDFS DFS(L);
345   DFS.perform(LI);
346   // Stash the DFS iterators before adding blocks to the loop.
347   LoopBlocksDFS::RPOIterator BlockBegin = DFS.beginRPO();
348   LoopBlocksDFS::RPOIterator BlockEnd = DFS.endRPO();
349 
350   if (Header->getParent()->isDebugInfoForProfiling())
351     for (BasicBlock *BB : L->getBlocks())
352       for (Instruction &I : *BB)
353         if (!isa<DbgInfoIntrinsic>(&I))
354           if (const DILocation *DIL = I.getDebugLoc()) {
355             auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(Count);
356             if (NewDIL)
357               I.setDebugLoc(NewDIL.getValue());
358             else
359               LLVM_DEBUG(dbgs()
360                          << "Failed to create new discriminator: "
361                          << DIL->getFilename() << " Line: " << DIL->getLine());
362           }
363 
364   // Copy all blocks
365   for (unsigned It = 1; It != Count; ++It) {
366     SmallVector<BasicBlock *, 8> NewBlocks;
367     // Maps Blocks[It] -> Blocks[It-1]
368     DenseMap<Value *, Value *> PrevItValueMap;
369 
370     for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) {
371       ValueToValueMapTy VMap;
372       BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It));
373       Header->getParent()->getBasicBlockList().push_back(New);
374 
375       if (ForeBlocks.count(*BB)) {
376         L->addBasicBlockToLoop(New, *LI);
377 
378         if (*BB == ForeBlocksFirst[0])
379           ForeBlocksFirst.push_back(New);
380         if (*BB == ForeBlocksLast[0])
381           ForeBlocksLast.push_back(New);
382       } else if (SubLoopBlocks.count(*BB)) {
383         SubLoop->addBasicBlockToLoop(New, *LI);
384 
385         if (*BB == SubLoopBlocksFirst[0])
386           SubLoopBlocksFirst.push_back(New);
387         if (*BB == SubLoopBlocksLast[0])
388           SubLoopBlocksLast.push_back(New);
389       } else if (AftBlocks.count(*BB)) {
390         L->addBasicBlockToLoop(New, *LI);
391 
392         if (*BB == AftBlocksFirst[0])
393           AftBlocksFirst.push_back(New);
394         if (*BB == AftBlocksLast[0])
395           AftBlocksLast.push_back(New);
396       } else {
397         llvm_unreachable("BB being cloned should be in Fore/Sub/Aft");
398       }
399 
400       // Update our running maps of newest clones
401       PrevItValueMap[New] = (It == 1 ? *BB : LastValueMap[*BB]);
402       LastValueMap[*BB] = New;
403       for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end();
404            VI != VE; ++VI) {
405         PrevItValueMap[VI->second] =
406             const_cast<Value *>(It == 1 ? VI->first : LastValueMap[VI->first]);
407         LastValueMap[VI->first] = VI->second;
408       }
409 
410       NewBlocks.push_back(New);
411 
412       // Update DomTree:
413       if (*BB == ForeBlocksFirst[0])
414         DT->addNewBlock(New, ForeBlocksLast[It - 1]);
415       else if (*BB == SubLoopBlocksFirst[0])
416         DT->addNewBlock(New, SubLoopBlocksLast[It - 1]);
417       else if (*BB == AftBlocksFirst[0])
418         DT->addNewBlock(New, AftBlocksLast[It - 1]);
419       else {
420         // Each set of blocks (Fore/Sub/Aft) will have the same internal domtree
421         // structure.
422         auto BBDomNode = DT->getNode(*BB);
423         auto BBIDom = BBDomNode->getIDom();
424         BasicBlock *OriginalBBIDom = BBIDom->getBlock();
425         assert(OriginalBBIDom);
426         assert(LastValueMap[cast<Value>(OriginalBBIDom)]);
427         DT->addNewBlock(
428             New, cast<BasicBlock>(LastValueMap[cast<Value>(OriginalBBIDom)]));
429       }
430     }
431 
432     // Remap all instructions in the most recent iteration
433     remapInstructionsInBlocks(NewBlocks, LastValueMap);
434     for (BasicBlock *NewBlock : NewBlocks) {
435       for (Instruction &I : *NewBlock) {
436         if (auto *II = dyn_cast<IntrinsicInst>(&I))
437           if (II->getIntrinsicID() == Intrinsic::assume)
438             AC->registerAssumption(II);
439       }
440     }
441 
442     // Alter the ForeBlocks phi's, pointing them at the latest version of the
443     // value from the previous iteration's phis
444     for (PHINode &Phi : ForeBlocksFirst[It]->phis()) {
445       Value *OldValue = Phi.getIncomingValueForBlock(AftBlocksLast[It]);
446       assert(OldValue && "should have incoming edge from Aft[It]");
447       Value *NewValue = OldValue;
448       if (Value *PrevValue = PrevItValueMap[OldValue])
449         NewValue = PrevValue;
450 
451       assert(Phi.getNumOperands() == 2);
452       Phi.setIncomingBlock(0, ForeBlocksLast[It - 1]);
453       Phi.setIncomingValue(0, NewValue);
454       Phi.removeIncomingValue(1);
455     }
456   }
457 
458   // Now that all the basic blocks for the unrolled iterations are in place,
459   // finish up connecting the blocks and phi nodes. At this point LastValueMap
460   // is the last unrolled iterations values.
461 
462   // Update Phis in BB from OldBB to point to NewBB
463   auto updatePHIBlocks = [](BasicBlock *BB, BasicBlock *OldBB,
464                             BasicBlock *NewBB) {
465     for (PHINode &Phi : BB->phis()) {
466       int I = Phi.getBasicBlockIndex(OldBB);
467       Phi.setIncomingBlock(I, NewBB);
468     }
469   };
470   // Update Phis in BB from OldBB to point to NewBB and use the latest value
471   // from LastValueMap
472   auto updatePHIBlocksAndValues = [](BasicBlock *BB, BasicBlock *OldBB,
473                                      BasicBlock *NewBB,
474                                      ValueToValueMapTy &LastValueMap) {
475     for (PHINode &Phi : BB->phis()) {
476       for (unsigned b = 0; b < Phi.getNumIncomingValues(); ++b) {
477         if (Phi.getIncomingBlock(b) == OldBB) {
478           Value *OldValue = Phi.getIncomingValue(b);
479           if (Value *LastValue = LastValueMap[OldValue])
480             Phi.setIncomingValue(b, LastValue);
481           Phi.setIncomingBlock(b, NewBB);
482           break;
483         }
484       }
485     }
486   };
487   // Move all the phis from Src into Dest
488   auto movePHIs = [](BasicBlock *Src, BasicBlock *Dest) {
489     Instruction *insertPoint = Dest->getFirstNonPHI();
490     while (PHINode *Phi = dyn_cast<PHINode>(Src->begin()))
491       Phi->moveBefore(insertPoint);
492   };
493 
494   // Update the PHI values outside the loop to point to the last block
495   updatePHIBlocksAndValues(LoopExit, AftBlocksLast[0], AftBlocksLast.back(),
496                            LastValueMap);
497 
498   // Update ForeBlocks successors and phi nodes
499   BranchInst *ForeTerm =
500       cast<BranchInst>(ForeBlocksLast.back()->getTerminator());
501   assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
502   ForeTerm->setSuccessor(0, SubLoopBlocksFirst[0]);
503 
504   if (CompletelyUnroll) {
505     while (PHINode *Phi = dyn_cast<PHINode>(ForeBlocksFirst[0]->begin())) {
506       Phi->replaceAllUsesWith(Phi->getIncomingValueForBlock(Preheader));
507       Phi->getParent()->getInstList().erase(Phi);
508     }
509   } else {
510     // Update the PHI values to point to the last aft block
511     updatePHIBlocksAndValues(ForeBlocksFirst[0], AftBlocksLast[0],
512                              AftBlocksLast.back(), LastValueMap);
513   }
514 
515   for (unsigned It = 1; It != Count; It++) {
516     // Remap ForeBlock successors from previous iteration to this
517     BranchInst *ForeTerm =
518         cast<BranchInst>(ForeBlocksLast[It - 1]->getTerminator());
519     assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
520     ForeTerm->setSuccessor(0, ForeBlocksFirst[It]);
521   }
522 
523   // Subloop successors and phis
524   BranchInst *SubTerm =
525       cast<BranchInst>(SubLoopBlocksLast.back()->getTerminator());
526   SubTerm->setSuccessor(!SubLoopContinueOnTrue, SubLoopBlocksFirst[0]);
527   SubTerm->setSuccessor(SubLoopContinueOnTrue, AftBlocksFirst[0]);
528   updatePHIBlocks(SubLoopBlocksFirst[0], ForeBlocksLast[0],
529                   ForeBlocksLast.back());
530   updatePHIBlocks(SubLoopBlocksFirst[0], SubLoopBlocksLast[0],
531                   SubLoopBlocksLast.back());
532 
533   for (unsigned It = 1; It != Count; It++) {
534     // Replace the conditional branch of the previous iteration subloop with an
535     // unconditional one to this one
536     BranchInst *SubTerm =
537         cast<BranchInst>(SubLoopBlocksLast[It - 1]->getTerminator());
538     BranchInst::Create(SubLoopBlocksFirst[It], SubTerm);
539     SubTerm->eraseFromParent();
540 
541     updatePHIBlocks(SubLoopBlocksFirst[It], ForeBlocksLast[It],
542                     ForeBlocksLast.back());
543     updatePHIBlocks(SubLoopBlocksFirst[It], SubLoopBlocksLast[It],
544                     SubLoopBlocksLast.back());
545     movePHIs(SubLoopBlocksFirst[It], SubLoopBlocksFirst[0]);
546   }
547 
548   // Aft blocks successors and phis
549   BranchInst *AftTerm = cast<BranchInst>(AftBlocksLast.back()->getTerminator());
550   if (CompletelyUnroll) {
551     BranchInst::Create(LoopExit, AftTerm);
552     AftTerm->eraseFromParent();
553   } else {
554     AftTerm->setSuccessor(!ContinueOnTrue, ForeBlocksFirst[0]);
555     assert(AftTerm->getSuccessor(ContinueOnTrue) == LoopExit &&
556            "Expecting the ContinueOnTrue successor of AftTerm to be LoopExit");
557   }
558   updatePHIBlocks(AftBlocksFirst[0], SubLoopBlocksLast[0],
559                   SubLoopBlocksLast.back());
560 
561   for (unsigned It = 1; It != Count; It++) {
562     // Replace the conditional branch of the previous iteration subloop with an
563     // unconditional one to this one
564     BranchInst *AftTerm =
565         cast<BranchInst>(AftBlocksLast[It - 1]->getTerminator());
566     BranchInst::Create(AftBlocksFirst[It], AftTerm);
567     AftTerm->eraseFromParent();
568 
569     updatePHIBlocks(AftBlocksFirst[It], SubLoopBlocksLast[It],
570                     SubLoopBlocksLast.back());
571     movePHIs(AftBlocksFirst[It], AftBlocksFirst[0]);
572   }
573 
574   DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
575   // Dominator Tree. Remove the old links between Fore, Sub and Aft, adding the
576   // new ones required.
577   if (Count != 1) {
578     SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
579     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete, ForeBlocksLast[0],
580                            SubLoopBlocksFirst[0]);
581     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete,
582                            SubLoopBlocksLast[0], AftBlocksFirst[0]);
583 
584     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
585                            ForeBlocksLast.back(), SubLoopBlocksFirst[0]);
586     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
587                            SubLoopBlocksLast.back(), AftBlocksFirst[0]);
588     DTU.applyUpdatesPermissive(DTUpdates);
589   }
590 
591   // Merge adjacent basic blocks, if possible.
592   SmallPtrSet<BasicBlock *, 16> MergeBlocks;
593   MergeBlocks.insert(ForeBlocksLast.begin(), ForeBlocksLast.end());
594   MergeBlocks.insert(SubLoopBlocksLast.begin(), SubLoopBlocksLast.end());
595   MergeBlocks.insert(AftBlocksLast.begin(), AftBlocksLast.end());
596   while (!MergeBlocks.empty()) {
597     BasicBlock *BB = *MergeBlocks.begin();
598     BranchInst *Term = dyn_cast<BranchInst>(BB->getTerminator());
599     if (Term && Term->isUnconditional() && L->contains(Term->getSuccessor(0))) {
600       BasicBlock *Dest = Term->getSuccessor(0);
601       BasicBlock *Fold = Dest->getUniquePredecessor();
602       if (MergeBlockIntoPredecessor(Dest, &DTU, LI)) {
603         // Don't remove BB and add Fold as they are the same BB
604         assert(Fold == BB);
605         (void)Fold;
606         MergeBlocks.erase(Dest);
607       } else
608         MergeBlocks.erase(BB);
609     } else
610       MergeBlocks.erase(BB);
611   }
612   // Apply updates to the DomTree.
613   DT = &DTU.getDomTree();
614 
615   // At this point, the code is well formed.  We now do a quick sweep over the
616   // inserted code, doing constant propagation and dead code elimination as we
617   // go.
618   simplifyLoopAfterUnroll(SubLoop, true, LI, SE, DT, AC, TTI);
619   simplifyLoopAfterUnroll(L, !CompletelyUnroll && Count > 1, LI, SE, DT, AC,
620                           TTI);
621 
622   NumCompletelyUnrolledAndJammed += CompletelyUnroll;
623   ++NumUnrolledAndJammed;
624 
625   // Update LoopInfo if the loop is completely removed.
626   if (CompletelyUnroll)
627     LI->erase(L);
628 
629 #ifndef NDEBUG
630   // We shouldn't have done anything to break loop simplify form or LCSSA.
631   Loop *OutestLoop = SubLoop->getParentLoop()
632                          ? SubLoop->getParentLoop()->getParentLoop()
633                                ? SubLoop->getParentLoop()->getParentLoop()
634                                : SubLoop->getParentLoop()
635                          : SubLoop;
636   assert(DT->verify());
637   LI->verify(*DT);
638   assert(OutestLoop->isRecursivelyLCSSAForm(*DT, *LI));
639   if (!CompletelyUnroll)
640     assert(L->isLoopSimplifyForm());
641   assert(SubLoop->isLoopSimplifyForm());
642   SE->verify();
643 #endif
644 
645   return CompletelyUnroll ? LoopUnrollResult::FullyUnrolled
646                           : LoopUnrollResult::PartiallyUnrolled;
647 }
648 
649 static bool getLoadsAndStores(BasicBlockSet &Blocks,
650                               SmallVector<Instruction *, 4> &MemInstr) {
651   // Scan the BBs and collect legal loads and stores.
652   // Returns false if non-simple loads/stores are found.
653   for (BasicBlock *BB : Blocks) {
654     for (Instruction &I : *BB) {
655       if (auto *Ld = dyn_cast<LoadInst>(&I)) {
656         if (!Ld->isSimple())
657           return false;
658         MemInstr.push_back(&I);
659       } else if (auto *St = dyn_cast<StoreInst>(&I)) {
660         if (!St->isSimple())
661           return false;
662         MemInstr.push_back(&I);
663       } else if (I.mayReadOrWriteMemory()) {
664         return false;
665       }
666     }
667   }
668   return true;
669 }
670 
671 static bool preservesForwardDependence(Instruction *Src, Instruction *Dst,
672                                        unsigned UnrollLevel, unsigned JamLevel,
673                                        bool Sequentialized, Dependence *D) {
674   // UnrollLevel might carry the dependency Src --> Dst
675   // Does a different loop after unrolling?
676   for (unsigned CurLoopDepth = UnrollLevel + 1; CurLoopDepth <= JamLevel;
677        ++CurLoopDepth) {
678     auto JammedDir = D->getDirection(CurLoopDepth);
679     if (JammedDir == Dependence::DVEntry::LT)
680       return true;
681 
682     if (JammedDir & Dependence::DVEntry::GT)
683       return false;
684   }
685 
686   return true;
687 }
688 
689 static bool preservesBackwardDependence(Instruction *Src, Instruction *Dst,
690                                         unsigned UnrollLevel, unsigned JamLevel,
691                                         bool Sequentialized, Dependence *D) {
692   // UnrollLevel might carry the dependency Dst --> Src
693   for (unsigned CurLoopDepth = UnrollLevel + 1; CurLoopDepth <= JamLevel;
694        ++CurLoopDepth) {
695     auto JammedDir = D->getDirection(CurLoopDepth);
696     if (JammedDir == Dependence::DVEntry::GT)
697       return true;
698 
699     if (JammedDir & Dependence::DVEntry::LT)
700       return false;
701   }
702 
703   // Backward dependencies are only preserved if not interleaved.
704   return Sequentialized;
705 }
706 
707 // Check whether it is semantically safe Src and Dst considering any potential
708 // dependency between them.
709 //
710 // @param UnrollLevel The level of the loop being unrolled
711 // @param JamLevel    The level of the loop being jammed; if Src and Dst are on
712 // different levels, the outermost common loop counts as jammed level
713 //
714 // @return true if is safe and false if there is a dependency violation.
715 static bool checkDependency(Instruction *Src, Instruction *Dst,
716                             unsigned UnrollLevel, unsigned JamLevel,
717                             bool Sequentialized, DependenceInfo &DI) {
718   assert(UnrollLevel <= JamLevel &&
719          "Expecting JamLevel to be at least UnrollLevel");
720 
721   if (Src == Dst)
722     return true;
723   // Ignore Input dependencies.
724   if (isa<LoadInst>(Src) && isa<LoadInst>(Dst))
725     return true;
726 
727   // Check whether unroll-and-jam may violate a dependency.
728   // By construction, every dependency will be lexicographically non-negative
729   // (if it was, it would violate the current execution order), such as
730   //   (0,0,>,*,*)
731   // Unroll-and-jam changes the GT execution of two executions to the same
732   // iteration of the chosen unroll level. That is, a GT dependence becomes a GE
733   // dependence (or EQ, if we fully unrolled the loop) at the loop's position:
734   //   (0,0,>=,*,*)
735   // Now, the dependency is not necessarily non-negative anymore, i.e.
736   // unroll-and-jam may violate correctness.
737   std::unique_ptr<Dependence> D = DI.depends(Src, Dst, true);
738   if (!D)
739     return true;
740   assert(D->isOrdered() && "Expected an output, flow or anti dep.");
741 
742   if (D->isConfused()) {
743     LLVM_DEBUG(dbgs() << "  Confused dependency between:\n"
744                       << "  " << *Src << "\n"
745                       << "  " << *Dst << "\n");
746     return false;
747   }
748 
749   // If outer levels (levels enclosing the loop being unroll-and-jammed) have a
750   // non-equal direction, then the locations accessed in the inner levels cannot
751   // overlap in memory. We assumes the indexes never overlap into neighboring
752   // dimensions.
753   for (unsigned CurLoopDepth = 1; CurLoopDepth < UnrollLevel; ++CurLoopDepth)
754     if (!(D->getDirection(CurLoopDepth) & Dependence::DVEntry::EQ))
755       return true;
756 
757   auto UnrollDirection = D->getDirection(UnrollLevel);
758 
759   // If the distance carried by the unrolled loop is 0, then after unrolling
760   // that distance will become non-zero resulting in non-overlapping accesses in
761   // the inner loops.
762   if (UnrollDirection == Dependence::DVEntry::EQ)
763     return true;
764 
765   if (UnrollDirection & Dependence::DVEntry::LT &&
766       !preservesForwardDependence(Src, Dst, UnrollLevel, JamLevel,
767                                   Sequentialized, D.get()))
768     return false;
769 
770   if (UnrollDirection & Dependence::DVEntry::GT &&
771       !preservesBackwardDependence(Src, Dst, UnrollLevel, JamLevel,
772                                    Sequentialized, D.get()))
773     return false;
774 
775   return true;
776 }
777 
778 static bool
779 checkDependencies(Loop &Root, const BasicBlockSet &SubLoopBlocks,
780                   const DenseMap<Loop *, BasicBlockSet> &ForeBlocksMap,
781                   const DenseMap<Loop *, BasicBlockSet> &AftBlocksMap,
782                   DependenceInfo &DI, LoopInfo &LI) {
783   SmallVector<BasicBlockSet, 8> AllBlocks;
784   for (Loop *L : Root.getLoopsInPreorder())
785     if (ForeBlocksMap.find(L) != ForeBlocksMap.end())
786       AllBlocks.push_back(ForeBlocksMap.lookup(L));
787   AllBlocks.push_back(SubLoopBlocks);
788   for (Loop *L : Root.getLoopsInPreorder())
789     if (AftBlocksMap.find(L) != AftBlocksMap.end())
790       AllBlocks.push_back(AftBlocksMap.lookup(L));
791 
792   unsigned LoopDepth = Root.getLoopDepth();
793   SmallVector<Instruction *, 4> EarlierLoadsAndStores;
794   SmallVector<Instruction *, 4> CurrentLoadsAndStores;
795   for (BasicBlockSet &Blocks : AllBlocks) {
796     CurrentLoadsAndStores.clear();
797     if (!getLoadsAndStores(Blocks, CurrentLoadsAndStores))
798       return false;
799 
800     Loop *CurLoop = LI.getLoopFor((*Blocks.begin())->front().getParent());
801     unsigned CurLoopDepth = CurLoop->getLoopDepth();
802 
803     for (auto *Earlier : EarlierLoadsAndStores) {
804       Loop *EarlierLoop = LI.getLoopFor(Earlier->getParent());
805       unsigned EarlierDepth = EarlierLoop->getLoopDepth();
806       unsigned CommonLoopDepth = std::min(EarlierDepth, CurLoopDepth);
807       for (auto *Later : CurrentLoadsAndStores) {
808         if (!checkDependency(Earlier, Later, LoopDepth, CommonLoopDepth, false,
809                              DI))
810           return false;
811       }
812     }
813 
814     size_t NumInsts = CurrentLoadsAndStores.size();
815     for (size_t I = 0; I < NumInsts; ++I) {
816       for (size_t J = I; J < NumInsts; ++J) {
817         if (!checkDependency(CurrentLoadsAndStores[I], CurrentLoadsAndStores[J],
818                              LoopDepth, CurLoopDepth, true, DI))
819           return false;
820       }
821     }
822 
823     EarlierLoadsAndStores.append(CurrentLoadsAndStores.begin(),
824                                  CurrentLoadsAndStores.end());
825   }
826   return true;
827 }
828 
829 static bool isEligibleLoopForm(const Loop &Root) {
830   // Root must have a child.
831   if (Root.getSubLoops().size() != 1)
832     return false;
833 
834   const Loop *L = &Root;
835   do {
836     // All loops in Root need to be in simplify and rotated form.
837     if (!L->isLoopSimplifyForm())
838       return false;
839 
840     if (!L->isRotatedForm())
841       return false;
842 
843     if (L->getHeader()->hasAddressTaken()) {
844       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Address taken\n");
845       return false;
846     }
847 
848     unsigned SubLoopsSize = L->getSubLoops().size();
849     if (SubLoopsSize == 0)
850       return true;
851 
852     // Only one child is allowed.
853     if (SubLoopsSize != 1)
854       return false;
855 
856     L = L->getSubLoops()[0];
857   } while (L);
858 
859   return true;
860 }
861 
862 static Loop *getInnerMostLoop(Loop *L) {
863   while (!L->getSubLoops().empty())
864     L = L->getSubLoops()[0];
865   return L;
866 }
867 
868 bool llvm::isSafeToUnrollAndJam(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
869                                 DependenceInfo &DI, LoopInfo &LI) {
870   if (!isEligibleLoopForm(*L)) {
871     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Ineligible loop form\n");
872     return false;
873   }
874 
875   /* We currently handle outer loops like this:
876         |
877     ForeFirst    <------\   }
878      Blocks             |   } ForeBlocks of L
879     ForeLast            |   }
880         |               |
881        ...              |
882         |               |
883     ForeFirst    <----\ |   }
884      Blocks           | |   } ForeBlocks of a inner loop of L
885     ForeLast          | |   }
886         |             | |
887     JamLoopFirst  <\  | |   }
888      Blocks        |  | |   } JamLoopBlocks of the innermost loop
889     JamLoopLast   -/  | |   }
890         |             | |
891     AftFirst          | |   }
892      Blocks           | |   } AftBlocks of a inner loop of L
893     AftLast     ------/ |   }
894         |               |
895        ...              |
896         |               |
897     AftFirst            |   }
898      Blocks             |   } AftBlocks of L
899     AftLast     --------/   }
900         |
901 
902     There are (theoretically) any number of blocks in ForeBlocks, SubLoopBlocks
903     and AftBlocks, providing that there is one edge from Fores to SubLoops,
904     one edge from SubLoops to Afts and a single outer loop exit (from Afts).
905     In practice we currently limit Aft blocks to a single block, and limit
906     things further in the profitablility checks of the unroll and jam pass.
907 
908     Because of the way we rearrange basic blocks, we also require that
909     the Fore blocks of L on all unrolled iterations are safe to move before the
910     blocks of the direct child of L of all iterations. So we require that the
911     phi node looping operands of ForeHeader can be moved to at least the end of
912     ForeEnd, so that we can arrange cloned Fore Blocks before the subloop and
913     match up Phi's correctly.
914 
915     i.e. The old order of blocks used to be
916            (F1)1 (F2)1 J1_1 J1_2 (A2)1 (A1)1 (F1)2 (F2)2 J2_1 J2_2 (A2)2 (A1)2.
917          It needs to be safe to transform this to
918            (F1)1 (F1)2 (F2)1 (F2)2 J1_1 J1_2 J2_1 J2_2 (A2)1 (A2)2 (A1)1 (A1)2.
919 
920     There are then a number of checks along the lines of no calls, no
921     exceptions, inner loop IV is consistent, etc. Note that for loops requiring
922     runtime unrolling, UnrollRuntimeLoopRemainder can also fail in
923     UnrollAndJamLoop if the trip count cannot be easily calculated.
924   */
925 
926   // Split blocks into Fore/SubLoop/Aft based on dominators
927   Loop *JamLoop = getInnerMostLoop(L);
928   BasicBlockSet SubLoopBlocks;
929   DenseMap<Loop *, BasicBlockSet> ForeBlocksMap;
930   DenseMap<Loop *, BasicBlockSet> AftBlocksMap;
931   if (!partitionOuterLoopBlocks(*L, *JamLoop, SubLoopBlocks, ForeBlocksMap,
932                                 AftBlocksMap, DT)) {
933     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Incompatible loop layout\n");
934     return false;
935   }
936 
937   // Aft blocks may need to move instructions to fore blocks, which becomes more
938   // difficult if there are multiple (potentially conditionally executed)
939   // blocks. For now we just exclude loops with multiple aft blocks.
940   if (AftBlocksMap[L].size() != 1) {
941     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Can't currently handle "
942                          "multiple blocks after the loop\n");
943     return false;
944   }
945 
946   // Check inner loop backedge count is consistent on all iterations of the
947   // outer loop
948   if (any_of(L->getLoopsInPreorder(), [&SE](Loop *SubLoop) {
949         return !hasIterationCountInvariantInParent(SubLoop, SE);
950       })) {
951     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Inner loop iteration count is "
952                          "not consistent on each iteration\n");
953     return false;
954   }
955 
956   // Check the loop safety info for exceptions.
957   SimpleLoopSafetyInfo LSI;
958   LSI.computeLoopSafetyInfo(L);
959   if (LSI.anyBlockMayThrow()) {
960     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Something may throw\n");
961     return false;
962   }
963 
964   // We've ruled out the easy stuff and now need to check that there are no
965   // interdependencies which may prevent us from moving the:
966   //  ForeBlocks before Subloop and AftBlocks.
967   //  Subloop before AftBlocks.
968   //  ForeBlock phi operands before the subloop
969 
970   // Make sure we can move all instructions we need to before the subloop
971   BasicBlock *Header = L->getHeader();
972   BasicBlock *Latch = L->getLoopLatch();
973   BasicBlockSet AftBlocks = AftBlocksMap[L];
974   Loop *SubLoop = L->getSubLoops()[0];
975   if (!processHeaderPhiOperands(
976           Header, Latch, AftBlocks, [&AftBlocks, &SubLoop](Instruction *I) {
977             if (SubLoop->contains(I->getParent()))
978               return false;
979             if (AftBlocks.count(I->getParent())) {
980               // If we hit a phi node in afts we know we are done (probably
981               // LCSSA)
982               if (isa<PHINode>(I))
983                 return false;
984               // Can't move instructions with side effects or memory
985               // reads/writes
986               if (I->mayHaveSideEffects() || I->mayReadOrWriteMemory())
987                 return false;
988             }
989             // Keep going
990             return true;
991           })) {
992     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; can't move required "
993                          "instructions after subloop to before it\n");
994     return false;
995   }
996 
997   // Check for memory dependencies which prohibit the unrolling we are doing.
998   // Because of the way we are unrolling Fore/Sub/Aft blocks, we need to check
999   // there are no dependencies between Fore-Sub, Fore-Aft, Sub-Aft and Sub-Sub.
1000   if (!checkDependencies(*L, SubLoopBlocks, ForeBlocksMap, AftBlocksMap, DI,
1001                          LI)) {
1002     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; failed dependency check\n");
1003     return false;
1004   }
1005 
1006   return true;
1007 }
1008