1628a2c14SArjun P //===- SlowMPInt.cpp - MLIR SlowMPInt Class -------------------------------===//
2628a2c14SArjun P //
3628a2c14SArjun P // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4628a2c14SArjun P // See https://llvm.org/LICENSE.txt for license information.
5628a2c14SArjun P // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6628a2c14SArjun P //
7628a2c14SArjun P //===----------------------------------------------------------------------===//
8628a2c14SArjun P 
9628a2c14SArjun P #include "mlir/Analysis/Presburger/SlowMPInt.h"
10628a2c14SArjun P #include "llvm/Support/MathExtras.h"
11628a2c14SArjun P 
12628a2c14SArjun P using namespace mlir;
13628a2c14SArjun P using namespace presburger;
14628a2c14SArjun P using namespace detail;
15628a2c14SArjun P 
SlowMPInt(int64_t val)16628a2c14SArjun P SlowMPInt::SlowMPInt(int64_t val) : val(64, val, /*isSigned=*/true) {}
SlowMPInt()17628a2c14SArjun P SlowMPInt::SlowMPInt() : SlowMPInt(0) {}
SlowMPInt(const llvm::APInt & val)18628a2c14SArjun P SlowMPInt::SlowMPInt(const llvm::APInt &val) : val(val) {}
operator =(int64_t val)19628a2c14SArjun P SlowMPInt &SlowMPInt::operator=(int64_t val) { return *this = SlowMPInt(val); }
operator int64_t() const20628a2c14SArjun P SlowMPInt::operator int64_t() const { return val.getSExtValue(); }
21628a2c14SArjun P 
hash_value(const SlowMPInt & x)22628a2c14SArjun P llvm::hash_code detail::hash_value(const SlowMPInt &x) {
23628a2c14SArjun P   return hash_value(x.val);
24628a2c14SArjun P }
25628a2c14SArjun P 
26628a2c14SArjun P /// ---------------------------------------------------------------------------
27628a2c14SArjun P /// Printing.
28628a2c14SArjun P /// ---------------------------------------------------------------------------
print(llvm::raw_ostream & os) const29628a2c14SArjun P void SlowMPInt::print(llvm::raw_ostream &os) const { os << val; }
30628a2c14SArjun P 
dump() const31628a2c14SArjun P void SlowMPInt::dump() const { print(llvm::errs()); }
32628a2c14SArjun P 
operator <<(llvm::raw_ostream & os,const SlowMPInt & x)33628a2c14SArjun P llvm::raw_ostream &detail::operator<<(llvm::raw_ostream &os,
34628a2c14SArjun P                                       const SlowMPInt &x) {
35628a2c14SArjun P   x.print(os);
36628a2c14SArjun P   return os;
37628a2c14SArjun P }
38628a2c14SArjun P 
39628a2c14SArjun P /// ---------------------------------------------------------------------------
40628a2c14SArjun P /// Convenience operator overloads for int64_t.
41628a2c14SArjun P /// ---------------------------------------------------------------------------
operator +=(SlowMPInt & a,int64_t b)42628a2c14SArjun P SlowMPInt &detail::operator+=(SlowMPInt &a, int64_t b) {
43628a2c14SArjun P   return a += SlowMPInt(b);
44628a2c14SArjun P }
operator -=(SlowMPInt & a,int64_t b)45628a2c14SArjun P SlowMPInt &detail::operator-=(SlowMPInt &a, int64_t b) {
46628a2c14SArjun P   return a -= SlowMPInt(b);
47628a2c14SArjun P }
operator *=(SlowMPInt & a,int64_t b)48628a2c14SArjun P SlowMPInt &detail::operator*=(SlowMPInt &a, int64_t b) {
49628a2c14SArjun P   return a *= SlowMPInt(b);
50628a2c14SArjun P }
operator /=(SlowMPInt & a,int64_t b)51628a2c14SArjun P SlowMPInt &detail::operator/=(SlowMPInt &a, int64_t b) {
52628a2c14SArjun P   return a /= SlowMPInt(b);
53628a2c14SArjun P }
operator %=(SlowMPInt & a,int64_t b)54628a2c14SArjun P SlowMPInt &detail::operator%=(SlowMPInt &a, int64_t b) {
55628a2c14SArjun P   return a %= SlowMPInt(b);
56628a2c14SArjun P }
57628a2c14SArjun P 
operator ==(const SlowMPInt & a,int64_t b)58628a2c14SArjun P bool detail::operator==(const SlowMPInt &a, int64_t b) {
59628a2c14SArjun P   return a == SlowMPInt(b);
60628a2c14SArjun P }
operator !=(const SlowMPInt & a,int64_t b)61628a2c14SArjun P bool detail::operator!=(const SlowMPInt &a, int64_t b) {
62628a2c14SArjun P   return a != SlowMPInt(b);
63628a2c14SArjun P }
operator >(const SlowMPInt & a,int64_t b)64628a2c14SArjun P bool detail::operator>(const SlowMPInt &a, int64_t b) {
65628a2c14SArjun P   return a > SlowMPInt(b);
66628a2c14SArjun P }
operator <(const SlowMPInt & a,int64_t b)67628a2c14SArjun P bool detail::operator<(const SlowMPInt &a, int64_t b) {
68628a2c14SArjun P   return a < SlowMPInt(b);
69628a2c14SArjun P }
operator <=(const SlowMPInt & a,int64_t b)70628a2c14SArjun P bool detail::operator<=(const SlowMPInt &a, int64_t b) {
71628a2c14SArjun P   return a <= SlowMPInt(b);
72628a2c14SArjun P }
operator >=(const SlowMPInt & a,int64_t b)73628a2c14SArjun P bool detail::operator>=(const SlowMPInt &a, int64_t b) {
74628a2c14SArjun P   return a >= SlowMPInt(b);
75628a2c14SArjun P }
operator +(const SlowMPInt & a,int64_t b)76628a2c14SArjun P SlowMPInt detail::operator+(const SlowMPInt &a, int64_t b) {
77628a2c14SArjun P   return a + SlowMPInt(b);
78628a2c14SArjun P }
operator -(const SlowMPInt & a,int64_t b)79628a2c14SArjun P SlowMPInt detail::operator-(const SlowMPInt &a, int64_t b) {
80628a2c14SArjun P   return a - SlowMPInt(b);
81628a2c14SArjun P }
operator *(const SlowMPInt & a,int64_t b)82628a2c14SArjun P SlowMPInt detail::operator*(const SlowMPInt &a, int64_t b) {
83628a2c14SArjun P   return a * SlowMPInt(b);
84628a2c14SArjun P }
operator /(const SlowMPInt & a,int64_t b)85628a2c14SArjun P SlowMPInt detail::operator/(const SlowMPInt &a, int64_t b) {
86628a2c14SArjun P   return a / SlowMPInt(b);
87628a2c14SArjun P }
operator %(const SlowMPInt & a,int64_t b)88628a2c14SArjun P SlowMPInt detail::operator%(const SlowMPInt &a, int64_t b) {
89628a2c14SArjun P   return a % SlowMPInt(b);
90628a2c14SArjun P }
91628a2c14SArjun P 
operator ==(int64_t a,const SlowMPInt & b)92628a2c14SArjun P bool detail::operator==(int64_t a, const SlowMPInt &b) {
93628a2c14SArjun P   return SlowMPInt(a) == b;
94628a2c14SArjun P }
operator !=(int64_t a,const SlowMPInt & b)95628a2c14SArjun P bool detail::operator!=(int64_t a, const SlowMPInt &b) {
96628a2c14SArjun P   return SlowMPInt(a) != b;
97628a2c14SArjun P }
operator >(int64_t a,const SlowMPInt & b)98628a2c14SArjun P bool detail::operator>(int64_t a, const SlowMPInt &b) {
99628a2c14SArjun P   return SlowMPInt(a) > b;
100628a2c14SArjun P }
operator <(int64_t a,const SlowMPInt & b)101628a2c14SArjun P bool detail::operator<(int64_t a, const SlowMPInt &b) {
102628a2c14SArjun P   return SlowMPInt(a) < b;
103628a2c14SArjun P }
operator <=(int64_t a,const SlowMPInt & b)104628a2c14SArjun P bool detail::operator<=(int64_t a, const SlowMPInt &b) {
105628a2c14SArjun P   return SlowMPInt(a) <= b;
106628a2c14SArjun P }
operator >=(int64_t a,const SlowMPInt & b)107628a2c14SArjun P bool detail::operator>=(int64_t a, const SlowMPInt &b) {
108628a2c14SArjun P   return SlowMPInt(a) >= b;
109628a2c14SArjun P }
operator +(int64_t a,const SlowMPInt & b)110628a2c14SArjun P SlowMPInt detail::operator+(int64_t a, const SlowMPInt &b) {
111628a2c14SArjun P   return SlowMPInt(a) + b;
112628a2c14SArjun P }
operator -(int64_t a,const SlowMPInt & b)113628a2c14SArjun P SlowMPInt detail::operator-(int64_t a, const SlowMPInt &b) {
114628a2c14SArjun P   return SlowMPInt(a) - b;
115628a2c14SArjun P }
operator *(int64_t a,const SlowMPInt & b)116628a2c14SArjun P SlowMPInt detail::operator*(int64_t a, const SlowMPInt &b) {
117628a2c14SArjun P   return SlowMPInt(a) * b;
118628a2c14SArjun P }
operator /(int64_t a,const SlowMPInt & b)119628a2c14SArjun P SlowMPInt detail::operator/(int64_t a, const SlowMPInt &b) {
120628a2c14SArjun P   return SlowMPInt(a) / b;
121628a2c14SArjun P }
operator %(int64_t a,const SlowMPInt & b)122628a2c14SArjun P SlowMPInt detail::operator%(int64_t a, const SlowMPInt &b) {
123628a2c14SArjun P   return SlowMPInt(a) % b;
124628a2c14SArjun P }
125628a2c14SArjun P 
getMaxWidth(const APInt & a,const APInt & b)126628a2c14SArjun P static unsigned getMaxWidth(const APInt &a, const APInt &b) {
127628a2c14SArjun P   return std::max(a.getBitWidth(), b.getBitWidth());
128628a2c14SArjun P }
129628a2c14SArjun P 
130628a2c14SArjun P /// ---------------------------------------------------------------------------
131628a2c14SArjun P /// Comparison operators.
132628a2c14SArjun P /// ---------------------------------------------------------------------------
133628a2c14SArjun P 
134628a2c14SArjun P // TODO: consider instead making APInt::compare available and using that.
operator ==(const SlowMPInt & o) const135628a2c14SArjun P bool SlowMPInt::operator==(const SlowMPInt &o) const {
136628a2c14SArjun P   unsigned width = getMaxWidth(val, o.val);
137628a2c14SArjun P   return val.sext(width) == o.val.sext(width);
138628a2c14SArjun P }
operator !=(const SlowMPInt & o) const139628a2c14SArjun P bool SlowMPInt::operator!=(const SlowMPInt &o) const {
140628a2c14SArjun P   unsigned width = getMaxWidth(val, o.val);
141628a2c14SArjun P   return val.sext(width) != o.val.sext(width);
142628a2c14SArjun P }
operator >(const SlowMPInt & o) const143628a2c14SArjun P bool SlowMPInt::operator>(const SlowMPInt &o) const {
144628a2c14SArjun P   unsigned width = getMaxWidth(val, o.val);
145628a2c14SArjun P   return val.sext(width).sgt(o.val.sext(width));
146628a2c14SArjun P }
operator <(const SlowMPInt & o) const147628a2c14SArjun P bool SlowMPInt::operator<(const SlowMPInt &o) const {
148628a2c14SArjun P   unsigned width = getMaxWidth(val, o.val);
149628a2c14SArjun P   return val.sext(width).slt(o.val.sext(width));
150628a2c14SArjun P }
operator <=(const SlowMPInt & o) const151628a2c14SArjun P bool SlowMPInt::operator<=(const SlowMPInt &o) const {
152628a2c14SArjun P   unsigned width = getMaxWidth(val, o.val);
153628a2c14SArjun P   return val.sext(width).sle(o.val.sext(width));
154628a2c14SArjun P }
operator >=(const SlowMPInt & o) const155628a2c14SArjun P bool SlowMPInt::operator>=(const SlowMPInt &o) const {
156628a2c14SArjun P   unsigned width = getMaxWidth(val, o.val);
157628a2c14SArjun P   return val.sext(width).sge(o.val.sext(width));
158628a2c14SArjun P }
159628a2c14SArjun P 
160628a2c14SArjun P /// ---------------------------------------------------------------------------
161628a2c14SArjun P /// Arithmetic operators.
162628a2c14SArjun P /// ---------------------------------------------------------------------------
163628a2c14SArjun P 
164628a2c14SArjun P /// Bring a and b to have the same width and then call op(a, b, overflow).
165628a2c14SArjun P /// If the overflow bit becomes set, resize a and b to double the width and
166628a2c14SArjun P /// call op(a, b, overflow), returning its result. The operation with double
167628a2c14SArjun P /// widths should not also overflow.
runOpWithExpandOnOverflow(const APInt & a,const APInt & b,llvm::function_ref<APInt (const APInt &,const APInt &,bool & overflow)> op)168628a2c14SArjun P APInt runOpWithExpandOnOverflow(
169628a2c14SArjun P     const APInt &a, const APInt &b,
170628a2c14SArjun P     llvm::function_ref<APInt(const APInt &, const APInt &, bool &overflow)>
171628a2c14SArjun P         op) {
172628a2c14SArjun P   bool overflow;
173628a2c14SArjun P   unsigned width = getMaxWidth(a, b);
174628a2c14SArjun P   APInt ret = op(a.sext(width), b.sext(width), overflow);
175628a2c14SArjun P   if (!overflow)
176628a2c14SArjun P     return ret;
177628a2c14SArjun P 
178628a2c14SArjun P   width *= 2;
179628a2c14SArjun P   ret = op(a.sext(width), b.sext(width), overflow);
180628a2c14SArjun P   assert(!overflow && "double width should be sufficient to avoid overflow!");
181628a2c14SArjun P   return ret;
182628a2c14SArjun P }
183628a2c14SArjun P 
operator +(const SlowMPInt & o) const184628a2c14SArjun P SlowMPInt SlowMPInt::operator+(const SlowMPInt &o) const {
185628a2c14SArjun P   return SlowMPInt(
186628a2c14SArjun P       runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::sadd_ov)));
187628a2c14SArjun P }
operator -(const SlowMPInt & o) const188628a2c14SArjun P SlowMPInt SlowMPInt::operator-(const SlowMPInt &o) const {
189628a2c14SArjun P   return SlowMPInt(
190628a2c14SArjun P       runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::ssub_ov)));
191628a2c14SArjun P }
operator *(const SlowMPInt & o) const192628a2c14SArjun P SlowMPInt SlowMPInt::operator*(const SlowMPInt &o) const {
193628a2c14SArjun P   return SlowMPInt(
194628a2c14SArjun P       runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::smul_ov)));
195628a2c14SArjun P }
operator /(const SlowMPInt & o) const196628a2c14SArjun P SlowMPInt SlowMPInt::operator/(const SlowMPInt &o) const {
197628a2c14SArjun P   return SlowMPInt(
198628a2c14SArjun P       runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::sdiv_ov)));
199628a2c14SArjun P }
abs(const SlowMPInt & x)200628a2c14SArjun P SlowMPInt detail::abs(const SlowMPInt &x) { return x >= 0 ? x : -x; }
ceilDiv(const SlowMPInt & lhs,const SlowMPInt & rhs)201628a2c14SArjun P SlowMPInt detail::ceilDiv(const SlowMPInt &lhs, const SlowMPInt &rhs) {
202628a2c14SArjun P   if (rhs == -1)
203628a2c14SArjun P     return -lhs;
204*ca6a5afbSArjun P   unsigned width = getMaxWidth(lhs.val, rhs.val);
205*ca6a5afbSArjun P   return SlowMPInt(llvm::APIntOps::RoundingSDiv(
206*ca6a5afbSArjun P       lhs.val.sext(width), rhs.val.sext(width), APInt::Rounding::UP));
207628a2c14SArjun P }
floorDiv(const SlowMPInt & lhs,const SlowMPInt & rhs)208628a2c14SArjun P SlowMPInt detail::floorDiv(const SlowMPInt &lhs, const SlowMPInt &rhs) {
209628a2c14SArjun P   if (rhs == -1)
210628a2c14SArjun P     return -lhs;
211*ca6a5afbSArjun P   unsigned width = getMaxWidth(lhs.val, rhs.val);
212*ca6a5afbSArjun P   return SlowMPInt(llvm::APIntOps::RoundingSDiv(
213*ca6a5afbSArjun P       lhs.val.sext(width), rhs.val.sext(width), APInt::Rounding::DOWN));
214628a2c14SArjun P }
215628a2c14SArjun P // The RHS is always expected to be positive, and the result
216628a2c14SArjun P /// is always non-negative.
mod(const SlowMPInt & lhs,const SlowMPInt & rhs)217628a2c14SArjun P SlowMPInt detail::mod(const SlowMPInt &lhs, const SlowMPInt &rhs) {
218628a2c14SArjun P   assert(rhs >= 1 && "mod is only supported for positive divisors!");
219628a2c14SArjun P   return lhs % rhs < 0 ? lhs % rhs + rhs : lhs % rhs;
220628a2c14SArjun P }
221628a2c14SArjun P 
gcd(const SlowMPInt & a,const SlowMPInt & b)222628a2c14SArjun P SlowMPInt detail::gcd(const SlowMPInt &a, const SlowMPInt &b) {
22386d73c11SArjun P   assert(a >= 0 && b >= 0 && "operands must be non-negative!");
22486d73c11SArjun P   return SlowMPInt(llvm::APIntOps::GreatestCommonDivisor(a.val, b.val));
225628a2c14SArjun P }
226628a2c14SArjun P 
227628a2c14SArjun P /// Returns the least common multiple of 'a' and 'b'.
lcm(const SlowMPInt & a,const SlowMPInt & b)228628a2c14SArjun P SlowMPInt detail::lcm(const SlowMPInt &a, const SlowMPInt &b) {
229628a2c14SArjun P   SlowMPInt x = abs(a);
230628a2c14SArjun P   SlowMPInt y = abs(b);
231628a2c14SArjun P   return (x * y) / gcd(x, y);
232628a2c14SArjun P }
233628a2c14SArjun P 
234628a2c14SArjun P /// This operation cannot overflow.
operator %(const SlowMPInt & o) const235628a2c14SArjun P SlowMPInt SlowMPInt::operator%(const SlowMPInt &o) const {
236628a2c14SArjun P   unsigned width = std::max(val.getBitWidth(), o.val.getBitWidth());
237628a2c14SArjun P   return SlowMPInt(val.sext(width).srem(o.val.sext(width)));
238628a2c14SArjun P }
239628a2c14SArjun P 
operator -() const240628a2c14SArjun P SlowMPInt SlowMPInt::operator-() const {
241628a2c14SArjun P   if (val.isMinSignedValue()) {
242628a2c14SArjun P     /// Overflow only occurs when the value is the minimum possible value.
243628a2c14SArjun P     APInt ret = val.sext(2 * val.getBitWidth());
244628a2c14SArjun P     return SlowMPInt(-ret);
245628a2c14SArjun P   }
246628a2c14SArjun P   return SlowMPInt(-val);
247628a2c14SArjun P }
248628a2c14SArjun P 
249628a2c14SArjun P /// ---------------------------------------------------------------------------
250628a2c14SArjun P /// Assignment operators, preincrement, predecrement.
251628a2c14SArjun P /// ---------------------------------------------------------------------------
operator +=(const SlowMPInt & o)252628a2c14SArjun P SlowMPInt &SlowMPInt::operator+=(const SlowMPInt &o) {
253628a2c14SArjun P   *this = *this + o;
254628a2c14SArjun P   return *this;
255628a2c14SArjun P }
operator -=(const SlowMPInt & o)256628a2c14SArjun P SlowMPInt &SlowMPInt::operator-=(const SlowMPInt &o) {
257628a2c14SArjun P   *this = *this - o;
258628a2c14SArjun P   return *this;
259628a2c14SArjun P }
operator *=(const SlowMPInt & o)260628a2c14SArjun P SlowMPInt &SlowMPInt::operator*=(const SlowMPInt &o) {
261628a2c14SArjun P   *this = *this * o;
262628a2c14SArjun P   return *this;
263628a2c14SArjun P }
operator /=(const SlowMPInt & o)264628a2c14SArjun P SlowMPInt &SlowMPInt::operator/=(const SlowMPInt &o) {
265628a2c14SArjun P   *this = *this / o;
266628a2c14SArjun P   return *this;
267628a2c14SArjun P }
operator %=(const SlowMPInt & o)268628a2c14SArjun P SlowMPInt &SlowMPInt::operator%=(const SlowMPInt &o) {
269628a2c14SArjun P   *this = *this % o;
270628a2c14SArjun P   return *this;
271628a2c14SArjun P }
operator ++()272628a2c14SArjun P SlowMPInt &SlowMPInt::operator++() {
273628a2c14SArjun P   *this += 1;
274628a2c14SArjun P   return *this;
275628a2c14SArjun P }
276628a2c14SArjun P 
operator --()277628a2c14SArjun P SlowMPInt &SlowMPInt::operator--() {
278628a2c14SArjun P   *this -= 1;
279628a2c14SArjun P   return *this;
280628a2c14SArjun P }
281