1 //===-- LoopUnrollAndJam.cpp - Loop unrolling utilities -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop unroll and jam as a routine, much like
10 // LoopUnroll.cpp implements loop unroll.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/ArrayRef.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/Optional.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SmallPtrSet.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/Statistic.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/ADT/Twine.h"
24 #include "llvm/ADT/iterator_range.h"
25 #include "llvm/Analysis/AssumptionCache.h"
26 #include "llvm/Analysis/DependenceAnalysis.h"
27 #include "llvm/Analysis/DomTreeUpdater.h"
28 #include "llvm/Analysis/LoopInfo.h"
29 #include "llvm/Analysis/LoopIterator.h"
30 #include "llvm/Analysis/MustExecute.h"
31 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
32 #include "llvm/Analysis/ScalarEvolution.h"
33 #include "llvm/IR/BasicBlock.h"
34 #include "llvm/IR/DebugInfoMetadata.h"
35 #include "llvm/IR/DebugLoc.h"
36 #include "llvm/IR/DiagnosticInfo.h"
37 #include "llvm/IR/Dominators.h"
38 #include "llvm/IR/Function.h"
39 #include "llvm/IR/Instruction.h"
40 #include "llvm/IR/Instructions.h"
41 #include "llvm/IR/IntrinsicInst.h"
42 #include "llvm/IR/Use.h"
43 #include "llvm/IR/User.h"
44 #include "llvm/IR/Value.h"
45 #include "llvm/IR/ValueHandle.h"
46 #include "llvm/IR/ValueMap.h"
47 #include "llvm/Support/Casting.h"
48 #include "llvm/Support/Debug.h"
49 #include "llvm/Support/ErrorHandling.h"
50 #include "llvm/Support/GenericDomTree.h"
51 #include "llvm/Support/raw_ostream.h"
52 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
53 #include "llvm/Transforms/Utils/Cloning.h"
54 #include "llvm/Transforms/Utils/LoopUtils.h"
55 #include "llvm/Transforms/Utils/UnrollLoop.h"
56 #include "llvm/Transforms/Utils/ValueMapper.h"
57 #include <assert.h>
58 #include <memory>
59 #include <type_traits>
60 #include <vector>
61 
62 using namespace llvm;
63 
64 #define DEBUG_TYPE "loop-unroll-and-jam"
65 
66 STATISTIC(NumUnrolledAndJammed, "Number of loops unroll and jammed");
67 STATISTIC(NumCompletelyUnrolledAndJammed, "Number of loops unroll and jammed");
68 
69 typedef SmallPtrSet<BasicBlock *, 4> BasicBlockSet;
70 
71 // Partition blocks in an outer/inner loop pair into blocks before and after
72 // the loop
73 static bool partitionLoopBlocks(Loop &L, BasicBlockSet &ForeBlocks,
74                                 BasicBlockSet &AftBlocks, DominatorTree &DT) {
75   Loop *SubLoop = L.getSubLoops()[0];
76   BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
77 
78   for (BasicBlock *BB : L.blocks()) {
79     if (!SubLoop->contains(BB)) {
80       if (DT.dominates(SubLoopLatch, BB))
81         AftBlocks.insert(BB);
82       else
83         ForeBlocks.insert(BB);
84     }
85   }
86 
87   // Check that all blocks in ForeBlocks together dominate the subloop
88   // TODO: This might ideally be done better with a dominator/postdominators.
89   BasicBlock *SubLoopPreHeader = SubLoop->getLoopPreheader();
90   for (BasicBlock *BB : ForeBlocks) {
91     if (BB == SubLoopPreHeader)
92       continue;
93     Instruction *TI = BB->getTerminator();
94     for (BasicBlock *Succ : successors(TI))
95       if (!ForeBlocks.count(Succ))
96         return false;
97   }
98 
99   return true;
100 }
101 
102 /// Partition blocks in a loop nest into blocks before and after each inner
103 /// loop.
104 static bool partitionOuterLoopBlocks(
105     Loop &Root, Loop &JamLoop, BasicBlockSet &JamLoopBlocks,
106     DenseMap<Loop *, BasicBlockSet> &ForeBlocksMap,
107     DenseMap<Loop *, BasicBlockSet> &AftBlocksMap, DominatorTree &DT) {
108   JamLoopBlocks.insert(JamLoop.block_begin(), JamLoop.block_end());
109 
110   for (Loop *L : Root.getLoopsInPreorder()) {
111     if (L == &JamLoop)
112       break;
113 
114     if (!partitionLoopBlocks(*L, ForeBlocksMap[L], AftBlocksMap[L], DT))
115       return false;
116   }
117 
118   return true;
119 }
120 
121 // TODO Remove when UnrollAndJamLoop changed to support unroll and jamming more
122 // than 2 levels loop.
123 static bool partitionOuterLoopBlocks(Loop *L, Loop *SubLoop,
124                                      BasicBlockSet &ForeBlocks,
125                                      BasicBlockSet &SubLoopBlocks,
126                                      BasicBlockSet &AftBlocks,
127                                      DominatorTree *DT) {
128   SubLoopBlocks.insert(SubLoop->block_begin(), SubLoop->block_end());
129   return partitionLoopBlocks(*L, ForeBlocks, AftBlocks, *DT);
130 }
131 
132 // Looks at the phi nodes in Header for values coming from Latch. For these
133 // instructions and all their operands calls Visit on them, keeping going for
134 // all the operands in AftBlocks. Returns false if Visit returns false,
135 // otherwise returns true. This is used to process the instructions in the
136 // Aft blocks that need to be moved before the subloop. It is used in two
137 // places. One to check that the required set of instructions can be moved
138 // before the loop. Then to collect the instructions to actually move in
139 // moveHeaderPhiOperandsToForeBlocks.
140 template <typename T>
141 static bool processHeaderPhiOperands(BasicBlock *Header, BasicBlock *Latch,
142                                      BasicBlockSet &AftBlocks, T Visit) {
143   SmallVector<Instruction *, 8> Worklist;
144   for (auto &Phi : Header->phis()) {
145     Value *V = Phi.getIncomingValueForBlock(Latch);
146     if (Instruction *I = dyn_cast<Instruction>(V))
147       Worklist.push_back(I);
148   }
149 
150   while (!Worklist.empty()) {
151     Instruction *I = Worklist.back();
152     Worklist.pop_back();
153     if (!Visit(I))
154       return false;
155 
156     if (AftBlocks.count(I->getParent()))
157       for (auto &U : I->operands())
158         if (Instruction *II = dyn_cast<Instruction>(U))
159           Worklist.push_back(II);
160   }
161 
162   return true;
163 }
164 
165 // Move the phi operands of Header from Latch out of AftBlocks to InsertLoc.
166 static void moveHeaderPhiOperandsToForeBlocks(BasicBlock *Header,
167                                               BasicBlock *Latch,
168                                               Instruction *InsertLoc,
169                                               BasicBlockSet &AftBlocks) {
170   // We need to ensure we move the instructions in the correct order,
171   // starting with the earliest required instruction and moving forward.
172   std::vector<Instruction *> Visited;
173   processHeaderPhiOperands(Header, Latch, AftBlocks,
174                            [&Visited, &AftBlocks](Instruction *I) {
175                              if (AftBlocks.count(I->getParent()))
176                                Visited.push_back(I);
177                              return true;
178                            });
179 
180   // Move all instructions in program order to before the InsertLoc
181   BasicBlock *InsertLocBB = InsertLoc->getParent();
182   for (Instruction *I : reverse(Visited)) {
183     if (I->getParent() != InsertLocBB)
184       I->moveBefore(InsertLoc);
185   }
186 }
187 
188 /*
189   This method performs Unroll and Jam. For a simple loop like:
190   for (i = ..)
191     Fore(i)
192     for (j = ..)
193       SubLoop(i, j)
194     Aft(i)
195 
196   Instead of doing normal inner or outer unrolling, we do:
197   for (i = .., i+=2)
198     Fore(i)
199     Fore(i+1)
200     for (j = ..)
201       SubLoop(i, j)
202       SubLoop(i+1, j)
203     Aft(i)
204     Aft(i+1)
205 
206   So the outer loop is essetially unrolled and then the inner loops are fused
207   ("jammed") together into a single loop. This can increase speed when there
208   are loads in SubLoop that are invariant to i, as they become shared between
209   the now jammed inner loops.
210 
211   We do this by spliting the blocks in the loop into Fore, Subloop and Aft.
212   Fore blocks are those before the inner loop, Aft are those after. Normal
213   Unroll code is used to copy each of these sets of blocks and the results are
214   combined together into the final form above.
215 
216   isSafeToUnrollAndJam should be used prior to calling this to make sure the
217   unrolling will be valid. Checking profitablility is also advisable.
218 
219   If EpilogueLoop is non-null, it receives the epilogue loop (if it was
220   necessary to create one and not fully unrolled).
221 */
222 LoopUnrollResult
223 llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount,
224                        unsigned TripMultiple, bool UnrollRemainder,
225                        LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT,
226                        AssumptionCache *AC, const TargetTransformInfo *TTI,
227                        OptimizationRemarkEmitter *ORE, Loop **EpilogueLoop) {
228 
229   // When we enter here we should have already checked that it is safe
230   BasicBlock *Header = L->getHeader();
231   assert(Header && "No header.");
232   assert(L->getSubLoops().size() == 1);
233   Loop *SubLoop = *L->begin();
234 
235   // Don't enter the unroll code if there is nothing to do.
236   if (TripCount == 0 && Count < 2) {
237     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; almost nothing to do\n");
238     return LoopUnrollResult::Unmodified;
239   }
240 
241   assert(Count > 0);
242   assert(TripMultiple > 0);
243   assert(TripCount == 0 || TripCount % TripMultiple == 0);
244 
245   // Are we eliminating the loop control altogether?
246   bool CompletelyUnroll = (Count == TripCount);
247 
248   // We use the runtime remainder in cases where we don't know trip multiple
249   if (TripMultiple == 1 || TripMultiple % Count != 0) {
250     if (!UnrollRuntimeLoopRemainder(L, Count, /*AllowExpensiveTripCount*/ false,
251                                     /*UseEpilogRemainder*/ true,
252                                     UnrollRemainder, /*ForgetAllSCEV*/ false,
253                                     LI, SE, DT, AC, TTI, true, EpilogueLoop)) {
254       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; remainder loop could not be "
255                            "generated when assuming runtime trip count\n");
256       return LoopUnrollResult::Unmodified;
257     }
258   }
259 
260   // Notify ScalarEvolution that the loop will be substantially changed,
261   // if not outright eliminated.
262   if (SE) {
263     SE->forgetLoop(L);
264     SE->forgetLoop(SubLoop);
265   }
266 
267   using namespace ore;
268   // Report the unrolling decision.
269   if (CompletelyUnroll) {
270     LLVM_DEBUG(dbgs() << "COMPLETELY UNROLL AND JAMMING loop %"
271                       << Header->getName() << " with trip count " << TripCount
272                       << "!\n");
273     ORE->emit(OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(),
274                                  L->getHeader())
275               << "completely unroll and jammed loop with "
276               << NV("UnrollCount", TripCount) << " iterations");
277   } else {
278     auto DiagBuilder = [&]() {
279       OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(),
280                               L->getHeader());
281       return Diag << "unroll and jammed loop by a factor of "
282                   << NV("UnrollCount", Count);
283     };
284 
285     LLVM_DEBUG(dbgs() << "UNROLL AND JAMMING loop %" << Header->getName()
286                       << " by " << Count);
287     if (TripMultiple != 1) {
288       LLVM_DEBUG(dbgs() << " with " << TripMultiple << " trips per branch");
289       ORE->emit([&]() {
290         return DiagBuilder() << " with " << NV("TripMultiple", TripMultiple)
291                              << " trips per branch";
292       });
293     } else {
294       LLVM_DEBUG(dbgs() << " with run-time trip count");
295       ORE->emit([&]() { return DiagBuilder() << " with run-time trip count"; });
296     }
297     LLVM_DEBUG(dbgs() << "!\n");
298   }
299 
300   BasicBlock *Preheader = L->getLoopPreheader();
301   BasicBlock *LatchBlock = L->getLoopLatch();
302   assert(Preheader && "No preheader");
303   assert(LatchBlock && "No latch block");
304   BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator());
305   assert(BI && !BI->isUnconditional());
306   bool ContinueOnTrue = L->contains(BI->getSuccessor(0));
307   BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue);
308   bool SubLoopContinueOnTrue = SubLoop->contains(
309       SubLoop->getLoopLatch()->getTerminator()->getSuccessor(0));
310 
311   // Partition blocks in an outer/inner loop pair into blocks before and after
312   // the loop
313   BasicBlockSet SubLoopBlocks;
314   BasicBlockSet ForeBlocks;
315   BasicBlockSet AftBlocks;
316   partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks, AftBlocks,
317                            DT);
318 
319   // We keep track of the entering/first and exiting/last block of each of
320   // Fore/SubLoop/Aft in each iteration. This helps make the stapling up of
321   // blocks easier.
322   std::vector<BasicBlock *> ForeBlocksFirst;
323   std::vector<BasicBlock *> ForeBlocksLast;
324   std::vector<BasicBlock *> SubLoopBlocksFirst;
325   std::vector<BasicBlock *> SubLoopBlocksLast;
326   std::vector<BasicBlock *> AftBlocksFirst;
327   std::vector<BasicBlock *> AftBlocksLast;
328   ForeBlocksFirst.push_back(Header);
329   ForeBlocksLast.push_back(SubLoop->getLoopPreheader());
330   SubLoopBlocksFirst.push_back(SubLoop->getHeader());
331   SubLoopBlocksLast.push_back(SubLoop->getExitingBlock());
332   AftBlocksFirst.push_back(SubLoop->getExitBlock());
333   AftBlocksLast.push_back(L->getExitingBlock());
334   // Maps Blocks[0] -> Blocks[It]
335   ValueToValueMapTy LastValueMap;
336 
337   // Move any instructions from fore phi operands from AftBlocks into Fore.
338   moveHeaderPhiOperandsToForeBlocks(
339       Header, LatchBlock, ForeBlocksLast[0]->getTerminator(), AftBlocks);
340 
341   // The current on-the-fly SSA update requires blocks to be processed in
342   // reverse postorder so that LastValueMap contains the correct value at each
343   // exit.
344   LoopBlocksDFS DFS(L);
345   DFS.perform(LI);
346   // Stash the DFS iterators before adding blocks to the loop.
347   LoopBlocksDFS::RPOIterator BlockBegin = DFS.beginRPO();
348   LoopBlocksDFS::RPOIterator BlockEnd = DFS.endRPO();
349 
350   if (Header->getParent()->isDebugInfoForProfiling())
351     for (BasicBlock *BB : L->getBlocks())
352       for (Instruction &I : *BB)
353         if (!isa<DbgInfoIntrinsic>(&I))
354           if (const DILocation *DIL = I.getDebugLoc()) {
355             auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(Count);
356             if (NewDIL)
357               I.setDebugLoc(NewDIL.getValue());
358             else
359               LLVM_DEBUG(dbgs()
360                          << "Failed to create new discriminator: "
361                          << DIL->getFilename() << " Line: " << DIL->getLine());
362           }
363 
364   // Copy all blocks
365   for (unsigned It = 1; It != Count; ++It) {
366     SmallVector<BasicBlock *, 8> NewBlocks;
367     // Maps Blocks[It] -> Blocks[It-1]
368     DenseMap<Value *, Value *> PrevItValueMap;
369     SmallDenseMap<const Loop *, Loop *, 4> NewLoops;
370     NewLoops[L] = L;
371     NewLoops[SubLoop] = SubLoop;
372 
373     for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) {
374       ValueToValueMapTy VMap;
375       BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It));
376       Header->getParent()->getBasicBlockList().push_back(New);
377 
378       // Tell LI about New.
379       addClonedBlockToLoopInfo(*BB, New, LI, NewLoops);
380 
381       if (ForeBlocks.count(*BB)) {
382         if (*BB == ForeBlocksFirst[0])
383           ForeBlocksFirst.push_back(New);
384         if (*BB == ForeBlocksLast[0])
385           ForeBlocksLast.push_back(New);
386       } else if (SubLoopBlocks.count(*BB)) {
387         if (*BB == SubLoopBlocksFirst[0])
388           SubLoopBlocksFirst.push_back(New);
389         if (*BB == SubLoopBlocksLast[0])
390           SubLoopBlocksLast.push_back(New);
391       } else if (AftBlocks.count(*BB)) {
392         if (*BB == AftBlocksFirst[0])
393           AftBlocksFirst.push_back(New);
394         if (*BB == AftBlocksLast[0])
395           AftBlocksLast.push_back(New);
396       } else {
397         llvm_unreachable("BB being cloned should be in Fore/Sub/Aft");
398       }
399 
400       // Update our running maps of newest clones
401       PrevItValueMap[New] = (It == 1 ? *BB : LastValueMap[*BB]);
402       LastValueMap[*BB] = New;
403       for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end();
404            VI != VE; ++VI) {
405         PrevItValueMap[VI->second] =
406             const_cast<Value *>(It == 1 ? VI->first : LastValueMap[VI->first]);
407         LastValueMap[VI->first] = VI->second;
408       }
409 
410       NewBlocks.push_back(New);
411 
412       // Update DomTree:
413       if (*BB == ForeBlocksFirst[0])
414         DT->addNewBlock(New, ForeBlocksLast[It - 1]);
415       else if (*BB == SubLoopBlocksFirst[0])
416         DT->addNewBlock(New, SubLoopBlocksLast[It - 1]);
417       else if (*BB == AftBlocksFirst[0])
418         DT->addNewBlock(New, AftBlocksLast[It - 1]);
419       else {
420         // Each set of blocks (Fore/Sub/Aft) will have the same internal domtree
421         // structure.
422         auto BBDomNode = DT->getNode(*BB);
423         auto BBIDom = BBDomNode->getIDom();
424         BasicBlock *OriginalBBIDom = BBIDom->getBlock();
425         assert(OriginalBBIDom);
426         assert(LastValueMap[cast<Value>(OriginalBBIDom)]);
427         DT->addNewBlock(
428             New, cast<BasicBlock>(LastValueMap[cast<Value>(OriginalBBIDom)]));
429       }
430     }
431 
432     // Remap all instructions in the most recent iteration
433     remapInstructionsInBlocks(NewBlocks, LastValueMap);
434     for (BasicBlock *NewBlock : NewBlocks) {
435       for (Instruction &I : *NewBlock) {
436         if (auto *II = dyn_cast<IntrinsicInst>(&I))
437           if (II->getIntrinsicID() == Intrinsic::assume)
438             AC->registerAssumption(II);
439       }
440     }
441 
442     // Alter the ForeBlocks phi's, pointing them at the latest version of the
443     // value from the previous iteration's phis
444     for (PHINode &Phi : ForeBlocksFirst[It]->phis()) {
445       Value *OldValue = Phi.getIncomingValueForBlock(AftBlocksLast[It]);
446       assert(OldValue && "should have incoming edge from Aft[It]");
447       Value *NewValue = OldValue;
448       if (Value *PrevValue = PrevItValueMap[OldValue])
449         NewValue = PrevValue;
450 
451       assert(Phi.getNumOperands() == 2);
452       Phi.setIncomingBlock(0, ForeBlocksLast[It - 1]);
453       Phi.setIncomingValue(0, NewValue);
454       Phi.removeIncomingValue(1);
455     }
456   }
457 
458   // Now that all the basic blocks for the unrolled iterations are in place,
459   // finish up connecting the blocks and phi nodes. At this point LastValueMap
460   // is the last unrolled iterations values.
461 
462   // Update Phis in BB from OldBB to point to NewBB and use the latest value
463   // from LastValueMap
464   auto updatePHIBlocksAndValues = [](BasicBlock *BB, BasicBlock *OldBB,
465                                      BasicBlock *NewBB,
466                                      ValueToValueMapTy &LastValueMap) {
467     for (PHINode &Phi : BB->phis()) {
468       for (unsigned b = 0; b < Phi.getNumIncomingValues(); ++b) {
469         if (Phi.getIncomingBlock(b) == OldBB) {
470           Value *OldValue = Phi.getIncomingValue(b);
471           if (Value *LastValue = LastValueMap[OldValue])
472             Phi.setIncomingValue(b, LastValue);
473           Phi.setIncomingBlock(b, NewBB);
474           break;
475         }
476       }
477     }
478   };
479   // Move all the phis from Src into Dest
480   auto movePHIs = [](BasicBlock *Src, BasicBlock *Dest) {
481     Instruction *insertPoint = Dest->getFirstNonPHI();
482     while (PHINode *Phi = dyn_cast<PHINode>(Src->begin()))
483       Phi->moveBefore(insertPoint);
484   };
485 
486   // Update the PHI values outside the loop to point to the last block
487   updatePHIBlocksAndValues(LoopExit, AftBlocksLast[0], AftBlocksLast.back(),
488                            LastValueMap);
489 
490   // Update ForeBlocks successors and phi nodes
491   BranchInst *ForeTerm =
492       cast<BranchInst>(ForeBlocksLast.back()->getTerminator());
493   assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
494   ForeTerm->setSuccessor(0, SubLoopBlocksFirst[0]);
495 
496   if (CompletelyUnroll) {
497     while (PHINode *Phi = dyn_cast<PHINode>(ForeBlocksFirst[0]->begin())) {
498       Phi->replaceAllUsesWith(Phi->getIncomingValueForBlock(Preheader));
499       Phi->getParent()->getInstList().erase(Phi);
500     }
501   } else {
502     // Update the PHI values to point to the last aft block
503     updatePHIBlocksAndValues(ForeBlocksFirst[0], AftBlocksLast[0],
504                              AftBlocksLast.back(), LastValueMap);
505   }
506 
507   for (unsigned It = 1; It != Count; It++) {
508     // Remap ForeBlock successors from previous iteration to this
509     BranchInst *ForeTerm =
510         cast<BranchInst>(ForeBlocksLast[It - 1]->getTerminator());
511     assert(ForeTerm->getNumSuccessors() == 1 && "Expecting one successor");
512     ForeTerm->setSuccessor(0, ForeBlocksFirst[It]);
513   }
514 
515   // Subloop successors and phis
516   BranchInst *SubTerm =
517       cast<BranchInst>(SubLoopBlocksLast.back()->getTerminator());
518   SubTerm->setSuccessor(!SubLoopContinueOnTrue, SubLoopBlocksFirst[0]);
519   SubTerm->setSuccessor(SubLoopContinueOnTrue, AftBlocksFirst[0]);
520   SubLoopBlocksFirst[0]->replacePhiUsesWith(ForeBlocksLast[0],
521                                             ForeBlocksLast.back());
522   SubLoopBlocksFirst[0]->replacePhiUsesWith(SubLoopBlocksLast[0],
523                                             SubLoopBlocksLast.back());
524 
525   for (unsigned It = 1; It != Count; It++) {
526     // Replace the conditional branch of the previous iteration subloop with an
527     // unconditional one to this one
528     BranchInst *SubTerm =
529         cast<BranchInst>(SubLoopBlocksLast[It - 1]->getTerminator());
530     BranchInst::Create(SubLoopBlocksFirst[It], SubTerm);
531     SubTerm->eraseFromParent();
532 
533     SubLoopBlocksFirst[It]->replacePhiUsesWith(ForeBlocksLast[It],
534                                                ForeBlocksLast.back());
535     SubLoopBlocksFirst[It]->replacePhiUsesWith(SubLoopBlocksLast[It],
536                                                SubLoopBlocksLast.back());
537     movePHIs(SubLoopBlocksFirst[It], SubLoopBlocksFirst[0]);
538   }
539 
540   // Aft blocks successors and phis
541   BranchInst *AftTerm = cast<BranchInst>(AftBlocksLast.back()->getTerminator());
542   if (CompletelyUnroll) {
543     BranchInst::Create(LoopExit, AftTerm);
544     AftTerm->eraseFromParent();
545   } else {
546     AftTerm->setSuccessor(!ContinueOnTrue, ForeBlocksFirst[0]);
547     assert(AftTerm->getSuccessor(ContinueOnTrue) == LoopExit &&
548            "Expecting the ContinueOnTrue successor of AftTerm to be LoopExit");
549   }
550   AftBlocksFirst[0]->replacePhiUsesWith(SubLoopBlocksLast[0],
551                                         SubLoopBlocksLast.back());
552 
553   for (unsigned It = 1; It != Count; It++) {
554     // Replace the conditional branch of the previous iteration subloop with an
555     // unconditional one to this one
556     BranchInst *AftTerm =
557         cast<BranchInst>(AftBlocksLast[It - 1]->getTerminator());
558     BranchInst::Create(AftBlocksFirst[It], AftTerm);
559     AftTerm->eraseFromParent();
560 
561     AftBlocksFirst[It]->replacePhiUsesWith(SubLoopBlocksLast[It],
562                                            SubLoopBlocksLast.back());
563     movePHIs(AftBlocksFirst[It], AftBlocksFirst[0]);
564   }
565 
566   DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
567   // Dominator Tree. Remove the old links between Fore, Sub and Aft, adding the
568   // new ones required.
569   if (Count != 1) {
570     SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
571     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete, ForeBlocksLast[0],
572                            SubLoopBlocksFirst[0]);
573     DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete,
574                            SubLoopBlocksLast[0], AftBlocksFirst[0]);
575 
576     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
577                            ForeBlocksLast.back(), SubLoopBlocksFirst[0]);
578     DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert,
579                            SubLoopBlocksLast.back(), AftBlocksFirst[0]);
580     DTU.applyUpdatesPermissive(DTUpdates);
581   }
582 
583   // Merge adjacent basic blocks, if possible.
584   SmallPtrSet<BasicBlock *, 16> MergeBlocks;
585   MergeBlocks.insert(ForeBlocksLast.begin(), ForeBlocksLast.end());
586   MergeBlocks.insert(SubLoopBlocksLast.begin(), SubLoopBlocksLast.end());
587   MergeBlocks.insert(AftBlocksLast.begin(), AftBlocksLast.end());
588 
589   MergeBlockSuccessorsIntoGivenBlocks(MergeBlocks, L, &DTU, LI);
590 
591   // Apply updates to the DomTree.
592   DT = &DTU.getDomTree();
593 
594   // At this point, the code is well formed.  We now do a quick sweep over the
595   // inserted code, doing constant propagation and dead code elimination as we
596   // go.
597   simplifyLoopAfterUnroll(SubLoop, true, LI, SE, DT, AC, TTI);
598   simplifyLoopAfterUnroll(L, !CompletelyUnroll && Count > 1, LI, SE, DT, AC,
599                           TTI);
600 
601   NumCompletelyUnrolledAndJammed += CompletelyUnroll;
602   ++NumUnrolledAndJammed;
603 
604   // Update LoopInfo if the loop is completely removed.
605   if (CompletelyUnroll)
606     LI->erase(L);
607 
608 #ifndef NDEBUG
609   // We shouldn't have done anything to break loop simplify form or LCSSA.
610   Loop *OutestLoop = SubLoop->getParentLoop()
611                          ? SubLoop->getParentLoop()->getParentLoop()
612                                ? SubLoop->getParentLoop()->getParentLoop()
613                                : SubLoop->getParentLoop()
614                          : SubLoop;
615   assert(DT->verify());
616   LI->verify(*DT);
617   assert(OutestLoop->isRecursivelyLCSSAForm(*DT, *LI));
618   if (!CompletelyUnroll)
619     assert(L->isLoopSimplifyForm());
620   assert(SubLoop->isLoopSimplifyForm());
621   SE->verify();
622 #endif
623 
624   return CompletelyUnroll ? LoopUnrollResult::FullyUnrolled
625                           : LoopUnrollResult::PartiallyUnrolled;
626 }
627 
628 static bool getLoadsAndStores(BasicBlockSet &Blocks,
629                               SmallVector<Instruction *, 4> &MemInstr) {
630   // Scan the BBs and collect legal loads and stores.
631   // Returns false if non-simple loads/stores are found.
632   for (BasicBlock *BB : Blocks) {
633     for (Instruction &I : *BB) {
634       if (auto *Ld = dyn_cast<LoadInst>(&I)) {
635         if (!Ld->isSimple())
636           return false;
637         MemInstr.push_back(&I);
638       } else if (auto *St = dyn_cast<StoreInst>(&I)) {
639         if (!St->isSimple())
640           return false;
641         MemInstr.push_back(&I);
642       } else if (I.mayReadOrWriteMemory()) {
643         return false;
644       }
645     }
646   }
647   return true;
648 }
649 
650 static bool preservesForwardDependence(Instruction *Src, Instruction *Dst,
651                                        unsigned UnrollLevel, unsigned JamLevel,
652                                        bool Sequentialized, Dependence *D) {
653   // UnrollLevel might carry the dependency Src --> Dst
654   // Does a different loop after unrolling?
655   for (unsigned CurLoopDepth = UnrollLevel + 1; CurLoopDepth <= JamLevel;
656        ++CurLoopDepth) {
657     auto JammedDir = D->getDirection(CurLoopDepth);
658     if (JammedDir == Dependence::DVEntry::LT)
659       return true;
660 
661     if (JammedDir & Dependence::DVEntry::GT)
662       return false;
663   }
664 
665   return true;
666 }
667 
668 static bool preservesBackwardDependence(Instruction *Src, Instruction *Dst,
669                                         unsigned UnrollLevel, unsigned JamLevel,
670                                         bool Sequentialized, Dependence *D) {
671   // UnrollLevel might carry the dependency Dst --> Src
672   for (unsigned CurLoopDepth = UnrollLevel + 1; CurLoopDepth <= JamLevel;
673        ++CurLoopDepth) {
674     auto JammedDir = D->getDirection(CurLoopDepth);
675     if (JammedDir == Dependence::DVEntry::GT)
676       return true;
677 
678     if (JammedDir & Dependence::DVEntry::LT)
679       return false;
680   }
681 
682   // Backward dependencies are only preserved if not interleaved.
683   return Sequentialized;
684 }
685 
686 // Check whether it is semantically safe Src and Dst considering any potential
687 // dependency between them.
688 //
689 // @param UnrollLevel The level of the loop being unrolled
690 // @param JamLevel    The level of the loop being jammed; if Src and Dst are on
691 // different levels, the outermost common loop counts as jammed level
692 //
693 // @return true if is safe and false if there is a dependency violation.
694 static bool checkDependency(Instruction *Src, Instruction *Dst,
695                             unsigned UnrollLevel, unsigned JamLevel,
696                             bool Sequentialized, DependenceInfo &DI) {
697   assert(UnrollLevel <= JamLevel &&
698          "Expecting JamLevel to be at least UnrollLevel");
699 
700   if (Src == Dst)
701     return true;
702   // Ignore Input dependencies.
703   if (isa<LoadInst>(Src) && isa<LoadInst>(Dst))
704     return true;
705 
706   // Check whether unroll-and-jam may violate a dependency.
707   // By construction, every dependency will be lexicographically non-negative
708   // (if it was, it would violate the current execution order), such as
709   //   (0,0,>,*,*)
710   // Unroll-and-jam changes the GT execution of two executions to the same
711   // iteration of the chosen unroll level. That is, a GT dependence becomes a GE
712   // dependence (or EQ, if we fully unrolled the loop) at the loop's position:
713   //   (0,0,>=,*,*)
714   // Now, the dependency is not necessarily non-negative anymore, i.e.
715   // unroll-and-jam may violate correctness.
716   std::unique_ptr<Dependence> D = DI.depends(Src, Dst, true);
717   if (!D)
718     return true;
719   assert(D->isOrdered() && "Expected an output, flow or anti dep.");
720 
721   if (D->isConfused()) {
722     LLVM_DEBUG(dbgs() << "  Confused dependency between:\n"
723                       << "  " << *Src << "\n"
724                       << "  " << *Dst << "\n");
725     return false;
726   }
727 
728   // If outer levels (levels enclosing the loop being unroll-and-jammed) have a
729   // non-equal direction, then the locations accessed in the inner levels cannot
730   // overlap in memory. We assumes the indexes never overlap into neighboring
731   // dimensions.
732   for (unsigned CurLoopDepth = 1; CurLoopDepth < UnrollLevel; ++CurLoopDepth)
733     if (!(D->getDirection(CurLoopDepth) & Dependence::DVEntry::EQ))
734       return true;
735 
736   auto UnrollDirection = D->getDirection(UnrollLevel);
737 
738   // If the distance carried by the unrolled loop is 0, then after unrolling
739   // that distance will become non-zero resulting in non-overlapping accesses in
740   // the inner loops.
741   if (UnrollDirection == Dependence::DVEntry::EQ)
742     return true;
743 
744   if (UnrollDirection & Dependence::DVEntry::LT &&
745       !preservesForwardDependence(Src, Dst, UnrollLevel, JamLevel,
746                                   Sequentialized, D.get()))
747     return false;
748 
749   if (UnrollDirection & Dependence::DVEntry::GT &&
750       !preservesBackwardDependence(Src, Dst, UnrollLevel, JamLevel,
751                                    Sequentialized, D.get()))
752     return false;
753 
754   return true;
755 }
756 
757 static bool
758 checkDependencies(Loop &Root, const BasicBlockSet &SubLoopBlocks,
759                   const DenseMap<Loop *, BasicBlockSet> &ForeBlocksMap,
760                   const DenseMap<Loop *, BasicBlockSet> &AftBlocksMap,
761                   DependenceInfo &DI, LoopInfo &LI) {
762   SmallVector<BasicBlockSet, 8> AllBlocks;
763   for (Loop *L : Root.getLoopsInPreorder())
764     if (ForeBlocksMap.find(L) != ForeBlocksMap.end())
765       AllBlocks.push_back(ForeBlocksMap.lookup(L));
766   AllBlocks.push_back(SubLoopBlocks);
767   for (Loop *L : Root.getLoopsInPreorder())
768     if (AftBlocksMap.find(L) != AftBlocksMap.end())
769       AllBlocks.push_back(AftBlocksMap.lookup(L));
770 
771   unsigned LoopDepth = Root.getLoopDepth();
772   SmallVector<Instruction *, 4> EarlierLoadsAndStores;
773   SmallVector<Instruction *, 4> CurrentLoadsAndStores;
774   for (BasicBlockSet &Blocks : AllBlocks) {
775     CurrentLoadsAndStores.clear();
776     if (!getLoadsAndStores(Blocks, CurrentLoadsAndStores))
777       return false;
778 
779     Loop *CurLoop = LI.getLoopFor((*Blocks.begin())->front().getParent());
780     unsigned CurLoopDepth = CurLoop->getLoopDepth();
781 
782     for (auto *Earlier : EarlierLoadsAndStores) {
783       Loop *EarlierLoop = LI.getLoopFor(Earlier->getParent());
784       unsigned EarlierDepth = EarlierLoop->getLoopDepth();
785       unsigned CommonLoopDepth = std::min(EarlierDepth, CurLoopDepth);
786       for (auto *Later : CurrentLoadsAndStores) {
787         if (!checkDependency(Earlier, Later, LoopDepth, CommonLoopDepth, false,
788                              DI))
789           return false;
790       }
791     }
792 
793     size_t NumInsts = CurrentLoadsAndStores.size();
794     for (size_t I = 0; I < NumInsts; ++I) {
795       for (size_t J = I; J < NumInsts; ++J) {
796         if (!checkDependency(CurrentLoadsAndStores[I], CurrentLoadsAndStores[J],
797                              LoopDepth, CurLoopDepth, true, DI))
798           return false;
799       }
800     }
801 
802     EarlierLoadsAndStores.append(CurrentLoadsAndStores.begin(),
803                                  CurrentLoadsAndStores.end());
804   }
805   return true;
806 }
807 
808 static bool isEligibleLoopForm(const Loop &Root) {
809   // Root must have a child.
810   if (Root.getSubLoops().size() != 1)
811     return false;
812 
813   const Loop *L = &Root;
814   do {
815     // All loops in Root need to be in simplify and rotated form.
816     if (!L->isLoopSimplifyForm())
817       return false;
818 
819     if (!L->isRotatedForm())
820       return false;
821 
822     if (L->getHeader()->hasAddressTaken()) {
823       LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Address taken\n");
824       return false;
825     }
826 
827     unsigned SubLoopsSize = L->getSubLoops().size();
828     if (SubLoopsSize == 0)
829       return true;
830 
831     // Only one child is allowed.
832     if (SubLoopsSize != 1)
833       return false;
834 
835     L = L->getSubLoops()[0];
836   } while (L);
837 
838   return true;
839 }
840 
841 static Loop *getInnerMostLoop(Loop *L) {
842   while (!L->getSubLoops().empty())
843     L = L->getSubLoops()[0];
844   return L;
845 }
846 
847 bool llvm::isSafeToUnrollAndJam(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
848                                 DependenceInfo &DI, LoopInfo &LI) {
849   if (!isEligibleLoopForm(*L)) {
850     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Ineligible loop form\n");
851     return false;
852   }
853 
854   /* We currently handle outer loops like this:
855         |
856     ForeFirst    <------\   }
857      Blocks             |   } ForeBlocks of L
858     ForeLast            |   }
859         |               |
860        ...              |
861         |               |
862     ForeFirst    <----\ |   }
863      Blocks           | |   } ForeBlocks of a inner loop of L
864     ForeLast          | |   }
865         |             | |
866     JamLoopFirst  <\  | |   }
867      Blocks        |  | |   } JamLoopBlocks of the innermost loop
868     JamLoopLast   -/  | |   }
869         |             | |
870     AftFirst          | |   }
871      Blocks           | |   } AftBlocks of a inner loop of L
872     AftLast     ------/ |   }
873         |               |
874        ...              |
875         |               |
876     AftFirst            |   }
877      Blocks             |   } AftBlocks of L
878     AftLast     --------/   }
879         |
880 
881     There are (theoretically) any number of blocks in ForeBlocks, SubLoopBlocks
882     and AftBlocks, providing that there is one edge from Fores to SubLoops,
883     one edge from SubLoops to Afts and a single outer loop exit (from Afts).
884     In practice we currently limit Aft blocks to a single block, and limit
885     things further in the profitablility checks of the unroll and jam pass.
886 
887     Because of the way we rearrange basic blocks, we also require that
888     the Fore blocks of L on all unrolled iterations are safe to move before the
889     blocks of the direct child of L of all iterations. So we require that the
890     phi node looping operands of ForeHeader can be moved to at least the end of
891     ForeEnd, so that we can arrange cloned Fore Blocks before the subloop and
892     match up Phi's correctly.
893 
894     i.e. The old order of blocks used to be
895            (F1)1 (F2)1 J1_1 J1_2 (A2)1 (A1)1 (F1)2 (F2)2 J2_1 J2_2 (A2)2 (A1)2.
896          It needs to be safe to transform this to
897            (F1)1 (F1)2 (F2)1 (F2)2 J1_1 J1_2 J2_1 J2_2 (A2)1 (A2)2 (A1)1 (A1)2.
898 
899     There are then a number of checks along the lines of no calls, no
900     exceptions, inner loop IV is consistent, etc. Note that for loops requiring
901     runtime unrolling, UnrollRuntimeLoopRemainder can also fail in
902     UnrollAndJamLoop if the trip count cannot be easily calculated.
903   */
904 
905   // Split blocks into Fore/SubLoop/Aft based on dominators
906   Loop *JamLoop = getInnerMostLoop(L);
907   BasicBlockSet SubLoopBlocks;
908   DenseMap<Loop *, BasicBlockSet> ForeBlocksMap;
909   DenseMap<Loop *, BasicBlockSet> AftBlocksMap;
910   if (!partitionOuterLoopBlocks(*L, *JamLoop, SubLoopBlocks, ForeBlocksMap,
911                                 AftBlocksMap, DT)) {
912     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Incompatible loop layout\n");
913     return false;
914   }
915 
916   // Aft blocks may need to move instructions to fore blocks, which becomes more
917   // difficult if there are multiple (potentially conditionally executed)
918   // blocks. For now we just exclude loops with multiple aft blocks.
919   if (AftBlocksMap[L].size() != 1) {
920     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Can't currently handle "
921                          "multiple blocks after the loop\n");
922     return false;
923   }
924 
925   // Check inner loop backedge count is consistent on all iterations of the
926   // outer loop
927   if (any_of(L->getLoopsInPreorder(), [&SE](Loop *SubLoop) {
928         return !hasIterationCountInvariantInParent(SubLoop, SE);
929       })) {
930     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Inner loop iteration count is "
931                          "not consistent on each iteration\n");
932     return false;
933   }
934 
935   // Check the loop safety info for exceptions.
936   SimpleLoopSafetyInfo LSI;
937   LSI.computeLoopSafetyInfo(L);
938   if (LSI.anyBlockMayThrow()) {
939     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Something may throw\n");
940     return false;
941   }
942 
943   // We've ruled out the easy stuff and now need to check that there are no
944   // interdependencies which may prevent us from moving the:
945   //  ForeBlocks before Subloop and AftBlocks.
946   //  Subloop before AftBlocks.
947   //  ForeBlock phi operands before the subloop
948 
949   // Make sure we can move all instructions we need to before the subloop
950   BasicBlock *Header = L->getHeader();
951   BasicBlock *Latch = L->getLoopLatch();
952   BasicBlockSet AftBlocks = AftBlocksMap[L];
953   Loop *SubLoop = L->getSubLoops()[0];
954   if (!processHeaderPhiOperands(
955           Header, Latch, AftBlocks, [&AftBlocks, &SubLoop](Instruction *I) {
956             if (SubLoop->contains(I->getParent()))
957               return false;
958             if (AftBlocks.count(I->getParent())) {
959               // If we hit a phi node in afts we know we are done (probably
960               // LCSSA)
961               if (isa<PHINode>(I))
962                 return false;
963               // Can't move instructions with side effects or memory
964               // reads/writes
965               if (I->mayHaveSideEffects() || I->mayReadOrWriteMemory())
966                 return false;
967             }
968             // Keep going
969             return true;
970           })) {
971     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; can't move required "
972                          "instructions after subloop to before it\n");
973     return false;
974   }
975 
976   // Check for memory dependencies which prohibit the unrolling we are doing.
977   // Because of the way we are unrolling Fore/Sub/Aft blocks, we need to check
978   // there are no dependencies between Fore-Sub, Fore-Aft, Sub-Aft and Sub-Sub.
979   if (!checkDependencies(*L, SubLoopBlocks, ForeBlocksMap, AftBlocksMap, DI,
980                          LI)) {
981     LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; failed dependency check\n");
982     return false;
983   }
984 
985   return true;
986 }
987