1 //===- IntegerRelation.cpp - MLIR IntegerRelation Class ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A class to represent an relation over integer tuples. A relation is
10 // represented as a constraint system over a space of tuples of integer valued
11 // varaiables supporting symbolic identifiers and existential quantification.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Analysis/Presburger/IntegerRelation.h"
16 #include "mlir/Analysis/Presburger/LinearTransform.h"
17 #include "mlir/Analysis/Presburger/PresburgerRelation.h"
18 #include "mlir/Analysis/Presburger/Simplex.h"
19 #include "mlir/Analysis/Presburger/Utils.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/DenseSet.h"
22 #include "llvm/Support/Debug.h"
23 
24 #define DEBUG_TYPE "presburger"
25 
26 using namespace mlir;
27 using namespace presburger;
28 
29 using llvm::SmallDenseMap;
30 using llvm::SmallDenseSet;
31 
32 std::unique_ptr<IntegerRelation> IntegerRelation::clone() const {
33   return std::make_unique<IntegerRelation>(*this);
34 }
35 
36 std::unique_ptr<IntegerPolyhedron> IntegerPolyhedron::clone() const {
37   return std::make_unique<IntegerPolyhedron>(*this);
38 }
39 
40 void IntegerRelation::append(const IntegerRelation &other) {
41   assert(isSpaceEqual(other) && "Spaces must be equal.");
42 
43   inequalities.reserveRows(inequalities.getNumRows() +
44                            other.getNumInequalities());
45   equalities.reserveRows(equalities.getNumRows() + other.getNumEqualities());
46 
47   for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) {
48     addInequality(other.getInequality(r));
49   }
50   for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) {
51     addEquality(other.getEquality(r));
52   }
53 }
54 
55 IntegerRelation IntegerRelation::intersect(IntegerRelation other) const {
56   IntegerRelation result = *this;
57   result.mergeLocalIds(other);
58   result.append(other);
59   return result;
60 }
61 
62 bool IntegerRelation::isEqual(const IntegerRelation &other) const {
63   assert(isSpaceEqual(other) && "Spaces must be equal.");
64   return PresburgerRelation(*this).isEqual(PresburgerRelation(other));
65 }
66 
67 bool IntegerRelation::isSubsetOf(const IntegerRelation &other) const {
68   assert(isSpaceEqual(other) && "Spaces must be equal.");
69   return PresburgerRelation(*this).isSubsetOf(PresburgerRelation(other));
70 }
71 
72 MaybeOptimum<SmallVector<Fraction, 8>>
73 IntegerRelation::findRationalLexMin() const {
74   assert(getNumSymbolIds() == 0 && "Symbols are not supported!");
75   MaybeOptimum<SmallVector<Fraction, 8>> maybeLexMin =
76       LexSimplex(*this).findRationalLexMin();
77 
78   if (!maybeLexMin.isBounded())
79     return maybeLexMin;
80 
81   // The Simplex returns the lexmin over all the variables including locals. But
82   // locals are not actually part of the space and should not be returned in the
83   // result. Since the locals are placed last in the list of identifiers, they
84   // will be minimized last in the lexmin. So simply truncating out the locals
85   // from the end of the answer gives the desired lexmin over the dimensions.
86   assert(maybeLexMin->size() == getNumIds() &&
87          "Incorrect number of vars in lexMin!");
88   maybeLexMin->resize(getNumDimAndSymbolIds());
89   return maybeLexMin;
90 }
91 
92 MaybeOptimum<SmallVector<int64_t, 8>>
93 IntegerRelation::findIntegerLexMin() const {
94   assert(getNumSymbolIds() == 0 && "Symbols are not supported!");
95   MaybeOptimum<SmallVector<int64_t, 8>> maybeLexMin =
96       LexSimplex(*this).findIntegerLexMin();
97 
98   if (!maybeLexMin.isBounded())
99     return maybeLexMin.getKind();
100 
101   // The Simplex returns the lexmin over all the variables including locals. But
102   // locals are not actually part of the space and should not be returned in the
103   // result. Since the locals are placed last in the list of identifiers, they
104   // will be minimized last in the lexmin. So simply truncating out the locals
105   // from the end of the answer gives the desired lexmin over the dimensions.
106   assert(maybeLexMin->size() == getNumIds() &&
107          "Incorrect number of vars in lexMin!");
108   maybeLexMin->resize(getNumDimAndSymbolIds());
109   return maybeLexMin;
110 }
111 
112 static bool rangeIsZero(ArrayRef<int64_t> range) {
113   return llvm::all_of(range, [](int64_t x) { return x == 0; });
114 }
115 
116 void removeConstraintsInvolvingIdRange(IntegerRelation &poly, unsigned begin,
117                                        unsigned count) {
118   // We loop until i > 0 and index into i - 1 to avoid sign issues.
119   //
120   // We iterate backwards so that whether we remove constraint i - 1 or not, the
121   // next constraint to be tested is always i - 2.
122   for (unsigned i = poly.getNumEqualities(); i > 0; i--)
123     if (!rangeIsZero(poly.getEquality(i - 1).slice(begin, count)))
124       poly.removeEquality(i - 1);
125   for (unsigned i = poly.getNumInequalities(); i > 0; i--)
126     if (!rangeIsZero(poly.getInequality(i - 1).slice(begin, count)))
127       poly.removeInequality(i - 1);
128 }
129 
130 IntegerRelation::CountsSnapshot IntegerRelation::getCounts() const {
131   return {PresburgerSpace(*this), getNumInequalities(), getNumEqualities()};
132 }
133 
134 void IntegerRelation::truncateIdKind(IdKind kind,
135                                      const CountsSnapshot &counts) {
136   truncateIdKind(kind, counts.getSpace().getNumIdKind(kind));
137 }
138 
139 void IntegerRelation::truncate(const CountsSnapshot &counts) {
140   truncateIdKind(IdKind::Domain, counts);
141   truncateIdKind(IdKind::Range, counts);
142   truncateIdKind(IdKind::Symbol, counts);
143   truncateIdKind(IdKind::Local, counts);
144   removeInequalityRange(counts.getNumIneqs(), getNumInequalities());
145   removeInequalityRange(counts.getNumEqs(), getNumEqualities());
146 }
147 
148 unsigned IntegerRelation::insertId(IdKind kind, unsigned pos, unsigned num) {
149   assert(pos <= getNumIdKind(kind));
150 
151   unsigned insertPos = PresburgerSpace::insertId(kind, pos, num);
152   inequalities.insertColumns(insertPos, num);
153   equalities.insertColumns(insertPos, num);
154   return insertPos;
155 }
156 
157 unsigned IntegerRelation::appendId(IdKind kind, unsigned num) {
158   unsigned pos = getNumIdKind(kind);
159   return insertId(kind, pos, num);
160 }
161 
162 void IntegerRelation::addEquality(ArrayRef<int64_t> eq) {
163   assert(eq.size() == getNumCols());
164   unsigned row = equalities.appendExtraRow();
165   for (unsigned i = 0, e = eq.size(); i < e; ++i)
166     equalities(row, i) = eq[i];
167 }
168 
169 void IntegerRelation::addInequality(ArrayRef<int64_t> inEq) {
170   assert(inEq.size() == getNumCols());
171   unsigned row = inequalities.appendExtraRow();
172   for (unsigned i = 0, e = inEq.size(); i < e; ++i)
173     inequalities(row, i) = inEq[i];
174 }
175 
176 void IntegerRelation::removeId(IdKind kind, unsigned pos) {
177   removeIdRange(kind, pos, pos + 1);
178 }
179 
180 void IntegerRelation::removeId(unsigned pos) { removeIdRange(pos, pos + 1); }
181 
182 void IntegerRelation::removeIdRange(IdKind kind, unsigned idStart,
183                                     unsigned idLimit) {
184   assert(idLimit <= getNumIdKind(kind));
185 
186   if (idStart >= idLimit)
187     return;
188 
189   // Remove eliminated identifiers from the constraints.
190   unsigned offset = getIdKindOffset(kind);
191   equalities.removeColumns(offset + idStart, idLimit - idStart);
192   inequalities.removeColumns(offset + idStart, idLimit - idStart);
193 
194   // Remove eliminated identifiers from the space.
195   PresburgerSpace::removeIdRange(kind, idStart, idLimit);
196 }
197 
198 void IntegerRelation::removeIdRange(unsigned idStart, unsigned idLimit) {
199   assert(idLimit <= getNumIds());
200 
201   if (idStart >= idLimit)
202     return;
203 
204   // Helper function to remove ids of the specified kind in the given range
205   // [start, limit), The range is absolute (i.e. it is not relative to the kind
206   // of identifier). Also updates `limit` to reflect the deleted identifiers.
207   auto removeIdKindInRange = [this](IdKind kind, unsigned &start,
208                                     unsigned &limit) {
209     if (start >= limit)
210       return;
211 
212     unsigned offset = getIdKindOffset(kind);
213     unsigned num = getNumIdKind(kind);
214 
215     // Get `start`, `limit` relative to the specified kind.
216     unsigned relativeStart =
217         start <= offset ? 0 : std::min(num, start - offset);
218     unsigned relativeLimit =
219         limit <= offset ? 0 : std::min(num, limit - offset);
220 
221     // Remove ids of the specified kind in the relative range.
222     removeIdRange(kind, relativeStart, relativeLimit);
223 
224     // Update `limit` to reflect deleted identifiers.
225     // `start` does not need to be updated because any identifiers that are
226     // deleted are after position `start`.
227     limit -= relativeLimit - relativeStart;
228   };
229 
230   removeIdKindInRange(IdKind::Domain, idStart, idLimit);
231   removeIdKindInRange(IdKind::Range, idStart, idLimit);
232   removeIdKindInRange(IdKind::Symbol, idStart, idLimit);
233   removeIdKindInRange(IdKind::Local, idStart, idLimit);
234 }
235 
236 void IntegerRelation::removeEquality(unsigned pos) {
237   equalities.removeRow(pos);
238 }
239 
240 void IntegerRelation::removeInequality(unsigned pos) {
241   inequalities.removeRow(pos);
242 }
243 
244 void IntegerRelation::removeEqualityRange(unsigned start, unsigned end) {
245   if (start >= end)
246     return;
247   equalities.removeRows(start, end - start);
248 }
249 
250 void IntegerRelation::removeInequalityRange(unsigned start, unsigned end) {
251   if (start >= end)
252     return;
253   inequalities.removeRows(start, end - start);
254 }
255 
256 void IntegerRelation::swapId(unsigned posA, unsigned posB) {
257   assert(posA < getNumIds() && "invalid position A");
258   assert(posB < getNumIds() && "invalid position B");
259 
260   if (posA == posB)
261     return;
262 
263   inequalities.swapColumns(posA, posB);
264   equalities.swapColumns(posA, posB);
265 }
266 
267 void IntegerRelation::clearConstraints() {
268   equalities.resizeVertically(0);
269   inequalities.resizeVertically(0);
270 }
271 
272 /// Gather all lower and upper bounds of the identifier at `pos`, and
273 /// optionally any equalities on it. In addition, the bounds are to be
274 /// independent of identifiers in position range [`offset`, `offset` + `num`).
275 void IntegerRelation::getLowerAndUpperBoundIndices(
276     unsigned pos, SmallVectorImpl<unsigned> *lbIndices,
277     SmallVectorImpl<unsigned> *ubIndices, SmallVectorImpl<unsigned> *eqIndices,
278     unsigned offset, unsigned num) const {
279   assert(pos < getNumIds() && "invalid position");
280   assert(offset + num < getNumCols() && "invalid range");
281 
282   // Checks for a constraint that has a non-zero coeff for the identifiers in
283   // the position range [offset, offset + num) while ignoring `pos`.
284   auto containsConstraintDependentOnRange = [&](unsigned r, bool isEq) {
285     unsigned c, f;
286     auto cst = isEq ? getEquality(r) : getInequality(r);
287     for (c = offset, f = offset + num; c < f; ++c) {
288       if (c == pos)
289         continue;
290       if (cst[c] != 0)
291         break;
292     }
293     return c < f;
294   };
295 
296   // Gather all lower bounds and upper bounds of the variable. Since the
297   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
298   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
299   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
300     // The bounds are to be independent of [offset, offset + num) columns.
301     if (containsConstraintDependentOnRange(r, /*isEq=*/false))
302       continue;
303     if (atIneq(r, pos) >= 1) {
304       // Lower bound.
305       lbIndices->push_back(r);
306     } else if (atIneq(r, pos) <= -1) {
307       // Upper bound.
308       ubIndices->push_back(r);
309     }
310   }
311 
312   // An equality is both a lower and upper bound. Record any equalities
313   // involving the pos^th identifier.
314   if (!eqIndices)
315     return;
316 
317   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
318     if (atEq(r, pos) == 0)
319       continue;
320     if (containsConstraintDependentOnRange(r, /*isEq=*/true))
321       continue;
322     eqIndices->push_back(r);
323   }
324 }
325 
326 bool IntegerRelation::hasConsistentState() const {
327   if (!inequalities.hasConsistentState())
328     return false;
329   if (!equalities.hasConsistentState())
330     return false;
331   return true;
332 }
333 
334 void IntegerRelation::setAndEliminate(unsigned pos, ArrayRef<int64_t> values) {
335   if (values.empty())
336     return;
337   assert(pos + values.size() <= getNumIds() &&
338          "invalid position or too many values");
339   // Setting x_j = p in sum_i a_i x_i + c is equivalent to adding p*a_j to the
340   // constant term and removing the id x_j. We do this for all the ids
341   // pos, pos + 1, ... pos + values.size() - 1.
342   unsigned constantColPos = getNumCols() - 1;
343   for (unsigned i = 0, numVals = values.size(); i < numVals; ++i)
344     inequalities.addToColumn(i + pos, constantColPos, values[i]);
345   for (unsigned i = 0, numVals = values.size(); i < numVals; ++i)
346     equalities.addToColumn(i + pos, constantColPos, values[i]);
347   removeIdRange(pos, pos + values.size());
348 }
349 
350 void IntegerRelation::clearAndCopyFrom(const IntegerRelation &other) {
351   *this = other;
352 }
353 
354 // Searches for a constraint with a non-zero coefficient at `colIdx` in
355 // equality (isEq=true) or inequality (isEq=false) constraints.
356 // Returns true and sets row found in search in `rowIdx`, false otherwise.
357 bool IntegerRelation::findConstraintWithNonZeroAt(unsigned colIdx, bool isEq,
358                                                   unsigned *rowIdx) const {
359   assert(colIdx < getNumCols() && "position out of bounds");
360   auto at = [&](unsigned rowIdx) -> int64_t {
361     return isEq ? atEq(rowIdx, colIdx) : atIneq(rowIdx, colIdx);
362   };
363   unsigned e = isEq ? getNumEqualities() : getNumInequalities();
364   for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) {
365     if (at(*rowIdx) != 0) {
366       return true;
367     }
368   }
369   return false;
370 }
371 
372 void IntegerRelation::normalizeConstraintsByGCD() {
373   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
374     equalities.normalizeRow(i);
375   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i)
376     inequalities.normalizeRow(i);
377 }
378 
379 bool IntegerRelation::hasInvalidConstraint() const {
380   assert(hasConsistentState());
381   auto check = [&](bool isEq) -> bool {
382     unsigned numCols = getNumCols();
383     unsigned numRows = isEq ? getNumEqualities() : getNumInequalities();
384     for (unsigned i = 0, e = numRows; i < e; ++i) {
385       unsigned j;
386       for (j = 0; j < numCols - 1; ++j) {
387         int64_t v = isEq ? atEq(i, j) : atIneq(i, j);
388         // Skip rows with non-zero variable coefficients.
389         if (v != 0)
390           break;
391       }
392       if (j < numCols - 1) {
393         continue;
394       }
395       // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'.
396       // Example invalid constraints include: '1 == 0' or '-1 >= 0'
397       int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1);
398       if ((isEq && v != 0) || (!isEq && v < 0)) {
399         return true;
400       }
401     }
402     return false;
403   };
404   if (check(/*isEq=*/true))
405     return true;
406   return check(/*isEq=*/false);
407 }
408 
409 /// Eliminate identifier from constraint at `rowIdx` based on coefficient at
410 /// pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be
411 /// updated as they have already been eliminated.
412 static void eliminateFromConstraint(IntegerRelation *constraints,
413                                     unsigned rowIdx, unsigned pivotRow,
414                                     unsigned pivotCol, unsigned elimColStart,
415                                     bool isEq) {
416   // Skip if equality 'rowIdx' if same as 'pivotRow'.
417   if (isEq && rowIdx == pivotRow)
418     return;
419   auto at = [&](unsigned i, unsigned j) -> int64_t {
420     return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j);
421   };
422   int64_t leadCoeff = at(rowIdx, pivotCol);
423   // Skip if leading coefficient at 'rowIdx' is already zero.
424   if (leadCoeff == 0)
425     return;
426   int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol);
427   int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1;
428   int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff);
429   int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff));
430   int64_t rowMultiplier = lcm / std::abs(leadCoeff);
431 
432   unsigned numCols = constraints->getNumCols();
433   for (unsigned j = 0; j < numCols; ++j) {
434     // Skip updating column 'j' if it was just eliminated.
435     if (j >= elimColStart && j < pivotCol)
436       continue;
437     int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) +
438                 rowMultiplier * at(rowIdx, j);
439     isEq ? constraints->atEq(rowIdx, j) = v
440          : constraints->atIneq(rowIdx, j) = v;
441   }
442 }
443 
444 /// Returns the position of the identifier that has the minimum <number of lower
445 /// bounds> times <number of upper bounds> from the specified range of
446 /// identifiers [start, end). It is often best to eliminate in the increasing
447 /// order of these counts when doing Fourier-Motzkin elimination since FM adds
448 /// that many new constraints.
449 static unsigned getBestIdToEliminate(const IntegerRelation &cst, unsigned start,
450                                      unsigned end) {
451   assert(start < cst.getNumIds() && end < cst.getNumIds() + 1);
452 
453   auto getProductOfNumLowerUpperBounds = [&](unsigned pos) {
454     unsigned numLb = 0;
455     unsigned numUb = 0;
456     for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
457       if (cst.atIneq(r, pos) > 0) {
458         ++numLb;
459       } else if (cst.atIneq(r, pos) < 0) {
460         ++numUb;
461       }
462     }
463     return numLb * numUb;
464   };
465 
466   unsigned minLoc = start;
467   unsigned min = getProductOfNumLowerUpperBounds(start);
468   for (unsigned c = start + 1; c < end; c++) {
469     unsigned numLbUbProduct = getProductOfNumLowerUpperBounds(c);
470     if (numLbUbProduct < min) {
471       min = numLbUbProduct;
472       minLoc = c;
473     }
474   }
475   return minLoc;
476 }
477 
478 // Checks for emptiness of the set by eliminating identifiers successively and
479 // using the GCD test (on all equality constraints) and checking for trivially
480 // invalid constraints. Returns 'true' if the constraint system is found to be
481 // empty; false otherwise.
482 bool IntegerRelation::isEmpty() const {
483   if (isEmptyByGCDTest() || hasInvalidConstraint())
484     return true;
485 
486   IntegerRelation tmpCst(*this);
487 
488   // First, eliminate as many local variables as possible using equalities.
489   tmpCst.removeRedundantLocalVars();
490   if (tmpCst.isEmptyByGCDTest() || tmpCst.hasInvalidConstraint())
491     return true;
492 
493   // Eliminate as many identifiers as possible using Gaussian elimination.
494   unsigned currentPos = 0;
495   while (currentPos < tmpCst.getNumIds()) {
496     tmpCst.gaussianEliminateIds(currentPos, tmpCst.getNumIds());
497     ++currentPos;
498     // We check emptiness through trivial checks after eliminating each ID to
499     // detect emptiness early. Since the checks isEmptyByGCDTest() and
500     // hasInvalidConstraint() are linear time and single sweep on the constraint
501     // buffer, this appears reasonable - but can optimize in the future.
502     if (tmpCst.hasInvalidConstraint() || tmpCst.isEmptyByGCDTest())
503       return true;
504   }
505 
506   // Eliminate the remaining using FM.
507   for (unsigned i = 0, e = tmpCst.getNumIds(); i < e; i++) {
508     tmpCst.fourierMotzkinEliminate(
509         getBestIdToEliminate(tmpCst, 0, tmpCst.getNumIds()));
510     // Check for a constraint explosion. This rarely happens in practice, but
511     // this check exists as a safeguard against improperly constructed
512     // constraint systems or artificially created arbitrarily complex systems
513     // that aren't the intended use case for IntegerRelation. This is
514     // needed since FM has a worst case exponential complexity in theory.
515     if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumIds()) {
516       LLVM_DEBUG(llvm::dbgs() << "FM constraint explosion detected\n");
517       return false;
518     }
519 
520     // FM wouldn't have modified the equalities in any way. So no need to again
521     // run GCD test. Check for trivial invalid constraints.
522     if (tmpCst.hasInvalidConstraint())
523       return true;
524   }
525   return false;
526 }
527 
528 // Runs the GCD test on all equality constraints. Returns 'true' if this test
529 // fails on any equality. Returns 'false' otherwise.
530 // This test can be used to disprove the existence of a solution. If it returns
531 // true, no integer solution to the equality constraints can exist.
532 //
533 // GCD test definition:
534 //
535 // The equality constraint:
536 //
537 //  c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0
538 //
539 // has an integer solution iff:
540 //
541 //  GCD of c_1, c_2, ..., c_n divides c_0.
542 //
543 bool IntegerRelation::isEmptyByGCDTest() const {
544   assert(hasConsistentState());
545   unsigned numCols = getNumCols();
546   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
547     uint64_t gcd = std::abs(atEq(i, 0));
548     for (unsigned j = 1; j < numCols - 1; ++j) {
549       gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atEq(i, j)));
550     }
551     int64_t v = std::abs(atEq(i, numCols - 1));
552     if (gcd > 0 && (v % gcd != 0)) {
553       return true;
554     }
555   }
556   return false;
557 }
558 
559 // Returns a matrix where each row is a vector along which the polytope is
560 // bounded. The span of the returned vectors is guaranteed to contain all
561 // such vectors. The returned vectors are NOT guaranteed to be linearly
562 // independent. This function should not be called on empty sets.
563 //
564 // It is sufficient to check the perpendiculars of the constraints, as the set
565 // of perpendiculars which are bounded must span all bounded directions.
566 Matrix IntegerRelation::getBoundedDirections() const {
567   // Note that it is necessary to add the equalities too (which the constructor
568   // does) even though we don't need to check if they are bounded; whether an
569   // inequality is bounded or not depends on what other constraints, including
570   // equalities, are present.
571   Simplex simplex(*this);
572 
573   assert(!simplex.isEmpty() && "It is not meaningful to ask whether a "
574                                "direction is bounded in an empty set.");
575 
576   SmallVector<unsigned, 8> boundedIneqs;
577   // The constructor adds the inequalities to the simplex first, so this
578   // processes all the inequalities.
579   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
580     if (simplex.isBoundedAlongConstraint(i))
581       boundedIneqs.push_back(i);
582   }
583 
584   // The direction vector is given by the coefficients and does not include the
585   // constant term, so the matrix has one fewer column.
586   unsigned dirsNumCols = getNumCols() - 1;
587   Matrix dirs(boundedIneqs.size() + getNumEqualities(), dirsNumCols);
588 
589   // Copy the bounded inequalities.
590   unsigned row = 0;
591   for (unsigned i : boundedIneqs) {
592     for (unsigned col = 0; col < dirsNumCols; ++col)
593       dirs(row, col) = atIneq(i, col);
594     ++row;
595   }
596 
597   // Copy the equalities. All the equalities' perpendiculars are bounded.
598   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
599     for (unsigned col = 0; col < dirsNumCols; ++col)
600       dirs(row, col) = atEq(i, col);
601     ++row;
602   }
603 
604   return dirs;
605 }
606 
607 bool IntegerRelation::isIntegerEmpty() const {
608   return !findIntegerSample().hasValue();
609 }
610 
611 /// Let this set be S. If S is bounded then we directly call into the GBR
612 /// sampling algorithm. Otherwise, there are some unbounded directions, i.e.,
613 /// vectors v such that S extends to infinity along v or -v. In this case we
614 /// use an algorithm described in the integer set library (isl) manual and used
615 /// by the isl_set_sample function in that library. The algorithm is:
616 ///
617 /// 1) Apply a unimodular transform T to S to obtain S*T, such that all
618 /// dimensions in which S*T is bounded lie in the linear span of a prefix of the
619 /// dimensions.
620 ///
621 /// 2) Construct a set B by removing all constraints that involve
622 /// the unbounded dimensions and then deleting the unbounded dimensions. Note
623 /// that B is a Bounded set.
624 ///
625 /// 3) Try to obtain a sample from B using the GBR sampling
626 /// algorithm. If no sample is found, return that S is empty.
627 ///
628 /// 4) Otherwise, substitute the obtained sample into S*T to obtain a set
629 /// C. C is a full-dimensional Cone and always contains a sample.
630 ///
631 /// 5) Obtain an integer sample from C.
632 ///
633 /// 6) Return T*v, where v is the concatenation of the samples from B and C.
634 ///
635 /// The following is a sketch of a proof that
636 /// a) If the algorithm returns empty, then S is empty.
637 /// b) If the algorithm returns a sample, it is a valid sample in S.
638 ///
639 /// The algorithm returns empty only if B is empty, in which case S*T is
640 /// certainly empty since B was obtained by removing constraints and then
641 /// deleting unconstrained dimensions from S*T. Since T is unimodular, a vector
642 /// v is in S*T iff T*v is in S. So in this case, since
643 /// S*T is empty, S is empty too.
644 ///
645 /// Otherwise, the algorithm substitutes the sample from B into S*T. All the
646 /// constraints of S*T that did not involve unbounded dimensions are satisfied
647 /// by this substitution. All dimensions in the linear span of the dimensions
648 /// outside the prefix are unbounded in S*T (step 1). Substituting values for
649 /// the bounded dimensions cannot make these dimensions bounded, and these are
650 /// the only remaining dimensions in C, so C is unbounded along every vector (in
651 /// the positive or negative direction, or both). C is hence a full-dimensional
652 /// cone and therefore always contains an integer point.
653 ///
654 /// Concatenating the samples from B and C gives a sample v in S*T, so the
655 /// returned sample T*v is a sample in S.
656 Optional<SmallVector<int64_t, 8>> IntegerRelation::findIntegerSample() const {
657   // First, try the GCD test heuristic.
658   if (isEmptyByGCDTest())
659     return {};
660 
661   Simplex simplex(*this);
662   if (simplex.isEmpty())
663     return {};
664 
665   // For a bounded set, we directly call into the GBR sampling algorithm.
666   if (!simplex.isUnbounded())
667     return simplex.findIntegerSample();
668 
669   // The set is unbounded. We cannot directly use the GBR algorithm.
670   //
671   // m is a matrix containing, in each row, a vector in which S is
672   // bounded, such that the linear span of all these dimensions contains all
673   // bounded dimensions in S.
674   Matrix m = getBoundedDirections();
675   // In column echelon form, each row of m occupies only the first rank(m)
676   // columns and has zeros on the other columns. The transform T that brings S
677   // to column echelon form is unimodular as well, so this is a suitable
678   // transform to use in step 1 of the algorithm.
679   std::pair<unsigned, LinearTransform> result =
680       LinearTransform::makeTransformToColumnEchelon(std::move(m));
681   const LinearTransform &transform = result.second;
682   // 1) Apply T to S to obtain S*T.
683   IntegerRelation transformedSet = transform.applyTo(*this);
684 
685   // 2) Remove the unbounded dimensions and constraints involving them to
686   // obtain a bounded set.
687   IntegerRelation boundedSet(transformedSet);
688   unsigned numBoundedDims = result.first;
689   unsigned numUnboundedDims = getNumIds() - numBoundedDims;
690   removeConstraintsInvolvingIdRange(boundedSet, numBoundedDims,
691                                     numUnboundedDims);
692   boundedSet.removeIdRange(numBoundedDims, boundedSet.getNumIds());
693 
694   // 3) Try to obtain a sample from the bounded set.
695   Optional<SmallVector<int64_t, 8>> boundedSample =
696       Simplex(boundedSet).findIntegerSample();
697   if (!boundedSample)
698     return {};
699   assert(boundedSet.containsPoint(*boundedSample) &&
700          "Simplex returned an invalid sample!");
701 
702   // 4) Substitute the values of the bounded dimensions into S*T to obtain a
703   // full-dimensional cone, which necessarily contains an integer sample.
704   transformedSet.setAndEliminate(0, *boundedSample);
705   IntegerRelation &cone = transformedSet;
706 
707   // 5) Obtain an integer sample from the cone.
708   //
709   // We shrink the cone such that for any rational point in the shrunken cone,
710   // rounding up each of the point's coordinates produces a point that still
711   // lies in the original cone.
712   //
713   // Rounding up a point x adds a number e_i in [0, 1) to each coordinate x_i.
714   // For each inequality sum_i a_i x_i + c >= 0 in the original cone, the
715   // shrunken cone will have the inequality tightened by some amount s, such
716   // that if x satisfies the shrunken cone's tightened inequality, then x + e
717   // satisfies the original inequality, i.e.,
718   //
719   // sum_i a_i x_i + c + s >= 0 implies sum_i a_i (x_i + e_i) + c >= 0
720   //
721   // for any e_i values in [0, 1). In fact, we will handle the slightly more
722   // general case where e_i can be in [0, 1]. For example, consider the
723   // inequality 2x_1 - 3x_2 - 7x_3 - 6 >= 0, and let x = (3, 0, 0). How low
724   // could the LHS go if we added a number in [0, 1] to each coordinate? The LHS
725   // is minimized when we add 1 to the x_i with negative coefficient a_i and
726   // keep the other x_i the same. In the example, we would get x = (3, 1, 1),
727   // changing the value of the LHS by -3 + -7 = -10.
728   //
729   // In general, the value of the LHS can change by at most the sum of the
730   // negative a_i, so we accomodate this by shifting the inequality by this
731   // amount for the shrunken cone.
732   for (unsigned i = 0, e = cone.getNumInequalities(); i < e; ++i) {
733     for (unsigned j = 0; j < cone.getNumIds(); ++j) {
734       int64_t coeff = cone.atIneq(i, j);
735       if (coeff < 0)
736         cone.atIneq(i, cone.getNumIds()) += coeff;
737     }
738   }
739 
740   // Obtain an integer sample in the cone by rounding up a rational point from
741   // the shrunken cone. Shrinking the cone amounts to shifting its apex
742   // "inwards" without changing its "shape"; the shrunken cone is still a
743   // full-dimensional cone and is hence non-empty.
744   Simplex shrunkenConeSimplex(cone);
745   assert(!shrunkenConeSimplex.isEmpty() && "Shrunken cone cannot be empty!");
746 
747   // The sample will always exist since the shrunken cone is non-empty.
748   SmallVector<Fraction, 8> shrunkenConeSample =
749       *shrunkenConeSimplex.getRationalSample();
750 
751   SmallVector<int64_t, 8> coneSample(llvm::map_range(shrunkenConeSample, ceil));
752 
753   // 6) Return transform * concat(boundedSample, coneSample).
754   SmallVector<int64_t, 8> &sample = boundedSample.getValue();
755   sample.append(coneSample.begin(), coneSample.end());
756   return transform.postMultiplyWithColumn(sample);
757 }
758 
759 /// Helper to evaluate an affine expression at a point.
760 /// The expression is a list of coefficients for the dimensions followed by the
761 /// constant term.
762 static int64_t valueAt(ArrayRef<int64_t> expr, ArrayRef<int64_t> point) {
763   assert(expr.size() == 1 + point.size() &&
764          "Dimensionalities of point and expression don't match!");
765   int64_t value = expr.back();
766   for (unsigned i = 0; i < point.size(); ++i)
767     value += expr[i] * point[i];
768   return value;
769 }
770 
771 /// A point satisfies an equality iff the value of the equality at the
772 /// expression is zero, and it satisfies an inequality iff the value of the
773 /// inequality at that point is non-negative.
774 bool IntegerRelation::containsPoint(ArrayRef<int64_t> point) const {
775   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
776     if (valueAt(getEquality(i), point) != 0)
777       return false;
778   }
779   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
780     if (valueAt(getInequality(i), point) < 0)
781       return false;
782   }
783   return true;
784 }
785 
786 /// Just substitute the values given and check if an integer sample exists for
787 /// the local ids.
788 ///
789 /// TODO: this could be made more efficient by handling divisions separately.
790 /// Instead of finding an integer sample over all the locals, we can first
791 /// compute the values of the locals that have division representations and
792 /// only use the integer emptiness check for the locals that don't have this.
793 /// Handling this correctly requires ordering the divs, though.
794 Optional<SmallVector<int64_t, 8>>
795 IntegerRelation::containsPointNoLocal(ArrayRef<int64_t> point) const {
796   assert(point.size() == getNumIds() - getNumLocalIds() &&
797          "Point should contain all ids except locals!");
798   assert(getIdKindOffset(IdKind::Local) == getNumIds() - getNumLocalIds() &&
799          "This function depends on locals being stored last!");
800   IntegerRelation copy = *this;
801   copy.setAndEliminate(0, point);
802   return copy.findIntegerSample();
803 }
804 
805 void IntegerRelation::getLocalReprs(std::vector<MaybeLocalRepr> &repr) const {
806   std::vector<SmallVector<int64_t, 8>> dividends(getNumLocalIds());
807   SmallVector<unsigned, 4> denominators(getNumLocalIds());
808   getLocalReprs(dividends, denominators, repr);
809 }
810 
811 void IntegerRelation::getLocalReprs(
812     std::vector<SmallVector<int64_t, 8>> &dividends,
813     SmallVector<unsigned, 4> &denominators) const {
814   std::vector<MaybeLocalRepr> repr(getNumLocalIds());
815   getLocalReprs(dividends, denominators, repr);
816 }
817 
818 void IntegerRelation::getLocalReprs(
819     std::vector<SmallVector<int64_t, 8>> &dividends,
820     SmallVector<unsigned, 4> &denominators,
821     std::vector<MaybeLocalRepr> &repr) const {
822 
823   repr.resize(getNumLocalIds());
824   dividends.resize(getNumLocalIds());
825   denominators.resize(getNumLocalIds());
826 
827   SmallVector<bool, 8> foundRepr(getNumIds(), false);
828   for (unsigned i = 0, e = getNumDimAndSymbolIds(); i < e; ++i)
829     foundRepr[i] = true;
830 
831   unsigned divOffset = getNumDimAndSymbolIds();
832   bool changed;
833   do {
834     // Each time changed is true, at end of this iteration, one or more local
835     // vars have been detected as floor divs.
836     changed = false;
837     for (unsigned i = 0, e = getNumLocalIds(); i < e; ++i) {
838       if (!foundRepr[i + divOffset]) {
839         MaybeLocalRepr res = computeSingleVarRepr(
840             *this, foundRepr, divOffset + i, dividends[i], denominators[i]);
841         if (!res)
842           continue;
843         foundRepr[i + divOffset] = true;
844         repr[i] = res;
845         changed = true;
846       }
847     }
848   } while (changed);
849 
850   // Set 0 denominator for identifiers for which no division representation
851   // could be found.
852   for (unsigned i = 0, e = repr.size(); i < e; ++i)
853     if (!repr[i])
854       denominators[i] = 0;
855 }
856 
857 /// Tightens inequalities given that we are dealing with integer spaces. This is
858 /// analogous to the GCD test but applied to inequalities. The constant term can
859 /// be reduced to the preceding multiple of the GCD of the coefficients, i.e.,
860 ///  64*i - 100 >= 0  =>  64*i - 128 >= 0 (since 'i' is an integer). This is a
861 /// fast method - linear in the number of coefficients.
862 // Example on how this affects practical cases: consider the scenario:
863 // 64*i >= 100, j = 64*i; without a tightening, elimination of i would yield
864 // j >= 100 instead of the tighter (exact) j >= 128.
865 void IntegerRelation::gcdTightenInequalities() {
866   unsigned numCols = getNumCols();
867   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
868     // Normalize the constraint and tighten the constant term by the GCD.
869     uint64_t gcd = inequalities.normalizeRow(i, getNumCols() - 1);
870     if (gcd > 1)
871       atIneq(i, numCols - 1) = mlir::floorDiv(atIneq(i, numCols - 1), gcd);
872   }
873 }
874 
875 // Eliminates all identifier variables in column range [posStart, posLimit).
876 // Returns the number of variables eliminated.
877 unsigned IntegerRelation::gaussianEliminateIds(unsigned posStart,
878                                                unsigned posLimit) {
879   // Return if identifier positions to eliminate are out of range.
880   assert(posLimit <= getNumIds());
881   assert(hasConsistentState());
882 
883   if (posStart >= posLimit)
884     return 0;
885 
886   gcdTightenInequalities();
887 
888   unsigned pivotCol = 0;
889   for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) {
890     // Find a row which has a non-zero coefficient in column 'j'.
891     unsigned pivotRow;
892     if (!findConstraintWithNonZeroAt(pivotCol, /*isEq=*/true, &pivotRow)) {
893       // No pivot row in equalities with non-zero at 'pivotCol'.
894       if (!findConstraintWithNonZeroAt(pivotCol, /*isEq=*/false, &pivotRow)) {
895         // If inequalities are also non-zero in 'pivotCol', it can be
896         // eliminated.
897         continue;
898       }
899       break;
900     }
901 
902     // Eliminate identifier at 'pivotCol' from each equality row.
903     for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
904       eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
905                               /*isEq=*/true);
906       equalities.normalizeRow(i);
907     }
908 
909     // Eliminate identifier at 'pivotCol' from each inequality row.
910     for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
911       eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
912                               /*isEq=*/false);
913       inequalities.normalizeRow(i);
914     }
915     removeEquality(pivotRow);
916     gcdTightenInequalities();
917   }
918   // Update position limit based on number eliminated.
919   posLimit = pivotCol;
920   // Remove eliminated columns from all constraints.
921   removeIdRange(posStart, posLimit);
922   return posLimit - posStart;
923 }
924 
925 // A more complex check to eliminate redundant inequalities. Uses FourierMotzkin
926 // to check if a constraint is redundant.
927 void IntegerRelation::removeRedundantInequalities() {
928   SmallVector<bool, 32> redun(getNumInequalities(), false);
929   // To check if an inequality is redundant, we replace the inequality by its
930   // complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting
931   // system is empty. If it is, the inequality is redundant.
932   IntegerRelation tmpCst(*this);
933   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
934     // Change the inequality to its complement.
935     tmpCst.inequalities.negateRow(r);
936     tmpCst.atIneq(r, tmpCst.getNumCols() - 1)--;
937     if (tmpCst.isEmpty()) {
938       redun[r] = true;
939       // Zero fill the redundant inequality.
940       inequalities.fillRow(r, /*value=*/0);
941       tmpCst.inequalities.fillRow(r, /*value=*/0);
942     } else {
943       // Reverse the change (to avoid recreating tmpCst each time).
944       tmpCst.atIneq(r, tmpCst.getNumCols() - 1)++;
945       tmpCst.inequalities.negateRow(r);
946     }
947   }
948 
949   unsigned pos = 0;
950   for (unsigned r = 0, e = getNumInequalities(); r < e; ++r) {
951     if (!redun[r])
952       inequalities.copyRow(r, pos++);
953   }
954   inequalities.resizeVertically(pos);
955 }
956 
957 // A more complex check to eliminate redundant inequalities and equalities. Uses
958 // Simplex to check if a constraint is redundant.
959 void IntegerRelation::removeRedundantConstraints() {
960   // First, we run gcdTightenInequalities. This allows us to catch some
961   // constraints which are not redundant when considering rational solutions
962   // but are redundant in terms of integer solutions.
963   gcdTightenInequalities();
964   Simplex simplex(*this);
965   simplex.detectRedundant();
966 
967   unsigned pos = 0;
968   unsigned numIneqs = getNumInequalities();
969   // Scan to get rid of all inequalities marked redundant, in-place. In Simplex,
970   // the first constraints added are the inequalities.
971   for (unsigned r = 0; r < numIneqs; r++) {
972     if (!simplex.isMarkedRedundant(r))
973       inequalities.copyRow(r, pos++);
974   }
975   inequalities.resizeVertically(pos);
976 
977   // Scan to get rid of all equalities marked redundant, in-place. In Simplex,
978   // after the inequalities, a pair of constraints for each equality is added.
979   // An equality is redundant if both the inequalities in its pair are
980   // redundant.
981   pos = 0;
982   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
983     if (!(simplex.isMarkedRedundant(numIneqs + 2 * r) &&
984           simplex.isMarkedRedundant(numIneqs + 2 * r + 1)))
985       equalities.copyRow(r, pos++);
986   }
987   equalities.resizeVertically(pos);
988 }
989 
990 Optional<uint64_t> IntegerRelation::computeVolume() const {
991   assert(getNumSymbolIds() == 0 && "Symbols are not yet supported!");
992 
993   Simplex simplex(*this);
994   // If the polytope is rationally empty, there are certainly no integer
995   // points.
996   if (simplex.isEmpty())
997     return 0;
998 
999   // Just find the maximum and minimum integer value of each non-local id
1000   // separately, thus finding the number of integer values each such id can
1001   // take. Multiplying these together gives a valid overapproximation of the
1002   // number of integer points in the relation. The result this gives is
1003   // equivalent to projecting (rationally) the relation onto its non-local ids
1004   // and returning the number of integer points in a minimal axis-parallel
1005   // hyperrectangular overapproximation of that.
1006   //
1007   // We also handle the special case where one dimension is unbounded and
1008   // another dimension can take no integer values. In this case, the volume is
1009   // zero.
1010   //
1011   // If there is no such empty dimension, if any dimension is unbounded we
1012   // just return the result as unbounded.
1013   uint64_t count = 1;
1014   SmallVector<int64_t, 8> dim(getNumIds() + 1);
1015   bool hasUnboundedId = false;
1016   for (unsigned i = 0, e = getNumDimAndSymbolIds(); i < e; ++i) {
1017     dim[i] = 1;
1018     MaybeOptimum<int64_t> min, max;
1019     std::tie(min, max) = simplex.computeIntegerBounds(dim);
1020     dim[i] = 0;
1021 
1022     assert((!min.isEmpty() && !max.isEmpty()) &&
1023            "Polytope should be rationally non-empty!");
1024 
1025     // One of the dimensions is unbounded. Note this fact. We will return
1026     // unbounded if none of the other dimensions makes the volume zero.
1027     if (min.isUnbounded() || max.isUnbounded()) {
1028       hasUnboundedId = true;
1029       continue;
1030     }
1031 
1032     // In this case there are no valid integer points and the volume is
1033     // definitely zero.
1034     if (min.getBoundedOptimum() > max.getBoundedOptimum())
1035       return 0;
1036 
1037     count *= (*max - *min + 1);
1038   }
1039 
1040   if (count == 0)
1041     return 0;
1042   if (hasUnboundedId)
1043     return {};
1044   return count;
1045 }
1046 
1047 void IntegerRelation::eliminateRedundantLocalId(unsigned posA, unsigned posB) {
1048   assert(posA < getNumLocalIds() && "Invalid local id position");
1049   assert(posB < getNumLocalIds() && "Invalid local id position");
1050 
1051   unsigned localOffset = getIdKindOffset(IdKind::Local);
1052   posA += localOffset;
1053   posB += localOffset;
1054   inequalities.addToColumn(posB, posA, 1);
1055   equalities.addToColumn(posB, posA, 1);
1056   removeId(posB);
1057 }
1058 
1059 /// Adds additional local ids to the sets such that they both have the union
1060 /// of the local ids in each set, without changing the set of points that
1061 /// lie in `this` and `other`.
1062 ///
1063 /// To detect local ids that always take the same in both sets, each local id is
1064 /// represented as a floordiv with constant denominator in terms of other ids.
1065 /// After extracting these divisions, local ids with the same division
1066 /// representation are considered duplicate and are merged. It is possible that
1067 /// division representation for some local id cannot be obtained, and thus these
1068 /// local ids are not considered for detecting duplicates.
1069 void IntegerRelation::mergeLocalIds(IntegerRelation &other) {
1070   assert(isSpaceCompatible(other) && "Spaces should be compatible.");
1071 
1072   IntegerRelation &relA = *this;
1073   IntegerRelation &relB = other;
1074 
1075   // Merge local ids of relA and relB without using division information,
1076   // i.e. append local ids of `relB` to `relA` and insert local ids of `relA`
1077   // to `relB` at start of its local ids.
1078   unsigned initLocals = relA.getNumLocalIds();
1079   insertId(IdKind::Local, relA.getNumLocalIds(), relB.getNumLocalIds());
1080   relB.insertId(IdKind::Local, 0, initLocals);
1081 
1082   // Get division representations from each rel.
1083   std::vector<SmallVector<int64_t, 8>> divsA, divsB;
1084   SmallVector<unsigned, 4> denomsA, denomsB;
1085   relA.getLocalReprs(divsA, denomsA);
1086   relB.getLocalReprs(divsB, denomsB);
1087 
1088   // Copy division information for relB into `divsA` and `denomsA`, so that
1089   // these have the combined division information of both rels. Since newly
1090   // added local variables in relA and relB have no constraints, they will not
1091   // have any division representation.
1092   std::copy(divsB.begin() + initLocals, divsB.end(),
1093             divsA.begin() + initLocals);
1094   std::copy(denomsB.begin() + initLocals, denomsB.end(),
1095             denomsA.begin() + initLocals);
1096 
1097   // Merge function that merges the local variables in both sets by treating
1098   // them as the same identifier.
1099   auto merge = [&relA, &relB](unsigned i, unsigned j) -> bool {
1100     relA.eliminateRedundantLocalId(i, j);
1101     relB.eliminateRedundantLocalId(i, j);
1102     return true;
1103   };
1104 
1105   // Merge all divisions by removing duplicate divisions.
1106   unsigned localOffset = getIdKindOffset(IdKind::Local);
1107   removeDuplicateDivs(divsA, denomsA, localOffset, merge);
1108 }
1109 
1110 /// Removes local variables using equalities. Each equality is checked if it
1111 /// can be reduced to the form: `e = affine-expr`, where `e` is a local
1112 /// variable and `affine-expr` is an affine expression not containing `e`.
1113 /// If an equality satisfies this form, the local variable is replaced in
1114 /// each constraint and then removed. The equality used to replace this local
1115 /// variable is also removed.
1116 void IntegerRelation::removeRedundantLocalVars() {
1117   // Normalize the equality constraints to reduce coefficients of local
1118   // variables to 1 wherever possible.
1119   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
1120     equalities.normalizeRow(i);
1121 
1122   while (true) {
1123     unsigned i, e, j, f;
1124     for (i = 0, e = getNumEqualities(); i < e; ++i) {
1125       // Find a local variable to eliminate using ith equality.
1126       for (j = getNumDimAndSymbolIds(), f = getNumIds(); j < f; ++j)
1127         if (std::abs(atEq(i, j)) == 1)
1128           break;
1129 
1130       // Local variable can be eliminated using ith equality.
1131       if (j < f)
1132         break;
1133     }
1134 
1135     // No equality can be used to eliminate a local variable.
1136     if (i == e)
1137       break;
1138 
1139     // Use the ith equality to simplify other equalities. If any changes
1140     // are made to an equality constraint, it is normalized by GCD.
1141     for (unsigned k = 0, t = getNumEqualities(); k < t; ++k) {
1142       if (atEq(k, j) != 0) {
1143         eliminateFromConstraint(this, k, i, j, j, /*isEq=*/true);
1144         equalities.normalizeRow(k);
1145       }
1146     }
1147 
1148     // Use the ith equality to simplify inequalities.
1149     for (unsigned k = 0, t = getNumInequalities(); k < t; ++k)
1150       eliminateFromConstraint(this, k, i, j, j, /*isEq=*/false);
1151 
1152     // Remove the ith equality and the found local variable.
1153     removeId(j);
1154     removeEquality(i);
1155   }
1156 }
1157 
1158 void IntegerRelation::convertIdKind(IdKind srcKind, unsigned idStart,
1159                                     unsigned idLimit, IdKind dstKind) {
1160   assert(idLimit <= getNumIdKind(srcKind) && "Invalid id range");
1161 
1162   if (idStart >= idLimit)
1163     return;
1164 
1165   // Append new local variables corresponding to the dimensions to be converted.
1166   unsigned newIdsBegin = getIdKindEnd(dstKind);
1167   unsigned convertCount = idLimit - idStart;
1168   appendId(dstKind, convertCount);
1169 
1170   // Swap the new local variables with dimensions.
1171   //
1172   // Essentially, this moves the information corresponding to the specified ids
1173   // of kind `srcKind` to the `convertCount` newly created ids of kind
1174   // `dstKind`. In particular, this moves the columns in the constraint
1175   // matrices, and zeros out the initially occupied columns (because the newly
1176   // created ids we're swapping with were zero-initialized).
1177   unsigned offset = getIdKindOffset(srcKind);
1178   for (unsigned i = 0; i < convertCount; ++i)
1179     swapId(offset + idStart + i, newIdsBegin + i);
1180 
1181   // Complete the move by deleting the initially occupied columns.
1182   removeIdRange(srcKind, idStart, idLimit);
1183 }
1184 
1185 void IntegerRelation::addBound(BoundType type, unsigned pos, int64_t value) {
1186   assert(pos < getNumCols());
1187   if (type == BoundType::EQ) {
1188     unsigned row = equalities.appendExtraRow();
1189     equalities(row, pos) = 1;
1190     equalities(row, getNumCols() - 1) = -value;
1191   } else {
1192     unsigned row = inequalities.appendExtraRow();
1193     inequalities(row, pos) = type == BoundType::LB ? 1 : -1;
1194     inequalities(row, getNumCols() - 1) =
1195         type == BoundType::LB ? -value : value;
1196   }
1197 }
1198 
1199 void IntegerRelation::addBound(BoundType type, ArrayRef<int64_t> expr,
1200                                int64_t value) {
1201   assert(type != BoundType::EQ && "EQ not implemented");
1202   assert(expr.size() == getNumCols());
1203   unsigned row = inequalities.appendExtraRow();
1204   for (unsigned i = 0, e = expr.size(); i < e; ++i)
1205     inequalities(row, i) = type == BoundType::LB ? expr[i] : -expr[i];
1206   inequalities(inequalities.getNumRows() - 1, getNumCols() - 1) +=
1207       type == BoundType::LB ? -value : value;
1208 }
1209 
1210 /// Adds a new local identifier as the floordiv of an affine function of other
1211 /// identifiers, the coefficients of which are provided in 'dividend' and with
1212 /// respect to a positive constant 'divisor'. Two constraints are added to the
1213 /// system to capture equivalence with the floordiv.
1214 ///      q = expr floordiv c    <=>   c*q <= expr <= c*q + c - 1.
1215 void IntegerRelation::addLocalFloorDiv(ArrayRef<int64_t> dividend,
1216                                        int64_t divisor) {
1217   assert(dividend.size() == getNumCols() && "incorrect dividend size");
1218   assert(divisor > 0 && "positive divisor expected");
1219 
1220   appendId(IdKind::Local);
1221 
1222   // Add two constraints for this new identifier 'q'.
1223   SmallVector<int64_t, 8> bound(dividend.size() + 1);
1224 
1225   // dividend - q * divisor >= 0
1226   std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1,
1227             bound.begin());
1228   bound.back() = dividend.back();
1229   bound[getNumIds() - 1] = -divisor;
1230   addInequality(bound);
1231 
1232   // -dividend +qdivisor * q + divisor - 1 >= 0
1233   std::transform(bound.begin(), bound.end(), bound.begin(),
1234                  std::negate<int64_t>());
1235   bound[bound.size() - 1] += divisor - 1;
1236   addInequality(bound);
1237 }
1238 
1239 /// Finds an equality that equates the specified identifier to a constant.
1240 /// Returns the position of the equality row. If 'symbolic' is set to true,
1241 /// symbols are also treated like a constant, i.e., an affine function of the
1242 /// symbols is also treated like a constant. Returns -1 if such an equality
1243 /// could not be found.
1244 static int findEqualityToConstant(const IntegerRelation &cst, unsigned pos,
1245                                   bool symbolic = false) {
1246   assert(pos < cst.getNumIds() && "invalid position");
1247   for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
1248     int64_t v = cst.atEq(r, pos);
1249     if (v * v != 1)
1250       continue;
1251     unsigned c;
1252     unsigned f = symbolic ? cst.getNumDimIds() : cst.getNumIds();
1253     // This checks for zeros in all positions other than 'pos' in [0, f)
1254     for (c = 0; c < f; c++) {
1255       if (c == pos)
1256         continue;
1257       if (cst.atEq(r, c) != 0) {
1258         // Dependent on another identifier.
1259         break;
1260       }
1261     }
1262     if (c == f)
1263       // Equality is free of other identifiers.
1264       return r;
1265   }
1266   return -1;
1267 }
1268 
1269 LogicalResult IntegerRelation::constantFoldId(unsigned pos) {
1270   assert(pos < getNumIds() && "invalid position");
1271   int rowIdx;
1272   if ((rowIdx = findEqualityToConstant(*this, pos)) == -1)
1273     return failure();
1274 
1275   // atEq(rowIdx, pos) is either -1 or 1.
1276   assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1);
1277   int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos);
1278   setAndEliminate(pos, constVal);
1279   return success();
1280 }
1281 
1282 void IntegerRelation::constantFoldIdRange(unsigned pos, unsigned num) {
1283   for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) {
1284     if (failed(constantFoldId(t)))
1285       t++;
1286   }
1287 }
1288 
1289 /// Returns a non-negative constant bound on the extent (upper bound - lower
1290 /// bound) of the specified identifier if it is found to be a constant; returns
1291 /// None if it's not a constant. This methods treats symbolic identifiers
1292 /// specially, i.e., it looks for constant differences between affine
1293 /// expressions involving only the symbolic identifiers. See comments at
1294 /// function definition for example. 'lb', if provided, is set to the lower
1295 /// bound associated with the constant difference. Note that 'lb' is purely
1296 /// symbolic and thus will contain the coefficients of the symbolic identifiers
1297 /// and the constant coefficient.
1298 //  Egs: 0 <= i <= 15, return 16.
1299 //       s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol)
1300 //       s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16.
1301 //       s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb =
1302 //       ceil(s0 - 7 / 8) = floor(s0 / 8)).
1303 Optional<int64_t> IntegerRelation::getConstantBoundOnDimSize(
1304     unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *boundFloorDivisor,
1305     SmallVectorImpl<int64_t> *ub, unsigned *minLbPos,
1306     unsigned *minUbPos) const {
1307   assert(pos < getNumDimIds() && "Invalid identifier position");
1308 
1309   // Find an equality for 'pos'^th identifier that equates it to some function
1310   // of the symbolic identifiers (+ constant).
1311   int eqPos = findEqualityToConstant(*this, pos, /*symbolic=*/true);
1312   if (eqPos != -1) {
1313     auto eq = getEquality(eqPos);
1314     // If the equality involves a local var, punt for now.
1315     // TODO: this can be handled in the future by using the explicit
1316     // representation of the local vars.
1317     if (!std::all_of(eq.begin() + getNumDimAndSymbolIds(), eq.end() - 1,
1318                      [](int64_t coeff) { return coeff == 0; }))
1319       return None;
1320 
1321     // This identifier can only take a single value.
1322     if (lb) {
1323       // Set lb to that symbolic value.
1324       lb->resize(getNumSymbolIds() + 1);
1325       if (ub)
1326         ub->resize(getNumSymbolIds() + 1);
1327       for (unsigned c = 0, f = getNumSymbolIds() + 1; c < f; c++) {
1328         int64_t v = atEq(eqPos, pos);
1329         // atEq(eqRow, pos) is either -1 or 1.
1330         assert(v * v == 1);
1331         (*lb)[c] = v < 0 ? atEq(eqPos, getNumDimIds() + c) / -v
1332                          : -atEq(eqPos, getNumDimIds() + c) / v;
1333         // Since this is an equality, ub = lb.
1334         if (ub)
1335           (*ub)[c] = (*lb)[c];
1336       }
1337       assert(boundFloorDivisor &&
1338              "both lb and divisor or none should be provided");
1339       *boundFloorDivisor = 1;
1340     }
1341     if (minLbPos)
1342       *minLbPos = eqPos;
1343     if (minUbPos)
1344       *minUbPos = eqPos;
1345     return 1;
1346   }
1347 
1348   // Check if the identifier appears at all in any of the inequalities.
1349   unsigned r, e;
1350   for (r = 0, e = getNumInequalities(); r < e; r++) {
1351     if (atIneq(r, pos) != 0)
1352       break;
1353   }
1354   if (r == e)
1355     // If it doesn't, there isn't a bound on it.
1356     return None;
1357 
1358   // Positions of constraints that are lower/upper bounds on the variable.
1359   SmallVector<unsigned, 4> lbIndices, ubIndices;
1360 
1361   // Gather all symbolic lower bounds and upper bounds of the variable, i.e.,
1362   // the bounds can only involve symbolic (and local) identifiers. Since the
1363   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
1364   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
1365   getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices,
1366                                /*eqIndices=*/nullptr, /*offset=*/0,
1367                                /*num=*/getNumDimIds());
1368 
1369   Optional<int64_t> minDiff = None;
1370   unsigned minLbPosition = 0, minUbPosition = 0;
1371   for (auto ubPos : ubIndices) {
1372     for (auto lbPos : lbIndices) {
1373       // Look for a lower bound and an upper bound that only differ by a
1374       // constant, i.e., pairs of the form  0 <= c_pos - f(c_i's) <= diffConst.
1375       // For example, if ii is the pos^th variable, we are looking for
1376       // constraints like ii >= i, ii <= ii + 50, 50 being the difference. The
1377       // minimum among all such constant differences is kept since that's the
1378       // constant bounding the extent of the pos^th variable.
1379       unsigned j, e;
1380       for (j = 0, e = getNumCols() - 1; j < e; j++)
1381         if (atIneq(ubPos, j) != -atIneq(lbPos, j)) {
1382           break;
1383         }
1384       if (j < getNumCols() - 1)
1385         continue;
1386       int64_t diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) +
1387                                  atIneq(lbPos, getNumCols() - 1) + 1,
1388                              atIneq(lbPos, pos));
1389       // This bound is non-negative by definition.
1390       diff = std::max<int64_t>(diff, 0);
1391       if (minDiff == None || diff < minDiff) {
1392         minDiff = diff;
1393         minLbPosition = lbPos;
1394         minUbPosition = ubPos;
1395       }
1396     }
1397   }
1398   if (lb && minDiff.hasValue()) {
1399     // Set lb to the symbolic lower bound.
1400     lb->resize(getNumSymbolIds() + 1);
1401     if (ub)
1402       ub->resize(getNumSymbolIds() + 1);
1403     // The lower bound is the ceildiv of the lb constraint over the coefficient
1404     // of the variable at 'pos'. We express the ceildiv equivalently as a floor
1405     // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N +
1406     // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32).
1407     *boundFloorDivisor = atIneq(minLbPosition, pos);
1408     assert(*boundFloorDivisor == -atIneq(minUbPosition, pos));
1409     for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) {
1410       (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c);
1411     }
1412     if (ub) {
1413       for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++)
1414         (*ub)[c] = atIneq(minUbPosition, getNumDimIds() + c);
1415     }
1416     // The lower bound leads to a ceildiv while the upper bound is a floordiv
1417     // whenever the coefficient at pos != 1. ceildiv (val / d) = floordiv (val +
1418     // d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to
1419     // the constant term for the lower bound.
1420     (*lb)[getNumSymbolIds()] += atIneq(minLbPosition, pos) - 1;
1421   }
1422   if (minLbPos)
1423     *minLbPos = minLbPosition;
1424   if (minUbPos)
1425     *minUbPos = minUbPosition;
1426   return minDiff;
1427 }
1428 
1429 template <bool isLower>
1430 Optional<int64_t>
1431 IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) {
1432   assert(pos < getNumIds() && "invalid position");
1433   // Project to 'pos'.
1434   projectOut(0, pos);
1435   projectOut(1, getNumIds() - 1);
1436   // Check if there's an equality equating the '0'^th identifier to a constant.
1437   int eqRowIdx = findEqualityToConstant(*this, 0, /*symbolic=*/false);
1438   if (eqRowIdx != -1)
1439     // atEq(rowIdx, 0) is either -1 or 1.
1440     return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, 0);
1441 
1442   // Check if the identifier appears at all in any of the inequalities.
1443   unsigned r, e;
1444   for (r = 0, e = getNumInequalities(); r < e; r++) {
1445     if (atIneq(r, 0) != 0)
1446       break;
1447   }
1448   if (r == e)
1449     // If it doesn't, there isn't a bound on it.
1450     return None;
1451 
1452   Optional<int64_t> minOrMaxConst = None;
1453 
1454   // Take the max across all const lower bounds (or min across all constant
1455   // upper bounds).
1456   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1457     if (isLower) {
1458       if (atIneq(r, 0) <= 0)
1459         // Not a lower bound.
1460         continue;
1461     } else if (atIneq(r, 0) >= 0) {
1462       // Not an upper bound.
1463       continue;
1464     }
1465     unsigned c, f;
1466     for (c = 0, f = getNumCols() - 1; c < f; c++)
1467       if (c != 0 && atIneq(r, c) != 0)
1468         break;
1469     if (c < getNumCols() - 1)
1470       // Not a constant bound.
1471       continue;
1472 
1473     int64_t boundConst =
1474         isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0))
1475                 : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0));
1476     if (isLower) {
1477       if (minOrMaxConst == None || boundConst > minOrMaxConst)
1478         minOrMaxConst = boundConst;
1479     } else {
1480       if (minOrMaxConst == None || boundConst < minOrMaxConst)
1481         minOrMaxConst = boundConst;
1482     }
1483   }
1484   return minOrMaxConst;
1485 }
1486 
1487 Optional<int64_t> IntegerRelation::getConstantBound(BoundType type,
1488                                                     unsigned pos) const {
1489   if (type == BoundType::LB)
1490     return IntegerRelation(*this)
1491         .computeConstantLowerOrUpperBound</*isLower=*/true>(pos);
1492   if (type == BoundType::UB)
1493     return IntegerRelation(*this)
1494         .computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
1495 
1496   assert(type == BoundType::EQ && "expected EQ");
1497   Optional<int64_t> lb =
1498       IntegerRelation(*this).computeConstantLowerOrUpperBound</*isLower=*/true>(
1499           pos);
1500   Optional<int64_t> ub =
1501       IntegerRelation(*this)
1502           .computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
1503   return (lb && ub && *lb == *ub) ? Optional<int64_t>(*ub) : None;
1504 }
1505 
1506 // A simple (naive and conservative) check for hyper-rectangularity.
1507 bool IntegerRelation::isHyperRectangular(unsigned pos, unsigned num) const {
1508   assert(pos < getNumCols() - 1);
1509   // Check for two non-zero coefficients in the range [pos, pos + sum).
1510   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1511     unsigned sum = 0;
1512     for (unsigned c = pos; c < pos + num; c++) {
1513       if (atIneq(r, c) != 0)
1514         sum++;
1515     }
1516     if (sum > 1)
1517       return false;
1518   }
1519   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1520     unsigned sum = 0;
1521     for (unsigned c = pos; c < pos + num; c++) {
1522       if (atEq(r, c) != 0)
1523         sum++;
1524     }
1525     if (sum > 1)
1526       return false;
1527   }
1528   return true;
1529 }
1530 
1531 /// Removes duplicate constraints, trivially true constraints, and constraints
1532 /// that can be detected as redundant as a result of differing only in their
1533 /// constant term part. A constraint of the form <non-negative constant> >= 0 is
1534 /// considered trivially true.
1535 //  Uses a DenseSet to hash and detect duplicates followed by a linear scan to
1536 //  remove duplicates in place.
1537 void IntegerRelation::removeTrivialRedundancy() {
1538   gcdTightenInequalities();
1539   normalizeConstraintsByGCD();
1540 
1541   // A map used to detect redundancy stemming from constraints that only differ
1542   // in their constant term. The value stored is <row position, const term>
1543   // for a given row.
1544   SmallDenseMap<ArrayRef<int64_t>, std::pair<unsigned, int64_t>>
1545       rowsWithoutConstTerm;
1546   // To unique rows.
1547   SmallDenseSet<ArrayRef<int64_t>, 8> rowSet;
1548 
1549   // Check if constraint is of the form <non-negative-constant> >= 0.
1550   auto isTriviallyValid = [&](unsigned r) -> bool {
1551     for (unsigned c = 0, e = getNumCols() - 1; c < e; c++) {
1552       if (atIneq(r, c) != 0)
1553         return false;
1554     }
1555     return atIneq(r, getNumCols() - 1) >= 0;
1556   };
1557 
1558   // Detect and mark redundant constraints.
1559   SmallVector<bool, 256> redunIneq(getNumInequalities(), false);
1560   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1561     int64_t *rowStart = &inequalities(r, 0);
1562     auto row = ArrayRef<int64_t>(rowStart, getNumCols());
1563     if (isTriviallyValid(r) || !rowSet.insert(row).second) {
1564       redunIneq[r] = true;
1565       continue;
1566     }
1567 
1568     // Among constraints that only differ in the constant term part, mark
1569     // everything other than the one with the smallest constant term redundant.
1570     // (eg: among i - 16j - 5 >= 0, i - 16j - 1 >=0, i - 16j - 7 >= 0, the
1571     // former two are redundant).
1572     int64_t constTerm = atIneq(r, getNumCols() - 1);
1573     auto rowWithoutConstTerm = ArrayRef<int64_t>(rowStart, getNumCols() - 1);
1574     const auto &ret =
1575         rowsWithoutConstTerm.insert({rowWithoutConstTerm, {r, constTerm}});
1576     if (!ret.second) {
1577       // Check if the other constraint has a higher constant term.
1578       auto &val = ret.first->second;
1579       if (val.second > constTerm) {
1580         // The stored row is redundant. Mark it so, and update with this one.
1581         redunIneq[val.first] = true;
1582         val = {r, constTerm};
1583       } else {
1584         // The one stored makes this one redundant.
1585         redunIneq[r] = true;
1586       }
1587     }
1588   }
1589 
1590   // Scan to get rid of all rows marked redundant, in-place.
1591   unsigned pos = 0;
1592   for (unsigned r = 0, e = getNumInequalities(); r < e; r++)
1593     if (!redunIneq[r])
1594       inequalities.copyRow(r, pos++);
1595 
1596   inequalities.resizeVertically(pos);
1597 
1598   // TODO: consider doing this for equalities as well, but probably not worth
1599   // the savings.
1600 }
1601 
1602 #undef DEBUG_TYPE
1603 #define DEBUG_TYPE "fm"
1604 
1605 /// Eliminates identifier at the specified position using Fourier-Motzkin
1606 /// variable elimination. This technique is exact for rational spaces but
1607 /// conservative (in "rare" cases) for integer spaces. The operation corresponds
1608 /// to a projection operation yielding the (convex) set of integer points
1609 /// contained in the rational shadow of the set. An emptiness test that relies
1610 /// on this method will guarantee emptiness, i.e., it disproves the existence of
1611 /// a solution if it says it's empty.
1612 /// If a non-null isResultIntegerExact is passed, it is set to true if the
1613 /// result is also integer exact. If it's set to false, the obtained solution
1614 /// *may* not be exact, i.e., it may contain integer points that do not have an
1615 /// integer pre-image in the original set.
1616 ///
1617 /// Eg:
1618 /// j >= 0, j <= i + 1
1619 /// i >= 0, i <= N + 1
1620 /// Eliminating i yields,
1621 ///   j >= 0, 0 <= N + 1, j - 1 <= N + 1
1622 ///
1623 /// If darkShadow = true, this method computes the dark shadow on elimination;
1624 /// the dark shadow is a convex integer subset of the exact integer shadow. A
1625 /// non-empty dark shadow proves the existence of an integer solution. The
1626 /// elimination in such a case could however be an under-approximation, and thus
1627 /// should not be used for scanning sets or used by itself for dependence
1628 /// checking.
1629 ///
1630 /// Eg: 2-d set, * represents grid points, 'o' represents a point in the set.
1631 ///            ^
1632 ///            |
1633 ///            | * * * * o o
1634 ///         i  | * * o o o o
1635 ///            | o * * * * *
1636 ///            --------------->
1637 ///                 j ->
1638 ///
1639 /// Eliminating i from this system (projecting on the j dimension):
1640 /// rational shadow / integer light shadow:  1 <= j <= 6
1641 /// dark shadow:                             3 <= j <= 6
1642 /// exact integer shadow:                    j = 1 \union  3 <= j <= 6
1643 /// holes/splinters:                         j = 2
1644 ///
1645 /// darkShadow = false, isResultIntegerExact = nullptr are default values.
1646 // TODO: a slight modification to yield dark shadow version of FM (tightened),
1647 // which can prove the existence of a solution if there is one.
1648 void IntegerRelation::fourierMotzkinEliminate(unsigned pos, bool darkShadow,
1649                                               bool *isResultIntegerExact) {
1650   LLVM_DEBUG(llvm::dbgs() << "FM input (eliminate pos " << pos << "):\n");
1651   LLVM_DEBUG(dump());
1652   assert(pos < getNumIds() && "invalid position");
1653   assert(hasConsistentState());
1654 
1655   // Check if this identifier can be eliminated through a substitution.
1656   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1657     if (atEq(r, pos) != 0) {
1658       // Use Gaussian elimination here (since we have an equality).
1659       LogicalResult ret = gaussianEliminateId(pos);
1660       (void)ret;
1661       assert(succeeded(ret) && "Gaussian elimination guaranteed to succeed");
1662       LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n");
1663       LLVM_DEBUG(dump());
1664       return;
1665     }
1666   }
1667 
1668   // A fast linear time tightening.
1669   gcdTightenInequalities();
1670 
1671   // Check if the identifier appears at all in any of the inequalities.
1672   if (isColZero(pos)) {
1673     // If it doesn't appear, just remove the column and return.
1674     // TODO: refactor removeColumns to use it from here.
1675     removeId(pos);
1676     LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
1677     LLVM_DEBUG(dump());
1678     return;
1679   }
1680 
1681   // Positions of constraints that are lower bounds on the variable.
1682   SmallVector<unsigned, 4> lbIndices;
1683   // Positions of constraints that are lower bounds on the variable.
1684   SmallVector<unsigned, 4> ubIndices;
1685   // Positions of constraints that do not involve the variable.
1686   std::vector<unsigned> nbIndices;
1687   nbIndices.reserve(getNumInequalities());
1688 
1689   // Gather all lower bounds and upper bounds of the variable. Since the
1690   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
1691   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
1692   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1693     if (atIneq(r, pos) == 0) {
1694       // Id does not appear in bound.
1695       nbIndices.push_back(r);
1696     } else if (atIneq(r, pos) >= 1) {
1697       // Lower bound.
1698       lbIndices.push_back(r);
1699     } else {
1700       // Upper bound.
1701       ubIndices.push_back(r);
1702     }
1703   }
1704 
1705   // Set the number of dimensions, symbols, locals in the resulting system.
1706   unsigned newNumDomain =
1707       getNumDomainIds() - getIdKindOverlap(IdKind::Domain, pos, pos + 1);
1708   unsigned newNumRange =
1709       getNumRangeIds() - getIdKindOverlap(IdKind::Range, pos, pos + 1);
1710   unsigned newNumSymbols =
1711       getNumSymbolIds() - getIdKindOverlap(IdKind::Symbol, pos, pos + 1);
1712   unsigned newNumLocals =
1713       getNumLocalIds() - getIdKindOverlap(IdKind::Local, pos, pos + 1);
1714 
1715   /// Create the new system which has one identifier less.
1716   IntegerRelation newRel(lbIndices.size() * ubIndices.size() + nbIndices.size(),
1717                          getNumEqualities(), getNumCols() - 1, newNumDomain,
1718                          newNumRange, newNumSymbols, newNumLocals);
1719 
1720   // This will be used to check if the elimination was integer exact.
1721   unsigned lcmProducts = 1;
1722 
1723   // Let x be the variable we are eliminating.
1724   // For each lower bound, lb <= c_l*x, and each upper bound c_u*x <= ub, (note
1725   // that c_l, c_u >= 1) we have:
1726   // lb*lcm(c_l, c_u)/c_l <= lcm(c_l, c_u)*x <= ub*lcm(c_l, c_u)/c_u
1727   // We thus generate a constraint:
1728   // lcm(c_l, c_u)/c_l*lb <= lcm(c_l, c_u)/c_u*ub.
1729   // Note if c_l = c_u = 1, all integer points captured by the resulting
1730   // constraint correspond to integer points in the original system (i.e., they
1731   // have integer pre-images). Hence, if the lcm's are all 1, the elimination is
1732   // integer exact.
1733   for (auto ubPos : ubIndices) {
1734     for (auto lbPos : lbIndices) {
1735       SmallVector<int64_t, 4> ineq;
1736       ineq.reserve(newRel.getNumCols());
1737       int64_t lbCoeff = atIneq(lbPos, pos);
1738       // Note that in the comments above, ubCoeff is the negation of the
1739       // coefficient in the canonical form as the view taken here is that of the
1740       // term being moved to the other size of '>='.
1741       int64_t ubCoeff = -atIneq(ubPos, pos);
1742       // TODO: refactor this loop to avoid all branches inside.
1743       for (unsigned l = 0, e = getNumCols(); l < e; l++) {
1744         if (l == pos)
1745           continue;
1746         assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified");
1747         int64_t lcm = mlir::lcm(lbCoeff, ubCoeff);
1748         ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) +
1749                        atIneq(lbPos, l) * (lcm / lbCoeff));
1750         lcmProducts *= lcm;
1751       }
1752       if (darkShadow) {
1753         // The dark shadow is a convex subset of the exact integer shadow. If
1754         // there is a point here, it proves the existence of a solution.
1755         ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1;
1756       }
1757       // TODO: we need to have a way to add inequalities in-place in
1758       // IntegerRelation instead of creating and copying over.
1759       newRel.addInequality(ineq);
1760     }
1761   }
1762 
1763   LLVM_DEBUG(llvm::dbgs() << "FM isResultIntegerExact: " << (lcmProducts == 1)
1764                           << "\n");
1765   if (lcmProducts == 1 && isResultIntegerExact)
1766     *isResultIntegerExact = true;
1767 
1768   // Copy over the constraints not involving this variable.
1769   for (auto nbPos : nbIndices) {
1770     SmallVector<int64_t, 4> ineq;
1771     ineq.reserve(getNumCols() - 1);
1772     for (unsigned l = 0, e = getNumCols(); l < e; l++) {
1773       if (l == pos)
1774         continue;
1775       ineq.push_back(atIneq(nbPos, l));
1776     }
1777     newRel.addInequality(ineq);
1778   }
1779 
1780   assert(newRel.getNumConstraints() ==
1781          lbIndices.size() * ubIndices.size() + nbIndices.size());
1782 
1783   // Copy over the equalities.
1784   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1785     SmallVector<int64_t, 4> eq;
1786     eq.reserve(newRel.getNumCols());
1787     for (unsigned l = 0, e = getNumCols(); l < e; l++) {
1788       if (l == pos)
1789         continue;
1790       eq.push_back(atEq(r, l));
1791     }
1792     newRel.addEquality(eq);
1793   }
1794 
1795   // GCD tightening and normalization allows detection of more trivially
1796   // redundant constraints.
1797   newRel.gcdTightenInequalities();
1798   newRel.normalizeConstraintsByGCD();
1799   newRel.removeTrivialRedundancy();
1800   clearAndCopyFrom(newRel);
1801   LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
1802   LLVM_DEBUG(dump());
1803 }
1804 
1805 #undef DEBUG_TYPE
1806 #define DEBUG_TYPE "presburger"
1807 
1808 void IntegerRelation::projectOut(unsigned pos, unsigned num) {
1809   if (num == 0)
1810     return;
1811 
1812   // 'pos' can be at most getNumCols() - 2 if num > 0.
1813   assert((getNumCols() < 2 || pos <= getNumCols() - 2) && "invalid position");
1814   assert(pos + num < getNumCols() && "invalid range");
1815 
1816   // Eliminate as many identifiers as possible using Gaussian elimination.
1817   unsigned currentPos = pos;
1818   unsigned numToEliminate = num;
1819   unsigned numGaussianEliminated = 0;
1820 
1821   while (currentPos < getNumIds()) {
1822     unsigned curNumEliminated =
1823         gaussianEliminateIds(currentPos, currentPos + numToEliminate);
1824     ++currentPos;
1825     numToEliminate -= curNumEliminated + 1;
1826     numGaussianEliminated += curNumEliminated;
1827   }
1828 
1829   // Eliminate the remaining using Fourier-Motzkin.
1830   for (unsigned i = 0; i < num - numGaussianEliminated; i++) {
1831     unsigned numToEliminate = num - numGaussianEliminated - i;
1832     fourierMotzkinEliminate(
1833         getBestIdToEliminate(*this, pos, pos + numToEliminate));
1834   }
1835 
1836   // Fast/trivial simplifications.
1837   gcdTightenInequalities();
1838   // Normalize constraints after tightening since the latter impacts this, but
1839   // not the other way round.
1840   normalizeConstraintsByGCD();
1841 }
1842 
1843 namespace {
1844 
1845 enum BoundCmpResult { Greater, Less, Equal, Unknown };
1846 
1847 /// Compares two affine bounds whose coefficients are provided in 'first' and
1848 /// 'second'. The last coefficient is the constant term.
1849 static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
1850   assert(a.size() == b.size());
1851 
1852   // For the bounds to be comparable, their corresponding identifier
1853   // coefficients should be equal; the constant terms are then compared to
1854   // determine less/greater/equal.
1855 
1856   if (!std::equal(a.begin(), a.end() - 1, b.begin()))
1857     return Unknown;
1858 
1859   if (a.back() == b.back())
1860     return Equal;
1861 
1862   return a.back() < b.back() ? Less : Greater;
1863 }
1864 } // namespace
1865 
1866 // Returns constraints that are common to both A & B.
1867 static void getCommonConstraints(const IntegerRelation &a,
1868                                  const IntegerRelation &b, IntegerRelation &c) {
1869   c = IntegerRelation(a.getNumDomainIds(), a.getNumRangeIds(),
1870                       a.getNumSymbolIds(), a.getNumLocalIds());
1871   // a naive O(n^2) check should be enough here given the input sizes.
1872   for (unsigned r = 0, e = a.getNumInequalities(); r < e; ++r) {
1873     for (unsigned s = 0, f = b.getNumInequalities(); s < f; ++s) {
1874       if (a.getInequality(r) == b.getInequality(s)) {
1875         c.addInequality(a.getInequality(r));
1876         break;
1877       }
1878     }
1879   }
1880   for (unsigned r = 0, e = a.getNumEqualities(); r < e; ++r) {
1881     for (unsigned s = 0, f = b.getNumEqualities(); s < f; ++s) {
1882       if (a.getEquality(r) == b.getEquality(s)) {
1883         c.addEquality(a.getEquality(r));
1884         break;
1885       }
1886     }
1887   }
1888 }
1889 
1890 // Computes the bounding box with respect to 'other' by finding the min of the
1891 // lower bounds and the max of the upper bounds along each of the dimensions.
1892 LogicalResult
1893 IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
1894   assert(isSpaceEqual(otherCst) && "Spaces should match.");
1895   assert(getNumLocalIds() == 0 && "local ids not supported yet here");
1896 
1897   // Get the constraints common to both systems; these will be added as is to
1898   // the union.
1899   IntegerRelation commonCst;
1900   getCommonConstraints(*this, otherCst, commonCst);
1901 
1902   std::vector<SmallVector<int64_t, 8>> boundingLbs;
1903   std::vector<SmallVector<int64_t, 8>> boundingUbs;
1904   boundingLbs.reserve(2 * getNumDimIds());
1905   boundingUbs.reserve(2 * getNumDimIds());
1906 
1907   // To hold lower and upper bounds for each dimension.
1908   SmallVector<int64_t, 4> lb, otherLb, ub, otherUb;
1909   // To compute min of lower bounds and max of upper bounds for each dimension.
1910   SmallVector<int64_t, 4> minLb(getNumSymbolIds() + 1);
1911   SmallVector<int64_t, 4> maxUb(getNumSymbolIds() + 1);
1912   // To compute final new lower and upper bounds for the union.
1913   SmallVector<int64_t, 8> newLb(getNumCols()), newUb(getNumCols());
1914 
1915   int64_t lbFloorDivisor, otherLbFloorDivisor;
1916   for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
1917     auto extent = getConstantBoundOnDimSize(d, &lb, &lbFloorDivisor, &ub);
1918     if (!extent.hasValue())
1919       // TODO: symbolic extents when necessary.
1920       // TODO: handle union if a dimension is unbounded.
1921       return failure();
1922 
1923     auto otherExtent = otherCst.getConstantBoundOnDimSize(
1924         d, &otherLb, &otherLbFloorDivisor, &otherUb);
1925     if (!otherExtent.hasValue() || lbFloorDivisor != otherLbFloorDivisor)
1926       // TODO: symbolic extents when necessary.
1927       return failure();
1928 
1929     assert(lbFloorDivisor > 0 && "divisor always expected to be positive");
1930 
1931     auto res = compareBounds(lb, otherLb);
1932     // Identify min.
1933     if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) {
1934       minLb = lb;
1935       // Since the divisor is for a floordiv, we need to convert to ceildiv,
1936       // i.e., i >= expr floordiv div <=> i >= (expr - div + 1) ceildiv div <=>
1937       // div * i >= expr - div + 1.
1938       minLb.back() -= lbFloorDivisor - 1;
1939     } else if (res == BoundCmpResult::Greater) {
1940       minLb = otherLb;
1941       minLb.back() -= otherLbFloorDivisor - 1;
1942     } else {
1943       // Uncomparable - check for constant lower/upper bounds.
1944       auto constLb = getConstantBound(BoundType::LB, d);
1945       auto constOtherLb = otherCst.getConstantBound(BoundType::LB, d);
1946       if (!constLb.hasValue() || !constOtherLb.hasValue())
1947         return failure();
1948       std::fill(minLb.begin(), minLb.end(), 0);
1949       minLb.back() = std::min(constLb.getValue(), constOtherLb.getValue());
1950     }
1951 
1952     // Do the same for ub's but max of upper bounds. Identify max.
1953     auto uRes = compareBounds(ub, otherUb);
1954     if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) {
1955       maxUb = ub;
1956     } else if (uRes == BoundCmpResult::Less) {
1957       maxUb = otherUb;
1958     } else {
1959       // Uncomparable - check for constant lower/upper bounds.
1960       auto constUb = getConstantBound(BoundType::UB, d);
1961       auto constOtherUb = otherCst.getConstantBound(BoundType::UB, d);
1962       if (!constUb.hasValue() || !constOtherUb.hasValue())
1963         return failure();
1964       std::fill(maxUb.begin(), maxUb.end(), 0);
1965       maxUb.back() = std::max(constUb.getValue(), constOtherUb.getValue());
1966     }
1967 
1968     std::fill(newLb.begin(), newLb.end(), 0);
1969     std::fill(newUb.begin(), newUb.end(), 0);
1970 
1971     // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor,
1972     // and so it's the divisor for newLb and newUb as well.
1973     newLb[d] = lbFloorDivisor;
1974     newUb[d] = -lbFloorDivisor;
1975     // Copy over the symbolic part + constant term.
1976     std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimIds());
1977     std::transform(newLb.begin() + getNumDimIds(), newLb.end(),
1978                    newLb.begin() + getNumDimIds(), std::negate<int64_t>());
1979     std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimIds());
1980 
1981     boundingLbs.push_back(newLb);
1982     boundingUbs.push_back(newUb);
1983   }
1984 
1985   // Clear all constraints and add the lower/upper bounds for the bounding box.
1986   clearConstraints();
1987   for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
1988     addInequality(boundingLbs[d]);
1989     addInequality(boundingUbs[d]);
1990   }
1991 
1992   // Add the constraints that were common to both systems.
1993   append(commonCst);
1994   removeTrivialRedundancy();
1995 
1996   // TODO: copy over pure symbolic constraints from this and 'other' over to the
1997   // union (since the above are just the union along dimensions); we shouldn't
1998   // be discarding any other constraints on the symbols.
1999 
2000   return success();
2001 }
2002 
2003 bool IntegerRelation::isColZero(unsigned pos) const {
2004   unsigned rowPos;
2005   return !findConstraintWithNonZeroAt(pos, /*isEq=*/false, &rowPos) &&
2006          !findConstraintWithNonZeroAt(pos, /*isEq=*/true, &rowPos);
2007 }
2008 
2009 /// Find positions of inequalities and equalities that do not have a coefficient
2010 /// for [pos, pos + num) identifiers.
2011 static void getIndependentConstraints(const IntegerRelation &cst, unsigned pos,
2012                                       unsigned num,
2013                                       SmallVectorImpl<unsigned> &nbIneqIndices,
2014                                       SmallVectorImpl<unsigned> &nbEqIndices) {
2015   assert(pos < cst.getNumIds() && "invalid start position");
2016   assert(pos + num <= cst.getNumIds() && "invalid limit");
2017 
2018   for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
2019     // The bounds are to be independent of [offset, offset + num) columns.
2020     unsigned c;
2021     for (c = pos; c < pos + num; ++c) {
2022       if (cst.atIneq(r, c) != 0)
2023         break;
2024     }
2025     if (c == pos + num)
2026       nbIneqIndices.push_back(r);
2027   }
2028 
2029   for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
2030     // The bounds are to be independent of [offset, offset + num) columns.
2031     unsigned c;
2032     for (c = pos; c < pos + num; ++c) {
2033       if (cst.atEq(r, c) != 0)
2034         break;
2035     }
2036     if (c == pos + num)
2037       nbEqIndices.push_back(r);
2038   }
2039 }
2040 
2041 void IntegerRelation::removeIndependentConstraints(unsigned pos, unsigned num) {
2042   assert(pos + num <= getNumIds() && "invalid range");
2043 
2044   // Remove constraints that are independent of these identifiers.
2045   SmallVector<unsigned, 4> nbIneqIndices, nbEqIndices;
2046   getIndependentConstraints(*this, /*pos=*/0, num, nbIneqIndices, nbEqIndices);
2047 
2048   // Iterate in reverse so that indices don't have to be updated.
2049   // TODO: This method can be made more efficient (because removal of each
2050   // inequality leads to much shifting/copying in the underlying buffer).
2051   for (auto nbIndex : llvm::reverse(nbIneqIndices))
2052     removeInequality(nbIndex);
2053   for (auto nbIndex : llvm::reverse(nbEqIndices))
2054     removeEquality(nbIndex);
2055 }
2056 
2057 void IntegerRelation::printSpace(raw_ostream &os) const {
2058   PresburgerSpace::print(os);
2059   os << getNumConstraints() << " constraints\n";
2060 }
2061 
2062 void IntegerRelation::print(raw_ostream &os) const {
2063   assert(hasConsistentState());
2064   printSpace(os);
2065   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
2066     for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
2067       os << atEq(i, j) << " ";
2068     }
2069     os << "= 0\n";
2070   }
2071   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
2072     for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
2073       os << atIneq(i, j) << " ";
2074     }
2075     os << ">= 0\n";
2076   }
2077   os << '\n';
2078 }
2079 
2080 void IntegerRelation::dump() const { print(llvm::errs()); }
2081 
2082 unsigned IntegerPolyhedron::insertId(IdKind kind, unsigned pos, unsigned num) {
2083   assert((kind != IdKind::Domain || num == 0) &&
2084          "Domain has to be zero in a set");
2085   return IntegerRelation::insertId(kind, pos, num);
2086 }
2087