1 
2 #include "polly/Support/SCEVValidator.h"
3 #include "polly/ScopDetection.h"
4 #include "llvm/Analysis/RegionInfo.h"
5 #include "llvm/Analysis/ScalarEvolution.h"
6 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
7 #include "llvm/Support/Debug.h"
8 
9 using namespace llvm;
10 using namespace polly;
11 
12 #define DEBUG_TYPE "polly-scev-validator"
13 
14 namespace SCEVType {
15 /// The type of a SCEV
16 ///
17 /// To check for the validity of a SCEV we assign to each SCEV a type. The
18 /// possible types are INT, PARAM, IV and INVALID. The order of the types is
19 /// important. The subexpressions of SCEV with a type X can only have a type
20 /// that is smaller or equal than X.
21 enum TYPE {
22   // An integer value.
23   INT,
24 
25   // An expression that is constant during the execution of the Scop,
26   // but that may depend on parameters unknown at compile time.
27   PARAM,
28 
29   // An expression that may change during the execution of the SCoP.
30   IV,
31 
32   // An invalid expression.
33   INVALID
34 };
35 } // namespace SCEVType
36 
37 /// The result the validator returns for a SCEV expression.
38 class ValidatorResult {
39   /// The type of the expression
40   SCEVType::TYPE Type;
41 
42   /// The set of Parameters in the expression.
43   ParameterSetTy Parameters;
44 
45 public:
46   /// The copy constructor
47   ValidatorResult(const ValidatorResult &Source) {
48     Type = Source.Type;
49     Parameters = Source.Parameters;
50   }
51 
52   /// Construct a result with a certain type and no parameters.
53   ValidatorResult(SCEVType::TYPE Type) : Type(Type) {
54     assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter");
55   }
56 
57   /// Construct a result with a certain type and a single parameter.
58   ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) {
59     Parameters.insert(Expr);
60   }
61 
62   /// Get the type of the ValidatorResult.
63   SCEVType::TYPE getType() { return Type; }
64 
65   /// Is the analyzed SCEV constant during the execution of the SCoP.
66   bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; }
67 
68   /// Is the analyzed SCEV valid.
69   bool isValid() { return Type != SCEVType::INVALID; }
70 
71   /// Is the analyzed SCEV of Type IV.
72   bool isIV() { return Type == SCEVType::IV; }
73 
74   /// Is the analyzed SCEV of Type INT.
75   bool isINT() { return Type == SCEVType::INT; }
76 
77   /// Is the analyzed SCEV of Type PARAM.
78   bool isPARAM() { return Type == SCEVType::PARAM; }
79 
80   /// Get the parameters of this validator result.
81   const ParameterSetTy &getParameters() { return Parameters; }
82 
83   /// Add the parameters of Source to this result.
84   void addParamsFrom(const ValidatorResult &Source) {
85     Parameters.insert(Source.Parameters.begin(), Source.Parameters.end());
86   }
87 
88   /// Merge a result.
89   ///
90   /// This means to merge the parameters and to set the Type to the most
91   /// specific Type that matches both.
92   void merge(const ValidatorResult &ToMerge) {
93     Type = std::max(Type, ToMerge.Type);
94     addParamsFrom(ToMerge);
95   }
96 
97   void print(raw_ostream &OS) {
98     switch (Type) {
99     case SCEVType::INT:
100       OS << "SCEVType::INT";
101       break;
102     case SCEVType::PARAM:
103       OS << "SCEVType::PARAM";
104       break;
105     case SCEVType::IV:
106       OS << "SCEVType::IV";
107       break;
108     case SCEVType::INVALID:
109       OS << "SCEVType::INVALID";
110       break;
111     }
112   }
113 };
114 
115 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) {
116   VR.print(OS);
117   return OS;
118 }
119 
120 /// Check if a SCEV is valid in a SCoP.
121 struct SCEVValidator
122     : public SCEVVisitor<SCEVValidator, class ValidatorResult> {
123 private:
124   const Region *R;
125   Loop *Scope;
126   ScalarEvolution &SE;
127   InvariantLoadsSetTy *ILS;
128 
129 public:
130   SCEVValidator(const Region *R, Loop *Scope, ScalarEvolution &SE,
131                 InvariantLoadsSetTy *ILS)
132       : R(R), Scope(Scope), SE(SE), ILS(ILS) {}
133 
134   class ValidatorResult visitConstant(const SCEVConstant *Constant) {
135     return ValidatorResult(SCEVType::INT);
136   }
137 
138   class ValidatorResult visitZeroExtendOrTruncateExpr(const SCEV *Expr,
139                                                       const SCEV *Operand) {
140     ValidatorResult Op = visit(Operand);
141     auto Type = Op.getType();
142 
143     // If unsigned operations are allowed return the operand, otherwise
144     // check if we can model the expression without unsigned assumptions.
145     if (PollyAllowUnsignedOperations || Type == SCEVType::INVALID)
146       return Op;
147 
148     if (Type == SCEVType::IV)
149       return ValidatorResult(SCEVType::INVALID);
150     return ValidatorResult(SCEVType::PARAM, Expr);
151   }
152 
153   class ValidatorResult visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
154     return visit(Expr->getOperand());
155   }
156 
157   class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) {
158     return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
159   }
160 
161   class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
162     return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
163   }
164 
165   class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
166     return visit(Expr->getOperand());
167   }
168 
169   class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) {
170     ValidatorResult Return(SCEVType::INT);
171 
172     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
173       ValidatorResult Op = visit(Expr->getOperand(i));
174       Return.merge(Op);
175 
176       // Early exit.
177       if (!Return.isValid())
178         break;
179     }
180 
181     return Return;
182   }
183 
184   class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) {
185     ValidatorResult Return(SCEVType::INT);
186 
187     bool HasMultipleParams = false;
188 
189     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
190       ValidatorResult Op = visit(Expr->getOperand(i));
191 
192       if (Op.isINT())
193         continue;
194 
195       if (Op.isPARAM() && Return.isPARAM()) {
196         HasMultipleParams = true;
197         continue;
198       }
199 
200       if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) {
201         LLVM_DEBUG(
202             dbgs() << "INVALID: More than one non-int operand in MulExpr\n"
203                    << "\tExpr: " << *Expr << "\n"
204                    << "\tPrevious expression type: " << Return << "\n"
205                    << "\tNext operand (" << Op << "): " << *Expr->getOperand(i)
206                    << "\n");
207 
208         return ValidatorResult(SCEVType::INVALID);
209       }
210 
211       Return.merge(Op);
212     }
213 
214     if (HasMultipleParams && Return.isValid())
215       return ValidatorResult(SCEVType::PARAM, Expr);
216 
217     return Return;
218   }
219 
220   class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) {
221     if (!Expr->isAffine()) {
222       LLVM_DEBUG(dbgs() << "INVALID: AddRec is not affine");
223       return ValidatorResult(SCEVType::INVALID);
224     }
225 
226     ValidatorResult Start = visit(Expr->getStart());
227     ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE));
228 
229     if (!Start.isValid())
230       return Start;
231 
232     if (!Recurrence.isValid())
233       return Recurrence;
234 
235     auto *L = Expr->getLoop();
236     if (R->contains(L) && (!Scope || !L->contains(Scope))) {
237       LLVM_DEBUG(
238           dbgs() << "INVALID: Loop of AddRec expression boxed in an a "
239                     "non-affine subregion or has a non-synthesizable exit "
240                     "value.");
241       return ValidatorResult(SCEVType::INVALID);
242     }
243 
244     if (R->contains(L)) {
245       if (Recurrence.isINT()) {
246         ValidatorResult Result(SCEVType::IV);
247         Result.addParamsFrom(Start);
248         return Result;
249       }
250 
251       LLVM_DEBUG(dbgs() << "INVALID: AddRec within scop has non-int"
252                            "recurrence part");
253       return ValidatorResult(SCEVType::INVALID);
254     }
255 
256     assert(Recurrence.isConstant() && "Expected 'Recurrence' to be constant");
257 
258     // Directly generate ValidatorResult for Expr if 'start' is zero.
259     if (Expr->getStart()->isZero())
260       return ValidatorResult(SCEVType::PARAM, Expr);
261 
262     // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}'
263     // if 'start' is not zero.
264     const SCEV *ZeroStartExpr = SE.getAddRecExpr(
265         SE.getConstant(Expr->getStart()->getType(), 0),
266         Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags());
267 
268     ValidatorResult ZeroStartResult =
269         ValidatorResult(SCEVType::PARAM, ZeroStartExpr);
270     ZeroStartResult.addParamsFrom(Start);
271 
272     return ZeroStartResult;
273   }
274 
275   class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) {
276     ValidatorResult Return(SCEVType::INT);
277 
278     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
279       ValidatorResult Op = visit(Expr->getOperand(i));
280 
281       if (!Op.isValid())
282         return Op;
283 
284       Return.merge(Op);
285     }
286 
287     return Return;
288   }
289 
290   class ValidatorResult visitSMinExpr(const SCEVSMinExpr *Expr) {
291     ValidatorResult Return(SCEVType::INT);
292 
293     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
294       ValidatorResult Op = visit(Expr->getOperand(i));
295 
296       if (!Op.isValid())
297         return Op;
298 
299       Return.merge(Op);
300     }
301 
302     return Return;
303   }
304 
305   class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) {
306     // We do not support unsigned max operations. If 'Expr' is constant during
307     // Scop execution we treat this as a parameter, otherwise we bail out.
308     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
309       ValidatorResult Op = visit(Expr->getOperand(i));
310 
311       if (!Op.isConstant()) {
312         LLVM_DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand");
313         return ValidatorResult(SCEVType::INVALID);
314       }
315     }
316 
317     return ValidatorResult(SCEVType::PARAM, Expr);
318   }
319 
320   class ValidatorResult visitUMinExpr(const SCEVUMinExpr *Expr) {
321     // We do not support unsigned min operations. If 'Expr' is constant during
322     // Scop execution we treat this as a parameter, otherwise we bail out.
323     for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
324       ValidatorResult Op = visit(Expr->getOperand(i));
325 
326       if (!Op.isConstant()) {
327         LLVM_DEBUG(dbgs() << "INVALID: UMinExpr has a non-constant operand");
328         return ValidatorResult(SCEVType::INVALID);
329       }
330     }
331 
332     return ValidatorResult(SCEVType::PARAM, Expr);
333   }
334 
335   ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) {
336     if (R->contains(I)) {
337       LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction "
338                            "within the region\n");
339       return ValidatorResult(SCEVType::INVALID);
340     }
341 
342     return ValidatorResult(SCEVType::PARAM, S);
343   }
344 
345   ValidatorResult visitLoadInstruction(Instruction *I, const SCEV *S) {
346     if (R->contains(I) && ILS) {
347       ILS->insert(cast<LoadInst>(I));
348       return ValidatorResult(SCEVType::PARAM, S);
349     }
350 
351     return visitGenericInst(I, S);
352   }
353 
354   ValidatorResult visitDivision(const SCEV *Dividend, const SCEV *Divisor,
355                                 const SCEV *DivExpr,
356                                 Instruction *SDiv = nullptr) {
357 
358     // First check if we might be able to model the division, thus if the
359     // divisor is constant. If so, check the dividend, otherwise check if
360     // the whole division can be seen as a parameter.
361     if (isa<SCEVConstant>(Divisor) && !Divisor->isZero())
362       return visit(Dividend);
363 
364     // For signed divisions use the SDiv instruction to check for a parameter
365     // division, for unsigned divisions check the operands.
366     if (SDiv)
367       return visitGenericInst(SDiv, DivExpr);
368 
369     ValidatorResult LHS = visit(Dividend);
370     ValidatorResult RHS = visit(Divisor);
371     if (LHS.isConstant() && RHS.isConstant())
372       return ValidatorResult(SCEVType::PARAM, DivExpr);
373 
374     LLVM_DEBUG(
375         dbgs() << "INVALID: unsigned division of non-constant expressions");
376     return ValidatorResult(SCEVType::INVALID);
377   }
378 
379   ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) {
380     if (!PollyAllowUnsignedOperations)
381       return ValidatorResult(SCEVType::INVALID);
382 
383     auto *Dividend = Expr->getLHS();
384     auto *Divisor = Expr->getRHS();
385     return visitDivision(Dividend, Divisor, Expr);
386   }
387 
388   ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *Expr) {
389     assert(SDiv->getOpcode() == Instruction::SDiv &&
390            "Assumed SDiv instruction!");
391 
392     auto *Dividend = SE.getSCEV(SDiv->getOperand(0));
393     auto *Divisor = SE.getSCEV(SDiv->getOperand(1));
394     return visitDivision(Dividend, Divisor, Expr, SDiv);
395   }
396 
397   ValidatorResult visitSRemInstruction(Instruction *SRem, const SCEV *S) {
398     assert(SRem->getOpcode() == Instruction::SRem &&
399            "Assumed SRem instruction!");
400 
401     auto *Divisor = SRem->getOperand(1);
402     auto *CI = dyn_cast<ConstantInt>(Divisor);
403     if (!CI || CI->isZeroValue())
404       return visitGenericInst(SRem, S);
405 
406     auto *Dividend = SRem->getOperand(0);
407     auto *DividendSCEV = SE.getSCEV(Dividend);
408     return visit(DividendSCEV);
409   }
410 
411   ValidatorResult visitUnknown(const SCEVUnknown *Expr) {
412     Value *V = Expr->getValue();
413 
414     if (!Expr->getType()->isIntegerTy() && !Expr->getType()->isPointerTy()) {
415       LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer or pointer");
416       return ValidatorResult(SCEVType::INVALID);
417     }
418 
419     if (isa<UndefValue>(V)) {
420       LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value");
421       return ValidatorResult(SCEVType::INVALID);
422     }
423 
424     if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) {
425       switch (I->getOpcode()) {
426       case Instruction::IntToPtr:
427         return visit(SE.getSCEVAtScope(I->getOperand(0), Scope));
428       case Instruction::Load:
429         return visitLoadInstruction(I, Expr);
430       case Instruction::SDiv:
431         return visitSDivInstruction(I, Expr);
432       case Instruction::SRem:
433         return visitSRemInstruction(I, Expr);
434       default:
435         return visitGenericInst(I, Expr);
436       }
437     }
438 
439     if (Expr->getType()->isPointerTy()) {
440       if (isa<ConstantPointerNull>(V))
441         return ValidatorResult(SCEVType::INT); // "int"
442     }
443 
444     return ValidatorResult(SCEVType::PARAM, Expr);
445   }
446 };
447 
448 /// Check whether a SCEV refers to an SSA name defined inside a region.
449 class SCEVInRegionDependences {
450   const Region *R;
451   Loop *Scope;
452   const InvariantLoadsSetTy &ILS;
453   bool AllowLoops;
454   bool HasInRegionDeps = false;
455 
456 public:
457   SCEVInRegionDependences(const Region *R, Loop *Scope, bool AllowLoops,
458                           const InvariantLoadsSetTy &ILS)
459       : R(R), Scope(Scope), ILS(ILS), AllowLoops(AllowLoops) {}
460 
461   bool follow(const SCEV *S) {
462     if (auto Unknown = dyn_cast<SCEVUnknown>(S)) {
463       Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
464 
465       if (Inst) {
466         // When we invariant load hoist a load, we first make sure that there
467         // can be no dependences created by it in the Scop region. So, we should
468         // not consider scalar dependences to `LoadInst`s that are invariant
469         // load hoisted.
470         //
471         // If this check is not present, then we create data dependences which
472         // are strictly not necessary by tracking the invariant load as a
473         // scalar.
474         LoadInst *LI = dyn_cast<LoadInst>(Inst);
475         if (LI && ILS.count(LI) > 0)
476           return false;
477       }
478 
479       // Return true when Inst is defined inside the region R.
480       if (!Inst || !R->contains(Inst))
481         return true;
482 
483       HasInRegionDeps = true;
484       return false;
485     }
486 
487     if (auto AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
488       if (AllowLoops)
489         return true;
490 
491       auto *L = AddRec->getLoop();
492       if (R->contains(L) && !L->contains(Scope)) {
493         HasInRegionDeps = true;
494         return false;
495       }
496     }
497 
498     return true;
499   }
500   bool isDone() { return false; }
501   bool hasDependences() { return HasInRegionDeps; }
502 };
503 
504 namespace polly {
505 /// Find all loops referenced in SCEVAddRecExprs.
506 class SCEVFindLoops {
507   SetVector<const Loop *> &Loops;
508 
509 public:
510   SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {}
511 
512   bool follow(const SCEV *S) {
513     if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S))
514       Loops.insert(AddRec->getLoop());
515     return true;
516   }
517   bool isDone() { return false; }
518 };
519 
520 void findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) {
521   SCEVFindLoops FindLoops(Loops);
522   SCEVTraversal<SCEVFindLoops> ST(FindLoops);
523   ST.visitAll(Expr);
524 }
525 
526 /// Find all values referenced in SCEVUnknowns.
527 class SCEVFindValues {
528   ScalarEvolution &SE;
529   SetVector<Value *> &Values;
530 
531 public:
532   SCEVFindValues(ScalarEvolution &SE, SetVector<Value *> &Values)
533       : SE(SE), Values(Values) {}
534 
535   bool follow(const SCEV *S) {
536     const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S);
537     if (!Unknown)
538       return true;
539 
540     Values.insert(Unknown->getValue());
541     Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
542     if (!Inst || (Inst->getOpcode() != Instruction::SRem &&
543                   Inst->getOpcode() != Instruction::SDiv))
544       return false;
545 
546     auto *Dividend = SE.getSCEV(Inst->getOperand(1));
547     if (!isa<SCEVConstant>(Dividend))
548       return false;
549 
550     auto *Divisor = SE.getSCEV(Inst->getOperand(0));
551     SCEVFindValues FindValues(SE, Values);
552     SCEVTraversal<SCEVFindValues> ST(FindValues);
553     ST.visitAll(Dividend);
554     ST.visitAll(Divisor);
555 
556     return false;
557   }
558   bool isDone() { return false; }
559 };
560 
561 void findValues(const SCEV *Expr, ScalarEvolution &SE,
562                 SetVector<Value *> &Values) {
563   SCEVFindValues FindValues(SE, Values);
564   SCEVTraversal<SCEVFindValues> ST(FindValues);
565   ST.visitAll(Expr);
566 }
567 
568 bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R,
569                                llvm::Loop *Scope, bool AllowLoops,
570                                const InvariantLoadsSetTy &ILS) {
571   SCEVInRegionDependences InRegionDeps(R, Scope, AllowLoops, ILS);
572   SCEVTraversal<SCEVInRegionDependences> ST(InRegionDeps);
573   ST.visitAll(Expr);
574   return InRegionDeps.hasDependences();
575 }
576 
577 bool isAffineExpr(const Region *R, llvm::Loop *Scope, const SCEV *Expr,
578                   ScalarEvolution &SE, InvariantLoadsSetTy *ILS) {
579   if (isa<SCEVCouldNotCompute>(Expr))
580     return false;
581 
582   SCEVValidator Validator(R, Scope, SE, ILS);
583   LLVM_DEBUG({
584     dbgs() << "\n";
585     dbgs() << "Expr: " << *Expr << "\n";
586     dbgs() << "Region: " << R->getNameStr() << "\n";
587     dbgs() << " -> ";
588   });
589 
590   ValidatorResult Result = Validator.visit(Expr);
591 
592   LLVM_DEBUG({
593     if (Result.isValid())
594       dbgs() << "VALID\n";
595     dbgs() << "\n";
596   });
597 
598   return Result.isValid();
599 }
600 
601 static bool isAffineExpr(Value *V, const Region *R, Loop *Scope,
602                          ScalarEvolution &SE, ParameterSetTy &Params) {
603   auto *E = SE.getSCEV(V);
604   if (isa<SCEVCouldNotCompute>(E))
605     return false;
606 
607   SCEVValidator Validator(R, Scope, SE, nullptr);
608   ValidatorResult Result = Validator.visit(E);
609   if (!Result.isValid())
610     return false;
611 
612   auto ResultParams = Result.getParameters();
613   Params.insert(ResultParams.begin(), ResultParams.end());
614 
615   return true;
616 }
617 
618 bool isAffineConstraint(Value *V, const Region *R, llvm::Loop *Scope,
619                         ScalarEvolution &SE, ParameterSetTy &Params,
620                         bool OrExpr) {
621   if (auto *ICmp = dyn_cast<ICmpInst>(V)) {
622     return isAffineConstraint(ICmp->getOperand(0), R, Scope, SE, Params,
623                               true) &&
624            isAffineConstraint(ICmp->getOperand(1), R, Scope, SE, Params, true);
625   } else if (auto *BinOp = dyn_cast<BinaryOperator>(V)) {
626     auto Opcode = BinOp->getOpcode();
627     if (Opcode == Instruction::And || Opcode == Instruction::Or)
628       return isAffineConstraint(BinOp->getOperand(0), R, Scope, SE, Params,
629                                 false) &&
630              isAffineConstraint(BinOp->getOperand(1), R, Scope, SE, Params,
631                                 false);
632     /* Fall through */
633   }
634 
635   if (!OrExpr)
636     return false;
637 
638   return isAffineExpr(V, R, Scope, SE, Params);
639 }
640 
641 ParameterSetTy getParamsInAffineExpr(const Region *R, Loop *Scope,
642                                      const SCEV *Expr, ScalarEvolution &SE) {
643   if (isa<SCEVCouldNotCompute>(Expr))
644     return ParameterSetTy();
645 
646   InvariantLoadsSetTy ILS;
647   SCEVValidator Validator(R, Scope, SE, &ILS);
648   ValidatorResult Result = Validator.visit(Expr);
649   assert(Result.isValid() && "Requested parameters for an invalid SCEV!");
650 
651   return Result.getParameters();
652 }
653 
654 std::pair<const SCEVConstant *, const SCEV *>
655 extractConstantFactor(const SCEV *S, ScalarEvolution &SE) {
656   auto *ConstPart = cast<SCEVConstant>(SE.getConstant(S->getType(), 1));
657 
658   if (auto *Constant = dyn_cast<SCEVConstant>(S))
659     return std::make_pair(Constant, SE.getConstant(S->getType(), 1));
660 
661   auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
662   if (AddRec) {
663     auto *StartExpr = AddRec->getStart();
664     if (StartExpr->isZero()) {
665       auto StepPair = extractConstantFactor(AddRec->getStepRecurrence(SE), SE);
666       auto *LeftOverAddRec =
667           SE.getAddRecExpr(StartExpr, StepPair.second, AddRec->getLoop(),
668                            AddRec->getNoWrapFlags());
669       return std::make_pair(StepPair.first, LeftOverAddRec);
670     }
671     return std::make_pair(ConstPart, S);
672   }
673 
674   if (auto *Add = dyn_cast<SCEVAddExpr>(S)) {
675     SmallVector<const SCEV *, 4> LeftOvers;
676     auto Op0Pair = extractConstantFactor(Add->getOperand(0), SE);
677     auto *Factor = Op0Pair.first;
678     if (SE.isKnownNegative(Factor)) {
679       Factor = cast<SCEVConstant>(SE.getNegativeSCEV(Factor));
680       LeftOvers.push_back(SE.getNegativeSCEV(Op0Pair.second));
681     } else {
682       LeftOvers.push_back(Op0Pair.second);
683     }
684 
685     for (unsigned u = 1, e = Add->getNumOperands(); u < e; u++) {
686       auto OpUPair = extractConstantFactor(Add->getOperand(u), SE);
687       // TODO: Use something smarter than equality here, e.g., gcd.
688       if (Factor == OpUPair.first)
689         LeftOvers.push_back(OpUPair.second);
690       else if (Factor == SE.getNegativeSCEV(OpUPair.first))
691         LeftOvers.push_back(SE.getNegativeSCEV(OpUPair.second));
692       else
693         return std::make_pair(ConstPart, S);
694     }
695 
696     auto *NewAdd = SE.getAddExpr(LeftOvers, Add->getNoWrapFlags());
697     return std::make_pair(Factor, NewAdd);
698   }
699 
700   auto *Mul = dyn_cast<SCEVMulExpr>(S);
701   if (!Mul)
702     return std::make_pair(ConstPart, S);
703 
704   SmallVector<const SCEV *, 4> LeftOvers;
705   for (auto *Op : Mul->operands())
706     if (isa<SCEVConstant>(Op))
707       ConstPart = cast<SCEVConstant>(SE.getMulExpr(ConstPart, Op));
708     else
709       LeftOvers.push_back(Op);
710 
711   return std::make_pair(ConstPart, SE.getMulExpr(LeftOvers));
712 }
713 
714 const SCEV *tryForwardThroughPHI(const SCEV *Expr, Region &R,
715                                  ScalarEvolution &SE, ScopDetection *SD) {
716   if (auto *Unknown = dyn_cast<SCEVUnknown>(Expr)) {
717     Value *V = Unknown->getValue();
718     auto *PHI = dyn_cast<PHINode>(V);
719     if (!PHI)
720       return Expr;
721 
722     Value *Final = nullptr;
723 
724     for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) {
725       BasicBlock *Incoming = PHI->getIncomingBlock(i);
726       if (SD->isErrorBlock(*Incoming, R) && R.contains(Incoming))
727         continue;
728       if (Final)
729         return Expr;
730       Final = PHI->getIncomingValue(i);
731     }
732 
733     if (Final)
734       return SE.getSCEV(Final);
735   }
736   return Expr;
737 }
738 
739 Value *getUniqueNonErrorValue(PHINode *PHI, Region *R, ScopDetection *SD) {
740   Value *V = nullptr;
741   for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) {
742     BasicBlock *BB = PHI->getIncomingBlock(i);
743     if (!SD->isErrorBlock(*BB, *R)) {
744       if (V)
745         return nullptr;
746       V = PHI->getIncomingValue(i);
747     }
748   }
749 
750   return V;
751 }
752 } // namespace polly
753