1 //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Interfaces/InferIntRangeInterface.h"
11 
12 #include "llvm/Support/Debug.h"
13 
14 #define DEBUG_TYPE "int-range-analysis"
15 
16 using namespace mlir;
17 using namespace mlir::arith;
18 
19 /// Function that evaluates the result of doing something on arithmetic
20 /// constants and returns None on overflow.
21 using ConstArithFn =
22     function_ref<Optional<APInt>(const APInt &, const APInt &)>;
23 
24 /// Return the maxmially wide signed or unsigned range for a given bitwidth.
25 
26 /// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
27 /// If either computation overflows, make the result unbounded.
computeBoundsBy(ConstArithFn op,const APInt & minLeft,const APInt & minRight,const APInt & maxLeft,const APInt & maxRight,bool isSigned)28 static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
29                                          const APInt &minRight,
30                                          const APInt &maxLeft,
31                                          const APInt &maxRight, bool isSigned) {
32   Optional<APInt> maybeMin = op(minLeft, minRight);
33   Optional<APInt> maybeMax = op(maxLeft, maxRight);
34   if (maybeMin && maybeMax)
35     return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned);
36   return ConstantIntRanges::maxRange(minLeft.getBitWidth());
37 }
38 
39 /// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`,
40 /// ignoring unbounded values. Returns the maximal range if `op` overflows.
minMaxBy(ConstArithFn op,ArrayRef<APInt> lhs,ArrayRef<APInt> rhs,bool isSigned)41 static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs,
42                                   ArrayRef<APInt> rhs, bool isSigned) {
43   unsigned width = lhs[0].getBitWidth();
44   APInt min =
45       isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width);
46   APInt max =
47       isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width);
48   for (const APInt &left : lhs) {
49     for (const APInt &right : rhs) {
50       Optional<APInt> maybeThisResult = op(left, right);
51       if (!maybeThisResult)
52         return ConstantIntRanges::maxRange(width);
53       APInt result = std::move(*maybeThisResult);
54       min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min;
55       max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max;
56     }
57   }
58   return ConstantIntRanges::range(min, max, isSigned);
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // ConstantOp
63 //===----------------------------------------------------------------------===//
64 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)65 void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
66                                           SetIntRangeFn setResultRange) {
67   auto constAttr = getValue().dyn_cast_or_null<IntegerAttr>();
68   if (constAttr) {
69     const APInt &value = constAttr.getValue();
70     setResultRange(getResult(), ConstantIntRanges::constant(value));
71   }
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // AddIOp
76 //===----------------------------------------------------------------------===//
77 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)78 void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
79                                       SetIntRangeFn setResultRange) {
80   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
81   ConstArithFn uadd = [](const APInt &a, const APInt &b) -> Optional<APInt> {
82     bool overflowed = false;
83     APInt result = a.uadd_ov(b, overflowed);
84     return overflowed ? Optional<APInt>() : result;
85   };
86   ConstArithFn sadd = [](const APInt &a, const APInt &b) -> Optional<APInt> {
87     bool overflowed = false;
88     APInt result = a.sadd_ov(b, overflowed);
89     return overflowed ? Optional<APInt>() : result;
90   };
91 
92   ConstantIntRanges urange = computeBoundsBy(
93       uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
94   ConstantIntRanges srange = computeBoundsBy(
95       sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
96   setResultRange(getResult(), urange.intersection(srange));
97 }
98 
99 //===----------------------------------------------------------------------===//
100 // SubIOp
101 //===----------------------------------------------------------------------===//
102 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)103 void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
104                                       SetIntRangeFn setResultRange) {
105   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
106 
107   ConstArithFn usub = [](const APInt &a, const APInt &b) -> Optional<APInt> {
108     bool overflowed = false;
109     APInt result = a.usub_ov(b, overflowed);
110     return overflowed ? Optional<APInt>() : result;
111   };
112   ConstArithFn ssub = [](const APInt &a, const APInt &b) -> Optional<APInt> {
113     bool overflowed = false;
114     APInt result = a.ssub_ov(b, overflowed);
115     return overflowed ? Optional<APInt>() : result;
116   };
117   ConstantIntRanges urange = computeBoundsBy(
118       usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
119   ConstantIntRanges srange = computeBoundsBy(
120       ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
121   setResultRange(getResult(), urange.intersection(srange));
122 }
123 
124 //===----------------------------------------------------------------------===//
125 // MulIOp
126 //===----------------------------------------------------------------------===//
127 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)128 void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
129                                       SetIntRangeFn setResultRange) {
130   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
131 
132   ConstArithFn umul = [](const APInt &a, const APInt &b) -> Optional<APInt> {
133     bool overflowed = false;
134     APInt result = a.umul_ov(b, overflowed);
135     return overflowed ? Optional<APInt>() : result;
136   };
137   ConstArithFn smul = [](const APInt &a, const APInt &b) -> Optional<APInt> {
138     bool overflowed = false;
139     APInt result = a.smul_ov(b, overflowed);
140     return overflowed ? Optional<APInt>() : result;
141   };
142 
143   ConstantIntRanges urange =
144       minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
145                /*isSigned=*/false);
146   ConstantIntRanges srange =
147       minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
148                /*isSigned=*/true);
149 
150   setResultRange(getResult(), urange.intersection(srange));
151 }
152 
153 //===----------------------------------------------------------------------===//
154 // DivUIOp
155 //===----------------------------------------------------------------------===//
156 
157 /// Fix up division results (ex. for ceiling and floor), returning an APInt
158 /// if there has been no overflow
159 using DivisionFixupFn = function_ref<Optional<APInt>(
160     const APInt &lhs, const APInt &rhs, const APInt &result)>;
161 
inferDivUIRange(const ConstantIntRanges & lhs,const ConstantIntRanges & rhs,DivisionFixupFn fixup)162 static ConstantIntRanges inferDivUIRange(const ConstantIntRanges &lhs,
163                                          const ConstantIntRanges &rhs,
164                                          DivisionFixupFn fixup) {
165   const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(),
166               &rhsMax = rhs.umax();
167 
168   if (!rhsMin.isZero()) {
169     auto udiv = [&fixup](const APInt &a, const APInt &b) -> Optional<APInt> {
170       return fixup(a, b, a.udiv(b));
171     };
172     return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
173                     /*isSigned=*/false);
174   }
175   // Otherwise, it's possible we might divide by 0.
176   return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
177 }
178 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)179 void arith::DivUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
180                                        SetIntRangeFn setResultRange) {
181   setResultRange(getResult(),
182                  inferDivUIRange(argRanges[0], argRanges[1],
183                                  [](const APInt &lhs, const APInt &rhs,
184                                     const APInt &result) { return result; }));
185 }
186 
187 //===----------------------------------------------------------------------===//
188 // DivSIOp
189 //===----------------------------------------------------------------------===//
190 
inferDivSIRange(const ConstantIntRanges & lhs,const ConstantIntRanges & rhs,DivisionFixupFn fixup)191 static ConstantIntRanges inferDivSIRange(const ConstantIntRanges &lhs,
192                                          const ConstantIntRanges &rhs,
193                                          DivisionFixupFn fixup) {
194   const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
195               &rhsMax = rhs.smax();
196   bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative();
197 
198   if (canDivide) {
199     auto sdiv = [&fixup](const APInt &a, const APInt &b) -> Optional<APInt> {
200       bool overflowed = false;
201       APInt result = a.sdiv_ov(b, overflowed);
202       return overflowed ? Optional<APInt>() : fixup(a, b, result);
203     };
204     return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
205                     /*isSigned=*/true);
206   }
207   return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
208 }
209 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)210 void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
211                                        SetIntRangeFn setResultRange) {
212   setResultRange(getResult(),
213                  inferDivSIRange(argRanges[0], argRanges[1],
214                                  [](const APInt &lhs, const APInt &rhs,
215                                     const APInt &result) { return result; }));
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // CeilDivUIOp
220 //===----------------------------------------------------------------------===//
221 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)222 void arith::CeilDivUIOp::inferResultRanges(
223     ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
224   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
225 
226   DivisionFixupFn ceilDivUIFix = [](const APInt &lhs, const APInt &rhs,
227                                     const APInt &result) -> Optional<APInt> {
228     if (!lhs.urem(rhs).isZero()) {
229       bool overflowed = false;
230       APInt corrected =
231           result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed);
232       return overflowed ? Optional<APInt>() : corrected;
233     }
234     return result;
235   };
236   setResultRange(getResult(), inferDivUIRange(lhs, rhs, ceilDivUIFix));
237 }
238 
239 //===----------------------------------------------------------------------===//
240 // CeilDivSIOp
241 //===----------------------------------------------------------------------===//
242 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)243 void arith::CeilDivSIOp::inferResultRanges(
244     ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
245   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
246 
247   DivisionFixupFn ceilDivSIFix = [](const APInt &lhs, const APInt &rhs,
248                                     const APInt &result) -> Optional<APInt> {
249     if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) {
250       bool overflowed = false;
251       APInt corrected =
252           result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed);
253       return overflowed ? Optional<APInt>() : corrected;
254     }
255     return result;
256   };
257   setResultRange(getResult(), inferDivSIRange(lhs, rhs, ceilDivSIFix));
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // FloorDivSIOp
262 //===----------------------------------------------------------------------===//
263 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)264 void arith::FloorDivSIOp::inferResultRanges(
265     ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
266   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
267 
268   DivisionFixupFn floorDivSIFix = [](const APInt &lhs, const APInt &rhs,
269                                      const APInt &result) -> Optional<APInt> {
270     if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) {
271       bool overflowed = false;
272       APInt corrected =
273           result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed);
274       return overflowed ? Optional<APInt>() : corrected;
275     }
276     return result;
277   };
278   setResultRange(getResult(), inferDivSIRange(lhs, rhs, floorDivSIFix));
279 }
280 
281 //===----------------------------------------------------------------------===//
282 // RemUIOp
283 //===----------------------------------------------------------------------===//
284 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)285 void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
286                                        SetIntRangeFn setResultRange) {
287   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
288   const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
289 
290   unsigned width = rhsMin.getBitWidth();
291   APInt umin = APInt::getZero(width);
292   APInt umax = APInt::getMaxValue(width);
293 
294   if (!rhsMin.isZero()) {
295     umax = rhsMax - 1;
296     // Special case: sweeping out a contiguous range in N/[modulus]
297     if (rhsMin == rhsMax) {
298       const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
299       if ((lhsMax - lhsMin).ult(rhsMax)) {
300         APInt minRem = lhsMin.urem(rhsMax);
301         APInt maxRem = lhsMax.urem(rhsMax);
302         if (minRem.ule(maxRem)) {
303           umin = minRem;
304           umax = maxRem;
305         }
306       }
307     }
308   }
309   setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
310 }
311 
312 //===----------------------------------------------------------------------===//
313 // RemSIOp
314 //===----------------------------------------------------------------------===//
315 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)316 void arith::RemSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
317                                        SetIntRangeFn setResultRange) {
318   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
319   const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
320               &rhsMax = rhs.smax();
321 
322   unsigned width = rhsMax.getBitWidth();
323   APInt smin = APInt::getSignedMinValue(width);
324   APInt smax = APInt::getSignedMaxValue(width);
325   // No bounds if zero could be a divisor.
326   bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative());
327   if (canBound) {
328     APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs();
329     bool canNegativeDividend = lhsMin.isNegative();
330     bool canPositiveDividend = lhsMax.isStrictlyPositive();
331     APInt zero = APInt::getZero(maxDivisor.getBitWidth());
332     APInt maxPositiveResult = maxDivisor - 1;
333     APInt minNegativeResult = -maxPositiveResult;
334     smin = canNegativeDividend ? minNegativeResult : zero;
335     smax = canPositiveDividend ? maxPositiveResult : zero;
336     // Special case: sweeping out a contiguous range in N/[modulus].
337     if (rhsMin == rhsMax) {
338       if ((lhsMax - lhsMin).ult(maxDivisor)) {
339         APInt minRem = lhsMin.srem(maxDivisor);
340         APInt maxRem = lhsMax.srem(maxDivisor);
341         if (minRem.sle(maxRem)) {
342           smin = minRem;
343           smax = maxRem;
344         }
345       }
346     }
347   }
348   setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax));
349 }
350 
351 //===----------------------------------------------------------------------===//
352 // AndIOp
353 //===----------------------------------------------------------------------===//
354 
355 /// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???,
356 /// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits
357 /// that both bonuds have in common. This gives us a consertive approximation
358 /// for what values can be passed to bitwise operations.
359 static std::tuple<APInt, APInt>
widenBitwiseBounds(const ConstantIntRanges & bound)360 widenBitwiseBounds(const ConstantIntRanges &bound) {
361   APInt leftVal = bound.umin(), rightVal = bound.umax();
362   unsigned bitwidth = leftVal.getBitWidth();
363   unsigned differingBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros();
364   leftVal.clearLowBits(differingBits);
365   rightVal.setLowBits(differingBits);
366   return std::make_tuple(std::move(leftVal), std::move(rightVal));
367 }
368 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)369 void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
370                                       SetIntRangeFn setResultRange) {
371   APInt lhsZeros, lhsOnes, rhsZeros, rhsOnes;
372   std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(argRanges[0]);
373   std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(argRanges[1]);
374   auto andi = [](const APInt &a, const APInt &b) -> Optional<APInt> {
375     return a & b;
376   };
377   setResultRange(getResult(),
378                  minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
379                           /*isSigned=*/false));
380 }
381 
382 //===----------------------------------------------------------------------===//
383 // OrIOp
384 //===----------------------------------------------------------------------===//
385 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)386 void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
387                                      SetIntRangeFn setResultRange) {
388   APInt lhsZeros, lhsOnes, rhsZeros, rhsOnes;
389   std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(argRanges[0]);
390   std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(argRanges[1]);
391   auto ori = [](const APInt &a, const APInt &b) -> Optional<APInt> {
392     return a | b;
393   };
394   setResultRange(getResult(),
395                  minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
396                           /*isSigned=*/false));
397 }
398 
399 //===----------------------------------------------------------------------===//
400 // XOrIOp
401 //===----------------------------------------------------------------------===//
402 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)403 void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
404                                       SetIntRangeFn setResultRange) {
405   APInt lhsZeros, lhsOnes, rhsZeros, rhsOnes;
406   std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(argRanges[0]);
407   std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(argRanges[1]);
408   auto xori = [](const APInt &a, const APInt &b) -> Optional<APInt> {
409     return a ^ b;
410   };
411   setResultRange(getResult(),
412                  minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
413                           /*isSigned=*/false));
414 }
415 
416 //===----------------------------------------------------------------------===//
417 // MaxSIOp
418 //===----------------------------------------------------------------------===//
419 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)420 void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
421                                        SetIntRangeFn setResultRange) {
422   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
423 
424   const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
425   const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
426   setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax));
427 }
428 
429 //===----------------------------------------------------------------------===//
430 // MaxUIOp
431 //===----------------------------------------------------------------------===//
432 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)433 void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
434                                        SetIntRangeFn setResultRange) {
435   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
436 
437   const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin();
438   const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax();
439   setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
440 }
441 
442 //===----------------------------------------------------------------------===//
443 // MinSIOp
444 //===----------------------------------------------------------------------===//
445 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)446 void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
447                                        SetIntRangeFn setResultRange) {
448   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
449 
450   const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin();
451   const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax();
452   setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax));
453 }
454 
455 //===----------------------------------------------------------------------===//
456 // MinUIOp
457 //===----------------------------------------------------------------------===//
458 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)459 void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
460                                        SetIntRangeFn setResultRange) {
461   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
462 
463   const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin();
464   const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax();
465   setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
466 }
467 
468 //===----------------------------------------------------------------------===//
469 // ExtUIOp
470 //===----------------------------------------------------------------------===//
471 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)472 void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
473                                        SetIntRangeFn setResultRange) {
474   Type destType = getResult().getType();
475   unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
476   APInt umin = argRanges[0].umin().zext(destWidth);
477   APInt umax = argRanges[0].umax().zext(destWidth);
478   setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
479 }
480 
481 //===----------------------------------------------------------------------===//
482 // ExtSIOp
483 //===----------------------------------------------------------------------===//
484 
extSIRange(const ConstantIntRanges & range,Type destType)485 static ConstantIntRanges extSIRange(const ConstantIntRanges &range,
486                                     Type destType) {
487   unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
488   APInt smin = range.smin().sext(destWidth);
489   APInt smax = range.smax().sext(destWidth);
490   return ConstantIntRanges::fromSigned(smin, smax);
491 }
492 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)493 void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
494                                        SetIntRangeFn setResultRange) {
495   Type destType = getResult().getType();
496   setResultRange(getResult(), extSIRange(argRanges[0], destType));
497 }
498 
499 //===----------------------------------------------------------------------===//
500 // TruncIOp
501 //===----------------------------------------------------------------------===//
502 
truncIRange(const ConstantIntRanges & range,Type destType)503 static ConstantIntRanges truncIRange(const ConstantIntRanges &range,
504                                      Type destType) {
505   unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
506   APInt umin = range.umin().trunc(destWidth);
507   APInt umax = range.umax().trunc(destWidth);
508   APInt smin = range.smin().trunc(destWidth);
509   APInt smax = range.smax().trunc(destWidth);
510   return {umin, umax, smin, smax};
511 }
512 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)513 void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
514                                         SetIntRangeFn setResultRange) {
515   Type destType = getResult().getType();
516   setResultRange(getResult(), truncIRange(argRanges[0], destType));
517 }
518 
519 //===----------------------------------------------------------------------===//
520 // IndexCastOp
521 //===----------------------------------------------------------------------===//
522 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)523 void arith::IndexCastOp::inferResultRanges(
524     ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
525   Type sourceType = getOperand().getType();
526   Type destType = getResult().getType();
527   unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
528   unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
529 
530   if (srcWidth < destWidth)
531     setResultRange(getResult(), extSIRange(argRanges[0], destType));
532   else if (srcWidth > destWidth)
533     setResultRange(getResult(), truncIRange(argRanges[0], destType));
534   else
535     setResultRange(getResult(), argRanges[0]);
536 }
537 
538 //===----------------------------------------------------------------------===//
539 // CmpIOp
540 //===----------------------------------------------------------------------===//
541 
isStaticallyTrue(arith::CmpIPredicate pred,const ConstantIntRanges & lhs,const ConstantIntRanges & rhs)542 bool isStaticallyTrue(arith::CmpIPredicate pred, const ConstantIntRanges &lhs,
543                       const ConstantIntRanges &rhs) {
544   switch (pred) {
545   case arith::CmpIPredicate::sle:
546   case arith::CmpIPredicate::slt:
547     return (applyCmpPredicate(pred, lhs.smax(), rhs.smin()));
548   case arith::CmpIPredicate::ule:
549   case arith::CmpIPredicate::ult:
550     return applyCmpPredicate(pred, lhs.umax(), rhs.umin());
551   case arith::CmpIPredicate::sge:
552   case arith::CmpIPredicate::sgt:
553     return applyCmpPredicate(pred, lhs.smin(), rhs.smax());
554   case arith::CmpIPredicate::uge:
555   case arith::CmpIPredicate::ugt:
556     return applyCmpPredicate(pred, lhs.umin(), rhs.umax());
557   case arith::CmpIPredicate::eq: {
558     Optional<APInt> lhsConst = lhs.getConstantValue();
559     Optional<APInt> rhsConst = rhs.getConstantValue();
560     return lhsConst && rhsConst && lhsConst == rhsConst;
561   }
562   case arith::CmpIPredicate::ne: {
563     // While equality requires that there is an interpration of the preceeding
564     // computations that produces equal constants, whether that be signed or
565     // unsigned, statically determining inequality requires that neither
566     // interpretation produce potentially overlapping ranges.
567     bool sne = isStaticallyTrue(CmpIPredicate::slt, lhs, rhs) ||
568                isStaticallyTrue(CmpIPredicate::sgt, lhs, rhs);
569     bool une = isStaticallyTrue(CmpIPredicate::ult, lhs, rhs) ||
570                isStaticallyTrue(CmpIPredicate::ugt, lhs, rhs);
571     return sne && une;
572   }
573   }
574   return false;
575 }
576 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)577 void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
578                                       SetIntRangeFn setResultRange) {
579   arith::CmpIPredicate pred = getPredicate();
580   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
581 
582   APInt min = APInt::getZero(1);
583   APInt max = APInt::getAllOnesValue(1);
584   if (isStaticallyTrue(pred, lhs, rhs))
585     min = max;
586   else if (isStaticallyTrue(invertPredicate(pred), lhs, rhs))
587     max = min;
588 
589   setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
590 }
591 
592 //===----------------------------------------------------------------------===//
593 // SelectOp
594 //===----------------------------------------------------------------------===//
595 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)596 void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
597                                         SetIntRangeFn setResultRange) {
598   Optional<APInt> mbCondVal = argRanges[0].getConstantValue();
599 
600   if (mbCondVal) {
601     if (mbCondVal->isZero())
602       setResultRange(getResult(), argRanges[2]);
603     else
604       setResultRange(getResult(), argRanges[1]);
605     return;
606   }
607   setResultRange(getResult(), argRanges[1].rangeUnion(argRanges[2]));
608 }
609 
610 //===----------------------------------------------------------------------===//
611 // ShLIOp
612 //===----------------------------------------------------------------------===//
613 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)614 void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
615                                       SetIntRangeFn setResultRange) {
616   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
617   ConstArithFn shl = [](const APInt &l, const APInt &r) -> Optional<APInt> {
618     return r.uge(r.getBitWidth()) ? Optional<APInt>() : l.shl(r);
619   };
620   ConstantIntRanges urange =
621       minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
622                /*isSigned=*/false);
623   ConstantIntRanges srange =
624       minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
625                /*isSigned=*/true);
626   setResultRange(getResult(), urange.intersection(srange));
627 }
628 
629 //===----------------------------------------------------------------------===//
630 // ShRUIOp
631 //===----------------------------------------------------------------------===//
632 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)633 void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
634                                        SetIntRangeFn setResultRange) {
635   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
636 
637   ConstArithFn lshr = [](const APInt &l, const APInt &r) -> Optional<APInt> {
638     return r.uge(r.getBitWidth()) ? Optional<APInt>() : l.lshr(r);
639   };
640   setResultRange(getResult(), minMaxBy(lshr, {lhs.umin(), lhs.umax()},
641                                        {rhs.umin(), rhs.umax()},
642                                        /*isSigned=*/false));
643 }
644 
645 //===----------------------------------------------------------------------===//
646 // ShRSIOp
647 //===----------------------------------------------------------------------===//
648 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)649 void arith::ShRSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
650                                        SetIntRangeFn setResultRange) {
651   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
652 
653   ConstArithFn ashr = [](const APInt &l, const APInt &r) -> Optional<APInt> {
654     return r.uge(r.getBitWidth()) ? Optional<APInt>() : l.ashr(r);
655   };
656 
657   setResultRange(getResult(),
658                  minMaxBy(ashr, {lhs.smin(), lhs.smax()},
659                           {rhs.umin(), rhs.umax()}, /*isSigned=*/true));
660 }
661