1 2 #include "polly/Support/SCEVValidator.h" 3 #include "polly/ScopDetection.h" 4 #include "llvm/Analysis/RegionInfo.h" 5 #include "llvm/Analysis/ScalarEvolution.h" 6 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 7 #include "llvm/Support/Debug.h" 8 9 using namespace llvm; 10 using namespace polly; 11 12 #define DEBUG_TYPE "polly-scev-validator" 13 14 namespace SCEVType { 15 /// The type of a SCEV 16 /// 17 /// To check for the validity of a SCEV we assign to each SCEV a type. The 18 /// possible types are INT, PARAM, IV and INVALID. The order of the types is 19 /// important. The subexpressions of SCEV with a type X can only have a type 20 /// that is smaller or equal than X. 21 enum TYPE { 22 // An integer value. 23 INT, 24 25 // An expression that is constant during the execution of the Scop, 26 // but that may depend on parameters unknown at compile time. 27 PARAM, 28 29 // An expression that may change during the execution of the SCoP. 30 IV, 31 32 // An invalid expression. 33 INVALID 34 }; 35 } // namespace SCEVType 36 37 /// The result the validator returns for a SCEV expression. 38 class ValidatorResult { 39 /// The type of the expression 40 SCEVType::TYPE Type; 41 42 /// The set of Parameters in the expression. 43 ParameterSetTy Parameters; 44 45 public: 46 /// The copy constructor 47 ValidatorResult(const ValidatorResult &Source) { 48 Type = Source.Type; 49 Parameters = Source.Parameters; 50 } 51 52 /// Construct a result with a certain type and no parameters. 53 ValidatorResult(SCEVType::TYPE Type) : Type(Type) { 54 assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter"); 55 } 56 57 /// Construct a result with a certain type and a single parameter. 58 ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) { 59 Parameters.insert(Expr); 60 } 61 62 /// Get the type of the ValidatorResult. 63 SCEVType::TYPE getType() { return Type; } 64 65 /// Is the analyzed SCEV constant during the execution of the SCoP. 66 bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; } 67 68 /// Is the analyzed SCEV valid. 69 bool isValid() { return Type != SCEVType::INVALID; } 70 71 /// Is the analyzed SCEV of Type IV. 72 bool isIV() { return Type == SCEVType::IV; } 73 74 /// Is the analyzed SCEV of Type INT. 75 bool isINT() { return Type == SCEVType::INT; } 76 77 /// Is the analyzed SCEV of Type PARAM. 78 bool isPARAM() { return Type == SCEVType::PARAM; } 79 80 /// Get the parameters of this validator result. 81 const ParameterSetTy &getParameters() { return Parameters; } 82 83 /// Add the parameters of Source to this result. 84 void addParamsFrom(const ValidatorResult &Source) { 85 Parameters.insert(Source.Parameters.begin(), Source.Parameters.end()); 86 } 87 88 /// Merge a result. 89 /// 90 /// This means to merge the parameters and to set the Type to the most 91 /// specific Type that matches both. 92 void merge(const ValidatorResult &ToMerge) { 93 Type = std::max(Type, ToMerge.Type); 94 addParamsFrom(ToMerge); 95 } 96 97 void print(raw_ostream &OS) { 98 switch (Type) { 99 case SCEVType::INT: 100 OS << "SCEVType::INT"; 101 break; 102 case SCEVType::PARAM: 103 OS << "SCEVType::PARAM"; 104 break; 105 case SCEVType::IV: 106 OS << "SCEVType::IV"; 107 break; 108 case SCEVType::INVALID: 109 OS << "SCEVType::INVALID"; 110 break; 111 } 112 } 113 }; 114 115 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) { 116 VR.print(OS); 117 return OS; 118 } 119 120 /// Check if a SCEV is valid in a SCoP. 121 struct SCEVValidator 122 : public SCEVVisitor<SCEVValidator, class ValidatorResult> { 123 private: 124 const Region *R; 125 Loop *Scope; 126 ScalarEvolution &SE; 127 InvariantLoadsSetTy *ILS; 128 129 public: 130 SCEVValidator(const Region *R, Loop *Scope, ScalarEvolution &SE, 131 InvariantLoadsSetTy *ILS) 132 : R(R), Scope(Scope), SE(SE), ILS(ILS) {} 133 134 class ValidatorResult visitConstant(const SCEVConstant *Constant) { 135 return ValidatorResult(SCEVType::INT); 136 } 137 138 class ValidatorResult visitZeroExtendOrTruncateExpr(const SCEV *Expr, 139 const SCEV *Operand) { 140 ValidatorResult Op = visit(Operand); 141 auto Type = Op.getType(); 142 143 // If unsigned operations are allowed return the operand, otherwise 144 // check if we can model the expression without unsigned assumptions. 145 if (PollyAllowUnsignedOperations || Type == SCEVType::INVALID) 146 return Op; 147 148 if (Type == SCEVType::IV) 149 return ValidatorResult(SCEVType::INVALID); 150 return ValidatorResult(SCEVType::PARAM, Expr); 151 } 152 153 class ValidatorResult visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { 154 return visit(Expr->getOperand()); 155 } 156 157 class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) { 158 return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand()); 159 } 160 161 class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { 162 return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand()); 163 } 164 165 class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { 166 return visit(Expr->getOperand()); 167 } 168 169 class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) { 170 ValidatorResult Return(SCEVType::INT); 171 172 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 173 ValidatorResult Op = visit(Expr->getOperand(i)); 174 Return.merge(Op); 175 176 // Early exit. 177 if (!Return.isValid()) 178 break; 179 } 180 181 return Return; 182 } 183 184 class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) { 185 ValidatorResult Return(SCEVType::INT); 186 187 bool HasMultipleParams = false; 188 189 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 190 ValidatorResult Op = visit(Expr->getOperand(i)); 191 192 if (Op.isINT()) 193 continue; 194 195 if (Op.isPARAM() && Return.isPARAM()) { 196 HasMultipleParams = true; 197 continue; 198 } 199 200 if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) { 201 LLVM_DEBUG( 202 dbgs() << "INVALID: More than one non-int operand in MulExpr\n" 203 << "\tExpr: " << *Expr << "\n" 204 << "\tPrevious expression type: " << Return << "\n" 205 << "\tNext operand (" << Op << "): " << *Expr->getOperand(i) 206 << "\n"); 207 208 return ValidatorResult(SCEVType::INVALID); 209 } 210 211 Return.merge(Op); 212 } 213 214 if (HasMultipleParams && Return.isValid()) 215 return ValidatorResult(SCEVType::PARAM, Expr); 216 217 return Return; 218 } 219 220 class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) { 221 if (!Expr->isAffine()) { 222 LLVM_DEBUG(dbgs() << "INVALID: AddRec is not affine"); 223 return ValidatorResult(SCEVType::INVALID); 224 } 225 226 ValidatorResult Start = visit(Expr->getStart()); 227 ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE)); 228 229 if (!Start.isValid()) 230 return Start; 231 232 if (!Recurrence.isValid()) 233 return Recurrence; 234 235 auto *L = Expr->getLoop(); 236 if (R->contains(L) && (!Scope || !L->contains(Scope))) { 237 LLVM_DEBUG( 238 dbgs() << "INVALID: Loop of AddRec expression boxed in an a " 239 "non-affine subregion or has a non-synthesizable exit " 240 "value."); 241 return ValidatorResult(SCEVType::INVALID); 242 } 243 244 if (R->contains(L)) { 245 if (Recurrence.isINT()) { 246 ValidatorResult Result(SCEVType::IV); 247 Result.addParamsFrom(Start); 248 return Result; 249 } 250 251 LLVM_DEBUG(dbgs() << "INVALID: AddRec within scop has non-int" 252 "recurrence part"); 253 return ValidatorResult(SCEVType::INVALID); 254 } 255 256 assert(Recurrence.isConstant() && "Expected 'Recurrence' to be constant"); 257 258 // Directly generate ValidatorResult for Expr if 'start' is zero. 259 if (Expr->getStart()->isZero()) 260 return ValidatorResult(SCEVType::PARAM, Expr); 261 262 // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}' 263 // if 'start' is not zero. 264 const SCEV *ZeroStartExpr = SE.getAddRecExpr( 265 SE.getConstant(Expr->getStart()->getType(), 0), 266 Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags()); 267 268 ValidatorResult ZeroStartResult = 269 ValidatorResult(SCEVType::PARAM, ZeroStartExpr); 270 ZeroStartResult.addParamsFrom(Start); 271 272 return ZeroStartResult; 273 } 274 275 class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) { 276 ValidatorResult Return(SCEVType::INT); 277 278 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 279 ValidatorResult Op = visit(Expr->getOperand(i)); 280 281 if (!Op.isValid()) 282 return Op; 283 284 Return.merge(Op); 285 } 286 287 return Return; 288 } 289 290 class ValidatorResult visitSMinExpr(const SCEVSMinExpr *Expr) { 291 ValidatorResult Return(SCEVType::INT); 292 293 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 294 ValidatorResult Op = visit(Expr->getOperand(i)); 295 296 if (!Op.isValid()) 297 return Op; 298 299 Return.merge(Op); 300 } 301 302 return Return; 303 } 304 305 class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) { 306 // We do not support unsigned max operations. If 'Expr' is constant during 307 // Scop execution we treat this as a parameter, otherwise we bail out. 308 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 309 ValidatorResult Op = visit(Expr->getOperand(i)); 310 311 if (!Op.isConstant()) { 312 LLVM_DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand"); 313 return ValidatorResult(SCEVType::INVALID); 314 } 315 } 316 317 return ValidatorResult(SCEVType::PARAM, Expr); 318 } 319 320 class ValidatorResult visitUMinExpr(const SCEVUMinExpr *Expr) { 321 // We do not support unsigned min operations. If 'Expr' is constant during 322 // Scop execution we treat this as a parameter, otherwise we bail out. 323 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 324 ValidatorResult Op = visit(Expr->getOperand(i)); 325 326 if (!Op.isConstant()) { 327 LLVM_DEBUG(dbgs() << "INVALID: UMinExpr has a non-constant operand"); 328 return ValidatorResult(SCEVType::INVALID); 329 } 330 } 331 332 return ValidatorResult(SCEVType::PARAM, Expr); 333 } 334 335 ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) { 336 if (R->contains(I)) { 337 LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction " 338 "within the region\n"); 339 return ValidatorResult(SCEVType::INVALID); 340 } 341 342 return ValidatorResult(SCEVType::PARAM, S); 343 } 344 345 ValidatorResult visitLoadInstruction(Instruction *I, const SCEV *S) { 346 if (R->contains(I) && ILS) { 347 ILS->insert(cast<LoadInst>(I)); 348 return ValidatorResult(SCEVType::PARAM, S); 349 } 350 351 return visitGenericInst(I, S); 352 } 353 354 ValidatorResult visitDivision(const SCEV *Dividend, const SCEV *Divisor, 355 const SCEV *DivExpr, 356 Instruction *SDiv = nullptr) { 357 358 // First check if we might be able to model the division, thus if the 359 // divisor is constant. If so, check the dividend, otherwise check if 360 // the whole division can be seen as a parameter. 361 if (isa<SCEVConstant>(Divisor) && !Divisor->isZero()) 362 return visit(Dividend); 363 364 // For signed divisions use the SDiv instruction to check for a parameter 365 // division, for unsigned divisions check the operands. 366 if (SDiv) 367 return visitGenericInst(SDiv, DivExpr); 368 369 ValidatorResult LHS = visit(Dividend); 370 ValidatorResult RHS = visit(Divisor); 371 if (LHS.isConstant() && RHS.isConstant()) 372 return ValidatorResult(SCEVType::PARAM, DivExpr); 373 374 LLVM_DEBUG( 375 dbgs() << "INVALID: unsigned division of non-constant expressions"); 376 return ValidatorResult(SCEVType::INVALID); 377 } 378 379 ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) { 380 if (!PollyAllowUnsignedOperations) 381 return ValidatorResult(SCEVType::INVALID); 382 383 auto *Dividend = Expr->getLHS(); 384 auto *Divisor = Expr->getRHS(); 385 return visitDivision(Dividend, Divisor, Expr); 386 } 387 388 ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *Expr) { 389 assert(SDiv->getOpcode() == Instruction::SDiv && 390 "Assumed SDiv instruction!"); 391 392 auto *Dividend = SE.getSCEV(SDiv->getOperand(0)); 393 auto *Divisor = SE.getSCEV(SDiv->getOperand(1)); 394 return visitDivision(Dividend, Divisor, Expr, SDiv); 395 } 396 397 ValidatorResult visitSRemInstruction(Instruction *SRem, const SCEV *S) { 398 assert(SRem->getOpcode() == Instruction::SRem && 399 "Assumed SRem instruction!"); 400 401 auto *Divisor = SRem->getOperand(1); 402 auto *CI = dyn_cast<ConstantInt>(Divisor); 403 if (!CI || CI->isZeroValue()) 404 return visitGenericInst(SRem, S); 405 406 auto *Dividend = SRem->getOperand(0); 407 auto *DividendSCEV = SE.getSCEV(Dividend); 408 return visit(DividendSCEV); 409 } 410 411 ValidatorResult visitUnknown(const SCEVUnknown *Expr) { 412 Value *V = Expr->getValue(); 413 414 if (!Expr->getType()->isIntegerTy() && !Expr->getType()->isPointerTy()) { 415 LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer or pointer"); 416 return ValidatorResult(SCEVType::INVALID); 417 } 418 419 if (isa<UndefValue>(V)) { 420 LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value"); 421 return ValidatorResult(SCEVType::INVALID); 422 } 423 424 if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) { 425 switch (I->getOpcode()) { 426 case Instruction::IntToPtr: 427 return visit(SE.getSCEVAtScope(I->getOperand(0), Scope)); 428 case Instruction::Load: 429 return visitLoadInstruction(I, Expr); 430 case Instruction::SDiv: 431 return visitSDivInstruction(I, Expr); 432 case Instruction::SRem: 433 return visitSRemInstruction(I, Expr); 434 default: 435 return visitGenericInst(I, Expr); 436 } 437 } 438 439 if (Expr->getType()->isPointerTy()) { 440 if (isa<ConstantPointerNull>(V)) 441 return ValidatorResult(SCEVType::INT); // "int" 442 } 443 444 return ValidatorResult(SCEVType::PARAM, Expr); 445 } 446 }; 447 448 /// Check whether a SCEV refers to an SSA name defined inside a region. 449 class SCEVInRegionDependences { 450 const Region *R; 451 Loop *Scope; 452 const InvariantLoadsSetTy &ILS; 453 bool AllowLoops; 454 bool HasInRegionDeps = false; 455 456 public: 457 SCEVInRegionDependences(const Region *R, Loop *Scope, bool AllowLoops, 458 const InvariantLoadsSetTy &ILS) 459 : R(R), Scope(Scope), ILS(ILS), AllowLoops(AllowLoops) {} 460 461 bool follow(const SCEV *S) { 462 if (auto Unknown = dyn_cast<SCEVUnknown>(S)) { 463 Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue()); 464 465 if (Inst) { 466 // When we invariant load hoist a load, we first make sure that there 467 // can be no dependences created by it in the Scop region. So, we should 468 // not consider scalar dependences to `LoadInst`s that are invariant 469 // load hoisted. 470 // 471 // If this check is not present, then we create data dependences which 472 // are strictly not necessary by tracking the invariant load as a 473 // scalar. 474 LoadInst *LI = dyn_cast<LoadInst>(Inst); 475 if (LI && ILS.count(LI) > 0) 476 return false; 477 } 478 479 // Return true when Inst is defined inside the region R. 480 if (!Inst || !R->contains(Inst)) 481 return true; 482 483 HasInRegionDeps = true; 484 return false; 485 } 486 487 if (auto AddRec = dyn_cast<SCEVAddRecExpr>(S)) { 488 if (AllowLoops) 489 return true; 490 491 auto *L = AddRec->getLoop(); 492 if (R->contains(L) && !L->contains(Scope)) { 493 HasInRegionDeps = true; 494 return false; 495 } 496 } 497 498 return true; 499 } 500 bool isDone() { return false; } 501 bool hasDependences() { return HasInRegionDeps; } 502 }; 503 504 namespace polly { 505 /// Find all loops referenced in SCEVAddRecExprs. 506 class SCEVFindLoops { 507 SetVector<const Loop *> &Loops; 508 509 public: 510 SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {} 511 512 bool follow(const SCEV *S) { 513 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) 514 Loops.insert(AddRec->getLoop()); 515 return true; 516 } 517 bool isDone() { return false; } 518 }; 519 520 void findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) { 521 SCEVFindLoops FindLoops(Loops); 522 SCEVTraversal<SCEVFindLoops> ST(FindLoops); 523 ST.visitAll(Expr); 524 } 525 526 /// Find all values referenced in SCEVUnknowns. 527 class SCEVFindValues { 528 ScalarEvolution &SE; 529 SetVector<Value *> &Values; 530 531 public: 532 SCEVFindValues(ScalarEvolution &SE, SetVector<Value *> &Values) 533 : SE(SE), Values(Values) {} 534 535 bool follow(const SCEV *S) { 536 const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S); 537 if (!Unknown) 538 return true; 539 540 Values.insert(Unknown->getValue()); 541 Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue()); 542 if (!Inst || (Inst->getOpcode() != Instruction::SRem && 543 Inst->getOpcode() != Instruction::SDiv)) 544 return false; 545 546 auto *Dividend = SE.getSCEV(Inst->getOperand(1)); 547 if (!isa<SCEVConstant>(Dividend)) 548 return false; 549 550 auto *Divisor = SE.getSCEV(Inst->getOperand(0)); 551 SCEVFindValues FindValues(SE, Values); 552 SCEVTraversal<SCEVFindValues> ST(FindValues); 553 ST.visitAll(Dividend); 554 ST.visitAll(Divisor); 555 556 return false; 557 } 558 bool isDone() { return false; } 559 }; 560 561 void findValues(const SCEV *Expr, ScalarEvolution &SE, 562 SetVector<Value *> &Values) { 563 SCEVFindValues FindValues(SE, Values); 564 SCEVTraversal<SCEVFindValues> ST(FindValues); 565 ST.visitAll(Expr); 566 } 567 568 bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R, 569 llvm::Loop *Scope, bool AllowLoops, 570 const InvariantLoadsSetTy &ILS) { 571 SCEVInRegionDependences InRegionDeps(R, Scope, AllowLoops, ILS); 572 SCEVTraversal<SCEVInRegionDependences> ST(InRegionDeps); 573 ST.visitAll(Expr); 574 return InRegionDeps.hasDependences(); 575 } 576 577 bool isAffineExpr(const Region *R, llvm::Loop *Scope, const SCEV *Expr, 578 ScalarEvolution &SE, InvariantLoadsSetTy *ILS) { 579 if (isa<SCEVCouldNotCompute>(Expr)) 580 return false; 581 582 SCEVValidator Validator(R, Scope, SE, ILS); 583 LLVM_DEBUG({ 584 dbgs() << "\n"; 585 dbgs() << "Expr: " << *Expr << "\n"; 586 dbgs() << "Region: " << R->getNameStr() << "\n"; 587 dbgs() << " -> "; 588 }); 589 590 ValidatorResult Result = Validator.visit(Expr); 591 592 LLVM_DEBUG({ 593 if (Result.isValid()) 594 dbgs() << "VALID\n"; 595 dbgs() << "\n"; 596 }); 597 598 return Result.isValid(); 599 } 600 601 static bool isAffineExpr(Value *V, const Region *R, Loop *Scope, 602 ScalarEvolution &SE, ParameterSetTy &Params) { 603 auto *E = SE.getSCEV(V); 604 if (isa<SCEVCouldNotCompute>(E)) 605 return false; 606 607 SCEVValidator Validator(R, Scope, SE, nullptr); 608 ValidatorResult Result = Validator.visit(E); 609 if (!Result.isValid()) 610 return false; 611 612 auto ResultParams = Result.getParameters(); 613 Params.insert(ResultParams.begin(), ResultParams.end()); 614 615 return true; 616 } 617 618 bool isAffineConstraint(Value *V, const Region *R, llvm::Loop *Scope, 619 ScalarEvolution &SE, ParameterSetTy &Params, 620 bool OrExpr) { 621 if (auto *ICmp = dyn_cast<ICmpInst>(V)) { 622 return isAffineConstraint(ICmp->getOperand(0), R, Scope, SE, Params, 623 true) && 624 isAffineConstraint(ICmp->getOperand(1), R, Scope, SE, Params, true); 625 } else if (auto *BinOp = dyn_cast<BinaryOperator>(V)) { 626 auto Opcode = BinOp->getOpcode(); 627 if (Opcode == Instruction::And || Opcode == Instruction::Or) 628 return isAffineConstraint(BinOp->getOperand(0), R, Scope, SE, Params, 629 false) && 630 isAffineConstraint(BinOp->getOperand(1), R, Scope, SE, Params, 631 false); 632 /* Fall through */ 633 } 634 635 if (!OrExpr) 636 return false; 637 638 return isAffineExpr(V, R, Scope, SE, Params); 639 } 640 641 ParameterSetTy getParamsInAffineExpr(const Region *R, Loop *Scope, 642 const SCEV *Expr, ScalarEvolution &SE) { 643 if (isa<SCEVCouldNotCompute>(Expr)) 644 return ParameterSetTy(); 645 646 InvariantLoadsSetTy ILS; 647 SCEVValidator Validator(R, Scope, SE, &ILS); 648 ValidatorResult Result = Validator.visit(Expr); 649 assert(Result.isValid() && "Requested parameters for an invalid SCEV!"); 650 651 return Result.getParameters(); 652 } 653 654 std::pair<const SCEVConstant *, const SCEV *> 655 extractConstantFactor(const SCEV *S, ScalarEvolution &SE) { 656 auto *ConstPart = cast<SCEVConstant>(SE.getConstant(S->getType(), 1)); 657 658 if (auto *Constant = dyn_cast<SCEVConstant>(S)) 659 return std::make_pair(Constant, SE.getConstant(S->getType(), 1)); 660 661 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S); 662 if (AddRec) { 663 auto *StartExpr = AddRec->getStart(); 664 if (StartExpr->isZero()) { 665 auto StepPair = extractConstantFactor(AddRec->getStepRecurrence(SE), SE); 666 auto *LeftOverAddRec = 667 SE.getAddRecExpr(StartExpr, StepPair.second, AddRec->getLoop(), 668 AddRec->getNoWrapFlags()); 669 return std::make_pair(StepPair.first, LeftOverAddRec); 670 } 671 return std::make_pair(ConstPart, S); 672 } 673 674 if (auto *Add = dyn_cast<SCEVAddExpr>(S)) { 675 SmallVector<const SCEV *, 4> LeftOvers; 676 auto Op0Pair = extractConstantFactor(Add->getOperand(0), SE); 677 auto *Factor = Op0Pair.first; 678 if (SE.isKnownNegative(Factor)) { 679 Factor = cast<SCEVConstant>(SE.getNegativeSCEV(Factor)); 680 LeftOvers.push_back(SE.getNegativeSCEV(Op0Pair.second)); 681 } else { 682 LeftOvers.push_back(Op0Pair.second); 683 } 684 685 for (unsigned u = 1, e = Add->getNumOperands(); u < e; u++) { 686 auto OpUPair = extractConstantFactor(Add->getOperand(u), SE); 687 // TODO: Use something smarter than equality here, e.g., gcd. 688 if (Factor == OpUPair.first) 689 LeftOvers.push_back(OpUPair.second); 690 else if (Factor == SE.getNegativeSCEV(OpUPair.first)) 691 LeftOvers.push_back(SE.getNegativeSCEV(OpUPair.second)); 692 else 693 return std::make_pair(ConstPart, S); 694 } 695 696 auto *NewAdd = SE.getAddExpr(LeftOvers, Add->getNoWrapFlags()); 697 return std::make_pair(Factor, NewAdd); 698 } 699 700 auto *Mul = dyn_cast<SCEVMulExpr>(S); 701 if (!Mul) 702 return std::make_pair(ConstPart, S); 703 704 SmallVector<const SCEV *, 4> LeftOvers; 705 for (auto *Op : Mul->operands()) 706 if (isa<SCEVConstant>(Op)) 707 ConstPart = cast<SCEVConstant>(SE.getMulExpr(ConstPart, Op)); 708 else 709 LeftOvers.push_back(Op); 710 711 return std::make_pair(ConstPart, SE.getMulExpr(LeftOvers)); 712 } 713 714 const SCEV *tryForwardThroughPHI(const SCEV *Expr, Region &R, 715 ScalarEvolution &SE, ScopDetection *SD) { 716 if (auto *Unknown = dyn_cast<SCEVUnknown>(Expr)) { 717 Value *V = Unknown->getValue(); 718 auto *PHI = dyn_cast<PHINode>(V); 719 if (!PHI) 720 return Expr; 721 722 Value *Final = nullptr; 723 724 for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) { 725 BasicBlock *Incoming = PHI->getIncomingBlock(i); 726 if (SD->isErrorBlock(*Incoming, R) && R.contains(Incoming)) 727 continue; 728 if (Final) 729 return Expr; 730 Final = PHI->getIncomingValue(i); 731 } 732 733 if (Final) 734 return SE.getSCEV(Final); 735 } 736 return Expr; 737 } 738 739 Value *getUniqueNonErrorValue(PHINode *PHI, Region *R, ScopDetection *SD) { 740 Value *V = nullptr; 741 for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) { 742 BasicBlock *BB = PHI->getIncomingBlock(i); 743 if (!SD->isErrorBlock(*BB, *R)) { 744 if (V) 745 return nullptr; 746 V = PHI->getIncomingValue(i); 747 } 748 } 749 750 return V; 751 } 752 } // namespace polly 753