//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Interfaces/InferIntRangeInterface.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "int-range-analysis" using namespace mlir; using namespace mlir::arith; /// Function that evaluates the result of doing something on arithmetic /// constants and returns None on overflow. using ConstArithFn = function_ref(const APInt &, const APInt &)>; /// Return the maxmially wide signed or unsigned range for a given bitwidth. /// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, /// If either computation overflows, make the result unbounded. static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft, const APInt &minRight, const APInt &maxLeft, const APInt &maxRight, bool isSigned) { Optional maybeMin = op(minLeft, minRight); Optional maybeMax = op(maxLeft, maxRight); if (maybeMin && maybeMax) return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned); return ConstantIntRanges::maxRange(minLeft.getBitWidth()); } /// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`, /// ignoring unbounded values. Returns the maximal range if `op` overflows. static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef lhs, ArrayRef rhs, bool isSigned) { unsigned width = lhs[0].getBitWidth(); APInt min = isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width); APInt max = isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width); for (const APInt &left : lhs) { for (const APInt &right : rhs) { Optional maybeThisResult = op(left, right); if (!maybeThisResult) return ConstantIntRanges::maxRange(width); APInt result = std::move(*maybeThisResult); min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min; max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max; } } return ConstantIntRanges::range(min, max, isSigned); } //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// void arith::ConstantOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { auto constAttr = getValue().dyn_cast_or_null(); if (constAttr) { const APInt &value = constAttr.getValue(); setResultRange(getResult(), ConstantIntRanges::constant(value)); } } //===----------------------------------------------------------------------===// // AddIOp //===----------------------------------------------------------------------===// void arith::AddIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; ConstArithFn uadd = [](const APInt &a, const APInt &b) -> Optional { bool overflowed = false; APInt result = a.uadd_ov(b, overflowed); return overflowed ? Optional() : result; }; ConstArithFn sadd = [](const APInt &a, const APInt &b) -> Optional { bool overflowed = false; APInt result = a.sadd_ov(b, overflowed); return overflowed ? Optional() : result; }; ConstantIntRanges urange = computeBoundsBy( uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false); ConstantIntRanges srange = computeBoundsBy( sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true); setResultRange(getResult(), urange.intersection(srange)); } //===----------------------------------------------------------------------===// // SubIOp //===----------------------------------------------------------------------===// void arith::SubIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; ConstArithFn usub = [](const APInt &a, const APInt &b) -> Optional { bool overflowed = false; APInt result = a.usub_ov(b, overflowed); return overflowed ? Optional() : result; }; ConstArithFn ssub = [](const APInt &a, const APInt &b) -> Optional { bool overflowed = false; APInt result = a.ssub_ov(b, overflowed); return overflowed ? Optional() : result; }; ConstantIntRanges urange = computeBoundsBy( usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false); ConstantIntRanges srange = computeBoundsBy( ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true); setResultRange(getResult(), urange.intersection(srange)); } //===----------------------------------------------------------------------===// // MulIOp //===----------------------------------------------------------------------===// void arith::MulIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; ConstArithFn umul = [](const APInt &a, const APInt &b) -> Optional { bool overflowed = false; APInt result = a.umul_ov(b, overflowed); return overflowed ? Optional() : result; }; ConstArithFn smul = [](const APInt &a, const APInt &b) -> Optional { bool overflowed = false; APInt result = a.smul_ov(b, overflowed); return overflowed ? Optional() : result; }; ConstantIntRanges urange = minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, /*isSigned=*/false); ConstantIntRanges srange = minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()}, /*isSigned=*/true); setResultRange(getResult(), urange.intersection(srange)); } //===----------------------------------------------------------------------===// // DivUIOp //===----------------------------------------------------------------------===// /// Fix up division results (ex. for ceiling and floor), returning an APInt /// if there has been no overflow using DivisionFixupFn = function_ref( const APInt &lhs, const APInt &rhs, const APInt &result)>; static ConstantIntRanges inferDivUIRange(const ConstantIntRanges &lhs, const ConstantIntRanges &rhs, DivisionFixupFn fixup) { const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(), &rhsMax = rhs.umax(); if (!rhsMin.isZero()) { auto udiv = [&fixup](const APInt &a, const APInt &b) -> Optional { return fixup(a, b, a.udiv(b)); }; return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, /*isSigned=*/false); } // Otherwise, it's possible we might divide by 0. return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); } void arith::DivUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferDivUIRange(argRanges[0], argRanges[1], [](const APInt &lhs, const APInt &rhs, const APInt &result) { return result; })); } //===----------------------------------------------------------------------===// // DivSIOp //===----------------------------------------------------------------------===// static ConstantIntRanges inferDivSIRange(const ConstantIntRanges &lhs, const ConstantIntRanges &rhs, DivisionFixupFn fixup) { const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), &rhsMax = rhs.smax(); bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative(); if (canDivide) { auto sdiv = [&fixup](const APInt &a, const APInt &b) -> Optional { bool overflowed = false; APInt result = a.sdiv_ov(b, overflowed); return overflowed ? Optional() : fixup(a, b, result); }; return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, /*isSigned=*/true); } return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); } void arith::DivSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferDivSIRange(argRanges[0], argRanges[1], [](const APInt &lhs, const APInt &rhs, const APInt &result) { return result; })); } //===----------------------------------------------------------------------===// // CeilDivUIOp //===----------------------------------------------------------------------===// void arith::CeilDivUIOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRange) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; DivisionFixupFn ceilDivUIFix = [](const APInt &lhs, const APInt &rhs, const APInt &result) -> Optional { if (!lhs.urem(rhs).isZero()) { bool overflowed = false; APInt corrected = result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed); return overflowed ? Optional() : corrected; } return result; }; setResultRange(getResult(), inferDivUIRange(lhs, rhs, ceilDivUIFix)); } //===----------------------------------------------------------------------===// // CeilDivSIOp //===----------------------------------------------------------------------===// void arith::CeilDivSIOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRange) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; DivisionFixupFn ceilDivSIFix = [](const APInt &lhs, const APInt &rhs, const APInt &result) -> Optional { if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) { bool overflowed = false; APInt corrected = result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed); return overflowed ? Optional() : corrected; } return result; }; setResultRange(getResult(), inferDivSIRange(lhs, rhs, ceilDivSIFix)); } //===----------------------------------------------------------------------===// // FloorDivSIOp //===----------------------------------------------------------------------===// void arith::FloorDivSIOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRange) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; DivisionFixupFn floorDivSIFix = [](const APInt &lhs, const APInt &rhs, const APInt &result) -> Optional { if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) { bool overflowed = false; APInt corrected = result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed); return overflowed ? Optional() : corrected; } return result; }; setResultRange(getResult(), inferDivSIRange(lhs, rhs, floorDivSIFix)); } //===----------------------------------------------------------------------===// // RemUIOp //===----------------------------------------------------------------------===// void arith::RemUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax(); unsigned width = rhsMin.getBitWidth(); APInt umin = APInt::getZero(width); APInt umax = APInt::getMaxValue(width); if (!rhsMin.isZero()) { umax = rhsMax - 1; // Special case: sweeping out a contiguous range in N/[modulus] if (rhsMin == rhsMax) { const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(); if ((lhsMax - lhsMin).ult(rhsMax)) { APInt minRem = lhsMin.urem(rhsMax); APInt maxRem = lhsMax.urem(rhsMax); if (minRem.ule(maxRem)) { umin = minRem; umax = maxRem; } } } } setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); } //===----------------------------------------------------------------------===// // RemSIOp //===----------------------------------------------------------------------===// void arith::RemSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), &rhsMax = rhs.smax(); unsigned width = rhsMax.getBitWidth(); APInt smin = APInt::getSignedMinValue(width); APInt smax = APInt::getSignedMaxValue(width); // No bounds if zero could be a divisor. bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative()); if (canBound) { APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs(); bool canNegativeDividend = lhsMin.isNegative(); bool canPositiveDividend = lhsMax.isStrictlyPositive(); APInt zero = APInt::getZero(maxDivisor.getBitWidth()); APInt maxPositiveResult = maxDivisor - 1; APInt minNegativeResult = -maxPositiveResult; smin = canNegativeDividend ? minNegativeResult : zero; smax = canPositiveDividend ? maxPositiveResult : zero; // Special case: sweeping out a contiguous range in N/[modulus]. if (rhsMin == rhsMax) { if ((lhsMax - lhsMin).ult(maxDivisor)) { APInt minRem = lhsMin.srem(maxDivisor); APInt maxRem = lhsMax.srem(maxDivisor); if (minRem.sle(maxRem)) { smin = minRem; smax = maxRem; } } } } setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); } //===----------------------------------------------------------------------===// // AndIOp //===----------------------------------------------------------------------===// /// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, /// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits /// that both bonuds have in common. This gives us a consertive approximation /// for what values can be passed to bitwise operations. static std::tuple widenBitwiseBounds(const ConstantIntRanges &bound) { APInt leftVal = bound.umin(), rightVal = bound.umax(); unsigned bitwidth = leftVal.getBitWidth(); unsigned differingBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros(); leftVal.clearLowBits(differingBits); rightVal.setLowBits(differingBits); return std::make_tuple(std::move(leftVal), std::move(rightVal)); } void arith::AndIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { APInt lhsZeros, lhsOnes, rhsZeros, rhsOnes; std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(argRanges[0]); std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(argRanges[1]); auto andi = [](const APInt &a, const APInt &b) -> Optional { return a & b; }; setResultRange(getResult(), minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, /*isSigned=*/false)); } //===----------------------------------------------------------------------===// // OrIOp //===----------------------------------------------------------------------===// void arith::OrIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { APInt lhsZeros, lhsOnes, rhsZeros, rhsOnes; std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(argRanges[0]); std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(argRanges[1]); auto ori = [](const APInt &a, const APInt &b) -> Optional { return a | b; }; setResultRange(getResult(), minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, /*isSigned=*/false)); } //===----------------------------------------------------------------------===// // XOrIOp //===----------------------------------------------------------------------===// void arith::XOrIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { APInt lhsZeros, lhsOnes, rhsZeros, rhsOnes; std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(argRanges[0]); std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(argRanges[1]); auto xori = [](const APInt &a, const APInt &b) -> Optional { return a ^ b; }; setResultRange(getResult(), minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, /*isSigned=*/false)); } //===----------------------------------------------------------------------===// // MaxSIOp //===----------------------------------------------------------------------===// void arith::MaxSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin(); const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax(); setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); } //===----------------------------------------------------------------------===// // MaxUIOp //===----------------------------------------------------------------------===// void arith::MaxUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin(); const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax(); setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); } //===----------------------------------------------------------------------===// // MinSIOp //===----------------------------------------------------------------------===// void arith::MinSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin(); const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax(); setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); } //===----------------------------------------------------------------------===// // MinUIOp //===----------------------------------------------------------------------===// void arith::MinUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin(); const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax(); setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); } //===----------------------------------------------------------------------===// // ExtUIOp //===----------------------------------------------------------------------===// void arith::ExtUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { Type destType = getResult().getType(); unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); APInt umin = argRanges[0].umin().zext(destWidth); APInt umax = argRanges[0].umax().zext(destWidth); setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); } //===----------------------------------------------------------------------===// // ExtSIOp //===----------------------------------------------------------------------===// static ConstantIntRanges extSIRange(const ConstantIntRanges &range, Type destType) { unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); APInt smin = range.smin().sext(destWidth); APInt smax = range.smax().sext(destWidth); return ConstantIntRanges::fromSigned(smin, smax); } void arith::ExtSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { Type destType = getResult().getType(); setResultRange(getResult(), extSIRange(argRanges[0], destType)); } //===----------------------------------------------------------------------===// // TruncIOp //===----------------------------------------------------------------------===// static ConstantIntRanges truncIRange(const ConstantIntRanges &range, Type destType) { unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); APInt umin = range.umin().trunc(destWidth); APInt umax = range.umax().trunc(destWidth); APInt smin = range.smin().trunc(destWidth); APInt smax = range.smax().trunc(destWidth); return {umin, umax, smin, smax}; } void arith::TruncIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { Type destType = getResult().getType(); setResultRange(getResult(), truncIRange(argRanges[0], destType)); } //===----------------------------------------------------------------------===// // IndexCastOp //===----------------------------------------------------------------------===// void arith::IndexCastOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRange) { Type sourceType = getOperand().getType(); Type destType = getResult().getType(); unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType); unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); if (srcWidth < destWidth) setResultRange(getResult(), extSIRange(argRanges[0], destType)); else if (srcWidth > destWidth) setResultRange(getResult(), truncIRange(argRanges[0], destType)); else setResultRange(getResult(), argRanges[0]); } //===----------------------------------------------------------------------===// // CmpIOp //===----------------------------------------------------------------------===// bool isStaticallyTrue(arith::CmpIPredicate pred, const ConstantIntRanges &lhs, const ConstantIntRanges &rhs) { switch (pred) { case arith::CmpIPredicate::sle: case arith::CmpIPredicate::slt: return (applyCmpPredicate(pred, lhs.smax(), rhs.smin())); case arith::CmpIPredicate::ule: case arith::CmpIPredicate::ult: return applyCmpPredicate(pred, lhs.umax(), rhs.umin()); case arith::CmpIPredicate::sge: case arith::CmpIPredicate::sgt: return applyCmpPredicate(pred, lhs.smin(), rhs.smax()); case arith::CmpIPredicate::uge: case arith::CmpIPredicate::ugt: return applyCmpPredicate(pred, lhs.umin(), rhs.umax()); case arith::CmpIPredicate::eq: { Optional lhsConst = lhs.getConstantValue(); Optional rhsConst = rhs.getConstantValue(); return lhsConst && rhsConst && lhsConst == rhsConst; } case arith::CmpIPredicate::ne: { // While equality requires that there is an interpration of the preceeding // computations that produces equal constants, whether that be signed or // unsigned, statically determining inequality requires that neither // interpretation produce potentially overlapping ranges. bool sne = isStaticallyTrue(CmpIPredicate::slt, lhs, rhs) || isStaticallyTrue(CmpIPredicate::sgt, lhs, rhs); bool une = isStaticallyTrue(CmpIPredicate::ult, lhs, rhs) || isStaticallyTrue(CmpIPredicate::ugt, lhs, rhs); return sne && une; } } return false; } void arith::CmpIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { arith::CmpIPredicate pred = getPredicate(); const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; APInt min = APInt::getZero(1); APInt max = APInt::getAllOnesValue(1); if (isStaticallyTrue(pred, lhs, rhs)) min = max; else if (isStaticallyTrue(invertPredicate(pred), lhs, rhs)) max = min; setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); } //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// void arith::SelectOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { Optional mbCondVal = argRanges[0].getConstantValue(); if (mbCondVal) { if (mbCondVal->isZero()) setResultRange(getResult(), argRanges[2]); else setResultRange(getResult(), argRanges[1]); return; } setResultRange(getResult(), argRanges[1].rangeUnion(argRanges[2])); } //===----------------------------------------------------------------------===// // ShLIOp //===----------------------------------------------------------------------===// void arith::ShLIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; ConstArithFn shl = [](const APInt &l, const APInt &r) -> Optional { return r.uge(r.getBitWidth()) ? Optional() : l.shl(r); }; ConstantIntRanges urange = minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, /*isSigned=*/false); ConstantIntRanges srange = minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, /*isSigned=*/true); setResultRange(getResult(), urange.intersection(srange)); } //===----------------------------------------------------------------------===// // ShRUIOp //===----------------------------------------------------------------------===// void arith::ShRUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; ConstArithFn lshr = [](const APInt &l, const APInt &r) -> Optional { return r.uge(r.getBitWidth()) ? Optional() : l.lshr(r); }; setResultRange(getResult(), minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, /*isSigned=*/false)); } //===----------------------------------------------------------------------===// // ShRSIOp //===----------------------------------------------------------------------===// void arith::ShRSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; ConstArithFn ashr = [](const APInt &l, const APInt &r) -> Optional { return r.uge(r.getBitWidth()) ? Optional() : l.ashr(r); }; setResultRange(getResult(), minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, /*isSigned=*/true)); }