1 //===------- LoopBoundSplit.cpp - Split Loop Bound --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Scalar/LoopBoundSplit.h"
10 #include "llvm/ADT/Sequence.h"
11 #include "llvm/Analysis/LoopAnalysisManager.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/MemorySSAUpdater.h"
14 #include "llvm/Analysis/ScalarEvolution.h"
15 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
16 #include "llvm/IR/PatternMatch.h"
17 #include "llvm/Transforms/Scalar/LoopPassManager.h"
18 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
19 #include "llvm/Transforms/Utils/Cloning.h"
20 #include "llvm/Transforms/Utils/LoopSimplify.h"
21 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
22 
23 #define DEBUG_TYPE "loop-bound-split"
24 
25 namespace llvm {
26 
27 using namespace PatternMatch;
28 
29 namespace {
30 struct ConditionInfo {
31   /// Branch instruction with this condition
32   BranchInst *BI = nullptr;
33   /// ICmp instruction with this condition
34   ICmpInst *ICmp = nullptr;
35   /// Preciate info
36   ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
37   /// AddRec llvm value
38   Value *AddRecValue = nullptr;
39   /// Non PHI AddRec llvm value
40   Value *NonPHIAddRecValue;
41   /// Bound llvm value
42   Value *BoundValue = nullptr;
43   /// AddRec SCEV
44   const SCEVAddRecExpr *AddRecSCEV = nullptr;
45   /// Bound SCEV
46   const SCEV *BoundSCEV = nullptr;
47 
48   ConditionInfo() = default;
49 };
50 } // namespace
51 
52 static void analyzeICmp(ScalarEvolution &SE, ICmpInst *ICmp,
53                         ConditionInfo &Cond, const Loop &L) {
54   Cond.ICmp = ICmp;
55   if (match(ICmp, m_ICmp(Cond.Pred, m_Value(Cond.AddRecValue),
56                          m_Value(Cond.BoundValue)))) {
57     const SCEV *AddRecSCEV = SE.getSCEV(Cond.AddRecValue);
58     const SCEV *BoundSCEV = SE.getSCEV(Cond.BoundValue);
59     const SCEVAddRecExpr *LHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV);
60     const SCEVAddRecExpr *RHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(BoundSCEV);
61     // Locate AddRec in LHSSCEV and Bound in RHSSCEV.
62     if (!LHSAddRecSCEV && RHSAddRecSCEV) {
63       std::swap(Cond.AddRecValue, Cond.BoundValue);
64       std::swap(AddRecSCEV, BoundSCEV);
65       Cond.Pred = ICmpInst::getSwappedPredicate(Cond.Pred);
66     }
67 
68     Cond.AddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV);
69     Cond.BoundSCEV = BoundSCEV;
70     Cond.NonPHIAddRecValue = Cond.AddRecValue;
71 
72     // If the Cond.AddRecValue is PHI node, update Cond.NonPHIAddRecValue with
73     // value from backedge.
74     if (Cond.AddRecSCEV && isa<PHINode>(Cond.AddRecValue)) {
75       PHINode *PN = cast<PHINode>(Cond.AddRecValue);
76       Cond.NonPHIAddRecValue = PN->getIncomingValueForBlock(L.getLoopLatch());
77     }
78   }
79 }
80 
81 static bool calculateUpperBound(const Loop &L, ScalarEvolution &SE,
82                                 ConditionInfo &Cond, bool IsExitCond) {
83   if (IsExitCond) {
84     const SCEV *ExitCount = SE.getExitCount(&L, Cond.ICmp->getParent());
85     if (isa<SCEVCouldNotCompute>(ExitCount))
86       return false;
87 
88     Cond.BoundSCEV = ExitCount;
89     return true;
90   }
91 
92   // For non-exit condtion, if pred is LT, keep existing bound.
93   if (Cond.Pred == ICmpInst::ICMP_SLT || Cond.Pred == ICmpInst::ICMP_ULT)
94     return true;
95 
96   // For non-exit condition, if pre is LE, try to convert it to LT.
97   //      Range                 Range
98   // AddRec <= Bound  -->  AddRec < Bound + 1
99   if (Cond.Pred != ICmpInst::ICMP_ULE && Cond.Pred != ICmpInst::ICMP_SLE)
100     return false;
101 
102   if (IntegerType *BoundSCEVIntType =
103           dyn_cast<IntegerType>(Cond.BoundSCEV->getType())) {
104     unsigned BitWidth = BoundSCEVIntType->getBitWidth();
105     APInt Max = ICmpInst::isSigned(Cond.Pred)
106                     ? APInt::getSignedMaxValue(BitWidth)
107                     : APInt::getMaxValue(BitWidth);
108     const SCEV *MaxSCEV = SE.getConstant(Max);
109     // Check Bound < INT_MAX
110     ICmpInst::Predicate Pred =
111         ICmpInst::isSigned(Cond.Pred) ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
112     if (SE.isKnownPredicate(Pred, Cond.BoundSCEV, MaxSCEV)) {
113       const SCEV *BoundPlusOneSCEV =
114           SE.getAddExpr(Cond.BoundSCEV, SE.getOne(BoundSCEVIntType));
115       Cond.BoundSCEV = BoundPlusOneSCEV;
116       Cond.Pred = Pred;
117       return true;
118     }
119   }
120 
121   // ToDo: Support ICMP_NE/EQ.
122 
123   return false;
124 }
125 
126 static bool hasProcessableCondition(const Loop &L, ScalarEvolution &SE,
127                                     ICmpInst *ICmp, ConditionInfo &Cond,
128                                     bool IsExitCond) {
129   analyzeICmp(SE, ICmp, Cond, L);
130 
131   // The BoundSCEV should be evaluated at loop entry.
132   if (!SE.isAvailableAtLoopEntry(Cond.BoundSCEV, &L))
133     return false;
134 
135   // Allowed AddRec as induction variable.
136   if (!Cond.AddRecSCEV)
137     return false;
138 
139   if (!Cond.AddRecSCEV->isAffine())
140     return false;
141 
142   const SCEV *StepRecSCEV = Cond.AddRecSCEV->getStepRecurrence(SE);
143   // Allowed constant step.
144   if (!isa<SCEVConstant>(StepRecSCEV))
145     return false;
146 
147   ConstantInt *StepCI = cast<SCEVConstant>(StepRecSCEV)->getValue();
148   // Allowed positive step for now.
149   // TODO: Support negative step.
150   if (StepCI->isNegative() || StepCI->isZero())
151     return false;
152 
153   // Calculate upper bound.
154   if (!calculateUpperBound(L, SE, Cond, IsExitCond))
155     return false;
156 
157   return true;
158 }
159 
160 static bool isProcessableCondBI(const ScalarEvolution &SE,
161                                 const BranchInst *BI) {
162   BasicBlock *TrueSucc = nullptr;
163   BasicBlock *FalseSucc = nullptr;
164   ICmpInst::Predicate Pred;
165   Value *LHS, *RHS;
166   if (!match(BI, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)),
167                       m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc))))
168     return false;
169 
170   if (!SE.isSCEVable(LHS->getType()))
171     return false;
172   assert(SE.isSCEVable(RHS->getType()) && "Expected RHS's type is SCEVable");
173 
174   if (TrueSucc == FalseSucc)
175     return false;
176 
177   return true;
178 }
179 
180 static bool canSplitLoopBound(const Loop &L, const DominatorTree &DT,
181                               ScalarEvolution &SE, ConditionInfo &Cond) {
182   // Skip function with optsize.
183   if (L.getHeader()->getParent()->hasOptSize())
184     return false;
185 
186   // Split only innermost loop.
187   if (!L.isInnermost())
188     return false;
189 
190   // Check loop is in simplified form.
191   if (!L.isLoopSimplifyForm())
192     return false;
193 
194   // Check loop is in LCSSA form.
195   if (!L.isLCSSAForm(DT))
196     return false;
197 
198   // Skip loop that cannot be cloned.
199   if (!L.isSafeToClone())
200     return false;
201 
202   BasicBlock *ExitingBB = L.getExitingBlock();
203   // Assumed only one exiting block.
204   if (!ExitingBB)
205     return false;
206 
207   BranchInst *ExitingBI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
208   if (!ExitingBI)
209     return false;
210 
211   // Allowed only conditional branch with ICmp.
212   if (!isProcessableCondBI(SE, ExitingBI))
213     return false;
214 
215   // Check the condition is processable.
216   ICmpInst *ICmp = cast<ICmpInst>(ExitingBI->getCondition());
217   if (!hasProcessableCondition(L, SE, ICmp, Cond, /*IsExitCond*/ true))
218     return false;
219 
220   Cond.BI = ExitingBI;
221   return true;
222 }
223 
224 static bool isProfitableToTransform(const Loop &L, const BranchInst *BI) {
225   // If the conditional branch splits a loop into two halves, we could
226   // generally say it is profitable.
227   //
228   // ToDo: Add more profitable cases here.
229 
230   // Check this branch causes diamond CFG.
231   BasicBlock *Succ0 = BI->getSuccessor(0);
232   BasicBlock *Succ1 = BI->getSuccessor(1);
233 
234   BasicBlock *Succ0Succ = Succ0->getSingleSuccessor();
235   BasicBlock *Succ1Succ = Succ1->getSingleSuccessor();
236   if (!Succ0Succ || !Succ1Succ || Succ0Succ != Succ1Succ)
237     return false;
238 
239   // ToDo: Calculate each successor's instruction cost.
240 
241   return true;
242 }
243 
244 static BranchInst *findSplitCandidate(const Loop &L, ScalarEvolution &SE,
245                                       ConditionInfo &ExitingCond,
246                                       ConditionInfo &SplitCandidateCond) {
247   for (auto *BB : L.blocks()) {
248     // Skip condition of backedge.
249     if (L.getLoopLatch() == BB)
250       continue;
251 
252     auto *BI = dyn_cast<BranchInst>(BB->getTerminator());
253     if (!BI)
254       continue;
255 
256     // Check conditional branch with ICmp.
257     if (!isProcessableCondBI(SE, BI))
258       continue;
259 
260     // Skip loop invariant condition.
261     if (L.isLoopInvariant(BI->getCondition()))
262       continue;
263 
264     // Check the condition is processable.
265     ICmpInst *ICmp = cast<ICmpInst>(BI->getCondition());
266     if (!hasProcessableCondition(L, SE, ICmp, SplitCandidateCond,
267                                  /*IsExitCond*/ false))
268       continue;
269 
270     if (ExitingCond.BoundSCEV->getType() !=
271         SplitCandidateCond.BoundSCEV->getType())
272       continue;
273 
274     // After transformation, we assume the split condition of the pre-loop is
275     // always true. In order to guarantee it, we need to check the start value
276     // of the split cond AddRec satisfies the split condition.
277     if (!SE.isLoopEntryGuardedByCond(&L, SplitCandidateCond.Pred,
278                                      SplitCandidateCond.AddRecSCEV->getStart(),
279                                      SplitCandidateCond.BoundSCEV))
280       continue;
281 
282     SplitCandidateCond.BI = BI;
283     return BI;
284   }
285 
286   return nullptr;
287 }
288 
289 static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI,
290                            ScalarEvolution &SE, LPMUpdater &U) {
291   ConditionInfo SplitCandidateCond;
292   ConditionInfo ExitingCond;
293 
294   // Check we can split this loop's bound.
295   if (!canSplitLoopBound(L, DT, SE, ExitingCond))
296     return false;
297 
298   if (!findSplitCandidate(L, SE, ExitingCond, SplitCandidateCond))
299     return false;
300 
301   if (!isProfitableToTransform(L, SplitCandidateCond.BI))
302     return false;
303 
304   // Now, we have a split candidate. Let's build a form as below.
305   //    +--------------------+
306   //    |     preheader      |
307   //    |  set up newbound   |
308   //    +--------------------+
309   //             |     /----------------\
310   //    +--------v----v------+          |
311   //    |      header        |---\      |
312   //    | with true condition|   |      |
313   //    +--------------------+   |      |
314   //             |               |      |
315   //    +--------v-----------+   |      |
316   //    |     if.then.BB     |   |      |
317   //    +--------------------+   |      |
318   //             |               |      |
319   //    +--------v-----------<---/      |
320   //    |       latch        >----------/
321   //    |   with newbound    |
322   //    +--------------------+
323   //             |
324   //    +--------v-----------+
325   //    |     preheader2     |--------------\
326   //    | if (AddRec i !=    |              |
327   //    |     org bound)     |              |
328   //    +--------------------+              |
329   //             |     /----------------\   |
330   //    +--------v----v------+          |   |
331   //    |      header2       |---\      |   |
332   //    | conditional branch |   |      |   |
333   //    |with false condition|   |      |   |
334   //    +--------------------+   |      |   |
335   //             |               |      |   |
336   //    +--------v-----------+   |      |   |
337   //    |    if.then.BB2     |   |      |   |
338   //    +--------------------+   |      |   |
339   //             |               |      |   |
340   //    +--------v-----------<---/      |   |
341   //    |       latch2       >----------/   |
342   //    |   with org bound   |              |
343   //    +--------v-----------+              |
344   //             |                          |
345   //             |  +---------------+       |
346   //             +-->     exit      <-------/
347   //                +---------------+
348 
349   // Let's create post loop.
350   SmallVector<BasicBlock *, 8> PostLoopBlocks;
351   Loop *PostLoop;
352   ValueToValueMapTy VMap;
353   BasicBlock *PreHeader = L.getLoopPreheader();
354   BasicBlock *SplitLoopPH = SplitEdge(PreHeader, L.getHeader(), &DT, &LI);
355   PostLoop = cloneLoopWithPreheader(L.getExitBlock(), SplitLoopPH, &L, VMap,
356                                     ".split", &LI, &DT, PostLoopBlocks);
357   remapInstructionsInBlocks(PostLoopBlocks, VMap);
358 
359   BasicBlock *PostLoopPreHeader = PostLoop->getLoopPreheader();
360   IRBuilder<> Builder(&PostLoopPreHeader->front());
361 
362   // Update phi nodes in header of post-loop.
363   bool isExitingLatch =
364       (L.getExitingBlock() == L.getLoopLatch()) ? true : false;
365   Value *ExitingCondLCSSAPhi = nullptr;
366   for (PHINode &PN : L.getHeader()->phis()) {
367     // Create LCSSA phi node in preheader of post-loop.
368     PHINode *LCSSAPhi =
369         Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa");
370     LCSSAPhi->setDebugLoc(PN.getDebugLoc());
371     // If the exiting block is loop latch, the phi does not have the update at
372     // last iteration. In this case, update lcssa phi with value from backedge.
373     LCSSAPhi->addIncoming(
374         isExitingLatch ? PN.getIncomingValueForBlock(L.getLoopLatch()) : &PN,
375         L.getExitingBlock());
376 
377     // Update the start value of phi node in post-loop with the LCSSA phi node.
378     PHINode *PostLoopPN = cast<PHINode>(VMap[&PN]);
379     PostLoopPN->setIncomingValueForBlock(PostLoopPreHeader, LCSSAPhi);
380 
381     // Find PHI with exiting condition from pre-loop. The PHI should be
382     // SCEVAddRecExpr and have same incoming value from backedge with
383     // ExitingCond.
384     if (!SE.isSCEVable(PN.getType()))
385       continue;
386 
387     const SCEVAddRecExpr *PhiSCEV = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
388     if (PhiSCEV && ExitingCond.NonPHIAddRecValue ==
389                        PN.getIncomingValueForBlock(L.getLoopLatch()))
390       ExitingCondLCSSAPhi = LCSSAPhi;
391   }
392 
393   // Add conditional branch to check we can skip post-loop in its preheader.
394   Instruction *OrigBI = PostLoopPreHeader->getTerminator();
395   ICmpInst::Predicate Pred = ICmpInst::ICMP_NE;
396   Value *Cond =
397       Builder.CreateICmp(Pred, ExitingCondLCSSAPhi, ExitingCond.BoundValue);
398   Builder.CreateCondBr(Cond, PostLoop->getHeader(), PostLoop->getExitBlock());
399   OrigBI->eraseFromParent();
400 
401   // Create new loop bound and add it into preheader of pre-loop.
402   const SCEV *NewBoundSCEV = ExitingCond.BoundSCEV;
403   const SCEV *SplitBoundSCEV = SplitCandidateCond.BoundSCEV;
404   NewBoundSCEV = ICmpInst::isSigned(ExitingCond.Pred)
405                      ? SE.getSMinExpr(NewBoundSCEV, SplitBoundSCEV)
406                      : SE.getUMinExpr(NewBoundSCEV, SplitBoundSCEV);
407 
408   SCEVExpander Expander(
409       SE, L.getHeader()->getParent()->getParent()->getDataLayout(), "split");
410   Instruction *InsertPt = SplitLoopPH->getTerminator();
411   Value *NewBoundValue =
412       Expander.expandCodeFor(NewBoundSCEV, NewBoundSCEV->getType(), InsertPt);
413   NewBoundValue->setName("new.bound");
414 
415   // Replace exiting bound value of pre-loop NewBound.
416   ExitingCond.ICmp->setOperand(1, NewBoundValue);
417 
418   // Replace SplitCandidateCond.BI's condition of pre-loop by True.
419   LLVMContext &Context = PreHeader->getContext();
420   SplitCandidateCond.BI->setCondition(ConstantInt::getTrue(Context));
421 
422   // Replace cloned SplitCandidateCond.BI's condition in post-loop by False.
423   BranchInst *ClonedSplitCandidateBI =
424       cast<BranchInst>(VMap[SplitCandidateCond.BI]);
425   ClonedSplitCandidateBI->setCondition(ConstantInt::getFalse(Context));
426 
427   // Replace exit branch target of pre-loop by post-loop's preheader.
428   if (L.getExitBlock() == ExitingCond.BI->getSuccessor(0))
429     ExitingCond.BI->setSuccessor(0, PostLoopPreHeader);
430   else
431     ExitingCond.BI->setSuccessor(1, PostLoopPreHeader);
432 
433   // Update phi node in exit block of post-loop.
434   Builder.SetInsertPoint(&PostLoopPreHeader->front());
435   for (PHINode &PN : PostLoop->getExitBlock()->phis()) {
436     for (auto i : seq<int>(0, PN.getNumOperands())) {
437       // Check incoming block is pre-loop's exiting block.
438       if (PN.getIncomingBlock(i) == L.getExitingBlock()) {
439         Value *IncomingValue = PN.getIncomingValue(i);
440 
441         // Create LCSSA phi node for incoming value.
442         PHINode *LCSSAPhi =
443             Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa");
444         LCSSAPhi->setDebugLoc(PN.getDebugLoc());
445         LCSSAPhi->addIncoming(IncomingValue, PN.getIncomingBlock(i));
446 
447         // Replace pre-loop's exiting block by post-loop's preheader.
448         PN.setIncomingBlock(i, PostLoopPreHeader);
449         // Replace incoming value by LCSSAPhi.
450         PN.setIncomingValue(i, LCSSAPhi);
451         // Add a new incoming value with post-loop's exiting block.
452         PN.addIncoming(VMap[IncomingValue], PostLoop->getExitingBlock());
453       }
454     }
455   }
456 
457   // Update dominator tree.
458   DT.changeImmediateDominator(PostLoopPreHeader, L.getExitingBlock());
459   DT.changeImmediateDominator(PostLoop->getExitBlock(), PostLoopPreHeader);
460 
461   // Invalidate cached SE information.
462   SE.forgetLoop(&L);
463 
464   // Canonicalize loops.
465   simplifyLoop(&L, &DT, &LI, &SE, nullptr, nullptr, true);
466   simplifyLoop(PostLoop, &DT, &LI, &SE, nullptr, nullptr, true);
467 
468   // Add new post-loop to loop pass manager.
469   U.addSiblingLoops(PostLoop);
470 
471   return true;
472 }
473 
474 PreservedAnalyses LoopBoundSplitPass::run(Loop &L, LoopAnalysisManager &AM,
475                                           LoopStandardAnalysisResults &AR,
476                                           LPMUpdater &U) {
477   Function &F = *L.getHeader()->getParent();
478   (void)F;
479 
480   LLVM_DEBUG(dbgs() << "Spliting bound of loop in " << F.getName() << ": " << L
481                     << "\n");
482 
483   if (!splitLoopBound(L, AR.DT, AR.LI, AR.SE, U))
484     return PreservedAnalyses::all();
485 
486   assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast));
487   AR.LI.verify(AR.DT);
488 
489   return getLoopPassPreservedAnalyses();
490 }
491 
492 } // end namespace llvm
493