1 
2 #include "polly/Support/SCEVValidator.h"
3 #include "polly/ScopInfo.h"
4 #include "llvm/Analysis/RegionInfo.h"
5 #include "llvm/Analysis/ScalarEvolution.h"
6 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
7 #include "llvm/Support/Debug.h"
8 
9 using namespace llvm;
10 using namespace polly;
11 
12 #define DEBUG_TYPE "polly-scev-validator"
13 
14 namespace SCEVType {
15 /// @brief The type of a SCEV
16 ///
17 /// To check for the validity of a SCEV we assign to each SCEV a type. The
18 /// possible types are INT, PARAM, IV and INVALID. The order of the types is
19 /// important. The subexpressions of SCEV with a type X can only have a type
20 /// that is smaller or equal than X.
21 enum TYPE {
22   // An integer value.
23   INT,
24 
25   // An expression that is constant during the execution of the Scop,
26   // but that may depend on parameters unknown at compile time.
27   PARAM,
28 
29   // An expression that may change during the execution of the SCoP.
30   IV,
31 
32   // An invalid expression.
33   INVALID
34 };
35 }
36 
37 /// @brief The result the validator returns for a SCEV expression.
38 class ValidatorResult {
39   /// @brief The type of the expression
40   SCEVType::TYPE Type;
41 
42   /// @brief The set of Parameters in the expression.
43   ParameterSetTy Parameters;
44 
45 public:
46   /// @brief The copy constructor
47   ValidatorResult(const ValidatorResult &Source) {
48     Type = Source.Type;
49     Parameters = Source.Parameters;
50   }
51 
52   /// @brief Construct a result with a certain type and no parameters.
53   ValidatorResult(SCEVType::TYPE Type) : Type(Type) {
54     assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter");
55   }
56 
57   /// @brief Construct a result with a certain type and a single parameter.
58   ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) {
59     Parameters.insert(Expr);
60   }
61 
62   /// @brief Get the type of the ValidatorResult.
63   SCEVType::TYPE getType() { return Type; }
64 
65   /// @brief Is the analyzed SCEV constant during the execution of the SCoP.
66   bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; }
67 
68   /// @brief Is the analyzed SCEV valid.
69   bool isValid() { return Type != SCEVType::INVALID; }
70 
71   /// @brief Is the analyzed SCEV of Type IV.
72   bool isIV() { return Type == SCEVType::IV; }
73 
74   /// @brief Is the analyzed SCEV of Type INT.
75   bool isINT() { return Type == SCEVType::INT; }
76 
77   /// @brief Is the analyzed SCEV of Type PARAM.
78   bool isPARAM() { return Type == SCEVType::PARAM; }
79 
80   /// @brief Get the parameters of this validator result.
81   const ParameterSetTy &getParameters() { return Parameters; }
82 
83   /// @brief Add the parameters of Source to this result.
84   void addParamsFrom(const ValidatorResult &Source) {
85     Parameters.insert(Source.Parameters.begin(), Source.Parameters.end());
86   }
87 
88   /// @brief Merge a result.
89   ///
90   /// This means to merge the parameters and to set the Type to the most
91   /// specific Type that matches both.
92   void merge(const ValidatorResult &ToMerge) {
93     Type = std::max(Type, ToMerge.Type);
94     addParamsFrom(ToMerge);
95   }
96 
97   void print(raw_ostream &OS) {
98     switch (Type) {
99     case SCEVType::INT:
100       OS << "SCEVType::INT";
101       break;
102     case SCEVType::PARAM:
103       OS << "SCEVType::PARAM";
104       break;
105     case SCEVType::IV:
106       OS << "SCEVType::IV";
107       break;
108     case SCEVType::INVALID:
109       OS << "SCEVType::INVALID";
110       break;
111     }
112   }
113 };
114 
115 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) {
116   VR.print(OS);
117   return OS;
118 }
119 
120 /// Check if a SCEV is valid in a SCoP.
121 struct SCEVValidator
122     : public SCEVVisitor<SCEVValidator, class ValidatorResult> {
123 private:
124   const Region *R;
125   Loop *Scope;
126   ScalarEvolution &SE;
127   InvariantLoadsSetTy *ILS;
128 
129 public:
130   SCEVValidator(const Region *R, Loop *Scope, ScalarEvolution &SE,
131                 InvariantLoadsSetTy *ILS)
132       : R(R), Scope(Scope), SE(SE), ILS(ILS) {}
133 
134   class ValidatorResult visitConstant(const SCEVConstant *Constant) {
135     return ValidatorResult(SCEVType::INT);
136   }
137 
138   class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) {
139     return visit(Expr->getOperand());
140   }
141 
142   class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
143     return visit(Expr->getOperand());
144   }
145 
146   class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
147     return visit(Expr->getOperand());
148   }
149 
150   class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) {
151     ValidatorResult Return(SCEVType::INT);
152 
153     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
154       ValidatorResult Op = visit(Expr->getOperand(i));
155       Return.merge(Op);
156 
157       // Early exit.
158       if (!Return.isValid())
159         break;
160     }
161 
162     return Return;
163   }
164 
165   class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) {
166     ValidatorResult Return(SCEVType::INT);
167 
168     bool HasMultipleParams = false;
169 
170     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
171       ValidatorResult Op = visit(Expr->getOperand(i));
172 
173       if (Op.isINT())
174         continue;
175 
176       if (Op.isPARAM() && Return.isPARAM()) {
177         HasMultipleParams = true;
178         continue;
179       }
180 
181       if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) {
182         DEBUG(dbgs() << "INVALID: More than one non-int operand in MulExpr\n"
183                      << "\tExpr: " << *Expr << "\n"
184                      << "\tPrevious expression type: " << Return << "\n"
185                      << "\tNext operand (" << Op
186                      << "): " << *Expr->getOperand(i) << "\n");
187 
188         return ValidatorResult(SCEVType::INVALID);
189       }
190 
191       Return.merge(Op);
192     }
193 
194     if (HasMultipleParams && Return.isValid())
195       return ValidatorResult(SCEVType::PARAM, Expr);
196 
197     return Return;
198   }
199 
200   class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) {
201     if (!Expr->isAffine()) {
202       DEBUG(dbgs() << "INVALID: AddRec is not affine");
203       return ValidatorResult(SCEVType::INVALID);
204     }
205 
206     ValidatorResult Start = visit(Expr->getStart());
207     ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE));
208 
209     if (!Start.isValid())
210       return Start;
211 
212     if (!Recurrence.isValid())
213       return Recurrence;
214 
215     auto *L = Expr->getLoop();
216     if (R->contains(L) && (!Scope || !L->contains(Scope))) {
217       DEBUG(dbgs() << "INVALID: AddRec out of a loop whose exit value is not "
218                       "synthesizable");
219       return ValidatorResult(SCEVType::INVALID);
220     }
221 
222     if (R->contains(L)) {
223       if (Recurrence.isINT()) {
224         ValidatorResult Result(SCEVType::IV);
225         Result.addParamsFrom(Start);
226         return Result;
227       }
228 
229       DEBUG(dbgs() << "INVALID: AddRec within scop has non-int"
230                       "recurrence part");
231       return ValidatorResult(SCEVType::INVALID);
232     }
233 
234     assert(Start.isConstant() && Recurrence.isConstant() &&
235            "Expected 'Start' and 'Recurrence' to be constant");
236 
237     // Directly generate ValidatorResult for Expr if 'start' is zero.
238     if (Expr->getStart()->isZero())
239       return ValidatorResult(SCEVType::PARAM, Expr);
240 
241     // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}'
242     // if 'start' is not zero.
243     const SCEV *ZeroStartExpr = SE.getAddRecExpr(
244         SE.getConstant(Expr->getStart()->getType(), 0),
245         Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags());
246 
247     ValidatorResult ZeroStartResult =
248         ValidatorResult(SCEVType::PARAM, ZeroStartExpr);
249     ZeroStartResult.addParamsFrom(Start);
250 
251     return ZeroStartResult;
252   }
253 
254   class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) {
255     ValidatorResult Return(SCEVType::INT);
256 
257     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
258       ValidatorResult Op = visit(Expr->getOperand(i));
259 
260       if (!Op.isValid())
261         return Op;
262 
263       Return.merge(Op);
264     }
265 
266     return Return;
267   }
268 
269   class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) {
270     // We do not support unsigned max operations. If 'Expr' is constant during
271     // Scop execution we treat this as a parameter, otherwise we bail out.
272     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
273       ValidatorResult Op = visit(Expr->getOperand(i));
274 
275       if (!Op.isConstant()) {
276         DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand");
277         return ValidatorResult(SCEVType::INVALID);
278       }
279     }
280 
281     return ValidatorResult(SCEVType::PARAM, Expr);
282   }
283 
284   ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) {
285     if (R->contains(I)) {
286       DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction "
287                       "within the region\n");
288       return ValidatorResult(SCEVType::INVALID);
289     }
290 
291     return ValidatorResult(SCEVType::PARAM, S);
292   }
293 
294   ValidatorResult visitLoadInstruction(Instruction *I, const SCEV *S) {
295     if (R->contains(I) && ILS) {
296       ILS->insert(cast<LoadInst>(I));
297       return ValidatorResult(SCEVType::PARAM, S);
298     }
299 
300     return visitGenericInst(I, S);
301   }
302 
303   ValidatorResult visitDivision(const SCEV *Dividend, const SCEV *Divisor,
304                                 const SCEV *DivExpr,
305                                 Instruction *SDiv = nullptr) {
306 
307     // First check if we might be able to model the division, thus if the
308     // divisor is constant. If so, check the dividend, otherwise check if
309     // the whole division can be seen as a parameter.
310     if (isa<SCEVConstant>(Divisor))
311       return visit(Dividend);
312 
313     // For signed divisions use the SDiv instruction to check for a parameter
314     // division, for unsigned divisions check the operands.
315     if (SDiv)
316       return visitGenericInst(SDiv, DivExpr);
317 
318     ValidatorResult LHS = visit(Dividend);
319     ValidatorResult RHS = visit(Divisor);
320     if (LHS.isConstant() && RHS.isConstant())
321       return ValidatorResult(SCEVType::PARAM, DivExpr);
322 
323     DEBUG(dbgs() << "INVALID: unsigned division of non-constant expressions");
324     return ValidatorResult(SCEVType::INVALID);
325   }
326 
327   ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) {
328     auto *Dividend = Expr->getLHS();
329     auto *Divisor = Expr->getRHS();
330     return visitDivision(Dividend, Divisor, Expr);
331   }
332 
333   ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *Expr) {
334     assert(SDiv->getOpcode() == Instruction::SDiv &&
335            "Assumed SDiv instruction!");
336 
337     auto *Dividend = SE.getSCEV(SDiv->getOperand(0));
338     auto *Divisor = SE.getSCEV(SDiv->getOperand(1));
339     return visitDivision(Dividend, Divisor, Expr, SDiv);
340   }
341 
342   ValidatorResult visitSRemInstruction(Instruction *SRem, const SCEV *S) {
343     assert(SRem->getOpcode() == Instruction::SRem &&
344            "Assumed SRem instruction!");
345 
346     auto *Divisor = SRem->getOperand(1);
347     auto *CI = dyn_cast<ConstantInt>(Divisor);
348     if (!CI)
349       return visitGenericInst(SRem, S);
350 
351     auto *Dividend = SRem->getOperand(0);
352     auto *DividendSCEV = SE.getSCEV(Dividend);
353     return visit(DividendSCEV);
354   }
355 
356   ValidatorResult visitUnknown(const SCEVUnknown *Expr) {
357     Value *V = Expr->getValue();
358 
359     if (!Expr->getType()->isIntegerTy() && !Expr->getType()->isPointerTy()) {
360       DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer or pointer");
361       return ValidatorResult(SCEVType::INVALID);
362     }
363 
364     if (isa<UndefValue>(V)) {
365       DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value");
366       return ValidatorResult(SCEVType::INVALID);
367     }
368 
369     if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) {
370       switch (I->getOpcode()) {
371       case Instruction::Load:
372         return visitLoadInstruction(I, Expr);
373       case Instruction::SDiv:
374         return visitSDivInstruction(I, Expr);
375       case Instruction::SRem:
376         return visitSRemInstruction(I, Expr);
377       default:
378         return visitGenericInst(I, Expr);
379       }
380     }
381 
382     return ValidatorResult(SCEVType::PARAM, Expr);
383   }
384 };
385 
386 /// @brief Check whether a SCEV refers to an SSA name defined inside a region.
387 class SCEVInRegionDependences {
388   const Region *R;
389   Loop *Scope;
390   bool AllowLoops;
391   bool HasInRegionDeps = false;
392 
393 public:
394   SCEVInRegionDependences(const Region *R, Loop *Scope, bool AllowLoops)
395       : R(R), Scope(Scope), AllowLoops(AllowLoops) {}
396 
397   bool follow(const SCEV *S) {
398     if (auto Unknown = dyn_cast<SCEVUnknown>(S)) {
399       Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
400 
401       // Return true when Inst is defined inside the region R.
402       if (Inst && R->contains(Inst)) {
403         HasInRegionDeps = true;
404         return false;
405       }
406     } else if (auto AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
407       if (!AllowLoops) {
408         if (!Scope) {
409           HasInRegionDeps = true;
410           return false;
411         }
412         auto *L = AddRec->getLoop();
413         if (R->contains(L) && !L->contains(Scope)) {
414           HasInRegionDeps = true;
415           return false;
416         }
417       }
418     }
419     return true;
420   }
421   bool isDone() { return false; }
422   bool hasDependences() { return HasInRegionDeps; }
423 };
424 
425 namespace polly {
426 /// Find all loops referenced in SCEVAddRecExprs.
427 class SCEVFindLoops {
428   SetVector<const Loop *> &Loops;
429 
430 public:
431   SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {}
432 
433   bool follow(const SCEV *S) {
434     if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S))
435       Loops.insert(AddRec->getLoop());
436     return true;
437   }
438   bool isDone() { return false; }
439 };
440 
441 void findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) {
442   SCEVFindLoops FindLoops(Loops);
443   SCEVTraversal<SCEVFindLoops> ST(FindLoops);
444   ST.visitAll(Expr);
445 }
446 
447 /// Find all values referenced in SCEVUnknowns.
448 class SCEVFindValues {
449   ScalarEvolution &SE;
450   SetVector<Value *> &Values;
451 
452 public:
453   SCEVFindValues(ScalarEvolution &SE, SetVector<Value *> &Values)
454       : SE(SE), Values(Values) {}
455 
456   bool follow(const SCEV *S) {
457     const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S);
458     if (!Unknown)
459       return true;
460 
461     Values.insert(Unknown->getValue());
462     Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
463     if (!Inst || (Inst->getOpcode() != Instruction::SRem &&
464                   Inst->getOpcode() != Instruction::SDiv))
465       return false;
466 
467     auto *Dividend = SE.getSCEV(Inst->getOperand(1));
468     if (!isa<SCEVConstant>(Dividend))
469       return false;
470 
471     auto *Divisor = SE.getSCEV(Inst->getOperand(0));
472     SCEVFindValues FindValues(SE, Values);
473     SCEVTraversal<SCEVFindValues> ST(FindValues);
474     ST.visitAll(Dividend);
475     ST.visitAll(Divisor);
476 
477     return false;
478   }
479   bool isDone() { return false; }
480 };
481 
482 void findValues(const SCEV *Expr, ScalarEvolution &SE,
483                 SetVector<Value *> &Values) {
484   SCEVFindValues FindValues(SE, Values);
485   SCEVTraversal<SCEVFindValues> ST(FindValues);
486   ST.visitAll(Expr);
487 }
488 
489 bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R,
490                                llvm::Loop *Scope, bool AllowLoops) {
491   SCEVInRegionDependences InRegionDeps(R, Scope, AllowLoops);
492   SCEVTraversal<SCEVInRegionDependences> ST(InRegionDeps);
493   ST.visitAll(Expr);
494   return InRegionDeps.hasDependences();
495 }
496 
497 bool isAffineExpr(const Region *R, llvm::Loop *Scope, const SCEV *Expr,
498                   ScalarEvolution &SE, InvariantLoadsSetTy *ILS) {
499   if (isa<SCEVCouldNotCompute>(Expr))
500     return false;
501 
502   SCEVValidator Validator(R, Scope, SE, ILS);
503   DEBUG({
504     dbgs() << "\n";
505     dbgs() << "Expr: " << *Expr << "\n";
506     dbgs() << "Region: " << R->getNameStr() << "\n";
507     dbgs() << " -> ";
508   });
509 
510   ValidatorResult Result = Validator.visit(Expr);
511 
512   DEBUG({
513     if (Result.isValid())
514       dbgs() << "VALID\n";
515     dbgs() << "\n";
516   });
517 
518   return Result.isValid();
519 }
520 
521 static bool isAffineExpr(Value *V, const Region *R, Loop *Scope,
522                          ScalarEvolution &SE, ParameterSetTy &Params) {
523   auto *E = SE.getSCEV(V);
524   if (isa<SCEVCouldNotCompute>(E))
525     return false;
526 
527   SCEVValidator Validator(R, Scope, SE, nullptr);
528   ValidatorResult Result = Validator.visit(E);
529   if (!Result.isValid())
530     return false;
531 
532   auto ResultParams = Result.getParameters();
533   Params.insert(ResultParams.begin(), ResultParams.end());
534 
535   return true;
536 }
537 
538 bool isAffineConstraint(Value *V, const Region *R, llvm::Loop *Scope,
539                         ScalarEvolution &SE, ParameterSetTy &Params,
540                         bool OrExpr) {
541   if (auto *ICmp = dyn_cast<ICmpInst>(V)) {
542     return isAffineConstraint(ICmp->getOperand(0), R, Scope, SE, Params,
543                               true) &&
544            isAffineConstraint(ICmp->getOperand(1), R, Scope, SE, Params, true);
545   } else if (auto *BinOp = dyn_cast<BinaryOperator>(V)) {
546     auto Opcode = BinOp->getOpcode();
547     if (Opcode == Instruction::And || Opcode == Instruction::Or)
548       return isAffineConstraint(BinOp->getOperand(0), R, Scope, SE, Params,
549                                 false) &&
550              isAffineConstraint(BinOp->getOperand(1), R, Scope, SE, Params,
551                                 false);
552     /* Fall through */
553   }
554 
555   if (!OrExpr)
556     return false;
557 
558   return isAffineExpr(V, R, Scope, SE, Params);
559 }
560 
561 ParameterSetTy getParamsInAffineExpr(const Region *R, Loop *Scope,
562                                      const SCEV *Expr, ScalarEvolution &SE) {
563   if (isa<SCEVCouldNotCompute>(Expr))
564     return ParameterSetTy();
565 
566   InvariantLoadsSetTy ILS;
567   SCEVValidator Validator(R, Scope, SE, &ILS);
568   ValidatorResult Result = Validator.visit(Expr);
569   assert(Result.isValid() && "Requested parameters for an invalid SCEV!");
570 
571   return Result.getParameters();
572 }
573 
574 std::pair<const SCEVConstant *, const SCEV *>
575 extractConstantFactor(const SCEV *S, ScalarEvolution &SE) {
576 
577   auto *LeftOver = SE.getConstant(S->getType(), 1);
578   auto *ConstPart = cast<SCEVConstant>(SE.getConstant(S->getType(), 1));
579 
580   if (auto *Constant = dyn_cast<SCEVConstant>(S))
581     return std::make_pair(Constant, LeftOver);
582 
583   auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
584   if (AddRec) {
585     auto *StartExpr = AddRec->getStart();
586     if (StartExpr->isZero()) {
587       auto StepPair = extractConstantFactor(AddRec->getStepRecurrence(SE), SE);
588       auto *LeftOverAddRec =
589           SE.getAddRecExpr(StartExpr, StepPair.second, AddRec->getLoop(),
590                            AddRec->getNoWrapFlags());
591       return std::make_pair(StepPair.first, LeftOverAddRec);
592     }
593     return std::make_pair(ConstPart, S);
594   }
595 
596   if (auto *Add = dyn_cast<SCEVAddExpr>(S)) {
597     SmallVector<const SCEV *, 4> LeftOvers;
598     auto Op0Pair = extractConstantFactor(Add->getOperand(0), SE);
599     auto *Factor = Op0Pair.first;
600     if (SE.isKnownNegative(Factor)) {
601       Factor = cast<SCEVConstant>(SE.getNegativeSCEV(Factor));
602       LeftOvers.push_back(SE.getNegativeSCEV(Op0Pair.second));
603     } else {
604       LeftOvers.push_back(Op0Pair.second);
605     }
606 
607     for (unsigned u = 1, e = Add->getNumOperands(); u < e; u++) {
608       auto OpUPair = extractConstantFactor(Add->getOperand(u), SE);
609       // TODO: Use something smarter than equality here, e.g., gcd.
610       if (Factor == OpUPair.first)
611         LeftOvers.push_back(OpUPair.second);
612       else if (Factor == SE.getNegativeSCEV(OpUPair.first))
613         LeftOvers.push_back(SE.getNegativeSCEV(OpUPair.second));
614       else
615         return std::make_pair(ConstPart, S);
616     }
617 
618     auto *NewAdd = SE.getAddExpr(LeftOvers, Add->getNoWrapFlags());
619     return std::make_pair(Factor, NewAdd);
620   }
621 
622   auto *Mul = dyn_cast<SCEVMulExpr>(S);
623   if (!Mul)
624     return std::make_pair(ConstPart, S);
625 
626   for (auto *Op : Mul->operands())
627     if (isa<SCEVConstant>(Op))
628       ConstPart = cast<SCEVConstant>(SE.getMulExpr(ConstPart, Op));
629     else
630       LeftOver = SE.getMulExpr(LeftOver, Op);
631 
632   return std::make_pair(ConstPart, LeftOver);
633 }
634 }
635