1 //===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Interfaces/InferIntRangeInterface.h" 11 12 #include "llvm/Support/Debug.h" 13 14 #define DEBUG_TYPE "int-range-analysis" 15 16 using namespace mlir; 17 using namespace mlir::arith; 18 19 /// Function that evaluates the result of doing something on arithmetic 20 /// constants and returns None on overflow. 21 using ConstArithFn = 22 function_ref<Optional<APInt>(const APInt &, const APInt &)>; 23 24 /// Return the maxmially wide signed or unsigned range for a given bitwidth. 25 26 /// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, 27 /// If either computation overflows, make the result unbounded. 28 static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft, 29 const APInt &minRight, 30 const APInt &maxLeft, 31 const APInt &maxRight, bool isSigned) { 32 Optional<APInt> maybeMin = op(minLeft, minRight); 33 Optional<APInt> maybeMax = op(maxLeft, maxRight); 34 if (maybeMin && maybeMax) 35 return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned); 36 return ConstantIntRanges::maxRange(minLeft.getBitWidth()); 37 } 38 39 /// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`, 40 /// ignoring unbounded values. Returns the maximal range if `op` overflows. 41 static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs, 42 ArrayRef<APInt> rhs, bool isSigned) { 43 unsigned width = lhs[0].getBitWidth(); 44 APInt min = 45 isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width); 46 APInt max = 47 isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width); 48 for (const APInt &left : lhs) { 49 for (const APInt &right : rhs) { 50 Optional<APInt> maybeThisResult = op(left, right); 51 if (!maybeThisResult) 52 return ConstantIntRanges::maxRange(width); 53 APInt result = std::move(*maybeThisResult); 54 min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min; 55 max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max; 56 } 57 } 58 return ConstantIntRanges::range(min, max, isSigned); 59 } 60 61 //===----------------------------------------------------------------------===// 62 // ConstantOp 63 //===----------------------------------------------------------------------===// 64 65 void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 66 SetIntRangeFn setResultRange) { 67 auto constAttr = getValue().dyn_cast_or_null<IntegerAttr>(); 68 if (constAttr) { 69 const APInt &value = constAttr.getValue(); 70 setResultRange(getResult(), ConstantIntRanges::constant(value)); 71 } 72 } 73 74 //===----------------------------------------------------------------------===// 75 // AddIOp 76 //===----------------------------------------------------------------------===// 77 78 void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 79 SetIntRangeFn setResultRange) { 80 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 81 ConstArithFn uadd = [](const APInt &a, const APInt &b) -> Optional<APInt> { 82 bool overflowed = false; 83 APInt result = a.uadd_ov(b, overflowed); 84 return overflowed ? Optional<APInt>() : result; 85 }; 86 ConstArithFn sadd = [](const APInt &a, const APInt &b) -> Optional<APInt> { 87 bool overflowed = false; 88 APInt result = a.sadd_ov(b, overflowed); 89 return overflowed ? Optional<APInt>() : result; 90 }; 91 92 ConstantIntRanges urange = computeBoundsBy( 93 uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false); 94 ConstantIntRanges srange = computeBoundsBy( 95 sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true); 96 setResultRange(getResult(), urange.intersection(srange)); 97 } 98 99 //===----------------------------------------------------------------------===// 100 // SubIOp 101 //===----------------------------------------------------------------------===// 102 103 void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 104 SetIntRangeFn setResultRange) { 105 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 106 107 ConstArithFn usub = [](const APInt &a, const APInt &b) -> Optional<APInt> { 108 bool overflowed = false; 109 APInt result = a.usub_ov(b, overflowed); 110 return overflowed ? Optional<APInt>() : result; 111 }; 112 ConstArithFn ssub = [](const APInt &a, const APInt &b) -> Optional<APInt> { 113 bool overflowed = false; 114 APInt result = a.ssub_ov(b, overflowed); 115 return overflowed ? Optional<APInt>() : result; 116 }; 117 ConstantIntRanges urange = computeBoundsBy( 118 usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false); 119 ConstantIntRanges srange = computeBoundsBy( 120 ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true); 121 setResultRange(getResult(), urange.intersection(srange)); 122 } 123 124 //===----------------------------------------------------------------------===// 125 // MulIOp 126 //===----------------------------------------------------------------------===// 127 128 void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 129 SetIntRangeFn setResultRange) { 130 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 131 132 ConstArithFn umul = [](const APInt &a, const APInt &b) -> Optional<APInt> { 133 bool overflowed = false; 134 APInt result = a.umul_ov(b, overflowed); 135 return overflowed ? Optional<APInt>() : result; 136 }; 137 ConstArithFn smul = [](const APInt &a, const APInt &b) -> Optional<APInt> { 138 bool overflowed = false; 139 APInt result = a.smul_ov(b, overflowed); 140 return overflowed ? Optional<APInt>() : result; 141 }; 142 143 ConstantIntRanges urange = 144 minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, 145 /*isSigned=*/false); 146 ConstantIntRanges srange = 147 minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()}, 148 /*isSigned=*/true); 149 150 setResultRange(getResult(), urange.intersection(srange)); 151 } 152 153 //===----------------------------------------------------------------------===// 154 // DivUIOp 155 //===----------------------------------------------------------------------===// 156 157 /// Fix up division results (ex. for ceiling and floor), returning an APInt 158 /// if there has been no overflow 159 using DivisionFixupFn = function_ref<Optional<APInt>( 160 const APInt &lhs, const APInt &rhs, const APInt &result)>; 161 162 static ConstantIntRanges inferDivUIRange(const ConstantIntRanges &lhs, 163 const ConstantIntRanges &rhs, 164 DivisionFixupFn fixup) { 165 const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(), 166 &rhsMax = rhs.umax(); 167 168 if (!rhsMin.isZero()) { 169 auto udiv = [&fixup](const APInt &a, const APInt &b) -> Optional<APInt> { 170 return fixup(a, b, a.udiv(b)); 171 }; 172 return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, 173 /*isSigned=*/false); 174 } 175 // Otherwise, it's possible we might divide by 0. 176 return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); 177 } 178 179 void arith::DivUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 180 SetIntRangeFn setResultRange) { 181 setResultRange(getResult(), 182 inferDivUIRange(argRanges[0], argRanges[1], 183 [](const APInt &lhs, const APInt &rhs, 184 const APInt &result) { return result; })); 185 } 186 187 //===----------------------------------------------------------------------===// 188 // DivSIOp 189 //===----------------------------------------------------------------------===// 190 191 static ConstantIntRanges inferDivSIRange(const ConstantIntRanges &lhs, 192 const ConstantIntRanges &rhs, 193 DivisionFixupFn fixup) { 194 const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), 195 &rhsMax = rhs.smax(); 196 bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative(); 197 198 if (canDivide) { 199 auto sdiv = [&fixup](const APInt &a, const APInt &b) -> Optional<APInt> { 200 bool overflowed = false; 201 APInt result = a.sdiv_ov(b, overflowed); 202 return overflowed ? Optional<APInt>() : fixup(a, b, result); 203 }; 204 return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, 205 /*isSigned=*/true); 206 } 207 return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); 208 } 209 210 void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 211 SetIntRangeFn setResultRange) { 212 setResultRange(getResult(), 213 inferDivSIRange(argRanges[0], argRanges[1], 214 [](const APInt &lhs, const APInt &rhs, 215 const APInt &result) { return result; })); 216 } 217 218 //===----------------------------------------------------------------------===// 219 // CeilDivUIOp 220 //===----------------------------------------------------------------------===// 221 222 void arith::CeilDivUIOp::inferResultRanges( 223 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) { 224 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 225 226 DivisionFixupFn ceilDivUIFix = [](const APInt &lhs, const APInt &rhs, 227 const APInt &result) -> Optional<APInt> { 228 if (!lhs.urem(rhs).isZero()) { 229 bool overflowed = false; 230 APInt corrected = 231 result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed); 232 return overflowed ? Optional<APInt>() : corrected; 233 } 234 return result; 235 }; 236 setResultRange(getResult(), inferDivUIRange(lhs, rhs, ceilDivUIFix)); 237 } 238 239 //===----------------------------------------------------------------------===// 240 // CeilDivSIOp 241 //===----------------------------------------------------------------------===// 242 243 void arith::CeilDivSIOp::inferResultRanges( 244 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) { 245 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 246 247 DivisionFixupFn ceilDivSIFix = [](const APInt &lhs, const APInt &rhs, 248 const APInt &result) -> Optional<APInt> { 249 if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) { 250 bool overflowed = false; 251 APInt corrected = 252 result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed); 253 return overflowed ? Optional<APInt>() : corrected; 254 } 255 return result; 256 }; 257 setResultRange(getResult(), inferDivSIRange(lhs, rhs, ceilDivSIFix)); 258 } 259 260 //===----------------------------------------------------------------------===// 261 // FloorDivSIOp 262 //===----------------------------------------------------------------------===// 263 264 void arith::FloorDivSIOp::inferResultRanges( 265 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) { 266 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 267 268 DivisionFixupFn floorDivSIFix = [](const APInt &lhs, const APInt &rhs, 269 const APInt &result) -> Optional<APInt> { 270 if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) { 271 bool overflowed = false; 272 APInt corrected = 273 result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed); 274 return overflowed ? Optional<APInt>() : corrected; 275 } 276 return result; 277 }; 278 setResultRange(getResult(), inferDivSIRange(lhs, rhs, floorDivSIFix)); 279 } 280 281 //===----------------------------------------------------------------------===// 282 // RemUIOp 283 //===----------------------------------------------------------------------===// 284 285 void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 286 SetIntRangeFn setResultRange) { 287 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 288 const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax(); 289 290 unsigned width = rhsMin.getBitWidth(); 291 APInt umin = APInt::getZero(width); 292 APInt umax = APInt::getMaxValue(width); 293 294 if (!rhsMin.isZero()) { 295 umax = rhsMax - 1; 296 // Special case: sweeping out a contiguous range in N/[modulus] 297 if (rhsMin == rhsMax) { 298 const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(); 299 if ((lhsMax - lhsMin).ult(rhsMax)) { 300 APInt minRem = lhsMin.urem(rhsMax); 301 APInt maxRem = lhsMax.urem(rhsMax); 302 if (minRem.ule(maxRem)) { 303 umin = minRem; 304 umax = maxRem; 305 } 306 } 307 } 308 } 309 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); 310 } 311 312 //===----------------------------------------------------------------------===// 313 // RemSIOp 314 //===----------------------------------------------------------------------===// 315 316 void arith::RemSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 317 SetIntRangeFn setResultRange) { 318 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 319 const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), 320 &rhsMax = rhs.smax(); 321 322 unsigned width = rhsMax.getBitWidth(); 323 APInt smin = APInt::getSignedMinValue(width); 324 APInt smax = APInt::getSignedMaxValue(width); 325 // No bounds if zero could be a divisor. 326 bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative()); 327 if (canBound) { 328 APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs(); 329 bool canNegativeDividend = lhsMin.isNegative(); 330 bool canPositiveDividend = lhsMax.isStrictlyPositive(); 331 APInt zero = APInt::getZero(maxDivisor.getBitWidth()); 332 APInt maxPositiveResult = maxDivisor - 1; 333 APInt minNegativeResult = -maxPositiveResult; 334 smin = canNegativeDividend ? minNegativeResult : zero; 335 smax = canPositiveDividend ? maxPositiveResult : zero; 336 // Special case: sweeping out a contiguous range in N/[modulus]. 337 if (rhsMin == rhsMax) { 338 if ((lhsMax - lhsMin).ult(maxDivisor)) { 339 APInt minRem = lhsMin.srem(maxDivisor); 340 APInt maxRem = lhsMax.srem(maxDivisor); 341 if (minRem.sle(maxRem)) { 342 smin = minRem; 343 smax = maxRem; 344 } 345 } 346 } 347 } 348 setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); 349 } 350 351 //===----------------------------------------------------------------------===// 352 // AndIOp 353 //===----------------------------------------------------------------------===// 354 355 /// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, 356 /// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits 357 /// that both bonuds have in common. This gives us a consertive approximation 358 /// for what values can be passed to bitwise operations. 359 static std::tuple<APInt, APInt> 360 widenBitwiseBounds(const ConstantIntRanges &bound) { 361 APInt leftVal = bound.umin(), rightVal = bound.umax(); 362 unsigned bitwidth = leftVal.getBitWidth(); 363 unsigned differingBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros(); 364 leftVal.clearLowBits(differingBits); 365 rightVal.setLowBits(differingBits); 366 return std::make_tuple(std::move(leftVal), std::move(rightVal)); 367 } 368 369 void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 370 SetIntRangeFn setResultRange) { 371 APInt lhsZeros, lhsOnes, rhsZeros, rhsOnes; 372 std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(argRanges[0]); 373 std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(argRanges[1]); 374 auto andi = [](const APInt &a, const APInt &b) -> Optional<APInt> { 375 return a & b; 376 }; 377 setResultRange(getResult(), 378 minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, 379 /*isSigned=*/false)); 380 } 381 382 //===----------------------------------------------------------------------===// 383 // OrIOp 384 //===----------------------------------------------------------------------===// 385 386 void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 387 SetIntRangeFn setResultRange) { 388 APInt lhsZeros, lhsOnes, rhsZeros, rhsOnes; 389 std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(argRanges[0]); 390 std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(argRanges[1]); 391 auto ori = [](const APInt &a, const APInt &b) -> Optional<APInt> { 392 return a | b; 393 }; 394 setResultRange(getResult(), 395 minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, 396 /*isSigned=*/false)); 397 } 398 399 //===----------------------------------------------------------------------===// 400 // XOrIOp 401 //===----------------------------------------------------------------------===// 402 403 void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 404 SetIntRangeFn setResultRange) { 405 APInt lhsZeros, lhsOnes, rhsZeros, rhsOnes; 406 std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(argRanges[0]); 407 std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(argRanges[1]); 408 auto xori = [](const APInt &a, const APInt &b) -> Optional<APInt> { 409 return a ^ b; 410 }; 411 setResultRange(getResult(), 412 minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, 413 /*isSigned=*/false)); 414 } 415 416 //===----------------------------------------------------------------------===// 417 // MaxSIOp 418 //===----------------------------------------------------------------------===// 419 420 void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 421 SetIntRangeFn setResultRange) { 422 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 423 424 const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin(); 425 const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax(); 426 setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); 427 } 428 429 //===----------------------------------------------------------------------===// 430 // MaxUIOp 431 //===----------------------------------------------------------------------===// 432 433 void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 434 SetIntRangeFn setResultRange) { 435 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 436 437 const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin(); 438 const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax(); 439 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); 440 } 441 442 //===----------------------------------------------------------------------===// 443 // MinSIOp 444 //===----------------------------------------------------------------------===// 445 446 void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 447 SetIntRangeFn setResultRange) { 448 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 449 450 const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin(); 451 const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax(); 452 setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); 453 } 454 455 //===----------------------------------------------------------------------===// 456 // MinUIOp 457 //===----------------------------------------------------------------------===// 458 459 void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 460 SetIntRangeFn setResultRange) { 461 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 462 463 const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin(); 464 const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax(); 465 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); 466 } 467 468 //===----------------------------------------------------------------------===// 469 // ExtUIOp 470 //===----------------------------------------------------------------------===// 471 472 void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 473 SetIntRangeFn setResultRange) { 474 Type destType = getResult().getType(); 475 unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); 476 APInt umin = argRanges[0].umin().zext(destWidth); 477 APInt umax = argRanges[0].umax().zext(destWidth); 478 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); 479 } 480 481 //===----------------------------------------------------------------------===// 482 // ExtSIOp 483 //===----------------------------------------------------------------------===// 484 485 static ConstantIntRanges extSIRange(const ConstantIntRanges &range, 486 Type destType) { 487 unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); 488 APInt smin = range.smin().sext(destWidth); 489 APInt smax = range.smax().sext(destWidth); 490 return ConstantIntRanges::fromSigned(smin, smax); 491 } 492 493 void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 494 SetIntRangeFn setResultRange) { 495 Type destType = getResult().getType(); 496 setResultRange(getResult(), extSIRange(argRanges[0], destType)); 497 } 498 499 //===----------------------------------------------------------------------===// 500 // TruncIOp 501 //===----------------------------------------------------------------------===// 502 503 static ConstantIntRanges truncIRange(const ConstantIntRanges &range, 504 Type destType) { 505 unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); 506 APInt umin = range.umin().trunc(destWidth); 507 APInt umax = range.umax().trunc(destWidth); 508 APInt smin = range.smin().trunc(destWidth); 509 APInt smax = range.smax().trunc(destWidth); 510 return {umin, umax, smin, smax}; 511 } 512 513 void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 514 SetIntRangeFn setResultRange) { 515 Type destType = getResult().getType(); 516 setResultRange(getResult(), truncIRange(argRanges[0], destType)); 517 } 518 519 //===----------------------------------------------------------------------===// 520 // IndexCastOp 521 //===----------------------------------------------------------------------===// 522 523 void arith::IndexCastOp::inferResultRanges( 524 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) { 525 Type sourceType = getOperand().getType(); 526 Type destType = getResult().getType(); 527 unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType); 528 unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); 529 530 if (srcWidth < destWidth) 531 setResultRange(getResult(), extSIRange(argRanges[0], destType)); 532 else if (srcWidth > destWidth) 533 setResultRange(getResult(), truncIRange(argRanges[0], destType)); 534 else 535 setResultRange(getResult(), argRanges[0]); 536 } 537 538 //===----------------------------------------------------------------------===// 539 // CmpIOp 540 //===----------------------------------------------------------------------===// 541 542 bool isStaticallyTrue(arith::CmpIPredicate pred, const ConstantIntRanges &lhs, 543 const ConstantIntRanges &rhs) { 544 switch (pred) { 545 case arith::CmpIPredicate::sle: 546 case arith::CmpIPredicate::slt: 547 return (applyCmpPredicate(pred, lhs.smax(), rhs.smin())); 548 case arith::CmpIPredicate::ule: 549 case arith::CmpIPredicate::ult: 550 return applyCmpPredicate(pred, lhs.umax(), rhs.umin()); 551 case arith::CmpIPredicate::sge: 552 case arith::CmpIPredicate::sgt: 553 return applyCmpPredicate(pred, lhs.smin(), rhs.smax()); 554 case arith::CmpIPredicate::uge: 555 case arith::CmpIPredicate::ugt: 556 return applyCmpPredicate(pred, lhs.umin(), rhs.umax()); 557 case arith::CmpIPredicate::eq: { 558 Optional<APInt> lhsConst = lhs.getConstantValue(); 559 Optional<APInt> rhsConst = rhs.getConstantValue(); 560 return lhsConst && rhsConst && lhsConst == rhsConst; 561 } 562 case arith::CmpIPredicate::ne: { 563 // While equality requires that there is an interpration of the preceeding 564 // computations that produces equal constants, whether that be signed or 565 // unsigned, statically determining inequality requires that neither 566 // interpretation produce potentially overlapping ranges. 567 bool sne = isStaticallyTrue(CmpIPredicate::slt, lhs, rhs) || 568 isStaticallyTrue(CmpIPredicate::sgt, lhs, rhs); 569 bool une = isStaticallyTrue(CmpIPredicate::ult, lhs, rhs) || 570 isStaticallyTrue(CmpIPredicate::ugt, lhs, rhs); 571 return sne && une; 572 } 573 } 574 return false; 575 } 576 577 void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 578 SetIntRangeFn setResultRange) { 579 arith::CmpIPredicate pred = getPredicate(); 580 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 581 582 APInt min = APInt::getZero(1); 583 APInt max = APInt::getAllOnesValue(1); 584 if (isStaticallyTrue(pred, lhs, rhs)) 585 min = max; 586 else if (isStaticallyTrue(invertPredicate(pred), lhs, rhs)) 587 max = min; 588 589 setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); 590 } 591 592 //===----------------------------------------------------------------------===// 593 // SelectOp 594 //===----------------------------------------------------------------------===// 595 596 void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 597 SetIntRangeFn setResultRange) { 598 Optional<APInt> mbCondVal = argRanges[0].getConstantValue(); 599 600 if (mbCondVal) { 601 if (mbCondVal->isZero()) 602 setResultRange(getResult(), argRanges[2]); 603 else 604 setResultRange(getResult(), argRanges[1]); 605 return; 606 } 607 setResultRange(getResult(), argRanges[1].rangeUnion(argRanges[2])); 608 } 609 610 //===----------------------------------------------------------------------===// 611 // ShLIOp 612 //===----------------------------------------------------------------------===// 613 614 void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 615 SetIntRangeFn setResultRange) { 616 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 617 ConstArithFn shl = [](const APInt &l, const APInt &r) -> Optional<APInt> { 618 return r.uge(r.getBitWidth()) ? Optional<APInt>() : l.shl(r); 619 }; 620 ConstantIntRanges urange = 621 minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, 622 /*isSigned=*/false); 623 ConstantIntRanges srange = 624 minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, 625 /*isSigned=*/true); 626 setResultRange(getResult(), urange.intersection(srange)); 627 } 628 629 //===----------------------------------------------------------------------===// 630 // ShRUIOp 631 //===----------------------------------------------------------------------===// 632 633 void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 634 SetIntRangeFn setResultRange) { 635 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 636 637 ConstArithFn lshr = [](const APInt &l, const APInt &r) -> Optional<APInt> { 638 return r.uge(r.getBitWidth()) ? Optional<APInt>() : l.lshr(r); 639 }; 640 setResultRange(getResult(), minMaxBy(lshr, {lhs.umin(), lhs.umax()}, 641 {rhs.umin(), rhs.umax()}, 642 /*isSigned=*/false)); 643 } 644 645 //===----------------------------------------------------------------------===// 646 // ShRSIOp 647 //===----------------------------------------------------------------------===// 648 649 void arith::ShRSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 650 SetIntRangeFn setResultRange) { 651 const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; 652 653 ConstArithFn ashr = [](const APInt &l, const APInt &r) -> Optional<APInt> { 654 return r.uge(r.getBitWidth()) ? Optional<APInt>() : l.ashr(r); 655 }; 656 657 setResultRange(getResult(), 658 minMaxBy(ashr, {lhs.smin(), lhs.smax()}, 659 {rhs.umin(), rhs.umax()}, /*isSigned=*/true)); 660 } 661