1 //===- SlowMPInt.cpp - MLIR SlowMPInt Class -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Analysis/Presburger/SlowMPInt.h"
10 #include "llvm/Support/MathExtras.h"
11
12 using namespace mlir;
13 using namespace presburger;
14 using namespace detail;
15
SlowMPInt(int64_t val)16 SlowMPInt::SlowMPInt(int64_t val) : val(64, val, /*isSigned=*/true) {}
SlowMPInt()17 SlowMPInt::SlowMPInt() : SlowMPInt(0) {}
SlowMPInt(const llvm::APInt & val)18 SlowMPInt::SlowMPInt(const llvm::APInt &val) : val(val) {}
operator =(int64_t val)19 SlowMPInt &SlowMPInt::operator=(int64_t val) { return *this = SlowMPInt(val); }
operator int64_t() const20 SlowMPInt::operator int64_t() const { return val.getSExtValue(); }
21
hash_value(const SlowMPInt & x)22 llvm::hash_code detail::hash_value(const SlowMPInt &x) {
23 return hash_value(x.val);
24 }
25
26 /// ---------------------------------------------------------------------------
27 /// Printing.
28 /// ---------------------------------------------------------------------------
print(llvm::raw_ostream & os) const29 void SlowMPInt::print(llvm::raw_ostream &os) const { os << val; }
30
dump() const31 void SlowMPInt::dump() const { print(llvm::errs()); }
32
operator <<(llvm::raw_ostream & os,const SlowMPInt & x)33 llvm::raw_ostream &detail::operator<<(llvm::raw_ostream &os,
34 const SlowMPInt &x) {
35 x.print(os);
36 return os;
37 }
38
39 /// ---------------------------------------------------------------------------
40 /// Convenience operator overloads for int64_t.
41 /// ---------------------------------------------------------------------------
operator +=(SlowMPInt & a,int64_t b)42 SlowMPInt &detail::operator+=(SlowMPInt &a, int64_t b) {
43 return a += SlowMPInt(b);
44 }
operator -=(SlowMPInt & a,int64_t b)45 SlowMPInt &detail::operator-=(SlowMPInt &a, int64_t b) {
46 return a -= SlowMPInt(b);
47 }
operator *=(SlowMPInt & a,int64_t b)48 SlowMPInt &detail::operator*=(SlowMPInt &a, int64_t b) {
49 return a *= SlowMPInt(b);
50 }
operator /=(SlowMPInt & a,int64_t b)51 SlowMPInt &detail::operator/=(SlowMPInt &a, int64_t b) {
52 return a /= SlowMPInt(b);
53 }
operator %=(SlowMPInt & a,int64_t b)54 SlowMPInt &detail::operator%=(SlowMPInt &a, int64_t b) {
55 return a %= SlowMPInt(b);
56 }
57
operator ==(const SlowMPInt & a,int64_t b)58 bool detail::operator==(const SlowMPInt &a, int64_t b) {
59 return a == SlowMPInt(b);
60 }
operator !=(const SlowMPInt & a,int64_t b)61 bool detail::operator!=(const SlowMPInt &a, int64_t b) {
62 return a != SlowMPInt(b);
63 }
operator >(const SlowMPInt & a,int64_t b)64 bool detail::operator>(const SlowMPInt &a, int64_t b) {
65 return a > SlowMPInt(b);
66 }
operator <(const SlowMPInt & a,int64_t b)67 bool detail::operator<(const SlowMPInt &a, int64_t b) {
68 return a < SlowMPInt(b);
69 }
operator <=(const SlowMPInt & a,int64_t b)70 bool detail::operator<=(const SlowMPInt &a, int64_t b) {
71 return a <= SlowMPInt(b);
72 }
operator >=(const SlowMPInt & a,int64_t b)73 bool detail::operator>=(const SlowMPInt &a, int64_t b) {
74 return a >= SlowMPInt(b);
75 }
operator +(const SlowMPInt & a,int64_t b)76 SlowMPInt detail::operator+(const SlowMPInt &a, int64_t b) {
77 return a + SlowMPInt(b);
78 }
operator -(const SlowMPInt & a,int64_t b)79 SlowMPInt detail::operator-(const SlowMPInt &a, int64_t b) {
80 return a - SlowMPInt(b);
81 }
operator *(const SlowMPInt & a,int64_t b)82 SlowMPInt detail::operator*(const SlowMPInt &a, int64_t b) {
83 return a * SlowMPInt(b);
84 }
operator /(const SlowMPInt & a,int64_t b)85 SlowMPInt detail::operator/(const SlowMPInt &a, int64_t b) {
86 return a / SlowMPInt(b);
87 }
operator %(const SlowMPInt & a,int64_t b)88 SlowMPInt detail::operator%(const SlowMPInt &a, int64_t b) {
89 return a % SlowMPInt(b);
90 }
91
operator ==(int64_t a,const SlowMPInt & b)92 bool detail::operator==(int64_t a, const SlowMPInt &b) {
93 return SlowMPInt(a) == b;
94 }
operator !=(int64_t a,const SlowMPInt & b)95 bool detail::operator!=(int64_t a, const SlowMPInt &b) {
96 return SlowMPInt(a) != b;
97 }
operator >(int64_t a,const SlowMPInt & b)98 bool detail::operator>(int64_t a, const SlowMPInt &b) {
99 return SlowMPInt(a) > b;
100 }
operator <(int64_t a,const SlowMPInt & b)101 bool detail::operator<(int64_t a, const SlowMPInt &b) {
102 return SlowMPInt(a) < b;
103 }
operator <=(int64_t a,const SlowMPInt & b)104 bool detail::operator<=(int64_t a, const SlowMPInt &b) {
105 return SlowMPInt(a) <= b;
106 }
operator >=(int64_t a,const SlowMPInt & b)107 bool detail::operator>=(int64_t a, const SlowMPInt &b) {
108 return SlowMPInt(a) >= b;
109 }
operator +(int64_t a,const SlowMPInt & b)110 SlowMPInt detail::operator+(int64_t a, const SlowMPInt &b) {
111 return SlowMPInt(a) + b;
112 }
operator -(int64_t a,const SlowMPInt & b)113 SlowMPInt detail::operator-(int64_t a, const SlowMPInt &b) {
114 return SlowMPInt(a) - b;
115 }
operator *(int64_t a,const SlowMPInt & b)116 SlowMPInt detail::operator*(int64_t a, const SlowMPInt &b) {
117 return SlowMPInt(a) * b;
118 }
operator /(int64_t a,const SlowMPInt & b)119 SlowMPInt detail::operator/(int64_t a, const SlowMPInt &b) {
120 return SlowMPInt(a) / b;
121 }
operator %(int64_t a,const SlowMPInt & b)122 SlowMPInt detail::operator%(int64_t a, const SlowMPInt &b) {
123 return SlowMPInt(a) % b;
124 }
125
getMaxWidth(const APInt & a,const APInt & b)126 static unsigned getMaxWidth(const APInt &a, const APInt &b) {
127 return std::max(a.getBitWidth(), b.getBitWidth());
128 }
129
130 /// ---------------------------------------------------------------------------
131 /// Comparison operators.
132 /// ---------------------------------------------------------------------------
133
134 // TODO: consider instead making APInt::compare available and using that.
operator ==(const SlowMPInt & o) const135 bool SlowMPInt::operator==(const SlowMPInt &o) const {
136 unsigned width = getMaxWidth(val, o.val);
137 return val.sext(width) == o.val.sext(width);
138 }
operator !=(const SlowMPInt & o) const139 bool SlowMPInt::operator!=(const SlowMPInt &o) const {
140 unsigned width = getMaxWidth(val, o.val);
141 return val.sext(width) != o.val.sext(width);
142 }
operator >(const SlowMPInt & o) const143 bool SlowMPInt::operator>(const SlowMPInt &o) const {
144 unsigned width = getMaxWidth(val, o.val);
145 return val.sext(width).sgt(o.val.sext(width));
146 }
operator <(const SlowMPInt & o) const147 bool SlowMPInt::operator<(const SlowMPInt &o) const {
148 unsigned width = getMaxWidth(val, o.val);
149 return val.sext(width).slt(o.val.sext(width));
150 }
operator <=(const SlowMPInt & o) const151 bool SlowMPInt::operator<=(const SlowMPInt &o) const {
152 unsigned width = getMaxWidth(val, o.val);
153 return val.sext(width).sle(o.val.sext(width));
154 }
operator >=(const SlowMPInt & o) const155 bool SlowMPInt::operator>=(const SlowMPInt &o) const {
156 unsigned width = getMaxWidth(val, o.val);
157 return val.sext(width).sge(o.val.sext(width));
158 }
159
160 /// ---------------------------------------------------------------------------
161 /// Arithmetic operators.
162 /// ---------------------------------------------------------------------------
163
164 /// Bring a and b to have the same width and then call op(a, b, overflow).
165 /// If the overflow bit becomes set, resize a and b to double the width and
166 /// call op(a, b, overflow), returning its result. The operation with double
167 /// widths should not also overflow.
runOpWithExpandOnOverflow(const APInt & a,const APInt & b,llvm::function_ref<APInt (const APInt &,const APInt &,bool & overflow)> op)168 APInt runOpWithExpandOnOverflow(
169 const APInt &a, const APInt &b,
170 llvm::function_ref<APInt(const APInt &, const APInt &, bool &overflow)>
171 op) {
172 bool overflow;
173 unsigned width = getMaxWidth(a, b);
174 APInt ret = op(a.sext(width), b.sext(width), overflow);
175 if (!overflow)
176 return ret;
177
178 width *= 2;
179 ret = op(a.sext(width), b.sext(width), overflow);
180 assert(!overflow && "double width should be sufficient to avoid overflow!");
181 return ret;
182 }
183
operator +(const SlowMPInt & o) const184 SlowMPInt SlowMPInt::operator+(const SlowMPInt &o) const {
185 return SlowMPInt(
186 runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::sadd_ov)));
187 }
operator -(const SlowMPInt & o) const188 SlowMPInt SlowMPInt::operator-(const SlowMPInt &o) const {
189 return SlowMPInt(
190 runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::ssub_ov)));
191 }
operator *(const SlowMPInt & o) const192 SlowMPInt SlowMPInt::operator*(const SlowMPInt &o) const {
193 return SlowMPInt(
194 runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::smul_ov)));
195 }
operator /(const SlowMPInt & o) const196 SlowMPInt SlowMPInt::operator/(const SlowMPInt &o) const {
197 return SlowMPInt(
198 runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::sdiv_ov)));
199 }
abs(const SlowMPInt & x)200 SlowMPInt detail::abs(const SlowMPInt &x) { return x >= 0 ? x : -x; }
ceilDiv(const SlowMPInt & lhs,const SlowMPInt & rhs)201 SlowMPInt detail::ceilDiv(const SlowMPInt &lhs, const SlowMPInt &rhs) {
202 if (rhs == -1)
203 return -lhs;
204 unsigned width = getMaxWidth(lhs.val, rhs.val);
205 return SlowMPInt(llvm::APIntOps::RoundingSDiv(
206 lhs.val.sext(width), rhs.val.sext(width), APInt::Rounding::UP));
207 }
floorDiv(const SlowMPInt & lhs,const SlowMPInt & rhs)208 SlowMPInt detail::floorDiv(const SlowMPInt &lhs, const SlowMPInt &rhs) {
209 if (rhs == -1)
210 return -lhs;
211 unsigned width = getMaxWidth(lhs.val, rhs.val);
212 return SlowMPInt(llvm::APIntOps::RoundingSDiv(
213 lhs.val.sext(width), rhs.val.sext(width), APInt::Rounding::DOWN));
214 }
215 // The RHS is always expected to be positive, and the result
216 /// is always non-negative.
mod(const SlowMPInt & lhs,const SlowMPInt & rhs)217 SlowMPInt detail::mod(const SlowMPInt &lhs, const SlowMPInt &rhs) {
218 assert(rhs >= 1 && "mod is only supported for positive divisors!");
219 return lhs % rhs < 0 ? lhs % rhs + rhs : lhs % rhs;
220 }
221
gcd(const SlowMPInt & a,const SlowMPInt & b)222 SlowMPInt detail::gcd(const SlowMPInt &a, const SlowMPInt &b) {
223 assert(a >= 0 && b >= 0 && "operands must be non-negative!");
224 return SlowMPInt(llvm::APIntOps::GreatestCommonDivisor(a.val, b.val));
225 }
226
227 /// Returns the least common multiple of 'a' and 'b'.
lcm(const SlowMPInt & a,const SlowMPInt & b)228 SlowMPInt detail::lcm(const SlowMPInt &a, const SlowMPInt &b) {
229 SlowMPInt x = abs(a);
230 SlowMPInt y = abs(b);
231 return (x * y) / gcd(x, y);
232 }
233
234 /// This operation cannot overflow.
operator %(const SlowMPInt & o) const235 SlowMPInt SlowMPInt::operator%(const SlowMPInt &o) const {
236 unsigned width = std::max(val.getBitWidth(), o.val.getBitWidth());
237 return SlowMPInt(val.sext(width).srem(o.val.sext(width)));
238 }
239
operator -() const240 SlowMPInt SlowMPInt::operator-() const {
241 if (val.isMinSignedValue()) {
242 /// Overflow only occurs when the value is the minimum possible value.
243 APInt ret = val.sext(2 * val.getBitWidth());
244 return SlowMPInt(-ret);
245 }
246 return SlowMPInt(-val);
247 }
248
249 /// ---------------------------------------------------------------------------
250 /// Assignment operators, preincrement, predecrement.
251 /// ---------------------------------------------------------------------------
operator +=(const SlowMPInt & o)252 SlowMPInt &SlowMPInt::operator+=(const SlowMPInt &o) {
253 *this = *this + o;
254 return *this;
255 }
operator -=(const SlowMPInt & o)256 SlowMPInt &SlowMPInt::operator-=(const SlowMPInt &o) {
257 *this = *this - o;
258 return *this;
259 }
operator *=(const SlowMPInt & o)260 SlowMPInt &SlowMPInt::operator*=(const SlowMPInt &o) {
261 *this = *this * o;
262 return *this;
263 }
operator /=(const SlowMPInt & o)264 SlowMPInt &SlowMPInt::operator/=(const SlowMPInt &o) {
265 *this = *this / o;
266 return *this;
267 }
operator %=(const SlowMPInt & o)268 SlowMPInt &SlowMPInt::operator%=(const SlowMPInt &o) {
269 *this = *this % o;
270 return *this;
271 }
operator ++()272 SlowMPInt &SlowMPInt::operator++() {
273 *this += 1;
274 return *this;
275 }
276
operator --()277 SlowMPInt &SlowMPInt::operator--() {
278 *this -= 1;
279 return *this;
280 }
281