1 2 #include "polly/Support/SCEVValidator.h" 3 #include "polly/ScopInfo.h" 4 #include "llvm/Analysis/RegionInfo.h" 5 #include "llvm/Analysis/ScalarEvolution.h" 6 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 7 #include "llvm/Support/Debug.h" 8 9 using namespace llvm; 10 using namespace polly; 11 12 #define DEBUG_TYPE "polly-scev-validator" 13 14 namespace SCEVType { 15 /// @brief The type of a SCEV 16 /// 17 /// To check for the validity of a SCEV we assign to each SCEV a type. The 18 /// possible types are INT, PARAM, IV and INVALID. The order of the types is 19 /// important. The subexpressions of SCEV with a type X can only have a type 20 /// that is smaller or equal than X. 21 enum TYPE { 22 // An integer value. 23 INT, 24 25 // An expression that is constant during the execution of the Scop, 26 // but that may depend on parameters unknown at compile time. 27 PARAM, 28 29 // An expression that may change during the execution of the SCoP. 30 IV, 31 32 // An invalid expression. 33 INVALID 34 }; 35 } 36 37 /// @brief The result the validator returns for a SCEV expression. 38 class ValidatorResult { 39 /// @brief The type of the expression 40 SCEVType::TYPE Type; 41 42 /// @brief The set of Parameters in the expression. 43 ParameterSetTy Parameters; 44 45 public: 46 /// @brief The copy constructor 47 ValidatorResult(const ValidatorResult &Source) { 48 Type = Source.Type; 49 Parameters = Source.Parameters; 50 } 51 52 /// @brief Construct a result with a certain type and no parameters. 53 ValidatorResult(SCEVType::TYPE Type) : Type(Type) { 54 assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter"); 55 } 56 57 /// @brief Construct a result with a certain type and a single parameter. 58 ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) { 59 Parameters.insert(Expr); 60 } 61 62 /// @brief Get the type of the ValidatorResult. 63 SCEVType::TYPE getType() { return Type; } 64 65 /// @brief Is the analyzed SCEV constant during the execution of the SCoP. 66 bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; } 67 68 /// @brief Is the analyzed SCEV valid. 69 bool isValid() { return Type != SCEVType::INVALID; } 70 71 /// @brief Is the analyzed SCEV of Type IV. 72 bool isIV() { return Type == SCEVType::IV; } 73 74 /// @brief Is the analyzed SCEV of Type INT. 75 bool isINT() { return Type == SCEVType::INT; } 76 77 /// @brief Is the analyzed SCEV of Type PARAM. 78 bool isPARAM() { return Type == SCEVType::PARAM; } 79 80 /// @brief Get the parameters of this validator result. 81 const ParameterSetTy &getParameters() { return Parameters; } 82 83 /// @brief Add the parameters of Source to this result. 84 void addParamsFrom(const ValidatorResult &Source) { 85 Parameters.insert(Source.Parameters.begin(), Source.Parameters.end()); 86 } 87 88 /// @brief Merge a result. 89 /// 90 /// This means to merge the parameters and to set the Type to the most 91 /// specific Type that matches both. 92 void merge(const ValidatorResult &ToMerge) { 93 Type = std::max(Type, ToMerge.Type); 94 addParamsFrom(ToMerge); 95 } 96 97 void print(raw_ostream &OS) { 98 switch (Type) { 99 case SCEVType::INT: 100 OS << "SCEVType::INT"; 101 break; 102 case SCEVType::PARAM: 103 OS << "SCEVType::PARAM"; 104 break; 105 case SCEVType::IV: 106 OS << "SCEVType::IV"; 107 break; 108 case SCEVType::INVALID: 109 OS << "SCEVType::INVALID"; 110 break; 111 } 112 } 113 }; 114 115 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) { 116 VR.print(OS); 117 return OS; 118 } 119 120 /// Check if a SCEV is valid in a SCoP. 121 struct SCEVValidator 122 : public SCEVVisitor<SCEVValidator, class ValidatorResult> { 123 private: 124 const Region *R; 125 Loop *Scope; 126 ScalarEvolution &SE; 127 InvariantLoadsSetTy *ILS; 128 129 public: 130 SCEVValidator(const Region *R, Loop *Scope, ScalarEvolution &SE, 131 InvariantLoadsSetTy *ILS) 132 : R(R), Scope(Scope), SE(SE), ILS(ILS) {} 133 134 class ValidatorResult visitConstant(const SCEVConstant *Constant) { 135 return ValidatorResult(SCEVType::INT); 136 } 137 138 class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) { 139 return visit(Expr->getOperand()); 140 } 141 142 class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { 143 return visit(Expr->getOperand()); 144 } 145 146 class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { 147 return visit(Expr->getOperand()); 148 } 149 150 class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) { 151 ValidatorResult Return(SCEVType::INT); 152 153 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 154 ValidatorResult Op = visit(Expr->getOperand(i)); 155 Return.merge(Op); 156 157 // Early exit. 158 if (!Return.isValid()) 159 break; 160 } 161 162 return Return; 163 } 164 165 class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) { 166 ValidatorResult Return(SCEVType::INT); 167 168 bool HasMultipleParams = false; 169 170 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 171 ValidatorResult Op = visit(Expr->getOperand(i)); 172 173 if (Op.isINT()) 174 continue; 175 176 if (Op.isPARAM() && Return.isPARAM()) { 177 HasMultipleParams = true; 178 continue; 179 } 180 181 if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) { 182 DEBUG(dbgs() << "INVALID: More than one non-int operand in MulExpr\n" 183 << "\tExpr: " << *Expr << "\n" 184 << "\tPrevious expression type: " << Return << "\n" 185 << "\tNext operand (" << Op 186 << "): " << *Expr->getOperand(i) << "\n"); 187 188 return ValidatorResult(SCEVType::INVALID); 189 } 190 191 Return.merge(Op); 192 } 193 194 if (HasMultipleParams && Return.isValid()) 195 return ValidatorResult(SCEVType::PARAM, Expr); 196 197 return Return; 198 } 199 200 class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) { 201 if (!Expr->isAffine()) { 202 DEBUG(dbgs() << "INVALID: AddRec is not affine"); 203 return ValidatorResult(SCEVType::INVALID); 204 } 205 206 ValidatorResult Start = visit(Expr->getStart()); 207 ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE)); 208 209 if (!Start.isValid()) 210 return Start; 211 212 if (!Recurrence.isValid()) 213 return Recurrence; 214 215 auto *L = Expr->getLoop(); 216 if (R->contains(L) && (!Scope || !L->contains(Scope))) { 217 DEBUG(dbgs() << "INVALID: AddRec out of a loop whose exit value is not " 218 "synthesizable"); 219 return ValidatorResult(SCEVType::INVALID); 220 } 221 222 if (R->contains(L)) { 223 if (Recurrence.isINT()) { 224 ValidatorResult Result(SCEVType::IV); 225 Result.addParamsFrom(Start); 226 return Result; 227 } 228 229 DEBUG(dbgs() << "INVALID: AddRec within scop has non-int" 230 "recurrence part"); 231 return ValidatorResult(SCEVType::INVALID); 232 } 233 234 assert(Start.isConstant() && Recurrence.isConstant() && 235 "Expected 'Start' and 'Recurrence' to be constant"); 236 237 // Directly generate ValidatorResult for Expr if 'start' is zero. 238 if (Expr->getStart()->isZero()) 239 return ValidatorResult(SCEVType::PARAM, Expr); 240 241 // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}' 242 // if 'start' is not zero. 243 const SCEV *ZeroStartExpr = SE.getAddRecExpr( 244 SE.getConstant(Expr->getStart()->getType(), 0), 245 Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags()); 246 247 ValidatorResult ZeroStartResult = 248 ValidatorResult(SCEVType::PARAM, ZeroStartExpr); 249 ZeroStartResult.addParamsFrom(Start); 250 251 return ZeroStartResult; 252 } 253 254 class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) { 255 ValidatorResult Return(SCEVType::INT); 256 257 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 258 ValidatorResult Op = visit(Expr->getOperand(i)); 259 260 if (!Op.isValid()) 261 return Op; 262 263 Return.merge(Op); 264 } 265 266 return Return; 267 } 268 269 class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) { 270 // We do not support unsigned max operations. If 'Expr' is constant during 271 // Scop execution we treat this as a parameter, otherwise we bail out. 272 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) { 273 ValidatorResult Op = visit(Expr->getOperand(i)); 274 275 if (!Op.isConstant()) { 276 DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand"); 277 return ValidatorResult(SCEVType::INVALID); 278 } 279 } 280 281 return ValidatorResult(SCEVType::PARAM, Expr); 282 } 283 284 ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) { 285 if (R->contains(I)) { 286 DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction " 287 "within the region\n"); 288 return ValidatorResult(SCEVType::INVALID); 289 } 290 291 return ValidatorResult(SCEVType::PARAM, S); 292 } 293 294 ValidatorResult visitLoadInstruction(Instruction *I, const SCEV *S) { 295 if (R->contains(I) && ILS) { 296 ILS->insert(cast<LoadInst>(I)); 297 return ValidatorResult(SCEVType::PARAM, S); 298 } 299 300 return visitGenericInst(I, S); 301 } 302 303 ValidatorResult visitDivision(const SCEV *Dividend, const SCEV *Divisor, 304 const SCEV *DivExpr, 305 Instruction *SDiv = nullptr) { 306 307 // First check if we might be able to model the division, thus if the 308 // divisor is constant. If so, check the dividend, otherwise check if 309 // the whole division can be seen as a parameter. 310 if (isa<SCEVConstant>(Divisor)) 311 return visit(Dividend); 312 313 // For signed divisions use the SDiv instruction to check for a parameter 314 // division, for unsigned divisions check the operands. 315 if (SDiv) 316 return visitGenericInst(SDiv, DivExpr); 317 318 ValidatorResult LHS = visit(Dividend); 319 ValidatorResult RHS = visit(Divisor); 320 if (LHS.isConstant() && RHS.isConstant()) 321 return ValidatorResult(SCEVType::PARAM, DivExpr); 322 323 DEBUG(dbgs() << "INVALID: unsigned division of non-constant expressions"); 324 return ValidatorResult(SCEVType::INVALID); 325 } 326 327 ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) { 328 auto *Dividend = Expr->getLHS(); 329 auto *Divisor = Expr->getRHS(); 330 return visitDivision(Dividend, Divisor, Expr); 331 } 332 333 ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *Expr) { 334 assert(SDiv->getOpcode() == Instruction::SDiv && 335 "Assumed SDiv instruction!"); 336 337 auto *Dividend = SE.getSCEV(SDiv->getOperand(0)); 338 auto *Divisor = SE.getSCEV(SDiv->getOperand(1)); 339 return visitDivision(Dividend, Divisor, Expr, SDiv); 340 } 341 342 ValidatorResult visitSRemInstruction(Instruction *SRem, const SCEV *S) { 343 assert(SRem->getOpcode() == Instruction::SRem && 344 "Assumed SRem instruction!"); 345 346 auto *Divisor = SRem->getOperand(1); 347 auto *CI = dyn_cast<ConstantInt>(Divisor); 348 if (!CI) 349 return visitGenericInst(SRem, S); 350 351 auto *Dividend = SRem->getOperand(0); 352 auto *DividendSCEV = SE.getSCEV(Dividend); 353 return visit(DividendSCEV); 354 } 355 356 ValidatorResult visitUnknown(const SCEVUnknown *Expr) { 357 Value *V = Expr->getValue(); 358 359 if (!Expr->getType()->isIntegerTy() && !Expr->getType()->isPointerTy()) { 360 DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer or pointer"); 361 return ValidatorResult(SCEVType::INVALID); 362 } 363 364 if (isa<UndefValue>(V)) { 365 DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value"); 366 return ValidatorResult(SCEVType::INVALID); 367 } 368 369 if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) { 370 switch (I->getOpcode()) { 371 case Instruction::Load: 372 return visitLoadInstruction(I, Expr); 373 case Instruction::SDiv: 374 return visitSDivInstruction(I, Expr); 375 case Instruction::SRem: 376 return visitSRemInstruction(I, Expr); 377 default: 378 return visitGenericInst(I, Expr); 379 } 380 } 381 382 return ValidatorResult(SCEVType::PARAM, Expr); 383 } 384 }; 385 386 /// @brief Check whether a SCEV refers to an SSA name defined inside a region. 387 class SCEVInRegionDependences { 388 const Region *R; 389 Loop *Scope; 390 bool AllowLoops; 391 bool HasInRegionDeps = false; 392 393 public: 394 SCEVInRegionDependences(const Region *R, Loop *Scope, bool AllowLoops) 395 : R(R), Scope(Scope), AllowLoops(AllowLoops) {} 396 397 bool follow(const SCEV *S) { 398 if (auto Unknown = dyn_cast<SCEVUnknown>(S)) { 399 Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue()); 400 401 // Return true when Inst is defined inside the region R. 402 if (Inst && R->contains(Inst)) { 403 HasInRegionDeps = true; 404 return false; 405 } 406 } else if (auto AddRec = dyn_cast<SCEVAddRecExpr>(S)) { 407 if (!AllowLoops) { 408 if (!Scope) { 409 HasInRegionDeps = true; 410 return false; 411 } 412 auto *L = AddRec->getLoop(); 413 if (R->contains(L) && !L->contains(Scope)) { 414 HasInRegionDeps = true; 415 return false; 416 } 417 } 418 } 419 return true; 420 } 421 bool isDone() { return false; } 422 bool hasDependences() { return HasInRegionDeps; } 423 }; 424 425 namespace polly { 426 /// Find all loops referenced in SCEVAddRecExprs. 427 class SCEVFindLoops { 428 SetVector<const Loop *> &Loops; 429 430 public: 431 SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {} 432 433 bool follow(const SCEV *S) { 434 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) 435 Loops.insert(AddRec->getLoop()); 436 return true; 437 } 438 bool isDone() { return false; } 439 }; 440 441 void findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) { 442 SCEVFindLoops FindLoops(Loops); 443 SCEVTraversal<SCEVFindLoops> ST(FindLoops); 444 ST.visitAll(Expr); 445 } 446 447 /// Find all values referenced in SCEVUnknowns. 448 class SCEVFindValues { 449 ScalarEvolution &SE; 450 SetVector<Value *> &Values; 451 452 public: 453 SCEVFindValues(ScalarEvolution &SE, SetVector<Value *> &Values) 454 : SE(SE), Values(Values) {} 455 456 bool follow(const SCEV *S) { 457 const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S); 458 if (!Unknown) 459 return true; 460 461 Values.insert(Unknown->getValue()); 462 Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue()); 463 if (!Inst || (Inst->getOpcode() != Instruction::SRem && 464 Inst->getOpcode() != Instruction::SDiv)) 465 return false; 466 467 auto *Dividend = SE.getSCEV(Inst->getOperand(1)); 468 if (!isa<SCEVConstant>(Dividend)) 469 return false; 470 471 auto *Divisor = SE.getSCEV(Inst->getOperand(0)); 472 SCEVFindValues FindValues(SE, Values); 473 SCEVTraversal<SCEVFindValues> ST(FindValues); 474 ST.visitAll(Dividend); 475 ST.visitAll(Divisor); 476 477 return false; 478 } 479 bool isDone() { return false; } 480 }; 481 482 void findValues(const SCEV *Expr, ScalarEvolution &SE, 483 SetVector<Value *> &Values) { 484 SCEVFindValues FindValues(SE, Values); 485 SCEVTraversal<SCEVFindValues> ST(FindValues); 486 ST.visitAll(Expr); 487 } 488 489 bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R, 490 llvm::Loop *Scope, bool AllowLoops) { 491 SCEVInRegionDependences InRegionDeps(R, Scope, AllowLoops); 492 SCEVTraversal<SCEVInRegionDependences> ST(InRegionDeps); 493 ST.visitAll(Expr); 494 return InRegionDeps.hasDependences(); 495 } 496 497 bool isAffineExpr(const Region *R, llvm::Loop *Scope, const SCEV *Expr, 498 ScalarEvolution &SE, InvariantLoadsSetTy *ILS) { 499 if (isa<SCEVCouldNotCompute>(Expr)) 500 return false; 501 502 SCEVValidator Validator(R, Scope, SE, ILS); 503 DEBUG({ 504 dbgs() << "\n"; 505 dbgs() << "Expr: " << *Expr << "\n"; 506 dbgs() << "Region: " << R->getNameStr() << "\n"; 507 dbgs() << " -> "; 508 }); 509 510 ValidatorResult Result = Validator.visit(Expr); 511 512 DEBUG({ 513 if (Result.isValid()) 514 dbgs() << "VALID\n"; 515 dbgs() << "\n"; 516 }); 517 518 return Result.isValid(); 519 } 520 521 static bool isAffineExpr(Value *V, const Region *R, Loop *Scope, 522 ScalarEvolution &SE, ParameterSetTy &Params) { 523 auto *E = SE.getSCEV(V); 524 if (isa<SCEVCouldNotCompute>(E)) 525 return false; 526 527 SCEVValidator Validator(R, Scope, SE, nullptr); 528 ValidatorResult Result = Validator.visit(E); 529 if (!Result.isValid()) 530 return false; 531 532 auto ResultParams = Result.getParameters(); 533 Params.insert(ResultParams.begin(), ResultParams.end()); 534 535 return true; 536 } 537 538 bool isAffineConstraint(Value *V, const Region *R, llvm::Loop *Scope, 539 ScalarEvolution &SE, ParameterSetTy &Params, 540 bool OrExpr) { 541 if (auto *ICmp = dyn_cast<ICmpInst>(V)) { 542 return isAffineConstraint(ICmp->getOperand(0), R, Scope, SE, Params, 543 true) && 544 isAffineConstraint(ICmp->getOperand(1), R, Scope, SE, Params, true); 545 } else if (auto *BinOp = dyn_cast<BinaryOperator>(V)) { 546 auto Opcode = BinOp->getOpcode(); 547 if (Opcode == Instruction::And || Opcode == Instruction::Or) 548 return isAffineConstraint(BinOp->getOperand(0), R, Scope, SE, Params, 549 false) && 550 isAffineConstraint(BinOp->getOperand(1), R, Scope, SE, Params, 551 false); 552 /* Fall through */ 553 } 554 555 if (!OrExpr) 556 return false; 557 558 return isAffineExpr(V, R, Scope, SE, Params); 559 } 560 561 ParameterSetTy getParamsInAffineExpr(const Region *R, Loop *Scope, 562 const SCEV *Expr, ScalarEvolution &SE) { 563 if (isa<SCEVCouldNotCompute>(Expr)) 564 return ParameterSetTy(); 565 566 InvariantLoadsSetTy ILS; 567 SCEVValidator Validator(R, Scope, SE, &ILS); 568 ValidatorResult Result = Validator.visit(Expr); 569 assert(Result.isValid() && "Requested parameters for an invalid SCEV!"); 570 571 return Result.getParameters(); 572 } 573 574 std::pair<const SCEVConstant *, const SCEV *> 575 extractConstantFactor(const SCEV *S, ScalarEvolution &SE) { 576 577 auto *LeftOver = SE.getConstant(S->getType(), 1); 578 auto *ConstPart = cast<SCEVConstant>(SE.getConstant(S->getType(), 1)); 579 580 if (auto *Constant = dyn_cast<SCEVConstant>(S)) 581 return std::make_pair(Constant, LeftOver); 582 583 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S); 584 if (AddRec) { 585 auto *StartExpr = AddRec->getStart(); 586 if (StartExpr->isZero()) { 587 auto StepPair = extractConstantFactor(AddRec->getStepRecurrence(SE), SE); 588 auto *LeftOverAddRec = 589 SE.getAddRecExpr(StartExpr, StepPair.second, AddRec->getLoop(), 590 AddRec->getNoWrapFlags()); 591 return std::make_pair(StepPair.first, LeftOverAddRec); 592 } 593 return std::make_pair(ConstPart, S); 594 } 595 596 if (auto *Add = dyn_cast<SCEVAddExpr>(S)) { 597 SmallVector<const SCEV *, 4> LeftOvers; 598 auto Op0Pair = extractConstantFactor(Add->getOperand(0), SE); 599 auto *Factor = Op0Pair.first; 600 if (SE.isKnownNegative(Factor)) { 601 Factor = cast<SCEVConstant>(SE.getNegativeSCEV(Factor)); 602 LeftOvers.push_back(SE.getNegativeSCEV(Op0Pair.second)); 603 } else { 604 LeftOvers.push_back(Op0Pair.second); 605 } 606 607 for (unsigned u = 1, e = Add->getNumOperands(); u < e; u++) { 608 auto OpUPair = extractConstantFactor(Add->getOperand(u), SE); 609 // TODO: Use something smarter than equality here, e.g., gcd. 610 if (Factor == OpUPair.first) 611 LeftOvers.push_back(OpUPair.second); 612 else if (Factor == SE.getNegativeSCEV(OpUPair.first)) 613 LeftOvers.push_back(SE.getNegativeSCEV(OpUPair.second)); 614 else 615 return std::make_pair(ConstPart, S); 616 } 617 618 auto *NewAdd = SE.getAddExpr(LeftOvers, Add->getNoWrapFlags()); 619 return std::make_pair(Factor, NewAdd); 620 } 621 622 auto *Mul = dyn_cast<SCEVMulExpr>(S); 623 if (!Mul) 624 return std::make_pair(ConstPart, S); 625 626 for (auto *Op : Mul->operands()) 627 if (isa<SCEVConstant>(Op)) 628 ConstPart = cast<SCEVConstant>(SE.getMulExpr(ConstPart, Op)); 629 else 630 LeftOver = SE.getMulExpr(LeftOver, Op); 631 632 return std::make_pair(ConstPart, LeftOver); 633 } 634 } 635