//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"

#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "int-range-analysis"

using namespace mlir;
using namespace mlir::arith;

/// Function that evaluates the result of doing something on arithmetic
/// constants and returns None on overflow.
using ConstArithFn =
    function_ref<Optional<APInt>(const APInt &, const APInt &)>;

/// Return the maxmially wide signed or unsigned range for a given bitwidth.

/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible,
/// If either computation overflows, make the result unbounded.
static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft,
                                         const APInt &minRight,
                                         const APInt &maxLeft,
                                         const APInt &maxRight, bool isSigned) {
  Optional<APInt> maybeMin = op(minLeft, minRight);
  Optional<APInt> maybeMax = op(maxLeft, maxRight);
  if (maybeMin && maybeMax)
    return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned);
  return ConstantIntRanges::maxRange(minLeft.getBitWidth());
}

/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`,
/// ignoring unbounded values. Returns the maximal range if `op` overflows.
static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef<APInt> lhs,
                                  ArrayRef<APInt> rhs, bool isSigned) {
  unsigned width = lhs[0].getBitWidth();
  APInt min =
      isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width);
  APInt max =
      isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width);
  for (const APInt &left : lhs) {
    for (const APInt &right : rhs) {
      Optional<APInt> maybeThisResult = op(left, right);
      if (!maybeThisResult)
        return ConstantIntRanges::maxRange(width);
      APInt result = std::move(*maybeThisResult);
      min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min;
      max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max;
    }
  }
  return ConstantIntRanges::range(min, max, isSigned);
}

//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//

void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                          SetIntRangeFn setResultRange) {
  auto constAttr = getValue().dyn_cast_or_null<IntegerAttr>();
  if (constAttr) {
    const APInt &value = constAttr.getValue();
    setResultRange(getResult(), ConstantIntRanges::constant(value));
  }
}

//===----------------------------------------------------------------------===//
// AddIOp
//===----------------------------------------------------------------------===//

void arith::AddIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                      SetIntRangeFn setResultRange) {
  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
  ConstArithFn uadd = [](const APInt &a, const APInt &b) -> Optional<APInt> {
    bool overflowed = false;
    APInt result = a.uadd_ov(b, overflowed);
    return overflowed ? Optional<APInt>() : result;
  };
  ConstArithFn sadd = [](const APInt &a, const APInt &b) -> Optional<APInt> {
    bool overflowed = false;
    APInt result = a.sadd_ov(b, overflowed);
    return overflowed ? Optional<APInt>() : result;
  };

  ConstantIntRanges urange = computeBoundsBy(
      uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
  ConstantIntRanges srange = computeBoundsBy(
      sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
  setResultRange(getResult(), urange.intersection(srange));
}

//===----------------------------------------------------------------------===//
// SubIOp
//===----------------------------------------------------------------------===//

void arith::SubIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                      SetIntRangeFn setResultRange) {
  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];

  ConstArithFn usub = [](const APInt &a, const APInt &b) -> Optional<APInt> {
    bool overflowed = false;
    APInt result = a.usub_ov(b, overflowed);
    return overflowed ? Optional<APInt>() : result;
  };
  ConstArithFn ssub = [](const APInt &a, const APInt &b) -> Optional<APInt> {
    bool overflowed = false;
    APInt result = a.ssub_ov(b, overflowed);
    return overflowed ? Optional<APInt>() : result;
  };
  ConstantIntRanges urange = computeBoundsBy(
      usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
  ConstantIntRanges srange = computeBoundsBy(
      ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
  setResultRange(getResult(), urange.intersection(srange));
}

//===----------------------------------------------------------------------===//
// MulIOp
//===----------------------------------------------------------------------===//

void arith::MulIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                      SetIntRangeFn setResultRange) {
  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];

  ConstArithFn umul = [](const APInt &a, const APInt &b) -> Optional<APInt> {
    bool overflowed = false;
    APInt result = a.umul_ov(b, overflowed);
    return overflowed ? Optional<APInt>() : result;
  };
  ConstArithFn smul = [](const APInt &a, const APInt &b) -> Optional<APInt> {
    bool overflowed = false;
    APInt result = a.smul_ov(b, overflowed);
    return overflowed ? Optional<APInt>() : result;
  };

  ConstantIntRanges urange =
      minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
               /*isSigned=*/false);
  ConstantIntRanges srange =
      minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
               /*isSigned=*/true);

  setResultRange(getResult(), urange.intersection(srange));
}

//===----------------------------------------------------------------------===//
// DivUIOp
//===----------------------------------------------------------------------===//

/// Fix up division results (ex. for ceiling and floor), returning an APInt
/// if there has been no overflow
using DivisionFixupFn = function_ref<Optional<APInt>(
    const APInt &lhs, const APInt &rhs, const APInt &result)>;

static ConstantIntRanges inferDivUIRange(const ConstantIntRanges &lhs,
                                         const ConstantIntRanges &rhs,
                                         DivisionFixupFn fixup) {
  const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(),
              &rhsMax = rhs.umax();

  if (!rhsMin.isZero()) {
    auto udiv = [&fixup](const APInt &a, const APInt &b) -> Optional<APInt> {
      return fixup(a, b, a.udiv(b));
    };
    return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
                    /*isSigned=*/false);
  }
  // Otherwise, it's possible we might divide by 0.
  return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
}

void arith::DivUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
  setResultRange(getResult(),
                 inferDivUIRange(argRanges[0], argRanges[1],
                                 [](const APInt &lhs, const APInt &rhs,
                                    const APInt &result) { return result; }));
}

//===----------------------------------------------------------------------===//
// DivSIOp
//===----------------------------------------------------------------------===//

static ConstantIntRanges inferDivSIRange(const ConstantIntRanges &lhs,
                                         const ConstantIntRanges &rhs,
                                         DivisionFixupFn fixup) {
  const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
              &rhsMax = rhs.smax();
  bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative();

  if (canDivide) {
    auto sdiv = [&fixup](const APInt &a, const APInt &b) -> Optional<APInt> {
      bool overflowed = false;
      APInt result = a.sdiv_ov(b, overflowed);
      return overflowed ? Optional<APInt>() : fixup(a, b, result);
    };
    return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
                    /*isSigned=*/true);
  }
  return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
}

void arith::DivSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
  setResultRange(getResult(),
                 inferDivSIRange(argRanges[0], argRanges[1],
                                 [](const APInt &lhs, const APInt &rhs,
                                    const APInt &result) { return result; }));
}

//===----------------------------------------------------------------------===//
// CeilDivUIOp
//===----------------------------------------------------------------------===//

void arith::CeilDivUIOp::inferResultRanges(
    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];

  DivisionFixupFn ceilDivUIFix = [](const APInt &lhs, const APInt &rhs,
                                    const APInt &result) -> Optional<APInt> {
    if (!lhs.urem(rhs).isZero()) {
      bool overflowed = false;
      APInt corrected =
          result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed);
      return overflowed ? Optional<APInt>() : corrected;
    }
    return result;
  };
  setResultRange(getResult(), inferDivUIRange(lhs, rhs, ceilDivUIFix));
}

//===----------------------------------------------------------------------===//
// CeilDivSIOp
//===----------------------------------------------------------------------===//

void arith::CeilDivSIOp::inferResultRanges(
    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];

  DivisionFixupFn ceilDivSIFix = [](const APInt &lhs, const APInt &rhs,
                                    const APInt &result) -> Optional<APInt> {
    if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) {
      bool overflowed = false;
      APInt corrected =
          result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed);
      return overflowed ? Optional<APInt>() : corrected;
    }
    return result;
  };
  setResultRange(getResult(), inferDivSIRange(lhs, rhs, ceilDivSIFix));
}

//===----------------------------------------------------------------------===//
// FloorDivSIOp
//===----------------------------------------------------------------------===//

void arith::FloorDivSIOp::inferResultRanges(
    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];

  DivisionFixupFn floorDivSIFix = [](const APInt &lhs, const APInt &rhs,
                                     const APInt &result) -> Optional<APInt> {
    if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) {
      bool overflowed = false;
      APInt corrected =
          result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed);
      return overflowed ? Optional<APInt>() : corrected;
    }
    return result;
  };
  setResultRange(getResult(), inferDivSIRange(lhs, rhs, floorDivSIFix));
}

//===----------------------------------------------------------------------===//
// RemUIOp
//===----------------------------------------------------------------------===//

void arith::RemUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
  const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();

  unsigned width = rhsMin.getBitWidth();
  APInt umin = APInt::getZero(width);
  APInt umax = APInt::getMaxValue(width);

  if (!rhsMin.isZero()) {
    umax = rhsMax - 1;
    // Special case: sweeping out a contiguous range in N/[modulus]
    if (rhsMin == rhsMax) {
      const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax();
      if ((lhsMax - lhsMin).ult(rhsMax)) {
        APInt minRem = lhsMin.urem(rhsMax);
        APInt maxRem = lhsMax.urem(rhsMax);
        if (minRem.ule(maxRem)) {
          umin = minRem;
          umax = maxRem;
        }
      }
    }
  }
  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
}

//===----------------------------------------------------------------------===//
// RemSIOp
//===----------------------------------------------------------------------===//

void arith::RemSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
  const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
              &rhsMax = rhs.smax();

  unsigned width = rhsMax.getBitWidth();
  APInt smin = APInt::getSignedMinValue(width);
  APInt smax = APInt::getSignedMaxValue(width);
  // No bounds if zero could be a divisor.
  bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative());
  if (canBound) {
    APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs();
    bool canNegativeDividend = lhsMin.isNegative();
    bool canPositiveDividend = lhsMax.isStrictlyPositive();
    APInt zero = APInt::getZero(maxDivisor.getBitWidth());
    APInt maxPositiveResult = maxDivisor - 1;
    APInt minNegativeResult = -maxPositiveResult;
    smin = canNegativeDividend ? minNegativeResult : zero;
    smax = canPositiveDividend ? maxPositiveResult : zero;
    // Special case: sweeping out a contiguous range in N/[modulus].
    if (rhsMin == rhsMax) {
      if ((lhsMax - lhsMin).ult(maxDivisor)) {
        APInt minRem = lhsMin.srem(maxDivisor);
        APInt maxRem = lhsMax.srem(maxDivisor);
        if (minRem.sle(maxRem)) {
          smin = minRem;
          smax = maxRem;
        }
      }
    }
  }
  setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax));
}

//===----------------------------------------------------------------------===//
// AndIOp
//===----------------------------------------------------------------------===//

/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???,
/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits
/// that both bonuds have in common. This gives us a consertive approximation
/// for what values can be passed to bitwise operations.
static std::tuple<APInt, APInt>
widenBitwiseBounds(const ConstantIntRanges &bound) {
  APInt leftVal = bound.umin(), rightVal = bound.umax();
  unsigned bitwidth = leftVal.getBitWidth();
  unsigned differingBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros();
  leftVal.clearLowBits(differingBits);
  rightVal.setLowBits(differingBits);
  return std::make_tuple(std::move(leftVal), std::move(rightVal));
}

void arith::AndIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                      SetIntRangeFn setResultRange) {
  APInt lhsZeros, lhsOnes, rhsZeros, rhsOnes;
  std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(argRanges[0]);
  std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(argRanges[1]);
  auto andi = [](const APInt &a, const APInt &b) -> Optional<APInt> {
    return a & b;
  };
  setResultRange(getResult(),
                 minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
                          /*isSigned=*/false));
}

//===----------------------------------------------------------------------===//
// OrIOp
//===----------------------------------------------------------------------===//

void arith::OrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                     SetIntRangeFn setResultRange) {
  APInt lhsZeros, lhsOnes, rhsZeros, rhsOnes;
  std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(argRanges[0]);
  std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(argRanges[1]);
  auto ori = [](const APInt &a, const APInt &b) -> Optional<APInt> {
    return a | b;
  };
  setResultRange(getResult(),
                 minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
                          /*isSigned=*/false));
}

//===----------------------------------------------------------------------===//
// XOrIOp
//===----------------------------------------------------------------------===//

void arith::XOrIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                      SetIntRangeFn setResultRange) {
  APInt lhsZeros, lhsOnes, rhsZeros, rhsOnes;
  std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(argRanges[0]);
  std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(argRanges[1]);
  auto xori = [](const APInt &a, const APInt &b) -> Optional<APInt> {
    return a ^ b;
  };
  setResultRange(getResult(),
                 minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
                          /*isSigned=*/false));
}

//===----------------------------------------------------------------------===//
// MaxSIOp
//===----------------------------------------------------------------------===//

void arith::MaxSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];

  const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
  const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
  setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax));
}

//===----------------------------------------------------------------------===//
// MaxUIOp
//===----------------------------------------------------------------------===//

void arith::MaxUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];

  const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin();
  const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax();
  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
}

//===----------------------------------------------------------------------===//
// MinSIOp
//===----------------------------------------------------------------------===//

void arith::MinSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];

  const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin();
  const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax();
  setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax));
}

//===----------------------------------------------------------------------===//
// MinUIOp
//===----------------------------------------------------------------------===//

void arith::MinUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];

  const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin();
  const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax();
  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
}

//===----------------------------------------------------------------------===//
// ExtUIOp
//===----------------------------------------------------------------------===//

void arith::ExtUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
  Type destType = getResult().getType();
  unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
  APInt umin = argRanges[0].umin().zext(destWidth);
  APInt umax = argRanges[0].umax().zext(destWidth);
  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax));
}

//===----------------------------------------------------------------------===//
// ExtSIOp
//===----------------------------------------------------------------------===//

static ConstantIntRanges extSIRange(const ConstantIntRanges &range,
                                    Type destType) {
  unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
  APInt smin = range.smin().sext(destWidth);
  APInt smax = range.smax().sext(destWidth);
  return ConstantIntRanges::fromSigned(smin, smax);
}

void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
  Type destType = getResult().getType();
  setResultRange(getResult(), extSIRange(argRanges[0], destType));
}

//===----------------------------------------------------------------------===//
// TruncIOp
//===----------------------------------------------------------------------===//

static ConstantIntRanges truncIRange(const ConstantIntRanges &range,
                                     Type destType) {
  unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
  APInt umin = range.umin().trunc(destWidth);
  APInt umax = range.umax().trunc(destWidth);
  APInt smin = range.smin().trunc(destWidth);
  APInt smax = range.smax().trunc(destWidth);
  return {umin, umax, smin, smax};
}

void arith::TruncIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
  Type destType = getResult().getType();
  setResultRange(getResult(), truncIRange(argRanges[0], destType));
}

//===----------------------------------------------------------------------===//
// IndexCastOp
//===----------------------------------------------------------------------===//

void arith::IndexCastOp::inferResultRanges(
    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) {
  Type sourceType = getOperand().getType();
  Type destType = getResult().getType();
  unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
  unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);

  if (srcWidth < destWidth)
    setResultRange(getResult(), extSIRange(argRanges[0], destType));
  else if (srcWidth > destWidth)
    setResultRange(getResult(), truncIRange(argRanges[0], destType));
  else
    setResultRange(getResult(), argRanges[0]);
}

//===----------------------------------------------------------------------===//
// CmpIOp
//===----------------------------------------------------------------------===//

bool isStaticallyTrue(arith::CmpIPredicate pred, const ConstantIntRanges &lhs,
                      const ConstantIntRanges &rhs) {
  switch (pred) {
  case arith::CmpIPredicate::sle:
  case arith::CmpIPredicate::slt:
    return (applyCmpPredicate(pred, lhs.smax(), rhs.smin()));
  case arith::CmpIPredicate::ule:
  case arith::CmpIPredicate::ult:
    return applyCmpPredicate(pred, lhs.umax(), rhs.umin());
  case arith::CmpIPredicate::sge:
  case arith::CmpIPredicate::sgt:
    return applyCmpPredicate(pred, lhs.smin(), rhs.smax());
  case arith::CmpIPredicate::uge:
  case arith::CmpIPredicate::ugt:
    return applyCmpPredicate(pred, lhs.umin(), rhs.umax());
  case arith::CmpIPredicate::eq: {
    Optional<APInt> lhsConst = lhs.getConstantValue();
    Optional<APInt> rhsConst = rhs.getConstantValue();
    return lhsConst && rhsConst && lhsConst == rhsConst;
  }
  case arith::CmpIPredicate::ne: {
    // While equality requires that there is an interpration of the preceeding
    // computations that produces equal constants, whether that be signed or
    // unsigned, statically determining inequality requires that neither
    // interpretation produce potentially overlapping ranges.
    bool sne = isStaticallyTrue(CmpIPredicate::slt, lhs, rhs) ||
               isStaticallyTrue(CmpIPredicate::sgt, lhs, rhs);
    bool une = isStaticallyTrue(CmpIPredicate::ult, lhs, rhs) ||
               isStaticallyTrue(CmpIPredicate::ugt, lhs, rhs);
    return sne && une;
  }
  }
  return false;
}

void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                      SetIntRangeFn setResultRange) {
  arith::CmpIPredicate pred = getPredicate();
  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];

  APInt min = APInt::getZero(1);
  APInt max = APInt::getAllOnesValue(1);
  if (isStaticallyTrue(pred, lhs, rhs))
    min = max;
  else if (isStaticallyTrue(invertPredicate(pred), lhs, rhs))
    max = min;

  setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
}

//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//

void arith::SelectOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                        SetIntRangeFn setResultRange) {
  Optional<APInt> mbCondVal = argRanges[0].getConstantValue();

  if (mbCondVal) {
    if (mbCondVal->isZero())
      setResultRange(getResult(), argRanges[2]);
    else
      setResultRange(getResult(), argRanges[1]);
    return;
  }
  setResultRange(getResult(), argRanges[1].rangeUnion(argRanges[2]));
}

//===----------------------------------------------------------------------===//
// ShLIOp
//===----------------------------------------------------------------------===//

void arith::ShLIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                      SetIntRangeFn setResultRange) {
  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
  ConstArithFn shl = [](const APInt &l, const APInt &r) -> Optional<APInt> {
    return r.uge(r.getBitWidth()) ? Optional<APInt>() : l.shl(r);
  };
  ConstantIntRanges urange =
      minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()},
               /*isSigned=*/false);
  ConstantIntRanges srange =
      minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()},
               /*isSigned=*/true);
  setResultRange(getResult(), urange.intersection(srange));
}

//===----------------------------------------------------------------------===//
// ShRUIOp
//===----------------------------------------------------------------------===//

void arith::ShRUIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];

  ConstArithFn lshr = [](const APInt &l, const APInt &r) -> Optional<APInt> {
    return r.uge(r.getBitWidth()) ? Optional<APInt>() : l.lshr(r);
  };
  setResultRange(getResult(), minMaxBy(lshr, {lhs.umin(), lhs.umax()},
                                       {rhs.umin(), rhs.umax()},
                                       /*isSigned=*/false));
}

//===----------------------------------------------------------------------===//
// ShRSIOp
//===----------------------------------------------------------------------===//

void arith::ShRSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
                                       SetIntRangeFn setResultRange) {
  const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];

  ConstArithFn ashr = [](const APInt &l, const APInt &r) -> Optional<APInt> {
    return r.uge(r.getBitWidth()) ? Optional<APInt>() : l.ashr(r);
  };

  setResultRange(getResult(),
                 minMaxBy(ashr, {lhs.smin(), lhs.smax()},
                          {rhs.umin(), rhs.umax()}, /*isSigned=*/true));
}
