1 //===-- InductiveRangeCheckElimination.cpp - ------------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 // The InductiveRangeCheckElimination pass splits a loop's iteration space into
10 // three disjoint ranges.  It does that in a way such that the loop running in
11 // the middle loop provably does not need range checks. As an example, it will
12 // convert
13 //
14 //   len = < known positive >
15 //   for (i = 0; i < n; i++) {
16 //     if (0 <= i && i < len) {
17 //       do_something();
18 //     } else {
19 //       throw_out_of_bounds();
20 //     }
21 //   }
22 //
23 // to
24 //
25 //   len = < known positive >
26 //   limit = smin(n, len)
27 //   // no first segment
28 //   for (i = 0; i < limit; i++) {
29 //     if (0 <= i && i < len) { // this check is fully redundant
30 //       do_something();
31 //     } else {
32 //       throw_out_of_bounds();
33 //     }
34 //   }
35 //   for (i = limit; i < n; i++) {
36 //     if (0 <= i && i < len) {
37 //       do_something();
38 //     } else {
39 //       throw_out_of_bounds();
40 //     }
41 //   }
42 //===----------------------------------------------------------------------===//
43 
44 #include "llvm/ADT/Optional.h"
45 #include "llvm/Analysis/BranchProbabilityInfo.h"
46 #include "llvm/Analysis/InstructionSimplify.h"
47 #include "llvm/Analysis/LoopInfo.h"
48 #include "llvm/Analysis/LoopPass.h"
49 #include "llvm/Analysis/ScalarEvolution.h"
50 #include "llvm/Analysis/ScalarEvolutionExpander.h"
51 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
52 #include "llvm/Analysis/ValueTracking.h"
53 #include "llvm/IR/Dominators.h"
54 #include "llvm/IR/Function.h"
55 #include "llvm/IR/IRBuilder.h"
56 #include "llvm/IR/Instructions.h"
57 #include "llvm/IR/Module.h"
58 #include "llvm/IR/PatternMatch.h"
59 #include "llvm/IR/ValueHandle.h"
60 #include "llvm/IR/Verifier.h"
61 #include "llvm/Pass.h"
62 #include "llvm/Support/Debug.h"
63 #include "llvm/Support/raw_ostream.h"
64 #include "llvm/Transforms/Scalar.h"
65 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
66 #include "llvm/Transforms/Utils/Cloning.h"
67 #include "llvm/Transforms/Utils/LoopUtils.h"
68 #include "llvm/Transforms/Utils/SimplifyIndVar.h"
69 #include "llvm/Transforms/Utils/UnrollLoop.h"
70 #include <array>
71 
72 using namespace llvm;
73 
74 static cl::opt<unsigned> LoopSizeCutoff("irce-loop-size-cutoff", cl::Hidden,
75                                         cl::init(64));
76 
77 static cl::opt<bool> PrintChangedLoops("irce-print-changed-loops", cl::Hidden,
78                                        cl::init(false));
79 
80 static cl::opt<bool> PrintRangeChecks("irce-print-range-checks", cl::Hidden,
81                                       cl::init(false));
82 
83 static cl::opt<int> MaxExitProbReciprocal("irce-max-exit-prob-reciprocal",
84                                           cl::Hidden, cl::init(10));
85 
86 #define DEBUG_TYPE "irce"
87 
88 namespace {
89 
90 /// An inductive range check is conditional branch in a loop with
91 ///
92 ///  1. a very cold successor (i.e. the branch jumps to that successor very
93 ///     rarely)
94 ///
95 ///  and
96 ///
97 ///  2. a condition that is provably true for some contiguous range of values
98 ///     taken by the containing loop's induction variable.
99 ///
100 class InductiveRangeCheck {
101   // Classifies a range check
102   enum RangeCheckKind : unsigned {
103     // Range check of the form "0 <= I".
104     RANGE_CHECK_LOWER = 1,
105 
106     // Range check of the form "I < L" where L is known positive.
107     RANGE_CHECK_UPPER = 2,
108 
109     // The logical and of the RANGE_CHECK_LOWER and RANGE_CHECK_UPPER
110     // conditions.
111     RANGE_CHECK_BOTH = RANGE_CHECK_LOWER | RANGE_CHECK_UPPER,
112 
113     // Unrecognized range check condition.
114     RANGE_CHECK_UNKNOWN = (unsigned)-1
115   };
116 
117   static const char *rangeCheckKindToStr(RangeCheckKind);
118 
119   const SCEV *Offset;
120   const SCEV *Scale;
121   Value *Length;
122   BranchInst *Branch;
123   RangeCheckKind Kind;
124 
125   static RangeCheckKind parseRangeCheckICmp(Loop *L, ICmpInst *ICI,
126                                             ScalarEvolution &SE, Value *&Index,
127                                             Value *&Length);
128 
129   static InductiveRangeCheck::RangeCheckKind
130   parseRangeCheck(Loop *L, ScalarEvolution &SE, Value *Condition,
131                   const SCEV *&Index, Value *&UpperLimit);
132 
133   InductiveRangeCheck() :
134     Offset(nullptr), Scale(nullptr), Length(nullptr), Branch(nullptr) { }
135 
136 public:
137   const SCEV *getOffset() const { return Offset; }
138   const SCEV *getScale() const { return Scale; }
139   Value *getLength() const { return Length; }
140 
141   void print(raw_ostream &OS) const {
142     OS << "InductiveRangeCheck:\n";
143     OS << "  Kind: " << rangeCheckKindToStr(Kind) << "\n";
144     OS << "  Offset: ";
145     Offset->print(OS);
146     OS << "  Scale: ";
147     Scale->print(OS);
148     OS << "  Length: ";
149     if (Length)
150       Length->print(OS);
151     else
152       OS << "(null)";
153     OS << "\n  Branch: ";
154     getBranch()->print(OS);
155     OS << "\n";
156   }
157 
158 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
159   void dump() {
160     print(dbgs());
161   }
162 #endif
163 
164   BranchInst *getBranch() const { return Branch; }
165 
166   /// Represents an signed integer range [Range.getBegin(), Range.getEnd()).  If
167   /// R.getEnd() sle R.getBegin(), then R denotes the empty range.
168 
169   class Range {
170     const SCEV *Begin;
171     const SCEV *End;
172 
173   public:
174     Range(const SCEV *Begin, const SCEV *End) : Begin(Begin), End(End) {
175       assert(Begin->getType() == End->getType() && "ill-typed range!");
176     }
177 
178     Type *getType() const { return Begin->getType(); }
179     const SCEV *getBegin() const { return Begin; }
180     const SCEV *getEnd() const { return End; }
181   };
182 
183   typedef SpecificBumpPtrAllocator<InductiveRangeCheck> AllocatorTy;
184 
185   /// This is the value the condition of the branch needs to evaluate to for the
186   /// branch to take the hot successor (see (1) above).
187   bool getPassingDirection() { return true; }
188 
189   /// Computes a range for the induction variable (IndVar) in which the range
190   /// check is redundant and can be constant-folded away.  The induction
191   /// variable is not required to be the canonical {0,+,1} induction variable.
192   Optional<Range> computeSafeIterationSpace(ScalarEvolution &SE,
193                                             const SCEVAddRecExpr *IndVar,
194                                             IRBuilder<> &B) const;
195 
196   /// Create an inductive range check out of BI if possible, else return
197   /// nullptr.
198   static InductiveRangeCheck *create(AllocatorTy &Alloc, BranchInst *BI,
199                                      Loop *L, ScalarEvolution &SE,
200                                      BranchProbabilityInfo &BPI);
201 };
202 
203 class InductiveRangeCheckElimination : public LoopPass {
204   InductiveRangeCheck::AllocatorTy Allocator;
205 
206 public:
207   static char ID;
208   InductiveRangeCheckElimination() : LoopPass(ID) {
209     initializeInductiveRangeCheckEliminationPass(
210         *PassRegistry::getPassRegistry());
211   }
212 
213   void getAnalysisUsage(AnalysisUsage &AU) const override {
214     AU.addRequired<BranchProbabilityInfoWrapperPass>();
215     getLoopAnalysisUsage(AU);
216   }
217 
218   bool runOnLoop(Loop *L, LPPassManager &LPM) override;
219 };
220 
221 char InductiveRangeCheckElimination::ID = 0;
222 }
223 
224 INITIALIZE_PASS_BEGIN(InductiveRangeCheckElimination, "irce",
225                       "Inductive range check elimination", false, false)
226 INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass)
227 INITIALIZE_PASS_DEPENDENCY(LoopPass)
228 INITIALIZE_PASS_END(InductiveRangeCheckElimination, "irce",
229                     "Inductive range check elimination", false, false)
230 
231 const char *InductiveRangeCheck::rangeCheckKindToStr(
232     InductiveRangeCheck::RangeCheckKind RCK) {
233   switch (RCK) {
234   case InductiveRangeCheck::RANGE_CHECK_UNKNOWN:
235     return "RANGE_CHECK_UNKNOWN";
236 
237   case InductiveRangeCheck::RANGE_CHECK_UPPER:
238     return "RANGE_CHECK_UPPER";
239 
240   case InductiveRangeCheck::RANGE_CHECK_LOWER:
241     return "RANGE_CHECK_LOWER";
242 
243   case InductiveRangeCheck::RANGE_CHECK_BOTH:
244     return "RANGE_CHECK_BOTH";
245   }
246 
247   llvm_unreachable("unknown range check type!");
248 }
249 
250 /// Parse a single ICmp instruction, `ICI`, into a range check.  If `ICI`
251 /// cannot
252 /// be interpreted as a range check, return `RANGE_CHECK_UNKNOWN` and set
253 /// `Index` and `Length` to `nullptr`.  Otherwise set `Index` to the value
254 /// being
255 /// range checked, and set `Length` to the upper limit `Index` is being range
256 /// checked with if (and only if) the range check type is stronger or equal to
257 /// RANGE_CHECK_UPPER.
258 ///
259 InductiveRangeCheck::RangeCheckKind
260 InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI,
261                                          ScalarEvolution &SE, Value *&Index,
262                                          Value *&Length) {
263 
264   auto IsNonNegativeAndNotLoopVarying = [&SE, L](Value *V) {
265     const SCEV *S = SE.getSCEV(V);
266     if (isa<SCEVCouldNotCompute>(S))
267       return false;
268 
269     return SE.getLoopDisposition(S, L) == ScalarEvolution::LoopInvariant &&
270            SE.isKnownNonNegative(S);
271   };
272 
273   using namespace llvm::PatternMatch;
274 
275   ICmpInst::Predicate Pred = ICI->getPredicate();
276   Value *LHS = ICI->getOperand(0);
277   Value *RHS = ICI->getOperand(1);
278 
279   switch (Pred) {
280   default:
281     return RANGE_CHECK_UNKNOWN;
282 
283   case ICmpInst::ICMP_SLE:
284     std::swap(LHS, RHS);
285   // fallthrough
286   case ICmpInst::ICMP_SGE:
287     if (match(RHS, m_ConstantInt<0>())) {
288       Index = LHS;
289       return RANGE_CHECK_LOWER;
290     }
291     return RANGE_CHECK_UNKNOWN;
292 
293   case ICmpInst::ICMP_SLT:
294     std::swap(LHS, RHS);
295   // fallthrough
296   case ICmpInst::ICMP_SGT:
297     if (match(RHS, m_ConstantInt<-1>())) {
298       Index = LHS;
299       return RANGE_CHECK_LOWER;
300     }
301 
302     if (IsNonNegativeAndNotLoopVarying(LHS)) {
303       Index = RHS;
304       Length = LHS;
305       return RANGE_CHECK_UPPER;
306     }
307     return RANGE_CHECK_UNKNOWN;
308 
309   case ICmpInst::ICMP_ULT:
310     std::swap(LHS, RHS);
311   // fallthrough
312   case ICmpInst::ICMP_UGT:
313     if (IsNonNegativeAndNotLoopVarying(LHS)) {
314       Index = RHS;
315       Length = LHS;
316       return RANGE_CHECK_BOTH;
317     }
318     return RANGE_CHECK_UNKNOWN;
319   }
320 
321   llvm_unreachable("default clause returns!");
322 }
323 
324 /// Parses an arbitrary condition into a range check.  `Length` is set only if
325 /// the range check is recognized to be `RANGE_CHECK_UPPER` or stronger.
326 InductiveRangeCheck::RangeCheckKind
327 InductiveRangeCheck::parseRangeCheck(Loop *L, ScalarEvolution &SE,
328                                      Value *Condition, const SCEV *&Index,
329                                      Value *&Length) {
330   using namespace llvm::PatternMatch;
331 
332   Value *A = nullptr;
333   Value *B = nullptr;
334 
335   if (match(Condition, m_And(m_Value(A), m_Value(B)))) {
336     Value *IndexA = nullptr, *IndexB = nullptr;
337     Value *LengthA = nullptr, *LengthB = nullptr;
338     ICmpInst *ICmpA = dyn_cast<ICmpInst>(A), *ICmpB = dyn_cast<ICmpInst>(B);
339 
340     if (!ICmpA || !ICmpB)
341       return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
342 
343     auto RCKindA = parseRangeCheckICmp(L, ICmpA, SE, IndexA, LengthA);
344     auto RCKindB = parseRangeCheckICmp(L, ICmpB, SE, IndexB, LengthB);
345 
346     if (RCKindA == InductiveRangeCheck::RANGE_CHECK_UNKNOWN ||
347         RCKindB == InductiveRangeCheck::RANGE_CHECK_UNKNOWN)
348       return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
349 
350     if (IndexA != IndexB)
351       return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
352 
353     if (LengthA != nullptr && LengthB != nullptr && LengthA != LengthB)
354       return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
355 
356     Index = SE.getSCEV(IndexA);
357     if (isa<SCEVCouldNotCompute>(Index))
358       return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
359 
360     Length = LengthA == nullptr ? LengthB : LengthA;
361 
362     return (InductiveRangeCheck::RangeCheckKind)(RCKindA | RCKindB);
363   }
364 
365   if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) {
366     Value *IndexVal = nullptr;
367 
368     auto RCKind = parseRangeCheckICmp(L, ICI, SE, IndexVal, Length);
369 
370     if (RCKind == InductiveRangeCheck::RANGE_CHECK_UNKNOWN)
371       return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
372 
373     Index = SE.getSCEV(IndexVal);
374     if (isa<SCEVCouldNotCompute>(Index))
375       return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
376 
377     return RCKind;
378   }
379 
380   return InductiveRangeCheck::RANGE_CHECK_UNKNOWN;
381 }
382 
383 
384 InductiveRangeCheck *
385 InductiveRangeCheck::create(InductiveRangeCheck::AllocatorTy &A, BranchInst *BI,
386                             Loop *L, ScalarEvolution &SE,
387                             BranchProbabilityInfo &BPI) {
388 
389   if (BI->isUnconditional() || BI->getParent() == L->getLoopLatch())
390     return nullptr;
391 
392   BranchProbability LikelyTaken(15, 16);
393 
394   if (BPI.getEdgeProbability(BI->getParent(), (unsigned) 0) < LikelyTaken)
395     return nullptr;
396 
397   Value *Length = nullptr;
398   const SCEV *IndexSCEV = nullptr;
399 
400   auto RCKind = InductiveRangeCheck::parseRangeCheck(L, SE, BI->getCondition(),
401                                                      IndexSCEV, Length);
402 
403   if (RCKind == InductiveRangeCheck::RANGE_CHECK_UNKNOWN)
404     return nullptr;
405 
406   assert(IndexSCEV && "contract with SplitRangeCheckCondition!");
407   assert((!(RCKind & InductiveRangeCheck::RANGE_CHECK_UPPER) || Length) &&
408          "contract with SplitRangeCheckCondition!");
409 
410   const SCEVAddRecExpr *IndexAddRec = dyn_cast<SCEVAddRecExpr>(IndexSCEV);
411   bool IsAffineIndex =
412       IndexAddRec && (IndexAddRec->getLoop() == L) && IndexAddRec->isAffine();
413 
414   if (!IsAffineIndex)
415     return nullptr;
416 
417   InductiveRangeCheck *IRC = new (A.Allocate()) InductiveRangeCheck;
418   IRC->Length = Length;
419   IRC->Offset = IndexAddRec->getStart();
420   IRC->Scale = IndexAddRec->getStepRecurrence(SE);
421   IRC->Branch = BI;
422   IRC->Kind = RCKind;
423   return IRC;
424 }
425 
426 namespace {
427 
428 // Keeps track of the structure of a loop.  This is similar to llvm::Loop,
429 // except that it is more lightweight and can track the state of a loop through
430 // changing and potentially invalid IR.  This structure also formalizes the
431 // kinds of loops we can deal with -- ones that have a single latch that is also
432 // an exiting block *and* have a canonical induction variable.
433 struct LoopStructure {
434   const char *Tag;
435 
436   BasicBlock *Header;
437   BasicBlock *Latch;
438 
439   // `Latch's terminator instruction is `LatchBr', and it's `LatchBrExitIdx'th
440   // successor is `LatchExit', the exit block of the loop.
441   BranchInst *LatchBr;
442   BasicBlock *LatchExit;
443   unsigned LatchBrExitIdx;
444 
445   Value *IndVarNext;
446   Value *IndVarStart;
447   Value *LoopExitAt;
448   bool IndVarIncreasing;
449 
450   LoopStructure()
451       : Tag(""), Header(nullptr), Latch(nullptr), LatchBr(nullptr),
452         LatchExit(nullptr), LatchBrExitIdx(-1), IndVarNext(nullptr),
453         IndVarStart(nullptr), LoopExitAt(nullptr), IndVarIncreasing(false) {}
454 
455   template <typename M> LoopStructure map(M Map) const {
456     LoopStructure Result;
457     Result.Tag = Tag;
458     Result.Header = cast<BasicBlock>(Map(Header));
459     Result.Latch = cast<BasicBlock>(Map(Latch));
460     Result.LatchBr = cast<BranchInst>(Map(LatchBr));
461     Result.LatchExit = cast<BasicBlock>(Map(LatchExit));
462     Result.LatchBrExitIdx = LatchBrExitIdx;
463     Result.IndVarNext = Map(IndVarNext);
464     Result.IndVarStart = Map(IndVarStart);
465     Result.LoopExitAt = Map(LoopExitAt);
466     Result.IndVarIncreasing = IndVarIncreasing;
467     return Result;
468   }
469 
470   static Optional<LoopStructure> parseLoopStructure(ScalarEvolution &,
471                                                     BranchProbabilityInfo &BPI,
472                                                     Loop &,
473                                                     const char *&);
474 };
475 
476 /// This class is used to constrain loops to run within a given iteration space.
477 /// The algorithm this class implements is given a Loop and a range [Begin,
478 /// End).  The algorithm then tries to break out a "main loop" out of the loop
479 /// it is given in a way that the "main loop" runs with the induction variable
480 /// in a subset of [Begin, End).  The algorithm emits appropriate pre and post
481 /// loops to run any remaining iterations.  The pre loop runs any iterations in
482 /// which the induction variable is < Begin, and the post loop runs any
483 /// iterations in which the induction variable is >= End.
484 ///
485 class LoopConstrainer {
486   // The representation of a clone of the original loop we started out with.
487   struct ClonedLoop {
488     // The cloned blocks
489     std::vector<BasicBlock *> Blocks;
490 
491     // `Map` maps values in the clonee into values in the cloned version
492     ValueToValueMapTy Map;
493 
494     // An instance of `LoopStructure` for the cloned loop
495     LoopStructure Structure;
496   };
497 
498   // Result of rewriting the range of a loop.  See changeIterationSpaceEnd for
499   // more details on what these fields mean.
500   struct RewrittenRangeInfo {
501     BasicBlock *PseudoExit;
502     BasicBlock *ExitSelector;
503     std::vector<PHINode *> PHIValuesAtPseudoExit;
504     PHINode *IndVarEnd;
505 
506     RewrittenRangeInfo()
507         : PseudoExit(nullptr), ExitSelector(nullptr), IndVarEnd(nullptr) {}
508   };
509 
510   // Calculated subranges we restrict the iteration space of the main loop to.
511   // See the implementation of `calculateSubRanges' for more details on how
512   // these fields are computed.  `LowLimit` is None if there is no restriction
513   // on low end of the restricted iteration space of the main loop.  `HighLimit`
514   // is None if there is no restriction on high end of the restricted iteration
515   // space of the main loop.
516 
517   struct SubRanges {
518     Optional<const SCEV *> LowLimit;
519     Optional<const SCEV *> HighLimit;
520   };
521 
522   // A utility function that does a `replaceUsesOfWith' on the incoming block
523   // set of a `PHINode' -- replaces instances of `Block' in the `PHINode's
524   // incoming block list with `ReplaceBy'.
525   static void replacePHIBlock(PHINode *PN, BasicBlock *Block,
526                               BasicBlock *ReplaceBy);
527 
528   // Compute a safe set of limits for the main loop to run in -- effectively the
529   // intersection of `Range' and the iteration space of the original loop.
530   // Return None if unable to compute the set of subranges.
531   //
532   Optional<SubRanges> calculateSubRanges() const;
533 
534   // Clone `OriginalLoop' and return the result in CLResult.  The IR after
535   // running `cloneLoop' is well formed except for the PHI nodes in CLResult --
536   // the PHI nodes say that there is an incoming edge from `OriginalPreheader`
537   // but there is no such edge.
538   //
539   void cloneLoop(ClonedLoop &CLResult, const char *Tag) const;
540 
541   // Rewrite the iteration space of the loop denoted by (LS, Preheader). The
542   // iteration space of the rewritten loop ends at ExitLoopAt.  The start of the
543   // iteration space is not changed.  `ExitLoopAt' is assumed to be slt
544   // `OriginalHeaderCount'.
545   //
546   // If there are iterations left to execute, control is made to jump to
547   // `ContinuationBlock', otherwise they take the normal loop exit.  The
548   // returned `RewrittenRangeInfo' object is populated as follows:
549   //
550   //  .PseudoExit is a basic block that unconditionally branches to
551   //      `ContinuationBlock'.
552   //
553   //  .ExitSelector is a basic block that decides, on exit from the loop,
554   //      whether to branch to the "true" exit or to `PseudoExit'.
555   //
556   //  .PHIValuesAtPseudoExit are PHINodes in `PseudoExit' that compute the value
557   //      for each PHINode in the loop header on taking the pseudo exit.
558   //
559   // After changeIterationSpaceEnd, `Preheader' is no longer a legitimate
560   // preheader because it is made to branch to the loop header only
561   // conditionally.
562   //
563   RewrittenRangeInfo
564   changeIterationSpaceEnd(const LoopStructure &LS, BasicBlock *Preheader,
565                           Value *ExitLoopAt,
566                           BasicBlock *ContinuationBlock) const;
567 
568   // The loop denoted by `LS' has `OldPreheader' as its preheader.  This
569   // function creates a new preheader for `LS' and returns it.
570   //
571   BasicBlock *createPreheader(const LoopStructure &LS, BasicBlock *OldPreheader,
572                               const char *Tag) const;
573 
574   // `ContinuationBlockAndPreheader' was the continuation block for some call to
575   // `changeIterationSpaceEnd' and is the preheader to the loop denoted by `LS'.
576   // This function rewrites the PHI nodes in `LS.Header' to start with the
577   // correct value.
578   void rewriteIncomingValuesForPHIs(
579       LoopStructure &LS, BasicBlock *ContinuationBlockAndPreheader,
580       const LoopConstrainer::RewrittenRangeInfo &RRI) const;
581 
582   // Even though we do not preserve any passes at this time, we at least need to
583   // keep the parent loop structure consistent.  The `LPPassManager' seems to
584   // verify this after running a loop pass.  This function adds the list of
585   // blocks denoted by BBs to this loops parent loop if required.
586   void addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs);
587 
588   // Some global state.
589   Function &F;
590   LLVMContext &Ctx;
591   ScalarEvolution &SE;
592 
593   // Information about the original loop we started out with.
594   Loop &OriginalLoop;
595   LoopInfo &OriginalLoopInfo;
596   const SCEV *LatchTakenCount;
597   BasicBlock *OriginalPreheader;
598 
599   // The preheader of the main loop.  This may or may not be different from
600   // `OriginalPreheader'.
601   BasicBlock *MainLoopPreheader;
602 
603   // The range we need to run the main loop in.
604   InductiveRangeCheck::Range Range;
605 
606   // The structure of the main loop (see comment at the beginning of this class
607   // for a definition)
608   LoopStructure MainLoopStructure;
609 
610 public:
611   LoopConstrainer(Loop &L, LoopInfo &LI, const LoopStructure &LS,
612                   ScalarEvolution &SE, InductiveRangeCheck::Range R)
613       : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()),
614         SE(SE), OriginalLoop(L), OriginalLoopInfo(LI), LatchTakenCount(nullptr),
615         OriginalPreheader(nullptr), MainLoopPreheader(nullptr), Range(R),
616         MainLoopStructure(LS) {}
617 
618   // Entry point for the algorithm.  Returns true on success.
619   bool run();
620 };
621 
622 }
623 
624 void LoopConstrainer::replacePHIBlock(PHINode *PN, BasicBlock *Block,
625                                       BasicBlock *ReplaceBy) {
626   for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
627     if (PN->getIncomingBlock(i) == Block)
628       PN->setIncomingBlock(i, ReplaceBy);
629 }
630 
631 static bool CanBeSMax(ScalarEvolution &SE, const SCEV *S) {
632   APInt SMax =
633       APInt::getSignedMaxValue(cast<IntegerType>(S->getType())->getBitWidth());
634   return SE.getSignedRange(S).contains(SMax) &&
635          SE.getUnsignedRange(S).contains(SMax);
636 }
637 
638 static bool CanBeSMin(ScalarEvolution &SE, const SCEV *S) {
639   APInt SMin =
640       APInt::getSignedMinValue(cast<IntegerType>(S->getType())->getBitWidth());
641   return SE.getSignedRange(S).contains(SMin) &&
642          SE.getUnsignedRange(S).contains(SMin);
643 }
644 
645 Optional<LoopStructure>
646 LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BPI,
647                                   Loop &L, const char *&FailureReason) {
648   assert(L.isLoopSimplifyForm() && "should follow from addRequired<>");
649 
650   BasicBlock *Latch = L.getLoopLatch();
651   if (!L.isLoopExiting(Latch)) {
652     FailureReason = "no loop latch";
653     return None;
654   }
655 
656   BasicBlock *Header = L.getHeader();
657   BasicBlock *Preheader = L.getLoopPreheader();
658   if (!Preheader) {
659     FailureReason = "no preheader";
660     return None;
661   }
662 
663   BranchInst *LatchBr = dyn_cast<BranchInst>(&*Latch->rbegin());
664   if (!LatchBr || LatchBr->isUnconditional()) {
665     FailureReason = "latch terminator not conditional branch";
666     return None;
667   }
668 
669   unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0;
670 
671   BranchProbability ExitProbability =
672     BPI.getEdgeProbability(LatchBr->getParent(), LatchBrExitIdx);
673 
674   if (ExitProbability > BranchProbability(1, MaxExitProbReciprocal)) {
675     FailureReason = "short running loop, not profitable";
676     return None;
677   }
678 
679   ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition());
680   if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) {
681     FailureReason = "latch terminator branch not conditional on integral icmp";
682     return None;
683   }
684 
685   const SCEV *LatchCount = SE.getExitCount(&L, Latch);
686   if (isa<SCEVCouldNotCompute>(LatchCount)) {
687     FailureReason = "could not compute latch count";
688     return None;
689   }
690 
691   ICmpInst::Predicate Pred = ICI->getPredicate();
692   Value *LeftValue = ICI->getOperand(0);
693   const SCEV *LeftSCEV = SE.getSCEV(LeftValue);
694   IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType());
695 
696   Value *RightValue = ICI->getOperand(1);
697   const SCEV *RightSCEV = SE.getSCEV(RightValue);
698 
699   // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence.
700   if (!isa<SCEVAddRecExpr>(LeftSCEV)) {
701     if (isa<SCEVAddRecExpr>(RightSCEV)) {
702       std::swap(LeftSCEV, RightSCEV);
703       std::swap(LeftValue, RightValue);
704       Pred = ICmpInst::getSwappedPredicate(Pred);
705     } else {
706       FailureReason = "no add recurrences in the icmp";
707       return None;
708     }
709   }
710 
711   auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) {
712     if (AR->getNoWrapFlags(SCEV::FlagNSW))
713       return true;
714 
715     IntegerType *Ty = cast<IntegerType>(AR->getType());
716     IntegerType *WideTy =
717         IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2);
718 
719     const SCEVAddRecExpr *ExtendAfterOp =
720         dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy));
721     if (ExtendAfterOp) {
722       const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy);
723       const SCEV *ExtendedStep =
724           SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy);
725 
726       bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart &&
727                           ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep;
728 
729       if (NoSignedWrap)
730         return true;
731     }
732 
733     // We may have proved this when computing the sign extension above.
734     return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap;
735   };
736 
737   auto IsInductionVar = [&](const SCEVAddRecExpr *AR, bool &IsIncreasing) {
738     if (!AR->isAffine())
739       return false;
740 
741     // Currently we only work with induction variables that have been proved to
742     // not wrap.  This restriction can potentially be lifted in the future.
743 
744     if (!HasNoSignedWrap(AR))
745       return false;
746 
747     if (const SCEVConstant *StepExpr =
748             dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE))) {
749       ConstantInt *StepCI = StepExpr->getValue();
750       if (StepCI->isOne() || StepCI->isMinusOne()) {
751         IsIncreasing = StepCI->isOne();
752         return true;
753       }
754     }
755 
756     return false;
757   };
758 
759   // `ICI` is interpreted as taking the backedge if the *next* value of the
760   // induction variable satisfies some constraint.
761 
762   const SCEVAddRecExpr *IndVarNext = cast<SCEVAddRecExpr>(LeftSCEV);
763   bool IsIncreasing = false;
764   if (!IsInductionVar(IndVarNext, IsIncreasing)) {
765     FailureReason = "LHS in icmp not induction variable";
766     return None;
767   }
768 
769   ConstantInt *One = ConstantInt::get(IndVarTy, 1);
770   // TODO: generalize the predicates here to also match their unsigned variants.
771   if (IsIncreasing) {
772     bool FoundExpectedPred =
773         (Pred == ICmpInst::ICMP_SLT && LatchBrExitIdx == 1) ||
774         (Pred == ICmpInst::ICMP_SGT && LatchBrExitIdx == 0);
775 
776     if (!FoundExpectedPred) {
777       FailureReason = "expected icmp slt semantically, found something else";
778       return None;
779     }
780 
781     if (LatchBrExitIdx == 0) {
782       if (CanBeSMax(SE, RightSCEV)) {
783         // TODO: this restriction is easily removable -- we just have to
784         // remember that the icmp was an slt and not an sle.
785         FailureReason = "limit may overflow when coercing sle to slt";
786         return None;
787       }
788 
789       IRBuilder<> B(&*Preheader->rbegin());
790       RightValue = B.CreateAdd(RightValue, One);
791     }
792 
793   } else {
794     bool FoundExpectedPred =
795         (Pred == ICmpInst::ICMP_SGT && LatchBrExitIdx == 1) ||
796         (Pred == ICmpInst::ICMP_SLT && LatchBrExitIdx == 0);
797 
798     if (!FoundExpectedPred) {
799       FailureReason = "expected icmp sgt semantically, found something else";
800       return None;
801     }
802 
803     if (LatchBrExitIdx == 0) {
804       if (CanBeSMin(SE, RightSCEV)) {
805         // TODO: this restriction is easily removable -- we just have to
806         // remember that the icmp was an sgt and not an sge.
807         FailureReason = "limit may overflow when coercing sge to sgt";
808         return None;
809       }
810 
811       IRBuilder<> B(&*Preheader->rbegin());
812       RightValue = B.CreateSub(RightValue, One);
813     }
814   }
815 
816   const SCEV *StartNext = IndVarNext->getStart();
817   const SCEV *Addend = SE.getNegativeSCEV(IndVarNext->getStepRecurrence(SE));
818   const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend);
819 
820   BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx);
821 
822   assert(SE.getLoopDisposition(LatchCount, &L) ==
823              ScalarEvolution::LoopInvariant &&
824          "loop variant exit count doesn't make sense!");
825 
826   assert(!L.contains(LatchExit) && "expected an exit block!");
827   const DataLayout &DL = Preheader->getModule()->getDataLayout();
828   Value *IndVarStartV =
829       SCEVExpander(SE, DL, "irce")
830           .expandCodeFor(IndVarStart, IndVarTy, &*Preheader->rbegin());
831   IndVarStartV->setName("indvar.start");
832 
833   LoopStructure Result;
834 
835   Result.Tag = "main";
836   Result.Header = Header;
837   Result.Latch = Latch;
838   Result.LatchBr = LatchBr;
839   Result.LatchExit = LatchExit;
840   Result.LatchBrExitIdx = LatchBrExitIdx;
841   Result.IndVarStart = IndVarStartV;
842   Result.IndVarNext = LeftValue;
843   Result.IndVarIncreasing = IsIncreasing;
844   Result.LoopExitAt = RightValue;
845 
846   FailureReason = nullptr;
847 
848   return Result;
849 }
850 
851 Optional<LoopConstrainer::SubRanges>
852 LoopConstrainer::calculateSubRanges() const {
853   IntegerType *Ty = cast<IntegerType>(LatchTakenCount->getType());
854 
855   if (Range.getType() != Ty)
856     return None;
857 
858   LoopConstrainer::SubRanges Result;
859 
860   // I think we can be more aggressive here and make this nuw / nsw if the
861   // addition that feeds into the icmp for the latch's terminating branch is nuw
862   // / nsw.  In any case, a wrapping 2's complement addition is safe.
863   ConstantInt *One = ConstantInt::get(Ty, 1);
864   const SCEV *Start = SE.getSCEV(MainLoopStructure.IndVarStart);
865   const SCEV *End = SE.getSCEV(MainLoopStructure.LoopExitAt);
866 
867   bool Increasing = MainLoopStructure.IndVarIncreasing;
868 
869   // We compute `Smallest` and `Greatest` such that [Smallest, Greatest) is the
870   // range of values the induction variable takes.
871 
872   const SCEV *Smallest = nullptr, *Greatest = nullptr;
873 
874   if (Increasing) {
875     Smallest = Start;
876     Greatest = End;
877   } else {
878     // These two computations may sign-overflow.  Here is why that is okay:
879     //
880     // We know that the induction variable does not sign-overflow on any
881     // iteration except the last one, and it starts at `Start` and ends at
882     // `End`, decrementing by one every time.
883     //
884     //  * if `Smallest` sign-overflows we know `End` is `INT_SMAX`. Since the
885     //    induction variable is decreasing we know that that the smallest value
886     //    the loop body is actually executed with is `INT_SMIN` == `Smallest`.
887     //
888     //  * if `Greatest` sign-overflows, we know it can only be `INT_SMIN`.  In
889     //    that case, `Clamp` will always return `Smallest` and
890     //    [`Result.LowLimit`, `Result.HighLimit`) = [`Smallest`, `Smallest`)
891     //    will be an empty range.  Returning an empty range is always safe.
892     //
893 
894     Smallest = SE.getAddExpr(End, SE.getSCEV(One));
895     Greatest = SE.getAddExpr(Start, SE.getSCEV(One));
896   }
897 
898   auto Clamp = [this, Smallest, Greatest](const SCEV *S) {
899     return SE.getSMaxExpr(Smallest, SE.getSMinExpr(Greatest, S));
900   };
901 
902   // In some cases we can prove that we don't need a pre or post loop
903 
904   bool ProvablyNoPreloop =
905       SE.isKnownPredicate(ICmpInst::ICMP_SLE, Range.getBegin(), Smallest);
906   if (!ProvablyNoPreloop)
907     Result.LowLimit = Clamp(Range.getBegin());
908 
909   bool ProvablyNoPostLoop =
910       SE.isKnownPredicate(ICmpInst::ICMP_SLE, Greatest, Range.getEnd());
911   if (!ProvablyNoPostLoop)
912     Result.HighLimit = Clamp(Range.getEnd());
913 
914   return Result;
915 }
916 
917 void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result,
918                                 const char *Tag) const {
919   for (BasicBlock *BB : OriginalLoop.getBlocks()) {
920     BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F);
921     Result.Blocks.push_back(Clone);
922     Result.Map[BB] = Clone;
923   }
924 
925   auto GetClonedValue = [&Result](Value *V) {
926     assert(V && "null values not in domain!");
927     auto It = Result.Map.find(V);
928     if (It == Result.Map.end())
929       return V;
930     return static_cast<Value *>(It->second);
931   };
932 
933   Result.Structure = MainLoopStructure.map(GetClonedValue);
934   Result.Structure.Tag = Tag;
935 
936   for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) {
937     BasicBlock *ClonedBB = Result.Blocks[i];
938     BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i];
939 
940     assert(Result.Map[OriginalBB] == ClonedBB && "invariant!");
941 
942     for (Instruction &I : *ClonedBB)
943       RemapInstruction(&I, Result.Map,
944                        RF_NoModuleLevelChanges | RF_IgnoreMissingEntries);
945 
946     // Exit blocks will now have one more predecessor and their PHI nodes need
947     // to be edited to reflect that.  No phi nodes need to be introduced because
948     // the loop is in LCSSA.
949 
950     for (auto SBBI = succ_begin(OriginalBB), SBBE = succ_end(OriginalBB);
951          SBBI != SBBE; ++SBBI) {
952 
953       if (OriginalLoop.contains(*SBBI))
954         continue; // not an exit block
955 
956       for (Instruction &I : **SBBI) {
957         if (!isa<PHINode>(&I))
958           break;
959 
960         PHINode *PN = cast<PHINode>(&I);
961         Value *OldIncoming = PN->getIncomingValueForBlock(OriginalBB);
962         PN->addIncoming(GetClonedValue(OldIncoming), ClonedBB);
963       }
964     }
965   }
966 }
967 
968 LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd(
969     const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt,
970     BasicBlock *ContinuationBlock) const {
971 
972   // We start with a loop with a single latch:
973   //
974   //    +--------------------+
975   //    |                    |
976   //    |     preheader      |
977   //    |                    |
978   //    +--------+-----------+
979   //             |      ----------------\
980   //             |     /                |
981   //    +--------v----v------+          |
982   //    |                    |          |
983   //    |      header        |          |
984   //    |                    |          |
985   //    +--------------------+          |
986   //                                    |
987   //            .....                   |
988   //                                    |
989   //    +--------------------+          |
990   //    |                    |          |
991   //    |       latch        >----------/
992   //    |                    |
993   //    +-------v------------+
994   //            |
995   //            |
996   //            |   +--------------------+
997   //            |   |                    |
998   //            +--->   original exit    |
999   //                |                    |
1000   //                +--------------------+
1001   //
1002   // We change the control flow to look like
1003   //
1004   //
1005   //    +--------------------+
1006   //    |                    |
1007   //    |     preheader      >-------------------------+
1008   //    |                    |                         |
1009   //    +--------v-----------+                         |
1010   //             |    /-------------+                  |
1011   //             |   /              |                  |
1012   //    +--------v--v--------+      |                  |
1013   //    |                    |      |                  |
1014   //    |      header        |      |   +--------+     |
1015   //    |                    |      |   |        |     |
1016   //    +--------------------+      |   |  +-----v-----v-----------+
1017   //                                |   |  |                       |
1018   //                                |   |  |     .pseudo.exit      |
1019   //                                |   |  |                       |
1020   //                                |   |  +-----------v-----------+
1021   //                                |   |              |
1022   //            .....               |   |              |
1023   //                                |   |     +--------v-------------+
1024   //    +--------------------+      |   |     |                      |
1025   //    |                    |      |   |     |   ContinuationBlock  |
1026   //    |       latch        >------+   |     |                      |
1027   //    |                    |          |     +----------------------+
1028   //    +---------v----------+          |
1029   //              |                     |
1030   //              |                     |
1031   //              |     +---------------^-----+
1032   //              |     |                     |
1033   //              +----->    .exit.selector   |
1034   //                    |                     |
1035   //                    +----------v----------+
1036   //                               |
1037   //     +--------------------+    |
1038   //     |                    |    |
1039   //     |   original exit    <----+
1040   //     |                    |
1041   //     +--------------------+
1042   //
1043 
1044   RewrittenRangeInfo RRI;
1045 
1046   auto BBInsertLocation = std::next(Function::iterator(LS.Latch));
1047   RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector",
1048                                         &F, &*BBInsertLocation);
1049   RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F,
1050                                       &*BBInsertLocation);
1051 
1052   BranchInst *PreheaderJump = cast<BranchInst>(&*Preheader->rbegin());
1053   bool Increasing = LS.IndVarIncreasing;
1054 
1055   IRBuilder<> B(PreheaderJump);
1056 
1057   // EnterLoopCond - is it okay to start executing this `LS'?
1058   Value *EnterLoopCond = Increasing
1059                              ? B.CreateICmpSLT(LS.IndVarStart, ExitSubloopAt)
1060                              : B.CreateICmpSGT(LS.IndVarStart, ExitSubloopAt);
1061 
1062   B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit);
1063   PreheaderJump->eraseFromParent();
1064 
1065   LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector);
1066   B.SetInsertPoint(LS.LatchBr);
1067   Value *TakeBackedgeLoopCond =
1068       Increasing ? B.CreateICmpSLT(LS.IndVarNext, ExitSubloopAt)
1069                  : B.CreateICmpSGT(LS.IndVarNext, ExitSubloopAt);
1070   Value *CondForBranch = LS.LatchBrExitIdx == 1
1071                              ? TakeBackedgeLoopCond
1072                              : B.CreateNot(TakeBackedgeLoopCond);
1073 
1074   LS.LatchBr->setCondition(CondForBranch);
1075 
1076   B.SetInsertPoint(RRI.ExitSelector);
1077 
1078   // IterationsLeft - are there any more iterations left, given the original
1079   // upper bound on the induction variable?  If not, we branch to the "real"
1080   // exit.
1081   Value *IterationsLeft = Increasing
1082                               ? B.CreateICmpSLT(LS.IndVarNext, LS.LoopExitAt)
1083                               : B.CreateICmpSGT(LS.IndVarNext, LS.LoopExitAt);
1084   B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit);
1085 
1086   BranchInst *BranchToContinuation =
1087       BranchInst::Create(ContinuationBlock, RRI.PseudoExit);
1088 
1089   // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of
1090   // each of the PHI nodes in the loop header.  This feeds into the initial
1091   // value of the same PHI nodes if/when we continue execution.
1092   for (Instruction &I : *LS.Header) {
1093     if (!isa<PHINode>(&I))
1094       break;
1095 
1096     PHINode *PN = cast<PHINode>(&I);
1097 
1098     PHINode *NewPHI = PHINode::Create(PN->getType(), 2, PN->getName() + ".copy",
1099                                       BranchToContinuation);
1100 
1101     NewPHI->addIncoming(PN->getIncomingValueForBlock(Preheader), Preheader);
1102     NewPHI->addIncoming(PN->getIncomingValueForBlock(LS.Latch),
1103                         RRI.ExitSelector);
1104     RRI.PHIValuesAtPseudoExit.push_back(NewPHI);
1105   }
1106 
1107   RRI.IndVarEnd = PHINode::Create(LS.IndVarNext->getType(), 2, "indvar.end",
1108                                   BranchToContinuation);
1109   RRI.IndVarEnd->addIncoming(LS.IndVarStart, Preheader);
1110   RRI.IndVarEnd->addIncoming(LS.IndVarNext, RRI.ExitSelector);
1111 
1112   // The latch exit now has a branch from `RRI.ExitSelector' instead of
1113   // `LS.Latch'.  The PHI nodes need to be updated to reflect that.
1114   for (Instruction &I : *LS.LatchExit) {
1115     if (PHINode *PN = dyn_cast<PHINode>(&I))
1116       replacePHIBlock(PN, LS.Latch, RRI.ExitSelector);
1117     else
1118       break;
1119   }
1120 
1121   return RRI;
1122 }
1123 
1124 void LoopConstrainer::rewriteIncomingValuesForPHIs(
1125     LoopStructure &LS, BasicBlock *ContinuationBlock,
1126     const LoopConstrainer::RewrittenRangeInfo &RRI) const {
1127 
1128   unsigned PHIIndex = 0;
1129   for (Instruction &I : *LS.Header) {
1130     if (!isa<PHINode>(&I))
1131       break;
1132 
1133     PHINode *PN = cast<PHINode>(&I);
1134 
1135     for (unsigned i = 0, e = PN->getNumIncomingValues(); i < e; ++i)
1136       if (PN->getIncomingBlock(i) == ContinuationBlock)
1137         PN->setIncomingValue(i, RRI.PHIValuesAtPseudoExit[PHIIndex++]);
1138   }
1139 
1140   LS.IndVarStart = RRI.IndVarEnd;
1141 }
1142 
1143 BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS,
1144                                              BasicBlock *OldPreheader,
1145                                              const char *Tag) const {
1146 
1147   BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header);
1148   BranchInst::Create(LS.Header, Preheader);
1149 
1150   for (Instruction &I : *LS.Header) {
1151     if (!isa<PHINode>(&I))
1152       break;
1153 
1154     PHINode *PN = cast<PHINode>(&I);
1155     for (unsigned i = 0, e = PN->getNumIncomingValues(); i < e; ++i)
1156       replacePHIBlock(PN, OldPreheader, Preheader);
1157   }
1158 
1159   return Preheader;
1160 }
1161 
1162 void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) {
1163   Loop *ParentLoop = OriginalLoop.getParentLoop();
1164   if (!ParentLoop)
1165     return;
1166 
1167   for (BasicBlock *BB : BBs)
1168     ParentLoop->addBasicBlockToLoop(BB, OriginalLoopInfo);
1169 }
1170 
1171 bool LoopConstrainer::run() {
1172   BasicBlock *Preheader = nullptr;
1173   LatchTakenCount = SE.getExitCount(&OriginalLoop, MainLoopStructure.Latch);
1174   Preheader = OriginalLoop.getLoopPreheader();
1175   assert(!isa<SCEVCouldNotCompute>(LatchTakenCount) && Preheader != nullptr &&
1176          "preconditions!");
1177 
1178   OriginalPreheader = Preheader;
1179   MainLoopPreheader = Preheader;
1180 
1181   Optional<SubRanges> MaybeSR = calculateSubRanges();
1182   if (!MaybeSR.hasValue()) {
1183     DEBUG(dbgs() << "irce: could not compute subranges\n");
1184     return false;
1185   }
1186 
1187   SubRanges SR = MaybeSR.getValue();
1188   bool Increasing = MainLoopStructure.IndVarIncreasing;
1189   IntegerType *IVTy =
1190       cast<IntegerType>(MainLoopStructure.IndVarNext->getType());
1191 
1192   SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "irce");
1193   Instruction *InsertPt = OriginalPreheader->getTerminator();
1194 
1195   // It would have been better to make `PreLoop' and `PostLoop'
1196   // `Optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy
1197   // constructor.
1198   ClonedLoop PreLoop, PostLoop;
1199   bool NeedsPreLoop =
1200       Increasing ? SR.LowLimit.hasValue() : SR.HighLimit.hasValue();
1201   bool NeedsPostLoop =
1202       Increasing ? SR.HighLimit.hasValue() : SR.LowLimit.hasValue();
1203 
1204   Value *ExitPreLoopAt = nullptr;
1205   Value *ExitMainLoopAt = nullptr;
1206   const SCEVConstant *MinusOneS =
1207       cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */));
1208 
1209   if (NeedsPreLoop) {
1210     const SCEV *ExitPreLoopAtSCEV = nullptr;
1211 
1212     if (Increasing)
1213       ExitPreLoopAtSCEV = *SR.LowLimit;
1214     else {
1215       if (CanBeSMin(SE, *SR.HighLimit)) {
1216         DEBUG(dbgs() << "irce: could not prove no-overflow when computing "
1217                      << "preloop exit limit.  HighLimit = " << *(*SR.HighLimit)
1218                      << "\n");
1219         return false;
1220       }
1221       ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS);
1222     }
1223 
1224     ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt);
1225     ExitPreLoopAt->setName("exit.preloop.at");
1226   }
1227 
1228   if (NeedsPostLoop) {
1229     const SCEV *ExitMainLoopAtSCEV = nullptr;
1230 
1231     if (Increasing)
1232       ExitMainLoopAtSCEV = *SR.HighLimit;
1233     else {
1234       if (CanBeSMin(SE, *SR.LowLimit)) {
1235         DEBUG(dbgs() << "irce: could not prove no-overflow when computing "
1236                      << "mainloop exit limit.  LowLimit = " << *(*SR.LowLimit)
1237                      << "\n");
1238         return false;
1239       }
1240       ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS);
1241     }
1242 
1243     ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt);
1244     ExitMainLoopAt->setName("exit.mainloop.at");
1245   }
1246 
1247   // We clone these ahead of time so that we don't have to deal with changing
1248   // and temporarily invalid IR as we transform the loops.
1249   if (NeedsPreLoop)
1250     cloneLoop(PreLoop, "preloop");
1251   if (NeedsPostLoop)
1252     cloneLoop(PostLoop, "postloop");
1253 
1254   RewrittenRangeInfo PreLoopRRI;
1255 
1256   if (NeedsPreLoop) {
1257     Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header,
1258                                                   PreLoop.Structure.Header);
1259 
1260     MainLoopPreheader =
1261         createPreheader(MainLoopStructure, Preheader, "mainloop");
1262     PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader,
1263                                          ExitPreLoopAt, MainLoopPreheader);
1264     rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader,
1265                                  PreLoopRRI);
1266   }
1267 
1268   BasicBlock *PostLoopPreheader = nullptr;
1269   RewrittenRangeInfo PostLoopRRI;
1270 
1271   if (NeedsPostLoop) {
1272     PostLoopPreheader =
1273         createPreheader(PostLoop.Structure, Preheader, "postloop");
1274     PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader,
1275                                           ExitMainLoopAt, PostLoopPreheader);
1276     rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader,
1277                                  PostLoopRRI);
1278   }
1279 
1280   BasicBlock *NewMainLoopPreheader =
1281       MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr;
1282   BasicBlock *NewBlocks[] = {PostLoopPreheader,        PreLoopRRI.PseudoExit,
1283                              PreLoopRRI.ExitSelector,  PostLoopRRI.PseudoExit,
1284                              PostLoopRRI.ExitSelector, NewMainLoopPreheader};
1285 
1286   // Some of the above may be nullptr, filter them out before passing to
1287   // addToParentLoopIfNeeded.
1288   auto NewBlocksEnd =
1289       std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr);
1290 
1291   addToParentLoopIfNeeded(makeArrayRef(std::begin(NewBlocks), NewBlocksEnd));
1292   addToParentLoopIfNeeded(PreLoop.Blocks);
1293   addToParentLoopIfNeeded(PostLoop.Blocks);
1294 
1295   return true;
1296 }
1297 
1298 /// Computes and returns a range of values for the induction variable (IndVar)
1299 /// in which the range check can be safely elided.  If it cannot compute such a
1300 /// range, returns None.
1301 Optional<InductiveRangeCheck::Range>
1302 InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE,
1303                                                const SCEVAddRecExpr *IndVar,
1304                                                IRBuilder<> &) const {
1305   // IndVar is of the form "A + B * I" (where "I" is the canonical induction
1306   // variable, that may or may not exist as a real llvm::Value in the loop) and
1307   // this inductive range check is a range check on the "C + D * I" ("C" is
1308   // getOffset() and "D" is getScale()).  We rewrite the value being range
1309   // checked to "M + N * IndVar" where "N" = "D * B^(-1)" and "M" = "C - NA".
1310   // Currently we support this only for "B" = "D" = { 1 or -1 }, but the code
1311   // can be generalized as needed.
1312   //
1313   // The actual inequalities we solve are of the form
1314   //
1315   //   0 <= M + 1 * IndVar < L given L >= 0  (i.e. N == 1)
1316   //
1317   // The inequality is satisfied by -M <= IndVar < (L - M) [^1].  All additions
1318   // and subtractions are twos-complement wrapping and comparisons are signed.
1319   //
1320   // Proof:
1321   //
1322   //   If there exists IndVar such that -M <= IndVar < (L - M) then it follows
1323   //   that -M <= (-M + L) [== Eq. 1].  Since L >= 0, if (-M + L) sign-overflows
1324   //   then (-M + L) < (-M).  Hence by [Eq. 1], (-M + L) could not have
1325   //   overflown.
1326   //
1327   //   This means IndVar = t + (-M) for t in [0, L).  Hence (IndVar + M) = t.
1328   //   Hence 0 <= (IndVar + M) < L
1329 
1330   // [^1]: Note that the solution does _not_ apply if L < 0; consider values M =
1331   // 127, IndVar = 126 and L = -2 in an i8 world.
1332 
1333   if (!IndVar->isAffine())
1334     return None;
1335 
1336   const SCEV *A = IndVar->getStart();
1337   const SCEVConstant *B = dyn_cast<SCEVConstant>(IndVar->getStepRecurrence(SE));
1338   if (!B)
1339     return None;
1340 
1341   const SCEV *C = getOffset();
1342   const SCEVConstant *D = dyn_cast<SCEVConstant>(getScale());
1343   if (D != B)
1344     return None;
1345 
1346   ConstantInt *ConstD = D->getValue();
1347   if (!(ConstD->isMinusOne() || ConstD->isOne()))
1348     return None;
1349 
1350   const SCEV *M = SE.getMinusSCEV(C, A);
1351 
1352   const SCEV *Begin = SE.getNegativeSCEV(M);
1353   const SCEV *UpperLimit = nullptr;
1354 
1355   // We strengthen "0 <= I" to "0 <= I < INT_SMAX" and "I < L" to "0 <= I < L".
1356   // We can potentially do much better here.
1357   if (Value *V = getLength()) {
1358     UpperLimit = SE.getSCEV(V);
1359   } else {
1360     assert(Kind == InductiveRangeCheck::RANGE_CHECK_LOWER && "invariant!");
1361     unsigned BitWidth = cast<IntegerType>(IndVar->getType())->getBitWidth();
1362     UpperLimit = SE.getConstant(APInt::getSignedMaxValue(BitWidth));
1363   }
1364 
1365   const SCEV *End = SE.getMinusSCEV(UpperLimit, M);
1366   return InductiveRangeCheck::Range(Begin, End);
1367 }
1368 
1369 static Optional<InductiveRangeCheck::Range>
1370 IntersectRange(ScalarEvolution &SE,
1371                const Optional<InductiveRangeCheck::Range> &R1,
1372                const InductiveRangeCheck::Range &R2, IRBuilder<> &B) {
1373   if (!R1.hasValue())
1374     return R2;
1375   auto &R1Value = R1.getValue();
1376 
1377   // TODO: we could widen the smaller range and have this work; but for now we
1378   // bail out to keep things simple.
1379   if (R1Value.getType() != R2.getType())
1380     return None;
1381 
1382   const SCEV *NewBegin = SE.getSMaxExpr(R1Value.getBegin(), R2.getBegin());
1383   const SCEV *NewEnd = SE.getSMinExpr(R1Value.getEnd(), R2.getEnd());
1384 
1385   return InductiveRangeCheck::Range(NewBegin, NewEnd);
1386 }
1387 
1388 bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) {
1389   if (L->getBlocks().size() >= LoopSizeCutoff) {
1390     DEBUG(dbgs() << "irce: giving up constraining loop, too large\n";);
1391     return false;
1392   }
1393 
1394   BasicBlock *Preheader = L->getLoopPreheader();
1395   if (!Preheader) {
1396     DEBUG(dbgs() << "irce: loop has no preheader, leaving\n");
1397     return false;
1398   }
1399 
1400   LLVMContext &Context = Preheader->getContext();
1401   InductiveRangeCheck::AllocatorTy IRCAlloc;
1402   SmallVector<InductiveRangeCheck *, 16> RangeChecks;
1403   ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
1404   BranchProbabilityInfo &BPI =
1405       getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI();
1406 
1407   for (auto BBI : L->getBlocks())
1408     if (BranchInst *TBI = dyn_cast<BranchInst>(BBI->getTerminator()))
1409       if (InductiveRangeCheck *IRC =
1410           InductiveRangeCheck::create(IRCAlloc, TBI, L, SE, BPI))
1411         RangeChecks.push_back(IRC);
1412 
1413   if (RangeChecks.empty())
1414     return false;
1415 
1416   auto PrintRecognizedRangeChecks = [&](raw_ostream &OS) {
1417     OS << "irce: looking at loop "; L->print(OS);
1418     OS << "irce: loop has " << RangeChecks.size()
1419        << " inductive range checks: \n";
1420     for (InductiveRangeCheck *IRC : RangeChecks)
1421       IRC->print(OS);
1422   };
1423 
1424   DEBUG(PrintRecognizedRangeChecks(dbgs()));
1425 
1426   if (PrintRangeChecks)
1427     PrintRecognizedRangeChecks(errs());
1428 
1429   const char *FailureReason = nullptr;
1430   Optional<LoopStructure> MaybeLoopStructure =
1431       LoopStructure::parseLoopStructure(SE, BPI, *L, FailureReason);
1432   if (!MaybeLoopStructure.hasValue()) {
1433     DEBUG(dbgs() << "irce: could not parse loop structure: " << FailureReason
1434                  << "\n";);
1435     return false;
1436   }
1437   LoopStructure LS = MaybeLoopStructure.getValue();
1438   bool Increasing = LS.IndVarIncreasing;
1439   const SCEV *MinusOne =
1440       SE.getConstant(LS.IndVarNext->getType(), Increasing ? -1 : 1, true);
1441   const SCEVAddRecExpr *IndVar =
1442       cast<SCEVAddRecExpr>(SE.getAddExpr(SE.getSCEV(LS.IndVarNext), MinusOne));
1443 
1444   Optional<InductiveRangeCheck::Range> SafeIterRange;
1445   Instruction *ExprInsertPt = Preheader->getTerminator();
1446 
1447   SmallVector<InductiveRangeCheck *, 4> RangeChecksToEliminate;
1448 
1449   IRBuilder<> B(ExprInsertPt);
1450   for (InductiveRangeCheck *IRC : RangeChecks) {
1451     auto Result = IRC->computeSafeIterationSpace(SE, IndVar, B);
1452     if (Result.hasValue()) {
1453       auto MaybeSafeIterRange =
1454         IntersectRange(SE, SafeIterRange, Result.getValue(), B);
1455       if (MaybeSafeIterRange.hasValue()) {
1456         RangeChecksToEliminate.push_back(IRC);
1457         SafeIterRange = MaybeSafeIterRange.getValue();
1458       }
1459     }
1460   }
1461 
1462   if (!SafeIterRange.hasValue())
1463     return false;
1464 
1465   LoopConstrainer LC(*L, getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), LS,
1466                      SE, SafeIterRange.getValue());
1467   bool Changed = LC.run();
1468 
1469   if (Changed) {
1470     auto PrintConstrainedLoopInfo = [L]() {
1471       dbgs() << "irce: in function ";
1472       dbgs() << L->getHeader()->getParent()->getName() << ": ";
1473       dbgs() << "constrained ";
1474       L->print(dbgs());
1475     };
1476 
1477     DEBUG(PrintConstrainedLoopInfo());
1478 
1479     if (PrintChangedLoops)
1480       PrintConstrainedLoopInfo();
1481 
1482     // Optimize away the now-redundant range checks.
1483 
1484     for (InductiveRangeCheck *IRC : RangeChecksToEliminate) {
1485       ConstantInt *FoldedRangeCheck = IRC->getPassingDirection()
1486                                           ? ConstantInt::getTrue(Context)
1487                                           : ConstantInt::getFalse(Context);
1488       IRC->getBranch()->setCondition(FoldedRangeCheck);
1489     }
1490   }
1491 
1492   return Changed;
1493 }
1494 
1495 Pass *llvm::createInductiveRangeCheckEliminationPass() {
1496   return new InductiveRangeCheckElimination;
1497 }
1498