1 //===------- LoopBoundSplit.cpp - Split Loop Bound --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Scalar/LoopBoundSplit.h"
10 #include "llvm/ADT/Sequence.h"
11 #include "llvm/Analysis/LoopAccessAnalysis.h"
12 #include "llvm/Analysis/LoopAnalysisManager.h"
13 #include "llvm/Analysis/LoopInfo.h"
14 #include "llvm/Analysis/LoopIterator.h"
15 #include "llvm/Analysis/LoopPass.h"
16 #include "llvm/Analysis/MemorySSA.h"
17 #include "llvm/Analysis/MemorySSAUpdater.h"
18 #include "llvm/Analysis/ScalarEvolution.h"
19 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
20 #include "llvm/IR/PatternMatch.h"
21 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
22 #include "llvm/Transforms/Utils/Cloning.h"
23 #include "llvm/Transforms/Utils/LoopSimplify.h"
24 #include "llvm/Transforms/Utils/LoopUtils.h"
25 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
26 
27 #define DEBUG_TYPE "loop-bound-split"
28 
29 namespace llvm {
30 
31 using namespace PatternMatch;
32 
33 namespace {
34 struct ConditionInfo {
35   /// Branch instruction with this condition
36   BranchInst *BI;
37   /// ICmp instruction with this condition
38   ICmpInst *ICmp;
39   /// Preciate info
40   ICmpInst::Predicate Pred;
41   /// AddRec llvm value
42   Value *AddRecValue;
43   /// Bound llvm value
44   Value *BoundValue;
45   /// AddRec SCEV
46   const SCEV *AddRecSCEV;
47   /// Bound SCEV
48   const SCEV *BoundSCEV;
49 
50   ConditionInfo()
51       : BI(nullptr), ICmp(nullptr), Pred(ICmpInst::BAD_ICMP_PREDICATE),
52         AddRecValue(nullptr), BoundValue(nullptr), AddRecSCEV(nullptr),
53         BoundSCEV(nullptr) {}
54 };
55 } // namespace
56 
57 static void analyzeICmp(ScalarEvolution &SE, ICmpInst *ICmp,
58                         ConditionInfo &Cond) {
59   Cond.ICmp = ICmp;
60   if (match(ICmp, m_ICmp(Cond.Pred, m_Value(Cond.AddRecValue),
61                          m_Value(Cond.BoundValue)))) {
62     Cond.AddRecSCEV = SE.getSCEV(Cond.AddRecValue);
63     Cond.BoundSCEV = SE.getSCEV(Cond.BoundValue);
64     // Locate AddRec in LHSSCEV and Bound in RHSSCEV.
65     if (isa<SCEVAddRecExpr>(Cond.BoundSCEV) &&
66         !isa<SCEVAddRecExpr>(Cond.AddRecSCEV)) {
67       std::swap(Cond.AddRecValue, Cond.BoundValue);
68       std::swap(Cond.AddRecSCEV, Cond.BoundSCEV);
69       Cond.Pred = ICmpInst::getSwappedPredicate(Cond.Pred);
70     }
71   }
72 }
73 
74 static bool calculateUpperBound(const Loop &L, ScalarEvolution &SE,
75                                 ConditionInfo &Cond, bool IsExitCond) {
76   if (IsExitCond) {
77     const SCEV *ExitCount = SE.getExitCount(&L, Cond.ICmp->getParent());
78     if (isa<SCEVCouldNotCompute>(ExitCount))
79       return false;
80 
81     Cond.BoundSCEV = ExitCount;
82     return true;
83   }
84 
85   // For non-exit condtion, if pred is LT, keep existing bound.
86   if (Cond.Pred == ICmpInst::ICMP_SLT || Cond.Pred == ICmpInst::ICMP_ULT)
87     return true;
88 
89   // For non-exit condition, if pre is LE, try to convert it to LT.
90   //      Range                 Range
91   // AddRec <= Bound  -->  AddRec < Bound + 1
92   if (Cond.Pred != ICmpInst::ICMP_ULE && Cond.Pred != ICmpInst::ICMP_SLE)
93     return false;
94 
95   if (IntegerType *BoundSCEVIntType =
96           dyn_cast<IntegerType>(Cond.BoundSCEV->getType())) {
97     unsigned BitWidth = BoundSCEVIntType->getBitWidth();
98     APInt Max = ICmpInst::isSigned(Cond.Pred)
99                     ? APInt::getSignedMaxValue(BitWidth)
100                     : APInt::getMaxValue(BitWidth);
101     const SCEV *MaxSCEV = SE.getConstant(Max);
102     // Check Bound < INT_MAX
103     ICmpInst::Predicate Pred =
104         ICmpInst::isSigned(Cond.Pred) ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
105     if (SE.isKnownPredicate(Pred, Cond.BoundSCEV, MaxSCEV)) {
106       const SCEV *BoundPlusOneSCEV =
107           SE.getAddExpr(Cond.BoundSCEV, SE.getOne(BoundSCEVIntType));
108       Cond.BoundSCEV = BoundPlusOneSCEV;
109       Cond.Pred = Pred;
110       return true;
111     }
112   }
113 
114   // ToDo: Support ICMP_NE/EQ.
115 
116   return false;
117 }
118 
119 static bool hasProcessableCondition(const Loop &L, ScalarEvolution &SE,
120                                     ICmpInst *ICmp, ConditionInfo &Cond,
121                                     bool IsExitCond) {
122   analyzeICmp(SE, ICmp, Cond);
123 
124   // The BoundSCEV should be evaluated at loop entry.
125   if (!SE.isAvailableAtLoopEntry(Cond.BoundSCEV, &L))
126     return false;
127 
128   const SCEVAddRecExpr *AddRecSCEV = dyn_cast<SCEVAddRecExpr>(Cond.AddRecSCEV);
129   // Allowed AddRec as induction variable.
130   if (!AddRecSCEV)
131     return false;
132 
133   if (!AddRecSCEV->isAffine())
134     return false;
135 
136   const SCEV *StepRecSCEV = AddRecSCEV->getStepRecurrence(SE);
137   // Allowed constant step.
138   if (!isa<SCEVConstant>(StepRecSCEV))
139     return false;
140 
141   ConstantInt *StepCI = cast<SCEVConstant>(StepRecSCEV)->getValue();
142   // Allowed positive step for now.
143   // TODO: Support negative step.
144   if (StepCI->isNegative() || StepCI->isZero())
145     return false;
146 
147   // Calculate upper bound.
148   if (!calculateUpperBound(L, SE, Cond, IsExitCond))
149     return false;
150 
151   return true;
152 }
153 
154 static bool isProcessableCondBI(const ScalarEvolution &SE,
155                                 const BranchInst *BI) {
156   BasicBlock *TrueSucc = nullptr;
157   BasicBlock *FalseSucc = nullptr;
158   ICmpInst::Predicate Pred;
159   Value *LHS, *RHS;
160   if (!match(BI, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)),
161                       m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc))))
162     return false;
163 
164   if (!SE.isSCEVable(LHS->getType()))
165     return false;
166   assert(SE.isSCEVable(RHS->getType()) && "Expected RHS's type is SCEVable");
167 
168   if (TrueSucc == FalseSucc)
169     return false;
170 
171   return true;
172 }
173 
174 static bool canSplitLoopBound(const Loop &L, const DominatorTree &DT,
175                               ScalarEvolution &SE, ConditionInfo &Cond) {
176   // Skip function with optsize.
177   if (L.getHeader()->getParent()->hasOptSize())
178     return false;
179 
180   // Split only innermost loop.
181   if (!L.isInnermost())
182     return false;
183 
184   // Check loop is in simplified form.
185   if (!L.isLoopSimplifyForm())
186     return false;
187 
188   // Check loop is in LCSSA form.
189   if (!L.isLCSSAForm(DT))
190     return false;
191 
192   // Skip loop that cannot be cloned.
193   if (!L.isSafeToClone())
194     return false;
195 
196   BasicBlock *ExitingBB = L.getExitingBlock();
197   // Assumed only one exiting block.
198   if (!ExitingBB)
199     return false;
200 
201   BranchInst *ExitingBI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
202   if (!ExitingBI)
203     return false;
204 
205   // Allowed only conditional branch with ICmp.
206   if (!isProcessableCondBI(SE, ExitingBI))
207     return false;
208 
209   // Check the condition is processable.
210   ICmpInst *ICmp = cast<ICmpInst>(ExitingBI->getCondition());
211   if (!hasProcessableCondition(L, SE, ICmp, Cond, /*IsExitCond*/ true))
212     return false;
213 
214   Cond.BI = ExitingBI;
215   return true;
216 }
217 
218 static bool isProfitableToTransform(const Loop &L, const BranchInst *BI) {
219   // If the conditional branch splits a loop into two halves, we could
220   // generally say it is profitable.
221   //
222   // ToDo: Add more profitable cases here.
223 
224   // Check this branch causes diamond CFG.
225   BasicBlock *Succ0 = BI->getSuccessor(0);
226   BasicBlock *Succ1 = BI->getSuccessor(1);
227 
228   BasicBlock *Succ0Succ = Succ0->getSingleSuccessor();
229   BasicBlock *Succ1Succ = Succ1->getSingleSuccessor();
230   if (!Succ0Succ || !Succ1Succ || Succ0Succ != Succ1Succ)
231     return false;
232 
233   // ToDo: Calculate each successor's instruction cost.
234 
235   return true;
236 }
237 
238 static BranchInst *findSplitCandidate(const Loop &L, ScalarEvolution &SE,
239                                       ConditionInfo &ExitingCond,
240                                       ConditionInfo &SplitCandidateCond) {
241   for (auto *BB : L.blocks()) {
242     // Skip condition of backedge.
243     if (L.getLoopLatch() == BB)
244       continue;
245 
246     auto *BI = dyn_cast<BranchInst>(BB->getTerminator());
247     if (!BI)
248       continue;
249 
250     // Check conditional branch with ICmp.
251     if (!isProcessableCondBI(SE, BI))
252       continue;
253 
254     // Skip loop invariant condition.
255     if (L.isLoopInvariant(BI->getCondition()))
256       continue;
257 
258     // Check the condition is processable.
259     ICmpInst *ICmp = cast<ICmpInst>(BI->getCondition());
260     if (!hasProcessableCondition(L, SE, ICmp, SplitCandidateCond,
261                                  /*IsExitCond*/ false))
262       continue;
263 
264     if (ExitingCond.BoundSCEV->getType() !=
265         SplitCandidateCond.BoundSCEV->getType())
266       continue;
267 
268     SplitCandidateCond.BI = BI;
269     return BI;
270   }
271 
272   return nullptr;
273 }
274 
275 static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI,
276                            ScalarEvolution &SE, LPMUpdater &U) {
277   ConditionInfo SplitCandidateCond;
278   ConditionInfo ExitingCond;
279 
280   // Check we can split this loop's bound.
281   if (!canSplitLoopBound(L, DT, SE, ExitingCond))
282     return false;
283 
284   if (!findSplitCandidate(L, SE, ExitingCond, SplitCandidateCond))
285     return false;
286 
287   if (!isProfitableToTransform(L, SplitCandidateCond.BI))
288     return false;
289 
290   // Now, we have a split candidate. Let's build a form as below.
291   //    +--------------------+
292   //    |     preheader      |
293   //    |  set up newbound   |
294   //    +--------------------+
295   //             |     /----------------\
296   //    +--------v----v------+          |
297   //    |      header        |---\      |
298   //    | with true condition|   |      |
299   //    +--------------------+   |      |
300   //             |               |      |
301   //    +--------v-----------+   |      |
302   //    |     if.then.BB     |   |      |
303   //    +--------------------+   |      |
304   //             |               |      |
305   //    +--------v-----------<---/      |
306   //    |       latch        >----------/
307   //    |   with newbound    |
308   //    +--------------------+
309   //             |
310   //    +--------v-----------+
311   //    |     preheader2     |--------------\
312   //    | if (AddRec i !=    |              |
313   //    |     org bound)     |              |
314   //    +--------------------+              |
315   //             |     /----------------\   |
316   //    +--------v----v------+          |   |
317   //    |      header2       |---\      |   |
318   //    | conditional branch |   |      |   |
319   //    |with false condition|   |      |   |
320   //    +--------------------+   |      |   |
321   //             |               |      |   |
322   //    +--------v-----------+   |      |   |
323   //    |    if.then.BB2     |   |      |   |
324   //    +--------------------+   |      |   |
325   //             |               |      |   |
326   //    +--------v-----------<---/      |   |
327   //    |       latch2       >----------/   |
328   //    |   with org bound   |              |
329   //    +--------v-----------+              |
330   //             |                          |
331   //             |  +---------------+       |
332   //             +-->     exit      <-------/
333   //                +---------------+
334 
335   // Let's create post loop.
336   SmallVector<BasicBlock *, 8> PostLoopBlocks;
337   Loop *PostLoop;
338   ValueToValueMapTy VMap;
339   BasicBlock *PreHeader = L.getLoopPreheader();
340   BasicBlock *SplitLoopPH = SplitEdge(PreHeader, L.getHeader(), &DT, &LI);
341   PostLoop = cloneLoopWithPreheader(L.getExitBlock(), SplitLoopPH, &L, VMap,
342                                     ".split", &LI, &DT, PostLoopBlocks);
343   remapInstructionsInBlocks(PostLoopBlocks, VMap);
344 
345   // Add conditional branch to check we can skip post-loop in its preheader.
346   BasicBlock *PostLoopPreHeader = PostLoop->getLoopPreheader();
347   IRBuilder<> Builder(PostLoopPreHeader);
348   Instruction *OrigBI = PostLoopPreHeader->getTerminator();
349   ICmpInst::Predicate Pred = ICmpInst::ICMP_NE;
350   Value *Cond =
351       Builder.CreateICmp(Pred, ExitingCond.AddRecValue, ExitingCond.BoundValue);
352   Builder.CreateCondBr(Cond, PostLoop->getHeader(), PostLoop->getExitBlock());
353   OrigBI->eraseFromParent();
354 
355   // Create new loop bound and add it into preheader of pre-loop.
356   const SCEV *NewBoundSCEV = ExitingCond.BoundSCEV;
357   const SCEV *SplitBoundSCEV = SplitCandidateCond.BoundSCEV;
358   NewBoundSCEV = ICmpInst::isSigned(ExitingCond.Pred)
359                      ? SE.getSMinExpr(NewBoundSCEV, SplitBoundSCEV)
360                      : SE.getUMinExpr(NewBoundSCEV, SplitBoundSCEV);
361 
362   SCEVExpander Expander(
363       SE, L.getHeader()->getParent()->getParent()->getDataLayout(), "split");
364   Instruction *InsertPt = SplitLoopPH->getTerminator();
365   Value *NewBoundValue =
366       Expander.expandCodeFor(NewBoundSCEV, NewBoundSCEV->getType(), InsertPt);
367   NewBoundValue->setName("new.bound");
368 
369   // Replace exiting bound value of pre-loop NewBound.
370   ExitingCond.ICmp->setOperand(1, NewBoundValue);
371 
372   // Replace IV's start value of post-loop by NewBound.
373   for (PHINode &PN : L.getHeader()->phis()) {
374     // Find PHI with exiting condition from pre-loop.
375     if (SE.isSCEVable(PN.getType()) && isa<SCEVAddRecExpr>(SE.getSCEV(&PN))) {
376       for (Value *Op : PN.incoming_values()) {
377         if (Op == ExitingCond.AddRecValue) {
378           // Find cloned PHI for post-loop.
379           PHINode *PostLoopPN = cast<PHINode>(VMap[&PN]);
380           PostLoopPN->setIncomingValueForBlock(PostLoopPreHeader,
381                                                NewBoundValue);
382         }
383       }
384     }
385   }
386 
387   // Replace SplitCandidateCond.BI's condition of pre-loop by True.
388   LLVMContext &Context = PreHeader->getContext();
389   SplitCandidateCond.BI->setCondition(ConstantInt::getTrue(Context));
390 
391   // Replace cloned SplitCandidateCond.BI's condition in post-loop by False.
392   BranchInst *ClonedSplitCandidateBI =
393       cast<BranchInst>(VMap[SplitCandidateCond.BI]);
394   ClonedSplitCandidateBI->setCondition(ConstantInt::getFalse(Context));
395 
396   // Replace exit branch target of pre-loop by post-loop's preheader.
397   if (L.getExitBlock() == ExitingCond.BI->getSuccessor(0))
398     ExitingCond.BI->setSuccessor(0, PostLoopPreHeader);
399   else
400     ExitingCond.BI->setSuccessor(1, PostLoopPreHeader);
401 
402   // Update phi node in exit block of post-loop.
403   for (PHINode &PN : PostLoop->getExitBlock()->phis()) {
404     for (auto i : seq<int>(0, PN.getNumOperands())) {
405       // Check incoming block is pre-loop's exiting block.
406       if (PN.getIncomingBlock(i) == L.getExitingBlock()) {
407         // Replace pre-loop's exiting block by post-loop's preheader.
408         PN.setIncomingBlock(i, PostLoopPreHeader);
409         // Add a new incoming value with post-loop's exiting block.
410         PN.addIncoming(VMap[PN.getIncomingValue(i)],
411                        PostLoop->getExitingBlock());
412       }
413     }
414   }
415 
416   // Update dominator tree.
417   DT.changeImmediateDominator(PostLoopPreHeader, L.getExitingBlock());
418   DT.changeImmediateDominator(PostLoop->getExitBlock(), PostLoopPreHeader);
419 
420   // Invalidate cached SE information.
421   SE.forgetLoop(&L);
422 
423   // Canonicalize loops.
424   // TODO: Try to update LCSSA information according to above change.
425   formLCSSA(L, DT, &LI, &SE);
426   simplifyLoop(&L, &DT, &LI, &SE, nullptr, nullptr, true);
427   formLCSSA(*PostLoop, DT, &LI, &SE);
428   simplifyLoop(PostLoop, &DT, &LI, &SE, nullptr, nullptr, true);
429 
430   // Add new post-loop to loop pass manager.
431   U.addSiblingLoops(PostLoop);
432 
433   return true;
434 }
435 
436 PreservedAnalyses LoopBoundSplitPass::run(Loop &L, LoopAnalysisManager &AM,
437                                           LoopStandardAnalysisResults &AR,
438                                           LPMUpdater &U) {
439   Function &F = *L.getHeader()->getParent();
440   (void)F;
441 
442   LLVM_DEBUG(dbgs() << "Spliting bound of loop in " << F.getName() << ": " << L
443                     << "\n");
444 
445   if (!splitLoopBound(L, AR.DT, AR.LI, AR.SE, U))
446     return PreservedAnalyses::all();
447 
448   assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast));
449   AR.LI.verify(AR.DT);
450 
451   return getLoopPassPreservedAnalyses();
452 }
453 
454 } // end namespace llvm
455