1*c9157d92SDimitry Andric #include "llvm/Transforms/Utils/LoopConstrainer.h"
2*c9157d92SDimitry Andric #include "llvm/Analysis/LoopInfo.h"
3*c9157d92SDimitry Andric #include "llvm/Analysis/ScalarEvolution.h"
4*c9157d92SDimitry Andric #include "llvm/Analysis/ScalarEvolutionExpressions.h"
5*c9157d92SDimitry Andric #include "llvm/IR/Dominators.h"
6*c9157d92SDimitry Andric #include "llvm/Transforms/Utils/Cloning.h"
7*c9157d92SDimitry Andric #include "llvm/Transforms/Utils/LoopSimplify.h"
8*c9157d92SDimitry Andric #include "llvm/Transforms/Utils/LoopUtils.h"
9*c9157d92SDimitry Andric #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
10*c9157d92SDimitry Andric
11*c9157d92SDimitry Andric using namespace llvm;
12*c9157d92SDimitry Andric
13*c9157d92SDimitry Andric static const char *ClonedLoopTag = "loop_constrainer.loop.clone";
14*c9157d92SDimitry Andric
15*c9157d92SDimitry Andric #define DEBUG_TYPE "loop-constrainer"
16*c9157d92SDimitry Andric
17*c9157d92SDimitry Andric /// Given a loop with an deccreasing induction variable, is it possible to
18*c9157d92SDimitry Andric /// safely calculate the bounds of a new loop using the given Predicate.
isSafeDecreasingBound(const SCEV * Start,const SCEV * BoundSCEV,const SCEV * Step,ICmpInst::Predicate Pred,unsigned LatchBrExitIdx,Loop * L,ScalarEvolution & SE)19*c9157d92SDimitry Andric static bool isSafeDecreasingBound(const SCEV *Start, const SCEV *BoundSCEV,
20*c9157d92SDimitry Andric const SCEV *Step, ICmpInst::Predicate Pred,
21*c9157d92SDimitry Andric unsigned LatchBrExitIdx, Loop *L,
22*c9157d92SDimitry Andric ScalarEvolution &SE) {
23*c9157d92SDimitry Andric if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
24*c9157d92SDimitry Andric Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
25*c9157d92SDimitry Andric return false;
26*c9157d92SDimitry Andric
27*c9157d92SDimitry Andric if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
28*c9157d92SDimitry Andric return false;
29*c9157d92SDimitry Andric
30*c9157d92SDimitry Andric assert(SE.isKnownNegative(Step) && "expecting negative step");
31*c9157d92SDimitry Andric
32*c9157d92SDimitry Andric LLVM_DEBUG(dbgs() << "isSafeDecreasingBound with:\n");
33*c9157d92SDimitry Andric LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n");
34*c9157d92SDimitry Andric LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n");
35*c9157d92SDimitry Andric LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n");
36*c9157d92SDimitry Andric LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n");
37*c9157d92SDimitry Andric LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n");
38*c9157d92SDimitry Andric
39*c9157d92SDimitry Andric bool IsSigned = ICmpInst::isSigned(Pred);
40*c9157d92SDimitry Andric // The predicate that we need to check that the induction variable lies
41*c9157d92SDimitry Andric // within bounds.
42*c9157d92SDimitry Andric ICmpInst::Predicate BoundPred =
43*c9157d92SDimitry Andric IsSigned ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT;
44*c9157d92SDimitry Andric
45*c9157d92SDimitry Andric if (LatchBrExitIdx == 1)
46*c9157d92SDimitry Andric return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV);
47*c9157d92SDimitry Andric
48*c9157d92SDimitry Andric assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be either 0 or 1");
49*c9157d92SDimitry Andric
50*c9157d92SDimitry Andric const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType()));
51*c9157d92SDimitry Andric unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
52*c9157d92SDimitry Andric APInt Min = IsSigned ? APInt::getSignedMinValue(BitWidth)
53*c9157d92SDimitry Andric : APInt::getMinValue(BitWidth);
54*c9157d92SDimitry Andric const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Min), StepPlusOne);
55*c9157d92SDimitry Andric
56*c9157d92SDimitry Andric const SCEV *MinusOne =
57*c9157d92SDimitry Andric SE.getMinusSCEV(BoundSCEV, SE.getOne(BoundSCEV->getType()));
58*c9157d92SDimitry Andric
59*c9157d92SDimitry Andric return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, MinusOne) &&
60*c9157d92SDimitry Andric SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit);
61*c9157d92SDimitry Andric }
62*c9157d92SDimitry Andric
63*c9157d92SDimitry Andric /// Given a loop with an increasing induction variable, is it possible to
64*c9157d92SDimitry Andric /// safely calculate the bounds of a new loop using the given Predicate.
isSafeIncreasingBound(const SCEV * Start,const SCEV * BoundSCEV,const SCEV * Step,ICmpInst::Predicate Pred,unsigned LatchBrExitIdx,Loop * L,ScalarEvolution & SE)65*c9157d92SDimitry Andric static bool isSafeIncreasingBound(const SCEV *Start, const SCEV *BoundSCEV,
66*c9157d92SDimitry Andric const SCEV *Step, ICmpInst::Predicate Pred,
67*c9157d92SDimitry Andric unsigned LatchBrExitIdx, Loop *L,
68*c9157d92SDimitry Andric ScalarEvolution &SE) {
69*c9157d92SDimitry Andric if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
70*c9157d92SDimitry Andric Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
71*c9157d92SDimitry Andric return false;
72*c9157d92SDimitry Andric
73*c9157d92SDimitry Andric if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
74*c9157d92SDimitry Andric return false;
75*c9157d92SDimitry Andric
76*c9157d92SDimitry Andric LLVM_DEBUG(dbgs() << "isSafeIncreasingBound with:\n");
77*c9157d92SDimitry Andric LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n");
78*c9157d92SDimitry Andric LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n");
79*c9157d92SDimitry Andric LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n");
80*c9157d92SDimitry Andric LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n");
81*c9157d92SDimitry Andric LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n");
82*c9157d92SDimitry Andric
83*c9157d92SDimitry Andric bool IsSigned = ICmpInst::isSigned(Pred);
84*c9157d92SDimitry Andric // The predicate that we need to check that the induction variable lies
85*c9157d92SDimitry Andric // within bounds.
86*c9157d92SDimitry Andric ICmpInst::Predicate BoundPred =
87*c9157d92SDimitry Andric IsSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT;
88*c9157d92SDimitry Andric
89*c9157d92SDimitry Andric if (LatchBrExitIdx == 1)
90*c9157d92SDimitry Andric return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV);
91*c9157d92SDimitry Andric
92*c9157d92SDimitry Andric assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be 0 or 1");
93*c9157d92SDimitry Andric
94*c9157d92SDimitry Andric const SCEV *StepMinusOne = SE.getMinusSCEV(Step, SE.getOne(Step->getType()));
95*c9157d92SDimitry Andric unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
96*c9157d92SDimitry Andric APInt Max = IsSigned ? APInt::getSignedMaxValue(BitWidth)
97*c9157d92SDimitry Andric : APInt::getMaxValue(BitWidth);
98*c9157d92SDimitry Andric const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Max), StepMinusOne);
99*c9157d92SDimitry Andric
100*c9157d92SDimitry Andric return (SE.isLoopEntryGuardedByCond(L, BoundPred, Start,
101*c9157d92SDimitry Andric SE.getAddExpr(BoundSCEV, Step)) &&
102*c9157d92SDimitry Andric SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit));
103*c9157d92SDimitry Andric }
104*c9157d92SDimitry Andric
105*c9157d92SDimitry Andric /// Returns estimate for max latch taken count of the loop of the narrowest
106*c9157d92SDimitry Andric /// available type. If the latch block has such estimate, it is returned.
107*c9157d92SDimitry Andric /// Otherwise, we use max exit count of whole loop (that is potentially of wider
108*c9157d92SDimitry Andric /// type than latch check itself), which is still better than no estimate.
getNarrowestLatchMaxTakenCountEstimate(ScalarEvolution & SE,const Loop & L)109*c9157d92SDimitry Andric static const SCEV *getNarrowestLatchMaxTakenCountEstimate(ScalarEvolution &SE,
110*c9157d92SDimitry Andric const Loop &L) {
111*c9157d92SDimitry Andric const SCEV *FromBlock =
112*c9157d92SDimitry Andric SE.getExitCount(&L, L.getLoopLatch(), ScalarEvolution::SymbolicMaximum);
113*c9157d92SDimitry Andric if (isa<SCEVCouldNotCompute>(FromBlock))
114*c9157d92SDimitry Andric return SE.getSymbolicMaxBackedgeTakenCount(&L);
115*c9157d92SDimitry Andric return FromBlock;
116*c9157d92SDimitry Andric }
117*c9157d92SDimitry Andric
118*c9157d92SDimitry Andric std::optional<LoopStructure>
parseLoopStructure(ScalarEvolution & SE,Loop & L,bool AllowUnsignedLatchCond,const char * & FailureReason)119*c9157d92SDimitry Andric LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L,
120*c9157d92SDimitry Andric bool AllowUnsignedLatchCond,
121*c9157d92SDimitry Andric const char *&FailureReason) {
122*c9157d92SDimitry Andric if (!L.isLoopSimplifyForm()) {
123*c9157d92SDimitry Andric FailureReason = "loop not in LoopSimplify form";
124*c9157d92SDimitry Andric return std::nullopt;
125*c9157d92SDimitry Andric }
126*c9157d92SDimitry Andric
127*c9157d92SDimitry Andric BasicBlock *Latch = L.getLoopLatch();
128*c9157d92SDimitry Andric assert(Latch && "Simplified loops only have one latch!");
129*c9157d92SDimitry Andric
130*c9157d92SDimitry Andric if (Latch->getTerminator()->getMetadata(ClonedLoopTag)) {
131*c9157d92SDimitry Andric FailureReason = "loop has already been cloned";
132*c9157d92SDimitry Andric return std::nullopt;
133*c9157d92SDimitry Andric }
134*c9157d92SDimitry Andric
135*c9157d92SDimitry Andric if (!L.isLoopExiting(Latch)) {
136*c9157d92SDimitry Andric FailureReason = "no loop latch";
137*c9157d92SDimitry Andric return std::nullopt;
138*c9157d92SDimitry Andric }
139*c9157d92SDimitry Andric
140*c9157d92SDimitry Andric BasicBlock *Header = L.getHeader();
141*c9157d92SDimitry Andric BasicBlock *Preheader = L.getLoopPreheader();
142*c9157d92SDimitry Andric if (!Preheader) {
143*c9157d92SDimitry Andric FailureReason = "no preheader";
144*c9157d92SDimitry Andric return std::nullopt;
145*c9157d92SDimitry Andric }
146*c9157d92SDimitry Andric
147*c9157d92SDimitry Andric BranchInst *LatchBr = dyn_cast<BranchInst>(Latch->getTerminator());
148*c9157d92SDimitry Andric if (!LatchBr || LatchBr->isUnconditional()) {
149*c9157d92SDimitry Andric FailureReason = "latch terminator not conditional branch";
150*c9157d92SDimitry Andric return std::nullopt;
151*c9157d92SDimitry Andric }
152*c9157d92SDimitry Andric
153*c9157d92SDimitry Andric unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0;
154*c9157d92SDimitry Andric
155*c9157d92SDimitry Andric ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition());
156*c9157d92SDimitry Andric if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) {
157*c9157d92SDimitry Andric FailureReason = "latch terminator branch not conditional on integral icmp";
158*c9157d92SDimitry Andric return std::nullopt;
159*c9157d92SDimitry Andric }
160*c9157d92SDimitry Andric
161*c9157d92SDimitry Andric const SCEV *MaxBETakenCount = getNarrowestLatchMaxTakenCountEstimate(SE, L);
162*c9157d92SDimitry Andric if (isa<SCEVCouldNotCompute>(MaxBETakenCount)) {
163*c9157d92SDimitry Andric FailureReason = "could not compute latch count";
164*c9157d92SDimitry Andric return std::nullopt;
165*c9157d92SDimitry Andric }
166*c9157d92SDimitry Andric assert(SE.getLoopDisposition(MaxBETakenCount, &L) ==
167*c9157d92SDimitry Andric ScalarEvolution::LoopInvariant &&
168*c9157d92SDimitry Andric "loop variant exit count doesn't make sense!");
169*c9157d92SDimitry Andric
170*c9157d92SDimitry Andric ICmpInst::Predicate Pred = ICI->getPredicate();
171*c9157d92SDimitry Andric Value *LeftValue = ICI->getOperand(0);
172*c9157d92SDimitry Andric const SCEV *LeftSCEV = SE.getSCEV(LeftValue);
173*c9157d92SDimitry Andric IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType());
174*c9157d92SDimitry Andric
175*c9157d92SDimitry Andric Value *RightValue = ICI->getOperand(1);
176*c9157d92SDimitry Andric const SCEV *RightSCEV = SE.getSCEV(RightValue);
177*c9157d92SDimitry Andric
178*c9157d92SDimitry Andric // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence.
179*c9157d92SDimitry Andric if (!isa<SCEVAddRecExpr>(LeftSCEV)) {
180*c9157d92SDimitry Andric if (isa<SCEVAddRecExpr>(RightSCEV)) {
181*c9157d92SDimitry Andric std::swap(LeftSCEV, RightSCEV);
182*c9157d92SDimitry Andric std::swap(LeftValue, RightValue);
183*c9157d92SDimitry Andric Pred = ICmpInst::getSwappedPredicate(Pred);
184*c9157d92SDimitry Andric } else {
185*c9157d92SDimitry Andric FailureReason = "no add recurrences in the icmp";
186*c9157d92SDimitry Andric return std::nullopt;
187*c9157d92SDimitry Andric }
188*c9157d92SDimitry Andric }
189*c9157d92SDimitry Andric
190*c9157d92SDimitry Andric auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) {
191*c9157d92SDimitry Andric if (AR->getNoWrapFlags(SCEV::FlagNSW))
192*c9157d92SDimitry Andric return true;
193*c9157d92SDimitry Andric
194*c9157d92SDimitry Andric IntegerType *Ty = cast<IntegerType>(AR->getType());
195*c9157d92SDimitry Andric IntegerType *WideTy =
196*c9157d92SDimitry Andric IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2);
197*c9157d92SDimitry Andric
198*c9157d92SDimitry Andric const SCEVAddRecExpr *ExtendAfterOp =
199*c9157d92SDimitry Andric dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy));
200*c9157d92SDimitry Andric if (ExtendAfterOp) {
201*c9157d92SDimitry Andric const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy);
202*c9157d92SDimitry Andric const SCEV *ExtendedStep =
203*c9157d92SDimitry Andric SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy);
204*c9157d92SDimitry Andric
205*c9157d92SDimitry Andric bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart &&
206*c9157d92SDimitry Andric ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep;
207*c9157d92SDimitry Andric
208*c9157d92SDimitry Andric if (NoSignedWrap)
209*c9157d92SDimitry Andric return true;
210*c9157d92SDimitry Andric }
211*c9157d92SDimitry Andric
212*c9157d92SDimitry Andric // We may have proved this when computing the sign extension above.
213*c9157d92SDimitry Andric return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap;
214*c9157d92SDimitry Andric };
215*c9157d92SDimitry Andric
216*c9157d92SDimitry Andric // `ICI` is interpreted as taking the backedge if the *next* value of the
217*c9157d92SDimitry Andric // induction variable satisfies some constraint.
218*c9157d92SDimitry Andric
219*c9157d92SDimitry Andric const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV);
220*c9157d92SDimitry Andric if (IndVarBase->getLoop() != &L) {
221*c9157d92SDimitry Andric FailureReason = "LHS in cmp is not an AddRec for this loop";
222*c9157d92SDimitry Andric return std::nullopt;
223*c9157d92SDimitry Andric }
224*c9157d92SDimitry Andric if (!IndVarBase->isAffine()) {
225*c9157d92SDimitry Andric FailureReason = "LHS in icmp not induction variable";
226*c9157d92SDimitry Andric return std::nullopt;
227*c9157d92SDimitry Andric }
228*c9157d92SDimitry Andric const SCEV *StepRec = IndVarBase->getStepRecurrence(SE);
229*c9157d92SDimitry Andric if (!isa<SCEVConstant>(StepRec)) {
230*c9157d92SDimitry Andric FailureReason = "LHS in icmp not induction variable";
231*c9157d92SDimitry Andric return std::nullopt;
232*c9157d92SDimitry Andric }
233*c9157d92SDimitry Andric ConstantInt *StepCI = cast<SCEVConstant>(StepRec)->getValue();
234*c9157d92SDimitry Andric
235*c9157d92SDimitry Andric if (ICI->isEquality() && !HasNoSignedWrap(IndVarBase)) {
236*c9157d92SDimitry Andric FailureReason = "LHS in icmp needs nsw for equality predicates";
237*c9157d92SDimitry Andric return std::nullopt;
238*c9157d92SDimitry Andric }
239*c9157d92SDimitry Andric
240*c9157d92SDimitry Andric assert(!StepCI->isZero() && "Zero step?");
241*c9157d92SDimitry Andric bool IsIncreasing = !StepCI->isNegative();
242*c9157d92SDimitry Andric bool IsSignedPredicate;
243*c9157d92SDimitry Andric const SCEV *StartNext = IndVarBase->getStart();
244*c9157d92SDimitry Andric const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE));
245*c9157d92SDimitry Andric const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend);
246*c9157d92SDimitry Andric const SCEV *Step = SE.getSCEV(StepCI);
247*c9157d92SDimitry Andric
248*c9157d92SDimitry Andric const SCEV *FixedRightSCEV = nullptr;
249*c9157d92SDimitry Andric
250*c9157d92SDimitry Andric // If RightValue resides within loop (but still being loop invariant),
251*c9157d92SDimitry Andric // regenerate it as preheader.
252*c9157d92SDimitry Andric if (auto *I = dyn_cast<Instruction>(RightValue))
253*c9157d92SDimitry Andric if (L.contains(I->getParent()))
254*c9157d92SDimitry Andric FixedRightSCEV = RightSCEV;
255*c9157d92SDimitry Andric
256*c9157d92SDimitry Andric if (IsIncreasing) {
257*c9157d92SDimitry Andric bool DecreasedRightValueByOne = false;
258*c9157d92SDimitry Andric if (StepCI->isOne()) {
259*c9157d92SDimitry Andric // Try to turn eq/ne predicates to those we can work with.
260*c9157d92SDimitry Andric if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
261*c9157d92SDimitry Andric // while (++i != len) { while (++i < len) {
262*c9157d92SDimitry Andric // ... ---> ...
263*c9157d92SDimitry Andric // } }
264*c9157d92SDimitry Andric // If both parts are known non-negative, it is profitable to use
265*c9157d92SDimitry Andric // unsigned comparison in increasing loop. This allows us to make the
266*c9157d92SDimitry Andric // comparison check against "RightSCEV + 1" more optimistic.
267*c9157d92SDimitry Andric if (isKnownNonNegativeInLoop(IndVarStart, &L, SE) &&
268*c9157d92SDimitry Andric isKnownNonNegativeInLoop(RightSCEV, &L, SE))
269*c9157d92SDimitry Andric Pred = ICmpInst::ICMP_ULT;
270*c9157d92SDimitry Andric else
271*c9157d92SDimitry Andric Pred = ICmpInst::ICMP_SLT;
272*c9157d92SDimitry Andric else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
273*c9157d92SDimitry Andric // while (true) { while (true) {
274*c9157d92SDimitry Andric // if (++i == len) ---> if (++i > len - 1)
275*c9157d92SDimitry Andric // break; break;
276*c9157d92SDimitry Andric // ... ...
277*c9157d92SDimitry Andric // } }
278*c9157d92SDimitry Andric if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
279*c9157d92SDimitry Andric cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ false)) {
280*c9157d92SDimitry Andric Pred = ICmpInst::ICMP_UGT;
281*c9157d92SDimitry Andric RightSCEV =
282*c9157d92SDimitry Andric SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
283*c9157d92SDimitry Andric DecreasedRightValueByOne = true;
284*c9157d92SDimitry Andric } else if (cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ true)) {
285*c9157d92SDimitry Andric Pred = ICmpInst::ICMP_SGT;
286*c9157d92SDimitry Andric RightSCEV =
287*c9157d92SDimitry Andric SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
288*c9157d92SDimitry Andric DecreasedRightValueByOne = true;
289*c9157d92SDimitry Andric }
290*c9157d92SDimitry Andric }
291*c9157d92SDimitry Andric }
292*c9157d92SDimitry Andric
293*c9157d92SDimitry Andric bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
294*c9157d92SDimitry Andric bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
295*c9157d92SDimitry Andric bool FoundExpectedPred =
296*c9157d92SDimitry Andric (LTPred && LatchBrExitIdx == 1) || (GTPred && LatchBrExitIdx == 0);
297*c9157d92SDimitry Andric
298*c9157d92SDimitry Andric if (!FoundExpectedPred) {
299*c9157d92SDimitry Andric FailureReason = "expected icmp slt semantically, found something else";
300*c9157d92SDimitry Andric return std::nullopt;
301*c9157d92SDimitry Andric }
302*c9157d92SDimitry Andric
303*c9157d92SDimitry Andric IsSignedPredicate = ICmpInst::isSigned(Pred);
304*c9157d92SDimitry Andric if (!IsSignedPredicate && !AllowUnsignedLatchCond) {
305*c9157d92SDimitry Andric FailureReason = "unsigned latch conditions are explicitly prohibited";
306*c9157d92SDimitry Andric return std::nullopt;
307*c9157d92SDimitry Andric }
308*c9157d92SDimitry Andric
309*c9157d92SDimitry Andric if (!isSafeIncreasingBound(IndVarStart, RightSCEV, Step, Pred,
310*c9157d92SDimitry Andric LatchBrExitIdx, &L, SE)) {
311*c9157d92SDimitry Andric FailureReason = "Unsafe loop bounds";
312*c9157d92SDimitry Andric return std::nullopt;
313*c9157d92SDimitry Andric }
314*c9157d92SDimitry Andric if (LatchBrExitIdx == 0) {
315*c9157d92SDimitry Andric // We need to increase the right value unless we have already decreased
316*c9157d92SDimitry Andric // it virtually when we replaced EQ with SGT.
317*c9157d92SDimitry Andric if (!DecreasedRightValueByOne)
318*c9157d92SDimitry Andric FixedRightSCEV =
319*c9157d92SDimitry Andric SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
320*c9157d92SDimitry Andric } else {
321*c9157d92SDimitry Andric assert(!DecreasedRightValueByOne &&
322*c9157d92SDimitry Andric "Right value can be decreased only for LatchBrExitIdx == 0!");
323*c9157d92SDimitry Andric }
324*c9157d92SDimitry Andric } else {
325*c9157d92SDimitry Andric bool IncreasedRightValueByOne = false;
326*c9157d92SDimitry Andric if (StepCI->isMinusOne()) {
327*c9157d92SDimitry Andric // Try to turn eq/ne predicates to those we can work with.
328*c9157d92SDimitry Andric if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
329*c9157d92SDimitry Andric // while (--i != len) { while (--i > len) {
330*c9157d92SDimitry Andric // ... ---> ...
331*c9157d92SDimitry Andric // } }
332*c9157d92SDimitry Andric // We intentionally don't turn the predicate into UGT even if we know
333*c9157d92SDimitry Andric // that both operands are non-negative, because it will only pessimize
334*c9157d92SDimitry Andric // our check against "RightSCEV - 1".
335*c9157d92SDimitry Andric Pred = ICmpInst::ICMP_SGT;
336*c9157d92SDimitry Andric else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
337*c9157d92SDimitry Andric // while (true) { while (true) {
338*c9157d92SDimitry Andric // if (--i == len) ---> if (--i < len + 1)
339*c9157d92SDimitry Andric // break; break;
340*c9157d92SDimitry Andric // ... ...
341*c9157d92SDimitry Andric // } }
342*c9157d92SDimitry Andric if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
343*c9157d92SDimitry Andric cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ false)) {
344*c9157d92SDimitry Andric Pred = ICmpInst::ICMP_ULT;
345*c9157d92SDimitry Andric RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
346*c9157d92SDimitry Andric IncreasedRightValueByOne = true;
347*c9157d92SDimitry Andric } else if (cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ true)) {
348*c9157d92SDimitry Andric Pred = ICmpInst::ICMP_SLT;
349*c9157d92SDimitry Andric RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
350*c9157d92SDimitry Andric IncreasedRightValueByOne = true;
351*c9157d92SDimitry Andric }
352*c9157d92SDimitry Andric }
353*c9157d92SDimitry Andric }
354*c9157d92SDimitry Andric
355*c9157d92SDimitry Andric bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
356*c9157d92SDimitry Andric bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
357*c9157d92SDimitry Andric
358*c9157d92SDimitry Andric bool FoundExpectedPred =
359*c9157d92SDimitry Andric (GTPred && LatchBrExitIdx == 1) || (LTPred && LatchBrExitIdx == 0);
360*c9157d92SDimitry Andric
361*c9157d92SDimitry Andric if (!FoundExpectedPred) {
362*c9157d92SDimitry Andric FailureReason = "expected icmp sgt semantically, found something else";
363*c9157d92SDimitry Andric return std::nullopt;
364*c9157d92SDimitry Andric }
365*c9157d92SDimitry Andric
366*c9157d92SDimitry Andric IsSignedPredicate =
367*c9157d92SDimitry Andric Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT;
368*c9157d92SDimitry Andric
369*c9157d92SDimitry Andric if (!IsSignedPredicate && !AllowUnsignedLatchCond) {
370*c9157d92SDimitry Andric FailureReason = "unsigned latch conditions are explicitly prohibited";
371*c9157d92SDimitry Andric return std::nullopt;
372*c9157d92SDimitry Andric }
373*c9157d92SDimitry Andric
374*c9157d92SDimitry Andric if (!isSafeDecreasingBound(IndVarStart, RightSCEV, Step, Pred,
375*c9157d92SDimitry Andric LatchBrExitIdx, &L, SE)) {
376*c9157d92SDimitry Andric FailureReason = "Unsafe bounds";
377*c9157d92SDimitry Andric return std::nullopt;
378*c9157d92SDimitry Andric }
379*c9157d92SDimitry Andric
380*c9157d92SDimitry Andric if (LatchBrExitIdx == 0) {
381*c9157d92SDimitry Andric // We need to decrease the right value unless we have already increased
382*c9157d92SDimitry Andric // it virtually when we replaced EQ with SLT.
383*c9157d92SDimitry Andric if (!IncreasedRightValueByOne)
384*c9157d92SDimitry Andric FixedRightSCEV =
385*c9157d92SDimitry Andric SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
386*c9157d92SDimitry Andric } else {
387*c9157d92SDimitry Andric assert(!IncreasedRightValueByOne &&
388*c9157d92SDimitry Andric "Right value can be increased only for LatchBrExitIdx == 0!");
389*c9157d92SDimitry Andric }
390*c9157d92SDimitry Andric }
391*c9157d92SDimitry Andric BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx);
392*c9157d92SDimitry Andric
393*c9157d92SDimitry Andric assert(!L.contains(LatchExit) && "expected an exit block!");
394*c9157d92SDimitry Andric const DataLayout &DL = Preheader->getModule()->getDataLayout();
395*c9157d92SDimitry Andric SCEVExpander Expander(SE, DL, "loop-constrainer");
396*c9157d92SDimitry Andric Instruction *Ins = Preheader->getTerminator();
397*c9157d92SDimitry Andric
398*c9157d92SDimitry Andric if (FixedRightSCEV)
399*c9157d92SDimitry Andric RightValue =
400*c9157d92SDimitry Andric Expander.expandCodeFor(FixedRightSCEV, FixedRightSCEV->getType(), Ins);
401*c9157d92SDimitry Andric
402*c9157d92SDimitry Andric Value *IndVarStartV = Expander.expandCodeFor(IndVarStart, IndVarTy, Ins);
403*c9157d92SDimitry Andric IndVarStartV->setName("indvar.start");
404*c9157d92SDimitry Andric
405*c9157d92SDimitry Andric LoopStructure Result;
406*c9157d92SDimitry Andric
407*c9157d92SDimitry Andric Result.Tag = "main";
408*c9157d92SDimitry Andric Result.Header = Header;
409*c9157d92SDimitry Andric Result.Latch = Latch;
410*c9157d92SDimitry Andric Result.LatchBr = LatchBr;
411*c9157d92SDimitry Andric Result.LatchExit = LatchExit;
412*c9157d92SDimitry Andric Result.LatchBrExitIdx = LatchBrExitIdx;
413*c9157d92SDimitry Andric Result.IndVarStart = IndVarStartV;
414*c9157d92SDimitry Andric Result.IndVarStep = StepCI;
415*c9157d92SDimitry Andric Result.IndVarBase = LeftValue;
416*c9157d92SDimitry Andric Result.IndVarIncreasing = IsIncreasing;
417*c9157d92SDimitry Andric Result.LoopExitAt = RightValue;
418*c9157d92SDimitry Andric Result.IsSignedPredicate = IsSignedPredicate;
419*c9157d92SDimitry Andric Result.ExitCountTy = cast<IntegerType>(MaxBETakenCount->getType());
420*c9157d92SDimitry Andric
421*c9157d92SDimitry Andric FailureReason = nullptr;
422*c9157d92SDimitry Andric
423*c9157d92SDimitry Andric return Result;
424*c9157d92SDimitry Andric }
425*c9157d92SDimitry Andric
426*c9157d92SDimitry Andric // Add metadata to the loop L to disable loop optimizations. Callers need to
427*c9157d92SDimitry Andric // confirm that optimizing loop L is not beneficial.
DisableAllLoopOptsOnLoop(Loop & L)428*c9157d92SDimitry Andric static void DisableAllLoopOptsOnLoop(Loop &L) {
429*c9157d92SDimitry Andric // We do not care about any existing loopID related metadata for L, since we
430*c9157d92SDimitry Andric // are setting all loop metadata to false.
431*c9157d92SDimitry Andric LLVMContext &Context = L.getHeader()->getContext();
432*c9157d92SDimitry Andric // Reserve first location for self reference to the LoopID metadata node.
433*c9157d92SDimitry Andric MDNode *Dummy = MDNode::get(Context, {});
434*c9157d92SDimitry Andric MDNode *DisableUnroll = MDNode::get(
435*c9157d92SDimitry Andric Context, {MDString::get(Context, "llvm.loop.unroll.disable")});
436*c9157d92SDimitry Andric Metadata *FalseVal =
437*c9157d92SDimitry Andric ConstantAsMetadata::get(ConstantInt::get(Type::getInt1Ty(Context), 0));
438*c9157d92SDimitry Andric MDNode *DisableVectorize = MDNode::get(
439*c9157d92SDimitry Andric Context,
440*c9157d92SDimitry Andric {MDString::get(Context, "llvm.loop.vectorize.enable"), FalseVal});
441*c9157d92SDimitry Andric MDNode *DisableLICMVersioning = MDNode::get(
442*c9157d92SDimitry Andric Context, {MDString::get(Context, "llvm.loop.licm_versioning.disable")});
443*c9157d92SDimitry Andric MDNode *DisableDistribution = MDNode::get(
444*c9157d92SDimitry Andric Context,
445*c9157d92SDimitry Andric {MDString::get(Context, "llvm.loop.distribute.enable"), FalseVal});
446*c9157d92SDimitry Andric MDNode *NewLoopID =
447*c9157d92SDimitry Andric MDNode::get(Context, {Dummy, DisableUnroll, DisableVectorize,
448*c9157d92SDimitry Andric DisableLICMVersioning, DisableDistribution});
449*c9157d92SDimitry Andric // Set operand 0 to refer to the loop id itself.
450*c9157d92SDimitry Andric NewLoopID->replaceOperandWith(0, NewLoopID);
451*c9157d92SDimitry Andric L.setLoopID(NewLoopID);
452*c9157d92SDimitry Andric }
453*c9157d92SDimitry Andric
LoopConstrainer(Loop & L,LoopInfo & LI,function_ref<void (Loop *,bool)> LPMAddNewLoop,const LoopStructure & LS,ScalarEvolution & SE,DominatorTree & DT,Type * T,SubRanges SR)454*c9157d92SDimitry Andric LoopConstrainer::LoopConstrainer(Loop &L, LoopInfo &LI,
455*c9157d92SDimitry Andric function_ref<void(Loop *, bool)> LPMAddNewLoop,
456*c9157d92SDimitry Andric const LoopStructure &LS, ScalarEvolution &SE,
457*c9157d92SDimitry Andric DominatorTree &DT, Type *T, SubRanges SR)
458*c9157d92SDimitry Andric : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), SE(SE),
459*c9157d92SDimitry Andric DT(DT), LI(LI), LPMAddNewLoop(LPMAddNewLoop), OriginalLoop(L), RangeTy(T),
460*c9157d92SDimitry Andric MainLoopStructure(LS), SR(SR) {}
461*c9157d92SDimitry Andric
cloneLoop(LoopConstrainer::ClonedLoop & Result,const char * Tag) const462*c9157d92SDimitry Andric void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result,
463*c9157d92SDimitry Andric const char *Tag) const {
464*c9157d92SDimitry Andric for (BasicBlock *BB : OriginalLoop.getBlocks()) {
465*c9157d92SDimitry Andric BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F);
466*c9157d92SDimitry Andric Result.Blocks.push_back(Clone);
467*c9157d92SDimitry Andric Result.Map[BB] = Clone;
468*c9157d92SDimitry Andric }
469*c9157d92SDimitry Andric
470*c9157d92SDimitry Andric auto GetClonedValue = [&Result](Value *V) {
471*c9157d92SDimitry Andric assert(V && "null values not in domain!");
472*c9157d92SDimitry Andric auto It = Result.Map.find(V);
473*c9157d92SDimitry Andric if (It == Result.Map.end())
474*c9157d92SDimitry Andric return V;
475*c9157d92SDimitry Andric return static_cast<Value *>(It->second);
476*c9157d92SDimitry Andric };
477*c9157d92SDimitry Andric
478*c9157d92SDimitry Andric auto *ClonedLatch =
479*c9157d92SDimitry Andric cast<BasicBlock>(GetClonedValue(OriginalLoop.getLoopLatch()));
480*c9157d92SDimitry Andric ClonedLatch->getTerminator()->setMetadata(ClonedLoopTag,
481*c9157d92SDimitry Andric MDNode::get(Ctx, {}));
482*c9157d92SDimitry Andric
483*c9157d92SDimitry Andric Result.Structure = MainLoopStructure.map(GetClonedValue);
484*c9157d92SDimitry Andric Result.Structure.Tag = Tag;
485*c9157d92SDimitry Andric
486*c9157d92SDimitry Andric for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) {
487*c9157d92SDimitry Andric BasicBlock *ClonedBB = Result.Blocks[i];
488*c9157d92SDimitry Andric BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i];
489*c9157d92SDimitry Andric
490*c9157d92SDimitry Andric assert(Result.Map[OriginalBB] == ClonedBB && "invariant!");
491*c9157d92SDimitry Andric
492*c9157d92SDimitry Andric for (Instruction &I : *ClonedBB)
493*c9157d92SDimitry Andric RemapInstruction(&I, Result.Map,
494*c9157d92SDimitry Andric RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
495*c9157d92SDimitry Andric
496*c9157d92SDimitry Andric // Exit blocks will now have one more predecessor and their PHI nodes need
497*c9157d92SDimitry Andric // to be edited to reflect that. No phi nodes need to be introduced because
498*c9157d92SDimitry Andric // the loop is in LCSSA.
499*c9157d92SDimitry Andric
500*c9157d92SDimitry Andric for (auto *SBB : successors(OriginalBB)) {
501*c9157d92SDimitry Andric if (OriginalLoop.contains(SBB))
502*c9157d92SDimitry Andric continue; // not an exit block
503*c9157d92SDimitry Andric
504*c9157d92SDimitry Andric for (PHINode &PN : SBB->phis()) {
505*c9157d92SDimitry Andric Value *OldIncoming = PN.getIncomingValueForBlock(OriginalBB);
506*c9157d92SDimitry Andric PN.addIncoming(GetClonedValue(OldIncoming), ClonedBB);
507*c9157d92SDimitry Andric SE.forgetValue(&PN);
508*c9157d92SDimitry Andric }
509*c9157d92SDimitry Andric }
510*c9157d92SDimitry Andric }
511*c9157d92SDimitry Andric }
512*c9157d92SDimitry Andric
changeIterationSpaceEnd(const LoopStructure & LS,BasicBlock * Preheader,Value * ExitSubloopAt,BasicBlock * ContinuationBlock) const513*c9157d92SDimitry Andric LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd(
514*c9157d92SDimitry Andric const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt,
515*c9157d92SDimitry Andric BasicBlock *ContinuationBlock) const {
516*c9157d92SDimitry Andric // We start with a loop with a single latch:
517*c9157d92SDimitry Andric //
518*c9157d92SDimitry Andric // +--------------------+
519*c9157d92SDimitry Andric // | |
520*c9157d92SDimitry Andric // | preheader |
521*c9157d92SDimitry Andric // | |
522*c9157d92SDimitry Andric // +--------+-----------+
523*c9157d92SDimitry Andric // | ----------------\
524*c9157d92SDimitry Andric // | / |
525*c9157d92SDimitry Andric // +--------v----v------+ |
526*c9157d92SDimitry Andric // | | |
527*c9157d92SDimitry Andric // | header | |
528*c9157d92SDimitry Andric // | | |
529*c9157d92SDimitry Andric // +--------------------+ |
530*c9157d92SDimitry Andric // |
531*c9157d92SDimitry Andric // ..... |
532*c9157d92SDimitry Andric // |
533*c9157d92SDimitry Andric // +--------------------+ |
534*c9157d92SDimitry Andric // | | |
535*c9157d92SDimitry Andric // | latch >----------/
536*c9157d92SDimitry Andric // | |
537*c9157d92SDimitry Andric // +-------v------------+
538*c9157d92SDimitry Andric // |
539*c9157d92SDimitry Andric // |
540*c9157d92SDimitry Andric // | +--------------------+
541*c9157d92SDimitry Andric // | | |
542*c9157d92SDimitry Andric // +---> original exit |
543*c9157d92SDimitry Andric // | |
544*c9157d92SDimitry Andric // +--------------------+
545*c9157d92SDimitry Andric //
546*c9157d92SDimitry Andric // We change the control flow to look like
547*c9157d92SDimitry Andric //
548*c9157d92SDimitry Andric //
549*c9157d92SDimitry Andric // +--------------------+
550*c9157d92SDimitry Andric // | |
551*c9157d92SDimitry Andric // | preheader >-------------------------+
552*c9157d92SDimitry Andric // | | |
553*c9157d92SDimitry Andric // +--------v-----------+ |
554*c9157d92SDimitry Andric // | /-------------+ |
555*c9157d92SDimitry Andric // | / | |
556*c9157d92SDimitry Andric // +--------v--v--------+ | |
557*c9157d92SDimitry Andric // | | | |
558*c9157d92SDimitry Andric // | header | | +--------+ |
559*c9157d92SDimitry Andric // | | | | | |
560*c9157d92SDimitry Andric // +--------------------+ | | +-----v-----v-----------+
561*c9157d92SDimitry Andric // | | | |
562*c9157d92SDimitry Andric // | | | .pseudo.exit |
563*c9157d92SDimitry Andric // | | | |
564*c9157d92SDimitry Andric // | | +-----------v-----------+
565*c9157d92SDimitry Andric // | | |
566*c9157d92SDimitry Andric // ..... | | |
567*c9157d92SDimitry Andric // | | +--------v-------------+
568*c9157d92SDimitry Andric // +--------------------+ | | | |
569*c9157d92SDimitry Andric // | | | | | ContinuationBlock |
570*c9157d92SDimitry Andric // | latch >------+ | | |
571*c9157d92SDimitry Andric // | | | +----------------------+
572*c9157d92SDimitry Andric // +---------v----------+ |
573*c9157d92SDimitry Andric // | |
574*c9157d92SDimitry Andric // | |
575*c9157d92SDimitry Andric // | +---------------^-----+
576*c9157d92SDimitry Andric // | | |
577*c9157d92SDimitry Andric // +-----> .exit.selector |
578*c9157d92SDimitry Andric // | |
579*c9157d92SDimitry Andric // +----------v----------+
580*c9157d92SDimitry Andric // |
581*c9157d92SDimitry Andric // +--------------------+ |
582*c9157d92SDimitry Andric // | | |
583*c9157d92SDimitry Andric // | original exit <----+
584*c9157d92SDimitry Andric // | |
585*c9157d92SDimitry Andric // +--------------------+
586*c9157d92SDimitry Andric
587*c9157d92SDimitry Andric RewrittenRangeInfo RRI;
588*c9157d92SDimitry Andric
589*c9157d92SDimitry Andric BasicBlock *BBInsertLocation = LS.Latch->getNextNode();
590*c9157d92SDimitry Andric RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector",
591*c9157d92SDimitry Andric &F, BBInsertLocation);
592*c9157d92SDimitry Andric RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F,
593*c9157d92SDimitry Andric BBInsertLocation);
594*c9157d92SDimitry Andric
595*c9157d92SDimitry Andric BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator());
596*c9157d92SDimitry Andric bool Increasing = LS.IndVarIncreasing;
597*c9157d92SDimitry Andric bool IsSignedPredicate = LS.IsSignedPredicate;
598*c9157d92SDimitry Andric
599*c9157d92SDimitry Andric IRBuilder<> B(PreheaderJump);
600*c9157d92SDimitry Andric auto NoopOrExt = [&](Value *V) {
601*c9157d92SDimitry Andric if (V->getType() == RangeTy)
602*c9157d92SDimitry Andric return V;
603*c9157d92SDimitry Andric return IsSignedPredicate ? B.CreateSExt(V, RangeTy, "wide." + V->getName())
604*c9157d92SDimitry Andric : B.CreateZExt(V, RangeTy, "wide." + V->getName());
605*c9157d92SDimitry Andric };
606*c9157d92SDimitry Andric
607*c9157d92SDimitry Andric // EnterLoopCond - is it okay to start executing this `LS'?
608*c9157d92SDimitry Andric Value *EnterLoopCond = nullptr;
609*c9157d92SDimitry Andric auto Pred =
610*c9157d92SDimitry Andric Increasing
611*c9157d92SDimitry Andric ? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT)
612*c9157d92SDimitry Andric : (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT);
613*c9157d92SDimitry Andric Value *IndVarStart = NoopOrExt(LS.IndVarStart);
614*c9157d92SDimitry Andric EnterLoopCond = B.CreateICmp(Pred, IndVarStart, ExitSubloopAt);
615*c9157d92SDimitry Andric
616*c9157d92SDimitry Andric B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit);
617*c9157d92SDimitry Andric PreheaderJump->eraseFromParent();
618*c9157d92SDimitry Andric
619*c9157d92SDimitry Andric LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector);
620*c9157d92SDimitry Andric B.SetInsertPoint(LS.LatchBr);
621*c9157d92SDimitry Andric Value *IndVarBase = NoopOrExt(LS.IndVarBase);
622*c9157d92SDimitry Andric Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, IndVarBase, ExitSubloopAt);
623*c9157d92SDimitry Andric
624*c9157d92SDimitry Andric Value *CondForBranch = LS.LatchBrExitIdx == 1
625*c9157d92SDimitry Andric ? TakeBackedgeLoopCond
626*c9157d92SDimitry Andric : B.CreateNot(TakeBackedgeLoopCond);
627*c9157d92SDimitry Andric
628*c9157d92SDimitry Andric LS.LatchBr->setCondition(CondForBranch);
629*c9157d92SDimitry Andric
630*c9157d92SDimitry Andric B.SetInsertPoint(RRI.ExitSelector);
631*c9157d92SDimitry Andric
632*c9157d92SDimitry Andric // IterationsLeft - are there any more iterations left, given the original
633*c9157d92SDimitry Andric // upper bound on the induction variable? If not, we branch to the "real"
634*c9157d92SDimitry Andric // exit.
635*c9157d92SDimitry Andric Value *LoopExitAt = NoopOrExt(LS.LoopExitAt);
636*c9157d92SDimitry Andric Value *IterationsLeft = B.CreateICmp(Pred, IndVarBase, LoopExitAt);
637*c9157d92SDimitry Andric B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit);
638*c9157d92SDimitry Andric
639*c9157d92SDimitry Andric BranchInst *BranchToContinuation =
640*c9157d92SDimitry Andric BranchInst::Create(ContinuationBlock, RRI.PseudoExit);
641*c9157d92SDimitry Andric
642*c9157d92SDimitry Andric // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of
643*c9157d92SDimitry Andric // each of the PHI nodes in the loop header. This feeds into the initial
644*c9157d92SDimitry Andric // value of the same PHI nodes if/when we continue execution.
645*c9157d92SDimitry Andric for (PHINode &PN : LS.Header->phis()) {
646*c9157d92SDimitry Andric PHINode *NewPHI = PHINode::Create(PN.getType(), 2, PN.getName() + ".copy",
647*c9157d92SDimitry Andric BranchToContinuation);
648*c9157d92SDimitry Andric
649*c9157d92SDimitry Andric NewPHI->addIncoming(PN.getIncomingValueForBlock(Preheader), Preheader);
650*c9157d92SDimitry Andric NewPHI->addIncoming(PN.getIncomingValueForBlock(LS.Latch),
651*c9157d92SDimitry Andric RRI.ExitSelector);
652*c9157d92SDimitry Andric RRI.PHIValuesAtPseudoExit.push_back(NewPHI);
653*c9157d92SDimitry Andric }
654*c9157d92SDimitry Andric
655*c9157d92SDimitry Andric RRI.IndVarEnd = PHINode::Create(IndVarBase->getType(), 2, "indvar.end",
656*c9157d92SDimitry Andric BranchToContinuation);
657*c9157d92SDimitry Andric RRI.IndVarEnd->addIncoming(IndVarStart, Preheader);
658*c9157d92SDimitry Andric RRI.IndVarEnd->addIncoming(IndVarBase, RRI.ExitSelector);
659*c9157d92SDimitry Andric
660*c9157d92SDimitry Andric // The latch exit now has a branch from `RRI.ExitSelector' instead of
661*c9157d92SDimitry Andric // `LS.Latch'. The PHI nodes need to be updated to reflect that.
662*c9157d92SDimitry Andric LS.LatchExit->replacePhiUsesWith(LS.Latch, RRI.ExitSelector);
663*c9157d92SDimitry Andric
664*c9157d92SDimitry Andric return RRI;
665*c9157d92SDimitry Andric }
666*c9157d92SDimitry Andric
rewriteIncomingValuesForPHIs(LoopStructure & LS,BasicBlock * ContinuationBlock,const LoopConstrainer::RewrittenRangeInfo & RRI) const667*c9157d92SDimitry Andric void LoopConstrainer::rewriteIncomingValuesForPHIs(
668*c9157d92SDimitry Andric LoopStructure &LS, BasicBlock *ContinuationBlock,
669*c9157d92SDimitry Andric const LoopConstrainer::RewrittenRangeInfo &RRI) const {
670*c9157d92SDimitry Andric unsigned PHIIndex = 0;
671*c9157d92SDimitry Andric for (PHINode &PN : LS.Header->phis())
672*c9157d92SDimitry Andric PN.setIncomingValueForBlock(ContinuationBlock,
673*c9157d92SDimitry Andric RRI.PHIValuesAtPseudoExit[PHIIndex++]);
674*c9157d92SDimitry Andric
675*c9157d92SDimitry Andric LS.IndVarStart = RRI.IndVarEnd;
676*c9157d92SDimitry Andric }
677*c9157d92SDimitry Andric
createPreheader(const LoopStructure & LS,BasicBlock * OldPreheader,const char * Tag) const678*c9157d92SDimitry Andric BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS,
679*c9157d92SDimitry Andric BasicBlock *OldPreheader,
680*c9157d92SDimitry Andric const char *Tag) const {
681*c9157d92SDimitry Andric BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header);
682*c9157d92SDimitry Andric BranchInst::Create(LS.Header, Preheader);
683*c9157d92SDimitry Andric
684*c9157d92SDimitry Andric LS.Header->replacePhiUsesWith(OldPreheader, Preheader);
685*c9157d92SDimitry Andric
686*c9157d92SDimitry Andric return Preheader;
687*c9157d92SDimitry Andric }
688*c9157d92SDimitry Andric
addToParentLoopIfNeeded(ArrayRef<BasicBlock * > BBs)689*c9157d92SDimitry Andric void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) {
690*c9157d92SDimitry Andric Loop *ParentLoop = OriginalLoop.getParentLoop();
691*c9157d92SDimitry Andric if (!ParentLoop)
692*c9157d92SDimitry Andric return;
693*c9157d92SDimitry Andric
694*c9157d92SDimitry Andric for (BasicBlock *BB : BBs)
695*c9157d92SDimitry Andric ParentLoop->addBasicBlockToLoop(BB, LI);
696*c9157d92SDimitry Andric }
697*c9157d92SDimitry Andric
createClonedLoopStructure(Loop * Original,Loop * Parent,ValueToValueMapTy & VM,bool IsSubloop)698*c9157d92SDimitry Andric Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent,
699*c9157d92SDimitry Andric ValueToValueMapTy &VM,
700*c9157d92SDimitry Andric bool IsSubloop) {
701*c9157d92SDimitry Andric Loop &New = *LI.AllocateLoop();
702*c9157d92SDimitry Andric if (Parent)
703*c9157d92SDimitry Andric Parent->addChildLoop(&New);
704*c9157d92SDimitry Andric else
705*c9157d92SDimitry Andric LI.addTopLevelLoop(&New);
706*c9157d92SDimitry Andric LPMAddNewLoop(&New, IsSubloop);
707*c9157d92SDimitry Andric
708*c9157d92SDimitry Andric // Add all of the blocks in Original to the new loop.
709*c9157d92SDimitry Andric for (auto *BB : Original->blocks())
710*c9157d92SDimitry Andric if (LI.getLoopFor(BB) == Original)
711*c9157d92SDimitry Andric New.addBasicBlockToLoop(cast<BasicBlock>(VM[BB]), LI);
712*c9157d92SDimitry Andric
713*c9157d92SDimitry Andric // Add all of the subloops to the new loop.
714*c9157d92SDimitry Andric for (Loop *SubLoop : *Original)
715*c9157d92SDimitry Andric createClonedLoopStructure(SubLoop, &New, VM, /* IsSubloop */ true);
716*c9157d92SDimitry Andric
717*c9157d92SDimitry Andric return &New;
718*c9157d92SDimitry Andric }
719*c9157d92SDimitry Andric
run()720*c9157d92SDimitry Andric bool LoopConstrainer::run() {
721*c9157d92SDimitry Andric BasicBlock *Preheader = OriginalLoop.getLoopPreheader();
722*c9157d92SDimitry Andric assert(Preheader != nullptr && "precondition!");
723*c9157d92SDimitry Andric
724*c9157d92SDimitry Andric OriginalPreheader = Preheader;
725*c9157d92SDimitry Andric MainLoopPreheader = Preheader;
726*c9157d92SDimitry Andric bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate;
727*c9157d92SDimitry Andric bool Increasing = MainLoopStructure.IndVarIncreasing;
728*c9157d92SDimitry Andric IntegerType *IVTy = cast<IntegerType>(RangeTy);
729*c9157d92SDimitry Andric
730*c9157d92SDimitry Andric SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "loop-constrainer");
731*c9157d92SDimitry Andric Instruction *InsertPt = OriginalPreheader->getTerminator();
732*c9157d92SDimitry Andric
733*c9157d92SDimitry Andric // It would have been better to make `PreLoop' and `PostLoop'
734*c9157d92SDimitry Andric // `std::optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy
735*c9157d92SDimitry Andric // constructor.
736*c9157d92SDimitry Andric ClonedLoop PreLoop, PostLoop;
737*c9157d92SDimitry Andric bool NeedsPreLoop =
738*c9157d92SDimitry Andric Increasing ? SR.LowLimit.has_value() : SR.HighLimit.has_value();
739*c9157d92SDimitry Andric bool NeedsPostLoop =
740*c9157d92SDimitry Andric Increasing ? SR.HighLimit.has_value() : SR.LowLimit.has_value();
741*c9157d92SDimitry Andric
742*c9157d92SDimitry Andric Value *ExitPreLoopAt = nullptr;
743*c9157d92SDimitry Andric Value *ExitMainLoopAt = nullptr;
744*c9157d92SDimitry Andric const SCEVConstant *MinusOneS =
745*c9157d92SDimitry Andric cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */));
746*c9157d92SDimitry Andric
747*c9157d92SDimitry Andric if (NeedsPreLoop) {
748*c9157d92SDimitry Andric const SCEV *ExitPreLoopAtSCEV = nullptr;
749*c9157d92SDimitry Andric
750*c9157d92SDimitry Andric if (Increasing)
751*c9157d92SDimitry Andric ExitPreLoopAtSCEV = *SR.LowLimit;
752*c9157d92SDimitry Andric else if (cannotBeMinInLoop(*SR.HighLimit, &OriginalLoop, SE,
753*c9157d92SDimitry Andric IsSignedPredicate))
754*c9157d92SDimitry Andric ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS);
755*c9157d92SDimitry Andric else {
756*c9157d92SDimitry Andric LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing "
757*c9157d92SDimitry Andric << "preloop exit limit. HighLimit = "
758*c9157d92SDimitry Andric << *(*SR.HighLimit) << "\n");
759*c9157d92SDimitry Andric return false;
760*c9157d92SDimitry Andric }
761*c9157d92SDimitry Andric
762*c9157d92SDimitry Andric if (!Expander.isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt)) {
763*c9157d92SDimitry Andric LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the"
764*c9157d92SDimitry Andric << " preloop exit limit " << *ExitPreLoopAtSCEV
765*c9157d92SDimitry Andric << " at block " << InsertPt->getParent()->getName()
766*c9157d92SDimitry Andric << "\n");
767*c9157d92SDimitry Andric return false;
768*c9157d92SDimitry Andric }
769*c9157d92SDimitry Andric
770*c9157d92SDimitry Andric ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt);
771*c9157d92SDimitry Andric ExitPreLoopAt->setName("exit.preloop.at");
772*c9157d92SDimitry Andric }
773*c9157d92SDimitry Andric
774*c9157d92SDimitry Andric if (NeedsPostLoop) {
775*c9157d92SDimitry Andric const SCEV *ExitMainLoopAtSCEV = nullptr;
776*c9157d92SDimitry Andric
777*c9157d92SDimitry Andric if (Increasing)
778*c9157d92SDimitry Andric ExitMainLoopAtSCEV = *SR.HighLimit;
779*c9157d92SDimitry Andric else if (cannotBeMinInLoop(*SR.LowLimit, &OriginalLoop, SE,
780*c9157d92SDimitry Andric IsSignedPredicate))
781*c9157d92SDimitry Andric ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS);
782*c9157d92SDimitry Andric else {
783*c9157d92SDimitry Andric LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing "
784*c9157d92SDimitry Andric << "mainloop exit limit. LowLimit = "
785*c9157d92SDimitry Andric << *(*SR.LowLimit) << "\n");
786*c9157d92SDimitry Andric return false;
787*c9157d92SDimitry Andric }
788*c9157d92SDimitry Andric
789*c9157d92SDimitry Andric if (!Expander.isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt)) {
790*c9157d92SDimitry Andric LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the"
791*c9157d92SDimitry Andric << " main loop exit limit " << *ExitMainLoopAtSCEV
792*c9157d92SDimitry Andric << " at block " << InsertPt->getParent()->getName()
793*c9157d92SDimitry Andric << "\n");
794*c9157d92SDimitry Andric return false;
795*c9157d92SDimitry Andric }
796*c9157d92SDimitry Andric
797*c9157d92SDimitry Andric ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt);
798*c9157d92SDimitry Andric ExitMainLoopAt->setName("exit.mainloop.at");
799*c9157d92SDimitry Andric }
800*c9157d92SDimitry Andric
801*c9157d92SDimitry Andric // We clone these ahead of time so that we don't have to deal with changing
802*c9157d92SDimitry Andric // and temporarily invalid IR as we transform the loops.
803*c9157d92SDimitry Andric if (NeedsPreLoop)
804*c9157d92SDimitry Andric cloneLoop(PreLoop, "preloop");
805*c9157d92SDimitry Andric if (NeedsPostLoop)
806*c9157d92SDimitry Andric cloneLoop(PostLoop, "postloop");
807*c9157d92SDimitry Andric
808*c9157d92SDimitry Andric RewrittenRangeInfo PreLoopRRI;
809*c9157d92SDimitry Andric
810*c9157d92SDimitry Andric if (NeedsPreLoop) {
811*c9157d92SDimitry Andric Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header,
812*c9157d92SDimitry Andric PreLoop.Structure.Header);
813*c9157d92SDimitry Andric
814*c9157d92SDimitry Andric MainLoopPreheader =
815*c9157d92SDimitry Andric createPreheader(MainLoopStructure, Preheader, "mainloop");
816*c9157d92SDimitry Andric PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader,
817*c9157d92SDimitry Andric ExitPreLoopAt, MainLoopPreheader);
818*c9157d92SDimitry Andric rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader,
819*c9157d92SDimitry Andric PreLoopRRI);
820*c9157d92SDimitry Andric }
821*c9157d92SDimitry Andric
822*c9157d92SDimitry Andric BasicBlock *PostLoopPreheader = nullptr;
823*c9157d92SDimitry Andric RewrittenRangeInfo PostLoopRRI;
824*c9157d92SDimitry Andric
825*c9157d92SDimitry Andric if (NeedsPostLoop) {
826*c9157d92SDimitry Andric PostLoopPreheader =
827*c9157d92SDimitry Andric createPreheader(PostLoop.Structure, Preheader, "postloop");
828*c9157d92SDimitry Andric PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader,
829*c9157d92SDimitry Andric ExitMainLoopAt, PostLoopPreheader);
830*c9157d92SDimitry Andric rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader,
831*c9157d92SDimitry Andric PostLoopRRI);
832*c9157d92SDimitry Andric }
833*c9157d92SDimitry Andric
834*c9157d92SDimitry Andric BasicBlock *NewMainLoopPreheader =
835*c9157d92SDimitry Andric MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr;
836*c9157d92SDimitry Andric BasicBlock *NewBlocks[] = {PostLoopPreheader, PreLoopRRI.PseudoExit,
837*c9157d92SDimitry Andric PreLoopRRI.ExitSelector, PostLoopRRI.PseudoExit,
838*c9157d92SDimitry Andric PostLoopRRI.ExitSelector, NewMainLoopPreheader};
839*c9157d92SDimitry Andric
840*c9157d92SDimitry Andric // Some of the above may be nullptr, filter them out before passing to
841*c9157d92SDimitry Andric // addToParentLoopIfNeeded.
842*c9157d92SDimitry Andric auto NewBlocksEnd =
843*c9157d92SDimitry Andric std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr);
844*c9157d92SDimitry Andric
845*c9157d92SDimitry Andric addToParentLoopIfNeeded(ArrayRef(std::begin(NewBlocks), NewBlocksEnd));
846*c9157d92SDimitry Andric
847*c9157d92SDimitry Andric DT.recalculate(F);
848*c9157d92SDimitry Andric
849*c9157d92SDimitry Andric // We need to first add all the pre and post loop blocks into the loop
850*c9157d92SDimitry Andric // structures (as part of createClonedLoopStructure), and then update the
851*c9157d92SDimitry Andric // LCSSA form and LoopSimplifyForm. This is necessary for correctly updating
852*c9157d92SDimitry Andric // LI when LoopSimplifyForm is generated.
853*c9157d92SDimitry Andric Loop *PreL = nullptr, *PostL = nullptr;
854*c9157d92SDimitry Andric if (!PreLoop.Blocks.empty()) {
855*c9157d92SDimitry Andric PreL = createClonedLoopStructure(&OriginalLoop,
856*c9157d92SDimitry Andric OriginalLoop.getParentLoop(), PreLoop.Map,
857*c9157d92SDimitry Andric /* IsSubLoop */ false);
858*c9157d92SDimitry Andric }
859*c9157d92SDimitry Andric
860*c9157d92SDimitry Andric if (!PostLoop.Blocks.empty()) {
861*c9157d92SDimitry Andric PostL =
862*c9157d92SDimitry Andric createClonedLoopStructure(&OriginalLoop, OriginalLoop.getParentLoop(),
863*c9157d92SDimitry Andric PostLoop.Map, /* IsSubLoop */ false);
864*c9157d92SDimitry Andric }
865*c9157d92SDimitry Andric
866*c9157d92SDimitry Andric // This function canonicalizes the loop into Loop-Simplify and LCSSA forms.
867*c9157d92SDimitry Andric auto CanonicalizeLoop = [&](Loop *L, bool IsOriginalLoop) {
868*c9157d92SDimitry Andric formLCSSARecursively(*L, DT, &LI, &SE);
869*c9157d92SDimitry Andric simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, true);
870*c9157d92SDimitry Andric // Pre/post loops are slow paths, we do not need to perform any loop
871*c9157d92SDimitry Andric // optimizations on them.
872*c9157d92SDimitry Andric if (!IsOriginalLoop)
873*c9157d92SDimitry Andric DisableAllLoopOptsOnLoop(*L);
874*c9157d92SDimitry Andric };
875*c9157d92SDimitry Andric if (PreL)
876*c9157d92SDimitry Andric CanonicalizeLoop(PreL, false);
877*c9157d92SDimitry Andric if (PostL)
878*c9157d92SDimitry Andric CanonicalizeLoop(PostL, false);
879*c9157d92SDimitry Andric CanonicalizeLoop(&OriginalLoop, true);
880*c9157d92SDimitry Andric
881*c9157d92SDimitry Andric /// At this point:
882*c9157d92SDimitry Andric /// - We've broken a "main loop" out of the loop in a way that the "main loop"
883*c9157d92SDimitry Andric /// runs with the induction variable in a subset of [Begin, End).
884*c9157d92SDimitry Andric /// - There is no overflow when computing "main loop" exit limit.
885*c9157d92SDimitry Andric /// - Max latch taken count of the loop is limited.
886*c9157d92SDimitry Andric /// It guarantees that induction variable will not overflow iterating in the
887*c9157d92SDimitry Andric /// "main loop".
888*c9157d92SDimitry Andric if (isa<OverflowingBinaryOperator>(MainLoopStructure.IndVarBase))
889*c9157d92SDimitry Andric if (IsSignedPredicate)
890*c9157d92SDimitry Andric cast<BinaryOperator>(MainLoopStructure.IndVarBase)
891*c9157d92SDimitry Andric ->setHasNoSignedWrap(true);
892*c9157d92SDimitry Andric /// TODO: support unsigned predicate.
893*c9157d92SDimitry Andric /// To add NUW flag we need to prove that both operands of BO are
894*c9157d92SDimitry Andric /// non-negative. E.g:
895*c9157d92SDimitry Andric /// ...
896*c9157d92SDimitry Andric /// %iv.next = add nsw i32 %iv, -1
897*c9157d92SDimitry Andric /// %cmp = icmp ult i32 %iv.next, %n
898*c9157d92SDimitry Andric /// br i1 %cmp, label %loopexit, label %loop
899*c9157d92SDimitry Andric ///
900*c9157d92SDimitry Andric /// -1 is MAX_UINT in terms of unsigned int. Adding anything but zero will
901*c9157d92SDimitry Andric /// overflow, therefore NUW flag is not legal here.
902*c9157d92SDimitry Andric
903*c9157d92SDimitry Andric return true;
904*c9157d92SDimitry Andric }
905