1 //===------- LoopBoundSplit.cpp - Split Loop Bound --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Scalar/LoopBoundSplit.h"
10 #include "llvm/ADT/Sequence.h"
11 #include "llvm/Analysis/LoopAccessAnalysis.h"
12 #include "llvm/Analysis/LoopAnalysisManager.h"
13 #include "llvm/Analysis/LoopInfo.h"
14 #include "llvm/Analysis/LoopIterator.h"
15 #include "llvm/Analysis/LoopPass.h"
16 #include "llvm/Analysis/MemorySSA.h"
17 #include "llvm/Analysis/MemorySSAUpdater.h"
18 #include "llvm/Analysis/ScalarEvolution.h"
19 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
20 #include "llvm/IR/PatternMatch.h"
21 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
22 #include "llvm/Transforms/Utils/Cloning.h"
23 #include "llvm/Transforms/Utils/LoopSimplify.h"
24 #include "llvm/Transforms/Utils/LoopUtils.h"
25 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
26 
27 #define DEBUG_TYPE "loop-bound-split"
28 
29 namespace llvm {
30 
31 using namespace PatternMatch;
32 
33 namespace {
34 struct ConditionInfo {
35   /// Branch instruction with this condition
36   BranchInst *BI;
37   /// ICmp instruction with this condition
38   ICmpInst *ICmp;
39   /// Preciate info
40   ICmpInst::Predicate Pred;
41   /// AddRec llvm value
42   Value *AddRecValue;
43   /// Bound llvm value
44   Value *BoundValue;
45   /// AddRec SCEV
46   const SCEV *AddRecSCEV;
47   /// Bound SCEV
48   const SCEV *BoundSCEV;
49 
50   ConditionInfo()
51       : BI(nullptr), ICmp(nullptr), Pred(ICmpInst::BAD_ICMP_PREDICATE),
52         AddRecValue(nullptr), BoundValue(nullptr), AddRecSCEV(nullptr),
53         BoundSCEV(nullptr) {}
54 };
55 } // namespace
56 
57 static void analyzeICmp(ScalarEvolution &SE, ICmpInst *ICmp,
58                         ConditionInfo &Cond) {
59   Cond.ICmp = ICmp;
60   if (match(ICmp, m_ICmp(Cond.Pred, m_Value(Cond.AddRecValue),
61                          m_Value(Cond.BoundValue)))) {
62     Cond.AddRecSCEV = SE.getSCEV(Cond.AddRecValue);
63     Cond.BoundSCEV = SE.getSCEV(Cond.BoundValue);
64     // Locate AddRec in LHSSCEV and Bound in RHSSCEV.
65     if (isa<SCEVAddRecExpr>(Cond.BoundSCEV) &&
66         !isa<SCEVAddRecExpr>(Cond.AddRecSCEV)) {
67       std::swap(Cond.AddRecValue, Cond.BoundValue);
68       std::swap(Cond.AddRecSCEV, Cond.BoundSCEV);
69       Cond.Pred = ICmpInst::getSwappedPredicate(Cond.Pred);
70     }
71   }
72 }
73 
74 static bool calculateUpperBound(const Loop &L, ScalarEvolution &SE,
75                                 ConditionInfo &Cond, bool IsExitCond) {
76   if (IsExitCond) {
77     const SCEV *ExitCount = SE.getExitCount(&L, Cond.ICmp->getParent());
78     if (isa<SCEVCouldNotCompute>(ExitCount))
79       return false;
80 
81     Cond.BoundSCEV = ExitCount;
82     return true;
83   }
84 
85   // For non-exit condtion, if pred is LT, keep existing bound.
86   if (Cond.Pred == ICmpInst::ICMP_SLT || Cond.Pred == ICmpInst::ICMP_ULT)
87     return true;
88 
89   // For non-exit condition, if pre is LE, try to convert it to LT.
90   //      Range                 Range
91   // AddRec <= Bound  -->  AddRec < Bound + 1
92   if (Cond.Pred != ICmpInst::ICMP_ULE && Cond.Pred != ICmpInst::ICMP_SLE)
93     return false;
94 
95   if (IntegerType *BoundSCEVIntType =
96           dyn_cast<IntegerType>(Cond.BoundSCEV->getType())) {
97     unsigned BitWidth = BoundSCEVIntType->getBitWidth();
98     APInt Max = ICmpInst::isSigned(Cond.Pred)
99                     ? APInt::getSignedMaxValue(BitWidth)
100                     : APInt::getMaxValue(BitWidth);
101     const SCEV *MaxSCEV = SE.getConstant(Max);
102     // Check Bound < INT_MAX
103     ICmpInst::Predicate Pred =
104         ICmpInst::isSigned(Cond.Pred) ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
105     if (SE.isKnownPredicate(Pred, Cond.BoundSCEV, MaxSCEV)) {
106       const SCEV *BoundPlusOneSCEV =
107           SE.getAddExpr(Cond.BoundSCEV, SE.getOne(BoundSCEVIntType));
108       Cond.BoundSCEV = BoundPlusOneSCEV;
109       Cond.Pred = Pred;
110       return true;
111     }
112   }
113 
114   // ToDo: Support ICMP_NE/EQ.
115 
116   return false;
117 }
118 
119 static bool hasProcessableCondition(const Loop &L, ScalarEvolution &SE,
120                                     ICmpInst *ICmp, ConditionInfo &Cond,
121                                     bool IsExitCond) {
122   analyzeICmp(SE, ICmp, Cond);
123 
124   // The BoundSCEV should be evaluated at loop entry.
125   if (!SE.isAvailableAtLoopEntry(Cond.BoundSCEV, &L))
126     return false;
127 
128   const SCEVAddRecExpr *AddRecSCEV = dyn_cast<SCEVAddRecExpr>(Cond.AddRecSCEV);
129   // Allowed AddRec as induction variable.
130   if (!AddRecSCEV)
131     return false;
132 
133   if (!AddRecSCEV->isAffine())
134     return false;
135 
136   const SCEV *StepRecSCEV = AddRecSCEV->getStepRecurrence(SE);
137   // Allowed constant step.
138   if (!isa<SCEVConstant>(StepRecSCEV))
139     return false;
140 
141   ConstantInt *StepCI = cast<SCEVConstant>(StepRecSCEV)->getValue();
142   // Allowed positive step for now.
143   // TODO: Support negative step.
144   if (StepCI->isNegative() || StepCI->isZero())
145     return false;
146 
147   // Calculate upper bound.
148   if (!calculateUpperBound(L, SE, Cond, IsExitCond))
149     return false;
150 
151   return true;
152 }
153 
154 static bool isProcessableCondBI(const ScalarEvolution &SE,
155                                 const BranchInst *BI) {
156   BasicBlock *TrueSucc = nullptr;
157   BasicBlock *FalseSucc = nullptr;
158   ICmpInst::Predicate Pred;
159   Value *LHS, *RHS;
160   if (!match(BI, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)),
161                       m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc))))
162     return false;
163 
164   if (!SE.isSCEVable(LHS->getType()))
165     return false;
166   assert(SE.isSCEVable(RHS->getType()) && "Expected RHS's type is SCEVable");
167 
168   if (TrueSucc == FalseSucc)
169     return false;
170 
171   return true;
172 }
173 
174 static bool canSplitLoopBound(const Loop &L, const DominatorTree &DT,
175                               ScalarEvolution &SE, ConditionInfo &Cond) {
176   // Skip function with optsize.
177   if (L.getHeader()->getParent()->hasOptSize())
178     return false;
179 
180   // Split only innermost loop.
181   if (!L.isInnermost())
182     return false;
183 
184   // Check loop is in simplified form.
185   if (!L.isLoopSimplifyForm())
186     return false;
187 
188   // Check loop is in LCSSA form.
189   if (!L.isLCSSAForm(DT))
190     return false;
191 
192   // Skip loop that cannot be cloned.
193   if (!L.isSafeToClone())
194     return false;
195 
196   BasicBlock *ExitingBB = L.getExitingBlock();
197   // Assumed only one exiting block.
198   if (!ExitingBB)
199     return false;
200 
201   BranchInst *ExitingBI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
202   if (!ExitingBI)
203     return false;
204 
205   // Allowed only conditional branch with ICmp.
206   if (!isProcessableCondBI(SE, ExitingBI))
207     return false;
208 
209   // Check the condition is processable.
210   ICmpInst *ICmp = cast<ICmpInst>(ExitingBI->getCondition());
211   if (!hasProcessableCondition(L, SE, ICmp, Cond, /*IsExitCond*/ true))
212     return false;
213 
214   Cond.BI = ExitingBI;
215   return true;
216 }
217 
218 static bool isProfitableToTransform(const Loop &L, const BranchInst *BI) {
219   // If the conditional branch splits a loop into two halves, we could
220   // generally say it is profitable.
221   //
222   // ToDo: Add more profitable cases here.
223 
224   // Check this branch causes diamond CFG.
225   BasicBlock *Succ0 = BI->getSuccessor(0);
226   BasicBlock *Succ1 = BI->getSuccessor(1);
227 
228   BasicBlock *Succ0Succ = Succ0->getSingleSuccessor();
229   BasicBlock *Succ1Succ = Succ1->getSingleSuccessor();
230   if (!Succ0Succ || !Succ1Succ || Succ0Succ != Succ1Succ)
231     return false;
232 
233   // ToDo: Calculate each successor's instruction cost.
234 
235   return true;
236 }
237 
238 static BranchInst *findSplitCandidate(const Loop &L, ScalarEvolution &SE,
239                                       ConditionInfo &ExitingCond,
240                                       ConditionInfo &SplitCandidateCond) {
241   for (auto *BB : L.blocks()) {
242     // Skip condition of backedge.
243     if (L.getLoopLatch() == BB)
244       continue;
245 
246     auto *BI = dyn_cast<BranchInst>(BB->getTerminator());
247     if (!BI)
248       continue;
249 
250     // Check conditional branch with ICmp.
251     if (!isProcessableCondBI(SE, BI))
252       continue;
253 
254     // Skip loop invariant condition.
255     if (L.isLoopInvariant(BI->getCondition()))
256       continue;
257 
258     // Check the condition is processable.
259     ICmpInst *ICmp = cast<ICmpInst>(BI->getCondition());
260     if (!hasProcessableCondition(L, SE, ICmp, SplitCandidateCond,
261                                  /*IsExitCond*/ false))
262       continue;
263 
264     if (ExitingCond.BoundSCEV->getType() !=
265         SplitCandidateCond.BoundSCEV->getType())
266       continue;
267 
268     // After transformation, we assume the split condition of the pre-loop is
269     // always true. In order to guarantee it, we need to check the start value
270     // of the split cond AddRec satisfies the split condition.
271     const SCEV *SplitAddRecStartSCEV =
272         cast<SCEVAddRecExpr>(SplitCandidateCond.AddRecSCEV)->getStart();
273     if (!SE.isKnownPredicate(SplitCandidateCond.Pred, SplitAddRecStartSCEV,
274                              SplitCandidateCond.BoundSCEV))
275       continue;
276 
277     SplitCandidateCond.BI = BI;
278     return BI;
279   }
280 
281   return nullptr;
282 }
283 
284 static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI,
285                            ScalarEvolution &SE, LPMUpdater &U) {
286   ConditionInfo SplitCandidateCond;
287   ConditionInfo ExitingCond;
288 
289   // Check we can split this loop's bound.
290   if (!canSplitLoopBound(L, DT, SE, ExitingCond))
291     return false;
292 
293   if (!findSplitCandidate(L, SE, ExitingCond, SplitCandidateCond))
294     return false;
295 
296   if (!isProfitableToTransform(L, SplitCandidateCond.BI))
297     return false;
298 
299   // Now, we have a split candidate. Let's build a form as below.
300   //    +--------------------+
301   //    |     preheader      |
302   //    |  set up newbound   |
303   //    +--------------------+
304   //             |     /----------------\
305   //    +--------v----v------+          |
306   //    |      header        |---\      |
307   //    | with true condition|   |      |
308   //    +--------------------+   |      |
309   //             |               |      |
310   //    +--------v-----------+   |      |
311   //    |     if.then.BB     |   |      |
312   //    +--------------------+   |      |
313   //             |               |      |
314   //    +--------v-----------<---/      |
315   //    |       latch        >----------/
316   //    |   with newbound    |
317   //    +--------------------+
318   //             |
319   //    +--------v-----------+
320   //    |     preheader2     |--------------\
321   //    | if (AddRec i !=    |              |
322   //    |     org bound)     |              |
323   //    +--------------------+              |
324   //             |     /----------------\   |
325   //    +--------v----v------+          |   |
326   //    |      header2       |---\      |   |
327   //    | conditional branch |   |      |   |
328   //    |with false condition|   |      |   |
329   //    +--------------------+   |      |   |
330   //             |               |      |   |
331   //    +--------v-----------+   |      |   |
332   //    |    if.then.BB2     |   |      |   |
333   //    +--------------------+   |      |   |
334   //             |               |      |   |
335   //    +--------v-----------<---/      |   |
336   //    |       latch2       >----------/   |
337   //    |   with org bound   |              |
338   //    +--------v-----------+              |
339   //             |                          |
340   //             |  +---------------+       |
341   //             +-->     exit      <-------/
342   //                +---------------+
343 
344   // Let's create post loop.
345   SmallVector<BasicBlock *, 8> PostLoopBlocks;
346   Loop *PostLoop;
347   ValueToValueMapTy VMap;
348   BasicBlock *PreHeader = L.getLoopPreheader();
349   BasicBlock *SplitLoopPH = SplitEdge(PreHeader, L.getHeader(), &DT, &LI);
350   PostLoop = cloneLoopWithPreheader(L.getExitBlock(), SplitLoopPH, &L, VMap,
351                                     ".split", &LI, &DT, PostLoopBlocks);
352   remapInstructionsInBlocks(PostLoopBlocks, VMap);
353 
354   // Add conditional branch to check we can skip post-loop in its preheader.
355   BasicBlock *PostLoopPreHeader = PostLoop->getLoopPreheader();
356   IRBuilder<> Builder(PostLoopPreHeader);
357   Instruction *OrigBI = PostLoopPreHeader->getTerminator();
358   ICmpInst::Predicate Pred = ICmpInst::ICMP_NE;
359   Value *Cond =
360       Builder.CreateICmp(Pred, ExitingCond.AddRecValue, ExitingCond.BoundValue);
361   Builder.CreateCondBr(Cond, PostLoop->getHeader(), PostLoop->getExitBlock());
362   OrigBI->eraseFromParent();
363 
364   // Create new loop bound and add it into preheader of pre-loop.
365   const SCEV *NewBoundSCEV = ExitingCond.BoundSCEV;
366   const SCEV *SplitBoundSCEV = SplitCandidateCond.BoundSCEV;
367   NewBoundSCEV = ICmpInst::isSigned(ExitingCond.Pred)
368                      ? SE.getSMinExpr(NewBoundSCEV, SplitBoundSCEV)
369                      : SE.getUMinExpr(NewBoundSCEV, SplitBoundSCEV);
370 
371   SCEVExpander Expander(
372       SE, L.getHeader()->getParent()->getParent()->getDataLayout(), "split");
373   Instruction *InsertPt = SplitLoopPH->getTerminator();
374   Value *NewBoundValue =
375       Expander.expandCodeFor(NewBoundSCEV, NewBoundSCEV->getType(), InsertPt);
376   NewBoundValue->setName("new.bound");
377 
378   // Replace exiting bound value of pre-loop NewBound.
379   ExitingCond.ICmp->setOperand(1, NewBoundValue);
380 
381   // Replace IV's start value of post-loop by NewBound.
382   for (PHINode &PN : L.getHeader()->phis()) {
383     // Find PHI with exiting condition from pre-loop.
384     if (SE.isSCEVable(PN.getType()) && isa<SCEVAddRecExpr>(SE.getSCEV(&PN))) {
385       for (Value *Op : PN.incoming_values()) {
386         if (Op == ExitingCond.AddRecValue) {
387           // Find cloned PHI for post-loop.
388           PHINode *PostLoopPN = cast<PHINode>(VMap[&PN]);
389           PostLoopPN->setIncomingValueForBlock(PostLoopPreHeader,
390                                                NewBoundValue);
391         }
392       }
393     }
394   }
395 
396   // Replace SplitCandidateCond.BI's condition of pre-loop by True.
397   LLVMContext &Context = PreHeader->getContext();
398   SplitCandidateCond.BI->setCondition(ConstantInt::getTrue(Context));
399 
400   // Replace cloned SplitCandidateCond.BI's condition in post-loop by False.
401   BranchInst *ClonedSplitCandidateBI =
402       cast<BranchInst>(VMap[SplitCandidateCond.BI]);
403   ClonedSplitCandidateBI->setCondition(ConstantInt::getFalse(Context));
404 
405   // Replace exit branch target of pre-loop by post-loop's preheader.
406   if (L.getExitBlock() == ExitingCond.BI->getSuccessor(0))
407     ExitingCond.BI->setSuccessor(0, PostLoopPreHeader);
408   else
409     ExitingCond.BI->setSuccessor(1, PostLoopPreHeader);
410 
411   // Update phi node in exit block of post-loop.
412   for (PHINode &PN : PostLoop->getExitBlock()->phis()) {
413     for (auto i : seq<int>(0, PN.getNumOperands())) {
414       // Check incoming block is pre-loop's exiting block.
415       if (PN.getIncomingBlock(i) == L.getExitingBlock()) {
416         // Replace pre-loop's exiting block by post-loop's preheader.
417         PN.setIncomingBlock(i, PostLoopPreHeader);
418         // Add a new incoming value with post-loop's exiting block.
419         PN.addIncoming(VMap[PN.getIncomingValue(i)],
420                        PostLoop->getExitingBlock());
421       }
422     }
423   }
424 
425   // Update dominator tree.
426   DT.changeImmediateDominator(PostLoopPreHeader, L.getExitingBlock());
427   DT.changeImmediateDominator(PostLoop->getExitBlock(), PostLoopPreHeader);
428 
429   // Invalidate cached SE information.
430   SE.forgetLoop(&L);
431 
432   // Canonicalize loops.
433   // TODO: Try to update LCSSA information according to above change.
434   formLCSSA(L, DT, &LI, &SE);
435   simplifyLoop(&L, &DT, &LI, &SE, nullptr, nullptr, true);
436   formLCSSA(*PostLoop, DT, &LI, &SE);
437   simplifyLoop(PostLoop, &DT, &LI, &SE, nullptr, nullptr, true);
438 
439   // Add new post-loop to loop pass manager.
440   U.addSiblingLoops(PostLoop);
441 
442   return true;
443 }
444 
445 PreservedAnalyses LoopBoundSplitPass::run(Loop &L, LoopAnalysisManager &AM,
446                                           LoopStandardAnalysisResults &AR,
447                                           LPMUpdater &U) {
448   Function &F = *L.getHeader()->getParent();
449   (void)F;
450 
451   LLVM_DEBUG(dbgs() << "Spliting bound of loop in " << F.getName() << ": " << L
452                     << "\n");
453 
454   if (!splitLoopBound(L, AR.DT, AR.LI, AR.SE, U))
455     return PreservedAnalyses::all();
456 
457   assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast));
458   AR.LI.verify(AR.DT);
459 
460   return getLoopPassPreservedAnalyses();
461 }
462 
463 } // end namespace llvm
464