1 //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the classes used to represent and build scalar expressions.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
14 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
15
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/SmallPtrSet.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/iterator_range.h"
20 #include "llvm/Analysis/ScalarEvolution.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/ValueHandle.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/ErrorHandling.h"
25 #include <cassert>
26 #include <cstddef>
27
28 namespace llvm {
29
30 class APInt;
31 class Constant;
32 class ConstantInt;
33 class ConstantRange;
34 class Loop;
35 class Type;
36 class Value;
37
38 enum SCEVTypes : unsigned short {
39 // These should be ordered in terms of increasing complexity to make the
40 // folders simpler.
41 scConstant,
42 scTruncate,
43 scZeroExtend,
44 scSignExtend,
45 scAddExpr,
46 scMulExpr,
47 scUDivExpr,
48 scAddRecExpr,
49 scUMaxExpr,
50 scSMaxExpr,
51 scUMinExpr,
52 scSMinExpr,
53 scSequentialUMinExpr,
54 scPtrToInt,
55 scUnknown,
56 scCouldNotCompute
57 };
58
59 /// This class represents a constant integer value.
60 class SCEVConstant : public SCEV {
61 friend class ScalarEvolution;
62
63 ConstantInt *V;
64
SCEVConstant(const FoldingSetNodeIDRef ID,ConstantInt * v)65 SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v)
66 : SCEV(ID, scConstant, 1), V(v) {}
67
68 public:
getValue()69 ConstantInt *getValue() const { return V; }
getAPInt()70 const APInt &getAPInt() const { return getValue()->getValue(); }
71
getType()72 Type *getType() const { return V->getType(); }
73
74 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)75 static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; }
76 };
77
computeExpressionSize(ArrayRef<const SCEV * > Args)78 inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) {
79 APInt Size(16, 1);
80 for (auto *Arg : Args)
81 Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize()));
82 return (unsigned short)Size.getZExtValue();
83 }
84
85 /// This is the base class for unary cast operator classes.
86 class SCEVCastExpr : public SCEV {
87 protected:
88 std::array<const SCEV *, 1> Operands;
89 Type *Ty;
90
91 SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op,
92 Type *ty);
93
94 public:
getOperand()95 const SCEV *getOperand() const { return Operands[0]; }
getOperand(unsigned i)96 const SCEV *getOperand(unsigned i) const {
97 assert(i == 0 && "Operand index out of range!");
98 return Operands[0];
99 }
100 using op_iterator = std::array<const SCEV *, 1>::const_iterator;
101 using op_range = iterator_range<op_iterator>;
102
operands()103 op_range operands() const {
104 return make_range(Operands.begin(), Operands.end());
105 }
getNumOperands()106 size_t getNumOperands() const { return 1; }
getType()107 Type *getType() const { return Ty; }
108
109 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)110 static bool classof(const SCEV *S) {
111 return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate ||
112 S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend;
113 }
114 };
115
116 /// This class represents a cast from a pointer to a pointer-sized integer
117 /// value.
118 class SCEVPtrToIntExpr : public SCEVCastExpr {
119 friend class ScalarEvolution;
120
121 SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy);
122
123 public:
124 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)125 static bool classof(const SCEV *S) { return S->getSCEVType() == scPtrToInt; }
126 };
127
128 /// This is the base class for unary integral cast operator classes.
129 class SCEVIntegralCastExpr : public SCEVCastExpr {
130 protected:
131 SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
132 const SCEV *op, Type *ty);
133
134 public:
135 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)136 static bool classof(const SCEV *S) {
137 return S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend ||
138 S->getSCEVType() == scSignExtend;
139 }
140 };
141
142 /// This class represents a truncation of an integer value to a
143 /// smaller integer value.
144 class SCEVTruncateExpr : public SCEVIntegralCastExpr {
145 friend class ScalarEvolution;
146
147 SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
148
149 public:
150 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)151 static bool classof(const SCEV *S) { return S->getSCEVType() == scTruncate; }
152 };
153
154 /// This class represents a zero extension of a small integer value
155 /// to a larger integer value.
156 class SCEVZeroExtendExpr : public SCEVIntegralCastExpr {
157 friend class ScalarEvolution;
158
159 SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
160
161 public:
162 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)163 static bool classof(const SCEV *S) {
164 return S->getSCEVType() == scZeroExtend;
165 }
166 };
167
168 /// This class represents a sign extension of a small integer value
169 /// to a larger integer value.
170 class SCEVSignExtendExpr : public SCEVIntegralCastExpr {
171 friend class ScalarEvolution;
172
173 SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
174
175 public:
176 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)177 static bool classof(const SCEV *S) {
178 return S->getSCEVType() == scSignExtend;
179 }
180 };
181
182 /// This node is a base class providing common functionality for
183 /// n'ary operators.
184 class SCEVNAryExpr : public SCEV {
185 protected:
186 // Since SCEVs are immutable, ScalarEvolution allocates operand
187 // arrays with its SCEVAllocator, so this class just needs a simple
188 // pointer rather than a more elaborate vector-like data structure.
189 // This also avoids the need for a non-trivial destructor.
190 const SCEV *const *Operands;
191 size_t NumOperands;
192
SCEVNAryExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)193 SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
194 const SCEV *const *O, size_t N)
195 : SCEV(ID, T, computeExpressionSize(makeArrayRef(O, N))), Operands(O),
196 NumOperands(N) {}
197
198 public:
getNumOperands()199 size_t getNumOperands() const { return NumOperands; }
200
getOperand(unsigned i)201 const SCEV *getOperand(unsigned i) const {
202 assert(i < NumOperands && "Operand index out of range!");
203 return Operands[i];
204 }
205
206 using op_iterator = const SCEV *const *;
207 using op_range = iterator_range<op_iterator>;
208
op_begin()209 op_iterator op_begin() const { return Operands; }
op_end()210 op_iterator op_end() const { return Operands + NumOperands; }
operands()211 op_range operands() const { return make_range(op_begin(), op_end()); }
212
213 NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const {
214 return (NoWrapFlags)(SubclassData & Mask);
215 }
216
hasNoUnsignedWrap()217 bool hasNoUnsignedWrap() const {
218 return getNoWrapFlags(FlagNUW) != FlagAnyWrap;
219 }
220
hasNoSignedWrap()221 bool hasNoSignedWrap() const {
222 return getNoWrapFlags(FlagNSW) != FlagAnyWrap;
223 }
224
hasNoSelfWrap()225 bool hasNoSelfWrap() const { return getNoWrapFlags(FlagNW) != FlagAnyWrap; }
226
227 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)228 static bool classof(const SCEV *S) {
229 return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
230 S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
231 S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr ||
232 S->getSCEVType() == scSequentialUMinExpr ||
233 S->getSCEVType() == scAddRecExpr;
234 }
235 };
236
237 /// This node is the base class for n'ary commutative operators.
238 class SCEVCommutativeExpr : public SCEVNAryExpr {
239 protected:
SCEVCommutativeExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)240 SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
241 const SCEV *const *O, size_t N)
242 : SCEVNAryExpr(ID, T, O, N) {}
243
244 public:
245 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)246 static bool classof(const SCEV *S) {
247 return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
248 S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
249 S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr;
250 }
251
252 /// Set flags for a non-recurrence without clearing previously set flags.
setNoWrapFlags(NoWrapFlags Flags)253 void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
254 };
255
256 /// This node represents an addition of some number of SCEVs.
257 class SCEVAddExpr : public SCEVCommutativeExpr {
258 friend class ScalarEvolution;
259
260 Type *Ty;
261
SCEVAddExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)262 SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
263 : SCEVCommutativeExpr(ID, scAddExpr, O, N) {
264 auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) {
265 return Op->getType()->isPointerTy();
266 });
267 if (FirstPointerTypedOp != operands().end())
268 Ty = (*FirstPointerTypedOp)->getType();
269 else
270 Ty = getOperand(0)->getType();
271 }
272
273 public:
getType()274 Type *getType() const { return Ty; }
275
276 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)277 static bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr; }
278 };
279
280 /// This node represents multiplication of some number of SCEVs.
281 class SCEVMulExpr : public SCEVCommutativeExpr {
282 friend class ScalarEvolution;
283
SCEVMulExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)284 SCEVMulExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
285 : SCEVCommutativeExpr(ID, scMulExpr, O, N) {}
286
287 public:
getType()288 Type *getType() const { return getOperand(0)->getType(); }
289
290 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)291 static bool classof(const SCEV *S) { return S->getSCEVType() == scMulExpr; }
292 };
293
294 /// This class represents a binary unsigned division operation.
295 class SCEVUDivExpr : public SCEV {
296 friend class ScalarEvolution;
297
298 std::array<const SCEV *, 2> Operands;
299
SCEVUDivExpr(const FoldingSetNodeIDRef ID,const SCEV * lhs,const SCEV * rhs)300 SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs)
301 : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})) {
302 Operands[0] = lhs;
303 Operands[1] = rhs;
304 }
305
306 public:
getLHS()307 const SCEV *getLHS() const { return Operands[0]; }
getRHS()308 const SCEV *getRHS() const { return Operands[1]; }
getNumOperands()309 size_t getNumOperands() const { return 2; }
getOperand(unsigned i)310 const SCEV *getOperand(unsigned i) const {
311 assert((i == 0 || i == 1) && "Operand index out of range!");
312 return i == 0 ? getLHS() : getRHS();
313 }
314
315 using op_iterator = std::array<const SCEV *, 2>::const_iterator;
316 using op_range = iterator_range<op_iterator>;
operands()317 op_range operands() const {
318 return make_range(Operands.begin(), Operands.end());
319 }
320
getType()321 Type *getType() const {
322 // In most cases the types of LHS and RHS will be the same, but in some
323 // crazy cases one or the other may be a pointer. ScalarEvolution doesn't
324 // depend on the type for correctness, but handling types carefully can
325 // avoid extra casts in the SCEVExpander. The LHS is more likely to be
326 // a pointer type than the RHS, so use the RHS' type here.
327 return getRHS()->getType();
328 }
329
330 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)331 static bool classof(const SCEV *S) { return S->getSCEVType() == scUDivExpr; }
332 };
333
334 /// This node represents a polynomial recurrence on the trip count
335 /// of the specified loop. This is the primary focus of the
336 /// ScalarEvolution framework; all the other SCEV subclasses are
337 /// mostly just supporting infrastructure to allow SCEVAddRecExpr
338 /// expressions to be created and analyzed.
339 ///
340 /// All operands of an AddRec are required to be loop invariant.
341 ///
342 class SCEVAddRecExpr : public SCEVNAryExpr {
343 friend class ScalarEvolution;
344
345 const Loop *L;
346
SCEVAddRecExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N,const Loop * l)347 SCEVAddRecExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N,
348 const Loop *l)
349 : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {}
350
351 public:
getType()352 Type *getType() const { return getStart()->getType(); }
getStart()353 const SCEV *getStart() const { return Operands[0]; }
getLoop()354 const Loop *getLoop() const { return L; }
355
356 /// Constructs and returns the recurrence indicating how much this
357 /// expression steps by. If this is a polynomial of degree N, it
358 /// returns a chrec of degree N-1. We cannot determine whether
359 /// the step recurrence has self-wraparound.
getStepRecurrence(ScalarEvolution & SE)360 const SCEV *getStepRecurrence(ScalarEvolution &SE) const {
361 if (isAffine())
362 return getOperand(1);
363 return SE.getAddRecExpr(
364 SmallVector<const SCEV *, 3>(op_begin() + 1, op_end()), getLoop(),
365 FlagAnyWrap);
366 }
367
368 /// Return true if this represents an expression A + B*x where A
369 /// and B are loop invariant values.
isAffine()370 bool isAffine() const {
371 // We know that the start value is invariant. This expression is thus
372 // affine iff the step is also invariant.
373 return getNumOperands() == 2;
374 }
375
376 /// Return true if this represents an expression A + B*x + C*x^2
377 /// where A, B and C are loop invariant values. This corresponds
378 /// to an addrec of the form {L,+,M,+,N}
isQuadratic()379 bool isQuadratic() const { return getNumOperands() == 3; }
380
381 /// Set flags for a recurrence without clearing any previously set flags.
382 /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here
383 /// to make it easier to propagate flags.
setNoWrapFlags(NoWrapFlags Flags)384 void setNoWrapFlags(NoWrapFlags Flags) {
385 if (Flags & (FlagNUW | FlagNSW))
386 Flags = ScalarEvolution::setFlags(Flags, FlagNW);
387 SubclassData |= Flags;
388 }
389
390 /// Return the value of this chain of recurrences at the specified
391 /// iteration number.
392 const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const;
393
394 /// Return the value of this chain of recurrences at the specified iteration
395 /// number. Takes an explicit list of operands to represent an AddRec.
396 static const SCEV *evaluateAtIteration(ArrayRef<const SCEV *> Operands,
397 const SCEV *It, ScalarEvolution &SE);
398
399 /// Return the number of iterations of this loop that produce
400 /// values in the specified constant range. Another way of
401 /// looking at this is that it returns the first iteration number
402 /// where the value is not in the condition, thus computing the
403 /// exit count. If the iteration count can't be computed, an
404 /// instance of SCEVCouldNotCompute is returned.
405 const SCEV *getNumIterationsInRange(const ConstantRange &Range,
406 ScalarEvolution &SE) const;
407
408 /// Return an expression representing the value of this expression
409 /// one iteration of the loop ahead.
410 const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const;
411
412 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)413 static bool classof(const SCEV *S) {
414 return S->getSCEVType() == scAddRecExpr;
415 }
416 };
417
418 /// This node is the base class min/max selections.
419 class SCEVMinMaxExpr : public SCEVCommutativeExpr {
420 friend class ScalarEvolution;
421
isMinMaxType(enum SCEVTypes T)422 static bool isMinMaxType(enum SCEVTypes T) {
423 return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr ||
424 T == scUMinExpr;
425 }
426
427 protected:
428 /// Note: Constructing subclasses via this constructor is allowed
SCEVMinMaxExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)429 SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
430 const SCEV *const *O, size_t N)
431 : SCEVCommutativeExpr(ID, T, O, N) {
432 assert(isMinMaxType(T));
433 // Min and max never overflow
434 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
435 }
436
437 public:
getType()438 Type *getType() const { return getOperand(0)->getType(); }
439
classof(const SCEV * S)440 static bool classof(const SCEV *S) { return isMinMaxType(S->getSCEVType()); }
441
negate(enum SCEVTypes T)442 static enum SCEVTypes negate(enum SCEVTypes T) {
443 switch (T) {
444 case scSMaxExpr:
445 return scSMinExpr;
446 case scSMinExpr:
447 return scSMaxExpr;
448 case scUMaxExpr:
449 return scUMinExpr;
450 case scUMinExpr:
451 return scUMaxExpr;
452 default:
453 llvm_unreachable("Not a min or max SCEV type!");
454 }
455 }
456 };
457
458 /// This class represents a signed maximum selection.
459 class SCEVSMaxExpr : public SCEVMinMaxExpr {
460 friend class ScalarEvolution;
461
SCEVSMaxExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)462 SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
463 : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {}
464
465 public:
466 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)467 static bool classof(const SCEV *S) { return S->getSCEVType() == scSMaxExpr; }
468 };
469
470 /// This class represents an unsigned maximum selection.
471 class SCEVUMaxExpr : public SCEVMinMaxExpr {
472 friend class ScalarEvolution;
473
SCEVUMaxExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)474 SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
475 : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {}
476
477 public:
478 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)479 static bool classof(const SCEV *S) { return S->getSCEVType() == scUMaxExpr; }
480 };
481
482 /// This class represents a signed minimum selection.
483 class SCEVSMinExpr : public SCEVMinMaxExpr {
484 friend class ScalarEvolution;
485
SCEVSMinExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)486 SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
487 : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {}
488
489 public:
490 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)491 static bool classof(const SCEV *S) { return S->getSCEVType() == scSMinExpr; }
492 };
493
494 /// This class represents an unsigned minimum selection.
495 class SCEVUMinExpr : public SCEVMinMaxExpr {
496 friend class ScalarEvolution;
497
SCEVUMinExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)498 SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
499 : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {}
500
501 public:
502 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)503 static bool classof(const SCEV *S) { return S->getSCEVType() == scUMinExpr; }
504 };
505
506 /// This node is the base class for sequential/in-order min/max selections.
507 /// Note that their fundamental difference from SCEVMinMaxExpr's is that they
508 /// are early-returning upon reaching saturation point.
509 /// I.e. given `0 umin_seq poison`, the result will be `0`,
510 /// while the result of `0 umin poison` is `poison`.
511 class SCEVSequentialMinMaxExpr : public SCEVNAryExpr {
512 friend class ScalarEvolution;
513
isSequentialMinMaxType(enum SCEVTypes T)514 static bool isSequentialMinMaxType(enum SCEVTypes T) {
515 return T == scSequentialUMinExpr;
516 }
517
518 /// Set flags for a non-recurrence without clearing previously set flags.
setNoWrapFlags(NoWrapFlags Flags)519 void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
520
521 protected:
522 /// Note: Constructing subclasses via this constructor is allowed
SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)523 SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
524 const SCEV *const *O, size_t N)
525 : SCEVNAryExpr(ID, T, O, N) {
526 assert(isSequentialMinMaxType(T));
527 // Min and max never overflow
528 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
529 }
530
531 public:
getType()532 Type *getType() const { return getOperand(0)->getType(); }
533
getEquivalentNonSequentialSCEVType(SCEVTypes Ty)534 static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty) {
535 assert(isSequentialMinMaxType(Ty));
536 switch (Ty) {
537 case scSequentialUMinExpr:
538 return scUMinExpr;
539 default:
540 llvm_unreachable("Not a sequential min/max type.");
541 }
542 }
543
getEquivalentNonSequentialSCEVType()544 SCEVTypes getEquivalentNonSequentialSCEVType() const {
545 return getEquivalentNonSequentialSCEVType(getSCEVType());
546 }
547
classof(const SCEV * S)548 static bool classof(const SCEV *S) {
549 return isSequentialMinMaxType(S->getSCEVType());
550 }
551 };
552
553 /// This class represents a sequential/in-order unsigned minimum selection.
554 class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr {
555 friend class ScalarEvolution;
556
SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)557 SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O,
558 size_t N)
559 : SCEVSequentialMinMaxExpr(ID, scSequentialUMinExpr, O, N) {}
560
561 public:
562 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)563 static bool classof(const SCEV *S) {
564 return S->getSCEVType() == scSequentialUMinExpr;
565 }
566 };
567
568 /// This means that we are dealing with an entirely unknown SCEV
569 /// value, and only represent it as its LLVM Value. This is the
570 /// "bottom" value for the analysis.
571 class SCEVUnknown final : public SCEV, private CallbackVH {
572 friend class ScalarEvolution;
573
574 /// The parent ScalarEvolution value. This is used to update the
575 /// parent's maps when the value associated with a SCEVUnknown is
576 /// deleted or RAUW'd.
577 ScalarEvolution *SE;
578
579 /// The next pointer in the linked list of all SCEVUnknown
580 /// instances owned by a ScalarEvolution.
581 SCEVUnknown *Next;
582
SCEVUnknown(const FoldingSetNodeIDRef ID,Value * V,ScalarEvolution * se,SCEVUnknown * next)583 SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V, ScalarEvolution *se,
584 SCEVUnknown *next)
585 : SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {}
586
587 // Implement CallbackVH.
588 void deleted() override;
589 void allUsesReplacedWith(Value *New) override;
590
591 public:
getValue()592 Value *getValue() const { return getValPtr(); }
593
594 /// @{
595 /// Test whether this is a special constant representing a type
596 /// size, alignment, or field offset in a target-independent
597 /// manner, and hasn't happened to have been folded with other
598 /// operations into something unrecognizable. This is mainly only
599 /// useful for pretty-printing and other situations where it isn't
600 /// absolutely required for these to succeed.
601 bool isSizeOf(Type *&AllocTy) const;
602 bool isAlignOf(Type *&AllocTy) const;
603 bool isOffsetOf(Type *&STy, Constant *&FieldNo) const;
604 /// @}
605
getType()606 Type *getType() const { return getValPtr()->getType(); }
607
608 /// Methods for support type inquiry through isa, cast, and dyn_cast:
classof(const SCEV * S)609 static bool classof(const SCEV *S) { return S->getSCEVType() == scUnknown; }
610 };
611
612 /// This class defines a simple visitor class that may be used for
613 /// various SCEV analysis purposes.
614 template <typename SC, typename RetVal = void> struct SCEVVisitor {
visitSCEVVisitor615 RetVal visit(const SCEV *S) {
616 switch (S->getSCEVType()) {
617 case scConstant:
618 return ((SC *)this)->visitConstant((const SCEVConstant *)S);
619 case scPtrToInt:
620 return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S);
621 case scTruncate:
622 return ((SC *)this)->visitTruncateExpr((const SCEVTruncateExpr *)S);
623 case scZeroExtend:
624 return ((SC *)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr *)S);
625 case scSignExtend:
626 return ((SC *)this)->visitSignExtendExpr((const SCEVSignExtendExpr *)S);
627 case scAddExpr:
628 return ((SC *)this)->visitAddExpr((const SCEVAddExpr *)S);
629 case scMulExpr:
630 return ((SC *)this)->visitMulExpr((const SCEVMulExpr *)S);
631 case scUDivExpr:
632 return ((SC *)this)->visitUDivExpr((const SCEVUDivExpr *)S);
633 case scAddRecExpr:
634 return ((SC *)this)->visitAddRecExpr((const SCEVAddRecExpr *)S);
635 case scSMaxExpr:
636 return ((SC *)this)->visitSMaxExpr((const SCEVSMaxExpr *)S);
637 case scUMaxExpr:
638 return ((SC *)this)->visitUMaxExpr((const SCEVUMaxExpr *)S);
639 case scSMinExpr:
640 return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S);
641 case scUMinExpr:
642 return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S);
643 case scSequentialUMinExpr:
644 return ((SC *)this)
645 ->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S);
646 case scUnknown:
647 return ((SC *)this)->visitUnknown((const SCEVUnknown *)S);
648 case scCouldNotCompute:
649 return ((SC *)this)->visitCouldNotCompute((const SCEVCouldNotCompute *)S);
650 }
651 llvm_unreachable("Unknown SCEV kind!");
652 }
653
visitCouldNotComputeSCEVVisitor654 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) {
655 llvm_unreachable("Invalid use of SCEVCouldNotCompute!");
656 }
657 };
658
659 /// Visit all nodes in the expression tree using worklist traversal.
660 ///
661 /// Visitor implements:
662 /// // return true to follow this node.
663 /// bool follow(const SCEV *S);
664 /// // return true to terminate the search.
665 /// bool isDone();
666 template <typename SV> class SCEVTraversal {
667 SV &Visitor;
668 SmallVector<const SCEV *, 8> Worklist;
669 SmallPtrSet<const SCEV *, 8> Visited;
670
push(const SCEV * S)671 void push(const SCEV *S) {
672 if (Visited.insert(S).second && Visitor.follow(S))
673 Worklist.push_back(S);
674 }
675
676 public:
SCEVTraversal(SV & V)677 SCEVTraversal(SV &V) : Visitor(V) {}
678
visitAll(const SCEV * Root)679 void visitAll(const SCEV *Root) {
680 push(Root);
681 while (!Worklist.empty() && !Visitor.isDone()) {
682 const SCEV *S = Worklist.pop_back_val();
683
684 switch (S->getSCEVType()) {
685 case scConstant:
686 case scUnknown:
687 continue;
688 case scPtrToInt:
689 case scTruncate:
690 case scZeroExtend:
691 case scSignExtend:
692 push(cast<SCEVCastExpr>(S)->getOperand());
693 continue;
694 case scAddExpr:
695 case scMulExpr:
696 case scSMaxExpr:
697 case scUMaxExpr:
698 case scSMinExpr:
699 case scUMinExpr:
700 case scSequentialUMinExpr:
701 case scAddRecExpr:
702 for (const auto *Op : cast<SCEVNAryExpr>(S)->operands()) {
703 push(Op);
704 if (Visitor.isDone())
705 break;
706 }
707 continue;
708 case scUDivExpr: {
709 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
710 push(UDiv->getLHS());
711 push(UDiv->getRHS());
712 continue;
713 }
714 case scCouldNotCompute:
715 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
716 }
717 llvm_unreachable("Unknown SCEV kind!");
718 }
719 }
720 };
721
722 /// Use SCEVTraversal to visit all nodes in the given expression tree.
visitAll(const SCEV * Root,SV & Visitor)723 template <typename SV> void visitAll(const SCEV *Root, SV &Visitor) {
724 SCEVTraversal<SV> T(Visitor);
725 T.visitAll(Root);
726 }
727
728 /// Return true if any node in \p Root satisfies the predicate \p Pred.
729 template <typename PredTy>
SCEVExprContains(const SCEV * Root,PredTy Pred)730 bool SCEVExprContains(const SCEV *Root, PredTy Pred) {
731 struct FindClosure {
732 bool Found = false;
733 PredTy Pred;
734
735 FindClosure(PredTy Pred) : Pred(Pred) {}
736
737 bool follow(const SCEV *S) {
738 if (!Pred(S))
739 return true;
740
741 Found = true;
742 return false;
743 }
744
745 bool isDone() const { return Found; }
746 };
747
748 FindClosure FC(Pred);
749 visitAll(Root, FC);
750 return FC.Found;
751 }
752
753 /// This visitor recursively visits a SCEV expression and re-writes it.
754 /// The result from each visit is cached, so it will return the same
755 /// SCEV for the same input.
756 template <typename SC>
757 class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {
758 protected:
759 ScalarEvolution &SE;
760 // Memoize the result of each visit so that we only compute once for
761 // the same input SCEV. This is to avoid redundant computations when
762 // a SCEV is referenced by multiple SCEVs. Without memoization, this
763 // visit algorithm would have exponential time complexity in the worst
764 // case, causing the compiler to hang on certain tests.
765 DenseMap<const SCEV *, const SCEV *> RewriteResults;
766
767 public:
SCEVRewriteVisitor(ScalarEvolution & SE)768 SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {}
769
visit(const SCEV * S)770 const SCEV *visit(const SCEV *S) {
771 auto It = RewriteResults.find(S);
772 if (It != RewriteResults.end())
773 return It->second;
774 auto *Visited = SCEVVisitor<SC, const SCEV *>::visit(S);
775 auto Result = RewriteResults.try_emplace(S, Visited);
776 assert(Result.second && "Should insert a new entry");
777 return Result.first->second;
778 }
779
visitConstant(const SCEVConstant * Constant)780 const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; }
781
visitPtrToIntExpr(const SCEVPtrToIntExpr * Expr)782 const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
783 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
784 return Operand == Expr->getOperand()
785 ? Expr
786 : SE.getPtrToIntExpr(Operand, Expr->getType());
787 }
788
visitTruncateExpr(const SCEVTruncateExpr * Expr)789 const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) {
790 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
791 return Operand == Expr->getOperand()
792 ? Expr
793 : SE.getTruncateExpr(Operand, Expr->getType());
794 }
795
visitZeroExtendExpr(const SCEVZeroExtendExpr * Expr)796 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
797 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
798 return Operand == Expr->getOperand()
799 ? Expr
800 : SE.getZeroExtendExpr(Operand, Expr->getType());
801 }
802
visitSignExtendExpr(const SCEVSignExtendExpr * Expr)803 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
804 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
805 return Operand == Expr->getOperand()
806 ? Expr
807 : SE.getSignExtendExpr(Operand, Expr->getType());
808 }
809
visitAddExpr(const SCEVAddExpr * Expr)810 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
811 SmallVector<const SCEV *, 2> Operands;
812 bool Changed = false;
813 for (auto *Op : Expr->operands()) {
814 Operands.push_back(((SC *)this)->visit(Op));
815 Changed |= Op != Operands.back();
816 }
817 return !Changed ? Expr : SE.getAddExpr(Operands);
818 }
819
visitMulExpr(const SCEVMulExpr * Expr)820 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
821 SmallVector<const SCEV *, 2> Operands;
822 bool Changed = false;
823 for (auto *Op : Expr->operands()) {
824 Operands.push_back(((SC *)this)->visit(Op));
825 Changed |= Op != Operands.back();
826 }
827 return !Changed ? Expr : SE.getMulExpr(Operands);
828 }
829
visitUDivExpr(const SCEVUDivExpr * Expr)830 const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
831 auto *LHS = ((SC *)this)->visit(Expr->getLHS());
832 auto *RHS = ((SC *)this)->visit(Expr->getRHS());
833 bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS();
834 return !Changed ? Expr : SE.getUDivExpr(LHS, RHS);
835 }
836
visitAddRecExpr(const SCEVAddRecExpr * Expr)837 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
838 SmallVector<const SCEV *, 2> Operands;
839 bool Changed = false;
840 for (auto *Op : Expr->operands()) {
841 Operands.push_back(((SC *)this)->visit(Op));
842 Changed |= Op != Operands.back();
843 }
844 return !Changed ? Expr
845 : SE.getAddRecExpr(Operands, Expr->getLoop(),
846 Expr->getNoWrapFlags());
847 }
848
visitSMaxExpr(const SCEVSMaxExpr * Expr)849 const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
850 SmallVector<const SCEV *, 2> Operands;
851 bool Changed = false;
852 for (auto *Op : Expr->operands()) {
853 Operands.push_back(((SC *)this)->visit(Op));
854 Changed |= Op != Operands.back();
855 }
856 return !Changed ? Expr : SE.getSMaxExpr(Operands);
857 }
858
visitUMaxExpr(const SCEVUMaxExpr * Expr)859 const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) {
860 SmallVector<const SCEV *, 2> Operands;
861 bool Changed = false;
862 for (auto *Op : Expr->operands()) {
863 Operands.push_back(((SC *)this)->visit(Op));
864 Changed |= Op != Operands.back();
865 }
866 return !Changed ? Expr : SE.getUMaxExpr(Operands);
867 }
868
visitSMinExpr(const SCEVSMinExpr * Expr)869 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
870 SmallVector<const SCEV *, 2> Operands;
871 bool Changed = false;
872 for (auto *Op : Expr->operands()) {
873 Operands.push_back(((SC *)this)->visit(Op));
874 Changed |= Op != Operands.back();
875 }
876 return !Changed ? Expr : SE.getSMinExpr(Operands);
877 }
878
visitUMinExpr(const SCEVUMinExpr * Expr)879 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
880 SmallVector<const SCEV *, 2> Operands;
881 bool Changed = false;
882 for (auto *Op : Expr->operands()) {
883 Operands.push_back(((SC *)this)->visit(Op));
884 Changed |= Op != Operands.back();
885 }
886 return !Changed ? Expr : SE.getUMinExpr(Operands);
887 }
888
visitSequentialUMinExpr(const SCEVSequentialUMinExpr * Expr)889 const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
890 SmallVector<const SCEV *, 2> Operands;
891 bool Changed = false;
892 for (auto *Op : Expr->operands()) {
893 Operands.push_back(((SC *)this)->visit(Op));
894 Changed |= Op != Operands.back();
895 }
896 return !Changed ? Expr : SE.getUMinExpr(Operands, /*Sequential=*/true);
897 }
898
visitUnknown(const SCEVUnknown * Expr)899 const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; }
900
visitCouldNotCompute(const SCEVCouldNotCompute * Expr)901 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
902 return Expr;
903 }
904 };
905
906 using ValueToValueMap = DenseMap<const Value *, Value *>;
907 using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>;
908
909 /// The SCEVParameterRewriter takes a scalar evolution expression and updates
910 /// the SCEVUnknown components following the Map (Value -> SCEV).
911 class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> {
912 public:
rewrite(const SCEV * Scev,ScalarEvolution & SE,ValueToSCEVMapTy & Map)913 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
914 ValueToSCEVMapTy &Map) {
915 SCEVParameterRewriter Rewriter(SE, Map);
916 return Rewriter.visit(Scev);
917 }
918
SCEVParameterRewriter(ScalarEvolution & SE,ValueToSCEVMapTy & M)919 SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M)
920 : SCEVRewriteVisitor(SE), Map(M) {}
921
visitUnknown(const SCEVUnknown * Expr)922 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
923 auto I = Map.find(Expr->getValue());
924 if (I == Map.end())
925 return Expr;
926 return I->second;
927 }
928
929 private:
930 ValueToSCEVMapTy ⤅
931 };
932
933 using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>;
934
935 /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies
936 /// the Map (Loop -> SCEV) to all AddRecExprs.
937 class SCEVLoopAddRecRewriter
938 : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> {
939 public:
SCEVLoopAddRecRewriter(ScalarEvolution & SE,LoopToScevMapT & M)940 SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M)
941 : SCEVRewriteVisitor(SE), Map(M) {}
942
rewrite(const SCEV * Scev,LoopToScevMapT & Map,ScalarEvolution & SE)943 static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map,
944 ScalarEvolution &SE) {
945 SCEVLoopAddRecRewriter Rewriter(SE, Map);
946 return Rewriter.visit(Scev);
947 }
948
visitAddRecExpr(const SCEVAddRecExpr * Expr)949 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
950 SmallVector<const SCEV *, 2> Operands;
951 for (const SCEV *Op : Expr->operands())
952 Operands.push_back(visit(Op));
953
954 const Loop *L = Expr->getLoop();
955 if (0 == Map.count(L))
956 return SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags());
957
958 return SCEVAddRecExpr::evaluateAtIteration(Operands, Map[L], SE);
959 }
960
961 private:
962 LoopToScevMapT ⤅
963 };
964
965 } // end namespace llvm
966
967 #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
968