1 //===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file contains the implementation of the scalar evolution analysis 11 // engine, which is used primarily to analyze expressions involving induction 12 // variables in loops. 13 // 14 // There are several aspects to this library. First is the representation of 15 // scalar expressions, which are represented as subclasses of the SCEV class. 16 // These classes are used to represent certain types of subexpressions that we 17 // can handle. We only create one SCEV of a particular shape, so 18 // pointer-comparisons for equality are legal. 19 // 20 // One important aspect of the SCEV objects is that they are never cyclic, even 21 // if there is a cycle in the dataflow for an expression (ie, a PHI node). If 22 // the PHI node is one of the idioms that we can represent (e.g., a polynomial 23 // recurrence) then we represent it directly as a recurrence node, otherwise we 24 // represent it as a SCEVUnknown node. 25 // 26 // In addition to being able to represent expressions of various types, we also 27 // have folders that are used to build the *canonical* representation for a 28 // particular expression. These folders are capable of using a variety of 29 // rewrite rules to simplify the expressions. 30 // 31 // Once the folders are defined, we can implement the more interesting 32 // higher-level code, such as the code that recognizes PHI nodes of various 33 // types, computes the execution count of a loop, etc. 34 // 35 // TODO: We should use these routines and value representations to implement 36 // dependence analysis! 37 // 38 //===----------------------------------------------------------------------===// 39 // 40 // There are several good references for the techniques used in this analysis. 41 // 42 // Chains of recurrences -- a method to expedite the evaluation 43 // of closed-form functions 44 // Olaf Bachmann, Paul S. Wang, Eugene V. Zima 45 // 46 // On computational properties of chains of recurrences 47 // Eugene V. Zima 48 // 49 // Symbolic Evaluation of Chains of Recurrences for Loop Optimization 50 // Robert A. van Engelen 51 // 52 // Efficient Symbolic Analysis for Optimizing Compilers 53 // Robert A. van Engelen 54 // 55 // Using the chains of recurrences algebra for data dependence testing and 56 // induction variable substitution 57 // MS Thesis, Johnie Birch 58 // 59 //===----------------------------------------------------------------------===// 60 61 #include "llvm/Analysis/ScalarEvolution.h" 62 #include "llvm/ADT/Optional.h" 63 #include "llvm/ADT/STLExtras.h" 64 #include "llvm/ADT/ScopeExit.h" 65 #include "llvm/ADT/Sequence.h" 66 #include "llvm/ADT/SmallPtrSet.h" 67 #include "llvm/ADT/Statistic.h" 68 #include "llvm/Analysis/AssumptionCache.h" 69 #include "llvm/Analysis/ConstantFolding.h" 70 #include "llvm/Analysis/InstructionSimplify.h" 71 #include "llvm/Analysis/LoopInfo.h" 72 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 73 #include "llvm/Analysis/TargetLibraryInfo.h" 74 #include "llvm/Analysis/ValueTracking.h" 75 #include "llvm/IR/ConstantRange.h" 76 #include "llvm/IR/Constants.h" 77 #include "llvm/IR/DataLayout.h" 78 #include "llvm/IR/DerivedTypes.h" 79 #include "llvm/IR/Dominators.h" 80 #include "llvm/IR/GetElementPtrTypeIterator.h" 81 #include "llvm/IR/GlobalAlias.h" 82 #include "llvm/IR/GlobalVariable.h" 83 #include "llvm/IR/InstIterator.h" 84 #include "llvm/IR/Instructions.h" 85 #include "llvm/IR/LLVMContext.h" 86 #include "llvm/IR/Metadata.h" 87 #include "llvm/IR/Operator.h" 88 #include "llvm/IR/PatternMatch.h" 89 #include "llvm/Support/CommandLine.h" 90 #include "llvm/Support/Debug.h" 91 #include "llvm/Support/ErrorHandling.h" 92 #include "llvm/Support/MathExtras.h" 93 #include "llvm/Support/raw_ostream.h" 94 #include "llvm/Support/SaveAndRestore.h" 95 #include <algorithm> 96 using namespace llvm; 97 98 #define DEBUG_TYPE "scalar-evolution" 99 100 STATISTIC(NumArrayLenItCounts, 101 "Number of trip counts computed with array length"); 102 STATISTIC(NumTripCountsComputed, 103 "Number of loops with predictable loop counts"); 104 STATISTIC(NumTripCountsNotComputed, 105 "Number of loops without predictable loop counts"); 106 STATISTIC(NumBruteForceTripCountsComputed, 107 "Number of loops with trip counts computed by force"); 108 109 static cl::opt<unsigned> 110 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, 111 cl::desc("Maximum number of iterations SCEV will " 112 "symbolically execute a constant " 113 "derived loop"), 114 cl::init(100)); 115 116 // FIXME: Enable this with EXPENSIVE_CHECKS when the test suite is clean. 117 static cl::opt<bool> 118 VerifySCEV("verify-scev", 119 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)")); 120 static cl::opt<bool> 121 VerifySCEVMap("verify-scev-maps", 122 cl::desc("Verify no dangling value in ScalarEvolution's " 123 "ExprValueMap (slow)")); 124 125 static cl::opt<unsigned> MulOpsInlineThreshold( 126 "scev-mulops-inline-threshold", cl::Hidden, 127 cl::desc("Threshold for inlining multiplication operands into a SCEV"), 128 cl::init(1000)); 129 130 static cl::opt<unsigned> 131 MaxCompareDepth("scalar-evolution-max-compare-depth", cl::Hidden, 132 cl::desc("Maximum depth of recursive compare complexity"), 133 cl::init(32)); 134 135 static cl::opt<unsigned> MaxConstantEvolvingDepth( 136 "scalar-evolution-max-constant-evolving-depth", cl::Hidden, 137 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32)); 138 139 //===----------------------------------------------------------------------===// 140 // SCEV class definitions 141 //===----------------------------------------------------------------------===// 142 143 //===----------------------------------------------------------------------===// 144 // Implementation of the SCEV class. 145 // 146 147 LLVM_DUMP_METHOD 148 void SCEV::dump() const { 149 print(dbgs()); 150 dbgs() << '\n'; 151 } 152 153 void SCEV::print(raw_ostream &OS) const { 154 switch (static_cast<SCEVTypes>(getSCEVType())) { 155 case scConstant: 156 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false); 157 return; 158 case scTruncate: { 159 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this); 160 const SCEV *Op = Trunc->getOperand(); 161 OS << "(trunc " << *Op->getType() << " " << *Op << " to " 162 << *Trunc->getType() << ")"; 163 return; 164 } 165 case scZeroExtend: { 166 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this); 167 const SCEV *Op = ZExt->getOperand(); 168 OS << "(zext " << *Op->getType() << " " << *Op << " to " 169 << *ZExt->getType() << ")"; 170 return; 171 } 172 case scSignExtend: { 173 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this); 174 const SCEV *Op = SExt->getOperand(); 175 OS << "(sext " << *Op->getType() << " " << *Op << " to " 176 << *SExt->getType() << ")"; 177 return; 178 } 179 case scAddRecExpr: { 180 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this); 181 OS << "{" << *AR->getOperand(0); 182 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i) 183 OS << ",+," << *AR->getOperand(i); 184 OS << "}<"; 185 if (AR->hasNoUnsignedWrap()) 186 OS << "nuw><"; 187 if (AR->hasNoSignedWrap()) 188 OS << "nsw><"; 189 if (AR->hasNoSelfWrap() && 190 !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW))) 191 OS << "nw><"; 192 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false); 193 OS << ">"; 194 return; 195 } 196 case scAddExpr: 197 case scMulExpr: 198 case scUMaxExpr: 199 case scSMaxExpr: { 200 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this); 201 const char *OpStr = nullptr; 202 switch (NAry->getSCEVType()) { 203 case scAddExpr: OpStr = " + "; break; 204 case scMulExpr: OpStr = " * "; break; 205 case scUMaxExpr: OpStr = " umax "; break; 206 case scSMaxExpr: OpStr = " smax "; break; 207 } 208 OS << "("; 209 for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end(); 210 I != E; ++I) { 211 OS << **I; 212 if (std::next(I) != E) 213 OS << OpStr; 214 } 215 OS << ")"; 216 switch (NAry->getSCEVType()) { 217 case scAddExpr: 218 case scMulExpr: 219 if (NAry->hasNoUnsignedWrap()) 220 OS << "<nuw>"; 221 if (NAry->hasNoSignedWrap()) 222 OS << "<nsw>"; 223 } 224 return; 225 } 226 case scUDivExpr: { 227 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this); 228 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")"; 229 return; 230 } 231 case scUnknown: { 232 const SCEVUnknown *U = cast<SCEVUnknown>(this); 233 Type *AllocTy; 234 if (U->isSizeOf(AllocTy)) { 235 OS << "sizeof(" << *AllocTy << ")"; 236 return; 237 } 238 if (U->isAlignOf(AllocTy)) { 239 OS << "alignof(" << *AllocTy << ")"; 240 return; 241 } 242 243 Type *CTy; 244 Constant *FieldNo; 245 if (U->isOffsetOf(CTy, FieldNo)) { 246 OS << "offsetof(" << *CTy << ", "; 247 FieldNo->printAsOperand(OS, false); 248 OS << ")"; 249 return; 250 } 251 252 // Otherwise just print it normally. 253 U->getValue()->printAsOperand(OS, false); 254 return; 255 } 256 case scCouldNotCompute: 257 OS << "***COULDNOTCOMPUTE***"; 258 return; 259 } 260 llvm_unreachable("Unknown SCEV kind!"); 261 } 262 263 Type *SCEV::getType() const { 264 switch (static_cast<SCEVTypes>(getSCEVType())) { 265 case scConstant: 266 return cast<SCEVConstant>(this)->getType(); 267 case scTruncate: 268 case scZeroExtend: 269 case scSignExtend: 270 return cast<SCEVCastExpr>(this)->getType(); 271 case scAddRecExpr: 272 case scMulExpr: 273 case scUMaxExpr: 274 case scSMaxExpr: 275 return cast<SCEVNAryExpr>(this)->getType(); 276 case scAddExpr: 277 return cast<SCEVAddExpr>(this)->getType(); 278 case scUDivExpr: 279 return cast<SCEVUDivExpr>(this)->getType(); 280 case scUnknown: 281 return cast<SCEVUnknown>(this)->getType(); 282 case scCouldNotCompute: 283 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 284 } 285 llvm_unreachable("Unknown SCEV kind!"); 286 } 287 288 bool SCEV::isZero() const { 289 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this)) 290 return SC->getValue()->isZero(); 291 return false; 292 } 293 294 bool SCEV::isOne() const { 295 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this)) 296 return SC->getValue()->isOne(); 297 return false; 298 } 299 300 bool SCEV::isAllOnesValue() const { 301 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this)) 302 return SC->getValue()->isAllOnesValue(); 303 return false; 304 } 305 306 bool SCEV::isNonConstantNegative() const { 307 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this); 308 if (!Mul) return false; 309 310 // If there is a constant factor, it will be first. 311 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0)); 312 if (!SC) return false; 313 314 // Return true if the value is negative, this matches things like (-42 * V). 315 return SC->getAPInt().isNegative(); 316 } 317 318 SCEVCouldNotCompute::SCEVCouldNotCompute() : 319 SCEV(FoldingSetNodeIDRef(), scCouldNotCompute) {} 320 321 bool SCEVCouldNotCompute::classof(const SCEV *S) { 322 return S->getSCEVType() == scCouldNotCompute; 323 } 324 325 const SCEV *ScalarEvolution::getConstant(ConstantInt *V) { 326 FoldingSetNodeID ID; 327 ID.AddInteger(scConstant); 328 ID.AddPointer(V); 329 void *IP = nullptr; 330 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; 331 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V); 332 UniqueSCEVs.InsertNode(S, IP); 333 return S; 334 } 335 336 const SCEV *ScalarEvolution::getConstant(const APInt &Val) { 337 return getConstant(ConstantInt::get(getContext(), Val)); 338 } 339 340 const SCEV * 341 ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) { 342 IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty)); 343 return getConstant(ConstantInt::get(ITy, V, isSigned)); 344 } 345 346 SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, 347 unsigned SCEVTy, const SCEV *op, Type *ty) 348 : SCEV(ID, SCEVTy), Op(op), Ty(ty) {} 349 350 SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, 351 const SCEV *op, Type *ty) 352 : SCEVCastExpr(ID, scTruncate, op, ty) { 353 assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) && 354 (Ty->isIntegerTy() || Ty->isPointerTy()) && 355 "Cannot truncate non-integer value!"); 356 } 357 358 SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, 359 const SCEV *op, Type *ty) 360 : SCEVCastExpr(ID, scZeroExtend, op, ty) { 361 assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) && 362 (Ty->isIntegerTy() || Ty->isPointerTy()) && 363 "Cannot zero extend non-integer value!"); 364 } 365 366 SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, 367 const SCEV *op, Type *ty) 368 : SCEVCastExpr(ID, scSignExtend, op, ty) { 369 assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) && 370 (Ty->isIntegerTy() || Ty->isPointerTy()) && 371 "Cannot sign extend non-integer value!"); 372 } 373 374 void SCEVUnknown::deleted() { 375 // Clear this SCEVUnknown from various maps. 376 SE->forgetMemoizedResults(this); 377 378 // Remove this SCEVUnknown from the uniquing map. 379 SE->UniqueSCEVs.RemoveNode(this); 380 381 // Release the value. 382 setValPtr(nullptr); 383 } 384 385 void SCEVUnknown::allUsesReplacedWith(Value *New) { 386 // Clear this SCEVUnknown from various maps. 387 SE->forgetMemoizedResults(this); 388 389 // Remove this SCEVUnknown from the uniquing map. 390 SE->UniqueSCEVs.RemoveNode(this); 391 392 // Update this SCEVUnknown to point to the new value. This is needed 393 // because there may still be outstanding SCEVs which still point to 394 // this SCEVUnknown. 395 setValPtr(New); 396 } 397 398 bool SCEVUnknown::isSizeOf(Type *&AllocTy) const { 399 if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue())) 400 if (VCE->getOpcode() == Instruction::PtrToInt) 401 if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0))) 402 if (CE->getOpcode() == Instruction::GetElementPtr && 403 CE->getOperand(0)->isNullValue() && 404 CE->getNumOperands() == 2) 405 if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1))) 406 if (CI->isOne()) { 407 AllocTy = cast<PointerType>(CE->getOperand(0)->getType()) 408 ->getElementType(); 409 return true; 410 } 411 412 return false; 413 } 414 415 bool SCEVUnknown::isAlignOf(Type *&AllocTy) const { 416 if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue())) 417 if (VCE->getOpcode() == Instruction::PtrToInt) 418 if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0))) 419 if (CE->getOpcode() == Instruction::GetElementPtr && 420 CE->getOperand(0)->isNullValue()) { 421 Type *Ty = 422 cast<PointerType>(CE->getOperand(0)->getType())->getElementType(); 423 if (StructType *STy = dyn_cast<StructType>(Ty)) 424 if (!STy->isPacked() && 425 CE->getNumOperands() == 3 && 426 CE->getOperand(1)->isNullValue()) { 427 if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2))) 428 if (CI->isOne() && 429 STy->getNumElements() == 2 && 430 STy->getElementType(0)->isIntegerTy(1)) { 431 AllocTy = STy->getElementType(1); 432 return true; 433 } 434 } 435 } 436 437 return false; 438 } 439 440 bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const { 441 if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue())) 442 if (VCE->getOpcode() == Instruction::PtrToInt) 443 if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0))) 444 if (CE->getOpcode() == Instruction::GetElementPtr && 445 CE->getNumOperands() == 3 && 446 CE->getOperand(0)->isNullValue() && 447 CE->getOperand(1)->isNullValue()) { 448 Type *Ty = 449 cast<PointerType>(CE->getOperand(0)->getType())->getElementType(); 450 // Ignore vector types here so that ScalarEvolutionExpander doesn't 451 // emit getelementptrs that index into vectors. 452 if (Ty->isStructTy() || Ty->isArrayTy()) { 453 CTy = Ty; 454 FieldNo = CE->getOperand(2); 455 return true; 456 } 457 } 458 459 return false; 460 } 461 462 //===----------------------------------------------------------------------===// 463 // SCEV Utilities 464 //===----------------------------------------------------------------------===// 465 466 /// Compare the two values \p LV and \p RV in terms of their "complexity" where 467 /// "complexity" is a partial (and somewhat ad-hoc) relation used to order 468 /// operands in SCEV expressions. \p EqCache is a set of pairs of values that 469 /// have been previously deemed to be "equally complex" by this routine. It is 470 /// intended to avoid exponential time complexity in cases like: 471 /// 472 /// %a = f(%x, %y) 473 /// %b = f(%a, %a) 474 /// %c = f(%b, %b) 475 /// 476 /// %d = f(%x, %y) 477 /// %e = f(%d, %d) 478 /// %f = f(%e, %e) 479 /// 480 /// CompareValueComplexity(%f, %c) 481 /// 482 /// Since we do not continue running this routine on expression trees once we 483 /// have seen unequal values, there is no need to track them in the cache. 484 static int 485 CompareValueComplexity(SmallSet<std::pair<Value *, Value *>, 8> &EqCache, 486 const LoopInfo *const LI, Value *LV, Value *RV, 487 unsigned Depth) { 488 if (Depth > MaxCompareDepth || EqCache.count({LV, RV})) 489 return 0; 490 491 // Order pointer values after integer values. This helps SCEVExpander form 492 // GEPs. 493 bool LIsPointer = LV->getType()->isPointerTy(), 494 RIsPointer = RV->getType()->isPointerTy(); 495 if (LIsPointer != RIsPointer) 496 return (int)LIsPointer - (int)RIsPointer; 497 498 // Compare getValueID values. 499 unsigned LID = LV->getValueID(), RID = RV->getValueID(); 500 if (LID != RID) 501 return (int)LID - (int)RID; 502 503 // Sort arguments by their position. 504 if (const auto *LA = dyn_cast<Argument>(LV)) { 505 const auto *RA = cast<Argument>(RV); 506 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo(); 507 return (int)LArgNo - (int)RArgNo; 508 } 509 510 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) { 511 const auto *RGV = cast<GlobalValue>(RV); 512 513 const auto IsGVNameSemantic = [&](const GlobalValue *GV) { 514 auto LT = GV->getLinkage(); 515 return !(GlobalValue::isPrivateLinkage(LT) || 516 GlobalValue::isInternalLinkage(LT)); 517 }; 518 519 // Use the names to distinguish the two values, but only if the 520 // names are semantically important. 521 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV)) 522 return LGV->getName().compare(RGV->getName()); 523 } 524 525 // For instructions, compare their loop depth, and their operand count. This 526 // is pretty loose. 527 if (const auto *LInst = dyn_cast<Instruction>(LV)) { 528 const auto *RInst = cast<Instruction>(RV); 529 530 // Compare loop depths. 531 const BasicBlock *LParent = LInst->getParent(), 532 *RParent = RInst->getParent(); 533 if (LParent != RParent) { 534 unsigned LDepth = LI->getLoopDepth(LParent), 535 RDepth = LI->getLoopDepth(RParent); 536 if (LDepth != RDepth) 537 return (int)LDepth - (int)RDepth; 538 } 539 540 // Compare the number of operands. 541 unsigned LNumOps = LInst->getNumOperands(), 542 RNumOps = RInst->getNumOperands(); 543 if (LNumOps != RNumOps) 544 return (int)LNumOps - (int)RNumOps; 545 546 for (unsigned Idx : seq(0u, LNumOps)) { 547 int Result = 548 CompareValueComplexity(EqCache, LI, LInst->getOperand(Idx), 549 RInst->getOperand(Idx), Depth + 1); 550 if (Result != 0) 551 return Result; 552 } 553 } 554 555 EqCache.insert({LV, RV}); 556 return 0; 557 } 558 559 // Return negative, zero, or positive, if LHS is less than, equal to, or greater 560 // than RHS, respectively. A three-way result allows recursive comparisons to be 561 // more efficient. 562 static int CompareSCEVComplexity( 563 SmallSet<std::pair<const SCEV *, const SCEV *>, 8> &EqCacheSCEV, 564 const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, 565 unsigned Depth = 0) { 566 // Fast-path: SCEVs are uniqued so we can do a quick equality check. 567 if (LHS == RHS) 568 return 0; 569 570 // Primarily, sort the SCEVs by their getSCEVType(). 571 unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType(); 572 if (LType != RType) 573 return (int)LType - (int)RType; 574 575 if (Depth > MaxCompareDepth || EqCacheSCEV.count({LHS, RHS})) 576 return 0; 577 // Aside from the getSCEVType() ordering, the particular ordering 578 // isn't very important except that it's beneficial to be consistent, 579 // so that (a + b) and (b + a) don't end up as different expressions. 580 switch (static_cast<SCEVTypes>(LType)) { 581 case scUnknown: { 582 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS); 583 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS); 584 585 SmallSet<std::pair<Value *, Value *>, 8> EqCache; 586 int X = CompareValueComplexity(EqCache, LI, LU->getValue(), RU->getValue(), 587 Depth + 1); 588 if (X == 0) 589 EqCacheSCEV.insert({LHS, RHS}); 590 return X; 591 } 592 593 case scConstant: { 594 const SCEVConstant *LC = cast<SCEVConstant>(LHS); 595 const SCEVConstant *RC = cast<SCEVConstant>(RHS); 596 597 // Compare constant values. 598 const APInt &LA = LC->getAPInt(); 599 const APInt &RA = RC->getAPInt(); 600 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth(); 601 if (LBitWidth != RBitWidth) 602 return (int)LBitWidth - (int)RBitWidth; 603 return LA.ult(RA) ? -1 : 1; 604 } 605 606 case scAddRecExpr: { 607 const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS); 608 const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS); 609 610 // Compare addrec loop depths. 611 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop(); 612 if (LLoop != RLoop) { 613 unsigned LDepth = LLoop->getLoopDepth(), RDepth = RLoop->getLoopDepth(); 614 if (LDepth != RDepth) 615 return (int)LDepth - (int)RDepth; 616 } 617 618 // Addrec complexity grows with operand count. 619 unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands(); 620 if (LNumOps != RNumOps) 621 return (int)LNumOps - (int)RNumOps; 622 623 // Lexicographically compare. 624 for (unsigned i = 0; i != LNumOps; ++i) { 625 int X = CompareSCEVComplexity(EqCacheSCEV, LI, LA->getOperand(i), 626 RA->getOperand(i), Depth + 1); 627 if (X != 0) 628 return X; 629 } 630 EqCacheSCEV.insert({LHS, RHS}); 631 return 0; 632 } 633 634 case scAddExpr: 635 case scMulExpr: 636 case scSMaxExpr: 637 case scUMaxExpr: { 638 const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS); 639 const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS); 640 641 // Lexicographically compare n-ary expressions. 642 unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands(); 643 if (LNumOps != RNumOps) 644 return (int)LNumOps - (int)RNumOps; 645 646 for (unsigned i = 0; i != LNumOps; ++i) { 647 if (i >= RNumOps) 648 return 1; 649 int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(i), 650 RC->getOperand(i), Depth + 1); 651 if (X != 0) 652 return X; 653 } 654 EqCacheSCEV.insert({LHS, RHS}); 655 return 0; 656 } 657 658 case scUDivExpr: { 659 const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS); 660 const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS); 661 662 // Lexicographically compare udiv expressions. 663 int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getLHS(), RC->getLHS(), 664 Depth + 1); 665 if (X != 0) 666 return X; 667 X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getRHS(), RC->getRHS(), 668 Depth + 1); 669 if (X == 0) 670 EqCacheSCEV.insert({LHS, RHS}); 671 return X; 672 } 673 674 case scTruncate: 675 case scZeroExtend: 676 case scSignExtend: { 677 const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS); 678 const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS); 679 680 // Compare cast expressions by operand. 681 int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(), 682 RC->getOperand(), Depth + 1); 683 if (X == 0) 684 EqCacheSCEV.insert({LHS, RHS}); 685 return X; 686 } 687 688 case scCouldNotCompute: 689 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 690 } 691 llvm_unreachable("Unknown SCEV kind!"); 692 } 693 694 /// Given a list of SCEV objects, order them by their complexity, and group 695 /// objects of the same complexity together by value. When this routine is 696 /// finished, we know that any duplicates in the vector are consecutive and that 697 /// complexity is monotonically increasing. 698 /// 699 /// Note that we go take special precautions to ensure that we get deterministic 700 /// results from this routine. In other words, we don't want the results of 701 /// this to depend on where the addresses of various SCEV objects happened to 702 /// land in memory. 703 /// 704 static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops, 705 LoopInfo *LI) { 706 if (Ops.size() < 2) return; // Noop 707 708 SmallSet<std::pair<const SCEV *, const SCEV *>, 8> EqCache; 709 if (Ops.size() == 2) { 710 // This is the common case, which also happens to be trivially simple. 711 // Special case it. 712 const SCEV *&LHS = Ops[0], *&RHS = Ops[1]; 713 if (CompareSCEVComplexity(EqCache, LI, RHS, LHS) < 0) 714 std::swap(LHS, RHS); 715 return; 716 } 717 718 // Do the rough sort by complexity. 719 std::stable_sort(Ops.begin(), Ops.end(), 720 [&EqCache, LI](const SCEV *LHS, const SCEV *RHS) { 721 return CompareSCEVComplexity(EqCache, LI, LHS, RHS) < 0; 722 }); 723 724 // Now that we are sorted by complexity, group elements of the same 725 // complexity. Note that this is, at worst, N^2, but the vector is likely to 726 // be extremely short in practice. Note that we take this approach because we 727 // do not want to depend on the addresses of the objects we are grouping. 728 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) { 729 const SCEV *S = Ops[i]; 730 unsigned Complexity = S->getSCEVType(); 731 732 // If there are any objects of the same complexity and same value as this 733 // one, group them. 734 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) { 735 if (Ops[j] == S) { // Found a duplicate. 736 // Move it to immediately after i'th element. 737 std::swap(Ops[i+1], Ops[j]); 738 ++i; // no need to rescan it. 739 if (i == e-2) return; // Done! 740 } 741 } 742 } 743 } 744 745 // Returns the size of the SCEV S. 746 static inline int sizeOfSCEV(const SCEV *S) { 747 struct FindSCEVSize { 748 int Size; 749 FindSCEVSize() : Size(0) {} 750 751 bool follow(const SCEV *S) { 752 ++Size; 753 // Keep looking at all operands of S. 754 return true; 755 } 756 bool isDone() const { 757 return false; 758 } 759 }; 760 761 FindSCEVSize F; 762 SCEVTraversal<FindSCEVSize> ST(F); 763 ST.visitAll(S); 764 return F.Size; 765 } 766 767 namespace { 768 769 struct SCEVDivision : public SCEVVisitor<SCEVDivision, void> { 770 public: 771 // Computes the Quotient and Remainder of the division of Numerator by 772 // Denominator. 773 static void divide(ScalarEvolution &SE, const SCEV *Numerator, 774 const SCEV *Denominator, const SCEV **Quotient, 775 const SCEV **Remainder) { 776 assert(Numerator && Denominator && "Uninitialized SCEV"); 777 778 SCEVDivision D(SE, Numerator, Denominator); 779 780 // Check for the trivial case here to avoid having to check for it in the 781 // rest of the code. 782 if (Numerator == Denominator) { 783 *Quotient = D.One; 784 *Remainder = D.Zero; 785 return; 786 } 787 788 if (Numerator->isZero()) { 789 *Quotient = D.Zero; 790 *Remainder = D.Zero; 791 return; 792 } 793 794 // A simple case when N/1. The quotient is N. 795 if (Denominator->isOne()) { 796 *Quotient = Numerator; 797 *Remainder = D.Zero; 798 return; 799 } 800 801 // Split the Denominator when it is a product. 802 if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) { 803 const SCEV *Q, *R; 804 *Quotient = Numerator; 805 for (const SCEV *Op : T->operands()) { 806 divide(SE, *Quotient, Op, &Q, &R); 807 *Quotient = Q; 808 809 // Bail out when the Numerator is not divisible by one of the terms of 810 // the Denominator. 811 if (!R->isZero()) { 812 *Quotient = D.Zero; 813 *Remainder = Numerator; 814 return; 815 } 816 } 817 *Remainder = D.Zero; 818 return; 819 } 820 821 D.visit(Numerator); 822 *Quotient = D.Quotient; 823 *Remainder = D.Remainder; 824 } 825 826 // Except in the trivial case described above, we do not know how to divide 827 // Expr by Denominator for the following functions with empty implementation. 828 void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {} 829 void visitZeroExtendExpr(const SCEVZeroExtendExpr *Numerator) {} 830 void visitSignExtendExpr(const SCEVSignExtendExpr *Numerator) {} 831 void visitUDivExpr(const SCEVUDivExpr *Numerator) {} 832 void visitSMaxExpr(const SCEVSMaxExpr *Numerator) {} 833 void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {} 834 void visitUnknown(const SCEVUnknown *Numerator) {} 835 void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {} 836 837 void visitConstant(const SCEVConstant *Numerator) { 838 if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) { 839 APInt NumeratorVal = Numerator->getAPInt(); 840 APInt DenominatorVal = D->getAPInt(); 841 uint32_t NumeratorBW = NumeratorVal.getBitWidth(); 842 uint32_t DenominatorBW = DenominatorVal.getBitWidth(); 843 844 if (NumeratorBW > DenominatorBW) 845 DenominatorVal = DenominatorVal.sext(NumeratorBW); 846 else if (NumeratorBW < DenominatorBW) 847 NumeratorVal = NumeratorVal.sext(DenominatorBW); 848 849 APInt QuotientVal(NumeratorVal.getBitWidth(), 0); 850 APInt RemainderVal(NumeratorVal.getBitWidth(), 0); 851 APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal); 852 Quotient = SE.getConstant(QuotientVal); 853 Remainder = SE.getConstant(RemainderVal); 854 return; 855 } 856 } 857 858 void visitAddRecExpr(const SCEVAddRecExpr *Numerator) { 859 const SCEV *StartQ, *StartR, *StepQ, *StepR; 860 if (!Numerator->isAffine()) 861 return cannotDivide(Numerator); 862 divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR); 863 divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR); 864 // Bail out if the types do not match. 865 Type *Ty = Denominator->getType(); 866 if (Ty != StartQ->getType() || Ty != StartR->getType() || 867 Ty != StepQ->getType() || Ty != StepR->getType()) 868 return cannotDivide(Numerator); 869 Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(), 870 Numerator->getNoWrapFlags()); 871 Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(), 872 Numerator->getNoWrapFlags()); 873 } 874 875 void visitAddExpr(const SCEVAddExpr *Numerator) { 876 SmallVector<const SCEV *, 2> Qs, Rs; 877 Type *Ty = Denominator->getType(); 878 879 for (const SCEV *Op : Numerator->operands()) { 880 const SCEV *Q, *R; 881 divide(SE, Op, Denominator, &Q, &R); 882 883 // Bail out if types do not match. 884 if (Ty != Q->getType() || Ty != R->getType()) 885 return cannotDivide(Numerator); 886 887 Qs.push_back(Q); 888 Rs.push_back(R); 889 } 890 891 if (Qs.size() == 1) { 892 Quotient = Qs[0]; 893 Remainder = Rs[0]; 894 return; 895 } 896 897 Quotient = SE.getAddExpr(Qs); 898 Remainder = SE.getAddExpr(Rs); 899 } 900 901 void visitMulExpr(const SCEVMulExpr *Numerator) { 902 SmallVector<const SCEV *, 2> Qs; 903 Type *Ty = Denominator->getType(); 904 905 bool FoundDenominatorTerm = false; 906 for (const SCEV *Op : Numerator->operands()) { 907 // Bail out if types do not match. 908 if (Ty != Op->getType()) 909 return cannotDivide(Numerator); 910 911 if (FoundDenominatorTerm) { 912 Qs.push_back(Op); 913 continue; 914 } 915 916 // Check whether Denominator divides one of the product operands. 917 const SCEV *Q, *R; 918 divide(SE, Op, Denominator, &Q, &R); 919 if (!R->isZero()) { 920 Qs.push_back(Op); 921 continue; 922 } 923 924 // Bail out if types do not match. 925 if (Ty != Q->getType()) 926 return cannotDivide(Numerator); 927 928 FoundDenominatorTerm = true; 929 Qs.push_back(Q); 930 } 931 932 if (FoundDenominatorTerm) { 933 Remainder = Zero; 934 if (Qs.size() == 1) 935 Quotient = Qs[0]; 936 else 937 Quotient = SE.getMulExpr(Qs); 938 return; 939 } 940 941 if (!isa<SCEVUnknown>(Denominator)) 942 return cannotDivide(Numerator); 943 944 // The Remainder is obtained by replacing Denominator by 0 in Numerator. 945 ValueToValueMap RewriteMap; 946 RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = 947 cast<SCEVConstant>(Zero)->getValue(); 948 Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); 949 950 if (Remainder->isZero()) { 951 // The Quotient is obtained by replacing Denominator by 1 in Numerator. 952 RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = 953 cast<SCEVConstant>(One)->getValue(); 954 Quotient = 955 SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); 956 return; 957 } 958 959 // Quotient is (Numerator - Remainder) divided by Denominator. 960 const SCEV *Q, *R; 961 const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder); 962 // This SCEV does not seem to simplify: fail the division here. 963 if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) 964 return cannotDivide(Numerator); 965 divide(SE, Diff, Denominator, &Q, &R); 966 if (R != Zero) 967 return cannotDivide(Numerator); 968 Quotient = Q; 969 } 970 971 private: 972 SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, 973 const SCEV *Denominator) 974 : SE(S), Denominator(Denominator) { 975 Zero = SE.getZero(Denominator->getType()); 976 One = SE.getOne(Denominator->getType()); 977 978 // We generally do not know how to divide Expr by Denominator. We 979 // initialize the division to a "cannot divide" state to simplify the rest 980 // of the code. 981 cannotDivide(Numerator); 982 } 983 984 // Convenience function for giving up on the division. We set the quotient to 985 // be equal to zero and the remainder to be equal to the numerator. 986 void cannotDivide(const SCEV *Numerator) { 987 Quotient = Zero; 988 Remainder = Numerator; 989 } 990 991 ScalarEvolution &SE; 992 const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One; 993 }; 994 995 } 996 997 //===----------------------------------------------------------------------===// 998 // Simple SCEV method implementations 999 //===----------------------------------------------------------------------===// 1000 1001 /// Compute BC(It, K). The result has width W. Assume, K > 0. 1002 static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, 1003 ScalarEvolution &SE, 1004 Type *ResultTy) { 1005 // Handle the simplest case efficiently. 1006 if (K == 1) 1007 return SE.getTruncateOrZeroExtend(It, ResultTy); 1008 1009 // We are using the following formula for BC(It, K): 1010 // 1011 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K! 1012 // 1013 // Suppose, W is the bitwidth of the return value. We must be prepared for 1014 // overflow. Hence, we must assure that the result of our computation is 1015 // equal to the accurate one modulo 2^W. Unfortunately, division isn't 1016 // safe in modular arithmetic. 1017 // 1018 // However, this code doesn't use exactly that formula; the formula it uses 1019 // is something like the following, where T is the number of factors of 2 in 1020 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is 1021 // exponentiation: 1022 // 1023 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T) 1024 // 1025 // This formula is trivially equivalent to the previous formula. However, 1026 // this formula can be implemented much more efficiently. The trick is that 1027 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular 1028 // arithmetic. To do exact division in modular arithmetic, all we have 1029 // to do is multiply by the inverse. Therefore, this step can be done at 1030 // width W. 1031 // 1032 // The next issue is how to safely do the division by 2^T. The way this 1033 // is done is by doing the multiplication step at a width of at least W + T 1034 // bits. This way, the bottom W+T bits of the product are accurate. Then, 1035 // when we perform the division by 2^T (which is equivalent to a right shift 1036 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get 1037 // truncated out after the division by 2^T. 1038 // 1039 // In comparison to just directly using the first formula, this technique 1040 // is much more efficient; using the first formula requires W * K bits, 1041 // but this formula less than W + K bits. Also, the first formula requires 1042 // a division step, whereas this formula only requires multiplies and shifts. 1043 // 1044 // It doesn't matter whether the subtraction step is done in the calculation 1045 // width or the input iteration count's width; if the subtraction overflows, 1046 // the result must be zero anyway. We prefer here to do it in the width of 1047 // the induction variable because it helps a lot for certain cases; CodeGen 1048 // isn't smart enough to ignore the overflow, which leads to much less 1049 // efficient code if the width of the subtraction is wider than the native 1050 // register width. 1051 // 1052 // (It's possible to not widen at all by pulling out factors of 2 before 1053 // the multiplication; for example, K=2 can be calculated as 1054 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires 1055 // extra arithmetic, so it's not an obvious win, and it gets 1056 // much more complicated for K > 3.) 1057 1058 // Protection from insane SCEVs; this bound is conservative, 1059 // but it probably doesn't matter. 1060 if (K > 1000) 1061 return SE.getCouldNotCompute(); 1062 1063 unsigned W = SE.getTypeSizeInBits(ResultTy); 1064 1065 // Calculate K! / 2^T and T; we divide out the factors of two before 1066 // multiplying for calculating K! / 2^T to avoid overflow. 1067 // Other overflow doesn't matter because we only care about the bottom 1068 // W bits of the result. 1069 APInt OddFactorial(W, 1); 1070 unsigned T = 1; 1071 for (unsigned i = 3; i <= K; ++i) { 1072 APInt Mult(W, i); 1073 unsigned TwoFactors = Mult.countTrailingZeros(); 1074 T += TwoFactors; 1075 Mult = Mult.lshr(TwoFactors); 1076 OddFactorial *= Mult; 1077 } 1078 1079 // We need at least W + T bits for the multiplication step 1080 unsigned CalculationBits = W + T; 1081 1082 // Calculate 2^T, at width T+W. 1083 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T); 1084 1085 // Calculate the multiplicative inverse of K! / 2^T; 1086 // this multiplication factor will perform the exact division by 1087 // K! / 2^T. 1088 APInt Mod = APInt::getSignedMinValue(W+1); 1089 APInt MultiplyFactor = OddFactorial.zext(W+1); 1090 MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod); 1091 MultiplyFactor = MultiplyFactor.trunc(W); 1092 1093 // Calculate the product, at width T+W 1094 IntegerType *CalculationTy = IntegerType::get(SE.getContext(), 1095 CalculationBits); 1096 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy); 1097 for (unsigned i = 1; i != K; ++i) { 1098 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i)); 1099 Dividend = SE.getMulExpr(Dividend, 1100 SE.getTruncateOrZeroExtend(S, CalculationTy)); 1101 } 1102 1103 // Divide by 2^T 1104 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor)); 1105 1106 // Truncate the result, and divide by K! / 2^T. 1107 1108 return SE.getMulExpr(SE.getConstant(MultiplyFactor), 1109 SE.getTruncateOrZeroExtend(DivResult, ResultTy)); 1110 } 1111 1112 /// Return the value of this chain of recurrences at the specified iteration 1113 /// number. We can evaluate this recurrence by multiplying each element in the 1114 /// chain by the binomial coefficient corresponding to it. In other words, we 1115 /// can evaluate {A,+,B,+,C,+,D} as: 1116 /// 1117 /// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3) 1118 /// 1119 /// where BC(It, k) stands for binomial coefficient. 1120 /// 1121 const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It, 1122 ScalarEvolution &SE) const { 1123 const SCEV *Result = getStart(); 1124 for (unsigned i = 1, e = getNumOperands(); i != e; ++i) { 1125 // The computation is correct in the face of overflow provided that the 1126 // multiplication is performed _after_ the evaluation of the binomial 1127 // coefficient. 1128 const SCEV *Coeff = BinomialCoefficient(It, i, SE, getType()); 1129 if (isa<SCEVCouldNotCompute>(Coeff)) 1130 return Coeff; 1131 1132 Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff)); 1133 } 1134 return Result; 1135 } 1136 1137 //===----------------------------------------------------------------------===// 1138 // SCEV Expression folder implementations 1139 //===----------------------------------------------------------------------===// 1140 1141 const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, 1142 Type *Ty) { 1143 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) && 1144 "This is not a truncating conversion!"); 1145 assert(isSCEVable(Ty) && 1146 "This is not a conversion to a SCEVable type!"); 1147 Ty = getEffectiveSCEVType(Ty); 1148 1149 FoldingSetNodeID ID; 1150 ID.AddInteger(scTruncate); 1151 ID.AddPointer(Op); 1152 ID.AddPointer(Ty); 1153 void *IP = nullptr; 1154 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; 1155 1156 // Fold if the operand is constant. 1157 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) 1158 return getConstant( 1159 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty))); 1160 1161 // trunc(trunc(x)) --> trunc(x) 1162 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) 1163 return getTruncateExpr(ST->getOperand(), Ty); 1164 1165 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing 1166 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op)) 1167 return getTruncateOrSignExtend(SS->getOperand(), Ty); 1168 1169 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing 1170 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op)) 1171 return getTruncateOrZeroExtend(SZ->getOperand(), Ty); 1172 1173 // trunc(x1+x2+...+xN) --> trunc(x1)+trunc(x2)+...+trunc(xN) if we can 1174 // eliminate all the truncates, or we replace other casts with truncates. 1175 if (const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Op)) { 1176 SmallVector<const SCEV *, 4> Operands; 1177 bool hasTrunc = false; 1178 for (unsigned i = 0, e = SA->getNumOperands(); i != e && !hasTrunc; ++i) { 1179 const SCEV *S = getTruncateExpr(SA->getOperand(i), Ty); 1180 if (!isa<SCEVCastExpr>(SA->getOperand(i))) 1181 hasTrunc = isa<SCEVTruncateExpr>(S); 1182 Operands.push_back(S); 1183 } 1184 if (!hasTrunc) 1185 return getAddExpr(Operands); 1186 UniqueSCEVs.FindNodeOrInsertPos(ID, IP); // Mutates IP, returns NULL. 1187 } 1188 1189 // trunc(x1*x2*...*xN) --> trunc(x1)*trunc(x2)*...*trunc(xN) if we can 1190 // eliminate all the truncates, or we replace other casts with truncates. 1191 if (const SCEVMulExpr *SM = dyn_cast<SCEVMulExpr>(Op)) { 1192 SmallVector<const SCEV *, 4> Operands; 1193 bool hasTrunc = false; 1194 for (unsigned i = 0, e = SM->getNumOperands(); i != e && !hasTrunc; ++i) { 1195 const SCEV *S = getTruncateExpr(SM->getOperand(i), Ty); 1196 if (!isa<SCEVCastExpr>(SM->getOperand(i))) 1197 hasTrunc = isa<SCEVTruncateExpr>(S); 1198 Operands.push_back(S); 1199 } 1200 if (!hasTrunc) 1201 return getMulExpr(Operands); 1202 UniqueSCEVs.FindNodeOrInsertPos(ID, IP); // Mutates IP, returns NULL. 1203 } 1204 1205 // If the input value is a chrec scev, truncate the chrec's operands. 1206 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) { 1207 SmallVector<const SCEV *, 4> Operands; 1208 for (const SCEV *Op : AddRec->operands()) 1209 Operands.push_back(getTruncateExpr(Op, Ty)); 1210 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap); 1211 } 1212 1213 // The cast wasn't folded; create an explicit cast node. We can reuse 1214 // the existing insert position since if we get here, we won't have 1215 // made any changes which would invalidate it. 1216 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), 1217 Op, Ty); 1218 UniqueSCEVs.InsertNode(S, IP); 1219 return S; 1220 } 1221 1222 // Get the limit of a recurrence such that incrementing by Step cannot cause 1223 // signed overflow as long as the value of the recurrence within the 1224 // loop does not exceed this limit before incrementing. 1225 static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step, 1226 ICmpInst::Predicate *Pred, 1227 ScalarEvolution *SE) { 1228 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); 1229 if (SE->isKnownPositive(Step)) { 1230 *Pred = ICmpInst::ICMP_SLT; 1231 return SE->getConstant(APInt::getSignedMinValue(BitWidth) - 1232 SE->getSignedRange(Step).getSignedMax()); 1233 } 1234 if (SE->isKnownNegative(Step)) { 1235 *Pred = ICmpInst::ICMP_SGT; 1236 return SE->getConstant(APInt::getSignedMaxValue(BitWidth) - 1237 SE->getSignedRange(Step).getSignedMin()); 1238 } 1239 return nullptr; 1240 } 1241 1242 // Get the limit of a recurrence such that incrementing by Step cannot cause 1243 // unsigned overflow as long as the value of the recurrence within the loop does 1244 // not exceed this limit before incrementing. 1245 static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step, 1246 ICmpInst::Predicate *Pred, 1247 ScalarEvolution *SE) { 1248 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); 1249 *Pred = ICmpInst::ICMP_ULT; 1250 1251 return SE->getConstant(APInt::getMinValue(BitWidth) - 1252 SE->getUnsignedRange(Step).getUnsignedMax()); 1253 } 1254 1255 namespace { 1256 1257 struct ExtendOpTraitsBase { 1258 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *); 1259 }; 1260 1261 // Used to make code generic over signed and unsigned overflow. 1262 template <typename ExtendOp> struct ExtendOpTraits { 1263 // Members present: 1264 // 1265 // static const SCEV::NoWrapFlags WrapType; 1266 // 1267 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr; 1268 // 1269 // static const SCEV *getOverflowLimitForStep(const SCEV *Step, 1270 // ICmpInst::Predicate *Pred, 1271 // ScalarEvolution *SE); 1272 }; 1273 1274 template <> 1275 struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase { 1276 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW; 1277 1278 static const GetExtendExprTy GetExtendExpr; 1279 1280 static const SCEV *getOverflowLimitForStep(const SCEV *Step, 1281 ICmpInst::Predicate *Pred, 1282 ScalarEvolution *SE) { 1283 return getSignedOverflowLimitForStep(Step, Pred, SE); 1284 } 1285 }; 1286 1287 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< 1288 SCEVSignExtendExpr>::GetExtendExpr = &ScalarEvolution::getSignExtendExpr; 1289 1290 template <> 1291 struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase { 1292 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW; 1293 1294 static const GetExtendExprTy GetExtendExpr; 1295 1296 static const SCEV *getOverflowLimitForStep(const SCEV *Step, 1297 ICmpInst::Predicate *Pred, 1298 ScalarEvolution *SE) { 1299 return getUnsignedOverflowLimitForStep(Step, Pred, SE); 1300 } 1301 }; 1302 1303 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< 1304 SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr; 1305 } 1306 1307 // The recurrence AR has been shown to have no signed/unsigned wrap or something 1308 // close to it. Typically, if we can prove NSW/NUW for AR, then we can just as 1309 // easily prove NSW/NUW for its preincrement or postincrement sibling. This 1310 // allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step + 1311 // Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the 1312 // expression "Step + sext/zext(PreIncAR)" is congruent with 1313 // "sext/zext(PostIncAR)" 1314 template <typename ExtendOpTy> 1315 static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, 1316 ScalarEvolution *SE) { 1317 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType; 1318 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr; 1319 1320 const Loop *L = AR->getLoop(); 1321 const SCEV *Start = AR->getStart(); 1322 const SCEV *Step = AR->getStepRecurrence(*SE); 1323 1324 // Check for a simple looking step prior to loop entry. 1325 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start); 1326 if (!SA) 1327 return nullptr; 1328 1329 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV 1330 // subtraction is expensive. For this purpose, perform a quick and dirty 1331 // difference, by checking for Step in the operand list. 1332 SmallVector<const SCEV *, 4> DiffOps; 1333 for (const SCEV *Op : SA->operands()) 1334 if (Op != Step) 1335 DiffOps.push_back(Op); 1336 1337 if (DiffOps.size() == SA->getNumOperands()) 1338 return nullptr; 1339 1340 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` + 1341 // `Step`: 1342 1343 // 1. NSW/NUW flags on the step increment. 1344 auto PreStartFlags = 1345 ScalarEvolution::maskFlags(SA->getNoWrapFlags(), SCEV::FlagNUW); 1346 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags); 1347 const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>( 1348 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap)); 1349 1350 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies 1351 // "S+X does not sign/unsign-overflow". 1352 // 1353 1354 const SCEV *BECount = SE->getBackedgeTakenCount(L); 1355 if (PreAR && PreAR->getNoWrapFlags(WrapType) && 1356 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount)) 1357 return PreStart; 1358 1359 // 2. Direct overflow check on the step operation's expression. 1360 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType()); 1361 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2); 1362 const SCEV *OperandExtendedStart = 1363 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy), 1364 (SE->*GetExtendExpr)(Step, WideTy)); 1365 if ((SE->*GetExtendExpr)(Start, WideTy) == OperandExtendedStart) { 1366 if (PreAR && AR->getNoWrapFlags(WrapType)) { 1367 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW 1368 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then 1369 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact. 1370 const_cast<SCEVAddRecExpr *>(PreAR)->setNoWrapFlags(WrapType); 1371 } 1372 return PreStart; 1373 } 1374 1375 // 3. Loop precondition. 1376 ICmpInst::Predicate Pred; 1377 const SCEV *OverflowLimit = 1378 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE); 1379 1380 if (OverflowLimit && 1381 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) 1382 return PreStart; 1383 1384 return nullptr; 1385 } 1386 1387 // Get the normalized zero or sign extended expression for this AddRec's Start. 1388 template <typename ExtendOpTy> 1389 static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, 1390 ScalarEvolution *SE) { 1391 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr; 1392 1393 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE); 1394 if (!PreStart) 1395 return (SE->*GetExtendExpr)(AR->getStart(), Ty); 1396 1397 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty), 1398 (SE->*GetExtendExpr)(PreStart, Ty)); 1399 } 1400 1401 // Try to prove away overflow by looking at "nearby" add recurrences. A 1402 // motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it 1403 // does not itself wrap then we can conclude that `{1,+,4}` is `nuw`. 1404 // 1405 // Formally: 1406 // 1407 // {S,+,X} == {S-T,+,X} + T 1408 // => Ext({S,+,X}) == Ext({S-T,+,X} + T) 1409 // 1410 // If ({S-T,+,X} + T) does not overflow ... (1) 1411 // 1412 // RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T) 1413 // 1414 // If {S-T,+,X} does not overflow ... (2) 1415 // 1416 // RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T) 1417 // == {Ext(S-T)+Ext(T),+,Ext(X)} 1418 // 1419 // If (S-T)+T does not overflow ... (3) 1420 // 1421 // RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)} 1422 // == {Ext(S),+,Ext(X)} == LHS 1423 // 1424 // Thus, if (1), (2) and (3) are true for some T, then 1425 // Ext({S,+,X}) == {Ext(S),+,Ext(X)} 1426 // 1427 // (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T) 1428 // does not overflow" restricted to the 0th iteration. Therefore we only need 1429 // to check for (1) and (2). 1430 // 1431 // In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T 1432 // is `Delta` (defined below). 1433 // 1434 template <typename ExtendOpTy> 1435 bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start, 1436 const SCEV *Step, 1437 const Loop *L) { 1438 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType; 1439 1440 // We restrict `Start` to a constant to prevent SCEV from spending too much 1441 // time here. It is correct (but more expensive) to continue with a 1442 // non-constant `Start` and do a general SCEV subtraction to compute 1443 // `PreStart` below. 1444 // 1445 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start); 1446 if (!StartC) 1447 return false; 1448 1449 APInt StartAI = StartC->getAPInt(); 1450 1451 for (unsigned Delta : {-2, -1, 1, 2}) { 1452 const SCEV *PreStart = getConstant(StartAI - Delta); 1453 1454 FoldingSetNodeID ID; 1455 ID.AddInteger(scAddRecExpr); 1456 ID.AddPointer(PreStart); 1457 ID.AddPointer(Step); 1458 ID.AddPointer(L); 1459 void *IP = nullptr; 1460 const auto *PreAR = 1461 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); 1462 1463 // Give up if we don't already have the add recurrence we need because 1464 // actually constructing an add recurrence is relatively expensive. 1465 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2) 1466 const SCEV *DeltaS = getConstant(StartC->getType(), Delta); 1467 ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; 1468 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep( 1469 DeltaS, &Pred, this); 1470 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1) 1471 return true; 1472 } 1473 } 1474 1475 return false; 1476 } 1477 1478 const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, 1479 Type *Ty) { 1480 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && 1481 "This is not an extending conversion!"); 1482 assert(isSCEVable(Ty) && 1483 "This is not a conversion to a SCEVable type!"); 1484 Ty = getEffectiveSCEVType(Ty); 1485 1486 // Fold if the operand is constant. 1487 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) 1488 return getConstant( 1489 cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty))); 1490 1491 // zext(zext(x)) --> zext(x) 1492 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op)) 1493 return getZeroExtendExpr(SZ->getOperand(), Ty); 1494 1495 // Before doing any expensive analysis, check to see if we've already 1496 // computed a SCEV for this Op and Ty. 1497 FoldingSetNodeID ID; 1498 ID.AddInteger(scZeroExtend); 1499 ID.AddPointer(Op); 1500 ID.AddPointer(Ty); 1501 void *IP = nullptr; 1502 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; 1503 1504 // zext(trunc(x)) --> zext(x) or x or trunc(x) 1505 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) { 1506 // It's possible the bits taken off by the truncate were all zero bits. If 1507 // so, we should be able to simplify this further. 1508 const SCEV *X = ST->getOperand(); 1509 ConstantRange CR = getUnsignedRange(X); 1510 unsigned TruncBits = getTypeSizeInBits(ST->getType()); 1511 unsigned NewBits = getTypeSizeInBits(Ty); 1512 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains( 1513 CR.zextOrTrunc(NewBits))) 1514 return getTruncateOrZeroExtend(X, Ty); 1515 } 1516 1517 // If the input value is a chrec scev, and we can prove that the value 1518 // did not overflow the old, smaller, value, we can zero extend all of the 1519 // operands (often constants). This allows analysis of something like 1520 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; } 1521 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) 1522 if (AR->isAffine()) { 1523 const SCEV *Start = AR->getStart(); 1524 const SCEV *Step = AR->getStepRecurrence(*this); 1525 unsigned BitWidth = getTypeSizeInBits(AR->getType()); 1526 const Loop *L = AR->getLoop(); 1527 1528 if (!AR->hasNoUnsignedWrap()) { 1529 auto NewFlags = proveNoWrapViaConstantRanges(AR); 1530 const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(NewFlags); 1531 } 1532 1533 // If we have special knowledge that this addrec won't overflow, 1534 // we don't need to do any further analysis. 1535 if (AR->hasNoUnsignedWrap()) 1536 return getAddRecExpr( 1537 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), 1538 getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); 1539 1540 // Check whether the backedge-taken count is SCEVCouldNotCompute. 1541 // Note that this serves two purposes: It filters out loops that are 1542 // simply not analyzable, and it covers the case where this code is 1543 // being called from within backedge-taken count analysis, such that 1544 // attempting to ask for the backedge-taken count would likely result 1545 // in infinite recursion. In the later case, the analysis code will 1546 // cope with a conservative value, and it will take care to purge 1547 // that value once it has finished. 1548 const SCEV *MaxBECount = getMaxBackedgeTakenCount(L); 1549 if (!isa<SCEVCouldNotCompute>(MaxBECount)) { 1550 // Manually compute the final value for AR, checking for 1551 // overflow. 1552 1553 // Check whether the backedge-taken count can be losslessly casted to 1554 // the addrec's type. The count is always unsigned. 1555 const SCEV *CastedMaxBECount = 1556 getTruncateOrZeroExtend(MaxBECount, Start->getType()); 1557 const SCEV *RecastedMaxBECount = 1558 getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType()); 1559 if (MaxBECount == RecastedMaxBECount) { 1560 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); 1561 // Check whether Start+Step*MaxBECount has no unsigned overflow. 1562 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step); 1563 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul), WideTy); 1564 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy); 1565 const SCEV *WideMaxBECount = 1566 getZeroExtendExpr(CastedMaxBECount, WideTy); 1567 const SCEV *OperandExtendedAdd = 1568 getAddExpr(WideStart, 1569 getMulExpr(WideMaxBECount, 1570 getZeroExtendExpr(Step, WideTy))); 1571 if (ZAdd == OperandExtendedAdd) { 1572 // Cache knowledge of AR NUW, which is propagated to this AddRec. 1573 const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW); 1574 // Return the expression with the addrec on the outside. 1575 return getAddRecExpr( 1576 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), 1577 getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); 1578 } 1579 // Similar to above, only this time treat the step value as signed. 1580 // This covers loops that count down. 1581 OperandExtendedAdd = 1582 getAddExpr(WideStart, 1583 getMulExpr(WideMaxBECount, 1584 getSignExtendExpr(Step, WideTy))); 1585 if (ZAdd == OperandExtendedAdd) { 1586 // Cache knowledge of AR NW, which is propagated to this AddRec. 1587 // Negative step causes unsigned wrap, but it still can't self-wrap. 1588 const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW); 1589 // Return the expression with the addrec on the outside. 1590 return getAddRecExpr( 1591 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), 1592 getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); 1593 } 1594 } 1595 } 1596 1597 // Normally, in the cases we can prove no-overflow via a 1598 // backedge guarding condition, we can also compute a backedge 1599 // taken count for the loop. The exceptions are assumptions and 1600 // guards present in the loop -- SCEV is not great at exploiting 1601 // these to compute max backedge taken counts, but can still use 1602 // these to prove lack of overflow. Use this fact to avoid 1603 // doing extra work that may not pay off. 1604 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards || 1605 !AC.assumptions().empty()) { 1606 // If the backedge is guarded by a comparison with the pre-inc 1607 // value the addrec is safe. Also, if the entry is guarded by 1608 // a comparison with the start value and the backedge is 1609 // guarded by a comparison with the post-inc value, the addrec 1610 // is safe. 1611 if (isKnownPositive(Step)) { 1612 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) - 1613 getUnsignedRange(Step).getUnsignedMax()); 1614 if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) || 1615 (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_ULT, Start, N) && 1616 isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, 1617 AR->getPostIncExpr(*this), N))) { 1618 // Cache knowledge of AR NUW, which is propagated to this 1619 // AddRec. 1620 const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW); 1621 // Return the expression with the addrec on the outside. 1622 return getAddRecExpr( 1623 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), 1624 getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); 1625 } 1626 } else if (isKnownNegative(Step)) { 1627 const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) - 1628 getSignedRange(Step).getSignedMin()); 1629 if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) || 1630 (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_UGT, Start, N) && 1631 isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, 1632 AR->getPostIncExpr(*this), N))) { 1633 // Cache knowledge of AR NW, which is propagated to this 1634 // AddRec. Negative step causes unsigned wrap, but it 1635 // still can't self-wrap. 1636 const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW); 1637 // Return the expression with the addrec on the outside. 1638 return getAddRecExpr( 1639 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), 1640 getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); 1641 } 1642 } 1643 } 1644 1645 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) { 1646 const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW); 1647 return getAddRecExpr( 1648 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), 1649 getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); 1650 } 1651 } 1652 1653 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) { 1654 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw> 1655 if (SA->hasNoUnsignedWrap()) { 1656 // If the addition does not unsign overflow then we can, by definition, 1657 // commute the zero extension with the addition operation. 1658 SmallVector<const SCEV *, 4> Ops; 1659 for (const auto *Op : SA->operands()) 1660 Ops.push_back(getZeroExtendExpr(Op, Ty)); 1661 return getAddExpr(Ops, SCEV::FlagNUW); 1662 } 1663 } 1664 1665 // The cast wasn't folded; create an explicit cast node. 1666 // Recompute the insert position, as it may have been invalidated. 1667 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; 1668 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator), 1669 Op, Ty); 1670 UniqueSCEVs.InsertNode(S, IP); 1671 return S; 1672 } 1673 1674 const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, 1675 Type *Ty) { 1676 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && 1677 "This is not an extending conversion!"); 1678 assert(isSCEVable(Ty) && 1679 "This is not a conversion to a SCEVable type!"); 1680 Ty = getEffectiveSCEVType(Ty); 1681 1682 // Fold if the operand is constant. 1683 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) 1684 return getConstant( 1685 cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty))); 1686 1687 // sext(sext(x)) --> sext(x) 1688 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op)) 1689 return getSignExtendExpr(SS->getOperand(), Ty); 1690 1691 // sext(zext(x)) --> zext(x) 1692 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op)) 1693 return getZeroExtendExpr(SZ->getOperand(), Ty); 1694 1695 // Before doing any expensive analysis, check to see if we've already 1696 // computed a SCEV for this Op and Ty. 1697 FoldingSetNodeID ID; 1698 ID.AddInteger(scSignExtend); 1699 ID.AddPointer(Op); 1700 ID.AddPointer(Ty); 1701 void *IP = nullptr; 1702 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; 1703 1704 // sext(trunc(x)) --> sext(x) or x or trunc(x) 1705 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) { 1706 // It's possible the bits taken off by the truncate were all sign bits. If 1707 // so, we should be able to simplify this further. 1708 const SCEV *X = ST->getOperand(); 1709 ConstantRange CR = getSignedRange(X); 1710 unsigned TruncBits = getTypeSizeInBits(ST->getType()); 1711 unsigned NewBits = getTypeSizeInBits(Ty); 1712 if (CR.truncate(TruncBits).signExtend(NewBits).contains( 1713 CR.sextOrTrunc(NewBits))) 1714 return getTruncateOrSignExtend(X, Ty); 1715 } 1716 1717 // sext(C1 + (C2 * x)) --> C1 + sext(C2 * x) if C1 < C2 1718 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) { 1719 if (SA->getNumOperands() == 2) { 1720 auto *SC1 = dyn_cast<SCEVConstant>(SA->getOperand(0)); 1721 auto *SMul = dyn_cast<SCEVMulExpr>(SA->getOperand(1)); 1722 if (SMul && SC1) { 1723 if (auto *SC2 = dyn_cast<SCEVConstant>(SMul->getOperand(0))) { 1724 const APInt &C1 = SC1->getAPInt(); 1725 const APInt &C2 = SC2->getAPInt(); 1726 if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && 1727 C2.ugt(C1) && C2.isPowerOf2()) 1728 return getAddExpr(getSignExtendExpr(SC1, Ty), 1729 getSignExtendExpr(SMul, Ty)); 1730 } 1731 } 1732 } 1733 1734 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw> 1735 if (SA->hasNoSignedWrap()) { 1736 // If the addition does not sign overflow then we can, by definition, 1737 // commute the sign extension with the addition operation. 1738 SmallVector<const SCEV *, 4> Ops; 1739 for (const auto *Op : SA->operands()) 1740 Ops.push_back(getSignExtendExpr(Op, Ty)); 1741 return getAddExpr(Ops, SCEV::FlagNSW); 1742 } 1743 } 1744 // If the input value is a chrec scev, and we can prove that the value 1745 // did not overflow the old, smaller, value, we can sign extend all of the 1746 // operands (often constants). This allows analysis of something like 1747 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; } 1748 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) 1749 if (AR->isAffine()) { 1750 const SCEV *Start = AR->getStart(); 1751 const SCEV *Step = AR->getStepRecurrence(*this); 1752 unsigned BitWidth = getTypeSizeInBits(AR->getType()); 1753 const Loop *L = AR->getLoop(); 1754 1755 if (!AR->hasNoSignedWrap()) { 1756 auto NewFlags = proveNoWrapViaConstantRanges(AR); 1757 const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(NewFlags); 1758 } 1759 1760 // If we have special knowledge that this addrec won't overflow, 1761 // we don't need to do any further analysis. 1762 if (AR->hasNoSignedWrap()) 1763 return getAddRecExpr( 1764 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), 1765 getSignExtendExpr(Step, Ty), L, SCEV::FlagNSW); 1766 1767 // Check whether the backedge-taken count is SCEVCouldNotCompute. 1768 // Note that this serves two purposes: It filters out loops that are 1769 // simply not analyzable, and it covers the case where this code is 1770 // being called from within backedge-taken count analysis, such that 1771 // attempting to ask for the backedge-taken count would likely result 1772 // in infinite recursion. In the later case, the analysis code will 1773 // cope with a conservative value, and it will take care to purge 1774 // that value once it has finished. 1775 const SCEV *MaxBECount = getMaxBackedgeTakenCount(L); 1776 if (!isa<SCEVCouldNotCompute>(MaxBECount)) { 1777 // Manually compute the final value for AR, checking for 1778 // overflow. 1779 1780 // Check whether the backedge-taken count can be losslessly casted to 1781 // the addrec's type. The count is always unsigned. 1782 const SCEV *CastedMaxBECount = 1783 getTruncateOrZeroExtend(MaxBECount, Start->getType()); 1784 const SCEV *RecastedMaxBECount = 1785 getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType()); 1786 if (MaxBECount == RecastedMaxBECount) { 1787 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); 1788 // Check whether Start+Step*MaxBECount has no signed overflow. 1789 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step); 1790 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul), WideTy); 1791 const SCEV *WideStart = getSignExtendExpr(Start, WideTy); 1792 const SCEV *WideMaxBECount = 1793 getZeroExtendExpr(CastedMaxBECount, WideTy); 1794 const SCEV *OperandExtendedAdd = 1795 getAddExpr(WideStart, 1796 getMulExpr(WideMaxBECount, 1797 getSignExtendExpr(Step, WideTy))); 1798 if (SAdd == OperandExtendedAdd) { 1799 // Cache knowledge of AR NSW, which is propagated to this AddRec. 1800 const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW); 1801 // Return the expression with the addrec on the outside. 1802 return getAddRecExpr( 1803 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), 1804 getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); 1805 } 1806 // Similar to above, only this time treat the step value as unsigned. 1807 // This covers loops that count up with an unsigned step. 1808 OperandExtendedAdd = 1809 getAddExpr(WideStart, 1810 getMulExpr(WideMaxBECount, 1811 getZeroExtendExpr(Step, WideTy))); 1812 if (SAdd == OperandExtendedAdd) { 1813 // If AR wraps around then 1814 // 1815 // abs(Step) * MaxBECount > unsigned-max(AR->getType()) 1816 // => SAdd != OperandExtendedAdd 1817 // 1818 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=> 1819 // (SAdd == OperandExtendedAdd => AR is NW) 1820 1821 const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW); 1822 1823 // Return the expression with the addrec on the outside. 1824 return getAddRecExpr( 1825 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), 1826 getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); 1827 } 1828 } 1829 } 1830 1831 // Normally, in the cases we can prove no-overflow via a 1832 // backedge guarding condition, we can also compute a backedge 1833 // taken count for the loop. The exceptions are assumptions and 1834 // guards present in the loop -- SCEV is not great at exploiting 1835 // these to compute max backedge taken counts, but can still use 1836 // these to prove lack of overflow. Use this fact to avoid 1837 // doing extra work that may not pay off. 1838 1839 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards || 1840 !AC.assumptions().empty()) { 1841 // If the backedge is guarded by a comparison with the pre-inc 1842 // value the addrec is safe. Also, if the entry is guarded by 1843 // a comparison with the start value and the backedge is 1844 // guarded by a comparison with the post-inc value, the addrec 1845 // is safe. 1846 ICmpInst::Predicate Pred; 1847 const SCEV *OverflowLimit = 1848 getSignedOverflowLimitForStep(Step, &Pred, this); 1849 if (OverflowLimit && 1850 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) || 1851 (isLoopEntryGuardedByCond(L, Pred, Start, OverflowLimit) && 1852 isLoopBackedgeGuardedByCond(L, Pred, AR->getPostIncExpr(*this), 1853 OverflowLimit)))) { 1854 // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec. 1855 const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW); 1856 return getAddRecExpr( 1857 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), 1858 getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); 1859 } 1860 } 1861 1862 // If Start and Step are constants, check if we can apply this 1863 // transformation: 1864 // sext{C1,+,C2} --> C1 + sext{0,+,C2} if C1 < C2 1865 auto *SC1 = dyn_cast<SCEVConstant>(Start); 1866 auto *SC2 = dyn_cast<SCEVConstant>(Step); 1867 if (SC1 && SC2) { 1868 const APInt &C1 = SC1->getAPInt(); 1869 const APInt &C2 = SC2->getAPInt(); 1870 if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) && 1871 C2.isPowerOf2()) { 1872 Start = getSignExtendExpr(Start, Ty); 1873 const SCEV *NewAR = getAddRecExpr(getZero(AR->getType()), Step, L, 1874 AR->getNoWrapFlags()); 1875 return getAddExpr(Start, getSignExtendExpr(NewAR, Ty)); 1876 } 1877 } 1878 1879 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) { 1880 const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW); 1881 return getAddRecExpr( 1882 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), 1883 getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); 1884 } 1885 } 1886 1887 // If the input value is provably positive and we could not simplify 1888 // away the sext build a zext instead. 1889 if (isKnownNonNegative(Op)) 1890 return getZeroExtendExpr(Op, Ty); 1891 1892 // The cast wasn't folded; create an explicit cast node. 1893 // Recompute the insert position, as it may have been invalidated. 1894 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; 1895 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator), 1896 Op, Ty); 1897 UniqueSCEVs.InsertNode(S, IP); 1898 return S; 1899 } 1900 1901 /// getAnyExtendExpr - Return a SCEV for the given operand extended with 1902 /// unspecified bits out to the given type. 1903 /// 1904 const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, 1905 Type *Ty) { 1906 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && 1907 "This is not an extending conversion!"); 1908 assert(isSCEVable(Ty) && 1909 "This is not a conversion to a SCEVable type!"); 1910 Ty = getEffectiveSCEVType(Ty); 1911 1912 // Sign-extend negative constants. 1913 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) 1914 if (SC->getAPInt().isNegative()) 1915 return getSignExtendExpr(Op, Ty); 1916 1917 // Peel off a truncate cast. 1918 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) { 1919 const SCEV *NewOp = T->getOperand(); 1920 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty)) 1921 return getAnyExtendExpr(NewOp, Ty); 1922 return getTruncateOrNoop(NewOp, Ty); 1923 } 1924 1925 // Next try a zext cast. If the cast is folded, use it. 1926 const SCEV *ZExt = getZeroExtendExpr(Op, Ty); 1927 if (!isa<SCEVZeroExtendExpr>(ZExt)) 1928 return ZExt; 1929 1930 // Next try a sext cast. If the cast is folded, use it. 1931 const SCEV *SExt = getSignExtendExpr(Op, Ty); 1932 if (!isa<SCEVSignExtendExpr>(SExt)) 1933 return SExt; 1934 1935 // Force the cast to be folded into the operands of an addrec. 1936 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) { 1937 SmallVector<const SCEV *, 4> Ops; 1938 for (const SCEV *Op : AR->operands()) 1939 Ops.push_back(getAnyExtendExpr(Op, Ty)); 1940 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW); 1941 } 1942 1943 // If the expression is obviously signed, use the sext cast value. 1944 if (isa<SCEVSMaxExpr>(Op)) 1945 return SExt; 1946 1947 // Absent any other information, use the zext cast value. 1948 return ZExt; 1949 } 1950 1951 /// Process the given Ops list, which is a list of operands to be added under 1952 /// the given scale, update the given map. This is a helper function for 1953 /// getAddRecExpr. As an example of what it does, given a sequence of operands 1954 /// that would form an add expression like this: 1955 /// 1956 /// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r) 1957 /// 1958 /// where A and B are constants, update the map with these values: 1959 /// 1960 /// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0) 1961 /// 1962 /// and add 13 + A*B*29 to AccumulatedConstant. 1963 /// This will allow getAddRecExpr to produce this: 1964 /// 1965 /// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B) 1966 /// 1967 /// This form often exposes folding opportunities that are hidden in 1968 /// the original operand list. 1969 /// 1970 /// Return true iff it appears that any interesting folding opportunities 1971 /// may be exposed. This helps getAddRecExpr short-circuit extra work in 1972 /// the common case where no interesting opportunities are present, and 1973 /// is also used as a check to avoid infinite recursion. 1974 /// 1975 static bool 1976 CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M, 1977 SmallVectorImpl<const SCEV *> &NewOps, 1978 APInt &AccumulatedConstant, 1979 const SCEV *const *Ops, size_t NumOperands, 1980 const APInt &Scale, 1981 ScalarEvolution &SE) { 1982 bool Interesting = false; 1983 1984 // Iterate over the add operands. They are sorted, with constants first. 1985 unsigned i = 0; 1986 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) { 1987 ++i; 1988 // Pull a buried constant out to the outside. 1989 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero()) 1990 Interesting = true; 1991 AccumulatedConstant += Scale * C->getAPInt(); 1992 } 1993 1994 // Next comes everything else. We're especially interested in multiplies 1995 // here, but they're in the middle, so just visit the rest with one loop. 1996 for (; i != NumOperands; ++i) { 1997 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]); 1998 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) { 1999 APInt NewScale = 2000 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt(); 2001 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) { 2002 // A multiplication of a constant with another add; recurse. 2003 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1)); 2004 Interesting |= 2005 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant, 2006 Add->op_begin(), Add->getNumOperands(), 2007 NewScale, SE); 2008 } else { 2009 // A multiplication of a constant with some other value. Update 2010 // the map. 2011 SmallVector<const SCEV *, 4> MulOps(Mul->op_begin()+1, Mul->op_end()); 2012 const SCEV *Key = SE.getMulExpr(MulOps); 2013 auto Pair = M.insert({Key, NewScale}); 2014 if (Pair.second) { 2015 NewOps.push_back(Pair.first->first); 2016 } else { 2017 Pair.first->second += NewScale; 2018 // The map already had an entry for this value, which may indicate 2019 // a folding opportunity. 2020 Interesting = true; 2021 } 2022 } 2023 } else { 2024 // An ordinary operand. Update the map. 2025 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair = 2026 M.insert({Ops[i], Scale}); 2027 if (Pair.second) { 2028 NewOps.push_back(Pair.first->first); 2029 } else { 2030 Pair.first->second += Scale; 2031 // The map already had an entry for this value, which may indicate 2032 // a folding opportunity. 2033 Interesting = true; 2034 } 2035 } 2036 } 2037 2038 return Interesting; 2039 } 2040 2041 // We're trying to construct a SCEV of type `Type' with `Ops' as operands and 2042 // `OldFlags' as can't-wrap behavior. Infer a more aggressive set of 2043 // can't-overflow flags for the operation if possible. 2044 static SCEV::NoWrapFlags 2045 StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, 2046 const SmallVectorImpl<const SCEV *> &Ops, 2047 SCEV::NoWrapFlags Flags) { 2048 using namespace std::placeholders; 2049 typedef OverflowingBinaryOperator OBO; 2050 2051 bool CanAnalyze = 2052 Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr; 2053 (void)CanAnalyze; 2054 assert(CanAnalyze && "don't call from other places!"); 2055 2056 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; 2057 SCEV::NoWrapFlags SignOrUnsignWrap = 2058 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask); 2059 2060 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. 2061 auto IsKnownNonNegative = [&](const SCEV *S) { 2062 return SE->isKnownNonNegative(S); 2063 }; 2064 2065 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative)) 2066 Flags = 2067 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask); 2068 2069 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask); 2070 2071 if (SignOrUnsignWrap != SignOrUnsignMask && Type == scAddExpr && 2072 Ops.size() == 2 && isa<SCEVConstant>(Ops[0])) { 2073 2074 // (A + C) --> (A + C)<nsw> if the addition does not sign overflow 2075 // (A + C) --> (A + C)<nuw> if the addition does not unsign overflow 2076 2077 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt(); 2078 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) { 2079 auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion( 2080 Instruction::Add, C, OBO::NoSignedWrap); 2081 if (NSWRegion.contains(SE->getSignedRange(Ops[1]))) 2082 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); 2083 } 2084 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) { 2085 auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion( 2086 Instruction::Add, C, OBO::NoUnsignedWrap); 2087 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1]))) 2088 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); 2089 } 2090 } 2091 2092 return Flags; 2093 } 2094 2095 /// Get a canonical add expression, or something simpler if possible. 2096 const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, 2097 SCEV::NoWrapFlags Flags) { 2098 assert(!(Flags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) && 2099 "only nuw or nsw allowed"); 2100 assert(!Ops.empty() && "Cannot get empty add!"); 2101 if (Ops.size() == 1) return Ops[0]; 2102 #ifndef NDEBUG 2103 Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); 2104 for (unsigned i = 1, e = Ops.size(); i != e; ++i) 2105 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && 2106 "SCEVAddExpr operand types don't match!"); 2107 #endif 2108 2109 // Sort by complexity, this groups all similar expression types together. 2110 GroupByComplexity(Ops, &LI); 2111 2112 Flags = StrengthenNoWrapFlags(this, scAddExpr, Ops, Flags); 2113 2114 // If there are any constants, fold them together. 2115 unsigned Idx = 0; 2116 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { 2117 ++Idx; 2118 assert(Idx < Ops.size()); 2119 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { 2120 // We found two constants, fold them together! 2121 Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt()); 2122 if (Ops.size() == 2) return Ops[0]; 2123 Ops.erase(Ops.begin()+1); // Erase the folded element 2124 LHSC = cast<SCEVConstant>(Ops[0]); 2125 } 2126 2127 // If we are left with a constant zero being added, strip it off. 2128 if (LHSC->getValue()->isZero()) { 2129 Ops.erase(Ops.begin()); 2130 --Idx; 2131 } 2132 2133 if (Ops.size() == 1) return Ops[0]; 2134 } 2135 2136 // Okay, check to see if the same value occurs in the operand list more than 2137 // once. If so, merge them together into an multiply expression. Since we 2138 // sorted the list, these values are required to be adjacent. 2139 Type *Ty = Ops[0]->getType(); 2140 bool FoundMatch = false; 2141 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i) 2142 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2 2143 // Scan ahead to count how many equal operands there are. 2144 unsigned Count = 2; 2145 while (i+Count != e && Ops[i+Count] == Ops[i]) 2146 ++Count; 2147 // Merge the values into a multiply. 2148 const SCEV *Scale = getConstant(Ty, Count); 2149 const SCEV *Mul = getMulExpr(Scale, Ops[i]); 2150 if (Ops.size() == Count) 2151 return Mul; 2152 Ops[i] = Mul; 2153 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count); 2154 --i; e -= Count - 1; 2155 FoundMatch = true; 2156 } 2157 if (FoundMatch) 2158 return getAddExpr(Ops, Flags); 2159 2160 // Check for truncates. If all the operands are truncated from the same 2161 // type, see if factoring out the truncate would permit the result to be 2162 // folded. eg., trunc(x) + m*trunc(n) --> trunc(x + trunc(m)*n) 2163 // if the contents of the resulting outer trunc fold to something simple. 2164 for (; Idx < Ops.size() && isa<SCEVTruncateExpr>(Ops[Idx]); ++Idx) { 2165 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(Ops[Idx]); 2166 Type *DstType = Trunc->getType(); 2167 Type *SrcType = Trunc->getOperand()->getType(); 2168 SmallVector<const SCEV *, 8> LargeOps; 2169 bool Ok = true; 2170 // Check all the operands to see if they can be represented in the 2171 // source type of the truncate. 2172 for (unsigned i = 0, e = Ops.size(); i != e; ++i) { 2173 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) { 2174 if (T->getOperand()->getType() != SrcType) { 2175 Ok = false; 2176 break; 2177 } 2178 LargeOps.push_back(T->getOperand()); 2179 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) { 2180 LargeOps.push_back(getAnyExtendExpr(C, SrcType)); 2181 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) { 2182 SmallVector<const SCEV *, 8> LargeMulOps; 2183 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) { 2184 if (const SCEVTruncateExpr *T = 2185 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) { 2186 if (T->getOperand()->getType() != SrcType) { 2187 Ok = false; 2188 break; 2189 } 2190 LargeMulOps.push_back(T->getOperand()); 2191 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) { 2192 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType)); 2193 } else { 2194 Ok = false; 2195 break; 2196 } 2197 } 2198 if (Ok) 2199 LargeOps.push_back(getMulExpr(LargeMulOps)); 2200 } else { 2201 Ok = false; 2202 break; 2203 } 2204 } 2205 if (Ok) { 2206 // Evaluate the expression in the larger type. 2207 const SCEV *Fold = getAddExpr(LargeOps, Flags); 2208 // If it folds to something simple, use it. Otherwise, don't. 2209 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold)) 2210 return getTruncateExpr(Fold, DstType); 2211 } 2212 } 2213 2214 // Skip past any other cast SCEVs. 2215 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr) 2216 ++Idx; 2217 2218 // If there are add operands they would be next. 2219 if (Idx < Ops.size()) { 2220 bool DeletedAdd = false; 2221 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) { 2222 // If we have an add, expand the add operands onto the end of the operands 2223 // list. 2224 Ops.erase(Ops.begin()+Idx); 2225 Ops.append(Add->op_begin(), Add->op_end()); 2226 DeletedAdd = true; 2227 } 2228 2229 // If we deleted at least one add, we added operands to the end of the list, 2230 // and they are not necessarily sorted. Recurse to resort and resimplify 2231 // any operands we just acquired. 2232 if (DeletedAdd) 2233 return getAddExpr(Ops); 2234 } 2235 2236 // Skip over the add expression until we get to a multiply. 2237 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr) 2238 ++Idx; 2239 2240 // Check to see if there are any folding opportunities present with 2241 // operands multiplied by constant values. 2242 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) { 2243 uint64_t BitWidth = getTypeSizeInBits(Ty); 2244 DenseMap<const SCEV *, APInt> M; 2245 SmallVector<const SCEV *, 8> NewOps; 2246 APInt AccumulatedConstant(BitWidth, 0); 2247 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant, 2248 Ops.data(), Ops.size(), 2249 APInt(BitWidth, 1), *this)) { 2250 struct APIntCompare { 2251 bool operator()(const APInt &LHS, const APInt &RHS) const { 2252 return LHS.ult(RHS); 2253 } 2254 }; 2255 2256 // Some interesting folding opportunity is present, so its worthwhile to 2257 // re-generate the operands list. Group the operands by constant scale, 2258 // to avoid multiplying by the same constant scale multiple times. 2259 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists; 2260 for (const SCEV *NewOp : NewOps) 2261 MulOpLists[M.find(NewOp)->second].push_back(NewOp); 2262 // Re-generate the operands list. 2263 Ops.clear(); 2264 if (AccumulatedConstant != 0) 2265 Ops.push_back(getConstant(AccumulatedConstant)); 2266 for (auto &MulOp : MulOpLists) 2267 if (MulOp.first != 0) 2268 Ops.push_back(getMulExpr(getConstant(MulOp.first), 2269 getAddExpr(MulOp.second))); 2270 if (Ops.empty()) 2271 return getZero(Ty); 2272 if (Ops.size() == 1) 2273 return Ops[0]; 2274 return getAddExpr(Ops); 2275 } 2276 } 2277 2278 // If we are adding something to a multiply expression, make sure the 2279 // something is not already an operand of the multiply. If so, merge it into 2280 // the multiply. 2281 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) { 2282 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]); 2283 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) { 2284 const SCEV *MulOpSCEV = Mul->getOperand(MulOp); 2285 if (isa<SCEVConstant>(MulOpSCEV)) 2286 continue; 2287 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp) 2288 if (MulOpSCEV == Ops[AddOp]) { 2289 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1)) 2290 const SCEV *InnerMul = Mul->getOperand(MulOp == 0); 2291 if (Mul->getNumOperands() != 2) { 2292 // If the multiply has more than two operands, we must get the 2293 // Y*Z term. 2294 SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(), 2295 Mul->op_begin()+MulOp); 2296 MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end()); 2297 InnerMul = getMulExpr(MulOps); 2298 } 2299 const SCEV *One = getOne(Ty); 2300 const SCEV *AddOne = getAddExpr(One, InnerMul); 2301 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV); 2302 if (Ops.size() == 2) return OuterMul; 2303 if (AddOp < Idx) { 2304 Ops.erase(Ops.begin()+AddOp); 2305 Ops.erase(Ops.begin()+Idx-1); 2306 } else { 2307 Ops.erase(Ops.begin()+Idx); 2308 Ops.erase(Ops.begin()+AddOp-1); 2309 } 2310 Ops.push_back(OuterMul); 2311 return getAddExpr(Ops); 2312 } 2313 2314 // Check this multiply against other multiplies being added together. 2315 for (unsigned OtherMulIdx = Idx+1; 2316 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]); 2317 ++OtherMulIdx) { 2318 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]); 2319 // If MulOp occurs in OtherMul, we can fold the two multiplies 2320 // together. 2321 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands(); 2322 OMulOp != e; ++OMulOp) 2323 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) { 2324 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E)) 2325 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0); 2326 if (Mul->getNumOperands() != 2) { 2327 SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(), 2328 Mul->op_begin()+MulOp); 2329 MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end()); 2330 InnerMul1 = getMulExpr(MulOps); 2331 } 2332 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0); 2333 if (OtherMul->getNumOperands() != 2) { 2334 SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(), 2335 OtherMul->op_begin()+OMulOp); 2336 MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end()); 2337 InnerMul2 = getMulExpr(MulOps); 2338 } 2339 const SCEV *InnerMulSum = getAddExpr(InnerMul1,InnerMul2); 2340 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum); 2341 if (Ops.size() == 2) return OuterMul; 2342 Ops.erase(Ops.begin()+Idx); 2343 Ops.erase(Ops.begin()+OtherMulIdx-1); 2344 Ops.push_back(OuterMul); 2345 return getAddExpr(Ops); 2346 } 2347 } 2348 } 2349 } 2350 2351 // If there are any add recurrences in the operands list, see if any other 2352 // added values are loop invariant. If so, we can fold them into the 2353 // recurrence. 2354 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr) 2355 ++Idx; 2356 2357 // Scan over all recurrences, trying to fold loop invariants into them. 2358 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) { 2359 // Scan all of the other operands to this add and add them to the vector if 2360 // they are loop invariant w.r.t. the recurrence. 2361 SmallVector<const SCEV *, 8> LIOps; 2362 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]); 2363 const Loop *AddRecLoop = AddRec->getLoop(); 2364 for (unsigned i = 0, e = Ops.size(); i != e; ++i) 2365 if (isLoopInvariant(Ops[i], AddRecLoop)) { 2366 LIOps.push_back(Ops[i]); 2367 Ops.erase(Ops.begin()+i); 2368 --i; --e; 2369 } 2370 2371 // If we found some loop invariants, fold them into the recurrence. 2372 if (!LIOps.empty()) { 2373 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step} 2374 LIOps.push_back(AddRec->getStart()); 2375 2376 SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(), 2377 AddRec->op_end()); 2378 // This follows from the fact that the no-wrap flags on the outer add 2379 // expression are applicable on the 0th iteration, when the add recurrence 2380 // will be equal to its start value. 2381 AddRecOps[0] = getAddExpr(LIOps, Flags); 2382 2383 // Build the new addrec. Propagate the NUW and NSW flags if both the 2384 // outer add and the inner addrec are guaranteed to have no overflow. 2385 // Always propagate NW. 2386 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW)); 2387 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags); 2388 2389 // If all of the other operands were loop invariant, we are done. 2390 if (Ops.size() == 1) return NewRec; 2391 2392 // Otherwise, add the folded AddRec by the non-invariant parts. 2393 for (unsigned i = 0;; ++i) 2394 if (Ops[i] == AddRec) { 2395 Ops[i] = NewRec; 2396 break; 2397 } 2398 return getAddExpr(Ops); 2399 } 2400 2401 // Okay, if there weren't any loop invariants to be folded, check to see if 2402 // there are multiple AddRec's with the same loop induction variable being 2403 // added together. If so, we can fold them. 2404 for (unsigned OtherIdx = Idx+1; 2405 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); 2406 ++OtherIdx) 2407 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) { 2408 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L> 2409 SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(), 2410 AddRec->op_end()); 2411 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); 2412 ++OtherIdx) 2413 if (const auto *OtherAddRec = dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx])) 2414 if (OtherAddRec->getLoop() == AddRecLoop) { 2415 for (unsigned i = 0, e = OtherAddRec->getNumOperands(); 2416 i != e; ++i) { 2417 if (i >= AddRecOps.size()) { 2418 AddRecOps.append(OtherAddRec->op_begin()+i, 2419 OtherAddRec->op_end()); 2420 break; 2421 } 2422 AddRecOps[i] = getAddExpr(AddRecOps[i], 2423 OtherAddRec->getOperand(i)); 2424 } 2425 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; 2426 } 2427 // Step size has changed, so we cannot guarantee no self-wraparound. 2428 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap); 2429 return getAddExpr(Ops); 2430 } 2431 2432 // Otherwise couldn't fold anything into this recurrence. Move onto the 2433 // next one. 2434 } 2435 2436 // Okay, it looks like we really DO need an add expr. Check to see if we 2437 // already have one, otherwise create a new one. 2438 FoldingSetNodeID ID; 2439 ID.AddInteger(scAddExpr); 2440 for (unsigned i = 0, e = Ops.size(); i != e; ++i) 2441 ID.AddPointer(Ops[i]); 2442 void *IP = nullptr; 2443 SCEVAddExpr *S = 2444 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); 2445 if (!S) { 2446 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); 2447 std::uninitialized_copy(Ops.begin(), Ops.end(), O); 2448 S = new (SCEVAllocator) SCEVAddExpr(ID.Intern(SCEVAllocator), 2449 O, Ops.size()); 2450 UniqueSCEVs.InsertNode(S, IP); 2451 } 2452 S->setNoWrapFlags(Flags); 2453 return S; 2454 } 2455 2456 static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) { 2457 uint64_t k = i*j; 2458 if (j > 1 && k / j != i) Overflow = true; 2459 return k; 2460 } 2461 2462 /// Compute the result of "n choose k", the binomial coefficient. If an 2463 /// intermediate computation overflows, Overflow will be set and the return will 2464 /// be garbage. Overflow is not cleared on absence of overflow. 2465 static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) { 2466 // We use the multiplicative formula: 2467 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 . 2468 // At each iteration, we take the n-th term of the numeral and divide by the 2469 // (k-n)th term of the denominator. This division will always produce an 2470 // integral result, and helps reduce the chance of overflow in the 2471 // intermediate computations. However, we can still overflow even when the 2472 // final result would fit. 2473 2474 if (n == 0 || n == k) return 1; 2475 if (k > n) return 0; 2476 2477 if (k > n/2) 2478 k = n-k; 2479 2480 uint64_t r = 1; 2481 for (uint64_t i = 1; i <= k; ++i) { 2482 r = umul_ov(r, n-(i-1), Overflow); 2483 r /= i; 2484 } 2485 return r; 2486 } 2487 2488 /// Determine if any of the operands in this SCEV are a constant or if 2489 /// any of the add or multiply expressions in this SCEV contain a constant. 2490 static bool containsConstantSomewhere(const SCEV *StartExpr) { 2491 SmallVector<const SCEV *, 4> Ops; 2492 Ops.push_back(StartExpr); 2493 while (!Ops.empty()) { 2494 const SCEV *CurrentExpr = Ops.pop_back_val(); 2495 if (isa<SCEVConstant>(*CurrentExpr)) 2496 return true; 2497 2498 if (isa<SCEVAddExpr>(*CurrentExpr) || isa<SCEVMulExpr>(*CurrentExpr)) { 2499 const auto *CurrentNAry = cast<SCEVNAryExpr>(CurrentExpr); 2500 Ops.append(CurrentNAry->op_begin(), CurrentNAry->op_end()); 2501 } 2502 } 2503 return false; 2504 } 2505 2506 /// Get a canonical multiply expression, or something simpler if possible. 2507 const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, 2508 SCEV::NoWrapFlags Flags) { 2509 assert(Flags == maskFlags(Flags, SCEV::FlagNUW | SCEV::FlagNSW) && 2510 "only nuw or nsw allowed"); 2511 assert(!Ops.empty() && "Cannot get empty mul!"); 2512 if (Ops.size() == 1) return Ops[0]; 2513 #ifndef NDEBUG 2514 Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); 2515 for (unsigned i = 1, e = Ops.size(); i != e; ++i) 2516 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && 2517 "SCEVMulExpr operand types don't match!"); 2518 #endif 2519 2520 // Sort by complexity, this groups all similar expression types together. 2521 GroupByComplexity(Ops, &LI); 2522 2523 Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags); 2524 2525 // If there are any constants, fold them together. 2526 unsigned Idx = 0; 2527 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { 2528 2529 // C1*(C2+V) -> C1*C2 + C1*V 2530 if (Ops.size() == 2) 2531 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) 2532 // If any of Add's ops are Adds or Muls with a constant, 2533 // apply this transformation as well. 2534 if (Add->getNumOperands() == 2) 2535 if (containsConstantSomewhere(Add)) 2536 return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)), 2537 getMulExpr(LHSC, Add->getOperand(1))); 2538 2539 ++Idx; 2540 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { 2541 // We found two constants, fold them together! 2542 ConstantInt *Fold = 2543 ConstantInt::get(getContext(), LHSC->getAPInt() * RHSC->getAPInt()); 2544 Ops[0] = getConstant(Fold); 2545 Ops.erase(Ops.begin()+1); // Erase the folded element 2546 if (Ops.size() == 1) return Ops[0]; 2547 LHSC = cast<SCEVConstant>(Ops[0]); 2548 } 2549 2550 // If we are left with a constant one being multiplied, strip it off. 2551 if (cast<SCEVConstant>(Ops[0])->getValue()->equalsInt(1)) { 2552 Ops.erase(Ops.begin()); 2553 --Idx; 2554 } else if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) { 2555 // If we have a multiply of zero, it will always be zero. 2556 return Ops[0]; 2557 } else if (Ops[0]->isAllOnesValue()) { 2558 // If we have a mul by -1 of an add, try distributing the -1 among the 2559 // add operands. 2560 if (Ops.size() == 2) { 2561 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) { 2562 SmallVector<const SCEV *, 4> NewOps; 2563 bool AnyFolded = false; 2564 for (const SCEV *AddOp : Add->operands()) { 2565 const SCEV *Mul = getMulExpr(Ops[0], AddOp); 2566 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true; 2567 NewOps.push_back(Mul); 2568 } 2569 if (AnyFolded) 2570 return getAddExpr(NewOps); 2571 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) { 2572 // Negation preserves a recurrence's no self-wrap property. 2573 SmallVector<const SCEV *, 4> Operands; 2574 for (const SCEV *AddRecOp : AddRec->operands()) 2575 Operands.push_back(getMulExpr(Ops[0], AddRecOp)); 2576 2577 return getAddRecExpr(Operands, AddRec->getLoop(), 2578 AddRec->getNoWrapFlags(SCEV::FlagNW)); 2579 } 2580 } 2581 } 2582 2583 if (Ops.size() == 1) 2584 return Ops[0]; 2585 } 2586 2587 // Skip over the add expression until we get to a multiply. 2588 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr) 2589 ++Idx; 2590 2591 // If there are mul operands inline them all into this expression. 2592 if (Idx < Ops.size()) { 2593 bool DeletedMul = false; 2594 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) { 2595 if (Ops.size() > MulOpsInlineThreshold) 2596 break; 2597 // If we have an mul, expand the mul operands onto the end of the operands 2598 // list. 2599 Ops.erase(Ops.begin()+Idx); 2600 Ops.append(Mul->op_begin(), Mul->op_end()); 2601 DeletedMul = true; 2602 } 2603 2604 // If we deleted at least one mul, we added operands to the end of the list, 2605 // and they are not necessarily sorted. Recurse to resort and resimplify 2606 // any operands we just acquired. 2607 if (DeletedMul) 2608 return getMulExpr(Ops); 2609 } 2610 2611 // If there are any add recurrences in the operands list, see if any other 2612 // added values are loop invariant. If so, we can fold them into the 2613 // recurrence. 2614 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr) 2615 ++Idx; 2616 2617 // Scan over all recurrences, trying to fold loop invariants into them. 2618 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) { 2619 // Scan all of the other operands to this mul and add them to the vector if 2620 // they are loop invariant w.r.t. the recurrence. 2621 SmallVector<const SCEV *, 8> LIOps; 2622 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]); 2623 const Loop *AddRecLoop = AddRec->getLoop(); 2624 for (unsigned i = 0, e = Ops.size(); i != e; ++i) 2625 if (isLoopInvariant(Ops[i], AddRecLoop)) { 2626 LIOps.push_back(Ops[i]); 2627 Ops.erase(Ops.begin()+i); 2628 --i; --e; 2629 } 2630 2631 // If we found some loop invariants, fold them into the recurrence. 2632 if (!LIOps.empty()) { 2633 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step} 2634 SmallVector<const SCEV *, 4> NewOps; 2635 NewOps.reserve(AddRec->getNumOperands()); 2636 const SCEV *Scale = getMulExpr(LIOps); 2637 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) 2638 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i))); 2639 2640 // Build the new addrec. Propagate the NUW and NSW flags if both the 2641 // outer mul and the inner addrec are guaranteed to have no overflow. 2642 // 2643 // No self-wrap cannot be guaranteed after changing the step size, but 2644 // will be inferred if either NUW or NSW is true. 2645 Flags = AddRec->getNoWrapFlags(clearFlags(Flags, SCEV::FlagNW)); 2646 const SCEV *NewRec = getAddRecExpr(NewOps, AddRecLoop, Flags); 2647 2648 // If all of the other operands were loop invariant, we are done. 2649 if (Ops.size() == 1) return NewRec; 2650 2651 // Otherwise, multiply the folded AddRec by the non-invariant parts. 2652 for (unsigned i = 0;; ++i) 2653 if (Ops[i] == AddRec) { 2654 Ops[i] = NewRec; 2655 break; 2656 } 2657 return getMulExpr(Ops); 2658 } 2659 2660 // Okay, if there weren't any loop invariants to be folded, check to see if 2661 // there are multiple AddRec's with the same loop induction variable being 2662 // multiplied together. If so, we can fold them. 2663 2664 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L> 2665 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [ 2666 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z 2667 // ]]],+,...up to x=2n}. 2668 // Note that the arguments to choose() are always integers with values 2669 // known at compile time, never SCEV objects. 2670 // 2671 // The implementation avoids pointless extra computations when the two 2672 // addrec's are of different length (mathematically, it's equivalent to 2673 // an infinite stream of zeros on the right). 2674 bool OpsModified = false; 2675 for (unsigned OtherIdx = Idx+1; 2676 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); 2677 ++OtherIdx) { 2678 const SCEVAddRecExpr *OtherAddRec = 2679 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]); 2680 if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop) 2681 continue; 2682 2683 bool Overflow = false; 2684 Type *Ty = AddRec->getType(); 2685 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64; 2686 SmallVector<const SCEV*, 7> AddRecOps; 2687 for (int x = 0, xe = AddRec->getNumOperands() + 2688 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) { 2689 const SCEV *Term = getZero(Ty); 2690 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) { 2691 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow); 2692 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1), 2693 ze = std::min(x+1, (int)OtherAddRec->getNumOperands()); 2694 z < ze && !Overflow; ++z) { 2695 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow); 2696 uint64_t Coeff; 2697 if (LargerThan64Bits) 2698 Coeff = umul_ov(Coeff1, Coeff2, Overflow); 2699 else 2700 Coeff = Coeff1*Coeff2; 2701 const SCEV *CoeffTerm = getConstant(Ty, Coeff); 2702 const SCEV *Term1 = AddRec->getOperand(y-z); 2703 const SCEV *Term2 = OtherAddRec->getOperand(z); 2704 Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1,Term2)); 2705 } 2706 } 2707 AddRecOps.push_back(Term); 2708 } 2709 if (!Overflow) { 2710 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(), 2711 SCEV::FlagAnyWrap); 2712 if (Ops.size() == 2) return NewAddRec; 2713 Ops[Idx] = NewAddRec; 2714 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; 2715 OpsModified = true; 2716 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec); 2717 if (!AddRec) 2718 break; 2719 } 2720 } 2721 if (OpsModified) 2722 return getMulExpr(Ops); 2723 2724 // Otherwise couldn't fold anything into this recurrence. Move onto the 2725 // next one. 2726 } 2727 2728 // Okay, it looks like we really DO need an mul expr. Check to see if we 2729 // already have one, otherwise create a new one. 2730 FoldingSetNodeID ID; 2731 ID.AddInteger(scMulExpr); 2732 for (unsigned i = 0, e = Ops.size(); i != e; ++i) 2733 ID.AddPointer(Ops[i]); 2734 void *IP = nullptr; 2735 SCEVMulExpr *S = 2736 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); 2737 if (!S) { 2738 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); 2739 std::uninitialized_copy(Ops.begin(), Ops.end(), O); 2740 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator), 2741 O, Ops.size()); 2742 UniqueSCEVs.InsertNode(S, IP); 2743 } 2744 S->setNoWrapFlags(Flags); 2745 return S; 2746 } 2747 2748 /// Get a canonical unsigned division expression, or something simpler if 2749 /// possible. 2750 const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, 2751 const SCEV *RHS) { 2752 assert(getEffectiveSCEVType(LHS->getType()) == 2753 getEffectiveSCEVType(RHS->getType()) && 2754 "SCEVUDivExpr operand types don't match!"); 2755 2756 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) { 2757 if (RHSC->getValue()->equalsInt(1)) 2758 return LHS; // X udiv 1 --> x 2759 // If the denominator is zero, the result of the udiv is undefined. Don't 2760 // try to analyze it, because the resolution chosen here may differ from 2761 // the resolution chosen in other parts of the compiler. 2762 if (!RHSC->getValue()->isZero()) { 2763 // Determine if the division can be folded into the operands of 2764 // its operands. 2765 // TODO: Generalize this to non-constants by using known-bits information. 2766 Type *Ty = LHS->getType(); 2767 unsigned LZ = RHSC->getAPInt().countLeadingZeros(); 2768 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1; 2769 // For non-power-of-two values, effectively round the value up to the 2770 // nearest power of two. 2771 if (!RHSC->getAPInt().isPowerOf2()) 2772 ++MaxShiftAmt; 2773 IntegerType *ExtTy = 2774 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt); 2775 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS)) 2776 if (const SCEVConstant *Step = 2777 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) { 2778 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded. 2779 const APInt &StepInt = Step->getAPInt(); 2780 const APInt &DivInt = RHSC->getAPInt(); 2781 if (!StepInt.urem(DivInt) && 2782 getZeroExtendExpr(AR, ExtTy) == 2783 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy), 2784 getZeroExtendExpr(Step, ExtTy), 2785 AR->getLoop(), SCEV::FlagAnyWrap)) { 2786 SmallVector<const SCEV *, 4> Operands; 2787 for (const SCEV *Op : AR->operands()) 2788 Operands.push_back(getUDivExpr(Op, RHS)); 2789 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW); 2790 } 2791 /// Get a canonical UDivExpr for a recurrence. 2792 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0. 2793 // We can currently only fold X%N if X is constant. 2794 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart()); 2795 if (StartC && !DivInt.urem(StepInt) && 2796 getZeroExtendExpr(AR, ExtTy) == 2797 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy), 2798 getZeroExtendExpr(Step, ExtTy), 2799 AR->getLoop(), SCEV::FlagAnyWrap)) { 2800 const APInt &StartInt = StartC->getAPInt(); 2801 const APInt &StartRem = StartInt.urem(StepInt); 2802 if (StartRem != 0) 2803 LHS = getAddRecExpr(getConstant(StartInt - StartRem), Step, 2804 AR->getLoop(), SCEV::FlagNW); 2805 } 2806 } 2807 // (A*B)/C --> A*(B/C) if safe and B/C can be folded. 2808 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) { 2809 SmallVector<const SCEV *, 4> Operands; 2810 for (const SCEV *Op : M->operands()) 2811 Operands.push_back(getZeroExtendExpr(Op, ExtTy)); 2812 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands)) 2813 // Find an operand that's safely divisible. 2814 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) { 2815 const SCEV *Op = M->getOperand(i); 2816 const SCEV *Div = getUDivExpr(Op, RHSC); 2817 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) { 2818 Operands = SmallVector<const SCEV *, 4>(M->op_begin(), 2819 M->op_end()); 2820 Operands[i] = Div; 2821 return getMulExpr(Operands); 2822 } 2823 } 2824 } 2825 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded. 2826 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) { 2827 SmallVector<const SCEV *, 4> Operands; 2828 for (const SCEV *Op : A->operands()) 2829 Operands.push_back(getZeroExtendExpr(Op, ExtTy)); 2830 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) { 2831 Operands.clear(); 2832 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) { 2833 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS); 2834 if (isa<SCEVUDivExpr>(Op) || 2835 getMulExpr(Op, RHS) != A->getOperand(i)) 2836 break; 2837 Operands.push_back(Op); 2838 } 2839 if (Operands.size() == A->getNumOperands()) 2840 return getAddExpr(Operands); 2841 } 2842 } 2843 2844 // Fold if both operands are constant. 2845 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) { 2846 Constant *LHSCV = LHSC->getValue(); 2847 Constant *RHSCV = RHSC->getValue(); 2848 return getConstant(cast<ConstantInt>(ConstantExpr::getUDiv(LHSCV, 2849 RHSCV))); 2850 } 2851 } 2852 } 2853 2854 FoldingSetNodeID ID; 2855 ID.AddInteger(scUDivExpr); 2856 ID.AddPointer(LHS); 2857 ID.AddPointer(RHS); 2858 void *IP = nullptr; 2859 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; 2860 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator), 2861 LHS, RHS); 2862 UniqueSCEVs.InsertNode(S, IP); 2863 return S; 2864 } 2865 2866 static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) { 2867 APInt A = C1->getAPInt().abs(); 2868 APInt B = C2->getAPInt().abs(); 2869 uint32_t ABW = A.getBitWidth(); 2870 uint32_t BBW = B.getBitWidth(); 2871 2872 if (ABW > BBW) 2873 B = B.zext(ABW); 2874 else if (ABW < BBW) 2875 A = A.zext(BBW); 2876 2877 return APIntOps::GreatestCommonDivisor(A, B); 2878 } 2879 2880 /// Get a canonical unsigned division expression, or something simpler if 2881 /// possible. There is no representation for an exact udiv in SCEV IR, but we 2882 /// can attempt to remove factors from the LHS and RHS. We can't do this when 2883 /// it's not exact because the udiv may be clearing bits. 2884 const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, 2885 const SCEV *RHS) { 2886 // TODO: we could try to find factors in all sorts of things, but for now we 2887 // just deal with u/exact (multiply, constant). See SCEVDivision towards the 2888 // end of this file for inspiration. 2889 2890 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS); 2891 if (!Mul || !Mul->hasNoUnsignedWrap()) 2892 return getUDivExpr(LHS, RHS); 2893 2894 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) { 2895 // If the mulexpr multiplies by a constant, then that constant must be the 2896 // first element of the mulexpr. 2897 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) { 2898 if (LHSCst == RHSCst) { 2899 SmallVector<const SCEV *, 2> Operands; 2900 Operands.append(Mul->op_begin() + 1, Mul->op_end()); 2901 return getMulExpr(Operands); 2902 } 2903 2904 // We can't just assume that LHSCst divides RHSCst cleanly, it could be 2905 // that there's a factor provided by one of the other terms. We need to 2906 // check. 2907 APInt Factor = gcd(LHSCst, RHSCst); 2908 if (!Factor.isIntN(1)) { 2909 LHSCst = 2910 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor))); 2911 RHSCst = 2912 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor))); 2913 SmallVector<const SCEV *, 2> Operands; 2914 Operands.push_back(LHSCst); 2915 Operands.append(Mul->op_begin() + 1, Mul->op_end()); 2916 LHS = getMulExpr(Operands); 2917 RHS = RHSCst; 2918 Mul = dyn_cast<SCEVMulExpr>(LHS); 2919 if (!Mul) 2920 return getUDivExactExpr(LHS, RHS); 2921 } 2922 } 2923 } 2924 2925 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) { 2926 if (Mul->getOperand(i) == RHS) { 2927 SmallVector<const SCEV *, 2> Operands; 2928 Operands.append(Mul->op_begin(), Mul->op_begin() + i); 2929 Operands.append(Mul->op_begin() + i + 1, Mul->op_end()); 2930 return getMulExpr(Operands); 2931 } 2932 } 2933 2934 return getUDivExpr(LHS, RHS); 2935 } 2936 2937 /// Get an add recurrence expression for the specified loop. Simplify the 2938 /// expression as much as possible. 2939 const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step, 2940 const Loop *L, 2941 SCEV::NoWrapFlags Flags) { 2942 SmallVector<const SCEV *, 4> Operands; 2943 Operands.push_back(Start); 2944 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step)) 2945 if (StepChrec->getLoop() == L) { 2946 Operands.append(StepChrec->op_begin(), StepChrec->op_end()); 2947 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW)); 2948 } 2949 2950 Operands.push_back(Step); 2951 return getAddRecExpr(Operands, L, Flags); 2952 } 2953 2954 /// Get an add recurrence expression for the specified loop. Simplify the 2955 /// expression as much as possible. 2956 const SCEV * 2957 ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands, 2958 const Loop *L, SCEV::NoWrapFlags Flags) { 2959 if (Operands.size() == 1) return Operands[0]; 2960 #ifndef NDEBUG 2961 Type *ETy = getEffectiveSCEVType(Operands[0]->getType()); 2962 for (unsigned i = 1, e = Operands.size(); i != e; ++i) 2963 assert(getEffectiveSCEVType(Operands[i]->getType()) == ETy && 2964 "SCEVAddRecExpr operand types don't match!"); 2965 for (unsigned i = 0, e = Operands.size(); i != e; ++i) 2966 assert(isLoopInvariant(Operands[i], L) && 2967 "SCEVAddRecExpr operand is not loop-invariant!"); 2968 #endif 2969 2970 if (Operands.back()->isZero()) { 2971 Operands.pop_back(); 2972 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X 2973 } 2974 2975 // It's tempting to want to call getMaxBackedgeTakenCount count here and 2976 // use that information to infer NUW and NSW flags. However, computing a 2977 // BE count requires calling getAddRecExpr, so we may not yet have a 2978 // meaningful BE count at this point (and if we don't, we'd be stuck 2979 // with a SCEVCouldNotCompute as the cached BE count). 2980 2981 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags); 2982 2983 // Canonicalize nested AddRecs in by nesting them in order of loop depth. 2984 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) { 2985 const Loop *NestedLoop = NestedAR->getLoop(); 2986 if (L->contains(NestedLoop) 2987 ? (L->getLoopDepth() < NestedLoop->getLoopDepth()) 2988 : (!NestedLoop->contains(L) && 2989 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) { 2990 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->op_begin(), 2991 NestedAR->op_end()); 2992 Operands[0] = NestedAR->getStart(); 2993 // AddRecs require their operands be loop-invariant with respect to their 2994 // loops. Don't perform this transformation if it would break this 2995 // requirement. 2996 bool AllInvariant = all_of( 2997 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); }); 2998 2999 if (AllInvariant) { 3000 // Create a recurrence for the outer loop with the same step size. 3001 // 3002 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the 3003 // inner recurrence has the same property. 3004 SCEV::NoWrapFlags OuterFlags = 3005 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags()); 3006 3007 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags); 3008 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) { 3009 return isLoopInvariant(Op, NestedLoop); 3010 }); 3011 3012 if (AllInvariant) { 3013 // Ok, both add recurrences are valid after the transformation. 3014 // 3015 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if 3016 // the outer recurrence has the same property. 3017 SCEV::NoWrapFlags InnerFlags = 3018 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags); 3019 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags); 3020 } 3021 } 3022 // Reset Operands to its original state. 3023 Operands[0] = NestedAR; 3024 } 3025 } 3026 3027 // Okay, it looks like we really DO need an addrec expr. Check to see if we 3028 // already have one, otherwise create a new one. 3029 FoldingSetNodeID ID; 3030 ID.AddInteger(scAddRecExpr); 3031 for (unsigned i = 0, e = Operands.size(); i != e; ++i) 3032 ID.AddPointer(Operands[i]); 3033 ID.AddPointer(L); 3034 void *IP = nullptr; 3035 SCEVAddRecExpr *S = 3036 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); 3037 if (!S) { 3038 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Operands.size()); 3039 std::uninitialized_copy(Operands.begin(), Operands.end(), O); 3040 S = new (SCEVAllocator) SCEVAddRecExpr(ID.Intern(SCEVAllocator), 3041 O, Operands.size(), L); 3042 UniqueSCEVs.InsertNode(S, IP); 3043 } 3044 S->setNoWrapFlags(Flags); 3045 return S; 3046 } 3047 3048 const SCEV * 3049 ScalarEvolution::getGEPExpr(GEPOperator *GEP, 3050 const SmallVectorImpl<const SCEV *> &IndexExprs) { 3051 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand()); 3052 // getSCEV(Base)->getType() has the same address space as Base->getType() 3053 // because SCEV::getType() preserves the address space. 3054 Type *IntPtrTy = getEffectiveSCEVType(BaseExpr->getType()); 3055 // FIXME(PR23527): Don't blindly transfer the inbounds flag from the GEP 3056 // instruction to its SCEV, because the Instruction may be guarded by control 3057 // flow and the no-overflow bits may not be valid for the expression in any 3058 // context. This can be fixed similarly to how these flags are handled for 3059 // adds. 3060 SCEV::NoWrapFlags Wrap = GEP->isInBounds() ? SCEV::FlagNSW 3061 : SCEV::FlagAnyWrap; 3062 3063 const SCEV *TotalOffset = getZero(IntPtrTy); 3064 // The array size is unimportant. The first thing we do on CurTy is getting 3065 // its element type. 3066 Type *CurTy = ArrayType::get(GEP->getSourceElementType(), 0); 3067 for (const SCEV *IndexExpr : IndexExprs) { 3068 // Compute the (potentially symbolic) offset in bytes for this index. 3069 if (StructType *STy = dyn_cast<StructType>(CurTy)) { 3070 // For a struct, add the member offset. 3071 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue(); 3072 unsigned FieldNo = Index->getZExtValue(); 3073 const SCEV *FieldOffset = getOffsetOfExpr(IntPtrTy, STy, FieldNo); 3074 3075 // Add the field offset to the running total offset. 3076 TotalOffset = getAddExpr(TotalOffset, FieldOffset); 3077 3078 // Update CurTy to the type of the field at Index. 3079 CurTy = STy->getTypeAtIndex(Index); 3080 } else { 3081 // Update CurTy to its element type. 3082 CurTy = cast<SequentialType>(CurTy)->getElementType(); 3083 // For an array, add the element offset, explicitly scaled. 3084 const SCEV *ElementSize = getSizeOfExpr(IntPtrTy, CurTy); 3085 // Getelementptr indices are signed. 3086 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntPtrTy); 3087 3088 // Multiply the index by the element size to compute the element offset. 3089 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, Wrap); 3090 3091 // Add the element offset to the running total offset. 3092 TotalOffset = getAddExpr(TotalOffset, LocalOffset); 3093 } 3094 } 3095 3096 // Add the total offset from all the GEP indices to the base. 3097 return getAddExpr(BaseExpr, TotalOffset, Wrap); 3098 } 3099 3100 const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, 3101 const SCEV *RHS) { 3102 SmallVector<const SCEV *, 2> Ops = {LHS, RHS}; 3103 return getSMaxExpr(Ops); 3104 } 3105 3106 const SCEV * 3107 ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) { 3108 assert(!Ops.empty() && "Cannot get empty smax!"); 3109 if (Ops.size() == 1) return Ops[0]; 3110 #ifndef NDEBUG 3111 Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); 3112 for (unsigned i = 1, e = Ops.size(); i != e; ++i) 3113 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && 3114 "SCEVSMaxExpr operand types don't match!"); 3115 #endif 3116 3117 // Sort by complexity, this groups all similar expression types together. 3118 GroupByComplexity(Ops, &LI); 3119 3120 // If there are any constants, fold them together. 3121 unsigned Idx = 0; 3122 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { 3123 ++Idx; 3124 assert(Idx < Ops.size()); 3125 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { 3126 // We found two constants, fold them together! 3127 ConstantInt *Fold = ConstantInt::get( 3128 getContext(), APIntOps::smax(LHSC->getAPInt(), RHSC->getAPInt())); 3129 Ops[0] = getConstant(Fold); 3130 Ops.erase(Ops.begin()+1); // Erase the folded element 3131 if (Ops.size() == 1) return Ops[0]; 3132 LHSC = cast<SCEVConstant>(Ops[0]); 3133 } 3134 3135 // If we are left with a constant minimum-int, strip it off. 3136 if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(true)) { 3137 Ops.erase(Ops.begin()); 3138 --Idx; 3139 } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(true)) { 3140 // If we have an smax with a constant maximum-int, it will always be 3141 // maximum-int. 3142 return Ops[0]; 3143 } 3144 3145 if (Ops.size() == 1) return Ops[0]; 3146 } 3147 3148 // Find the first SMax 3149 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr) 3150 ++Idx; 3151 3152 // Check to see if one of the operands is an SMax. If so, expand its operands 3153 // onto our operand list, and recurse to simplify. 3154 if (Idx < Ops.size()) { 3155 bool DeletedSMax = false; 3156 while (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(Ops[Idx])) { 3157 Ops.erase(Ops.begin()+Idx); 3158 Ops.append(SMax->op_begin(), SMax->op_end()); 3159 DeletedSMax = true; 3160 } 3161 3162 if (DeletedSMax) 3163 return getSMaxExpr(Ops); 3164 } 3165 3166 // Okay, check to see if the same value occurs in the operand list twice. If 3167 // so, delete one. Since we sorted the list, these values are required to 3168 // be adjacent. 3169 for (unsigned i = 0, e = Ops.size()-1; i != e; ++i) 3170 // X smax Y smax Y --> X smax Y 3171 // X smax Y --> X, if X is always greater than Y 3172 if (Ops[i] == Ops[i+1] || 3173 isKnownPredicate(ICmpInst::ICMP_SGE, Ops[i], Ops[i+1])) { 3174 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+2); 3175 --i; --e; 3176 } else if (isKnownPredicate(ICmpInst::ICMP_SLE, Ops[i], Ops[i+1])) { 3177 Ops.erase(Ops.begin()+i, Ops.begin()+i+1); 3178 --i; --e; 3179 } 3180 3181 if (Ops.size() == 1) return Ops[0]; 3182 3183 assert(!Ops.empty() && "Reduced smax down to nothing!"); 3184 3185 // Okay, it looks like we really DO need an smax expr. Check to see if we 3186 // already have one, otherwise create a new one. 3187 FoldingSetNodeID ID; 3188 ID.AddInteger(scSMaxExpr); 3189 for (unsigned i = 0, e = Ops.size(); i != e; ++i) 3190 ID.AddPointer(Ops[i]); 3191 void *IP = nullptr; 3192 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; 3193 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); 3194 std::uninitialized_copy(Ops.begin(), Ops.end(), O); 3195 SCEV *S = new (SCEVAllocator) SCEVSMaxExpr(ID.Intern(SCEVAllocator), 3196 O, Ops.size()); 3197 UniqueSCEVs.InsertNode(S, IP); 3198 return S; 3199 } 3200 3201 const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, 3202 const SCEV *RHS) { 3203 SmallVector<const SCEV *, 2> Ops = {LHS, RHS}; 3204 return getUMaxExpr(Ops); 3205 } 3206 3207 const SCEV * 3208 ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) { 3209 assert(!Ops.empty() && "Cannot get empty umax!"); 3210 if (Ops.size() == 1) return Ops[0]; 3211 #ifndef NDEBUG 3212 Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); 3213 for (unsigned i = 1, e = Ops.size(); i != e; ++i) 3214 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && 3215 "SCEVUMaxExpr operand types don't match!"); 3216 #endif 3217 3218 // Sort by complexity, this groups all similar expression types together. 3219 GroupByComplexity(Ops, &LI); 3220 3221 // If there are any constants, fold them together. 3222 unsigned Idx = 0; 3223 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { 3224 ++Idx; 3225 assert(Idx < Ops.size()); 3226 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { 3227 // We found two constants, fold them together! 3228 ConstantInt *Fold = ConstantInt::get( 3229 getContext(), APIntOps::umax(LHSC->getAPInt(), RHSC->getAPInt())); 3230 Ops[0] = getConstant(Fold); 3231 Ops.erase(Ops.begin()+1); // Erase the folded element 3232 if (Ops.size() == 1) return Ops[0]; 3233 LHSC = cast<SCEVConstant>(Ops[0]); 3234 } 3235 3236 // If we are left with a constant minimum-int, strip it off. 3237 if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(false)) { 3238 Ops.erase(Ops.begin()); 3239 --Idx; 3240 } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(false)) { 3241 // If we have an umax with a constant maximum-int, it will always be 3242 // maximum-int. 3243 return Ops[0]; 3244 } 3245 3246 if (Ops.size() == 1) return Ops[0]; 3247 } 3248 3249 // Find the first UMax 3250 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scUMaxExpr) 3251 ++Idx; 3252 3253 // Check to see if one of the operands is a UMax. If so, expand its operands 3254 // onto our operand list, and recurse to simplify. 3255 if (Idx < Ops.size()) { 3256 bool DeletedUMax = false; 3257 while (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(Ops[Idx])) { 3258 Ops.erase(Ops.begin()+Idx); 3259 Ops.append(UMax->op_begin(), UMax->op_end()); 3260 DeletedUMax = true; 3261 } 3262 3263 if (DeletedUMax) 3264 return getUMaxExpr(Ops); 3265 } 3266 3267 // Okay, check to see if the same value occurs in the operand list twice. If 3268 // so, delete one. Since we sorted the list, these values are required to 3269 // be adjacent. 3270 for (unsigned i = 0, e = Ops.size()-1; i != e; ++i) 3271 // X umax Y umax Y --> X umax Y 3272 // X umax Y --> X, if X is always greater than Y 3273 if (Ops[i] == Ops[i+1] || 3274 isKnownPredicate(ICmpInst::ICMP_UGE, Ops[i], Ops[i+1])) { 3275 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+2); 3276 --i; --e; 3277 } else if (isKnownPredicate(ICmpInst::ICMP_ULE, Ops[i], Ops[i+1])) { 3278 Ops.erase(Ops.begin()+i, Ops.begin()+i+1); 3279 --i; --e; 3280 } 3281 3282 if (Ops.size() == 1) return Ops[0]; 3283 3284 assert(!Ops.empty() && "Reduced umax down to nothing!"); 3285 3286 // Okay, it looks like we really DO need a umax expr. Check to see if we 3287 // already have one, otherwise create a new one. 3288 FoldingSetNodeID ID; 3289 ID.AddInteger(scUMaxExpr); 3290 for (unsigned i = 0, e = Ops.size(); i != e; ++i) 3291 ID.AddPointer(Ops[i]); 3292 void *IP = nullptr; 3293 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; 3294 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); 3295 std::uninitialized_copy(Ops.begin(), Ops.end(), O); 3296 SCEV *S = new (SCEVAllocator) SCEVUMaxExpr(ID.Intern(SCEVAllocator), 3297 O, Ops.size()); 3298 UniqueSCEVs.InsertNode(S, IP); 3299 return S; 3300 } 3301 3302 const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS, 3303 const SCEV *RHS) { 3304 // ~smax(~x, ~y) == smin(x, y). 3305 return getNotSCEV(getSMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS))); 3306 } 3307 3308 const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, 3309 const SCEV *RHS) { 3310 // ~umax(~x, ~y) == umin(x, y) 3311 return getNotSCEV(getUMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS))); 3312 } 3313 3314 const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) { 3315 // We can bypass creating a target-independent 3316 // constant expression and then folding it back into a ConstantInt. 3317 // This is just a compile-time optimization. 3318 return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy)); 3319 } 3320 3321 const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy, 3322 StructType *STy, 3323 unsigned FieldNo) { 3324 // We can bypass creating a target-independent 3325 // constant expression and then folding it back into a ConstantInt. 3326 // This is just a compile-time optimization. 3327 return getConstant( 3328 IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo)); 3329 } 3330 3331 const SCEV *ScalarEvolution::getUnknown(Value *V) { 3332 // Don't attempt to do anything other than create a SCEVUnknown object 3333 // here. createSCEV only calls getUnknown after checking for all other 3334 // interesting possibilities, and any other code that calls getUnknown 3335 // is doing so in order to hide a value from SCEV canonicalization. 3336 3337 FoldingSetNodeID ID; 3338 ID.AddInteger(scUnknown); 3339 ID.AddPointer(V); 3340 void *IP = nullptr; 3341 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) { 3342 assert(cast<SCEVUnknown>(S)->getValue() == V && 3343 "Stale SCEVUnknown in uniquing map!"); 3344 return S; 3345 } 3346 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this, 3347 FirstUnknown); 3348 FirstUnknown = cast<SCEVUnknown>(S); 3349 UniqueSCEVs.InsertNode(S, IP); 3350 return S; 3351 } 3352 3353 //===----------------------------------------------------------------------===// 3354 // Basic SCEV Analysis and PHI Idiom Recognition Code 3355 // 3356 3357 /// Test if values of the given type are analyzable within the SCEV 3358 /// framework. This primarily includes integer types, and it can optionally 3359 /// include pointer types if the ScalarEvolution class has access to 3360 /// target-specific information. 3361 bool ScalarEvolution::isSCEVable(Type *Ty) const { 3362 // Integers and pointers are always SCEVable. 3363 return Ty->isIntegerTy() || Ty->isPointerTy(); 3364 } 3365 3366 /// Return the size in bits of the specified type, for which isSCEVable must 3367 /// return true. 3368 uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const { 3369 assert(isSCEVable(Ty) && "Type is not SCEVable!"); 3370 return getDataLayout().getTypeSizeInBits(Ty); 3371 } 3372 3373 /// Return a type with the same bitwidth as the given type and which represents 3374 /// how SCEV will treat the given type, for which isSCEVable must return 3375 /// true. For pointer types, this is the pointer-sized integer type. 3376 Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const { 3377 assert(isSCEVable(Ty) && "Type is not SCEVable!"); 3378 3379 if (Ty->isIntegerTy()) 3380 return Ty; 3381 3382 // The only other support type is pointer. 3383 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!"); 3384 return getDataLayout().getIntPtrType(Ty); 3385 } 3386 3387 const SCEV *ScalarEvolution::getCouldNotCompute() { 3388 return CouldNotCompute.get(); 3389 } 3390 3391 bool ScalarEvolution::checkValidity(const SCEV *S) const { 3392 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) { 3393 auto *SU = dyn_cast<SCEVUnknown>(S); 3394 return SU && SU->getValue() == nullptr; 3395 }); 3396 3397 return !ContainsNulls; 3398 } 3399 3400 bool ScalarEvolution::containsAddRecurrence(const SCEV *S) { 3401 HasRecMapType::iterator I = HasRecMap.find(S); 3402 if (I != HasRecMap.end()) 3403 return I->second; 3404 3405 bool FoundAddRec = SCEVExprContains(S, isa<SCEVAddRecExpr, const SCEV *>); 3406 HasRecMap.insert({S, FoundAddRec}); 3407 return FoundAddRec; 3408 } 3409 3410 /// Try to split a SCEVAddExpr into a pair of {SCEV, ConstantInt}. 3411 /// If \p S is a SCEVAddExpr and is composed of a sub SCEV S' and an 3412 /// offset I, then return {S', I}, else return {\p S, nullptr}. 3413 static std::pair<const SCEV *, ConstantInt *> splitAddExpr(const SCEV *S) { 3414 const auto *Add = dyn_cast<SCEVAddExpr>(S); 3415 if (!Add) 3416 return {S, nullptr}; 3417 3418 if (Add->getNumOperands() != 2) 3419 return {S, nullptr}; 3420 3421 auto *ConstOp = dyn_cast<SCEVConstant>(Add->getOperand(0)); 3422 if (!ConstOp) 3423 return {S, nullptr}; 3424 3425 return {Add->getOperand(1), ConstOp->getValue()}; 3426 } 3427 3428 /// Return the ValueOffsetPair set for \p S. \p S can be represented 3429 /// by the value and offset from any ValueOffsetPair in the set. 3430 SetVector<ScalarEvolution::ValueOffsetPair> * 3431 ScalarEvolution::getSCEVValues(const SCEV *S) { 3432 ExprValueMapType::iterator SI = ExprValueMap.find_as(S); 3433 if (SI == ExprValueMap.end()) 3434 return nullptr; 3435 #ifndef NDEBUG 3436 if (VerifySCEVMap) { 3437 // Check there is no dangling Value in the set returned. 3438 for (const auto &VE : SI->second) 3439 assert(ValueExprMap.count(VE.first)); 3440 } 3441 #endif 3442 return &SI->second; 3443 } 3444 3445 /// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V) 3446 /// cannot be used separately. eraseValueFromMap should be used to remove 3447 /// V from ValueExprMap and ExprValueMap at the same time. 3448 void ScalarEvolution::eraseValueFromMap(Value *V) { 3449 ValueExprMapType::iterator I = ValueExprMap.find_as(V); 3450 if (I != ValueExprMap.end()) { 3451 const SCEV *S = I->second; 3452 // Remove {V, 0} from the set of ExprValueMap[S] 3453 if (SetVector<ValueOffsetPair> *SV = getSCEVValues(S)) 3454 SV->remove({V, nullptr}); 3455 3456 // Remove {V, Offset} from the set of ExprValueMap[Stripped] 3457 const SCEV *Stripped; 3458 ConstantInt *Offset; 3459 std::tie(Stripped, Offset) = splitAddExpr(S); 3460 if (Offset != nullptr) { 3461 if (SetVector<ValueOffsetPair> *SV = getSCEVValues(Stripped)) 3462 SV->remove({V, Offset}); 3463 } 3464 ValueExprMap.erase(V); 3465 } 3466 } 3467 3468 /// Return an existing SCEV if it exists, otherwise analyze the expression and 3469 /// create a new one. 3470 const SCEV *ScalarEvolution::getSCEV(Value *V) { 3471 assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); 3472 3473 const SCEV *S = getExistingSCEV(V); 3474 if (S == nullptr) { 3475 S = createSCEV(V); 3476 // During PHI resolution, it is possible to create two SCEVs for the same 3477 // V, so it is needed to double check whether V->S is inserted into 3478 // ValueExprMap before insert S->{V, 0} into ExprValueMap. 3479 std::pair<ValueExprMapType::iterator, bool> Pair = 3480 ValueExprMap.insert({SCEVCallbackVH(V, this), S}); 3481 if (Pair.second) { 3482 ExprValueMap[S].insert({V, nullptr}); 3483 3484 // If S == Stripped + Offset, add Stripped -> {V, Offset} into 3485 // ExprValueMap. 3486 const SCEV *Stripped = S; 3487 ConstantInt *Offset = nullptr; 3488 std::tie(Stripped, Offset) = splitAddExpr(S); 3489 // If stripped is SCEVUnknown, don't bother to save 3490 // Stripped -> {V, offset}. It doesn't simplify and sometimes even 3491 // increase the complexity of the expansion code. 3492 // If V is GetElementPtrInst, don't save Stripped -> {V, offset} 3493 // because it may generate add/sub instead of GEP in SCEV expansion. 3494 if (Offset != nullptr && !isa<SCEVUnknown>(Stripped) && 3495 !isa<GetElementPtrInst>(V)) 3496 ExprValueMap[Stripped].insert({V, Offset}); 3497 } 3498 } 3499 return S; 3500 } 3501 3502 const SCEV *ScalarEvolution::getExistingSCEV(Value *V) { 3503 assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); 3504 3505 ValueExprMapType::iterator I = ValueExprMap.find_as(V); 3506 if (I != ValueExprMap.end()) { 3507 const SCEV *S = I->second; 3508 if (checkValidity(S)) 3509 return S; 3510 eraseValueFromMap(V); 3511 forgetMemoizedResults(S); 3512 } 3513 return nullptr; 3514 } 3515 3516 /// Return a SCEV corresponding to -V = -1*V 3517 /// 3518 const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V, 3519 SCEV::NoWrapFlags Flags) { 3520 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V)) 3521 return getConstant( 3522 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue()))); 3523 3524 Type *Ty = V->getType(); 3525 Ty = getEffectiveSCEVType(Ty); 3526 return getMulExpr( 3527 V, getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty))), Flags); 3528 } 3529 3530 /// Return a SCEV corresponding to ~V = -1-V 3531 const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) { 3532 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V)) 3533 return getConstant( 3534 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue()))); 3535 3536 Type *Ty = V->getType(); 3537 Ty = getEffectiveSCEVType(Ty); 3538 const SCEV *AllOnes = 3539 getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty))); 3540 return getMinusSCEV(AllOnes, V); 3541 } 3542 3543 const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, 3544 SCEV::NoWrapFlags Flags) { 3545 // Fast path: X - X --> 0. 3546 if (LHS == RHS) 3547 return getZero(LHS->getType()); 3548 3549 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation 3550 // makes it so that we cannot make much use of NUW. 3551 auto AddFlags = SCEV::FlagAnyWrap; 3552 const bool RHSIsNotMinSigned = 3553 !getSignedRange(RHS).getSignedMin().isMinSignedValue(); 3554 if (maskFlags(Flags, SCEV::FlagNSW) == SCEV::FlagNSW) { 3555 // Let M be the minimum representable signed value. Then (-1)*RHS 3556 // signed-wraps if and only if RHS is M. That can happen even for 3557 // a NSW subtraction because e.g. (-1)*M signed-wraps even though 3558 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS + 3559 // (-1)*RHS, we need to prove that RHS != M. 3560 // 3561 // If LHS is non-negative and we know that LHS - RHS does not 3562 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap 3563 // either by proving that RHS > M or that LHS >= 0. 3564 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) { 3565 AddFlags = SCEV::FlagNSW; 3566 } 3567 } 3568 3569 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS - 3570 // RHS is NSW and LHS >= 0. 3571 // 3572 // The difficulty here is that the NSW flag may have been proven 3573 // relative to a loop that is to be found in a recurrence in LHS and 3574 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a 3575 // larger scope than intended. 3576 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap; 3577 3578 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags); 3579 } 3580 3581 const SCEV * 3582 ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty) { 3583 Type *SrcTy = V->getType(); 3584 assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && 3585 (Ty->isIntegerTy() || Ty->isPointerTy()) && 3586 "Cannot truncate or zero extend with non-integer arguments!"); 3587 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) 3588 return V; // No conversion 3589 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty)) 3590 return getTruncateExpr(V, Ty); 3591 return getZeroExtendExpr(V, Ty); 3592 } 3593 3594 const SCEV * 3595 ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, 3596 Type *Ty) { 3597 Type *SrcTy = V->getType(); 3598 assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && 3599 (Ty->isIntegerTy() || Ty->isPointerTy()) && 3600 "Cannot truncate or zero extend with non-integer arguments!"); 3601 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) 3602 return V; // No conversion 3603 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty)) 3604 return getTruncateExpr(V, Ty); 3605 return getSignExtendExpr(V, Ty); 3606 } 3607 3608 const SCEV * 3609 ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) { 3610 Type *SrcTy = V->getType(); 3611 assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && 3612 (Ty->isIntegerTy() || Ty->isPointerTy()) && 3613 "Cannot noop or zero extend with non-integer arguments!"); 3614 assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && 3615 "getNoopOrZeroExtend cannot truncate!"); 3616 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) 3617 return V; // No conversion 3618 return getZeroExtendExpr(V, Ty); 3619 } 3620 3621 const SCEV * 3622 ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) { 3623 Type *SrcTy = V->getType(); 3624 assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && 3625 (Ty->isIntegerTy() || Ty->isPointerTy()) && 3626 "Cannot noop or sign extend with non-integer arguments!"); 3627 assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && 3628 "getNoopOrSignExtend cannot truncate!"); 3629 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) 3630 return V; // No conversion 3631 return getSignExtendExpr(V, Ty); 3632 } 3633 3634 const SCEV * 3635 ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) { 3636 Type *SrcTy = V->getType(); 3637 assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && 3638 (Ty->isIntegerTy() || Ty->isPointerTy()) && 3639 "Cannot noop or any extend with non-integer arguments!"); 3640 assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && 3641 "getNoopOrAnyExtend cannot truncate!"); 3642 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) 3643 return V; // No conversion 3644 return getAnyExtendExpr(V, Ty); 3645 } 3646 3647 const SCEV * 3648 ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) { 3649 Type *SrcTy = V->getType(); 3650 assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && 3651 (Ty->isIntegerTy() || Ty->isPointerTy()) && 3652 "Cannot truncate or noop with non-integer arguments!"); 3653 assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) && 3654 "getTruncateOrNoop cannot extend!"); 3655 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) 3656 return V; // No conversion 3657 return getTruncateExpr(V, Ty); 3658 } 3659 3660 const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS, 3661 const SCEV *RHS) { 3662 const SCEV *PromotedLHS = LHS; 3663 const SCEV *PromotedRHS = RHS; 3664 3665 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType())) 3666 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType()); 3667 else 3668 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType()); 3669 3670 return getUMaxExpr(PromotedLHS, PromotedRHS); 3671 } 3672 3673 const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS, 3674 const SCEV *RHS) { 3675 const SCEV *PromotedLHS = LHS; 3676 const SCEV *PromotedRHS = RHS; 3677 3678 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType())) 3679 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType()); 3680 else 3681 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType()); 3682 3683 return getUMinExpr(PromotedLHS, PromotedRHS); 3684 } 3685 3686 const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { 3687 // A pointer operand may evaluate to a nonpointer expression, such as null. 3688 if (!V->getType()->isPointerTy()) 3689 return V; 3690 3691 if (const SCEVCastExpr *Cast = dyn_cast<SCEVCastExpr>(V)) { 3692 return getPointerBase(Cast->getOperand()); 3693 } else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(V)) { 3694 const SCEV *PtrOp = nullptr; 3695 for (const SCEV *NAryOp : NAry->operands()) { 3696 if (NAryOp->getType()->isPointerTy()) { 3697 // Cannot find the base of an expression with multiple pointer operands. 3698 if (PtrOp) 3699 return V; 3700 PtrOp = NAryOp; 3701 } 3702 } 3703 if (!PtrOp) 3704 return V; 3705 return getPointerBase(PtrOp); 3706 } 3707 return V; 3708 } 3709 3710 /// Push users of the given Instruction onto the given Worklist. 3711 static void 3712 PushDefUseChildren(Instruction *I, 3713 SmallVectorImpl<Instruction *> &Worklist) { 3714 // Push the def-use children onto the Worklist stack. 3715 for (User *U : I->users()) 3716 Worklist.push_back(cast<Instruction>(U)); 3717 } 3718 3719 void ScalarEvolution::forgetSymbolicName(Instruction *PN, const SCEV *SymName) { 3720 SmallVector<Instruction *, 16> Worklist; 3721 PushDefUseChildren(PN, Worklist); 3722 3723 SmallPtrSet<Instruction *, 8> Visited; 3724 Visited.insert(PN); 3725 while (!Worklist.empty()) { 3726 Instruction *I = Worklist.pop_back_val(); 3727 if (!Visited.insert(I).second) 3728 continue; 3729 3730 auto It = ValueExprMap.find_as(static_cast<Value *>(I)); 3731 if (It != ValueExprMap.end()) { 3732 const SCEV *Old = It->second; 3733 3734 // Short-circuit the def-use traversal if the symbolic name 3735 // ceases to appear in expressions. 3736 if (Old != SymName && !hasOperand(Old, SymName)) 3737 continue; 3738 3739 // SCEVUnknown for a PHI either means that it has an unrecognized 3740 // structure, it's a PHI that's in the progress of being computed 3741 // by createNodeForPHI, or it's a single-value PHI. In the first case, 3742 // additional loop trip count information isn't going to change anything. 3743 // In the second case, createNodeForPHI will perform the necessary 3744 // updates on its own when it gets to that point. In the third, we do 3745 // want to forget the SCEVUnknown. 3746 if (!isa<PHINode>(I) || 3747 !isa<SCEVUnknown>(Old) || 3748 (I != PN && Old == SymName)) { 3749 eraseValueFromMap(It->first); 3750 forgetMemoizedResults(Old); 3751 } 3752 } 3753 3754 PushDefUseChildren(I, Worklist); 3755 } 3756 } 3757 3758 namespace { 3759 class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> { 3760 public: 3761 static const SCEV *rewrite(const SCEV *S, const Loop *L, 3762 ScalarEvolution &SE) { 3763 SCEVInitRewriter Rewriter(L, SE); 3764 const SCEV *Result = Rewriter.visit(S); 3765 return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); 3766 } 3767 3768 SCEVInitRewriter(const Loop *L, ScalarEvolution &SE) 3769 : SCEVRewriteVisitor(SE), L(L), Valid(true) {} 3770 3771 const SCEV *visitUnknown(const SCEVUnknown *Expr) { 3772 if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant)) 3773 Valid = false; 3774 return Expr; 3775 } 3776 3777 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { 3778 // Only allow AddRecExprs for this loop. 3779 if (Expr->getLoop() == L) 3780 return Expr->getStart(); 3781 Valid = false; 3782 return Expr; 3783 } 3784 3785 bool isValid() { return Valid; } 3786 3787 private: 3788 const Loop *L; 3789 bool Valid; 3790 }; 3791 3792 class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> { 3793 public: 3794 static const SCEV *rewrite(const SCEV *S, const Loop *L, 3795 ScalarEvolution &SE) { 3796 SCEVShiftRewriter Rewriter(L, SE); 3797 const SCEV *Result = Rewriter.visit(S); 3798 return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); 3799 } 3800 3801 SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE) 3802 : SCEVRewriteVisitor(SE), L(L), Valid(true) {} 3803 3804 const SCEV *visitUnknown(const SCEVUnknown *Expr) { 3805 // Only allow AddRecExprs for this loop. 3806 if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant)) 3807 Valid = false; 3808 return Expr; 3809 } 3810 3811 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { 3812 if (Expr->getLoop() == L && Expr->isAffine()) 3813 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE)); 3814 Valid = false; 3815 return Expr; 3816 } 3817 bool isValid() { return Valid; } 3818 3819 private: 3820 const Loop *L; 3821 bool Valid; 3822 }; 3823 } // end anonymous namespace 3824 3825 SCEV::NoWrapFlags 3826 ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) { 3827 if (!AR->isAffine()) 3828 return SCEV::FlagAnyWrap; 3829 3830 typedef OverflowingBinaryOperator OBO; 3831 SCEV::NoWrapFlags Result = SCEV::FlagAnyWrap; 3832 3833 if (!AR->hasNoSignedWrap()) { 3834 ConstantRange AddRecRange = getSignedRange(AR); 3835 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this)); 3836 3837 auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion( 3838 Instruction::Add, IncRange, OBO::NoSignedWrap); 3839 if (NSWRegion.contains(AddRecRange)) 3840 Result = ScalarEvolution::setFlags(Result, SCEV::FlagNSW); 3841 } 3842 3843 if (!AR->hasNoUnsignedWrap()) { 3844 ConstantRange AddRecRange = getUnsignedRange(AR); 3845 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this)); 3846 3847 auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion( 3848 Instruction::Add, IncRange, OBO::NoUnsignedWrap); 3849 if (NUWRegion.contains(AddRecRange)) 3850 Result = ScalarEvolution::setFlags(Result, SCEV::FlagNUW); 3851 } 3852 3853 return Result; 3854 } 3855 3856 namespace { 3857 /// Represents an abstract binary operation. This may exist as a 3858 /// normal instruction or constant expression, or may have been 3859 /// derived from an expression tree. 3860 struct BinaryOp { 3861 unsigned Opcode; 3862 Value *LHS; 3863 Value *RHS; 3864 bool IsNSW; 3865 bool IsNUW; 3866 3867 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or 3868 /// constant expression. 3869 Operator *Op; 3870 3871 explicit BinaryOp(Operator *Op) 3872 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)), 3873 IsNSW(false), IsNUW(false), Op(Op) { 3874 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) { 3875 IsNSW = OBO->hasNoSignedWrap(); 3876 IsNUW = OBO->hasNoUnsignedWrap(); 3877 } 3878 } 3879 3880 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false, 3881 bool IsNUW = false) 3882 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW), 3883 Op(nullptr) {} 3884 }; 3885 } 3886 3887 3888 /// Try to map \p V into a BinaryOp, and return \c None on failure. 3889 static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) { 3890 auto *Op = dyn_cast<Operator>(V); 3891 if (!Op) 3892 return None; 3893 3894 // Implementation detail: all the cleverness here should happen without 3895 // creating new SCEV expressions -- our caller knowns tricks to avoid creating 3896 // SCEV expressions when possible, and we should not break that. 3897 3898 switch (Op->getOpcode()) { 3899 case Instruction::Add: 3900 case Instruction::Sub: 3901 case Instruction::Mul: 3902 case Instruction::UDiv: 3903 case Instruction::And: 3904 case Instruction::Or: 3905 case Instruction::AShr: 3906 case Instruction::Shl: 3907 return BinaryOp(Op); 3908 3909 case Instruction::Xor: 3910 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1))) 3911 // If the RHS of the xor is a signbit, then this is just an add. 3912 // Instcombine turns add of signbit into xor as a strength reduction step. 3913 if (RHSC->getValue().isSignBit()) 3914 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1)); 3915 return BinaryOp(Op); 3916 3917 case Instruction::LShr: 3918 // Turn logical shift right of a constant into a unsigned divide. 3919 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) { 3920 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth(); 3921 3922 // If the shift count is not less than the bitwidth, the result of 3923 // the shift is undefined. Don't try to analyze it, because the 3924 // resolution chosen here may differ from the resolution chosen in 3925 // other parts of the compiler. 3926 if (SA->getValue().ult(BitWidth)) { 3927 Constant *X = 3928 ConstantInt::get(SA->getContext(), 3929 APInt::getOneBitSet(BitWidth, SA->getZExtValue())); 3930 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X); 3931 } 3932 } 3933 return BinaryOp(Op); 3934 3935 case Instruction::ExtractValue: { 3936 auto *EVI = cast<ExtractValueInst>(Op); 3937 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0) 3938 break; 3939 3940 auto *CI = dyn_cast<CallInst>(EVI->getAggregateOperand()); 3941 if (!CI) 3942 break; 3943 3944 if (auto *F = CI->getCalledFunction()) 3945 switch (F->getIntrinsicID()) { 3946 case Intrinsic::sadd_with_overflow: 3947 case Intrinsic::uadd_with_overflow: { 3948 if (!isOverflowIntrinsicNoWrap(cast<IntrinsicInst>(CI), DT)) 3949 return BinaryOp(Instruction::Add, CI->getArgOperand(0), 3950 CI->getArgOperand(1)); 3951 3952 // Now that we know that all uses of the arithmetic-result component of 3953 // CI are guarded by the overflow check, we can go ahead and pretend 3954 // that the arithmetic is non-overflowing. 3955 if (F->getIntrinsicID() == Intrinsic::sadd_with_overflow) 3956 return BinaryOp(Instruction::Add, CI->getArgOperand(0), 3957 CI->getArgOperand(1), /* IsNSW = */ true, 3958 /* IsNUW = */ false); 3959 else 3960 return BinaryOp(Instruction::Add, CI->getArgOperand(0), 3961 CI->getArgOperand(1), /* IsNSW = */ false, 3962 /* IsNUW*/ true); 3963 } 3964 3965 case Intrinsic::ssub_with_overflow: 3966 case Intrinsic::usub_with_overflow: 3967 return BinaryOp(Instruction::Sub, CI->getArgOperand(0), 3968 CI->getArgOperand(1)); 3969 3970 case Intrinsic::smul_with_overflow: 3971 case Intrinsic::umul_with_overflow: 3972 return BinaryOp(Instruction::Mul, CI->getArgOperand(0), 3973 CI->getArgOperand(1)); 3974 default: 3975 break; 3976 } 3977 } 3978 3979 default: 3980 break; 3981 } 3982 3983 return None; 3984 } 3985 3986 const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { 3987 const Loop *L = LI.getLoopFor(PN->getParent()); 3988 if (!L || L->getHeader() != PN->getParent()) 3989 return nullptr; 3990 3991 // The loop may have multiple entrances or multiple exits; we can analyze 3992 // this phi as an addrec if it has a unique entry value and a unique 3993 // backedge value. 3994 Value *BEValueV = nullptr, *StartValueV = nullptr; 3995 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { 3996 Value *V = PN->getIncomingValue(i); 3997 if (L->contains(PN->getIncomingBlock(i))) { 3998 if (!BEValueV) { 3999 BEValueV = V; 4000 } else if (BEValueV != V) { 4001 BEValueV = nullptr; 4002 break; 4003 } 4004 } else if (!StartValueV) { 4005 StartValueV = V; 4006 } else if (StartValueV != V) { 4007 StartValueV = nullptr; 4008 break; 4009 } 4010 } 4011 if (BEValueV && StartValueV) { 4012 // While we are analyzing this PHI node, handle its value symbolically. 4013 const SCEV *SymbolicName = getUnknown(PN); 4014 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() && 4015 "PHI node already processed?"); 4016 ValueExprMap.insert({SCEVCallbackVH(PN, this), SymbolicName}); 4017 4018 // Using this symbolic name for the PHI, analyze the value coming around 4019 // the back-edge. 4020 const SCEV *BEValue = getSCEV(BEValueV); 4021 4022 // NOTE: If BEValue is loop invariant, we know that the PHI node just 4023 // has a special value for the first iteration of the loop. 4024 4025 // If the value coming around the backedge is an add with the symbolic 4026 // value we just inserted, then we found a simple induction variable! 4027 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) { 4028 // If there is a single occurrence of the symbolic value, replace it 4029 // with a recurrence. 4030 unsigned FoundIndex = Add->getNumOperands(); 4031 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) 4032 if (Add->getOperand(i) == SymbolicName) 4033 if (FoundIndex == e) { 4034 FoundIndex = i; 4035 break; 4036 } 4037 4038 if (FoundIndex != Add->getNumOperands()) { 4039 // Create an add with everything but the specified operand. 4040 SmallVector<const SCEV *, 8> Ops; 4041 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) 4042 if (i != FoundIndex) 4043 Ops.push_back(Add->getOperand(i)); 4044 const SCEV *Accum = getAddExpr(Ops); 4045 4046 // This is not a valid addrec if the step amount is varying each 4047 // loop iteration, but is not itself an addrec in this loop. 4048 if (isLoopInvariant(Accum, L) || 4049 (isa<SCEVAddRecExpr>(Accum) && 4050 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) { 4051 SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; 4052 4053 if (auto BO = MatchBinaryOp(BEValueV, DT)) { 4054 if (BO->Opcode == Instruction::Add && BO->LHS == PN) { 4055 if (BO->IsNUW) 4056 Flags = setFlags(Flags, SCEV::FlagNUW); 4057 if (BO->IsNSW) 4058 Flags = setFlags(Flags, SCEV::FlagNSW); 4059 } 4060 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) { 4061 // If the increment is an inbounds GEP, then we know the address 4062 // space cannot be wrapped around. We cannot make any guarantee 4063 // about signed or unsigned overflow because pointers are 4064 // unsigned but we may have a negative index from the base 4065 // pointer. We can guarantee that no unsigned wrap occurs if the 4066 // indices form a positive value. 4067 if (GEP->isInBounds() && GEP->getOperand(0) == PN) { 4068 Flags = setFlags(Flags, SCEV::FlagNW); 4069 4070 const SCEV *Ptr = getSCEV(GEP->getPointerOperand()); 4071 if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr))) 4072 Flags = setFlags(Flags, SCEV::FlagNUW); 4073 } 4074 4075 // We cannot transfer nuw and nsw flags from subtraction 4076 // operations -- sub nuw X, Y is not the same as add nuw X, -Y 4077 // for instance. 4078 } 4079 4080 const SCEV *StartVal = getSCEV(StartValueV); 4081 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); 4082 4083 // Okay, for the entire analysis of this edge we assumed the PHI 4084 // to be symbolic. We now need to go back and purge all of the 4085 // entries for the scalars that use the symbolic expression. 4086 forgetSymbolicName(PN, SymbolicName); 4087 ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; 4088 4089 // We can add Flags to the post-inc expression only if we 4090 // know that it us *undefined behavior* for BEValueV to 4091 // overflow. 4092 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) 4093 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L)) 4094 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags); 4095 4096 return PHISCEV; 4097 } 4098 } 4099 } else { 4100 // Otherwise, this could be a loop like this: 4101 // i = 0; for (j = 1; ..; ++j) { .... i = j; } 4102 // In this case, j = {1,+,1} and BEValue is j. 4103 // Because the other in-value of i (0) fits the evolution of BEValue 4104 // i really is an addrec evolution. 4105 // 4106 // We can generalize this saying that i is the shifted value of BEValue 4107 // by one iteration: 4108 // PHI(f(0), f({1,+,1})) --> f({0,+,1}) 4109 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this); 4110 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this); 4111 if (Shifted != getCouldNotCompute() && 4112 Start != getCouldNotCompute()) { 4113 const SCEV *StartVal = getSCEV(StartValueV); 4114 if (Start == StartVal) { 4115 // Okay, for the entire analysis of this edge we assumed the PHI 4116 // to be symbolic. We now need to go back and purge all of the 4117 // entries for the scalars that use the symbolic expression. 4118 forgetSymbolicName(PN, SymbolicName); 4119 ValueExprMap[SCEVCallbackVH(PN, this)] = Shifted; 4120 return Shifted; 4121 } 4122 } 4123 } 4124 4125 // Remove the temporary PHI node SCEV that has been inserted while intending 4126 // to create an AddRecExpr for this PHI node. We can not keep this temporary 4127 // as it will prevent later (possibly simpler) SCEV expressions to be added 4128 // to the ValueExprMap. 4129 eraseValueFromMap(PN); 4130 } 4131 4132 return nullptr; 4133 } 4134 4135 // Checks if the SCEV S is available at BB. S is considered available at BB 4136 // if S can be materialized at BB without introducing a fault. 4137 static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S, 4138 BasicBlock *BB) { 4139 struct CheckAvailable { 4140 bool TraversalDone = false; 4141 bool Available = true; 4142 4143 const Loop *L = nullptr; // The loop BB is in (can be nullptr) 4144 BasicBlock *BB = nullptr; 4145 DominatorTree &DT; 4146 4147 CheckAvailable(const Loop *L, BasicBlock *BB, DominatorTree &DT) 4148 : L(L), BB(BB), DT(DT) {} 4149 4150 bool setUnavailable() { 4151 TraversalDone = true; 4152 Available = false; 4153 return false; 4154 } 4155 4156 bool follow(const SCEV *S) { 4157 switch (S->getSCEVType()) { 4158 case scConstant: case scTruncate: case scZeroExtend: case scSignExtend: 4159 case scAddExpr: case scMulExpr: case scUMaxExpr: case scSMaxExpr: 4160 // These expressions are available if their operand(s) is/are. 4161 return true; 4162 4163 case scAddRecExpr: { 4164 // We allow add recurrences that are on the loop BB is in, or some 4165 // outer loop. This guarantees availability because the value of the 4166 // add recurrence at BB is simply the "current" value of the induction 4167 // variable. We can relax this in the future; for instance an add 4168 // recurrence on a sibling dominating loop is also available at BB. 4169 const auto *ARLoop = cast<SCEVAddRecExpr>(S)->getLoop(); 4170 if (L && (ARLoop == L || ARLoop->contains(L))) 4171 return true; 4172 4173 return setUnavailable(); 4174 } 4175 4176 case scUnknown: { 4177 // For SCEVUnknown, we check for simple dominance. 4178 const auto *SU = cast<SCEVUnknown>(S); 4179 Value *V = SU->getValue(); 4180 4181 if (isa<Argument>(V)) 4182 return false; 4183 4184 if (isa<Instruction>(V) && DT.dominates(cast<Instruction>(V), BB)) 4185 return false; 4186 4187 return setUnavailable(); 4188 } 4189 4190 case scUDivExpr: 4191 case scCouldNotCompute: 4192 // We do not try to smart about these at all. 4193 return setUnavailable(); 4194 } 4195 llvm_unreachable("switch should be fully covered!"); 4196 } 4197 4198 bool isDone() { return TraversalDone; } 4199 }; 4200 4201 CheckAvailable CA(L, BB, DT); 4202 SCEVTraversal<CheckAvailable> ST(CA); 4203 4204 ST.visitAll(S); 4205 return CA.Available; 4206 } 4207 4208 // Try to match a control flow sequence that branches out at BI and merges back 4209 // at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful 4210 // match. 4211 static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, 4212 Value *&C, Value *&LHS, Value *&RHS) { 4213 C = BI->getCondition(); 4214 4215 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0)); 4216 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1)); 4217 4218 if (!LeftEdge.isSingleEdge()) 4219 return false; 4220 4221 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()"); 4222 4223 Use &LeftUse = Merge->getOperandUse(0); 4224 Use &RightUse = Merge->getOperandUse(1); 4225 4226 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) { 4227 LHS = LeftUse; 4228 RHS = RightUse; 4229 return true; 4230 } 4231 4232 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) { 4233 LHS = RightUse; 4234 RHS = LeftUse; 4235 return true; 4236 } 4237 4238 return false; 4239 } 4240 4241 const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) { 4242 auto IsReachable = 4243 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); }; 4244 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) { 4245 const Loop *L = LI.getLoopFor(PN->getParent()); 4246 4247 // We don't want to break LCSSA, even in a SCEV expression tree. 4248 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) 4249 if (LI.getLoopFor(PN->getIncomingBlock(i)) != L) 4250 return nullptr; 4251 4252 // Try to match 4253 // 4254 // br %cond, label %left, label %right 4255 // left: 4256 // br label %merge 4257 // right: 4258 // br label %merge 4259 // merge: 4260 // V = phi [ %x, %left ], [ %y, %right ] 4261 // 4262 // as "select %cond, %x, %y" 4263 4264 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock(); 4265 assert(IDom && "At least the entry block should dominate PN"); 4266 4267 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator()); 4268 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr; 4269 4270 if (BI && BI->isConditional() && 4271 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) && 4272 IsAvailableOnEntry(L, DT, getSCEV(LHS), PN->getParent()) && 4273 IsAvailableOnEntry(L, DT, getSCEV(RHS), PN->getParent())) 4274 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS); 4275 } 4276 4277 return nullptr; 4278 } 4279 4280 const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { 4281 if (const SCEV *S = createAddRecFromPHI(PN)) 4282 return S; 4283 4284 if (const SCEV *S = createNodeFromSelectLikePHI(PN)) 4285 return S; 4286 4287 // If the PHI has a single incoming value, follow that value, unless the 4288 // PHI's incoming blocks are in a different loop, in which case doing so 4289 // risks breaking LCSSA form. Instcombine would normally zap these, but 4290 // it doesn't have DominatorTree information, so it may miss cases. 4291 if (Value *V = SimplifyInstruction(PN, getDataLayout(), &TLI, &DT, &AC)) 4292 if (LI.replacementPreservesLCSSAForm(PN, V)) 4293 return getSCEV(V); 4294 4295 // If it's not a loop phi, we can't handle it yet. 4296 return getUnknown(PN); 4297 } 4298 4299 const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I, 4300 Value *Cond, 4301 Value *TrueVal, 4302 Value *FalseVal) { 4303 // Handle "constant" branch or select. This can occur for instance when a 4304 // loop pass transforms an inner loop and moves on to process the outer loop. 4305 if (auto *CI = dyn_cast<ConstantInt>(Cond)) 4306 return getSCEV(CI->isOne() ? TrueVal : FalseVal); 4307 4308 // Try to match some simple smax or umax patterns. 4309 auto *ICI = dyn_cast<ICmpInst>(Cond); 4310 if (!ICI) 4311 return getUnknown(I); 4312 4313 Value *LHS = ICI->getOperand(0); 4314 Value *RHS = ICI->getOperand(1); 4315 4316 switch (ICI->getPredicate()) { 4317 case ICmpInst::ICMP_SLT: 4318 case ICmpInst::ICMP_SLE: 4319 std::swap(LHS, RHS); 4320 LLVM_FALLTHROUGH; 4321 case ICmpInst::ICMP_SGT: 4322 case ICmpInst::ICMP_SGE: 4323 // a >s b ? a+x : b+x -> smax(a, b)+x 4324 // a >s b ? b+x : a+x -> smin(a, b)+x 4325 if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) { 4326 const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), I->getType()); 4327 const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), I->getType()); 4328 const SCEV *LA = getSCEV(TrueVal); 4329 const SCEV *RA = getSCEV(FalseVal); 4330 const SCEV *LDiff = getMinusSCEV(LA, LS); 4331 const SCEV *RDiff = getMinusSCEV(RA, RS); 4332 if (LDiff == RDiff) 4333 return getAddExpr(getSMaxExpr(LS, RS), LDiff); 4334 LDiff = getMinusSCEV(LA, RS); 4335 RDiff = getMinusSCEV(RA, LS); 4336 if (LDiff == RDiff) 4337 return getAddExpr(getSMinExpr(LS, RS), LDiff); 4338 } 4339 break; 4340 case ICmpInst::ICMP_ULT: 4341 case ICmpInst::ICMP_ULE: 4342 std::swap(LHS, RHS); 4343 LLVM_FALLTHROUGH; 4344 case ICmpInst::ICMP_UGT: 4345 case ICmpInst::ICMP_UGE: 4346 // a >u b ? a+x : b+x -> umax(a, b)+x 4347 // a >u b ? b+x : a+x -> umin(a, b)+x 4348 if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) { 4349 const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType()); 4350 const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), I->getType()); 4351 const SCEV *LA = getSCEV(TrueVal); 4352 const SCEV *RA = getSCEV(FalseVal); 4353 const SCEV *LDiff = getMinusSCEV(LA, LS); 4354 const SCEV *RDiff = getMinusSCEV(RA, RS); 4355 if (LDiff == RDiff) 4356 return getAddExpr(getUMaxExpr(LS, RS), LDiff); 4357 LDiff = getMinusSCEV(LA, RS); 4358 RDiff = getMinusSCEV(RA, LS); 4359 if (LDiff == RDiff) 4360 return getAddExpr(getUMinExpr(LS, RS), LDiff); 4361 } 4362 break; 4363 case ICmpInst::ICMP_NE: 4364 // n != 0 ? n+x : 1+x -> umax(n, 1)+x 4365 if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) && 4366 isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) { 4367 const SCEV *One = getOne(I->getType()); 4368 const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType()); 4369 const SCEV *LA = getSCEV(TrueVal); 4370 const SCEV *RA = getSCEV(FalseVal); 4371 const SCEV *LDiff = getMinusSCEV(LA, LS); 4372 const SCEV *RDiff = getMinusSCEV(RA, One); 4373 if (LDiff == RDiff) 4374 return getAddExpr(getUMaxExpr(One, LS), LDiff); 4375 } 4376 break; 4377 case ICmpInst::ICMP_EQ: 4378 // n == 0 ? 1+x : n+x -> umax(n, 1)+x 4379 if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) && 4380 isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) { 4381 const SCEV *One = getOne(I->getType()); 4382 const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType()); 4383 const SCEV *LA = getSCEV(TrueVal); 4384 const SCEV *RA = getSCEV(FalseVal); 4385 const SCEV *LDiff = getMinusSCEV(LA, One); 4386 const SCEV *RDiff = getMinusSCEV(RA, LS); 4387 if (LDiff == RDiff) 4388 return getAddExpr(getUMaxExpr(One, LS), LDiff); 4389 } 4390 break; 4391 default: 4392 break; 4393 } 4394 4395 return getUnknown(I); 4396 } 4397 4398 /// Expand GEP instructions into add and multiply operations. This allows them 4399 /// to be analyzed by regular SCEV code. 4400 const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { 4401 // Don't attempt to analyze GEPs over unsized objects. 4402 if (!GEP->getSourceElementType()->isSized()) 4403 return getUnknown(GEP); 4404 4405 SmallVector<const SCEV *, 4> IndexExprs; 4406 for (auto Index = GEP->idx_begin(); Index != GEP->idx_end(); ++Index) 4407 IndexExprs.push_back(getSCEV(*Index)); 4408 return getGEPExpr(GEP, IndexExprs); 4409 } 4410 4411 uint32_t 4412 ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { 4413 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) 4414 return C->getAPInt().countTrailingZeros(); 4415 4416 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S)) 4417 return std::min(GetMinTrailingZeros(T->getOperand()), 4418 (uint32_t)getTypeSizeInBits(T->getType())); 4419 4420 if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) { 4421 uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); 4422 return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ? 4423 getTypeSizeInBits(E->getType()) : OpRes; 4424 } 4425 4426 if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) { 4427 uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); 4428 return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ? 4429 getTypeSizeInBits(E->getType()) : OpRes; 4430 } 4431 4432 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) { 4433 // The result is the min of all operands results. 4434 uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0)); 4435 for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) 4436 MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i))); 4437 return MinOpRes; 4438 } 4439 4440 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) { 4441 // The result is the sum of all operands results. 4442 uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0)); 4443 uint32_t BitWidth = getTypeSizeInBits(M->getType()); 4444 for (unsigned i = 1, e = M->getNumOperands(); 4445 SumOpRes != BitWidth && i != e; ++i) 4446 SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), 4447 BitWidth); 4448 return SumOpRes; 4449 } 4450 4451 if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) { 4452 // The result is the min of all operands results. 4453 uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0)); 4454 for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) 4455 MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i))); 4456 return MinOpRes; 4457 } 4458 4459 if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) { 4460 // The result is the min of all operands results. 4461 uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0)); 4462 for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i) 4463 MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i))); 4464 return MinOpRes; 4465 } 4466 4467 if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) { 4468 // The result is the min of all operands results. 4469 uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0)); 4470 for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i) 4471 MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i))); 4472 return MinOpRes; 4473 } 4474 4475 if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { 4476 // For a SCEVUnknown, ask ValueTracking. 4477 unsigned BitWidth = getTypeSizeInBits(U->getType()); 4478 APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); 4479 computeKnownBits(U->getValue(), Zeros, Ones, getDataLayout(), 0, &AC, 4480 nullptr, &DT); 4481 return Zeros.countTrailingOnes(); 4482 } 4483 4484 // SCEVUDivExpr 4485 return 0; 4486 } 4487 4488 /// Helper method to assign a range to V from metadata present in the IR. 4489 static Optional<ConstantRange> GetRangeFromMetadata(Value *V) { 4490 if (Instruction *I = dyn_cast<Instruction>(V)) 4491 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range)) 4492 return getConstantRangeFromMetadata(*MD); 4493 4494 return None; 4495 } 4496 4497 /// Determine the range for a particular SCEV. If SignHint is 4498 /// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges 4499 /// with a "cleaner" unsigned (resp. signed) representation. 4500 ConstantRange 4501 ScalarEvolution::getRange(const SCEV *S, 4502 ScalarEvolution::RangeSignHint SignHint) { 4503 DenseMap<const SCEV *, ConstantRange> &Cache = 4504 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges 4505 : SignedRanges; 4506 4507 // See if we've computed this range already. 4508 DenseMap<const SCEV *, ConstantRange>::iterator I = Cache.find(S); 4509 if (I != Cache.end()) 4510 return I->second; 4511 4512 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) 4513 return setRange(C, SignHint, ConstantRange(C->getAPInt())); 4514 4515 unsigned BitWidth = getTypeSizeInBits(S->getType()); 4516 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true); 4517 4518 // If the value has known zeros, the maximum value will have those known zeros 4519 // as well. 4520 uint32_t TZ = GetMinTrailingZeros(S); 4521 if (TZ != 0) { 4522 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) 4523 ConservativeResult = 4524 ConstantRange(APInt::getMinValue(BitWidth), 4525 APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1); 4526 else 4527 ConservativeResult = ConstantRange( 4528 APInt::getSignedMinValue(BitWidth), 4529 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1); 4530 } 4531 4532 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) { 4533 ConstantRange X = getRange(Add->getOperand(0), SignHint); 4534 for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i) 4535 X = X.add(getRange(Add->getOperand(i), SignHint)); 4536 return setRange(Add, SignHint, ConservativeResult.intersectWith(X)); 4537 } 4538 4539 if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) { 4540 ConstantRange X = getRange(Mul->getOperand(0), SignHint); 4541 for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i) 4542 X = X.multiply(getRange(Mul->getOperand(i), SignHint)); 4543 return setRange(Mul, SignHint, ConservativeResult.intersectWith(X)); 4544 } 4545 4546 if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) { 4547 ConstantRange X = getRange(SMax->getOperand(0), SignHint); 4548 for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i) 4549 X = X.smax(getRange(SMax->getOperand(i), SignHint)); 4550 return setRange(SMax, SignHint, ConservativeResult.intersectWith(X)); 4551 } 4552 4553 if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) { 4554 ConstantRange X = getRange(UMax->getOperand(0), SignHint); 4555 for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i) 4556 X = X.umax(getRange(UMax->getOperand(i), SignHint)); 4557 return setRange(UMax, SignHint, ConservativeResult.intersectWith(X)); 4558 } 4559 4560 if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) { 4561 ConstantRange X = getRange(UDiv->getLHS(), SignHint); 4562 ConstantRange Y = getRange(UDiv->getRHS(), SignHint); 4563 return setRange(UDiv, SignHint, 4564 ConservativeResult.intersectWith(X.udiv(Y))); 4565 } 4566 4567 if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) { 4568 ConstantRange X = getRange(ZExt->getOperand(), SignHint); 4569 return setRange(ZExt, SignHint, 4570 ConservativeResult.intersectWith(X.zeroExtend(BitWidth))); 4571 } 4572 4573 if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) { 4574 ConstantRange X = getRange(SExt->getOperand(), SignHint); 4575 return setRange(SExt, SignHint, 4576 ConservativeResult.intersectWith(X.signExtend(BitWidth))); 4577 } 4578 4579 if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) { 4580 ConstantRange X = getRange(Trunc->getOperand(), SignHint); 4581 return setRange(Trunc, SignHint, 4582 ConservativeResult.intersectWith(X.truncate(BitWidth))); 4583 } 4584 4585 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) { 4586 // If there's no unsigned wrap, the value will never be less than its 4587 // initial value. 4588 if (AddRec->hasNoUnsignedWrap()) 4589 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(AddRec->getStart())) 4590 if (!C->getValue()->isZero()) 4591 ConservativeResult = ConservativeResult.intersectWith( 4592 ConstantRange(C->getAPInt(), APInt(BitWidth, 0))); 4593 4594 // If there's no signed wrap, and all the operands have the same sign or 4595 // zero, the value won't ever change sign. 4596 if (AddRec->hasNoSignedWrap()) { 4597 bool AllNonNeg = true; 4598 bool AllNonPos = true; 4599 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { 4600 if (!isKnownNonNegative(AddRec->getOperand(i))) AllNonNeg = false; 4601 if (!isKnownNonPositive(AddRec->getOperand(i))) AllNonPos = false; 4602 } 4603 if (AllNonNeg) 4604 ConservativeResult = ConservativeResult.intersectWith( 4605 ConstantRange(APInt(BitWidth, 0), 4606 APInt::getSignedMinValue(BitWidth))); 4607 else if (AllNonPos) 4608 ConservativeResult = ConservativeResult.intersectWith( 4609 ConstantRange(APInt::getSignedMinValue(BitWidth), 4610 APInt(BitWidth, 1))); 4611 } 4612 4613 // TODO: non-affine addrec 4614 if (AddRec->isAffine()) { 4615 const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop()); 4616 if (!isa<SCEVCouldNotCompute>(MaxBECount) && 4617 getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) { 4618 auto RangeFromAffine = getRangeForAffineAR( 4619 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount, 4620 BitWidth); 4621 if (!RangeFromAffine.isFullSet()) 4622 ConservativeResult = 4623 ConservativeResult.intersectWith(RangeFromAffine); 4624 4625 auto RangeFromFactoring = getRangeViaFactoring( 4626 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount, 4627 BitWidth); 4628 if (!RangeFromFactoring.isFullSet()) 4629 ConservativeResult = 4630 ConservativeResult.intersectWith(RangeFromFactoring); 4631 } 4632 } 4633 4634 return setRange(AddRec, SignHint, ConservativeResult); 4635 } 4636 4637 if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { 4638 // Check if the IR explicitly contains !range metadata. 4639 Optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue()); 4640 if (MDRange.hasValue()) 4641 ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue()); 4642 4643 // Split here to avoid paying the compile-time cost of calling both 4644 // computeKnownBits and ComputeNumSignBits. This restriction can be lifted 4645 // if needed. 4646 const DataLayout &DL = getDataLayout(); 4647 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) { 4648 // For a SCEVUnknown, ask ValueTracking. 4649 APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); 4650 computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, &AC, nullptr, &DT); 4651 if (Ones != ~Zeros + 1) 4652 ConservativeResult = 4653 ConservativeResult.intersectWith(ConstantRange(Ones, ~Zeros + 1)); 4654 } else { 4655 assert(SignHint == ScalarEvolution::HINT_RANGE_SIGNED && 4656 "generalize as needed!"); 4657 unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, &AC, nullptr, &DT); 4658 if (NS > 1) 4659 ConservativeResult = ConservativeResult.intersectWith( 4660 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1), 4661 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1)); 4662 } 4663 4664 return setRange(U, SignHint, ConservativeResult); 4665 } 4666 4667 return setRange(S, SignHint, ConservativeResult); 4668 } 4669 4670 ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, 4671 const SCEV *Step, 4672 const SCEV *MaxBECount, 4673 unsigned BitWidth) { 4674 assert(!isa<SCEVCouldNotCompute>(MaxBECount) && 4675 getTypeSizeInBits(MaxBECount->getType()) <= BitWidth && 4676 "Precondition!"); 4677 4678 ConstantRange Result(BitWidth, /* isFullSet = */ true); 4679 4680 // Check for overflow. This must be done with ConstantRange arithmetic 4681 // because we could be called from within the ScalarEvolution overflow 4682 // checking code. 4683 4684 MaxBECount = getNoopOrZeroExtend(MaxBECount, Start->getType()); 4685 ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount); 4686 ConstantRange ZExtMaxBECountRange = MaxBECountRange.zextOrTrunc(BitWidth * 2); 4687 4688 ConstantRange StepSRange = getSignedRange(Step); 4689 ConstantRange SExtStepSRange = StepSRange.sextOrTrunc(BitWidth * 2); 4690 4691 ConstantRange StartURange = getUnsignedRange(Start); 4692 ConstantRange EndURange = 4693 StartURange.add(MaxBECountRange.multiply(StepSRange)); 4694 4695 // Check for unsigned overflow. 4696 ConstantRange ZExtStartURange = StartURange.zextOrTrunc(BitWidth * 2); 4697 ConstantRange ZExtEndURange = EndURange.zextOrTrunc(BitWidth * 2); 4698 if (ZExtStartURange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == 4699 ZExtEndURange) { 4700 APInt Min = APIntOps::umin(StartURange.getUnsignedMin(), 4701 EndURange.getUnsignedMin()); 4702 APInt Max = APIntOps::umax(StartURange.getUnsignedMax(), 4703 EndURange.getUnsignedMax()); 4704 bool IsFullRange = Min.isMinValue() && Max.isMaxValue(); 4705 if (!IsFullRange) 4706 Result = 4707 Result.intersectWith(ConstantRange(Min, Max + 1)); 4708 } 4709 4710 ConstantRange StartSRange = getSignedRange(Start); 4711 ConstantRange EndSRange = 4712 StartSRange.add(MaxBECountRange.multiply(StepSRange)); 4713 4714 // Check for signed overflow. This must be done with ConstantRange 4715 // arithmetic because we could be called from within the ScalarEvolution 4716 // overflow checking code. 4717 ConstantRange SExtStartSRange = StartSRange.sextOrTrunc(BitWidth * 2); 4718 ConstantRange SExtEndSRange = EndSRange.sextOrTrunc(BitWidth * 2); 4719 if (SExtStartSRange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == 4720 SExtEndSRange) { 4721 APInt Min = 4722 APIntOps::smin(StartSRange.getSignedMin(), EndSRange.getSignedMin()); 4723 APInt Max = 4724 APIntOps::smax(StartSRange.getSignedMax(), EndSRange.getSignedMax()); 4725 bool IsFullRange = Min.isMinSignedValue() && Max.isMaxSignedValue(); 4726 if (!IsFullRange) 4727 Result = 4728 Result.intersectWith(ConstantRange(Min, Max + 1)); 4729 } 4730 4731 return Result; 4732 } 4733 4734 ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start, 4735 const SCEV *Step, 4736 const SCEV *MaxBECount, 4737 unsigned BitWidth) { 4738 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q}) 4739 // == RangeOf({A,+,P}) union RangeOf({B,+,Q}) 4740 4741 struct SelectPattern { 4742 Value *Condition = nullptr; 4743 APInt TrueValue; 4744 APInt FalseValue; 4745 4746 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth, 4747 const SCEV *S) { 4748 Optional<unsigned> CastOp; 4749 APInt Offset(BitWidth, 0); 4750 4751 assert(SE.getTypeSizeInBits(S->getType()) == BitWidth && 4752 "Should be!"); 4753 4754 // Peel off a constant offset: 4755 if (auto *SA = dyn_cast<SCEVAddExpr>(S)) { 4756 // In the future we could consider being smarter here and handle 4757 // {Start+Step,+,Step} too. 4758 if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0))) 4759 return; 4760 4761 Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt(); 4762 S = SA->getOperand(1); 4763 } 4764 4765 // Peel off a cast operation 4766 if (auto *SCast = dyn_cast<SCEVCastExpr>(S)) { 4767 CastOp = SCast->getSCEVType(); 4768 S = SCast->getOperand(); 4769 } 4770 4771 using namespace llvm::PatternMatch; 4772 4773 auto *SU = dyn_cast<SCEVUnknown>(S); 4774 const APInt *TrueVal, *FalseVal; 4775 if (!SU || 4776 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal), 4777 m_APInt(FalseVal)))) { 4778 Condition = nullptr; 4779 return; 4780 } 4781 4782 TrueValue = *TrueVal; 4783 FalseValue = *FalseVal; 4784 4785 // Re-apply the cast we peeled off earlier 4786 if (CastOp.hasValue()) 4787 switch (*CastOp) { 4788 default: 4789 llvm_unreachable("Unknown SCEV cast type!"); 4790 4791 case scTruncate: 4792 TrueValue = TrueValue.trunc(BitWidth); 4793 FalseValue = FalseValue.trunc(BitWidth); 4794 break; 4795 case scZeroExtend: 4796 TrueValue = TrueValue.zext(BitWidth); 4797 FalseValue = FalseValue.zext(BitWidth); 4798 break; 4799 case scSignExtend: 4800 TrueValue = TrueValue.sext(BitWidth); 4801 FalseValue = FalseValue.sext(BitWidth); 4802 break; 4803 } 4804 4805 // Re-apply the constant offset we peeled off earlier 4806 TrueValue += Offset; 4807 FalseValue += Offset; 4808 } 4809 4810 bool isRecognized() { return Condition != nullptr; } 4811 }; 4812 4813 SelectPattern StartPattern(*this, BitWidth, Start); 4814 if (!StartPattern.isRecognized()) 4815 return ConstantRange(BitWidth, /* isFullSet = */ true); 4816 4817 SelectPattern StepPattern(*this, BitWidth, Step); 4818 if (!StepPattern.isRecognized()) 4819 return ConstantRange(BitWidth, /* isFullSet = */ true); 4820 4821 if (StartPattern.Condition != StepPattern.Condition) { 4822 // We don't handle this case today; but we could, by considering four 4823 // possibilities below instead of two. I'm not sure if there are cases where 4824 // that will help over what getRange already does, though. 4825 return ConstantRange(BitWidth, /* isFullSet = */ true); 4826 } 4827 4828 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to 4829 // construct arbitrary general SCEV expressions here. This function is called 4830 // from deep in the call stack, and calling getSCEV (on a sext instruction, 4831 // say) can end up caching a suboptimal value. 4832 4833 // FIXME: without the explicit `this` receiver below, MSVC errors out with 4834 // C2352 and C2512 (otherwise it isn't needed). 4835 4836 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue); 4837 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue); 4838 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue); 4839 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue); 4840 4841 ConstantRange TrueRange = 4842 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount, BitWidth); 4843 ConstantRange FalseRange = 4844 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount, BitWidth); 4845 4846 return TrueRange.unionWith(FalseRange); 4847 } 4848 4849 SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) { 4850 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap; 4851 const BinaryOperator *BinOp = cast<BinaryOperator>(V); 4852 4853 // Return early if there are no flags to propagate to the SCEV. 4854 SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; 4855 if (BinOp->hasNoUnsignedWrap()) 4856 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); 4857 if (BinOp->hasNoSignedWrap()) 4858 Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); 4859 if (Flags == SCEV::FlagAnyWrap) 4860 return SCEV::FlagAnyWrap; 4861 4862 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap; 4863 } 4864 4865 bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) { 4866 // Here we check that I is in the header of the innermost loop containing I, 4867 // since we only deal with instructions in the loop header. The actual loop we 4868 // need to check later will come from an add recurrence, but getting that 4869 // requires computing the SCEV of the operands, which can be expensive. This 4870 // check we can do cheaply to rule out some cases early. 4871 Loop *InnermostContainingLoop = LI.getLoopFor(I->getParent()); 4872 if (InnermostContainingLoop == nullptr || 4873 InnermostContainingLoop->getHeader() != I->getParent()) 4874 return false; 4875 4876 // Only proceed if we can prove that I does not yield poison. 4877 if (!isKnownNotFullPoison(I)) return false; 4878 4879 // At this point we know that if I is executed, then it does not wrap 4880 // according to at least one of NSW or NUW. If I is not executed, then we do 4881 // not know if the calculation that I represents would wrap. Multiple 4882 // instructions can map to the same SCEV. If we apply NSW or NUW from I to 4883 // the SCEV, we must guarantee no wrapping for that SCEV also when it is 4884 // derived from other instructions that map to the same SCEV. We cannot make 4885 // that guarantee for cases where I is not executed. So we need to find the 4886 // loop that I is considered in relation to and prove that I is executed for 4887 // every iteration of that loop. That implies that the value that I 4888 // calculates does not wrap anywhere in the loop, so then we can apply the 4889 // flags to the SCEV. 4890 // 4891 // We check isLoopInvariant to disambiguate in case we are adding recurrences 4892 // from different loops, so that we know which loop to prove that I is 4893 // executed in. 4894 for (unsigned OpIndex = 0; OpIndex < I->getNumOperands(); ++OpIndex) { 4895 // I could be an extractvalue from a call to an overflow intrinsic. 4896 // TODO: We can do better here in some cases. 4897 if (!isSCEVable(I->getOperand(OpIndex)->getType())) 4898 return false; 4899 const SCEV *Op = getSCEV(I->getOperand(OpIndex)); 4900 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) { 4901 bool AllOtherOpsLoopInvariant = true; 4902 for (unsigned OtherOpIndex = 0; OtherOpIndex < I->getNumOperands(); 4903 ++OtherOpIndex) { 4904 if (OtherOpIndex != OpIndex) { 4905 const SCEV *OtherOp = getSCEV(I->getOperand(OtherOpIndex)); 4906 if (!isLoopInvariant(OtherOp, AddRec->getLoop())) { 4907 AllOtherOpsLoopInvariant = false; 4908 break; 4909 } 4910 } 4911 } 4912 if (AllOtherOpsLoopInvariant && 4913 isGuaranteedToExecuteForEveryIteration(I, AddRec->getLoop())) 4914 return true; 4915 } 4916 } 4917 return false; 4918 } 4919 4920 bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) { 4921 // If we know that \c I can never be poison period, then that's enough. 4922 if (isSCEVExprNeverPoison(I)) 4923 return true; 4924 4925 // For an add recurrence specifically, we assume that infinite loops without 4926 // side effects are undefined behavior, and then reason as follows: 4927 // 4928 // If the add recurrence is poison in any iteration, it is poison on all 4929 // future iterations (since incrementing poison yields poison). If the result 4930 // of the add recurrence is fed into the loop latch condition and the loop 4931 // does not contain any throws or exiting blocks other than the latch, we now 4932 // have the ability to "choose" whether the backedge is taken or not (by 4933 // choosing a sufficiently evil value for the poison feeding into the branch) 4934 // for every iteration including and after the one in which \p I first became 4935 // poison. There are two possibilities (let's call the iteration in which \p 4936 // I first became poison as K): 4937 // 4938 // 1. In the set of iterations including and after K, the loop body executes 4939 // no side effects. In this case executing the backege an infinte number 4940 // of times will yield undefined behavior. 4941 // 4942 // 2. In the set of iterations including and after K, the loop body executes 4943 // at least one side effect. In this case, that specific instance of side 4944 // effect is control dependent on poison, which also yields undefined 4945 // behavior. 4946 4947 auto *ExitingBB = L->getExitingBlock(); 4948 auto *LatchBB = L->getLoopLatch(); 4949 if (!ExitingBB || !LatchBB || ExitingBB != LatchBB) 4950 return false; 4951 4952 SmallPtrSet<const Instruction *, 16> Pushed; 4953 SmallVector<const Instruction *, 8> PoisonStack; 4954 4955 // We start by assuming \c I, the post-inc add recurrence, is poison. Only 4956 // things that are known to be fully poison under that assumption go on the 4957 // PoisonStack. 4958 Pushed.insert(I); 4959 PoisonStack.push_back(I); 4960 4961 bool LatchControlDependentOnPoison = false; 4962 while (!PoisonStack.empty() && !LatchControlDependentOnPoison) { 4963 const Instruction *Poison = PoisonStack.pop_back_val(); 4964 4965 for (auto *PoisonUser : Poison->users()) { 4966 if (propagatesFullPoison(cast<Instruction>(PoisonUser))) { 4967 if (Pushed.insert(cast<Instruction>(PoisonUser)).second) 4968 PoisonStack.push_back(cast<Instruction>(PoisonUser)); 4969 } else if (auto *BI = dyn_cast<BranchInst>(PoisonUser)) { 4970 assert(BI->isConditional() && "Only possibility!"); 4971 if (BI->getParent() == LatchBB) { 4972 LatchControlDependentOnPoison = true; 4973 break; 4974 } 4975 } 4976 } 4977 } 4978 4979 return LatchControlDependentOnPoison && loopHasNoAbnormalExits(L); 4980 } 4981 4982 ScalarEvolution::LoopProperties 4983 ScalarEvolution::getLoopProperties(const Loop *L) { 4984 typedef ScalarEvolution::LoopProperties LoopProperties; 4985 4986 auto Itr = LoopPropertiesCache.find(L); 4987 if (Itr == LoopPropertiesCache.end()) { 4988 auto HasSideEffects = [](Instruction *I) { 4989 if (auto *SI = dyn_cast<StoreInst>(I)) 4990 return !SI->isSimple(); 4991 4992 return I->mayHaveSideEffects(); 4993 }; 4994 4995 LoopProperties LP = {/* HasNoAbnormalExits */ true, 4996 /*HasNoSideEffects*/ true}; 4997 4998 for (auto *BB : L->getBlocks()) 4999 for (auto &I : *BB) { 5000 if (!isGuaranteedToTransferExecutionToSuccessor(&I)) 5001 LP.HasNoAbnormalExits = false; 5002 if (HasSideEffects(&I)) 5003 LP.HasNoSideEffects = false; 5004 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects) 5005 break; // We're already as pessimistic as we can get. 5006 } 5007 5008 auto InsertPair = LoopPropertiesCache.insert({L, LP}); 5009 assert(InsertPair.second && "We just checked!"); 5010 Itr = InsertPair.first; 5011 } 5012 5013 return Itr->second; 5014 } 5015 5016 const SCEV *ScalarEvolution::createSCEV(Value *V) { 5017 if (!isSCEVable(V->getType())) 5018 return getUnknown(V); 5019 5020 if (Instruction *I = dyn_cast<Instruction>(V)) { 5021 // Don't attempt to analyze instructions in blocks that aren't 5022 // reachable. Such instructions don't matter, and they aren't required 5023 // to obey basic rules for definitions dominating uses which this 5024 // analysis depends on. 5025 if (!DT.isReachableFromEntry(I->getParent())) 5026 return getUnknown(V); 5027 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) 5028 return getConstant(CI); 5029 else if (isa<ConstantPointerNull>(V)) 5030 return getZero(V->getType()); 5031 else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) 5032 return GA->isInterposable() ? getUnknown(V) : getSCEV(GA->getAliasee()); 5033 else if (!isa<ConstantExpr>(V)) 5034 return getUnknown(V); 5035 5036 Operator *U = cast<Operator>(V); 5037 if (auto BO = MatchBinaryOp(U, DT)) { 5038 switch (BO->Opcode) { 5039 case Instruction::Add: { 5040 // The simple thing to do would be to just call getSCEV on both operands 5041 // and call getAddExpr with the result. However if we're looking at a 5042 // bunch of things all added together, this can be quite inefficient, 5043 // because it leads to N-1 getAddExpr calls for N ultimate operands. 5044 // Instead, gather up all the operands and make a single getAddExpr call. 5045 // LLVM IR canonical form means we need only traverse the left operands. 5046 SmallVector<const SCEV *, 4> AddOps; 5047 do { 5048 if (BO->Op) { 5049 if (auto *OpSCEV = getExistingSCEV(BO->Op)) { 5050 AddOps.push_back(OpSCEV); 5051 break; 5052 } 5053 5054 // If a NUW or NSW flag can be applied to the SCEV for this 5055 // addition, then compute the SCEV for this addition by itself 5056 // with a separate call to getAddExpr. We need to do that 5057 // instead of pushing the operands of the addition onto AddOps, 5058 // since the flags are only known to apply to this particular 5059 // addition - they may not apply to other additions that can be 5060 // formed with operands from AddOps. 5061 const SCEV *RHS = getSCEV(BO->RHS); 5062 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op); 5063 if (Flags != SCEV::FlagAnyWrap) { 5064 const SCEV *LHS = getSCEV(BO->LHS); 5065 if (BO->Opcode == Instruction::Sub) 5066 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags)); 5067 else 5068 AddOps.push_back(getAddExpr(LHS, RHS, Flags)); 5069 break; 5070 } 5071 } 5072 5073 if (BO->Opcode == Instruction::Sub) 5074 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS))); 5075 else 5076 AddOps.push_back(getSCEV(BO->RHS)); 5077 5078 auto NewBO = MatchBinaryOp(BO->LHS, DT); 5079 if (!NewBO || (NewBO->Opcode != Instruction::Add && 5080 NewBO->Opcode != Instruction::Sub)) { 5081 AddOps.push_back(getSCEV(BO->LHS)); 5082 break; 5083 } 5084 BO = NewBO; 5085 } while (true); 5086 5087 return getAddExpr(AddOps); 5088 } 5089 5090 case Instruction::Mul: { 5091 SmallVector<const SCEV *, 4> MulOps; 5092 do { 5093 if (BO->Op) { 5094 if (auto *OpSCEV = getExistingSCEV(BO->Op)) { 5095 MulOps.push_back(OpSCEV); 5096 break; 5097 } 5098 5099 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op); 5100 if (Flags != SCEV::FlagAnyWrap) { 5101 MulOps.push_back( 5102 getMulExpr(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags)); 5103 break; 5104 } 5105 } 5106 5107 MulOps.push_back(getSCEV(BO->RHS)); 5108 auto NewBO = MatchBinaryOp(BO->LHS, DT); 5109 if (!NewBO || NewBO->Opcode != Instruction::Mul) { 5110 MulOps.push_back(getSCEV(BO->LHS)); 5111 break; 5112 } 5113 BO = NewBO; 5114 } while (true); 5115 5116 return getMulExpr(MulOps); 5117 } 5118 case Instruction::UDiv: 5119 return getUDivExpr(getSCEV(BO->LHS), getSCEV(BO->RHS)); 5120 case Instruction::Sub: { 5121 SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; 5122 if (BO->Op) 5123 Flags = getNoWrapFlagsFromUB(BO->Op); 5124 return getMinusSCEV(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags); 5125 } 5126 case Instruction::And: 5127 // For an expression like x&255 that merely masks off the high bits, 5128 // use zext(trunc(x)) as the SCEV expression. 5129 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) { 5130 if (CI->isNullValue()) 5131 return getSCEV(BO->RHS); 5132 if (CI->isAllOnesValue()) 5133 return getSCEV(BO->LHS); 5134 const APInt &A = CI->getValue(); 5135 5136 // Instcombine's ShrinkDemandedConstant may strip bits out of 5137 // constants, obscuring what would otherwise be a low-bits mask. 5138 // Use computeKnownBits to compute what ShrinkDemandedConstant 5139 // knew about to reconstruct a low-bits mask value. 5140 unsigned LZ = A.countLeadingZeros(); 5141 unsigned TZ = A.countTrailingZeros(); 5142 unsigned BitWidth = A.getBitWidth(); 5143 APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); 5144 computeKnownBits(BO->LHS, KnownZero, KnownOne, getDataLayout(), 5145 0, &AC, nullptr, &DT); 5146 5147 APInt EffectiveMask = 5148 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); 5149 if ((LZ != 0 || TZ != 0) && !((~A & ~KnownZero) & EffectiveMask)) { 5150 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ)); 5151 const SCEV *LHS = getSCEV(BO->LHS); 5152 const SCEV *ShiftedLHS = nullptr; 5153 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) { 5154 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) { 5155 // For an expression like (x * 8) & 8, simplify the multiply. 5156 unsigned MulZeros = OpC->getAPInt().countTrailingZeros(); 5157 unsigned GCD = std::min(MulZeros, TZ); 5158 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD); 5159 SmallVector<const SCEV*, 4> MulOps; 5160 MulOps.push_back(getConstant(OpC->getAPInt().lshr(GCD))); 5161 MulOps.append(LHSMul->op_begin() + 1, LHSMul->op_end()); 5162 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags()); 5163 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt)); 5164 } 5165 } 5166 if (!ShiftedLHS) 5167 ShiftedLHS = getUDivExpr(LHS, MulCount); 5168 return getMulExpr( 5169 getZeroExtendExpr( 5170 getTruncateExpr(ShiftedLHS, 5171 IntegerType::get(getContext(), BitWidth - LZ - TZ)), 5172 BO->LHS->getType()), 5173 MulCount); 5174 } 5175 } 5176 break; 5177 5178 case Instruction::Or: 5179 // If the RHS of the Or is a constant, we may have something like: 5180 // X*4+1 which got turned into X*4|1. Handle this as an Add so loop 5181 // optimizations will transparently handle this case. 5182 // 5183 // In order for this transformation to be safe, the LHS must be of the 5184 // form X*(2^n) and the Or constant must be less than 2^n. 5185 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) { 5186 const SCEV *LHS = getSCEV(BO->LHS); 5187 const APInt &CIVal = CI->getValue(); 5188 if (GetMinTrailingZeros(LHS) >= 5189 (CIVal.getBitWidth() - CIVal.countLeadingZeros())) { 5190 // Build a plain add SCEV. 5191 const SCEV *S = getAddExpr(LHS, getSCEV(CI)); 5192 // If the LHS of the add was an addrec and it has no-wrap flags, 5193 // transfer the no-wrap flags, since an or won't introduce a wrap. 5194 if (const SCEVAddRecExpr *NewAR = dyn_cast<SCEVAddRecExpr>(S)) { 5195 const SCEVAddRecExpr *OldAR = cast<SCEVAddRecExpr>(LHS); 5196 const_cast<SCEVAddRecExpr *>(NewAR)->setNoWrapFlags( 5197 OldAR->getNoWrapFlags()); 5198 } 5199 return S; 5200 } 5201 } 5202 break; 5203 5204 case Instruction::Xor: 5205 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) { 5206 // If the RHS of xor is -1, then this is a not operation. 5207 if (CI->isAllOnesValue()) 5208 return getNotSCEV(getSCEV(BO->LHS)); 5209 5210 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask. 5211 // This is a variant of the check for xor with -1, and it handles 5212 // the case where instcombine has trimmed non-demanded bits out 5213 // of an xor with -1. 5214 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS)) 5215 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1))) 5216 if (LBO->getOpcode() == Instruction::And && 5217 LCI->getValue() == CI->getValue()) 5218 if (const SCEVZeroExtendExpr *Z = 5219 dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) { 5220 Type *UTy = BO->LHS->getType(); 5221 const SCEV *Z0 = Z->getOperand(); 5222 Type *Z0Ty = Z0->getType(); 5223 unsigned Z0TySize = getTypeSizeInBits(Z0Ty); 5224 5225 // If C is a low-bits mask, the zero extend is serving to 5226 // mask off the high bits. Complement the operand and 5227 // re-apply the zext. 5228 if (APIntOps::isMask(Z0TySize, CI->getValue())) 5229 return getZeroExtendExpr(getNotSCEV(Z0), UTy); 5230 5231 // If C is a single bit, it may be in the sign-bit position 5232 // before the zero-extend. In this case, represent the xor 5233 // using an add, which is equivalent, and re-apply the zext. 5234 APInt Trunc = CI->getValue().trunc(Z0TySize); 5235 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() && 5236 Trunc.isSignBit()) 5237 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)), 5238 UTy); 5239 } 5240 } 5241 break; 5242 5243 case Instruction::Shl: 5244 // Turn shift left of a constant amount into a multiply. 5245 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) { 5246 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth(); 5247 5248 // If the shift count is not less than the bitwidth, the result of 5249 // the shift is undefined. Don't try to analyze it, because the 5250 // resolution chosen here may differ from the resolution chosen in 5251 // other parts of the compiler. 5252 if (SA->getValue().uge(BitWidth)) 5253 break; 5254 5255 // It is currently not resolved how to interpret NSW for left 5256 // shift by BitWidth - 1, so we avoid applying flags in that 5257 // case. Remove this check (or this comment) once the situation 5258 // is resolved. See 5259 // http://lists.llvm.org/pipermail/llvm-dev/2015-April/084195.html 5260 // and http://reviews.llvm.org/D8890 . 5261 auto Flags = SCEV::FlagAnyWrap; 5262 if (BO->Op && SA->getValue().ult(BitWidth - 1)) 5263 Flags = getNoWrapFlagsFromUB(BO->Op); 5264 5265 Constant *X = ConstantInt::get(getContext(), 5266 APInt::getOneBitSet(BitWidth, SA->getZExtValue())); 5267 return getMulExpr(getSCEV(BO->LHS), getSCEV(X), Flags); 5268 } 5269 break; 5270 5271 case Instruction::AShr: 5272 // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression. 5273 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) 5274 if (Operator *L = dyn_cast<Operator>(BO->LHS)) 5275 if (L->getOpcode() == Instruction::Shl && 5276 L->getOperand(1) == BO->RHS) { 5277 uint64_t BitWidth = getTypeSizeInBits(BO->LHS->getType()); 5278 5279 // If the shift count is not less than the bitwidth, the result of 5280 // the shift is undefined. Don't try to analyze it, because the 5281 // resolution chosen here may differ from the resolution chosen in 5282 // other parts of the compiler. 5283 if (CI->getValue().uge(BitWidth)) 5284 break; 5285 5286 uint64_t Amt = BitWidth - CI->getZExtValue(); 5287 if (Amt == BitWidth) 5288 return getSCEV(L->getOperand(0)); // shift by zero --> noop 5289 return getSignExtendExpr( 5290 getTruncateExpr(getSCEV(L->getOperand(0)), 5291 IntegerType::get(getContext(), Amt)), 5292 BO->LHS->getType()); 5293 } 5294 break; 5295 } 5296 } 5297 5298 switch (U->getOpcode()) { 5299 case Instruction::Trunc: 5300 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType()); 5301 5302 case Instruction::ZExt: 5303 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType()); 5304 5305 case Instruction::SExt: 5306 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType()); 5307 5308 case Instruction::BitCast: 5309 // BitCasts are no-op casts so we just eliminate the cast. 5310 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) 5311 return getSCEV(U->getOperand(0)); 5312 break; 5313 5314 // It's tempting to handle inttoptr and ptrtoint as no-ops, however this can 5315 // lead to pointer expressions which cannot safely be expanded to GEPs, 5316 // because ScalarEvolution doesn't respect the GEP aliasing rules when 5317 // simplifying integer expressions. 5318 5319 case Instruction::GetElementPtr: 5320 return createNodeForGEP(cast<GEPOperator>(U)); 5321 5322 case Instruction::PHI: 5323 return createNodeForPHI(cast<PHINode>(U)); 5324 5325 case Instruction::Select: 5326 // U can also be a select constant expr, which let fall through. Since 5327 // createNodeForSelect only works for a condition that is an `ICmpInst`, and 5328 // constant expressions cannot have instructions as operands, we'd have 5329 // returned getUnknown for a select constant expressions anyway. 5330 if (isa<Instruction>(U)) 5331 return createNodeForSelectOrPHI(cast<Instruction>(U), U->getOperand(0), 5332 U->getOperand(1), U->getOperand(2)); 5333 break; 5334 5335 case Instruction::Call: 5336 case Instruction::Invoke: 5337 if (Value *RV = CallSite(U).getReturnedArgOperand()) 5338 return getSCEV(RV); 5339 break; 5340 } 5341 5342 return getUnknown(V); 5343 } 5344 5345 5346 5347 //===----------------------------------------------------------------------===// 5348 // Iteration Count Computation Code 5349 // 5350 5351 static unsigned getConstantTripCount(const SCEVConstant *ExitCount) { 5352 if (!ExitCount) 5353 return 0; 5354 5355 ConstantInt *ExitConst = ExitCount->getValue(); 5356 5357 // Guard against huge trip counts. 5358 if (ExitConst->getValue().getActiveBits() > 32) 5359 return 0; 5360 5361 // In case of integer overflow, this returns 0, which is correct. 5362 return ((unsigned)ExitConst->getZExtValue()) + 1; 5363 } 5364 5365 unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L) { 5366 if (BasicBlock *ExitingBB = L->getExitingBlock()) 5367 return getSmallConstantTripCount(L, ExitingBB); 5368 5369 // No trip count information for multiple exits. 5370 return 0; 5371 } 5372 5373 unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L, 5374 BasicBlock *ExitingBlock) { 5375 assert(ExitingBlock && "Must pass a non-null exiting block!"); 5376 assert(L->isLoopExiting(ExitingBlock) && 5377 "Exiting block must actually branch out of the loop!"); 5378 const SCEVConstant *ExitCount = 5379 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock)); 5380 return getConstantTripCount(ExitCount); 5381 } 5382 5383 unsigned ScalarEvolution::getSmallConstantMaxTripCount(Loop *L) { 5384 const auto *MaxExitCount = 5385 dyn_cast<SCEVConstant>(getMaxBackedgeTakenCount(L)); 5386 return getConstantTripCount(MaxExitCount); 5387 } 5388 5389 unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L) { 5390 if (BasicBlock *ExitingBB = L->getExitingBlock()) 5391 return getSmallConstantTripMultiple(L, ExitingBB); 5392 5393 // No trip multiple information for multiple exits. 5394 return 0; 5395 } 5396 5397 /// Returns the largest constant divisor of the trip count of this loop as a 5398 /// normal unsigned value, if possible. This means that the actual trip count is 5399 /// always a multiple of the returned value (don't forget the trip count could 5400 /// very well be zero as well!). 5401 /// 5402 /// Returns 1 if the trip count is unknown or not guaranteed to be the 5403 /// multiple of a constant (which is also the case if the trip count is simply 5404 /// constant, use getSmallConstantTripCount for that case), Will also return 1 5405 /// if the trip count is very large (>= 2^32). 5406 /// 5407 /// As explained in the comments for getSmallConstantTripCount, this assumes 5408 /// that control exits the loop via ExitingBlock. 5409 unsigned 5410 ScalarEvolution::getSmallConstantTripMultiple(Loop *L, 5411 BasicBlock *ExitingBlock) { 5412 assert(ExitingBlock && "Must pass a non-null exiting block!"); 5413 assert(L->isLoopExiting(ExitingBlock) && 5414 "Exiting block must actually branch out of the loop!"); 5415 const SCEV *ExitCount = getExitCount(L, ExitingBlock); 5416 if (ExitCount == getCouldNotCompute()) 5417 return 1; 5418 5419 // Get the trip count from the BE count by adding 1. 5420 const SCEV *TCMul = getAddExpr(ExitCount, getOne(ExitCount->getType())); 5421 // FIXME: SCEV distributes multiplication as V1*C1 + V2*C1. We could attempt 5422 // to factor simple cases. 5423 if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(TCMul)) 5424 TCMul = Mul->getOperand(0); 5425 5426 const SCEVConstant *MulC = dyn_cast<SCEVConstant>(TCMul); 5427 if (!MulC) 5428 return 1; 5429 5430 ConstantInt *Result = MulC->getValue(); 5431 5432 // Guard against huge trip counts (this requires checking 5433 // for zero to handle the case where the trip count == -1 and the 5434 // addition wraps). 5435 if (!Result || Result->getValue().getActiveBits() > 32 || 5436 Result->getValue().getActiveBits() == 0) 5437 return 1; 5438 5439 return (unsigned)Result->getZExtValue(); 5440 } 5441 5442 /// Get the expression for the number of loop iterations for which this loop is 5443 /// guaranteed not to exit via ExitingBlock. Otherwise return 5444 /// SCEVCouldNotCompute. 5445 const SCEV *ScalarEvolution::getExitCount(Loop *L, BasicBlock *ExitingBlock) { 5446 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this); 5447 } 5448 5449 const SCEV * 5450 ScalarEvolution::getPredicatedBackedgeTakenCount(const Loop *L, 5451 SCEVUnionPredicate &Preds) { 5452 return getPredicatedBackedgeTakenInfo(L).getExact(this, &Preds); 5453 } 5454 5455 const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L) { 5456 return getBackedgeTakenInfo(L).getExact(this); 5457 } 5458 5459 /// Similar to getBackedgeTakenCount, except return the least SCEV value that is 5460 /// known never to be less than the actual backedge taken count. 5461 const SCEV *ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) { 5462 return getBackedgeTakenInfo(L).getMax(this); 5463 } 5464 5465 bool ScalarEvolution::isBackedgeTakenCountMaxOrZero(const Loop *L) { 5466 return getBackedgeTakenInfo(L).isMaxOrZero(this); 5467 } 5468 5469 /// Push PHI nodes in the header of the given loop onto the given Worklist. 5470 static void 5471 PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) { 5472 BasicBlock *Header = L->getHeader(); 5473 5474 // Push all Loop-header PHIs onto the Worklist stack. 5475 for (BasicBlock::iterator I = Header->begin(); 5476 PHINode *PN = dyn_cast<PHINode>(I); ++I) 5477 Worklist.push_back(PN); 5478 } 5479 5480 const ScalarEvolution::BackedgeTakenInfo & 5481 ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) { 5482 auto &BTI = getBackedgeTakenInfo(L); 5483 if (BTI.hasFullInfo()) 5484 return BTI; 5485 5486 auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()}); 5487 5488 if (!Pair.second) 5489 return Pair.first->second; 5490 5491 BackedgeTakenInfo Result = 5492 computeBackedgeTakenCount(L, /*AllowPredicates=*/true); 5493 5494 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result); 5495 } 5496 5497 const ScalarEvolution::BackedgeTakenInfo & 5498 ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { 5499 // Initially insert an invalid entry for this loop. If the insertion 5500 // succeeds, proceed to actually compute a backedge-taken count and 5501 // update the value. The temporary CouldNotCompute value tells SCEV 5502 // code elsewhere that it shouldn't attempt to request a new 5503 // backedge-taken count, which could result in infinite recursion. 5504 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair = 5505 BackedgeTakenCounts.insert({L, BackedgeTakenInfo()}); 5506 if (!Pair.second) 5507 return Pair.first->second; 5508 5509 // computeBackedgeTakenCount may allocate memory for its result. Inserting it 5510 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result 5511 // must be cleared in this scope. 5512 BackedgeTakenInfo Result = computeBackedgeTakenCount(L); 5513 5514 if (Result.getExact(this) != getCouldNotCompute()) { 5515 assert(isLoopInvariant(Result.getExact(this), L) && 5516 isLoopInvariant(Result.getMax(this), L) && 5517 "Computed backedge-taken count isn't loop invariant for loop!"); 5518 ++NumTripCountsComputed; 5519 } 5520 else if (Result.getMax(this) == getCouldNotCompute() && 5521 isa<PHINode>(L->getHeader()->begin())) { 5522 // Only count loops that have phi nodes as not being computable. 5523 ++NumTripCountsNotComputed; 5524 } 5525 5526 // Now that we know more about the trip count for this loop, forget any 5527 // existing SCEV values for PHI nodes in this loop since they are only 5528 // conservative estimates made without the benefit of trip count 5529 // information. This is similar to the code in forgetLoop, except that 5530 // it handles SCEVUnknown PHI nodes specially. 5531 if (Result.hasAnyInfo()) { 5532 SmallVector<Instruction *, 16> Worklist; 5533 PushLoopPHIs(L, Worklist); 5534 5535 SmallPtrSet<Instruction *, 8> Visited; 5536 while (!Worklist.empty()) { 5537 Instruction *I = Worklist.pop_back_val(); 5538 if (!Visited.insert(I).second) 5539 continue; 5540 5541 ValueExprMapType::iterator It = 5542 ValueExprMap.find_as(static_cast<Value *>(I)); 5543 if (It != ValueExprMap.end()) { 5544 const SCEV *Old = It->second; 5545 5546 // SCEVUnknown for a PHI either means that it has an unrecognized 5547 // structure, or it's a PHI that's in the progress of being computed 5548 // by createNodeForPHI. In the former case, additional loop trip 5549 // count information isn't going to change anything. In the later 5550 // case, createNodeForPHI will perform the necessary updates on its 5551 // own when it gets to that point. 5552 if (!isa<PHINode>(I) || !isa<SCEVUnknown>(Old)) { 5553 eraseValueFromMap(It->first); 5554 forgetMemoizedResults(Old); 5555 } 5556 if (PHINode *PN = dyn_cast<PHINode>(I)) 5557 ConstantEvolutionLoopExitValue.erase(PN); 5558 } 5559 5560 PushDefUseChildren(I, Worklist); 5561 } 5562 } 5563 5564 // Re-lookup the insert position, since the call to 5565 // computeBackedgeTakenCount above could result in a 5566 // recusive call to getBackedgeTakenInfo (on a different 5567 // loop), which would invalidate the iterator computed 5568 // earlier. 5569 return BackedgeTakenCounts.find(L)->second = std::move(Result); 5570 } 5571 5572 void ScalarEvolution::forgetLoop(const Loop *L) { 5573 // Drop any stored trip count value. 5574 auto RemoveLoopFromBackedgeMap = 5575 [L](DenseMap<const Loop *, BackedgeTakenInfo> &Map) { 5576 auto BTCPos = Map.find(L); 5577 if (BTCPos != Map.end()) { 5578 BTCPos->second.clear(); 5579 Map.erase(BTCPos); 5580 } 5581 }; 5582 5583 RemoveLoopFromBackedgeMap(BackedgeTakenCounts); 5584 RemoveLoopFromBackedgeMap(PredicatedBackedgeTakenCounts); 5585 5586 // Drop information about expressions based on loop-header PHIs. 5587 SmallVector<Instruction *, 16> Worklist; 5588 PushLoopPHIs(L, Worklist); 5589 5590 SmallPtrSet<Instruction *, 8> Visited; 5591 while (!Worklist.empty()) { 5592 Instruction *I = Worklist.pop_back_val(); 5593 if (!Visited.insert(I).second) 5594 continue; 5595 5596 ValueExprMapType::iterator It = 5597 ValueExprMap.find_as(static_cast<Value *>(I)); 5598 if (It != ValueExprMap.end()) { 5599 eraseValueFromMap(It->first); 5600 forgetMemoizedResults(It->second); 5601 if (PHINode *PN = dyn_cast<PHINode>(I)) 5602 ConstantEvolutionLoopExitValue.erase(PN); 5603 } 5604 5605 PushDefUseChildren(I, Worklist); 5606 } 5607 5608 // Forget all contained loops too, to avoid dangling entries in the 5609 // ValuesAtScopes map. 5610 for (Loop *I : *L) 5611 forgetLoop(I); 5612 5613 LoopPropertiesCache.erase(L); 5614 } 5615 5616 void ScalarEvolution::forgetValue(Value *V) { 5617 Instruction *I = dyn_cast<Instruction>(V); 5618 if (!I) return; 5619 5620 // Drop information about expressions based on loop-header PHIs. 5621 SmallVector<Instruction *, 16> Worklist; 5622 Worklist.push_back(I); 5623 5624 SmallPtrSet<Instruction *, 8> Visited; 5625 while (!Worklist.empty()) { 5626 I = Worklist.pop_back_val(); 5627 if (!Visited.insert(I).second) 5628 continue; 5629 5630 ValueExprMapType::iterator It = 5631 ValueExprMap.find_as(static_cast<Value *>(I)); 5632 if (It != ValueExprMap.end()) { 5633 eraseValueFromMap(It->first); 5634 forgetMemoizedResults(It->second); 5635 if (PHINode *PN = dyn_cast<PHINode>(I)) 5636 ConstantEvolutionLoopExitValue.erase(PN); 5637 } 5638 5639 PushDefUseChildren(I, Worklist); 5640 } 5641 } 5642 5643 /// Get the exact loop backedge taken count considering all loop exits. A 5644 /// computable result can only be returned for loops with a single exit. 5645 /// Returning the minimum taken count among all exits is incorrect because one 5646 /// of the loop's exit limit's may have been skipped. howFarToZero assumes that 5647 /// the limit of each loop test is never skipped. This is a valid assumption as 5648 /// long as the loop exits via that test. For precise results, it is the 5649 /// caller's responsibility to specify the relevant loop exit using 5650 /// getExact(ExitingBlock, SE). 5651 const SCEV * 5652 ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE, 5653 SCEVUnionPredicate *Preds) const { 5654 // If any exits were not computable, the loop is not computable. 5655 if (!isComplete() || ExitNotTaken.empty()) 5656 return SE->getCouldNotCompute(); 5657 5658 const SCEV *BECount = nullptr; 5659 for (auto &ENT : ExitNotTaken) { 5660 assert(ENT.ExactNotTaken != SE->getCouldNotCompute() && "bad exit SCEV"); 5661 5662 if (!BECount) 5663 BECount = ENT.ExactNotTaken; 5664 else if (BECount != ENT.ExactNotTaken) 5665 return SE->getCouldNotCompute(); 5666 if (Preds && !ENT.hasAlwaysTruePredicate()) 5667 Preds->add(ENT.Predicate.get()); 5668 5669 assert((Preds || ENT.hasAlwaysTruePredicate()) && 5670 "Predicate should be always true!"); 5671 } 5672 5673 assert(BECount && "Invalid not taken count for loop exit"); 5674 return BECount; 5675 } 5676 5677 /// Get the exact not taken count for this loop exit. 5678 const SCEV * 5679 ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock, 5680 ScalarEvolution *SE) const { 5681 for (auto &ENT : ExitNotTaken) 5682 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate()) 5683 return ENT.ExactNotTaken; 5684 5685 return SE->getCouldNotCompute(); 5686 } 5687 5688 /// getMax - Get the max backedge taken count for the loop. 5689 const SCEV * 5690 ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const { 5691 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) { 5692 return !ENT.hasAlwaysTruePredicate(); 5693 }; 5694 5695 if (any_of(ExitNotTaken, PredicateNotAlwaysTrue) || !getMax()) 5696 return SE->getCouldNotCompute(); 5697 5698 return getMax(); 5699 } 5700 5701 bool ScalarEvolution::BackedgeTakenInfo::isMaxOrZero(ScalarEvolution *SE) const { 5702 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) { 5703 return !ENT.hasAlwaysTruePredicate(); 5704 }; 5705 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue); 5706 } 5707 5708 bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S, 5709 ScalarEvolution *SE) const { 5710 if (getMax() && getMax() != SE->getCouldNotCompute() && 5711 SE->hasOperand(getMax(), S)) 5712 return true; 5713 5714 for (auto &ENT : ExitNotTaken) 5715 if (ENT.ExactNotTaken != SE->getCouldNotCompute() && 5716 SE->hasOperand(ENT.ExactNotTaken, S)) 5717 return true; 5718 5719 return false; 5720 } 5721 5722 /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each 5723 /// computable exit into a persistent ExitNotTakenInfo array. 5724 ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( 5725 SmallVectorImpl<ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo> 5726 &&ExitCounts, 5727 bool Complete, const SCEV *MaxCount, bool MaxOrZero) 5728 : MaxAndComplete(MaxCount, Complete), MaxOrZero(MaxOrZero) { 5729 typedef ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo EdgeExitInfo; 5730 ExitNotTaken.reserve(ExitCounts.size()); 5731 std::transform( 5732 ExitCounts.begin(), ExitCounts.end(), std::back_inserter(ExitNotTaken), 5733 [&](const EdgeExitInfo &EEI) { 5734 BasicBlock *ExitBB = EEI.first; 5735 const ExitLimit &EL = EEI.second; 5736 if (EL.Predicates.empty()) 5737 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, nullptr); 5738 5739 std::unique_ptr<SCEVUnionPredicate> Predicate(new SCEVUnionPredicate); 5740 for (auto *Pred : EL.Predicates) 5741 Predicate->add(Pred); 5742 5743 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, std::move(Predicate)); 5744 }); 5745 } 5746 5747 /// Invalidate this result and free the ExitNotTakenInfo array. 5748 void ScalarEvolution::BackedgeTakenInfo::clear() { 5749 ExitNotTaken.clear(); 5750 } 5751 5752 /// Compute the number of times the backedge of the specified loop will execute. 5753 ScalarEvolution::BackedgeTakenInfo 5754 ScalarEvolution::computeBackedgeTakenCount(const Loop *L, 5755 bool AllowPredicates) { 5756 SmallVector<BasicBlock *, 8> ExitingBlocks; 5757 L->getExitingBlocks(ExitingBlocks); 5758 5759 typedef ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo EdgeExitInfo; 5760 5761 SmallVector<EdgeExitInfo, 4> ExitCounts; 5762 bool CouldComputeBECount = true; 5763 BasicBlock *Latch = L->getLoopLatch(); // may be NULL. 5764 const SCEV *MustExitMaxBECount = nullptr; 5765 const SCEV *MayExitMaxBECount = nullptr; 5766 bool MustExitMaxOrZero = false; 5767 5768 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts 5769 // and compute maxBECount. 5770 // Do a union of all the predicates here. 5771 for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) { 5772 BasicBlock *ExitBB = ExitingBlocks[i]; 5773 ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates); 5774 5775 assert((AllowPredicates || EL.Predicates.empty()) && 5776 "Predicated exit limit when predicates are not allowed!"); 5777 5778 // 1. For each exit that can be computed, add an entry to ExitCounts. 5779 // CouldComputeBECount is true only if all exits can be computed. 5780 if (EL.ExactNotTaken == getCouldNotCompute()) 5781 // We couldn't compute an exact value for this exit, so 5782 // we won't be able to compute an exact value for the loop. 5783 CouldComputeBECount = false; 5784 else 5785 ExitCounts.emplace_back(ExitBB, EL); 5786 5787 // 2. Derive the loop's MaxBECount from each exit's max number of 5788 // non-exiting iterations. Partition the loop exits into two kinds: 5789 // LoopMustExits and LoopMayExits. 5790 // 5791 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it 5792 // is a LoopMayExit. If any computable LoopMustExit is found, then 5793 // MaxBECount is the minimum EL.MaxNotTaken of computable 5794 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum 5795 // EL.MaxNotTaken, where CouldNotCompute is considered greater than any 5796 // computable EL.MaxNotTaken. 5797 if (EL.MaxNotTaken != getCouldNotCompute() && Latch && 5798 DT.dominates(ExitBB, Latch)) { 5799 if (!MustExitMaxBECount) { 5800 MustExitMaxBECount = EL.MaxNotTaken; 5801 MustExitMaxOrZero = EL.MaxOrZero; 5802 } else { 5803 MustExitMaxBECount = 5804 getUMinFromMismatchedTypes(MustExitMaxBECount, EL.MaxNotTaken); 5805 } 5806 } else if (MayExitMaxBECount != getCouldNotCompute()) { 5807 if (!MayExitMaxBECount || EL.MaxNotTaken == getCouldNotCompute()) 5808 MayExitMaxBECount = EL.MaxNotTaken; 5809 else { 5810 MayExitMaxBECount = 5811 getUMaxFromMismatchedTypes(MayExitMaxBECount, EL.MaxNotTaken); 5812 } 5813 } 5814 } 5815 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount : 5816 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute()); 5817 // The loop backedge will be taken the maximum or zero times if there's 5818 // a single exit that must be taken the maximum or zero times. 5819 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1); 5820 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount, 5821 MaxBECount, MaxOrZero); 5822 } 5823 5824 ScalarEvolution::ExitLimit 5825 ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, 5826 bool AllowPredicates) { 5827 5828 // Okay, we've chosen an exiting block. See what condition causes us to exit 5829 // at this block and remember the exit block and whether all other targets 5830 // lead to the loop header. 5831 bool MustExecuteLoopHeader = true; 5832 BasicBlock *Exit = nullptr; 5833 for (auto *SBB : successors(ExitingBlock)) 5834 if (!L->contains(SBB)) { 5835 if (Exit) // Multiple exit successors. 5836 return getCouldNotCompute(); 5837 Exit = SBB; 5838 } else if (SBB != L->getHeader()) { 5839 MustExecuteLoopHeader = false; 5840 } 5841 5842 // At this point, we know we have a conditional branch that determines whether 5843 // the loop is exited. However, we don't know if the branch is executed each 5844 // time through the loop. If not, then the execution count of the branch will 5845 // not be equal to the trip count of the loop. 5846 // 5847 // Currently we check for this by checking to see if the Exit branch goes to 5848 // the loop header. If so, we know it will always execute the same number of 5849 // times as the loop. We also handle the case where the exit block *is* the 5850 // loop header. This is common for un-rotated loops. 5851 // 5852 // If both of those tests fail, walk up the unique predecessor chain to the 5853 // header, stopping if there is an edge that doesn't exit the loop. If the 5854 // header is reached, the execution count of the branch will be equal to the 5855 // trip count of the loop. 5856 // 5857 // More extensive analysis could be done to handle more cases here. 5858 // 5859 if (!MustExecuteLoopHeader && ExitingBlock != L->getHeader()) { 5860 // The simple checks failed, try climbing the unique predecessor chain 5861 // up to the header. 5862 bool Ok = false; 5863 for (BasicBlock *BB = ExitingBlock; BB; ) { 5864 BasicBlock *Pred = BB->getUniquePredecessor(); 5865 if (!Pred) 5866 return getCouldNotCompute(); 5867 TerminatorInst *PredTerm = Pred->getTerminator(); 5868 for (const BasicBlock *PredSucc : PredTerm->successors()) { 5869 if (PredSucc == BB) 5870 continue; 5871 // If the predecessor has a successor that isn't BB and isn't 5872 // outside the loop, assume the worst. 5873 if (L->contains(PredSucc)) 5874 return getCouldNotCompute(); 5875 } 5876 if (Pred == L->getHeader()) { 5877 Ok = true; 5878 break; 5879 } 5880 BB = Pred; 5881 } 5882 if (!Ok) 5883 return getCouldNotCompute(); 5884 } 5885 5886 bool IsOnlyExit = (L->getExitingBlock() != nullptr); 5887 TerminatorInst *Term = ExitingBlock->getTerminator(); 5888 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) { 5889 assert(BI->isConditional() && "If unconditional, it can't be in loop!"); 5890 // Proceed to the next level to examine the exit condition expression. 5891 return computeExitLimitFromCond( 5892 L, BI->getCondition(), BI->getSuccessor(0), BI->getSuccessor(1), 5893 /*ControlsExit=*/IsOnlyExit, AllowPredicates); 5894 } 5895 5896 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) 5897 return computeExitLimitFromSingleExitSwitch(L, SI, Exit, 5898 /*ControlsExit=*/IsOnlyExit); 5899 5900 return getCouldNotCompute(); 5901 } 5902 5903 ScalarEvolution::ExitLimit 5904 ScalarEvolution::computeExitLimitFromCond(const Loop *L, 5905 Value *ExitCond, 5906 BasicBlock *TBB, 5907 BasicBlock *FBB, 5908 bool ControlsExit, 5909 bool AllowPredicates) { 5910 // Check if the controlling expression for this loop is an And or Or. 5911 if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) { 5912 if (BO->getOpcode() == Instruction::And) { 5913 // Recurse on the operands of the and. 5914 bool EitherMayExit = L->contains(TBB); 5915 ExitLimit EL0 = computeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, 5916 ControlsExit && !EitherMayExit, 5917 AllowPredicates); 5918 ExitLimit EL1 = computeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, 5919 ControlsExit && !EitherMayExit, 5920 AllowPredicates); 5921 const SCEV *BECount = getCouldNotCompute(); 5922 const SCEV *MaxBECount = getCouldNotCompute(); 5923 if (EitherMayExit) { 5924 // Both conditions must be true for the loop to continue executing. 5925 // Choose the less conservative count. 5926 if (EL0.ExactNotTaken == getCouldNotCompute() || 5927 EL1.ExactNotTaken == getCouldNotCompute()) 5928 BECount = getCouldNotCompute(); 5929 else 5930 BECount = 5931 getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken); 5932 if (EL0.MaxNotTaken == getCouldNotCompute()) 5933 MaxBECount = EL1.MaxNotTaken; 5934 else if (EL1.MaxNotTaken == getCouldNotCompute()) 5935 MaxBECount = EL0.MaxNotTaken; 5936 else 5937 MaxBECount = 5938 getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken); 5939 } else { 5940 // Both conditions must be true at the same time for the loop to exit. 5941 // For now, be conservative. 5942 assert(L->contains(FBB) && "Loop block has no successor in loop!"); 5943 if (EL0.MaxNotTaken == EL1.MaxNotTaken) 5944 MaxBECount = EL0.MaxNotTaken; 5945 if (EL0.ExactNotTaken == EL1.ExactNotTaken) 5946 BECount = EL0.ExactNotTaken; 5947 } 5948 5949 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able 5950 // to be more aggressive when computing BECount than when computing 5951 // MaxBECount. In these cases it is possible for EL0.ExactNotTaken and 5952 // EL1.ExactNotTaken to match, but for EL0.MaxNotTaken and EL1.MaxNotTaken 5953 // to not. 5954 if (isa<SCEVCouldNotCompute>(MaxBECount) && 5955 !isa<SCEVCouldNotCompute>(BECount)) 5956 MaxBECount = BECount; 5957 5958 return ExitLimit(BECount, MaxBECount, false, 5959 {&EL0.Predicates, &EL1.Predicates}); 5960 } 5961 if (BO->getOpcode() == Instruction::Or) { 5962 // Recurse on the operands of the or. 5963 bool EitherMayExit = L->contains(FBB); 5964 ExitLimit EL0 = computeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, 5965 ControlsExit && !EitherMayExit, 5966 AllowPredicates); 5967 ExitLimit EL1 = computeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, 5968 ControlsExit && !EitherMayExit, 5969 AllowPredicates); 5970 const SCEV *BECount = getCouldNotCompute(); 5971 const SCEV *MaxBECount = getCouldNotCompute(); 5972 if (EitherMayExit) { 5973 // Both conditions must be false for the loop to continue executing. 5974 // Choose the less conservative count. 5975 if (EL0.ExactNotTaken == getCouldNotCompute() || 5976 EL1.ExactNotTaken == getCouldNotCompute()) 5977 BECount = getCouldNotCompute(); 5978 else 5979 BECount = 5980 getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken); 5981 if (EL0.MaxNotTaken == getCouldNotCompute()) 5982 MaxBECount = EL1.MaxNotTaken; 5983 else if (EL1.MaxNotTaken == getCouldNotCompute()) 5984 MaxBECount = EL0.MaxNotTaken; 5985 else 5986 MaxBECount = 5987 getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken); 5988 } else { 5989 // Both conditions must be false at the same time for the loop to exit. 5990 // For now, be conservative. 5991 assert(L->contains(TBB) && "Loop block has no successor in loop!"); 5992 if (EL0.MaxNotTaken == EL1.MaxNotTaken) 5993 MaxBECount = EL0.MaxNotTaken; 5994 if (EL0.ExactNotTaken == EL1.ExactNotTaken) 5995 BECount = EL0.ExactNotTaken; 5996 } 5997 5998 return ExitLimit(BECount, MaxBECount, false, 5999 {&EL0.Predicates, &EL1.Predicates}); 6000 } 6001 } 6002 6003 // With an icmp, it may be feasible to compute an exact backedge-taken count. 6004 // Proceed to the next level to examine the icmp. 6005 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) { 6006 ExitLimit EL = 6007 computeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit); 6008 if (EL.hasFullInfo() || !AllowPredicates) 6009 return EL; 6010 6011 // Try again, but use SCEV predicates this time. 6012 return computeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit, 6013 /*AllowPredicates=*/true); 6014 } 6015 6016 // Check for a constant condition. These are normally stripped out by 6017 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to 6018 // preserve the CFG and is temporarily leaving constant conditions 6019 // in place. 6020 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) { 6021 if (L->contains(FBB) == !CI->getZExtValue()) 6022 // The backedge is always taken. 6023 return getCouldNotCompute(); 6024 else 6025 // The backedge is never taken. 6026 return getZero(CI->getType()); 6027 } 6028 6029 // If it's not an integer or pointer comparison then compute it the hard way. 6030 return computeExitCountExhaustively(L, ExitCond, !L->contains(TBB)); 6031 } 6032 6033 ScalarEvolution::ExitLimit 6034 ScalarEvolution::computeExitLimitFromICmp(const Loop *L, 6035 ICmpInst *ExitCond, 6036 BasicBlock *TBB, 6037 BasicBlock *FBB, 6038 bool ControlsExit, 6039 bool AllowPredicates) { 6040 6041 // If the condition was exit on true, convert the condition to exit on false 6042 ICmpInst::Predicate Cond; 6043 if (!L->contains(FBB)) 6044 Cond = ExitCond->getPredicate(); 6045 else 6046 Cond = ExitCond->getInversePredicate(); 6047 6048 // Handle common loops like: for (X = "string"; *X; ++X) 6049 if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0))) 6050 if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) { 6051 ExitLimit ItCnt = 6052 computeLoadConstantCompareExitLimit(LI, RHS, L, Cond); 6053 if (ItCnt.hasAnyInfo()) 6054 return ItCnt; 6055 } 6056 6057 const SCEV *LHS = getSCEV(ExitCond->getOperand(0)); 6058 const SCEV *RHS = getSCEV(ExitCond->getOperand(1)); 6059 6060 // Try to evaluate any dependencies out of the loop. 6061 LHS = getSCEVAtScope(LHS, L); 6062 RHS = getSCEVAtScope(RHS, L); 6063 6064 // At this point, we would like to compute how many iterations of the 6065 // loop the predicate will return true for these inputs. 6066 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) { 6067 // If there is a loop-invariant, force it into the RHS. 6068 std::swap(LHS, RHS); 6069 Cond = ICmpInst::getSwappedPredicate(Cond); 6070 } 6071 6072 // Simplify the operands before analyzing them. 6073 (void)SimplifyICmpOperands(Cond, LHS, RHS); 6074 6075 // If we have a comparison of a chrec against a constant, try to use value 6076 // ranges to answer this query. 6077 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) 6078 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS)) 6079 if (AddRec->getLoop() == L) { 6080 // Form the constant range. 6081 ConstantRange CompRange = 6082 ConstantRange::makeExactICmpRegion(Cond, RHSC->getAPInt()); 6083 6084 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this); 6085 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret; 6086 } 6087 6088 switch (Cond) { 6089 case ICmpInst::ICMP_NE: { // while (X != Y) 6090 // Convert to: while (X-Y != 0) 6091 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit, 6092 AllowPredicates); 6093 if (EL.hasAnyInfo()) return EL; 6094 break; 6095 } 6096 case ICmpInst::ICMP_EQ: { // while (X == Y) 6097 // Convert to: while (X-Y == 0) 6098 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L); 6099 if (EL.hasAnyInfo()) return EL; 6100 break; 6101 } 6102 case ICmpInst::ICMP_SLT: 6103 case ICmpInst::ICMP_ULT: { // while (X < Y) 6104 bool IsSigned = Cond == ICmpInst::ICMP_SLT; 6105 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsExit, 6106 AllowPredicates); 6107 if (EL.hasAnyInfo()) return EL; 6108 break; 6109 } 6110 case ICmpInst::ICMP_SGT: 6111 case ICmpInst::ICMP_UGT: { // while (X > Y) 6112 bool IsSigned = Cond == ICmpInst::ICMP_SGT; 6113 ExitLimit EL = 6114 howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit, 6115 AllowPredicates); 6116 if (EL.hasAnyInfo()) return EL; 6117 break; 6118 } 6119 default: 6120 break; 6121 } 6122 6123 auto *ExhaustiveCount = 6124 computeExitCountExhaustively(L, ExitCond, !L->contains(TBB)); 6125 6126 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount)) 6127 return ExhaustiveCount; 6128 6129 return computeShiftCompareExitLimit(ExitCond->getOperand(0), 6130 ExitCond->getOperand(1), L, Cond); 6131 } 6132 6133 ScalarEvolution::ExitLimit 6134 ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L, 6135 SwitchInst *Switch, 6136 BasicBlock *ExitingBlock, 6137 bool ControlsExit) { 6138 assert(!L->contains(ExitingBlock) && "Not an exiting block!"); 6139 6140 // Give up if the exit is the default dest of a switch. 6141 if (Switch->getDefaultDest() == ExitingBlock) 6142 return getCouldNotCompute(); 6143 6144 assert(L->contains(Switch->getDefaultDest()) && 6145 "Default case must not exit the loop!"); 6146 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L); 6147 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock)); 6148 6149 // while (X != Y) --> while (X-Y != 0) 6150 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit); 6151 if (EL.hasAnyInfo()) 6152 return EL; 6153 6154 return getCouldNotCompute(); 6155 } 6156 6157 static ConstantInt * 6158 EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, 6159 ScalarEvolution &SE) { 6160 const SCEV *InVal = SE.getConstant(C); 6161 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE); 6162 assert(isa<SCEVConstant>(Val) && 6163 "Evaluation of SCEV at constant didn't fold correctly?"); 6164 return cast<SCEVConstant>(Val)->getValue(); 6165 } 6166 6167 /// Given an exit condition of 'icmp op load X, cst', try to see if we can 6168 /// compute the backedge execution count. 6169 ScalarEvolution::ExitLimit 6170 ScalarEvolution::computeLoadConstantCompareExitLimit( 6171 LoadInst *LI, 6172 Constant *RHS, 6173 const Loop *L, 6174 ICmpInst::Predicate predicate) { 6175 6176 if (LI->isVolatile()) return getCouldNotCompute(); 6177 6178 // Check to see if the loaded pointer is a getelementptr of a global. 6179 // TODO: Use SCEV instead of manually grubbing with GEPs. 6180 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0)); 6181 if (!GEP) return getCouldNotCompute(); 6182 6183 // Make sure that it is really a constant global we are gepping, with an 6184 // initializer, and make sure the first IDX is really 0. 6185 GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0)); 6186 if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer() || 6187 GEP->getNumOperands() < 3 || !isa<Constant>(GEP->getOperand(1)) || 6188 !cast<Constant>(GEP->getOperand(1))->isNullValue()) 6189 return getCouldNotCompute(); 6190 6191 // Okay, we allow one non-constant index into the GEP instruction. 6192 Value *VarIdx = nullptr; 6193 std::vector<Constant*> Indexes; 6194 unsigned VarIdxNum = 0; 6195 for (unsigned i = 2, e = GEP->getNumOperands(); i != e; ++i) 6196 if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) { 6197 Indexes.push_back(CI); 6198 } else if (!isa<ConstantInt>(GEP->getOperand(i))) { 6199 if (VarIdx) return getCouldNotCompute(); // Multiple non-constant idx's. 6200 VarIdx = GEP->getOperand(i); 6201 VarIdxNum = i-2; 6202 Indexes.push_back(nullptr); 6203 } 6204 6205 // Loop-invariant loads may be a byproduct of loop optimization. Skip them. 6206 if (!VarIdx) 6207 return getCouldNotCompute(); 6208 6209 // Okay, we know we have a (load (gep GV, 0, X)) comparison with a constant. 6210 // Check to see if X is a loop variant variable value now. 6211 const SCEV *Idx = getSCEV(VarIdx); 6212 Idx = getSCEVAtScope(Idx, L); 6213 6214 // We can only recognize very limited forms of loop index expressions, in 6215 // particular, only affine AddRec's like {C1,+,C2}. 6216 const SCEVAddRecExpr *IdxExpr = dyn_cast<SCEVAddRecExpr>(Idx); 6217 if (!IdxExpr || !IdxExpr->isAffine() || isLoopInvariant(IdxExpr, L) || 6218 !isa<SCEVConstant>(IdxExpr->getOperand(0)) || 6219 !isa<SCEVConstant>(IdxExpr->getOperand(1))) 6220 return getCouldNotCompute(); 6221 6222 unsigned MaxSteps = MaxBruteForceIterations; 6223 for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) { 6224 ConstantInt *ItCst = ConstantInt::get( 6225 cast<IntegerType>(IdxExpr->getType()), IterationNum); 6226 ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, *this); 6227 6228 // Form the GEP offset. 6229 Indexes[VarIdxNum] = Val; 6230 6231 Constant *Result = ConstantFoldLoadThroughGEPIndices(GV->getInitializer(), 6232 Indexes); 6233 if (!Result) break; // Cannot compute! 6234 6235 // Evaluate the condition for this iteration. 6236 Result = ConstantExpr::getICmp(predicate, Result, RHS); 6237 if (!isa<ConstantInt>(Result)) break; // Couldn't decide for sure 6238 if (cast<ConstantInt>(Result)->getValue().isMinValue()) { 6239 ++NumArrayLenItCounts; 6240 return getConstant(ItCst); // Found terminating iteration! 6241 } 6242 } 6243 return getCouldNotCompute(); 6244 } 6245 6246 ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit( 6247 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) { 6248 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV); 6249 if (!RHS) 6250 return getCouldNotCompute(); 6251 6252 const BasicBlock *Latch = L->getLoopLatch(); 6253 if (!Latch) 6254 return getCouldNotCompute(); 6255 6256 const BasicBlock *Predecessor = L->getLoopPredecessor(); 6257 if (!Predecessor) 6258 return getCouldNotCompute(); 6259 6260 // Return true if V is of the form "LHS `shift_op` <positive constant>". 6261 // Return LHS in OutLHS and shift_opt in OutOpCode. 6262 auto MatchPositiveShift = 6263 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) { 6264 6265 using namespace PatternMatch; 6266 6267 ConstantInt *ShiftAmt; 6268 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) 6269 OutOpCode = Instruction::LShr; 6270 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) 6271 OutOpCode = Instruction::AShr; 6272 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) 6273 OutOpCode = Instruction::Shl; 6274 else 6275 return false; 6276 6277 return ShiftAmt->getValue().isStrictlyPositive(); 6278 }; 6279 6280 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in 6281 // 6282 // loop: 6283 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ] 6284 // %iv.shifted = lshr i32 %iv, <positive constant> 6285 // 6286 // Return true on a successful match. Return the corresponding PHI node (%iv 6287 // above) in PNOut and the opcode of the shift operation in OpCodeOut. 6288 auto MatchShiftRecurrence = 6289 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) { 6290 Optional<Instruction::BinaryOps> PostShiftOpCode; 6291 6292 { 6293 Instruction::BinaryOps OpC; 6294 Value *V; 6295 6296 // If we encounter a shift instruction, "peel off" the shift operation, 6297 // and remember that we did so. Later when we inspect %iv's backedge 6298 // value, we will make sure that the backedge value uses the same 6299 // operation. 6300 // 6301 // Note: the peeled shift operation does not have to be the same 6302 // instruction as the one feeding into the PHI's backedge value. We only 6303 // really care about it being the same *kind* of shift instruction -- 6304 // that's all that is required for our later inferences to hold. 6305 if (MatchPositiveShift(LHS, V, OpC)) { 6306 PostShiftOpCode = OpC; 6307 LHS = V; 6308 } 6309 } 6310 6311 PNOut = dyn_cast<PHINode>(LHS); 6312 if (!PNOut || PNOut->getParent() != L->getHeader()) 6313 return false; 6314 6315 Value *BEValue = PNOut->getIncomingValueForBlock(Latch); 6316 Value *OpLHS; 6317 6318 return 6319 // The backedge value for the PHI node must be a shift by a positive 6320 // amount 6321 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) && 6322 6323 // of the PHI node itself 6324 OpLHS == PNOut && 6325 6326 // and the kind of shift should be match the kind of shift we peeled 6327 // off, if any. 6328 (!PostShiftOpCode.hasValue() || *PostShiftOpCode == OpCodeOut); 6329 }; 6330 6331 PHINode *PN; 6332 Instruction::BinaryOps OpCode; 6333 if (!MatchShiftRecurrence(LHS, PN, OpCode)) 6334 return getCouldNotCompute(); 6335 6336 const DataLayout &DL = getDataLayout(); 6337 6338 // The key rationale for this optimization is that for some kinds of shift 6339 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1 6340 // within a finite number of iterations. If the condition guarding the 6341 // backedge (in the sense that the backedge is taken if the condition is true) 6342 // is false for the value the shift recurrence stabilizes to, then we know 6343 // that the backedge is taken only a finite number of times. 6344 6345 ConstantInt *StableValue = nullptr; 6346 switch (OpCode) { 6347 default: 6348 llvm_unreachable("Impossible case!"); 6349 6350 case Instruction::AShr: { 6351 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most 6352 // bitwidth(K) iterations. 6353 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor); 6354 bool KnownZero, KnownOne; 6355 ComputeSignBit(FirstValue, KnownZero, KnownOne, DL, 0, nullptr, 6356 Predecessor->getTerminator(), &DT); 6357 auto *Ty = cast<IntegerType>(RHS->getType()); 6358 if (KnownZero) 6359 StableValue = ConstantInt::get(Ty, 0); 6360 else if (KnownOne) 6361 StableValue = ConstantInt::get(Ty, -1, true); 6362 else 6363 return getCouldNotCompute(); 6364 6365 break; 6366 } 6367 case Instruction::LShr: 6368 case Instruction::Shl: 6369 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>} 6370 // stabilize to 0 in at most bitwidth(K) iterations. 6371 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0); 6372 break; 6373 } 6374 6375 auto *Result = 6376 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI); 6377 assert(Result->getType()->isIntegerTy(1) && 6378 "Otherwise cannot be an operand to a branch instruction"); 6379 6380 if (Result->isZeroValue()) { 6381 unsigned BitWidth = getTypeSizeInBits(RHS->getType()); 6382 const SCEV *UpperBound = 6383 getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth); 6384 return ExitLimit(getCouldNotCompute(), UpperBound, false); 6385 } 6386 6387 return getCouldNotCompute(); 6388 } 6389 6390 /// Return true if we can constant fold an instruction of the specified type, 6391 /// assuming that all operands were constants. 6392 static bool CanConstantFold(const Instruction *I) { 6393 if (isa<BinaryOperator>(I) || isa<CmpInst>(I) || 6394 isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) || 6395 isa<LoadInst>(I)) 6396 return true; 6397 6398 if (const CallInst *CI = dyn_cast<CallInst>(I)) 6399 if (const Function *F = CI->getCalledFunction()) 6400 return canConstantFoldCallTo(F); 6401 return false; 6402 } 6403 6404 /// Determine whether this instruction can constant evolve within this loop 6405 /// assuming its operands can all constant evolve. 6406 static bool canConstantEvolve(Instruction *I, const Loop *L) { 6407 // An instruction outside of the loop can't be derived from a loop PHI. 6408 if (!L->contains(I)) return false; 6409 6410 if (isa<PHINode>(I)) { 6411 // We don't currently keep track of the control flow needed to evaluate 6412 // PHIs, so we cannot handle PHIs inside of loops. 6413 return L->getHeader() == I->getParent(); 6414 } 6415 6416 // If we won't be able to constant fold this expression even if the operands 6417 // are constants, bail early. 6418 return CanConstantFold(I); 6419 } 6420 6421 /// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by 6422 /// recursing through each instruction operand until reaching a loop header phi. 6423 static PHINode * 6424 getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, 6425 DenseMap<Instruction *, PHINode *> &PHIMap, 6426 unsigned Depth) { 6427 if (Depth > MaxConstantEvolvingDepth) 6428 return nullptr; 6429 6430 // Otherwise, we can evaluate this instruction if all of its operands are 6431 // constant or derived from a PHI node themselves. 6432 PHINode *PHI = nullptr; 6433 for (Value *Op : UseInst->operands()) { 6434 if (isa<Constant>(Op)) continue; 6435 6436 Instruction *OpInst = dyn_cast<Instruction>(Op); 6437 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr; 6438 6439 PHINode *P = dyn_cast<PHINode>(OpInst); 6440 if (!P) 6441 // If this operand is already visited, reuse the prior result. 6442 // We may have P != PHI if this is the deepest point at which the 6443 // inconsistent paths meet. 6444 P = PHIMap.lookup(OpInst); 6445 if (!P) { 6446 // Recurse and memoize the results, whether a phi is found or not. 6447 // This recursive call invalidates pointers into PHIMap. 6448 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1); 6449 PHIMap[OpInst] = P; 6450 } 6451 if (!P) 6452 return nullptr; // Not evolving from PHI 6453 if (PHI && PHI != P) 6454 return nullptr; // Evolving from multiple different PHIs. 6455 PHI = P; 6456 } 6457 // This is a expression evolving from a constant PHI! 6458 return PHI; 6459 } 6460 6461 /// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node 6462 /// in the loop that V is derived from. We allow arbitrary operations along the 6463 /// way, but the operands of an operation must either be constants or a value 6464 /// derived from a constant PHI. If this expression does not fit with these 6465 /// constraints, return null. 6466 static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { 6467 Instruction *I = dyn_cast<Instruction>(V); 6468 if (!I || !canConstantEvolve(I, L)) return nullptr; 6469 6470 if (PHINode *PN = dyn_cast<PHINode>(I)) 6471 return PN; 6472 6473 // Record non-constant instructions contained by the loop. 6474 DenseMap<Instruction *, PHINode *> PHIMap; 6475 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0); 6476 } 6477 6478 /// EvaluateExpression - Given an expression that passes the 6479 /// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node 6480 /// in the loop has the value PHIVal. If we can't fold this expression for some 6481 /// reason, return null. 6482 static Constant *EvaluateExpression(Value *V, const Loop *L, 6483 DenseMap<Instruction *, Constant *> &Vals, 6484 const DataLayout &DL, 6485 const TargetLibraryInfo *TLI) { 6486 // Convenient constant check, but redundant for recursive calls. 6487 if (Constant *C = dyn_cast<Constant>(V)) return C; 6488 Instruction *I = dyn_cast<Instruction>(V); 6489 if (!I) return nullptr; 6490 6491 if (Constant *C = Vals.lookup(I)) return C; 6492 6493 // An instruction inside the loop depends on a value outside the loop that we 6494 // weren't given a mapping for, or a value such as a call inside the loop. 6495 if (!canConstantEvolve(I, L)) return nullptr; 6496 6497 // An unmapped PHI can be due to a branch or another loop inside this loop, 6498 // or due to this not being the initial iteration through a loop where we 6499 // couldn't compute the evolution of this particular PHI last time. 6500 if (isa<PHINode>(I)) return nullptr; 6501 6502 std::vector<Constant*> Operands(I->getNumOperands()); 6503 6504 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { 6505 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i)); 6506 if (!Operand) { 6507 Operands[i] = dyn_cast<Constant>(I->getOperand(i)); 6508 if (!Operands[i]) return nullptr; 6509 continue; 6510 } 6511 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI); 6512 Vals[Operand] = C; 6513 if (!C) return nullptr; 6514 Operands[i] = C; 6515 } 6516 6517 if (CmpInst *CI = dyn_cast<CmpInst>(I)) 6518 return ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0], 6519 Operands[1], DL, TLI); 6520 if (LoadInst *LI = dyn_cast<LoadInst>(I)) { 6521 if (!LI->isVolatile()) 6522 return ConstantFoldLoadFromConstPtr(Operands[0], LI->getType(), DL); 6523 } 6524 return ConstantFoldInstOperands(I, Operands, DL, TLI); 6525 } 6526 6527 6528 // If every incoming value to PN except the one for BB is a specific Constant, 6529 // return that, else return nullptr. 6530 static Constant *getOtherIncomingValue(PHINode *PN, BasicBlock *BB) { 6531 Constant *IncomingVal = nullptr; 6532 6533 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { 6534 if (PN->getIncomingBlock(i) == BB) 6535 continue; 6536 6537 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i)); 6538 if (!CurrentVal) 6539 return nullptr; 6540 6541 if (IncomingVal != CurrentVal) { 6542 if (IncomingVal) 6543 return nullptr; 6544 IncomingVal = CurrentVal; 6545 } 6546 } 6547 6548 return IncomingVal; 6549 } 6550 6551 /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is 6552 /// in the header of its containing loop, we know the loop executes a 6553 /// constant number of times, and the PHI node is just a recurrence 6554 /// involving constants, fold it. 6555 Constant * 6556 ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, 6557 const APInt &BEs, 6558 const Loop *L) { 6559 auto I = ConstantEvolutionLoopExitValue.find(PN); 6560 if (I != ConstantEvolutionLoopExitValue.end()) 6561 return I->second; 6562 6563 if (BEs.ugt(MaxBruteForceIterations)) 6564 return ConstantEvolutionLoopExitValue[PN] = nullptr; // Not going to evaluate it. 6565 6566 Constant *&RetVal = ConstantEvolutionLoopExitValue[PN]; 6567 6568 DenseMap<Instruction *, Constant *> CurrentIterVals; 6569 BasicBlock *Header = L->getHeader(); 6570 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!"); 6571 6572 BasicBlock *Latch = L->getLoopLatch(); 6573 if (!Latch) 6574 return nullptr; 6575 6576 for (auto &I : *Header) { 6577 PHINode *PHI = dyn_cast<PHINode>(&I); 6578 if (!PHI) break; 6579 auto *StartCST = getOtherIncomingValue(PHI, Latch); 6580 if (!StartCST) continue; 6581 CurrentIterVals[PHI] = StartCST; 6582 } 6583 if (!CurrentIterVals.count(PN)) 6584 return RetVal = nullptr; 6585 6586 Value *BEValue = PN->getIncomingValueForBlock(Latch); 6587 6588 // Execute the loop symbolically to determine the exit value. 6589 if (BEs.getActiveBits() >= 32) 6590 return RetVal = nullptr; // More than 2^32-1 iterations?? Not doing it! 6591 6592 unsigned NumIterations = BEs.getZExtValue(); // must be in range 6593 unsigned IterationNum = 0; 6594 const DataLayout &DL = getDataLayout(); 6595 for (; ; ++IterationNum) { 6596 if (IterationNum == NumIterations) 6597 return RetVal = CurrentIterVals[PN]; // Got exit value! 6598 6599 // Compute the value of the PHIs for the next iteration. 6600 // EvaluateExpression adds non-phi values to the CurrentIterVals map. 6601 DenseMap<Instruction *, Constant *> NextIterVals; 6602 Constant *NextPHI = 6603 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI); 6604 if (!NextPHI) 6605 return nullptr; // Couldn't evaluate! 6606 NextIterVals[PN] = NextPHI; 6607 6608 bool StoppedEvolving = NextPHI == CurrentIterVals[PN]; 6609 6610 // Also evaluate the other PHI nodes. However, we don't get to stop if we 6611 // cease to be able to evaluate one of them or if they stop evolving, 6612 // because that doesn't necessarily prevent us from computing PN. 6613 SmallVector<std::pair<PHINode *, Constant *>, 8> PHIsToCompute; 6614 for (const auto &I : CurrentIterVals) { 6615 PHINode *PHI = dyn_cast<PHINode>(I.first); 6616 if (!PHI || PHI == PN || PHI->getParent() != Header) continue; 6617 PHIsToCompute.emplace_back(PHI, I.second); 6618 } 6619 // We use two distinct loops because EvaluateExpression may invalidate any 6620 // iterators into CurrentIterVals. 6621 for (const auto &I : PHIsToCompute) { 6622 PHINode *PHI = I.first; 6623 Constant *&NextPHI = NextIterVals[PHI]; 6624 if (!NextPHI) { // Not already computed. 6625 Value *BEValue = PHI->getIncomingValueForBlock(Latch); 6626 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI); 6627 } 6628 if (NextPHI != I.second) 6629 StoppedEvolving = false; 6630 } 6631 6632 // If all entries in CurrentIterVals == NextIterVals then we can stop 6633 // iterating, the loop can't continue to change. 6634 if (StoppedEvolving) 6635 return RetVal = CurrentIterVals[PN]; 6636 6637 CurrentIterVals.swap(NextIterVals); 6638 } 6639 } 6640 6641 const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L, 6642 Value *Cond, 6643 bool ExitWhen) { 6644 PHINode *PN = getConstantEvolvingPHI(Cond, L); 6645 if (!PN) return getCouldNotCompute(); 6646 6647 // If the loop is canonicalized, the PHI will have exactly two entries. 6648 // That's the only form we support here. 6649 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute(); 6650 6651 DenseMap<Instruction *, Constant *> CurrentIterVals; 6652 BasicBlock *Header = L->getHeader(); 6653 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!"); 6654 6655 BasicBlock *Latch = L->getLoopLatch(); 6656 assert(Latch && "Should follow from NumIncomingValues == 2!"); 6657 6658 for (auto &I : *Header) { 6659 PHINode *PHI = dyn_cast<PHINode>(&I); 6660 if (!PHI) 6661 break; 6662 auto *StartCST = getOtherIncomingValue(PHI, Latch); 6663 if (!StartCST) continue; 6664 CurrentIterVals[PHI] = StartCST; 6665 } 6666 if (!CurrentIterVals.count(PN)) 6667 return getCouldNotCompute(); 6668 6669 // Okay, we find a PHI node that defines the trip count of this loop. Execute 6670 // the loop symbolically to determine when the condition gets a value of 6671 // "ExitWhen". 6672 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis. 6673 const DataLayout &DL = getDataLayout(); 6674 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){ 6675 auto *CondVal = dyn_cast_or_null<ConstantInt>( 6676 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI)); 6677 6678 // Couldn't symbolically evaluate. 6679 if (!CondVal) return getCouldNotCompute(); 6680 6681 if (CondVal->getValue() == uint64_t(ExitWhen)) { 6682 ++NumBruteForceTripCountsComputed; 6683 return getConstant(Type::getInt32Ty(getContext()), IterationNum); 6684 } 6685 6686 // Update all the PHI nodes for the next iteration. 6687 DenseMap<Instruction *, Constant *> NextIterVals; 6688 6689 // Create a list of which PHIs we need to compute. We want to do this before 6690 // calling EvaluateExpression on them because that may invalidate iterators 6691 // into CurrentIterVals. 6692 SmallVector<PHINode *, 8> PHIsToCompute; 6693 for (const auto &I : CurrentIterVals) { 6694 PHINode *PHI = dyn_cast<PHINode>(I.first); 6695 if (!PHI || PHI->getParent() != Header) continue; 6696 PHIsToCompute.push_back(PHI); 6697 } 6698 for (PHINode *PHI : PHIsToCompute) { 6699 Constant *&NextPHI = NextIterVals[PHI]; 6700 if (NextPHI) continue; // Already computed! 6701 6702 Value *BEValue = PHI->getIncomingValueForBlock(Latch); 6703 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI); 6704 } 6705 CurrentIterVals.swap(NextIterVals); 6706 } 6707 6708 // Too many iterations were needed to evaluate. 6709 return getCouldNotCompute(); 6710 } 6711 6712 const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { 6713 SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values = 6714 ValuesAtScopes[V]; 6715 // Check to see if we've folded this expression at this loop before. 6716 for (auto &LS : Values) 6717 if (LS.first == L) 6718 return LS.second ? LS.second : V; 6719 6720 Values.emplace_back(L, nullptr); 6721 6722 // Otherwise compute it. 6723 const SCEV *C = computeSCEVAtScope(V, L); 6724 for (auto &LS : reverse(ValuesAtScopes[V])) 6725 if (LS.first == L) { 6726 LS.second = C; 6727 break; 6728 } 6729 return C; 6730 } 6731 6732 /// This builds up a Constant using the ConstantExpr interface. That way, we 6733 /// will return Constants for objects which aren't represented by a 6734 /// SCEVConstant, because SCEVConstant is restricted to ConstantInt. 6735 /// Returns NULL if the SCEV isn't representable as a Constant. 6736 static Constant *BuildConstantFromSCEV(const SCEV *V) { 6737 switch (static_cast<SCEVTypes>(V->getSCEVType())) { 6738 case scCouldNotCompute: 6739 case scAddRecExpr: 6740 break; 6741 case scConstant: 6742 return cast<SCEVConstant>(V)->getValue(); 6743 case scUnknown: 6744 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue()); 6745 case scSignExtend: { 6746 const SCEVSignExtendExpr *SS = cast<SCEVSignExtendExpr>(V); 6747 if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand())) 6748 return ConstantExpr::getSExt(CastOp, SS->getType()); 6749 break; 6750 } 6751 case scZeroExtend: { 6752 const SCEVZeroExtendExpr *SZ = cast<SCEVZeroExtendExpr>(V); 6753 if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand())) 6754 return ConstantExpr::getZExt(CastOp, SZ->getType()); 6755 break; 6756 } 6757 case scTruncate: { 6758 const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V); 6759 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand())) 6760 return ConstantExpr::getTrunc(CastOp, ST->getType()); 6761 break; 6762 } 6763 case scAddExpr: { 6764 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V); 6765 if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0))) { 6766 if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) { 6767 unsigned AS = PTy->getAddressSpace(); 6768 Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS); 6769 C = ConstantExpr::getBitCast(C, DestPtrTy); 6770 } 6771 for (unsigned i = 1, e = SA->getNumOperands(); i != e; ++i) { 6772 Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i)); 6773 if (!C2) return nullptr; 6774 6775 // First pointer! 6776 if (!C->getType()->isPointerTy() && C2->getType()->isPointerTy()) { 6777 unsigned AS = C2->getType()->getPointerAddressSpace(); 6778 std::swap(C, C2); 6779 Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS); 6780 // The offsets have been converted to bytes. We can add bytes to an 6781 // i8* by GEP with the byte count in the first index. 6782 C = ConstantExpr::getBitCast(C, DestPtrTy); 6783 } 6784 6785 // Don't bother trying to sum two pointers. We probably can't 6786 // statically compute a load that results from it anyway. 6787 if (C2->getType()->isPointerTy()) 6788 return nullptr; 6789 6790 if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) { 6791 if (PTy->getElementType()->isStructTy()) 6792 C2 = ConstantExpr::getIntegerCast( 6793 C2, Type::getInt32Ty(C->getContext()), true); 6794 C = ConstantExpr::getGetElementPtr(PTy->getElementType(), C, C2); 6795 } else 6796 C = ConstantExpr::getAdd(C, C2); 6797 } 6798 return C; 6799 } 6800 break; 6801 } 6802 case scMulExpr: { 6803 const SCEVMulExpr *SM = cast<SCEVMulExpr>(V); 6804 if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0))) { 6805 // Don't bother with pointers at all. 6806 if (C->getType()->isPointerTy()) return nullptr; 6807 for (unsigned i = 1, e = SM->getNumOperands(); i != e; ++i) { 6808 Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i)); 6809 if (!C2 || C2->getType()->isPointerTy()) return nullptr; 6810 C = ConstantExpr::getMul(C, C2); 6811 } 6812 return C; 6813 } 6814 break; 6815 } 6816 case scUDivExpr: { 6817 const SCEVUDivExpr *SU = cast<SCEVUDivExpr>(V); 6818 if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS())) 6819 if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS())) 6820 if (LHS->getType() == RHS->getType()) 6821 return ConstantExpr::getUDiv(LHS, RHS); 6822 break; 6823 } 6824 case scSMaxExpr: 6825 case scUMaxExpr: 6826 break; // TODO: smax, umax. 6827 } 6828 return nullptr; 6829 } 6830 6831 const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { 6832 if (isa<SCEVConstant>(V)) return V; 6833 6834 // If this instruction is evolved from a constant-evolving PHI, compute the 6835 // exit value from the loop without using SCEVs. 6836 if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) { 6837 if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) { 6838 const Loop *LI = this->LI[I->getParent()]; 6839 if (LI && LI->getParentLoop() == L) // Looking for loop exit value. 6840 if (PHINode *PN = dyn_cast<PHINode>(I)) 6841 if (PN->getParent() == LI->getHeader()) { 6842 // Okay, there is no closed form solution for the PHI node. Check 6843 // to see if the loop that contains it has a known backedge-taken 6844 // count. If so, we may be able to force computation of the exit 6845 // value. 6846 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(LI); 6847 if (const SCEVConstant *BTCC = 6848 dyn_cast<SCEVConstant>(BackedgeTakenCount)) { 6849 // Okay, we know how many times the containing loop executes. If 6850 // this is a constant evolving PHI node, get the final value at 6851 // the specified iteration number. 6852 Constant *RV = 6853 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), LI); 6854 if (RV) return getSCEV(RV); 6855 } 6856 } 6857 6858 // Okay, this is an expression that we cannot symbolically evaluate 6859 // into a SCEV. Check to see if it's possible to symbolically evaluate 6860 // the arguments into constants, and if so, try to constant propagate the 6861 // result. This is particularly useful for computing loop exit values. 6862 if (CanConstantFold(I)) { 6863 SmallVector<Constant *, 4> Operands; 6864 bool MadeImprovement = false; 6865 for (Value *Op : I->operands()) { 6866 if (Constant *C = dyn_cast<Constant>(Op)) { 6867 Operands.push_back(C); 6868 continue; 6869 } 6870 6871 // If any of the operands is non-constant and if they are 6872 // non-integer and non-pointer, don't even try to analyze them 6873 // with scev techniques. 6874 if (!isSCEVable(Op->getType())) 6875 return V; 6876 6877 const SCEV *OrigV = getSCEV(Op); 6878 const SCEV *OpV = getSCEVAtScope(OrigV, L); 6879 MadeImprovement |= OrigV != OpV; 6880 6881 Constant *C = BuildConstantFromSCEV(OpV); 6882 if (!C) return V; 6883 if (C->getType() != Op->getType()) 6884 C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false, 6885 Op->getType(), 6886 false), 6887 C, Op->getType()); 6888 Operands.push_back(C); 6889 } 6890 6891 // Check to see if getSCEVAtScope actually made an improvement. 6892 if (MadeImprovement) { 6893 Constant *C = nullptr; 6894 const DataLayout &DL = getDataLayout(); 6895 if (const CmpInst *CI = dyn_cast<CmpInst>(I)) 6896 C = ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0], 6897 Operands[1], DL, &TLI); 6898 else if (const LoadInst *LI = dyn_cast<LoadInst>(I)) { 6899 if (!LI->isVolatile()) 6900 C = ConstantFoldLoadFromConstPtr(Operands[0], LI->getType(), DL); 6901 } else 6902 C = ConstantFoldInstOperands(I, Operands, DL, &TLI); 6903 if (!C) return V; 6904 return getSCEV(C); 6905 } 6906 } 6907 } 6908 6909 // This is some other type of SCEVUnknown, just return it. 6910 return V; 6911 } 6912 6913 if (const SCEVCommutativeExpr *Comm = dyn_cast<SCEVCommutativeExpr>(V)) { 6914 // Avoid performing the look-up in the common case where the specified 6915 // expression has no loop-variant portions. 6916 for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) { 6917 const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L); 6918 if (OpAtScope != Comm->getOperand(i)) { 6919 // Okay, at least one of these operands is loop variant but might be 6920 // foldable. Build a new instance of the folded commutative expression. 6921 SmallVector<const SCEV *, 8> NewOps(Comm->op_begin(), 6922 Comm->op_begin()+i); 6923 NewOps.push_back(OpAtScope); 6924 6925 for (++i; i != e; ++i) { 6926 OpAtScope = getSCEVAtScope(Comm->getOperand(i), L); 6927 NewOps.push_back(OpAtScope); 6928 } 6929 if (isa<SCEVAddExpr>(Comm)) 6930 return getAddExpr(NewOps); 6931 if (isa<SCEVMulExpr>(Comm)) 6932 return getMulExpr(NewOps); 6933 if (isa<SCEVSMaxExpr>(Comm)) 6934 return getSMaxExpr(NewOps); 6935 if (isa<SCEVUMaxExpr>(Comm)) 6936 return getUMaxExpr(NewOps); 6937 llvm_unreachable("Unknown commutative SCEV type!"); 6938 } 6939 } 6940 // If we got here, all operands are loop invariant. 6941 return Comm; 6942 } 6943 6944 if (const SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) { 6945 const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L); 6946 const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L); 6947 if (LHS == Div->getLHS() && RHS == Div->getRHS()) 6948 return Div; // must be loop invariant 6949 return getUDivExpr(LHS, RHS); 6950 } 6951 6952 // If this is a loop recurrence for a loop that does not contain L, then we 6953 // are dealing with the final value computed by the loop. 6954 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) { 6955 // First, attempt to evaluate each operand. 6956 // Avoid performing the look-up in the common case where the specified 6957 // expression has no loop-variant portions. 6958 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { 6959 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L); 6960 if (OpAtScope == AddRec->getOperand(i)) 6961 continue; 6962 6963 // Okay, at least one of these operands is loop variant but might be 6964 // foldable. Build a new instance of the folded commutative expression. 6965 SmallVector<const SCEV *, 8> NewOps(AddRec->op_begin(), 6966 AddRec->op_begin()+i); 6967 NewOps.push_back(OpAtScope); 6968 for (++i; i != e; ++i) 6969 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L)); 6970 6971 const SCEV *FoldedRec = 6972 getAddRecExpr(NewOps, AddRec->getLoop(), 6973 AddRec->getNoWrapFlags(SCEV::FlagNW)); 6974 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec); 6975 // The addrec may be folded to a nonrecurrence, for example, if the 6976 // induction variable is multiplied by zero after constant folding. Go 6977 // ahead and return the folded value. 6978 if (!AddRec) 6979 return FoldedRec; 6980 break; 6981 } 6982 6983 // If the scope is outside the addrec's loop, evaluate it by using the 6984 // loop exit value of the addrec. 6985 if (!AddRec->getLoop()->contains(L)) { 6986 // To evaluate this recurrence, we need to know how many times the AddRec 6987 // loop iterates. Compute this now. 6988 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop()); 6989 if (BackedgeTakenCount == getCouldNotCompute()) return AddRec; 6990 6991 // Then, evaluate the AddRec. 6992 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this); 6993 } 6994 6995 return AddRec; 6996 } 6997 6998 if (const SCEVZeroExtendExpr *Cast = dyn_cast<SCEVZeroExtendExpr>(V)) { 6999 const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); 7000 if (Op == Cast->getOperand()) 7001 return Cast; // must be loop invariant 7002 return getZeroExtendExpr(Op, Cast->getType()); 7003 } 7004 7005 if (const SCEVSignExtendExpr *Cast = dyn_cast<SCEVSignExtendExpr>(V)) { 7006 const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); 7007 if (Op == Cast->getOperand()) 7008 return Cast; // must be loop invariant 7009 return getSignExtendExpr(Op, Cast->getType()); 7010 } 7011 7012 if (const SCEVTruncateExpr *Cast = dyn_cast<SCEVTruncateExpr>(V)) { 7013 const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); 7014 if (Op == Cast->getOperand()) 7015 return Cast; // must be loop invariant 7016 return getTruncateExpr(Op, Cast->getType()); 7017 } 7018 7019 llvm_unreachable("Unknown SCEV type!"); 7020 } 7021 7022 const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) { 7023 return getSCEVAtScope(getSCEV(V), L); 7024 } 7025 7026 /// Finds the minimum unsigned root of the following equation: 7027 /// 7028 /// A * X = B (mod N) 7029 /// 7030 /// where N = 2^BW and BW is the common bit width of A and B. The signedness of 7031 /// A and B isn't important. 7032 /// 7033 /// If the equation does not have a solution, SCEVCouldNotCompute is returned. 7034 static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B, 7035 ScalarEvolution &SE) { 7036 uint32_t BW = A.getBitWidth(); 7037 assert(BW == B.getBitWidth() && "Bit widths must be the same."); 7038 assert(A != 0 && "A must be non-zero."); 7039 7040 // 1. D = gcd(A, N) 7041 // 7042 // The gcd of A and N may have only one prime factor: 2. The number of 7043 // trailing zeros in A is its multiplicity 7044 uint32_t Mult2 = A.countTrailingZeros(); 7045 // D = 2^Mult2 7046 7047 // 2. Check if B is divisible by D. 7048 // 7049 // B is divisible by D if and only if the multiplicity of prime factor 2 for B 7050 // is not less than multiplicity of this prime factor for D. 7051 if (B.countTrailingZeros() < Mult2) 7052 return SE.getCouldNotCompute(); 7053 7054 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic 7055 // modulo (N / D). 7056 // 7057 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent 7058 // (N / D) in general. The inverse itself always fits into BW bits, though, 7059 // so we immediately truncate it. 7060 APInt AD = A.lshr(Mult2).zext(BW + 1); // AD = A / D 7061 APInt Mod(BW + 1, 0); 7062 Mod.setBit(BW - Mult2); // Mod = N / D 7063 APInt I = AD.multiplicativeInverse(Mod).trunc(BW); 7064 7065 // 4. Compute the minimum unsigned root of the equation: 7066 // I * (B / D) mod (N / D) 7067 // To simplify the computation, we factor out the divide by D: 7068 // (I * B mod N) / D 7069 APInt Result = (I * B).lshr(Mult2); 7070 7071 return SE.getConstant(Result); 7072 } 7073 7074 /// Find the roots of the quadratic equation for the given quadratic chrec 7075 /// {L,+,M,+,N}. This returns either the two roots (which might be the same) or 7076 /// two SCEVCouldNotCompute objects. 7077 /// 7078 static Optional<std::pair<const SCEVConstant *,const SCEVConstant *>> 7079 SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { 7080 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!"); 7081 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0)); 7082 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1)); 7083 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2)); 7084 7085 // We currently can only solve this if the coefficients are constants. 7086 if (!LC || !MC || !NC) 7087 return None; 7088 7089 uint32_t BitWidth = LC->getAPInt().getBitWidth(); 7090 const APInt &L = LC->getAPInt(); 7091 const APInt &M = MC->getAPInt(); 7092 const APInt &N = NC->getAPInt(); 7093 APInt Two(BitWidth, 2); 7094 APInt Four(BitWidth, 4); 7095 7096 { 7097 using namespace APIntOps; 7098 const APInt& C = L; 7099 // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C 7100 // The B coefficient is M-N/2 7101 APInt B(M); 7102 B -= sdiv(N,Two); 7103 7104 // The A coefficient is N/2 7105 APInt A(N.sdiv(Two)); 7106 7107 // Compute the B^2-4ac term. 7108 APInt SqrtTerm(B); 7109 SqrtTerm *= B; 7110 SqrtTerm -= Four * (A * C); 7111 7112 if (SqrtTerm.isNegative()) { 7113 // The loop is provably infinite. 7114 return None; 7115 } 7116 7117 // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest 7118 // integer value or else APInt::sqrt() will assert. 7119 APInt SqrtVal(SqrtTerm.sqrt()); 7120 7121 // Compute the two solutions for the quadratic formula. 7122 // The divisions must be performed as signed divisions. 7123 APInt NegB(-B); 7124 APInt TwoA(A << 1); 7125 if (TwoA.isMinValue()) 7126 return None; 7127 7128 LLVMContext &Context = SE.getContext(); 7129 7130 ConstantInt *Solution1 = 7131 ConstantInt::get(Context, (NegB + SqrtVal).sdiv(TwoA)); 7132 ConstantInt *Solution2 = 7133 ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA)); 7134 7135 return std::make_pair(cast<SCEVConstant>(SE.getConstant(Solution1)), 7136 cast<SCEVConstant>(SE.getConstant(Solution2))); 7137 } // end APIntOps namespace 7138 } 7139 7140 ScalarEvolution::ExitLimit 7141 ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, 7142 bool AllowPredicates) { 7143 7144 // This is only used for loops with a "x != y" exit test. The exit condition 7145 // is now expressed as a single expression, V = x-y. So the exit test is 7146 // effectively V != 0. We know and take advantage of the fact that this 7147 // expression only being used in a comparison by zero context. 7148 7149 SmallPtrSet<const SCEVPredicate *, 4> Predicates; 7150 // If the value is a constant 7151 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) { 7152 // If the value is already zero, the branch will execute zero times. 7153 if (C->getValue()->isZero()) return C; 7154 return getCouldNotCompute(); // Otherwise it will loop infinitely. 7155 } 7156 7157 const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V); 7158 if (!AddRec && AllowPredicates) 7159 // Try to make this an AddRec using runtime tests, in the first X 7160 // iterations of this loop, where X is the SCEV expression found by the 7161 // algorithm below. 7162 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates); 7163 7164 if (!AddRec || AddRec->getLoop() != L) 7165 return getCouldNotCompute(); 7166 7167 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of 7168 // the quadratic equation to solve it. 7169 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) { 7170 if (auto Roots = SolveQuadraticEquation(AddRec, *this)) { 7171 const SCEVConstant *R1 = Roots->first; 7172 const SCEVConstant *R2 = Roots->second; 7173 // Pick the smallest positive root value. 7174 if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp( 7175 CmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { 7176 if (!CB->getZExtValue()) 7177 std::swap(R1, R2); // R1 is the minimum root now. 7178 7179 // We can only use this value if the chrec ends up with an exact zero 7180 // value at this index. When solving for "X*X != 5", for example, we 7181 // should not accept a root of 2. 7182 const SCEV *Val = AddRec->evaluateAtIteration(R1, *this); 7183 if (Val->isZero()) 7184 // We found a quadratic root! 7185 return ExitLimit(R1, R1, false, Predicates); 7186 } 7187 } 7188 return getCouldNotCompute(); 7189 } 7190 7191 // Otherwise we can only handle this if it is affine. 7192 if (!AddRec->isAffine()) 7193 return getCouldNotCompute(); 7194 7195 // If this is an affine expression, the execution count of this branch is 7196 // the minimum unsigned root of the following equation: 7197 // 7198 // Start + Step*N = 0 (mod 2^BW) 7199 // 7200 // equivalent to: 7201 // 7202 // Step*N = -Start (mod 2^BW) 7203 // 7204 // where BW is the common bit width of Start and Step. 7205 7206 // Get the initial value for the loop. 7207 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop()); 7208 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop()); 7209 7210 // For now we handle only constant steps. 7211 // 7212 // TODO: Handle a nonconstant Step given AddRec<NUW>. If the 7213 // AddRec is NUW, then (in an unsigned sense) it cannot be counting up to wrap 7214 // to 0, it must be counting down to equal 0. Consequently, N = Start / -Step. 7215 // We have not yet seen any such cases. 7216 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step); 7217 if (!StepC || StepC->getValue()->equalsInt(0)) 7218 return getCouldNotCompute(); 7219 7220 // For positive steps (counting up until unsigned overflow): 7221 // N = -Start/Step (as unsigned) 7222 // For negative steps (counting down to zero): 7223 // N = Start/-Step 7224 // First compute the unsigned distance from zero in the direction of Step. 7225 bool CountDown = StepC->getAPInt().isNegative(); 7226 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start); 7227 7228 // Handle unitary steps, which cannot wraparound. 7229 // 1*N = -Start; -1*N = Start (mod 2^BW), so: 7230 // N = Distance (as unsigned) 7231 if (StepC->getValue()->equalsInt(1) || StepC->getValue()->isAllOnesValue()) { 7232 APInt MaxBECount = getUnsignedRange(Distance).getUnsignedMax(); 7233 7234 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated, 7235 // we end up with a loop whose backedge-taken count is n - 1. Detect this 7236 // case, and see if we can improve the bound. 7237 // 7238 // Explicitly handling this here is necessary because getUnsignedRange 7239 // isn't context-sensitive; it doesn't know that we only care about the 7240 // range inside the loop. 7241 const SCEV *Zero = getZero(Distance->getType()); 7242 const SCEV *One = getOne(Distance->getType()); 7243 const SCEV *DistancePlusOne = getAddExpr(Distance, One); 7244 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) { 7245 // If Distance + 1 doesn't overflow, we can compute the maximum distance 7246 // as "unsigned_max(Distance + 1) - 1". 7247 ConstantRange CR = getUnsignedRange(DistancePlusOne); 7248 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1); 7249 } 7250 return ExitLimit(Distance, getConstant(MaxBECount), false, Predicates); 7251 } 7252 7253 // As a special case, handle the instance where Step is a positive power of 7254 // two. In this case, determining whether Step divides Distance evenly can be 7255 // done by counting and comparing the number of trailing zeros of Step and 7256 // Distance. 7257 if (!CountDown) { 7258 const APInt &StepV = StepC->getAPInt(); 7259 // StepV.isPowerOf2() returns true if StepV is an positive power of two. It 7260 // also returns true if StepV is maximally negative (eg, INT_MIN), but that 7261 // case is not handled as this code is guarded by !CountDown. 7262 if (StepV.isPowerOf2() && 7263 GetMinTrailingZeros(Distance) >= StepV.countTrailingZeros()) { 7264 // Here we've constrained the equation to be of the form 7265 // 7266 // 2^(N + k) * Distance' = (StepV == 2^N) * X (mod 2^W) ... (0) 7267 // 7268 // where we're operating on a W bit wide integer domain and k is 7269 // non-negative. The smallest unsigned solution for X is the trip count. 7270 // 7271 // (0) is equivalent to: 7272 // 7273 // 2^(N + k) * Distance' - 2^N * X = L * 2^W 7274 // <=> 2^N(2^k * Distance' - X) = L * 2^(W - N) * 2^N 7275 // <=> 2^k * Distance' - X = L * 2^(W - N) 7276 // <=> 2^k * Distance' = L * 2^(W - N) + X ... (1) 7277 // 7278 // The smallest X satisfying (1) is unsigned remainder of dividing the LHS 7279 // by 2^(W - N). 7280 // 7281 // <=> X = 2^k * Distance' URem 2^(W - N) ... (2) 7282 // 7283 // E.g. say we're solving 7284 // 7285 // 2 * Val = 2 * X (in i8) ... (3) 7286 // 7287 // then from (2), we get X = Val URem i8 128 (k = 0 in this case). 7288 // 7289 // Note: It is tempting to solve (3) by setting X = Val, but Val is not 7290 // necessarily the smallest unsigned value of X that satisfies (3). 7291 // E.g. if Val is i8 -127 then the smallest value of X that satisfies (3) 7292 // is i8 1, not i8 -127 7293 7294 const auto *Limit = getUDivExactExpr(Distance, Step); 7295 return ExitLimit(Limit, Limit, false, Predicates); 7296 } 7297 } 7298 7299 // If the condition controls loop exit (the loop exits only if the expression 7300 // is true) and the addition is no-wrap we can use unsigned divide to 7301 // compute the backedge count. In this case, the step may not divide the 7302 // distance, but we don't care because if the condition is "missed" the loop 7303 // will have undefined behavior due to wrapping. 7304 if (ControlsExit && AddRec->hasNoSelfWrap() && 7305 loopHasNoAbnormalExits(AddRec->getLoop())) { 7306 const SCEV *Exact = 7307 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); 7308 return ExitLimit(Exact, Exact, false, Predicates); 7309 } 7310 7311 // Then, try to solve the above equation provided that Start is constant. 7312 if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) { 7313 const SCEV *E = SolveLinEquationWithOverflow( 7314 StepC->getValue()->getValue(), -StartC->getValue()->getValue(), *this); 7315 return ExitLimit(E, E, false, Predicates); 7316 } 7317 return getCouldNotCompute(); 7318 } 7319 7320 ScalarEvolution::ExitLimit 7321 ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) { 7322 // Loops that look like: while (X == 0) are very strange indeed. We don't 7323 // handle them yet except for the trivial case. This could be expanded in the 7324 // future as needed. 7325 7326 // If the value is a constant, check to see if it is known to be non-zero 7327 // already. If so, the backedge will execute zero times. 7328 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) { 7329 if (!C->getValue()->isNullValue()) 7330 return getZero(C->getType()); 7331 return getCouldNotCompute(); // Otherwise it will loop infinitely. 7332 } 7333 7334 // We could implement others, but I really doubt anyone writes loops like 7335 // this, and if they did, they would already be constant folded. 7336 return getCouldNotCompute(); 7337 } 7338 7339 std::pair<BasicBlock *, BasicBlock *> 7340 ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) { 7341 // If the block has a unique predecessor, then there is no path from the 7342 // predecessor to the block that does not go through the direct edge 7343 // from the predecessor to the block. 7344 if (BasicBlock *Pred = BB->getSinglePredecessor()) 7345 return {Pred, BB}; 7346 7347 // A loop's header is defined to be a block that dominates the loop. 7348 // If the header has a unique predecessor outside the loop, it must be 7349 // a block that has exactly one successor that can reach the loop. 7350 if (Loop *L = LI.getLoopFor(BB)) 7351 return {L->getLoopPredecessor(), L->getHeader()}; 7352 7353 return {nullptr, nullptr}; 7354 } 7355 7356 /// SCEV structural equivalence is usually sufficient for testing whether two 7357 /// expressions are equal, however for the purposes of looking for a condition 7358 /// guarding a loop, it can be useful to be a little more general, since a 7359 /// front-end may have replicated the controlling expression. 7360 /// 7361 static bool HasSameValue(const SCEV *A, const SCEV *B) { 7362 // Quick check to see if they are the same SCEV. 7363 if (A == B) return true; 7364 7365 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) { 7366 // Not all instructions that are "identical" compute the same value. For 7367 // instance, two distinct alloca instructions allocating the same type are 7368 // identical and do not read memory; but compute distinct values. 7369 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A)); 7370 }; 7371 7372 // Otherwise, if they're both SCEVUnknown, it's possible that they hold 7373 // two different instructions with the same value. Check for this case. 7374 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A)) 7375 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B)) 7376 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue())) 7377 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue())) 7378 if (ComputesEqualValues(AI, BI)) 7379 return true; 7380 7381 // Otherwise assume they may have a different value. 7382 return false; 7383 } 7384 7385 bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, 7386 const SCEV *&LHS, const SCEV *&RHS, 7387 unsigned Depth) { 7388 bool Changed = false; 7389 7390 // If we hit the max recursion limit bail out. 7391 if (Depth >= 3) 7392 return false; 7393 7394 // Canonicalize a constant to the right side. 7395 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) { 7396 // Check for both operands constant. 7397 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) { 7398 if (ConstantExpr::getICmp(Pred, 7399 LHSC->getValue(), 7400 RHSC->getValue())->isNullValue()) 7401 goto trivially_false; 7402 else 7403 goto trivially_true; 7404 } 7405 // Otherwise swap the operands to put the constant on the right. 7406 std::swap(LHS, RHS); 7407 Pred = ICmpInst::getSwappedPredicate(Pred); 7408 Changed = true; 7409 } 7410 7411 // If we're comparing an addrec with a value which is loop-invariant in the 7412 // addrec's loop, put the addrec on the left. Also make a dominance check, 7413 // as both operands could be addrecs loop-invariant in each other's loop. 7414 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) { 7415 const Loop *L = AR->getLoop(); 7416 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) { 7417 std::swap(LHS, RHS); 7418 Pred = ICmpInst::getSwappedPredicate(Pred); 7419 Changed = true; 7420 } 7421 } 7422 7423 // If there's a constant operand, canonicalize comparisons with boundary 7424 // cases, and canonicalize *-or-equal comparisons to regular comparisons. 7425 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) { 7426 const APInt &RA = RC->getAPInt(); 7427 7428 bool SimplifiedByConstantRange = false; 7429 7430 if (!ICmpInst::isEquality(Pred)) { 7431 ConstantRange ExactCR = ConstantRange::makeExactICmpRegion(Pred, RA); 7432 if (ExactCR.isFullSet()) 7433 goto trivially_true; 7434 else if (ExactCR.isEmptySet()) 7435 goto trivially_false; 7436 7437 APInt NewRHS; 7438 CmpInst::Predicate NewPred; 7439 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) && 7440 ICmpInst::isEquality(NewPred)) { 7441 // We were able to convert an inequality to an equality. 7442 Pred = NewPred; 7443 RHS = getConstant(NewRHS); 7444 Changed = SimplifiedByConstantRange = true; 7445 } 7446 } 7447 7448 if (!SimplifiedByConstantRange) { 7449 switch (Pred) { 7450 default: 7451 break; 7452 case ICmpInst::ICMP_EQ: 7453 case ICmpInst::ICMP_NE: 7454 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b. 7455 if (!RA) 7456 if (const SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(LHS)) 7457 if (const SCEVMulExpr *ME = 7458 dyn_cast<SCEVMulExpr>(AE->getOperand(0))) 7459 if (AE->getNumOperands() == 2 && ME->getNumOperands() == 2 && 7460 ME->getOperand(0)->isAllOnesValue()) { 7461 RHS = AE->getOperand(1); 7462 LHS = ME->getOperand(1); 7463 Changed = true; 7464 } 7465 break; 7466 7467 7468 // The "Should have been caught earlier!" messages refer to the fact 7469 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above 7470 // should have fired on the corresponding cases, and canonicalized the 7471 // check to trivially_true or trivially_false. 7472 7473 case ICmpInst::ICMP_UGE: 7474 assert(!RA.isMinValue() && "Should have been caught earlier!"); 7475 Pred = ICmpInst::ICMP_UGT; 7476 RHS = getConstant(RA - 1); 7477 Changed = true; 7478 break; 7479 case ICmpInst::ICMP_ULE: 7480 assert(!RA.isMaxValue() && "Should have been caught earlier!"); 7481 Pred = ICmpInst::ICMP_ULT; 7482 RHS = getConstant(RA + 1); 7483 Changed = true; 7484 break; 7485 case ICmpInst::ICMP_SGE: 7486 assert(!RA.isMinSignedValue() && "Should have been caught earlier!"); 7487 Pred = ICmpInst::ICMP_SGT; 7488 RHS = getConstant(RA - 1); 7489 Changed = true; 7490 break; 7491 case ICmpInst::ICMP_SLE: 7492 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!"); 7493 Pred = ICmpInst::ICMP_SLT; 7494 RHS = getConstant(RA + 1); 7495 Changed = true; 7496 break; 7497 } 7498 } 7499 } 7500 7501 // Check for obvious equality. 7502 if (HasSameValue(LHS, RHS)) { 7503 if (ICmpInst::isTrueWhenEqual(Pred)) 7504 goto trivially_true; 7505 if (ICmpInst::isFalseWhenEqual(Pred)) 7506 goto trivially_false; 7507 } 7508 7509 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by 7510 // adding or subtracting 1 from one of the operands. 7511 switch (Pred) { 7512 case ICmpInst::ICMP_SLE: 7513 if (!getSignedRange(RHS).getSignedMax().isMaxSignedValue()) { 7514 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS, 7515 SCEV::FlagNSW); 7516 Pred = ICmpInst::ICMP_SLT; 7517 Changed = true; 7518 } else if (!getSignedRange(LHS).getSignedMin().isMinSignedValue()) { 7519 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS, 7520 SCEV::FlagNSW); 7521 Pred = ICmpInst::ICMP_SLT; 7522 Changed = true; 7523 } 7524 break; 7525 case ICmpInst::ICMP_SGE: 7526 if (!getSignedRange(RHS).getSignedMin().isMinSignedValue()) { 7527 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS, 7528 SCEV::FlagNSW); 7529 Pred = ICmpInst::ICMP_SGT; 7530 Changed = true; 7531 } else if (!getSignedRange(LHS).getSignedMax().isMaxSignedValue()) { 7532 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS, 7533 SCEV::FlagNSW); 7534 Pred = ICmpInst::ICMP_SGT; 7535 Changed = true; 7536 } 7537 break; 7538 case ICmpInst::ICMP_ULE: 7539 if (!getUnsignedRange(RHS).getUnsignedMax().isMaxValue()) { 7540 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS, 7541 SCEV::FlagNUW); 7542 Pred = ICmpInst::ICMP_ULT; 7543 Changed = true; 7544 } else if (!getUnsignedRange(LHS).getUnsignedMin().isMinValue()) { 7545 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS); 7546 Pred = ICmpInst::ICMP_ULT; 7547 Changed = true; 7548 } 7549 break; 7550 case ICmpInst::ICMP_UGE: 7551 if (!getUnsignedRange(RHS).getUnsignedMin().isMinValue()) { 7552 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS); 7553 Pred = ICmpInst::ICMP_UGT; 7554 Changed = true; 7555 } else if (!getUnsignedRange(LHS).getUnsignedMax().isMaxValue()) { 7556 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS, 7557 SCEV::FlagNUW); 7558 Pred = ICmpInst::ICMP_UGT; 7559 Changed = true; 7560 } 7561 break; 7562 default: 7563 break; 7564 } 7565 7566 // TODO: More simplifications are possible here. 7567 7568 // Recursively simplify until we either hit a recursion limit or nothing 7569 // changes. 7570 if (Changed) 7571 return SimplifyICmpOperands(Pred, LHS, RHS, Depth+1); 7572 7573 return Changed; 7574 7575 trivially_true: 7576 // Return 0 == 0. 7577 LHS = RHS = getConstant(ConstantInt::getFalse(getContext())); 7578 Pred = ICmpInst::ICMP_EQ; 7579 return true; 7580 7581 trivially_false: 7582 // Return 0 != 0. 7583 LHS = RHS = getConstant(ConstantInt::getFalse(getContext())); 7584 Pred = ICmpInst::ICMP_NE; 7585 return true; 7586 } 7587 7588 bool ScalarEvolution::isKnownNegative(const SCEV *S) { 7589 return getSignedRange(S).getSignedMax().isNegative(); 7590 } 7591 7592 bool ScalarEvolution::isKnownPositive(const SCEV *S) { 7593 return getSignedRange(S).getSignedMin().isStrictlyPositive(); 7594 } 7595 7596 bool ScalarEvolution::isKnownNonNegative(const SCEV *S) { 7597 return !getSignedRange(S).getSignedMin().isNegative(); 7598 } 7599 7600 bool ScalarEvolution::isKnownNonPositive(const SCEV *S) { 7601 return !getSignedRange(S).getSignedMax().isStrictlyPositive(); 7602 } 7603 7604 bool ScalarEvolution::isKnownNonZero(const SCEV *S) { 7605 return isKnownNegative(S) || isKnownPositive(S); 7606 } 7607 7608 bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, 7609 const SCEV *LHS, const SCEV *RHS) { 7610 // Canonicalize the inputs first. 7611 (void)SimplifyICmpOperands(Pred, LHS, RHS); 7612 7613 // If LHS or RHS is an addrec, check to see if the condition is true in 7614 // every iteration of the loop. 7615 // If LHS and RHS are both addrec, both conditions must be true in 7616 // every iteration of the loop. 7617 const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS); 7618 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS); 7619 bool LeftGuarded = false; 7620 bool RightGuarded = false; 7621 if (LAR) { 7622 const Loop *L = LAR->getLoop(); 7623 if (isLoopEntryGuardedByCond(L, Pred, LAR->getStart(), RHS) && 7624 isLoopBackedgeGuardedByCond(L, Pred, LAR->getPostIncExpr(*this), RHS)) { 7625 if (!RAR) return true; 7626 LeftGuarded = true; 7627 } 7628 } 7629 if (RAR) { 7630 const Loop *L = RAR->getLoop(); 7631 if (isLoopEntryGuardedByCond(L, Pred, LHS, RAR->getStart()) && 7632 isLoopBackedgeGuardedByCond(L, Pred, LHS, RAR->getPostIncExpr(*this))) { 7633 if (!LAR) return true; 7634 RightGuarded = true; 7635 } 7636 } 7637 if (LeftGuarded && RightGuarded) 7638 return true; 7639 7640 if (isKnownPredicateViaSplitting(Pred, LHS, RHS)) 7641 return true; 7642 7643 // Otherwise see what can be done with known constant ranges. 7644 return isKnownPredicateViaConstantRanges(Pred, LHS, RHS); 7645 } 7646 7647 bool ScalarEvolution::isMonotonicPredicate(const SCEVAddRecExpr *LHS, 7648 ICmpInst::Predicate Pred, 7649 bool &Increasing) { 7650 bool Result = isMonotonicPredicateImpl(LHS, Pred, Increasing); 7651 7652 #ifndef NDEBUG 7653 // Verify an invariant: inverting the predicate should turn a monotonically 7654 // increasing change to a monotonically decreasing one, and vice versa. 7655 bool IncreasingSwapped; 7656 bool ResultSwapped = isMonotonicPredicateImpl( 7657 LHS, ICmpInst::getSwappedPredicate(Pred), IncreasingSwapped); 7658 7659 assert(Result == ResultSwapped && "should be able to analyze both!"); 7660 if (ResultSwapped) 7661 assert(Increasing == !IncreasingSwapped && 7662 "monotonicity should flip as we flip the predicate"); 7663 #endif 7664 7665 return Result; 7666 } 7667 7668 bool ScalarEvolution::isMonotonicPredicateImpl(const SCEVAddRecExpr *LHS, 7669 ICmpInst::Predicate Pred, 7670 bool &Increasing) { 7671 7672 // A zero step value for LHS means the induction variable is essentially a 7673 // loop invariant value. We don't really depend on the predicate actually 7674 // flipping from false to true (for increasing predicates, and the other way 7675 // around for decreasing predicates), all we care about is that *if* the 7676 // predicate changes then it only changes from false to true. 7677 // 7678 // A zero step value in itself is not very useful, but there may be places 7679 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be 7680 // as general as possible. 7681 7682 switch (Pred) { 7683 default: 7684 return false; // Conservative answer 7685 7686 case ICmpInst::ICMP_UGT: 7687 case ICmpInst::ICMP_UGE: 7688 case ICmpInst::ICMP_ULT: 7689 case ICmpInst::ICMP_ULE: 7690 if (!LHS->hasNoUnsignedWrap()) 7691 return false; 7692 7693 Increasing = Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE; 7694 return true; 7695 7696 case ICmpInst::ICMP_SGT: 7697 case ICmpInst::ICMP_SGE: 7698 case ICmpInst::ICMP_SLT: 7699 case ICmpInst::ICMP_SLE: { 7700 if (!LHS->hasNoSignedWrap()) 7701 return false; 7702 7703 const SCEV *Step = LHS->getStepRecurrence(*this); 7704 7705 if (isKnownNonNegative(Step)) { 7706 Increasing = Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE; 7707 return true; 7708 } 7709 7710 if (isKnownNonPositive(Step)) { 7711 Increasing = Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE; 7712 return true; 7713 } 7714 7715 return false; 7716 } 7717 7718 } 7719 7720 llvm_unreachable("switch has default clause!"); 7721 } 7722 7723 bool ScalarEvolution::isLoopInvariantPredicate( 7724 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, 7725 ICmpInst::Predicate &InvariantPred, const SCEV *&InvariantLHS, 7726 const SCEV *&InvariantRHS) { 7727 7728 // If there is a loop-invariant, force it into the RHS, otherwise bail out. 7729 if (!isLoopInvariant(RHS, L)) { 7730 if (!isLoopInvariant(LHS, L)) 7731 return false; 7732 7733 std::swap(LHS, RHS); 7734 Pred = ICmpInst::getSwappedPredicate(Pred); 7735 } 7736 7737 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS); 7738 if (!ArLHS || ArLHS->getLoop() != L) 7739 return false; 7740 7741 bool Increasing; 7742 if (!isMonotonicPredicate(ArLHS, Pred, Increasing)) 7743 return false; 7744 7745 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to 7746 // true as the loop iterates, and the backedge is control dependent on 7747 // "ArLHS `Pred` RHS" == true then we can reason as follows: 7748 // 7749 // * if the predicate was false in the first iteration then the predicate 7750 // is never evaluated again, since the loop exits without taking the 7751 // backedge. 7752 // * if the predicate was true in the first iteration then it will 7753 // continue to be true for all future iterations since it is 7754 // monotonically increasing. 7755 // 7756 // For both the above possibilities, we can replace the loop varying 7757 // predicate with its value on the first iteration of the loop (which is 7758 // loop invariant). 7759 // 7760 // A similar reasoning applies for a monotonically decreasing predicate, by 7761 // replacing true with false and false with true in the above two bullets. 7762 7763 auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred); 7764 7765 if (!isLoopBackedgeGuardedByCond(L, P, LHS, RHS)) 7766 return false; 7767 7768 InvariantPred = Pred; 7769 InvariantLHS = ArLHS->getStart(); 7770 InvariantRHS = RHS; 7771 return true; 7772 } 7773 7774 bool ScalarEvolution::isKnownPredicateViaConstantRanges( 7775 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { 7776 if (HasSameValue(LHS, RHS)) 7777 return ICmpInst::isTrueWhenEqual(Pred); 7778 7779 // This code is split out from isKnownPredicate because it is called from 7780 // within isLoopEntryGuardedByCond. 7781 7782 auto CheckRanges = 7783 [&](const ConstantRange &RangeLHS, const ConstantRange &RangeRHS) { 7784 return ConstantRange::makeSatisfyingICmpRegion(Pred, RangeRHS) 7785 .contains(RangeLHS); 7786 }; 7787 7788 // The check at the top of the function catches the case where the values are 7789 // known to be equal. 7790 if (Pred == CmpInst::ICMP_EQ) 7791 return false; 7792 7793 if (Pred == CmpInst::ICMP_NE) 7794 return CheckRanges(getSignedRange(LHS), getSignedRange(RHS)) || 7795 CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS)) || 7796 isKnownNonZero(getMinusSCEV(LHS, RHS)); 7797 7798 if (CmpInst::isSigned(Pred)) 7799 return CheckRanges(getSignedRange(LHS), getSignedRange(RHS)); 7800 7801 return CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS)); 7802 } 7803 7804 bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, 7805 const SCEV *LHS, 7806 const SCEV *RHS) { 7807 7808 // Match Result to (X + Y)<ExpectedFlags> where Y is a constant integer. 7809 // Return Y via OutY. 7810 auto MatchBinaryAddToConst = 7811 [this](const SCEV *Result, const SCEV *X, APInt &OutY, 7812 SCEV::NoWrapFlags ExpectedFlags) { 7813 const SCEV *NonConstOp, *ConstOp; 7814 SCEV::NoWrapFlags FlagsPresent; 7815 7816 if (!splitBinaryAdd(Result, ConstOp, NonConstOp, FlagsPresent) || 7817 !isa<SCEVConstant>(ConstOp) || NonConstOp != X) 7818 return false; 7819 7820 OutY = cast<SCEVConstant>(ConstOp)->getAPInt(); 7821 return (FlagsPresent & ExpectedFlags) == ExpectedFlags; 7822 }; 7823 7824 APInt C; 7825 7826 switch (Pred) { 7827 default: 7828 break; 7829 7830 case ICmpInst::ICMP_SGE: 7831 std::swap(LHS, RHS); 7832 case ICmpInst::ICMP_SLE: 7833 // X s<= (X + C)<nsw> if C >= 0 7834 if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && C.isNonNegative()) 7835 return true; 7836 7837 // (X + C)<nsw> s<= X if C <= 0 7838 if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && 7839 !C.isStrictlyPositive()) 7840 return true; 7841 break; 7842 7843 case ICmpInst::ICMP_SGT: 7844 std::swap(LHS, RHS); 7845 case ICmpInst::ICMP_SLT: 7846 // X s< (X + C)<nsw> if C > 0 7847 if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && 7848 C.isStrictlyPositive()) 7849 return true; 7850 7851 // (X + C)<nsw> s< X if C < 0 7852 if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && C.isNegative()) 7853 return true; 7854 break; 7855 } 7856 7857 return false; 7858 } 7859 7860 bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred, 7861 const SCEV *LHS, 7862 const SCEV *RHS) { 7863 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate) 7864 return false; 7865 7866 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on 7867 // the stack can result in exponential time complexity. 7868 SaveAndRestore<bool> Restore(ProvingSplitPredicate, true); 7869 7870 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L 7871 // 7872 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use 7873 // isKnownPredicate. isKnownPredicate is more powerful, but also more 7874 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the 7875 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to 7876 // use isKnownPredicate later if needed. 7877 return isKnownNonNegative(RHS) && 7878 isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) && 7879 isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS); 7880 } 7881 7882 bool ScalarEvolution::isImpliedViaGuard(BasicBlock *BB, 7883 ICmpInst::Predicate Pred, 7884 const SCEV *LHS, const SCEV *RHS) { 7885 // No need to even try if we know the module has no guards. 7886 if (!HasGuards) 7887 return false; 7888 7889 return any_of(*BB, [&](Instruction &I) { 7890 using namespace llvm::PatternMatch; 7891 7892 Value *Condition; 7893 return match(&I, m_Intrinsic<Intrinsic::experimental_guard>( 7894 m_Value(Condition))) && 7895 isImpliedCond(Pred, LHS, RHS, Condition, false); 7896 }); 7897 } 7898 7899 /// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is 7900 /// protected by a conditional between LHS and RHS. This is used to 7901 /// to eliminate casts. 7902 bool 7903 ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, 7904 ICmpInst::Predicate Pred, 7905 const SCEV *LHS, const SCEV *RHS) { 7906 // Interpret a null as meaning no loop, where there is obviously no guard 7907 // (interprocedural conditions notwithstanding). 7908 if (!L) return true; 7909 7910 if (isKnownPredicateViaConstantRanges(Pred, LHS, RHS)) 7911 return true; 7912 7913 BasicBlock *Latch = L->getLoopLatch(); 7914 if (!Latch) 7915 return false; 7916 7917 BranchInst *LoopContinuePredicate = 7918 dyn_cast<BranchInst>(Latch->getTerminator()); 7919 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() && 7920 isImpliedCond(Pred, LHS, RHS, 7921 LoopContinuePredicate->getCondition(), 7922 LoopContinuePredicate->getSuccessor(0) != L->getHeader())) 7923 return true; 7924 7925 // We don't want more than one activation of the following loops on the stack 7926 // -- that can lead to O(n!) time complexity. 7927 if (WalkingBEDominatingConds) 7928 return false; 7929 7930 SaveAndRestore<bool> ClearOnExit(WalkingBEDominatingConds, true); 7931 7932 // See if we can exploit a trip count to prove the predicate. 7933 const auto &BETakenInfo = getBackedgeTakenInfo(L); 7934 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this); 7935 if (LatchBECount != getCouldNotCompute()) { 7936 // We know that Latch branches back to the loop header exactly 7937 // LatchBECount times. This means the backdege condition at Latch is 7938 // equivalent to "{0,+,1} u< LatchBECount". 7939 Type *Ty = LatchBECount->getType(); 7940 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW); 7941 const SCEV *LoopCounter = 7942 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags); 7943 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter, 7944 LatchBECount)) 7945 return true; 7946 } 7947 7948 // Check conditions due to any @llvm.assume intrinsics. 7949 for (auto &AssumeVH : AC.assumptions()) { 7950 if (!AssumeVH) 7951 continue; 7952 auto *CI = cast<CallInst>(AssumeVH); 7953 if (!DT.dominates(CI, Latch->getTerminator())) 7954 continue; 7955 7956 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false)) 7957 return true; 7958 } 7959 7960 // If the loop is not reachable from the entry block, we risk running into an 7961 // infinite loop as we walk up into the dom tree. These loops do not matter 7962 // anyway, so we just return a conservative answer when we see them. 7963 if (!DT.isReachableFromEntry(L->getHeader())) 7964 return false; 7965 7966 if (isImpliedViaGuard(Latch, Pred, LHS, RHS)) 7967 return true; 7968 7969 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()]; 7970 DTN != HeaderDTN; DTN = DTN->getIDom()) { 7971 7972 assert(DTN && "should reach the loop header before reaching the root!"); 7973 7974 BasicBlock *BB = DTN->getBlock(); 7975 if (isImpliedViaGuard(BB, Pred, LHS, RHS)) 7976 return true; 7977 7978 BasicBlock *PBB = BB->getSinglePredecessor(); 7979 if (!PBB) 7980 continue; 7981 7982 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator()); 7983 if (!ContinuePredicate || !ContinuePredicate->isConditional()) 7984 continue; 7985 7986 Value *Condition = ContinuePredicate->getCondition(); 7987 7988 // If we have an edge `E` within the loop body that dominates the only 7989 // latch, the condition guarding `E` also guards the backedge. This 7990 // reasoning works only for loops with a single latch. 7991 7992 BasicBlockEdge DominatingEdge(PBB, BB); 7993 if (DominatingEdge.isSingleEdge()) { 7994 // We're constructively (and conservatively) enumerating edges within the 7995 // loop body that dominate the latch. The dominator tree better agree 7996 // with us on this: 7997 assert(DT.dominates(DominatingEdge, Latch) && "should be!"); 7998 7999 if (isImpliedCond(Pred, LHS, RHS, Condition, 8000 BB != ContinuePredicate->getSuccessor(0))) 8001 return true; 8002 } 8003 } 8004 8005 return false; 8006 } 8007 8008 bool 8009 ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, 8010 ICmpInst::Predicate Pred, 8011 const SCEV *LHS, const SCEV *RHS) { 8012 // Interpret a null as meaning no loop, where there is obviously no guard 8013 // (interprocedural conditions notwithstanding). 8014 if (!L) return false; 8015 8016 if (isKnownPredicateViaConstantRanges(Pred, LHS, RHS)) 8017 return true; 8018 8019 // Starting at the loop predecessor, climb up the predecessor chain, as long 8020 // as there are predecessors that can be found that have unique successors 8021 // leading to the original header. 8022 for (std::pair<BasicBlock *, BasicBlock *> 8023 Pair(L->getLoopPredecessor(), L->getHeader()); 8024 Pair.first; 8025 Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) { 8026 8027 if (isImpliedViaGuard(Pair.first, Pred, LHS, RHS)) 8028 return true; 8029 8030 BranchInst *LoopEntryPredicate = 8031 dyn_cast<BranchInst>(Pair.first->getTerminator()); 8032 if (!LoopEntryPredicate || 8033 LoopEntryPredicate->isUnconditional()) 8034 continue; 8035 8036 if (isImpliedCond(Pred, LHS, RHS, 8037 LoopEntryPredicate->getCondition(), 8038 LoopEntryPredicate->getSuccessor(0) != Pair.second)) 8039 return true; 8040 } 8041 8042 // Check conditions due to any @llvm.assume intrinsics. 8043 for (auto &AssumeVH : AC.assumptions()) { 8044 if (!AssumeVH) 8045 continue; 8046 auto *CI = cast<CallInst>(AssumeVH); 8047 if (!DT.dominates(CI, L->getHeader())) 8048 continue; 8049 8050 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false)) 8051 return true; 8052 } 8053 8054 return false; 8055 } 8056 8057 bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, 8058 const SCEV *LHS, const SCEV *RHS, 8059 Value *FoundCondValue, 8060 bool Inverse) { 8061 if (!PendingLoopPredicates.insert(FoundCondValue).second) 8062 return false; 8063 8064 auto ClearOnExit = 8065 make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); }); 8066 8067 // Recursively handle And and Or conditions. 8068 if (BinaryOperator *BO = dyn_cast<BinaryOperator>(FoundCondValue)) { 8069 if (BO->getOpcode() == Instruction::And) { 8070 if (!Inverse) 8071 return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) || 8072 isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse); 8073 } else if (BO->getOpcode() == Instruction::Or) { 8074 if (Inverse) 8075 return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) || 8076 isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse); 8077 } 8078 } 8079 8080 ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue); 8081 if (!ICI) return false; 8082 8083 // Now that we found a conditional branch that dominates the loop or controls 8084 // the loop latch. Check to see if it is the comparison we are looking for. 8085 ICmpInst::Predicate FoundPred; 8086 if (Inverse) 8087 FoundPred = ICI->getInversePredicate(); 8088 else 8089 FoundPred = ICI->getPredicate(); 8090 8091 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0)); 8092 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1)); 8093 8094 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS); 8095 } 8096 8097 bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, 8098 const SCEV *RHS, 8099 ICmpInst::Predicate FoundPred, 8100 const SCEV *FoundLHS, 8101 const SCEV *FoundRHS) { 8102 // Balance the types. 8103 if (getTypeSizeInBits(LHS->getType()) < 8104 getTypeSizeInBits(FoundLHS->getType())) { 8105 if (CmpInst::isSigned(Pred)) { 8106 LHS = getSignExtendExpr(LHS, FoundLHS->getType()); 8107 RHS = getSignExtendExpr(RHS, FoundLHS->getType()); 8108 } else { 8109 LHS = getZeroExtendExpr(LHS, FoundLHS->getType()); 8110 RHS = getZeroExtendExpr(RHS, FoundLHS->getType()); 8111 } 8112 } else if (getTypeSizeInBits(LHS->getType()) > 8113 getTypeSizeInBits(FoundLHS->getType())) { 8114 if (CmpInst::isSigned(FoundPred)) { 8115 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType()); 8116 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType()); 8117 } else { 8118 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType()); 8119 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType()); 8120 } 8121 } 8122 8123 // Canonicalize the query to match the way instcombine will have 8124 // canonicalized the comparison. 8125 if (SimplifyICmpOperands(Pred, LHS, RHS)) 8126 if (LHS == RHS) 8127 return CmpInst::isTrueWhenEqual(Pred); 8128 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS)) 8129 if (FoundLHS == FoundRHS) 8130 return CmpInst::isFalseWhenEqual(FoundPred); 8131 8132 // Check to see if we can make the LHS or RHS match. 8133 if (LHS == FoundRHS || RHS == FoundLHS) { 8134 if (isa<SCEVConstant>(RHS)) { 8135 std::swap(FoundLHS, FoundRHS); 8136 FoundPred = ICmpInst::getSwappedPredicate(FoundPred); 8137 } else { 8138 std::swap(LHS, RHS); 8139 Pred = ICmpInst::getSwappedPredicate(Pred); 8140 } 8141 } 8142 8143 // Check whether the found predicate is the same as the desired predicate. 8144 if (FoundPred == Pred) 8145 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS); 8146 8147 // Check whether swapping the found predicate makes it the same as the 8148 // desired predicate. 8149 if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) { 8150 if (isa<SCEVConstant>(RHS)) 8151 return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS); 8152 else 8153 return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred), 8154 RHS, LHS, FoundLHS, FoundRHS); 8155 } 8156 8157 // Unsigned comparison is the same as signed comparison when both the operands 8158 // are non-negative. 8159 if (CmpInst::isUnsigned(FoundPred) && 8160 CmpInst::getSignedPredicate(FoundPred) == Pred && 8161 isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) 8162 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS); 8163 8164 // Check if we can make progress by sharpening ranges. 8165 if (FoundPred == ICmpInst::ICMP_NE && 8166 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) { 8167 8168 const SCEVConstant *C = nullptr; 8169 const SCEV *V = nullptr; 8170 8171 if (isa<SCEVConstant>(FoundLHS)) { 8172 C = cast<SCEVConstant>(FoundLHS); 8173 V = FoundRHS; 8174 } else { 8175 C = cast<SCEVConstant>(FoundRHS); 8176 V = FoundLHS; 8177 } 8178 8179 // The guarding predicate tells us that C != V. If the known range 8180 // of V is [C, t), we can sharpen the range to [C + 1, t). The 8181 // range we consider has to correspond to same signedness as the 8182 // predicate we're interested in folding. 8183 8184 APInt Min = ICmpInst::isSigned(Pred) ? 8185 getSignedRange(V).getSignedMin() : getUnsignedRange(V).getUnsignedMin(); 8186 8187 if (Min == C->getAPInt()) { 8188 // Given (V >= Min && V != Min) we conclude V >= (Min + 1). 8189 // This is true even if (Min + 1) wraps around -- in case of 8190 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)). 8191 8192 APInt SharperMin = Min + 1; 8193 8194 switch (Pred) { 8195 case ICmpInst::ICMP_SGE: 8196 case ICmpInst::ICMP_UGE: 8197 // We know V `Pred` SharperMin. If this implies LHS `Pred` 8198 // RHS, we're done. 8199 if (isImpliedCondOperands(Pred, LHS, RHS, V, 8200 getConstant(SharperMin))) 8201 return true; 8202 8203 case ICmpInst::ICMP_SGT: 8204 case ICmpInst::ICMP_UGT: 8205 // We know from the range information that (V `Pred` Min || 8206 // V == Min). We know from the guarding condition that !(V 8207 // == Min). This gives us 8208 // 8209 // V `Pred` Min || V == Min && !(V == Min) 8210 // => V `Pred` Min 8211 // 8212 // If V `Pred` Min implies LHS `Pred` RHS, we're done. 8213 8214 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min))) 8215 return true; 8216 8217 default: 8218 // No change 8219 break; 8220 } 8221 } 8222 } 8223 8224 // Check whether the actual condition is beyond sufficient. 8225 if (FoundPred == ICmpInst::ICMP_EQ) 8226 if (ICmpInst::isTrueWhenEqual(Pred)) 8227 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS)) 8228 return true; 8229 if (Pred == ICmpInst::ICMP_NE) 8230 if (!ICmpInst::isTrueWhenEqual(FoundPred)) 8231 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS)) 8232 return true; 8233 8234 // Otherwise assume the worst. 8235 return false; 8236 } 8237 8238 bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr, 8239 const SCEV *&L, const SCEV *&R, 8240 SCEV::NoWrapFlags &Flags) { 8241 const auto *AE = dyn_cast<SCEVAddExpr>(Expr); 8242 if (!AE || AE->getNumOperands() != 2) 8243 return false; 8244 8245 L = AE->getOperand(0); 8246 R = AE->getOperand(1); 8247 Flags = AE->getNoWrapFlags(); 8248 return true; 8249 } 8250 8251 Optional<APInt> ScalarEvolution::computeConstantDifference(const SCEV *More, 8252 const SCEV *Less) { 8253 // We avoid subtracting expressions here because this function is usually 8254 // fairly deep in the call stack (i.e. is called many times). 8255 8256 if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) { 8257 const auto *LAR = cast<SCEVAddRecExpr>(Less); 8258 const auto *MAR = cast<SCEVAddRecExpr>(More); 8259 8260 if (LAR->getLoop() != MAR->getLoop()) 8261 return None; 8262 8263 // We look at affine expressions only; not for correctness but to keep 8264 // getStepRecurrence cheap. 8265 if (!LAR->isAffine() || !MAR->isAffine()) 8266 return None; 8267 8268 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this)) 8269 return None; 8270 8271 Less = LAR->getStart(); 8272 More = MAR->getStart(); 8273 8274 // fall through 8275 } 8276 8277 if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) { 8278 const auto &M = cast<SCEVConstant>(More)->getAPInt(); 8279 const auto &L = cast<SCEVConstant>(Less)->getAPInt(); 8280 return M - L; 8281 } 8282 8283 const SCEV *L, *R; 8284 SCEV::NoWrapFlags Flags; 8285 if (splitBinaryAdd(Less, L, R, Flags)) 8286 if (const auto *LC = dyn_cast<SCEVConstant>(L)) 8287 if (R == More) 8288 return -(LC->getAPInt()); 8289 8290 if (splitBinaryAdd(More, L, R, Flags)) 8291 if (const auto *LC = dyn_cast<SCEVConstant>(L)) 8292 if (R == Less) 8293 return LC->getAPInt(); 8294 8295 return None; 8296 } 8297 8298 bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow( 8299 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, 8300 const SCEV *FoundLHS, const SCEV *FoundRHS) { 8301 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT) 8302 return false; 8303 8304 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS); 8305 if (!AddRecLHS) 8306 return false; 8307 8308 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS); 8309 if (!AddRecFoundLHS) 8310 return false; 8311 8312 // We'd like to let SCEV reason about control dependencies, so we constrain 8313 // both the inequalities to be about add recurrences on the same loop. This 8314 // way we can use isLoopEntryGuardedByCond later. 8315 8316 const Loop *L = AddRecFoundLHS->getLoop(); 8317 if (L != AddRecLHS->getLoop()) 8318 return false; 8319 8320 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1) 8321 // 8322 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C) 8323 // ... (2) 8324 // 8325 // Informal proof for (2), assuming (1) [*]: 8326 // 8327 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**] 8328 // 8329 // Then 8330 // 8331 // FoundLHS s< FoundRHS s< INT_MIN - C 8332 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ] 8333 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ] 8334 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s< 8335 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ] 8336 // <=> FoundLHS + C s< FoundRHS + C 8337 // 8338 // [*]: (1) can be proved by ruling out overflow. 8339 // 8340 // [**]: This can be proved by analyzing all the four possibilities: 8341 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and 8342 // (A s>= 0, B s>= 0). 8343 // 8344 // Note: 8345 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C" 8346 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS 8347 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS 8348 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is 8349 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS + 8350 // C)". 8351 8352 Optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS); 8353 Optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS); 8354 if (!LDiff || !RDiff || *LDiff != *RDiff) 8355 return false; 8356 8357 if (LDiff->isMinValue()) 8358 return true; 8359 8360 APInt FoundRHSLimit; 8361 8362 if (Pred == CmpInst::ICMP_ULT) { 8363 FoundRHSLimit = -(*RDiff); 8364 } else { 8365 assert(Pred == CmpInst::ICMP_SLT && "Checked above!"); 8366 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff; 8367 } 8368 8369 // Try to prove (1) or (2), as needed. 8370 return isLoopEntryGuardedByCond(L, Pred, FoundRHS, 8371 getConstant(FoundRHSLimit)); 8372 } 8373 8374 bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, 8375 const SCEV *LHS, const SCEV *RHS, 8376 const SCEV *FoundLHS, 8377 const SCEV *FoundRHS) { 8378 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS)) 8379 return true; 8380 8381 if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS)) 8382 return true; 8383 8384 return isImpliedCondOperandsHelper(Pred, LHS, RHS, 8385 FoundLHS, FoundRHS) || 8386 // ~x < ~y --> x > y 8387 isImpliedCondOperandsHelper(Pred, LHS, RHS, 8388 getNotSCEV(FoundRHS), 8389 getNotSCEV(FoundLHS)); 8390 } 8391 8392 8393 /// If Expr computes ~A, return A else return nullptr 8394 static const SCEV *MatchNotExpr(const SCEV *Expr) { 8395 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr); 8396 if (!Add || Add->getNumOperands() != 2 || 8397 !Add->getOperand(0)->isAllOnesValue()) 8398 return nullptr; 8399 8400 const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1)); 8401 if (!AddRHS || AddRHS->getNumOperands() != 2 || 8402 !AddRHS->getOperand(0)->isAllOnesValue()) 8403 return nullptr; 8404 8405 return AddRHS->getOperand(1); 8406 } 8407 8408 8409 /// Is MaybeMaxExpr an SMax or UMax of Candidate and some other values? 8410 template<typename MaxExprType> 8411 static bool IsMaxConsistingOf(const SCEV *MaybeMaxExpr, 8412 const SCEV *Candidate) { 8413 const MaxExprType *MaxExpr = dyn_cast<MaxExprType>(MaybeMaxExpr); 8414 if (!MaxExpr) return false; 8415 8416 return find(MaxExpr->operands(), Candidate) != MaxExpr->op_end(); 8417 } 8418 8419 8420 /// Is MaybeMinExpr an SMin or UMin of Candidate and some other values? 8421 template<typename MaxExprType> 8422 static bool IsMinConsistingOf(ScalarEvolution &SE, 8423 const SCEV *MaybeMinExpr, 8424 const SCEV *Candidate) { 8425 const SCEV *MaybeMaxExpr = MatchNotExpr(MaybeMinExpr); 8426 if (!MaybeMaxExpr) 8427 return false; 8428 8429 return IsMaxConsistingOf<MaxExprType>(MaybeMaxExpr, SE.getNotSCEV(Candidate)); 8430 } 8431 8432 static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, 8433 ICmpInst::Predicate Pred, 8434 const SCEV *LHS, const SCEV *RHS) { 8435 8436 // If both sides are affine addrecs for the same loop, with equal 8437 // steps, and we know the recurrences don't wrap, then we only 8438 // need to check the predicate on the starting values. 8439 8440 if (!ICmpInst::isRelational(Pred)) 8441 return false; 8442 8443 const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS); 8444 if (!LAR) 8445 return false; 8446 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS); 8447 if (!RAR) 8448 return false; 8449 if (LAR->getLoop() != RAR->getLoop()) 8450 return false; 8451 if (!LAR->isAffine() || !RAR->isAffine()) 8452 return false; 8453 8454 if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE)) 8455 return false; 8456 8457 SCEV::NoWrapFlags NW = ICmpInst::isSigned(Pred) ? 8458 SCEV::FlagNSW : SCEV::FlagNUW; 8459 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW)) 8460 return false; 8461 8462 return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart()); 8463 } 8464 8465 /// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max 8466 /// expression? 8467 static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, 8468 ICmpInst::Predicate Pred, 8469 const SCEV *LHS, const SCEV *RHS) { 8470 switch (Pred) { 8471 default: 8472 return false; 8473 8474 case ICmpInst::ICMP_SGE: 8475 std::swap(LHS, RHS); 8476 LLVM_FALLTHROUGH; 8477 case ICmpInst::ICMP_SLE: 8478 return 8479 // min(A, ...) <= A 8480 IsMinConsistingOf<SCEVSMaxExpr>(SE, LHS, RHS) || 8481 // A <= max(A, ...) 8482 IsMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS); 8483 8484 case ICmpInst::ICMP_UGE: 8485 std::swap(LHS, RHS); 8486 LLVM_FALLTHROUGH; 8487 case ICmpInst::ICMP_ULE: 8488 return 8489 // min(A, ...) <= A 8490 IsMinConsistingOf<SCEVUMaxExpr>(SE, LHS, RHS) || 8491 // A <= max(A, ...) 8492 IsMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS); 8493 } 8494 8495 llvm_unreachable("covered switch fell through?!"); 8496 } 8497 8498 bool 8499 ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, 8500 const SCEV *LHS, const SCEV *RHS, 8501 const SCEV *FoundLHS, 8502 const SCEV *FoundRHS) { 8503 auto IsKnownPredicateFull = 8504 [this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { 8505 return isKnownPredicateViaConstantRanges(Pred, LHS, RHS) || 8506 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || 8507 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) || 8508 isKnownPredicateViaNoOverflow(Pred, LHS, RHS); 8509 }; 8510 8511 switch (Pred) { 8512 default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); 8513 case ICmpInst::ICMP_EQ: 8514 case ICmpInst::ICMP_NE: 8515 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS)) 8516 return true; 8517 break; 8518 case ICmpInst::ICMP_SLT: 8519 case ICmpInst::ICMP_SLE: 8520 if (IsKnownPredicateFull(ICmpInst::ICMP_SLE, LHS, FoundLHS) && 8521 IsKnownPredicateFull(ICmpInst::ICMP_SGE, RHS, FoundRHS)) 8522 return true; 8523 break; 8524 case ICmpInst::ICMP_SGT: 8525 case ICmpInst::ICMP_SGE: 8526 if (IsKnownPredicateFull(ICmpInst::ICMP_SGE, LHS, FoundLHS) && 8527 IsKnownPredicateFull(ICmpInst::ICMP_SLE, RHS, FoundRHS)) 8528 return true; 8529 break; 8530 case ICmpInst::ICMP_ULT: 8531 case ICmpInst::ICMP_ULE: 8532 if (IsKnownPredicateFull(ICmpInst::ICMP_ULE, LHS, FoundLHS) && 8533 IsKnownPredicateFull(ICmpInst::ICMP_UGE, RHS, FoundRHS)) 8534 return true; 8535 break; 8536 case ICmpInst::ICMP_UGT: 8537 case ICmpInst::ICMP_UGE: 8538 if (IsKnownPredicateFull(ICmpInst::ICMP_UGE, LHS, FoundLHS) && 8539 IsKnownPredicateFull(ICmpInst::ICMP_ULE, RHS, FoundRHS)) 8540 return true; 8541 break; 8542 } 8543 8544 return false; 8545 } 8546 8547 bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, 8548 const SCEV *LHS, 8549 const SCEV *RHS, 8550 const SCEV *FoundLHS, 8551 const SCEV *FoundRHS) { 8552 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS)) 8553 // The restriction on `FoundRHS` be lifted easily -- it exists only to 8554 // reduce the compile time impact of this optimization. 8555 return false; 8556 8557 Optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS); 8558 if (!Addend) 8559 return false; 8560 8561 APInt ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt(); 8562 8563 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the 8564 // antecedent "`FoundLHS` `Pred` `FoundRHS`". 8565 ConstantRange FoundLHSRange = 8566 ConstantRange::makeAllowedICmpRegion(Pred, ConstFoundRHS); 8567 8568 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`: 8569 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend)); 8570 8571 // We can also compute the range of values for `LHS` that satisfy the 8572 // consequent, "`LHS` `Pred` `RHS`": 8573 APInt ConstRHS = cast<SCEVConstant>(RHS)->getAPInt(); 8574 ConstantRange SatisfyingLHSRange = 8575 ConstantRange::makeSatisfyingICmpRegion(Pred, ConstRHS); 8576 8577 // The antecedent implies the consequent if every value of `LHS` that 8578 // satisfies the antecedent also satisfies the consequent. 8579 return SatisfyingLHSRange.contains(LHSRange); 8580 } 8581 8582 bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, 8583 bool IsSigned, bool NoWrap) { 8584 assert(isKnownPositive(Stride) && "Positive stride expected!"); 8585 8586 if (NoWrap) return false; 8587 8588 unsigned BitWidth = getTypeSizeInBits(RHS->getType()); 8589 const SCEV *One = getOne(Stride->getType()); 8590 8591 if (IsSigned) { 8592 APInt MaxRHS = getSignedRange(RHS).getSignedMax(); 8593 APInt MaxValue = APInt::getSignedMaxValue(BitWidth); 8594 APInt MaxStrideMinusOne = getSignedRange(getMinusSCEV(Stride, One)) 8595 .getSignedMax(); 8596 8597 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow! 8598 return (MaxValue - MaxStrideMinusOne).slt(MaxRHS); 8599 } 8600 8601 APInt MaxRHS = getUnsignedRange(RHS).getUnsignedMax(); 8602 APInt MaxValue = APInt::getMaxValue(BitWidth); 8603 APInt MaxStrideMinusOne = getUnsignedRange(getMinusSCEV(Stride, One)) 8604 .getUnsignedMax(); 8605 8606 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow! 8607 return (MaxValue - MaxStrideMinusOne).ult(MaxRHS); 8608 } 8609 8610 bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, 8611 bool IsSigned, bool NoWrap) { 8612 if (NoWrap) return false; 8613 8614 unsigned BitWidth = getTypeSizeInBits(RHS->getType()); 8615 const SCEV *One = getOne(Stride->getType()); 8616 8617 if (IsSigned) { 8618 APInt MinRHS = getSignedRange(RHS).getSignedMin(); 8619 APInt MinValue = APInt::getSignedMinValue(BitWidth); 8620 APInt MaxStrideMinusOne = getSignedRange(getMinusSCEV(Stride, One)) 8621 .getSignedMax(); 8622 8623 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow! 8624 return (MinValue + MaxStrideMinusOne).sgt(MinRHS); 8625 } 8626 8627 APInt MinRHS = getUnsignedRange(RHS).getUnsignedMin(); 8628 APInt MinValue = APInt::getMinValue(BitWidth); 8629 APInt MaxStrideMinusOne = getUnsignedRange(getMinusSCEV(Stride, One)) 8630 .getUnsignedMax(); 8631 8632 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow! 8633 return (MinValue + MaxStrideMinusOne).ugt(MinRHS); 8634 } 8635 8636 const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, 8637 bool Equality) { 8638 const SCEV *One = getOne(Step->getType()); 8639 Delta = Equality ? getAddExpr(Delta, Step) 8640 : getAddExpr(Delta, getMinusSCEV(Step, One)); 8641 return getUDivExpr(Delta, Step); 8642 } 8643 8644 ScalarEvolution::ExitLimit 8645 ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, 8646 const Loop *L, bool IsSigned, 8647 bool ControlsExit, bool AllowPredicates) { 8648 SmallPtrSet<const SCEVPredicate *, 4> Predicates; 8649 // We handle only IV < Invariant 8650 if (!isLoopInvariant(RHS, L)) 8651 return getCouldNotCompute(); 8652 8653 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS); 8654 bool PredicatedIV = false; 8655 8656 if (!IV && AllowPredicates) { 8657 // Try to make this an AddRec using runtime tests, in the first X 8658 // iterations of this loop, where X is the SCEV expression found by the 8659 // algorithm below. 8660 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates); 8661 PredicatedIV = true; 8662 } 8663 8664 // Avoid weird loops 8665 if (!IV || IV->getLoop() != L || !IV->isAffine()) 8666 return getCouldNotCompute(); 8667 8668 bool NoWrap = ControlsExit && 8669 IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW); 8670 8671 const SCEV *Stride = IV->getStepRecurrence(*this); 8672 8673 bool PositiveStride = isKnownPositive(Stride); 8674 8675 // Avoid negative or zero stride values. 8676 if (!PositiveStride) { 8677 // We can compute the correct backedge taken count for loops with unknown 8678 // strides if we can prove that the loop is not an infinite loop with side 8679 // effects. Here's the loop structure we are trying to handle - 8680 // 8681 // i = start 8682 // do { 8683 // A[i] = i; 8684 // i += s; 8685 // } while (i < end); 8686 // 8687 // The backedge taken count for such loops is evaluated as - 8688 // (max(end, start + stride) - start - 1) /u stride 8689 // 8690 // The additional preconditions that we need to check to prove correctness 8691 // of the above formula is as follows - 8692 // 8693 // a) IV is either nuw or nsw depending upon signedness (indicated by the 8694 // NoWrap flag). 8695 // b) loop is single exit with no side effects. 8696 // 8697 // 8698 // Precondition a) implies that if the stride is negative, this is a single 8699 // trip loop. The backedge taken count formula reduces to zero in this case. 8700 // 8701 // Precondition b) implies that the unknown stride cannot be zero otherwise 8702 // we have UB. 8703 // 8704 // The positive stride case is the same as isKnownPositive(Stride) returning 8705 // true (original behavior of the function). 8706 // 8707 // We want to make sure that the stride is truly unknown as there are edge 8708 // cases where ScalarEvolution propagates no wrap flags to the 8709 // post-increment/decrement IV even though the increment/decrement operation 8710 // itself is wrapping. The computed backedge taken count may be wrong in 8711 // such cases. This is prevented by checking that the stride is not known to 8712 // be either positive or non-positive. For example, no wrap flags are 8713 // propagated to the post-increment IV of this loop with a trip count of 2 - 8714 // 8715 // unsigned char i; 8716 // for(i=127; i<128; i+=129) 8717 // A[i] = i; 8718 // 8719 if (PredicatedIV || !NoWrap || isKnownNonPositive(Stride) || 8720 !loopHasNoSideEffects(L)) 8721 return getCouldNotCompute(); 8722 8723 } else if (!Stride->isOne() && 8724 doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap)) 8725 // Avoid proven overflow cases: this will ensure that the backedge taken 8726 // count will not generate any unsigned overflow. Relaxed no-overflow 8727 // conditions exploit NoWrapFlags, allowing to optimize in presence of 8728 // undefined behaviors like the case of C language. 8729 return getCouldNotCompute(); 8730 8731 ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT 8732 : ICmpInst::ICMP_ULT; 8733 const SCEV *Start = IV->getStart(); 8734 const SCEV *End = RHS; 8735 // If the backedge is taken at least once, then it will be taken 8736 // (End-Start)/Stride times (rounded up to a multiple of Stride), where Start 8737 // is the LHS value of the less-than comparison the first time it is evaluated 8738 // and End is the RHS. 8739 const SCEV *BECountIfBackedgeTaken = 8740 computeBECount(getMinusSCEV(End, Start), Stride, false); 8741 // If the loop entry is guarded by the result of the backedge test of the 8742 // first loop iteration, then we know the backedge will be taken at least 8743 // once and so the backedge taken count is as above. If not then we use the 8744 // expression (max(End,Start)-Start)/Stride to describe the backedge count, 8745 // as if the backedge is taken at least once max(End,Start) is End and so the 8746 // result is as above, and if not max(End,Start) is Start so we get a backedge 8747 // count of zero. 8748 const SCEV *BECount; 8749 if (isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) 8750 BECount = BECountIfBackedgeTaken; 8751 else { 8752 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start); 8753 BECount = computeBECount(getMinusSCEV(End, Start), Stride, false); 8754 } 8755 8756 const SCEV *MaxBECount; 8757 bool MaxOrZero = false; 8758 if (isa<SCEVConstant>(BECount)) 8759 MaxBECount = BECount; 8760 else if (isa<SCEVConstant>(BECountIfBackedgeTaken)) { 8761 // If we know exactly how many times the backedge will be taken if it's 8762 // taken at least once, then the backedge count will either be that or 8763 // zero. 8764 MaxBECount = BECountIfBackedgeTaken; 8765 MaxOrZero = true; 8766 } else { 8767 // Calculate the maximum backedge count based on the range of values 8768 // permitted by Start, End, and Stride. 8769 APInt MinStart = IsSigned ? getSignedRange(Start).getSignedMin() 8770 : getUnsignedRange(Start).getUnsignedMin(); 8771 8772 unsigned BitWidth = getTypeSizeInBits(LHS->getType()); 8773 8774 APInt StrideForMaxBECount; 8775 8776 if (PositiveStride) 8777 StrideForMaxBECount = 8778 IsSigned ? getSignedRange(Stride).getSignedMin() 8779 : getUnsignedRange(Stride).getUnsignedMin(); 8780 else 8781 // Using a stride of 1 is safe when computing max backedge taken count for 8782 // a loop with unknown stride. 8783 StrideForMaxBECount = APInt(BitWidth, 1, IsSigned); 8784 8785 APInt Limit = 8786 IsSigned ? APInt::getSignedMaxValue(BitWidth) - (StrideForMaxBECount - 1) 8787 : APInt::getMaxValue(BitWidth) - (StrideForMaxBECount - 1); 8788 8789 // Although End can be a MAX expression we estimate MaxEnd considering only 8790 // the case End = RHS. This is safe because in the other case (End - Start) 8791 // is zero, leading to a zero maximum backedge taken count. 8792 APInt MaxEnd = 8793 IsSigned ? APIntOps::smin(getSignedRange(RHS).getSignedMax(), Limit) 8794 : APIntOps::umin(getUnsignedRange(RHS).getUnsignedMax(), Limit); 8795 8796 MaxBECount = computeBECount(getConstant(MaxEnd - MinStart), 8797 getConstant(StrideForMaxBECount), false); 8798 } 8799 8800 if (isa<SCEVCouldNotCompute>(MaxBECount)) 8801 MaxBECount = BECount; 8802 8803 return ExitLimit(BECount, MaxBECount, MaxOrZero, Predicates); 8804 } 8805 8806 ScalarEvolution::ExitLimit 8807 ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, 8808 const Loop *L, bool IsSigned, 8809 bool ControlsExit, bool AllowPredicates) { 8810 SmallPtrSet<const SCEVPredicate *, 4> Predicates; 8811 // We handle only IV > Invariant 8812 if (!isLoopInvariant(RHS, L)) 8813 return getCouldNotCompute(); 8814 8815 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS); 8816 if (!IV && AllowPredicates) 8817 // Try to make this an AddRec using runtime tests, in the first X 8818 // iterations of this loop, where X is the SCEV expression found by the 8819 // algorithm below. 8820 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates); 8821 8822 // Avoid weird loops 8823 if (!IV || IV->getLoop() != L || !IV->isAffine()) 8824 return getCouldNotCompute(); 8825 8826 bool NoWrap = ControlsExit && 8827 IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW); 8828 8829 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this)); 8830 8831 // Avoid negative or zero stride values 8832 if (!isKnownPositive(Stride)) 8833 return getCouldNotCompute(); 8834 8835 // Avoid proven overflow cases: this will ensure that the backedge taken count 8836 // will not generate any unsigned overflow. Relaxed no-overflow conditions 8837 // exploit NoWrapFlags, allowing to optimize in presence of undefined 8838 // behaviors like the case of C language. 8839 if (!Stride->isOne() && doesIVOverflowOnGT(RHS, Stride, IsSigned, NoWrap)) 8840 return getCouldNotCompute(); 8841 8842 ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SGT 8843 : ICmpInst::ICMP_UGT; 8844 8845 const SCEV *Start = IV->getStart(); 8846 const SCEV *End = RHS; 8847 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) 8848 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start); 8849 8850 const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride, false); 8851 8852 APInt MaxStart = IsSigned ? getSignedRange(Start).getSignedMax() 8853 : getUnsignedRange(Start).getUnsignedMax(); 8854 8855 APInt MinStride = IsSigned ? getSignedRange(Stride).getSignedMin() 8856 : getUnsignedRange(Stride).getUnsignedMin(); 8857 8858 unsigned BitWidth = getTypeSizeInBits(LHS->getType()); 8859 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1) 8860 : APInt::getMinValue(BitWidth) + (MinStride - 1); 8861 8862 // Although End can be a MIN expression we estimate MinEnd considering only 8863 // the case End = RHS. This is safe because in the other case (Start - End) 8864 // is zero, leading to a zero maximum backedge taken count. 8865 APInt MinEnd = 8866 IsSigned ? APIntOps::smax(getSignedRange(RHS).getSignedMin(), Limit) 8867 : APIntOps::umax(getUnsignedRange(RHS).getUnsignedMin(), Limit); 8868 8869 8870 const SCEV *MaxBECount = getCouldNotCompute(); 8871 if (isa<SCEVConstant>(BECount)) 8872 MaxBECount = BECount; 8873 else 8874 MaxBECount = computeBECount(getConstant(MaxStart - MinEnd), 8875 getConstant(MinStride), false); 8876 8877 if (isa<SCEVCouldNotCompute>(MaxBECount)) 8878 MaxBECount = BECount; 8879 8880 return ExitLimit(BECount, MaxBECount, false, Predicates); 8881 } 8882 8883 const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, 8884 ScalarEvolution &SE) const { 8885 if (Range.isFullSet()) // Infinite loop. 8886 return SE.getCouldNotCompute(); 8887 8888 // If the start is a non-zero constant, shift the range to simplify things. 8889 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart())) 8890 if (!SC->getValue()->isZero()) { 8891 SmallVector<const SCEV *, 4> Operands(op_begin(), op_end()); 8892 Operands[0] = SE.getZero(SC->getType()); 8893 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(), 8894 getNoWrapFlags(FlagNW)); 8895 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted)) 8896 return ShiftedAddRec->getNumIterationsInRange( 8897 Range.subtract(SC->getAPInt()), SE); 8898 // This is strange and shouldn't happen. 8899 return SE.getCouldNotCompute(); 8900 } 8901 8902 // The only time we can solve this is when we have all constant indices. 8903 // Otherwise, we cannot determine the overflow conditions. 8904 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); })) 8905 return SE.getCouldNotCompute(); 8906 8907 // Okay at this point we know that all elements of the chrec are constants and 8908 // that the start element is zero. 8909 8910 // First check to see if the range contains zero. If not, the first 8911 // iteration exits. 8912 unsigned BitWidth = SE.getTypeSizeInBits(getType()); 8913 if (!Range.contains(APInt(BitWidth, 0))) 8914 return SE.getZero(getType()); 8915 8916 if (isAffine()) { 8917 // If this is an affine expression then we have this situation: 8918 // Solve {0,+,A} in Range === Ax in Range 8919 8920 // We know that zero is in the range. If A is positive then we know that 8921 // the upper value of the range must be the first possible exit value. 8922 // If A is negative then the lower of the range is the last possible loop 8923 // value. Also note that we already checked for a full range. 8924 APInt One(BitWidth,1); 8925 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt(); 8926 APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower(); 8927 8928 // The exit value should be (End+A)/A. 8929 APInt ExitVal = (End + A).udiv(A); 8930 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal); 8931 8932 // Evaluate at the exit value. If we really did fall out of the valid 8933 // range, then we computed our trip count, otherwise wrap around or other 8934 // things must have happened. 8935 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE); 8936 if (Range.contains(Val->getValue())) 8937 return SE.getCouldNotCompute(); // Something strange happened 8938 8939 // Ensure that the previous value is in the range. This is a sanity check. 8940 assert(Range.contains( 8941 EvaluateConstantChrecAtConstant(this, 8942 ConstantInt::get(SE.getContext(), ExitVal - One), SE)->getValue()) && 8943 "Linear scev computation is off in a bad way!"); 8944 return SE.getConstant(ExitValue); 8945 } else if (isQuadratic()) { 8946 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the 8947 // quadratic equation to solve it. To do this, we must frame our problem in 8948 // terms of figuring out when zero is crossed, instead of when 8949 // Range.getUpper() is crossed. 8950 SmallVector<const SCEV *, 4> NewOps(op_begin(), op_end()); 8951 NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper())); 8952 const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop(), FlagAnyWrap); 8953 8954 // Next, solve the constructed addrec 8955 if (auto Roots = 8956 SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE)) { 8957 const SCEVConstant *R1 = Roots->first; 8958 const SCEVConstant *R2 = Roots->second; 8959 // Pick the smallest positive root value. 8960 if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp( 8961 ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { 8962 if (!CB->getZExtValue()) 8963 std::swap(R1, R2); // R1 is the minimum root now. 8964 8965 // Make sure the root is not off by one. The returned iteration should 8966 // not be in the range, but the previous one should be. When solving 8967 // for "X*X < 5", for example, we should not return a root of 2. 8968 ConstantInt *R1Val = 8969 EvaluateConstantChrecAtConstant(this, R1->getValue(), SE); 8970 if (Range.contains(R1Val->getValue())) { 8971 // The next iteration must be out of the range... 8972 ConstantInt *NextVal = 8973 ConstantInt::get(SE.getContext(), R1->getAPInt() + 1); 8974 8975 R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); 8976 if (!Range.contains(R1Val->getValue())) 8977 return SE.getConstant(NextVal); 8978 return SE.getCouldNotCompute(); // Something strange happened 8979 } 8980 8981 // If R1 was not in the range, then it is a good return value. Make 8982 // sure that R1-1 WAS in the range though, just in case. 8983 ConstantInt *NextVal = 8984 ConstantInt::get(SE.getContext(), R1->getAPInt() - 1); 8985 R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); 8986 if (Range.contains(R1Val->getValue())) 8987 return R1; 8988 return SE.getCouldNotCompute(); // Something strange happened 8989 } 8990 } 8991 } 8992 8993 return SE.getCouldNotCompute(); 8994 } 8995 8996 // Return true when S contains at least an undef value. 8997 static inline bool containsUndefs(const SCEV *S) { 8998 return SCEVExprContains(S, [](const SCEV *S) { 8999 if (const auto *SU = dyn_cast<SCEVUnknown>(S)) 9000 return isa<UndefValue>(SU->getValue()); 9001 else if (const auto *SC = dyn_cast<SCEVConstant>(S)) 9002 return isa<UndefValue>(SC->getValue()); 9003 return false; 9004 }); 9005 } 9006 9007 namespace { 9008 // Collect all steps of SCEV expressions. 9009 struct SCEVCollectStrides { 9010 ScalarEvolution &SE; 9011 SmallVectorImpl<const SCEV *> &Strides; 9012 9013 SCEVCollectStrides(ScalarEvolution &SE, SmallVectorImpl<const SCEV *> &S) 9014 : SE(SE), Strides(S) {} 9015 9016 bool follow(const SCEV *S) { 9017 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) 9018 Strides.push_back(AR->getStepRecurrence(SE)); 9019 return true; 9020 } 9021 bool isDone() const { return false; } 9022 }; 9023 9024 // Collect all SCEVUnknown and SCEVMulExpr expressions. 9025 struct SCEVCollectTerms { 9026 SmallVectorImpl<const SCEV *> &Terms; 9027 9028 SCEVCollectTerms(SmallVectorImpl<const SCEV *> &T) 9029 : Terms(T) {} 9030 9031 bool follow(const SCEV *S) { 9032 if (isa<SCEVUnknown>(S) || isa<SCEVMulExpr>(S) || 9033 isa<SCEVSignExtendExpr>(S)) { 9034 if (!containsUndefs(S)) 9035 Terms.push_back(S); 9036 9037 // Stop recursion: once we collected a term, do not walk its operands. 9038 return false; 9039 } 9040 9041 // Keep looking. 9042 return true; 9043 } 9044 bool isDone() const { return false; } 9045 }; 9046 9047 // Check if a SCEV contains an AddRecExpr. 9048 struct SCEVHasAddRec { 9049 bool &ContainsAddRec; 9050 9051 SCEVHasAddRec(bool &ContainsAddRec) : ContainsAddRec(ContainsAddRec) { 9052 ContainsAddRec = false; 9053 } 9054 9055 bool follow(const SCEV *S) { 9056 if (isa<SCEVAddRecExpr>(S)) { 9057 ContainsAddRec = true; 9058 9059 // Stop recursion: once we collected a term, do not walk its operands. 9060 return false; 9061 } 9062 9063 // Keep looking. 9064 return true; 9065 } 9066 bool isDone() const { return false; } 9067 }; 9068 9069 // Find factors that are multiplied with an expression that (possibly as a 9070 // subexpression) contains an AddRecExpr. In the expression: 9071 // 9072 // 8 * (100 + %p * %q * (%a + {0, +, 1}_loop)) 9073 // 9074 // "%p * %q" are factors multiplied by the expression "(%a + {0, +, 1}_loop)" 9075 // that contains the AddRec {0, +, 1}_loop. %p * %q are likely to be array size 9076 // parameters as they form a product with an induction variable. 9077 // 9078 // This collector expects all array size parameters to be in the same MulExpr. 9079 // It might be necessary to later add support for collecting parameters that are 9080 // spread over different nested MulExpr. 9081 struct SCEVCollectAddRecMultiplies { 9082 SmallVectorImpl<const SCEV *> &Terms; 9083 ScalarEvolution &SE; 9084 9085 SCEVCollectAddRecMultiplies(SmallVectorImpl<const SCEV *> &T, ScalarEvolution &SE) 9086 : Terms(T), SE(SE) {} 9087 9088 bool follow(const SCEV *S) { 9089 if (auto *Mul = dyn_cast<SCEVMulExpr>(S)) { 9090 bool HasAddRec = false; 9091 SmallVector<const SCEV *, 0> Operands; 9092 for (auto Op : Mul->operands()) { 9093 if (isa<SCEVUnknown>(Op)) { 9094 Operands.push_back(Op); 9095 } else { 9096 bool ContainsAddRec; 9097 SCEVHasAddRec ContiansAddRec(ContainsAddRec); 9098 visitAll(Op, ContiansAddRec); 9099 HasAddRec |= ContainsAddRec; 9100 } 9101 } 9102 if (Operands.size() == 0) 9103 return true; 9104 9105 if (!HasAddRec) 9106 return false; 9107 9108 Terms.push_back(SE.getMulExpr(Operands)); 9109 // Stop recursion: once we collected a term, do not walk its operands. 9110 return false; 9111 } 9112 9113 // Keep looking. 9114 return true; 9115 } 9116 bool isDone() const { return false; } 9117 }; 9118 } 9119 9120 /// Find parametric terms in this SCEVAddRecExpr. We first for parameters in 9121 /// two places: 9122 /// 1) The strides of AddRec expressions. 9123 /// 2) Unknowns that are multiplied with AddRec expressions. 9124 void ScalarEvolution::collectParametricTerms(const SCEV *Expr, 9125 SmallVectorImpl<const SCEV *> &Terms) { 9126 SmallVector<const SCEV *, 4> Strides; 9127 SCEVCollectStrides StrideCollector(*this, Strides); 9128 visitAll(Expr, StrideCollector); 9129 9130 DEBUG({ 9131 dbgs() << "Strides:\n"; 9132 for (const SCEV *S : Strides) 9133 dbgs() << *S << "\n"; 9134 }); 9135 9136 for (const SCEV *S : Strides) { 9137 SCEVCollectTerms TermCollector(Terms); 9138 visitAll(S, TermCollector); 9139 } 9140 9141 DEBUG({ 9142 dbgs() << "Terms:\n"; 9143 for (const SCEV *T : Terms) 9144 dbgs() << *T << "\n"; 9145 }); 9146 9147 SCEVCollectAddRecMultiplies MulCollector(Terms, *this); 9148 visitAll(Expr, MulCollector); 9149 } 9150 9151 static bool findArrayDimensionsRec(ScalarEvolution &SE, 9152 SmallVectorImpl<const SCEV *> &Terms, 9153 SmallVectorImpl<const SCEV *> &Sizes) { 9154 int Last = Terms.size() - 1; 9155 const SCEV *Step = Terms[Last]; 9156 9157 // End of recursion. 9158 if (Last == 0) { 9159 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Step)) { 9160 SmallVector<const SCEV *, 2> Qs; 9161 for (const SCEV *Op : M->operands()) 9162 if (!isa<SCEVConstant>(Op)) 9163 Qs.push_back(Op); 9164 9165 Step = SE.getMulExpr(Qs); 9166 } 9167 9168 Sizes.push_back(Step); 9169 return true; 9170 } 9171 9172 for (const SCEV *&Term : Terms) { 9173 // Normalize the terms before the next call to findArrayDimensionsRec. 9174 const SCEV *Q, *R; 9175 SCEVDivision::divide(SE, Term, Step, &Q, &R); 9176 9177 // Bail out when GCD does not evenly divide one of the terms. 9178 if (!R->isZero()) 9179 return false; 9180 9181 Term = Q; 9182 } 9183 9184 // Remove all SCEVConstants. 9185 Terms.erase( 9186 remove_if(Terms, [](const SCEV *E) { return isa<SCEVConstant>(E); }), 9187 Terms.end()); 9188 9189 if (Terms.size() > 0) 9190 if (!findArrayDimensionsRec(SE, Terms, Sizes)) 9191 return false; 9192 9193 Sizes.push_back(Step); 9194 return true; 9195 } 9196 9197 9198 // Returns true when one of the SCEVs of Terms contains a SCEVUnknown parameter. 9199 static inline bool containsParameters(SmallVectorImpl<const SCEV *> &Terms) { 9200 for (const SCEV *T : Terms) 9201 if (SCEVExprContains(T, isa<SCEVUnknown, const SCEV *>)) 9202 return true; 9203 return false; 9204 } 9205 9206 // Return the number of product terms in S. 9207 static inline int numberOfTerms(const SCEV *S) { 9208 if (const SCEVMulExpr *Expr = dyn_cast<SCEVMulExpr>(S)) 9209 return Expr->getNumOperands(); 9210 return 1; 9211 } 9212 9213 static const SCEV *removeConstantFactors(ScalarEvolution &SE, const SCEV *T) { 9214 if (isa<SCEVConstant>(T)) 9215 return nullptr; 9216 9217 if (isa<SCEVUnknown>(T)) 9218 return T; 9219 9220 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(T)) { 9221 SmallVector<const SCEV *, 2> Factors; 9222 for (const SCEV *Op : M->operands()) 9223 if (!isa<SCEVConstant>(Op)) 9224 Factors.push_back(Op); 9225 9226 return SE.getMulExpr(Factors); 9227 } 9228 9229 return T; 9230 } 9231 9232 /// Return the size of an element read or written by Inst. 9233 const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) { 9234 Type *Ty; 9235 if (StoreInst *Store = dyn_cast<StoreInst>(Inst)) 9236 Ty = Store->getValueOperand()->getType(); 9237 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst)) 9238 Ty = Load->getType(); 9239 else 9240 return nullptr; 9241 9242 Type *ETy = getEffectiveSCEVType(PointerType::getUnqual(Ty)); 9243 return getSizeOfExpr(ETy, Ty); 9244 } 9245 9246 void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, 9247 SmallVectorImpl<const SCEV *> &Sizes, 9248 const SCEV *ElementSize) const { 9249 if (Terms.size() < 1 || !ElementSize) 9250 return; 9251 9252 // Early return when Terms do not contain parameters: we do not delinearize 9253 // non parametric SCEVs. 9254 if (!containsParameters(Terms)) 9255 return; 9256 9257 DEBUG({ 9258 dbgs() << "Terms:\n"; 9259 for (const SCEV *T : Terms) 9260 dbgs() << *T << "\n"; 9261 }); 9262 9263 // Remove duplicates. 9264 std::sort(Terms.begin(), Terms.end()); 9265 Terms.erase(std::unique(Terms.begin(), Terms.end()), Terms.end()); 9266 9267 // Put larger terms first. 9268 std::sort(Terms.begin(), Terms.end(), [](const SCEV *LHS, const SCEV *RHS) { 9269 return numberOfTerms(LHS) > numberOfTerms(RHS); 9270 }); 9271 9272 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this); 9273 9274 // Try to divide all terms by the element size. If term is not divisible by 9275 // element size, proceed with the original term. 9276 for (const SCEV *&Term : Terms) { 9277 const SCEV *Q, *R; 9278 SCEVDivision::divide(SE, Term, ElementSize, &Q, &R); 9279 if (!Q->isZero()) 9280 Term = Q; 9281 } 9282 9283 SmallVector<const SCEV *, 4> NewTerms; 9284 9285 // Remove constant factors. 9286 for (const SCEV *T : Terms) 9287 if (const SCEV *NewT = removeConstantFactors(SE, T)) 9288 NewTerms.push_back(NewT); 9289 9290 DEBUG({ 9291 dbgs() << "Terms after sorting:\n"; 9292 for (const SCEV *T : NewTerms) 9293 dbgs() << *T << "\n"; 9294 }); 9295 9296 if (NewTerms.empty() || 9297 !findArrayDimensionsRec(SE, NewTerms, Sizes)) { 9298 Sizes.clear(); 9299 return; 9300 } 9301 9302 // The last element to be pushed into Sizes is the size of an element. 9303 Sizes.push_back(ElementSize); 9304 9305 DEBUG({ 9306 dbgs() << "Sizes:\n"; 9307 for (const SCEV *S : Sizes) 9308 dbgs() << *S << "\n"; 9309 }); 9310 } 9311 9312 void ScalarEvolution::computeAccessFunctions( 9313 const SCEV *Expr, SmallVectorImpl<const SCEV *> &Subscripts, 9314 SmallVectorImpl<const SCEV *> &Sizes) { 9315 9316 // Early exit in case this SCEV is not an affine multivariate function. 9317 if (Sizes.empty()) 9318 return; 9319 9320 if (auto *AR = dyn_cast<SCEVAddRecExpr>(Expr)) 9321 if (!AR->isAffine()) 9322 return; 9323 9324 const SCEV *Res = Expr; 9325 int Last = Sizes.size() - 1; 9326 for (int i = Last; i >= 0; i--) { 9327 const SCEV *Q, *R; 9328 SCEVDivision::divide(*this, Res, Sizes[i], &Q, &R); 9329 9330 DEBUG({ 9331 dbgs() << "Res: " << *Res << "\n"; 9332 dbgs() << "Sizes[i]: " << *Sizes[i] << "\n"; 9333 dbgs() << "Res divided by Sizes[i]:\n"; 9334 dbgs() << "Quotient: " << *Q << "\n"; 9335 dbgs() << "Remainder: " << *R << "\n"; 9336 }); 9337 9338 Res = Q; 9339 9340 // Do not record the last subscript corresponding to the size of elements in 9341 // the array. 9342 if (i == Last) { 9343 9344 // Bail out if the remainder is too complex. 9345 if (isa<SCEVAddRecExpr>(R)) { 9346 Subscripts.clear(); 9347 Sizes.clear(); 9348 return; 9349 } 9350 9351 continue; 9352 } 9353 9354 // Record the access function for the current subscript. 9355 Subscripts.push_back(R); 9356 } 9357 9358 // Also push in last position the remainder of the last division: it will be 9359 // the access function of the innermost dimension. 9360 Subscripts.push_back(Res); 9361 9362 std::reverse(Subscripts.begin(), Subscripts.end()); 9363 9364 DEBUG({ 9365 dbgs() << "Subscripts:\n"; 9366 for (const SCEV *S : Subscripts) 9367 dbgs() << *S << "\n"; 9368 }); 9369 } 9370 9371 /// Splits the SCEV into two vectors of SCEVs representing the subscripts and 9372 /// sizes of an array access. Returns the remainder of the delinearization that 9373 /// is the offset start of the array. The SCEV->delinearize algorithm computes 9374 /// the multiples of SCEV coefficients: that is a pattern matching of sub 9375 /// expressions in the stride and base of a SCEV corresponding to the 9376 /// computation of a GCD (greatest common divisor) of base and stride. When 9377 /// SCEV->delinearize fails, it returns the SCEV unchanged. 9378 /// 9379 /// For example: when analyzing the memory access A[i][j][k] in this loop nest 9380 /// 9381 /// void foo(long n, long m, long o, double A[n][m][o]) { 9382 /// 9383 /// for (long i = 0; i < n; i++) 9384 /// for (long j = 0; j < m; j++) 9385 /// for (long k = 0; k < o; k++) 9386 /// A[i][j][k] = 1.0; 9387 /// } 9388 /// 9389 /// the delinearization input is the following AddRec SCEV: 9390 /// 9391 /// AddRec: {{{%A,+,(8 * %m * %o)}<%for.i>,+,(8 * %o)}<%for.j>,+,8}<%for.k> 9392 /// 9393 /// From this SCEV, we are able to say that the base offset of the access is %A 9394 /// because it appears as an offset that does not divide any of the strides in 9395 /// the loops: 9396 /// 9397 /// CHECK: Base offset: %A 9398 /// 9399 /// and then SCEV->delinearize determines the size of some of the dimensions of 9400 /// the array as these are the multiples by which the strides are happening: 9401 /// 9402 /// CHECK: ArrayDecl[UnknownSize][%m][%o] with elements of sizeof(double) bytes. 9403 /// 9404 /// Note that the outermost dimension remains of UnknownSize because there are 9405 /// no strides that would help identifying the size of the last dimension: when 9406 /// the array has been statically allocated, one could compute the size of that 9407 /// dimension by dividing the overall size of the array by the size of the known 9408 /// dimensions: %m * %o * 8. 9409 /// 9410 /// Finally delinearize provides the access functions for the array reference 9411 /// that does correspond to A[i][j][k] of the above C testcase: 9412 /// 9413 /// CHECK: ArrayRef[{0,+,1}<%for.i>][{0,+,1}<%for.j>][{0,+,1}<%for.k>] 9414 /// 9415 /// The testcases are checking the output of a function pass: 9416 /// DelinearizationPass that walks through all loads and stores of a function 9417 /// asking for the SCEV of the memory access with respect to all enclosing 9418 /// loops, calling SCEV->delinearize on that and printing the results. 9419 9420 void ScalarEvolution::delinearize(const SCEV *Expr, 9421 SmallVectorImpl<const SCEV *> &Subscripts, 9422 SmallVectorImpl<const SCEV *> &Sizes, 9423 const SCEV *ElementSize) { 9424 // First step: collect parametric terms. 9425 SmallVector<const SCEV *, 4> Terms; 9426 collectParametricTerms(Expr, Terms); 9427 9428 if (Terms.empty()) 9429 return; 9430 9431 // Second step: find subscript sizes. 9432 findArrayDimensions(Terms, Sizes, ElementSize); 9433 9434 if (Sizes.empty()) 9435 return; 9436 9437 // Third step: compute the access functions for each subscript. 9438 computeAccessFunctions(Expr, Subscripts, Sizes); 9439 9440 if (Subscripts.empty()) 9441 return; 9442 9443 DEBUG({ 9444 dbgs() << "succeeded to delinearize " << *Expr << "\n"; 9445 dbgs() << "ArrayDecl[UnknownSize]"; 9446 for (const SCEV *S : Sizes) 9447 dbgs() << "[" << *S << "]"; 9448 9449 dbgs() << "\nArrayRef"; 9450 for (const SCEV *S : Subscripts) 9451 dbgs() << "[" << *S << "]"; 9452 dbgs() << "\n"; 9453 }); 9454 } 9455 9456 //===----------------------------------------------------------------------===// 9457 // SCEVCallbackVH Class Implementation 9458 //===----------------------------------------------------------------------===// 9459 9460 void ScalarEvolution::SCEVCallbackVH::deleted() { 9461 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!"); 9462 if (PHINode *PN = dyn_cast<PHINode>(getValPtr())) 9463 SE->ConstantEvolutionLoopExitValue.erase(PN); 9464 SE->eraseValueFromMap(getValPtr()); 9465 // this now dangles! 9466 } 9467 9468 void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) { 9469 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!"); 9470 9471 // Forget all the expressions associated with users of the old value, 9472 // so that future queries will recompute the expressions using the new 9473 // value. 9474 Value *Old = getValPtr(); 9475 SmallVector<User *, 16> Worklist(Old->user_begin(), Old->user_end()); 9476 SmallPtrSet<User *, 8> Visited; 9477 while (!Worklist.empty()) { 9478 User *U = Worklist.pop_back_val(); 9479 // Deleting the Old value will cause this to dangle. Postpone 9480 // that until everything else is done. 9481 if (U == Old) 9482 continue; 9483 if (!Visited.insert(U).second) 9484 continue; 9485 if (PHINode *PN = dyn_cast<PHINode>(U)) 9486 SE->ConstantEvolutionLoopExitValue.erase(PN); 9487 SE->eraseValueFromMap(U); 9488 Worklist.insert(Worklist.end(), U->user_begin(), U->user_end()); 9489 } 9490 // Delete the Old value. 9491 if (PHINode *PN = dyn_cast<PHINode>(Old)) 9492 SE->ConstantEvolutionLoopExitValue.erase(PN); 9493 SE->eraseValueFromMap(Old); 9494 // this now dangles! 9495 } 9496 9497 ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se) 9498 : CallbackVH(V), SE(se) {} 9499 9500 //===----------------------------------------------------------------------===// 9501 // ScalarEvolution Class Implementation 9502 //===----------------------------------------------------------------------===// 9503 9504 ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI, 9505 AssumptionCache &AC, DominatorTree &DT, 9506 LoopInfo &LI) 9507 : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI), 9508 CouldNotCompute(new SCEVCouldNotCompute()), 9509 WalkingBEDominatingConds(false), ProvingSplitPredicate(false), 9510 ValuesAtScopes(64), LoopDispositions(64), BlockDispositions(64), 9511 FirstUnknown(nullptr) { 9512 9513 // To use guards for proving predicates, we need to scan every instruction in 9514 // relevant basic blocks, and not just terminators. Doing this is a waste of 9515 // time if the IR does not actually contain any calls to 9516 // @llvm.experimental.guard, so do a quick check and remember this beforehand. 9517 // 9518 // This pessimizes the case where a pass that preserves ScalarEvolution wants 9519 // to _add_ guards to the module when there weren't any before, and wants 9520 // ScalarEvolution to optimize based on those guards. For now we prefer to be 9521 // efficient in lieu of being smart in that rather obscure case. 9522 9523 auto *GuardDecl = F.getParent()->getFunction( 9524 Intrinsic::getName(Intrinsic::experimental_guard)); 9525 HasGuards = GuardDecl && !GuardDecl->use_empty(); 9526 } 9527 9528 ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) 9529 : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT), 9530 LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)), 9531 ValueExprMap(std::move(Arg.ValueExprMap)), 9532 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)), 9533 WalkingBEDominatingConds(false), ProvingSplitPredicate(false), 9534 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)), 9535 PredicatedBackedgeTakenCounts( 9536 std::move(Arg.PredicatedBackedgeTakenCounts)), 9537 ConstantEvolutionLoopExitValue( 9538 std::move(Arg.ConstantEvolutionLoopExitValue)), 9539 ValuesAtScopes(std::move(Arg.ValuesAtScopes)), 9540 LoopDispositions(std::move(Arg.LoopDispositions)), 9541 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)), 9542 BlockDispositions(std::move(Arg.BlockDispositions)), 9543 UnsignedRanges(std::move(Arg.UnsignedRanges)), 9544 SignedRanges(std::move(Arg.SignedRanges)), 9545 UniqueSCEVs(std::move(Arg.UniqueSCEVs)), 9546 UniquePreds(std::move(Arg.UniquePreds)), 9547 SCEVAllocator(std::move(Arg.SCEVAllocator)), 9548 FirstUnknown(Arg.FirstUnknown) { 9549 Arg.FirstUnknown = nullptr; 9550 } 9551 9552 ScalarEvolution::~ScalarEvolution() { 9553 // Iterate through all the SCEVUnknown instances and call their 9554 // destructors, so that they release their references to their values. 9555 for (SCEVUnknown *U = FirstUnknown; U;) { 9556 SCEVUnknown *Tmp = U; 9557 U = U->Next; 9558 Tmp->~SCEVUnknown(); 9559 } 9560 FirstUnknown = nullptr; 9561 9562 ExprValueMap.clear(); 9563 ValueExprMap.clear(); 9564 HasRecMap.clear(); 9565 9566 // Free any extra memory created for ExitNotTakenInfo in the unlikely event 9567 // that a loop had multiple computable exits. 9568 for (auto &BTCI : BackedgeTakenCounts) 9569 BTCI.second.clear(); 9570 for (auto &BTCI : PredicatedBackedgeTakenCounts) 9571 BTCI.second.clear(); 9572 9573 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage"); 9574 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!"); 9575 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!"); 9576 } 9577 9578 bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) { 9579 return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L)); 9580 } 9581 9582 static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, 9583 const Loop *L) { 9584 // Print all inner loops first 9585 for (Loop *I : *L) 9586 PrintLoopInfo(OS, SE, I); 9587 9588 OS << "Loop "; 9589 L->getHeader()->printAsOperand(OS, /*PrintType=*/false); 9590 OS << ": "; 9591 9592 SmallVector<BasicBlock *, 8> ExitBlocks; 9593 L->getExitBlocks(ExitBlocks); 9594 if (ExitBlocks.size() != 1) 9595 OS << "<multiple exits> "; 9596 9597 if (SE->hasLoopInvariantBackedgeTakenCount(L)) { 9598 OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L); 9599 } else { 9600 OS << "Unpredictable backedge-taken count. "; 9601 } 9602 9603 OS << "\n" 9604 "Loop "; 9605 L->getHeader()->printAsOperand(OS, /*PrintType=*/false); 9606 OS << ": "; 9607 9608 if (!isa<SCEVCouldNotCompute>(SE->getMaxBackedgeTakenCount(L))) { 9609 OS << "max backedge-taken count is " << *SE->getMaxBackedgeTakenCount(L); 9610 if (SE->isBackedgeTakenCountMaxOrZero(L)) 9611 OS << ", actual taken count either this or zero."; 9612 } else { 9613 OS << "Unpredictable max backedge-taken count. "; 9614 } 9615 9616 OS << "\n" 9617 "Loop "; 9618 L->getHeader()->printAsOperand(OS, /*PrintType=*/false); 9619 OS << ": "; 9620 9621 SCEVUnionPredicate Pred; 9622 auto PBT = SE->getPredicatedBackedgeTakenCount(L, Pred); 9623 if (!isa<SCEVCouldNotCompute>(PBT)) { 9624 OS << "Predicated backedge-taken count is " << *PBT << "\n"; 9625 OS << " Predicates:\n"; 9626 Pred.print(OS, 4); 9627 } else { 9628 OS << "Unpredictable predicated backedge-taken count. "; 9629 } 9630 OS << "\n"; 9631 } 9632 9633 static StringRef loopDispositionToStr(ScalarEvolution::LoopDisposition LD) { 9634 switch (LD) { 9635 case ScalarEvolution::LoopVariant: 9636 return "Variant"; 9637 case ScalarEvolution::LoopInvariant: 9638 return "Invariant"; 9639 case ScalarEvolution::LoopComputable: 9640 return "Computable"; 9641 } 9642 llvm_unreachable("Unknown ScalarEvolution::LoopDisposition kind!"); 9643 } 9644 9645 void ScalarEvolution::print(raw_ostream &OS) const { 9646 // ScalarEvolution's implementation of the print method is to print 9647 // out SCEV values of all instructions that are interesting. Doing 9648 // this potentially causes it to create new SCEV objects though, 9649 // which technically conflicts with the const qualifier. This isn't 9650 // observable from outside the class though, so casting away the 9651 // const isn't dangerous. 9652 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this); 9653 9654 OS << "Classifying expressions for: "; 9655 F.printAsOperand(OS, /*PrintType=*/false); 9656 OS << "\n"; 9657 for (Instruction &I : instructions(F)) 9658 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) { 9659 OS << I << '\n'; 9660 OS << " --> "; 9661 const SCEV *SV = SE.getSCEV(&I); 9662 SV->print(OS); 9663 if (!isa<SCEVCouldNotCompute>(SV)) { 9664 OS << " U: "; 9665 SE.getUnsignedRange(SV).print(OS); 9666 OS << " S: "; 9667 SE.getSignedRange(SV).print(OS); 9668 } 9669 9670 const Loop *L = LI.getLoopFor(I.getParent()); 9671 9672 const SCEV *AtUse = SE.getSCEVAtScope(SV, L); 9673 if (AtUse != SV) { 9674 OS << " --> "; 9675 AtUse->print(OS); 9676 if (!isa<SCEVCouldNotCompute>(AtUse)) { 9677 OS << " U: "; 9678 SE.getUnsignedRange(AtUse).print(OS); 9679 OS << " S: "; 9680 SE.getSignedRange(AtUse).print(OS); 9681 } 9682 } 9683 9684 if (L) { 9685 OS << "\t\t" "Exits: "; 9686 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop()); 9687 if (!SE.isLoopInvariant(ExitValue, L)) { 9688 OS << "<<Unknown>>"; 9689 } else { 9690 OS << *ExitValue; 9691 } 9692 9693 bool First = true; 9694 for (auto *Iter = L; Iter; Iter = Iter->getParentLoop()) { 9695 if (First) { 9696 OS << "\t\t" "LoopDispositions: { "; 9697 First = false; 9698 } else { 9699 OS << ", "; 9700 } 9701 9702 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false); 9703 OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, Iter)); 9704 } 9705 9706 for (auto *InnerL : depth_first(L)) { 9707 if (InnerL == L) 9708 continue; 9709 if (First) { 9710 OS << "\t\t" "LoopDispositions: { "; 9711 First = false; 9712 } else { 9713 OS << ", "; 9714 } 9715 9716 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false); 9717 OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, InnerL)); 9718 } 9719 9720 OS << " }"; 9721 } 9722 9723 OS << "\n"; 9724 } 9725 9726 OS << "Determining loop execution counts for: "; 9727 F.printAsOperand(OS, /*PrintType=*/false); 9728 OS << "\n"; 9729 for (Loop *I : LI) 9730 PrintLoopInfo(OS, &SE, I); 9731 } 9732 9733 ScalarEvolution::LoopDisposition 9734 ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) { 9735 auto &Values = LoopDispositions[S]; 9736 for (auto &V : Values) { 9737 if (V.getPointer() == L) 9738 return V.getInt(); 9739 } 9740 Values.emplace_back(L, LoopVariant); 9741 LoopDisposition D = computeLoopDisposition(S, L); 9742 auto &Values2 = LoopDispositions[S]; 9743 for (auto &V : make_range(Values2.rbegin(), Values2.rend())) { 9744 if (V.getPointer() == L) { 9745 V.setInt(D); 9746 break; 9747 } 9748 } 9749 return D; 9750 } 9751 9752 ScalarEvolution::LoopDisposition 9753 ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { 9754 switch (static_cast<SCEVTypes>(S->getSCEVType())) { 9755 case scConstant: 9756 return LoopInvariant; 9757 case scTruncate: 9758 case scZeroExtend: 9759 case scSignExtend: 9760 return getLoopDisposition(cast<SCEVCastExpr>(S)->getOperand(), L); 9761 case scAddRecExpr: { 9762 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S); 9763 9764 // If L is the addrec's loop, it's computable. 9765 if (AR->getLoop() == L) 9766 return LoopComputable; 9767 9768 // Add recurrences are never invariant in the function-body (null loop). 9769 if (!L) 9770 return LoopVariant; 9771 9772 // This recurrence is variant w.r.t. L if L contains AR's loop. 9773 if (L->contains(AR->getLoop())) 9774 return LoopVariant; 9775 9776 // This recurrence is invariant w.r.t. L if AR's loop contains L. 9777 if (AR->getLoop()->contains(L)) 9778 return LoopInvariant; 9779 9780 // This recurrence is variant w.r.t. L if any of its operands 9781 // are variant. 9782 for (auto *Op : AR->operands()) 9783 if (!isLoopInvariant(Op, L)) 9784 return LoopVariant; 9785 9786 // Otherwise it's loop-invariant. 9787 return LoopInvariant; 9788 } 9789 case scAddExpr: 9790 case scMulExpr: 9791 case scUMaxExpr: 9792 case scSMaxExpr: { 9793 bool HasVarying = false; 9794 for (auto *Op : cast<SCEVNAryExpr>(S)->operands()) { 9795 LoopDisposition D = getLoopDisposition(Op, L); 9796 if (D == LoopVariant) 9797 return LoopVariant; 9798 if (D == LoopComputable) 9799 HasVarying = true; 9800 } 9801 return HasVarying ? LoopComputable : LoopInvariant; 9802 } 9803 case scUDivExpr: { 9804 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S); 9805 LoopDisposition LD = getLoopDisposition(UDiv->getLHS(), L); 9806 if (LD == LoopVariant) 9807 return LoopVariant; 9808 LoopDisposition RD = getLoopDisposition(UDiv->getRHS(), L); 9809 if (RD == LoopVariant) 9810 return LoopVariant; 9811 return (LD == LoopInvariant && RD == LoopInvariant) ? 9812 LoopInvariant : LoopComputable; 9813 } 9814 case scUnknown: 9815 // All non-instruction values are loop invariant. All instructions are loop 9816 // invariant if they are not contained in the specified loop. 9817 // Instructions are never considered invariant in the function body 9818 // (null loop) because they are defined within the "loop". 9819 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) 9820 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant; 9821 return LoopInvariant; 9822 case scCouldNotCompute: 9823 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 9824 } 9825 llvm_unreachable("Unknown SCEV kind!"); 9826 } 9827 9828 bool ScalarEvolution::isLoopInvariant(const SCEV *S, const Loop *L) { 9829 return getLoopDisposition(S, L) == LoopInvariant; 9830 } 9831 9832 bool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) { 9833 return getLoopDisposition(S, L) == LoopComputable; 9834 } 9835 9836 ScalarEvolution::BlockDisposition 9837 ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) { 9838 auto &Values = BlockDispositions[S]; 9839 for (auto &V : Values) { 9840 if (V.getPointer() == BB) 9841 return V.getInt(); 9842 } 9843 Values.emplace_back(BB, DoesNotDominateBlock); 9844 BlockDisposition D = computeBlockDisposition(S, BB); 9845 auto &Values2 = BlockDispositions[S]; 9846 for (auto &V : make_range(Values2.rbegin(), Values2.rend())) { 9847 if (V.getPointer() == BB) { 9848 V.setInt(D); 9849 break; 9850 } 9851 } 9852 return D; 9853 } 9854 9855 ScalarEvolution::BlockDisposition 9856 ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) { 9857 switch (static_cast<SCEVTypes>(S->getSCEVType())) { 9858 case scConstant: 9859 return ProperlyDominatesBlock; 9860 case scTruncate: 9861 case scZeroExtend: 9862 case scSignExtend: 9863 return getBlockDisposition(cast<SCEVCastExpr>(S)->getOperand(), BB); 9864 case scAddRecExpr: { 9865 // This uses a "dominates" query instead of "properly dominates" query 9866 // to test for proper dominance too, because the instruction which 9867 // produces the addrec's value is a PHI, and a PHI effectively properly 9868 // dominates its entire containing block. 9869 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S); 9870 if (!DT.dominates(AR->getLoop()->getHeader(), BB)) 9871 return DoesNotDominateBlock; 9872 9873 // Fall through into SCEVNAryExpr handling. 9874 LLVM_FALLTHROUGH; 9875 } 9876 case scAddExpr: 9877 case scMulExpr: 9878 case scUMaxExpr: 9879 case scSMaxExpr: { 9880 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S); 9881 bool Proper = true; 9882 for (const SCEV *NAryOp : NAry->operands()) { 9883 BlockDisposition D = getBlockDisposition(NAryOp, BB); 9884 if (D == DoesNotDominateBlock) 9885 return DoesNotDominateBlock; 9886 if (D == DominatesBlock) 9887 Proper = false; 9888 } 9889 return Proper ? ProperlyDominatesBlock : DominatesBlock; 9890 } 9891 case scUDivExpr: { 9892 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S); 9893 const SCEV *LHS = UDiv->getLHS(), *RHS = UDiv->getRHS(); 9894 BlockDisposition LD = getBlockDisposition(LHS, BB); 9895 if (LD == DoesNotDominateBlock) 9896 return DoesNotDominateBlock; 9897 BlockDisposition RD = getBlockDisposition(RHS, BB); 9898 if (RD == DoesNotDominateBlock) 9899 return DoesNotDominateBlock; 9900 return (LD == ProperlyDominatesBlock && RD == ProperlyDominatesBlock) ? 9901 ProperlyDominatesBlock : DominatesBlock; 9902 } 9903 case scUnknown: 9904 if (Instruction *I = 9905 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) { 9906 if (I->getParent() == BB) 9907 return DominatesBlock; 9908 if (DT.properlyDominates(I->getParent(), BB)) 9909 return ProperlyDominatesBlock; 9910 return DoesNotDominateBlock; 9911 } 9912 return ProperlyDominatesBlock; 9913 case scCouldNotCompute: 9914 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 9915 } 9916 llvm_unreachable("Unknown SCEV kind!"); 9917 } 9918 9919 bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) { 9920 return getBlockDisposition(S, BB) >= DominatesBlock; 9921 } 9922 9923 bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) { 9924 return getBlockDisposition(S, BB) == ProperlyDominatesBlock; 9925 } 9926 9927 bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const { 9928 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; }); 9929 } 9930 9931 void ScalarEvolution::forgetMemoizedResults(const SCEV *S) { 9932 ValuesAtScopes.erase(S); 9933 LoopDispositions.erase(S); 9934 BlockDispositions.erase(S); 9935 UnsignedRanges.erase(S); 9936 SignedRanges.erase(S); 9937 ExprValueMap.erase(S); 9938 HasRecMap.erase(S); 9939 9940 auto RemoveSCEVFromBackedgeMap = 9941 [S, this](DenseMap<const Loop *, BackedgeTakenInfo> &Map) { 9942 for (auto I = Map.begin(), E = Map.end(); I != E;) { 9943 BackedgeTakenInfo &BEInfo = I->second; 9944 if (BEInfo.hasOperand(S, this)) { 9945 BEInfo.clear(); 9946 Map.erase(I++); 9947 } else 9948 ++I; 9949 } 9950 }; 9951 9952 RemoveSCEVFromBackedgeMap(BackedgeTakenCounts); 9953 RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts); 9954 } 9955 9956 typedef DenseMap<const Loop *, std::string> VerifyMap; 9957 9958 /// replaceSubString - Replaces all occurrences of From in Str with To. 9959 static void replaceSubString(std::string &Str, StringRef From, StringRef To) { 9960 size_t Pos = 0; 9961 while ((Pos = Str.find(From, Pos)) != std::string::npos) { 9962 Str.replace(Pos, From.size(), To.data(), To.size()); 9963 Pos += To.size(); 9964 } 9965 } 9966 9967 /// getLoopBackedgeTakenCounts - Helper method for verifyAnalysis. 9968 static void 9969 getLoopBackedgeTakenCounts(Loop *L, VerifyMap &Map, ScalarEvolution &SE) { 9970 std::string &S = Map[L]; 9971 if (S.empty()) { 9972 raw_string_ostream OS(S); 9973 SE.getBackedgeTakenCount(L)->print(OS); 9974 9975 // false and 0 are semantically equivalent. This can happen in dead loops. 9976 replaceSubString(OS.str(), "false", "0"); 9977 // Remove wrap flags, their use in SCEV is highly fragile. 9978 // FIXME: Remove this when SCEV gets smarter about them. 9979 replaceSubString(OS.str(), "<nw>", ""); 9980 replaceSubString(OS.str(), "<nsw>", ""); 9981 replaceSubString(OS.str(), "<nuw>", ""); 9982 } 9983 9984 for (auto *R : reverse(*L)) 9985 getLoopBackedgeTakenCounts(R, Map, SE); // recurse. 9986 } 9987 9988 void ScalarEvolution::verify() const { 9989 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this); 9990 9991 // Gather stringified backedge taken counts for all loops using SCEV's caches. 9992 // FIXME: It would be much better to store actual values instead of strings, 9993 // but SCEV pointers will change if we drop the caches. 9994 VerifyMap BackedgeDumpsOld, BackedgeDumpsNew; 9995 for (LoopInfo::reverse_iterator I = LI.rbegin(), E = LI.rend(); I != E; ++I) 9996 getLoopBackedgeTakenCounts(*I, BackedgeDumpsOld, SE); 9997 9998 // Gather stringified backedge taken counts for all loops using a fresh 9999 // ScalarEvolution object. 10000 ScalarEvolution SE2(F, TLI, AC, DT, LI); 10001 for (LoopInfo::reverse_iterator I = LI.rbegin(), E = LI.rend(); I != E; ++I) 10002 getLoopBackedgeTakenCounts(*I, BackedgeDumpsNew, SE2); 10003 10004 // Now compare whether they're the same with and without caches. This allows 10005 // verifying that no pass changed the cache. 10006 assert(BackedgeDumpsOld.size() == BackedgeDumpsNew.size() && 10007 "New loops suddenly appeared!"); 10008 10009 for (VerifyMap::iterator OldI = BackedgeDumpsOld.begin(), 10010 OldE = BackedgeDumpsOld.end(), 10011 NewI = BackedgeDumpsNew.begin(); 10012 OldI != OldE; ++OldI, ++NewI) { 10013 assert(OldI->first == NewI->first && "Loop order changed!"); 10014 10015 // Compare the stringified SCEVs. We don't care if undef backedgetaken count 10016 // changes. 10017 // FIXME: We currently ignore SCEV changes from/to CouldNotCompute. This 10018 // means that a pass is buggy or SCEV has to learn a new pattern but is 10019 // usually not harmful. 10020 if (OldI->second != NewI->second && 10021 OldI->second.find("undef") == std::string::npos && 10022 NewI->second.find("undef") == std::string::npos && 10023 OldI->second != "***COULDNOTCOMPUTE***" && 10024 NewI->second != "***COULDNOTCOMPUTE***") { 10025 dbgs() << "SCEVValidator: SCEV for loop '" 10026 << OldI->first->getHeader()->getName() 10027 << "' changed from '" << OldI->second 10028 << "' to '" << NewI->second << "'!\n"; 10029 std::abort(); 10030 } 10031 } 10032 10033 // TODO: Verify more things. 10034 } 10035 10036 bool ScalarEvolution::invalidate( 10037 Function &F, const PreservedAnalyses &PA, 10038 FunctionAnalysisManager::Invalidator &Inv) { 10039 // Invalidate the ScalarEvolution object whenever it isn't preserved or one 10040 // of its dependencies is invalidated. 10041 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>(); 10042 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) || 10043 Inv.invalidate<AssumptionAnalysis>(F, PA) || 10044 Inv.invalidate<DominatorTreeAnalysis>(F, PA) || 10045 Inv.invalidate<LoopAnalysis>(F, PA); 10046 } 10047 10048 AnalysisKey ScalarEvolutionAnalysis::Key; 10049 10050 ScalarEvolution ScalarEvolutionAnalysis::run(Function &F, 10051 FunctionAnalysisManager &AM) { 10052 return ScalarEvolution(F, AM.getResult<TargetLibraryAnalysis>(F), 10053 AM.getResult<AssumptionAnalysis>(F), 10054 AM.getResult<DominatorTreeAnalysis>(F), 10055 AM.getResult<LoopAnalysis>(F)); 10056 } 10057 10058 PreservedAnalyses 10059 ScalarEvolutionPrinterPass::run(Function &F, FunctionAnalysisManager &AM) { 10060 AM.getResult<ScalarEvolutionAnalysis>(F).print(OS); 10061 return PreservedAnalyses::all(); 10062 } 10063 10064 INITIALIZE_PASS_BEGIN(ScalarEvolutionWrapperPass, "scalar-evolution", 10065 "Scalar Evolution Analysis", false, true) 10066 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) 10067 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 10068 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 10069 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 10070 INITIALIZE_PASS_END(ScalarEvolutionWrapperPass, "scalar-evolution", 10071 "Scalar Evolution Analysis", false, true) 10072 char ScalarEvolutionWrapperPass::ID = 0; 10073 10074 ScalarEvolutionWrapperPass::ScalarEvolutionWrapperPass() : FunctionPass(ID) { 10075 initializeScalarEvolutionWrapperPassPass(*PassRegistry::getPassRegistry()); 10076 } 10077 10078 bool ScalarEvolutionWrapperPass::runOnFunction(Function &F) { 10079 SE.reset(new ScalarEvolution( 10080 F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(), 10081 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F), 10082 getAnalysis<DominatorTreeWrapperPass>().getDomTree(), 10083 getAnalysis<LoopInfoWrapperPass>().getLoopInfo())); 10084 return false; 10085 } 10086 10087 void ScalarEvolutionWrapperPass::releaseMemory() { SE.reset(); } 10088 10089 void ScalarEvolutionWrapperPass::print(raw_ostream &OS, const Module *) const { 10090 SE->print(OS); 10091 } 10092 10093 void ScalarEvolutionWrapperPass::verifyAnalysis() const { 10094 if (!VerifySCEV) 10095 return; 10096 10097 SE->verify(); 10098 } 10099 10100 void ScalarEvolutionWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { 10101 AU.setPreservesAll(); 10102 AU.addRequiredTransitive<AssumptionCacheTracker>(); 10103 AU.addRequiredTransitive<LoopInfoWrapperPass>(); 10104 AU.addRequiredTransitive<DominatorTreeWrapperPass>(); 10105 AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>(); 10106 } 10107 10108 const SCEVPredicate * 10109 ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS, 10110 const SCEVConstant *RHS) { 10111 FoldingSetNodeID ID; 10112 // Unique this node based on the arguments 10113 ID.AddInteger(SCEVPredicate::P_Equal); 10114 ID.AddPointer(LHS); 10115 ID.AddPointer(RHS); 10116 void *IP = nullptr; 10117 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) 10118 return S; 10119 SCEVEqualPredicate *Eq = new (SCEVAllocator) 10120 SCEVEqualPredicate(ID.Intern(SCEVAllocator), LHS, RHS); 10121 UniquePreds.InsertNode(Eq, IP); 10122 return Eq; 10123 } 10124 10125 const SCEVPredicate *ScalarEvolution::getWrapPredicate( 10126 const SCEVAddRecExpr *AR, 10127 SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { 10128 FoldingSetNodeID ID; 10129 // Unique this node based on the arguments 10130 ID.AddInteger(SCEVPredicate::P_Wrap); 10131 ID.AddPointer(AR); 10132 ID.AddInteger(AddedFlags); 10133 void *IP = nullptr; 10134 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) 10135 return S; 10136 auto *OF = new (SCEVAllocator) 10137 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags); 10138 UniquePreds.InsertNode(OF, IP); 10139 return OF; 10140 } 10141 10142 namespace { 10143 10144 class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> { 10145 public: 10146 /// Rewrites \p S in the context of a loop L and the SCEV predication 10147 /// infrastructure. 10148 /// 10149 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the 10150 /// equivalences present in \p Pred. 10151 /// 10152 /// If \p NewPreds is non-null, rewrite is free to add further predicates to 10153 /// \p NewPreds such that the result will be an AddRecExpr. 10154 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE, 10155 SmallPtrSetImpl<const SCEVPredicate *> *NewPreds, 10156 SCEVUnionPredicate *Pred) { 10157 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred); 10158 return Rewriter.visit(S); 10159 } 10160 10161 SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE, 10162 SmallPtrSetImpl<const SCEVPredicate *> *NewPreds, 10163 SCEVUnionPredicate *Pred) 10164 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {} 10165 10166 const SCEV *visitUnknown(const SCEVUnknown *Expr) { 10167 if (Pred) { 10168 auto ExprPreds = Pred->getPredicatesForExpr(Expr); 10169 for (auto *Pred : ExprPreds) 10170 if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred)) 10171 if (IPred->getLHS() == Expr) 10172 return IPred->getRHS(); 10173 } 10174 10175 return Expr; 10176 } 10177 10178 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { 10179 const SCEV *Operand = visit(Expr->getOperand()); 10180 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand); 10181 if (AR && AR->getLoop() == L && AR->isAffine()) { 10182 // This couldn't be folded because the operand didn't have the nuw 10183 // flag. Add the nusw flag as an assumption that we could make. 10184 const SCEV *Step = AR->getStepRecurrence(SE); 10185 Type *Ty = Expr->getType(); 10186 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW)) 10187 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty), 10188 SE.getSignExtendExpr(Step, Ty), L, 10189 AR->getNoWrapFlags()); 10190 } 10191 return SE.getZeroExtendExpr(Operand, Expr->getType()); 10192 } 10193 10194 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { 10195 const SCEV *Operand = visit(Expr->getOperand()); 10196 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand); 10197 if (AR && AR->getLoop() == L && AR->isAffine()) { 10198 // This couldn't be folded because the operand didn't have the nsw 10199 // flag. Add the nssw flag as an assumption that we could make. 10200 const SCEV *Step = AR->getStepRecurrence(SE); 10201 Type *Ty = Expr->getType(); 10202 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW)) 10203 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty), 10204 SE.getSignExtendExpr(Step, Ty), L, 10205 AR->getNoWrapFlags()); 10206 } 10207 return SE.getSignExtendExpr(Operand, Expr->getType()); 10208 } 10209 10210 private: 10211 bool addOverflowAssumption(const SCEVAddRecExpr *AR, 10212 SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { 10213 auto *A = SE.getWrapPredicate(AR, AddedFlags); 10214 if (!NewPreds) { 10215 // Check if we've already made this assumption. 10216 return Pred && Pred->implies(A); 10217 } 10218 NewPreds->insert(A); 10219 return true; 10220 } 10221 10222 SmallPtrSetImpl<const SCEVPredicate *> *NewPreds; 10223 SCEVUnionPredicate *Pred; 10224 const Loop *L; 10225 }; 10226 } // end anonymous namespace 10227 10228 const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L, 10229 SCEVUnionPredicate &Preds) { 10230 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds); 10231 } 10232 10233 const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates( 10234 const SCEV *S, const Loop *L, 10235 SmallPtrSetImpl<const SCEVPredicate *> &Preds) { 10236 10237 SmallPtrSet<const SCEVPredicate *, 4> TransformPreds; 10238 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr); 10239 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S); 10240 10241 if (!AddRec) 10242 return nullptr; 10243 10244 // Since the transformation was successful, we can now transfer the SCEV 10245 // predicates. 10246 for (auto *P : TransformPreds) 10247 Preds.insert(P); 10248 10249 return AddRec; 10250 } 10251 10252 /// SCEV predicates 10253 SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID, 10254 SCEVPredicateKind Kind) 10255 : FastID(ID), Kind(Kind) {} 10256 10257 SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID, 10258 const SCEVUnknown *LHS, 10259 const SCEVConstant *RHS) 10260 : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {} 10261 10262 bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const { 10263 const auto *Op = dyn_cast<SCEVEqualPredicate>(N); 10264 10265 if (!Op) 10266 return false; 10267 10268 return Op->LHS == LHS && Op->RHS == RHS; 10269 } 10270 10271 bool SCEVEqualPredicate::isAlwaysTrue() const { return false; } 10272 10273 const SCEV *SCEVEqualPredicate::getExpr() const { return LHS; } 10274 10275 void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const { 10276 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n"; 10277 } 10278 10279 SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID, 10280 const SCEVAddRecExpr *AR, 10281 IncrementWrapFlags Flags) 10282 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {} 10283 10284 const SCEV *SCEVWrapPredicate::getExpr() const { return AR; } 10285 10286 bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const { 10287 const auto *Op = dyn_cast<SCEVWrapPredicate>(N); 10288 10289 return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags; 10290 } 10291 10292 bool SCEVWrapPredicate::isAlwaysTrue() const { 10293 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags(); 10294 IncrementWrapFlags IFlags = Flags; 10295 10296 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags) 10297 IFlags = clearFlags(IFlags, IncrementNSSW); 10298 10299 return IFlags == IncrementAnyWrap; 10300 } 10301 10302 void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const { 10303 OS.indent(Depth) << *getExpr() << " Added Flags: "; 10304 if (SCEVWrapPredicate::IncrementNUSW & getFlags()) 10305 OS << "<nusw>"; 10306 if (SCEVWrapPredicate::IncrementNSSW & getFlags()) 10307 OS << "<nssw>"; 10308 OS << "\n"; 10309 } 10310 10311 SCEVWrapPredicate::IncrementWrapFlags 10312 SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR, 10313 ScalarEvolution &SE) { 10314 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap; 10315 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags(); 10316 10317 // We can safely transfer the NSW flag as NSSW. 10318 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags) 10319 ImpliedFlags = IncrementNSSW; 10320 10321 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) { 10322 // If the increment is positive, the SCEV NUW flag will also imply the 10323 // WrapPredicate NUSW flag. 10324 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE))) 10325 if (Step->getValue()->getValue().isNonNegative()) 10326 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW); 10327 } 10328 10329 return ImpliedFlags; 10330 } 10331 10332 /// Union predicates don't get cached so create a dummy set ID for it. 10333 SCEVUnionPredicate::SCEVUnionPredicate() 10334 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {} 10335 10336 bool SCEVUnionPredicate::isAlwaysTrue() const { 10337 return all_of(Preds, 10338 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); }); 10339 } 10340 10341 ArrayRef<const SCEVPredicate *> 10342 SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) { 10343 auto I = SCEVToPreds.find(Expr); 10344 if (I == SCEVToPreds.end()) 10345 return ArrayRef<const SCEVPredicate *>(); 10346 return I->second; 10347 } 10348 10349 bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const { 10350 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) 10351 return all_of(Set->Preds, 10352 [this](const SCEVPredicate *I) { return this->implies(I); }); 10353 10354 auto ScevPredsIt = SCEVToPreds.find(N->getExpr()); 10355 if (ScevPredsIt == SCEVToPreds.end()) 10356 return false; 10357 auto &SCEVPreds = ScevPredsIt->second; 10358 10359 return any_of(SCEVPreds, 10360 [N](const SCEVPredicate *I) { return I->implies(N); }); 10361 } 10362 10363 const SCEV *SCEVUnionPredicate::getExpr() const { return nullptr; } 10364 10365 void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const { 10366 for (auto Pred : Preds) 10367 Pred->print(OS, Depth); 10368 } 10369 10370 void SCEVUnionPredicate::add(const SCEVPredicate *N) { 10371 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) { 10372 for (auto Pred : Set->Preds) 10373 add(Pred); 10374 return; 10375 } 10376 10377 if (implies(N)) 10378 return; 10379 10380 const SCEV *Key = N->getExpr(); 10381 assert(Key && "Only SCEVUnionPredicate doesn't have an " 10382 " associated expression!"); 10383 10384 SCEVToPreds[Key].push_back(N); 10385 Preds.push_back(N); 10386 } 10387 10388 PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE, 10389 Loop &L) 10390 : SE(SE), L(L), Generation(0), BackedgeCount(nullptr) {} 10391 10392 const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { 10393 const SCEV *Expr = SE.getSCEV(V); 10394 RewriteEntry &Entry = RewriteMap[Expr]; 10395 10396 // If we already have an entry and the version matches, return it. 10397 if (Entry.second && Generation == Entry.first) 10398 return Entry.second; 10399 10400 // We found an entry but it's stale. Rewrite the stale entry 10401 // according to the current predicate. 10402 if (Entry.second) 10403 Expr = Entry.second; 10404 10405 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, Preds); 10406 Entry = {Generation, NewSCEV}; 10407 10408 return NewSCEV; 10409 } 10410 10411 const SCEV *PredicatedScalarEvolution::getBackedgeTakenCount() { 10412 if (!BackedgeCount) { 10413 SCEVUnionPredicate BackedgePred; 10414 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, BackedgePred); 10415 addPredicate(BackedgePred); 10416 } 10417 return BackedgeCount; 10418 } 10419 10420 void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) { 10421 if (Preds.implies(&Pred)) 10422 return; 10423 Preds.add(&Pred); 10424 updateGeneration(); 10425 } 10426 10427 const SCEVUnionPredicate &PredicatedScalarEvolution::getUnionPredicate() const { 10428 return Preds; 10429 } 10430 10431 void PredicatedScalarEvolution::updateGeneration() { 10432 // If the generation number wrapped recompute everything. 10433 if (++Generation == 0) { 10434 for (auto &II : RewriteMap) { 10435 const SCEV *Rewritten = II.second.second; 10436 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, Preds)}; 10437 } 10438 } 10439 } 10440 10441 void PredicatedScalarEvolution::setNoOverflow( 10442 Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { 10443 const SCEV *Expr = getSCEV(V); 10444 const auto *AR = cast<SCEVAddRecExpr>(Expr); 10445 10446 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE); 10447 10448 // Clear the statically implied flags. 10449 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags); 10450 addPredicate(*SE.getWrapPredicate(AR, Flags)); 10451 10452 auto II = FlagsMap.insert({V, Flags}); 10453 if (!II.second) 10454 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second); 10455 } 10456 10457 bool PredicatedScalarEvolution::hasNoOverflow( 10458 Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { 10459 const SCEV *Expr = getSCEV(V); 10460 const auto *AR = cast<SCEVAddRecExpr>(Expr); 10461 10462 Flags = SCEVWrapPredicate::clearFlags( 10463 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE)); 10464 10465 auto II = FlagsMap.find(V); 10466 10467 if (II != FlagsMap.end()) 10468 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second); 10469 10470 return Flags == SCEVWrapPredicate::IncrementAnyWrap; 10471 } 10472 10473 const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) { 10474 const SCEV *Expr = this->getSCEV(V); 10475 SmallPtrSet<const SCEVPredicate *, 4> NewPreds; 10476 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds); 10477 10478 if (!New) 10479 return nullptr; 10480 10481 for (auto *P : NewPreds) 10482 Preds.add(P); 10483 10484 updateGeneration(); 10485 RewriteMap[SE.getSCEV(V)] = {Generation, New}; 10486 return New; 10487 } 10488 10489 PredicatedScalarEvolution::PredicatedScalarEvolution( 10490 const PredicatedScalarEvolution &Init) 10491 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), Preds(Init.Preds), 10492 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) { 10493 for (const auto &I : Init.FlagsMap) 10494 FlagsMap.insert(I); 10495 } 10496 10497 void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const { 10498 // For each block. 10499 for (auto *BB : L.getBlocks()) 10500 for (auto &I : *BB) { 10501 if (!SE.isSCEVable(I.getType())) 10502 continue; 10503 10504 auto *Expr = SE.getSCEV(&I); 10505 auto II = RewriteMap.find(Expr); 10506 10507 if (II == RewriteMap.end()) 10508 continue; 10509 10510 // Don't print things that are not interesting. 10511 if (II->second.second == Expr) 10512 continue; 10513 10514 OS.indent(Depth) << "[PSE]" << I << ":\n"; 10515 OS.indent(Depth + 2) << *Expr << "\n"; 10516 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n"; 10517 } 10518 } 10519