1 //===- SlowMPInt.cpp - MLIR SlowMPInt Class -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/Presburger/SlowMPInt.h"
10 #include "llvm/Support/MathExtras.h"
11 
12 using namespace mlir;
13 using namespace presburger;
14 using namespace detail;
15 
16 SlowMPInt::SlowMPInt(int64_t val) : val(64, val, /*isSigned=*/true) {}
17 SlowMPInt::SlowMPInt() : SlowMPInt(0) {}
18 SlowMPInt::SlowMPInt(const llvm::APInt &val) : val(val) {}
19 SlowMPInt &SlowMPInt::operator=(int64_t val) { return *this = SlowMPInt(val); }
20 SlowMPInt::operator int64_t() const { return val.getSExtValue(); }
21 
22 llvm::hash_code detail::hash_value(const SlowMPInt &x) {
23   return hash_value(x.val);
24 }
25 
26 /// ---------------------------------------------------------------------------
27 /// Printing.
28 /// ---------------------------------------------------------------------------
29 void SlowMPInt::print(llvm::raw_ostream &os) const { os << val; }
30 
31 void SlowMPInt::dump() const { print(llvm::errs()); }
32 
33 llvm::raw_ostream &detail::operator<<(llvm::raw_ostream &os,
34                                       const SlowMPInt &x) {
35   x.print(os);
36   return os;
37 }
38 
39 /// ---------------------------------------------------------------------------
40 /// Convenience operator overloads for int64_t.
41 /// ---------------------------------------------------------------------------
42 SlowMPInt &detail::operator+=(SlowMPInt &a, int64_t b) {
43   return a += SlowMPInt(b);
44 }
45 SlowMPInt &detail::operator-=(SlowMPInt &a, int64_t b) {
46   return a -= SlowMPInt(b);
47 }
48 SlowMPInt &detail::operator*=(SlowMPInt &a, int64_t b) {
49   return a *= SlowMPInt(b);
50 }
51 SlowMPInt &detail::operator/=(SlowMPInt &a, int64_t b) {
52   return a /= SlowMPInt(b);
53 }
54 SlowMPInt &detail::operator%=(SlowMPInt &a, int64_t b) {
55   return a %= SlowMPInt(b);
56 }
57 
58 bool detail::operator==(const SlowMPInt &a, int64_t b) {
59   return a == SlowMPInt(b);
60 }
61 bool detail::operator!=(const SlowMPInt &a, int64_t b) {
62   return a != SlowMPInt(b);
63 }
64 bool detail::operator>(const SlowMPInt &a, int64_t b) {
65   return a > SlowMPInt(b);
66 }
67 bool detail::operator<(const SlowMPInt &a, int64_t b) {
68   return a < SlowMPInt(b);
69 }
70 bool detail::operator<=(const SlowMPInt &a, int64_t b) {
71   return a <= SlowMPInt(b);
72 }
73 bool detail::operator>=(const SlowMPInt &a, int64_t b) {
74   return a >= SlowMPInt(b);
75 }
76 SlowMPInt detail::operator+(const SlowMPInt &a, int64_t b) {
77   return a + SlowMPInt(b);
78 }
79 SlowMPInt detail::operator-(const SlowMPInt &a, int64_t b) {
80   return a - SlowMPInt(b);
81 }
82 SlowMPInt detail::operator*(const SlowMPInt &a, int64_t b) {
83   return a * SlowMPInt(b);
84 }
85 SlowMPInt detail::operator/(const SlowMPInt &a, int64_t b) {
86   return a / SlowMPInt(b);
87 }
88 SlowMPInt detail::operator%(const SlowMPInt &a, int64_t b) {
89   return a % SlowMPInt(b);
90 }
91 
92 bool detail::operator==(int64_t a, const SlowMPInt &b) {
93   return SlowMPInt(a) == b;
94 }
95 bool detail::operator!=(int64_t a, const SlowMPInt &b) {
96   return SlowMPInt(a) != b;
97 }
98 bool detail::operator>(int64_t a, const SlowMPInt &b) {
99   return SlowMPInt(a) > b;
100 }
101 bool detail::operator<(int64_t a, const SlowMPInt &b) {
102   return SlowMPInt(a) < b;
103 }
104 bool detail::operator<=(int64_t a, const SlowMPInt &b) {
105   return SlowMPInt(a) <= b;
106 }
107 bool detail::operator>=(int64_t a, const SlowMPInt &b) {
108   return SlowMPInt(a) >= b;
109 }
110 SlowMPInt detail::operator+(int64_t a, const SlowMPInt &b) {
111   return SlowMPInt(a) + b;
112 }
113 SlowMPInt detail::operator-(int64_t a, const SlowMPInt &b) {
114   return SlowMPInt(a) - b;
115 }
116 SlowMPInt detail::operator*(int64_t a, const SlowMPInt &b) {
117   return SlowMPInt(a) * b;
118 }
119 SlowMPInt detail::operator/(int64_t a, const SlowMPInt &b) {
120   return SlowMPInt(a) / b;
121 }
122 SlowMPInt detail::operator%(int64_t a, const SlowMPInt &b) {
123   return SlowMPInt(a) % b;
124 }
125 
126 static unsigned getMaxWidth(const APInt &a, const APInt &b) {
127   return std::max(a.getBitWidth(), b.getBitWidth());
128 }
129 
130 /// ---------------------------------------------------------------------------
131 /// Comparison operators.
132 /// ---------------------------------------------------------------------------
133 
134 // TODO: consider instead making APInt::compare available and using that.
135 bool SlowMPInt::operator==(const SlowMPInt &o) const {
136   unsigned width = getMaxWidth(val, o.val);
137   return val.sext(width) == o.val.sext(width);
138 }
139 bool SlowMPInt::operator!=(const SlowMPInt &o) const {
140   unsigned width = getMaxWidth(val, o.val);
141   return val.sext(width) != o.val.sext(width);
142 }
143 bool SlowMPInt::operator>(const SlowMPInt &o) const {
144   unsigned width = getMaxWidth(val, o.val);
145   return val.sext(width).sgt(o.val.sext(width));
146 }
147 bool SlowMPInt::operator<(const SlowMPInt &o) const {
148   unsigned width = getMaxWidth(val, o.val);
149   return val.sext(width).slt(o.val.sext(width));
150 }
151 bool SlowMPInt::operator<=(const SlowMPInt &o) const {
152   unsigned width = getMaxWidth(val, o.val);
153   return val.sext(width).sle(o.val.sext(width));
154 }
155 bool SlowMPInt::operator>=(const SlowMPInt &o) const {
156   unsigned width = getMaxWidth(val, o.val);
157   return val.sext(width).sge(o.val.sext(width));
158 }
159 
160 /// ---------------------------------------------------------------------------
161 /// Arithmetic operators.
162 /// ---------------------------------------------------------------------------
163 
164 /// Bring a and b to have the same width and then call op(a, b, overflow).
165 /// If the overflow bit becomes set, resize a and b to double the width and
166 /// call op(a, b, overflow), returning its result. The operation with double
167 /// widths should not also overflow.
168 APInt runOpWithExpandOnOverflow(
169     const APInt &a, const APInt &b,
170     llvm::function_ref<APInt(const APInt &, const APInt &, bool &overflow)>
171         op) {
172   bool overflow;
173   unsigned width = getMaxWidth(a, b);
174   APInt ret = op(a.sext(width), b.sext(width), overflow);
175   if (!overflow)
176     return ret;
177 
178   width *= 2;
179   ret = op(a.sext(width), b.sext(width), overflow);
180   assert(!overflow && "double width should be sufficient to avoid overflow!");
181   return ret;
182 }
183 
184 SlowMPInt SlowMPInt::operator+(const SlowMPInt &o) const {
185   return SlowMPInt(
186       runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::sadd_ov)));
187 }
188 SlowMPInt SlowMPInt::operator-(const SlowMPInt &o) const {
189   return SlowMPInt(
190       runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::ssub_ov)));
191 }
192 SlowMPInt SlowMPInt::operator*(const SlowMPInt &o) const {
193   return SlowMPInt(
194       runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::smul_ov)));
195 }
196 SlowMPInt SlowMPInt::operator/(const SlowMPInt &o) const {
197   return SlowMPInt(
198       runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::sdiv_ov)));
199 }
200 SlowMPInt detail::abs(const SlowMPInt &x) { return x >= 0 ? x : -x; }
201 SlowMPInt detail::ceilDiv(const SlowMPInt &lhs, const SlowMPInt &rhs) {
202   if (rhs == -1)
203     return -lhs;
204   return SlowMPInt(
205       llvm::APIntOps::RoundingSDiv(lhs.val, rhs.val, APInt::Rounding::UP));
206 }
207 SlowMPInt detail::floorDiv(const SlowMPInt &lhs, const SlowMPInt &rhs) {
208   if (rhs == -1)
209     return -lhs;
210   return SlowMPInt(
211       llvm::APIntOps::RoundingSDiv(lhs.val, rhs.val, APInt::Rounding::DOWN));
212 }
213 // The RHS is always expected to be positive, and the result
214 /// is always non-negative.
215 SlowMPInt detail::mod(const SlowMPInt &lhs, const SlowMPInt &rhs) {
216   assert(rhs >= 1 && "mod is only supported for positive divisors!");
217   return lhs % rhs < 0 ? lhs % rhs + rhs : lhs % rhs;
218 }
219 
220 SlowMPInt detail::gcd(const SlowMPInt &a, const SlowMPInt &b) {
221   return SlowMPInt(
222       llvm::APIntOps::GreatestCommonDivisor(a.val.abs(), b.val.abs()));
223 }
224 
225 /// Returns the least common multiple of 'a' and 'b'.
226 SlowMPInt detail::lcm(const SlowMPInt &a, const SlowMPInt &b) {
227   SlowMPInt x = abs(a);
228   SlowMPInt y = abs(b);
229   return (x * y) / gcd(x, y);
230 }
231 
232 /// This operation cannot overflow.
233 SlowMPInt SlowMPInt::operator%(const SlowMPInt &o) const {
234   unsigned width = std::max(val.getBitWidth(), o.val.getBitWidth());
235   return SlowMPInt(val.sext(width).srem(o.val.sext(width)));
236 }
237 
238 SlowMPInt SlowMPInt::operator-() const {
239   if (val.isMinSignedValue()) {
240     /// Overflow only occurs when the value is the minimum possible value.
241     APInt ret = val.sext(2 * val.getBitWidth());
242     return SlowMPInt(-ret);
243   }
244   return SlowMPInt(-val);
245 }
246 
247 /// ---------------------------------------------------------------------------
248 /// Assignment operators, preincrement, predecrement.
249 /// ---------------------------------------------------------------------------
250 SlowMPInt &SlowMPInt::operator+=(const SlowMPInt &o) {
251   *this = *this + o;
252   return *this;
253 }
254 SlowMPInt &SlowMPInt::operator-=(const SlowMPInt &o) {
255   *this = *this - o;
256   return *this;
257 }
258 SlowMPInt &SlowMPInt::operator*=(const SlowMPInt &o) {
259   *this = *this * o;
260   return *this;
261 }
262 SlowMPInt &SlowMPInt::operator/=(const SlowMPInt &o) {
263   *this = *this / o;
264   return *this;
265 }
266 SlowMPInt &SlowMPInt::operator%=(const SlowMPInt &o) {
267   *this = *this % o;
268   return *this;
269 }
270 SlowMPInt &SlowMPInt::operator++() {
271   *this += 1;
272   return *this;
273 }
274 
275 SlowMPInt &SlowMPInt::operator--() {
276   *this -= 1;
277   return *this;
278 }
279