1 //===- SlowMPInt.cpp - MLIR SlowMPInt Class -------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Analysis/Presburger/SlowMPInt.h" 10 #include "llvm/Support/MathExtras.h" 11 12 using namespace mlir; 13 using namespace presburger; 14 using namespace detail; 15 16 SlowMPInt::SlowMPInt(int64_t val) : val(64, val, /*isSigned=*/true) {} 17 SlowMPInt::SlowMPInt() : SlowMPInt(0) {} 18 SlowMPInt::SlowMPInt(const llvm::APInt &val) : val(val) {} 19 SlowMPInt &SlowMPInt::operator=(int64_t val) { return *this = SlowMPInt(val); } 20 SlowMPInt::operator int64_t() const { return val.getSExtValue(); } 21 22 llvm::hash_code detail::hash_value(const SlowMPInt &x) { 23 return hash_value(x.val); 24 } 25 26 /// --------------------------------------------------------------------------- 27 /// Printing. 28 /// --------------------------------------------------------------------------- 29 void SlowMPInt::print(llvm::raw_ostream &os) const { os << val; } 30 31 void SlowMPInt::dump() const { print(llvm::errs()); } 32 33 llvm::raw_ostream &detail::operator<<(llvm::raw_ostream &os, 34 const SlowMPInt &x) { 35 x.print(os); 36 return os; 37 } 38 39 /// --------------------------------------------------------------------------- 40 /// Convenience operator overloads for int64_t. 41 /// --------------------------------------------------------------------------- 42 SlowMPInt &detail::operator+=(SlowMPInt &a, int64_t b) { 43 return a += SlowMPInt(b); 44 } 45 SlowMPInt &detail::operator-=(SlowMPInt &a, int64_t b) { 46 return a -= SlowMPInt(b); 47 } 48 SlowMPInt &detail::operator*=(SlowMPInt &a, int64_t b) { 49 return a *= SlowMPInt(b); 50 } 51 SlowMPInt &detail::operator/=(SlowMPInt &a, int64_t b) { 52 return a /= SlowMPInt(b); 53 } 54 SlowMPInt &detail::operator%=(SlowMPInt &a, int64_t b) { 55 return a %= SlowMPInt(b); 56 } 57 58 bool detail::operator==(const SlowMPInt &a, int64_t b) { 59 return a == SlowMPInt(b); 60 } 61 bool detail::operator!=(const SlowMPInt &a, int64_t b) { 62 return a != SlowMPInt(b); 63 } 64 bool detail::operator>(const SlowMPInt &a, int64_t b) { 65 return a > SlowMPInt(b); 66 } 67 bool detail::operator<(const SlowMPInt &a, int64_t b) { 68 return a < SlowMPInt(b); 69 } 70 bool detail::operator<=(const SlowMPInt &a, int64_t b) { 71 return a <= SlowMPInt(b); 72 } 73 bool detail::operator>=(const SlowMPInt &a, int64_t b) { 74 return a >= SlowMPInt(b); 75 } 76 SlowMPInt detail::operator+(const SlowMPInt &a, int64_t b) { 77 return a + SlowMPInt(b); 78 } 79 SlowMPInt detail::operator-(const SlowMPInt &a, int64_t b) { 80 return a - SlowMPInt(b); 81 } 82 SlowMPInt detail::operator*(const SlowMPInt &a, int64_t b) { 83 return a * SlowMPInt(b); 84 } 85 SlowMPInt detail::operator/(const SlowMPInt &a, int64_t b) { 86 return a / SlowMPInt(b); 87 } 88 SlowMPInt detail::operator%(const SlowMPInt &a, int64_t b) { 89 return a % SlowMPInt(b); 90 } 91 92 bool detail::operator==(int64_t a, const SlowMPInt &b) { 93 return SlowMPInt(a) == b; 94 } 95 bool detail::operator!=(int64_t a, const SlowMPInt &b) { 96 return SlowMPInt(a) != b; 97 } 98 bool detail::operator>(int64_t a, const SlowMPInt &b) { 99 return SlowMPInt(a) > b; 100 } 101 bool detail::operator<(int64_t a, const SlowMPInt &b) { 102 return SlowMPInt(a) < b; 103 } 104 bool detail::operator<=(int64_t a, const SlowMPInt &b) { 105 return SlowMPInt(a) <= b; 106 } 107 bool detail::operator>=(int64_t a, const SlowMPInt &b) { 108 return SlowMPInt(a) >= b; 109 } 110 SlowMPInt detail::operator+(int64_t a, const SlowMPInt &b) { 111 return SlowMPInt(a) + b; 112 } 113 SlowMPInt detail::operator-(int64_t a, const SlowMPInt &b) { 114 return SlowMPInt(a) - b; 115 } 116 SlowMPInt detail::operator*(int64_t a, const SlowMPInt &b) { 117 return SlowMPInt(a) * b; 118 } 119 SlowMPInt detail::operator/(int64_t a, const SlowMPInt &b) { 120 return SlowMPInt(a) / b; 121 } 122 SlowMPInt detail::operator%(int64_t a, const SlowMPInt &b) { 123 return SlowMPInt(a) % b; 124 } 125 126 static unsigned getMaxWidth(const APInt &a, const APInt &b) { 127 return std::max(a.getBitWidth(), b.getBitWidth()); 128 } 129 130 /// --------------------------------------------------------------------------- 131 /// Comparison operators. 132 /// --------------------------------------------------------------------------- 133 134 // TODO: consider instead making APInt::compare available and using that. 135 bool SlowMPInt::operator==(const SlowMPInt &o) const { 136 unsigned width = getMaxWidth(val, o.val); 137 return val.sext(width) == o.val.sext(width); 138 } 139 bool SlowMPInt::operator!=(const SlowMPInt &o) const { 140 unsigned width = getMaxWidth(val, o.val); 141 return val.sext(width) != o.val.sext(width); 142 } 143 bool SlowMPInt::operator>(const SlowMPInt &o) const { 144 unsigned width = getMaxWidth(val, o.val); 145 return val.sext(width).sgt(o.val.sext(width)); 146 } 147 bool SlowMPInt::operator<(const SlowMPInt &o) const { 148 unsigned width = getMaxWidth(val, o.val); 149 return val.sext(width).slt(o.val.sext(width)); 150 } 151 bool SlowMPInt::operator<=(const SlowMPInt &o) const { 152 unsigned width = getMaxWidth(val, o.val); 153 return val.sext(width).sle(o.val.sext(width)); 154 } 155 bool SlowMPInt::operator>=(const SlowMPInt &o) const { 156 unsigned width = getMaxWidth(val, o.val); 157 return val.sext(width).sge(o.val.sext(width)); 158 } 159 160 /// --------------------------------------------------------------------------- 161 /// Arithmetic operators. 162 /// --------------------------------------------------------------------------- 163 164 /// Bring a and b to have the same width and then call op(a, b, overflow). 165 /// If the overflow bit becomes set, resize a and b to double the width and 166 /// call op(a, b, overflow), returning its result. The operation with double 167 /// widths should not also overflow. 168 APInt runOpWithExpandOnOverflow( 169 const APInt &a, const APInt &b, 170 llvm::function_ref<APInt(const APInt &, const APInt &, bool &overflow)> 171 op) { 172 bool overflow; 173 unsigned width = getMaxWidth(a, b); 174 APInt ret = op(a.sext(width), b.sext(width), overflow); 175 if (!overflow) 176 return ret; 177 178 width *= 2; 179 ret = op(a.sext(width), b.sext(width), overflow); 180 assert(!overflow && "double width should be sufficient to avoid overflow!"); 181 return ret; 182 } 183 184 SlowMPInt SlowMPInt::operator+(const SlowMPInt &o) const { 185 return SlowMPInt( 186 runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::sadd_ov))); 187 } 188 SlowMPInt SlowMPInt::operator-(const SlowMPInt &o) const { 189 return SlowMPInt( 190 runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::ssub_ov))); 191 } 192 SlowMPInt SlowMPInt::operator*(const SlowMPInt &o) const { 193 return SlowMPInt( 194 runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::smul_ov))); 195 } 196 SlowMPInt SlowMPInt::operator/(const SlowMPInt &o) const { 197 return SlowMPInt( 198 runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::sdiv_ov))); 199 } 200 SlowMPInt detail::abs(const SlowMPInt &x) { return x >= 0 ? x : -x; } 201 SlowMPInt detail::ceilDiv(const SlowMPInt &lhs, const SlowMPInt &rhs) { 202 if (rhs == -1) 203 return -lhs; 204 return SlowMPInt( 205 llvm::APIntOps::RoundingSDiv(lhs.val, rhs.val, APInt::Rounding::UP)); 206 } 207 SlowMPInt detail::floorDiv(const SlowMPInt &lhs, const SlowMPInt &rhs) { 208 if (rhs == -1) 209 return -lhs; 210 return SlowMPInt( 211 llvm::APIntOps::RoundingSDiv(lhs.val, rhs.val, APInt::Rounding::DOWN)); 212 } 213 // The RHS is always expected to be positive, and the result 214 /// is always non-negative. 215 SlowMPInt detail::mod(const SlowMPInt &lhs, const SlowMPInt &rhs) { 216 assert(rhs >= 1 && "mod is only supported for positive divisors!"); 217 return lhs % rhs < 0 ? lhs % rhs + rhs : lhs % rhs; 218 } 219 220 SlowMPInt detail::gcd(const SlowMPInt &a, const SlowMPInt &b) { 221 return SlowMPInt( 222 llvm::APIntOps::GreatestCommonDivisor(a.val.abs(), b.val.abs())); 223 } 224 225 /// Returns the least common multiple of 'a' and 'b'. 226 SlowMPInt detail::lcm(const SlowMPInt &a, const SlowMPInt &b) { 227 SlowMPInt x = abs(a); 228 SlowMPInt y = abs(b); 229 return (x * y) / gcd(x, y); 230 } 231 232 /// This operation cannot overflow. 233 SlowMPInt SlowMPInt::operator%(const SlowMPInt &o) const { 234 unsigned width = std::max(val.getBitWidth(), o.val.getBitWidth()); 235 return SlowMPInt(val.sext(width).srem(o.val.sext(width))); 236 } 237 238 SlowMPInt SlowMPInt::operator-() const { 239 if (val.isMinSignedValue()) { 240 /// Overflow only occurs when the value is the minimum possible value. 241 APInt ret = val.sext(2 * val.getBitWidth()); 242 return SlowMPInt(-ret); 243 } 244 return SlowMPInt(-val); 245 } 246 247 /// --------------------------------------------------------------------------- 248 /// Assignment operators, preincrement, predecrement. 249 /// --------------------------------------------------------------------------- 250 SlowMPInt &SlowMPInt::operator+=(const SlowMPInt &o) { 251 *this = *this + o; 252 return *this; 253 } 254 SlowMPInt &SlowMPInt::operator-=(const SlowMPInt &o) { 255 *this = *this - o; 256 return *this; 257 } 258 SlowMPInt &SlowMPInt::operator*=(const SlowMPInt &o) { 259 *this = *this * o; 260 return *this; 261 } 262 SlowMPInt &SlowMPInt::operator/=(const SlowMPInt &o) { 263 *this = *this / o; 264 return *this; 265 } 266 SlowMPInt &SlowMPInt::operator%=(const SlowMPInt &o) { 267 *this = *this % o; 268 return *this; 269 } 270 SlowMPInt &SlowMPInt::operator++() { 271 *this += 1; 272 return *this; 273 } 274 275 SlowMPInt &SlowMPInt::operator--() { 276 *this -= 1; 277 return *this; 278 } 279