1 //===- IntegerRelation.cpp - MLIR IntegerRelation Class ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A class to represent an relation over integer tuples. A relation is
10 // represented as a constraint system over a space of tuples of integer valued
11 // varaiables supporting symbolic identifiers and existential quantification.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Analysis/Presburger/IntegerRelation.h"
16 #include "mlir/Analysis/Presburger/LinearTransform.h"
17 #include "mlir/Analysis/Presburger/PresburgerSet.h"
18 #include "mlir/Analysis/Presburger/Simplex.h"
19 #include "mlir/Analysis/Presburger/Utils.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/DenseSet.h"
22 #include "llvm/Support/Debug.h"
23 
24 #define DEBUG_TYPE "presburger"
25 
26 using namespace mlir;
27 using namespace presburger;
28 
29 using llvm::SmallDenseMap;
30 using llvm::SmallDenseSet;
31 
32 std::unique_ptr<IntegerRelation> IntegerRelation::clone() const {
33   return std::make_unique<IntegerRelation>(*this);
34 }
35 
36 std::unique_ptr<IntegerPolyhedron> IntegerPolyhedron::clone() const {
37   return std::make_unique<IntegerPolyhedron>(*this);
38 }
39 
40 void IntegerRelation::append(const IntegerRelation &other) {
41   assert(PresburgerLocalSpace::isEqual(other) && "Spaces must be equal.");
42 
43   inequalities.reserveRows(inequalities.getNumRows() +
44                            other.getNumInequalities());
45   equalities.reserveRows(equalities.getNumRows() + other.getNumEqualities());
46 
47   for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) {
48     addInequality(other.getInequality(r));
49   }
50   for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) {
51     addEquality(other.getEquality(r));
52   }
53 }
54 
55 static IntegerPolyhedron createSetFromRelation(const IntegerRelation &rel) {
56   IntegerPolyhedron result(rel.getNumDimIds(), rel.getNumSymbolIds(),
57                            rel.getNumLocalIds());
58 
59   for (unsigned i = 0, e = rel.getNumInequalities(); i < e; ++i)
60     result.addInequality(rel.getInequality(i));
61   for (unsigned i = 0, e = rel.getNumEqualities(); i < e; ++i)
62     result.addEquality(rel.getEquality(i));
63 
64   return result;
65 }
66 
67 bool IntegerRelation::isEqual(const IntegerRelation &other) const {
68   assert(PresburgerLocalSpace::isEqual(other) && "Spaces must be equal.");
69   return PresburgerSet(createSetFromRelation(*this))
70       .isEqual(PresburgerSet(createSetFromRelation(other)));
71 }
72 
73 bool IntegerRelation::isSubsetOf(const IntegerRelation &other) const {
74   assert(PresburgerLocalSpace::isEqual(other) && "Spaces must be equal.");
75   return PresburgerSet(createSetFromRelation(*this))
76       .isSubsetOf(PresburgerSet(createSetFromRelation(other)));
77 }
78 
79 MaybeOptimum<SmallVector<Fraction, 8>>
80 IntegerRelation::findRationalLexMin() const {
81   assert(getNumSymbolIds() == 0 && "Symbols are not supported!");
82   MaybeOptimum<SmallVector<Fraction, 8>> maybeLexMin =
83       LexSimplex(*this).findRationalLexMin();
84 
85   if (!maybeLexMin.isBounded())
86     return maybeLexMin;
87 
88   // The Simplex returns the lexmin over all the variables including locals. But
89   // locals are not actually part of the space and should not be returned in the
90   // result. Since the locals are placed last in the list of identifiers, they
91   // will be minimized last in the lexmin. So simply truncating out the locals
92   // from the end of the answer gives the desired lexmin over the dimensions.
93   assert(maybeLexMin->size() == getNumIds() &&
94          "Incorrect number of vars in lexMin!");
95   maybeLexMin->resize(getNumDimAndSymbolIds());
96   return maybeLexMin;
97 }
98 
99 MaybeOptimum<SmallVector<int64_t, 8>>
100 IntegerRelation::findIntegerLexMin() const {
101   assert(getNumSymbolIds() == 0 && "Symbols are not supported!");
102   MaybeOptimum<SmallVector<int64_t, 8>> maybeLexMin =
103       LexSimplex(*this).findIntegerLexMin();
104 
105   if (!maybeLexMin.isBounded())
106     return maybeLexMin.getKind();
107 
108   // The Simplex returns the lexmin over all the variables including locals. But
109   // locals are not actually part of the space and should not be returned in the
110   // result. Since the locals are placed last in the list of identifiers, they
111   // will be minimized last in the lexmin. So simply truncating out the locals
112   // from the end of the answer gives the desired lexmin over the dimensions.
113   assert(maybeLexMin->size() == getNumIds() &&
114          "Incorrect number of vars in lexMin!");
115   maybeLexMin->resize(getNumDimAndSymbolIds());
116   return maybeLexMin;
117 }
118 
119 unsigned IntegerRelation::insertId(IdKind kind, unsigned pos, unsigned num) {
120   assert(pos <= getNumIdKind(kind));
121 
122   unsigned insertPos = PresburgerLocalSpace::insertId(kind, pos, num);
123   inequalities.insertColumns(insertPos, num);
124   equalities.insertColumns(insertPos, num);
125   return insertPos;
126 }
127 
128 unsigned IntegerRelation::appendId(IdKind kind, unsigned num) {
129   unsigned pos = getNumIdKind(kind);
130   return insertId(kind, pos, num);
131 }
132 
133 void IntegerRelation::addEquality(ArrayRef<int64_t> eq) {
134   assert(eq.size() == getNumCols());
135   unsigned row = equalities.appendExtraRow();
136   for (unsigned i = 0, e = eq.size(); i < e; ++i)
137     equalities(row, i) = eq[i];
138 }
139 
140 void IntegerRelation::addInequality(ArrayRef<int64_t> inEq) {
141   assert(inEq.size() == getNumCols());
142   unsigned row = inequalities.appendExtraRow();
143   for (unsigned i = 0, e = inEq.size(); i < e; ++i)
144     inequalities(row, i) = inEq[i];
145 }
146 
147 void IntegerRelation::removeId(IdKind kind, unsigned pos) {
148   removeIdRange(kind, pos, pos + 1);
149 }
150 
151 void IntegerRelation::removeId(unsigned pos) { removeIdRange(pos, pos + 1); }
152 
153 void IntegerRelation::removeIdRange(IdKind kind, unsigned idStart,
154                                     unsigned idLimit) {
155   assert(idLimit <= getNumIdKind(kind));
156   removeIdRange(getIdKindOffset(kind) + idStart,
157                 getIdKindOffset(kind) + idLimit);
158 }
159 
160 void IntegerRelation::removeIdRange(unsigned idStart, unsigned idLimit) {
161   // Update space paramaters.
162   PresburgerLocalSpace::removeIdRange(idStart, idLimit);
163 
164   // Remove eliminated identifiers from the constraints..
165   equalities.removeColumns(idStart, idLimit - idStart);
166   inequalities.removeColumns(idStart, idLimit - idStart);
167 }
168 
169 void IntegerRelation::removeEquality(unsigned pos) {
170   equalities.removeRow(pos);
171 }
172 
173 void IntegerRelation::removeInequality(unsigned pos) {
174   inequalities.removeRow(pos);
175 }
176 
177 void IntegerRelation::removeEqualityRange(unsigned start, unsigned end) {
178   if (start >= end)
179     return;
180   equalities.removeRows(start, end - start);
181 }
182 
183 void IntegerRelation::removeInequalityRange(unsigned start, unsigned end) {
184   if (start >= end)
185     return;
186   inequalities.removeRows(start, end - start);
187 }
188 
189 void IntegerRelation::swapId(unsigned posA, unsigned posB) {
190   assert(posA < getNumIds() && "invalid position A");
191   assert(posB < getNumIds() && "invalid position B");
192 
193   if (posA == posB)
194     return;
195 
196   inequalities.swapColumns(posA, posB);
197   equalities.swapColumns(posA, posB);
198 }
199 
200 void IntegerRelation::clearConstraints() {
201   equalities.resizeVertically(0);
202   inequalities.resizeVertically(0);
203 }
204 
205 /// Gather all lower and upper bounds of the identifier at `pos`, and
206 /// optionally any equalities on it. In addition, the bounds are to be
207 /// independent of identifiers in position range [`offset`, `offset` + `num`).
208 void IntegerRelation::getLowerAndUpperBoundIndices(
209     unsigned pos, SmallVectorImpl<unsigned> *lbIndices,
210     SmallVectorImpl<unsigned> *ubIndices, SmallVectorImpl<unsigned> *eqIndices,
211     unsigned offset, unsigned num) const {
212   assert(pos < getNumIds() && "invalid position");
213   assert(offset + num < getNumCols() && "invalid range");
214 
215   // Checks for a constraint that has a non-zero coeff for the identifiers in
216   // the position range [offset, offset + num) while ignoring `pos`.
217   auto containsConstraintDependentOnRange = [&](unsigned r, bool isEq) {
218     unsigned c, f;
219     auto cst = isEq ? getEquality(r) : getInequality(r);
220     for (c = offset, f = offset + num; c < f; ++c) {
221       if (c == pos)
222         continue;
223       if (cst[c] != 0)
224         break;
225     }
226     return c < f;
227   };
228 
229   // Gather all lower bounds and upper bounds of the variable. Since the
230   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
231   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
232   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
233     // The bounds are to be independent of [offset, offset + num) columns.
234     if (containsConstraintDependentOnRange(r, /*isEq=*/false))
235       continue;
236     if (atIneq(r, pos) >= 1) {
237       // Lower bound.
238       lbIndices->push_back(r);
239     } else if (atIneq(r, pos) <= -1) {
240       // Upper bound.
241       ubIndices->push_back(r);
242     }
243   }
244 
245   // An equality is both a lower and upper bound. Record any equalities
246   // involving the pos^th identifier.
247   if (!eqIndices)
248     return;
249 
250   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
251     if (atEq(r, pos) == 0)
252       continue;
253     if (containsConstraintDependentOnRange(r, /*isEq=*/true))
254       continue;
255     eqIndices->push_back(r);
256   }
257 }
258 
259 bool IntegerRelation::hasConsistentState() const {
260   if (!inequalities.hasConsistentState())
261     return false;
262   if (!equalities.hasConsistentState())
263     return false;
264   return true;
265 }
266 
267 void IntegerRelation::setAndEliminate(unsigned pos, ArrayRef<int64_t> values) {
268   if (values.empty())
269     return;
270   assert(pos + values.size() <= getNumIds() &&
271          "invalid position or too many values");
272   // Setting x_j = p in sum_i a_i x_i + c is equivalent to adding p*a_j to the
273   // constant term and removing the id x_j. We do this for all the ids
274   // pos, pos + 1, ... pos + values.size() - 1.
275   unsigned constantColPos = getNumCols() - 1;
276   for (unsigned i = 0, numVals = values.size(); i < numVals; ++i)
277     inequalities.addToColumn(i + pos, constantColPos, values[i]);
278   for (unsigned i = 0, numVals = values.size(); i < numVals; ++i)
279     equalities.addToColumn(i + pos, constantColPos, values[i]);
280   removeIdRange(pos, pos + values.size());
281 }
282 
283 void IntegerRelation::clearAndCopyFrom(const IntegerRelation &other) {
284   *this = other;
285 }
286 
287 // Searches for a constraint with a non-zero coefficient at `colIdx` in
288 // equality (isEq=true) or inequality (isEq=false) constraints.
289 // Returns true and sets row found in search in `rowIdx`, false otherwise.
290 bool IntegerRelation::findConstraintWithNonZeroAt(unsigned colIdx, bool isEq,
291                                                   unsigned *rowIdx) const {
292   assert(colIdx < getNumCols() && "position out of bounds");
293   auto at = [&](unsigned rowIdx) -> int64_t {
294     return isEq ? atEq(rowIdx, colIdx) : atIneq(rowIdx, colIdx);
295   };
296   unsigned e = isEq ? getNumEqualities() : getNumInequalities();
297   for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) {
298     if (at(*rowIdx) != 0) {
299       return true;
300     }
301   }
302   return false;
303 }
304 
305 void IntegerRelation::normalizeConstraintsByGCD() {
306   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
307     equalities.normalizeRow(i);
308   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i)
309     inequalities.normalizeRow(i);
310 }
311 
312 bool IntegerRelation::hasInvalidConstraint() const {
313   assert(hasConsistentState());
314   auto check = [&](bool isEq) -> bool {
315     unsigned numCols = getNumCols();
316     unsigned numRows = isEq ? getNumEqualities() : getNumInequalities();
317     for (unsigned i = 0, e = numRows; i < e; ++i) {
318       unsigned j;
319       for (j = 0; j < numCols - 1; ++j) {
320         int64_t v = isEq ? atEq(i, j) : atIneq(i, j);
321         // Skip rows with non-zero variable coefficients.
322         if (v != 0)
323           break;
324       }
325       if (j < numCols - 1) {
326         continue;
327       }
328       // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'.
329       // Example invalid constraints include: '1 == 0' or '-1 >= 0'
330       int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1);
331       if ((isEq && v != 0) || (!isEq && v < 0)) {
332         return true;
333       }
334     }
335     return false;
336   };
337   if (check(/*isEq=*/true))
338     return true;
339   return check(/*isEq=*/false);
340 }
341 
342 /// Eliminate identifier from constraint at `rowIdx` based on coefficient at
343 /// pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be
344 /// updated as they have already been eliminated.
345 static void eliminateFromConstraint(IntegerRelation *constraints,
346                                     unsigned rowIdx, unsigned pivotRow,
347                                     unsigned pivotCol, unsigned elimColStart,
348                                     bool isEq) {
349   // Skip if equality 'rowIdx' if same as 'pivotRow'.
350   if (isEq && rowIdx == pivotRow)
351     return;
352   auto at = [&](unsigned i, unsigned j) -> int64_t {
353     return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j);
354   };
355   int64_t leadCoeff = at(rowIdx, pivotCol);
356   // Skip if leading coefficient at 'rowIdx' is already zero.
357   if (leadCoeff == 0)
358     return;
359   int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol);
360   int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1;
361   int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff);
362   int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff));
363   int64_t rowMultiplier = lcm / std::abs(leadCoeff);
364 
365   unsigned numCols = constraints->getNumCols();
366   for (unsigned j = 0; j < numCols; ++j) {
367     // Skip updating column 'j' if it was just eliminated.
368     if (j >= elimColStart && j < pivotCol)
369       continue;
370     int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) +
371                 rowMultiplier * at(rowIdx, j);
372     isEq ? constraints->atEq(rowIdx, j) = v
373          : constraints->atIneq(rowIdx, j) = v;
374   }
375 }
376 
377 /// Returns the position of the identifier that has the minimum <number of lower
378 /// bounds> times <number of upper bounds> from the specified range of
379 /// identifiers [start, end). It is often best to eliminate in the increasing
380 /// order of these counts when doing Fourier-Motzkin elimination since FM adds
381 /// that many new constraints.
382 static unsigned getBestIdToEliminate(const IntegerRelation &cst, unsigned start,
383                                      unsigned end) {
384   assert(start < cst.getNumIds() && end < cst.getNumIds() + 1);
385 
386   auto getProductOfNumLowerUpperBounds = [&](unsigned pos) {
387     unsigned numLb = 0;
388     unsigned numUb = 0;
389     for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
390       if (cst.atIneq(r, pos) > 0) {
391         ++numLb;
392       } else if (cst.atIneq(r, pos) < 0) {
393         ++numUb;
394       }
395     }
396     return numLb * numUb;
397   };
398 
399   unsigned minLoc = start;
400   unsigned min = getProductOfNumLowerUpperBounds(start);
401   for (unsigned c = start + 1; c < end; c++) {
402     unsigned numLbUbProduct = getProductOfNumLowerUpperBounds(c);
403     if (numLbUbProduct < min) {
404       min = numLbUbProduct;
405       minLoc = c;
406     }
407   }
408   return minLoc;
409 }
410 
411 // Checks for emptiness of the set by eliminating identifiers successively and
412 // using the GCD test (on all equality constraints) and checking for trivially
413 // invalid constraints. Returns 'true' if the constraint system is found to be
414 // empty; false otherwise.
415 bool IntegerRelation::isEmpty() const {
416   if (isEmptyByGCDTest() || hasInvalidConstraint())
417     return true;
418 
419   IntegerRelation tmpCst(*this);
420 
421   // First, eliminate as many local variables as possible using equalities.
422   tmpCst.removeRedundantLocalVars();
423   if (tmpCst.isEmptyByGCDTest() || tmpCst.hasInvalidConstraint())
424     return true;
425 
426   // Eliminate as many identifiers as possible using Gaussian elimination.
427   unsigned currentPos = 0;
428   while (currentPos < tmpCst.getNumIds()) {
429     tmpCst.gaussianEliminateIds(currentPos, tmpCst.getNumIds());
430     ++currentPos;
431     // We check emptiness through trivial checks after eliminating each ID to
432     // detect emptiness early. Since the checks isEmptyByGCDTest() and
433     // hasInvalidConstraint() are linear time and single sweep on the constraint
434     // buffer, this appears reasonable - but can optimize in the future.
435     if (tmpCst.hasInvalidConstraint() || tmpCst.isEmptyByGCDTest())
436       return true;
437   }
438 
439   // Eliminate the remaining using FM.
440   for (unsigned i = 0, e = tmpCst.getNumIds(); i < e; i++) {
441     tmpCst.fourierMotzkinEliminate(
442         getBestIdToEliminate(tmpCst, 0, tmpCst.getNumIds()));
443     // Check for a constraint explosion. This rarely happens in practice, but
444     // this check exists as a safeguard against improperly constructed
445     // constraint systems or artificially created arbitrarily complex systems
446     // that aren't the intended use case for IntegerRelation. This is
447     // needed since FM has a worst case exponential complexity in theory.
448     if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumIds()) {
449       LLVM_DEBUG(llvm::dbgs() << "FM constraint explosion detected\n");
450       return false;
451     }
452 
453     // FM wouldn't have modified the equalities in any way. So no need to again
454     // run GCD test. Check for trivial invalid constraints.
455     if (tmpCst.hasInvalidConstraint())
456       return true;
457   }
458   return false;
459 }
460 
461 // Runs the GCD test on all equality constraints. Returns 'true' if this test
462 // fails on any equality. Returns 'false' otherwise.
463 // This test can be used to disprove the existence of a solution. If it returns
464 // true, no integer solution to the equality constraints can exist.
465 //
466 // GCD test definition:
467 //
468 // The equality constraint:
469 //
470 //  c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0
471 //
472 // has an integer solution iff:
473 //
474 //  GCD of c_1, c_2, ..., c_n divides c_0.
475 //
476 bool IntegerRelation::isEmptyByGCDTest() const {
477   assert(hasConsistentState());
478   unsigned numCols = getNumCols();
479   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
480     uint64_t gcd = std::abs(atEq(i, 0));
481     for (unsigned j = 1; j < numCols - 1; ++j) {
482       gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atEq(i, j)));
483     }
484     int64_t v = std::abs(atEq(i, numCols - 1));
485     if (gcd > 0 && (v % gcd != 0)) {
486       return true;
487     }
488   }
489   return false;
490 }
491 
492 // Returns a matrix where each row is a vector along which the polytope is
493 // bounded. The span of the returned vectors is guaranteed to contain all
494 // such vectors. The returned vectors are NOT guaranteed to be linearly
495 // independent. This function should not be called on empty sets.
496 //
497 // It is sufficient to check the perpendiculars of the constraints, as the set
498 // of perpendiculars which are bounded must span all bounded directions.
499 Matrix IntegerRelation::getBoundedDirections() const {
500   // Note that it is necessary to add the equalities too (which the constructor
501   // does) even though we don't need to check if they are bounded; whether an
502   // inequality is bounded or not depends on what other constraints, including
503   // equalities, are present.
504   Simplex simplex(*this);
505 
506   assert(!simplex.isEmpty() && "It is not meaningful to ask whether a "
507                                "direction is bounded in an empty set.");
508 
509   SmallVector<unsigned, 8> boundedIneqs;
510   // The constructor adds the inequalities to the simplex first, so this
511   // processes all the inequalities.
512   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
513     if (simplex.isBoundedAlongConstraint(i))
514       boundedIneqs.push_back(i);
515   }
516 
517   // The direction vector is given by the coefficients and does not include the
518   // constant term, so the matrix has one fewer column.
519   unsigned dirsNumCols = getNumCols() - 1;
520   Matrix dirs(boundedIneqs.size() + getNumEqualities(), dirsNumCols);
521 
522   // Copy the bounded inequalities.
523   unsigned row = 0;
524   for (unsigned i : boundedIneqs) {
525     for (unsigned col = 0; col < dirsNumCols; ++col)
526       dirs(row, col) = atIneq(i, col);
527     ++row;
528   }
529 
530   // Copy the equalities. All the equalities' perpendiculars are bounded.
531   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
532     for (unsigned col = 0; col < dirsNumCols; ++col)
533       dirs(row, col) = atEq(i, col);
534     ++row;
535   }
536 
537   return dirs;
538 }
539 
540 bool eqInvolvesSuffixDims(const IntegerRelation &rel, unsigned eqIndex,
541                           unsigned numDims) {
542   for (unsigned e = rel.getNumIds(), j = e - numDims; j < e; ++j)
543     if (rel.atEq(eqIndex, j) != 0)
544       return true;
545   return false;
546 }
547 bool ineqInvolvesSuffixDims(const IntegerRelation &rel, unsigned ineqIndex,
548                             unsigned numDims) {
549   for (unsigned e = rel.getNumIds(), j = e - numDims; j < e; ++j)
550     if (rel.atIneq(ineqIndex, j) != 0)
551       return true;
552   return false;
553 }
554 
555 void removeConstraintsInvolvingSuffixDims(IntegerRelation &rel,
556                                           unsigned unboundedDims) {
557   // We iterate backwards so that whether we remove constraint i - 1 or not, the
558   // next constraint to be tested is always i - 2.
559   for (unsigned i = rel.getNumEqualities(); i > 0; i--)
560     if (eqInvolvesSuffixDims(rel, i - 1, unboundedDims))
561       rel.removeEquality(i - 1);
562   for (unsigned i = rel.getNumInequalities(); i > 0; i--)
563     if (ineqInvolvesSuffixDims(rel, i - 1, unboundedDims))
564       rel.removeInequality(i - 1);
565 }
566 
567 bool IntegerRelation::isIntegerEmpty() const {
568   return !findIntegerSample().hasValue();
569 }
570 
571 /// Let this set be S. If S is bounded then we directly call into the GBR
572 /// sampling algorithm. Otherwise, there are some unbounded directions, i.e.,
573 /// vectors v such that S extends to infinity along v or -v. In this case we
574 /// use an algorithm described in the integer set library (isl) manual and used
575 /// by the isl_set_sample function in that library. The algorithm is:
576 ///
577 /// 1) Apply a unimodular transform T to S to obtain S*T, such that all
578 /// dimensions in which S*T is bounded lie in the linear span of a prefix of the
579 /// dimensions.
580 ///
581 /// 2) Construct a set B by removing all constraints that involve
582 /// the unbounded dimensions and then deleting the unbounded dimensions. Note
583 /// that B is a Bounded set.
584 ///
585 /// 3) Try to obtain a sample from B using the GBR sampling
586 /// algorithm. If no sample is found, return that S is empty.
587 ///
588 /// 4) Otherwise, substitute the obtained sample into S*T to obtain a set
589 /// C. C is a full-dimensional Cone and always contains a sample.
590 ///
591 /// 5) Obtain an integer sample from C.
592 ///
593 /// 6) Return T*v, where v is the concatenation of the samples from B and C.
594 ///
595 /// The following is a sketch of a proof that
596 /// a) If the algorithm returns empty, then S is empty.
597 /// b) If the algorithm returns a sample, it is a valid sample in S.
598 ///
599 /// The algorithm returns empty only if B is empty, in which case S*T is
600 /// certainly empty since B was obtained by removing constraints and then
601 /// deleting unconstrained dimensions from S*T. Since T is unimodular, a vector
602 /// v is in S*T iff T*v is in S. So in this case, since
603 /// S*T is empty, S is empty too.
604 ///
605 /// Otherwise, the algorithm substitutes the sample from B into S*T. All the
606 /// constraints of S*T that did not involve unbounded dimensions are satisfied
607 /// by this substitution. All dimensions in the linear span of the dimensions
608 /// outside the prefix are unbounded in S*T (step 1). Substituting values for
609 /// the bounded dimensions cannot make these dimensions bounded, and these are
610 /// the only remaining dimensions in C, so C is unbounded along every vector (in
611 /// the positive or negative direction, or both). C is hence a full-dimensional
612 /// cone and therefore always contains an integer point.
613 ///
614 /// Concatenating the samples from B and C gives a sample v in S*T, so the
615 /// returned sample T*v is a sample in S.
616 Optional<SmallVector<int64_t, 8>> IntegerRelation::findIntegerSample() const {
617   // First, try the GCD test heuristic.
618   if (isEmptyByGCDTest())
619     return {};
620 
621   Simplex simplex(*this);
622   if (simplex.isEmpty())
623     return {};
624 
625   // For a bounded set, we directly call into the GBR sampling algorithm.
626   if (!simplex.isUnbounded())
627     return simplex.findIntegerSample();
628 
629   // The set is unbounded. We cannot directly use the GBR algorithm.
630   //
631   // m is a matrix containing, in each row, a vector in which S is
632   // bounded, such that the linear span of all these dimensions contains all
633   // bounded dimensions in S.
634   Matrix m = getBoundedDirections();
635   // In column echelon form, each row of m occupies only the first rank(m)
636   // columns and has zeros on the other columns. The transform T that brings S
637   // to column echelon form is unimodular as well, so this is a suitable
638   // transform to use in step 1 of the algorithm.
639   std::pair<unsigned, LinearTransform> result =
640       LinearTransform::makeTransformToColumnEchelon(std::move(m));
641   const LinearTransform &transform = result.second;
642   // 1) Apply T to S to obtain S*T.
643   IntegerRelation transformedSet = transform.applyTo(*this);
644 
645   // 2) Remove the unbounded dimensions and constraints involving them to
646   // obtain a bounded set.
647   IntegerRelation boundedSet(transformedSet);
648   unsigned numBoundedDims = result.first;
649   unsigned numUnboundedDims = getNumIds() - numBoundedDims;
650   removeConstraintsInvolvingSuffixDims(boundedSet, numUnboundedDims);
651   boundedSet.removeIdRange(numBoundedDims, boundedSet.getNumIds());
652 
653   // 3) Try to obtain a sample from the bounded set.
654   Optional<SmallVector<int64_t, 8>> boundedSample =
655       Simplex(boundedSet).findIntegerSample();
656   if (!boundedSample)
657     return {};
658   assert(boundedSet.containsPoint(*boundedSample) &&
659          "Simplex returned an invalid sample!");
660 
661   // 4) Substitute the values of the bounded dimensions into S*T to obtain a
662   // full-dimensional cone, which necessarily contains an integer sample.
663   transformedSet.setAndEliminate(0, *boundedSample);
664   IntegerRelation &cone = transformedSet;
665 
666   // 5) Obtain an integer sample from the cone.
667   //
668   // We shrink the cone such that for any rational point in the shrunken cone,
669   // rounding up each of the point's coordinates produces a point that still
670   // lies in the original cone.
671   //
672   // Rounding up a point x adds a number e_i in [0, 1) to each coordinate x_i.
673   // For each inequality sum_i a_i x_i + c >= 0 in the original cone, the
674   // shrunken cone will have the inequality tightened by some amount s, such
675   // that if x satisfies the shrunken cone's tightened inequality, then x + e
676   // satisfies the original inequality, i.e.,
677   //
678   // sum_i a_i x_i + c + s >= 0 implies sum_i a_i (x_i + e_i) + c >= 0
679   //
680   // for any e_i values in [0, 1). In fact, we will handle the slightly more
681   // general case where e_i can be in [0, 1]. For example, consider the
682   // inequality 2x_1 - 3x_2 - 7x_3 - 6 >= 0, and let x = (3, 0, 0). How low
683   // could the LHS go if we added a number in [0, 1] to each coordinate? The LHS
684   // is minimized when we add 1 to the x_i with negative coefficient a_i and
685   // keep the other x_i the same. In the example, we would get x = (3, 1, 1),
686   // changing the value of the LHS by -3 + -7 = -10.
687   //
688   // In general, the value of the LHS can change by at most the sum of the
689   // negative a_i, so we accomodate this by shifting the inequality by this
690   // amount for the shrunken cone.
691   for (unsigned i = 0, e = cone.getNumInequalities(); i < e; ++i) {
692     for (unsigned j = 0; j < cone.getNumIds(); ++j) {
693       int64_t coeff = cone.atIneq(i, j);
694       if (coeff < 0)
695         cone.atIneq(i, cone.getNumIds()) += coeff;
696     }
697   }
698 
699   // Obtain an integer sample in the cone by rounding up a rational point from
700   // the shrunken cone. Shrinking the cone amounts to shifting its apex
701   // "inwards" without changing its "shape"; the shrunken cone is still a
702   // full-dimensional cone and is hence non-empty.
703   Simplex shrunkenConeSimplex(cone);
704   assert(!shrunkenConeSimplex.isEmpty() && "Shrunken cone cannot be empty!");
705 
706   // The sample will always exist since the shrunken cone is non-empty.
707   SmallVector<Fraction, 8> shrunkenConeSample =
708       *shrunkenConeSimplex.getRationalSample();
709 
710   SmallVector<int64_t, 8> coneSample(llvm::map_range(shrunkenConeSample, ceil));
711 
712   // 6) Return transform * concat(boundedSample, coneSample).
713   SmallVector<int64_t, 8> &sample = boundedSample.getValue();
714   sample.append(coneSample.begin(), coneSample.end());
715   return transform.postMultiplyWithColumn(sample);
716 }
717 
718 /// Helper to evaluate an affine expression at a point.
719 /// The expression is a list of coefficients for the dimensions followed by the
720 /// constant term.
721 static int64_t valueAt(ArrayRef<int64_t> expr, ArrayRef<int64_t> point) {
722   assert(expr.size() == 1 + point.size() &&
723          "Dimensionalities of point and expression don't match!");
724   int64_t value = expr.back();
725   for (unsigned i = 0; i < point.size(); ++i)
726     value += expr[i] * point[i];
727   return value;
728 }
729 
730 /// A point satisfies an equality iff the value of the equality at the
731 /// expression is zero, and it satisfies an inequality iff the value of the
732 /// inequality at that point is non-negative.
733 bool IntegerRelation::containsPoint(ArrayRef<int64_t> point) const {
734   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
735     if (valueAt(getEquality(i), point) != 0)
736       return false;
737   }
738   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
739     if (valueAt(getInequality(i), point) < 0)
740       return false;
741   }
742   return true;
743 }
744 
745 void IntegerRelation::getLocalReprs(std::vector<MaybeLocalRepr> &repr) const {
746   std::vector<SmallVector<int64_t, 8>> dividends(getNumLocalIds());
747   SmallVector<unsigned, 4> denominators(getNumLocalIds());
748   getLocalReprs(dividends, denominators, repr);
749 }
750 
751 void IntegerRelation::getLocalReprs(
752     std::vector<SmallVector<int64_t, 8>> &dividends,
753     SmallVector<unsigned, 4> &denominators) const {
754   std::vector<MaybeLocalRepr> repr(getNumLocalIds());
755   getLocalReprs(dividends, denominators, repr);
756 }
757 
758 void IntegerRelation::getLocalReprs(
759     std::vector<SmallVector<int64_t, 8>> &dividends,
760     SmallVector<unsigned, 4> &denominators,
761     std::vector<MaybeLocalRepr> &repr) const {
762 
763   repr.resize(getNumLocalIds());
764   dividends.resize(getNumLocalIds());
765   denominators.resize(getNumLocalIds());
766 
767   SmallVector<bool, 8> foundRepr(getNumIds(), false);
768   for (unsigned i = 0, e = getNumDimAndSymbolIds(); i < e; ++i)
769     foundRepr[i] = true;
770 
771   unsigned divOffset = getNumDimAndSymbolIds();
772   bool changed;
773   do {
774     // Each time changed is true, at end of this iteration, one or more local
775     // vars have been detected as floor divs.
776     changed = false;
777     for (unsigned i = 0, e = getNumLocalIds(); i < e; ++i) {
778       if (!foundRepr[i + divOffset]) {
779         MaybeLocalRepr res = computeSingleVarRepr(
780             *this, foundRepr, divOffset + i, dividends[i], denominators[i]);
781         if (!res)
782           continue;
783         foundRepr[i + divOffset] = true;
784         repr[i] = res;
785         changed = true;
786       }
787     }
788   } while (changed);
789 
790   // Set 0 denominator for identifiers for which no division representation
791   // could be found.
792   for (unsigned i = 0, e = repr.size(); i < e; ++i)
793     if (!repr[i])
794       denominators[i] = 0;
795 }
796 
797 /// Tightens inequalities given that we are dealing with integer spaces. This is
798 /// analogous to the GCD test but applied to inequalities. The constant term can
799 /// be reduced to the preceding multiple of the GCD of the coefficients, i.e.,
800 ///  64*i - 100 >= 0  =>  64*i - 128 >= 0 (since 'i' is an integer). This is a
801 /// fast method - linear in the number of coefficients.
802 // Example on how this affects practical cases: consider the scenario:
803 // 64*i >= 100, j = 64*i; without a tightening, elimination of i would yield
804 // j >= 100 instead of the tighter (exact) j >= 128.
805 void IntegerRelation::gcdTightenInequalities() {
806   unsigned numCols = getNumCols();
807   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
808     // Normalize the constraint and tighten the constant term by the GCD.
809     uint64_t gcd = inequalities.normalizeRow(i, getNumCols() - 1);
810     if (gcd > 1)
811       atIneq(i, numCols - 1) = mlir::floorDiv(atIneq(i, numCols - 1), gcd);
812   }
813 }
814 
815 // Eliminates all identifier variables in column range [posStart, posLimit).
816 // Returns the number of variables eliminated.
817 unsigned IntegerRelation::gaussianEliminateIds(unsigned posStart,
818                                                unsigned posLimit) {
819   // Return if identifier positions to eliminate are out of range.
820   assert(posLimit <= getNumIds());
821   assert(hasConsistentState());
822 
823   if (posStart >= posLimit)
824     return 0;
825 
826   gcdTightenInequalities();
827 
828   unsigned pivotCol = 0;
829   for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) {
830     // Find a row which has a non-zero coefficient in column 'j'.
831     unsigned pivotRow;
832     if (!findConstraintWithNonZeroAt(pivotCol, /*isEq=*/true, &pivotRow)) {
833       // No pivot row in equalities with non-zero at 'pivotCol'.
834       if (!findConstraintWithNonZeroAt(pivotCol, /*isEq=*/false, &pivotRow)) {
835         // If inequalities are also non-zero in 'pivotCol', it can be
836         // eliminated.
837         continue;
838       }
839       break;
840     }
841 
842     // Eliminate identifier at 'pivotCol' from each equality row.
843     for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
844       eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
845                               /*isEq=*/true);
846       equalities.normalizeRow(i);
847     }
848 
849     // Eliminate identifier at 'pivotCol' from each inequality row.
850     for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
851       eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
852                               /*isEq=*/false);
853       inequalities.normalizeRow(i);
854     }
855     removeEquality(pivotRow);
856     gcdTightenInequalities();
857   }
858   // Update position limit based on number eliminated.
859   posLimit = pivotCol;
860   // Remove eliminated columns from all constraints.
861   removeIdRange(posStart, posLimit);
862   return posLimit - posStart;
863 }
864 
865 // A more complex check to eliminate redundant inequalities. Uses FourierMotzkin
866 // to check if a constraint is redundant.
867 void IntegerRelation::removeRedundantInequalities() {
868   SmallVector<bool, 32> redun(getNumInequalities(), false);
869   // To check if an inequality is redundant, we replace the inequality by its
870   // complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting
871   // system is empty. If it is, the inequality is redundant.
872   IntegerRelation tmpCst(*this);
873   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
874     // Change the inequality to its complement.
875     tmpCst.inequalities.negateRow(r);
876     tmpCst.atIneq(r, tmpCst.getNumCols() - 1)--;
877     if (tmpCst.isEmpty()) {
878       redun[r] = true;
879       // Zero fill the redundant inequality.
880       inequalities.fillRow(r, /*value=*/0);
881       tmpCst.inequalities.fillRow(r, /*value=*/0);
882     } else {
883       // Reverse the change (to avoid recreating tmpCst each time).
884       tmpCst.atIneq(r, tmpCst.getNumCols() - 1)++;
885       tmpCst.inequalities.negateRow(r);
886     }
887   }
888 
889   unsigned pos = 0;
890   for (unsigned r = 0, e = getNumInequalities(); r < e; ++r) {
891     if (!redun[r])
892       inequalities.copyRow(r, pos++);
893   }
894   inequalities.resizeVertically(pos);
895 }
896 
897 // A more complex check to eliminate redundant inequalities and equalities. Uses
898 // Simplex to check if a constraint is redundant.
899 void IntegerRelation::removeRedundantConstraints() {
900   // First, we run gcdTightenInequalities. This allows us to catch some
901   // constraints which are not redundant when considering rational solutions
902   // but are redundant in terms of integer solutions.
903   gcdTightenInequalities();
904   Simplex simplex(*this);
905   simplex.detectRedundant();
906 
907   unsigned pos = 0;
908   unsigned numIneqs = getNumInequalities();
909   // Scan to get rid of all inequalities marked redundant, in-place. In Simplex,
910   // the first constraints added are the inequalities.
911   for (unsigned r = 0; r < numIneqs; r++) {
912     if (!simplex.isMarkedRedundant(r))
913       inequalities.copyRow(r, pos++);
914   }
915   inequalities.resizeVertically(pos);
916 
917   // Scan to get rid of all equalities marked redundant, in-place. In Simplex,
918   // after the inequalities, a pair of constraints for each equality is added.
919   // An equality is redundant if both the inequalities in its pair are
920   // redundant.
921   pos = 0;
922   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
923     if (!(simplex.isMarkedRedundant(numIneqs + 2 * r) &&
924           simplex.isMarkedRedundant(numIneqs + 2 * r + 1)))
925       equalities.copyRow(r, pos++);
926   }
927   equalities.resizeVertically(pos);
928 }
929 
930 Optional<uint64_t> IntegerRelation::computeVolume() const {
931   assert(getNumSymbolIds() == 0 && "Symbols are not yet supported!");
932 
933   Simplex simplex(*this);
934   // If the polytope is rationally empty, there are certainly no integer
935   // points.
936   if (simplex.isEmpty())
937     return 0;
938 
939   // Just find the maximum and minimum integer value of each non-local id
940   // separately, thus finding the number of integer values each such id can
941   // take. Multiplying these together gives a valid overapproximation of the
942   // number of integer points in the relation. The result this gives is
943   // equivalent to projecting (rationally) the relation onto its non-local ids
944   // and returning the number of integer points in a minimal axis-parallel
945   // hyperrectangular overapproximation of that.
946   //
947   // We also handle the special case where one dimension is unbounded and
948   // another dimension can take no integer values. In this case, the volume is
949   // zero.
950   //
951   // If there is no such empty dimension, if any dimension is unbounded we
952   // just return the result as unbounded.
953   uint64_t count = 1;
954   SmallVector<int64_t, 8> dim(getNumIds() + 1);
955   bool hasUnboundedId = false;
956   for (unsigned i = 0, e = getNumDimAndSymbolIds(); i < e; ++i) {
957     dim[i] = 1;
958     MaybeOptimum<int64_t> min, max;
959     std::tie(min, max) = simplex.computeIntegerBounds(dim);
960     dim[i] = 0;
961 
962     assert((!min.isEmpty() && !max.isEmpty()) &&
963            "Polytope should be rationally non-empty!");
964 
965     // One of the dimensions is unbounded. Note this fact. We will return
966     // unbounded if none of the other dimensions makes the volume zero.
967     if (min.isUnbounded() || max.isUnbounded()) {
968       hasUnboundedId = true;
969       continue;
970     }
971 
972     // In this case there are no valid integer points and the volume is
973     // definitely zero.
974     if (min.getBoundedOptimum() > max.getBoundedOptimum())
975       return 0;
976 
977     count *= (*max - *min + 1);
978   }
979 
980   if (count == 0)
981     return 0;
982   if (hasUnboundedId)
983     return {};
984   return count;
985 }
986 
987 void IntegerRelation::eliminateRedundantLocalId(unsigned posA, unsigned posB) {
988   assert(posA < getNumLocalIds() && "Invalid local id position");
989   assert(posB < getNumLocalIds() && "Invalid local id position");
990 
991   unsigned localOffset = getIdKindOffset(IdKind::Local);
992   posA += localOffset;
993   posB += localOffset;
994   inequalities.addToColumn(posB, posA, 1);
995   equalities.addToColumn(posB, posA, 1);
996   removeId(posB);
997 }
998 
999 /// Adds additional local ids to the sets such that they both have the union
1000 /// of the local ids in each set, without changing the set of points that
1001 /// lie in `this` and `other`.
1002 ///
1003 /// To detect local ids that always take the same in both sets, each local id is
1004 /// represented as a floordiv with constant denominator in terms of other ids.
1005 /// After extracting these divisions, local ids with the same division
1006 /// representation are considered duplicate and are merged. It is possible that
1007 /// division representation for some local id cannot be obtained, and thus these
1008 /// local ids are not considered for detecting duplicates.
1009 void IntegerRelation::mergeLocalIds(IntegerRelation &other) {
1010   assert(PresburgerSpace::isEqual(other) && "Spaces should match.");
1011 
1012   IntegerRelation &relA = *this;
1013   IntegerRelation &relB = other;
1014 
1015   // Merge local ids of relA and relB without using division information,
1016   // i.e. append local ids of `relB` to `relA` and insert local ids of `relA`
1017   // to `relB` at start of its local ids.
1018   unsigned initLocals = relA.getNumLocalIds();
1019   insertId(IdKind::Local, relA.getNumLocalIds(), relB.getNumLocalIds());
1020   relB.insertId(IdKind::Local, 0, initLocals);
1021 
1022   // Get division representations from each rel.
1023   std::vector<SmallVector<int64_t, 8>> divsA, divsB;
1024   SmallVector<unsigned, 4> denomsA, denomsB;
1025   relA.getLocalReprs(divsA, denomsA);
1026   relB.getLocalReprs(divsB, denomsB);
1027 
1028   // Copy division information for relB into `divsA` and `denomsA`, so that
1029   // these have the combined division information of both rels. Since newly
1030   // added local variables in relA and relB have no constraints, they will not
1031   // have any division representation.
1032   std::copy(divsB.begin() + initLocals, divsB.end(),
1033             divsA.begin() + initLocals);
1034   std::copy(denomsB.begin() + initLocals, denomsB.end(),
1035             denomsA.begin() + initLocals);
1036 
1037   // Merge function that merges the local variables in both sets by treating
1038   // them as the same identifier.
1039   auto merge = [&relA, &relB](unsigned i, unsigned j) -> bool {
1040     relA.eliminateRedundantLocalId(i, j);
1041     relB.eliminateRedundantLocalId(i, j);
1042     return true;
1043   };
1044 
1045   // Merge all divisions by removing duplicate divisions.
1046   unsigned localOffset = getIdKindOffset(IdKind::Local);
1047   removeDuplicateDivs(divsA, denomsA, localOffset, merge);
1048 }
1049 
1050 /// Removes local variables using equalities. Each equality is checked if it
1051 /// can be reduced to the form: `e = affine-expr`, where `e` is a local
1052 /// variable and `affine-expr` is an affine expression not containing `e`.
1053 /// If an equality satisfies this form, the local variable is replaced in
1054 /// each constraint and then removed. The equality used to replace this local
1055 /// variable is also removed.
1056 void IntegerRelation::removeRedundantLocalVars() {
1057   // Normalize the equality constraints to reduce coefficients of local
1058   // variables to 1 wherever possible.
1059   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
1060     equalities.normalizeRow(i);
1061 
1062   while (true) {
1063     unsigned i, e, j, f;
1064     for (i = 0, e = getNumEqualities(); i < e; ++i) {
1065       // Find a local variable to eliminate using ith equality.
1066       for (j = getNumDimAndSymbolIds(), f = getNumIds(); j < f; ++j)
1067         if (std::abs(atEq(i, j)) == 1)
1068           break;
1069 
1070       // Local variable can be eliminated using ith equality.
1071       if (j < f)
1072         break;
1073     }
1074 
1075     // No equality can be used to eliminate a local variable.
1076     if (i == e)
1077       break;
1078 
1079     // Use the ith equality to simplify other equalities. If any changes
1080     // are made to an equality constraint, it is normalized by GCD.
1081     for (unsigned k = 0, t = getNumEqualities(); k < t; ++k) {
1082       if (atEq(k, j) != 0) {
1083         eliminateFromConstraint(this, k, i, j, j, /*isEq=*/true);
1084         equalities.normalizeRow(k);
1085       }
1086     }
1087 
1088     // Use the ith equality to simplify inequalities.
1089     for (unsigned k = 0, t = getNumInequalities(); k < t; ++k)
1090       eliminateFromConstraint(this, k, i, j, j, /*isEq=*/false);
1091 
1092     // Remove the ith equality and the found local variable.
1093     removeId(j);
1094     removeEquality(i);
1095   }
1096 }
1097 
1098 void IntegerRelation::convertDimToLocal(unsigned dimStart, unsigned dimLimit) {
1099   assert(dimLimit <= getNumDimIds() && "Invalid dim pos range");
1100 
1101   if (dimStart >= dimLimit)
1102     return;
1103 
1104   // Append new local variables corresponding to the dimensions to be converted.
1105   unsigned convertCount = dimLimit - dimStart;
1106   unsigned newLocalIdStart = getNumIds();
1107   appendId(IdKind::Local, convertCount);
1108 
1109   // Swap the new local variables with dimensions.
1110   for (unsigned i = 0; i < convertCount; ++i)
1111     swapId(i + dimStart, i + newLocalIdStart);
1112 
1113   // Remove dimensions converted to local variables.
1114   removeIdRange(dimStart, dimLimit);
1115 }
1116 
1117 void IntegerRelation::addBound(BoundType type, unsigned pos, int64_t value) {
1118   assert(pos < getNumCols());
1119   if (type == BoundType::EQ) {
1120     unsigned row = equalities.appendExtraRow();
1121     equalities(row, pos) = 1;
1122     equalities(row, getNumCols() - 1) = -value;
1123   } else {
1124     unsigned row = inequalities.appendExtraRow();
1125     inequalities(row, pos) = type == BoundType::LB ? 1 : -1;
1126     inequalities(row, getNumCols() - 1) =
1127         type == BoundType::LB ? -value : value;
1128   }
1129 }
1130 
1131 void IntegerRelation::addBound(BoundType type, ArrayRef<int64_t> expr,
1132                                int64_t value) {
1133   assert(type != BoundType::EQ && "EQ not implemented");
1134   assert(expr.size() == getNumCols());
1135   unsigned row = inequalities.appendExtraRow();
1136   for (unsigned i = 0, e = expr.size(); i < e; ++i)
1137     inequalities(row, i) = type == BoundType::LB ? expr[i] : -expr[i];
1138   inequalities(inequalities.getNumRows() - 1, getNumCols() - 1) +=
1139       type == BoundType::LB ? -value : value;
1140 }
1141 
1142 /// Adds a new local identifier as the floordiv of an affine function of other
1143 /// identifiers, the coefficients of which are provided in 'dividend' and with
1144 /// respect to a positive constant 'divisor'. Two constraints are added to the
1145 /// system to capture equivalence with the floordiv.
1146 ///      q = expr floordiv c    <=>   c*q <= expr <= c*q + c - 1.
1147 void IntegerRelation::addLocalFloorDiv(ArrayRef<int64_t> dividend,
1148                                        int64_t divisor) {
1149   assert(dividend.size() == getNumCols() && "incorrect dividend size");
1150   assert(divisor > 0 && "positive divisor expected");
1151 
1152   appendId(IdKind::Local);
1153 
1154   // Add two constraints for this new identifier 'q'.
1155   SmallVector<int64_t, 8> bound(dividend.size() + 1);
1156 
1157   // dividend - q * divisor >= 0
1158   std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1,
1159             bound.begin());
1160   bound.back() = dividend.back();
1161   bound[getNumIds() - 1] = -divisor;
1162   addInequality(bound);
1163 
1164   // -dividend +qdivisor * q + divisor - 1 >= 0
1165   std::transform(bound.begin(), bound.end(), bound.begin(),
1166                  std::negate<int64_t>());
1167   bound[bound.size() - 1] += divisor - 1;
1168   addInequality(bound);
1169 }
1170 
1171 /// Finds an equality that equates the specified identifier to a constant.
1172 /// Returns the position of the equality row. If 'symbolic' is set to true,
1173 /// symbols are also treated like a constant, i.e., an affine function of the
1174 /// symbols is also treated like a constant. Returns -1 if such an equality
1175 /// could not be found.
1176 static int findEqualityToConstant(const IntegerRelation &cst, unsigned pos,
1177                                   bool symbolic = false) {
1178   assert(pos < cst.getNumIds() && "invalid position");
1179   for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
1180     int64_t v = cst.atEq(r, pos);
1181     if (v * v != 1)
1182       continue;
1183     unsigned c;
1184     unsigned f = symbolic ? cst.getNumDimIds() : cst.getNumIds();
1185     // This checks for zeros in all positions other than 'pos' in [0, f)
1186     for (c = 0; c < f; c++) {
1187       if (c == pos)
1188         continue;
1189       if (cst.atEq(r, c) != 0) {
1190         // Dependent on another identifier.
1191         break;
1192       }
1193     }
1194     if (c == f)
1195       // Equality is free of other identifiers.
1196       return r;
1197   }
1198   return -1;
1199 }
1200 
1201 LogicalResult IntegerRelation::constantFoldId(unsigned pos) {
1202   assert(pos < getNumIds() && "invalid position");
1203   int rowIdx;
1204   if ((rowIdx = findEqualityToConstant(*this, pos)) == -1)
1205     return failure();
1206 
1207   // atEq(rowIdx, pos) is either -1 or 1.
1208   assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1);
1209   int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos);
1210   setAndEliminate(pos, constVal);
1211   return success();
1212 }
1213 
1214 void IntegerRelation::constantFoldIdRange(unsigned pos, unsigned num) {
1215   for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) {
1216     if (failed(constantFoldId(t)))
1217       t++;
1218   }
1219 }
1220 
1221 /// Returns a non-negative constant bound on the extent (upper bound - lower
1222 /// bound) of the specified identifier if it is found to be a constant; returns
1223 /// None if it's not a constant. This methods treats symbolic identifiers
1224 /// specially, i.e., it looks for constant differences between affine
1225 /// expressions involving only the symbolic identifiers. See comments at
1226 /// function definition for example. 'lb', if provided, is set to the lower
1227 /// bound associated with the constant difference. Note that 'lb' is purely
1228 /// symbolic and thus will contain the coefficients of the symbolic identifiers
1229 /// and the constant coefficient.
1230 //  Egs: 0 <= i <= 15, return 16.
1231 //       s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol)
1232 //       s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16.
1233 //       s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb =
1234 //       ceil(s0 - 7 / 8) = floor(s0 / 8)).
1235 Optional<int64_t> IntegerRelation::getConstantBoundOnDimSize(
1236     unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *boundFloorDivisor,
1237     SmallVectorImpl<int64_t> *ub, unsigned *minLbPos,
1238     unsigned *minUbPos) const {
1239   assert(pos < getNumDimIds() && "Invalid identifier position");
1240 
1241   // Find an equality for 'pos'^th identifier that equates it to some function
1242   // of the symbolic identifiers (+ constant).
1243   int eqPos = findEqualityToConstant(*this, pos, /*symbolic=*/true);
1244   if (eqPos != -1) {
1245     auto eq = getEquality(eqPos);
1246     // If the equality involves a local var, punt for now.
1247     // TODO: this can be handled in the future by using the explicit
1248     // representation of the local vars.
1249     if (!std::all_of(eq.begin() + getNumDimAndSymbolIds(), eq.end() - 1,
1250                      [](int64_t coeff) { return coeff == 0; }))
1251       return None;
1252 
1253     // This identifier can only take a single value.
1254     if (lb) {
1255       // Set lb to that symbolic value.
1256       lb->resize(getNumSymbolIds() + 1);
1257       if (ub)
1258         ub->resize(getNumSymbolIds() + 1);
1259       for (unsigned c = 0, f = getNumSymbolIds() + 1; c < f; c++) {
1260         int64_t v = atEq(eqPos, pos);
1261         // atEq(eqRow, pos) is either -1 or 1.
1262         assert(v * v == 1);
1263         (*lb)[c] = v < 0 ? atEq(eqPos, getNumDimIds() + c) / -v
1264                          : -atEq(eqPos, getNumDimIds() + c) / v;
1265         // Since this is an equality, ub = lb.
1266         if (ub)
1267           (*ub)[c] = (*lb)[c];
1268       }
1269       assert(boundFloorDivisor &&
1270              "both lb and divisor or none should be provided");
1271       *boundFloorDivisor = 1;
1272     }
1273     if (minLbPos)
1274       *minLbPos = eqPos;
1275     if (minUbPos)
1276       *minUbPos = eqPos;
1277     return 1;
1278   }
1279 
1280   // Check if the identifier appears at all in any of the inequalities.
1281   unsigned r, e;
1282   for (r = 0, e = getNumInequalities(); r < e; r++) {
1283     if (atIneq(r, pos) != 0)
1284       break;
1285   }
1286   if (r == e)
1287     // If it doesn't, there isn't a bound on it.
1288     return None;
1289 
1290   // Positions of constraints that are lower/upper bounds on the variable.
1291   SmallVector<unsigned, 4> lbIndices, ubIndices;
1292 
1293   // Gather all symbolic lower bounds and upper bounds of the variable, i.e.,
1294   // the bounds can only involve symbolic (and local) identifiers. Since the
1295   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
1296   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
1297   getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices,
1298                                /*eqIndices=*/nullptr, /*offset=*/0,
1299                                /*num=*/getNumDimIds());
1300 
1301   Optional<int64_t> minDiff = None;
1302   unsigned minLbPosition = 0, minUbPosition = 0;
1303   for (auto ubPos : ubIndices) {
1304     for (auto lbPos : lbIndices) {
1305       // Look for a lower bound and an upper bound that only differ by a
1306       // constant, i.e., pairs of the form  0 <= c_pos - f(c_i's) <= diffConst.
1307       // For example, if ii is the pos^th variable, we are looking for
1308       // constraints like ii >= i, ii <= ii + 50, 50 being the difference. The
1309       // minimum among all such constant differences is kept since that's the
1310       // constant bounding the extent of the pos^th variable.
1311       unsigned j, e;
1312       for (j = 0, e = getNumCols() - 1; j < e; j++)
1313         if (atIneq(ubPos, j) != -atIneq(lbPos, j)) {
1314           break;
1315         }
1316       if (j < getNumCols() - 1)
1317         continue;
1318       int64_t diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) +
1319                                  atIneq(lbPos, getNumCols() - 1) + 1,
1320                              atIneq(lbPos, pos));
1321       // This bound is non-negative by definition.
1322       diff = std::max<int64_t>(diff, 0);
1323       if (minDiff == None || diff < minDiff) {
1324         minDiff = diff;
1325         minLbPosition = lbPos;
1326         minUbPosition = ubPos;
1327       }
1328     }
1329   }
1330   if (lb && minDiff.hasValue()) {
1331     // Set lb to the symbolic lower bound.
1332     lb->resize(getNumSymbolIds() + 1);
1333     if (ub)
1334       ub->resize(getNumSymbolIds() + 1);
1335     // The lower bound is the ceildiv of the lb constraint over the coefficient
1336     // of the variable at 'pos'. We express the ceildiv equivalently as a floor
1337     // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N +
1338     // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32).
1339     *boundFloorDivisor = atIneq(minLbPosition, pos);
1340     assert(*boundFloorDivisor == -atIneq(minUbPosition, pos));
1341     for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) {
1342       (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c);
1343     }
1344     if (ub) {
1345       for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++)
1346         (*ub)[c] = atIneq(minUbPosition, getNumDimIds() + c);
1347     }
1348     // The lower bound leads to a ceildiv while the upper bound is a floordiv
1349     // whenever the coefficient at pos != 1. ceildiv (val / d) = floordiv (val +
1350     // d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to
1351     // the constant term for the lower bound.
1352     (*lb)[getNumSymbolIds()] += atIneq(minLbPosition, pos) - 1;
1353   }
1354   if (minLbPos)
1355     *minLbPos = minLbPosition;
1356   if (minUbPos)
1357     *minUbPos = minUbPosition;
1358   return minDiff;
1359 }
1360 
1361 template <bool isLower>
1362 Optional<int64_t>
1363 IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) {
1364   assert(pos < getNumIds() && "invalid position");
1365   // Project to 'pos'.
1366   projectOut(0, pos);
1367   projectOut(1, getNumIds() - 1);
1368   // Check if there's an equality equating the '0'^th identifier to a constant.
1369   int eqRowIdx = findEqualityToConstant(*this, 0, /*symbolic=*/false);
1370   if (eqRowIdx != -1)
1371     // atEq(rowIdx, 0) is either -1 or 1.
1372     return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, 0);
1373 
1374   // Check if the identifier appears at all in any of the inequalities.
1375   unsigned r, e;
1376   for (r = 0, e = getNumInequalities(); r < e; r++) {
1377     if (atIneq(r, 0) != 0)
1378       break;
1379   }
1380   if (r == e)
1381     // If it doesn't, there isn't a bound on it.
1382     return None;
1383 
1384   Optional<int64_t> minOrMaxConst = None;
1385 
1386   // Take the max across all const lower bounds (or min across all constant
1387   // upper bounds).
1388   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1389     if (isLower) {
1390       if (atIneq(r, 0) <= 0)
1391         // Not a lower bound.
1392         continue;
1393     } else if (atIneq(r, 0) >= 0) {
1394       // Not an upper bound.
1395       continue;
1396     }
1397     unsigned c, f;
1398     for (c = 0, f = getNumCols() - 1; c < f; c++)
1399       if (c != 0 && atIneq(r, c) != 0)
1400         break;
1401     if (c < getNumCols() - 1)
1402       // Not a constant bound.
1403       continue;
1404 
1405     int64_t boundConst =
1406         isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0))
1407                 : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0));
1408     if (isLower) {
1409       if (minOrMaxConst == None || boundConst > minOrMaxConst)
1410         minOrMaxConst = boundConst;
1411     } else {
1412       if (minOrMaxConst == None || boundConst < minOrMaxConst)
1413         minOrMaxConst = boundConst;
1414     }
1415   }
1416   return minOrMaxConst;
1417 }
1418 
1419 Optional<int64_t> IntegerRelation::getConstantBound(BoundType type,
1420                                                     unsigned pos) const {
1421   if (type == BoundType::LB)
1422     return IntegerRelation(*this)
1423         .computeConstantLowerOrUpperBound</*isLower=*/true>(pos);
1424   if (type == BoundType::UB)
1425     return IntegerRelation(*this)
1426         .computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
1427 
1428   assert(type == BoundType::EQ && "expected EQ");
1429   Optional<int64_t> lb =
1430       IntegerRelation(*this).computeConstantLowerOrUpperBound</*isLower=*/true>(
1431           pos);
1432   Optional<int64_t> ub =
1433       IntegerRelation(*this)
1434           .computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
1435   return (lb && ub && *lb == *ub) ? Optional<int64_t>(*ub) : None;
1436 }
1437 
1438 // A simple (naive and conservative) check for hyper-rectangularity.
1439 bool IntegerRelation::isHyperRectangular(unsigned pos, unsigned num) const {
1440   assert(pos < getNumCols() - 1);
1441   // Check for two non-zero coefficients in the range [pos, pos + sum).
1442   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1443     unsigned sum = 0;
1444     for (unsigned c = pos; c < pos + num; c++) {
1445       if (atIneq(r, c) != 0)
1446         sum++;
1447     }
1448     if (sum > 1)
1449       return false;
1450   }
1451   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1452     unsigned sum = 0;
1453     for (unsigned c = pos; c < pos + num; c++) {
1454       if (atEq(r, c) != 0)
1455         sum++;
1456     }
1457     if (sum > 1)
1458       return false;
1459   }
1460   return true;
1461 }
1462 
1463 /// Removes duplicate constraints, trivially true constraints, and constraints
1464 /// that can be detected as redundant as a result of differing only in their
1465 /// constant term part. A constraint of the form <non-negative constant> >= 0 is
1466 /// considered trivially true.
1467 //  Uses a DenseSet to hash and detect duplicates followed by a linear scan to
1468 //  remove duplicates in place.
1469 void IntegerRelation::removeTrivialRedundancy() {
1470   gcdTightenInequalities();
1471   normalizeConstraintsByGCD();
1472 
1473   // A map used to detect redundancy stemming from constraints that only differ
1474   // in their constant term. The value stored is <row position, const term>
1475   // for a given row.
1476   SmallDenseMap<ArrayRef<int64_t>, std::pair<unsigned, int64_t>>
1477       rowsWithoutConstTerm;
1478   // To unique rows.
1479   SmallDenseSet<ArrayRef<int64_t>, 8> rowSet;
1480 
1481   // Check if constraint is of the form <non-negative-constant> >= 0.
1482   auto isTriviallyValid = [&](unsigned r) -> bool {
1483     for (unsigned c = 0, e = getNumCols() - 1; c < e; c++) {
1484       if (atIneq(r, c) != 0)
1485         return false;
1486     }
1487     return atIneq(r, getNumCols() - 1) >= 0;
1488   };
1489 
1490   // Detect and mark redundant constraints.
1491   SmallVector<bool, 256> redunIneq(getNumInequalities(), false);
1492   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1493     int64_t *rowStart = &inequalities(r, 0);
1494     auto row = ArrayRef<int64_t>(rowStart, getNumCols());
1495     if (isTriviallyValid(r) || !rowSet.insert(row).second) {
1496       redunIneq[r] = true;
1497       continue;
1498     }
1499 
1500     // Among constraints that only differ in the constant term part, mark
1501     // everything other than the one with the smallest constant term redundant.
1502     // (eg: among i - 16j - 5 >= 0, i - 16j - 1 >=0, i - 16j - 7 >= 0, the
1503     // former two are redundant).
1504     int64_t constTerm = atIneq(r, getNumCols() - 1);
1505     auto rowWithoutConstTerm = ArrayRef<int64_t>(rowStart, getNumCols() - 1);
1506     const auto &ret =
1507         rowsWithoutConstTerm.insert({rowWithoutConstTerm, {r, constTerm}});
1508     if (!ret.second) {
1509       // Check if the other constraint has a higher constant term.
1510       auto &val = ret.first->second;
1511       if (val.second > constTerm) {
1512         // The stored row is redundant. Mark it so, and update with this one.
1513         redunIneq[val.first] = true;
1514         val = {r, constTerm};
1515       } else {
1516         // The one stored makes this one redundant.
1517         redunIneq[r] = true;
1518       }
1519     }
1520   }
1521 
1522   // Scan to get rid of all rows marked redundant, in-place.
1523   unsigned pos = 0;
1524   for (unsigned r = 0, e = getNumInequalities(); r < e; r++)
1525     if (!redunIneq[r])
1526       inequalities.copyRow(r, pos++);
1527 
1528   inequalities.resizeVertically(pos);
1529 
1530   // TODO: consider doing this for equalities as well, but probably not worth
1531   // the savings.
1532 }
1533 
1534 #undef DEBUG_TYPE
1535 #define DEBUG_TYPE "fm"
1536 
1537 /// Eliminates identifier at the specified position using Fourier-Motzkin
1538 /// variable elimination. This technique is exact for rational spaces but
1539 /// conservative (in "rare" cases) for integer spaces. The operation corresponds
1540 /// to a projection operation yielding the (convex) set of integer points
1541 /// contained in the rational shadow of the set. An emptiness test that relies
1542 /// on this method will guarantee emptiness, i.e., it disproves the existence of
1543 /// a solution if it says it's empty.
1544 /// If a non-null isResultIntegerExact is passed, it is set to true if the
1545 /// result is also integer exact. If it's set to false, the obtained solution
1546 /// *may* not be exact, i.e., it may contain integer points that do not have an
1547 /// integer pre-image in the original set.
1548 ///
1549 /// Eg:
1550 /// j >= 0, j <= i + 1
1551 /// i >= 0, i <= N + 1
1552 /// Eliminating i yields,
1553 ///   j >= 0, 0 <= N + 1, j - 1 <= N + 1
1554 ///
1555 /// If darkShadow = true, this method computes the dark shadow on elimination;
1556 /// the dark shadow is a convex integer subset of the exact integer shadow. A
1557 /// non-empty dark shadow proves the existence of an integer solution. The
1558 /// elimination in such a case could however be an under-approximation, and thus
1559 /// should not be used for scanning sets or used by itself for dependence
1560 /// checking.
1561 ///
1562 /// Eg: 2-d set, * represents grid points, 'o' represents a point in the set.
1563 ///            ^
1564 ///            |
1565 ///            | * * * * o o
1566 ///         i  | * * o o o o
1567 ///            | o * * * * *
1568 ///            --------------->
1569 ///                 j ->
1570 ///
1571 /// Eliminating i from this system (projecting on the j dimension):
1572 /// rational shadow / integer light shadow:  1 <= j <= 6
1573 /// dark shadow:                             3 <= j <= 6
1574 /// exact integer shadow:                    j = 1 \union  3 <= j <= 6
1575 /// holes/splinters:                         j = 2
1576 ///
1577 /// darkShadow = false, isResultIntegerExact = nullptr are default values.
1578 // TODO: a slight modification to yield dark shadow version of FM (tightened),
1579 // which can prove the existence of a solution if there is one.
1580 void IntegerRelation::fourierMotzkinEliminate(unsigned pos, bool darkShadow,
1581                                               bool *isResultIntegerExact) {
1582   LLVM_DEBUG(llvm::dbgs() << "FM input (eliminate pos " << pos << "):\n");
1583   LLVM_DEBUG(dump());
1584   assert(pos < getNumIds() && "invalid position");
1585   assert(hasConsistentState());
1586 
1587   // Check if this identifier can be eliminated through a substitution.
1588   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1589     if (atEq(r, pos) != 0) {
1590       // Use Gaussian elimination here (since we have an equality).
1591       LogicalResult ret = gaussianEliminateId(pos);
1592       (void)ret;
1593       assert(succeeded(ret) && "Gaussian elimination guaranteed to succeed");
1594       LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n");
1595       LLVM_DEBUG(dump());
1596       return;
1597     }
1598   }
1599 
1600   // A fast linear time tightening.
1601   gcdTightenInequalities();
1602 
1603   // Check if the identifier appears at all in any of the inequalities.
1604   if (isColZero(pos)) {
1605     // If it doesn't appear, just remove the column and return.
1606     // TODO: refactor removeColumns to use it from here.
1607     removeId(pos);
1608     LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
1609     LLVM_DEBUG(dump());
1610     return;
1611   }
1612 
1613   // Positions of constraints that are lower bounds on the variable.
1614   SmallVector<unsigned, 4> lbIndices;
1615   // Positions of constraints that are lower bounds on the variable.
1616   SmallVector<unsigned, 4> ubIndices;
1617   // Positions of constraints that do not involve the variable.
1618   std::vector<unsigned> nbIndices;
1619   nbIndices.reserve(getNumInequalities());
1620 
1621   // Gather all lower bounds and upper bounds of the variable. Since the
1622   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
1623   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
1624   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1625     if (atIneq(r, pos) == 0) {
1626       // Id does not appear in bound.
1627       nbIndices.push_back(r);
1628     } else if (atIneq(r, pos) >= 1) {
1629       // Lower bound.
1630       lbIndices.push_back(r);
1631     } else {
1632       // Upper bound.
1633       ubIndices.push_back(r);
1634     }
1635   }
1636 
1637   // Set the number of dimensions, symbols, locals in the resulting system.
1638   unsigned newNumDomain =
1639       getNumDomainIds() - getIdKindOverlap(IdKind::Domain, pos, pos + 1);
1640   unsigned newNumRange =
1641       getNumRangeIds() - getIdKindOverlap(IdKind::Range, pos, pos + 1);
1642   unsigned newNumSymbols =
1643       getNumSymbolIds() - getIdKindOverlap(IdKind::Symbol, pos, pos + 1);
1644   unsigned newNumLocals =
1645       getNumLocalIds() - getIdKindOverlap(IdKind::Local, pos, pos + 1);
1646 
1647   /// Create the new system which has one identifier less.
1648   IntegerRelation newRel(lbIndices.size() * ubIndices.size() + nbIndices.size(),
1649                          getNumEqualities(), getNumCols() - 1, newNumDomain,
1650                          newNumRange, newNumSymbols, newNumLocals);
1651 
1652   // This will be used to check if the elimination was integer exact.
1653   unsigned lcmProducts = 1;
1654 
1655   // Let x be the variable we are eliminating.
1656   // For each lower bound, lb <= c_l*x, and each upper bound c_u*x <= ub, (note
1657   // that c_l, c_u >= 1) we have:
1658   // lb*lcm(c_l, c_u)/c_l <= lcm(c_l, c_u)*x <= ub*lcm(c_l, c_u)/c_u
1659   // We thus generate a constraint:
1660   // lcm(c_l, c_u)/c_l*lb <= lcm(c_l, c_u)/c_u*ub.
1661   // Note if c_l = c_u = 1, all integer points captured by the resulting
1662   // constraint correspond to integer points in the original system (i.e., they
1663   // have integer pre-images). Hence, if the lcm's are all 1, the elimination is
1664   // integer exact.
1665   for (auto ubPos : ubIndices) {
1666     for (auto lbPos : lbIndices) {
1667       SmallVector<int64_t, 4> ineq;
1668       ineq.reserve(newRel.getNumCols());
1669       int64_t lbCoeff = atIneq(lbPos, pos);
1670       // Note that in the comments above, ubCoeff is the negation of the
1671       // coefficient in the canonical form as the view taken here is that of the
1672       // term being moved to the other size of '>='.
1673       int64_t ubCoeff = -atIneq(ubPos, pos);
1674       // TODO: refactor this loop to avoid all branches inside.
1675       for (unsigned l = 0, e = getNumCols(); l < e; l++) {
1676         if (l == pos)
1677           continue;
1678         assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified");
1679         int64_t lcm = mlir::lcm(lbCoeff, ubCoeff);
1680         ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) +
1681                        atIneq(lbPos, l) * (lcm / lbCoeff));
1682         lcmProducts *= lcm;
1683       }
1684       if (darkShadow) {
1685         // The dark shadow is a convex subset of the exact integer shadow. If
1686         // there is a point here, it proves the existence of a solution.
1687         ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1;
1688       }
1689       // TODO: we need to have a way to add inequalities in-place in
1690       // IntegerRelation instead of creating and copying over.
1691       newRel.addInequality(ineq);
1692     }
1693   }
1694 
1695   LLVM_DEBUG(llvm::dbgs() << "FM isResultIntegerExact: " << (lcmProducts == 1)
1696                           << "\n");
1697   if (lcmProducts == 1 && isResultIntegerExact)
1698     *isResultIntegerExact = true;
1699 
1700   // Copy over the constraints not involving this variable.
1701   for (auto nbPos : nbIndices) {
1702     SmallVector<int64_t, 4> ineq;
1703     ineq.reserve(getNumCols() - 1);
1704     for (unsigned l = 0, e = getNumCols(); l < e; l++) {
1705       if (l == pos)
1706         continue;
1707       ineq.push_back(atIneq(nbPos, l));
1708     }
1709     newRel.addInequality(ineq);
1710   }
1711 
1712   assert(newRel.getNumConstraints() ==
1713          lbIndices.size() * ubIndices.size() + nbIndices.size());
1714 
1715   // Copy over the equalities.
1716   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1717     SmallVector<int64_t, 4> eq;
1718     eq.reserve(newRel.getNumCols());
1719     for (unsigned l = 0, e = getNumCols(); l < e; l++) {
1720       if (l == pos)
1721         continue;
1722       eq.push_back(atEq(r, l));
1723     }
1724     newRel.addEquality(eq);
1725   }
1726 
1727   // GCD tightening and normalization allows detection of more trivially
1728   // redundant constraints.
1729   newRel.gcdTightenInequalities();
1730   newRel.normalizeConstraintsByGCD();
1731   newRel.removeTrivialRedundancy();
1732   clearAndCopyFrom(newRel);
1733   LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
1734   LLVM_DEBUG(dump());
1735 }
1736 
1737 #undef DEBUG_TYPE
1738 #define DEBUG_TYPE "presburger"
1739 
1740 void IntegerRelation::projectOut(unsigned pos, unsigned num) {
1741   if (num == 0)
1742     return;
1743 
1744   // 'pos' can be at most getNumCols() - 2 if num > 0.
1745   assert((getNumCols() < 2 || pos <= getNumCols() - 2) && "invalid position");
1746   assert(pos + num < getNumCols() && "invalid range");
1747 
1748   // Eliminate as many identifiers as possible using Gaussian elimination.
1749   unsigned currentPos = pos;
1750   unsigned numToEliminate = num;
1751   unsigned numGaussianEliminated = 0;
1752 
1753   while (currentPos < getNumIds()) {
1754     unsigned curNumEliminated =
1755         gaussianEliminateIds(currentPos, currentPos + numToEliminate);
1756     ++currentPos;
1757     numToEliminate -= curNumEliminated + 1;
1758     numGaussianEliminated += curNumEliminated;
1759   }
1760 
1761   // Eliminate the remaining using Fourier-Motzkin.
1762   for (unsigned i = 0; i < num - numGaussianEliminated; i++) {
1763     unsigned numToEliminate = num - numGaussianEliminated - i;
1764     fourierMotzkinEliminate(
1765         getBestIdToEliminate(*this, pos, pos + numToEliminate));
1766   }
1767 
1768   // Fast/trivial simplifications.
1769   gcdTightenInequalities();
1770   // Normalize constraints after tightening since the latter impacts this, but
1771   // not the other way round.
1772   normalizeConstraintsByGCD();
1773 }
1774 
1775 namespace {
1776 
1777 enum BoundCmpResult { Greater, Less, Equal, Unknown };
1778 
1779 /// Compares two affine bounds whose coefficients are provided in 'first' and
1780 /// 'second'. The last coefficient is the constant term.
1781 static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
1782   assert(a.size() == b.size());
1783 
1784   // For the bounds to be comparable, their corresponding identifier
1785   // coefficients should be equal; the constant terms are then compared to
1786   // determine less/greater/equal.
1787 
1788   if (!std::equal(a.begin(), a.end() - 1, b.begin()))
1789     return Unknown;
1790 
1791   if (a.back() == b.back())
1792     return Equal;
1793 
1794   return a.back() < b.back() ? Less : Greater;
1795 }
1796 } // namespace
1797 
1798 // Returns constraints that are common to both A & B.
1799 static void getCommonConstraints(const IntegerRelation &a,
1800                                  const IntegerRelation &b, IntegerRelation &c) {
1801   c = IntegerRelation(a.getNumDomainIds(), a.getNumRangeIds(),
1802                       a.getNumSymbolIds(), a.getNumLocalIds());
1803   // a naive O(n^2) check should be enough here given the input sizes.
1804   for (unsigned r = 0, e = a.getNumInequalities(); r < e; ++r) {
1805     for (unsigned s = 0, f = b.getNumInequalities(); s < f; ++s) {
1806       if (a.getInequality(r) == b.getInequality(s)) {
1807         c.addInequality(a.getInequality(r));
1808         break;
1809       }
1810     }
1811   }
1812   for (unsigned r = 0, e = a.getNumEqualities(); r < e; ++r) {
1813     for (unsigned s = 0, f = b.getNumEqualities(); s < f; ++s) {
1814       if (a.getEquality(r) == b.getEquality(s)) {
1815         c.addEquality(a.getEquality(r));
1816         break;
1817       }
1818     }
1819   }
1820 }
1821 
1822 // Computes the bounding box with respect to 'other' by finding the min of the
1823 // lower bounds and the max of the upper bounds along each of the dimensions.
1824 LogicalResult
1825 IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
1826   assert(PresburgerLocalSpace::isEqual(otherCst) && "Spaces should match.");
1827   assert(getNumLocalIds() == 0 && "local ids not supported yet here");
1828 
1829   // Get the constraints common to both systems; these will be added as is to
1830   // the union.
1831   IntegerRelation commonCst;
1832   getCommonConstraints(*this, otherCst, commonCst);
1833 
1834   std::vector<SmallVector<int64_t, 8>> boundingLbs;
1835   std::vector<SmallVector<int64_t, 8>> boundingUbs;
1836   boundingLbs.reserve(2 * getNumDimIds());
1837   boundingUbs.reserve(2 * getNumDimIds());
1838 
1839   // To hold lower and upper bounds for each dimension.
1840   SmallVector<int64_t, 4> lb, otherLb, ub, otherUb;
1841   // To compute min of lower bounds and max of upper bounds for each dimension.
1842   SmallVector<int64_t, 4> minLb(getNumSymbolIds() + 1);
1843   SmallVector<int64_t, 4> maxUb(getNumSymbolIds() + 1);
1844   // To compute final new lower and upper bounds for the union.
1845   SmallVector<int64_t, 8> newLb(getNumCols()), newUb(getNumCols());
1846 
1847   int64_t lbFloorDivisor, otherLbFloorDivisor;
1848   for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
1849     auto extent = getConstantBoundOnDimSize(d, &lb, &lbFloorDivisor, &ub);
1850     if (!extent.hasValue())
1851       // TODO: symbolic extents when necessary.
1852       // TODO: handle union if a dimension is unbounded.
1853       return failure();
1854 
1855     auto otherExtent = otherCst.getConstantBoundOnDimSize(
1856         d, &otherLb, &otherLbFloorDivisor, &otherUb);
1857     if (!otherExtent.hasValue() || lbFloorDivisor != otherLbFloorDivisor)
1858       // TODO: symbolic extents when necessary.
1859       return failure();
1860 
1861     assert(lbFloorDivisor > 0 && "divisor always expected to be positive");
1862 
1863     auto res = compareBounds(lb, otherLb);
1864     // Identify min.
1865     if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) {
1866       minLb = lb;
1867       // Since the divisor is for a floordiv, we need to convert to ceildiv,
1868       // i.e., i >= expr floordiv div <=> i >= (expr - div + 1) ceildiv div <=>
1869       // div * i >= expr - div + 1.
1870       minLb.back() -= lbFloorDivisor - 1;
1871     } else if (res == BoundCmpResult::Greater) {
1872       minLb = otherLb;
1873       minLb.back() -= otherLbFloorDivisor - 1;
1874     } else {
1875       // Uncomparable - check for constant lower/upper bounds.
1876       auto constLb = getConstantBound(BoundType::LB, d);
1877       auto constOtherLb = otherCst.getConstantBound(BoundType::LB, d);
1878       if (!constLb.hasValue() || !constOtherLb.hasValue())
1879         return failure();
1880       std::fill(minLb.begin(), minLb.end(), 0);
1881       minLb.back() = std::min(constLb.getValue(), constOtherLb.getValue());
1882     }
1883 
1884     // Do the same for ub's but max of upper bounds. Identify max.
1885     auto uRes = compareBounds(ub, otherUb);
1886     if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) {
1887       maxUb = ub;
1888     } else if (uRes == BoundCmpResult::Less) {
1889       maxUb = otherUb;
1890     } else {
1891       // Uncomparable - check for constant lower/upper bounds.
1892       auto constUb = getConstantBound(BoundType::UB, d);
1893       auto constOtherUb = otherCst.getConstantBound(BoundType::UB, d);
1894       if (!constUb.hasValue() || !constOtherUb.hasValue())
1895         return failure();
1896       std::fill(maxUb.begin(), maxUb.end(), 0);
1897       maxUb.back() = std::max(constUb.getValue(), constOtherUb.getValue());
1898     }
1899 
1900     std::fill(newLb.begin(), newLb.end(), 0);
1901     std::fill(newUb.begin(), newUb.end(), 0);
1902 
1903     // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor,
1904     // and so it's the divisor for newLb and newUb as well.
1905     newLb[d] = lbFloorDivisor;
1906     newUb[d] = -lbFloorDivisor;
1907     // Copy over the symbolic part + constant term.
1908     std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimIds());
1909     std::transform(newLb.begin() + getNumDimIds(), newLb.end(),
1910                    newLb.begin() + getNumDimIds(), std::negate<int64_t>());
1911     std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimIds());
1912 
1913     boundingLbs.push_back(newLb);
1914     boundingUbs.push_back(newUb);
1915   }
1916 
1917   // Clear all constraints and add the lower/upper bounds for the bounding box.
1918   clearConstraints();
1919   for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
1920     addInequality(boundingLbs[d]);
1921     addInequality(boundingUbs[d]);
1922   }
1923 
1924   // Add the constraints that were common to both systems.
1925   append(commonCst);
1926   removeTrivialRedundancy();
1927 
1928   // TODO: copy over pure symbolic constraints from this and 'other' over to the
1929   // union (since the above are just the union along dimensions); we shouldn't
1930   // be discarding any other constraints on the symbols.
1931 
1932   return success();
1933 }
1934 
1935 bool IntegerRelation::isColZero(unsigned pos) const {
1936   unsigned rowPos;
1937   return !findConstraintWithNonZeroAt(pos, /*isEq=*/false, &rowPos) &&
1938          !findConstraintWithNonZeroAt(pos, /*isEq=*/true, &rowPos);
1939 }
1940 
1941 /// Find positions of inequalities and equalities that do not have a coefficient
1942 /// for [pos, pos + num) identifiers.
1943 static void getIndependentConstraints(const IntegerRelation &cst, unsigned pos,
1944                                       unsigned num,
1945                                       SmallVectorImpl<unsigned> &nbIneqIndices,
1946                                       SmallVectorImpl<unsigned> &nbEqIndices) {
1947   assert(pos < cst.getNumIds() && "invalid start position");
1948   assert(pos + num <= cst.getNumIds() && "invalid limit");
1949 
1950   for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
1951     // The bounds are to be independent of [offset, offset + num) columns.
1952     unsigned c;
1953     for (c = pos; c < pos + num; ++c) {
1954       if (cst.atIneq(r, c) != 0)
1955         break;
1956     }
1957     if (c == pos + num)
1958       nbIneqIndices.push_back(r);
1959   }
1960 
1961   for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
1962     // The bounds are to be independent of [offset, offset + num) columns.
1963     unsigned c;
1964     for (c = pos; c < pos + num; ++c) {
1965       if (cst.atEq(r, c) != 0)
1966         break;
1967     }
1968     if (c == pos + num)
1969       nbEqIndices.push_back(r);
1970   }
1971 }
1972 
1973 void IntegerRelation::removeIndependentConstraints(unsigned pos, unsigned num) {
1974   assert(pos + num <= getNumIds() && "invalid range");
1975 
1976   // Remove constraints that are independent of these identifiers.
1977   SmallVector<unsigned, 4> nbIneqIndices, nbEqIndices;
1978   getIndependentConstraints(*this, /*pos=*/0, num, nbIneqIndices, nbEqIndices);
1979 
1980   // Iterate in reverse so that indices don't have to be updated.
1981   // TODO: This method can be made more efficient (because removal of each
1982   // inequality leads to much shifting/copying in the underlying buffer).
1983   for (auto nbIndex : llvm::reverse(nbIneqIndices))
1984     removeInequality(nbIndex);
1985   for (auto nbIndex : llvm::reverse(nbEqIndices))
1986     removeEquality(nbIndex);
1987 }
1988 
1989 void IntegerRelation::printSpace(raw_ostream &os) const {
1990   PresburgerLocalSpace::print(os);
1991   os << getNumConstraints() << " constraints\n";
1992 }
1993 
1994 void IntegerRelation::print(raw_ostream &os) const {
1995   assert(hasConsistentState());
1996   printSpace(os);
1997   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
1998     for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
1999       os << atEq(i, j) << " ";
2000     }
2001     os << "= 0\n";
2002   }
2003   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
2004     for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
2005       os << atIneq(i, j) << " ";
2006     }
2007     os << ">= 0\n";
2008   }
2009   os << '\n';
2010 }
2011 
2012 void IntegerRelation::dump() const { print(llvm::errs()); }
2013