1
2 #include "polly/Support/SCEVValidator.h"
3 #include "polly/ScopDetection.h"
4 #include "llvm/Analysis/RegionInfo.h"
5 #include "llvm/Analysis/ScalarEvolution.h"
6 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
7 #include "llvm/Support/Debug.h"
8
9 using namespace llvm;
10 using namespace polly;
11
12 #define DEBUG_TYPE "polly-scev-validator"
13
14 namespace SCEVType {
15 /// The type of a SCEV
16 ///
17 /// To check for the validity of a SCEV we assign to each SCEV a type. The
18 /// possible types are INT, PARAM, IV and INVALID. The order of the types is
19 /// important. The subexpressions of SCEV with a type X can only have a type
20 /// that is smaller or equal than X.
21 enum TYPE {
22 // An integer value.
23 INT,
24
25 // An expression that is constant during the execution of the Scop,
26 // but that may depend on parameters unknown at compile time.
27 PARAM,
28
29 // An expression that may change during the execution of the SCoP.
30 IV,
31
32 // An invalid expression.
33 INVALID
34 };
35 } // namespace SCEVType
36
37 /// The result the validator returns for a SCEV expression.
38 class ValidatorResult final {
39 /// The type of the expression
40 SCEVType::TYPE Type;
41
42 /// The set of Parameters in the expression.
43 ParameterSetTy Parameters;
44
45 public:
46 /// The copy constructor
ValidatorResult(const ValidatorResult & Source)47 ValidatorResult(const ValidatorResult &Source) {
48 Type = Source.Type;
49 Parameters = Source.Parameters;
50 }
51
52 /// Construct a result with a certain type and no parameters.
ValidatorResult(SCEVType::TYPE Type)53 ValidatorResult(SCEVType::TYPE Type) : Type(Type) {
54 assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter");
55 }
56
57 /// Construct a result with a certain type and a single parameter.
ValidatorResult(SCEVType::TYPE Type,const SCEV * Expr)58 ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) {
59 Parameters.insert(Expr);
60 }
61
62 /// Get the type of the ValidatorResult.
getType()63 SCEVType::TYPE getType() { return Type; }
64
65 /// Is the analyzed SCEV constant during the execution of the SCoP.
isConstant()66 bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; }
67
68 /// Is the analyzed SCEV valid.
isValid()69 bool isValid() { return Type != SCEVType::INVALID; }
70
71 /// Is the analyzed SCEV of Type IV.
isIV()72 bool isIV() { return Type == SCEVType::IV; }
73
74 /// Is the analyzed SCEV of Type INT.
isINT()75 bool isINT() { return Type == SCEVType::INT; }
76
77 /// Is the analyzed SCEV of Type PARAM.
isPARAM()78 bool isPARAM() { return Type == SCEVType::PARAM; }
79
80 /// Get the parameters of this validator result.
getParameters()81 const ParameterSetTy &getParameters() { return Parameters; }
82
83 /// Add the parameters of Source to this result.
addParamsFrom(const ValidatorResult & Source)84 void addParamsFrom(const ValidatorResult &Source) {
85 Parameters.insert(Source.Parameters.begin(), Source.Parameters.end());
86 }
87
88 /// Merge a result.
89 ///
90 /// This means to merge the parameters and to set the Type to the most
91 /// specific Type that matches both.
merge(const ValidatorResult & ToMerge)92 void merge(const ValidatorResult &ToMerge) {
93 Type = std::max(Type, ToMerge.Type);
94 addParamsFrom(ToMerge);
95 }
96
print(raw_ostream & OS)97 void print(raw_ostream &OS) {
98 switch (Type) {
99 case SCEVType::INT:
100 OS << "SCEVType::INT";
101 break;
102 case SCEVType::PARAM:
103 OS << "SCEVType::PARAM";
104 break;
105 case SCEVType::IV:
106 OS << "SCEVType::IV";
107 break;
108 case SCEVType::INVALID:
109 OS << "SCEVType::INVALID";
110 break;
111 }
112 }
113 };
114
operator <<(raw_ostream & OS,ValidatorResult & VR)115 raw_ostream &operator<<(raw_ostream &OS, ValidatorResult &VR) {
116 VR.print(OS);
117 return OS;
118 }
119
120 /// Check if a SCEV is valid in a SCoP.
121 class SCEVValidator : public SCEVVisitor<SCEVValidator, ValidatorResult> {
122 private:
123 const Region *R;
124 Loop *Scope;
125 ScalarEvolution &SE;
126 InvariantLoadsSetTy *ILS;
127
128 public:
SCEVValidator(const Region * R,Loop * Scope,ScalarEvolution & SE,InvariantLoadsSetTy * ILS)129 SCEVValidator(const Region *R, Loop *Scope, ScalarEvolution &SE,
130 InvariantLoadsSetTy *ILS)
131 : R(R), Scope(Scope), SE(SE), ILS(ILS) {}
132
visitConstant(const SCEVConstant * Constant)133 ValidatorResult visitConstant(const SCEVConstant *Constant) {
134 return ValidatorResult(SCEVType::INT);
135 }
136
visitZeroExtendOrTruncateExpr(const SCEV * Expr,const SCEV * Operand)137 ValidatorResult visitZeroExtendOrTruncateExpr(const SCEV *Expr,
138 const SCEV *Operand) {
139 ValidatorResult Op = visit(Operand);
140 auto Type = Op.getType();
141
142 // If unsigned operations are allowed return the operand, otherwise
143 // check if we can model the expression without unsigned assumptions.
144 if (PollyAllowUnsignedOperations || Type == SCEVType::INVALID)
145 return Op;
146
147 if (Type == SCEVType::IV)
148 return ValidatorResult(SCEVType::INVALID);
149 return ValidatorResult(SCEVType::PARAM, Expr);
150 }
151
visitPtrToIntExpr(const SCEVPtrToIntExpr * Expr)152 ValidatorResult visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
153 return visit(Expr->getOperand());
154 }
155
visitTruncateExpr(const SCEVTruncateExpr * Expr)156 ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) {
157 return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
158 }
159
visitZeroExtendExpr(const SCEVZeroExtendExpr * Expr)160 ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
161 return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
162 }
163
visitSignExtendExpr(const SCEVSignExtendExpr * Expr)164 ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
165 return visit(Expr->getOperand());
166 }
167
visitAddExpr(const SCEVAddExpr * Expr)168 ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) {
169 ValidatorResult Return(SCEVType::INT);
170
171 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
172 ValidatorResult Op = visit(Expr->getOperand(i));
173 Return.merge(Op);
174
175 // Early exit.
176 if (!Return.isValid())
177 break;
178 }
179
180 return Return;
181 }
182
visitMulExpr(const SCEVMulExpr * Expr)183 ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) {
184 ValidatorResult Return(SCEVType::INT);
185
186 bool HasMultipleParams = false;
187
188 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
189 ValidatorResult Op = visit(Expr->getOperand(i));
190
191 if (Op.isINT())
192 continue;
193
194 if (Op.isPARAM() && Return.isPARAM()) {
195 HasMultipleParams = true;
196 continue;
197 }
198
199 if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) {
200 LLVM_DEBUG(
201 dbgs() << "INVALID: More than one non-int operand in MulExpr\n"
202 << "\tExpr: " << *Expr << "\n"
203 << "\tPrevious expression type: " << Return << "\n"
204 << "\tNext operand (" << Op << "): " << *Expr->getOperand(i)
205 << "\n");
206
207 return ValidatorResult(SCEVType::INVALID);
208 }
209
210 Return.merge(Op);
211 }
212
213 if (HasMultipleParams && Return.isValid())
214 return ValidatorResult(SCEVType::PARAM, Expr);
215
216 return Return;
217 }
218
visitAddRecExpr(const SCEVAddRecExpr * Expr)219 ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) {
220 if (!Expr->isAffine()) {
221 LLVM_DEBUG(dbgs() << "INVALID: AddRec is not affine");
222 return ValidatorResult(SCEVType::INVALID);
223 }
224
225 ValidatorResult Start = visit(Expr->getStart());
226 ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE));
227
228 if (!Start.isValid())
229 return Start;
230
231 if (!Recurrence.isValid())
232 return Recurrence;
233
234 auto *L = Expr->getLoop();
235 if (R->contains(L) && (!Scope || !L->contains(Scope))) {
236 LLVM_DEBUG(
237 dbgs() << "INVALID: Loop of AddRec expression boxed in an a "
238 "non-affine subregion or has a non-synthesizable exit "
239 "value.");
240 return ValidatorResult(SCEVType::INVALID);
241 }
242
243 if (R->contains(L)) {
244 if (Recurrence.isINT()) {
245 ValidatorResult Result(SCEVType::IV);
246 Result.addParamsFrom(Start);
247 return Result;
248 }
249
250 LLVM_DEBUG(dbgs() << "INVALID: AddRec within scop has non-int"
251 "recurrence part");
252 return ValidatorResult(SCEVType::INVALID);
253 }
254
255 assert(Recurrence.isConstant() && "Expected 'Recurrence' to be constant");
256
257 // Directly generate ValidatorResult for Expr if 'start' is zero.
258 if (Expr->getStart()->isZero())
259 return ValidatorResult(SCEVType::PARAM, Expr);
260
261 // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}'
262 // if 'start' is not zero.
263 const SCEV *ZeroStartExpr = SE.getAddRecExpr(
264 SE.getConstant(Expr->getStart()->getType(), 0),
265 Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags());
266
267 ValidatorResult ZeroStartResult =
268 ValidatorResult(SCEVType::PARAM, ZeroStartExpr);
269 ZeroStartResult.addParamsFrom(Start);
270
271 return ZeroStartResult;
272 }
273
visitSMaxExpr(const SCEVSMaxExpr * Expr)274 ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) {
275 ValidatorResult Return(SCEVType::INT);
276
277 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
278 ValidatorResult Op = visit(Expr->getOperand(i));
279
280 if (!Op.isValid())
281 return Op;
282
283 Return.merge(Op);
284 }
285
286 return Return;
287 }
288
visitSMinExpr(const SCEVSMinExpr * Expr)289 ValidatorResult visitSMinExpr(const SCEVSMinExpr *Expr) {
290 ValidatorResult Return(SCEVType::INT);
291
292 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
293 ValidatorResult Op = visit(Expr->getOperand(i));
294
295 if (!Op.isValid())
296 return Op;
297
298 Return.merge(Op);
299 }
300
301 return Return;
302 }
303
visitUMaxExpr(const SCEVUMaxExpr * Expr)304 ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) {
305 // We do not support unsigned max operations. If 'Expr' is constant during
306 // Scop execution we treat this as a parameter, otherwise we bail out.
307 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
308 ValidatorResult Op = visit(Expr->getOperand(i));
309
310 if (!Op.isConstant()) {
311 LLVM_DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand");
312 return ValidatorResult(SCEVType::INVALID);
313 }
314 }
315
316 return ValidatorResult(SCEVType::PARAM, Expr);
317 }
318
visitUMinExpr(const SCEVUMinExpr * Expr)319 ValidatorResult visitUMinExpr(const SCEVUMinExpr *Expr) {
320 // We do not support unsigned min operations. If 'Expr' is constant during
321 // Scop execution we treat this as a parameter, otherwise we bail out.
322 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
323 ValidatorResult Op = visit(Expr->getOperand(i));
324
325 if (!Op.isConstant()) {
326 LLVM_DEBUG(dbgs() << "INVALID: UMinExpr has a non-constant operand");
327 return ValidatorResult(SCEVType::INVALID);
328 }
329 }
330
331 return ValidatorResult(SCEVType::PARAM, Expr);
332 }
333
visitSequentialUMinExpr(const SCEVSequentialUMinExpr * Expr)334 ValidatorResult visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
335 // We do not support unsigned min operations. If 'Expr' is constant during
336 // Scop execution we treat this as a parameter, otherwise we bail out.
337 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
338 ValidatorResult Op = visit(Expr->getOperand(i));
339
340 if (!Op.isConstant()) {
341 LLVM_DEBUG(
342 dbgs()
343 << "INVALID: SCEVSequentialUMinExpr has a non-constant operand");
344 return ValidatorResult(SCEVType::INVALID);
345 }
346 }
347
348 return ValidatorResult(SCEVType::PARAM, Expr);
349 }
350
visitGenericInst(Instruction * I,const SCEV * S)351 ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) {
352 if (R->contains(I)) {
353 LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction "
354 "within the region\n");
355 return ValidatorResult(SCEVType::INVALID);
356 }
357
358 return ValidatorResult(SCEVType::PARAM, S);
359 }
360
visitLoadInstruction(Instruction * I,const SCEV * S)361 ValidatorResult visitLoadInstruction(Instruction *I, const SCEV *S) {
362 if (R->contains(I) && ILS) {
363 ILS->insert(cast<LoadInst>(I));
364 return ValidatorResult(SCEVType::PARAM, S);
365 }
366
367 return visitGenericInst(I, S);
368 }
369
visitDivision(const SCEV * Dividend,const SCEV * Divisor,const SCEV * DivExpr,Instruction * SDiv=nullptr)370 ValidatorResult visitDivision(const SCEV *Dividend, const SCEV *Divisor,
371 const SCEV *DivExpr,
372 Instruction *SDiv = nullptr) {
373
374 // First check if we might be able to model the division, thus if the
375 // divisor is constant. If so, check the dividend, otherwise check if
376 // the whole division can be seen as a parameter.
377 if (isa<SCEVConstant>(Divisor) && !Divisor->isZero())
378 return visit(Dividend);
379
380 // For signed divisions use the SDiv instruction to check for a parameter
381 // division, for unsigned divisions check the operands.
382 if (SDiv)
383 return visitGenericInst(SDiv, DivExpr);
384
385 ValidatorResult LHS = visit(Dividend);
386 ValidatorResult RHS = visit(Divisor);
387 if (LHS.isConstant() && RHS.isConstant())
388 return ValidatorResult(SCEVType::PARAM, DivExpr);
389
390 LLVM_DEBUG(
391 dbgs() << "INVALID: unsigned division of non-constant expressions");
392 return ValidatorResult(SCEVType::INVALID);
393 }
394
visitUDivExpr(const SCEVUDivExpr * Expr)395 ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) {
396 if (!PollyAllowUnsignedOperations)
397 return ValidatorResult(SCEVType::INVALID);
398
399 auto *Dividend = Expr->getLHS();
400 auto *Divisor = Expr->getRHS();
401 return visitDivision(Dividend, Divisor, Expr);
402 }
403
visitSDivInstruction(Instruction * SDiv,const SCEV * Expr)404 ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *Expr) {
405 assert(SDiv->getOpcode() == Instruction::SDiv &&
406 "Assumed SDiv instruction!");
407
408 auto *Dividend = SE.getSCEV(SDiv->getOperand(0));
409 auto *Divisor = SE.getSCEV(SDiv->getOperand(1));
410 return visitDivision(Dividend, Divisor, Expr, SDiv);
411 }
412
visitSRemInstruction(Instruction * SRem,const SCEV * S)413 ValidatorResult visitSRemInstruction(Instruction *SRem, const SCEV *S) {
414 assert(SRem->getOpcode() == Instruction::SRem &&
415 "Assumed SRem instruction!");
416
417 auto *Divisor = SRem->getOperand(1);
418 auto *CI = dyn_cast<ConstantInt>(Divisor);
419 if (!CI || CI->isZeroValue())
420 return visitGenericInst(SRem, S);
421
422 auto *Dividend = SRem->getOperand(0);
423 auto *DividendSCEV = SE.getSCEV(Dividend);
424 return visit(DividendSCEV);
425 }
426
visitUnknown(const SCEVUnknown * Expr)427 ValidatorResult visitUnknown(const SCEVUnknown *Expr) {
428 Value *V = Expr->getValue();
429
430 if (!Expr->getType()->isIntegerTy() && !Expr->getType()->isPointerTy()) {
431 LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer or pointer");
432 return ValidatorResult(SCEVType::INVALID);
433 }
434
435 if (isa<UndefValue>(V)) {
436 LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value");
437 return ValidatorResult(SCEVType::INVALID);
438 }
439
440 if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) {
441 switch (I->getOpcode()) {
442 case Instruction::IntToPtr:
443 return visit(SE.getSCEVAtScope(I->getOperand(0), Scope));
444 case Instruction::Load:
445 return visitLoadInstruction(I, Expr);
446 case Instruction::SDiv:
447 return visitSDivInstruction(I, Expr);
448 case Instruction::SRem:
449 return visitSRemInstruction(I, Expr);
450 default:
451 return visitGenericInst(I, Expr);
452 }
453 }
454
455 if (Expr->getType()->isPointerTy()) {
456 if (isa<ConstantPointerNull>(V))
457 return ValidatorResult(SCEVType::INT); // "int"
458 }
459
460 return ValidatorResult(SCEVType::PARAM, Expr);
461 }
462 };
463
464 /// Check whether a SCEV refers to an SSA name defined inside a region.
465 class SCEVInRegionDependences final {
466 const Region *R;
467 Loop *Scope;
468 const InvariantLoadsSetTy &ILS;
469 bool AllowLoops;
470 bool HasInRegionDeps = false;
471
472 public:
SCEVInRegionDependences(const Region * R,Loop * Scope,bool AllowLoops,const InvariantLoadsSetTy & ILS)473 SCEVInRegionDependences(const Region *R, Loop *Scope, bool AllowLoops,
474 const InvariantLoadsSetTy &ILS)
475 : R(R), Scope(Scope), ILS(ILS), AllowLoops(AllowLoops) {}
476
follow(const SCEV * S)477 bool follow(const SCEV *S) {
478 if (auto Unknown = dyn_cast<SCEVUnknown>(S)) {
479 Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
480
481 if (Inst) {
482 // When we invariant load hoist a load, we first make sure that there
483 // can be no dependences created by it in the Scop region. So, we should
484 // not consider scalar dependences to `LoadInst`s that are invariant
485 // load hoisted.
486 //
487 // If this check is not present, then we create data dependences which
488 // are strictly not necessary by tracking the invariant load as a
489 // scalar.
490 LoadInst *LI = dyn_cast<LoadInst>(Inst);
491 if (LI && ILS.contains(LI))
492 return false;
493 }
494
495 // Return true when Inst is defined inside the region R.
496 if (!Inst || !R->contains(Inst))
497 return true;
498
499 HasInRegionDeps = true;
500 return false;
501 }
502
503 if (auto AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
504 if (AllowLoops)
505 return true;
506
507 auto *L = AddRec->getLoop();
508 if (R->contains(L) && !L->contains(Scope)) {
509 HasInRegionDeps = true;
510 return false;
511 }
512 }
513
514 return true;
515 }
isDone()516 bool isDone() { return false; }
hasDependences()517 bool hasDependences() { return HasInRegionDeps; }
518 };
519
520 /// Find all loops referenced in SCEVAddRecExprs.
521 class SCEVFindLoops final {
522 SetVector<const Loop *> &Loops;
523
524 public:
SCEVFindLoops(SetVector<const Loop * > & Loops)525 SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {}
526
follow(const SCEV * S)527 bool follow(const SCEV *S) {
528 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S))
529 Loops.insert(AddRec->getLoop());
530 return true;
531 }
isDone()532 bool isDone() { return false; }
533 };
534
findLoops(const SCEV * Expr,SetVector<const Loop * > & Loops)535 void polly::findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) {
536 SCEVFindLoops FindLoops(Loops);
537 SCEVTraversal<SCEVFindLoops> ST(FindLoops);
538 ST.visitAll(Expr);
539 }
540
541 /// Find all values referenced in SCEVUnknowns.
542 class SCEVFindValues final {
543 ScalarEvolution &SE;
544 SetVector<Value *> &Values;
545
546 public:
SCEVFindValues(ScalarEvolution & SE,SetVector<Value * > & Values)547 SCEVFindValues(ScalarEvolution &SE, SetVector<Value *> &Values)
548 : SE(SE), Values(Values) {}
549
follow(const SCEV * S)550 bool follow(const SCEV *S) {
551 const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S);
552 if (!Unknown)
553 return true;
554
555 Values.insert(Unknown->getValue());
556 Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
557 if (!Inst || (Inst->getOpcode() != Instruction::SRem &&
558 Inst->getOpcode() != Instruction::SDiv))
559 return false;
560
561 auto *Dividend = SE.getSCEV(Inst->getOperand(1));
562 if (!isa<SCEVConstant>(Dividend))
563 return false;
564
565 auto *Divisor = SE.getSCEV(Inst->getOperand(0));
566 SCEVFindValues FindValues(SE, Values);
567 SCEVTraversal<SCEVFindValues> ST(FindValues);
568 ST.visitAll(Dividend);
569 ST.visitAll(Divisor);
570
571 return false;
572 }
isDone()573 bool isDone() { return false; }
574 };
575
findValues(const SCEV * Expr,ScalarEvolution & SE,SetVector<Value * > & Values)576 void polly::findValues(const SCEV *Expr, ScalarEvolution &SE,
577 SetVector<Value *> &Values) {
578 SCEVFindValues FindValues(SE, Values);
579 SCEVTraversal<SCEVFindValues> ST(FindValues);
580 ST.visitAll(Expr);
581 }
582
hasScalarDepsInsideRegion(const SCEV * Expr,const Region * R,llvm::Loop * Scope,bool AllowLoops,const InvariantLoadsSetTy & ILS)583 bool polly::hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R,
584 llvm::Loop *Scope, bool AllowLoops,
585 const InvariantLoadsSetTy &ILS) {
586 SCEVInRegionDependences InRegionDeps(R, Scope, AllowLoops, ILS);
587 SCEVTraversal<SCEVInRegionDependences> ST(InRegionDeps);
588 ST.visitAll(Expr);
589 return InRegionDeps.hasDependences();
590 }
591
isAffineExpr(const Region * R,llvm::Loop * Scope,const SCEV * Expr,ScalarEvolution & SE,InvariantLoadsSetTy * ILS)592 bool polly::isAffineExpr(const Region *R, llvm::Loop *Scope, const SCEV *Expr,
593 ScalarEvolution &SE, InvariantLoadsSetTy *ILS) {
594 if (isa<SCEVCouldNotCompute>(Expr))
595 return false;
596
597 SCEVValidator Validator(R, Scope, SE, ILS);
598 LLVM_DEBUG({
599 dbgs() << "\n";
600 dbgs() << "Expr: " << *Expr << "\n";
601 dbgs() << "Region: " << R->getNameStr() << "\n";
602 dbgs() << " -> ";
603 });
604
605 ValidatorResult Result = Validator.visit(Expr);
606
607 LLVM_DEBUG({
608 if (Result.isValid())
609 dbgs() << "VALID\n";
610 dbgs() << "\n";
611 });
612
613 return Result.isValid();
614 }
615
isAffineExpr(Value * V,const Region * R,Loop * Scope,ScalarEvolution & SE,ParameterSetTy & Params)616 static bool isAffineExpr(Value *V, const Region *R, Loop *Scope,
617 ScalarEvolution &SE, ParameterSetTy &Params) {
618 auto *E = SE.getSCEV(V);
619 if (isa<SCEVCouldNotCompute>(E))
620 return false;
621
622 SCEVValidator Validator(R, Scope, SE, nullptr);
623 ValidatorResult Result = Validator.visit(E);
624 if (!Result.isValid())
625 return false;
626
627 auto ResultParams = Result.getParameters();
628 Params.insert(ResultParams.begin(), ResultParams.end());
629
630 return true;
631 }
632
isAffineConstraint(Value * V,const Region * R,Loop * Scope,ScalarEvolution & SE,ParameterSetTy & Params,bool OrExpr)633 bool polly::isAffineConstraint(Value *V, const Region *R, Loop *Scope,
634 ScalarEvolution &SE, ParameterSetTy &Params,
635 bool OrExpr) {
636 if (auto *ICmp = dyn_cast<ICmpInst>(V)) {
637 return isAffineConstraint(ICmp->getOperand(0), R, Scope, SE, Params,
638 true) &&
639 isAffineConstraint(ICmp->getOperand(1), R, Scope, SE, Params, true);
640 } else if (auto *BinOp = dyn_cast<BinaryOperator>(V)) {
641 auto Opcode = BinOp->getOpcode();
642 if (Opcode == Instruction::And || Opcode == Instruction::Or)
643 return isAffineConstraint(BinOp->getOperand(0), R, Scope, SE, Params,
644 false) &&
645 isAffineConstraint(BinOp->getOperand(1), R, Scope, SE, Params,
646 false);
647 /* Fall through */
648 }
649
650 if (!OrExpr)
651 return false;
652
653 return ::isAffineExpr(V, R, Scope, SE, Params);
654 }
655
getParamsInAffineExpr(const Region * R,Loop * Scope,const SCEV * Expr,ScalarEvolution & SE)656 ParameterSetTy polly::getParamsInAffineExpr(const Region *R, Loop *Scope,
657 const SCEV *Expr,
658 ScalarEvolution &SE) {
659 if (isa<SCEVCouldNotCompute>(Expr))
660 return ParameterSetTy();
661
662 InvariantLoadsSetTy ILS;
663 SCEVValidator Validator(R, Scope, SE, &ILS);
664 ValidatorResult Result = Validator.visit(Expr);
665 assert(Result.isValid() && "Requested parameters for an invalid SCEV!");
666
667 return Result.getParameters();
668 }
669
670 std::pair<const SCEVConstant *, const SCEV *>
extractConstantFactor(const SCEV * S,ScalarEvolution & SE)671 polly::extractConstantFactor(const SCEV *S, ScalarEvolution &SE) {
672 auto *ConstPart = cast<SCEVConstant>(SE.getConstant(S->getType(), 1));
673
674 if (auto *Constant = dyn_cast<SCEVConstant>(S))
675 return std::make_pair(Constant, SE.getConstant(S->getType(), 1));
676
677 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
678 if (AddRec) {
679 auto *StartExpr = AddRec->getStart();
680 if (StartExpr->isZero()) {
681 auto StepPair = extractConstantFactor(AddRec->getStepRecurrence(SE), SE);
682 auto *LeftOverAddRec =
683 SE.getAddRecExpr(StartExpr, StepPair.second, AddRec->getLoop(),
684 AddRec->getNoWrapFlags());
685 return std::make_pair(StepPair.first, LeftOverAddRec);
686 }
687 return std::make_pair(ConstPart, S);
688 }
689
690 if (auto *Add = dyn_cast<SCEVAddExpr>(S)) {
691 SmallVector<const SCEV *, 4> LeftOvers;
692 auto Op0Pair = extractConstantFactor(Add->getOperand(0), SE);
693 auto *Factor = Op0Pair.first;
694 if (SE.isKnownNegative(Factor)) {
695 Factor = cast<SCEVConstant>(SE.getNegativeSCEV(Factor));
696 LeftOvers.push_back(SE.getNegativeSCEV(Op0Pair.second));
697 } else {
698 LeftOvers.push_back(Op0Pair.second);
699 }
700
701 for (unsigned u = 1, e = Add->getNumOperands(); u < e; u++) {
702 auto OpUPair = extractConstantFactor(Add->getOperand(u), SE);
703 // TODO: Use something smarter than equality here, e.g., gcd.
704 if (Factor == OpUPair.first)
705 LeftOvers.push_back(OpUPair.second);
706 else if (Factor == SE.getNegativeSCEV(OpUPair.first))
707 LeftOvers.push_back(SE.getNegativeSCEV(OpUPair.second));
708 else
709 return std::make_pair(ConstPart, S);
710 }
711
712 auto *NewAdd = SE.getAddExpr(LeftOvers, Add->getNoWrapFlags());
713 return std::make_pair(Factor, NewAdd);
714 }
715
716 auto *Mul = dyn_cast<SCEVMulExpr>(S);
717 if (!Mul)
718 return std::make_pair(ConstPart, S);
719
720 SmallVector<const SCEV *, 4> LeftOvers;
721 for (auto *Op : Mul->operands())
722 if (isa<SCEVConstant>(Op))
723 ConstPart = cast<SCEVConstant>(SE.getMulExpr(ConstPart, Op));
724 else
725 LeftOvers.push_back(Op);
726
727 return std::make_pair(ConstPart, SE.getMulExpr(LeftOvers));
728 }
729
tryForwardThroughPHI(const SCEV * Expr,Region & R,ScalarEvolution & SE,ScopDetection * SD)730 const SCEV *polly::tryForwardThroughPHI(const SCEV *Expr, Region &R,
731 ScalarEvolution &SE,
732 ScopDetection *SD) {
733 if (auto *Unknown = dyn_cast<SCEVUnknown>(Expr)) {
734 Value *V = Unknown->getValue();
735 auto *PHI = dyn_cast<PHINode>(V);
736 if (!PHI)
737 return Expr;
738
739 Value *Final = nullptr;
740
741 for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) {
742 BasicBlock *Incoming = PHI->getIncomingBlock(i);
743 if (SD->isErrorBlock(*Incoming, R) && R.contains(Incoming))
744 continue;
745 if (Final)
746 return Expr;
747 Final = PHI->getIncomingValue(i);
748 }
749
750 if (Final)
751 return SE.getSCEV(Final);
752 }
753 return Expr;
754 }
755
getUniqueNonErrorValue(PHINode * PHI,Region * R,ScopDetection * SD)756 Value *polly::getUniqueNonErrorValue(PHINode *PHI, Region *R,
757 ScopDetection *SD) {
758 Value *V = nullptr;
759 for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) {
760 BasicBlock *BB = PHI->getIncomingBlock(i);
761 if (!SD->isErrorBlock(*BB, *R)) {
762 if (V)
763 return nullptr;
764 V = PHI->getIncomingValue(i);
765 }
766 }
767
768 return V;
769 }
770