1 //===-- InductiveRangeCheckElimination.cpp - ------------------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // The InductiveRangeCheckElimination pass splits a loop's iteration space into 10 // three disjoint ranges. It does that in a way such that the loop running in 11 // the middle loop provably does not need range checks. As an example, it will 12 // convert 13 // 14 // len = < known positive > 15 // for (i = 0; i < n; i++) { 16 // if (0 <= i && i < len) { 17 // do_something(); 18 // } else { 19 // throw_out_of_bounds(); 20 // } 21 // } 22 // 23 // to 24 // 25 // len = < known positive > 26 // limit = smin(n, len) 27 // // no first segment 28 // for (i = 0; i < limit; i++) { 29 // if (0 <= i && i < len) { // this check is fully redundant 30 // do_something(); 31 // } else { 32 // throw_out_of_bounds(); 33 // } 34 // } 35 // for (i = limit; i < n; i++) { 36 // if (0 <= i && i < len) { 37 // do_something(); 38 // } else { 39 // throw_out_of_bounds(); 40 // } 41 // } 42 //===----------------------------------------------------------------------===// 43 44 #include "llvm/ADT/Optional.h" 45 #include "llvm/Analysis/BranchProbabilityInfo.h" 46 #include "llvm/Analysis/LoopInfo.h" 47 #include "llvm/Analysis/LoopPass.h" 48 #include "llvm/Analysis/ScalarEvolution.h" 49 #include "llvm/Analysis/ScalarEvolutionExpander.h" 50 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 51 #include "llvm/IR/Dominators.h" 52 #include "llvm/IR/Function.h" 53 #include "llvm/IR/IRBuilder.h" 54 #include "llvm/IR/Instructions.h" 55 #include "llvm/IR/PatternMatch.h" 56 #include "llvm/Pass.h" 57 #include "llvm/Support/Debug.h" 58 #include "llvm/Support/raw_ostream.h" 59 #include "llvm/Transforms/Scalar.h" 60 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 61 #include "llvm/Transforms/Utils/Cloning.h" 62 #include "llvm/Transforms/Utils/LoopUtils.h" 63 #include "llvm/Transforms/Utils/LoopSimplify.h" 64 65 using namespace llvm; 66 67 static cl::opt<unsigned> LoopSizeCutoff("irce-loop-size-cutoff", cl::Hidden, 68 cl::init(64)); 69 70 static cl::opt<bool> PrintChangedLoops("irce-print-changed-loops", cl::Hidden, 71 cl::init(false)); 72 73 static cl::opt<bool> PrintRangeChecks("irce-print-range-checks", cl::Hidden, 74 cl::init(false)); 75 76 static cl::opt<int> MaxExitProbReciprocal("irce-max-exit-prob-reciprocal", 77 cl::Hidden, cl::init(10)); 78 79 static cl::opt<bool> SkipProfitabilityChecks("irce-skip-profitability-checks", 80 cl::Hidden, cl::init(false)); 81 82 #define DEBUG_TYPE "irce" 83 84 namespace { 85 86 /// An inductive range check is conditional branch in a loop with 87 /// 88 /// 1. a very cold successor (i.e. the branch jumps to that successor very 89 /// rarely) 90 /// 91 /// and 92 /// 93 /// 2. a condition that is provably true for some contiguous range of values 94 /// taken by the containing loop's induction variable. 95 /// 96 class InductiveRangeCheck { 97 // Classifies a range check 98 enum RangeCheckKind : unsigned { 99 // Range check of the form "0 <= I". 100 RANGE_CHECK_LOWER = 1, 101 102 // Range check of the form "I < L" where L is known positive. 103 RANGE_CHECK_UPPER = 2, 104 105 // The logical and of the RANGE_CHECK_LOWER and RANGE_CHECK_UPPER 106 // conditions. 107 RANGE_CHECK_BOTH = RANGE_CHECK_LOWER | RANGE_CHECK_UPPER, 108 109 // Unrecognized range check condition. 110 RANGE_CHECK_UNKNOWN = (unsigned)-1 111 }; 112 113 static StringRef rangeCheckKindToStr(RangeCheckKind); 114 115 const SCEV *Offset = nullptr; 116 const SCEV *Scale = nullptr; 117 Value *Length = nullptr; 118 Use *CheckUse = nullptr; 119 RangeCheckKind Kind = RANGE_CHECK_UNKNOWN; 120 121 static RangeCheckKind parseRangeCheckICmp(Loop *L, ICmpInst *ICI, 122 ScalarEvolution &SE, Value *&Index, 123 Value *&Length); 124 125 static void 126 extractRangeChecksFromCond(Loop *L, ScalarEvolution &SE, Use &ConditionUse, 127 SmallVectorImpl<InductiveRangeCheck> &Checks, 128 SmallPtrSetImpl<Value *> &Visited); 129 130 public: 131 const SCEV *getOffset() const { return Offset; } 132 const SCEV *getScale() const { return Scale; } 133 Value *getLength() const { return Length; } 134 135 void print(raw_ostream &OS) const { 136 OS << "InductiveRangeCheck:\n"; 137 OS << " Kind: " << rangeCheckKindToStr(Kind) << "\n"; 138 OS << " Offset: "; 139 Offset->print(OS); 140 OS << " Scale: "; 141 Scale->print(OS); 142 OS << " Length: "; 143 if (Length) 144 Length->print(OS); 145 else 146 OS << "(null)"; 147 OS << "\n CheckUse: "; 148 getCheckUse()->getUser()->print(OS); 149 OS << " Operand: " << getCheckUse()->getOperandNo() << "\n"; 150 } 151 152 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 153 void dump() { 154 print(dbgs()); 155 } 156 #endif 157 158 Use *getCheckUse() const { return CheckUse; } 159 160 /// Represents an signed integer range [Range.getBegin(), Range.getEnd()). If 161 /// R.getEnd() sle R.getBegin(), then R denotes the empty range. 162 163 class Range { 164 const SCEV *Begin; 165 const SCEV *End; 166 167 public: 168 Range(const SCEV *Begin, const SCEV *End) : Begin(Begin), End(End) { 169 assert(Begin->getType() == End->getType() && "ill-typed range!"); 170 } 171 172 Type *getType() const { return Begin->getType(); } 173 const SCEV *getBegin() const { return Begin; } 174 const SCEV *getEnd() const { return End; } 175 }; 176 177 /// This is the value the condition of the branch needs to evaluate to for the 178 /// branch to take the hot successor (see (1) above). 179 bool getPassingDirection() { return true; } 180 181 /// Computes a range for the induction variable (IndVar) in which the range 182 /// check is redundant and can be constant-folded away. The induction 183 /// variable is not required to be the canonical {0,+,1} induction variable. 184 Optional<Range> computeSafeIterationSpace(ScalarEvolution &SE, 185 const SCEVAddRecExpr *IndVar) const; 186 187 /// Parse out a set of inductive range checks from \p BI and append them to \p 188 /// Checks. 189 /// 190 /// NB! There may be conditions feeding into \p BI that aren't inductive range 191 /// checks, and hence don't end up in \p Checks. 192 static void 193 extractRangeChecksFromBranch(BranchInst *BI, Loop *L, ScalarEvolution &SE, 194 BranchProbabilityInfo &BPI, 195 SmallVectorImpl<InductiveRangeCheck> &Checks); 196 }; 197 198 class InductiveRangeCheckElimination : public LoopPass { 199 public: 200 static char ID; 201 InductiveRangeCheckElimination() : LoopPass(ID) { 202 initializeInductiveRangeCheckEliminationPass( 203 *PassRegistry::getPassRegistry()); 204 } 205 206 void getAnalysisUsage(AnalysisUsage &AU) const override { 207 AU.addRequired<BranchProbabilityInfoWrapperPass>(); 208 getLoopAnalysisUsage(AU); 209 } 210 211 bool runOnLoop(Loop *L, LPPassManager &LPM) override; 212 }; 213 214 char InductiveRangeCheckElimination::ID = 0; 215 } 216 217 INITIALIZE_PASS_BEGIN(InductiveRangeCheckElimination, "irce", 218 "Inductive range check elimination", false, false) 219 INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) 220 INITIALIZE_PASS_DEPENDENCY(LoopPass) 221 INITIALIZE_PASS_END(InductiveRangeCheckElimination, "irce", 222 "Inductive range check elimination", false, false) 223 224 StringRef InductiveRangeCheck::rangeCheckKindToStr( 225 InductiveRangeCheck::RangeCheckKind RCK) { 226 switch (RCK) { 227 case InductiveRangeCheck::RANGE_CHECK_UNKNOWN: 228 return "RANGE_CHECK_UNKNOWN"; 229 230 case InductiveRangeCheck::RANGE_CHECK_UPPER: 231 return "RANGE_CHECK_UPPER"; 232 233 case InductiveRangeCheck::RANGE_CHECK_LOWER: 234 return "RANGE_CHECK_LOWER"; 235 236 case InductiveRangeCheck::RANGE_CHECK_BOTH: 237 return "RANGE_CHECK_BOTH"; 238 } 239 240 llvm_unreachable("unknown range check type!"); 241 } 242 243 /// Parse a single ICmp instruction, `ICI`, into a range check. If `ICI` cannot 244 /// be interpreted as a range check, return `RANGE_CHECK_UNKNOWN` and set 245 /// `Index` and `Length` to `nullptr`. Otherwise set `Index` to the value being 246 /// range checked, and set `Length` to the upper limit `Index` is being range 247 /// checked with if (and only if) the range check type is stronger or equal to 248 /// RANGE_CHECK_UPPER. 249 /// 250 InductiveRangeCheck::RangeCheckKind 251 InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, 252 ScalarEvolution &SE, Value *&Index, 253 Value *&Length) { 254 255 auto IsNonNegativeAndNotLoopVarying = [&SE, L](Value *V) { 256 const SCEV *S = SE.getSCEV(V); 257 if (isa<SCEVCouldNotCompute>(S)) 258 return false; 259 260 return SE.getLoopDisposition(S, L) == ScalarEvolution::LoopInvariant && 261 SE.isKnownNonNegative(S); 262 }; 263 264 using namespace llvm::PatternMatch; 265 266 ICmpInst::Predicate Pred = ICI->getPredicate(); 267 Value *LHS = ICI->getOperand(0); 268 Value *RHS = ICI->getOperand(1); 269 270 switch (Pred) { 271 default: 272 return RANGE_CHECK_UNKNOWN; 273 274 case ICmpInst::ICMP_SLE: 275 std::swap(LHS, RHS); 276 // fallthrough 277 case ICmpInst::ICMP_SGE: 278 if (match(RHS, m_ConstantInt<0>())) { 279 Index = LHS; 280 return RANGE_CHECK_LOWER; 281 } 282 return RANGE_CHECK_UNKNOWN; 283 284 case ICmpInst::ICMP_SLT: 285 std::swap(LHS, RHS); 286 // fallthrough 287 case ICmpInst::ICMP_SGT: 288 if (match(RHS, m_ConstantInt<-1>())) { 289 Index = LHS; 290 return RANGE_CHECK_LOWER; 291 } 292 293 if (IsNonNegativeAndNotLoopVarying(LHS)) { 294 Index = RHS; 295 Length = LHS; 296 return RANGE_CHECK_UPPER; 297 } 298 return RANGE_CHECK_UNKNOWN; 299 300 case ICmpInst::ICMP_ULT: 301 std::swap(LHS, RHS); 302 // fallthrough 303 case ICmpInst::ICMP_UGT: 304 if (IsNonNegativeAndNotLoopVarying(LHS)) { 305 Index = RHS; 306 Length = LHS; 307 return RANGE_CHECK_BOTH; 308 } 309 return RANGE_CHECK_UNKNOWN; 310 } 311 312 llvm_unreachable("default clause returns!"); 313 } 314 315 void InductiveRangeCheck::extractRangeChecksFromCond( 316 Loop *L, ScalarEvolution &SE, Use &ConditionUse, 317 SmallVectorImpl<InductiveRangeCheck> &Checks, 318 SmallPtrSetImpl<Value *> &Visited) { 319 using namespace llvm::PatternMatch; 320 321 Value *Condition = ConditionUse.get(); 322 if (!Visited.insert(Condition).second) 323 return; 324 325 if (match(Condition, m_And(m_Value(), m_Value()))) { 326 SmallVector<InductiveRangeCheck, 8> SubChecks; 327 extractRangeChecksFromCond(L, SE, cast<User>(Condition)->getOperandUse(0), 328 SubChecks, Visited); 329 extractRangeChecksFromCond(L, SE, cast<User>(Condition)->getOperandUse(1), 330 SubChecks, Visited); 331 332 if (SubChecks.size() == 2) { 333 // Handle a special case where we know how to merge two checks separately 334 // checking the upper and lower bounds into a full range check. 335 const auto &RChkA = SubChecks[0]; 336 const auto &RChkB = SubChecks[1]; 337 if ((RChkA.Length == RChkB.Length || !RChkA.Length || !RChkB.Length) && 338 RChkA.Offset == RChkB.Offset && RChkA.Scale == RChkB.Scale) { 339 340 // If RChkA.Kind == RChkB.Kind then we just found two identical checks. 341 // But if one of them is a RANGE_CHECK_LOWER and the other is a 342 // RANGE_CHECK_UPPER (only possibility if they're different) then 343 // together they form a RANGE_CHECK_BOTH. 344 SubChecks[0].Kind = 345 (InductiveRangeCheck::RangeCheckKind)(RChkA.Kind | RChkB.Kind); 346 SubChecks[0].Length = RChkA.Length ? RChkA.Length : RChkB.Length; 347 SubChecks[0].CheckUse = &ConditionUse; 348 349 // We updated one of the checks in place, now erase the other. 350 SubChecks.pop_back(); 351 } 352 } 353 354 Checks.insert(Checks.end(), SubChecks.begin(), SubChecks.end()); 355 return; 356 } 357 358 ICmpInst *ICI = dyn_cast<ICmpInst>(Condition); 359 if (!ICI) 360 return; 361 362 Value *Length = nullptr, *Index; 363 auto RCKind = parseRangeCheckICmp(L, ICI, SE, Index, Length); 364 if (RCKind == InductiveRangeCheck::RANGE_CHECK_UNKNOWN) 365 return; 366 367 const auto *IndexAddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(Index)); 368 bool IsAffineIndex = 369 IndexAddRec && (IndexAddRec->getLoop() == L) && IndexAddRec->isAffine(); 370 371 if (!IsAffineIndex) 372 return; 373 374 InductiveRangeCheck IRC; 375 IRC.Length = Length; 376 IRC.Offset = IndexAddRec->getStart(); 377 IRC.Scale = IndexAddRec->getStepRecurrence(SE); 378 IRC.CheckUse = &ConditionUse; 379 IRC.Kind = RCKind; 380 Checks.push_back(IRC); 381 } 382 383 void InductiveRangeCheck::extractRangeChecksFromBranch( 384 BranchInst *BI, Loop *L, ScalarEvolution &SE, BranchProbabilityInfo &BPI, 385 SmallVectorImpl<InductiveRangeCheck> &Checks) { 386 387 if (BI->isUnconditional() || BI->getParent() == L->getLoopLatch()) 388 return; 389 390 BranchProbability LikelyTaken(15, 16); 391 392 if (!SkipProfitabilityChecks && 393 BPI.getEdgeProbability(BI->getParent(), (unsigned)0) < LikelyTaken) 394 return; 395 396 SmallPtrSet<Value *, 8> Visited; 397 InductiveRangeCheck::extractRangeChecksFromCond(L, SE, BI->getOperandUse(0), 398 Checks, Visited); 399 } 400 401 namespace { 402 403 // Keeps track of the structure of a loop. This is similar to llvm::Loop, 404 // except that it is more lightweight and can track the state of a loop through 405 // changing and potentially invalid IR. This structure also formalizes the 406 // kinds of loops we can deal with -- ones that have a single latch that is also 407 // an exiting block *and* have a canonical induction variable. 408 struct LoopStructure { 409 const char *Tag; 410 411 BasicBlock *Header; 412 BasicBlock *Latch; 413 414 // `Latch's terminator instruction is `LatchBr', and it's `LatchBrExitIdx'th 415 // successor is `LatchExit', the exit block of the loop. 416 BranchInst *LatchBr; 417 BasicBlock *LatchExit; 418 unsigned LatchBrExitIdx; 419 420 Value *IndVarNext; 421 Value *IndVarStart; 422 Value *LoopExitAt; 423 bool IndVarIncreasing; 424 425 LoopStructure() 426 : Tag(""), Header(nullptr), Latch(nullptr), LatchBr(nullptr), 427 LatchExit(nullptr), LatchBrExitIdx(-1), IndVarNext(nullptr), 428 IndVarStart(nullptr), LoopExitAt(nullptr), IndVarIncreasing(false) {} 429 430 template <typename M> LoopStructure map(M Map) const { 431 LoopStructure Result; 432 Result.Tag = Tag; 433 Result.Header = cast<BasicBlock>(Map(Header)); 434 Result.Latch = cast<BasicBlock>(Map(Latch)); 435 Result.LatchBr = cast<BranchInst>(Map(LatchBr)); 436 Result.LatchExit = cast<BasicBlock>(Map(LatchExit)); 437 Result.LatchBrExitIdx = LatchBrExitIdx; 438 Result.IndVarNext = Map(IndVarNext); 439 Result.IndVarStart = Map(IndVarStart); 440 Result.LoopExitAt = Map(LoopExitAt); 441 Result.IndVarIncreasing = IndVarIncreasing; 442 return Result; 443 } 444 445 static Optional<LoopStructure> parseLoopStructure(ScalarEvolution &, 446 BranchProbabilityInfo &BPI, 447 Loop &, 448 const char *&); 449 }; 450 451 /// This class is used to constrain loops to run within a given iteration space. 452 /// The algorithm this class implements is given a Loop and a range [Begin, 453 /// End). The algorithm then tries to break out a "main loop" out of the loop 454 /// it is given in a way that the "main loop" runs with the induction variable 455 /// in a subset of [Begin, End). The algorithm emits appropriate pre and post 456 /// loops to run any remaining iterations. The pre loop runs any iterations in 457 /// which the induction variable is < Begin, and the post loop runs any 458 /// iterations in which the induction variable is >= End. 459 /// 460 class LoopConstrainer { 461 // The representation of a clone of the original loop we started out with. 462 struct ClonedLoop { 463 // The cloned blocks 464 std::vector<BasicBlock *> Blocks; 465 466 // `Map` maps values in the clonee into values in the cloned version 467 ValueToValueMapTy Map; 468 469 // An instance of `LoopStructure` for the cloned loop 470 LoopStructure Structure; 471 }; 472 473 // Result of rewriting the range of a loop. See changeIterationSpaceEnd for 474 // more details on what these fields mean. 475 struct RewrittenRangeInfo { 476 BasicBlock *PseudoExit; 477 BasicBlock *ExitSelector; 478 std::vector<PHINode *> PHIValuesAtPseudoExit; 479 PHINode *IndVarEnd; 480 481 RewrittenRangeInfo() 482 : PseudoExit(nullptr), ExitSelector(nullptr), IndVarEnd(nullptr) {} 483 }; 484 485 // Calculated subranges we restrict the iteration space of the main loop to. 486 // See the implementation of `calculateSubRanges' for more details on how 487 // these fields are computed. `LowLimit` is None if there is no restriction 488 // on low end of the restricted iteration space of the main loop. `HighLimit` 489 // is None if there is no restriction on high end of the restricted iteration 490 // space of the main loop. 491 492 struct SubRanges { 493 Optional<const SCEV *> LowLimit; 494 Optional<const SCEV *> HighLimit; 495 }; 496 497 // A utility function that does a `replaceUsesOfWith' on the incoming block 498 // set of a `PHINode' -- replaces instances of `Block' in the `PHINode's 499 // incoming block list with `ReplaceBy'. 500 static void replacePHIBlock(PHINode *PN, BasicBlock *Block, 501 BasicBlock *ReplaceBy); 502 503 // Compute a safe set of limits for the main loop to run in -- effectively the 504 // intersection of `Range' and the iteration space of the original loop. 505 // Return None if unable to compute the set of subranges. 506 // 507 Optional<SubRanges> calculateSubRanges() const; 508 509 // Clone `OriginalLoop' and return the result in CLResult. The IR after 510 // running `cloneLoop' is well formed except for the PHI nodes in CLResult -- 511 // the PHI nodes say that there is an incoming edge from `OriginalPreheader` 512 // but there is no such edge. 513 // 514 void cloneLoop(ClonedLoop &CLResult, const char *Tag) const; 515 516 // Rewrite the iteration space of the loop denoted by (LS, Preheader). The 517 // iteration space of the rewritten loop ends at ExitLoopAt. The start of the 518 // iteration space is not changed. `ExitLoopAt' is assumed to be slt 519 // `OriginalHeaderCount'. 520 // 521 // If there are iterations left to execute, control is made to jump to 522 // `ContinuationBlock', otherwise they take the normal loop exit. The 523 // returned `RewrittenRangeInfo' object is populated as follows: 524 // 525 // .PseudoExit is a basic block that unconditionally branches to 526 // `ContinuationBlock'. 527 // 528 // .ExitSelector is a basic block that decides, on exit from the loop, 529 // whether to branch to the "true" exit or to `PseudoExit'. 530 // 531 // .PHIValuesAtPseudoExit are PHINodes in `PseudoExit' that compute the value 532 // for each PHINode in the loop header on taking the pseudo exit. 533 // 534 // After changeIterationSpaceEnd, `Preheader' is no longer a legitimate 535 // preheader because it is made to branch to the loop header only 536 // conditionally. 537 // 538 RewrittenRangeInfo 539 changeIterationSpaceEnd(const LoopStructure &LS, BasicBlock *Preheader, 540 Value *ExitLoopAt, 541 BasicBlock *ContinuationBlock) const; 542 543 // The loop denoted by `LS' has `OldPreheader' as its preheader. This 544 // function creates a new preheader for `LS' and returns it. 545 // 546 BasicBlock *createPreheader(const LoopStructure &LS, BasicBlock *OldPreheader, 547 const char *Tag) const; 548 549 // `ContinuationBlockAndPreheader' was the continuation block for some call to 550 // `changeIterationSpaceEnd' and is the preheader to the loop denoted by `LS'. 551 // This function rewrites the PHI nodes in `LS.Header' to start with the 552 // correct value. 553 void rewriteIncomingValuesForPHIs( 554 LoopStructure &LS, BasicBlock *ContinuationBlockAndPreheader, 555 const LoopConstrainer::RewrittenRangeInfo &RRI) const; 556 557 // Even though we do not preserve any passes at this time, we at least need to 558 // keep the parent loop structure consistent. The `LPPassManager' seems to 559 // verify this after running a loop pass. This function adds the list of 560 // blocks denoted by BBs to this loops parent loop if required. 561 void addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs); 562 563 // Some global state. 564 Function &F; 565 LLVMContext &Ctx; 566 ScalarEvolution &SE; 567 DominatorTree &DT; 568 569 // Information about the original loop we started out with. 570 Loop &OriginalLoop; 571 LoopInfo &LI; 572 const SCEV *LatchTakenCount; 573 BasicBlock *OriginalPreheader; 574 575 // The preheader of the main loop. This may or may not be different from 576 // `OriginalPreheader'. 577 BasicBlock *MainLoopPreheader; 578 579 // The range we need to run the main loop in. 580 InductiveRangeCheck::Range Range; 581 582 // The structure of the main loop (see comment at the beginning of this class 583 // for a definition) 584 LoopStructure MainLoopStructure; 585 586 public: 587 LoopConstrainer(Loop &L, LoopInfo &LI, const LoopStructure &LS, 588 ScalarEvolution &SE, DominatorTree &DT, 589 InductiveRangeCheck::Range R) 590 : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), 591 SE(SE), DT(DT), OriginalLoop(L), LI(LI), LatchTakenCount(nullptr), 592 OriginalPreheader(nullptr), MainLoopPreheader(nullptr), Range(R), 593 MainLoopStructure(LS) {} 594 595 // Entry point for the algorithm. Returns true on success. 596 bool run(); 597 }; 598 599 } 600 601 void LoopConstrainer::replacePHIBlock(PHINode *PN, BasicBlock *Block, 602 BasicBlock *ReplaceBy) { 603 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) 604 if (PN->getIncomingBlock(i) == Block) 605 PN->setIncomingBlock(i, ReplaceBy); 606 } 607 608 static bool CanBeSMax(ScalarEvolution &SE, const SCEV *S) { 609 APInt SMax = 610 APInt::getSignedMaxValue(cast<IntegerType>(S->getType())->getBitWidth()); 611 return SE.getSignedRange(S).contains(SMax) && 612 SE.getUnsignedRange(S).contains(SMax); 613 } 614 615 static bool CanBeSMin(ScalarEvolution &SE, const SCEV *S) { 616 APInt SMin = 617 APInt::getSignedMinValue(cast<IntegerType>(S->getType())->getBitWidth()); 618 return SE.getSignedRange(S).contains(SMin) && 619 SE.getUnsignedRange(S).contains(SMin); 620 } 621 622 Optional<LoopStructure> 623 LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BPI, 624 Loop &L, const char *&FailureReason) { 625 assert(L.isLoopSimplifyForm() && "should follow from addRequired<>"); 626 627 BasicBlock *Latch = L.getLoopLatch(); 628 if (!L.isLoopExiting(Latch)) { 629 FailureReason = "no loop latch"; 630 return None; 631 } 632 633 BasicBlock *Header = L.getHeader(); 634 BasicBlock *Preheader = L.getLoopPreheader(); 635 if (!Preheader) { 636 FailureReason = "no preheader"; 637 return None; 638 } 639 640 BranchInst *LatchBr = dyn_cast<BranchInst>(Latch->getTerminator()); 641 if (!LatchBr || LatchBr->isUnconditional()) { 642 FailureReason = "latch terminator not conditional branch"; 643 return None; 644 } 645 646 unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0; 647 648 BranchProbability ExitProbability = 649 BPI.getEdgeProbability(LatchBr->getParent(), LatchBrExitIdx); 650 651 if (!SkipProfitabilityChecks && 652 ExitProbability > BranchProbability(1, MaxExitProbReciprocal)) { 653 FailureReason = "short running loop, not profitable"; 654 return None; 655 } 656 657 ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition()); 658 if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) { 659 FailureReason = "latch terminator branch not conditional on integral icmp"; 660 return None; 661 } 662 663 const SCEV *LatchCount = SE.getExitCount(&L, Latch); 664 if (isa<SCEVCouldNotCompute>(LatchCount)) { 665 FailureReason = "could not compute latch count"; 666 return None; 667 } 668 669 ICmpInst::Predicate Pred = ICI->getPredicate(); 670 Value *LeftValue = ICI->getOperand(0); 671 const SCEV *LeftSCEV = SE.getSCEV(LeftValue); 672 IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType()); 673 674 Value *RightValue = ICI->getOperand(1); 675 const SCEV *RightSCEV = SE.getSCEV(RightValue); 676 677 // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence. 678 if (!isa<SCEVAddRecExpr>(LeftSCEV)) { 679 if (isa<SCEVAddRecExpr>(RightSCEV)) { 680 std::swap(LeftSCEV, RightSCEV); 681 std::swap(LeftValue, RightValue); 682 Pred = ICmpInst::getSwappedPredicate(Pred); 683 } else { 684 FailureReason = "no add recurrences in the icmp"; 685 return None; 686 } 687 } 688 689 auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) { 690 if (AR->getNoWrapFlags(SCEV::FlagNSW)) 691 return true; 692 693 IntegerType *Ty = cast<IntegerType>(AR->getType()); 694 IntegerType *WideTy = 695 IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2); 696 697 const SCEVAddRecExpr *ExtendAfterOp = 698 dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy)); 699 if (ExtendAfterOp) { 700 const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy); 701 const SCEV *ExtendedStep = 702 SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy); 703 704 bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart && 705 ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep; 706 707 if (NoSignedWrap) 708 return true; 709 } 710 711 // We may have proved this when computing the sign extension above. 712 return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap; 713 }; 714 715 auto IsInductionVar = [&](const SCEVAddRecExpr *AR, bool &IsIncreasing) { 716 if (!AR->isAffine()) 717 return false; 718 719 // Currently we only work with induction variables that have been proved to 720 // not wrap. This restriction can potentially be lifted in the future. 721 722 if (!HasNoSignedWrap(AR)) 723 return false; 724 725 if (const SCEVConstant *StepExpr = 726 dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE))) { 727 ConstantInt *StepCI = StepExpr->getValue(); 728 if (StepCI->isOne() || StepCI->isMinusOne()) { 729 IsIncreasing = StepCI->isOne(); 730 return true; 731 } 732 } 733 734 return false; 735 }; 736 737 // `ICI` is interpreted as taking the backedge if the *next* value of the 738 // induction variable satisfies some constraint. 739 740 const SCEVAddRecExpr *IndVarNext = cast<SCEVAddRecExpr>(LeftSCEV); 741 bool IsIncreasing = false; 742 if (!IsInductionVar(IndVarNext, IsIncreasing)) { 743 FailureReason = "LHS in icmp not induction variable"; 744 return None; 745 } 746 747 ConstantInt *One = ConstantInt::get(IndVarTy, 1); 748 // TODO: generalize the predicates here to also match their unsigned variants. 749 if (IsIncreasing) { 750 bool FoundExpectedPred = 751 (Pred == ICmpInst::ICMP_SLT && LatchBrExitIdx == 1) || 752 (Pred == ICmpInst::ICMP_SGT && LatchBrExitIdx == 0); 753 754 if (!FoundExpectedPred) { 755 FailureReason = "expected icmp slt semantically, found something else"; 756 return None; 757 } 758 759 if (LatchBrExitIdx == 0) { 760 if (CanBeSMax(SE, RightSCEV)) { 761 // TODO: this restriction is easily removable -- we just have to 762 // remember that the icmp was an slt and not an sle. 763 FailureReason = "limit may overflow when coercing sle to slt"; 764 return None; 765 } 766 767 IRBuilder<> B(Preheader->getTerminator()); 768 RightValue = B.CreateAdd(RightValue, One); 769 } 770 771 } else { 772 bool FoundExpectedPred = 773 (Pred == ICmpInst::ICMP_SGT && LatchBrExitIdx == 1) || 774 (Pred == ICmpInst::ICMP_SLT && LatchBrExitIdx == 0); 775 776 if (!FoundExpectedPred) { 777 FailureReason = "expected icmp sgt semantically, found something else"; 778 return None; 779 } 780 781 if (LatchBrExitIdx == 0) { 782 if (CanBeSMin(SE, RightSCEV)) { 783 // TODO: this restriction is easily removable -- we just have to 784 // remember that the icmp was an sgt and not an sge. 785 FailureReason = "limit may overflow when coercing sge to sgt"; 786 return None; 787 } 788 789 IRBuilder<> B(Preheader->getTerminator()); 790 RightValue = B.CreateSub(RightValue, One); 791 } 792 } 793 794 const SCEV *StartNext = IndVarNext->getStart(); 795 const SCEV *Addend = SE.getNegativeSCEV(IndVarNext->getStepRecurrence(SE)); 796 const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend); 797 798 BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx); 799 800 assert(SE.getLoopDisposition(LatchCount, &L) == 801 ScalarEvolution::LoopInvariant && 802 "loop variant exit count doesn't make sense!"); 803 804 assert(!L.contains(LatchExit) && "expected an exit block!"); 805 const DataLayout &DL = Preheader->getModule()->getDataLayout(); 806 Value *IndVarStartV = 807 SCEVExpander(SE, DL, "irce") 808 .expandCodeFor(IndVarStart, IndVarTy, Preheader->getTerminator()); 809 IndVarStartV->setName("indvar.start"); 810 811 LoopStructure Result; 812 813 Result.Tag = "main"; 814 Result.Header = Header; 815 Result.Latch = Latch; 816 Result.LatchBr = LatchBr; 817 Result.LatchExit = LatchExit; 818 Result.LatchBrExitIdx = LatchBrExitIdx; 819 Result.IndVarStart = IndVarStartV; 820 Result.IndVarNext = LeftValue; 821 Result.IndVarIncreasing = IsIncreasing; 822 Result.LoopExitAt = RightValue; 823 824 FailureReason = nullptr; 825 826 return Result; 827 } 828 829 Optional<LoopConstrainer::SubRanges> 830 LoopConstrainer::calculateSubRanges() const { 831 IntegerType *Ty = cast<IntegerType>(LatchTakenCount->getType()); 832 833 if (Range.getType() != Ty) 834 return None; 835 836 LoopConstrainer::SubRanges Result; 837 838 // I think we can be more aggressive here and make this nuw / nsw if the 839 // addition that feeds into the icmp for the latch's terminating branch is nuw 840 // / nsw. In any case, a wrapping 2's complement addition is safe. 841 ConstantInt *One = ConstantInt::get(Ty, 1); 842 const SCEV *Start = SE.getSCEV(MainLoopStructure.IndVarStart); 843 const SCEV *End = SE.getSCEV(MainLoopStructure.LoopExitAt); 844 845 bool Increasing = MainLoopStructure.IndVarIncreasing; 846 847 // We compute `Smallest` and `Greatest` such that [Smallest, Greatest) is the 848 // range of values the induction variable takes. 849 850 const SCEV *Smallest = nullptr, *Greatest = nullptr; 851 852 if (Increasing) { 853 Smallest = Start; 854 Greatest = End; 855 } else { 856 // These two computations may sign-overflow. Here is why that is okay: 857 // 858 // We know that the induction variable does not sign-overflow on any 859 // iteration except the last one, and it starts at `Start` and ends at 860 // `End`, decrementing by one every time. 861 // 862 // * if `Smallest` sign-overflows we know `End` is `INT_SMAX`. Since the 863 // induction variable is decreasing we know that that the smallest value 864 // the loop body is actually executed with is `INT_SMIN` == `Smallest`. 865 // 866 // * if `Greatest` sign-overflows, we know it can only be `INT_SMIN`. In 867 // that case, `Clamp` will always return `Smallest` and 868 // [`Result.LowLimit`, `Result.HighLimit`) = [`Smallest`, `Smallest`) 869 // will be an empty range. Returning an empty range is always safe. 870 // 871 872 Smallest = SE.getAddExpr(End, SE.getSCEV(One)); 873 Greatest = SE.getAddExpr(Start, SE.getSCEV(One)); 874 } 875 876 auto Clamp = [this, Smallest, Greatest](const SCEV *S) { 877 return SE.getSMaxExpr(Smallest, SE.getSMinExpr(Greatest, S)); 878 }; 879 880 // In some cases we can prove that we don't need a pre or post loop 881 882 bool ProvablyNoPreloop = 883 SE.isKnownPredicate(ICmpInst::ICMP_SLE, Range.getBegin(), Smallest); 884 if (!ProvablyNoPreloop) 885 Result.LowLimit = Clamp(Range.getBegin()); 886 887 bool ProvablyNoPostLoop = 888 SE.isKnownPredicate(ICmpInst::ICMP_SLE, Greatest, Range.getEnd()); 889 if (!ProvablyNoPostLoop) 890 Result.HighLimit = Clamp(Range.getEnd()); 891 892 return Result; 893 } 894 895 void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result, 896 const char *Tag) const { 897 for (BasicBlock *BB : OriginalLoop.getBlocks()) { 898 BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F); 899 Result.Blocks.push_back(Clone); 900 Result.Map[BB] = Clone; 901 } 902 903 auto GetClonedValue = [&Result](Value *V) { 904 assert(V && "null values not in domain!"); 905 auto It = Result.Map.find(V); 906 if (It == Result.Map.end()) 907 return V; 908 return static_cast<Value *>(It->second); 909 }; 910 911 Result.Structure = MainLoopStructure.map(GetClonedValue); 912 Result.Structure.Tag = Tag; 913 914 for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) { 915 BasicBlock *ClonedBB = Result.Blocks[i]; 916 BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i]; 917 918 assert(Result.Map[OriginalBB] == ClonedBB && "invariant!"); 919 920 for (Instruction &I : *ClonedBB) 921 RemapInstruction(&I, Result.Map, 922 RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); 923 924 // Exit blocks will now have one more predecessor and their PHI nodes need 925 // to be edited to reflect that. No phi nodes need to be introduced because 926 // the loop is in LCSSA. 927 928 for (auto SBBI = succ_begin(OriginalBB), SBBE = succ_end(OriginalBB); 929 SBBI != SBBE; ++SBBI) { 930 931 if (OriginalLoop.contains(*SBBI)) 932 continue; // not an exit block 933 934 for (Instruction &I : **SBBI) { 935 if (!isa<PHINode>(&I)) 936 break; 937 938 PHINode *PN = cast<PHINode>(&I); 939 Value *OldIncoming = PN->getIncomingValueForBlock(OriginalBB); 940 PN->addIncoming(GetClonedValue(OldIncoming), ClonedBB); 941 } 942 } 943 } 944 } 945 946 LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( 947 const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt, 948 BasicBlock *ContinuationBlock) const { 949 950 // We start with a loop with a single latch: 951 // 952 // +--------------------+ 953 // | | 954 // | preheader | 955 // | | 956 // +--------+-----------+ 957 // | ----------------\ 958 // | / | 959 // +--------v----v------+ | 960 // | | | 961 // | header | | 962 // | | | 963 // +--------------------+ | 964 // | 965 // ..... | 966 // | 967 // +--------------------+ | 968 // | | | 969 // | latch >----------/ 970 // | | 971 // +-------v------------+ 972 // | 973 // | 974 // | +--------------------+ 975 // | | | 976 // +---> original exit | 977 // | | 978 // +--------------------+ 979 // 980 // We change the control flow to look like 981 // 982 // 983 // +--------------------+ 984 // | | 985 // | preheader >-------------------------+ 986 // | | | 987 // +--------v-----------+ | 988 // | /-------------+ | 989 // | / | | 990 // +--------v--v--------+ | | 991 // | | | | 992 // | header | | +--------+ | 993 // | | | | | | 994 // +--------------------+ | | +-----v-----v-----------+ 995 // | | | | 996 // | | | .pseudo.exit | 997 // | | | | 998 // | | +-----------v-----------+ 999 // | | | 1000 // ..... | | | 1001 // | | +--------v-------------+ 1002 // +--------------------+ | | | | 1003 // | | | | | ContinuationBlock | 1004 // | latch >------+ | | | 1005 // | | | +----------------------+ 1006 // +---------v----------+ | 1007 // | | 1008 // | | 1009 // | +---------------^-----+ 1010 // | | | 1011 // +-----> .exit.selector | 1012 // | | 1013 // +----------v----------+ 1014 // | 1015 // +--------------------+ | 1016 // | | | 1017 // | original exit <----+ 1018 // | | 1019 // +--------------------+ 1020 // 1021 1022 RewrittenRangeInfo RRI; 1023 1024 auto BBInsertLocation = std::next(Function::iterator(LS.Latch)); 1025 RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector", 1026 &F, &*BBInsertLocation); 1027 RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F, 1028 &*BBInsertLocation); 1029 1030 BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator()); 1031 bool Increasing = LS.IndVarIncreasing; 1032 1033 IRBuilder<> B(PreheaderJump); 1034 1035 // EnterLoopCond - is it okay to start executing this `LS'? 1036 Value *EnterLoopCond = Increasing 1037 ? B.CreateICmpSLT(LS.IndVarStart, ExitSubloopAt) 1038 : B.CreateICmpSGT(LS.IndVarStart, ExitSubloopAt); 1039 1040 B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit); 1041 PreheaderJump->eraseFromParent(); 1042 1043 LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector); 1044 B.SetInsertPoint(LS.LatchBr); 1045 Value *TakeBackedgeLoopCond = 1046 Increasing ? B.CreateICmpSLT(LS.IndVarNext, ExitSubloopAt) 1047 : B.CreateICmpSGT(LS.IndVarNext, ExitSubloopAt); 1048 Value *CondForBranch = LS.LatchBrExitIdx == 1 1049 ? TakeBackedgeLoopCond 1050 : B.CreateNot(TakeBackedgeLoopCond); 1051 1052 LS.LatchBr->setCondition(CondForBranch); 1053 1054 B.SetInsertPoint(RRI.ExitSelector); 1055 1056 // IterationsLeft - are there any more iterations left, given the original 1057 // upper bound on the induction variable? If not, we branch to the "real" 1058 // exit. 1059 Value *IterationsLeft = Increasing 1060 ? B.CreateICmpSLT(LS.IndVarNext, LS.LoopExitAt) 1061 : B.CreateICmpSGT(LS.IndVarNext, LS.LoopExitAt); 1062 B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit); 1063 1064 BranchInst *BranchToContinuation = 1065 BranchInst::Create(ContinuationBlock, RRI.PseudoExit); 1066 1067 // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of 1068 // each of the PHI nodes in the loop header. This feeds into the initial 1069 // value of the same PHI nodes if/when we continue execution. 1070 for (Instruction &I : *LS.Header) { 1071 if (!isa<PHINode>(&I)) 1072 break; 1073 1074 PHINode *PN = cast<PHINode>(&I); 1075 1076 PHINode *NewPHI = PHINode::Create(PN->getType(), 2, PN->getName() + ".copy", 1077 BranchToContinuation); 1078 1079 NewPHI->addIncoming(PN->getIncomingValueForBlock(Preheader), Preheader); 1080 NewPHI->addIncoming(PN->getIncomingValueForBlock(LS.Latch), 1081 RRI.ExitSelector); 1082 RRI.PHIValuesAtPseudoExit.push_back(NewPHI); 1083 } 1084 1085 RRI.IndVarEnd = PHINode::Create(LS.IndVarNext->getType(), 2, "indvar.end", 1086 BranchToContinuation); 1087 RRI.IndVarEnd->addIncoming(LS.IndVarStart, Preheader); 1088 RRI.IndVarEnd->addIncoming(LS.IndVarNext, RRI.ExitSelector); 1089 1090 // The latch exit now has a branch from `RRI.ExitSelector' instead of 1091 // `LS.Latch'. The PHI nodes need to be updated to reflect that. 1092 for (Instruction &I : *LS.LatchExit) { 1093 if (PHINode *PN = dyn_cast<PHINode>(&I)) 1094 replacePHIBlock(PN, LS.Latch, RRI.ExitSelector); 1095 else 1096 break; 1097 } 1098 1099 return RRI; 1100 } 1101 1102 void LoopConstrainer::rewriteIncomingValuesForPHIs( 1103 LoopStructure &LS, BasicBlock *ContinuationBlock, 1104 const LoopConstrainer::RewrittenRangeInfo &RRI) const { 1105 1106 unsigned PHIIndex = 0; 1107 for (Instruction &I : *LS.Header) { 1108 if (!isa<PHINode>(&I)) 1109 break; 1110 1111 PHINode *PN = cast<PHINode>(&I); 1112 1113 for (unsigned i = 0, e = PN->getNumIncomingValues(); i < e; ++i) 1114 if (PN->getIncomingBlock(i) == ContinuationBlock) 1115 PN->setIncomingValue(i, RRI.PHIValuesAtPseudoExit[PHIIndex++]); 1116 } 1117 1118 LS.IndVarStart = RRI.IndVarEnd; 1119 } 1120 1121 BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS, 1122 BasicBlock *OldPreheader, 1123 const char *Tag) const { 1124 1125 BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header); 1126 BranchInst::Create(LS.Header, Preheader); 1127 1128 for (Instruction &I : *LS.Header) { 1129 if (!isa<PHINode>(&I)) 1130 break; 1131 1132 PHINode *PN = cast<PHINode>(&I); 1133 for (unsigned i = 0, e = PN->getNumIncomingValues(); i < e; ++i) 1134 replacePHIBlock(PN, OldPreheader, Preheader); 1135 } 1136 1137 return Preheader; 1138 } 1139 1140 void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) { 1141 Loop *ParentLoop = OriginalLoop.getParentLoop(); 1142 if (!ParentLoop) 1143 return; 1144 1145 for (BasicBlock *BB : BBs) 1146 ParentLoop->addBasicBlockToLoop(BB, LI); 1147 } 1148 1149 bool LoopConstrainer::run() { 1150 BasicBlock *Preheader = nullptr; 1151 LatchTakenCount = SE.getExitCount(&OriginalLoop, MainLoopStructure.Latch); 1152 Preheader = OriginalLoop.getLoopPreheader(); 1153 assert(!isa<SCEVCouldNotCompute>(LatchTakenCount) && Preheader != nullptr && 1154 "preconditions!"); 1155 1156 OriginalPreheader = Preheader; 1157 MainLoopPreheader = Preheader; 1158 1159 Optional<SubRanges> MaybeSR = calculateSubRanges(); 1160 if (!MaybeSR.hasValue()) { 1161 DEBUG(dbgs() << "irce: could not compute subranges\n"); 1162 return false; 1163 } 1164 1165 SubRanges SR = MaybeSR.getValue(); 1166 bool Increasing = MainLoopStructure.IndVarIncreasing; 1167 IntegerType *IVTy = 1168 cast<IntegerType>(MainLoopStructure.IndVarNext->getType()); 1169 1170 SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "irce"); 1171 Instruction *InsertPt = OriginalPreheader->getTerminator(); 1172 1173 // It would have been better to make `PreLoop' and `PostLoop' 1174 // `Optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy 1175 // constructor. 1176 ClonedLoop PreLoop, PostLoop; 1177 bool NeedsPreLoop = 1178 Increasing ? SR.LowLimit.hasValue() : SR.HighLimit.hasValue(); 1179 bool NeedsPostLoop = 1180 Increasing ? SR.HighLimit.hasValue() : SR.LowLimit.hasValue(); 1181 1182 Value *ExitPreLoopAt = nullptr; 1183 Value *ExitMainLoopAt = nullptr; 1184 const SCEVConstant *MinusOneS = 1185 cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */)); 1186 1187 if (NeedsPreLoop) { 1188 const SCEV *ExitPreLoopAtSCEV = nullptr; 1189 1190 if (Increasing) 1191 ExitPreLoopAtSCEV = *SR.LowLimit; 1192 else { 1193 if (CanBeSMin(SE, *SR.HighLimit)) { 1194 DEBUG(dbgs() << "irce: could not prove no-overflow when computing " 1195 << "preloop exit limit. HighLimit = " << *(*SR.HighLimit) 1196 << "\n"); 1197 return false; 1198 } 1199 ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS); 1200 } 1201 1202 ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt); 1203 ExitPreLoopAt->setName("exit.preloop.at"); 1204 } 1205 1206 if (NeedsPostLoop) { 1207 const SCEV *ExitMainLoopAtSCEV = nullptr; 1208 1209 if (Increasing) 1210 ExitMainLoopAtSCEV = *SR.HighLimit; 1211 else { 1212 if (CanBeSMin(SE, *SR.LowLimit)) { 1213 DEBUG(dbgs() << "irce: could not prove no-overflow when computing " 1214 << "mainloop exit limit. LowLimit = " << *(*SR.LowLimit) 1215 << "\n"); 1216 return false; 1217 } 1218 ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS); 1219 } 1220 1221 ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt); 1222 ExitMainLoopAt->setName("exit.mainloop.at"); 1223 } 1224 1225 // We clone these ahead of time so that we don't have to deal with changing 1226 // and temporarily invalid IR as we transform the loops. 1227 if (NeedsPreLoop) 1228 cloneLoop(PreLoop, "preloop"); 1229 if (NeedsPostLoop) 1230 cloneLoop(PostLoop, "postloop"); 1231 1232 RewrittenRangeInfo PreLoopRRI; 1233 1234 if (NeedsPreLoop) { 1235 Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header, 1236 PreLoop.Structure.Header); 1237 1238 MainLoopPreheader = 1239 createPreheader(MainLoopStructure, Preheader, "mainloop"); 1240 PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader, 1241 ExitPreLoopAt, MainLoopPreheader); 1242 rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader, 1243 PreLoopRRI); 1244 } 1245 1246 BasicBlock *PostLoopPreheader = nullptr; 1247 RewrittenRangeInfo PostLoopRRI; 1248 1249 if (NeedsPostLoop) { 1250 PostLoopPreheader = 1251 createPreheader(PostLoop.Structure, Preheader, "postloop"); 1252 PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader, 1253 ExitMainLoopAt, PostLoopPreheader); 1254 rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader, 1255 PostLoopRRI); 1256 } 1257 1258 BasicBlock *NewMainLoopPreheader = 1259 MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr; 1260 BasicBlock *NewBlocks[] = {PostLoopPreheader, PreLoopRRI.PseudoExit, 1261 PreLoopRRI.ExitSelector, PostLoopRRI.PseudoExit, 1262 PostLoopRRI.ExitSelector, NewMainLoopPreheader}; 1263 1264 // Some of the above may be nullptr, filter them out before passing to 1265 // addToParentLoopIfNeeded. 1266 auto NewBlocksEnd = 1267 std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr); 1268 1269 addToParentLoopIfNeeded(makeArrayRef(std::begin(NewBlocks), NewBlocksEnd)); 1270 addToParentLoopIfNeeded(PreLoop.Blocks); 1271 addToParentLoopIfNeeded(PostLoop.Blocks); 1272 1273 DT.recalculate(F); 1274 formLCSSARecursively(OriginalLoop, DT, &LI, &SE); 1275 simplifyLoop(&OriginalLoop, &DT, &LI, &SE, nullptr, true); 1276 1277 return true; 1278 } 1279 1280 /// Computes and returns a range of values for the induction variable (IndVar) 1281 /// in which the range check can be safely elided. If it cannot compute such a 1282 /// range, returns None. 1283 Optional<InductiveRangeCheck::Range> 1284 InductiveRangeCheck::computeSafeIterationSpace( 1285 ScalarEvolution &SE, const SCEVAddRecExpr *IndVar) const { 1286 // IndVar is of the form "A + B * I" (where "I" is the canonical induction 1287 // variable, that may or may not exist as a real llvm::Value in the loop) and 1288 // this inductive range check is a range check on the "C + D * I" ("C" is 1289 // getOffset() and "D" is getScale()). We rewrite the value being range 1290 // checked to "M + N * IndVar" where "N" = "D * B^(-1)" and "M" = "C - NA". 1291 // Currently we support this only for "B" = "D" = { 1 or -1 }, but the code 1292 // can be generalized as needed. 1293 // 1294 // The actual inequalities we solve are of the form 1295 // 1296 // 0 <= M + 1 * IndVar < L given L >= 0 (i.e. N == 1) 1297 // 1298 // The inequality is satisfied by -M <= IndVar < (L - M) [^1]. All additions 1299 // and subtractions are twos-complement wrapping and comparisons are signed. 1300 // 1301 // Proof: 1302 // 1303 // If there exists IndVar such that -M <= IndVar < (L - M) then it follows 1304 // that -M <= (-M + L) [== Eq. 1]. Since L >= 0, if (-M + L) sign-overflows 1305 // then (-M + L) < (-M). Hence by [Eq. 1], (-M + L) could not have 1306 // overflown. 1307 // 1308 // This means IndVar = t + (-M) for t in [0, L). Hence (IndVar + M) = t. 1309 // Hence 0 <= (IndVar + M) < L 1310 1311 // [^1]: Note that the solution does _not_ apply if L < 0; consider values M = 1312 // 127, IndVar = 126 and L = -2 in an i8 world. 1313 1314 if (!IndVar->isAffine()) 1315 return None; 1316 1317 const SCEV *A = IndVar->getStart(); 1318 const SCEVConstant *B = dyn_cast<SCEVConstant>(IndVar->getStepRecurrence(SE)); 1319 if (!B) 1320 return None; 1321 1322 const SCEV *C = getOffset(); 1323 const SCEVConstant *D = dyn_cast<SCEVConstant>(getScale()); 1324 if (D != B) 1325 return None; 1326 1327 ConstantInt *ConstD = D->getValue(); 1328 if (!(ConstD->isMinusOne() || ConstD->isOne())) 1329 return None; 1330 1331 const SCEV *M = SE.getMinusSCEV(C, A); 1332 1333 const SCEV *Begin = SE.getNegativeSCEV(M); 1334 const SCEV *UpperLimit = nullptr; 1335 1336 // We strengthen "0 <= I" to "0 <= I < INT_SMAX" and "I < L" to "0 <= I < L". 1337 // We can potentially do much better here. 1338 if (Value *V = getLength()) { 1339 UpperLimit = SE.getSCEV(V); 1340 } else { 1341 assert(Kind == InductiveRangeCheck::RANGE_CHECK_LOWER && "invariant!"); 1342 unsigned BitWidth = cast<IntegerType>(IndVar->getType())->getBitWidth(); 1343 UpperLimit = SE.getConstant(APInt::getSignedMaxValue(BitWidth)); 1344 } 1345 1346 const SCEV *End = SE.getMinusSCEV(UpperLimit, M); 1347 return InductiveRangeCheck::Range(Begin, End); 1348 } 1349 1350 static Optional<InductiveRangeCheck::Range> 1351 IntersectRange(ScalarEvolution &SE, 1352 const Optional<InductiveRangeCheck::Range> &R1, 1353 const InductiveRangeCheck::Range &R2) { 1354 if (!R1.hasValue()) 1355 return R2; 1356 auto &R1Value = R1.getValue(); 1357 1358 // TODO: we could widen the smaller range and have this work; but for now we 1359 // bail out to keep things simple. 1360 if (R1Value.getType() != R2.getType()) 1361 return None; 1362 1363 const SCEV *NewBegin = SE.getSMaxExpr(R1Value.getBegin(), R2.getBegin()); 1364 const SCEV *NewEnd = SE.getSMinExpr(R1Value.getEnd(), R2.getEnd()); 1365 1366 return InductiveRangeCheck::Range(NewBegin, NewEnd); 1367 } 1368 1369 bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { 1370 if (skipLoop(L)) 1371 return false; 1372 1373 if (L->getBlocks().size() >= LoopSizeCutoff) { 1374 DEBUG(dbgs() << "irce: giving up constraining loop, too large\n";); 1375 return false; 1376 } 1377 1378 BasicBlock *Preheader = L->getLoopPreheader(); 1379 if (!Preheader) { 1380 DEBUG(dbgs() << "irce: loop has no preheader, leaving\n"); 1381 return false; 1382 } 1383 1384 LLVMContext &Context = Preheader->getContext(); 1385 SmallVector<InductiveRangeCheck, 16> RangeChecks; 1386 ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); 1387 BranchProbabilityInfo &BPI = 1388 getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI(); 1389 1390 for (auto BBI : L->getBlocks()) 1391 if (BranchInst *TBI = dyn_cast<BranchInst>(BBI->getTerminator())) 1392 InductiveRangeCheck::extractRangeChecksFromBranch(TBI, L, SE, BPI, 1393 RangeChecks); 1394 1395 if (RangeChecks.empty()) 1396 return false; 1397 1398 auto PrintRecognizedRangeChecks = [&](raw_ostream &OS) { 1399 OS << "irce: looking at loop "; L->print(OS); 1400 OS << "irce: loop has " << RangeChecks.size() 1401 << " inductive range checks: \n"; 1402 for (InductiveRangeCheck &IRC : RangeChecks) 1403 IRC.print(OS); 1404 }; 1405 1406 DEBUG(PrintRecognizedRangeChecks(dbgs())); 1407 1408 if (PrintRangeChecks) 1409 PrintRecognizedRangeChecks(errs()); 1410 1411 const char *FailureReason = nullptr; 1412 Optional<LoopStructure> MaybeLoopStructure = 1413 LoopStructure::parseLoopStructure(SE, BPI, *L, FailureReason); 1414 if (!MaybeLoopStructure.hasValue()) { 1415 DEBUG(dbgs() << "irce: could not parse loop structure: " << FailureReason 1416 << "\n";); 1417 return false; 1418 } 1419 LoopStructure LS = MaybeLoopStructure.getValue(); 1420 bool Increasing = LS.IndVarIncreasing; 1421 const SCEV *MinusOne = 1422 SE.getConstant(LS.IndVarNext->getType(), Increasing ? -1 : 1, true); 1423 const SCEVAddRecExpr *IndVar = 1424 cast<SCEVAddRecExpr>(SE.getAddExpr(SE.getSCEV(LS.IndVarNext), MinusOne)); 1425 1426 Optional<InductiveRangeCheck::Range> SafeIterRange; 1427 Instruction *ExprInsertPt = Preheader->getTerminator(); 1428 1429 SmallVector<InductiveRangeCheck, 4> RangeChecksToEliminate; 1430 1431 IRBuilder<> B(ExprInsertPt); 1432 for (InductiveRangeCheck &IRC : RangeChecks) { 1433 auto Result = IRC.computeSafeIterationSpace(SE, IndVar); 1434 if (Result.hasValue()) { 1435 auto MaybeSafeIterRange = 1436 IntersectRange(SE, SafeIterRange, Result.getValue()); 1437 if (MaybeSafeIterRange.hasValue()) { 1438 RangeChecksToEliminate.push_back(IRC); 1439 SafeIterRange = MaybeSafeIterRange.getValue(); 1440 } 1441 } 1442 } 1443 1444 if (!SafeIterRange.hasValue()) 1445 return false; 1446 1447 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 1448 LoopConstrainer LC(*L, getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), LS, 1449 SE, DT, SafeIterRange.getValue()); 1450 bool Changed = LC.run(); 1451 1452 if (Changed) { 1453 auto PrintConstrainedLoopInfo = [L]() { 1454 dbgs() << "irce: in function "; 1455 dbgs() << L->getHeader()->getParent()->getName() << ": "; 1456 dbgs() << "constrained "; 1457 L->print(dbgs()); 1458 }; 1459 1460 DEBUG(PrintConstrainedLoopInfo()); 1461 1462 if (PrintChangedLoops) 1463 PrintConstrainedLoopInfo(); 1464 1465 // Optimize away the now-redundant range checks. 1466 1467 for (InductiveRangeCheck &IRC : RangeChecksToEliminate) { 1468 ConstantInt *FoldedRangeCheck = IRC.getPassingDirection() 1469 ? ConstantInt::getTrue(Context) 1470 : ConstantInt::getFalse(Context); 1471 IRC.getCheckUse()->set(FoldedRangeCheck); 1472 } 1473 } 1474 1475 return Changed; 1476 } 1477 1478 Pass *llvm::createInductiveRangeCheckEliminationPass() { 1479 return new InductiveRangeCheckElimination; 1480 } 1481