1 //===------- LoopBoundSplit.cpp - Split Loop Bound --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Scalar/LoopBoundSplit.h"
10 #include "llvm/ADT/Sequence.h"
11 #include "llvm/Analysis/LoopAccessAnalysis.h"
12 #include "llvm/Analysis/LoopAnalysisManager.h"
13 #include "llvm/Analysis/LoopInfo.h"
14 #include "llvm/Analysis/LoopIterator.h"
15 #include "llvm/Analysis/LoopPass.h"
16 #include "llvm/Analysis/MemorySSA.h"
17 #include "llvm/Analysis/MemorySSAUpdater.h"
18 #include "llvm/Analysis/ScalarEvolution.h"
19 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
20 #include "llvm/IR/PatternMatch.h"
21 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
22 #include "llvm/Transforms/Utils/Cloning.h"
23 #include "llvm/Transforms/Utils/LoopSimplify.h"
24 #include "llvm/Transforms/Utils/LoopUtils.h"
25 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
26 
27 #define DEBUG_TYPE "loop-bound-split"
28 
29 namespace llvm {
30 
31 using namespace PatternMatch;
32 
33 namespace {
34 struct ConditionInfo {
35   /// Branch instruction with this condition
36   BranchInst *BI = nullptr;
37   /// ICmp instruction with this condition
38   ICmpInst *ICmp = nullptr;
39   /// Preciate info
40   ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
41   /// AddRec llvm value
42   Value *AddRecValue = nullptr;
43   /// Non PHI AddRec llvm value
44   Value *NonPHIAddRecValue;
45   /// Bound llvm value
46   Value *BoundValue = nullptr;
47   /// AddRec SCEV
48   const SCEVAddRecExpr *AddRecSCEV = nullptr;
49   /// Bound SCEV
50   const SCEV *BoundSCEV = nullptr;
51 
52   ConditionInfo() = default;
53 };
54 } // namespace
55 
56 static void analyzeICmp(ScalarEvolution &SE, ICmpInst *ICmp,
57                         ConditionInfo &Cond, const Loop &L) {
58   Cond.ICmp = ICmp;
59   if (match(ICmp, m_ICmp(Cond.Pred, m_Value(Cond.AddRecValue),
60                          m_Value(Cond.BoundValue)))) {
61     const SCEV *AddRecSCEV = SE.getSCEV(Cond.AddRecValue);
62     const SCEV *BoundSCEV = SE.getSCEV(Cond.BoundValue);
63     const SCEVAddRecExpr *LHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV);
64     const SCEVAddRecExpr *RHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(BoundSCEV);
65     // Locate AddRec in LHSSCEV and Bound in RHSSCEV.
66     if (!LHSAddRecSCEV && RHSAddRecSCEV) {
67       std::swap(Cond.AddRecValue, Cond.BoundValue);
68       std::swap(AddRecSCEV, BoundSCEV);
69       Cond.Pred = ICmpInst::getSwappedPredicate(Cond.Pred);
70     }
71 
72     Cond.AddRecSCEV = dyn_cast<SCEVAddRecExpr>(AddRecSCEV);
73     Cond.BoundSCEV = BoundSCEV;
74     Cond.NonPHIAddRecValue = Cond.AddRecValue;
75 
76     // If the Cond.AddRecValue is PHI node, update Cond.NonPHIAddRecValue with
77     // value from backedge.
78     if (Cond.AddRecSCEV && isa<PHINode>(Cond.AddRecValue)) {
79       PHINode *PN = cast<PHINode>(Cond.AddRecValue);
80       Cond.NonPHIAddRecValue = PN->getIncomingValueForBlock(L.getLoopLatch());
81     }
82   }
83 }
84 
85 static bool calculateUpperBound(const Loop &L, ScalarEvolution &SE,
86                                 ConditionInfo &Cond, bool IsExitCond) {
87   if (IsExitCond) {
88     const SCEV *ExitCount = SE.getExitCount(&L, Cond.ICmp->getParent());
89     if (isa<SCEVCouldNotCompute>(ExitCount))
90       return false;
91 
92     Cond.BoundSCEV = ExitCount;
93     return true;
94   }
95 
96   // For non-exit condtion, if pred is LT, keep existing bound.
97   if (Cond.Pred == ICmpInst::ICMP_SLT || Cond.Pred == ICmpInst::ICMP_ULT)
98     return true;
99 
100   // For non-exit condition, if pre is LE, try to convert it to LT.
101   //      Range                 Range
102   // AddRec <= Bound  -->  AddRec < Bound + 1
103   if (Cond.Pred != ICmpInst::ICMP_ULE && Cond.Pred != ICmpInst::ICMP_SLE)
104     return false;
105 
106   if (IntegerType *BoundSCEVIntType =
107           dyn_cast<IntegerType>(Cond.BoundSCEV->getType())) {
108     unsigned BitWidth = BoundSCEVIntType->getBitWidth();
109     APInt Max = ICmpInst::isSigned(Cond.Pred)
110                     ? APInt::getSignedMaxValue(BitWidth)
111                     : APInt::getMaxValue(BitWidth);
112     const SCEV *MaxSCEV = SE.getConstant(Max);
113     // Check Bound < INT_MAX
114     ICmpInst::Predicate Pred =
115         ICmpInst::isSigned(Cond.Pred) ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
116     if (SE.isKnownPredicate(Pred, Cond.BoundSCEV, MaxSCEV)) {
117       const SCEV *BoundPlusOneSCEV =
118           SE.getAddExpr(Cond.BoundSCEV, SE.getOne(BoundSCEVIntType));
119       Cond.BoundSCEV = BoundPlusOneSCEV;
120       Cond.Pred = Pred;
121       return true;
122     }
123   }
124 
125   // ToDo: Support ICMP_NE/EQ.
126 
127   return false;
128 }
129 
130 static bool hasProcessableCondition(const Loop &L, ScalarEvolution &SE,
131                                     ICmpInst *ICmp, ConditionInfo &Cond,
132                                     bool IsExitCond) {
133   analyzeICmp(SE, ICmp, Cond, L);
134 
135   // The BoundSCEV should be evaluated at loop entry.
136   if (!SE.isAvailableAtLoopEntry(Cond.BoundSCEV, &L))
137     return false;
138 
139   // Allowed AddRec as induction variable.
140   if (!Cond.AddRecSCEV)
141     return false;
142 
143   if (!Cond.AddRecSCEV->isAffine())
144     return false;
145 
146   const SCEV *StepRecSCEV = Cond.AddRecSCEV->getStepRecurrence(SE);
147   // Allowed constant step.
148   if (!isa<SCEVConstant>(StepRecSCEV))
149     return false;
150 
151   ConstantInt *StepCI = cast<SCEVConstant>(StepRecSCEV)->getValue();
152   // Allowed positive step for now.
153   // TODO: Support negative step.
154   if (StepCI->isNegative() || StepCI->isZero())
155     return false;
156 
157   // Calculate upper bound.
158   if (!calculateUpperBound(L, SE, Cond, IsExitCond))
159     return false;
160 
161   return true;
162 }
163 
164 static bool isProcessableCondBI(const ScalarEvolution &SE,
165                                 const BranchInst *BI) {
166   BasicBlock *TrueSucc = nullptr;
167   BasicBlock *FalseSucc = nullptr;
168   ICmpInst::Predicate Pred;
169   Value *LHS, *RHS;
170   if (!match(BI, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)),
171                       m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc))))
172     return false;
173 
174   if (!SE.isSCEVable(LHS->getType()))
175     return false;
176   assert(SE.isSCEVable(RHS->getType()) && "Expected RHS's type is SCEVable");
177 
178   if (TrueSucc == FalseSucc)
179     return false;
180 
181   return true;
182 }
183 
184 static bool canSplitLoopBound(const Loop &L, const DominatorTree &DT,
185                               ScalarEvolution &SE, ConditionInfo &Cond) {
186   // Skip function with optsize.
187   if (L.getHeader()->getParent()->hasOptSize())
188     return false;
189 
190   // Split only innermost loop.
191   if (!L.isInnermost())
192     return false;
193 
194   // Check loop is in simplified form.
195   if (!L.isLoopSimplifyForm())
196     return false;
197 
198   // Check loop is in LCSSA form.
199   if (!L.isLCSSAForm(DT))
200     return false;
201 
202   // Skip loop that cannot be cloned.
203   if (!L.isSafeToClone())
204     return false;
205 
206   BasicBlock *ExitingBB = L.getExitingBlock();
207   // Assumed only one exiting block.
208   if (!ExitingBB)
209     return false;
210 
211   BranchInst *ExitingBI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
212   if (!ExitingBI)
213     return false;
214 
215   // Allowed only conditional branch with ICmp.
216   if (!isProcessableCondBI(SE, ExitingBI))
217     return false;
218 
219   // Check the condition is processable.
220   ICmpInst *ICmp = cast<ICmpInst>(ExitingBI->getCondition());
221   if (!hasProcessableCondition(L, SE, ICmp, Cond, /*IsExitCond*/ true))
222     return false;
223 
224   Cond.BI = ExitingBI;
225   return true;
226 }
227 
228 static bool isProfitableToTransform(const Loop &L, const BranchInst *BI) {
229   // If the conditional branch splits a loop into two halves, we could
230   // generally say it is profitable.
231   //
232   // ToDo: Add more profitable cases here.
233 
234   // Check this branch causes diamond CFG.
235   BasicBlock *Succ0 = BI->getSuccessor(0);
236   BasicBlock *Succ1 = BI->getSuccessor(1);
237 
238   BasicBlock *Succ0Succ = Succ0->getSingleSuccessor();
239   BasicBlock *Succ1Succ = Succ1->getSingleSuccessor();
240   if (!Succ0Succ || !Succ1Succ || Succ0Succ != Succ1Succ)
241     return false;
242 
243   // ToDo: Calculate each successor's instruction cost.
244 
245   return true;
246 }
247 
248 static BranchInst *findSplitCandidate(const Loop &L, ScalarEvolution &SE,
249                                       ConditionInfo &ExitingCond,
250                                       ConditionInfo &SplitCandidateCond) {
251   for (auto *BB : L.blocks()) {
252     // Skip condition of backedge.
253     if (L.getLoopLatch() == BB)
254       continue;
255 
256     auto *BI = dyn_cast<BranchInst>(BB->getTerminator());
257     if (!BI)
258       continue;
259 
260     // Check conditional branch with ICmp.
261     if (!isProcessableCondBI(SE, BI))
262       continue;
263 
264     // Skip loop invariant condition.
265     if (L.isLoopInvariant(BI->getCondition()))
266       continue;
267 
268     // Check the condition is processable.
269     ICmpInst *ICmp = cast<ICmpInst>(BI->getCondition());
270     if (!hasProcessableCondition(L, SE, ICmp, SplitCandidateCond,
271                                  /*IsExitCond*/ false))
272       continue;
273 
274     if (ExitingCond.BoundSCEV->getType() !=
275         SplitCandidateCond.BoundSCEV->getType())
276       continue;
277 
278     // After transformation, we assume the split condition of the pre-loop is
279     // always true. In order to guarantee it, we need to check the start value
280     // of the split cond AddRec satisfies the split condition.
281     if (!SE.isLoopEntryGuardedByCond(&L, SplitCandidateCond.Pred,
282                                      SplitCandidateCond.AddRecSCEV->getStart(),
283                                      SplitCandidateCond.BoundSCEV))
284       continue;
285 
286     SplitCandidateCond.BI = BI;
287     return BI;
288   }
289 
290   return nullptr;
291 }
292 
293 static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI,
294                            ScalarEvolution &SE, LPMUpdater &U) {
295   ConditionInfo SplitCandidateCond;
296   ConditionInfo ExitingCond;
297 
298   // Check we can split this loop's bound.
299   if (!canSplitLoopBound(L, DT, SE, ExitingCond))
300     return false;
301 
302   if (!findSplitCandidate(L, SE, ExitingCond, SplitCandidateCond))
303     return false;
304 
305   if (!isProfitableToTransform(L, SplitCandidateCond.BI))
306     return false;
307 
308   // Now, we have a split candidate. Let's build a form as below.
309   //    +--------------------+
310   //    |     preheader      |
311   //    |  set up newbound   |
312   //    +--------------------+
313   //             |     /----------------\
314   //    +--------v----v------+          |
315   //    |      header        |---\      |
316   //    | with true condition|   |      |
317   //    +--------------------+   |      |
318   //             |               |      |
319   //    +--------v-----------+   |      |
320   //    |     if.then.BB     |   |      |
321   //    +--------------------+   |      |
322   //             |               |      |
323   //    +--------v-----------<---/      |
324   //    |       latch        >----------/
325   //    |   with newbound    |
326   //    +--------------------+
327   //             |
328   //    +--------v-----------+
329   //    |     preheader2     |--------------\
330   //    | if (AddRec i !=    |              |
331   //    |     org bound)     |              |
332   //    +--------------------+              |
333   //             |     /----------------\   |
334   //    +--------v----v------+          |   |
335   //    |      header2       |---\      |   |
336   //    | conditional branch |   |      |   |
337   //    |with false condition|   |      |   |
338   //    +--------------------+   |      |   |
339   //             |               |      |   |
340   //    +--------v-----------+   |      |   |
341   //    |    if.then.BB2     |   |      |   |
342   //    +--------------------+   |      |   |
343   //             |               |      |   |
344   //    +--------v-----------<---/      |   |
345   //    |       latch2       >----------/   |
346   //    |   with org bound   |              |
347   //    +--------v-----------+              |
348   //             |                          |
349   //             |  +---------------+       |
350   //             +-->     exit      <-------/
351   //                +---------------+
352 
353   // Let's create post loop.
354   SmallVector<BasicBlock *, 8> PostLoopBlocks;
355   Loop *PostLoop;
356   ValueToValueMapTy VMap;
357   BasicBlock *PreHeader = L.getLoopPreheader();
358   BasicBlock *SplitLoopPH = SplitEdge(PreHeader, L.getHeader(), &DT, &LI);
359   PostLoop = cloneLoopWithPreheader(L.getExitBlock(), SplitLoopPH, &L, VMap,
360                                     ".split", &LI, &DT, PostLoopBlocks);
361   remapInstructionsInBlocks(PostLoopBlocks, VMap);
362 
363   BasicBlock *PostLoopPreHeader = PostLoop->getLoopPreheader();
364   IRBuilder<> Builder(&PostLoopPreHeader->front());
365 
366   // Update phi nodes in header of post-loop.
367   bool isExitingLatch =
368       (L.getExitingBlock() == L.getLoopLatch()) ? true : false;
369   Value *ExitingCondLCSSAPhi = nullptr;
370   for (PHINode &PN : L.getHeader()->phis()) {
371     // Create LCSSA phi node in preheader of post-loop.
372     PHINode *LCSSAPhi =
373         Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa");
374     LCSSAPhi->setDebugLoc(PN.getDebugLoc());
375     // If the exiting block is loop latch, the phi does not have the update at
376     // last iteration. In this case, update lcssa phi with value from backedge.
377     LCSSAPhi->addIncoming(
378         isExitingLatch ? PN.getIncomingValueForBlock(L.getLoopLatch()) : &PN,
379         L.getExitingBlock());
380 
381     // Update the start value of phi node in post-loop with the LCSSA phi node.
382     PHINode *PostLoopPN = cast<PHINode>(VMap[&PN]);
383     PostLoopPN->setIncomingValueForBlock(PostLoopPreHeader, LCSSAPhi);
384 
385     // Find PHI with exiting condition from pre-loop. The PHI should be
386     // SCEVAddRecExpr and have same incoming value from backedge with
387     // ExitingCond.
388     if (!SE.isSCEVable(PN.getType()))
389       continue;
390 
391     const SCEVAddRecExpr *PhiSCEV = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
392     if (PhiSCEV && ExitingCond.NonPHIAddRecValue ==
393                        PN.getIncomingValueForBlock(L.getLoopLatch()))
394       ExitingCondLCSSAPhi = LCSSAPhi;
395   }
396 
397   // Add conditional branch to check we can skip post-loop in its preheader.
398   Instruction *OrigBI = PostLoopPreHeader->getTerminator();
399   ICmpInst::Predicate Pred = ICmpInst::ICMP_NE;
400   Value *Cond =
401       Builder.CreateICmp(Pred, ExitingCondLCSSAPhi, ExitingCond.BoundValue);
402   Builder.CreateCondBr(Cond, PostLoop->getHeader(), PostLoop->getExitBlock());
403   OrigBI->eraseFromParent();
404 
405   // Create new loop bound and add it into preheader of pre-loop.
406   const SCEV *NewBoundSCEV = ExitingCond.BoundSCEV;
407   const SCEV *SplitBoundSCEV = SplitCandidateCond.BoundSCEV;
408   NewBoundSCEV = ICmpInst::isSigned(ExitingCond.Pred)
409                      ? SE.getSMinExpr(NewBoundSCEV, SplitBoundSCEV)
410                      : SE.getUMinExpr(NewBoundSCEV, SplitBoundSCEV);
411 
412   SCEVExpander Expander(
413       SE, L.getHeader()->getParent()->getParent()->getDataLayout(), "split");
414   Instruction *InsertPt = SplitLoopPH->getTerminator();
415   Value *NewBoundValue =
416       Expander.expandCodeFor(NewBoundSCEV, NewBoundSCEV->getType(), InsertPt);
417   NewBoundValue->setName("new.bound");
418 
419   // Replace exiting bound value of pre-loop NewBound.
420   ExitingCond.ICmp->setOperand(1, NewBoundValue);
421 
422   // Replace SplitCandidateCond.BI's condition of pre-loop by True.
423   LLVMContext &Context = PreHeader->getContext();
424   SplitCandidateCond.BI->setCondition(ConstantInt::getTrue(Context));
425 
426   // Replace cloned SplitCandidateCond.BI's condition in post-loop by False.
427   BranchInst *ClonedSplitCandidateBI =
428       cast<BranchInst>(VMap[SplitCandidateCond.BI]);
429   ClonedSplitCandidateBI->setCondition(ConstantInt::getFalse(Context));
430 
431   // Replace exit branch target of pre-loop by post-loop's preheader.
432   if (L.getExitBlock() == ExitingCond.BI->getSuccessor(0))
433     ExitingCond.BI->setSuccessor(0, PostLoopPreHeader);
434   else
435     ExitingCond.BI->setSuccessor(1, PostLoopPreHeader);
436 
437   // Update phi node in exit block of post-loop.
438   Builder.SetInsertPoint(&PostLoopPreHeader->front());
439   for (PHINode &PN : PostLoop->getExitBlock()->phis()) {
440     for (auto i : seq<int>(0, PN.getNumOperands())) {
441       // Check incoming block is pre-loop's exiting block.
442       if (PN.getIncomingBlock(i) == L.getExitingBlock()) {
443         Value *IncomingValue = PN.getIncomingValue(i);
444 
445         // Create LCSSA phi node for incoming value.
446         PHINode *LCSSAPhi =
447             Builder.CreatePHI(PN.getType(), 1, PN.getName() + ".lcssa");
448         LCSSAPhi->setDebugLoc(PN.getDebugLoc());
449         LCSSAPhi->addIncoming(IncomingValue, PN.getIncomingBlock(i));
450 
451         // Replace pre-loop's exiting block by post-loop's preheader.
452         PN.setIncomingBlock(i, PostLoopPreHeader);
453         // Replace incoming value by LCSSAPhi.
454         PN.setIncomingValue(i, LCSSAPhi);
455         // Add a new incoming value with post-loop's exiting block.
456         PN.addIncoming(VMap[IncomingValue], PostLoop->getExitingBlock());
457       }
458     }
459   }
460 
461   // Update dominator tree.
462   DT.changeImmediateDominator(PostLoopPreHeader, L.getExitingBlock());
463   DT.changeImmediateDominator(PostLoop->getExitBlock(), PostLoopPreHeader);
464 
465   // Invalidate cached SE information.
466   SE.forgetLoop(&L);
467 
468   // Canonicalize loops.
469   simplifyLoop(&L, &DT, &LI, &SE, nullptr, nullptr, true);
470   simplifyLoop(PostLoop, &DT, &LI, &SE, nullptr, nullptr, true);
471 
472   // Add new post-loop to loop pass manager.
473   U.addSiblingLoops(PostLoop);
474 
475   return true;
476 }
477 
478 PreservedAnalyses LoopBoundSplitPass::run(Loop &L, LoopAnalysisManager &AM,
479                                           LoopStandardAnalysisResults &AR,
480                                           LPMUpdater &U) {
481   Function &F = *L.getHeader()->getParent();
482   (void)F;
483 
484   LLVM_DEBUG(dbgs() << "Spliting bound of loop in " << F.getName() << ": " << L
485                     << "\n");
486 
487   if (!splitLoopBound(L, AR.DT, AR.LI, AR.SE, U))
488     return PreservedAnalyses::all();
489 
490   assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast));
491   AR.LI.verify(AR.DT);
492 
493   return getLoopPassPreservedAnalyses();
494 }
495 
496 } // end namespace llvm
497