1 //===-- LoopUnrollAndJam.cpp - Loop unrolling utilities -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop unroll and jam as a routine, much like
10 // LoopUnroll.cpp implements loop unroll.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/ArrayRef.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/Optional.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/Statistic.h"
21 #include "llvm/ADT/StringRef.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/ADT/iterator_range.h"
24 #include "llvm/Analysis/AssumptionCache.h"
25 #include "llvm/Analysis/DependenceAnalysis.h"
26 #include "llvm/Analysis/DomTreeUpdater.h"
27 #include "llvm/Analysis/LoopInfo.h"
28 #include "llvm/Analysis/LoopIterator.h"
29 #include "llvm/Analysis/MustExecute.h"
30 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
31 #include "llvm/Analysis/ScalarEvolution.h"
32 #include "llvm/IR/BasicBlock.h"
33 #include "llvm/IR/DebugInfoMetadata.h"
34 #include "llvm/IR/DebugLoc.h"
35 #include "llvm/IR/DiagnosticInfo.h"
36 #include "llvm/IR/Dominators.h"
37 #include "llvm/IR/Function.h"
38 #include "llvm/IR/Instruction.h"
39 #include "llvm/IR/Instructions.h"
40 #include "llvm/IR/IntrinsicInst.h"
41 #include "llvm/IR/Use.h"
42 #include "llvm/IR/User.h"
43 #include "llvm/IR/Value.h"
44 #include "llvm/IR/ValueHandle.h"
45 #include "llvm/IR/ValueMap.h"
46 #include "llvm/Support/Casting.h"
47 #include "llvm/Support/Debug.h"
48 #include "llvm/Support/ErrorHandling.h"
49 #include "llvm/Support/GenericDomTree.h"
50 #include "llvm/Support/raw_ostream.h"
51 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
52 #include "llvm/Transforms/Utils/Cloning.h"
53 #include "llvm/Transforms/Utils/LoopUtils.h"
54 #include "llvm/Transforms/Utils/UnrollLoop.h"
55 #include "llvm/Transforms/Utils/ValueMapper.h"
56 #include <assert.h>
57 #include <memory>
58 #include <type_traits>
59 #include <vector>
60 
61 using namespace llvm;
62 
63 #define DEBUG_TYPE "loop-unroll-and-jam"
64 
65 STATISTIC(NumUnrolledAndJammed, "Number of loops unroll and jammed");
66 STATISTIC(NumCompletelyUnrolledAndJammed, "Number of loops unroll and jammed");
67 
68 typedef SmallPtrSet<BasicBlock *, 4> BasicBlockSet;
69 
70 // Partition blocks in an outer/inner loop pair into blocks before and after
71 // the loop
72 static bool partitionOuterLoopBlocks(Loop *L, Loop *SubLoop,
73                                      BasicBlockSet &ForeBlocks,
74                                      BasicBlockSet &SubLoopBlocks,
75                                      BasicBlockSet &AftBlocks,
76                                      DominatorTree *DT) {
77   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
78   SubLoopBlocks.insert(SubLoop->block_begin(), SubLoop->block_end());
79 
80   for (BasicBlock *BB : L->blocks()) {
81     if (!SubLoop->contains(BB)) {
82       if (DT->dominates(SubLoopLatch, BB))
83         AftBlocks.insert(BB);
84       else
85         ForeBlocks.insert(BB);
86     }
87   }
88 
89   // Check that all blocks in ForeBlocks together dominate the subloop
90   // TODO: This might ideally be done better with a dominator/postdominators.
91   BasicBlock *SubLoopPreHeader = SubLoop->getLoopPreheader();
92   for (BasicBlock *BB : ForeBlocks) {
93     if (BB == SubLoopPreHeader)
94       continue;
95     Instruction *TI = BB->getTerminator();
96     for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
97       if (!ForeBlocks.count(TI->getSuccessor(i)))
98         return false;
99   }
100 
101   return true;
102 }
103 
104 // Looks at the phi nodes in Header for values coming from Latch. For these
105 // instructions and all their operands calls Visit on them, keeping going for
106 // all the operands in AftBlocks. Returns false if Visit returns false,
107 // otherwise returns true. This is used to process the instructions in the
108 // Aft blocks that need to be moved before the subloop. It is used in two
109 // places. One to check that the required set of instructions can be moved
110 // before the loop. Then to collect the instructions to actually move in
111 // moveHeaderPhiOperandsToForeBlocks.
112 template <typename T>
113 static bool processHeaderPhiOperands(BasicBlock *Header, BasicBlock *Latch,
114                                      BasicBlockSet &AftBlocks, T Visit) {
115   SmallVector<Instruction *, 8> Worklist;
116   for (auto &Phi : Header->phis()) {
117     Value *V = Phi.getIncomingValueForBlock(Latch);
118     if (Instruction *I = dyn_cast<Instruction>(V))
119       Worklist.push_back(I);
120   }
121 
122   while (!Worklist.empty()) {
123     Instruction *I = Worklist.back();
124     Worklist.pop_back();
125     if (!Visit(I))
126       return false;
127 
128     if (AftBlocks.count(I->getParent()))
129       for (auto &U : I->operands())
130         if (Instruction *II = dyn_cast<Instruction>(U))
131           Worklist.push_back(II);
132   }
133 
134   return true;
135 }
136 
137 // Move the phi operands of Header from Latch out of AftBlocks to InsertLoc.
138 static void moveHeaderPhiOperandsToForeBlocks(BasicBlock *Header,
139                                               BasicBlock *Latch,
140                                               Instruction *InsertLoc,
141                                               BasicBlockSet &AftBlocks) {
142   // We need to ensure we move the instructions in the correct order,
143   // starting with the earliest required instruction and moving forward.
144   std::vector<Instruction *> Visited;
145   processHeaderPhiOperands(Header, Latch, AftBlocks,
146                            [&Visited, &AftBlocks](Instruction *I) {
147                              if (AftBlocks.count(I->getParent()))
148                                Visited.push_back(I);
149                              return true;
150                            });
151 
152   // Move all instructions in program order to before the InsertLoc
153   BasicBlock *InsertLocBB = InsertLoc->getParent();
154   for (Instruction *I : reverse(Visited)) {
155     if (I->getParent() != InsertLocBB)
156       I->moveBefore(InsertLoc);
157   }
158 }
159 
160 /*
161   This method performs Unroll and Jam. For a simple loop like:
162   for (i = ..)
163     Fore(i)
164     for (j = ..)
165       SubLoop(i, j)
166     Aft(i)
167 
168   Instead of doing normal inner or outer unrolling, we do:
169   for (i = .., i+=2)
170     Fore(i)
171     Fore(i+1)
172     for (j = ..)
173       SubLoop(i, j)
174       SubLoop(i+1, j)
175     Aft(i)
176     Aft(i+1)
177 
178   So the outer loop is essetially unrolled and then the inner loops are fused
179   ("jammed") together into a single loop. This can increase speed when there
180   are loads in SubLoop that are invariant to i, as they become shared between
181   the now jammed inner loops.
182 
183   We do this by spliting the blocks in the loop into Fore, Subloop and Aft.
184   Fore blocks are those before the inner loop, Aft are those after. Normal
185   Unroll code is used to copy each of these sets of blocks and the results are
186   combined together into the final form above.
187 
188   isSafeToUnrollAndJam should be used prior to calling this to make sure the
189   unrolling will be valid. Checking profitablility is also advisable.
190 
191   If EpilogueLoop is non-null, it receives the epilogue loop (if it was
192   necessary to create one and not fully unrolled).
193 */
194 LoopUnrollResult llvm::UnrollAndJamLoop(
195     Loop *L, unsigned Count, unsigned TripCount, unsigned TripMultiple,
196     bool UnrollRemainder, LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT,
197     AssumptionCache *AC, OptimizationRemarkEmitter *ORE, Loop **EpilogueLoop) {
198 
199   // When we enter here we should have already checked that it is safe
200   BasicBlock *Header = L->getHeader();
201   assert(Header && "No header.");
202   assert(L->getSubLoops().size() == 1);
203   Loop *SubLoop = *L->begin();
204 
205   // Don't enter the unroll code if there is nothing to do.
206   if (TripCount == 0 && Count < 2) {
207     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; almost nothing to do\n");
208     return LoopUnrollResult::Unmodified;
209   }
210 
211   assert(Count > 0);
212   assert(TripMultiple > 0);
213   assert(TripCount == 0 || TripCount % TripMultiple == 0);
214 
215   // Are we eliminating the loop control altogether?
216   bool CompletelyUnroll = (Count == TripCount);
217 
218   // We use the runtime remainder in cases where we don't know trip multiple
219   if (TripMultiple == 1 || TripMultiple % Count != 0) {
220     if (!UnrollRuntimeLoopRemainder(L, Count, /*AllowExpensiveTripCount*/ false,
221                                     /*UseEpilogRemainder*/ true,
222                                     UnrollRemainder, /*ForgetAllSCEV*/ false,
223                                     LI, SE, DT, AC, true, EpilogueLoop)) {
224       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; remainder loop could not be "
225                            "generated when assuming runtime trip count\n");
226       return LoopUnrollResult::Unmodified;
227     }
228   }
229 
230   // Notify ScalarEvolution that the loop will be substantially changed,
231   // if not outright eliminated.
232   if (SE) {
233     SE->forgetLoop(L);
234     SE->forgetLoop(SubLoop);
235   }
236 
237   using namespace ore;
238   // Report the unrolling decision.
239   if (CompletelyUnroll) {
240     LLVM_DEBUG(dbgs() << "COMPLETELY UNROLL AND JAMMING loop %"
241                       << Header->getName() << " with trip count " << TripCount
242                       << "!\n");
243     ORE->emit(OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(),
244                                  L->getHeader())
245               << "completely unroll and jammed loop with "
246               << NV("UnrollCount", TripCount) << " iterations");
247   } else {
248     auto DiagBuilder = [&]() {
249       OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(),
250                               L->getHeader());
251       return Diag << "unroll and jammed loop by a factor of "
252                   << NV("UnrollCount", Count);
253     };
254 
255     LLVM_DEBUG(dbgs() << "UNROLL AND JAMMING loop %" << Header->getName()
256                       << " by " << Count);
257     if (TripMultiple != 1) {
258       LLVM_DEBUG(dbgs() << " with " << TripMultiple << " trips per branch");
259       ORE->emit([&]() {
260         return DiagBuilder() << " with " << NV("TripMultiple", TripMultiple)
261                              << " trips per branch";
262       });
263     } else {
264       LLVM_DEBUG(dbgs() << " with run-time trip count");
265       ORE->emit([&]() { return DiagBuilder() << " with run-time trip count"; });
266     }
267     LLVM_DEBUG(dbgs() << "!\n");
268   }
269 
270   BasicBlock *Preheader = L->getLoopPreheader();
271   BasicBlock *LatchBlock = L->getLoopLatch();
272   assert(Preheader && "No preheader");
273   assert(LatchBlock && "No latch block");
274   BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator());
275   assert(BI && !BI->isUnconditional());
276   bool ContinueOnTrue = L->contains(BI->getSuccessor(0));
277   BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue);
278   bool SubLoopContinueOnTrue = SubLoop->contains(
279       SubLoop->getLoopLatch()->getTerminator()->getSuccessor(0));
280 
281   // Partition blocks in an outer/inner loop pair into blocks before and after
282   // the loop
283   BasicBlockSet SubLoopBlocks;
284   BasicBlockSet ForeBlocks;
285   BasicBlockSet AftBlocks;
286   partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks, AftBlocks,
287                            DT);
288 
289   // We keep track of the entering/first and exiting/last block of each of
290   // Fore/SubLoop/Aft in each iteration. This helps make the stapling up of
291   // blocks easier.
292   std::vector<BasicBlock *> ForeBlocksFirst;
293   std::vector<BasicBlock *> ForeBlocksLast;
294   std::vector<BasicBlock *> SubLoopBlocksFirst;
295   std::vector<BasicBlock *> SubLoopBlocksLast;
296   std::vector<BasicBlock *> AftBlocksFirst;
297   std::vector<BasicBlock *> AftBlocksLast;
298   ForeBlocksFirst.push_back(Header);
299   ForeBlocksLast.push_back(SubLoop->getLoopPreheader());
300   SubLoopBlocksFirst.push_back(SubLoop->getHeader());
301   SubLoopBlocksLast.push_back(SubLoop->getExitingBlock());
302   AftBlocksFirst.push_back(SubLoop->getExitBlock());
303   AftBlocksLast.push_back(L->getExitingBlock());
304   // Maps Blocks[0] -> Blocks[It]
305   ValueToValueMapTy LastValueMap;
306 
307   // Move any instructions from fore phi operands from AftBlocks into Fore.
308   moveHeaderPhiOperandsToForeBlocks(
309       Header, LatchBlock, ForeBlocksLast[0]->getTerminator(), AftBlocks);
310 
311   // The current on-the-fly SSA update requires blocks to be processed in
312   // reverse postorder so that LastValueMap contains the correct value at each
313   // exit.
314   LoopBlocksDFS DFS(L);
315   DFS.perform(LI);
316   // Stash the DFS iterators before adding blocks to the loop.
317   LoopBlocksDFS::RPOIterator BlockBegin = DFS.beginRPO();
318   LoopBlocksDFS::RPOIterator BlockEnd = DFS.endRPO();
319 
320   if (Header->getParent()->isDebugInfoForProfiling())
321     for (BasicBlock *BB : L->getBlocks())
322       for (Instruction &I : *BB)
323         if (!isa<DbgInfoIntrinsic>(&I))
324           if (const DILocation *DIL = I.getDebugLoc()) {
325             auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(Count);
326             if (NewDIL)
327               I.setDebugLoc(NewDIL.getValue());
328             else
329               LLVM_DEBUG(dbgs()
330                          << "Failed to create new discriminator: "
331                          << DIL->getFilename() << " Line: " << DIL->getLine());
332           }
333 
334   // Copy all blocks
335   for (unsigned It = 1; It != Count; ++It) {
336     SmallVector<BasicBlock *, 8> NewBlocks;
337     // Maps Blocks[It] -> Blocks[It-1]
338     DenseMap<Value *, Value *> PrevItValueMap;
339 
340     for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) {
341       ValueToValueMapTy VMap;
342       BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It));
343       Header->getParent()->getBasicBlockList().push_back(New);
344 
345       if (ForeBlocks.count(*BB)) {
346         L->addBasicBlockToLoop(New, *LI);
347 
348         if (*BB == ForeBlocksFirst[0])
349           ForeBlocksFirst.push_back(New);
350         if (*BB == ForeBlocksLast[0])
351           ForeBlocksLast.push_back(New);
352       } else if (SubLoopBlocks.count(*BB)) {
353         SubLoop->addBasicBlockToLoop(New, *LI);
354 
355         if (*BB == SubLoopBlocksFirst[0])
356           SubLoopBlocksFirst.push_back(New);
357         if (*BB == SubLoopBlocksLast[0])
358           SubLoopBlocksLast.push_back(New);
359       } else if (AftBlocks.count(*BB)) {
360         L->addBasicBlockToLoop(New, *LI);
361 
362         if (*BB == AftBlocksFirst[0])
363           AftBlocksFirst.push_back(New);
364         if (*BB == AftBlocksLast[0])
365           AftBlocksLast.push_back(New);
366       } else {
367         llvm_unreachable("BB being cloned should be in Fore/Sub/Aft");
368       }
369 
370       // Update our running maps of newest clones
371       PrevItValueMap[New] = (It == 1 ? *BB : LastValueMap[*BB]);
372       LastValueMap[*BB] = New;
373       for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end();
374            VI != VE; ++VI) {
375         PrevItValueMap[VI->second] =
376             const_cast<Value *>(It == 1 ? VI->first : LastValueMap[VI->first]);
377         LastValueMap[VI->first] = VI->second;
378       }
379 
380       NewBlocks.push_back(New);
381 
382       // Update DomTree:
383       if (*BB == ForeBlocksFirst[0])
384         DT->addNewBlock(New, ForeBlocksLast[It - 1]);
385       else if (*BB == SubLoopBlocksFirst[0])
386         DT->addNewBlock(New, SubLoopBlocksLast[It - 1]);
387       else if (*BB == AftBlocksFirst[0])
388         DT->addNewBlock(New, AftBlocksLast[It - 1]);
389       else {
390         // Each set of blocks (Fore/Sub/Aft) will have the same internal domtree
391         // structure.
392         auto BBDomNode = DT->getNode(*BB);
393         auto BBIDom = BBDomNode->getIDom();
394         BasicBlock *OriginalBBIDom = BBIDom->getBlock();
395         assert(OriginalBBIDom);
396         assert(LastValueMap[cast<Value>(OriginalBBIDom)]);
397         DT->addNewBlock(
398             New, cast<BasicBlock>(LastValueMap[cast<Value>(OriginalBBIDom)]));
399       }
400     }
401 
402     // Remap all instructions in the most recent iteration
403     remapInstructionsInBlocks(NewBlocks, LastValueMap);
404     for (BasicBlock *NewBlock : NewBlocks) {
405       for (Instruction &I : *NewBlock) {
406         if (auto *II = dyn_cast<IntrinsicInst>(&I))
407           if (II->getIntrinsicID() == Intrinsic::assume)
408             AC->registerAssumption(II);
409       }
410     }
411 
412     // Alter the ForeBlocks phi's, pointing them at the latest version of the
413     // value from the previous iteration's phis
414     for (PHINode &Phi : ForeBlocksFirst[It]->phis()) {
415       Value *OldValue = Phi.getIncomingValueForBlock(AftBlocksLast[It]);
416       assert(OldValue && "should have incoming edge from Aft[It]");
417       Value *NewValue = OldValue;
418       if (Value *PrevValue = PrevItValueMap[OldValue])
419         NewValue = PrevValue;
420 
421       assert(Phi.getNumOperands() == 2);
422       Phi.setIncomingBlock(0, ForeBlocksLast[It - 1]);
423       Phi.setIncomingValue(0, NewValue);
424       Phi.removeIncomingValue(1);
425     }
426   }
427 
428   // Now that all the basic blocks for the unrolled iterations are in place,
429   // finish up connecting the blocks and phi nodes. At this point LastValueMap
430   // is the last unrolled iterations values.
431 
432   // Update Phis in BB from OldBB to point to NewBB
433   auto updatePHIBlocks = [](BasicBlock *BB, BasicBlock *OldBB,
434                             BasicBlock *NewBB) {
435     for (PHINode &Phi : BB->phis()) {
436       int I = Phi.getBasicBlockIndex(OldBB);
437       Phi.setIncomingBlock(I, NewBB);
438     }
439   };
440   // Update Phis in BB from OldBB to point to NewBB and use the latest value
441   // from LastValueMap
442   auto updatePHIBlocksAndValues = [](BasicBlock *BB, BasicBlock *OldBB,
443                                      BasicBlock *NewBB,
444                                      ValueToValueMapTy &LastValueMap) {
445     for (PHINode &Phi : BB->phis()) {
446       for (unsigned b = 0; b < Phi.getNumIncomingValues(); ++b) {
447         if (Phi.getIncomingBlock(b) == OldBB) {
448           Value *OldValue = Phi.getIncomingValue(b);
449           if (Value *LastValue = LastValueMap[OldValue])
450             Phi.setIncomingValue(b, LastValue);
451           Phi.setIncomingBlock(b, NewBB);
452           break;
453         }
454       }
455     }
456   };
457   // Move all the phis from Src into Dest
458   auto movePHIs = [](BasicBlock *Src, BasicBlock *Dest) {
459     Instruction *insertPoint = Dest->getFirstNonPHI();
460     while (PHINode *Phi = dyn_cast<PHINode>(Src->begin()))
461       Phi->moveBefore(insertPoint);
462   };
463 
464   // Update the PHI values outside the loop to point to the last block
465   updatePHIBlocksAndValues(LoopExit, AftBlocksLast[0], AftBlocksLast.back(),
466                            LastValueMap);
467 
468   // Update ForeBlocks successors and phi nodes
469   BranchInst *ForeTerm =
470       cast<BranchInst>(ForeBlocksLast.back()->getTerminator());
471   assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
472   ForeTerm->setSuccessor(0, SubLoopBlocksFirst[0]);
473 
474   if (CompletelyUnroll) {
475     while (PHINode *Phi = dyn_cast<PHINode>(ForeBlocksFirst[0]->begin())) {
476       Phi->replaceAllUsesWith(Phi->getIncomingValueForBlock(Preheader));
477       Phi->getParent()->getInstList().erase(Phi);
478     }
479   } else {
480     // Update the PHI values to point to the last aft block
481     updatePHIBlocksAndValues(ForeBlocksFirst[0], AftBlocksLast[0],
482                              AftBlocksLast.back(), LastValueMap);
483   }
484 
485   for (unsigned It = 1; It != Count; It++) {
486     // Remap ForeBlock successors from previous iteration to this
487     BranchInst *ForeTerm =
488         cast<BranchInst>(ForeBlocksLast[It - 1]->getTerminator());
489     assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
490     ForeTerm->setSuccessor(0, ForeBlocksFirst[It]);
491   }
492 
493   // Subloop successors and phis
494   BranchInst *SubTerm =
495       cast<BranchInst>(SubLoopBlocksLast.back()->getTerminator());
496   SubTerm->setSuccessor(!SubLoopContinueOnTrue, SubLoopBlocksFirst[0]);
497   SubTerm->setSuccessor(SubLoopContinueOnTrue, AftBlocksFirst[0]);
498   updatePHIBlocks(SubLoopBlocksFirst[0], ForeBlocksLast[0],
499                   ForeBlocksLast.back());
500   updatePHIBlocks(SubLoopBlocksFirst[0], SubLoopBlocksLast[0],
501                   SubLoopBlocksLast.back());
502 
503   for (unsigned It = 1; It != Count; It++) {
504     // Replace the conditional branch of the previous iteration subloop with an
505     // unconditional one to this one
506     BranchInst *SubTerm =
507         cast<BranchInst>(SubLoopBlocksLast[It - 1]->getTerminator());
508     BranchInst::Create(SubLoopBlocksFirst[It], SubTerm);
509     SubTerm->eraseFromParent();
510 
511     updatePHIBlocks(SubLoopBlocksFirst[It], ForeBlocksLast[It],
512                     ForeBlocksLast.back());
513     updatePHIBlocks(SubLoopBlocksFirst[It], SubLoopBlocksLast[It],
514                     SubLoopBlocksLast.back());
515     movePHIs(SubLoopBlocksFirst[It], SubLoopBlocksFirst[0]);
516   }
517 
518   // Aft blocks successors and phis
519   BranchInst *AftTerm = cast<BranchInst>(AftBlocksLast.back()->getTerminator());
520   if (CompletelyUnroll) {
521     BranchInst::Create(LoopExit, AftTerm);
522     AftTerm->eraseFromParent();
523   } else {
524     AftTerm->setSuccessor(!ContinueOnTrue, ForeBlocksFirst[0]);
525     assert(AftTerm->getSuccessor(ContinueOnTrue) == LoopExit &&
526            "Expecting the ContinueOnTrue successor of AftTerm to be LoopExit");
527   }
528   updatePHIBlocks(AftBlocksFirst[0], SubLoopBlocksLast[0],
529                   SubLoopBlocksLast.back());
530 
531   for (unsigned It = 1; It != Count; It++) {
532     // Replace the conditional branch of the previous iteration subloop with an
533     // unconditional one to this one
534     BranchInst *AftTerm =
535         cast<BranchInst>(AftBlocksLast[It - 1]->getTerminator());
536     BranchInst::Create(AftBlocksFirst[It], AftTerm);
537     AftTerm->eraseFromParent();
538 
539     updatePHIBlocks(AftBlocksFirst[It], SubLoopBlocksLast[It],
540                     SubLoopBlocksLast.back());
541     movePHIs(AftBlocksFirst[It], AftBlocksFirst[0]);
542   }
543 
544   DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
545   // Dominator Tree. Remove the old links between Fore, Sub and Aft, adding the
546   // new ones required.
547   if (Count != 1) {
548     SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
549     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete, ForeBlocksLast[0],
550                            SubLoopBlocksFirst[0]);
551     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete,
552                            SubLoopBlocksLast[0], AftBlocksFirst[0]);
553 
554     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
555                            ForeBlocksLast.back(), SubLoopBlocksFirst[0]);
556     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
557                            SubLoopBlocksLast.back(), AftBlocksFirst[0]);
558     DTU.applyUpdatesPermissive(DTUpdates);
559   }
560 
561   // Merge adjacent basic blocks, if possible.
562   SmallPtrSet<BasicBlock *, 16> MergeBlocks;
563   MergeBlocks.insert(ForeBlocksLast.begin(), ForeBlocksLast.end());
564   MergeBlocks.insert(SubLoopBlocksLast.begin(), SubLoopBlocksLast.end());
565   MergeBlocks.insert(AftBlocksLast.begin(), AftBlocksLast.end());
566   while (!MergeBlocks.empty()) {
567     BasicBlock *BB = *MergeBlocks.begin();
568     BranchInst *Term = dyn_cast<BranchInst>(BB->getTerminator());
569     if (Term && Term->isUnconditional() && L->contains(Term->getSuccessor(0))) {
570       BasicBlock *Dest = Term->getSuccessor(0);
571       BasicBlock *Fold = Dest->getUniquePredecessor();
572       if (MergeBlockIntoPredecessor(Dest, &DTU, LI)) {
573         // Don't remove BB and add Fold as they are the same BB
574         assert(Fold == BB);
575         (void)Fold;
576         MergeBlocks.erase(Dest);
577       } else
578         MergeBlocks.erase(BB);
579     } else
580       MergeBlocks.erase(BB);
581   }
582   // Apply updates to the DomTree.
583   DT = &DTU.getDomTree();
584 
585   // At this point, the code is well formed.  We now do a quick sweep over the
586   // inserted code, doing constant propagation and dead code elimination as we
587   // go.
588   simplifyLoopAfterUnroll(SubLoop, true, LI, SE, DT, AC);
589   simplifyLoopAfterUnroll(L, !CompletelyUnroll && Count > 1, LI, SE, DT, AC);
590 
591   NumCompletelyUnrolledAndJammed += CompletelyUnroll;
592   ++NumUnrolledAndJammed;
593 
594   // Update LoopInfo if the loop is completely removed.
595   if (CompletelyUnroll)
596     LI->erase(L);
597 
598 #ifndef NDEBUG
599   // We shouldn't have done anything to break loop simplify form or LCSSA.
600   Loop *OutestLoop = SubLoop->getParentLoop()
601                          ? SubLoop->getParentLoop()->getParentLoop()
602                                ? SubLoop->getParentLoop()->getParentLoop()
603                                : SubLoop->getParentLoop()
604                          : SubLoop;
605   assert(DT->verify());
606   LI->verify(*DT);
607   assert(OutestLoop->isRecursivelyLCSSAForm(*DT, *LI));
608   if (!CompletelyUnroll)
609     assert(L->isLoopSimplifyForm());
610   assert(SubLoop->isLoopSimplifyForm());
611   SE->verify();
612 #endif
613 
614   return CompletelyUnroll ? LoopUnrollResult::FullyUnrolled
615                           : LoopUnrollResult::PartiallyUnrolled;
616 }
617 
618 static bool getLoadsAndStores(BasicBlockSet &Blocks,
619                               SmallVector<Value *, 4> &MemInstr) {
620   // Scan the BBs and collect legal loads and stores.
621   // Returns false if non-simple loads/stores are found.
622   for (BasicBlock *BB : Blocks) {
623     for (Instruction &I : *BB) {
624       if (auto *Ld = dyn_cast<LoadInst>(&I)) {
625         if (!Ld->isSimple())
626           return false;
627         MemInstr.push_back(&I);
628       } else if (auto *St = dyn_cast<StoreInst>(&I)) {
629         if (!St->isSimple())
630           return false;
631         MemInstr.push_back(&I);
632       } else if (I.mayReadOrWriteMemory()) {
633         return false;
634       }
635     }
636   }
637   return true;
638 }
639 
640 static bool checkDependencies(SmallVector<Value *, 4> &Earlier,
641                               SmallVector<Value *, 4> &Later,
642                               unsigned LoopDepth, bool InnerLoop,
643                               DependenceInfo &DI) {
644   // Use DA to check for dependencies between loads and stores that make unroll
645   // and jam invalid
646   for (Value *I : Earlier) {
647     for (Value *J : Later) {
648       Instruction *Src = cast<Instruction>(I);
649       Instruction *Dst = cast<Instruction>(J);
650       if (Src == Dst)
651         continue;
652       // Ignore Input dependencies.
653       if (isa<LoadInst>(Src) && isa<LoadInst>(Dst))
654         continue;
655 
656       // Track dependencies, and if we find them take a conservative approach
657       // by allowing only = or < (not >), altough some > would be safe
658       // (depending upon unroll width).
659       // For the inner loop, we need to disallow any (> <) dependencies
660       // FIXME: Allow > so long as distance is less than unroll width
661       if (auto D = DI.depends(Src, Dst, true)) {
662         assert(D->isOrdered() && "Expected an output, flow or anti dep.");
663 
664         if (D->isConfused()) {
665           LLVM_DEBUG(dbgs() << "  Confused dependency between:\n"
666                             << "  " << *Src << "\n"
667                             << "  " << *Dst << "\n");
668           return false;
669         }
670         if (!InnerLoop) {
671           if (D->getDirection(LoopDepth) & Dependence::DVEntry::GT) {
672             LLVM_DEBUG(dbgs() << "  > dependency between:\n"
673                               << "  " << *Src << "\n"
674                               << "  " << *Dst << "\n");
675             return false;
676           }
677         } else {
678           assert(LoopDepth + 1 <= D->getLevels());
679           if (D->getDirection(LoopDepth) & Dependence::DVEntry::GT &&
680               D->getDirection(LoopDepth + 1) & Dependence::DVEntry::LT) {
681             LLVM_DEBUG(dbgs() << "  < > dependency between:\n"
682                               << "  " << *Src << "\n"
683                               << "  " << *Dst << "\n");
684             return false;
685           }
686         }
687       }
688     }
689   }
690   return true;
691 }
692 
693 static bool checkDependencies(Loop *L, BasicBlockSet &ForeBlocks,
694                               BasicBlockSet &SubLoopBlocks,
695                               BasicBlockSet &AftBlocks, DependenceInfo &DI) {
696   // Get all loads/store pairs for each blocks
697   SmallVector<Value *, 4> ForeMemInstr;
698   SmallVector<Value *, 4> SubLoopMemInstr;
699   SmallVector<Value *, 4> AftMemInstr;
700   if (!getLoadsAndStores(ForeBlocks, ForeMemInstr) ||
701       !getLoadsAndStores(SubLoopBlocks, SubLoopMemInstr) ||
702       !getLoadsAndStores(AftBlocks, AftMemInstr))
703     return false;
704 
705   // Check for dependencies between any blocks that may change order
706   unsigned LoopDepth = L->getLoopDepth();
707   return checkDependencies(ForeMemInstr, SubLoopMemInstr, LoopDepth, false,
708                            DI) &&
709          checkDependencies(ForeMemInstr, AftMemInstr, LoopDepth, false, DI) &&
710          checkDependencies(SubLoopMemInstr, AftMemInstr, LoopDepth, false,
711                            DI) &&
712          checkDependencies(SubLoopMemInstr, SubLoopMemInstr, LoopDepth, true,
713                            DI);
714 }
715 
716 bool llvm::isSafeToUnrollAndJam(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
717                                 DependenceInfo &DI) {
718   /* We currently handle outer loops like this:
719         |
720     ForeFirst    <----\    }
721      Blocks           |    } ForeBlocks
722     ForeLast          |    }
723         |             |
724     SubLoopFirst  <\  |    }
725      Blocks        |  |    } SubLoopBlocks
726     SubLoopLast   -/  |    }
727         |             |
728     AftFirst          |    }
729      Blocks           |    } AftBlocks
730     AftLast     ------/    }
731         |
732 
733     There are (theoretically) any number of blocks in ForeBlocks, SubLoopBlocks
734     and AftBlocks, providing that there is one edge from Fores to SubLoops,
735     one edge from SubLoops to Afts and a single outer loop exit (from Afts).
736     In practice we currently limit Aft blocks to a single block, and limit
737     things further in the profitablility checks of the unroll and jam pass.
738 
739     Because of the way we rearrange basic blocks, we also require that
740     the Fore blocks on all unrolled iterations are safe to move before the
741     SubLoop blocks of all iterations. So we require that the phi node looping
742     operands of ForeHeader can be moved to at least the end of ForeEnd, so that
743     we can arrange cloned Fore Blocks before the subloop and match up Phi's
744     correctly.
745 
746     i.e. The old order of blocks used to be F1 S1_1 S1_2 A1 F2 S2_1 S2_2 A2.
747     It needs to be safe to tranform this to F1 F2 S1_1 S2_1 S1_2 S2_2 A1 A2.
748 
749     There are then a number of checks along the lines of no calls, no
750     exceptions, inner loop IV is consistent, etc. Note that for loops requiring
751     runtime unrolling, UnrollRuntimeLoopRemainder can also fail in
752     UnrollAndJamLoop if the trip count cannot be easily calculated.
753   */
754 
755   if (!L->isLoopSimplifyForm() || L->getSubLoops().size() != 1)
756     return false;
757   Loop *SubLoop = L->getSubLoops()[0];
758   if (!SubLoop->isLoopSimplifyForm())
759     return false;
760 
761   BasicBlock *Header = L->getHeader();
762   BasicBlock *Latch = L->getLoopLatch();
763   BasicBlock *Exit = L->getExitingBlock();
764   BasicBlock *SubLoopHeader = SubLoop->getHeader();
765   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
766   BasicBlock *SubLoopExit = SubLoop->getExitingBlock();
767 
768   if (Latch != Exit)
769     return false;
770   if (SubLoopLatch != SubLoopExit)
771     return false;
772 
773   if (Header->hasAddressTaken() || SubLoopHeader->hasAddressTaken()) {
774     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Address taken\n");
775     return false;
776   }
777 
778   // Split blocks into Fore/SubLoop/Aft based on dominators
779   BasicBlockSet SubLoopBlocks;
780   BasicBlockSet ForeBlocks;
781   BasicBlockSet AftBlocks;
782   if (!partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks,
783                                 AftBlocks, &DT)) {
784     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Incompatible loop layout\n");
785     return false;
786   }
787 
788   // Aft blocks may need to move instructions to fore blocks, which becomes more
789   // difficult if there are multiple (potentially conditionally executed)
790   // blocks. For now we just exclude loops with multiple aft blocks.
791   if (AftBlocks.size() != 1) {
792     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Can't currently handle "
793                          "multiple blocks after the loop\n");
794     return false;
795   }
796 
797   // Check inner loop backedge count is consistent on all iterations of the
798   // outer loop
799   if (!hasIterationCountInvariantInParent(SubLoop, SE)) {
800     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Inner loop iteration count is "
801                          "not consistent on each iteration\n");
802     return false;
803   }
804 
805   // Check the loop safety info for exceptions.
806   SimpleLoopSafetyInfo LSI;
807   LSI.computeLoopSafetyInfo(L);
808   if (LSI.anyBlockMayThrow()) {
809     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Something may throw\n");
810     return false;
811   }
812 
813   // We've ruled out the easy stuff and now need to check that there are no
814   // interdependencies which may prevent us from moving the:
815   //  ForeBlocks before Subloop and AftBlocks.
816   //  Subloop before AftBlocks.
817   //  ForeBlock phi operands before the subloop
818 
819   // Make sure we can move all instructions we need to before the subloop
820   if (!processHeaderPhiOperands(
821           Header, Latch, AftBlocks, [&AftBlocks, &SubLoop](Instruction *I) {
822             if (SubLoop->contains(I->getParent()))
823               return false;
824             if (AftBlocks.count(I->getParent())) {
825               // If we hit a phi node in afts we know we are done (probably
826               // LCSSA)
827               if (isa<PHINode>(I))
828                 return false;
829               // Can't move instructions with side effects or memory
830               // reads/writes
831               if (I->mayHaveSideEffects() || I->mayReadOrWriteMemory())
832                 return false;
833             }
834             // Keep going
835             return true;
836           })) {
837     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; can't move required "
838                          "instructions after subloop to before it\n");
839     return false;
840   }
841 
842   // Check for memory dependencies which prohibit the unrolling we are doing.
843   // Because of the way we are unrolling Fore/Sub/Aft blocks, we need to check
844   // there are no dependencies between Fore-Sub, Fore-Aft, Sub-Aft and Sub-Sub.
845   if (!checkDependencies(L, ForeBlocks, SubLoopBlocks, AftBlocks, DI)) {
846     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; failed dependency check\n");
847     return false;
848   }
849 
850   return true;
851 }
852