1 //===- IntegerRelation.cpp - MLIR IntegerRelation Class ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A class to represent an relation over integer tuples. A relation is
10 // represented as a constraint system over a space of tuples of integer valued
11 // varaiables supporting symbolic identifiers and existential quantification.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Analysis/Presburger/IntegerRelation.h"
16 #include "mlir/Analysis/Presburger/LinearTransform.h"
17 #include "mlir/Analysis/Presburger/PWMAFunction.h"
18 #include "mlir/Analysis/Presburger/PresburgerRelation.h"
19 #include "mlir/Analysis/Presburger/Simplex.h"
20 #include "mlir/Analysis/Presburger/Utils.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/DenseSet.h"
23 #include "llvm/Support/Debug.h"
24 
25 #define DEBUG_TYPE "presburger"
26 
27 using namespace mlir;
28 using namespace presburger;
29 
30 using llvm::SmallDenseMap;
31 using llvm::SmallDenseSet;
32 
33 std::unique_ptr<IntegerRelation> IntegerRelation::clone() const {
34   return std::make_unique<IntegerRelation>(*this);
35 }
36 
37 std::unique_ptr<IntegerPolyhedron> IntegerPolyhedron::clone() const {
38   return std::make_unique<IntegerPolyhedron>(*this);
39 }
40 
41 void IntegerRelation::append(const IntegerRelation &other) {
42   assert(isSpaceEqual(other) && "Spaces must be equal.");
43 
44   inequalities.reserveRows(inequalities.getNumRows() +
45                            other.getNumInequalities());
46   equalities.reserveRows(equalities.getNumRows() + other.getNumEqualities());
47 
48   for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) {
49     addInequality(other.getInequality(r));
50   }
51   for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) {
52     addEquality(other.getEquality(r));
53   }
54 }
55 
56 IntegerRelation IntegerRelation::intersect(IntegerRelation other) const {
57   IntegerRelation result = *this;
58   result.mergeLocalIds(other);
59   result.append(other);
60   return result;
61 }
62 
63 bool IntegerRelation::isEqual(const IntegerRelation &other) const {
64   assert(isSpaceEqual(other) && "Spaces must be equal.");
65   return PresburgerRelation(*this).isEqual(PresburgerRelation(other));
66 }
67 
68 bool IntegerRelation::isSubsetOf(const IntegerRelation &other) const {
69   assert(isSpaceEqual(other) && "Spaces must be equal.");
70   return PresburgerRelation(*this).isSubsetOf(PresburgerRelation(other));
71 }
72 
73 MaybeOptimum<SmallVector<Fraction, 8>>
74 IntegerRelation::findRationalLexMin() const {
75   assert(getNumSymbolIds() == 0 && "Symbols are not supported!");
76   MaybeOptimum<SmallVector<Fraction, 8>> maybeLexMin =
77       LexSimplex(*this).findRationalLexMin();
78 
79   if (!maybeLexMin.isBounded())
80     return maybeLexMin;
81 
82   // The Simplex returns the lexmin over all the variables including locals. But
83   // locals are not actually part of the space and should not be returned in the
84   // result. Since the locals are placed last in the list of identifiers, they
85   // will be minimized last in the lexmin. So simply truncating out the locals
86   // from the end of the answer gives the desired lexmin over the dimensions.
87   assert(maybeLexMin->size() == getNumIds() &&
88          "Incorrect number of vars in lexMin!");
89   maybeLexMin->resize(getNumDimAndSymbolIds());
90   return maybeLexMin;
91 }
92 
93 MaybeOptimum<SmallVector<int64_t, 8>>
94 IntegerRelation::findIntegerLexMin() const {
95   assert(getNumSymbolIds() == 0 && "Symbols are not supported!");
96   MaybeOptimum<SmallVector<int64_t, 8>> maybeLexMin =
97       LexSimplex(*this).findIntegerLexMin();
98 
99   if (!maybeLexMin.isBounded())
100     return maybeLexMin.getKind();
101 
102   // The Simplex returns the lexmin over all the variables including locals. But
103   // locals are not actually part of the space and should not be returned in the
104   // result. Since the locals are placed last in the list of identifiers, they
105   // will be minimized last in the lexmin. So simply truncating out the locals
106   // from the end of the answer gives the desired lexmin over the dimensions.
107   assert(maybeLexMin->size() == getNumIds() &&
108          "Incorrect number of vars in lexMin!");
109   maybeLexMin->resize(getNumDimAndSymbolIds());
110   return maybeLexMin;
111 }
112 
113 static bool rangeIsZero(ArrayRef<int64_t> range) {
114   return llvm::all_of(range, [](int64_t x) { return x == 0; });
115 }
116 
117 void removeConstraintsInvolvingIdRange(IntegerRelation &poly, unsigned begin,
118                                        unsigned count) {
119   // We loop until i > 0 and index into i - 1 to avoid sign issues.
120   //
121   // We iterate backwards so that whether we remove constraint i - 1 or not, the
122   // next constraint to be tested is always i - 2.
123   for (unsigned i = poly.getNumEqualities(); i > 0; i--)
124     if (!rangeIsZero(poly.getEquality(i - 1).slice(begin, count)))
125       poly.removeEquality(i - 1);
126   for (unsigned i = poly.getNumInequalities(); i > 0; i--)
127     if (!rangeIsZero(poly.getInequality(i - 1).slice(begin, count)))
128       poly.removeInequality(i - 1);
129 }
130 
131 IntegerRelation::CountsSnapshot IntegerRelation::getCounts() const {
132   return {PresburgerSpace(*this), getNumInequalities(), getNumEqualities()};
133 }
134 
135 void IntegerRelation::truncateIdKind(IdKind kind,
136                                      const CountsSnapshot &counts) {
137   truncateIdKind(kind, counts.getSpace().getNumIdKind(kind));
138 }
139 
140 void IntegerRelation::truncate(const CountsSnapshot &counts) {
141   truncateIdKind(IdKind::Domain, counts);
142   truncateIdKind(IdKind::Range, counts);
143   truncateIdKind(IdKind::Symbol, counts);
144   truncateIdKind(IdKind::Local, counts);
145   removeInequalityRange(counts.getNumIneqs(), getNumInequalities());
146   removeEqualityRange(counts.getNumEqs(), getNumEqualities());
147 }
148 
149 SymbolicLexMin IntegerPolyhedron::findSymbolicIntegerLexMin() const {
150   // Compute the symbolic lexmin of the dims and locals, with the symbols being
151   // the actual symbols of this set.
152   SymbolicLexMin result =
153       SymbolicLexSimplex(*this, IntegerPolyhedron(PresburgerSpace::getSetSpace(
154                                     /*numDims=*/getNumSymbolIds())))
155           .computeSymbolicIntegerLexMin();
156 
157   // We want to return only the lexmin over the dims, so strip the locals from
158   // the computed lexmin.
159   result.lexmin.truncateOutput(result.lexmin.getNumOutputs() -
160                                getNumLocalIds());
161   return result;
162 }
163 
164 unsigned IntegerRelation::insertId(IdKind kind, unsigned pos, unsigned num) {
165   assert(pos <= getNumIdKind(kind));
166 
167   unsigned insertPos = PresburgerSpace::insertId(kind, pos, num);
168   inequalities.insertColumns(insertPos, num);
169   equalities.insertColumns(insertPos, num);
170   return insertPos;
171 }
172 
173 unsigned IntegerRelation::appendId(IdKind kind, unsigned num) {
174   unsigned pos = getNumIdKind(kind);
175   return insertId(kind, pos, num);
176 }
177 
178 void IntegerRelation::addEquality(ArrayRef<int64_t> eq) {
179   assert(eq.size() == getNumCols());
180   unsigned row = equalities.appendExtraRow();
181   for (unsigned i = 0, e = eq.size(); i < e; ++i)
182     equalities(row, i) = eq[i];
183 }
184 
185 void IntegerRelation::addInequality(ArrayRef<int64_t> inEq) {
186   assert(inEq.size() == getNumCols());
187   unsigned row = inequalities.appendExtraRow();
188   for (unsigned i = 0, e = inEq.size(); i < e; ++i)
189     inequalities(row, i) = inEq[i];
190 }
191 
192 void IntegerRelation::removeId(IdKind kind, unsigned pos) {
193   removeIdRange(kind, pos, pos + 1);
194 }
195 
196 void IntegerRelation::removeId(unsigned pos) { removeIdRange(pos, pos + 1); }
197 
198 void IntegerRelation::removeIdRange(IdKind kind, unsigned idStart,
199                                     unsigned idLimit) {
200   assert(idLimit <= getNumIdKind(kind));
201 
202   if (idStart >= idLimit)
203     return;
204 
205   // Remove eliminated identifiers from the constraints.
206   unsigned offset = getIdKindOffset(kind);
207   equalities.removeColumns(offset + idStart, idLimit - idStart);
208   inequalities.removeColumns(offset + idStart, idLimit - idStart);
209 
210   // Remove eliminated identifiers from the space.
211   PresburgerSpace::removeIdRange(kind, idStart, idLimit);
212 }
213 
214 void IntegerRelation::removeIdRange(unsigned idStart, unsigned idLimit) {
215   assert(idLimit <= getNumIds());
216 
217   if (idStart >= idLimit)
218     return;
219 
220   // Helper function to remove ids of the specified kind in the given range
221   // [start, limit), The range is absolute (i.e. it is not relative to the kind
222   // of identifier). Also updates `limit` to reflect the deleted identifiers.
223   auto removeIdKindInRange = [this](IdKind kind, unsigned &start,
224                                     unsigned &limit) {
225     if (start >= limit)
226       return;
227 
228     unsigned offset = getIdKindOffset(kind);
229     unsigned num = getNumIdKind(kind);
230 
231     // Get `start`, `limit` relative to the specified kind.
232     unsigned relativeStart =
233         start <= offset ? 0 : std::min(num, start - offset);
234     unsigned relativeLimit =
235         limit <= offset ? 0 : std::min(num, limit - offset);
236 
237     // Remove ids of the specified kind in the relative range.
238     removeIdRange(kind, relativeStart, relativeLimit);
239 
240     // Update `limit` to reflect deleted identifiers.
241     // `start` does not need to be updated because any identifiers that are
242     // deleted are after position `start`.
243     limit -= relativeLimit - relativeStart;
244   };
245 
246   removeIdKindInRange(IdKind::Domain, idStart, idLimit);
247   removeIdKindInRange(IdKind::Range, idStart, idLimit);
248   removeIdKindInRange(IdKind::Symbol, idStart, idLimit);
249   removeIdKindInRange(IdKind::Local, idStart, idLimit);
250 }
251 
252 void IntegerRelation::removeEquality(unsigned pos) {
253   equalities.removeRow(pos);
254 }
255 
256 void IntegerRelation::removeInequality(unsigned pos) {
257   inequalities.removeRow(pos);
258 }
259 
260 void IntegerRelation::removeEqualityRange(unsigned start, unsigned end) {
261   if (start >= end)
262     return;
263   equalities.removeRows(start, end - start);
264 }
265 
266 void IntegerRelation::removeInequalityRange(unsigned start, unsigned end) {
267   if (start >= end)
268     return;
269   inequalities.removeRows(start, end - start);
270 }
271 
272 void IntegerRelation::swapId(unsigned posA, unsigned posB) {
273   assert(posA < getNumIds() && "invalid position A");
274   assert(posB < getNumIds() && "invalid position B");
275 
276   if (posA == posB)
277     return;
278 
279   inequalities.swapColumns(posA, posB);
280   equalities.swapColumns(posA, posB);
281 }
282 
283 void IntegerRelation::clearConstraints() {
284   equalities.resizeVertically(0);
285   inequalities.resizeVertically(0);
286 }
287 
288 /// Gather all lower and upper bounds of the identifier at `pos`, and
289 /// optionally any equalities on it. In addition, the bounds are to be
290 /// independent of identifiers in position range [`offset`, `offset` + `num`).
291 void IntegerRelation::getLowerAndUpperBoundIndices(
292     unsigned pos, SmallVectorImpl<unsigned> *lbIndices,
293     SmallVectorImpl<unsigned> *ubIndices, SmallVectorImpl<unsigned> *eqIndices,
294     unsigned offset, unsigned num) const {
295   assert(pos < getNumIds() && "invalid position");
296   assert(offset + num < getNumCols() && "invalid range");
297 
298   // Checks for a constraint that has a non-zero coeff for the identifiers in
299   // the position range [offset, offset + num) while ignoring `pos`.
300   auto containsConstraintDependentOnRange = [&](unsigned r, bool isEq) {
301     unsigned c, f;
302     auto cst = isEq ? getEquality(r) : getInequality(r);
303     for (c = offset, f = offset + num; c < f; ++c) {
304       if (c == pos)
305         continue;
306       if (cst[c] != 0)
307         break;
308     }
309     return c < f;
310   };
311 
312   // Gather all lower bounds and upper bounds of the variable. Since the
313   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
314   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
315   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
316     // The bounds are to be independent of [offset, offset + num) columns.
317     if (containsConstraintDependentOnRange(r, /*isEq=*/false))
318       continue;
319     if (atIneq(r, pos) >= 1) {
320       // Lower bound.
321       lbIndices->push_back(r);
322     } else if (atIneq(r, pos) <= -1) {
323       // Upper bound.
324       ubIndices->push_back(r);
325     }
326   }
327 
328   // An equality is both a lower and upper bound. Record any equalities
329   // involving the pos^th identifier.
330   if (!eqIndices)
331     return;
332 
333   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
334     if (atEq(r, pos) == 0)
335       continue;
336     if (containsConstraintDependentOnRange(r, /*isEq=*/true))
337       continue;
338     eqIndices->push_back(r);
339   }
340 }
341 
342 bool IntegerRelation::hasConsistentState() const {
343   if (!inequalities.hasConsistentState())
344     return false;
345   if (!equalities.hasConsistentState())
346     return false;
347   return true;
348 }
349 
350 void IntegerRelation::setAndEliminate(unsigned pos, ArrayRef<int64_t> values) {
351   if (values.empty())
352     return;
353   assert(pos + values.size() <= getNumIds() &&
354          "invalid position or too many values");
355   // Setting x_j = p in sum_i a_i x_i + c is equivalent to adding p*a_j to the
356   // constant term and removing the id x_j. We do this for all the ids
357   // pos, pos + 1, ... pos + values.size() - 1.
358   unsigned constantColPos = getNumCols() - 1;
359   for (unsigned i = 0, numVals = values.size(); i < numVals; ++i)
360     inequalities.addToColumn(i + pos, constantColPos, values[i]);
361   for (unsigned i = 0, numVals = values.size(); i < numVals; ++i)
362     equalities.addToColumn(i + pos, constantColPos, values[i]);
363   removeIdRange(pos, pos + values.size());
364 }
365 
366 void IntegerRelation::clearAndCopyFrom(const IntegerRelation &other) {
367   *this = other;
368 }
369 
370 // Searches for a constraint with a non-zero coefficient at `colIdx` in
371 // equality (isEq=true) or inequality (isEq=false) constraints.
372 // Returns true and sets row found in search in `rowIdx`, false otherwise.
373 bool IntegerRelation::findConstraintWithNonZeroAt(unsigned colIdx, bool isEq,
374                                                   unsigned *rowIdx) const {
375   assert(colIdx < getNumCols() && "position out of bounds");
376   auto at = [&](unsigned rowIdx) -> int64_t {
377     return isEq ? atEq(rowIdx, colIdx) : atIneq(rowIdx, colIdx);
378   };
379   unsigned e = isEq ? getNumEqualities() : getNumInequalities();
380   for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) {
381     if (at(*rowIdx) != 0) {
382       return true;
383     }
384   }
385   return false;
386 }
387 
388 void IntegerRelation::normalizeConstraintsByGCD() {
389   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
390     equalities.normalizeRow(i);
391   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i)
392     inequalities.normalizeRow(i);
393 }
394 
395 bool IntegerRelation::hasInvalidConstraint() const {
396   assert(hasConsistentState());
397   auto check = [&](bool isEq) -> bool {
398     unsigned numCols = getNumCols();
399     unsigned numRows = isEq ? getNumEqualities() : getNumInequalities();
400     for (unsigned i = 0, e = numRows; i < e; ++i) {
401       unsigned j;
402       for (j = 0; j < numCols - 1; ++j) {
403         int64_t v = isEq ? atEq(i, j) : atIneq(i, j);
404         // Skip rows with non-zero variable coefficients.
405         if (v != 0)
406           break;
407       }
408       if (j < numCols - 1) {
409         continue;
410       }
411       // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'.
412       // Example invalid constraints include: '1 == 0' or '-1 >= 0'
413       int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1);
414       if ((isEq && v != 0) || (!isEq && v < 0)) {
415         return true;
416       }
417     }
418     return false;
419   };
420   if (check(/*isEq=*/true))
421     return true;
422   return check(/*isEq=*/false);
423 }
424 
425 /// Eliminate identifier from constraint at `rowIdx` based on coefficient at
426 /// pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be
427 /// updated as they have already been eliminated.
428 static void eliminateFromConstraint(IntegerRelation *constraints,
429                                     unsigned rowIdx, unsigned pivotRow,
430                                     unsigned pivotCol, unsigned elimColStart,
431                                     bool isEq) {
432   // Skip if equality 'rowIdx' if same as 'pivotRow'.
433   if (isEq && rowIdx == pivotRow)
434     return;
435   auto at = [&](unsigned i, unsigned j) -> int64_t {
436     return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j);
437   };
438   int64_t leadCoeff = at(rowIdx, pivotCol);
439   // Skip if leading coefficient at 'rowIdx' is already zero.
440   if (leadCoeff == 0)
441     return;
442   int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol);
443   int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1;
444   int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff);
445   int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff));
446   int64_t rowMultiplier = lcm / std::abs(leadCoeff);
447 
448   unsigned numCols = constraints->getNumCols();
449   for (unsigned j = 0; j < numCols; ++j) {
450     // Skip updating column 'j' if it was just eliminated.
451     if (j >= elimColStart && j < pivotCol)
452       continue;
453     int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) +
454                 rowMultiplier * at(rowIdx, j);
455     isEq ? constraints->atEq(rowIdx, j) = v
456          : constraints->atIneq(rowIdx, j) = v;
457   }
458 }
459 
460 /// Returns the position of the identifier that has the minimum <number of lower
461 /// bounds> times <number of upper bounds> from the specified range of
462 /// identifiers [start, end). It is often best to eliminate in the increasing
463 /// order of these counts when doing Fourier-Motzkin elimination since FM adds
464 /// that many new constraints.
465 static unsigned getBestIdToEliminate(const IntegerRelation &cst, unsigned start,
466                                      unsigned end) {
467   assert(start < cst.getNumIds() && end < cst.getNumIds() + 1);
468 
469   auto getProductOfNumLowerUpperBounds = [&](unsigned pos) {
470     unsigned numLb = 0;
471     unsigned numUb = 0;
472     for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
473       if (cst.atIneq(r, pos) > 0) {
474         ++numLb;
475       } else if (cst.atIneq(r, pos) < 0) {
476         ++numUb;
477       }
478     }
479     return numLb * numUb;
480   };
481 
482   unsigned minLoc = start;
483   unsigned min = getProductOfNumLowerUpperBounds(start);
484   for (unsigned c = start + 1; c < end; c++) {
485     unsigned numLbUbProduct = getProductOfNumLowerUpperBounds(c);
486     if (numLbUbProduct < min) {
487       min = numLbUbProduct;
488       minLoc = c;
489     }
490   }
491   return minLoc;
492 }
493 
494 // Checks for emptiness of the set by eliminating identifiers successively and
495 // using the GCD test (on all equality constraints) and checking for trivially
496 // invalid constraints. Returns 'true' if the constraint system is found to be
497 // empty; false otherwise.
498 bool IntegerRelation::isEmpty() const {
499   if (isEmptyByGCDTest() || hasInvalidConstraint())
500     return true;
501 
502   IntegerRelation tmpCst(*this);
503 
504   // First, eliminate as many local variables as possible using equalities.
505   tmpCst.removeRedundantLocalVars();
506   if (tmpCst.isEmptyByGCDTest() || tmpCst.hasInvalidConstraint())
507     return true;
508 
509   // Eliminate as many identifiers as possible using Gaussian elimination.
510   unsigned currentPos = 0;
511   while (currentPos < tmpCst.getNumIds()) {
512     tmpCst.gaussianEliminateIds(currentPos, tmpCst.getNumIds());
513     ++currentPos;
514     // We check emptiness through trivial checks after eliminating each ID to
515     // detect emptiness early. Since the checks isEmptyByGCDTest() and
516     // hasInvalidConstraint() are linear time and single sweep on the constraint
517     // buffer, this appears reasonable - but can optimize in the future.
518     if (tmpCst.hasInvalidConstraint() || tmpCst.isEmptyByGCDTest())
519       return true;
520   }
521 
522   // Eliminate the remaining using FM.
523   for (unsigned i = 0, e = tmpCst.getNumIds(); i < e; i++) {
524     tmpCst.fourierMotzkinEliminate(
525         getBestIdToEliminate(tmpCst, 0, tmpCst.getNumIds()));
526     // Check for a constraint explosion. This rarely happens in practice, but
527     // this check exists as a safeguard against improperly constructed
528     // constraint systems or artificially created arbitrarily complex systems
529     // that aren't the intended use case for IntegerRelation. This is
530     // needed since FM has a worst case exponential complexity in theory.
531     if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumIds()) {
532       LLVM_DEBUG(llvm::dbgs() << "FM constraint explosion detected\n");
533       return false;
534     }
535 
536     // FM wouldn't have modified the equalities in any way. So no need to again
537     // run GCD test. Check for trivial invalid constraints.
538     if (tmpCst.hasInvalidConstraint())
539       return true;
540   }
541   return false;
542 }
543 
544 // Runs the GCD test on all equality constraints. Returns 'true' if this test
545 // fails on any equality. Returns 'false' otherwise.
546 // This test can be used to disprove the existence of a solution. If it returns
547 // true, no integer solution to the equality constraints can exist.
548 //
549 // GCD test definition:
550 //
551 // The equality constraint:
552 //
553 //  c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0
554 //
555 // has an integer solution iff:
556 //
557 //  GCD of c_1, c_2, ..., c_n divides c_0.
558 //
559 bool IntegerRelation::isEmptyByGCDTest() const {
560   assert(hasConsistentState());
561   unsigned numCols = getNumCols();
562   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
563     uint64_t gcd = std::abs(atEq(i, 0));
564     for (unsigned j = 1; j < numCols - 1; ++j) {
565       gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atEq(i, j)));
566     }
567     int64_t v = std::abs(atEq(i, numCols - 1));
568     if (gcd > 0 && (v % gcd != 0)) {
569       return true;
570     }
571   }
572   return false;
573 }
574 
575 // Returns a matrix where each row is a vector along which the polytope is
576 // bounded. The span of the returned vectors is guaranteed to contain all
577 // such vectors. The returned vectors are NOT guaranteed to be linearly
578 // independent. This function should not be called on empty sets.
579 //
580 // It is sufficient to check the perpendiculars of the constraints, as the set
581 // of perpendiculars which are bounded must span all bounded directions.
582 Matrix IntegerRelation::getBoundedDirections() const {
583   // Note that it is necessary to add the equalities too (which the constructor
584   // does) even though we don't need to check if they are bounded; whether an
585   // inequality is bounded or not depends on what other constraints, including
586   // equalities, are present.
587   Simplex simplex(*this);
588 
589   assert(!simplex.isEmpty() && "It is not meaningful to ask whether a "
590                                "direction is bounded in an empty set.");
591 
592   SmallVector<unsigned, 8> boundedIneqs;
593   // The constructor adds the inequalities to the simplex first, so this
594   // processes all the inequalities.
595   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
596     if (simplex.isBoundedAlongConstraint(i))
597       boundedIneqs.push_back(i);
598   }
599 
600   // The direction vector is given by the coefficients and does not include the
601   // constant term, so the matrix has one fewer column.
602   unsigned dirsNumCols = getNumCols() - 1;
603   Matrix dirs(boundedIneqs.size() + getNumEqualities(), dirsNumCols);
604 
605   // Copy the bounded inequalities.
606   unsigned row = 0;
607   for (unsigned i : boundedIneqs) {
608     for (unsigned col = 0; col < dirsNumCols; ++col)
609       dirs(row, col) = atIneq(i, col);
610     ++row;
611   }
612 
613   // Copy the equalities. All the equalities' perpendiculars are bounded.
614   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
615     for (unsigned col = 0; col < dirsNumCols; ++col)
616       dirs(row, col) = atEq(i, col);
617     ++row;
618   }
619 
620   return dirs;
621 }
622 
623 bool IntegerRelation::isIntegerEmpty() const {
624   return !findIntegerSample().hasValue();
625 }
626 
627 /// Let this set be S. If S is bounded then we directly call into the GBR
628 /// sampling algorithm. Otherwise, there are some unbounded directions, i.e.,
629 /// vectors v such that S extends to infinity along v or -v. In this case we
630 /// use an algorithm described in the integer set library (isl) manual and used
631 /// by the isl_set_sample function in that library. The algorithm is:
632 ///
633 /// 1) Apply a unimodular transform T to S to obtain S*T, such that all
634 /// dimensions in which S*T is bounded lie in the linear span of a prefix of the
635 /// dimensions.
636 ///
637 /// 2) Construct a set B by removing all constraints that involve
638 /// the unbounded dimensions and then deleting the unbounded dimensions. Note
639 /// that B is a Bounded set.
640 ///
641 /// 3) Try to obtain a sample from B using the GBR sampling
642 /// algorithm. If no sample is found, return that S is empty.
643 ///
644 /// 4) Otherwise, substitute the obtained sample into S*T to obtain a set
645 /// C. C is a full-dimensional Cone and always contains a sample.
646 ///
647 /// 5) Obtain an integer sample from C.
648 ///
649 /// 6) Return T*v, where v is the concatenation of the samples from B and C.
650 ///
651 /// The following is a sketch of a proof that
652 /// a) If the algorithm returns empty, then S is empty.
653 /// b) If the algorithm returns a sample, it is a valid sample in S.
654 ///
655 /// The algorithm returns empty only if B is empty, in which case S*T is
656 /// certainly empty since B was obtained by removing constraints and then
657 /// deleting unconstrained dimensions from S*T. Since T is unimodular, a vector
658 /// v is in S*T iff T*v is in S. So in this case, since
659 /// S*T is empty, S is empty too.
660 ///
661 /// Otherwise, the algorithm substitutes the sample from B into S*T. All the
662 /// constraints of S*T that did not involve unbounded dimensions are satisfied
663 /// by this substitution. All dimensions in the linear span of the dimensions
664 /// outside the prefix are unbounded in S*T (step 1). Substituting values for
665 /// the bounded dimensions cannot make these dimensions bounded, and these are
666 /// the only remaining dimensions in C, so C is unbounded along every vector (in
667 /// the positive or negative direction, or both). C is hence a full-dimensional
668 /// cone and therefore always contains an integer point.
669 ///
670 /// Concatenating the samples from B and C gives a sample v in S*T, so the
671 /// returned sample T*v is a sample in S.
672 Optional<SmallVector<int64_t, 8>> IntegerRelation::findIntegerSample() const {
673   // First, try the GCD test heuristic.
674   if (isEmptyByGCDTest())
675     return {};
676 
677   Simplex simplex(*this);
678   if (simplex.isEmpty())
679     return {};
680 
681   // For a bounded set, we directly call into the GBR sampling algorithm.
682   if (!simplex.isUnbounded())
683     return simplex.findIntegerSample();
684 
685   // The set is unbounded. We cannot directly use the GBR algorithm.
686   //
687   // m is a matrix containing, in each row, a vector in which S is
688   // bounded, such that the linear span of all these dimensions contains all
689   // bounded dimensions in S.
690   Matrix m = getBoundedDirections();
691   // In column echelon form, each row of m occupies only the first rank(m)
692   // columns and has zeros on the other columns. The transform T that brings S
693   // to column echelon form is unimodular as well, so this is a suitable
694   // transform to use in step 1 of the algorithm.
695   std::pair<unsigned, LinearTransform> result =
696       LinearTransform::makeTransformToColumnEchelon(std::move(m));
697   const LinearTransform &transform = result.second;
698   // 1) Apply T to S to obtain S*T.
699   IntegerRelation transformedSet = transform.applyTo(*this);
700 
701   // 2) Remove the unbounded dimensions and constraints involving them to
702   // obtain a bounded set.
703   IntegerRelation boundedSet(transformedSet);
704   unsigned numBoundedDims = result.first;
705   unsigned numUnboundedDims = getNumIds() - numBoundedDims;
706   removeConstraintsInvolvingIdRange(boundedSet, numBoundedDims,
707                                     numUnboundedDims);
708   boundedSet.removeIdRange(numBoundedDims, boundedSet.getNumIds());
709 
710   // 3) Try to obtain a sample from the bounded set.
711   Optional<SmallVector<int64_t, 8>> boundedSample =
712       Simplex(boundedSet).findIntegerSample();
713   if (!boundedSample)
714     return {};
715   assert(boundedSet.containsPoint(*boundedSample) &&
716          "Simplex returned an invalid sample!");
717 
718   // 4) Substitute the values of the bounded dimensions into S*T to obtain a
719   // full-dimensional cone, which necessarily contains an integer sample.
720   transformedSet.setAndEliminate(0, *boundedSample);
721   IntegerRelation &cone = transformedSet;
722 
723   // 5) Obtain an integer sample from the cone.
724   //
725   // We shrink the cone such that for any rational point in the shrunken cone,
726   // rounding up each of the point's coordinates produces a point that still
727   // lies in the original cone.
728   //
729   // Rounding up a point x adds a number e_i in [0, 1) to each coordinate x_i.
730   // For each inequality sum_i a_i x_i + c >= 0 in the original cone, the
731   // shrunken cone will have the inequality tightened by some amount s, such
732   // that if x satisfies the shrunken cone's tightened inequality, then x + e
733   // satisfies the original inequality, i.e.,
734   //
735   // sum_i a_i x_i + c + s >= 0 implies sum_i a_i (x_i + e_i) + c >= 0
736   //
737   // for any e_i values in [0, 1). In fact, we will handle the slightly more
738   // general case where e_i can be in [0, 1]. For example, consider the
739   // inequality 2x_1 - 3x_2 - 7x_3 - 6 >= 0, and let x = (3, 0, 0). How low
740   // could the LHS go if we added a number in [0, 1] to each coordinate? The LHS
741   // is minimized when we add 1 to the x_i with negative coefficient a_i and
742   // keep the other x_i the same. In the example, we would get x = (3, 1, 1),
743   // changing the value of the LHS by -3 + -7 = -10.
744   //
745   // In general, the value of the LHS can change by at most the sum of the
746   // negative a_i, so we accomodate this by shifting the inequality by this
747   // amount for the shrunken cone.
748   for (unsigned i = 0, e = cone.getNumInequalities(); i < e; ++i) {
749     for (unsigned j = 0; j < cone.getNumIds(); ++j) {
750       int64_t coeff = cone.atIneq(i, j);
751       if (coeff < 0)
752         cone.atIneq(i, cone.getNumIds()) += coeff;
753     }
754   }
755 
756   // Obtain an integer sample in the cone by rounding up a rational point from
757   // the shrunken cone. Shrinking the cone amounts to shifting its apex
758   // "inwards" without changing its "shape"; the shrunken cone is still a
759   // full-dimensional cone and is hence non-empty.
760   Simplex shrunkenConeSimplex(cone);
761   assert(!shrunkenConeSimplex.isEmpty() && "Shrunken cone cannot be empty!");
762 
763   // The sample will always exist since the shrunken cone is non-empty.
764   SmallVector<Fraction, 8> shrunkenConeSample =
765       *shrunkenConeSimplex.getRationalSample();
766 
767   SmallVector<int64_t, 8> coneSample(llvm::map_range(shrunkenConeSample, ceil));
768 
769   // 6) Return transform * concat(boundedSample, coneSample).
770   SmallVector<int64_t, 8> &sample = boundedSample.getValue();
771   sample.append(coneSample.begin(), coneSample.end());
772   return transform.postMultiplyWithColumn(sample);
773 }
774 
775 /// Helper to evaluate an affine expression at a point.
776 /// The expression is a list of coefficients for the dimensions followed by the
777 /// constant term.
778 static int64_t valueAt(ArrayRef<int64_t> expr, ArrayRef<int64_t> point) {
779   assert(expr.size() == 1 + point.size() &&
780          "Dimensionalities of point and expression don't match!");
781   int64_t value = expr.back();
782   for (unsigned i = 0; i < point.size(); ++i)
783     value += expr[i] * point[i];
784   return value;
785 }
786 
787 /// A point satisfies an equality iff the value of the equality at the
788 /// expression is zero, and it satisfies an inequality iff the value of the
789 /// inequality at that point is non-negative.
790 bool IntegerRelation::containsPoint(ArrayRef<int64_t> point) const {
791   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
792     if (valueAt(getEquality(i), point) != 0)
793       return false;
794   }
795   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
796     if (valueAt(getInequality(i), point) < 0)
797       return false;
798   }
799   return true;
800 }
801 
802 /// Just substitute the values given and check if an integer sample exists for
803 /// the local ids.
804 ///
805 /// TODO: this could be made more efficient by handling divisions separately.
806 /// Instead of finding an integer sample over all the locals, we can first
807 /// compute the values of the locals that have division representations and
808 /// only use the integer emptiness check for the locals that don't have this.
809 /// Handling this correctly requires ordering the divs, though.
810 Optional<SmallVector<int64_t, 8>>
811 IntegerRelation::containsPointNoLocal(ArrayRef<int64_t> point) const {
812   assert(point.size() == getNumIds() - getNumLocalIds() &&
813          "Point should contain all ids except locals!");
814   assert(getIdKindOffset(IdKind::Local) == getNumIds() - getNumLocalIds() &&
815          "This function depends on locals being stored last!");
816   IntegerRelation copy = *this;
817   copy.setAndEliminate(0, point);
818   return copy.findIntegerSample();
819 }
820 
821 void IntegerRelation::getLocalReprs(std::vector<MaybeLocalRepr> &repr) const {
822   std::vector<SmallVector<int64_t, 8>> dividends(getNumLocalIds());
823   SmallVector<unsigned, 4> denominators(getNumLocalIds());
824   getLocalReprs(dividends, denominators, repr);
825 }
826 
827 void IntegerRelation::getLocalReprs(
828     std::vector<SmallVector<int64_t, 8>> &dividends,
829     SmallVector<unsigned, 4> &denominators) const {
830   std::vector<MaybeLocalRepr> repr(getNumLocalIds());
831   getLocalReprs(dividends, denominators, repr);
832 }
833 
834 void IntegerRelation::getLocalReprs(
835     std::vector<SmallVector<int64_t, 8>> &dividends,
836     SmallVector<unsigned, 4> &denominators,
837     std::vector<MaybeLocalRepr> &repr) const {
838 
839   repr.resize(getNumLocalIds());
840   dividends.resize(getNumLocalIds());
841   denominators.resize(getNumLocalIds());
842 
843   SmallVector<bool, 8> foundRepr(getNumIds(), false);
844   for (unsigned i = 0, e = getNumDimAndSymbolIds(); i < e; ++i)
845     foundRepr[i] = true;
846 
847   unsigned divOffset = getNumDimAndSymbolIds();
848   bool changed;
849   do {
850     // Each time changed is true, at end of this iteration, one or more local
851     // vars have been detected as floor divs.
852     changed = false;
853     for (unsigned i = 0, e = getNumLocalIds(); i < e; ++i) {
854       if (!foundRepr[i + divOffset]) {
855         MaybeLocalRepr res = computeSingleVarRepr(
856             *this, foundRepr, divOffset + i, dividends[i], denominators[i]);
857         if (!res)
858           continue;
859         foundRepr[i + divOffset] = true;
860         repr[i] = res;
861         changed = true;
862       }
863     }
864   } while (changed);
865 
866   // Set 0 denominator for identifiers for which no division representation
867   // could be found.
868   for (unsigned i = 0, e = repr.size(); i < e; ++i)
869     if (!repr[i])
870       denominators[i] = 0;
871 }
872 
873 /// Tightens inequalities given that we are dealing with integer spaces. This is
874 /// analogous to the GCD test but applied to inequalities. The constant term can
875 /// be reduced to the preceding multiple of the GCD of the coefficients, i.e.,
876 ///  64*i - 100 >= 0  =>  64*i - 128 >= 0 (since 'i' is an integer). This is a
877 /// fast method - linear in the number of coefficients.
878 // Example on how this affects practical cases: consider the scenario:
879 // 64*i >= 100, j = 64*i; without a tightening, elimination of i would yield
880 // j >= 100 instead of the tighter (exact) j >= 128.
881 void IntegerRelation::gcdTightenInequalities() {
882   unsigned numCols = getNumCols();
883   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
884     // Normalize the constraint and tighten the constant term by the GCD.
885     uint64_t gcd = inequalities.normalizeRow(i, getNumCols() - 1);
886     if (gcd > 1)
887       atIneq(i, numCols - 1) = mlir::floorDiv(atIneq(i, numCols - 1), gcd);
888   }
889 }
890 
891 // Eliminates all identifier variables in column range [posStart, posLimit).
892 // Returns the number of variables eliminated.
893 unsigned IntegerRelation::gaussianEliminateIds(unsigned posStart,
894                                                unsigned posLimit) {
895   // Return if identifier positions to eliminate are out of range.
896   assert(posLimit <= getNumIds());
897   assert(hasConsistentState());
898 
899   if (posStart >= posLimit)
900     return 0;
901 
902   gcdTightenInequalities();
903 
904   unsigned pivotCol = 0;
905   for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) {
906     // Find a row which has a non-zero coefficient in column 'j'.
907     unsigned pivotRow;
908     if (!findConstraintWithNonZeroAt(pivotCol, /*isEq=*/true, &pivotRow)) {
909       // No pivot row in equalities with non-zero at 'pivotCol'.
910       if (!findConstraintWithNonZeroAt(pivotCol, /*isEq=*/false, &pivotRow)) {
911         // If inequalities are also non-zero in 'pivotCol', it can be
912         // eliminated.
913         continue;
914       }
915       break;
916     }
917 
918     // Eliminate identifier at 'pivotCol' from each equality row.
919     for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
920       eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
921                               /*isEq=*/true);
922       equalities.normalizeRow(i);
923     }
924 
925     // Eliminate identifier at 'pivotCol' from each inequality row.
926     for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
927       eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
928                               /*isEq=*/false);
929       inequalities.normalizeRow(i);
930     }
931     removeEquality(pivotRow);
932     gcdTightenInequalities();
933   }
934   // Update position limit based on number eliminated.
935   posLimit = pivotCol;
936   // Remove eliminated columns from all constraints.
937   removeIdRange(posStart, posLimit);
938   return posLimit - posStart;
939 }
940 
941 // A more complex check to eliminate redundant inequalities. Uses FourierMotzkin
942 // to check if a constraint is redundant.
943 void IntegerRelation::removeRedundantInequalities() {
944   SmallVector<bool, 32> redun(getNumInequalities(), false);
945   // To check if an inequality is redundant, we replace the inequality by its
946   // complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting
947   // system is empty. If it is, the inequality is redundant.
948   IntegerRelation tmpCst(*this);
949   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
950     // Change the inequality to its complement.
951     tmpCst.inequalities.negateRow(r);
952     tmpCst.atIneq(r, tmpCst.getNumCols() - 1)--;
953     if (tmpCst.isEmpty()) {
954       redun[r] = true;
955       // Zero fill the redundant inequality.
956       inequalities.fillRow(r, /*value=*/0);
957       tmpCst.inequalities.fillRow(r, /*value=*/0);
958     } else {
959       // Reverse the change (to avoid recreating tmpCst each time).
960       tmpCst.atIneq(r, tmpCst.getNumCols() - 1)++;
961       tmpCst.inequalities.negateRow(r);
962     }
963   }
964 
965   unsigned pos = 0;
966   for (unsigned r = 0, e = getNumInequalities(); r < e; ++r) {
967     if (!redun[r])
968       inequalities.copyRow(r, pos++);
969   }
970   inequalities.resizeVertically(pos);
971 }
972 
973 // A more complex check to eliminate redundant inequalities and equalities. Uses
974 // Simplex to check if a constraint is redundant.
975 void IntegerRelation::removeRedundantConstraints() {
976   // First, we run gcdTightenInequalities. This allows us to catch some
977   // constraints which are not redundant when considering rational solutions
978   // but are redundant in terms of integer solutions.
979   gcdTightenInequalities();
980   Simplex simplex(*this);
981   simplex.detectRedundant();
982 
983   unsigned pos = 0;
984   unsigned numIneqs = getNumInequalities();
985   // Scan to get rid of all inequalities marked redundant, in-place. In Simplex,
986   // the first constraints added are the inequalities.
987   for (unsigned r = 0; r < numIneqs; r++) {
988     if (!simplex.isMarkedRedundant(r))
989       inequalities.copyRow(r, pos++);
990   }
991   inequalities.resizeVertically(pos);
992 
993   // Scan to get rid of all equalities marked redundant, in-place. In Simplex,
994   // after the inequalities, a pair of constraints for each equality is added.
995   // An equality is redundant if both the inequalities in its pair are
996   // redundant.
997   pos = 0;
998   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
999     if (!(simplex.isMarkedRedundant(numIneqs + 2 * r) &&
1000           simplex.isMarkedRedundant(numIneqs + 2 * r + 1)))
1001       equalities.copyRow(r, pos++);
1002   }
1003   equalities.resizeVertically(pos);
1004 }
1005 
1006 Optional<uint64_t> IntegerRelation::computeVolume() const {
1007   assert(getNumSymbolIds() == 0 && "Symbols are not yet supported!");
1008 
1009   Simplex simplex(*this);
1010   // If the polytope is rationally empty, there are certainly no integer
1011   // points.
1012   if (simplex.isEmpty())
1013     return 0;
1014 
1015   // Just find the maximum and minimum integer value of each non-local id
1016   // separately, thus finding the number of integer values each such id can
1017   // take. Multiplying these together gives a valid overapproximation of the
1018   // number of integer points in the relation. The result this gives is
1019   // equivalent to projecting (rationally) the relation onto its non-local ids
1020   // and returning the number of integer points in a minimal axis-parallel
1021   // hyperrectangular overapproximation of that.
1022   //
1023   // We also handle the special case where one dimension is unbounded and
1024   // another dimension can take no integer values. In this case, the volume is
1025   // zero.
1026   //
1027   // If there is no such empty dimension, if any dimension is unbounded we
1028   // just return the result as unbounded.
1029   uint64_t count = 1;
1030   SmallVector<int64_t, 8> dim(getNumIds() + 1);
1031   bool hasUnboundedId = false;
1032   for (unsigned i = 0, e = getNumDimAndSymbolIds(); i < e; ++i) {
1033     dim[i] = 1;
1034     MaybeOptimum<int64_t> min, max;
1035     std::tie(min, max) = simplex.computeIntegerBounds(dim);
1036     dim[i] = 0;
1037 
1038     assert((!min.isEmpty() && !max.isEmpty()) &&
1039            "Polytope should be rationally non-empty!");
1040 
1041     // One of the dimensions is unbounded. Note this fact. We will return
1042     // unbounded if none of the other dimensions makes the volume zero.
1043     if (min.isUnbounded() || max.isUnbounded()) {
1044       hasUnboundedId = true;
1045       continue;
1046     }
1047 
1048     // In this case there are no valid integer points and the volume is
1049     // definitely zero.
1050     if (min.getBoundedOptimum() > max.getBoundedOptimum())
1051       return 0;
1052 
1053     count *= (*max - *min + 1);
1054   }
1055 
1056   if (count == 0)
1057     return 0;
1058   if (hasUnboundedId)
1059     return {};
1060   return count;
1061 }
1062 
1063 void IntegerRelation::eliminateRedundantLocalId(unsigned posA, unsigned posB) {
1064   assert(posA < getNumLocalIds() && "Invalid local id position");
1065   assert(posB < getNumLocalIds() && "Invalid local id position");
1066 
1067   unsigned localOffset = getIdKindOffset(IdKind::Local);
1068   posA += localOffset;
1069   posB += localOffset;
1070   inequalities.addToColumn(posB, posA, 1);
1071   equalities.addToColumn(posB, posA, 1);
1072   removeId(posB);
1073 }
1074 
1075 /// Adds additional local ids to the sets such that they both have the union
1076 /// of the local ids in each set, without changing the set of points that
1077 /// lie in `this` and `other`.
1078 ///
1079 /// To detect local ids that always take the same in both sets, each local id is
1080 /// represented as a floordiv with constant denominator in terms of other ids.
1081 /// After extracting these divisions, local ids with the same division
1082 /// representation are considered duplicate and are merged. It is possible that
1083 /// division representation for some local id cannot be obtained, and thus these
1084 /// local ids are not considered for detecting duplicates.
1085 void IntegerRelation::mergeLocalIds(IntegerRelation &other) {
1086   assert(isSpaceCompatible(other) && "Spaces should be compatible.");
1087 
1088   IntegerRelation &relA = *this;
1089   IntegerRelation &relB = other;
1090 
1091   // Merge local ids of relA and relB without using division information,
1092   // i.e. append local ids of `relB` to `relA` and insert local ids of `relA`
1093   // to `relB` at start of its local ids.
1094   unsigned initLocals = relA.getNumLocalIds();
1095   insertId(IdKind::Local, relA.getNumLocalIds(), relB.getNumLocalIds());
1096   relB.insertId(IdKind::Local, 0, initLocals);
1097 
1098   // Get division representations from each rel.
1099   std::vector<SmallVector<int64_t, 8>> divsA, divsB;
1100   SmallVector<unsigned, 4> denomsA, denomsB;
1101   relA.getLocalReprs(divsA, denomsA);
1102   relB.getLocalReprs(divsB, denomsB);
1103 
1104   // Copy division information for relB into `divsA` and `denomsA`, so that
1105   // these have the combined division information of both rels. Since newly
1106   // added local variables in relA and relB have no constraints, they will not
1107   // have any division representation.
1108   std::copy(divsB.begin() + initLocals, divsB.end(),
1109             divsA.begin() + initLocals);
1110   std::copy(denomsB.begin() + initLocals, denomsB.end(),
1111             denomsA.begin() + initLocals);
1112 
1113   // Merge function that merges the local variables in both sets by treating
1114   // them as the same identifier.
1115   auto merge = [&relA, &relB](unsigned i, unsigned j) -> bool {
1116     relA.eliminateRedundantLocalId(i, j);
1117     relB.eliminateRedundantLocalId(i, j);
1118     return true;
1119   };
1120 
1121   // Merge all divisions by removing duplicate divisions.
1122   unsigned localOffset = getIdKindOffset(IdKind::Local);
1123   presburger::removeDuplicateDivs(divsA, denomsA, localOffset, merge);
1124 }
1125 
1126 void IntegerRelation::removeDuplicateDivs() {
1127   std::vector<SmallVector<int64_t, 8>> divs;
1128   SmallVector<unsigned, 4> denoms;
1129 
1130   getLocalReprs(divs, denoms);
1131   auto merge = [this](unsigned i, unsigned j) -> bool {
1132     eliminateRedundantLocalId(i, j);
1133     return true;
1134   };
1135   presburger::removeDuplicateDivs(divs, denoms, getIdKindOffset(IdKind::Local),
1136                                   merge);
1137 }
1138 
1139 /// Removes local variables using equalities. Each equality is checked if it
1140 /// can be reduced to the form: `e = affine-expr`, where `e` is a local
1141 /// variable and `affine-expr` is an affine expression not containing `e`.
1142 /// If an equality satisfies this form, the local variable is replaced in
1143 /// each constraint and then removed. The equality used to replace this local
1144 /// variable is also removed.
1145 void IntegerRelation::removeRedundantLocalVars() {
1146   // Normalize the equality constraints to reduce coefficients of local
1147   // variables to 1 wherever possible.
1148   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
1149     equalities.normalizeRow(i);
1150 
1151   while (true) {
1152     unsigned i, e, j, f;
1153     for (i = 0, e = getNumEqualities(); i < e; ++i) {
1154       // Find a local variable to eliminate using ith equality.
1155       for (j = getNumDimAndSymbolIds(), f = getNumIds(); j < f; ++j)
1156         if (std::abs(atEq(i, j)) == 1)
1157           break;
1158 
1159       // Local variable can be eliminated using ith equality.
1160       if (j < f)
1161         break;
1162     }
1163 
1164     // No equality can be used to eliminate a local variable.
1165     if (i == e)
1166       break;
1167 
1168     // Use the ith equality to simplify other equalities. If any changes
1169     // are made to an equality constraint, it is normalized by GCD.
1170     for (unsigned k = 0, t = getNumEqualities(); k < t; ++k) {
1171       if (atEq(k, j) != 0) {
1172         eliminateFromConstraint(this, k, i, j, j, /*isEq=*/true);
1173         equalities.normalizeRow(k);
1174       }
1175     }
1176 
1177     // Use the ith equality to simplify inequalities.
1178     for (unsigned k = 0, t = getNumInequalities(); k < t; ++k)
1179       eliminateFromConstraint(this, k, i, j, j, /*isEq=*/false);
1180 
1181     // Remove the ith equality and the found local variable.
1182     removeId(j);
1183     removeEquality(i);
1184   }
1185 }
1186 
1187 void IntegerRelation::convertIdKind(IdKind srcKind, unsigned idStart,
1188                                     unsigned idLimit, IdKind dstKind) {
1189   assert(idLimit <= getNumIdKind(srcKind) && "Invalid id range");
1190 
1191   if (idStart >= idLimit)
1192     return;
1193 
1194   // Append new local variables corresponding to the dimensions to be converted.
1195   unsigned newIdsBegin = getIdKindEnd(dstKind);
1196   unsigned convertCount = idLimit - idStart;
1197   appendId(dstKind, convertCount);
1198 
1199   // Swap the new local variables with dimensions.
1200   //
1201   // Essentially, this moves the information corresponding to the specified ids
1202   // of kind `srcKind` to the `convertCount` newly created ids of kind
1203   // `dstKind`. In particular, this moves the columns in the constraint
1204   // matrices, and zeros out the initially occupied columns (because the newly
1205   // created ids we're swapping with were zero-initialized).
1206   unsigned offset = getIdKindOffset(srcKind);
1207   for (unsigned i = 0; i < convertCount; ++i)
1208     swapId(offset + idStart + i, newIdsBegin + i);
1209 
1210   // Complete the move by deleting the initially occupied columns.
1211   removeIdRange(srcKind, idStart, idLimit);
1212 }
1213 
1214 void IntegerRelation::addBound(BoundType type, unsigned pos, int64_t value) {
1215   assert(pos < getNumCols());
1216   if (type == BoundType::EQ) {
1217     unsigned row = equalities.appendExtraRow();
1218     equalities(row, pos) = 1;
1219     equalities(row, getNumCols() - 1) = -value;
1220   } else {
1221     unsigned row = inequalities.appendExtraRow();
1222     inequalities(row, pos) = type == BoundType::LB ? 1 : -1;
1223     inequalities(row, getNumCols() - 1) =
1224         type == BoundType::LB ? -value : value;
1225   }
1226 }
1227 
1228 void IntegerRelation::addBound(BoundType type, ArrayRef<int64_t> expr,
1229                                int64_t value) {
1230   assert(type != BoundType::EQ && "EQ not implemented");
1231   assert(expr.size() == getNumCols());
1232   unsigned row = inequalities.appendExtraRow();
1233   for (unsigned i = 0, e = expr.size(); i < e; ++i)
1234     inequalities(row, i) = type == BoundType::LB ? expr[i] : -expr[i];
1235   inequalities(inequalities.getNumRows() - 1, getNumCols() - 1) +=
1236       type == BoundType::LB ? -value : value;
1237 }
1238 
1239 /// Adds a new local identifier as the floordiv of an affine function of other
1240 /// identifiers, the coefficients of which are provided in 'dividend' and with
1241 /// respect to a positive constant 'divisor'. Two constraints are added to the
1242 /// system to capture equivalence with the floordiv.
1243 ///      q = expr floordiv c    <=>   c*q <= expr <= c*q + c - 1.
1244 void IntegerRelation::addLocalFloorDiv(ArrayRef<int64_t> dividend,
1245                                        int64_t divisor) {
1246   assert(dividend.size() == getNumCols() && "incorrect dividend size");
1247   assert(divisor > 0 && "positive divisor expected");
1248 
1249   appendId(IdKind::Local);
1250 
1251   // Add two constraints for this new identifier 'q'.
1252   SmallVector<int64_t, 8> bound(dividend.size() + 1);
1253 
1254   // dividend - q * divisor >= 0
1255   std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1,
1256             bound.begin());
1257   bound.back() = dividend.back();
1258   bound[getNumIds() - 1] = -divisor;
1259   addInequality(bound);
1260 
1261   // -dividend +qdivisor * q + divisor - 1 >= 0
1262   std::transform(bound.begin(), bound.end(), bound.begin(),
1263                  std::negate<int64_t>());
1264   bound[bound.size() - 1] += divisor - 1;
1265   addInequality(bound);
1266 }
1267 
1268 /// Finds an equality that equates the specified identifier to a constant.
1269 /// Returns the position of the equality row. If 'symbolic' is set to true,
1270 /// symbols are also treated like a constant, i.e., an affine function of the
1271 /// symbols is also treated like a constant. Returns -1 if such an equality
1272 /// could not be found.
1273 static int findEqualityToConstant(const IntegerRelation &cst, unsigned pos,
1274                                   bool symbolic = false) {
1275   assert(pos < cst.getNumIds() && "invalid position");
1276   for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
1277     int64_t v = cst.atEq(r, pos);
1278     if (v * v != 1)
1279       continue;
1280     unsigned c;
1281     unsigned f = symbolic ? cst.getNumDimIds() : cst.getNumIds();
1282     // This checks for zeros in all positions other than 'pos' in [0, f)
1283     for (c = 0; c < f; c++) {
1284       if (c == pos)
1285         continue;
1286       if (cst.atEq(r, c) != 0) {
1287         // Dependent on another identifier.
1288         break;
1289       }
1290     }
1291     if (c == f)
1292       // Equality is free of other identifiers.
1293       return r;
1294   }
1295   return -1;
1296 }
1297 
1298 LogicalResult IntegerRelation::constantFoldId(unsigned pos) {
1299   assert(pos < getNumIds() && "invalid position");
1300   int rowIdx;
1301   if ((rowIdx = findEqualityToConstant(*this, pos)) == -1)
1302     return failure();
1303 
1304   // atEq(rowIdx, pos) is either -1 or 1.
1305   assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1);
1306   int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos);
1307   setAndEliminate(pos, constVal);
1308   return success();
1309 }
1310 
1311 void IntegerRelation::constantFoldIdRange(unsigned pos, unsigned num) {
1312   for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) {
1313     if (failed(constantFoldId(t)))
1314       t++;
1315   }
1316 }
1317 
1318 /// Returns a non-negative constant bound on the extent (upper bound - lower
1319 /// bound) of the specified identifier if it is found to be a constant; returns
1320 /// None if it's not a constant. This methods treats symbolic identifiers
1321 /// specially, i.e., it looks for constant differences between affine
1322 /// expressions involving only the symbolic identifiers. See comments at
1323 /// function definition for example. 'lb', if provided, is set to the lower
1324 /// bound associated with the constant difference. Note that 'lb' is purely
1325 /// symbolic and thus will contain the coefficients of the symbolic identifiers
1326 /// and the constant coefficient.
1327 //  Egs: 0 <= i <= 15, return 16.
1328 //       s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol)
1329 //       s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16.
1330 //       s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb =
1331 //       ceil(s0 - 7 / 8) = floor(s0 / 8)).
1332 Optional<int64_t> IntegerRelation::getConstantBoundOnDimSize(
1333     unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *boundFloorDivisor,
1334     SmallVectorImpl<int64_t> *ub, unsigned *minLbPos,
1335     unsigned *minUbPos) const {
1336   assert(pos < getNumDimIds() && "Invalid identifier position");
1337 
1338   // Find an equality for 'pos'^th identifier that equates it to some function
1339   // of the symbolic identifiers (+ constant).
1340   int eqPos = findEqualityToConstant(*this, pos, /*symbolic=*/true);
1341   if (eqPos != -1) {
1342     auto eq = getEquality(eqPos);
1343     // If the equality involves a local var, punt for now.
1344     // TODO: this can be handled in the future by using the explicit
1345     // representation of the local vars.
1346     if (!std::all_of(eq.begin() + getNumDimAndSymbolIds(), eq.end() - 1,
1347                      [](int64_t coeff) { return coeff == 0; }))
1348       return None;
1349 
1350     // This identifier can only take a single value.
1351     if (lb) {
1352       // Set lb to that symbolic value.
1353       lb->resize(getNumSymbolIds() + 1);
1354       if (ub)
1355         ub->resize(getNumSymbolIds() + 1);
1356       for (unsigned c = 0, f = getNumSymbolIds() + 1; c < f; c++) {
1357         int64_t v = atEq(eqPos, pos);
1358         // atEq(eqRow, pos) is either -1 or 1.
1359         assert(v * v == 1);
1360         (*lb)[c] = v < 0 ? atEq(eqPos, getNumDimIds() + c) / -v
1361                          : -atEq(eqPos, getNumDimIds() + c) / v;
1362         // Since this is an equality, ub = lb.
1363         if (ub)
1364           (*ub)[c] = (*lb)[c];
1365       }
1366       assert(boundFloorDivisor &&
1367              "both lb and divisor or none should be provided");
1368       *boundFloorDivisor = 1;
1369     }
1370     if (minLbPos)
1371       *minLbPos = eqPos;
1372     if (minUbPos)
1373       *minUbPos = eqPos;
1374     return 1;
1375   }
1376 
1377   // Check if the identifier appears at all in any of the inequalities.
1378   unsigned r, e;
1379   for (r = 0, e = getNumInequalities(); r < e; r++) {
1380     if (atIneq(r, pos) != 0)
1381       break;
1382   }
1383   if (r == e)
1384     // If it doesn't, there isn't a bound on it.
1385     return None;
1386 
1387   // Positions of constraints that are lower/upper bounds on the variable.
1388   SmallVector<unsigned, 4> lbIndices, ubIndices;
1389 
1390   // Gather all symbolic lower bounds and upper bounds of the variable, i.e.,
1391   // the bounds can only involve symbolic (and local) identifiers. Since the
1392   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
1393   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
1394   getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices,
1395                                /*eqIndices=*/nullptr, /*offset=*/0,
1396                                /*num=*/getNumDimIds());
1397 
1398   Optional<int64_t> minDiff = None;
1399   unsigned minLbPosition = 0, minUbPosition = 0;
1400   for (auto ubPos : ubIndices) {
1401     for (auto lbPos : lbIndices) {
1402       // Look for a lower bound and an upper bound that only differ by a
1403       // constant, i.e., pairs of the form  0 <= c_pos - f(c_i's) <= diffConst.
1404       // For example, if ii is the pos^th variable, we are looking for
1405       // constraints like ii >= i, ii <= ii + 50, 50 being the difference. The
1406       // minimum among all such constant differences is kept since that's the
1407       // constant bounding the extent of the pos^th variable.
1408       unsigned j, e;
1409       for (j = 0, e = getNumCols() - 1; j < e; j++)
1410         if (atIneq(ubPos, j) != -atIneq(lbPos, j)) {
1411           break;
1412         }
1413       if (j < getNumCols() - 1)
1414         continue;
1415       int64_t diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) +
1416                                  atIneq(lbPos, getNumCols() - 1) + 1,
1417                              atIneq(lbPos, pos));
1418       // This bound is non-negative by definition.
1419       diff = std::max<int64_t>(diff, 0);
1420       if (minDiff == None || diff < minDiff) {
1421         minDiff = diff;
1422         minLbPosition = lbPos;
1423         minUbPosition = ubPos;
1424       }
1425     }
1426   }
1427   if (lb && minDiff.hasValue()) {
1428     // Set lb to the symbolic lower bound.
1429     lb->resize(getNumSymbolIds() + 1);
1430     if (ub)
1431       ub->resize(getNumSymbolIds() + 1);
1432     // The lower bound is the ceildiv of the lb constraint over the coefficient
1433     // of the variable at 'pos'. We express the ceildiv equivalently as a floor
1434     // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N +
1435     // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32).
1436     *boundFloorDivisor = atIneq(minLbPosition, pos);
1437     assert(*boundFloorDivisor == -atIneq(minUbPosition, pos));
1438     for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) {
1439       (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c);
1440     }
1441     if (ub) {
1442       for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++)
1443         (*ub)[c] = atIneq(minUbPosition, getNumDimIds() + c);
1444     }
1445     // The lower bound leads to a ceildiv while the upper bound is a floordiv
1446     // whenever the coefficient at pos != 1. ceildiv (val / d) = floordiv (val +
1447     // d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to
1448     // the constant term for the lower bound.
1449     (*lb)[getNumSymbolIds()] += atIneq(minLbPosition, pos) - 1;
1450   }
1451   if (minLbPos)
1452     *minLbPos = minLbPosition;
1453   if (minUbPos)
1454     *minUbPos = minUbPosition;
1455   return minDiff;
1456 }
1457 
1458 template <bool isLower>
1459 Optional<int64_t>
1460 IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) {
1461   assert(pos < getNumIds() && "invalid position");
1462   // Project to 'pos'.
1463   projectOut(0, pos);
1464   projectOut(1, getNumIds() - 1);
1465   // Check if there's an equality equating the '0'^th identifier to a constant.
1466   int eqRowIdx = findEqualityToConstant(*this, 0, /*symbolic=*/false);
1467   if (eqRowIdx != -1)
1468     // atEq(rowIdx, 0) is either -1 or 1.
1469     return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, 0);
1470 
1471   // Check if the identifier appears at all in any of the inequalities.
1472   unsigned r, e;
1473   for (r = 0, e = getNumInequalities(); r < e; r++) {
1474     if (atIneq(r, 0) != 0)
1475       break;
1476   }
1477   if (r == e)
1478     // If it doesn't, there isn't a bound on it.
1479     return None;
1480 
1481   Optional<int64_t> minOrMaxConst = None;
1482 
1483   // Take the max across all const lower bounds (or min across all constant
1484   // upper bounds).
1485   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1486     if (isLower) {
1487       if (atIneq(r, 0) <= 0)
1488         // Not a lower bound.
1489         continue;
1490     } else if (atIneq(r, 0) >= 0) {
1491       // Not an upper bound.
1492       continue;
1493     }
1494     unsigned c, f;
1495     for (c = 0, f = getNumCols() - 1; c < f; c++)
1496       if (c != 0 && atIneq(r, c) != 0)
1497         break;
1498     if (c < getNumCols() - 1)
1499       // Not a constant bound.
1500       continue;
1501 
1502     int64_t boundConst =
1503         isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0))
1504                 : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0));
1505     if (isLower) {
1506       if (minOrMaxConst == None || boundConst > minOrMaxConst)
1507         minOrMaxConst = boundConst;
1508     } else {
1509       if (minOrMaxConst == None || boundConst < minOrMaxConst)
1510         minOrMaxConst = boundConst;
1511     }
1512   }
1513   return minOrMaxConst;
1514 }
1515 
1516 Optional<int64_t> IntegerRelation::getConstantBound(BoundType type,
1517                                                     unsigned pos) const {
1518   if (type == BoundType::LB)
1519     return IntegerRelation(*this)
1520         .computeConstantLowerOrUpperBound</*isLower=*/true>(pos);
1521   if (type == BoundType::UB)
1522     return IntegerRelation(*this)
1523         .computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
1524 
1525   assert(type == BoundType::EQ && "expected EQ");
1526   Optional<int64_t> lb =
1527       IntegerRelation(*this).computeConstantLowerOrUpperBound</*isLower=*/true>(
1528           pos);
1529   Optional<int64_t> ub =
1530       IntegerRelation(*this)
1531           .computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
1532   return (lb && ub && *lb == *ub) ? Optional<int64_t>(*ub) : None;
1533 }
1534 
1535 // A simple (naive and conservative) check for hyper-rectangularity.
1536 bool IntegerRelation::isHyperRectangular(unsigned pos, unsigned num) const {
1537   assert(pos < getNumCols() - 1);
1538   // Check for two non-zero coefficients in the range [pos, pos + sum).
1539   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1540     unsigned sum = 0;
1541     for (unsigned c = pos; c < pos + num; c++) {
1542       if (atIneq(r, c) != 0)
1543         sum++;
1544     }
1545     if (sum > 1)
1546       return false;
1547   }
1548   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1549     unsigned sum = 0;
1550     for (unsigned c = pos; c < pos + num; c++) {
1551       if (atEq(r, c) != 0)
1552         sum++;
1553     }
1554     if (sum > 1)
1555       return false;
1556   }
1557   return true;
1558 }
1559 
1560 /// Removes duplicate constraints, trivially true constraints, and constraints
1561 /// that can be detected as redundant as a result of differing only in their
1562 /// constant term part. A constraint of the form <non-negative constant> >= 0 is
1563 /// considered trivially true.
1564 //  Uses a DenseSet to hash and detect duplicates followed by a linear scan to
1565 //  remove duplicates in place.
1566 void IntegerRelation::removeTrivialRedundancy() {
1567   gcdTightenInequalities();
1568   normalizeConstraintsByGCD();
1569 
1570   // A map used to detect redundancy stemming from constraints that only differ
1571   // in their constant term. The value stored is <row position, const term>
1572   // for a given row.
1573   SmallDenseMap<ArrayRef<int64_t>, std::pair<unsigned, int64_t>>
1574       rowsWithoutConstTerm;
1575   // To unique rows.
1576   SmallDenseSet<ArrayRef<int64_t>, 8> rowSet;
1577 
1578   // Check if constraint is of the form <non-negative-constant> >= 0.
1579   auto isTriviallyValid = [&](unsigned r) -> bool {
1580     for (unsigned c = 0, e = getNumCols() - 1; c < e; c++) {
1581       if (atIneq(r, c) != 0)
1582         return false;
1583     }
1584     return atIneq(r, getNumCols() - 1) >= 0;
1585   };
1586 
1587   // Detect and mark redundant constraints.
1588   SmallVector<bool, 256> redunIneq(getNumInequalities(), false);
1589   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1590     int64_t *rowStart = &inequalities(r, 0);
1591     auto row = ArrayRef<int64_t>(rowStart, getNumCols());
1592     if (isTriviallyValid(r) || !rowSet.insert(row).second) {
1593       redunIneq[r] = true;
1594       continue;
1595     }
1596 
1597     // Among constraints that only differ in the constant term part, mark
1598     // everything other than the one with the smallest constant term redundant.
1599     // (eg: among i - 16j - 5 >= 0, i - 16j - 1 >=0, i - 16j - 7 >= 0, the
1600     // former two are redundant).
1601     int64_t constTerm = atIneq(r, getNumCols() - 1);
1602     auto rowWithoutConstTerm = ArrayRef<int64_t>(rowStart, getNumCols() - 1);
1603     const auto &ret =
1604         rowsWithoutConstTerm.insert({rowWithoutConstTerm, {r, constTerm}});
1605     if (!ret.second) {
1606       // Check if the other constraint has a higher constant term.
1607       auto &val = ret.first->second;
1608       if (val.second > constTerm) {
1609         // The stored row is redundant. Mark it so, and update with this one.
1610         redunIneq[val.first] = true;
1611         val = {r, constTerm};
1612       } else {
1613         // The one stored makes this one redundant.
1614         redunIneq[r] = true;
1615       }
1616     }
1617   }
1618 
1619   // Scan to get rid of all rows marked redundant, in-place.
1620   unsigned pos = 0;
1621   for (unsigned r = 0, e = getNumInequalities(); r < e; r++)
1622     if (!redunIneq[r])
1623       inequalities.copyRow(r, pos++);
1624 
1625   inequalities.resizeVertically(pos);
1626 
1627   // TODO: consider doing this for equalities as well, but probably not worth
1628   // the savings.
1629 }
1630 
1631 #undef DEBUG_TYPE
1632 #define DEBUG_TYPE "fm"
1633 
1634 /// Eliminates identifier at the specified position using Fourier-Motzkin
1635 /// variable elimination. This technique is exact for rational spaces but
1636 /// conservative (in "rare" cases) for integer spaces. The operation corresponds
1637 /// to a projection operation yielding the (convex) set of integer points
1638 /// contained in the rational shadow of the set. An emptiness test that relies
1639 /// on this method will guarantee emptiness, i.e., it disproves the existence of
1640 /// a solution if it says it's empty.
1641 /// If a non-null isResultIntegerExact is passed, it is set to true if the
1642 /// result is also integer exact. If it's set to false, the obtained solution
1643 /// *may* not be exact, i.e., it may contain integer points that do not have an
1644 /// integer pre-image in the original set.
1645 ///
1646 /// Eg:
1647 /// j >= 0, j <= i + 1
1648 /// i >= 0, i <= N + 1
1649 /// Eliminating i yields,
1650 ///   j >= 0, 0 <= N + 1, j - 1 <= N + 1
1651 ///
1652 /// If darkShadow = true, this method computes the dark shadow on elimination;
1653 /// the dark shadow is a convex integer subset of the exact integer shadow. A
1654 /// non-empty dark shadow proves the existence of an integer solution. The
1655 /// elimination in such a case could however be an under-approximation, and thus
1656 /// should not be used for scanning sets or used by itself for dependence
1657 /// checking.
1658 ///
1659 /// Eg: 2-d set, * represents grid points, 'o' represents a point in the set.
1660 ///            ^
1661 ///            |
1662 ///            | * * * * o o
1663 ///         i  | * * o o o o
1664 ///            | o * * * * *
1665 ///            --------------->
1666 ///                 j ->
1667 ///
1668 /// Eliminating i from this system (projecting on the j dimension):
1669 /// rational shadow / integer light shadow:  1 <= j <= 6
1670 /// dark shadow:                             3 <= j <= 6
1671 /// exact integer shadow:                    j = 1 \union  3 <= j <= 6
1672 /// holes/splinters:                         j = 2
1673 ///
1674 /// darkShadow = false, isResultIntegerExact = nullptr are default values.
1675 // TODO: a slight modification to yield dark shadow version of FM (tightened),
1676 // which can prove the existence of a solution if there is one.
1677 void IntegerRelation::fourierMotzkinEliminate(unsigned pos, bool darkShadow,
1678                                               bool *isResultIntegerExact) {
1679   LLVM_DEBUG(llvm::dbgs() << "FM input (eliminate pos " << pos << "):\n");
1680   LLVM_DEBUG(dump());
1681   assert(pos < getNumIds() && "invalid position");
1682   assert(hasConsistentState());
1683 
1684   // Check if this identifier can be eliminated through a substitution.
1685   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1686     if (atEq(r, pos) != 0) {
1687       // Use Gaussian elimination here (since we have an equality).
1688       LogicalResult ret = gaussianEliminateId(pos);
1689       (void)ret;
1690       assert(succeeded(ret) && "Gaussian elimination guaranteed to succeed");
1691       LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n");
1692       LLVM_DEBUG(dump());
1693       return;
1694     }
1695   }
1696 
1697   // A fast linear time tightening.
1698   gcdTightenInequalities();
1699 
1700   // Check if the identifier appears at all in any of the inequalities.
1701   if (isColZero(pos)) {
1702     // If it doesn't appear, just remove the column and return.
1703     // TODO: refactor removeColumns to use it from here.
1704     removeId(pos);
1705     LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
1706     LLVM_DEBUG(dump());
1707     return;
1708   }
1709 
1710   // Positions of constraints that are lower bounds on the variable.
1711   SmallVector<unsigned, 4> lbIndices;
1712   // Positions of constraints that are lower bounds on the variable.
1713   SmallVector<unsigned, 4> ubIndices;
1714   // Positions of constraints that do not involve the variable.
1715   std::vector<unsigned> nbIndices;
1716   nbIndices.reserve(getNumInequalities());
1717 
1718   // Gather all lower bounds and upper bounds of the variable. Since the
1719   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
1720   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
1721   for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
1722     if (atIneq(r, pos) == 0) {
1723       // Id does not appear in bound.
1724       nbIndices.push_back(r);
1725     } else if (atIneq(r, pos) >= 1) {
1726       // Lower bound.
1727       lbIndices.push_back(r);
1728     } else {
1729       // Upper bound.
1730       ubIndices.push_back(r);
1731     }
1732   }
1733 
1734   PresburgerSpace newSpace = getSpace();
1735   IdKind idKindRemove = newSpace.getIdKindAt(pos);
1736   unsigned relativePos = pos - newSpace.getIdKindOffset(idKindRemove);
1737   newSpace.removeIdRange(idKindRemove, relativePos, relativePos + 1);
1738 
1739   /// Create the new system which has one identifier less.
1740   IntegerRelation newRel(lbIndices.size() * ubIndices.size() + nbIndices.size(),
1741                          getNumEqualities(), getNumCols() - 1, newSpace);
1742 
1743   // This will be used to check if the elimination was integer exact.
1744   unsigned lcmProducts = 1;
1745 
1746   // Let x be the variable we are eliminating.
1747   // For each lower bound, lb <= c_l*x, and each upper bound c_u*x <= ub, (note
1748   // that c_l, c_u >= 1) we have:
1749   // lb*lcm(c_l, c_u)/c_l <= lcm(c_l, c_u)*x <= ub*lcm(c_l, c_u)/c_u
1750   // We thus generate a constraint:
1751   // lcm(c_l, c_u)/c_l*lb <= lcm(c_l, c_u)/c_u*ub.
1752   // Note if c_l = c_u = 1, all integer points captured by the resulting
1753   // constraint correspond to integer points in the original system (i.e., they
1754   // have integer pre-images). Hence, if the lcm's are all 1, the elimination is
1755   // integer exact.
1756   for (auto ubPos : ubIndices) {
1757     for (auto lbPos : lbIndices) {
1758       SmallVector<int64_t, 4> ineq;
1759       ineq.reserve(newRel.getNumCols());
1760       int64_t lbCoeff = atIneq(lbPos, pos);
1761       // Note that in the comments above, ubCoeff is the negation of the
1762       // coefficient in the canonical form as the view taken here is that of the
1763       // term being moved to the other size of '>='.
1764       int64_t ubCoeff = -atIneq(ubPos, pos);
1765       // TODO: refactor this loop to avoid all branches inside.
1766       for (unsigned l = 0, e = getNumCols(); l < e; l++) {
1767         if (l == pos)
1768           continue;
1769         assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified");
1770         int64_t lcm = mlir::lcm(lbCoeff, ubCoeff);
1771         ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) +
1772                        atIneq(lbPos, l) * (lcm / lbCoeff));
1773         lcmProducts *= lcm;
1774       }
1775       if (darkShadow) {
1776         // The dark shadow is a convex subset of the exact integer shadow. If
1777         // there is a point here, it proves the existence of a solution.
1778         ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1;
1779       }
1780       // TODO: we need to have a way to add inequalities in-place in
1781       // IntegerRelation instead of creating and copying over.
1782       newRel.addInequality(ineq);
1783     }
1784   }
1785 
1786   LLVM_DEBUG(llvm::dbgs() << "FM isResultIntegerExact: " << (lcmProducts == 1)
1787                           << "\n");
1788   if (lcmProducts == 1 && isResultIntegerExact)
1789     *isResultIntegerExact = true;
1790 
1791   // Copy over the constraints not involving this variable.
1792   for (auto nbPos : nbIndices) {
1793     SmallVector<int64_t, 4> ineq;
1794     ineq.reserve(getNumCols() - 1);
1795     for (unsigned l = 0, e = getNumCols(); l < e; l++) {
1796       if (l == pos)
1797         continue;
1798       ineq.push_back(atIneq(nbPos, l));
1799     }
1800     newRel.addInequality(ineq);
1801   }
1802 
1803   assert(newRel.getNumConstraints() ==
1804          lbIndices.size() * ubIndices.size() + nbIndices.size());
1805 
1806   // Copy over the equalities.
1807   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
1808     SmallVector<int64_t, 4> eq;
1809     eq.reserve(newRel.getNumCols());
1810     for (unsigned l = 0, e = getNumCols(); l < e; l++) {
1811       if (l == pos)
1812         continue;
1813       eq.push_back(atEq(r, l));
1814     }
1815     newRel.addEquality(eq);
1816   }
1817 
1818   // GCD tightening and normalization allows detection of more trivially
1819   // redundant constraints.
1820   newRel.gcdTightenInequalities();
1821   newRel.normalizeConstraintsByGCD();
1822   newRel.removeTrivialRedundancy();
1823   clearAndCopyFrom(newRel);
1824   LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
1825   LLVM_DEBUG(dump());
1826 }
1827 
1828 #undef DEBUG_TYPE
1829 #define DEBUG_TYPE "presburger"
1830 
1831 void IntegerRelation::projectOut(unsigned pos, unsigned num) {
1832   if (num == 0)
1833     return;
1834 
1835   // 'pos' can be at most getNumCols() - 2 if num > 0.
1836   assert((getNumCols() < 2 || pos <= getNumCols() - 2) && "invalid position");
1837   assert(pos + num < getNumCols() && "invalid range");
1838 
1839   // Eliminate as many identifiers as possible using Gaussian elimination.
1840   unsigned currentPos = pos;
1841   unsigned numToEliminate = num;
1842   unsigned numGaussianEliminated = 0;
1843 
1844   while (currentPos < getNumIds()) {
1845     unsigned curNumEliminated =
1846         gaussianEliminateIds(currentPos, currentPos + numToEliminate);
1847     ++currentPos;
1848     numToEliminate -= curNumEliminated + 1;
1849     numGaussianEliminated += curNumEliminated;
1850   }
1851 
1852   // Eliminate the remaining using Fourier-Motzkin.
1853   for (unsigned i = 0; i < num - numGaussianEliminated; i++) {
1854     unsigned numToEliminate = num - numGaussianEliminated - i;
1855     fourierMotzkinEliminate(
1856         getBestIdToEliminate(*this, pos, pos + numToEliminate));
1857   }
1858 
1859   // Fast/trivial simplifications.
1860   gcdTightenInequalities();
1861   // Normalize constraints after tightening since the latter impacts this, but
1862   // not the other way round.
1863   normalizeConstraintsByGCD();
1864 }
1865 
1866 namespace {
1867 
1868 enum BoundCmpResult { Greater, Less, Equal, Unknown };
1869 
1870 /// Compares two affine bounds whose coefficients are provided in 'first' and
1871 /// 'second'. The last coefficient is the constant term.
1872 static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
1873   assert(a.size() == b.size());
1874 
1875   // For the bounds to be comparable, their corresponding identifier
1876   // coefficients should be equal; the constant terms are then compared to
1877   // determine less/greater/equal.
1878 
1879   if (!std::equal(a.begin(), a.end() - 1, b.begin()))
1880     return Unknown;
1881 
1882   if (a.back() == b.back())
1883     return Equal;
1884 
1885   return a.back() < b.back() ? Less : Greater;
1886 }
1887 } // namespace
1888 
1889 // Returns constraints that are common to both A & B.
1890 static void getCommonConstraints(const IntegerRelation &a,
1891                                  const IntegerRelation &b, IntegerRelation &c) {
1892   c = IntegerRelation(a.getSpace());
1893   // a naive O(n^2) check should be enough here given the input sizes.
1894   for (unsigned r = 0, e = a.getNumInequalities(); r < e; ++r) {
1895     for (unsigned s = 0, f = b.getNumInequalities(); s < f; ++s) {
1896       if (a.getInequality(r) == b.getInequality(s)) {
1897         c.addInequality(a.getInequality(r));
1898         break;
1899       }
1900     }
1901   }
1902   for (unsigned r = 0, e = a.getNumEqualities(); r < e; ++r) {
1903     for (unsigned s = 0, f = b.getNumEqualities(); s < f; ++s) {
1904       if (a.getEquality(r) == b.getEquality(s)) {
1905         c.addEquality(a.getEquality(r));
1906         break;
1907       }
1908     }
1909   }
1910 }
1911 
1912 // Computes the bounding box with respect to 'other' by finding the min of the
1913 // lower bounds and the max of the upper bounds along each of the dimensions.
1914 LogicalResult
1915 IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
1916   assert(isSpaceEqual(otherCst) && "Spaces should match.");
1917   assert(getNumLocalIds() == 0 && "local ids not supported yet here");
1918 
1919   // Get the constraints common to both systems; these will be added as is to
1920   // the union.
1921   IntegerRelation commonCst(PresburgerSpace::getRelationSpace());
1922   getCommonConstraints(*this, otherCst, commonCst);
1923 
1924   std::vector<SmallVector<int64_t, 8>> boundingLbs;
1925   std::vector<SmallVector<int64_t, 8>> boundingUbs;
1926   boundingLbs.reserve(2 * getNumDimIds());
1927   boundingUbs.reserve(2 * getNumDimIds());
1928 
1929   // To hold lower and upper bounds for each dimension.
1930   SmallVector<int64_t, 4> lb, otherLb, ub, otherUb;
1931   // To compute min of lower bounds and max of upper bounds for each dimension.
1932   SmallVector<int64_t, 4> minLb(getNumSymbolIds() + 1);
1933   SmallVector<int64_t, 4> maxUb(getNumSymbolIds() + 1);
1934   // To compute final new lower and upper bounds for the union.
1935   SmallVector<int64_t, 8> newLb(getNumCols()), newUb(getNumCols());
1936 
1937   int64_t lbFloorDivisor, otherLbFloorDivisor;
1938   for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
1939     auto extent = getConstantBoundOnDimSize(d, &lb, &lbFloorDivisor, &ub);
1940     if (!extent.hasValue())
1941       // TODO: symbolic extents when necessary.
1942       // TODO: handle union if a dimension is unbounded.
1943       return failure();
1944 
1945     auto otherExtent = otherCst.getConstantBoundOnDimSize(
1946         d, &otherLb, &otherLbFloorDivisor, &otherUb);
1947     if (!otherExtent.hasValue() || lbFloorDivisor != otherLbFloorDivisor)
1948       // TODO: symbolic extents when necessary.
1949       return failure();
1950 
1951     assert(lbFloorDivisor > 0 && "divisor always expected to be positive");
1952 
1953     auto res = compareBounds(lb, otherLb);
1954     // Identify min.
1955     if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) {
1956       minLb = lb;
1957       // Since the divisor is for a floordiv, we need to convert to ceildiv,
1958       // i.e., i >= expr floordiv div <=> i >= (expr - div + 1) ceildiv div <=>
1959       // div * i >= expr - div + 1.
1960       minLb.back() -= lbFloorDivisor - 1;
1961     } else if (res == BoundCmpResult::Greater) {
1962       minLb = otherLb;
1963       minLb.back() -= otherLbFloorDivisor - 1;
1964     } else {
1965       // Uncomparable - check for constant lower/upper bounds.
1966       auto constLb = getConstantBound(BoundType::LB, d);
1967       auto constOtherLb = otherCst.getConstantBound(BoundType::LB, d);
1968       if (!constLb.hasValue() || !constOtherLb.hasValue())
1969         return failure();
1970       std::fill(minLb.begin(), minLb.end(), 0);
1971       minLb.back() = std::min(constLb.getValue(), constOtherLb.getValue());
1972     }
1973 
1974     // Do the same for ub's but max of upper bounds. Identify max.
1975     auto uRes = compareBounds(ub, otherUb);
1976     if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) {
1977       maxUb = ub;
1978     } else if (uRes == BoundCmpResult::Less) {
1979       maxUb = otherUb;
1980     } else {
1981       // Uncomparable - check for constant lower/upper bounds.
1982       auto constUb = getConstantBound(BoundType::UB, d);
1983       auto constOtherUb = otherCst.getConstantBound(BoundType::UB, d);
1984       if (!constUb.hasValue() || !constOtherUb.hasValue())
1985         return failure();
1986       std::fill(maxUb.begin(), maxUb.end(), 0);
1987       maxUb.back() = std::max(constUb.getValue(), constOtherUb.getValue());
1988     }
1989 
1990     std::fill(newLb.begin(), newLb.end(), 0);
1991     std::fill(newUb.begin(), newUb.end(), 0);
1992 
1993     // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor,
1994     // and so it's the divisor for newLb and newUb as well.
1995     newLb[d] = lbFloorDivisor;
1996     newUb[d] = -lbFloorDivisor;
1997     // Copy over the symbolic part + constant term.
1998     std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimIds());
1999     std::transform(newLb.begin() + getNumDimIds(), newLb.end(),
2000                    newLb.begin() + getNumDimIds(), std::negate<int64_t>());
2001     std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimIds());
2002 
2003     boundingLbs.push_back(newLb);
2004     boundingUbs.push_back(newUb);
2005   }
2006 
2007   // Clear all constraints and add the lower/upper bounds for the bounding box.
2008   clearConstraints();
2009   for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
2010     addInequality(boundingLbs[d]);
2011     addInequality(boundingUbs[d]);
2012   }
2013 
2014   // Add the constraints that were common to both systems.
2015   append(commonCst);
2016   removeTrivialRedundancy();
2017 
2018   // TODO: copy over pure symbolic constraints from this and 'other' over to the
2019   // union (since the above are just the union along dimensions); we shouldn't
2020   // be discarding any other constraints on the symbols.
2021 
2022   return success();
2023 }
2024 
2025 bool IntegerRelation::isColZero(unsigned pos) const {
2026   unsigned rowPos;
2027   return !findConstraintWithNonZeroAt(pos, /*isEq=*/false, &rowPos) &&
2028          !findConstraintWithNonZeroAt(pos, /*isEq=*/true, &rowPos);
2029 }
2030 
2031 /// Find positions of inequalities and equalities that do not have a coefficient
2032 /// for [pos, pos + num) identifiers.
2033 static void getIndependentConstraints(const IntegerRelation &cst, unsigned pos,
2034                                       unsigned num,
2035                                       SmallVectorImpl<unsigned> &nbIneqIndices,
2036                                       SmallVectorImpl<unsigned> &nbEqIndices) {
2037   assert(pos < cst.getNumIds() && "invalid start position");
2038   assert(pos + num <= cst.getNumIds() && "invalid limit");
2039 
2040   for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
2041     // The bounds are to be independent of [offset, offset + num) columns.
2042     unsigned c;
2043     for (c = pos; c < pos + num; ++c) {
2044       if (cst.atIneq(r, c) != 0)
2045         break;
2046     }
2047     if (c == pos + num)
2048       nbIneqIndices.push_back(r);
2049   }
2050 
2051   for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
2052     // The bounds are to be independent of [offset, offset + num) columns.
2053     unsigned c;
2054     for (c = pos; c < pos + num; ++c) {
2055       if (cst.atEq(r, c) != 0)
2056         break;
2057     }
2058     if (c == pos + num)
2059       nbEqIndices.push_back(r);
2060   }
2061 }
2062 
2063 void IntegerRelation::removeIndependentConstraints(unsigned pos, unsigned num) {
2064   assert(pos + num <= getNumIds() && "invalid range");
2065 
2066   // Remove constraints that are independent of these identifiers.
2067   SmallVector<unsigned, 4> nbIneqIndices, nbEqIndices;
2068   getIndependentConstraints(*this, /*pos=*/0, num, nbIneqIndices, nbEqIndices);
2069 
2070   // Iterate in reverse so that indices don't have to be updated.
2071   // TODO: This method can be made more efficient (because removal of each
2072   // inequality leads to much shifting/copying in the underlying buffer).
2073   for (auto nbIndex : llvm::reverse(nbIneqIndices))
2074     removeInequality(nbIndex);
2075   for (auto nbIndex : llvm::reverse(nbEqIndices))
2076     removeEquality(nbIndex);
2077 }
2078 
2079 void IntegerRelation::printSpace(raw_ostream &os) const {
2080   PresburgerSpace::print(os);
2081   os << getNumConstraints() << " constraints\n";
2082 }
2083 
2084 void IntegerRelation::print(raw_ostream &os) const {
2085   assert(hasConsistentState());
2086   printSpace(os);
2087   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
2088     for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
2089       os << atEq(i, j) << " ";
2090     }
2091     os << "= 0\n";
2092   }
2093   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
2094     for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
2095       os << atIneq(i, j) << " ";
2096     }
2097     os << ">= 0\n";
2098   }
2099   os << '\n';
2100 }
2101 
2102 void IntegerRelation::dump() const { print(llvm::errs()); }
2103 
2104 unsigned IntegerPolyhedron::insertId(IdKind kind, unsigned pos, unsigned num) {
2105   assert((kind != IdKind::Domain || num == 0) &&
2106          "Domain has to be zero in a set");
2107   return IntegerRelation::insertId(kind, pos, num);
2108 }
2109