1 //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Interfaces/InferIntRangeInterface.h"
11
12 #include "llvm/Support/Debug.h"
13
14 #define DEBUG_TYPE "int-range-analysis"
15
16 using namespace mlir;
17 using namespace mlir::arith;
18
19 /// Function that evaluates the result of doing something on arithmetic
20 /// constants and returns None on overflow.
21 using ConstArithFn =
22 function_ref<Optional<APInt>(const APInt &, const APInt &)>;
23
24 /// Return the maxmially wide signed or unsigned range for a given bitwidth.
25
26 /// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
27 /// If either computation overflows, make the result unbounded.
computeBoundsBy(ConstArithFn op,const APInt & minLeft,const APInt & minRight,const APInt & maxLeft,const APInt & maxRight,bool isSigned)28 static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
29 const APInt &minRight,
30 const APInt &maxLeft,
31 const APInt &maxRight, bool isSigned) {
32 Optional<APInt> maybeMin = op(minLeft, minRight);
33 Optional<APInt> maybeMax = op(maxLeft, maxRight);
34 if (maybeMin && maybeMax)
35 return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned);
36 return ConstantIntRanges::maxRange(minLeft.getBitWidth());
37 }
38
39 /// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`,
40 /// ignoring unbounded values. Returns the maximal range if `op` overflows.
minMaxBy(ConstArithFn op,ArrayRef<APInt> lhs,ArrayRef<APInt> rhs,bool isSigned)41 static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs,
42 ArrayRef<APInt> rhs, bool isSigned) {
43 unsigned width = lhs[0].getBitWidth();
44 APInt min =
45 isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width);
46 APInt max =
47 isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width);
48 for (const APInt &left : lhs) {
49 for (const APInt &right : rhs) {
50 Optional<APInt> maybeThisResult = op(left, right);
51 if (!maybeThisResult)
52 return ConstantIntRanges::maxRange(width);
53 APInt result = std::move(*maybeThisResult);
54 min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min;
55 max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max;
56 }
57 }
58 return ConstantIntRanges::range(min, max, isSigned);
59 }
60
61 //===----------------------------------------------------------------------===//
62 // ConstantOp
63 //===----------------------------------------------------------------------===//
64
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)65 void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
66 SetIntRangeFn setResultRange) {
67 auto constAttr = getValue().dyn_cast_or_null<IntegerAttr>();
68 if (constAttr) {
69 const APInt &value = constAttr.getValue();
70 setResultRange(getResult(), ConstantIntRanges::constant(value));
71 }
72 }
73
74 //===----------------------------------------------------------------------===//
75 // AddIOp
76 //===----------------------------------------------------------------------===//
77
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)78 void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
79 SetIntRangeFn setResultRange) {
80 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
81 ConstArithFn uadd = [](const APInt &a, const APInt &b) -> Optional<APInt> {
82 bool overflowed = false;
83 APInt result = a.uadd_ov(b, overflowed);
84 return overflowed ? Optional<APInt>() : result;
85 };
86 ConstArithFn sadd = [](const APInt &a, const APInt &b) -> Optional<APInt> {
87 bool overflowed = false;
88 APInt result = a.sadd_ov(b, overflowed);
89 return overflowed ? Optional<APInt>() : result;
90 };
91
92 ConstantIntRanges urange = computeBoundsBy(
93 uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
94 ConstantIntRanges srange = computeBoundsBy(
95 sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
96 setResultRange(getResult(), urange.intersection(srange));
97 }
98
99 //===----------------------------------------------------------------------===//
100 // SubIOp
101 //===----------------------------------------------------------------------===//
102
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)103 void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
104 SetIntRangeFn setResultRange) {
105 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
106
107 ConstArithFn usub = [](const APInt &a, const APInt &b) -> Optional<APInt> {
108 bool overflowed = false;
109 APInt result = a.usub_ov(b, overflowed);
110 return overflowed ? Optional<APInt>() : result;
111 };
112 ConstArithFn ssub = [](const APInt &a, const APInt &b) -> Optional<APInt> {
113 bool overflowed = false;
114 APInt result = a.ssub_ov(b, overflowed);
115 return overflowed ? Optional<APInt>() : result;
116 };
117 ConstantIntRanges urange = computeBoundsBy(
118 usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
119 ConstantIntRanges srange = computeBoundsBy(
120 ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
121 setResultRange(getResult(), urange.intersection(srange));
122 }
123
124 //===----------------------------------------------------------------------===//
125 // MulIOp
126 //===----------------------------------------------------------------------===//
127
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)128 void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
129 SetIntRangeFn setResultRange) {
130 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
131
132 ConstArithFn umul = [](const APInt &a, const APInt &b) -> Optional<APInt> {
133 bool overflowed = false;
134 APInt result = a.umul_ov(b, overflowed);
135 return overflowed ? Optional<APInt>() : result;
136 };
137 ConstArithFn smul = [](const APInt &a, const APInt &b) -> Optional<APInt> {
138 bool overflowed = false;
139 APInt result = a.smul_ov(b, overflowed);
140 return overflowed ? Optional<APInt>() : result;
141 };
142
143 ConstantIntRanges urange =
144 minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
145 /*isSigned=*/false);
146 ConstantIntRanges srange =
147 minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
148 /*isSigned=*/true);
149
150 setResultRange(getResult(), urange.intersection(srange));
151 }
152
153 //===----------------------------------------------------------------------===//
154 // DivUIOp
155 //===----------------------------------------------------------------------===//
156
157 /// Fix up division results (ex. for ceiling and floor), returning an APInt
158 /// if there has been no overflow
159 using DivisionFixupFn = function_ref<Optional<APInt>(
160 const APInt &lhs, const APInt &rhs, const APInt &result)>;
161
inferDivUIRange(const ConstantIntRanges & lhs,const ConstantIntRanges & rhs,DivisionFixupFn fixup)162 static ConstantIntRanges inferDivUIRange(const ConstantIntRanges &lhs,
163 const ConstantIntRanges &rhs,
164 DivisionFixupFn fixup) {
165 const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(),
166 &rhsMax = rhs.umax();
167
168 if (!rhsMin.isZero()) {
169 auto udiv = [&fixup](const APInt &a, const APInt &b) -> Optional<APInt> {
170 return fixup(a, b, a.udiv(b));
171 };
172 return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
173 /*isSigned=*/false);
174 }
175 // Otherwise, it's possible we might divide by 0.
176 return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
177 }
178
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)179 void arith::DivUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
180 SetIntRangeFn setResultRange) {
181 setResultRange(getResult(),
182 inferDivUIRange(argRanges[0], argRanges[1],
183 [](const APInt &lhs, const APInt &rhs,
184 const APInt &result) { return result; }));
185 }
186
187 //===----------------------------------------------------------------------===//
188 // DivSIOp
189 //===----------------------------------------------------------------------===//
190
inferDivSIRange(const ConstantIntRanges & lhs,const ConstantIntRanges & rhs,DivisionFixupFn fixup)191 static ConstantIntRanges inferDivSIRange(const ConstantIntRanges &lhs,
192 const ConstantIntRanges &rhs,
193 DivisionFixupFn fixup) {
194 const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
195 &rhsMax = rhs.smax();
196 bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative();
197
198 if (canDivide) {
199 auto sdiv = [&fixup](const APInt &a, const APInt &b) -> Optional<APInt> {
200 bool overflowed = false;
201 APInt result = a.sdiv_ov(b, overflowed);
202 return overflowed ? Optional<APInt>() : fixup(a, b, result);
203 };
204 return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
205 /*isSigned=*/true);
206 }
207 return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
208 }
209
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)210 void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
211 SetIntRangeFn setResultRange) {
212 setResultRange(getResult(),
213 inferDivSIRange(argRanges[0], argRanges[1],
214 [](const APInt &lhs, const APInt &rhs,
215 const APInt &result) { return result; }));
216 }
217
218 //===----------------------------------------------------------------------===//
219 // CeilDivUIOp
220 //===----------------------------------------------------------------------===//
221
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)222 void arith::CeilDivUIOp::inferResultRanges(
223 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
224 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
225
226 DivisionFixupFn ceilDivUIFix = [](const APInt &lhs, const APInt &rhs,
227 const APInt &result) -> Optional<APInt> {
228 if (!lhs.urem(rhs).isZero()) {
229 bool overflowed = false;
230 APInt corrected =
231 result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed);
232 return overflowed ? Optional<APInt>() : corrected;
233 }
234 return result;
235 };
236 setResultRange(getResult(), inferDivUIRange(lhs, rhs, ceilDivUIFix));
237 }
238
239 //===----------------------------------------------------------------------===//
240 // CeilDivSIOp
241 //===----------------------------------------------------------------------===//
242
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)243 void arith::CeilDivSIOp::inferResultRanges(
244 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
245 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
246
247 DivisionFixupFn ceilDivSIFix = [](const APInt &lhs, const APInt &rhs,
248 const APInt &result) -> Optional<APInt> {
249 if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) {
250 bool overflowed = false;
251 APInt corrected =
252 result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed);
253 return overflowed ? Optional<APInt>() : corrected;
254 }
255 return result;
256 };
257 setResultRange(getResult(), inferDivSIRange(lhs, rhs, ceilDivSIFix));
258 }
259
260 //===----------------------------------------------------------------------===//
261 // FloorDivSIOp
262 //===----------------------------------------------------------------------===//
263
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)264 void arith::FloorDivSIOp::inferResultRanges(
265 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
266 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
267
268 DivisionFixupFn floorDivSIFix = [](const APInt &lhs, const APInt &rhs,
269 const APInt &result) -> Optional<APInt> {
270 if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) {
271 bool overflowed = false;
272 APInt corrected =
273 result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed);
274 return overflowed ? Optional<APInt>() : corrected;
275 }
276 return result;
277 };
278 setResultRange(getResult(), inferDivSIRange(lhs, rhs, floorDivSIFix));
279 }
280
281 //===----------------------------------------------------------------------===//
282 // RemUIOp
283 //===----------------------------------------------------------------------===//
284
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)285 void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
286 SetIntRangeFn setResultRange) {
287 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
288 const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
289
290 unsigned width = rhsMin.getBitWidth();
291 APInt umin = APInt::getZero(width);
292 APInt umax = APInt::getMaxValue(width);
293
294 if (!rhsMin.isZero()) {
295 umax = rhsMax - 1;
296 // Special case: sweeping out a contiguous range in N/[modulus]
297 if (rhsMin == rhsMax) {
298 const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
299 if ((lhsMax - lhsMin).ult(rhsMax)) {
300 APInt minRem = lhsMin.urem(rhsMax);
301 APInt maxRem = lhsMax.urem(rhsMax);
302 if (minRem.ule(maxRem)) {
303 umin = minRem;
304 umax = maxRem;
305 }
306 }
307 }
308 }
309 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
310 }
311
312 //===----------------------------------------------------------------------===//
313 // RemSIOp
314 //===----------------------------------------------------------------------===//
315
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)316 void arith::RemSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
317 SetIntRangeFn setResultRange) {
318 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
319 const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
320 &rhsMax = rhs.smax();
321
322 unsigned width = rhsMax.getBitWidth();
323 APInt smin = APInt::getSignedMinValue(width);
324 APInt smax = APInt::getSignedMaxValue(width);
325 // No bounds if zero could be a divisor.
326 bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative());
327 if (canBound) {
328 APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs();
329 bool canNegativeDividend = lhsMin.isNegative();
330 bool canPositiveDividend = lhsMax.isStrictlyPositive();
331 APInt zero = APInt::getZero(maxDivisor.getBitWidth());
332 APInt maxPositiveResult = maxDivisor - 1;
333 APInt minNegativeResult = -maxPositiveResult;
334 smin = canNegativeDividend ? minNegativeResult : zero;
335 smax = canPositiveDividend ? maxPositiveResult : zero;
336 // Special case: sweeping out a contiguous range in N/[modulus].
337 if (rhsMin == rhsMax) {
338 if ((lhsMax - lhsMin).ult(maxDivisor)) {
339 APInt minRem = lhsMin.srem(maxDivisor);
340 APInt maxRem = lhsMax.srem(maxDivisor);
341 if (minRem.sle(maxRem)) {
342 smin = minRem;
343 smax = maxRem;
344 }
345 }
346 }
347 }
348 setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax));
349 }
350
351 //===----------------------------------------------------------------------===//
352 // AndIOp
353 //===----------------------------------------------------------------------===//
354
355 /// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???,
356 /// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits
357 /// that both bonuds have in common. This gives us a consertive approximation
358 /// for what values can be passed to bitwise operations.
359 static std::tuple<APInt, APInt>
widenBitwiseBounds(const ConstantIntRanges & bound)360 widenBitwiseBounds(const ConstantIntRanges &bound) {
361 APInt leftVal = bound.umin(), rightVal = bound.umax();
362 unsigned bitwidth = leftVal.getBitWidth();
363 unsigned differingBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros();
364 leftVal.clearLowBits(differingBits);
365 rightVal.setLowBits(differingBits);
366 return std::make_tuple(std::move(leftVal), std::move(rightVal));
367 }
368
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)369 void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
370 SetIntRangeFn setResultRange) {
371 APInt lhsZeros, lhsOnes, rhsZeros, rhsOnes;
372 std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(argRanges[0]);
373 std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(argRanges[1]);
374 auto andi = [](const APInt &a, const APInt &b) -> Optional<APInt> {
375 return a & b;
376 };
377 setResultRange(getResult(),
378 minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
379 /*isSigned=*/false));
380 }
381
382 //===----------------------------------------------------------------------===//
383 // OrIOp
384 //===----------------------------------------------------------------------===//
385
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)386 void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
387 SetIntRangeFn setResultRange) {
388 APInt lhsZeros, lhsOnes, rhsZeros, rhsOnes;
389 std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(argRanges[0]);
390 std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(argRanges[1]);
391 auto ori = [](const APInt &a, const APInt &b) -> Optional<APInt> {
392 return a | b;
393 };
394 setResultRange(getResult(),
395 minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
396 /*isSigned=*/false));
397 }
398
399 //===----------------------------------------------------------------------===//
400 // XOrIOp
401 //===----------------------------------------------------------------------===//
402
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)403 void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
404 SetIntRangeFn setResultRange) {
405 APInt lhsZeros, lhsOnes, rhsZeros, rhsOnes;
406 std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(argRanges[0]);
407 std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(argRanges[1]);
408 auto xori = [](const APInt &a, const APInt &b) -> Optional<APInt> {
409 return a ^ b;
410 };
411 setResultRange(getResult(),
412 minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
413 /*isSigned=*/false));
414 }
415
416 //===----------------------------------------------------------------------===//
417 // MaxSIOp
418 //===----------------------------------------------------------------------===//
419
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)420 void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
421 SetIntRangeFn setResultRange) {
422 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
423
424 const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
425 const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
426 setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax));
427 }
428
429 //===----------------------------------------------------------------------===//
430 // MaxUIOp
431 //===----------------------------------------------------------------------===//
432
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)433 void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
434 SetIntRangeFn setResultRange) {
435 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
436
437 const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin();
438 const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax();
439 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
440 }
441
442 //===----------------------------------------------------------------------===//
443 // MinSIOp
444 //===----------------------------------------------------------------------===//
445
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)446 void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
447 SetIntRangeFn setResultRange) {
448 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
449
450 const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin();
451 const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax();
452 setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax));
453 }
454
455 //===----------------------------------------------------------------------===//
456 // MinUIOp
457 //===----------------------------------------------------------------------===//
458
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)459 void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
460 SetIntRangeFn setResultRange) {
461 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
462
463 const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin();
464 const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax();
465 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
466 }
467
468 //===----------------------------------------------------------------------===//
469 // ExtUIOp
470 //===----------------------------------------------------------------------===//
471
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)472 void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
473 SetIntRangeFn setResultRange) {
474 Type destType = getResult().getType();
475 unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
476 APInt umin = argRanges[0].umin().zext(destWidth);
477 APInt umax = argRanges[0].umax().zext(destWidth);
478 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
479 }
480
481 //===----------------------------------------------------------------------===//
482 // ExtSIOp
483 //===----------------------------------------------------------------------===//
484
extSIRange(const ConstantIntRanges & range,Type destType)485 static ConstantIntRanges extSIRange(const ConstantIntRanges &range,
486 Type destType) {
487 unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
488 APInt smin = range.smin().sext(destWidth);
489 APInt smax = range.smax().sext(destWidth);
490 return ConstantIntRanges::fromSigned(smin, smax);
491 }
492
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)493 void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
494 SetIntRangeFn setResultRange) {
495 Type destType = getResult().getType();
496 setResultRange(getResult(), extSIRange(argRanges[0], destType));
497 }
498
499 //===----------------------------------------------------------------------===//
500 // TruncIOp
501 //===----------------------------------------------------------------------===//
502
truncIRange(const ConstantIntRanges & range,Type destType)503 static ConstantIntRanges truncIRange(const ConstantIntRanges &range,
504 Type destType) {
505 unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
506 APInt umin = range.umin().trunc(destWidth);
507 APInt umax = range.umax().trunc(destWidth);
508 APInt smin = range.smin().trunc(destWidth);
509 APInt smax = range.smax().trunc(destWidth);
510 return {umin, umax, smin, smax};
511 }
512
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)513 void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
514 SetIntRangeFn setResultRange) {
515 Type destType = getResult().getType();
516 setResultRange(getResult(), truncIRange(argRanges[0], destType));
517 }
518
519 //===----------------------------------------------------------------------===//
520 // IndexCastOp
521 //===----------------------------------------------------------------------===//
522
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)523 void arith::IndexCastOp::inferResultRanges(
524 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
525 Type sourceType = getOperand().getType();
526 Type destType = getResult().getType();
527 unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
528 unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
529
530 if (srcWidth < destWidth)
531 setResultRange(getResult(), extSIRange(argRanges[0], destType));
532 else if (srcWidth > destWidth)
533 setResultRange(getResult(), truncIRange(argRanges[0], destType));
534 else
535 setResultRange(getResult(), argRanges[0]);
536 }
537
538 //===----------------------------------------------------------------------===//
539 // CmpIOp
540 //===----------------------------------------------------------------------===//
541
isStaticallyTrue(arith::CmpIPredicate pred,const ConstantIntRanges & lhs,const ConstantIntRanges & rhs)542 bool isStaticallyTrue(arith::CmpIPredicate pred, const ConstantIntRanges &lhs,
543 const ConstantIntRanges &rhs) {
544 switch (pred) {
545 case arith::CmpIPredicate::sle:
546 case arith::CmpIPredicate::slt:
547 return (applyCmpPredicate(pred, lhs.smax(), rhs.smin()));
548 case arith::CmpIPredicate::ule:
549 case arith::CmpIPredicate::ult:
550 return applyCmpPredicate(pred, lhs.umax(), rhs.umin());
551 case arith::CmpIPredicate::sge:
552 case arith::CmpIPredicate::sgt:
553 return applyCmpPredicate(pred, lhs.smin(), rhs.smax());
554 case arith::CmpIPredicate::uge:
555 case arith::CmpIPredicate::ugt:
556 return applyCmpPredicate(pred, lhs.umin(), rhs.umax());
557 case arith::CmpIPredicate::eq: {
558 Optional<APInt> lhsConst = lhs.getConstantValue();
559 Optional<APInt> rhsConst = rhs.getConstantValue();
560 return lhsConst && rhsConst && lhsConst == rhsConst;
561 }
562 case arith::CmpIPredicate::ne: {
563 // While equality requires that there is an interpration of the preceeding
564 // computations that produces equal constants, whether that be signed or
565 // unsigned, statically determining inequality requires that neither
566 // interpretation produce potentially overlapping ranges.
567 bool sne = isStaticallyTrue(CmpIPredicate::slt, lhs, rhs) ||
568 isStaticallyTrue(CmpIPredicate::sgt, lhs, rhs);
569 bool une = isStaticallyTrue(CmpIPredicate::ult, lhs, rhs) ||
570 isStaticallyTrue(CmpIPredicate::ugt, lhs, rhs);
571 return sne && une;
572 }
573 }
574 return false;
575 }
576
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)577 void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
578 SetIntRangeFn setResultRange) {
579 arith::CmpIPredicate pred = getPredicate();
580 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
581
582 APInt min = APInt::getZero(1);
583 APInt max = APInt::getAllOnesValue(1);
584 if (isStaticallyTrue(pred, lhs, rhs))
585 min = max;
586 else if (isStaticallyTrue(invertPredicate(pred), lhs, rhs))
587 max = min;
588
589 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
590 }
591
592 //===----------------------------------------------------------------------===//
593 // SelectOp
594 //===----------------------------------------------------------------------===//
595
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)596 void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
597 SetIntRangeFn setResultRange) {
598 Optional<APInt> mbCondVal = argRanges[0].getConstantValue();
599
600 if (mbCondVal) {
601 if (mbCondVal->isZero())
602 setResultRange(getResult(), argRanges[2]);
603 else
604 setResultRange(getResult(), argRanges[1]);
605 return;
606 }
607 setResultRange(getResult(), argRanges[1].rangeUnion(argRanges[2]));
608 }
609
610 //===----------------------------------------------------------------------===//
611 // ShLIOp
612 //===----------------------------------------------------------------------===//
613
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)614 void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
615 SetIntRangeFn setResultRange) {
616 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
617 ConstArithFn shl = [](const APInt &l, const APInt &r) -> Optional<APInt> {
618 return r.uge(r.getBitWidth()) ? Optional<APInt>() : l.shl(r);
619 };
620 ConstantIntRanges urange =
621 minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
622 /*isSigned=*/false);
623 ConstantIntRanges srange =
624 minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
625 /*isSigned=*/true);
626 setResultRange(getResult(), urange.intersection(srange));
627 }
628
629 //===----------------------------------------------------------------------===//
630 // ShRUIOp
631 //===----------------------------------------------------------------------===//
632
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)633 void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
634 SetIntRangeFn setResultRange) {
635 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
636
637 ConstArithFn lshr = [](const APInt &l, const APInt &r) -> Optional<APInt> {
638 return r.uge(r.getBitWidth()) ? Optional<APInt>() : l.lshr(r);
639 };
640 setResultRange(getResult(), minMaxBy(lshr, {lhs.umin(), lhs.umax()},
641 {rhs.umin(), rhs.umax()},
642 /*isSigned=*/false));
643 }
644
645 //===----------------------------------------------------------------------===//
646 // ShRSIOp
647 //===----------------------------------------------------------------------===//
648
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRange)649 void arith::ShRSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
650 SetIntRangeFn setResultRange) {
651 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
652
653 ConstArithFn ashr = [](const APInt &l, const APInt &r) -> Optional<APInt> {
654 return r.uge(r.getBitWidth()) ? Optional<APInt>() : l.ashr(r);
655 };
656
657 setResultRange(getResult(),
658 minMaxBy(ashr, {lhs.smin(), lhs.smax()},
659 {rhs.umin(), rhs.umax()}, /*isSigned=*/true));
660 }
661