1 //===-- InductiveRangeCheckElimination.cpp - ------------------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // The InductiveRangeCheckElimination pass splits a loop's iteration space into 10 // three disjoint ranges. It does that in a way such that the loop running in 11 // the middle loop provably does not need range checks. As an example, it will 12 // convert 13 // 14 // len = < known positive > 15 // for (i = 0; i < n; i++) { 16 // if (0 <= i && i < len) { 17 // do_something(); 18 // } else { 19 // throw_out_of_bounds(); 20 // } 21 // } 22 // 23 // to 24 // 25 // len = < known positive > 26 // limit = smin(n, len) 27 // // no first segment 28 // for (i = 0; i < limit; i++) { 29 // if (0 <= i && i < len) { // this check is fully redundant 30 // do_something(); 31 // } else { 32 // throw_out_of_bounds(); 33 // } 34 // } 35 // for (i = limit; i < n; i++) { 36 // if (0 <= i && i < len) { 37 // do_something(); 38 // } else { 39 // throw_out_of_bounds(); 40 // } 41 // } 42 //===----------------------------------------------------------------------===// 43 44 #include "llvm/ADT/Optional.h" 45 46 #include "llvm/Analysis/InstructionSimplify.h" 47 #include "llvm/Analysis/LoopInfo.h" 48 #include "llvm/Analysis/LoopPass.h" 49 #include "llvm/Analysis/ScalarEvolution.h" 50 #include "llvm/Analysis/ScalarEvolutionExpander.h" 51 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 52 #include "llvm/Analysis/ValueTracking.h" 53 54 #include "llvm/IR/Dominators.h" 55 #include "llvm/IR/Function.h" 56 #include "llvm/IR/Instructions.h" 57 #include "llvm/IR/IRBuilder.h" 58 #include "llvm/IR/Module.h" 59 #include "llvm/IR/PatternMatch.h" 60 #include "llvm/IR/ValueHandle.h" 61 #include "llvm/IR/Verifier.h" 62 63 #include "llvm/Support/Debug.h" 64 65 #include "llvm/Transforms/Scalar.h" 66 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 67 #include "llvm/Transforms/Utils/Cloning.h" 68 #include "llvm/Transforms/Utils/LoopUtils.h" 69 #include "llvm/Transforms/Utils/SimplifyIndVar.h" 70 #include "llvm/Transforms/Utils/UnrollLoop.h" 71 72 #include "llvm/Pass.h" 73 74 #include <array> 75 76 using namespace llvm; 77 78 cl::opt<unsigned> LoopSizeCutoff("irce-loop-size-cutoff", cl::Hidden, 79 cl::init(64)); 80 81 cl::opt<bool> PrintChangedLoops("irce-print-changed-loops", cl::Hidden, 82 cl::init(false)); 83 84 #define DEBUG_TYPE "irce" 85 86 namespace { 87 88 /// An inductive range check is conditional branch in a loop with 89 /// 90 /// 1. a very cold successor (i.e. the branch jumps to that successor very 91 /// rarely) 92 /// 93 /// and 94 /// 95 /// 2. a condition that is provably true for some range of values taken by the 96 /// containing loop's induction variable. 97 /// 98 /// Currently all inductive range checks are branches conditional on an 99 /// expression of the form 100 /// 101 /// 0 <= (Offset + Scale * I) < Length 102 /// 103 /// where `I' is the canonical induction variable of a loop to which Offset and 104 /// Scale are loop invariant, and Length is >= 0. Currently the 'false' branch 105 /// is considered cold, looking at profiling data to verify that is a TODO. 106 107 class InductiveRangeCheck { 108 const SCEV *Offset; 109 const SCEV *Scale; 110 Value *Length; 111 BranchInst *Branch; 112 113 InductiveRangeCheck() : 114 Offset(nullptr), Scale(nullptr), Length(nullptr), Branch(nullptr) { } 115 116 public: 117 const SCEV *getOffset() const { return Offset; } 118 const SCEV *getScale() const { return Scale; } 119 Value *getLength() const { return Length; } 120 121 void print(raw_ostream &OS) const { 122 OS << "InductiveRangeCheck:\n"; 123 OS << " Offset: "; 124 Offset->print(OS); 125 OS << " Scale: "; 126 Scale->print(OS); 127 OS << " Length: "; 128 Length->print(OS); 129 OS << " Branch: "; 130 getBranch()->print(OS); 131 } 132 133 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) 134 void dump() { 135 print(dbgs()); 136 } 137 #endif 138 139 BranchInst *getBranch() const { return Branch; } 140 141 /// Represents an integer range [Range.first, Range.second). If Range.second 142 /// < Range.first, then the value denotes the empty range. 143 typedef std::pair<Value *, Value *> Range; 144 typedef SpecificBumpPtrAllocator<InductiveRangeCheck> AllocatorTy; 145 146 /// This is the value the condition of the branch needs to evaluate to for the 147 /// branch to take the hot successor (see (1) above). 148 bool getPassingDirection() { return true; } 149 150 /// Computes a range for the induction variable in which the range check is 151 /// redundant and can be constant-folded away. 152 Optional<Range> computeSafeIterationSpace(ScalarEvolution &SE, 153 IRBuilder<> &B) const; 154 155 /// Create an inductive range check out of BI if possible, else return 156 /// nullptr. 157 static InductiveRangeCheck *create(AllocatorTy &Alloc, BranchInst *BI, 158 Loop *L, ScalarEvolution &SE); 159 }; 160 161 class InductiveRangeCheckElimination : public LoopPass { 162 InductiveRangeCheck::AllocatorTy Allocator; 163 164 public: 165 static char ID; 166 InductiveRangeCheckElimination() : LoopPass(ID) { 167 initializeInductiveRangeCheckEliminationPass( 168 *PassRegistry::getPassRegistry()); 169 } 170 171 void getAnalysisUsage(AnalysisUsage &AU) const override { 172 AU.addRequired<LoopInfoWrapperPass>(); 173 AU.addRequiredID(LoopSimplifyID); 174 AU.addRequiredID(LCSSAID); 175 AU.addRequired<ScalarEvolution>(); 176 } 177 178 bool runOnLoop(Loop *L, LPPassManager &LPM) override; 179 }; 180 181 char InductiveRangeCheckElimination::ID = 0; 182 } 183 184 INITIALIZE_PASS(InductiveRangeCheckElimination, "irce", 185 "Inductive range check elimination", false, false) 186 187 static bool IsLowerBoundCheck(Value *Check, Value *&IndexV) { 188 using namespace llvm::PatternMatch; 189 190 ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; 191 Value *LHS = nullptr, *RHS = nullptr; 192 193 if (!match(Check, m_ICmp(Pred, m_Value(LHS), m_Value(RHS)))) 194 return false; 195 196 switch (Pred) { 197 default: 198 return false; 199 200 case ICmpInst::ICMP_SLE: 201 std::swap(LHS, RHS); 202 // fallthrough 203 case ICmpInst::ICMP_SGE: 204 if (!match(RHS, m_ConstantInt<0>())) 205 return false; 206 IndexV = LHS; 207 return true; 208 209 case ICmpInst::ICMP_SLT: 210 std::swap(LHS, RHS); 211 // fallthrough 212 case ICmpInst::ICMP_SGT: 213 if (!match(RHS, m_ConstantInt<-1>())) 214 return false; 215 IndexV = LHS; 216 return true; 217 } 218 } 219 220 static bool IsUpperBoundCheck(Value *Check, Value *Index, Value *&UpperLimit) { 221 using namespace llvm::PatternMatch; 222 223 ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; 224 Value *LHS = nullptr, *RHS = nullptr; 225 226 if (!match(Check, m_ICmp(Pred, m_Value(LHS), m_Value(RHS)))) 227 return false; 228 229 switch (Pred) { 230 default: 231 return false; 232 233 case ICmpInst::ICMP_SGT: 234 std::swap(LHS, RHS); 235 // fallthrough 236 case ICmpInst::ICMP_SLT: 237 if (LHS != Index) 238 return false; 239 UpperLimit = RHS; 240 return true; 241 242 case ICmpInst::ICMP_UGT: 243 std::swap(LHS, RHS); 244 // fallthrough 245 case ICmpInst::ICMP_ULT: 246 if (LHS != Index) 247 return false; 248 UpperLimit = RHS; 249 return true; 250 } 251 } 252 253 /// Split a condition into something semantically equivalent to (0 <= I < 254 /// Limit), both comparisons signed and Len loop invariant on L and positive. 255 /// On success, return true and set Index to I and UpperLimit to Limit. Return 256 /// false on failure (we may still write to UpperLimit and Index on failure). 257 /// It does not try to interpret I as a loop index. 258 /// 259 static bool SplitRangeCheckCondition(Loop *L, ScalarEvolution &SE, 260 Value *Condition, const SCEV *&Index, 261 Value *&UpperLimit) { 262 263 // TODO: currently this catches some silly cases like comparing "%idx slt 1". 264 // Our transformations are still correct, but less likely to be profitable in 265 // those cases. We have to come up with some heuristics that pick out the 266 // range checks that are more profitable to clone a loop for. This function 267 // in general can be made more robust. 268 269 using namespace llvm::PatternMatch; 270 271 Value *A = nullptr; 272 Value *B = nullptr; 273 ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; 274 275 // In these early checks we assume that the matched UpperLimit is positive. 276 // We'll verify that fact later, before returning true. 277 278 if (match(Condition, m_And(m_Value(A), m_Value(B)))) { 279 Value *IndexV = nullptr; 280 Value *ExpectedUpperBoundCheck = nullptr; 281 282 if (IsLowerBoundCheck(A, IndexV)) 283 ExpectedUpperBoundCheck = B; 284 else if (IsLowerBoundCheck(B, IndexV)) 285 ExpectedUpperBoundCheck = A; 286 else 287 return false; 288 289 if (!IsUpperBoundCheck(ExpectedUpperBoundCheck, IndexV, UpperLimit)) 290 return false; 291 292 Index = SE.getSCEV(IndexV); 293 294 if (isa<SCEVCouldNotCompute>(Index)) 295 return false; 296 297 } else if (match(Condition, m_ICmp(Pred, m_Value(A), m_Value(B)))) { 298 switch (Pred) { 299 default: 300 return false; 301 302 case ICmpInst::ICMP_SGT: 303 std::swap(A, B); 304 // fall through 305 case ICmpInst::ICMP_SLT: 306 UpperLimit = B; 307 Index = SE.getSCEV(A); 308 if (isa<SCEVCouldNotCompute>(Index) || !SE.isKnownNonNegative(Index)) 309 return false; 310 break; 311 312 case ICmpInst::ICMP_UGT: 313 std::swap(A, B); 314 // fall through 315 case ICmpInst::ICMP_ULT: 316 UpperLimit = B; 317 Index = SE.getSCEV(A); 318 if (isa<SCEVCouldNotCompute>(Index)) 319 return false; 320 break; 321 } 322 } else { 323 return false; 324 } 325 326 const SCEV *UpperLimitSCEV = SE.getSCEV(UpperLimit); 327 if (isa<SCEVCouldNotCompute>(UpperLimitSCEV) || 328 !SE.isKnownNonNegative(UpperLimitSCEV)) 329 return false; 330 331 if (SE.getLoopDisposition(UpperLimitSCEV, L) != 332 ScalarEvolution::LoopInvariant) { 333 DEBUG(dbgs() << " in function: " << L->getHeader()->getParent()->getName() 334 << " "; 335 dbgs() << " UpperLimit is not loop invariant: " 336 << UpperLimit->getName() << "\n";); 337 return false; 338 } 339 340 return true; 341 } 342 343 InductiveRangeCheck * 344 InductiveRangeCheck::create(InductiveRangeCheck::AllocatorTy &A, BranchInst *BI, 345 Loop *L, ScalarEvolution &SE) { 346 347 if (BI->isUnconditional() || BI->getParent() == L->getLoopLatch()) 348 return nullptr; 349 350 Value *Length = nullptr; 351 const SCEV *IndexSCEV = nullptr; 352 353 if (!SplitRangeCheckCondition(L, SE, BI->getCondition(), IndexSCEV, Length)) 354 return nullptr; 355 356 assert(IndexSCEV && Length && "contract with SplitRangeCheckCondition!"); 357 358 const SCEVAddRecExpr *IndexAddRec = dyn_cast<SCEVAddRecExpr>(IndexSCEV); 359 bool IsAffineIndex = 360 IndexAddRec && (IndexAddRec->getLoop() == L) && IndexAddRec->isAffine(); 361 362 if (!IsAffineIndex) 363 return nullptr; 364 365 InductiveRangeCheck *IRC = new (A.Allocate()) InductiveRangeCheck; 366 IRC->Length = Length; 367 IRC->Offset = IndexAddRec->getStart(); 368 IRC->Scale = IndexAddRec->getStepRecurrence(SE); 369 IRC->Branch = BI; 370 return IRC; 371 } 372 373 static Value *MaybeSimplify(Value *V) { 374 if (Instruction *I = dyn_cast<Instruction>(V)) 375 if (Value *Simplified = SimplifyInstruction(I)) 376 return Simplified; 377 return V; 378 } 379 380 static Value *ConstructSMinOf(Value *X, Value *Y, IRBuilder<> &B) { 381 return MaybeSimplify(B.CreateSelect(B.CreateICmpSLT(X, Y), X, Y)); 382 } 383 384 static Value *ConstructSMaxOf(Value *X, Value *Y, IRBuilder<> &B) { 385 return MaybeSimplify(B.CreateSelect(B.CreateICmpSGT(X, Y), X, Y)); 386 } 387 388 namespace { 389 390 /// This class is used to constrain loops to run within a given iteration space. 391 /// The algorithm this class implements is given a Loop and a range [Begin, 392 /// End). The algorithm then tries to break out a "main loop" out of the loop 393 /// it is given in a way that the "main loop" runs with the induction variable 394 /// in a subset of [Begin, End). The algorithm emits appropriate pre and post 395 /// loops to run any remaining iterations. The pre loop runs any iterations in 396 /// which the induction variable is < Begin, and the post loop runs any 397 /// iterations in which the induction variable is >= End. 398 /// 399 class LoopConstrainer { 400 401 // Keeps track of the structure of a loop. This is similar to llvm::Loop, 402 // except that it is more lightweight and can track the state of a loop 403 // through changing and potentially invalid IR. This structure also 404 // formalizes the kinds of loops we can deal with -- ones that have a single 405 // latch that is also an exiting block *and* have a canonical induction 406 // variable. 407 struct LoopStructure { 408 const char *Tag; 409 410 BasicBlock *Header; 411 BasicBlock *Latch; 412 413 // `Latch's terminator instruction is `LatchBr', and it's `LatchBrExitIdx'th 414 // successor is `LatchExit', the exit block of the loop. 415 BranchInst *LatchBr; 416 BasicBlock *LatchExit; 417 unsigned LatchBrExitIdx; 418 419 // The canonical induction variable. It's value is `CIVStart` on the 0th 420 // itertion and `CIVNext` for all iterations after that. 421 PHINode *CIV; 422 Value *CIVStart; 423 Value *CIVNext; 424 425 LoopStructure() : Tag(""), Header(nullptr), Latch(nullptr), 426 LatchBr(nullptr), LatchExit(nullptr), 427 LatchBrExitIdx(-1), CIV(nullptr), 428 CIVStart(nullptr), CIVNext(nullptr) { } 429 430 template <typename M> LoopStructure map(M Map) const { 431 LoopStructure Result; 432 Result.Tag = Tag; 433 Result.Header = cast<BasicBlock>(Map(Header)); 434 Result.Latch = cast<BasicBlock>(Map(Latch)); 435 Result.LatchBr = cast<BranchInst>(Map(LatchBr)); 436 Result.LatchExit = cast<BasicBlock>(Map(LatchExit)); 437 Result.LatchBrExitIdx = LatchBrExitIdx; 438 Result.CIV = cast<PHINode>(Map(CIV)); 439 Result.CIVNext = Map(CIVNext); 440 Result.CIVStart = Map(CIVStart); 441 return Result; 442 } 443 }; 444 445 // The representation of a clone of the original loop we started out with. 446 struct ClonedLoop { 447 // The cloned blocks 448 std::vector<BasicBlock *> Blocks; 449 450 // `Map` maps values in the clonee into values in the cloned version 451 ValueToValueMapTy Map; 452 453 // An instance of `LoopStructure` for the cloned loop 454 LoopStructure Structure; 455 }; 456 457 // Result of rewriting the range of a loop. See changeIterationSpaceEnd for 458 // more details on what these fields mean. 459 struct RewrittenRangeInfo { 460 BasicBlock *PseudoExit; 461 BasicBlock *ExitSelector; 462 std::vector<PHINode *> PHIValuesAtPseudoExit; 463 464 RewrittenRangeInfo() : PseudoExit(nullptr), ExitSelector(nullptr) { } 465 }; 466 467 // Calculated subranges we restrict the iteration space of the main loop to. 468 // See the implementation of `calculateSubRanges' for more details on how 469 // these fields are computed. `ExitPreLoopAt' is `None' if we don't need a 470 // pre loop. `ExitMainLoopAt' is `None' if we don't need a post loop. 471 struct SubRanges { 472 Optional<Value *> ExitPreLoopAt; 473 Optional<Value *> ExitMainLoopAt; 474 }; 475 476 // A utility function that does a `replaceUsesOfWith' on the incoming block 477 // set of a `PHINode' -- replaces instances of `Block' in the `PHINode's 478 // incoming block list with `ReplaceBy'. 479 static void replacePHIBlock(PHINode *PN, BasicBlock *Block, 480 BasicBlock *ReplaceBy); 481 482 // Try to "parse" `OriginalLoop' and populate the various out parameters. 483 // Returns true on success, false on failure. 484 // 485 bool recognizeLoop(LoopStructure &LoopStructureOut, 486 const SCEV *&LatchCountOut, BasicBlock *&PreHeaderOut, 487 const char *&FailureReasonOut) const; 488 489 // Compute a safe set of limits for the main loop to run in -- effectively the 490 // intersection of `Range' and the iteration space of the original loop. 491 // Return the header count (1 + the latch taken count) in `HeaderCount'. 492 // 493 SubRanges calculateSubRanges(Value *&HeaderCount) const; 494 495 // Clone `OriginalLoop' and return the result in CLResult. The IR after 496 // running `cloneLoop' is well formed except for the PHI nodes in CLResult -- 497 // the PHI nodes say that there is an incoming edge from `OriginalPreheader` 498 // but there is no such edge. 499 // 500 void cloneLoop(ClonedLoop &CLResult, const char *Tag) const; 501 502 // Rewrite the iteration space of the loop denoted by (LS, Preheader). The 503 // iteration space of the rewritten loop ends at ExitLoopAt. The start of the 504 // iteration space is not changed. `ExitLoopAt' is assumed to be slt 505 // `OriginalHeaderCount'. 506 // 507 // If there are iterations left to execute, control is made to jump to 508 // `ContinuationBlock', otherwise they take the normal loop exit. The 509 // returned `RewrittenRangeInfo' object is populated as follows: 510 // 511 // .PseudoExit is a basic block that unconditionally branches to 512 // `ContinuationBlock'. 513 // 514 // .ExitSelector is a basic block that decides, on exit from the loop, 515 // whether to branch to the "true" exit or to `PseudoExit'. 516 // 517 // .PHIValuesAtPseudoExit are PHINodes in `PseudoExit' that compute the value 518 // for each PHINode in the loop header on taking the pseudo exit. 519 // 520 // After changeIterationSpaceEnd, `Preheader' is no longer a legitimate 521 // preheader because it is made to branch to the loop header only 522 // conditionally. 523 // 524 RewrittenRangeInfo 525 changeIterationSpaceEnd(const LoopStructure &LS, BasicBlock *Preheader, 526 Value *ExitLoopAt, 527 BasicBlock *ContinuationBlock) const; 528 529 // The loop denoted by `LS' has `OldPreheader' as its preheader. This 530 // function creates a new preheader for `LS' and returns it. 531 // 532 BasicBlock *createPreheader(const LoopConstrainer::LoopStructure &LS, 533 BasicBlock *OldPreheader, const char *Tag) const; 534 535 // `ContinuationBlockAndPreheader' was the continuation block for some call to 536 // `changeIterationSpaceEnd' and is the preheader to the loop denoted by `LS'. 537 // This function rewrites the PHI nodes in `LS.Header' to start with the 538 // correct value. 539 void rewriteIncomingValuesForPHIs( 540 LoopConstrainer::LoopStructure &LS, 541 BasicBlock *ContinuationBlockAndPreheader, 542 const LoopConstrainer::RewrittenRangeInfo &RRI) const; 543 544 // Even though we do not preserve any passes at this time, we at least need to 545 // keep the parent loop structure consistent. The `LPPassManager' seems to 546 // verify this after running a loop pass. This function adds the list of 547 // blocks denoted by the iterator range [BlocksBegin, BlocksEnd) to this loops 548 // parent loop if required. 549 template<typename IteratorTy> 550 void addToParentLoopIfNeeded(IteratorTy BlocksBegin, IteratorTy BlocksEnd); 551 552 // Some global state. 553 Function &F; 554 LLVMContext &Ctx; 555 ScalarEvolution &SE; 556 557 // Information about the original loop we started out with. 558 Loop &OriginalLoop; 559 LoopInfo &OriginalLoopInfo; 560 const SCEV *LatchTakenCount; 561 BasicBlock *OriginalPreheader; 562 Value *OriginalHeaderCount; 563 564 // The preheader of the main loop. This may or may not be different from 565 // `OriginalPreheader'. 566 BasicBlock *MainLoopPreheader; 567 568 // The range we need to run the main loop in. 569 InductiveRangeCheck::Range Range; 570 571 // The structure of the main loop (see comment at the beginning of this class 572 // for a definition) 573 LoopStructure MainLoopStructure; 574 575 public: 576 LoopConstrainer(Loop &L, LoopInfo &LI, ScalarEvolution &SE, 577 InductiveRangeCheck::Range R) 578 : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), SE(SE), 579 OriginalLoop(L), OriginalLoopInfo(LI), LatchTakenCount(nullptr), 580 OriginalPreheader(nullptr), OriginalHeaderCount(nullptr), 581 MainLoopPreheader(nullptr), Range(R) { } 582 583 // Entry point for the algorithm. Returns true on success. 584 bool run(); 585 }; 586 587 } 588 589 void LoopConstrainer::replacePHIBlock(PHINode *PN, BasicBlock *Block, 590 BasicBlock *ReplaceBy) { 591 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) 592 if (PN->getIncomingBlock(i) == Block) 593 PN->setIncomingBlock(i, ReplaceBy); 594 } 595 596 bool LoopConstrainer::recognizeLoop(LoopStructure &LoopStructureOut, 597 const SCEV *&LatchCountOut, 598 BasicBlock *&PreheaderOut, 599 const char *&FailureReason) const { 600 using namespace llvm::PatternMatch; 601 602 assert(OriginalLoop.isLoopSimplifyForm() && 603 "should follow from addRequired<>"); 604 605 BasicBlock *Latch = OriginalLoop.getLoopLatch(); 606 if (!OriginalLoop.isLoopExiting(Latch)) { 607 FailureReason = "no loop latch"; 608 return false; 609 } 610 611 PHINode *CIV = OriginalLoop.getCanonicalInductionVariable(); 612 if (!CIV) { 613 FailureReason = "no CIV"; 614 return false; 615 } 616 617 BasicBlock *Header = OriginalLoop.getHeader(); 618 BasicBlock *Preheader = OriginalLoop.getLoopPreheader(); 619 if (!Preheader) { 620 FailureReason = "no preheader"; 621 return false; 622 } 623 624 Value *CIVNext = CIV->getIncomingValueForBlock(Latch); 625 Value *CIVStart = CIV->getIncomingValueForBlock(Preheader); 626 627 const SCEV *LatchCount = SE.getExitCount(&OriginalLoop, Latch); 628 if (isa<SCEVCouldNotCompute>(LatchCount)) { 629 FailureReason = "could not compute latch count"; 630 return false; 631 } 632 633 // While SCEV does most of the analysis for us, we still have to 634 // modify the latch; and currently we can only deal with certain 635 // kinds of latches. This can be made more sophisticated as needed. 636 637 BranchInst *LatchBr = dyn_cast<BranchInst>(&*Latch->rbegin()); 638 639 if (!LatchBr || LatchBr->isUnconditional()) { 640 FailureReason = "latch terminator not conditional branch"; 641 return false; 642 } 643 644 // Currently we only support a latch condition of the form: 645 // 646 // %condition = icmp slt %civNext, %limit 647 // br i1 %condition, label %header, label %exit 648 649 if (LatchBr->getSuccessor(0) != Header) { 650 FailureReason = "unknown latch form (header not first successor)"; 651 return false; 652 } 653 654 Value *CIVComparedTo = nullptr; 655 ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; 656 if (!(match(LatchBr->getCondition(), 657 m_ICmp(Pred, m_Specific(CIVNext), m_Value(CIVComparedTo))) && 658 Pred == ICmpInst::ICMP_SLT)) { 659 FailureReason = "unknown latch form (not slt)"; 660 return false; 661 } 662 663 const SCEV *CIVComparedToSCEV = SE.getSCEV(CIVComparedTo); 664 if (isa<SCEVCouldNotCompute>(CIVComparedToSCEV)) { 665 FailureReason = "could not relate CIV to latch expression"; 666 return false; 667 } 668 669 const SCEV *ShouldBeOne = SE.getMinusSCEV(CIVComparedToSCEV, LatchCount); 670 const SCEVConstant *SCEVOne = dyn_cast<SCEVConstant>(ShouldBeOne); 671 if (!SCEVOne || SCEVOne->getValue()->getValue() != 1) { 672 FailureReason = "unexpected header count in latch"; 673 return false; 674 } 675 676 unsigned LatchBrExitIdx = 1; 677 BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx); 678 679 assert(SE.getLoopDisposition(LatchCount, &OriginalLoop) == 680 ScalarEvolution::LoopInvariant && 681 "loop variant exit count doesn't make sense!"); 682 683 assert(!OriginalLoop.contains(LatchExit) && "expected an exit block!"); 684 685 LoopStructureOut.Tag = "main"; 686 LoopStructureOut.Header = Header; 687 LoopStructureOut.Latch = Latch; 688 LoopStructureOut.LatchBr = LatchBr; 689 LoopStructureOut.LatchExit = LatchExit; 690 LoopStructureOut.LatchBrExitIdx = LatchBrExitIdx; 691 LoopStructureOut.CIV = CIV; 692 LoopStructureOut.CIVNext = CIVNext; 693 LoopStructureOut.CIVStart = CIVStart; 694 695 LatchCountOut = LatchCount; 696 PreheaderOut = Preheader; 697 FailureReason = nullptr; 698 699 return true; 700 } 701 702 LoopConstrainer::SubRanges 703 LoopConstrainer::calculateSubRanges(Value *&HeaderCountOut) const { 704 IntegerType *Ty = cast<IntegerType>(LatchTakenCount->getType()); 705 706 SCEVExpander Expander(SE, "irce"); 707 Instruction *InsertPt = OriginalPreheader->getTerminator(); 708 709 Value *LatchCountV = 710 MaybeSimplify(Expander.expandCodeFor(LatchTakenCount, Ty, InsertPt)); 711 712 IRBuilder<> B(InsertPt); 713 714 LoopConstrainer::SubRanges Result; 715 716 // I think we can be more aggressive here and make this nuw / nsw if the 717 // addition that feeds into the icmp for the latch's terminating branch is nuw 718 // / nsw. In any case, a wrapping 2's complement addition is safe. 719 ConstantInt *One = ConstantInt::get(Ty, 1); 720 HeaderCountOut = MaybeSimplify(B.CreateAdd(LatchCountV, One, "header.count")); 721 722 const SCEV *RangeBegin = SE.getSCEV(Range.first); 723 const SCEV *RangeEnd = SE.getSCEV(Range.second); 724 const SCEV *HeaderCountSCEV = SE.getSCEV(HeaderCountOut); 725 const SCEV *Zero = SE.getConstant(Ty, 0); 726 727 // In some cases we can prove that we don't need a pre or post loop 728 729 bool ProvablyNoPreloop = 730 SE.isKnownPredicate(ICmpInst::ICMP_SLE, RangeBegin, Zero); 731 if (!ProvablyNoPreloop) 732 Result.ExitPreLoopAt = ConstructSMinOf(HeaderCountOut, Range.first, B); 733 734 bool ProvablyNoPostLoop = 735 SE.isKnownPredicate(ICmpInst::ICMP_SLE, HeaderCountSCEV, RangeEnd); 736 if (!ProvablyNoPostLoop) 737 Result.ExitMainLoopAt = ConstructSMinOf(HeaderCountOut, Range.second, B); 738 739 return Result; 740 } 741 742 void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result, 743 const char *Tag) const { 744 for (BasicBlock *BB : OriginalLoop.getBlocks()) { 745 BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F); 746 Result.Blocks.push_back(Clone); 747 Result.Map[BB] = Clone; 748 } 749 750 auto GetClonedValue = [&Result](Value *V) { 751 assert(V && "null values not in domain!"); 752 auto It = Result.Map.find(V); 753 if (It == Result.Map.end()) 754 return V; 755 return static_cast<Value *>(It->second); 756 }; 757 758 Result.Structure = MainLoopStructure.map(GetClonedValue); 759 Result.Structure.Tag = Tag; 760 761 for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) { 762 BasicBlock *ClonedBB = Result.Blocks[i]; 763 BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i]; 764 765 assert(Result.Map[OriginalBB] == ClonedBB && "invariant!"); 766 767 for (Instruction &I : *ClonedBB) 768 RemapInstruction(&I, Result.Map, 769 RF_NoModuleLevelChanges | RF_IgnoreMissingEntries); 770 771 // Exit blocks will now have one more predecessor and their PHI nodes need 772 // to be edited to reflect that. No phi nodes need to be introduced because 773 // the loop is in LCSSA. 774 775 for (auto SBBI = succ_begin(OriginalBB), SBBE = succ_end(OriginalBB); 776 SBBI != SBBE; ++SBBI) { 777 778 if (OriginalLoop.contains(*SBBI)) 779 continue; // not an exit block 780 781 for (Instruction &I : **SBBI) { 782 if (!isa<PHINode>(&I)) 783 break; 784 785 PHINode *PN = cast<PHINode>(&I); 786 Value *OldIncoming = PN->getIncomingValueForBlock(OriginalBB); 787 PN->addIncoming(GetClonedValue(OldIncoming), ClonedBB); 788 } 789 } 790 } 791 } 792 793 LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( 794 const LoopStructure &LS, BasicBlock *Preheader, Value *ExitLoopAt, 795 BasicBlock *ContinuationBlock) const { 796 797 // We start with a loop with a single latch: 798 // 799 // +--------------------+ 800 // | | 801 // | preheader | 802 // | | 803 // +--------+-----------+ 804 // | ----------------\ 805 // | / | 806 // +--------v----v------+ | 807 // | | | 808 // | header | | 809 // | | | 810 // +--------------------+ | 811 // | 812 // ..... | 813 // | 814 // +--------------------+ | 815 // | | | 816 // | latch >----------/ 817 // | | 818 // +-------v------------+ 819 // | 820 // | 821 // | +--------------------+ 822 // | | | 823 // +---> original exit | 824 // | | 825 // +--------------------+ 826 // 827 // We change the control flow to look like 828 // 829 // 830 // +--------------------+ 831 // | | 832 // | preheader >-------------------------+ 833 // | | | 834 // +--------v-----------+ | 835 // | /-------------+ | 836 // | / | | 837 // +--------v--v--------+ | | 838 // | | | | 839 // | header | | +--------+ | 840 // | | | | | | 841 // +--------------------+ | | +-----v-----v-----------+ 842 // | | | | 843 // | | | .pseudo.exit | 844 // | | | | 845 // | | +-----------v-----------+ 846 // | | | 847 // ..... | | | 848 // | | +--------v-------------+ 849 // +--------------------+ | | | | 850 // | | | | | ContinuationBlock | 851 // | latch >------+ | | | 852 // | | | +----------------------+ 853 // +---------v----------+ | 854 // | | 855 // | | 856 // | +---------------^-----+ 857 // | | | 858 // +-----> .exit.selector | 859 // | | 860 // +----------v----------+ 861 // | 862 // +--------------------+ | 863 // | | | 864 // | original exit <----+ 865 // | | 866 // +--------------------+ 867 // 868 869 RewrittenRangeInfo RRI; 870 871 auto BBInsertLocation = std::next(Function::iterator(LS.Latch)); 872 RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector", 873 &F, BBInsertLocation); 874 RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F, 875 BBInsertLocation); 876 877 BranchInst *PreheaderJump = cast<BranchInst>(&*Preheader->rbegin()); 878 879 IRBuilder<> B(PreheaderJump); 880 881 // EnterLoopCond - is it okay to start executing this `LS'? 882 Value *EnterLoopCond = B.CreateICmpSLT(LS.CIVStart, ExitLoopAt); 883 B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit); 884 PreheaderJump->eraseFromParent(); 885 886 assert(LS.LatchBrExitIdx == 1 && "generalize this as needed!"); 887 888 B.SetInsertPoint(LS.LatchBr); 889 890 // ContinueCond - is it okay to execute the next iteration in `LS'? 891 Value *ContinueCond = B.CreateICmpSLT(LS.CIVNext, ExitLoopAt); 892 893 LS.LatchBr->setCondition(ContinueCond); 894 assert(LS.LatchBr->getSuccessor(LS.LatchBrExitIdx) == LS.LatchExit && 895 "invariant!"); 896 LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector); 897 898 B.SetInsertPoint(RRI.ExitSelector); 899 900 // IterationsLeft - are there any more iterations left, given the original 901 // upper bound on the induction variable? If not, we branch to the "real" 902 // exit. 903 Value *IterationsLeft = B.CreateICmpSLT(LS.CIVNext, OriginalHeaderCount); 904 B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit); 905 906 BranchInst *BranchToContinuation = 907 BranchInst::Create(ContinuationBlock, RRI.PseudoExit); 908 909 // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of 910 // each of the PHI nodes in the loop header. This feeds into the initial 911 // value of the same PHI nodes if/when we continue execution. 912 for (Instruction &I : *LS.Header) { 913 if (!isa<PHINode>(&I)) 914 break; 915 916 PHINode *PN = cast<PHINode>(&I); 917 918 PHINode *NewPHI = PHINode::Create(PN->getType(), 2, PN->getName() + ".copy", 919 BranchToContinuation); 920 921 NewPHI->addIncoming(PN->getIncomingValueForBlock(Preheader), Preheader); 922 NewPHI->addIncoming(PN->getIncomingValueForBlock(LS.Latch), 923 RRI.ExitSelector); 924 RRI.PHIValuesAtPseudoExit.push_back(NewPHI); 925 } 926 927 // The latch exit now has a branch from `RRI.ExitSelector' instead of 928 // `LS.Latch'. The PHI nodes need to be updated to reflect that. 929 for (Instruction &I : *LS.LatchExit) { 930 if (PHINode *PN = dyn_cast<PHINode>(&I)) 931 replacePHIBlock(PN, LS.Latch, RRI.ExitSelector); 932 else 933 break; 934 } 935 936 return RRI; 937 } 938 939 void LoopConstrainer::rewriteIncomingValuesForPHIs( 940 LoopConstrainer::LoopStructure &LS, BasicBlock *ContinuationBlock, 941 const LoopConstrainer::RewrittenRangeInfo &RRI) const { 942 943 unsigned PHIIndex = 0; 944 for (Instruction &I : *LS.Header) { 945 if (!isa<PHINode>(&I)) 946 break; 947 948 PHINode *PN = cast<PHINode>(&I); 949 950 for (unsigned i = 0, e = PN->getNumIncomingValues(); i < e; ++i) 951 if (PN->getIncomingBlock(i) == ContinuationBlock) 952 PN->setIncomingValue(i, RRI.PHIValuesAtPseudoExit[PHIIndex++]); 953 } 954 955 LS.CIVStart = LS.CIV->getIncomingValueForBlock(ContinuationBlock); 956 } 957 958 BasicBlock * 959 LoopConstrainer::createPreheader(const LoopConstrainer::LoopStructure &LS, 960 BasicBlock *OldPreheader, 961 const char *Tag) const { 962 963 BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header); 964 BranchInst::Create(LS.Header, Preheader); 965 966 for (Instruction &I : *LS.Header) { 967 if (!isa<PHINode>(&I)) 968 break; 969 970 PHINode *PN = cast<PHINode>(&I); 971 for (unsigned i = 0, e = PN->getNumIncomingValues(); i < e; ++i) 972 replacePHIBlock(PN, OldPreheader, Preheader); 973 } 974 975 return Preheader; 976 } 977 978 template<typename IteratorTy> 979 void LoopConstrainer::addToParentLoopIfNeeded(IteratorTy Begin, 980 IteratorTy End) { 981 Loop *ParentLoop = OriginalLoop.getParentLoop(); 982 if (!ParentLoop) 983 return; 984 985 for (; Begin != End; Begin++) 986 ParentLoop->addBasicBlockToLoop(*Begin, OriginalLoopInfo); 987 } 988 989 bool LoopConstrainer::run() { 990 BasicBlock *Preheader = nullptr; 991 const char *CouldNotProceedBecause = nullptr; 992 if (!recognizeLoop(MainLoopStructure, LatchTakenCount, Preheader, 993 CouldNotProceedBecause)) { 994 DEBUG(dbgs() << "irce: could not recognize loop, " << CouldNotProceedBecause 995 << "\n";); 996 return false; 997 } 998 999 OriginalPreheader = Preheader; 1000 MainLoopPreheader = Preheader; 1001 1002 SubRanges SR = calculateSubRanges(OriginalHeaderCount); 1003 1004 // It would have been better to make `PreLoop' and `PostLoop' 1005 // `Optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy 1006 // constructor. 1007 ClonedLoop PreLoop, PostLoop; 1008 bool NeedsPreLoop = SR.ExitPreLoopAt.hasValue(); 1009 bool NeedsPostLoop = SR.ExitMainLoopAt.hasValue(); 1010 1011 // We clone these ahead of time so that we don't have to deal with changing 1012 // and temporarily invalid IR as we transform the loops. 1013 if (NeedsPreLoop) 1014 cloneLoop(PreLoop, "preloop"); 1015 if (NeedsPostLoop) 1016 cloneLoop(PostLoop, "postloop"); 1017 1018 RewrittenRangeInfo PreLoopRRI; 1019 1020 if (NeedsPreLoop) { 1021 Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header, 1022 PreLoop.Structure.Header); 1023 1024 MainLoopPreheader = 1025 createPreheader(MainLoopStructure, Preheader, "mainloop"); 1026 PreLoopRRI = 1027 changeIterationSpaceEnd(PreLoop.Structure, Preheader, 1028 SR.ExitPreLoopAt.getValue(), MainLoopPreheader); 1029 rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader, 1030 PreLoopRRI); 1031 } 1032 1033 BasicBlock *PostLoopPreheader = nullptr; 1034 RewrittenRangeInfo PostLoopRRI; 1035 1036 if (NeedsPostLoop) { 1037 PostLoopPreheader = 1038 createPreheader(PostLoop.Structure, Preheader, "postloop"); 1039 PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader, 1040 SR.ExitMainLoopAt.getValue(), 1041 PostLoopPreheader); 1042 rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader, 1043 PostLoopRRI); 1044 } 1045 1046 SmallVector<BasicBlock *, 6> NewBlocks; 1047 NewBlocks.push_back(PostLoopPreheader); 1048 NewBlocks.push_back(PreLoopRRI.PseudoExit); 1049 NewBlocks.push_back(PreLoopRRI.ExitSelector); 1050 NewBlocks.push_back(PostLoopRRI.PseudoExit); 1051 NewBlocks.push_back(PostLoopRRI.ExitSelector); 1052 if (MainLoopPreheader != Preheader) 1053 NewBlocks.push_back(MainLoopPreheader); 1054 1055 // Some of the above may be nullptr, filter them out before passing to 1056 // addToParentLoopIfNeeded. 1057 auto NewBlocksEnd = std::remove(NewBlocks.begin(), NewBlocks.end(), nullptr); 1058 1059 typedef SmallVector<BasicBlock *, 6>::iterator SmallVectItTy; 1060 typedef std::vector<BasicBlock *>::iterator StdVectItTy; 1061 1062 addToParentLoopIfNeeded<SmallVectItTy>(NewBlocks.begin(), NewBlocksEnd); 1063 addToParentLoopIfNeeded<StdVectItTy>(PreLoop.Blocks.begin(), 1064 PreLoop.Blocks.end()); 1065 addToParentLoopIfNeeded<StdVectItTy>(PostLoop.Blocks.begin(), 1066 PostLoop.Blocks.end()); 1067 1068 return true; 1069 } 1070 1071 /// Computes and returns a range of values for the induction variable in which 1072 /// the range check can be safely elided. If it cannot compute such a range, 1073 /// returns None. 1074 Optional<InductiveRangeCheck::Range> 1075 InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE, 1076 IRBuilder<> &B) const { 1077 1078 // Currently we support inequalities of the form: 1079 // 1080 // 0 <= Offset + 1 * CIV < L given L >= 0 1081 // 1082 // The inequality is satisfied by -Offset <= CIV < (L - Offset) [^1]. All 1083 // additions and subtractions are twos-complement wrapping and comparisons are 1084 // signed. 1085 // 1086 // Proof: 1087 // 1088 // If there exists CIV such that -Offset <= CIV < (L - Offset) then it 1089 // follows that -Offset <= (-Offset + L) [== Eq. 1]. Since L >= 0, if 1090 // (-Offset + L) sign-overflows then (-Offset + L) < (-Offset). Hence by 1091 // [Eq. 1], (-Offset + L) could not have overflown. 1092 // 1093 // This means CIV = t + (-Offset) for t in [0, L). Hence (CIV + Offset) = 1094 // t. Hence 0 <= (CIV + Offset) < L 1095 1096 // [^1]: Note that the solution does _not_ apply if L < 0; consider values 1097 // Offset = 127, CIV = 126 and L = -2 in an i8 world. 1098 1099 const SCEVConstant *ScaleC = dyn_cast<SCEVConstant>(getScale()); 1100 if (!(ScaleC && ScaleC->getValue()->getValue() == 1)) { 1101 DEBUG(dbgs() << "irce: could not compute safe iteration space for:\n"; 1102 print(dbgs())); 1103 return None; 1104 } 1105 1106 Value *OffsetV = SCEVExpander(SE, "safe.itr.space").expandCodeFor( 1107 getOffset(), getOffset()->getType(), B.GetInsertPoint()); 1108 OffsetV = MaybeSimplify(OffsetV); 1109 1110 Value *Begin = MaybeSimplify(B.CreateNeg(OffsetV)); 1111 Value *End = MaybeSimplify(B.CreateSub(getLength(), OffsetV)); 1112 1113 return std::make_pair(Begin, End); 1114 } 1115 1116 static InductiveRangeCheck::Range 1117 IntersectRange(const Optional<InductiveRangeCheck::Range> &R1, 1118 const InductiveRangeCheck::Range &R2, IRBuilder<> &B) { 1119 if (!R1.hasValue()) 1120 return R2; 1121 auto &R1Value = R1.getValue(); 1122 1123 Value *NewMin = ConstructSMaxOf(R1Value.first, R2.first, B); 1124 Value *NewMax = ConstructSMinOf(R1Value.second, R2.second, B); 1125 return std::make_pair(NewMin, NewMax); 1126 } 1127 1128 bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { 1129 if (L->getBlocks().size() >= LoopSizeCutoff) { 1130 DEBUG(dbgs() << "irce: giving up constraining loop, too large\n";); 1131 return false; 1132 } 1133 1134 BasicBlock *Preheader = L->getLoopPreheader(); 1135 if (!Preheader) { 1136 DEBUG(dbgs() << "irce: loop has no preheader, leaving\n"); 1137 return false; 1138 } 1139 1140 LLVMContext &Context = Preheader->getContext(); 1141 InductiveRangeCheck::AllocatorTy IRCAlloc; 1142 SmallVector<InductiveRangeCheck *, 16> RangeChecks; 1143 ScalarEvolution &SE = getAnalysis<ScalarEvolution>(); 1144 1145 for (auto BBI : L->getBlocks()) 1146 if (BranchInst *TBI = dyn_cast<BranchInst>(BBI->getTerminator())) 1147 if (InductiveRangeCheck *IRC = 1148 InductiveRangeCheck::create(IRCAlloc, TBI, L, SE)) 1149 RangeChecks.push_back(IRC); 1150 1151 if (RangeChecks.empty()) 1152 return false; 1153 1154 DEBUG(dbgs() << "irce: looking at loop "; L->print(dbgs()); 1155 dbgs() << "irce: loop has " << RangeChecks.size() 1156 << " inductive range checks: \n"; 1157 for (InductiveRangeCheck *IRC : RangeChecks) 1158 IRC->print(dbgs()); 1159 ); 1160 1161 Optional<InductiveRangeCheck::Range> SafeIterRange; 1162 Instruction *ExprInsertPt = Preheader->getTerminator(); 1163 1164 SmallVector<InductiveRangeCheck *, 4> RangeChecksToEliminate; 1165 1166 IRBuilder<> B(ExprInsertPt); 1167 for (InductiveRangeCheck *IRC : RangeChecks) { 1168 auto Result = IRC->computeSafeIterationSpace(SE, B); 1169 if (Result.hasValue()) { 1170 SafeIterRange = IntersectRange(SafeIterRange, Result.getValue(), B); 1171 RangeChecksToEliminate.push_back(IRC); 1172 } 1173 } 1174 1175 if (!SafeIterRange.hasValue()) 1176 return false; 1177 1178 LoopConstrainer LC(*L, getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), SE, 1179 SafeIterRange.getValue()); 1180 bool Changed = LC.run(); 1181 1182 if (Changed) { 1183 auto PrintConstrainedLoopInfo = [L]() { 1184 dbgs() << "irce: in function "; 1185 dbgs() << L->getHeader()->getParent()->getName() << ": "; 1186 dbgs() << "constrained "; 1187 L->print(dbgs()); 1188 }; 1189 1190 DEBUG(PrintConstrainedLoopInfo()); 1191 1192 if (PrintChangedLoops) 1193 PrintConstrainedLoopInfo(); 1194 1195 // Optimize away the now-redundant range checks. 1196 1197 for (InductiveRangeCheck *IRC : RangeChecksToEliminate) { 1198 ConstantInt *FoldedRangeCheck = IRC->getPassingDirection() 1199 ? ConstantInt::getTrue(Context) 1200 : ConstantInt::getFalse(Context); 1201 IRC->getBranch()->setCondition(FoldedRangeCheck); 1202 } 1203 } 1204 1205 return Changed; 1206 } 1207 1208 Pass *llvm::createInductiveRangeCheckEliminationPass() { 1209 return new InductiveRangeCheckElimination; 1210 } 1211